diff options
Diffstat (limited to 'mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp')
-rw-r--r-- | mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 69 |
1 files changed, 35 insertions, 34 deletions
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index a8099aaff99..5fe4f07613a 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -170,12 +170,13 @@ public: LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *mallocFunc = module->getNamedFunction("malloc"); + auto *module = op->getFunction().getModule(); + Function mallocFunc = module->getNamedFunction("malloc"); if (!mallocFunc) { auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy); - mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); - module->getFunctions().push_back(mallocFunc); + mallocFunc = + Function::create(rewriter.getUnknownLoc(), "malloc", mallocType); + module->push_back(mallocFunc); } // Get MLIR types for injecting element pointer. @@ -230,12 +231,12 @@ public: auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. - auto *module = op->getFunction()->getModule(); - Function *freeFunc = module->getNamedFunction("free"); + auto *module = op->getFunction().getModule(); + Function freeFunc = module->getNamedFunction("free"); if (!freeFunc) { auto freeType = rewriter.getFunctionType(voidPtrTy, {}); - freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); - module->getFunctions().push_back(freeFunc); + freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType); + module->push_back(freeFunc); } // Get MLIR types for extracting element pointer. @@ -572,37 +573,37 @@ public: // Create a function definition which takes as argument pointers to the input // types and returns pointers to the output types. -static Function *getLLVMLibraryCallImplDefinition(Function *libFn) { - auto implFnName = (libFn->getName().str() + "_impl"); - auto module = libFn->getModule(); - if (auto *f = module->getNamedFunction(implFnName)) { +static Function getLLVMLibraryCallImplDefinition(Function libFn) { + auto implFnName = (libFn.getName().str() + "_impl"); + auto module = libFn.getModule(); + if (auto f = module->getNamedFunction(implFnName)) { return f; } SmallVector<Type, 4> fnArgTypes; - for (auto t : libFn->getType().getInputs()) { + for (auto t : libFn.getType().getInputs()) { assert(t.isa<LLVMType>() && "Expected LLVM Type for argument while generating library Call " "Implementation Definition"); fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo()); } - auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext()); + auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext()); // Insert the implementation function definition. - auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType); - module->getFunctions().push_back(implFnDefn); + auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType); + module->push_back(implFnDefn); return implFnDefn; } // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template <typename LinalgOp> -static Function *getLLVMLibraryCallDeclaration(Operation *op, - LLVMTypeConverter &lowering, - PatternRewriter &rewriter) { +static Function getLLVMLibraryCallDeclaration(Operation *op, + LLVMTypeConverter &lowering, + PatternRewriter &rewriter) { assert(isa<LinalgOp>(op)); auto fnName = LinalgOp::getLibraryCallName(); - auto module = op->getFunction()->getModule(); - if (auto *f = module->getNamedFunction(fnName)) { + auto module = op->getFunction().getModule(); + if (auto f = module->getNamedFunction(fnName)) { return f; } @@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op, "Library call for linalg operation can be generated only for ops that " "have void return types"); auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); - auto libFn = new Function(op->getLoc(), fnName, libFnType); - module->getFunctions().push_back(libFn); + auto libFn = Function::create(op->getLoc(), fnName, libFnType); + module->push_back(libFn); // Return after creating the function definition. The body will be created // later. return libFn; } -static void getLLVMLibraryCallDefinition(Function *fn, +static void getLLVMLibraryCallDefinition(Function fn, LLVMTypeConverter &lowering) { // Generate the implementation function definition. auto implFn = getLLVMLibraryCallImplDefinition(fn); // Generate the function body. - fn->addEntryBlock(); + fn.addEntryBlock(); - OpBuilder builder(fn->getBody()); - edsc::ScopedContext scope(builder, fn->getLoc()); + OpBuilder builder(fn.getBody()); + edsc::ScopedContext scope(builder, fn.getLoc()); SmallVector<Value *, 4> implFnArgs; // Create a constant 1. auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()), - IntegerAttr::get(IndexType::get(fn->getContext()), 1)); - for (auto arg : fn->getArguments()) { + IntegerAttr::get(IndexType::get(fn.getContext()), 1)); + for (auto arg : fn.getArguments()) { // Allocate a stack for storing the argument value. The stack is passed to // the implementation function. auto alloca = @@ -665,17 +666,17 @@ public: return convertLinalgType(t, *this); } - void addLibraryFnDeclaration(Function *fn) { + void addLibraryFnDeclaration(Function fn) { libraryFnDeclarations.push_back(fn); } - ArrayRef<Function *> getLibraryFnDeclarations() { + ArrayRef<Function> getLibraryFnDeclarations() { return libraryFnDeclarations; } private: /// List of library functions declarations needed during dialect conversion - SmallVector<Function *, 2> libraryFnDeclarations; + SmallVector<Function, 2> libraryFnDeclarations; }; } // end anonymous namespace @@ -692,7 +693,7 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override { // Only emit library call declaration. Fill in the body later. - auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter); + auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter); static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f); auto fAttr = rewriter.getFunctionAttr(f); @@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) { void LowerLinalgToLLVMPass::runOnModule() { auto &module = getModule(); - for (auto &f : module.getFunctions()) { + for (auto f : module.getFunctions()) { lowerLinalgSubViewOps(f); lowerLinalgForToCFG(f); if (failed(lowerAffineConstructs(f))) |