diff options
Diffstat (limited to 'mlir/examples')
18 files changed, 221 insertions, 214 deletions
diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index ddd6df9fb89..1f129c6b283 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -57,15 +57,15 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, } /// A basic function builder -inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name, - llvm::ArrayRef<mlir::Type> types, - llvm::ArrayRef<mlir::Type> resultTypes) { +inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name, + llvm::ArrayRef<mlir::Type> types, + llvm::ArrayRef<mlir::Type> resultTypes) { auto *context = module.getContext(); - auto *function = new mlir::Function( + auto function = mlir::Function::create( mlir::UnknownLoc::get(context), name, mlir::FunctionType::get({types}, resultTypes, context)); - function->addEntryBlock(); - module.getFunctions().push_back(function); + function.addEntryBlock(); + module.push_back(function); return function; } @@ -83,19 +83,19 @@ inline std::unique_ptr<mlir::PassManager> cleanupPassManager() { /// llvm::outs() for FileCheck'ing. /// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() /// which will make the associated FileCheck test fail. -inline void cleanupAndPrintFunction(mlir::Function *f) { +inline void cleanupAndPrintFunction(mlir::Function f) { bool printToOuts = true; - auto check = [f, &printToOuts](mlir::LogicalResult result) { + auto check = [&f, &printToOuts](mlir::LogicalResult result) { if (failed(result)) { - f->emitError("Verification and cleanup passes failed"); + f.emitError("Verification and cleanup passes failed"); printToOuts = false; } }; auto pm = cleanupPassManager(); - check(f->getModule()->verify()); - check(pm->run(f->getModule())); + check(f.getModule()->verify()); + check(pm->run(f.getModule())); if (printToOuts) - f->print(llvm::outs()); + f.print(llvm::outs()); } /// Helper class to sugar building loop nests from indexings that appear in diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index a415daebdf5..9534711f1f4 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -36,14 +36,14 @@ TEST_FUNC(linalg_ops) { MLIRContext context; Module module(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function *f = + mlir::Function f = makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)), + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), @@ -75,14 +75,14 @@ TEST_FUNC(linalg_ops_folded_slices) { MLIRContext context; Module module(&context); auto indexType = mlir::IndexType::get(&context); - mlir::Function *f = makeFunction(module, "linalg_ops_folded_slices", - {indexType, indexType, indexType}, {}); + mlir::Function f = makeFunction(module, "linalg_ops_folded_slices", + {indexType, indexType, indexType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), K(f->getArgument(2)), + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)), K(f.getArgument(2)), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), @@ -104,7 +104,7 @@ TEST_FUNC(linalg_ops_folded_slices) { // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg.view<f32> // clang-format on - f->walk<SliceOp>([](SliceOp slice) { + f.walk<SliceOp>([](SliceOp slice) { auto *sliceResult = slice.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index 37d1b51f53e..23d1cfef5dc 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -37,26 +37,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -67,7 +67,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(foo) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(module); diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index f02aef920e4..8b04344b19e 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -34,26 +34,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + mlir::OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -64,7 +64,7 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_as_matvec) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); // clang-format off @@ -82,7 +82,7 @@ TEST_FUNC(matmul_as_matvec) { TEST_FUNC(matmul_as_dot) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot"); lowerToFinerGrainedTensorContraction(f); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); @@ -103,7 +103,7 @@ TEST_FUNC(matmul_as_dot) { TEST_FUNC(matmul_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); composeSliceOps(f); // clang-format off @@ -135,7 +135,7 @@ TEST_FUNC(matmul_as_loops) { TEST_FUNC(matmul_as_matvec_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops"); lowerToFinerGrainedTensorContraction(f); lowerToLoops(f); @@ -166,14 +166,14 @@ TEST_FUNC(matmul_as_matvec_as_loops) { TEST_FUNC(matmul_as_matvec_as_affine) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine"); lowerToFinerGrainedTensorContraction(f); composeSliceOps(f); lowerToLoops(f); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f->getModule()))) + if (succeeded(pm.run(f.getModule()))) cleanupAndPrintFunction(f); // clang-format off diff --git a/mlir/examples/Linalg/Linalg3/Execution.cpp b/mlir/examples/Linalg/Linalg3/Execution.cpp index 00d571cbc99..94b233a56b0 100644 --- a/mlir/examples/Linalg/Linalg3/Execution.cpp +++ b/mlir/examples/Linalg/Linalg3/Execution.cpp @@ -37,26 +37,26 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - mlir::OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + mlir::OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -110,7 +110,7 @@ TEST_FUNC(execution) { // dialect through partial conversions. MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops"); lowerToLoops(f); convertLinalg3ToLLVM(module); diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index 9af528e8c51..6c0aec0b000 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -55,11 +55,11 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, /// Traverses `f` and rewrites linalg.slice, and the operations it depends on, /// to only use linalg.view operations. -void composeSliceOps(mlir::Function *f); +void composeSliceOps(mlir::Function f); /// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) /// as linalg.matvec(resp. linalg.dot). -void lowerToFinerGrainedTensorContraction(mlir::Function *f); +void lowerToFinerGrainedTensorContraction(mlir::Function f); /// Operation-wise writing of linalg operations to loop form. /// It is the caller's responsibility to erase the `op` if necessary. @@ -69,7 +69,7 @@ llvm::Optional<llvm::SmallVector<mlir::AffineForOp, 4>> writeAsLoops(mlir::Operation *op); /// Traverses `f` and rewrites linalg operations in loop form. -void lowerToLoops(mlir::Function *f); +void lowerToLoops(mlir::Function f); /// Creates a pass that rewrites linalg.load and linalg.store to affine.load and /// affine.store operations. diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 7b559bf2f21..96b0f371ef1 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -148,7 +148,7 @@ static void populateLinalg3ToLLVMConversionPatterns( void linalg::convertLinalg3ToLLVM(Module &module) { // Remove affine constructs. - for (auto &func : module) { + for (auto func : module) { auto rr = lowerAffineConstructs(func); (void)rr; assert(succeeded(rr) && "affine loop lowering failed"); diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index d5c8641acbe..7b9e5ffee96 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -35,8 +35,8 @@ using namespace mlir::edsc::intrinsics; using namespace linalg; using namespace linalg::intrinsics; -void linalg::composeSliceOps(mlir::Function *f) { - f->walk<SliceOp>([](SliceOp sliceOp) { +void linalg::composeSliceOps(mlir::Function f) { + f.walk<SliceOp>([](SliceOp sliceOp) { auto *sliceResult = sliceOp.getResult(); auto viewOp = emitAndReturnFullyComposedView(sliceResult); sliceResult->replaceAllUsesWith(viewOp.getResult()); @@ -44,8 +44,8 @@ void linalg::composeSliceOps(mlir::Function *f) { }); } -void linalg::lowerToFinerGrainedTensorContraction(mlir::Function *f) { - f->walk([](Operation *op) { +void linalg::lowerToFinerGrainedTensorContraction(mlir::Function f) { + f.walk([](Operation *op) { if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) { matmulOp.writeAsFinerGrainTensorContraction(); } else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) { @@ -211,8 +211,8 @@ linalg::writeAsLoops(Operation *op) { return llvm::None; } -void linalg::lowerToLoops(mlir::Function *f) { - f->walk([](Operation *op) { +void linalg::lowerToLoops(mlir::Function f) { + f.walk([](Operation *op) { if (writeAsLoops(op)) op->erase(); }); diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index cdc05a1cc21..873e57e78f3 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -34,27 +34,27 @@ using namespace linalg; using namespace linalg::common; using namespace linalg::intrinsics; -Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { +Function makeFunctionWithAMatmulOp(Module &module, StringRef name) { MLIRContext *context = module.getContext(); auto dynamic2DMemRefType = floatMemRefType<2>(context); - mlir::Function *f = linalg::common::makeFunction( + mlir::Function f = linalg::common::makeFunction( module, name, {dynamic2DMemRefType, dynamic2DMemRefType, dynamic2DMemRefType}, {}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle - M = dim(f->getArgument(0), 0), - N = dim(f->getArgument(2), 1), - K = dim(f->getArgument(0), 1), + M = dim(f.getArgument(0), 0), + N = dim(f.getArgument(2), 1), + K = dim(f.getArgument(0), 1), rM = range(constant_index(0), M, constant_index(1)), rN = range(constant_index(0), N, constant_index(1)), rK = range(constant_index(0), K, constant_index(1)), - vA = view(f->getArgument(0), {rM, rK}), - vB = view(f->getArgument(1), {rK, rN}), - vC = view(f->getArgument(2), {rM, rN}); + vA = view(f.getArgument(0), {rM, rK}), + vB = view(f.getArgument(1), {rK, rN}), + vC = view(f.getArgument(2), {rM, rN}); matmul(vA, vB, vC); ret(); // clang-format on @@ -65,11 +65,11 @@ Function *makeFunctionWithAMatmulOp(Module &module, StringRef name) { TEST_FUNC(matmul_tiled_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops"); lowerToTiledLoops(f, {8, 9}); PassManager pm; pm.addPass(createLowerLinalgLoadStorePass()); - if (succeeded(pm.run(f->getModule()))) + if (succeeded(pm.run(f.getModule()))) cleanupAndPrintFunction(f); // clang-format off @@ -96,10 +96,10 @@ TEST_FUNC(matmul_tiled_loops) { TEST_FUNC(matmul_tiled_views) { MLIRContext context; Module module(&context); - mlir::Function *f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); - OpBuilder b(f->getBody()); - lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8), - b.create<ConstantIndexOp>(f->getLoc(), 9)}); + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views"); + OpBuilder b(f.getBody()); + lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8), + b.create<ConstantIndexOp>(f.getLoc(), 9)}); composeSliceOps(f); // clang-format off @@ -125,11 +125,11 @@ TEST_FUNC(matmul_tiled_views) { TEST_FUNC(matmul_tiled_views_as_loops) { MLIRContext context; Module module(&context); - mlir::Function *f = + mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops"); - OpBuilder b(f->getBody()); - lowerToTiledViews(f, {b.create<ConstantIndexOp>(f->getLoc(), 8), - b.create<ConstantIndexOp>(f->getLoc(), 9)}); + OpBuilder b(f.getBody()); + lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8), + b.create<ConstantIndexOp>(f.getLoc(), 9)}); composeSliceOps(f); lowerToLoops(f); // This cannot lower below linalg.load and linalg.store due to lost diff --git a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h index 2165cab6ac1..ba7273e409d 100644 --- a/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h +++ b/mlir/examples/Linalg/Linalg4/include/linalg4/Transforms.h @@ -34,12 +34,12 @@ writeAsTiledViews(mlir::Operation *op, llvm::ArrayRef<mlir::Value *> tileSizes); /// Apply `writeAsTiledLoops` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledLoops(mlir::Function *f, llvm::ArrayRef<uint64_t> tileSizes); +void lowerToTiledLoops(mlir::Function f, llvm::ArrayRef<uint64_t> tileSizes); /// Apply `writeAsTiledViews` on all linalg ops. This is a convenience function /// and is not exposed as a pass because a fixed set of tile sizes for all ops /// in a function can generally not be specified. -void lowerToTiledViews(mlir::Function *f, +void lowerToTiledViews(mlir::Function f, llvm::ArrayRef<mlir::Value *> tileSizes); } // namespace linalg diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index 1a308df1313..16b395da506 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -43,9 +43,8 @@ linalg::writeAsTiledLoops(Operation *op, ArrayRef<uint64_t> tileSizes) { return llvm::None; } -void linalg::lowerToTiledLoops(mlir::Function *f, - ArrayRef<uint64_t> tileSizes) { - f->walk([tileSizes](Operation *op) { +void linalg::lowerToTiledLoops(mlir::Function f, ArrayRef<uint64_t> tileSizes) { + f.walk([tileSizes](Operation *op) { if (writeAsTiledLoops(op, tileSizes).hasValue()) op->erase(); }); @@ -185,8 +184,8 @@ linalg::writeAsTiledViews(Operation *op, ArrayRef<Value *> tileSizes) { return llvm::None; } -void linalg::lowerToTiledViews(mlir::Function *f, ArrayRef<Value *> tileSizes) { - f->walk([tileSizes](Operation *op) { +void linalg::lowerToTiledViews(mlir::Function f, ArrayRef<Value *> tileSizes) { + f.walk([tileSizes](Operation *op) { if (auto matmulOp = dyn_cast<linalg::MatmulOp>(op)) { writeAsTiledViews(matmulOp, tileSizes); } else if (auto matvecOp = dyn_cast<linalg::MatvecOp>(op)) { diff --git a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp index 842c7a1d0f8..73789fa41a4 100644 --- a/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch2/mlir/MLIRGen.cpp @@ -75,7 +75,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -129,40 +129,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector<mlir::Type, 4> ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -172,16 +172,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique<mlir::OpBuilder>(function->getBody()); + builder = llvm::make_unique<mlir::OpBuilder>(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index e365f37f8c8..23cb85309c2 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector<mlir::Type, 4> ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique<mlir::OpBuilder>(function->getBody()); + builder = llvm::make_unique<mlir::OpBuilder>(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index 032766a547f..f2132c29c33 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector<mlir::Type, 4> ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique<mlir::OpBuilder>(function->getBody()); + builder = llvm::make_unique<mlir::OpBuilder>(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 688c73645a5..f237fd9fb53 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -113,14 +113,14 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector<mlir::Type, 4> argumentsType; }; void runOnModule() override { auto &module = getModule(); - auto *main = module.getNamedFunction("main"); + auto main = module.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -139,7 +139,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType<mlir::BoolAttr>("toy.generic")) { if (genericAttr.getValue()) @@ -153,7 +153,7 @@ public: mlir::LogicalResult specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -169,36 +169,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); - getModule().getFunctions().push_back(newFunction); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); + getModule().push_back(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast<int>(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast<int>(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -211,7 +211,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast<ToyArrayType>().isGeneric()) @@ -292,9 +292,9 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = getModule().getNamedFunction(calleeName); + auto callee = getModule().getNamedFunction(calleeName); if (!callee) { - f->emitError("Shape inference failed, call to unknown '") + f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; signalPassFailure(); return mlir::failure(); @@ -302,7 +302,7 @@ public: auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = getModule().getNamedFunction(mangledName); + auto mangledCallee = getModule().getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -327,7 +327,7 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { @@ -337,31 +337,31 @@ public: << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) errorMsg << " - " << *ope << "\n"; - f->emitError(errorMsg.str()); + f.emitError(errorMsg.str()); signalPassFailure(); return mlir::failure(); } // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast<ReturnOp>(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector<mlir::Type, 1> retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector<mlir::Type> argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 8b2a3927d78..60a8b5a3b9a 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -136,14 +136,14 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. - Function *printfFunc = getPrintf(*op->getFunction()->getModule()); + Function printfFunc = getPrintf(*op->getFunction().getModule()); auto print = cast<toy::PrintOp>(op); auto loc = print.getLoc(); // We will operate on a MemRef abstraction, we use a type.cast to get one // if our operand is still a Toy array. Value *operand = memRefTypeCast(rewriter, operands[0]); - Type retTy = printfFunc->getType().getResult(0); + Type retTy = printfFunc.getType().getResult(0); // Create our loop nest now using namespace edsc; @@ -205,8 +205,8 @@ private: /// Return the prototype declaration for printf in the module, create it if /// necessary. - Function *getPrintf(Module &module) const { - auto *printfFunc = module.getNamedFunction("printf"); + Function getPrintf(Module &module) const { + auto printfFunc = module.getNamedFunction("printf"); if (printfFunc) return printfFunc; @@ -218,10 +218,10 @@ private: auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); auto printfTy = builder.getFunctionType({llvmI8PtrTy}, {llvmI32Ty}); - printfFunc = new Function(builder.getUnknownLoc(), "printf", printfTy); + printfFunc = Function::create(builder.getUnknownLoc(), "printf", printfTy); // It should be variadic, but we don't support it fully just yet. - printfFunc->setAttr("std.varargs", builder.getBoolAttr(true)); - module.getFunctions().push_back(printfFunc); + printfFunc.setAttr("std.varargs", builder.getBoolAttr(true)); + module.push_back(printfFunc); return printfFunc; } }; @@ -369,7 +369,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { // affine dialect: they already include conversion to the LLVM dialect. // First patch calls type to return memref instead of ToyArray - for (auto &function : getModule()) { + for (auto function : getModule()) { function.walk([&](Operation *op) { auto callOp = dyn_cast<CallOp>(op); if (!callOp) @@ -384,7 +384,7 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { }); } - for (auto &function : getModule()) { + for (auto function : getModule()) { function.walk([&](Operation *op) { // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). if (auto allocOp = dyn_cast<toy::AllocOp>(op)) { @@ -403,8 +403,8 @@ struct LateLoweringPass : public ModulePass<LateLoweringPass> { } // Lower Linalg to affine - for (auto &function : getModule()) - linalg::lowerToLoops(&function); + for (auto function : getModule()) + linalg::lowerToLoops(function); getModule().dump(); diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index f7e6fad568e..9ebfeb438ca 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -76,7 +76,7 @@ public: auto func = mlirGen(F); if (!func) return nullptr; - theModule->getFunctions().push_back(func.release()); + theModule->push_back(func); } // FIXME: (in the next chapter...) without registering a dialect in MLIR, @@ -130,40 +130,40 @@ private: /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. - mlir::Function *mlirGen(PrototypeAST &proto) { + mlir::Function mlirGen(PrototypeAST &proto) { // This is a generic function, the return type will be inferred later. llvm::SmallVector<mlir::Type, 4> ret_types; // Arguments type is uniformly a generic array. llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(), getType(VarType{})); auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto *function = new mlir::Function(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto function = mlir::Function::create(loc(proto.loc()), proto.getName(), + func_type, /* attrs = */ {}); // Mark the function as generic: it'll require type specialization for every // call site. - if (function->getNumArguments()) - function->setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); + if (function.getNumArguments()) + function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); return function; } /// Emit a new function and add it to the MLIR module. - std::unique_ptr<mlir::Function> mlirGen(FunctionAST &funcAST) { + mlir::Function mlirGen(FunctionAST &funcAST) { // Create a scope in the symbol table to hold variable declarations. ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable); // Create an MLIR function for the given prototype. - std::unique_ptr<mlir::Function> function(mlirGen(*funcAST.getProto())); + mlir::Function function(mlirGen(*funcAST.getProto())); if (!function) return nullptr; // Let's start the body of the function now! // In MLIR the entry block of the function is special: it must have the same // argument list as the function itself. - function->addEntryBlock(); + function.addEntryBlock(); - auto &entryBlock = function->front(); + auto &entryBlock = function.front(); auto &protoArgs = funcAST.getProto()->getArgs(); // Declare all the function arguments in the symbol table. for (const auto &name_value : @@ -173,16 +173,18 @@ private: // Create a builder for the function, it will be used throughout the codegen // to create operations in this function. - builder = llvm::make_unique<mlir::OpBuilder>(function->getBody()); + builder = llvm::make_unique<mlir::OpBuilder>(function.getBody()); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) + if (!mlirGen(*funcAST.getBody())) { + function.erase(); return nullptr; + } // Implicitly return void if no return statement was emited. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function->getBlocks().back().back().getName().getStringRef() != + if (function.getBlocks().back().back().getName().getStringRef() != "toy.return") { ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); mlirGen(fakeRet); diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index cad2deda57e..0abcb4bb850 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -113,7 +113,7 @@ public: // function to process, the mangled name for this specialization, and the // types of the arguments on which to specialize. struct FunctionToSpecialize { - mlir::Function *function; + mlir::Function function; std::string mangledName; SmallVector<mlir::Type, 4> argumentsType; }; @@ -121,7 +121,7 @@ public: void runOnModule() override { auto &module = getModule(); mlir::ModuleManager moduleManager(&module); - auto *main = moduleManager.getNamedFunction("main"); + auto main = moduleManager.getNamedFunction("main"); if (!main) { emitError(mlir::UnknownLoc::get(module.getContext()), "Shape inference failed: can't find a main function\n"); @@ -140,7 +140,7 @@ public: // Delete any generic function left // FIXME: we may want this as a separate pass. - for (mlir::Function &function : llvm::make_early_inc_range(module)) { + for (mlir::Function function : llvm::make_early_inc_range(module)) { if (auto genericAttr = function.getAttrOfType<mlir::BoolAttr>("toy.generic")) { if (genericAttr.getValue()) @@ -155,7 +155,7 @@ public: specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist, mlir::ModuleManager &moduleManager) { FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::Function *f = functionToSpecialize.function; + mlir::Function f = functionToSpecialize.function; // Check if cloning for specialization is needed (usually anything but main) // We will create a new function with the concrete types for the parameters @@ -171,36 +171,36 @@ public: auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, {ToyArrayType::get(&getContext())}, &getContext()); - auto *newFunction = new mlir::Function( - f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs()); + auto newFunction = mlir::Function::create( + f.getLoc(), functionToSpecialize.mangledName, type, f.getAttrs()); moduleManager.insert(newFunction); // Clone the function body mlir::BlockAndValueMapping mapper; - f->cloneInto(newFunction, mapper); + f.cloneInto(newFunction, mapper); LLVM_DEBUG({ llvm::dbgs() << "====== Cloned : \n"; - f->dump(); + f.dump(); llvm::dbgs() << "====== Into : \n"; - newFunction->dump(); + newFunction.dump(); }); f = newFunction; - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // Remap the entry-block arguments // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f->getBlocks().front(); + auto &entryBlock = f.getBlocks().front(); int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast<int>(f->getType().getInputs().size())); - entryBlock.addArguments(f->getType().getInputs()); + assert(blockArgSize == static_cast<int>(f.getType().getInputs().size())); + entryBlock.addArguments(f.getType().getInputs()); auto argList = entryBlock.getArguments(); for (int argNum = 0; argNum < blockArgSize; ++argNum) { argList[0]->replaceAllUsesWith(argList[blockArgSize]); entryBlock.eraseArgument(0); } - assert(succeeded(f->verify())); + assert(succeeded(f.verify())); } LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f->getName() << "'\n"); + << "Run shape inference on : '" << f.getName() << "'\n"); auto *toyDialect = getContext().getRegisteredDialect("toy"); if (!toyDialect) { @@ -212,7 +212,7 @@ public: // Populate the worklist with the operations that need shape inference: // these are the Toy operations that return a generic array. llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist; - f->walk([&](mlir::Operation *op) { + f.walk([&](mlir::Operation *op) { if (op->getDialect() == toyDialect) { if (op->getNumResults() == 1 && op->getResult(0)->getType().cast<ToyArrayType>().isGeneric()) @@ -295,16 +295,16 @@ public: // restart after the callee is processed. if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) { auto calleeName = callOp.getCalleeName(); - auto *callee = moduleManager.getNamedFunction(calleeName); + auto callee = moduleManager.getNamedFunction(calleeName); if (!callee) { signalPassFailure(); - return f->emitError("Shape inference failed, call to unknown '") + return f.emitError("Shape inference failed, call to unknown '") << calleeName << "'"; } auto mangledName = mangle(calleeName, op->getOpOperands()); LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName << "', mangled: '" << mangledName << "'\n"); - auto *mangledCallee = moduleManager.getNamedFunction(mangledName); + auto mangledCallee = moduleManager.getNamedFunction(mangledName); if (!mangledCallee) { // Can't find the target, this is where we queue the request for the // callee and stop the inference for the current function now. @@ -315,7 +315,7 @@ public: // Found a specialized callee! Let's turn this into a normal call // operation. SmallVector<mlir::Value *, 8> operands(op->getOperands()); - mlir::OpBuilder builder(f->getBody()); + mlir::OpBuilder builder(f.getBody()); builder.setInsertionPoint(op); auto newCall = builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands); @@ -330,12 +330,12 @@ public: // Done with inference on this function, removing it from the worklist. funcWorklist.pop_back(); // Mark the function as non-generic now that inference has succeeded - f->setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); + f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { signalPassFailure(); - auto diag = f->emitError("Shape inference failed, ") + auto diag = f.emitError("Shape inference failed, ") << opWorklist.size() << " operations couldn't be inferred\n"; for (auto *ope : opWorklist) diag << " - " << *ope << "\n"; @@ -344,24 +344,24 @@ public: // Finally, update the return type of the function based on the argument to // the return operation. - for (auto &block : f->getBlocks()) { + for (auto &block : f.getBlocks()) { auto ret = llvm::cast<ReturnOp>(block.getTerminator()); if (!ret) continue; if (ret.getNumOperands() && - f->getType().getResult(0) == ret.getOperand()->getType()) + f.getType().getResult(0) == ret.getOperand()->getType()) // type match, we're done break; SmallVector<mlir::Type, 1> retTy; if (ret.getNumOperands()) retTy.push_back(ret.getOperand()->getType()); std::vector<mlir::Type> argumentsType; - for (auto arg : f->getArguments()) + for (auto arg : f.getArguments()) argumentsType.push_back(arg->getType()); auto newType = mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f->setType(newType); - assert(succeeded(f->verify())); + f.setType(newType); + assert(succeeded(f.verify())); break; } return mlir::success(); |