diff options
Diffstat (limited to 'mlir/lib/Analysis/AffineAnalysis.cpp')
| -rw-r--r-- | mlir/lib/Analysis/AffineAnalysis.cpp | 80 |
1 files changed, 60 insertions, 20 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 12af803fdad..f01735f26e1 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -247,22 +247,6 @@ public: eq[getConstantIndex()] = expr.getValue(); } - // Simplify the affine expression by flattening it and reconstructing it. - AffineExpr simplifyAffineExpr(AffineExpr expr) { - // TODO(bondhugula): only pure affine for now. The simplification here can - // be extended to semi-affine maps in the future. - if (!expr.isPureAffine()) - return expr; - - walkPostOrder(expr); - ArrayRef<int64_t> flattenedExpr = operandExprStack.back(); - auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, - localExprs, expr.getContext()); - operandExprStack.pop_back(); - assert(operandExprStack.empty()); - return simplifiedExpr; - } - private: void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) { assert(operandExprStack.size() >= 2); @@ -356,10 +340,23 @@ private: } // end anonymous namespace +/// Simplify the affine expression by flattening it and reconstructing it. AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols) { + // TODO(bondhugula): only pure affine for now. The simplification here can + // be extended to semi-affine maps in the future. + if (!expr.isPureAffine()) + return expr; + AffineExprFlattener flattener(numDims, numSymbols, expr.getContext()); - return flattener.simplifyAffineExpr(expr); + flattener.walkPostOrder(expr); + ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); + auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols, + flattener.localExprs, expr.getContext()); + flattener.operandExprStack.pop_back(); + assert(flattener.operandExprStack.empty()); + + return simplifiedExpr; } /// Returns the AffineExpr that results from substituting `exprs[i]` into `e` @@ -416,6 +413,7 @@ static bool getFlattenedAffineExprs( return true; } + flattenedExprs->clear(); flattenedExprs->reserve(exprs.size()); AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); @@ -428,6 +426,7 @@ static bool getFlattenedAffineExprs( flattener.walkPostOrder(expr); } + assert(flattener.operandExprStack.size() == exprs.size()); flattenedExprs->insert(flattenedExprs->end(), flattener.operandExprStack.begin(), flattener.operandExprStack.end()); @@ -766,11 +765,15 @@ static void addDomainConstraints(const FlatAffineConstraints &srcDomain, // // Returns false if any AffineExpr cannot be flattened (due to it being // semi-affine). Returns true otherwise. +// TODO(bondhugula): assumes that dependenceDomain doesn't have local +// variables already. Fix this soon. static bool addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, const AffineValueMap &dstAccessMap, const ValuePositionMap &valuePosMap, FlatAffineConstraints *dependenceDomain) { + if (dependenceDomain->getNumLocalIds() != 0) + return false; AffineMap srcMap = srcAccessMap.getAffineMap(); AffineMap dstMap = dstAccessMap.getAffineMap(); assert(srcMap.getNumResults() == dstMap.getNumResults()); @@ -826,7 +829,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, // Local terms. for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) eq[numDims + numSymbols + numSrcLocalIds + j] = - destFlatExpr[dstNumIds + j]; + -destFlatExpr[dstNumIds + j]; // Set constant term. eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1]; @@ -856,8 +859,45 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, // Add equality constraints for any dst symbols defined by constant ops. addEqForConstOperands(dstOperands); - // TODO(b/122081337): add srcLocalVarCst, destLocalVarCst to the dependence - // domain. + // By construction (see flattener), local var constraints will not have any + // equalities. + assert(srcLocalVarCst.getNumEqualities() == 0 && + destLocalVarCst.getNumEqualities() == 0); + // Add inequalities from srcLocalVarCst and destLocalVarCst into the + // dependence domain. + SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols()); + for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) { + std::fill(ineq.begin(), ineq.end(), 0); + + // Set identifier coefficients from src local var constraints. + for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) + ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = + srcLocalVarCst.atIneq(r, j); + // Local terms. + for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) + ineq[numDims + numSymbols + j] = srcLocalVarCst.atIneq(r, srcNumIds + j); + // Set constant term. + ineq[ineq.size() - 1] = + srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1); + dependenceDomain->addInequality(ineq); + } + + for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) { + std::fill(ineq.begin(), ineq.end(), 0); + // Set identifier coefficients from dest local var constraints. + for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) + ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] = + destLocalVarCst.atIneq(r, j); + // Local terms. + for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) + ineq[numDims + numSymbols + numSrcLocalIds + j] = + destLocalVarCst.atIneq(r, dstNumIds + j); + // Set constant term. + ineq[ineq.size() - 1] = + destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1); + + dependenceDomain->addInequality(ineq); + } return true; } |

