summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis/AffineAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis/AffineAnalysis.cpp')
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp80
1 files changed, 60 insertions, 20 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 12af803fdad..f01735f26e1 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -247,22 +247,6 @@ public:
eq[getConstantIndex()] = expr.getValue();
}
- // Simplify the affine expression by flattening it and reconstructing it.
- AffineExpr simplifyAffineExpr(AffineExpr expr) {
- // TODO(bondhugula): only pure affine for now. The simplification here can
- // be extended to semi-affine maps in the future.
- if (!expr.isPureAffine())
- return expr;
-
- walkPostOrder(expr);
- ArrayRef<int64_t> flattenedExpr = operandExprStack.back();
- auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
- localExprs, expr.getContext());
- operandExprStack.pop_back();
- assert(operandExprStack.empty());
- return simplifiedExpr;
- }
-
private:
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
@@ -356,10 +340,23 @@ private:
} // end anonymous namespace
+/// Simplify the affine expression by flattening it and reconstructing it.
AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols) {
+ // TODO(bondhugula): only pure affine for now. The simplification here can
+ // be extended to semi-affine maps in the future.
+ if (!expr.isPureAffine())
+ return expr;
+
AffineExprFlattener flattener(numDims, numSymbols, expr.getContext());
- return flattener.simplifyAffineExpr(expr);
+ flattener.walkPostOrder(expr);
+ ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
+ auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
+ flattener.localExprs, expr.getContext());
+ flattener.operandExprStack.pop_back();
+ assert(flattener.operandExprStack.empty());
+
+ return simplifiedExpr;
}
/// Returns the AffineExpr that results from substituting `exprs[i]` into `e`
@@ -416,6 +413,7 @@ static bool getFlattenedAffineExprs(
return true;
}
+ flattenedExprs->clear();
flattenedExprs->reserve(exprs.size());
AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
@@ -428,6 +426,7 @@ static bool getFlattenedAffineExprs(
flattener.walkPostOrder(expr);
}
+ assert(flattener.operandExprStack.size() == exprs.size());
flattenedExprs->insert(flattenedExprs->end(),
flattener.operandExprStack.begin(),
flattener.operandExprStack.end());
@@ -766,11 +765,15 @@ static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
//
// Returns false if any AffineExpr cannot be flattened (due to it being
// semi-affine). Returns true otherwise.
+// TODO(bondhugula): assumes that dependenceDomain doesn't have local
+// variables already. Fix this soon.
static bool
addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
const AffineValueMap &dstAccessMap,
const ValuePositionMap &valuePosMap,
FlatAffineConstraints *dependenceDomain) {
+ if (dependenceDomain->getNumLocalIds() != 0)
+ return false;
AffineMap srcMap = srcAccessMap.getAffineMap();
AffineMap dstMap = dstAccessMap.getAffineMap();
assert(srcMap.getNumResults() == dstMap.getNumResults());
@@ -826,7 +829,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Local terms.
for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
eq[numDims + numSymbols + numSrcLocalIds + j] =
- destFlatExpr[dstNumIds + j];
+ -destFlatExpr[dstNumIds + j];
// Set constant term.
eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
@@ -856,8 +859,45 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Add equality constraints for any dst symbols defined by constant ops.
addEqForConstOperands(dstOperands);
- // TODO(b/122081337): add srcLocalVarCst, destLocalVarCst to the dependence
- // domain.
+ // By construction (see flattener), local var constraints will not have any
+ // equalities.
+ assert(srcLocalVarCst.getNumEqualities() == 0 &&
+ destLocalVarCst.getNumEqualities() == 0);
+ // Add inequalities from srcLocalVarCst and destLocalVarCst into the
+ // dependence domain.
+ SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
+ for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
+ std::fill(ineq.begin(), ineq.end(), 0);
+
+ // Set identifier coefficients from src local var constraints.
+ for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
+ ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
+ srcLocalVarCst.atIneq(r, j);
+ // Local terms.
+ for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
+ ineq[numDims + numSymbols + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
+ // Set constant term.
+ ineq[ineq.size() - 1] =
+ srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
+ dependenceDomain->addInequality(ineq);
+ }
+
+ for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
+ std::fill(ineq.begin(), ineq.end(), 0);
+ // Set identifier coefficients from dest local var constraints.
+ for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
+ ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
+ destLocalVarCst.atIneq(r, j);
+ // Local terms.
+ for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
+ ineq[numDims + numSymbols + numSrcLocalIds + j] =
+ destLocalVarCst.atIneq(r, dstNumIds + j);
+ // Set constant term.
+ ineq[ineq.size() - 1] =
+ destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
+
+ dependenceDomain->addInequality(ineq);
+ }
return true;
}
OpenPOWER on IntegriCloud