summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/AffineExpr.h5
-rw-r--r--mlir/lib/IR/AffineExpr.cpp43
-rw-r--r--mlir/test/IR/affine-map.mlir8
-rw-r--r--mlir/test/Transforms/Vectorize/compose_maps.mlir4
-rw-r--r--mlir/test/Transforms/unroll.mlir2
5 files changed, 48 insertions, 14 deletions
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index cca7eac536f..928ced5289a 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -114,8 +114,9 @@ public:
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool isPureAffine() const;
- /// Returns the greatest known integral divisor of this affine expression.
- uint64_t getLargestKnownDivisor() const;
+ /// Returns the greatest known integral divisor of this affine expression. The
+ /// result is always positive.
+ int64_t getLargestKnownDivisor() const;
/// Return true if the affine expression is a multiple of 'factor'.
bool isMultipleOf(int64_t factor) const;
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 95ebc0a1cbe..19599a8a62e 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -160,7 +160,7 @@ bool AffineExpr::isPureAffine() const {
}
// Returns the greatest known integral divisor of this affine expression.
-uint64_t AffineExpr::getLargestKnownDivisor() const {
+int64_t AffineExpr::getLargestKnownDivisor() const {
AffineBinaryOpExpr binExpr(nullptr);
switch (getKind()) {
case AffineExprKind::SymbolId:
@@ -444,6 +444,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ // mlir floordiv by zero or negative numbers is undefined and preserved as is.
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
@@ -453,18 +454,32 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
- if (rhsConst.getValue() == 1)
+ if (rhsConst == 1)
return lhs;
+ // Simplify (expr * const) floordiv divConst when expr is known to be a
+ // multiple of divConst.
auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
- // rhsConst is known to be positive if a constant.
+ // rhsConst is known to be a positive constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
}
}
+ // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
+ // known to be a multiple of divConst.
+ if (lBin && lBin.getKind() == AffineExprKind::Add) {
+ int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
+ int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
+ // rhsConst is known to be a positive constant.
+ if (llhsDiv % rhsConst.getValue() == 0 ||
+ lrhsDiv % rhsConst.getValue() == 0)
+ return lBin.getLHS().floorDiv(rhsConst.getValue()) +
+ lBin.getRHS().floorDiv(rhsConst.getValue());
+ }
+
return nullptr;
}
@@ -497,10 +512,12 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
if (rhsConst.getValue() == 1)
return lhs;
+ // Simplify (expr * const) ceildiv divConst when const is known to be a
+ // multiple of divConst.
auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
- // rhsConst is known to be positive if a constant.
+ // rhsConst is known to be a positive constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
}
@@ -526,6 +543,7 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+ // mod w.r.t zero or negative numbers is undefined and preserved as is.
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
@@ -539,11 +557,20 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
return getAffineConstantExpr(0, lhs.getContext());
+ // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
+ // known to be a multiple of divConst.
+ auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+ if (lBin && lBin.getKind() == AffineExprKind::Add) {
+ int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
+ int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
+ // rhsConst is known to be a positive constant.
+ if (llhsDiv % rhsConst.getValue() == 0)
+ return lBin.getRHS() % rhsConst.getValue();
+ if (lrhsDiv % rhsConst.getValue() == 0)
+ return lBin.getLHS() % rhsConst.getValue();
+ }
+
return nullptr;
- // TODO(bondhugula): In general, this can be simplified more by using the GCD
- // test, or in general using quantifier elimination (add two new variables q
- // and r, and eliminate all variables from the linear system other than r. All
- // of this can be done through mlir/Analysis/'s FlatAffineConstraints.
}
AffineExpr AffineExpr::operator%(uint64_t v) const {
diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir
index d867d5113b6..ebbd4735635 100644
--- a/mlir/test/IR/affine-map.mlir
+++ b/mlir/test/IR/affine-map.mlir
@@ -156,7 +156,7 @@
#map48 = (i, j, k) -> (i * 64 floordiv 64, i * 512 floordiv 128, 4 * j mod 4, 4*j*4 mod 8)
// Simplifications for mod using known GCD's of the LHS expr.
-// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (0, 0, 0, (d0 * 4 + 3) mod 2)
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (0, 0, 0, 1)
#map49 = (i, j)[s0] -> ( (i * 4 + 8) mod 4, 32 * j * s0 * 8 mod 256, (4*i + (j * (s0 * 2))) mod 2, (4*i + 3) mod 2)
// Floordiv, ceildiv divide by one.
@@ -180,6 +180,9 @@
// CHECK: #map{{[0-9]+}} = () -> ()
#map55 = () -> ()
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0, d0 * 2 + d1 * 4 + 2, 1, 2, (d0 * 4) mod 8)
+#map56 = (d0, d1) -> ((4*d0 + 2) floordiv 4, (4*d0 + 8*d1 + 5) floordiv 2, (2*d0 + 4*d1 + 3) mod 2, (3*d0 - 4) mod 3, (4*d0 + 8*d1) mod 8)
+
// Single identity maps are removed.
// CHECK: func @f0(memref<2x4xi8, 1>)
func @f0(memref<2x4xi8, #map0, 1>)
@@ -355,3 +358,6 @@ func @f54(memref<10xi32, #map54>)
// CHECK: "foo.op"() {map = #map{{[0-9]+}}} : () -> ()
"foo.op"() {map = #map55} : () -> ()
+
+// CHECK: func @f56(memref<1x1xi8, #map{{[0-9]+}}>)
+func @f56(memref<1x1xi8, #map56>)
diff --git a/mlir/test/Transforms/Vectorize/compose_maps.mlir b/mlir/test/Transforms/Vectorize/compose_maps.mlir
index a8afbec9eff..f1826f440f2 100644
--- a/mlir/test/Transforms/Vectorize/compose_maps.mlir
+++ b/mlir/test/Transforms/Vectorize/compose_maps.mlir
@@ -78,7 +78,7 @@ func @simple5c() {
}
func @simple5d() {
- // CHECK: Composed map: (d0) -> ((d0 * 4 + 24) floordiv 3)
+ // CHECK: Composed map: (d0) -> ((d0 * 4) floordiv 3 + 8)
"test_affine_map"() { affine_map = (d0) -> (d0 - 1) } : () -> ()
"test_affine_map"() { affine_map = (d0) -> (d0 + 7) } : () -> ()
"test_affine_map"() { affine_map = (d0) -> (d0 * 4) } : () -> ()
@@ -128,4 +128,4 @@ func @multi_symbols() {
"test_affine_map"() { affine_map = (d0)[s0] -> (d0 + s0, d0 - s0) } : () -> ()
"test_affine_map"() { affine_map = (d0, d1)[s0, s1] -> (d0 + 1 + s1, d1 - 1 - s0) } : () -> ()
return
-} \ No newline at end of file
+}
diff --git a/mlir/test/Transforms/unroll.mlir b/mlir/test/Transforms/unroll.mlir
index 208df58c84e..da2a5e59bc9 100644
--- a/mlir/test/Transforms/unroll.mlir
+++ b/mlir/test/Transforms/unroll.mlir
@@ -21,7 +21,7 @@
// UNROLL-BY-4-DAG: [[MAP5:#map[0-9]+]] = (d0)[s0] -> (d0 + s0 + 1)
// UNROLL-BY-4-DAG: [[MAP6:#map[0-9]+]] = (d0, d1) -> (d0 * 16 + d1)
// UNROLL-BY-4-DAG: [[MAP11:#map[0-9]+]] = (d0) -> (d0)
-// UNROLL-BY-4-DAG: [[MAP_TRIP_COUNT_MULTIPLE_FOUR:#map[0-9]+]] = ()[s0, s1, s2] -> (s0 + ((-s0 + s1) floordiv 4) * 4, s0 + ((-s0 + s2) floordiv 4) * 4, s0 + ((-s0 + 1024) floordiv 4) * 4)
+// UNROLL-BY-4-DAG: [[MAP_TRIP_COUNT_MULTIPLE_FOUR:#map[0-9]+]] = ()[s0, s1, s2] -> (s0 + ((-s0 + s1) floordiv 4) * 4, s0 + ((-s0 + s2) floordiv 4) * 4, s0 + ((-s0) floordiv 4) * 4 + 1024)
// UNROLL-FULL-LABEL: func @loop_nest_simplest() {
func @loop_nest_simplest() {
OpenPOWER on IntegriCloud