summaryrefslogtreecommitdiffstats
path: root/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp')
-rw-r--r--mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp54
1 files changed, 27 insertions, 27 deletions
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
index cad2deda57e..0abcb4bb850 100644
--- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
+++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
@@ -113,7 +113,7 @@ public:
// function to process, the mangled name for this specialization, and the
// types of the arguments on which to specialize.
struct FunctionToSpecialize {
- mlir::Function *function;
+ mlir::Function function;
std::string mangledName;
SmallVector<mlir::Type, 4> argumentsType;
};
@@ -121,7 +121,7 @@ public:
void runOnModule() override {
auto &module = getModule();
mlir::ModuleManager moduleManager(&module);
- auto *main = moduleManager.getNamedFunction("main");
+ auto main = moduleManager.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
@@ -140,7 +140,7 @@ public:
// Delete any generic function left
// FIXME: we may want this as a separate pass.
- for (mlir::Function &function : llvm::make_early_inc_range(module)) {
+ for (mlir::Function function : llvm::make_early_inc_range(module)) {
if (auto genericAttr =
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
if (genericAttr.getValue())
@@ -155,7 +155,7 @@ public:
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
mlir::ModuleManager &moduleManager) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
- mlir::Function *f = functionToSpecialize.function;
+ mlir::Function f = functionToSpecialize.function;
// Check if cloning for specialization is needed (usually anything but main)
// We will create a new function with the concrete types for the parameters
@@ -171,36 +171,36 @@ public:
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
{ToyArrayType::get(&getContext())},
&getContext());
- auto *newFunction = new mlir::Function(
- f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
+ auto newFunction = mlir::Function::create(
+ f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs());
moduleManager.insert(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
- f->cloneInto(newFunction, mapper);
+ f.cloneInto(newFunction, mapper);
LLVM_DEBUG({
llvm::dbgs() << "====== Cloned : \n";
- f->dump();
+ f.dump();
llvm::dbgs() << "====== Into : \n";
- newFunction->dump();
+ newFunction.dump();
});
f = newFunction;
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+ f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// Remap the entry-block arguments
// FIXME: this seems like a bug in `cloneInto()` above?
- auto &entryBlock = f->getBlocks().front();
+ auto &entryBlock = f.getBlocks().front();
int blockArgSize = entryBlock.getArguments().size();
- assert(blockArgSize == static_cast<int>(f->getType().getInputs().size()));
- entryBlock.addArguments(f->getType().getInputs());
+ assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
+ entryBlock.addArguments(f.getType().getInputs());
auto argList = entryBlock.getArguments();
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
entryBlock.eraseArgument(0);
}
- assert(succeeded(f->verify()));
+ assert(succeeded(f.verify()));
}
LLVM_DEBUG(llvm::dbgs()
- << "Run shape inference on : '" << f->getName() << "'\n");
+ << "Run shape inference on : '" << f.getName() << "'\n");
auto *toyDialect = getContext().getRegisteredDialect("toy");
if (!toyDialect) {
@@ -212,7 +212,7 @@ public:
// Populate the worklist with the operations that need shape inference:
// these are the Toy operations that return a generic array.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
- f->walk([&](mlir::Operation *op) {
+ f.walk([&](mlir::Operation *op) {
if (op->getDialect() == toyDialect) {
if (op->getNumResults() == 1 &&
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
@@ -295,16 +295,16 @@ public:
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
- auto *callee = moduleManager.getNamedFunction(calleeName);
+ auto callee = moduleManager.getNamedFunction(calleeName);
if (!callee) {
signalPassFailure();
- return f->emitError("Shape inference failed, call to unknown '")
+ return f.emitError("Shape inference failed, call to unknown '")
<< calleeName << "'";
}
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
- auto *mangledCallee = moduleManager.getNamedFunction(mangledName);
+ auto mangledCallee = moduleManager.getNamedFunction(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
@@ -315,7 +315,7 @@ public:
// Found a specialized callee! Let's turn this into a normal call
// operation.
SmallVector<mlir::Value *, 8> operands(op->getOperands());
- mlir::OpBuilder builder(f->getBody());
+ mlir::OpBuilder builder(f.getBody());
builder.setInsertionPoint(op);
auto newCall =
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
@@ -330,12 +330,12 @@ public:
// Done with inference on this function, removing it from the worklist.
funcWorklist.pop_back();
// Mark the function as non-generic now that inference has succeeded
- f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
+ f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
signalPassFailure();
- auto diag = f->emitError("Shape inference failed, ")
+ auto diag = f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
for (auto *ope : opWorklist)
diag << " - " << *ope << "\n";
@@ -344,24 +344,24 @@ public:
// Finally, update the return type of the function based on the argument to
// the return operation.
- for (auto &block : f->getBlocks()) {
+ for (auto &block : f.getBlocks()) {
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
if (!ret)
continue;
if (ret.getNumOperands() &&
- f->getType().getResult(0) == ret.getOperand()->getType())
+ f.getType().getResult(0) == ret.getOperand()->getType())
// type match, we're done
break;
SmallVector<mlir::Type, 1> retTy;
if (ret.getNumOperands())
retTy.push_back(ret.getOperand()->getType());
std::vector<mlir::Type> argumentsType;
- for (auto arg : f->getArguments())
+ for (auto arg : f.getArguments())
argumentsType.push_back(arg->getType());
auto newType =
mlir::FunctionType::get(argumentsType, retTy, &getContext());
- f->setType(newType);
- assert(succeeded(f->verify()));
+ f.setType(newType);
+ assert(succeeded(f.verify()));
break;
}
return mlir::success();
OpenPOWER on IntegriCloud