diff options
author | River Riddle <riverriddle@google.com> | 2019-07-01 10:29:09 -0700 |
---|---|---|
committer | jpienaar <jpienaar@google.com> | 2019-07-01 11:39:00 -0700 |
commit | 54cd6a7e97a226738e2c85b86559918dd9e3cd5d (patch) | |
tree | affa803347d6695be575137d1ad55a055a8021e3 /mlir/lib | |
parent | 84bd67fc4fd116e80f7a66bfadfe9a7fd6fd5e82 (diff) | |
download | bcm5719-llvm-54cd6a7e97a226738e2c85b86559918dd9e3cd5d.tar.gz bcm5719-llvm-54cd6a7e97a226738e2c85b86559918dd9e3cd5d.zip |
NFC: Refactor Function to be value typed.
Move the data members out of Function and into a new impl storage class 'FunctionStorage'. This allows for Function to become value typed, which will greatly simplify the transition of Function to FuncOp(given that FuncOp is also value typed).
PiperOrigin-RevId: 255983022
Diffstat (limited to 'mlir/lib')
57 files changed, 378 insertions, 383 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 016ef43a84a..d7650dcb127 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -303,7 +303,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { if (inserted) { reorderedDims.push_back(v); } - return getAffineDimExpr(iterPos->second, v->getFunction()->getContext()) + return getAffineDimExpr(iterPos->second, v->getFunction().getContext()) .cast<AffineDimExpr>(); } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index 954a01b4843..b4cdeb7d886 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -37,17 +37,16 @@ template class llvm::DomTreeNodeBase<Block>; /// Recalculate the dominance info. template <bool IsPostDom> -void DominanceInfoBase<IsPostDom>::recalculate(Function *function) { +void DominanceInfoBase<IsPostDom>::recalculate(Function function) { dominanceInfos.clear(); // Build the top level function dominance. auto functionDominance = llvm::make_unique<base>(); - functionDominance->recalculate(function->getBody()); - dominanceInfos.try_emplace(&function->getBody(), - std::move(functionDominance)); + functionDominance->recalculate(function.getBody()); + dominanceInfos.try_emplace(&function.getBody(), std::move(functionDominance)); /// Build the dominance for each of the operation regions. - function->walk([&](Operation *op) { + function.walk([&](Operation *op) { for (auto ®ion : op->getRegions()) { // Don't compute dominance if the region is empty. if (region.empty()) diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp index 5177afcee67..75a2fc1a5dc 100644 --- a/mlir/lib/Analysis/OpStats.cpp +++ b/mlir/lib/Analysis/OpStats.cpp @@ -45,7 +45,7 @@ void PrintOpStatsPass::runOnModule() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &fn : getModule()) + for (auto fn : getModule()) fn.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); printSummary(); } diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp index cbda6d40224..473d253cfa2 100644 --- a/mlir/lib/Analysis/TestParallelismDetection.cpp +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -43,7 +43,7 @@ FunctionPassBase *mlir::createParallelismDetectionTestPass() { // Walks the function and emits a note for all 'affine.for' ops detected as // parallel. void TestParallelismDetection::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); OpBuilder b(f.getBody()); f.walk<AffineForOp>([&](AffineForOp forOp) { if (isLoopParallel(forOp)) diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 1330fe0fb94..0d0525145ef 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -53,7 +53,7 @@ public: : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} /// Verify the body of the given function. - LogicalResult verify(Function &fn); + LogicalResult verify(Function fn); /// Verify the given operation. LogicalResult verify(Operation &op); @@ -104,7 +104,7 @@ private: } // end anonymous namespace /// Verify the body of the given function. -LogicalResult OperationVerifier::verify(Function &fn) { +LogicalResult OperationVerifier::verify(Function fn) { // Verify the body first. if (failed(verifyRegion(fn.getBody()))) return failure(); @@ -113,7 +113,7 @@ LogicalResult OperationVerifier::verify(Function &fn) { // check. We do this as a second pass since malformed CFG's can cause // dominator analysis constructure to crash and we want the verifier to be // resilient to malformed code. - DominanceInfo theDomInfo(&fn); + DominanceInfo theDomInfo(fn); domInfo = &theDomInfo; if (failed(verifyDominance(fn.getBody()))) return failure(); @@ -313,7 +313,7 @@ LogicalResult Function::verify() { // Verify this attribute with the defining dialect. if (auto *dialect = opVerifier.getDialectForAttribute(attr)) - if (failed(dialect->verifyFunctionAttribute(this, attr))) + if (failed(dialect->verifyFunctionAttribute(*this, attr))) return failure(); } @@ -331,7 +331,7 @@ LogicalResult Function::verify() { // Verify this attribute with the defining dialect. if (auto *dialect = opVerifier.getDialectForAttribute(attr)) - if (failed(dialect->verifyFunctionArgAttribute(this, i, attr))) + if (failed(dialect->verifyFunctionArgAttribute(*this, i, attr))) return failure(); } } @@ -369,7 +369,7 @@ LogicalResult Operation::verify() { LogicalResult Module::verify() { // Check that all functions are uniquely named. llvm::StringMap<Location> nameToOrigLoc; - for (auto &fn : *this) { + for (auto fn : *this) { auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc()); if (!it.second) return fn.emitError() @@ -379,7 +379,7 @@ LogicalResult Module::verify() { } // Check that each function is correct. - for (auto &fn : *this) + for (auto fn : *this) if (failed(fn.verify())) return failure(); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp index 9d7aeeb6321..022d8c70cc6 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp @@ -64,8 +64,8 @@ public: LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXAsmPrinter(); - for (auto &function : getModule()) { - if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) { + for (auto function : getModule()) { + if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) { continue; } if (failed(translateGpuKernelToCubinAnnotation(function))) @@ -142,7 +142,7 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) { std::unique_ptr<Module> module(builder.createModule()); // TODO(herhut): Also handle called functions. - module->getFunctions().push_back(function.clone()); + module->push_back(function.clone()); auto llvmModule = translateModuleToNVVMIR(*module); auto cubin = convertModuleToCubin(*llvmModule, function); diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp index bd96f396b22..f9d5899456a 100644 --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -118,7 +118,7 @@ private: void declareCudaFunctions(Location loc); Value *setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder); - Value *generateKernelNameConstant(Function *kernelFunction, Location &loc, + Value *generateKernelNameConstant(Function kernelFunction, Location &loc, OpBuilder &builder); void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp); @@ -130,7 +130,7 @@ public: // Cache the used LLVM types. initializeCachedTypes(); - for (auto &func : getModule()) { + for (auto func : getModule()) { func.walk<mlir::gpu::LaunchFuncOp>( [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); } @@ -155,66 +155,66 @@ void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { Module &module = getModule(); Builder builder(&module); if (!module.getNamedFunction(cuModuleLoadName)) { - module.getFunctions().push_back( - new Function(loc, cuModuleLoadName, - builder.getFunctionType( - { - getPointerPointerType(), /* CUmodule *module */ - getPointerType() /* void *cubin */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuModuleLoadName, + builder.getFunctionType( + { + getPointerPointerType(), /* CUmodule *module */ + getPointerType() /* void *cubin */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuModuleGetFunctionName)) { // The helper uses void* instead of CUDA's opaque CUmodule and // CUfunction. - module.getFunctions().push_back( - new Function(loc, cuModuleGetFunctionName, - builder.getFunctionType( - { - getPointerPointerType(), /* void **function */ - getPointerType(), /* void *module */ - getPointerType() /* char *name */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuModuleGetFunctionName, + builder.getFunctionType( + { + getPointerPointerType(), /* void **function */ + getPointerType(), /* void *module */ + getPointerType() /* char *name */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuLaunchKernelName)) { // Other than the CUDA api, the wrappers use uintptr_t to match the // LLVM type if MLIR's index type, which the GPU dialect uses. // Furthermore, they use void* instead of CUDA's opaque CUfunction and // CUstream. - module.getFunctions().push_back( - new Function(loc, cuLaunchKernelName, - builder.getFunctionType( - { - getPointerType(), /* void* f */ - getIntPtrType(), /* intptr_t gridXDim */ - getIntPtrType(), /* intptr_t gridyDim */ - getIntPtrType(), /* intptr_t gridZDim */ - getIntPtrType(), /* intptr_t blockXDim */ - getIntPtrType(), /* intptr_t blockYDim */ - getIntPtrType(), /* intptr_t blockZDim */ - getInt32Type(), /* unsigned int sharedMemBytes */ - getPointerType(), /* void *hstream */ - getPointerPointerType(), /* void **kernelParams */ - getPointerPointerType() /* void **extra */ - }, - getCUResultType()))); + module.push_back(Function::create( + loc, cuLaunchKernelName, + builder.getFunctionType( + { + getPointerType(), /* void* f */ + getIntPtrType(), /* intptr_t gridXDim */ + getIntPtrType(), /* intptr_t gridyDim */ + getIntPtrType(), /* intptr_t gridZDim */ + getIntPtrType(), /* intptr_t blockXDim */ + getIntPtrType(), /* intptr_t blockYDim */ + getIntPtrType(), /* intptr_t blockZDim */ + getInt32Type(), /* unsigned int sharedMemBytes */ + getPointerType(), /* void *hstream */ + getPointerPointerType(), /* void **kernelParams */ + getPointerPointerType() /* void **extra */ + }, + getCUResultType()))); } if (!module.getNamedFunction(cuGetStreamHelperName)) { // Helper function to get the current CUDA stream. Uses void* instead of // CUDAs opaque CUstream. - module.getFunctions().push_back(new Function( + module.push_back(Function::create( loc, cuGetStreamHelperName, builder.getFunctionType({}, getPointerType() /* void *stream */))); } if (!module.getNamedFunction(cuStreamSynchronizeName)) { - module.getFunctions().push_back( - new Function(loc, cuStreamSynchronizeName, - builder.getFunctionType( - { - getPointerType() /* CUstream stream */ - }, - getCUResultType()))); + module.push_back( + Function::create(loc, cuStreamSynchronizeName, + builder.getFunctionType( + { + getPointerType() /* CUstream stream */ + }, + getCUResultType()))); } } @@ -264,14 +264,14 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp, // %0[n] = constant name[n] // %0[n+1] = 0 Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( - Function *kernelFunction, Location &loc, OpBuilder &builder) { + Function kernelFunction, Location &loc, OpBuilder &builder) { // TODO(herhut): Make this a constant once this is supported. auto kernelNameSize = builder.create<LLVM::ConstantOp>( loc, getInt32Type(), - builder.getI32IntegerAttr(kernelFunction->getName().size() + 1)); + builder.getI32IntegerAttr(kernelFunction.getName().size() + 1)); auto kernelName = builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize); - for (auto byte : llvm::enumerate(kernelFunction->getName())) { + for (auto byte : llvm::enumerate(kernelFunction.getName())) { auto index = builder.create<LLVM::ConstantOp>( loc, getInt32Type(), builder.getI32IntegerAttr(byte.index())); auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName, @@ -284,7 +284,7 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant( // Add trailing zero to terminate string. auto index = builder.create<LLVM::ConstantOp>( loc, getInt32Type(), - builder.getI32IntegerAttr(kernelFunction->getName().size())); + builder.getI32IntegerAttr(kernelFunction.getName().size())); auto gep = builder.create<LLVM::GEPOp>(loc, getPointerType(), kernelName, ArrayRef<Value *>{index}); auto value = builder.create<LLVM::ConstantOp>( @@ -326,9 +326,9 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // TODO(herhut): This should rather be a static global once supported. auto kernelFunction = getModule().getNamedFunction(launchOp.kernel()); auto cubinGetter = - kernelFunction->getAttrOfType<FunctionAttr>(kCubinGetterAnnotation); + kernelFunction.getAttrOfType<FunctionAttr>(kCubinGetterAnnotation); if (!cubinGetter) { - kernelFunction->emitError("Missing ") + kernelFunction.emitError("Missing ") << kCubinGetterAnnotation << " attribute."; return signalPassFailure(); } @@ -337,7 +337,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( // Emit the load module call to load the module data. Error checking is done // in the called helper function. auto cuModule = allocatePointer(builder, loc); - Function *cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); + Function cuModuleLoad = getModule().getNamedFunction(cuModuleLoadName); builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()}, builder.getFunctionAttr(cuModuleLoad), ArrayRef<Value *>{cuModule, data.getResult(0)}); @@ -347,14 +347,14 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls( builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule); auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder); auto cuFunction = allocatePointer(builder, loc); - Function *cuModuleGetFunction = + Function cuModuleGetFunction = getModule().getNamedFunction(cuModuleGetFunctionName); builder.create<LLVM::CallOp>( loc, ArrayRef<Type>{getCUResultType()}, builder.getFunctionAttr(cuModuleGetFunction), ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName}); // Grab the global stream needed for execution. - Function *cuGetStreamHelper = + Function cuGetStreamHelper = getModule().getNamedFunction(cuGetStreamHelperName); auto cuStream = builder.create<LLVM::CallOp>( loc, ArrayRef<Type>{getPointerType()}, diff --git a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp index c1d4af380ce..97790a5afce 100644 --- a/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp @@ -53,15 +53,15 @@ constexpr const char *kMallocHelperName = "mcuMalloc"; class GpuGenerateCubinAccessorsPass : public ModulePass<GpuGenerateCubinAccessorsPass> { private: - Function *getMallocHelper(Location loc, Builder &builder) { - Function *result = getModule().getNamedFunction(kMallocHelperName); + Function getMallocHelper(Location loc, Builder &builder) { + Function result = getModule().getNamedFunction(kMallocHelperName); if (!result) { - result = new Function( + result = Function::create( loc, kMallocHelperName, builder.getFunctionType( ArrayRef<Type>{LLVM::LLVMType::getInt32Ty(llvmDialect)}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); - getModule().getFunctions().push_back(result); + getModule().push_back(result); } return result; } @@ -70,18 +70,18 @@ private: // data from blob. As there are currently no global constants, this uses a // sequence of store operations. // TODO(herhut): Use global constants instead. - Function *generateCubinAccessor(Builder &builder, Function &orig, - StringAttr blob) { + Function generateCubinAccessor(Builder &builder, Function &orig, + StringAttr blob) { Location loc = orig.getLoc(); SmallString<128> nameBuffer(orig.getName()); nameBuffer.append(kCubinGetterSuffix); // Generate a function that returns void*. - Function *result = new Function( + Function result = Function::create( loc, mlir::Identifier::get(nameBuffer, &getContext()), builder.getFunctionType(ArrayRef<Type>{}, LLVM::LLVMType::getInt8PtrTy(llvmDialect))); // Insert a body block that just returns the constant. - OpBuilder ob(result->getBody()); + OpBuilder ob(result.getBody()); ob.createBlock(); auto sizeConstant = ob.create<LLVM::ConstantOp>( loc, LLVM::LLVMType::getInt32Ty(llvmDialect), @@ -115,18 +115,18 @@ public: void runOnModule() override { llvmDialect = getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); - Builder builder(getModule().getContext()); + auto &module = getModule(); + Builder builder(&getContext()); - auto &functions = getModule().getFunctions(); + auto functions = module.getFunctions(); for (auto it = functions.begin(); it != functions.end();) { // Move iterator to after the current function so that potential insertion // of the accessor is after the kernel with cubin iself. - Function &orig = *it++; + Function orig = *it++; StringAttr cubinBlob = orig.getAttrOfType<StringAttr>(kCubinAnnotation); if (!cubinBlob) continue; - it = - functions.insert(it, generateCubinAccessor(builder, orig, cubinBlob)); + module.insert(it, generateCubinAccessor(builder, orig, cubinBlob)); } } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 872707842d7..e849f6fd023 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -441,13 +441,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { createIndexConstant(rewriter, op->getLoc(), elementSize)}); // Insert the `malloc` declaration if it is not already present. - Function *mallocFunc = - op->getFunction()->getModule()->getNamedFunction("malloc"); + Function mallocFunc = + op->getFunction().getModule()->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(getIndexType(), getVoidPtrType()); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - op->getFunction()->getModule()->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + op->getFunction().getModule()->push_back(mallocFunc); } // Allocate the underlying buffer and store a pointer to it in the MemRef @@ -502,12 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { OperandAdaptor<DeallocOp> transformed(operands); // Insert the `free` declaration if it is not already present. - Function *freeFunc = - op->getFunction()->getModule()->getNamedFunction("free"); + Function freeFunc = op->getFunction().getModule()->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - op->getFunction()->getModule()->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + op->getFunction().getModule()->push_back(freeFunc); } auto type = transformed.memref()->getType().cast<LLVM::LLVMType>(); @@ -937,7 +937,7 @@ static void ensureDistinctSuccessors(Block &bb) { } void mlir::LLVM::ensureDistinctSuccessors(Module *m) { - for (auto &f : *m) { + for (auto f : *m) { for (auto &bb : f.getBlocks()) { ::ensureDistinctSuccessors(bb); } diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index ff198217bb7..dafc8e711f5 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -365,7 +365,7 @@ struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> { //===----------------------------------------------------------------------===// void LowerUniformRealMathPass::runOnFunction() { - auto &fn = getFunction(); + auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context)); @@ -386,7 +386,7 @@ static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass( //===----------------------------------------------------------------------===// void LowerUniformCastsPass::runOnFunction() { - auto &fn = getFunction(); + auto fn = getFunction(); OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context)); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 9dcc6df6bea..8469fa2ea70 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -106,7 +106,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, void ConvertConstPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context)); applyPatternsGreedily(func, std::move(patterns)); diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index ea8095b791c..0c93146a232 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -95,7 +95,7 @@ public: void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back( llvm::make_unique<ConstFakeQuantRewrite>(context, &hadFailure)); diff --git a/mlir/lib/ExecutionEngine/MemRefUtils.cpp b/mlir/lib/ExecutionEngine/MemRefUtils.cpp index 51636037382..f13b743de0c 100644 --- a/mlir/lib/ExecutionEngine/MemRefUtils.cpp +++ b/mlir/lib/ExecutionEngine/MemRefUtils.cpp @@ -67,10 +67,10 @@ allocMemRefDescriptor(Type type, bool allocateData = true, } llvm::Expected<SmallVector<void *, 8>> -mlir::allocateMemRefArguments(Function *func, float initialValue) { +mlir::allocateMemRefArguments(Function func, float initialValue) { SmallVector<void *, 8> args; - args.reserve(func->getNumArguments()); - for (const auto &arg : func->getArguments()) { + args.reserve(func.getNumArguments()); + for (const auto &arg : func.getArguments()) { auto descriptor = allocMemRefDescriptor(arg->getType(), /*allocateData=*/true, initialValue); @@ -79,10 +79,10 @@ mlir::allocateMemRefArguments(Function *func, float initialValue) { args.push_back(*descriptor); } - if (func->getType().getNumResults() > 1) + if (func.getType().getNumResults() > 1) return make_string_error("functions with more than 1 result not supported"); - for (Type resType : func->getType().getResults()) { + for (Type resType : func.getType().getResults()) { auto descriptor = allocMemRefDescriptor(resType, /*allocateData=*/false); if (!descriptor) return descriptor.takeError(); diff --git a/mlir/lib/GPU/IR/GPUDialect.cpp b/mlir/lib/GPU/IR/GPUDialect.cpp index e39860bddda..5e8090b42b4 100644 --- a/mlir/lib/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/GPU/IR/GPUDialect.cpp @@ -30,9 +30,9 @@ using namespace mlir::gpu; StringRef GPUDialect::getDialectName() { return "gpu"; } -bool GPUDialect::isKernel(Function *function) { +bool GPUDialect::isKernel(Function function) { UnitAttr isKernelAttr = - function->getAttrOfType<UnitAttr>(getKernelFuncAttrName()); + function.getAttrOfType<UnitAttr>(getKernelFuncAttrName()); return static_cast<bool>(isKernelAttr); } @@ -318,7 +318,7 @@ ParseResult LaunchOp::parse(OpAsmParser *parser, OperationState *result) { //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function *kernelFunc, Value *gridSizeX, + Function kernelFunc, Value *gridSizeX, Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { @@ -331,7 +331,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result, } void LaunchFuncOp::build(Builder *builder, OperationState *result, - Function *kernelFunc, KernelDim3 gridSize, + Function kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef<Value *> kernelOperands) { build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, @@ -366,23 +366,23 @@ LogicalResult LaunchFuncOp::verify() { return emitOpError("attribute 'kernel' must be a function"); } - auto *module = getOperation()->getFunction()->getModule(); - Function *kernelFunc = module->getNamedFunction(kernel()); + auto *module = getOperation()->getFunction().getModule(); + Function kernelFunc = module->getNamedFunction(kernel()); if (!kernelFunc) return emitError() << "kernel function '" << kernelAttr << "' is undefined"; - if (!kernelFunc->getAttrOfType<mlir::UnitAttr>( + if (!kernelFunc.getAttrOfType<mlir::UnitAttr>( GPUDialect::getKernelFuncAttrName())) { return emitError("kernel function is missing the '") << GPUDialect::getKernelFuncAttrName() << "' attribute"; } - unsigned numKernelFuncArgs = kernelFunc->getNumArguments(); + unsigned numKernelFuncArgs = kernelFunc.getNumArguments(); if (getNumKernelOperands() != numKernelFuncArgs) { return emitOpError("got ") << getNumKernelOperands() << " kernel operands but expected " << numKernelFuncArgs; } - auto functionType = kernelFunc->getType(); + auto functionType = kernelFunc.getType(); for (unsigned i = 0; i < numKernelFuncArgs; ++i) { if (getKernelOperand(i)->getType() != functionType.getInput(i)) { return emitOpError("type of function argument ") diff --git a/mlir/lib/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/GPU/Transforms/KernelOutlining.cpp index 46363f06f72..f93febcf5da 100644 --- a/mlir/lib/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/GPU/Transforms/KernelOutlining.cpp @@ -40,7 +40,7 @@ static void createForAllDimensions(OpBuilder &builder, Location loc, // Add operations generating block/thread ids and gird/block dimensions at the // beginning of `kernelFunc` and replace uses of the respective function args. -static void injectGpuIndexOperations(Location loc, Function &kernelFunc) { +static void injectGpuIndexOperations(Location loc, Function kernelFunc) { OpBuilder OpBuilder(kernelFunc.getBody()); SmallVector<Value *, 12> indexOps; createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps); @@ -58,20 +58,20 @@ static void injectGpuIndexOperations(Location loc, Function &kernelFunc) { // Outline the `gpu.launch` operation body into a kernel function. Replace // `gpu.return` operations by `std.return` in the generated functions. -static Function *outlineKernelFunc(gpu::LaunchOp launchOp) { +static Function outlineKernelFunc(gpu::LaunchOp launchOp) { Location loc = launchOp.getLoc(); SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes()); FunctionType type = FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); std::string kernelFuncName = - Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str(); - Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type); - outlinedFunc->getBody().takeBody(launchOp.getBody()); + Twine(launchOp.getOperation()->getFunction().getName(), "_kernel").str(); + Function outlinedFunc = Function::create(loc, kernelFuncName, type); + outlinedFunc.getBody().takeBody(launchOp.getBody()); Builder builder(launchOp.getContext()); - outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), - builder.getUnitAttr()); - injectGpuIndexOperations(loc, *outlinedFunc); - outlinedFunc->walk<mlir::gpu::Return>([](mlir::gpu::Return op) { + outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + injectGpuIndexOperations(loc, outlinedFunc); + outlinedFunc.walk<mlir::gpu::Return>([](mlir::gpu::Return op) { OpBuilder replacer(op); replacer.create<ReturnOp>(op.getLoc()); op.erase(); @@ -82,12 +82,12 @@ static Function *outlineKernelFunc(gpu::LaunchOp launchOp) { // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching // `kernelFunc`. static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, - Function &kernelFunc) { + Function kernelFunc) { OpBuilder builder(launchOp); SmallVector<Value *, 4> kernelOperandValues( launchOp.getKernelOperandValues()); builder.create<gpu::LaunchFuncOp>( - launchOp.getLoc(), &kernelFunc, launchOp.getGridSizeOperandValues(), + launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), launchOp.getBlockSizeOperandValues(), kernelOperandValues); launchOp.erase(); } @@ -98,11 +98,11 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { public: void runOnModule() override { ModuleManager moduleManager(&getModule()); - for (auto &func : getModule()) { + for (auto func : getModule()) { func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) { - Function *outlinedFunc = outlineKernelFunc(op); + Function outlinedFunc = outlineKernelFunc(op); moduleManager.insert(outlinedFunc); - convertToLaunchFuncOp(op, *outlinedFunc); + convertToLaunchFuncOp(op, outlinedFunc); }); } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 8e3d5788bb1..346d35af231 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -306,7 +306,7 @@ void ModuleState::initialize(Module *module) { initializeSymbolAliases(); // Walk the module and visit each operation. - for (auto &fn : *module) { + for (auto fn : *module) { visitType(fn.getType()); for (auto attr : fn.getAttrs()) ModuleState::visitAttribute(attr.second); @@ -342,7 +342,7 @@ public: void printAttribute(Attribute attr, bool mayElideType = false); void printType(Type type); - void print(Function *fn); + void print(Function fn); void printLocation(LocationAttr loc); void printAffineMap(AffineMap map); @@ -460,8 +460,8 @@ void ModulePrinter::print(Module *module) { state.printTypeAliases(os); // Print the module. - for (auto &fn : *module) - print(&fn); + for (auto fn : *module) + print(fn); } /// Print a floating point value in a way that the parser will be able to @@ -1186,7 +1186,7 @@ namespace { // CFG and ML functions. class FunctionPrinter : public ModulePrinter, private OpAsmPrinter { public: - FunctionPrinter(Function *function, ModulePrinter &other); + FunctionPrinter(Function function, ModulePrinter &other); // Prints the function as a whole. void print(); @@ -1275,7 +1275,7 @@ protected: void printValueID(Value *value, bool printResultNo = true) const; private: - Function *function; + Function function; /// This is the value ID for each SSA value in the current function. If this /// returns ~0, then the valueID has an entry in valueNames. @@ -1305,10 +1305,10 @@ private: }; } // end anonymous namespace -FunctionPrinter::FunctionPrinter(Function *function, ModulePrinter &other) +FunctionPrinter::FunctionPrinter(Function function, ModulePrinter &other) : ModulePrinter(other), function(function) { - for (auto &block : *function) + for (auto &block : function) numberValuesInBlock(block); } @@ -1419,17 +1419,17 @@ void FunctionPrinter::print() { printFunctionSignature(); // Print out function attributes, if present. - auto attrs = function->getAttrs(); + auto attrs = function.getAttrs(); if (!attrs.empty()) { os << "\n attributes "; printOptionalAttrDict(attrs); } // Print the trailing location. - printTrailingLocation(function->getLoc()); + printTrailingLocation(function.getLoc()); - if (!function->empty()) { - printRegion(function->getBody(), /*printEntryBlockArgs=*/false, + if (!function.empty()) { + printRegion(function.getBody(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); os << "\n"; } @@ -1437,24 +1437,24 @@ void FunctionPrinter::print() { } void FunctionPrinter::printFunctionSignature() { - os << "func @" << function->getName() << '('; + os << "func @" << function.getName() << '('; - auto fnType = function->getType(); - bool isExternal = function->isExternal(); - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) { + auto fnType = function.getType(); + bool isExternal = function.isExternal(); + for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) { if (i > 0) os << ", "; // If this is an external function, don't print argument labels. if (!isExternal) { - printOperand(function->getArgument(i)); + printOperand(function.getArgument(i)); os << ": "; } printType(fnType.getInput(i)); // Print the attributes for this argument. - printOptionalAttrDict(function->getArgAttrs(i)); + printOptionalAttrDict(function.getArgAttrs(i)); } os << ')'; @@ -1662,7 +1662,7 @@ void FunctionPrinter::printSuccessorAndUseList(Operation *term, } // Prints function with initialized module state. -void ModulePrinter::print(Function *fn) { FunctionPrinter(fn, *this).print(); } +void ModulePrinter::print(Function fn) { FunctionPrinter(fn, *this).print(); } //===----------------------------------------------------------------------===// // print and dump methods @@ -1737,13 +1737,13 @@ void Value::print(raw_ostream &os) { void Value::dump() { print(llvm::errs()); } void Operation::print(raw_ostream &os) { - auto *function = getFunction(); + auto function = getFunction(); if (!function) { os << "<<UNLINKED INSTRUCTION>>\n"; return; } - ModuleState state(function->getContext()); + ModuleState state(function.getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(function, modulePrinter).print(this); } @@ -1754,13 +1754,13 @@ void Operation::dump() { } void Block::print(raw_ostream &os) { - auto *function = getFunction(); + auto function = getFunction(); if (!function) { os << "<<UNLINKED BLOCK>>\n"; return; } - ModuleState state(function->getContext()); + ModuleState state(function.getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(function, modulePrinter).print(this); } @@ -1773,14 +1773,14 @@ void Block::printAsOperand(raw_ostream &os, bool printType) { os << "<<UNLINKED BLOCK>>\n"; return; } - ModuleState state(getFunction()->getContext()); + ModuleState state(getFunction().getContext()); ModulePrinter modulePrinter(os, state); FunctionPrinter(getFunction(), modulePrinter).printBlockName(this); } void Function::print(raw_ostream &os) { ModuleState state(getContext()); - ModulePrinter(os, state).print(this); + ModulePrinter(os, state).print(*this); } void Function::dump() { print(llvm::errs()); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 01f9a060bd9..9cbba0fe429 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -249,11 +249,6 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc, // FunctionAttr //===----------------------------------------------------------------------===// -FunctionAttr FunctionAttr::get(Function *value) { - assert(value && "Cannot get FunctionAttr for a null function"); - return get(value->getName(), value->getContext()); -} - FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) { return Base::get(ctx, StandardAttributes::Function, value, NoneType::get(ctx)); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index e7616f6d7d0..134f6e468a0 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -50,7 +50,7 @@ Operation *Block::getContainingOp() { return getParent() ? getParent()->getContainingOp() : nullptr; } -Function *Block::getFunction() { +Function Block::getFunction() { Block *block = this; while (auto *op = block->getContainingOp()) { block = op->getBlock(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 9b30205abdb..89df64260d3 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -177,8 +177,8 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); } -FunctionAttr Builder::getFunctionAttr(Function *value) { - return FunctionAttr::get(value); +FunctionAttr Builder::getFunctionAttr(Function value) { + return getFunctionAttr(value.getName()); } FunctionAttr Builder::getFunctionAttr(StringRef value) { return FunctionAttr::get(value, getContext()); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 4547452eb55..e38b95ff0f7 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectHooks.h" +#include "mlir/IR/Function.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ManagedStatic.h" @@ -68,6 +69,20 @@ Dialect::Dialect(StringRef name, MLIRContext *context) Dialect::~Dialect() {} +/// Verify an attribute from this dialect on the given function. Returns +/// failure if the verification failed, success otherwise. +LogicalResult Dialect::verifyFunctionAttribute(Function, NamedAttribute) { + return success(); +} + +/// Verify an attribute from this dialect on the argument at 'argIndex' for +/// the given function. Returns failure if the verification failed, success +/// otherwise. +LogicalResult Dialect::verifyFunctionArgAttribute(Function, unsigned argIndex, + NamedAttribute) { + return success(); +} + /// Parse an attribute registered to this dialect. Attribute Dialect::parseAttribute(StringRef attrData, Location loc) const { emitError(loc) << "dialect '" << getNamespace() diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 7d17ed1d705..f8835f02c26 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -27,45 +27,50 @@ #include "llvm/ADT/Twine.h" using namespace mlir; +using namespace mlir::detail; -Function::Function(Location location, StringRef name, FunctionType type, - ArrayRef<NamedAttribute> attrs) +FunctionStorage::FunctionStorage(Location location, StringRef name, + FunctionType type, + ArrayRef<NamedAttribute> attrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {} -Function::Function(Location location, StringRef name, FunctionType type, - ArrayRef<NamedAttribute> attrs, - ArrayRef<NamedAttributeList> argAttrs) +FunctionStorage::FunctionStorage(Location location, StringRef name, + FunctionType type, + ArrayRef<NamedAttribute> attrs, + ArrayRef<NamedAttributeList> argAttrs) : name(Identifier::get(name, type.getContext())), location(location), type(type), attrs(attrs), argAttrs(argAttrs), body(this) {} MLIRContext *Function::getContext() { return getType().getContext(); } -Module *llvm::ilist_traits<Function>::getContainingModule() { +Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() { size_t Offset( size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr)))); - iplist<Function> *Anchor(static_cast<iplist<Function> *>(this)); + iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this)); return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset); } /// This is a trait method invoked when a Function is added to a Module. We /// keep the module pointer and module symbol table up to date. -void llvm::ilist_traits<Function>::addNodeToList(Function *function) { - assert(!function->getModule() && "already in a module!"); +void llvm::ilist_traits<FunctionStorage>::addNodeToList( + FunctionStorage *function) { + assert(!function->module && "already in a module!"); function->module = getContainingModule(); } /// This is a trait method invoked when a Function is removed from a Module. /// We keep the module pointer up to date. -void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) { +void llvm::ilist_traits<FunctionStorage>::removeNodeFromList( + FunctionStorage *function) { assert(function->module && "not already in a module!"); function->module = nullptr; } /// This is a trait method invoked when an operation is moved from one block /// to another. We keep the block pointer up to date. -void llvm::ilist_traits<Function>::transferNodesFromList( - ilist_traits<Function> &otherList, function_iterator first, +void llvm::ilist_traits<FunctionStorage>::transferNodesFromList( + ilist_traits<FunctionStorage> &otherList, function_iterator first, function_iterator last) { // If we are transferring functions within the same module, the Module // pointer doesn't need to be updated. @@ -82,8 +87,10 @@ void llvm::ilist_traits<Function>::transferNodesFromList( /// Unlink this function from its Module and delete it. void Function::erase() { - assert(getModule() && "Function has no parent"); - getModule()->getFunctions().erase(this); + if (auto *module = getModule()) + getModule()->functions.erase(impl); + else + delete impl; } /// Emit an error about fatal conditions with this function, reporting up to @@ -111,10 +118,10 @@ InFlightDiagnostic Function::emitRemark(const Twine &message) { /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. -void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { +void Function::cloneInto(Function dest, BlockAndValueMapping &mapper) { // Add the attributes of this function to dest. llvm::MapVector<Identifier, Attribute> newAttrs; - for (auto &attr : dest->getAttrs()) + for (auto &attr : dest.getAttrs()) newAttrs.insert(attr); for (auto &attr : getAttrs()) { auto insertPair = newAttrs.insert(attr); @@ -125,10 +132,10 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { assert((insertPair.second || insertPair.first->second == attr.second) && "the two functions have incompatible attributes"); } - dest->setAttrs(newAttrs.takeVector()); + dest.setAttrs(newAttrs.takeVector()); // Clone the body. - body.cloneInto(&dest->body, mapper); + impl->body.cloneInto(&dest.impl->body, mapper); } /// Create a deep copy of this function and all of its blocks, remapping @@ -136,8 +143,8 @@ void Function::cloneInto(Function *dest, BlockAndValueMapping &mapper) { /// provided (leaving them alone if no entry is present). Replaces references /// to cloned sub-values with the corresponding value that is copied, and adds /// those mappings to the mapper. -Function *Function::clone(BlockAndValueMapping &mapper) { - FunctionType newType = type; +Function Function::clone(BlockAndValueMapping &mapper) { + FunctionType newType = impl->type; // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the @@ -147,23 +154,23 @@ Function *Function::clone(BlockAndValueMapping &mapper) { SmallVector<Type, 4> inputTypes; for (unsigned i = 0, e = getNumArguments(); i != e; ++i) if (!mapper.contains(getArgument(i))) - inputTypes.push_back(type.getInput(i)); - newType = FunctionType::get(inputTypes, type.getResults(), getContext()); + inputTypes.push_back(newType.getInput(i)); + newType = FunctionType::get(inputTypes, newType.getResults(), getContext()); } // Create the new function. - Function *newFunc = new Function(getLoc(), getName(), newType); + Function newFunc = Function::create(getLoc(), getName(), newType); /// Set the argument attributes for arguments that aren't being replaced. for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i) if (isExternalFn || !mapper.contains(getArgument(i))) - newFunc->setArgAttrs(destI++, getArgAttrs(i)); + newFunc.setArgAttrs(destI++, getArgAttrs(i)); /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; } -Function *Function::clone() { +Function Function::clone() { BlockAndValueMapping mapper; return clone(mapper); } @@ -178,7 +185,7 @@ void Function::addEntryBlock() { assert(empty() && "function already has an entry block"); auto *entry = new Block(); push_back(entry); - entry->addArguments(type.getInputs()); + entry->addArguments(impl->type.getInputs()); } void Function::walk(const std::function<void(Operation *)> &callback) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 83171f12d1d..f953cd27a56 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ Operation *Operation::getParentOp() { return block ? block->getContainingOp() : nullptr; } -Function *Operation::getFunction() { +Function Operation::getFunction() { return block ? block->getFunction() : nullptr; } @@ -861,12 +861,13 @@ static LogicalResult verifyBBArguments(Operation::operand_range operands, } static LogicalResult verifyTerminatorSuccessors(Operation *op) { + auto *parent = op->getContainingRegion(); + // Verify that the operands lines up with the BB arguments in the successor. - Function *fn = op->getFunction(); for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { auto *succ = op->getSuccessor(i); - if (succ->getFunction() != fn) - return op->emitError("reference to block defined in another function"); + if (succ->getParent() != parent) + return op->emitError("reference to block defined in another region"); if (failed(verifyBBArguments(op->getSuccessorOperands(i), succ, op))) return failure(); } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 992d9112beb..74c71b7aeac 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -21,7 +21,7 @@ #include "mlir/IR/Operation.h" using namespace mlir; -Region::Region(Function *container) : container(container) {} +Region::Region(Function container) : container(container.impl) {} Region::Region(Operation *container) : container(container) {} @@ -38,7 +38,7 @@ MLIRContext *Region::getContext() { assert(!container.isNull() && "region is not attached to a container"); if (auto *inst = getContainingOp()) return inst->getContext(); - return getContainingFunction()->getContext(); + return getContainingFunction().getContext(); } /// Return a location for this region. This is the location attached to the @@ -47,7 +47,7 @@ Location Region::getLoc() { assert(!container.isNull() && "region is not attached to a container"); if (auto *inst = getContainingOp()) return inst->getLoc(); - return getContainingFunction()->getLoc(); + return getContainingFunction().getLoc(); } Region *Region::getContainingRegion() { @@ -60,8 +60,8 @@ Operation *Region::getContainingOp() { return container.dyn_cast<Operation *>(); } -Function *Region::getContainingFunction() { - return container.dyn_cast<Function *>(); +Function Region::getContainingFunction() { + return container.dyn_cast<detail::FunctionStorage *>(); } bool Region::isProperAncestor(Region *other) { diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index a0819a78fc1..dafbd48f513 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -22,8 +22,8 @@ using namespace mlir; /// Build a symbol table with the symbols within the given module. SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { - for (auto &func : *module) { - auto inserted = symbolTable.insert({func.getName(), &func}); + for (auto func : *module) { + auto inserted = symbolTable.insert({func.getName(), func}); (void)inserted; assert(inserted.second && "expected module to contain uniquely named functions"); @@ -32,34 +32,34 @@ SymbolTable::SymbolTable(Module *module) : context(module->getContext()) { /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. -Function *SymbolTable::lookup(StringRef name) const { +Function SymbolTable::lookup(StringRef name) const { return lookup(Identifier::get(name, context)); } /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. -Function *SymbolTable::lookup(Identifier name) const { +Function SymbolTable::lookup(Identifier name) const { return symbolTable.lookup(name); } /// Erase the given symbol from the table. -void SymbolTable::erase(Function *symbol) { - auto it = symbolTable.find(symbol->getName()); +void SymbolTable::erase(Function symbol) { + auto it = symbolTable.find(symbol.getName()); if (it != symbolTable.end() && it->second == symbol) symbolTable.erase(it); } /// Insert a new symbol into the table, and rename it as necessary to avoid /// collisions. -void SymbolTable::insert(Function *symbol) { +void SymbolTable::insert(Function symbol) { // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. - if (symbolTable.insert({symbol->getName(), symbol}).second) + if (symbolTable.insert({symbol.getName(), symbol}).second) return; // If a conflict was detected, then the function will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(symbol->getName()); + SmallString<128> nameBuffer(symbol.getName()); unsigned originalLength = nameBuffer.size(); // Iteratively try suffixes until we find one that isn't used. We use a @@ -68,6 +68,6 @@ void SymbolTable::insert(Function *symbol) { nameBuffer.resize(originalLength); nameBuffer += '_'; nameBuffer += std::to_string(uniquingCounter++); - symbol->setName(Identifier::get(nameBuffer, context)); - } while (!symbolTable.insert({symbol->getName(), symbol}).second); + symbol.setName(Identifier::get(nameBuffer, context)); + } while (!symbolTable.insert({symbol.getName(), symbol}).second); } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 073c3b369c6..65a98f7ee59 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -30,7 +30,7 @@ Operation *Value::getDefiningOp() { } /// Return the function that this Value is defined in. -Function *Value::getFunction() { +Function Value::getFunction() { switch (getKind()) { case Value::Kind::BlockArgument: return cast<BlockArgument>(this)->getFunction(); @@ -84,7 +84,7 @@ void IRObjectWithUseList::dropAllUses() { //===----------------------------------------------------------------------===// /// Return the function that this argument is defined in. -Function *BlockArgument::getFunction() { +Function BlockArgument::getFunction() { if (auto *owner = getOwner()) return owner->getFunction(); return nullptr; @@ -92,6 +92,6 @@ Function *BlockArgument::getFunction() { /// Returns if the current argument is a function argument. bool BlockArgument::isFunctionArgument() { - auto *containingFn = getFunction(); - return containingFn && &containingFn->front() == getOwner(); + auto containingFn = getFunction(); + return containingFn && &containingFn.front() == getOwner(); } diff --git a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp index 0d3a5ca2756..0dbf63a3ce7 100644 --- a/mlir/lib/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/LLVMIR/IR/LLVMDialect.cpp @@ -816,12 +816,12 @@ void LLVMDialect::printType(Type type, raw_ostream &os) const { } /// Verify LLVMIR function argument attributes. -LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function *func, +LogicalResult LLVMDialect::verifyFunctionArgAttribute(Function func, unsigned argIdx, NamedAttribute argAttr) { // Check that llvm.noalias is a boolean attribute. if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>()) - return func->emitError() + return func.emitError() << "llvm.noalias argument attribute of non boolean type"; return success(); } diff --git a/mlir/lib/Linalg/Transforms/Fusion.cpp b/mlir/lib/Linalg/Transforms/Fusion.cpp index 7ddb7b0c19f..5761cc637b7 100644 --- a/mlir/lib/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Linalg/Transforms/Fusion.cpp @@ -209,7 +209,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView, return true; } -static void fuseLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) { +static void fuseLinalgOps(Function f, ArrayRef<int64_t> tileSizes) { OperationFolder state; DenseSet<Operation *> eraseSet; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index a8099aaff99..5fe4f07613a 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,12 +170,13 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *mallocFunc = module->getNamedFunction("malloc"); + auto *module = op->getFunction().getModule(); + Function mallocFunc = module->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - module->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + module->push_back(mallocFunc); } // Get MLIR types for injecting element pointer. @@ -230,12 +231,12 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *freeFunc = module->getNamedFunction("free"); + auto *module = op->getFunction().getModule(); + Function freeFunc = module->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - module->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + module->push_back(freeFunc); } // Get MLIR types for extracting element pointer. @@ -572,37 +573,37 @@ public: // Create a function definition which takes as argument pointers to the input // types and returns pointers to the output types. -static Function *getLLVMLibraryCallImplDefinition(Function *libFn) { - auto implFnName = (libFn->getName().str() + "_impl"); - auto module = libFn->getModule(); - if (auto *f = module->getNamedFunction(implFnName)) { +static Function getLLVMLibraryCallImplDefinition(Function libFn) { + auto implFnName = (libFn.getName().str() + "_impl"); + auto module = libFn.getModule(); + if (auto f = module->getNamedFunction(implFnName)) { return f; } SmallVector<Type, 4> fnArgTypes; - for (auto t : libFn->getType().getInputs()) { + for (auto t : libFn.getType().getInputs()) { assert(t.isa<LLVMType>() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo()); } - auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext()); + auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext()); // Insert the implementation function definition. - auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType); - module->getFunctions().push_back(implFnDefn); + auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType); + module->push_back(implFnDefn); return implFnDefn; } // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template <typename LinalgOp> -static Function *getLLVMLibraryCallDeclaration(Operation *op, - LLVMTypeConverter &lowering, - PatternRewriter &rewriter) { +static Function getLLVMLibraryCallDeclaration(Operation *op, + LLVMTypeConverter &lowering, + PatternRewriter &rewriter) { assert(isa<LinalgOp>(op)); auto fnName = LinalgOp::getLibraryCallName(); - auto module = op->getFunction()->getModule(); - if (auto *f = module->getNamedFunction(fnName)) { + auto module = op->getFunction().getModule(); + if (auto f = module->getNamedFunction(fnName)) { return f; } @@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op, "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); - auto libFn = new Function(op->getLoc(), fnName, libFnType); - module->getFunctions().push_back(libFn); + auto libFn = Function::create(op->getLoc(), fnName, libFnType); + module->push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; } -static void getLLVMLibraryCallDefinition(Function *fn, +static void getLLVMLibraryCallDefinition(Function fn, LLVMTypeConverter &lowering) { // Generate the implementation function definition. auto implFn = getLLVMLibraryCallImplDefinition(fn); // Generate the function body. - fn->addEntryBlock(); + fn.addEntryBlock(); - OpBuilder builder(fn->getBody()); - edsc::ScopedContext scope(builder, fn->getLoc()); + OpBuilder builder(fn.getBody()); + edsc::ScopedContext scope(builder, fn.getLoc()); SmallVector<Value *, 4> implFnArgs; // Create a constant 1. auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()), - IntegerAttr::get(IndexType::get(fn->getContext()), 1)); - for (auto arg : fn->getArguments()) { + IntegerAttr::get(IndexType::get(fn.getContext()), 1)); + for (auto arg : fn.getArguments()) { // Allocate a stack for storing the argument value. The stack is passed to // the implementation function. auto alloca = @@ -665,17 +666,17 @@ public: return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(Function *fn) { + void addLibraryFnDeclaration(Function fn) { libraryFnDeclarations.push_back(fn); } - ArrayRef<Function *> getLibraryFnDeclarations() { + ArrayRef<Function> getLibraryFnDeclarations() { return libraryFnDeclarations; } private: /// List of library functions declarations needed during dialect conversion - SmallVector<Function *, 2> libraryFnDeclarations; + SmallVector<Function, 2> libraryFnDeclarations; }; } // end anonymous namespace @@ -692,7 +693,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override { // Only emit library call declaration. Fill in the body later. - auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter); + auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter); static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f); auto fAttr = rewriter.getFunctionAttr(f); @@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) { void LowerLinalgToLLVMPass::runOnModule() { auto &module = getModule(); - for (auto &f : module.getFunctions()) { + for (auto f : module.getFunctions()) { lowerLinalgSubViewOps(f); lowerLinalgForToCFG(f); if (failed(lowerAffineConstructs(f))) diff --git a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp index d31ba5bf22d..2e616c35f1d 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLoops.cpp @@ -104,9 +104,8 @@ struct LowerLinalgToLoopsPass : public FunctionPass<LowerLinalgToLoopsPass> { } // namespace void LowerLinalgToLoopsPass::runOnFunction() { - auto &f = getFunction(); OperationFolder state; - f.walk<LinalgOp>([&state](LinalgOp linalgOp) { + getFunction().walk<LinalgOp>([&state](LinalgOp linalgOp) { emitLinalgOpAsLoops(linalgOp, state); linalgOp.getOperation()->erase(); }); diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index c63e1cf197d..2f752b2b637 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -259,7 +259,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<int64_t> tileSizes, return tileLinalgOp(op, tileSizeValues, state); } -static void tileLinalgOps(Function &f, ArrayRef<int64_t> tileSizes) { +static void tileLinalgOps(Function f, ArrayRef<int64_t> tileSizes) { OperationFolder state; f.walk<LinalgOp>([tileSizes, &state](LinalgOp op) { auto opLoopsPair = tileLinalgOp(op, tileSizes, state); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 44f05963727..4af2f093daf 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -254,7 +254,7 @@ public: /// trailing-location ::= location? /// template <typename Owner> - ParseResult parseOptionalTrailingLocation(Owner *owner) { + ParseResult parseOptionalTrailingLocation(Owner &owner) { // If there is a 'loc' we parse a trailing location. if (!getToken().is(Token::kw_loc)) return success(); @@ -263,7 +263,7 @@ public: LocationAttr directLoc; if (parseLocation(directLoc)) return failure(); - owner->setLoc(directLoc); + owner.setLoc(directLoc); return success(); } @@ -2472,8 +2472,8 @@ namespace { /// operations. class OperationParser : public Parser { public: - OperationParser(ParserState &state, Function *function) - : Parser(state), function(function), opBuilder(function->getBody()) {} + OperationParser(ParserState &state, Function function) + : Parser(state), function(function), opBuilder(function.getBody()) {} ~OperationParser(); @@ -2588,7 +2588,7 @@ public: Block *defineBlockNamed(StringRef name, SMLoc loc, Block *existing); private: - Function *function; + Function function; /// Returns the info for a block at the current scope for the given name. std::pair<Block *, SMLoc> &getBlockInfoByName(StringRef name) { @@ -2690,7 +2690,7 @@ ParseResult OperationParser::popSSANameScope() { for (auto entry : forwardRefInCurrentScope) { errors.push_back({entry.second.getPointer(), entry.first}); // Add this block to the top-level region to allow for automatic cleanup. - function->push_back(entry.first); + function.push_back(entry.first); } llvm::array_pod_sort(errors.begin(), errors.end()); @@ -2984,7 +2984,7 @@ ParseResult OperationParser::parseOperation() { } // Try to parse the optional trailing location. - if (parseOptionalTrailingLocation(op)) + if (parseOptionalTrailingLocation(*op)) return failure(); return success(); @@ -4049,17 +4049,17 @@ ParseResult ModuleParser::parseFunc(Module *module) { } // Okay, the function signature was parsed correctly, create the function now. - auto *function = - new Function(getEncodedSourceLocation(loc), name, type, attrs); - module->getFunctions().push_back(function); + auto function = + Function::create(getEncodedSourceLocation(loc), name, type, attrs); + module->push_back(function); // Parse an optional trailing location. if (parseOptionalTrailingLocation(function)) return failure(); // Add the attributes to the function arguments. - for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) - function->setArgAttrs(i, argAttrs[i]); + for (unsigned i = 0, e = function.getNumArguments(); i != e; ++i) + function.setArgAttrs(i, argAttrs[i]); // External functions have no body. if (getToken().isNot(Token::l_brace)) @@ -4076,11 +4076,11 @@ ParseResult ModuleParser::parseFunc(Module *module) { // Parse the function body. auto parser = OperationParser(getState(), function); - if (parser.parseRegion(function->getBody(), entryArgs)) + if (parser.parseRegion(function.getBody(), entryArgs)) return failure(); // Verify that a valid function body was parsed. - if (function->empty()) + if (function.empty()) return emitError(braceLoc, "function must have a body"); return parser.finalize(braceLoc); diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 868d492e094..057f2655207 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -61,12 +61,12 @@ private: static void printIR(const llvm::Any &ir, bool printModuleScope, raw_ostream &out) { // Check for printing at module scope. - if (printModuleScope && llvm::any_isa<Function *>(ir)) { - Function *function = llvm::any_cast<Function *>(ir); + if (printModuleScope && llvm::any_isa<Function>(ir)) { + Function function = llvm::any_cast<Function>(ir); // Print the function name and a newline before the Module. - out << " (function: " << function->getName() << ")\n"; - function->getModule()->print(out); + out << " (function: " << function.getName() << ")\n"; + function.getModule()->print(out); return; } @@ -74,8 +74,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope, out << "\n"; // Print the given function. - if (llvm::any_isa<Function *>(ir)) { - llvm::any_cast<Function *>(ir)->print(out); + if (llvm::any_isa<Function>(ir)) { + llvm::any_cast<Function>(ir).print(out); return; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 2f605b6690b..27ec74c23c2 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -46,8 +46,7 @@ static llvm::cl::opt<bool> void Pass::anchor() {} /// Forwarding function to execute this pass. -LogicalResult FunctionPassBase::run(Function *fn, - FunctionAnalysisManager &fam) { +LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) { // Initialize the pass state. passState.emplace(fn, fam); @@ -115,7 +114,7 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs) } /// Run all of the passes in this manager over the current function. -LogicalResult detail::FunctionPassExecutor::run(Function *function, +LogicalResult detail::FunctionPassExecutor::run(Function function, FunctionAnalysisManager &fam) { // Run each of the held passes. for (auto &pass : passes) @@ -141,7 +140,7 @@ LogicalResult detail::ModulePassExecutor::run(Module *module, /// Utility to run the given function and analysis manager on a provided /// function pass executor. static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, - Function *func, + Function func, FunctionAnalysisManager &fam) { // Run the function pipeline over the provided function. auto result = fpe.run(func, fam); @@ -158,14 +157,14 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, /// module. void ModuleToFunctionPassAdaptor::runOnModule() { ModuleAnalysisManager &mam = getAnalysisManager(); - for (auto &func : getModule()) { + for (auto func : getModule()) { // Skip external functions. if (func.isExternal()) continue; // Run the held function pipeline over the current function. - auto fam = mam.slice(&func); - if (failed(runFunctionPipeline(fpe, &func, fam))) + auto fam = mam.slice(func); + if (failed(runFunctionPipeline(fpe, func, fam))) return signalPassFailure(); // Clear out any computed function analyses. These analyses won't be used @@ -189,10 +188,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() { // Run a prepass over the module to collect the functions to execute a over. // This ensures that an analysis manager exists for each function, as well as // providing a queue of functions to execute over. - std::vector<std::pair<Function *, FunctionAnalysisManager>> funcAMPairs; - for (auto &func : getModule()) + std::vector<std::pair<Function, FunctionAnalysisManager>> funcAMPairs; + for (auto func : getModule()) if (!func.isExternal()) - funcAMPairs.emplace_back(&func, mam.slice(&func)); + funcAMPairs.emplace_back(func, mam.slice(func)); // A parallel diagnostic handler that provides deterministic diagnostic // ordering. @@ -340,8 +339,8 @@ PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const { } /// Create an analysis slice for the given child function. -FunctionAnalysisManager ModuleAnalysisManager::slice(Function *func) { - assert(func->getModule() == moduleAnalyses.getIRUnit() && +FunctionAnalysisManager ModuleAnalysisManager::slice(Function func) { + assert(func.getModule() == moduleAnalyses.getIRUnit() && "function has a different parent module"); auto it = functionAnalyses.find(func); if (it == functionAnalyses.end()) { diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 46addfb8e9c..d2563fb62cd 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -48,7 +48,7 @@ public: FunctionPassExecutor(const FunctionPassExecutor &rhs); /// Run the executor on the given function. - LogicalResult run(Function *function, FunctionAnalysisManager &fam); + LogicalResult run(Function function, FunctionAnalysisManager &fam); /// Add a pass to the current executor. This takes ownership over the provided /// pass pointer. diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 375a64d8f2d..3f26bf075af 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -71,7 +71,7 @@ void AddDefaultStatsPass::runOnFunction() { void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { - auto &func = getFunction(); + auto func = getFunction(); // Insert stats for each argument. for (auto *arg : func.getArguments()) { diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp index dec4ea90db8..169fec3b39a 100644 --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -129,7 +129,7 @@ void InferQuantizedTypesPass::runOnModule() { void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { CAGSlice cag(solverContext); - for (auto &f : getModule()) { + for (auto f : getModule()) { f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); }); } config.finalizeAnchors(cag); diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp index ed3b0956a16..6b376db8516 100644 --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -58,7 +58,7 @@ public: void RemoveInstrumentationPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); auto *context = &getContext(); patterns.push_back( llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context)); diff --git a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp index 3add211fdd5..543b7300af0 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp @@ -36,11 +36,11 @@ using namespace mlir; // block. The created block will be terminated by `std.return`. Block *createOneBlockFunction(Builder builder, Module *module) { auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); - auto *fn = new Function(builder.getUnknownLoc(), "spirv_module", fnType); - module->getFunctions().push_back(fn); + auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType); + module->push_back(fn); auto *block = new Block(); - fn->push_back(block); + fn.push_back(block); OperationState state(builder.getUnknownLoc(), ReturnOp::getOperationName()); ReturnOp::build(&builder, &state); diff --git a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp index ebdcaf73717..33572d5adbe 100644 --- a/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp +++ b/mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp @@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) { // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed // to take in the SPIR-V ModuleOp directly after module and function are // migrated to be general ops. - for (auto &fn : *module) { + for (auto fn : *module) { fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) { if (done) { spirvModule.emitError("found more than one 'spv.module' op"); diff --git a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp index 1a8d79c1790..1ce2b69f055 100644 --- a/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp +++ b/mlir/lib/SPIRV/Transforms/StdOpsToSPIRVConversion.cpp @@ -42,7 +42,7 @@ class StdOpsToSPIRVConversionPass void StdOpsToSPIRVConversionPass::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); populateWithGenerated(func.getContext(), &patterns); applyPatternsGreedily(func, std::move(patterns)); diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 6d5073f1c37..9fc216eef25 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -440,14 +440,14 @@ static LogicalResult verify(CallOp op) { auto fnAttr = op.getAttrOfType<FunctionAttr>("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' function attribute"); - auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. - auto fnType = fn->getType(); + auto fnType = fn.getType(); if (fnType.getNumInputs() != op.getNumOperands()) return op.emitOpError("incorrect number of operands for callee"); @@ -1107,13 +1107,13 @@ static LogicalResult verify(ConstantOp &op) { return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. - auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction( + auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction( fnAttr.getValue()); if (!fn) return op.emitOpError("reference to undefined function 'bar'"); // Check that the referenced function has the correct type. - if (fn->getType() != type) + if (fn.getType() != type) return op.emitOpError("reference to function with mismatched type"); return success(); @@ -1876,10 +1876,10 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - auto *function = op.getOperation()->getFunction(); + auto function = op.getOperation()->getFunction(); // The operand number and types must match the function signature. - const auto &results = function->getType().getResults(); + const auto &results = function.getType().getResults(); if (op.getNumOperands() != results.size()) return op.emitOpError("has ") << op.getNumOperands() diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 74ade942fc7..1e8409246ef 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -69,7 +69,7 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) { // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the // function as a kernel. - for (Function &func : m) { + for (Function func : m) { if (!func.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelFuncAttrName())) continue; @@ -89,20 +89,21 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) { return llvmModule; } -static TranslateFromMLIRRegistration registration( - "mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) { - if (!module) - return true; +static TranslateFromMLIRRegistration + registration("mlir-to-nvvmir", + [](Module *module, llvm::StringRef outputFilename) { + if (!module) + return true; - auto llvmModule = mlir::translateModuleToNVVMIR(*module); - if (!llvmModule) - return true; + auto llvmModule = mlir::translateModuleToNVVMIR(*module); + if (!llvmModule) + return true; - auto file = openOutputFile(outputFilename); - if (!file) - return true; + auto file = openOutputFile(outputFilename); + if (!file) + return true; - llvmModule->print(file->os(), nullptr); - file->keep(); - return false; - }); + llvmModule->print(file->os(), nullptr); + file->keep(); + return false; + }); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ef286cb64fd..4a68ac71ee0 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -375,7 +375,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) { bool ModuleTranslation::convertFunctions() { // Declare all functions first because there may be function calls that form a // call graph with cycles. - for (Function &function : mlirModule) { + for (Function function : mlirModule) { mlir::BoolAttr isVarArgsAttr = function.getAttrOfType<BoolAttr>("std.varargs"); bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue(); @@ -392,7 +392,7 @@ bool ModuleTranslation::convertFunctions() { } // Convert functions. - for (Function &function : mlirModule) { + for (Function function : mlirModule) { // Ignore external functions. if (function.isExternal()) continue; diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 8a2002ce368..394b3ef8db5 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,7 +40,7 @@ struct Canonicalizer : public FunctionPass<Canonicalizer> { void Canonicalizer::runOnFunction() { OwningRewritePatternList patterns; - auto &func = getFunction(); + auto func = getFunction(); // TODO: Instead of adding all known patterns from the whole system lazily add // and cache the canonicalization patterns for ops we see in practice when diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index be60ada6a43..84f00b97e38 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -849,7 +849,7 @@ struct FunctionConverter { /// error, success otherwise. If 'signatureConversion' is provided, the /// arguments of the entry block are updated accordingly. LogicalResult - convertFunction(Function *f, + convertFunction(Function f, TypeConverter::SignatureConversion *signatureConversion); /// Converts the given region starting from the entry block and following the @@ -957,22 +957,22 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, } LogicalResult FunctionConverter::convertFunction( - Function *f, TypeConverter::SignatureConversion *signatureConversion) { + Function f, TypeConverter::SignatureConversion *signatureConversion) { // If this is an external function, there is nothing else to do. - if (f->isExternal()) + if (f.isExternal()) return success(); - DialectConversionRewriter rewriter(f->getBody(), typeConverter); + DialectConversionRewriter rewriter(f.getBody(), typeConverter); // Update the signature of the entry block. if (signatureConversion) { rewriter.argConverter.convertSignature( - &f->getBody().front(), *signatureConversion, rewriter.mapping); + &f.getBody().front(), *signatureConversion, rewriter.mapping); } // Rewrite the function body. if (failed( - convertRegion(rewriter, f->getBody(), /*convertEntryTypes=*/false))) { + convertRegion(rewriter, f.getBody(), /*convertEntryTypes=*/false))) { // Reset any of the generated rewrites. rewriter.discardRewrites(); return failure(); @@ -1124,24 +1124,6 @@ auto ConversionTarget::getOpAction(OperationName op) const // applyConversionPatterns //===----------------------------------------------------------------------===// -namespace { -/// This class represents a function to be converted. It allows for converting -/// the body of functions and the signature in two phases. -struct ConvertedFunction { - ConvertedFunction(Function *fn, FunctionType newType, - ArrayRef<NamedAttributeList> newFunctionArgAttrs) - : fn(fn), newType(newType), - newFunctionArgAttrs(newFunctionArgAttrs.begin(), - newFunctionArgAttrs.end()) {} - - /// The function to convert. - Function *fn; - /// The new type and argument attributes for the function. - FunctionType newType; - SmallVector<NamedAttributeList, 4> newFunctionArgAttrs; -}; -} // end anonymous namespace - /// Convert the given module with the provided conversion patterns and type /// conversion object. If conversion fails for specific functions, those /// functions remains unmodified. @@ -1149,37 +1131,33 @@ LogicalResult mlir::applyConversionPatterns(Module &module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { - std::vector<Function *> allFunctions; - allFunctions.reserve(module.getFunctions().size()); - for (auto &func : module) - allFunctions.push_back(&func); + SmallVector<Function, 32> allFunctions(module.getFunctions()); return applyConversionPatterns(allFunctions, target, converter, std::move(patterns)); } /// Convert the given functions with the provided conversion patterns. LogicalResult mlir::applyConversionPatterns( - ArrayRef<Function *> fns, ConversionTarget &target, + MutableArrayRef<Function> fns, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns) { if (fns.empty()) return success(); // Build the function converter. - FunctionConverter funcConverter(fns.front()->getContext(), target, patterns, - &converter); + auto *ctx = fns.front().getContext(); + FunctionConverter funcConverter(ctx, target, patterns, &converter); // Try to convert each of the functions within the module. - auto *ctx = fns.front()->getContext(); - for (auto *func : fns) { + for (auto func : fns) { // Convert the function type using the type converter. auto conversion = - converter.convertSignature(func->getType(), func->getAllArgAttrs()); + converter.convertSignature(func.getType(), func.getAllArgAttrs()); if (!conversion) return failure(); // Update the function signature. - func->setType(conversion->getConvertedType(ctx)); - func->setAllArgAttrs(conversion->getConvertedArgAttrs()); + func.setType(conversion->getConvertedType(ctx)); + func.setAllArgAttrs(conversion->getConvertedArgAttrs()); // Convert the body of this function. if (failed(funcConverter.convertFunction(func, &*conversion))) @@ -1193,9 +1171,9 @@ LogicalResult mlir::applyConversionPatterns( /// convert as many of the operations within 'fn' as possible given the set of /// patterns. LogicalResult -mlir::applyConversionPatterns(Function &fn, ConversionTarget &target, +mlir::applyConversionPatterns(Function fn, ConversionTarget &target, OwningRewritePatternList &&patterns) { // Convert the body of this function. FunctionConverter converter(fn.getContext(), target, patterns); - return converter.convertFunction(&fn, /*signatureConversion=*/nullptr); + return converter.convertFunction(fn, /*signatureConversion=*/nullptr); } diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 5a926ceaa92..a3aa092b0ec 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -214,7 +214,7 @@ static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs, static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED emitRemarkForBlock(Block &block) { auto *op = block.getContainingOp(); - return op ? op->emitRemark() : block.getFunction()->emitRemark(); + return op ? op->emitRemark() : block.getFunction().emitRemark(); } /// Creates a buffer in the faster memory space for the specified region; @@ -246,8 +246,8 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, Block *block, OpBuilder &b = region.isWrite() ? epilogue : prologue; // Builder to create constants at the top level. - auto *func = block->getFunction(); - OpBuilder top(func->getBody()); + auto func = block->getFunction(); + OpBuilder top(func.getBody()); auto loc = region.loc; auto *memref = region.memref; @@ -751,14 +751,14 @@ uint64_t DmaGeneration::runOnBlock(Block::iterator begin, Block::iterator end) { if (auto *op = block->getContainingOp()) op->emitError(str); else - block->getFunction()->emitError(str); + block->getFunction().emitError(str); } return totalDmaBuffersSizeInBytes; } void DmaGeneration::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); OpBuilder topBuilder(f.getBody()); zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 8d2e75b2dca..77b944f3e01 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -257,7 +257,7 @@ public: // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(Function &f); + bool init(Function f); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -637,7 +637,7 @@ public: // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(Function &f) { +bool MemRefDependenceGraph::init(Function f) { DenseMap<Value *, SetVector<unsigned>> memrefAccesses; // TODO: support multi-block functions. @@ -859,7 +859,7 @@ static Value *createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); // Builder to create constants at the top level. - OpBuilder top(forInst->getFunction()->getBody()); + OpBuilder top(forInst->getFunction().getBody()); // Create new memref type based on slice bounds. auto *oldMemRef = cast<StoreOp>(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef->getType().cast<MemRefType>(); @@ -1750,9 +1750,9 @@ public: }; // Search for siblings which load the same memref function argument. - auto *fn = dstNode->op->getFunction(); - for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { - for (auto *user : fn->getArgument(i)->getUsers()) { + auto fn = dstNode->op->getFunction(); + for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { + for (auto *user : fn.getArgument(i)->getUsers()) { if (auto loadOp = dyn_cast<LoadOp>(user)) { // Gather loops surrounding 'use'. SmallVector<AffineForOp, 4> loops; diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index c1be6e8f6b1..2744e5ca05c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -261,7 +261,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band, // Identify valid and profitable bands of loops to tile. This is currently just // a temporary placeholder to test the mechanics of tiled code generation. // Returns all maximal outermost perfect loop nests to tile. -static void getTileableBands(Function &f, +static void getTileableBands(Function f, std::vector<SmallVector<AffineForOp, 6>> *bands) { // Get maximal perfect nest of 'affine.for' insts starting from root // (inclusive). diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 05953926376..6f13f623fe8 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -92,8 +92,8 @@ void LoopUnroll::runOnFunction() { // Store innermost loops as we walk. std::vector<AffineForOp> loops; - void walkPostOrder(Function *f) { - for (auto &b : *f) + void walkPostOrder(Function f) { + for (auto &b : f) walkPostOrder(b.begin(), b.end()); } @@ -142,10 +142,10 @@ void LoopUnroll::runOnFunction() { ? clUnrollNumRepetitions : 1; // If the call back is provided, we will recurse until no loops are found. - Function &func = getFunction(); + Function func = getFunction(); for (unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) { InnermostLoopGatherer ilg; - ilg.walkPostOrder(&func); + ilg.walkPostOrder(func); auto &loops = ilg.loops; if (loops.empty()) break; diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 77a23b156b0..df30e270fe6 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -726,7 +726,7 @@ public: } // end namespace -LogicalResult mlir::lowerAffineConstructs(Function &function) { +LogicalResult mlir::lowerAffineConstructs(Function function) { OwningRewritePatternList patterns; RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering, AffineLoadLowering, diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index cd92198ac04..f59f1006ec5 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state, } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG((*slice)[0]->getFunction()->print(dbgs())); + LLVM_DEBUG((*slice)[0]->getFunction().print(dbgs())); // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* @@ -667,7 +667,7 @@ static bool emitSlice(MaterializationState *state, /// because we currently disallow vectorization of defs that come from another /// scope. /// TODO(ntv): please document return value. -static bool materialize(Function *f, const SetVector<Operation *> &terminators, +static bool materialize(Function f, const SetVector<Operation *> &terminators, MaterializationState *state) { DenseSet<Operation *> seen; DominanceInfo domInfo(f); @@ -721,7 +721,7 @@ static bool materialize(Function *f, const SetVector<Operation *> &terminators, return true; } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); } return false; } @@ -731,13 +731,13 @@ void MaterializeVectorsPass::runOnFunction() { NestedPatternContext mlContext; // TODO(ntv): Check to see if this supports arbitrary top-level code. - Function *f = &getFunction(); - if (f->getBlocks().size() != 1) + Function f = getFunction(); + if (f.getBlocks().size() != 1) return; using matcher::Op; LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); - LLVM_DEBUG(f->print(dbgs())); + LLVM_DEBUG(f.print(dbgs())); MaterializationState state(hwVectorSize); // Get the hardware vector type. diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index c5676afaf63..1208e2fdd15 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -212,7 +212,7 @@ void MemRefDataFlowOpt::forwardStoreToLoad(LoadOp loadOp) { void MemRefDataFlowOpt::runOnFunction() { // Only supports single block functions at the moment. - Function &f = getFunction(); + Function f = getFunction(); if (f.getBlocks().size() != 1) { markAllAnalysesPreserved(); return; diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp index f97f549c93e..c7c3621781a 100644 --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,7 +29,7 @@ struct StripDebugInfo : public FunctionPass<StripDebugInfo> { } // end anonymous namespace void StripDebugInfo::runOnFunction() { - Function &func = getFunction(); + Function func = getFunction(); auto unknownLoc = UnknownLoc::get(&getContext()); // Strip the debug info from the function and its operations. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 47ca378f324..e185f702d27 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -44,7 +44,7 @@ namespace { /// applies the locally optimal patterns in a roughly "bottom up" way. class GreedyPatternRewriteDriver : public PatternRewriter { public: - explicit GreedyPatternRewriteDriver(Function &fn, + explicit GreedyPatternRewriteDriver(Function fn, OwningRewritePatternList &&patterns) : PatternRewriter(fn.getBody()), matcher(std::move(patterns)) { worklist.reserve(64); @@ -213,7 +213,7 @@ bool GreedyPatternRewriteDriver::simplifyFunction(int maxIterations) { /// patterns in a greedy work-list driven manner. Return true if no more /// patterns can be matched in the result function. /// -bool mlir::applyPatternsGreedily(Function &fn, +bool mlir::applyPatternsGreedily(Function fn, OwningRewritePatternList &&patterns) { GreedyPatternRewriteDriver driver(fn, std::move(patterns)); bool converged = driver.simplifyFunction(maxPatternMatchIterations); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index 728123f71a5..4ddf93c2232 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -125,7 +125,7 @@ LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) { Operation *op = forOp.getOperation(); if (!iv->use_empty()) { if (forOp.hasConstantLowerBound()) { - OpBuilder topBuilder(op->getFunction()->getBody()); + OpBuilder topBuilder(op->getFunction().getBody()); auto constOp = topBuilder.create<ConstantIndexOp>( forOp.getLoc(), forOp.getConstantLowerBound()); iv->replaceAllUsesWith(constOp); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 39a05d8c300..3fca26bef19 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -1194,7 +1194,7 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnFunction() { - Function &f = getFunction(); + Function f = getFunction(); if (!fastestVaryingPattern.empty() && fastestVaryingPattern.size() != vectorSizes.size()) { f.emitRemark("Fastest varying pattern specified with different size than " @@ -1220,7 +1220,7 @@ void Vectorize::runOnFunction() { unsigned patternDepth = pat.getDepth(); SmallVector<NestedMatch, 8> matches; - pat.match(&f, &matches); + pat.match(f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { diff --git a/mlir/lib/Transforms/ViewFunctionGraph.cpp b/mlir/lib/Transforms/ViewFunctionGraph.cpp index 1f2ab69409e..3c1a1b3b481 100644 --- a/mlir/lib/Transforms/ViewFunctionGraph.cpp +++ b/mlir/lib/Transforms/ViewFunctionGraph.cpp @@ -53,13 +53,13 @@ std::string DOTGraphTraits<Function *>::getNodeLabel(Block *Block, Function *) { } // end namespace llvm -void mlir::viewGraph(Function &function, const llvm::Twine &name, +void mlir::viewGraph(Function function, const llvm::Twine &name, bool shortNames, const llvm::Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(&function, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function, +llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function function, bool shortNames, const llvm::Twine &title) { return llvm::WriteGraph(os, &function, shortNames, title); } |