summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp')
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp69
1 files changed, 35 insertions, 34 deletions
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index a8099aaff99..5fe4f07613a 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -170,12 +170,13 @@ public:
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
- auto *module = op->getFunction()->getModule();
- Function *mallocFunc = module->getNamedFunction("malloc");
+ auto *module = op->getFunction().getModule();
+ Function mallocFunc = module->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
- mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
- module->getFunctions().push_back(mallocFunc);
+ mallocFunc =
+ Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
+ module->push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
@@ -230,12 +231,12 @@ public:
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
- auto *module = op->getFunction()->getModule();
- Function *freeFunc = module->getNamedFunction("free");
+ auto *module = op->getFunction().getModule();
+ Function freeFunc = module->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
- freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
- module->getFunctions().push_back(freeFunc);
+ freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
+ module->push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
@@ -572,37 +573,37 @@ public:
// Create a function definition which takes as argument pointers to the input
// types and returns pointers to the output types.
-static Function *getLLVMLibraryCallImplDefinition(Function *libFn) {
- auto implFnName = (libFn->getName().str() + "_impl");
- auto module = libFn->getModule();
- if (auto *f = module->getNamedFunction(implFnName)) {
+static Function getLLVMLibraryCallImplDefinition(Function libFn) {
+ auto implFnName = (libFn.getName().str() + "_impl");
+ auto module = libFn.getModule();
+ if (auto f = module->getNamedFunction(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
- for (auto t : libFn->getType().getInputs()) {
+ for (auto t : libFn.getType().getInputs()) {
assert(t.isa<LLVMType>() &&
"Expected LLVM Type for argument while generating library Call "
"Implementation Definition");
fnArgTypes.push_back(t.cast<LLVMType>().getPointerTo());
}
- auto implFnType = FunctionType::get(fnArgTypes, {}, libFn->getContext());
+ auto implFnType = FunctionType::get(fnArgTypes, {}, libFn.getContext());
// Insert the implementation function definition.
- auto implFnDefn = new Function(libFn->getLoc(), implFnName, implFnType);
- module->getFunctions().push_back(implFnDefn);
+ auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
+ module->push_back(implFnDefn);
return implFnDefn;
}
// Get function definition for the LinalgOp. If it doesn't exist, insert a
// definition.
template <typename LinalgOp>
-static Function *getLLVMLibraryCallDeclaration(Operation *op,
- LLVMTypeConverter &lowering,
- PatternRewriter &rewriter) {
+static Function getLLVMLibraryCallDeclaration(Operation *op,
+ LLVMTypeConverter &lowering,
+ PatternRewriter &rewriter) {
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
- auto module = op->getFunction()->getModule();
- if (auto *f = module->getNamedFunction(fnName)) {
+ auto module = op->getFunction().getModule();
+ if (auto f = module->getNamedFunction(fnName)) {
return f;
}
@@ -618,29 +619,29 @@ static Function *getLLVMLibraryCallDeclaration(Operation *op,
"Library call for linalg operation can be generated only for ops that "
"have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
- auto libFn = new Function(op->getLoc(), fnName, libFnType);
- module->getFunctions().push_back(libFn);
+ auto libFn = Function::create(op->getLoc(), fnName, libFnType);
+ module->push_back(libFn);
// Return after creating the function definition. The body will be created
// later.
return libFn;
}
-static void getLLVMLibraryCallDefinition(Function *fn,
+static void getLLVMLibraryCallDefinition(Function fn,
LLVMTypeConverter &lowering) {
// Generate the implementation function definition.
auto implFn = getLLVMLibraryCallImplDefinition(fn);
// Generate the function body.
- fn->addEntryBlock();
+ fn.addEntryBlock();
- OpBuilder builder(fn->getBody());
- edsc::ScopedContext scope(builder, fn->getLoc());
+ OpBuilder builder(fn.getBody());
+ edsc::ScopedContext scope(builder, fn.getLoc());
SmallVector<Value *, 4> implFnArgs;
// Create a constant 1.
auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
- IntegerAttr::get(IndexType::get(fn->getContext()), 1));
- for (auto arg : fn->getArguments()) {
+ IntegerAttr::get(IndexType::get(fn.getContext()), 1));
+ for (auto arg : fn.getArguments()) {
// Allocate a stack for storing the argument value. The stack is passed to
// the implementation function.
auto alloca =
@@ -665,17 +666,17 @@ public:
return convertLinalgType(t, *this);
}
- void addLibraryFnDeclaration(Function *fn) {
+ void addLibraryFnDeclaration(Function fn) {
libraryFnDeclarations.push_back(fn);
}
- ArrayRef<Function *> getLibraryFnDeclarations() {
+ ArrayRef<Function> getLibraryFnDeclarations() {
return libraryFnDeclarations;
}
private:
/// List of library functions declarations needed during dialect conversion
- SmallVector<Function *, 2> libraryFnDeclarations;
+ SmallVector<Function, 2> libraryFnDeclarations;
};
} // end anonymous namespace
@@ -692,7 +693,7 @@ public:
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Only emit library call declaration. Fill in the body later.
- auto *f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+ auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getFunctionAttr(f);
@@ -803,7 +804,7 @@ static void lowerLinalgForToCFG(Function &f) {
void LowerLinalgToLLVMPass::runOnModule() {
auto &module = getModule();
- for (auto &f : module.getFunctions()) {
+ for (auto f : module.getFunctions()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
if (failed(lowerAffineConstructs(f)))
OpenPOWER on IntegriCloud