summaryrefslogtreecommitdiffstats
path: root/mlir/test/EDSC/builder-api-test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/EDSC/builder-api-test.cpp')
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp150
1 files changed, 80 insertions, 70 deletions
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 834e7c98228..a88312dba9b 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -43,13 +43,12 @@ static MLIRContext &globalContext() {
return context;
}
-static std::unique_ptr<Function> makeFunction(StringRef name,
- ArrayRef<Type> results = {},
- ArrayRef<Type> args = {}) {
+static Function makeFunction(StringRef name, ArrayRef<Type> results = {},
+ ArrayRef<Type> args = {}) {
auto &ctx = globalContext();
- auto function = llvm::make_unique<Function>(
- UnknownLoc::get(&ctx), name, FunctionType::get(args, results, &ctx));
- function->addEntryBlock();
+ auto function = Function::create(UnknownLoc::get(&ctx), name,
+ FunctionType::get(args, results, &ctx));
+ function.addEntryBlock();
return function;
}
@@ -62,10 +61,10 @@ TEST_FUNC(builder_dynamic_for_func_args) {
auto f =
makeFunction("builder_dynamic_for_func_args", {}, {indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle i(indexType), j(indexType), lb(f->getArgument(0)),
- ub(f->getArgument(1));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)),
+ ub(f.getArgument(1));
ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type));
ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type));
ValueHandle i7(constant_int(7, 32));
@@ -102,7 +101,8 @@ TEST_FUNC(builder_dynamic_for_func_args) {
// CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32
// CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_dynamic_for) {
@@ -113,10 +113,10 @@ TEST_FUNC(builder_dynamic_for) {
auto f = makeFunction("builder_dynamic_for", {},
{indexType, indexType, indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle i(indexType), a(f->getArgument(0)), b(f->getArgument(1)),
- c(f->getArgument(2)), d(f->getArgument(3));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)),
+ c(f.getArgument(2)), d(f.getArgument(3));
LoopBuilder(&i, a - b, c + d, 2)();
// clang-format off
@@ -125,7 +125,8 @@ TEST_FUNC(builder_dynamic_for) {
// CHECK-DAG: [[r1:%[0-9]+]] = affine.apply ()[s0, s1] -> (s0 + s1)()[%arg2, %arg3]
// CHECK-NEXT: affine.for %i0 = (d0) -> (d0)([[r0]]) to (d0) -> (d0)([[r1]]) step 2 {
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_max_min_for) {
@@ -136,10 +137,10 @@ TEST_FUNC(builder_max_min_for) {
auto f = makeFunction("builder_max_min_for", {},
{indexType, indexType, indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle i(indexType), lb1(f->getArgument(0)), lb2(f->getArgument(1)),
- ub1(f->getArgument(2)), ub2(f->getArgument(3));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)),
+ ub1(f.getArgument(2)), ub2(f.getArgument(3));
LoopBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)();
ret();
@@ -148,7 +149,8 @@ TEST_FUNC(builder_max_min_for) {
// CHECK: affine.for %i0 = max (d0, d1) -> (d0, d1)(%arg0, %arg1) to min (d0, d1) -> (d0, d1)(%arg2, %arg3) {
// CHECK: return
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_blocks) {
@@ -157,14 +159,14 @@ TEST_FUNC(builder_blocks) {
using namespace edsc::op;
auto f = makeFunction("builder_blocks");
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
arg4(c1.getType()), r(c1.getType());
- BlockHandle b1, b2, functionBlock(&f->front());
+ BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1, &arg2})(
// b2 has not yet been constructed, need to come back later.
// This is a byproduct of non-structured control-flow.
@@ -192,7 +194,8 @@ TEST_FUNC(builder_blocks) {
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
// CHECK-NEXT: }
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_blocks_eager) {
@@ -201,8 +204,8 @@ TEST_FUNC(builder_blocks_eager) {
using namespace edsc::op;
auto f = makeFunction("builder_blocks_eager");
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle c1(ValueHandle::create<ConstantIntOp>(42, 32)),
c2(ValueHandle::create<ConstantIntOp>(1234, 32));
ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()),
@@ -235,7 +238,8 @@ TEST_FUNC(builder_blocks_eager) {
// CHECK-NEXT: br ^bb1(%3, %4 : i32, i32)
// CHECK-NEXT: }
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_cond_branch) {
@@ -244,15 +248,15 @@ TEST_FUNC(builder_cond_branch) {
auto f = makeFunction("builder_cond_branch", {},
{IntegerType::get(1, &globalContext())});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle funcArg(f->getArgument(0));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType());
- BlockHandle b1, b2, functionBlock(&f->front());
+ BlockHandle b1, b2, functionBlock(&f.front());
BlockBuilder(&b1, {&arg1})([&] { ret(); });
BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); });
// Get back to entry block and add a conditional branch
@@ -271,7 +275,8 @@ TEST_FUNC(builder_cond_branch) {
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
// CHECK-NEXT: return
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_cond_branch_eager) {
@@ -281,9 +286,9 @@ TEST_FUNC(builder_cond_branch_eager) {
auto f = makeFunction("builder_cond_branch_eager", {},
{IntegerType::get(1, &globalContext())});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
- ValueHandle funcArg(f->getArgument(0));
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
+ ValueHandle funcArg(f.getArgument(0));
ValueHandle c32(ValueHandle::create<ConstantIntOp>(32, 32)),
c64(ValueHandle::create<ConstantIntOp>(64, 64)),
c42(ValueHandle::create<ConstantIntOp>(42, 32));
@@ -309,7 +314,8 @@ TEST_FUNC(builder_cond_branch_eager) {
// CHECK-NEXT: ^bb2(%1: i64, %2: i32): // pred: ^bb0
// CHECK-NEXT: return
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(builder_helpers) {
@@ -321,14 +327,14 @@ TEST_FUNC(builder_helpers) {
auto f =
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle f7(
ValueHandle::create<ConstantFloatOp>(llvm::APFloat(7.0f), f32Type));
- MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
- vC(f->getArgument(2));
- IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
+ MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)),
+ vC(f.getArgument(2));
+ IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2;
int64_t step0, step1, step2;
std::tie(lb0, ub0, step0) = vA.range(0);
@@ -363,7 +369,8 @@ TEST_FUNC(builder_helpers) {
// CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32
// CHECK-NEXT: store [[e]], %arg2[%i0, %i1, %i3] : memref<?x?x?xf32>
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(custom_ops) {
@@ -373,8 +380,8 @@ TEST_FUNC(custom_ops) {
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("custom_ops", {}, {indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
CustomOperation<ValueHandle> MY_CUSTOM_OP("my_custom_op");
CustomOperation<OperationHandle> MY_CUSTOM_OP_0("my_custom_op_0");
CustomOperation<OperationHandle> MY_CUSTOM_OP_2("my_custom_op_2");
@@ -382,7 +389,7 @@ TEST_FUNC(custom_ops) {
// clang-format off
ValueHandle vh(indexType), vh20(indexType), vh21(indexType);
OperationHandle ih0, ih2;
- IndexHandle m, n, M(f->getArgument(0)), N(f->getArgument(1));
+ IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1));
IndexHandle ten(index_t(10)), twenty(index_t(20));
LoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{
vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {});
@@ -402,7 +409,8 @@ TEST_FUNC(custom_ops) {
// CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
// CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(insertion_in_block) {
@@ -412,8 +420,8 @@ TEST_FUNC(insertion_in_block) {
auto indexType = IndexType::get(&globalContext());
auto f = makeFunction("insertion_in_block", {}, {indexType, indexType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
BlockHandle b1;
// clang-format off
ValueHandle::create<ConstantIntOp>(0, 32);
@@ -427,7 +435,8 @@ TEST_FUNC(insertion_in_block) {
// CHECK: ^bb1: // no predecessors
// CHECK: {{.*}} = constant 1 : i32
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
TEST_FUNC(select_op) {
@@ -438,12 +447,12 @@ TEST_FUNC(select_op) {
auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
// clang-format off
ValueHandle zero = constant_index(0), one = constant_index(1);
- MemRefView vA(f->getArgument(0));
- IndexedValue A(f->getArgument(0));
+ MemRefView vA(f.getArgument(0));
+ IndexedValue A(f.getArgument(0));
IndexHandle i, j;
LoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{
// This test exercises IndexedValue::operator Value*.
@@ -461,7 +470,8 @@ TEST_FUNC(select_op) {
// CHECK-DAG: {{.*}} = load
// CHECK-NEXT: {{.*}} = select
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
// Inject an EDSC-constructed computation to exercise imperfectly nested 2-d
@@ -474,12 +484,11 @@ TEST_FUNC(tile_2d) {
MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
- MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
- vC(f->getArgument(2));
- IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
+ MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
+ IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2));
// clang-format off
@@ -531,7 +540,8 @@ TEST_FUNC(tile_2d) {
// CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32
// CHECK-NEXT: store {{.*}}, {{.*}}[%i8, %i9, %i7] : memref<?x?x?xf32>
// clang-format on
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
// Inject an EDSC-constructed computation to exercise 2-d vectorization.
@@ -544,16 +554,15 @@ TEST_FUNC(vectorize_2d) {
auto owningF =
makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
- mlir::Function *f = owningF.release();
+ mlir::Function f = owningF;
mlir::Module module(&globalContext());
- module.getFunctions().push_back(f);
+ module.push_back(f);
- OpBuilder builder(f->getBody());
- ScopedContext scope(builder, f->getLoc());
+ OpBuilder builder(f.getBody());
+ ScopedContext scope(builder, f.getLoc());
ValueHandle zero = constant_index(0);
- MemRefView vA(f->getArgument(0)), vB(f->getArgument(1)),
- vC(f->getArgument(2));
- IndexedValue A(f->getArgument(0)), B(f->getArgument(1)), C(f->getArgument(2));
+ MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
+ IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2));
// clang-format off
@@ -580,9 +589,10 @@ TEST_FUNC(vectorize_2d) {
pm.addPass(mlir::createCanonicalizerPass());
SmallVector<int64_t, 2> vectorSizes{4, 4};
pm.addPass(mlir::createVectorizePass(vectorSizes));
- auto result = pm.run(f->getModule());
+ auto result = pm.run(f.getModule());
if (succeeded(result))
- f->print(llvm::outs());
+ f.print(llvm::outs());
+ f.erase();
}
int main() {
OpenPOWER on IntegriCloud