diff options
Diffstat (limited to 'mlir/test/EDSC/builder-api-test.cpp')
-rw-r--r-- | mlir/test/EDSC/builder-api-test.cpp | 150 |
1 files changed, 80 insertions, 70 deletions
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 834e7c98228..a88312dba9b 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -43,13 +43,12 @@ static MLIRContext &globalContext() { return context; } -static std::unique_ptr<Function> makeFunction(StringRef name, - ArrayRef<Type> results = {}, - ArrayRef<Type> args = {}) { +static Function makeFunction(StringRef name, ArrayRef<Type> results = {}, + ArrayRef<Type> args = {}) { auto &ctx = globalContext(); - auto function = llvm::make_unique<Function>( - UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx)); - function->addEntryBlock(); + auto function = Function::create(UnknownLoc::get(&ctx), name, + FunctionType::get(args, results, &ctx)); + function.addEntryBlock(); return function; } @@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) { auto f = makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)), - ub(f->getArgument(1)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)), + ub(f.getArgument(1)); ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type)); ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type)); ValueHandle i7(constant_int(7, 32)); @@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) { // CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32 // CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32 // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_dynamic_for) { @@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) { auto f = makeFunction("builder_dynamic_for", {}, {indexType, indexType, indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)), - c(f->getArgument(2)), d(f->getArgument(3)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), + c(f.getArgument(2)), d(f.getArgument(3)); LoopBuilder(&i, a - b, c + d, 2)(); // clang-format off @@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) { // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3] // CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 { // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_max_min_for) { @@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) { auto f = makeFunction("builder_max_min_for", {}, {indexType, indexType, indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)), - ub1(f->getArgument(2)), ub2(f->getArgument(3)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)), + ub1(f.getArgument(2)), ub2(f.getArgument(3)); LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)(); ret(); @@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) { // CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) { // CHECK: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_blocks) { @@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) { using namespace edsc::op; auto f = makeFunction("builder_blocks"); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)), c2(ValueHandle::create<ConstantIntOp>(1234, 32)); ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), arg4(c1.getType()), r(c1.getType()); - BlockHandle b1, b2, functionBlock(&f->front()); + BlockHandle b1, b2, functionBlock(&f.front()); BlockBuilder(&b1, {&arg1, &arg2})( // b2 has not yet been constructed, need to come back later. // This is a byproduct of non-structured control-flow. @@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) { // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) // CHECK-NEXT: } // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_blocks_eager) { @@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) { using namespace edsc::op; auto f = makeFunction("builder_blocks_eager"); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)), c2(ValueHandle::create<ConstantIntOp>(1234, 32)); ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), @@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) { // CHECK-NEXT: br ^bb1(%3, %4 : i32, i32) // CHECK-NEXT: } // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_cond_branch) { @@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) { auto f = makeFunction("builder_cond_branch", {}, {IntegerType::get(1, &globalContext())}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle funcArg(f->getArgument(0)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle funcArg(f.getArgument(0)); ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)), c64(ValueHandle::create<ConstantIntOp>(64, 64)), c42(ValueHandle::create<ConstantIntOp>(42, 32)); ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); - BlockHandle b1, b2, functionBlock(&f->front()); + BlockHandle b1, b2, functionBlock(&f.front()); BlockBuilder(&b1, {&arg1})([&] { ret(); }); BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); }); // Get back to entry block and add a conditional branch @@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) { // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_cond_branch_eager) { @@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) { auto f = makeFunction("builder_cond_branch_eager", {}, {IntegerType::get(1, &globalContext())}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); - ValueHandle funcArg(f->getArgument(0)); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + ValueHandle funcArg(f.getArgument(0)); ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)), c64(ValueHandle::create<ConstantIntOp>(64, 64)), c42(ValueHandle::create<ConstantIntOp>(42, 32)); @@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) { // CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(builder_helpers) { @@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) { auto f = makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle f7( ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type)); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), + vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2; int64_t step0, step1, step2; std::tie(lb0, ub0, step0) = vA.range(0); @@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) { // CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32 // CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref<?x?x?xf32> // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(custom_ops) { @@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("custom_ops", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op"); CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0"); CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2"); @@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) { // clang-format off ValueHandle vh(indexType), vh20(indexType), vh21(indexType); OperationHandle ih0, ih2; - IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1)); + IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1)); IndexHandle ten(index_t(10)), twenty(index_t(20)); LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{ vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}); @@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) { // CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index) // CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(insertion_in_block) { @@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) { auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("insertion_in_block", {}, {indexType, indexType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); BlockHandle b1; // clang-format off ValueHandle::create<ConstantIntOp>(0, 32); @@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) { // CHECK: ^bb1: // no predecessors // CHECK: {{.*}} = constant 1 : i32 // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } TEST_FUNC(select_op) { @@ -438,12 +447,12 @@ TEST_FUNC(select_op) { auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); // clang-format off ValueHandle zero = constant_index(0), one = constant_index(1); - MemRefView vA(f->getArgument(0)); - IndexedValue A(f->getArgument(0)); + MemRefView vA(f.getArgument(0)); + IndexedValue A(f.getArgument(0)); IndexHandle i, j; LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ // This test exercises IndexedValue::operator Value*. @@ -461,7 +470,8 @@ TEST_FUNC(select_op) { // CHECK-DAG: {{.*}} = load // CHECK-NEXT: {{.*}} = select // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } // Inject an EDSC-constructed computation to exercise imperfectly nested 2-d @@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) { MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0); auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType}); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle zero = constant_index(0); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); // clang-format off @@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) { // CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32 // CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref<?x?x?xf32> // clang-format on - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } // Inject an EDSC-constructed computation to exercise 2-d vectorization. @@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) { auto owningF = makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType}); - mlir::Function *f = owningF.release(); + mlir::Function f = owningF; mlir::Module module(&globalContext()); - module.getFunctions().push_back(f); + module.push_back(f); - OpBuilder builder(f->getBody()); - ScopedContext scope(builder, f->getLoc()); + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); ValueHandle zero = constant_index(0); - MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)), - vC(f->getArgument(2)); - IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2)); + MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); + IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2)); // clang-format off @@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) { pm.addPass(mlir::createCanonicalizerPass()); SmallVector<int64_t, 2> vectorSizes{4, 4}; pm.addPass(mlir::createVectorizePass(vectorSizes)); - auto result = pm.run(f->getModule()); + auto result = pm.run(f.getModule()); if (succeeded(result)) - f->print(llvm::outs()); + f.print(llvm::outs()); + f.erase(); } int main() { |