summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SDBM/SDBMExpr.cpp')
-rw-r--r--mlir/lib/Dialect/SDBM/SDBMExpr.cpp647
1 files changed, 647 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
new file mode 100644
index 00000000000..a174c8c84f2
--- /dev/null
+++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
@@ -0,0 +1,647 @@
+//===- SDBMExpr.cpp - MLIR SDBM Expression implementation -----------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SDBM/SDBMExpr.h"
+#include "SDBMExprDetail.h"
+#include "mlir/Dialect/SDBM/SDBMDialect.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+/// A simple compositional matcher for AffineExpr
+///
+/// Example usage:
+///
+/// ```c++
+/// AffineExprMatcher x, C, m;
+/// AffineExprMatcher pattern1 = ((x % C) * m) + x;
+/// AffineExprMatcher pattern2 = x + ((x % C) * m);
+/// if (pattern1.match(expr) || pattern2.match(expr)) {
+/// ...
+/// }
+/// ```
+class AffineExprMatcherStorage;
+class AffineExprMatcher {
+public:
+ AffineExprMatcher();
+ AffineExprMatcher(const AffineExprMatcher &other);
+
+ AffineExprMatcher operator+(AffineExprMatcher other) {
+ return AffineExprMatcher(AffineExprKind::Add, *this, other);
+ }
+ AffineExprMatcher operator*(AffineExprMatcher other) {
+ return AffineExprMatcher(AffineExprKind::Mul, *this, other);
+ }
+ AffineExprMatcher floorDiv(AffineExprMatcher other) {
+ return AffineExprMatcher(AffineExprKind::FloorDiv, *this, other);
+ }
+ AffineExprMatcher ceilDiv(AffineExprMatcher other) {
+ return AffineExprMatcher(AffineExprKind::CeilDiv, *this, other);
+ }
+ AffineExprMatcher operator%(AffineExprMatcher other) {
+ return AffineExprMatcher(AffineExprKind::Mod, *this, other);
+ }
+
+ AffineExpr match(AffineExpr expr);
+ AffineExpr matched();
+ Optional<int> getMatchedConstantValue();
+
+private:
+ AffineExprMatcher(AffineExprKind k, AffineExprMatcher a, AffineExprMatcher b);
+ AffineExprKind kind; // only used to match in binary op cases.
+ // A shared_ptr allows multiple references to same matcher storage without
+ // worrying about ownership or dealing with an arena. To be cleaned up if we
+ // go with this.
+ std::shared_ptr<AffineExprMatcherStorage> storage;
+};
+
+class AffineExprMatcherStorage {
+public:
+ AffineExprMatcherStorage() {}
+ AffineExprMatcherStorage(const AffineExprMatcherStorage &other)
+ : subExprs(other.subExprs.begin(), other.subExprs.end()),
+ matched(other.matched) {}
+ AffineExprMatcherStorage(ArrayRef<AffineExprMatcher> exprs)
+ : subExprs(exprs.begin(), exprs.end()) {}
+ AffineExprMatcherStorage(AffineExprMatcher &a, AffineExprMatcher &b)
+ : subExprs({a, b}) {}
+ llvm::SmallVector<AffineExprMatcher, 0> subExprs;
+ AffineExpr matched;
+};
+} // namespace
+
+AffineExprMatcher::AffineExprMatcher()
+ : kind(AffineExprKind::Constant), storage(new AffineExprMatcherStorage()) {}
+
+AffineExprMatcher::AffineExprMatcher(const AffineExprMatcher &other)
+ : kind(other.kind), storage(other.storage) {}
+
+Optional<int> AffineExprMatcher::getMatchedConstantValue() {
+ if (auto cst = storage->matched.dyn_cast<AffineConstantExpr>())
+ return cst.getValue();
+ return None;
+}
+
+AffineExpr AffineExprMatcher::match(AffineExpr expr) {
+ if (kind > AffineExprKind::LAST_AFFINE_BINARY_OP) {
+ if (storage->matched)
+ if (storage->matched != expr)
+ return AffineExpr();
+ storage->matched = expr;
+ return storage->matched;
+ }
+ if (kind != expr.getKind()) {
+ return AffineExpr();
+ }
+ if (auto bin = expr.dyn_cast<AffineBinaryOpExpr>()) {
+ if (!storage->subExprs.empty() &&
+ !storage->subExprs[0].match(bin.getLHS())) {
+ return AffineExpr();
+ }
+ if (!storage->subExprs.empty() &&
+ !storage->subExprs[1].match(bin.getRHS())) {
+ return AffineExpr();
+ }
+ if (storage->matched)
+ if (storage->matched != expr)
+ return AffineExpr();
+ storage->matched = expr;
+ return storage->matched;
+ }
+ llvm_unreachable("binary expected");
+}
+
+AffineExpr AffineExprMatcher::matched() { return storage->matched; }
+
+AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
+ AffineExprMatcher b)
+ : kind(k), storage(new AffineExprMatcherStorage(a, b)) {
+ storage->subExprs.push_back(a);
+ storage->subExprs.push_back(b);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMExpr
+//===----------------------------------------------------------------------===//
+
+SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
+
+MLIRContext *SDBMExpr::getContext() const {
+ return impl->dialect->getContext();
+}
+
+SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
+
+void SDBMExpr::print(raw_ostream &os) const {
+ struct Printer : public SDBMVisitor<Printer> {
+ Printer(raw_ostream &ostream) : prn(ostream) {}
+
+ void visitSum(SDBMSumExpr expr) {
+ visitVarying(expr.getLHS());
+ prn << " + ";
+ visitConstant(expr.getRHS());
+ }
+ void visitDiff(SDBMDiffExpr expr) {
+ visitPositive(expr.getLHS());
+ prn << " - ";
+ visitPositive(expr.getRHS());
+ }
+ void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
+ void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
+ void visitStripe(SDBMStripeExpr expr) {
+ visitPositive(expr.getVar());
+ prn << " # ";
+ visitConstant(expr.getStripeFactor());
+ }
+ void visitNeg(SDBMNegExpr expr) {
+ prn << '-';
+ visitPositive(expr.getVar());
+ }
+ void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
+
+ raw_ostream &prn;
+ };
+ Printer printer(os);
+ printer.visit(*this);
+}
+
+void SDBMExpr::dump() const {
+ print(llvm::errs());
+ llvm::errs() << '\n';
+}
+
+namespace {
+// Helper class to perform negation of an SDBM expression.
+struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
+ // Any positive expression is wrapped into a negation expression.
+ // -(x) = -x
+ SDBMExpr visitPositive(SDBMPositiveExpr expr) {
+ return SDBMNegExpr::get(expr);
+ }
+ // A negation expression is unwrapped.
+ // -(-x) = x
+ SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
+ // The value of the constant is negated.
+ SDBMExpr visitConstant(SDBMConstantExpr expr) {
+ return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
+ }
+ // Both terms of the sum are negated recursively.
+ SDBMExpr visitSum(SDBMSumExpr expr) {
+ return SDBMSumExpr::get(visit(expr.getLHS()).cast<SDBMVaryingExpr>(),
+ visit(expr.getRHS()).cast<SDBMConstantExpr>());
+ }
+ // Terms of a difference are interchanged.
+ // -(x - y) = y - x
+ SDBMExpr visitDiff(SDBMDiffExpr expr) {
+ return SDBMDiffExpr::get(expr.getRHS(), expr.getLHS());
+ }
+};
+} // namespace
+
+SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
+
+//===----------------------------------------------------------------------===//
+// SDBMSumExpr
+//===----------------------------------------------------------------------===//
+
+SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
+ assert(lhs && "expected SDBM variable expression");
+ assert(rhs && "expected SDBM constant");
+
+ // If LHS of a sum is another sum, fold the constant RHS parts.
+ if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
+ lhs = lhsSum.getLHS();
+ rhs = SDBMConstantExpr::get(rhs.getDialect(),
+ rhs.getValue() + lhsSum.getRHS().getValue());
+ }
+
+ StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
+ return uniquer.get<detail::SDBMBinaryExprStorage>(
+ /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
+}
+
+SDBMVaryingExpr SDBMSumExpr::getLHS() const {
+ return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMConstantExpr SDBMSumExpr::getRHS() const {
+ return static_cast<ImplType *>(impl)->rhs;
+}
+
+AffineExpr SDBMExpr::getAsAffineExpr() const {
+ struct Converter : public SDBMVisitor<Converter, AffineExpr> {
+ AffineExpr visitSum(SDBMSumExpr expr) {
+ AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ return lhs + rhs;
+ }
+
+ AffineExpr visitStripe(SDBMStripeExpr expr) {
+ AffineExpr lhs = visit(expr.getVar()),
+ rhs = visit(expr.getStripeFactor());
+ return lhs - (lhs % rhs);
+ }
+
+ AffineExpr visitDiff(SDBMDiffExpr expr) {
+ AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ return lhs - rhs;
+ }
+
+ AffineExpr visitDim(SDBMDimExpr expr) {
+ return getAffineDimExpr(expr.getPosition(), expr.getContext());
+ }
+
+ AffineExpr visitSymbol(SDBMSymbolExpr expr) {
+ return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
+ }
+
+ AffineExpr visitNeg(SDBMNegExpr expr) {
+ return getAffineBinaryOpExpr(AffineExprKind::Mul,
+ getAffineConstantExpr(-1, expr.getContext()),
+ visit(expr.getVar()));
+ }
+
+ AffineExpr visitConstant(SDBMConstantExpr expr) {
+ return getAffineConstantExpr(expr.getValue(), expr.getContext());
+ }
+ } converter;
+ return converter.visit(*this);
+}
+
+Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
+ struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
+ SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
+ // Attempt to recover a stripe expression. Because AffineExprs don't have
+ // a first-class difference kind, we check for both x + -1 * (x mod C) and
+ // -1 * (x mod C) + x cases.
+ AffineExprMatcher x, C, m;
+ AffineExprMatcher pattern1 = ((x % C) * m) + x;
+ AffineExprMatcher pattern2 = x + ((x % C) * m);
+ if ((pattern1.match(expr) && m.getMatchedConstantValue() == -1) ||
+ (pattern2.match(expr) && m.getMatchedConstantValue() == -1)) {
+ if (auto convertedLHS = visit(x.matched())) {
+ // TODO(ntv): return convertedLHS.stripe(C);
+ return SDBMStripeExpr::get(
+ convertedLHS.cast<SDBMPositiveExpr>(),
+ visit(C.matched()).cast<SDBMConstantExpr>());
+ }
+ }
+ auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return {};
+
+ // In a "add" AffineExpr, the constant always appears on the right. If
+ // there were two constants, they would have been folded away.
+ assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+
+ // SDBM accepts LHS variables and RHS constants in a sum.
+ auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
+ auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
+ if (rhsConstant && lhsVar)
+ return SDBMSumExpr::get(lhsVar, rhsConstant);
+
+ // The sum of a negated variable and a non-negated variable is a
+ // difference, supported as a special kind in SDBM. Because AffineExprs
+ // don't have first-class difference kind, check both LHS and RHS for
+ // negation.
+ auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>();
+ auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>();
+ auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
+ auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
+ if (lhsNeg && rhsVar)
+ return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
+ if (rhsNeg && lhsVar)
+ return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
+
+ // Other cases don't fit into SDBM.
+ return {};
+ }
+
+ SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
+ // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
+ AffineExprMatcher x, C;
+ AffineExprMatcher pattern = (x.floorDiv(C)) * C;
+ if (pattern.match(expr)) {
+ if (SDBMExpr converted = visit(x.matched())) {
+ if (auto varConverted = converted.dyn_cast<SDBMPositiveExpr>())
+ // TODO(ntv): return varConverted.stripe(C.getConstantValue());
+ return SDBMStripeExpr::get(
+ varConverted,
+ SDBMConstantExpr::get(dialect,
+ C.getMatchedConstantValue().getValue()));
+ }
+ }
+
+ auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return {};
+
+ // In a "mul" AffineExpr, the constant always appears on the right. If
+ // there were two constants, they would have been folded away.
+ assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+ if (!rhsConstant)
+ return {};
+
+ // The only supported "multiplication" expression is an SDBM is dimension
+ // negation, that is a product of dimension and constant -1.
+ auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>();
+ if (lhsVar && rhsConstant.getValue() == -1)
+ return SDBMNegExpr::get(lhsVar);
+
+ // Other multiplications are not allowed in SDBM.
+ return {};
+ }
+
+ SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
+ auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return {};
+
+ // 'mod' can only be converted to SDBM if its LHS is a variable
+ // and its RHS is a constant. Then it `x mod c = x - x stripe c`.
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+ auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>();
+ if (!lhsVar || !rhsConstant)
+ return {};
+ return SDBMDiffExpr::get(lhsVar,
+ SDBMStripeExpr::get(lhsVar, rhsConstant));
+ }
+
+ // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
+ SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
+ SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
+
+ // Dimensions, symbols and constants are converted trivially.
+ SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
+ return SDBMConstantExpr::get(dialect, expr.getValue());
+ }
+ SDBMExpr visitDimExpr(AffineDimExpr expr) {
+ return SDBMDimExpr::get(dialect, expr.getPosition());
+ }
+ SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
+ return SDBMSymbolExpr::get(dialect, expr.getPosition());
+ }
+
+ SDBMDialect *dialect;
+ } converter;
+ converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
+
+ if (auto result = converter.visit(affine))
+ return result;
+ return None;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMDiffExpr
+//===----------------------------------------------------------------------===//
+
+SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
+ assert(lhs && "expected SDBM dimension");
+ assert(rhs && "expected SDBM dimension");
+
+ StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
+ return uniquer.get<detail::SDBMDiffExprStorage>(
+ /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
+}
+
+SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
+ return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
+ return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMStripeExpr
+//===----------------------------------------------------------------------===//
+
+SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
+ SDBMConstantExpr stripeFactor) {
+ assert(var && "expected SDBM variable expression");
+ assert(stripeFactor && "expected non-null stripe factor");
+ if (stripeFactor.getValue() <= 0)
+ llvm::report_fatal_error("non-positive stripe factor");
+
+ StorageUniquer &uniquer = var.getDialect()->getUniquer();
+ return uniquer.get<detail::SDBMBinaryExprStorage>(
+ /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
+ stripeFactor);
+}
+
+SDBMPositiveExpr SDBMStripeExpr::getVar() const {
+ if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
+ return lhs.cast<SDBMPositiveExpr>();
+ return {};
+}
+
+SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
+ return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMInputExpr
+//===----------------------------------------------------------------------===//
+
+unsigned SDBMInputExpr::getPosition() const {
+ return static_cast<ImplType *>(impl)->position;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMDimExpr
+//===----------------------------------------------------------------------===//
+
+SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
+ assert(dialect && "expected non-null dialect");
+
+ auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
+ storage->dialect = dialect;
+ };
+
+ StorageUniquer &uniquer = dialect->getUniquer();
+ return uniquer.get<detail::SDBMPositiveExprStorage>(
+ assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMSymbolExpr
+//===----------------------------------------------------------------------===//
+
+SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
+ assert(dialect && "expected non-null dialect");
+
+ auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
+ storage->dialect = dialect;
+ };
+
+ StorageUniquer &uniquer = dialect->getUniquer();
+ return uniquer.get<detail::SDBMPositiveExprStorage>(
+ assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMConstantExpr
+//===----------------------------------------------------------------------===//
+
+SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
+ assert(dialect && "expected non-null dialect");
+
+ auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
+ storage->dialect = dialect;
+ };
+
+ StorageUniquer &uniquer = dialect->getUniquer();
+ return uniquer.get<detail::SDBMConstantExprStorage>(
+ assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value);
+}
+
+int64_t SDBMConstantExpr::getValue() const {
+ return static_cast<ImplType *>(impl)->constant;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMNegExpr
+//===----------------------------------------------------------------------===//
+
+SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
+ assert(var && "expected non-null SDBM variable expression");
+
+ StorageUniquer &uniquer = var.getDialect()->getUniquer();
+ return uniquer.get<detail::SDBMNegExprStorage>(
+ /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
+}
+
+SDBMPositiveExpr SDBMNegExpr::getVar() const {
+ return static_cast<ImplType *>(impl)->dim;
+}
+
+namespace mlir {
+namespace ops_assertions {
+
+SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
+ // If one of the operands is a negation, take a difference rather than a sum.
+ auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
+ auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
+ assert(!(lhsNeg && rhsNeg) && "a sum of negated expressions is a negation of "
+ "a sum of variables and not a correct SDBM");
+ if (lhsNeg)
+ return rhs - lhsNeg.getVar();
+ if (rhsNeg)
+ return lhs - rhsNeg.getVar();
+
+ // If LHS is a constant and RHS is not, swap the order to get into a supported
+ // sum case. From now on, RHS must be a constant.
+ auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+ if (!rhsConstant && lhsConstant) {
+ std::swap(lhs, rhs);
+ std::swap(lhsConstant, rhsConstant);
+ }
+ assert(rhsConstant && "at least one operand must be a constant");
+
+ // If LHS is another sum, first compute the sum of its variable
+ // part with the other argument and then add the constant part to enable
+ // constant folding (the variable part may, e.g., be a negation that requires
+ // to enter this function again).
+ auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
+ if (lhsSum)
+ return lhsSum.getLHS() +
+ (lhsSum.getRHS().getValue() + rhsConstant.getValue());
+
+ // Constant-fold if LHS is a constant.
+ if (lhsConstant)
+ return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
+ rhsConstant.getValue());
+
+ // Fold x + 0 == x.
+ if (rhsConstant.getValue() == 0)
+ return lhs;
+
+ return SDBMSumExpr::get(lhs.cast<SDBMVaryingExpr>(),
+ rhs.cast<SDBMConstantExpr>());
+}
+
+SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
+ // Fold x - x == 0.
+ if (lhs == rhs)
+ return SDBMConstantExpr::get(lhs.getDialect(), 0);
+
+ // LHS and RHS may be constants.
+ auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+
+ // Constant fold if both LHS and RHS are constants.
+ if (lhsConstant && rhsConstant)
+ return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
+ rhsConstant.getValue());
+
+ // Replace a difference with a sum with a negated value if one of LHS and RHS
+ // is a constant:
+ // x - C == x + (-C);
+ // C - x == -x + C.
+ // This calls into operator+ for further simplification.
+ if (rhsConstant)
+ return lhs + (-rhsConstant);
+ if (lhsConstant)
+ return -rhs + lhsConstant;
+
+ // Hoist constant factors outside the difference if any of sides is a sum:
+ // (x + A) - (y - B) == x - y + (A - B).
+ // If either LHS or RHS is a sum, collect the constant values separately and
+ // update LHS and RHS to point to the variable part of the sum.
+ auto lhsSum = lhs.dyn_cast<SDBMSumExpr>();
+ auto rhsSum = rhs.dyn_cast<SDBMSumExpr>();
+ int64_t value = 0;
+ if (lhsSum) {
+ value += lhsSum.getRHS().getValue();
+ lhs = lhsSum.getLHS();
+ }
+ if (rhsSum) {
+ value -= rhsSum.getRHS().getValue();
+ rhs = rhsSum.getLHS();
+ }
+
+ // This calls into operator+ for futher simplification in case value == 0.
+ return SDBMDiffExpr::get(lhs.cast<SDBMPositiveExpr>(),
+ rhs.cast<SDBMPositiveExpr>()) +
+ value;
+}
+
+SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) {
+ auto constantFactor = factor.cast<SDBMConstantExpr>();
+ assert(constantFactor.getValue() > 0 && "non-positive stripe");
+
+ // Fold x # 1 = x.
+ if (constantFactor.getValue() == 1)
+ return expr;
+
+ return SDBMStripeExpr::get(expr.cast<SDBMPositiveExpr>(), constantFactor);
+}
+
+} // namespace ops_assertions
+} // namespace mlir
OpenPOWER on IntegriCloud