summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/IR/MLIRContext.cpp15
-rw-r--r--mlir/test/IR/parser.mlir55
2 files changed, 49 insertions, 21 deletions
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 3ce41d7dab9..ba60cec4d95 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -723,13 +723,16 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
auto *context = elementType->getContext();
auto &impl = context->getImpl();
- // Drop the affine map composition if comprises a single unbounded identity
- // map (the absence of map composition is considered as implicit identity).
- if (affineMapComposition.size() == 1 &&
- affineMapComposition.front().isIdentity() &&
- !affineMapComposition.front().isBounded()) {
- affineMapComposition = affineMapComposition.drop_front();
+ // Drop the unbounded identity maps from the composition.
+ // This may lead to the composition becoming empty, which is interpreted as an
+ // implicit identity.
+ llvm::SmallVector<AffineMap, 2> cleanedAffineMapComposition;
+ for (const auto &map : affineMapComposition) {
+ if (map.isIdentity() && !map.isBounded())
+ continue;
+ cleanedAffineMapComposition.push_back(map);
}
+ affineMapComposition = cleanedAffineMapComposition;
// Look to see if we already have this memref type.
auto key =
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index e3e0046a8af..a557ac8caed 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -21,18 +21,21 @@
// CHECK-DAG: #map{{[0-9]+}} = ()[s0] -> (100, s0 + 1)
#inline_map_minmax_loop2 = ()[s0] -> (100, s0 + 1)
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1, s0 + 1)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1, s0 + 1)
#inline_map_loop_bounds1 = (d0, d1)[s0] -> (d0 + d1, s0 + 1)
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + s0)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1)[s0] -> (d0 + d1 + s0)
#bound_map1 = (i, j)[s] -> (i + j + s)
-// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + d1)
+// CHECK-DAG: #map{{[0-9]+}} = (d0, d1) -> (d0 + d1)
#inline_map_loop_bounds2 = (d0, d1) -> (d0 + d1)
-// CHECK: #map{{[0-9]+}} = (d0)[s0] -> (d0 + s0, d0 - s0)
+// CHECK-DAG: #map{{[0-9]+}} = (d0)[s0] -> (d0 + s0, d0 - s0)
#bound_map2 = (i)[s] -> (i + s, i - s)
+// All maps appear in arbitrary order before all sets, in arbitrary order.
+// CHECK-EMPTY
+
// CHECK-DAG: #set{{[0-9]+}} = (d0)[s0, s1] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0, d0 * -1 + s1 + 1 >= 0)
#set0 = (i)[N, M] : (i >= 0, -i + N >= 0, N - 5 == 0, -i + M + 1 >= 0)
@@ -67,7 +70,7 @@ extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
extfunc @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
tensor<1x?x4x?x?xi32>, tensor<i8>)
-// CHECK: extfunc @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
+// CHECK: extfunc @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>)
extfunc @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)
// Test memref affine map compositions.
@@ -75,18 +78,18 @@ extfunc @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)
// CHECK: extfunc @memrefs2(memref<2x4x8xi8, 1>)
extfunc @memrefs2(memref<2x4x8xi8, #map2, 1>)
-// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
+// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}>)
extfunc @memrefs23(memref<2x4x8xi8, #map2, #map3, 0>)
-// CHECK: extfunc @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>)
+// CHECK: extfunc @memrefs234(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 3>)
extfunc @memrefs234(memref<2x4x8xi8, #map2, #map3, #map4, 3>)
-// Test memref inline affine map compositions.
+// Test memref inline affine map compositions, minding that identity maps are removed.
// CHECK: extfunc @memrefs3(memref<2x4x8xi8>)
extfunc @memrefs3(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
-// CHECK: extfunc @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
+// CHECK: extfunc @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, 1>)
extfunc @memrefs33(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
// CHECK: extfunc @memrefs_drop_triv_id_inline(memref<2xi8>)
@@ -98,9 +101,31 @@ extfunc @memrefs_drop_triv_id_inline0(memref<2xi8, (d0) -> (d0), 0>)
// CHECK: extfunc @memrefs_drop_triv_id_inline1(memref<2xi8, 1>)
extfunc @memrefs_drop_triv_id_inline1(memref<2xi8, (d0) -> (d0), 1>)
+// Identity maps should be dropped from the composition, but not the pair of
+// "interchange" maps that, if composed, would be also an identity.
+// CHECK: extfunc @memrefs_drop_triv_id_composition(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
+extfunc @memrefs_drop_triv_id_composition(memref<2x2xi8,
+ (d0, d1) -> (d1, d0),
+ (d0, d1) -> (d0, d1),
+ (d0, d1) -> (d1, d0),
+ (d0, d1) -> (d0, d1),
+ (d0, d1) -> (d0, d1)>)
+
+// CHECK: extfunc @memrefs_drop_triv_id_trailing(memref<2x2xi8, #map{{[0-9]+}}>)
+extfunc @memrefs_drop_triv_id_trailing(memref<2x2xi8, (d0, d1) -> (d1, d0),
+ (d0, d1) -> (d0, d1)>)
+
+// CHECK: extfunc @memrefs_drop_triv_id_middle(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
+extfunc @memrefs_drop_triv_id_middle(memref<2x2xi8, (d0, d1) -> (d0, d1 + 1),
+ (d0, d1) -> (d0, d1),
+ (d0, d1) -> (d0 + 1, d1)>)
+
+// CHECK: extfunc @memrefs_drop_triv_id_multiple(memref<2xi8>)
+extfunc @memrefs_drop_triv_id_multiple(memref<2xi8, (d0) -> (d0), (d0) -> (d0)>)
+
// These maps appeared before, so they must be uniqued and hoisted to the beginning.
-// Identity map should not be removed because of the composition.
-// CHECK: extfunc @memrefs_compose_with_id(memref<2x2xi8, #map{{[0-9]+}}, #map{{[0-9]+}}>)
+// Identity map should be removed.
+// CHECK: extfunc @memrefs_compose_with_id(memref<2x2xi8, #map{{[0-9]+}}>)
extfunc @memrefs_compose_with_id(memref<2x2xi8, (d0, d1) -> (d0, d1),
(d0, d1) -> (d1, d0)>)
@@ -212,7 +237,7 @@ mlfunc @complex_loops() {
mlfunc @triang_loop(%arg0 : index, %arg1 : memref<?x?xi32>) {
%c = constant 0 : i32 // CHECK: %c0_i32 = constant 0 : i32
for %i0 = 1 to %arg0 { // CHECK: for %i0 = 1 to %arg0 {
- for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map1(%i0) to %arg0 {
+ for %i1 = %i0 to %arg0 { // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to %arg0 {
store %c, %arg1[%i0, %i1] : memref<?x?xi32> // CHECK: store %c0_i32, %arg1[%i0, %i1]
} // CHECK: }
} // CHECK: }
@@ -235,7 +260,7 @@ mlfunc @loop_bounds(%N : index) {
%s = "foo"(%N) : (index) -> index
// CHECK: for %i0 = %0 to %arg0
for %i = %s to %N {
- // CHECK: for %i1 = #map1(%i0) to 0
+ // CHECK: for %i1 = #map{{[0-9]+}}(%i0) to 0
for %j = %i to 0 step 1 {
// CHECK: %1 = affine_apply #map{{.*}}(%i0, %i1)[%0]
%w = affine_apply(d0, d1)[s0] -> (d0+d1, s0+1) (%i, %j) [%s]
@@ -261,7 +286,7 @@ mlfunc @loop_bounds(%N : index) {
// CHECK-LABEL: mlfunc @ifstmt(%arg0 : index) {
mlfunc @ifstmt(%N: index) {
%c = constant 200 : index // CHECK %c200 = constant 200
- for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
+ for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] {
%x = constant 1 : i32
// CHECK: %c1_i32 = constant 1 : i32
@@ -282,7 +307,7 @@ mlfunc @ifstmt(%N: index) {
// CHECK-LABEL: mlfunc @simple_ifstmt(%arg0 : index) {
mlfunc @simple_ifstmt(%N: index) {
%c = constant 200 : index // CHECK %c200 = constant 200
- for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
+ for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] {
%x = constant 1 : i32
// CHECK: %c1_i32 = constant 1 : i32
OpenPOWER on IntegriCloud