diff options
Diffstat (limited to 'mlir/lib/IR/AffineExpr.cpp')
| -rw-r--r-- | mlir/lib/IR/AffineExpr.cpp | 43 |
1 files changed, 35 insertions, 8 deletions
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 95ebc0a1cbe..19599a8a62e 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -160,7 +160,7 @@ bool AffineExpr::isPureAffine() const { } // Returns the greatest known integral divisor of this affine expression. -uint64_t AffineExpr::getLargestKnownDivisor() const { +int64_t AffineExpr::getLargestKnownDivisor() const { AffineBinaryOpExpr binExpr(nullptr); switch (getKind()) { case AffineExprKind::SymbolId: @@ -444,6 +444,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); + // mlir floordiv by zero or negative numbers is undefined and preserved as is. if (!rhsConst || rhsConst.getValue() < 1) return nullptr; @@ -453,18 +454,32 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { // Fold floordiv of a multiply with a constant that is a multiple of the // divisor. Eg: (i * 128) floordiv 64 = i * 2. - if (rhsConst.getValue() == 1) + if (rhsConst == 1) return lhs; + // Simplify (expr * const) floordiv divConst when expr is known to be a + // multiple of divConst. auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); if (lBin && lBin.getKind() == AffineExprKind::Mul) { if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { - // rhsConst is known to be positive if a constant. + // rhsConst is known to be a positive constant. if (lrhs.getValue() % rhsConst.getValue() == 0) return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); } } + // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is + // known to be a multiple of divConst. + if (lBin && lBin.getKind() == AffineExprKind::Add) { + int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); + int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); + // rhsConst is known to be a positive constant. + if (llhsDiv % rhsConst.getValue() == 0 || + lrhsDiv % rhsConst.getValue() == 0) + return lBin.getLHS().floorDiv(rhsConst.getValue()) + + lBin.getRHS().floorDiv(rhsConst.getValue()); + } + return nullptr; } @@ -497,10 +512,12 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { if (rhsConst.getValue() == 1) return lhs; + // Simplify (expr * const) ceildiv divConst when const is known to be a + // multiple of divConst. auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); if (lBin && lBin.getKind() == AffineExprKind::Mul) { if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) { - // rhsConst is known to be positive if a constant. + // rhsConst is known to be a positive constant. if (lrhs.getValue() % rhsConst.getValue() == 0) return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue()); } @@ -526,6 +543,7 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { auto lhsConst = lhs.dyn_cast<AffineConstantExpr>(); auto rhsConst = rhs.dyn_cast<AffineConstantExpr>(); + // mod w.r.t zero or negative numbers is undefined and preserved as is. if (!rhsConst || rhsConst.getValue() < 1) return nullptr; @@ -539,11 +557,20 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0) return getAffineConstantExpr(0, lhs.getContext()); + // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is + // known to be a multiple of divConst. + auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>(); + if (lBin && lBin.getKind() == AffineExprKind::Add) { + int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor(); + int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor(); + // rhsConst is known to be a positive constant. + if (llhsDiv % rhsConst.getValue() == 0) + return lBin.getRHS() % rhsConst.getValue(); + if (lrhsDiv % rhsConst.getValue() == 0) + return lBin.getLHS() % rhsConst.getValue(); + } + return nullptr; - // TODO(bondhugula): In general, this can be simplified more by using the GCD - // test, or in general using quantifier elimination (add two new variables q - // and r, and eliminate all variables from the linear system other than r. All - // of this can be done through mlir/Analysis/'s FlatAffineConstraints. } AffineExpr AffineExpr::operator%(uint64_t v) const { |

