diff options
Diffstat (limited to 'mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp')
-rw-r--r-- | mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp | 54 |
1 files changed, 27 insertions, 27 deletions
diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index cad2deda57e..0abcb4bb850 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -113,7 +113,7 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector<mlir::Type, 4> argumentsType; }; @@ -121,7 +121,7 @@ public: void runOnModule() override { auto &module = getModule(); mlir::ModuleManager moduleManager(&module); - auto *main = moduleManager.getNamedFunction("main"); + auto main = moduleManager.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -140,7 +140,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType<mlir::BoolAttr>("toy.generic")) { if (genericAttr.getValue()) @@ -155,7 +155,7 @@ public: specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist, mlir::ModuleManager &moduleManager) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -171,36 +171,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); moduleManager.insert(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast<int>(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast<int>(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -212,7 +212,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast<ToyArrayType>().isGeneric()) @@ -295,16 +295,16 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = moduleManager.getNamedFunction(calleeName); + auto callee = moduleManager.getNamedFunction(calleeName); if (!callee) { signalPassFailure(); - return f->emitError("Shape inference failed, call to unknown '") + return f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; } auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = moduleManager.getNamedFunction(mangledName); + auto mangledCallee = moduleManager.getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -315,7 +315,7 @@ public: // Found a specialized callee! Let's turn this into a normal call // operation. SmallVector<mlir::Value *, 8> operands(op->getOperands()); - mlir::OpBuilder builder(f->getBody()); + mlir::OpBuilder builder(f.getBody()); builder.setInsertionPoint(op); auto newCall = builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands); @@ -330,12 +330,12 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { signalPassFailure(); - auto diag = f->emitError("Shape inference failed, ") + auto diag = f.emitError("Shape inference failed, ") << opWorklist.size() << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) diag << " - " << *ope << "\n"; @@ -344,24 +344,24 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast<ReturnOp>(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector<mlir::Type, 1> retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector<mlir::Type> argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); |