summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Analysis')
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp886
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp2854
-rw-r--r--mlir/lib/Analysis/CMakeLists.txt29
-rw-r--r--mlir/lib/Analysis/CallGraph.cpp256
-rw-r--r--mlir/lib/Analysis/Dominance.cpp171
-rw-r--r--mlir/lib/Analysis/InferTypeOpInterface.cpp22
-rw-r--r--mlir/lib/Analysis/Liveness.cpp373
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp388
-rw-r--r--mlir/lib/Analysis/MemRefBoundCheck.cpp53
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp152
-rw-r--r--mlir/lib/Analysis/OpStats.cpp84
-rw-r--r--mlir/lib/Analysis/SliceAnalysis.cpp213
-rw-r--r--mlir/lib/Analysis/TestMemRefDependenceCheck.cpp121
-rw-r--r--mlir/lib/Analysis/TestParallelismDetection.cpp48
-rw-r--r--mlir/lib/Analysis/Utils.cpp1007
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp232
-rw-r--r--mlir/lib/Analysis/Verifier.cpp266
17 files changed, 7155 insertions, 0 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
new file mode 100644
index 00000000000..3358bb437ff
--- /dev/null
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -0,0 +1,886 @@
+//===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous analysis routines for affine structures
+// (expressions, maps, sets), and other utilities relying on such analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/MathExtras.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "affine-analysis"
+
+using namespace mlir;
+
+using llvm::dbgs;
+
+/// Returns the sequence of AffineApplyOp Operations operation in
+/// 'affineApplyOps', which are reachable via a search starting from 'operands',
+/// and ending at operands which are not defined by AffineApplyOps.
+// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes
+// the AffineApplyOp into any user AffineApplyOps.
+void mlir::getReachableAffineApplyOps(
+ ArrayRef<Value> operands, SmallVectorImpl<Operation *> &affineApplyOps) {
+ struct State {
+ // The ssa value for this node in the DFS traversal.
+ Value value;
+ // The operand index of 'value' to explore next during DFS traversal.
+ unsigned operandIndex;
+ };
+ SmallVector<State, 4> worklist;
+ for (auto operand : operands) {
+ worklist.push_back({operand, 0});
+ }
+
+ while (!worklist.empty()) {
+ State &state = worklist.back();
+ auto *opInst = state.value->getDefiningOp();
+ // Note: getDefiningOp will return nullptr if the operand is not an
+ // Operation (i.e. block argument), which is a terminator for the search.
+ if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
+ worklist.pop_back();
+ continue;
+ }
+
+ if (state.operandIndex == 0) {
+ // Pre-Visit: Add 'opInst' to reachable sequence.
+ affineApplyOps.push_back(opInst);
+ }
+ if (state.operandIndex < opInst->getNumOperands()) {
+ // Visit: Add next 'affineApplyOp' operand to worklist.
+ // Get next operand to visit at 'operandIndex'.
+ auto nextOperand = opInst->getOperand(state.operandIndex);
+ // Increment 'operandIndex' in 'state'.
+ ++state.operandIndex;
+ // Add 'nextOperand' to worklist.
+ worklist.push_back({nextOperand, 0});
+ } else {
+ // Post-visit: done visiting operands AffineApplyOp, pop off stack.
+ worklist.pop_back();
+ }
+ }
+}
+
+// Builds a system of constraints with dimensional identifiers corresponding to
+// the loop IVs of the forOps appearing in that order. Any symbols founds in
+// the bound operands are added as symbols in the system. Returns failure for
+// the yet unimplemented cases.
+// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
+// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
+// step = 0 and/or by introducing a method in FlatAffineConstraints
+// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
+LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps,
+ FlatAffineConstraints *domain) {
+ SmallVector<Value, 4> indices;
+ extractForInductionVars(forOps, &indices);
+ // Reset while associated Values in 'indices' to the domain.
+ domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
+ for (auto forOp : forOps) {
+ // Add constraints from forOp's bounds.
+ if (failed(domain->addAffineForOpDomain(forOp)))
+ return failure();
+ }
+ return success();
+}
+
+// Computes the iteration domain for 'opInst' and populates 'indexSet', which
+// encapsulates the constraints involving loops surrounding 'opInst' and
+// potentially involving any Function symbols. The dimensional identifiers in
+// 'indexSet' correspond to the loops surrounding 'op' from outermost to
+// innermost.
+// TODO(andydavis) Add support to handle IfInsts surrounding 'op'.
+static LogicalResult getInstIndexSet(Operation *op,
+ FlatAffineConstraints *indexSet) {
+ // TODO(andydavis) Extend this to gather enclosing IfInsts and consider
+ // factoring it out into a utility function.
+ SmallVector<AffineForOp, 4> loops;
+ getLoopIVs(*op, &loops);
+ return getIndexSet(loops, indexSet);
+}
+
+// ValuePositionMap manages the mapping from Values which represent dimension
+// and symbol identifiers from 'src' and 'dst' access functions to positions
+// in new space where some Values are kept separate (using addSrc/DstValue)
+// and some Values are merged (addSymbolValue).
+// Position lookups return the absolute position in the new space which
+// has the following format:
+//
+// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers]
+//
+// Note: access function non-IV dimension identifiers (that have 'dimension'
+// positions in the access function position space) are assigned as symbols
+// in the output position space. Convenience access functions which lookup
+// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle
+// the common case of resolving positions for all access function operands.
+//
+// TODO(andydavis) Generalize this: could take a template parameter for
+// the number of maps (3 in the current case), and lookups could take indices
+// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})".
+class ValuePositionMap {
+public:
+ void addSrcValue(Value value) {
+ if (addValueAt(value, &srcDimPosMap, numSrcDims))
+ ++numSrcDims;
+ }
+ void addDstValue(Value value) {
+ if (addValueAt(value, &dstDimPosMap, numDstDims))
+ ++numDstDims;
+ }
+ void addSymbolValue(Value value) {
+ if (addValueAt(value, &symbolPosMap, numSymbols))
+ ++numSymbols;
+ }
+ unsigned getSrcDimOrSymPos(Value value) const {
+ return getDimOrSymPos(value, srcDimPosMap, 0);
+ }
+ unsigned getDstDimOrSymPos(Value value) const {
+ return getDimOrSymPos(value, dstDimPosMap, numSrcDims);
+ }
+ unsigned getSymPos(Value value) const {
+ auto it = symbolPosMap.find(value);
+ assert(it != symbolPosMap.end());
+ return numSrcDims + numDstDims + it->second;
+ }
+
+ unsigned getNumSrcDims() const { return numSrcDims; }
+ unsigned getNumDstDims() const { return numDstDims; }
+ unsigned getNumDims() const { return numSrcDims + numDstDims; }
+ unsigned getNumSymbols() const { return numSymbols; }
+
+private:
+ bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap,
+ unsigned position) {
+ auto it = posMap->find(value);
+ if (it == posMap->end()) {
+ (*posMap)[value] = position;
+ return true;
+ }
+ return false;
+ }
+ unsigned getDimOrSymPos(Value value,
+ const DenseMap<Value, unsigned> &dimPosMap,
+ unsigned dimPosOffset) const {
+ auto it = dimPosMap.find(value);
+ if (it != dimPosMap.end()) {
+ return dimPosOffset + it->second;
+ }
+ it = symbolPosMap.find(value);
+ assert(it != symbolPosMap.end());
+ return numSrcDims + numDstDims + it->second;
+ }
+
+ unsigned numSrcDims = 0;
+ unsigned numDstDims = 0;
+ unsigned numSymbols = 0;
+ DenseMap<Value, unsigned> srcDimPosMap;
+ DenseMap<Value, unsigned> dstDimPosMap;
+ DenseMap<Value, unsigned> symbolPosMap;
+};
+
+// Builds a map from Value to identifier position in a new merged identifier
+// list, which is the result of merging dim/symbol lists from src/dst
+// iteration domains, the format of which is as follows:
+//
+// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term]
+//
+// This method populates 'valuePosMap' with mappings from operand Values in
+// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain')
+// to the position of these values in the merged list.
+static void buildDimAndSymbolPositionMaps(
+ const FlatAffineConstraints &srcDomain,
+ const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap,
+ const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap,
+ FlatAffineConstraints *dependenceConstraints) {
+ auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) {
+ for (unsigned i = 0, e = values.size(); i < e; ++i) {
+ auto value = values[i];
+ if (!isForInductionVar(values[i])) {
+ assert(isValidSymbol(values[i]) &&
+ "access operand has to be either a loop IV or a symbol");
+ valuePosMap->addSymbolValue(value);
+ } else if (isSrc) {
+ valuePosMap->addSrcValue(value);
+ } else {
+ valuePosMap->addDstValue(value);
+ }
+ }
+ };
+
+ SmallVector<Value, 4> srcValues, destValues;
+ srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues);
+ dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues);
+ // Update value position map with identifiers from src iteration domain.
+ updateValuePosMap(srcValues, /*isSrc=*/true);
+ // Update value position map with identifiers from dst iteration domain.
+ updateValuePosMap(destValues, /*isSrc=*/false);
+ // Update value position map with identifiers from src access function.
+ updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true);
+ // Update value position map with identifiers from dst access function.
+ updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false);
+}
+
+// Sets up dependence constraints columns appropriately, in the format:
+// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term]
+void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
+ const FlatAffineConstraints &dstDomain,
+ const AffineValueMap &srcAccessMap,
+ const AffineValueMap &dstAccessMap,
+ const ValuePositionMap &valuePosMap,
+ FlatAffineConstraints *dependenceConstraints) {
+ // Calculate number of equalities/inequalities and columns required to
+ // initialize FlatAffineConstraints for 'dependenceDomain'.
+ unsigned numIneq =
+ srcDomain.getNumInequalities() + dstDomain.getNumInequalities();
+ AffineMap srcMap = srcAccessMap.getAffineMap();
+ assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults());
+ unsigned numEq = srcMap.getNumResults();
+ unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds();
+ unsigned numSymbols = valuePosMap.getNumSymbols();
+ unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds();
+ unsigned numIds = numDims + numSymbols + numLocals;
+ unsigned numCols = numIds + 1;
+
+ // Set flat affine constraints sizes and reserving space for constraints.
+ dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols,
+ numLocals);
+
+ // Set values corresponding to dependence constraint identifiers.
+ SmallVector<Value, 4> srcLoopIVs, dstLoopIVs;
+ srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs);
+ dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs);
+
+ dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs);
+ dependenceConstraints->setIdValues(
+ srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs);
+
+ // Set values for the symbolic identifier dimensions.
+ auto setSymbolIds = [&](ArrayRef<Value> values) {
+ for (auto value : values) {
+ if (!isForInductionVar(value)) {
+ assert(isValidSymbol(value) && "expected symbol");
+ dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
+ }
+ }
+ };
+
+ setSymbolIds(srcAccessMap.getOperands());
+ setSymbolIds(dstAccessMap.getOperands());
+
+ SmallVector<Value, 8> srcSymbolValues, dstSymbolValues;
+ srcDomain.getIdValues(srcDomain.getNumDimIds(),
+ srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues);
+ dstDomain.getIdValues(dstDomain.getNumDimIds(),
+ dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues);
+ setSymbolIds(srcSymbolValues);
+ setSymbolIds(dstSymbolValues);
+
+ for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds();
+ i < e; i++)
+ assert(dependenceConstraints->getIds()[i].hasValue());
+}
+
+// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into
+// 'dependenceDomain'.
+// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a
+// srcDomain/dstDomain Value maps.
+static void addDomainConstraints(const FlatAffineConstraints &srcDomain,
+ const FlatAffineConstraints &dstDomain,
+ const ValuePositionMap &valuePosMap,
+ FlatAffineConstraints *dependenceDomain) {
+ unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds();
+
+ SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols());
+
+ auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) {
+ const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain;
+ unsigned numCsts =
+ isEq ? domain.getNumEqualities() : domain.getNumInequalities();
+ unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds();
+ auto at = [&](unsigned i, unsigned j) -> int64_t {
+ return isEq ? domain.atEq(i, j) : domain.atIneq(i, j);
+ };
+ auto map = [&](unsigned i) -> int64_t {
+ return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i))
+ : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i));
+ };
+
+ for (unsigned i = 0; i < numCsts; ++i) {
+ // Zero fill.
+ std::fill(cst.begin(), cst.end(), 0);
+ // Set coefficients for identifiers corresponding to domain.
+ for (unsigned j = 0; j < numDimAndSymbolIds; ++j)
+ cst[map(j)] = at(i, j);
+ // Local terms.
+ for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++)
+ cst[depNumDimsAndSymbolIds + localOffset + j] =
+ at(i, numDimAndSymbolIds + j);
+ // Set constant term.
+ cst[cst.size() - 1] = at(i, domain.getNumCols() - 1);
+ // Add constraint.
+ if (isEq)
+ dependenceDomain->addEquality(cst);
+ else
+ dependenceDomain->addInequality(cst);
+ }
+ };
+
+ // Add equalities from src domain.
+ addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0);
+ // Add inequalities from src domain.
+ addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0);
+ // Add equalities from dst domain.
+ addDomain(/*isSrc=*/false, /*isEq=*/true,
+ /*localOffset=*/srcDomain.getNumLocalIds());
+ // Add inequalities from dst domain.
+ addDomain(/*isSrc=*/false, /*isEq=*/false,
+ /*localOffset=*/srcDomain.getNumLocalIds());
+}
+
+// Adds equality constraints that equate src and dst access functions
+// represented by 'srcAccessMap' and 'dstAccessMap' for each result.
+// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count.
+// For example, given the following two accesses functions to a 2D memref:
+//
+// Source access function:
+// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2)
+//
+// Destination access function:
+// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2)
+//
+// This method constructs the following equality constraints in
+// 'dependenceDomain', by equating the access functions for each result
+// (i.e. each memref dim). Notice that 'd0' for the destination access function
+// is mapped into 'd0' in the equality constraint:
+//
+// d0 d1 s0 c
+// -- -- -- --
+// a0 -c0 (a1 - c1) (a1 - c2) = 0
+// b0 -f0 (b1 - f1) (b1 - f2) = 0
+//
+// Returns failure if any AffineExpr cannot be flattened (due to it being
+// semi-affine). Returns success otherwise.
+static LogicalResult
+addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
+ const AffineValueMap &dstAccessMap,
+ const ValuePositionMap &valuePosMap,
+ FlatAffineConstraints *dependenceDomain) {
+ AffineMap srcMap = srcAccessMap.getAffineMap();
+ AffineMap dstMap = dstAccessMap.getAffineMap();
+ assert(srcMap.getNumResults() == dstMap.getNumResults());
+ unsigned numResults = srcMap.getNumResults();
+
+ unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols();
+ ArrayRef<Value> srcOperands = srcAccessMap.getOperands();
+
+ unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols();
+ ArrayRef<Value> dstOperands = dstAccessMap.getOperands();
+
+ std::vector<SmallVector<int64_t, 8>> srcFlatExprs;
+ std::vector<SmallVector<int64_t, 8>> destFlatExprs;
+ FlatAffineConstraints srcLocalVarCst, destLocalVarCst;
+ // Get flattened expressions for the source destination maps.
+ if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) ||
+ failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst)))
+ return failure();
+
+ unsigned domNumLocalIds = dependenceDomain->getNumLocalIds();
+ unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds();
+ unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds();
+ unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds;
+ for (unsigned i = 0; i < numLocalIdsToAdd; i++) {
+ dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds());
+ }
+
+ unsigned numDims = dependenceDomain->getNumDimIds();
+ unsigned numSymbols = dependenceDomain->getNumSymbolIds();
+ unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds();
+ unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds;
+
+ // Equality to add.
+ SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols());
+ for (unsigned i = 0; i < numResults; ++i) {
+ // Zero fill.
+ std::fill(eq.begin(), eq.end(), 0);
+
+ // Flattened AffineExpr for src result 'i'.
+ const auto &srcFlatExpr = srcFlatExprs[i];
+ // Set identifier coefficients from src access function.
+ for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
+ eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j];
+ // Local terms.
+ for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
+ eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j];
+ // Set constant term.
+ eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1];
+
+ // Flattened AffineExpr for dest result 'i'.
+ const auto &destFlatExpr = destFlatExprs[i];
+ // Set identifier coefficients from dst access function.
+ for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
+ eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j];
+ // Local terms.
+ for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
+ eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j];
+ // Set constant term.
+ eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1];
+
+ // Add equality constraint.
+ dependenceDomain->addEquality(eq);
+ }
+
+ // Add equality constraints for any operands that are defined by constant ops.
+ auto addEqForConstOperands = [&](ArrayRef<Value> operands) {
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (isForInductionVar(operands[i]))
+ continue;
+ auto symbol = operands[i];
+ assert(isValidSymbol(symbol));
+ // Check if the symbol is a constant.
+ if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp()))
+ dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
+ cOp.getValue());
+ }
+ };
+
+ // Add equality constraints for any src symbols defined by constant ops.
+ addEqForConstOperands(srcOperands);
+ // Add equality constraints for any dst symbols defined by constant ops.
+ addEqForConstOperands(dstOperands);
+
+ // By construction (see flattener), local var constraints will not have any
+ // equalities.
+ assert(srcLocalVarCst.getNumEqualities() == 0 &&
+ destLocalVarCst.getNumEqualities() == 0);
+ // Add inequalities from srcLocalVarCst and destLocalVarCst into the
+ // dependence domain.
+ SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols());
+ for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) {
+ std::fill(ineq.begin(), ineq.end(), 0);
+
+ // Set identifier coefficients from src local var constraints.
+ for (unsigned j = 0, e = srcOperands.size(); j < e; ++j)
+ ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] =
+ srcLocalVarCst.atIneq(r, j);
+ // Local terms.
+ for (unsigned j = 0, e = srcNumLocalIds; j < e; j++)
+ ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j);
+ // Set constant term.
+ ineq[ineq.size() - 1] =
+ srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1);
+ dependenceDomain->addInequality(ineq);
+ }
+
+ for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) {
+ std::fill(ineq.begin(), ineq.end(), 0);
+ // Set identifier coefficients from dest local var constraints.
+ for (unsigned j = 0, e = dstOperands.size(); j < e; ++j)
+ ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] =
+ destLocalVarCst.atIneq(r, j);
+ // Local terms.
+ for (unsigned j = 0, e = dstNumLocalIds; j < e; j++)
+ ineq[newLocalIdOffset + numSrcLocalIds + j] =
+ destLocalVarCst.atIneq(r, dstNumIds + j);
+ // Set constant term.
+ ineq[ineq.size() - 1] =
+ destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1);
+
+ dependenceDomain->addInequality(ineq);
+ }
+ return success();
+}
+
+// Returns the number of outer loop common to 'src/dstDomain'.
+// Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
+static unsigned
+getNumCommonLoops(const FlatAffineConstraints &srcDomain,
+ const FlatAffineConstraints &dstDomain,
+ SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
+ // Find the number of common loops shared by src and dst accesses.
+ unsigned minNumLoops =
+ std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
+ unsigned numCommonLoops = 0;
+ for (unsigned i = 0; i < minNumLoops; ++i) {
+ if (!isForInductionVar(srcDomain.getIdValue(i)) ||
+ !isForInductionVar(dstDomain.getIdValue(i)) ||
+ srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
+ break;
+ if (commonLoops != nullptr)
+ commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i)));
+ ++numCommonLoops;
+ }
+ if (commonLoops != nullptr)
+ assert(commonLoops->size() == numCommonLoops);
+ return numCommonLoops;
+}
+
+// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
+static Block *getCommonBlock(const MemRefAccess &srcAccess,
+ const MemRefAccess &dstAccess,
+ const FlatAffineConstraints &srcDomain,
+ unsigned numCommonLoops) {
+ if (numCommonLoops == 0) {
+ auto *block = srcAccess.opInst->getBlock();
+ while (!llvm::isa<FuncOp>(block->getParentOp())) {
+ block = block->getParentOp()->getBlock();
+ }
+ return block;
+ }
+ auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
+ auto forOp = getForInductionVarOwner(commonForValue);
+ assert(forOp && "commonForValue was not an induction variable");
+ return forOp.getBody();
+}
+
+// Returns true if the ancestor operation of 'srcAccess' appears before the
+// ancestor operation of 'dstAccess' in the common ancestral block. Returns
+// false otherwise.
+// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
+// the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that
+// 'numCommonLoops' is the number of contiguous surrounding outer loops.
+static bool srcAppearsBeforeDstInAncestralBlock(
+ const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
+ const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) {
+ // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
+ auto *commonBlock =
+ getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
+ // Check the dominance relationship between the respective ancestors of the
+ // src and dst in the Block of the innermost among the common loops.
+ auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst);
+ assert(srcInst != nullptr);
+ auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst);
+ assert(dstInst != nullptr);
+
+ // Determine whether dstInst comes after srcInst.
+ return srcInst->isBeforeInBlock(dstInst);
+}
+
+// Adds ordering constraints to 'dependenceDomain' based on number of loops
+// common to 'src/dstDomain' and requested 'loopDepth'.
+// Note that 'loopDepth' cannot exceed the number of common loops plus one.
+// EX: Given a loop nest of depth 2 with IVs 'i' and 'j':
+// *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1
+// *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1
+// *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j'
+static void addOrderingConstraints(const FlatAffineConstraints &srcDomain,
+ const FlatAffineConstraints &dstDomain,
+ unsigned loopDepth,
+ FlatAffineConstraints *dependenceDomain) {
+ unsigned numCols = dependenceDomain->getNumCols();
+ SmallVector<int64_t, 4> eq(numCols);
+ unsigned numSrcDims = srcDomain.getNumDimIds();
+ unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
+ unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth);
+ for (unsigned i = 0; i < numCommonLoopConstraints; ++i) {
+ std::fill(eq.begin(), eq.end(), 0);
+ eq[i] = -1;
+ eq[i + numSrcDims] = 1;
+ if (i == loopDepth - 1) {
+ eq[numCols - 1] = -1;
+ dependenceDomain->addInequality(eq);
+ } else {
+ dependenceDomain->addEquality(eq);
+ }
+ }
+}
+
+// Computes distance and direction vectors in 'dependences', by adding
+// variables to 'dependenceDomain' which represent the difference of the IVs,
+// eliminating all other variables, and reading off distance vectors from
+// equality constraints (if possible), and direction vectors from inequalities.
+static void computeDirectionVector(
+ const FlatAffineConstraints &srcDomain,
+ const FlatAffineConstraints &dstDomain, unsigned loopDepth,
+ FlatAffineConstraints *dependenceDomain,
+ SmallVector<DependenceComponent, 2> *dependenceComponents) {
+ // Find the number of common loops shared by src and dst accesses.
+ SmallVector<AffineForOp, 4> commonLoops;
+ unsigned numCommonLoops =
+ getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
+ if (numCommonLoops == 0)
+ return;
+ // Compute direction vectors for requested loop depth.
+ unsigned numIdsToEliminate = dependenceDomain->getNumIds();
+ // Add new variables to 'dependenceDomain' to represent the direction
+ // constraints for each shared loop.
+ for (unsigned j = 0; j < numCommonLoops; ++j) {
+ dependenceDomain->addDimId(j);
+ }
+
+ // Add equality constraints for each common loop, setting newly introduced
+ // variable at column 'j' to the 'dst' IV minus the 'src IV.
+ SmallVector<int64_t, 4> eq;
+ eq.resize(dependenceDomain->getNumCols());
+ unsigned numSrcDims = srcDomain.getNumDimIds();
+ // Constraint variables format:
+ // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant]
+ for (unsigned j = 0; j < numCommonLoops; ++j) {
+ std::fill(eq.begin(), eq.end(), 0);
+ eq[j] = 1;
+ eq[j + numCommonLoops] = 1;
+ eq[j + numCommonLoops + numSrcDims] = -1;
+ dependenceDomain->addEquality(eq);
+ }
+
+ // Eliminate all variables other than the direction variables just added.
+ dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate);
+
+ // Scan each common loop variable column and set direction vectors based
+ // on eliminated constraint system.
+ dependenceComponents->resize(numCommonLoops);
+ for (unsigned j = 0; j < numCommonLoops; ++j) {
+ (*dependenceComponents)[j].op = commonLoops[j].getOperation();
+ auto lbConst = dependenceDomain->getConstantLowerBound(j);
+ (*dependenceComponents)[j].lb =
+ lbConst.getValueOr(std::numeric_limits<int64_t>::min());
+ auto ubConst = dependenceDomain->getConstantUpperBound(j);
+ (*dependenceComponents)[j].ub =
+ ubConst.getValueOr(std::numeric_limits<int64_t>::max());
+ }
+}
+
+// Populates 'accessMap' with composition of AffineApplyOps reachable from
+// indices of MemRefAccess.
+void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
+ // Get affine map from AffineLoad/Store.
+ AffineMap map;
+ if (auto loadOp = dyn_cast<AffineLoadOp>(opInst))
+ map = loadOp.getAffineMap();
+ else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst))
+ map = storeOp.getAffineMap();
+ SmallVector<Value, 8> operands(indices.begin(), indices.end());
+ fullyComposeAffineMapAndOperands(&map, &operands);
+ map = simplifyAffineMap(map);
+ canonicalizeMapAndOperands(&map, &operands);
+ accessMap->reset(map, operands);
+}
+
+// Builds a flat affine constraint system to check if there exists a dependence
+// between memref accesses 'srcAccess' and 'dstAccess'.
+// Returns 'NoDependence' if the accesses can be definitively shown not to
+// access the same element.
+// Returns 'HasDependence' if the accesses do access the same element.
+// Returns 'Failure' if an error or unsupported case was encountered.
+// If a dependence exists, returns in 'dependenceComponents' a direction
+// vector for the dependence, with a component for each loop IV in loops
+// common to both accesses (see Dependence in AffineAnalysis.h for details).
+//
+// The memref access dependence check is comprised of the following steps:
+// *) Compute access functions for each access. Access functions are computed
+// using AffineValueMaps initialized with the indices from an access, then
+// composed with AffineApplyOps reachable from operands of that access,
+// until operands of the AffineValueMap are loop IVs or symbols.
+// *) Build iteration domain constraints for each access. Iteration domain
+// constraints are pairs of inequality constraints representing the
+// upper/lower loop bounds for each AffineForOp in the loop nest associated
+// with each access.
+// *) Build dimension and symbol position maps for each access, which map
+// Values from access functions and iteration domains to their position
+// in the merged constraint system built by this method.
+//
+// This method builds a constraint system with the following column format:
+//
+// [src-dim-identifiers, dst-dim-identifiers, symbols, constant]
+//
+// For example, given the following MLIR code with "source" and "destination"
+// accesses to the same memref label, and symbols %M, %N, %K:
+//
+// affine.for %i0 = 0 to 100 {
+// affine.for %i1 = 0 to 50 {
+// %a0 = affine.apply
+// (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N]
+// // Source memref access.
+// store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32>
+// }
+// }
+//
+// affine.for %i2 = 0 to 100 {
+// affine.for %i3 = 0 to 50 {
+// %a1 = affine.apply
+// (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M]
+// // Destination memref access.
+// %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32>
+// }
+// }
+//
+// The access functions would be the following:
+//
+// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M)
+// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K)
+//
+// The iteration domains for the src/dst accesses would be the following:
+//
+// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50
+// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50
+//
+// The symbols by both accesses would be assigned to a canonical position order
+// which will be used in the dependence constraint system:
+//
+// symbol name: %M %N %K
+// symbol pos: 0 1 2
+//
+// Equality constraints are built by equating each result of src/destination
+// access functions. For this example, the following two equality constraints
+// will be added to the dependence constraint system:
+//
+// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
+// 2 -4 -7 -9 1 1 0 0 = 0
+// 0 3 0 -11 -1 0 1 0 = 0
+//
+// Inequality constraints from the iteration domain will be meged into
+// the dependence constraint system
+//
+// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const]
+// 1 0 0 0 0 0 0 0 >= 0
+// -1 0 0 0 0 0 0 100 >= 0
+// 0 1 0 0 0 0 0 0 >= 0
+// 0 -1 0 0 0 0 0 50 >= 0
+// 0 0 1 0 0 0 0 0 >= 0
+// 0 0 -1 0 0 0 0 100 >= 0
+// 0 0 0 1 0 0 0 0 >= 0
+// 0 0 0 -1 0 0 0 50 >= 0
+//
+//
+// TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv.
+DependenceResult mlir::checkMemrefAccessDependence(
+ const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
+ unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
+ SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) {
+ LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: "
+ << Twine(loopDepth) << " between:\n";);
+ LLVM_DEBUG(srcAccess.opInst->dump(););
+ LLVM_DEBUG(dstAccess.opInst->dump(););
+
+ // Return 'NoDependence' if these accesses do not access the same memref.
+ if (srcAccess.memref != dstAccess.memref)
+ return DependenceResult::NoDependence;
+
+ // Return 'NoDependence' if one of these accesses is not an AffineStoreOp.
+ if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) &&
+ !isa<AffineStoreOp>(dstAccess.opInst))
+ return DependenceResult::NoDependence;
+
+ // Get composed access function for 'srcAccess'.
+ AffineValueMap srcAccessMap;
+ srcAccess.getAccessMap(&srcAccessMap);
+
+ // Get composed access function for 'dstAccess'.
+ AffineValueMap dstAccessMap;
+ dstAccess.getAccessMap(&dstAccessMap);
+
+ // Get iteration domain for the 'srcAccess' operation.
+ FlatAffineConstraints srcDomain;
+ if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain)))
+ return DependenceResult::Failure;
+
+ // Get iteration domain for 'dstAccess' operation.
+ FlatAffineConstraints dstDomain;
+ if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain)))
+ return DependenceResult::Failure;
+
+ // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor
+ // operation of 'srcAccess' does not properly dominate the ancestor
+ // operation of 'dstAccess' in the same common operation block.
+ // Note: this check is skipped if 'allowRAR' is true, because because RAR
+ // deps can exist irrespective of lexicographic ordering b/w src and dst.
+ unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
+ assert(loopDepth <= numCommonLoops + 1);
+ if (!allowRAR && loopDepth > numCommonLoops &&
+ !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain,
+ numCommonLoops)) {
+ return DependenceResult::NoDependence;
+ }
+ // Build dim and symbol position maps for each access from access operand
+ // Value to position in merged constraint system.
+ ValuePositionMap valuePosMap;
+ buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap,
+ dstAccessMap, &valuePosMap,
+ dependenceConstraints);
+
+ initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap,
+ valuePosMap, dependenceConstraints);
+
+ assert(valuePosMap.getNumDims() ==
+ srcDomain.getNumDimIds() + dstDomain.getNumDimIds());
+
+ // Create memref access constraint by equating src/dst access functions.
+ // Note that this check is conservative, and will fail in the future when
+ // local variables for mod/div exprs are supported.
+ if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap,
+ dependenceConstraints)))
+ return DependenceResult::Failure;
+
+ // Add 'src' happens before 'dst' ordering constraints.
+ addOrderingConstraints(srcDomain, dstDomain, loopDepth,
+ dependenceConstraints);
+ // Add src and dst domain constraints.
+ addDomainConstraints(srcDomain, dstDomain, valuePosMap,
+ dependenceConstraints);
+
+ // Return 'NoDependence' if the solution space is empty: no dependence.
+ if (dependenceConstraints->isEmpty()) {
+ return DependenceResult::NoDependence;
+ }
+
+ // Compute dependence direction vector and return true.
+ if (dependenceComponents != nullptr) {
+ computeDirectionVector(srcDomain, dstDomain, loopDepth,
+ dependenceConstraints, dependenceComponents);
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n");
+ LLVM_DEBUG(dependenceConstraints->dump());
+ return DependenceResult::HasDependence;
+}
+
+/// Gathers dependence components for dependences between all ops in loop nest
+/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
+void mlir::getDependenceComponents(
+ AffineForOp forOp, unsigned maxLoopDepth,
+ std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) {
+ // Collect all load and store ops in loop nest rooted at 'forOp'.
+ SmallVector<Operation *, 8> loadAndStoreOpInsts;
+ forOp.getOperation()->walk([&](Operation *opInst) {
+ if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+ loadAndStoreOpInsts.push_back(opInst);
+ });
+
+ unsigned numOps = loadAndStoreOpInsts.size();
+ for (unsigned d = 1; d <= maxLoopDepth; ++d) {
+ for (unsigned i = 0; i < numOps; ++i) {
+ auto *srcOpInst = loadAndStoreOpInsts[i];
+ MemRefAccess srcAccess(srcOpInst);
+ for (unsigned j = 0; j < numOps; ++j) {
+ auto *dstOpInst = loadAndStoreOpInsts[j];
+ MemRefAccess dstAccess(dstOpInst);
+
+ FlatAffineConstraints dependenceConstraints;
+ SmallVector<DependenceComponent, 2> depComps;
+ // TODO(andydavis,bondhugula) Explore whether it would be profitable
+ // to pre-compute and store deps instead of repeatedly checking.
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, dstAccess, d, &dependenceConstraints, &depComps);
+ if (hasDependence(result))
+ depCompsVec->push_back(depComps);
+ }
+ }
+ }
+}
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
new file mode 100644
index 00000000000..78a869884ee
--- /dev/null
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -0,0 +1,2854 @@
+//===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Structures for affine/polyhedral analysis of MLIR functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "affine-structures"
+
+using namespace mlir;
+using llvm::SmallDenseMap;
+using llvm::SmallDenseSet;
+
+namespace {
+
+// See comments for SimpleAffineExprFlattener.
+// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
+// constraint information associated with mod's, floordiv's, and ceildiv's
+// in FlatAffineConstraints 'localVarCst'.
+struct AffineExprFlattener : public SimpleAffineExprFlattener {
+public:
+ // Constraints connecting newly introduced local variables (for mod's and
+ // div's) to existing (dimensional and symbolic) ones. These are always
+ // inequalities.
+ FlatAffineConstraints localVarCst;
+
+ AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx)
+ : SimpleAffineExprFlattener(nDims, nSymbols) {
+ localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
+ }
+
+private:
+ // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
+ // The local identifier added is always a floordiv of a pure add/mul affine
+ // function of other identifiers, coefficients of which are specified in
+ // `dividend' and with respect to the positive constant `divisor'. localExpr
+ // is the simplified tree expression (AffineExpr) corresponding to the
+ // quantifier.
+ void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
+ AffineExpr localExpr) override {
+ SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
+ // Update localVarCst.
+ localVarCst.addLocalFloorDiv(dividend, divisor);
+ }
+};
+
+} // end anonymous namespace
+
+// Flattens the expressions in map. Returns failure if 'expr' was unable to be
+// flattened (i.e., semi-affine expressions not handled yet).
+static LogicalResult
+getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
+ unsigned numSymbols,
+ std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
+ FlatAffineConstraints *localVarCst) {
+ if (exprs.empty()) {
+ localVarCst->reset(numDims, numSymbols);
+ return success();
+ }
+
+ AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext());
+ // Use the same flattener to simplify each expression successively. This way
+ // local identifiers / expressions are shared.
+ for (auto expr : exprs) {
+ if (!expr.isPureAffine())
+ return failure();
+
+ flattener.walkPostOrder(expr);
+ }
+
+ assert(flattener.operandExprStack.size() == exprs.size());
+ flattenedExprs->clear();
+ flattenedExprs->assign(flattener.operandExprStack.begin(),
+ flattener.operandExprStack.end());
+
+ if (localVarCst) {
+ localVarCst->clearAndCopyFrom(flattener.localVarCst);
+ }
+
+ return success();
+}
+
+// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
+// be flattened (semi-affine expressions not handled yet).
+LogicalResult
+mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
+ unsigned numSymbols,
+ SmallVectorImpl<int64_t> *flattenedExpr,
+ FlatAffineConstraints *localVarCst) {
+ std::vector<SmallVector<int64_t, 8>> flattenedExprs;
+ LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
+ &flattenedExprs, localVarCst);
+ *flattenedExpr = flattenedExprs[0];
+ return ret;
+}
+
+/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
+/// flattened (i.e., semi-affine expressions not handled yet).
+LogicalResult mlir::getFlattenedAffineExprs(
+ AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
+ FlatAffineConstraints *localVarCst) {
+ if (map.getNumResults() == 0) {
+ localVarCst->reset(map.getNumDims(), map.getNumSymbols());
+ return success();
+ }
+ return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
+ map.getNumSymbols(), flattenedExprs,
+ localVarCst);
+}
+
+LogicalResult mlir::getFlattenedAffineExprs(
+ IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
+ FlatAffineConstraints *localVarCst) {
+ if (set.getNumConstraints() == 0) {
+ localVarCst->reset(set.getNumDims(), set.getNumSymbols());
+ return success();
+ }
+ return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
+ set.getNumSymbols(), flattenedExprs,
+ localVarCst);
+}
+
+//===----------------------------------------------------------------------===//
+// MutableAffineMap.
+//===----------------------------------------------------------------------===//
+
+MutableAffineMap::MutableAffineMap(AffineMap map)
+ : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
+ // A map always has at least 1 result by construction
+ context(map.getResult(0).getContext()) {
+ for (auto result : map.getResults())
+ results.push_back(result);
+}
+
+void MutableAffineMap::reset(AffineMap map) {
+ results.clear();
+ numDims = map.getNumDims();
+ numSymbols = map.getNumSymbols();
+ // A map always has at least 1 result by construction
+ context = map.getResult(0).getContext();
+ for (auto result : map.getResults())
+ results.push_back(result);
+}
+
+bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
+ if (results[idx].isMultipleOf(factor))
+ return true;
+
+ // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
+ // complete this (for a more powerful analysis).
+ return false;
+}
+
+// Simplifies the result affine expressions of this map. The expressions have to
+// be pure for the simplification implemented.
+void MutableAffineMap::simplify() {
+ // Simplify each of the results if possible.
+ // TODO(ntv): functional-style map
+ for (unsigned i = 0, e = getNumResults(); i < e; i++) {
+ results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
+ }
+}
+
+AffineMap MutableAffineMap::getAffineMap() const {
+ return AffineMap::get(numDims, numSymbols, results);
+}
+
+MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context)
+ : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()) {
+ // TODO(bondhugula)
+}
+
+// Universal set.
+MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+ MLIRContext *context)
+ : numDims(numDims), numSymbols(numSymbols) {}
+
+//===----------------------------------------------------------------------===//
+// AffineValueMap.
+//===----------------------------------------------------------------------===//
+
+AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value> operands,
+ ArrayRef<Value> results)
+ : map(map), operands(operands.begin(), operands.end()),
+ results(results.begin(), results.end()) {}
+
+AffineValueMap::AffineValueMap(AffineApplyOp applyOp)
+ : map(applyOp.getAffineMap()),
+ operands(applyOp.operand_begin(), applyOp.operand_end()) {
+ results.push_back(applyOp.getResult());
+}
+
+AffineValueMap::AffineValueMap(AffineBound bound)
+ : map(bound.getMap()),
+ operands(bound.operand_begin(), bound.operand_end()) {}
+
+void AffineValueMap::reset(AffineMap map, ArrayRef<Value> operands,
+ ArrayRef<Value> results) {
+ this->map.reset(map);
+ this->operands.assign(operands.begin(), operands.end());
+ this->results.assign(results.begin(), results.end());
+}
+
+void AffineValueMap::difference(const AffineValueMap &a,
+ const AffineValueMap &b, AffineValueMap *res) {
+ assert(a.getNumResults() == b.getNumResults() && "invalid inputs");
+
+ // Fully compose A's map + operands.
+ auto aMap = a.getAffineMap();
+ SmallVector<Value, 4> aOperands(a.getOperands().begin(),
+ a.getOperands().end());
+ fullyComposeAffineMapAndOperands(&aMap, &aOperands);
+
+ // Use the affine apply normalizer to get B's map into A's coordinate space.
+ AffineApplyNormalizer normalizer(aMap, aOperands);
+ SmallVector<Value, 4> bOperands(b.getOperands().begin(),
+ b.getOperands().end());
+ auto bMap = b.getAffineMap();
+ normalizer.normalize(&bMap, &bOperands);
+
+ assert(std::equal(bOperands.begin(), bOperands.end(),
+ normalizer.getOperands().begin()) &&
+ "operands are expected to be the same after normalization");
+
+ // Construct the difference expressions.
+ SmallVector<AffineExpr, 4> diffExprs;
+ diffExprs.reserve(a.getNumResults());
+ for (unsigned i = 0, e = bMap.getNumResults(); i < e; ++i)
+ diffExprs.push_back(normalizer.getAffineMap().getResult(i) -
+ bMap.getResult(i));
+
+ auto diffMap = AffineMap::get(normalizer.getNumDims(),
+ normalizer.getNumSymbols(), diffExprs);
+ canonicalizeMapAndOperands(&diffMap, &bOperands);
+ diffMap = simplifyAffineMap(diffMap);
+ res->reset(diffMap, bOperands);
+}
+
+// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in
+// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise.
+static bool findIndex(Value valueToMatch, ArrayRef<Value> valuesToSearch,
+ unsigned indexStart, unsigned *indexOfMatch) {
+ unsigned size = valuesToSearch.size();
+ for (unsigned i = indexStart; i < size; ++i) {
+ if (valueToMatch == valuesToSearch[i]) {
+ *indexOfMatch = i;
+ return true;
+ }
+ }
+ return false;
+}
+
+inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
+ return map.isMultipleOf(idx, factor);
+}
+
+/// This method uses the invariant that operands are always positionally aligned
+/// with the AffineDimExpr in the underlying AffineMap.
+bool AffineValueMap::isFunctionOf(unsigned idx, Value value) const {
+ unsigned index;
+ if (!findIndex(value, operands, /*indexStart=*/0, &index)) {
+ return false;
+ }
+ auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx);
+ // TODO(ntv): this is better implemented on a flattened representation.
+ // At least for now it is conservative.
+ return expr.isFunctionOfDim(index);
+}
+
+Value AffineValueMap::getOperand(unsigned i) const {
+ return static_cast<Value>(operands[i]);
+}
+
+ArrayRef<Value> AffineValueMap::getOperands() const {
+ return ArrayRef<Value>(operands);
+}
+
+AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); }
+
+AffineValueMap::~AffineValueMap() {}
+
+//===----------------------------------------------------------------------===//
+// FlatAffineConstraints.
+//===----------------------------------------------------------------------===//
+
+// Copy constructor.
+FlatAffineConstraints::FlatAffineConstraints(
+ const FlatAffineConstraints &other) {
+ numReservedCols = other.numReservedCols;
+ numDims = other.getNumDimIds();
+ numSymbols = other.getNumSymbolIds();
+ numIds = other.getNumIds();
+
+ auto otherIds = other.getIds();
+ ids.reserve(numReservedCols);
+ ids.append(otherIds.begin(), otherIds.end());
+
+ unsigned numReservedEqualities = other.getNumReservedEqualities();
+ unsigned numReservedInequalities = other.getNumReservedInequalities();
+
+ equalities.reserve(numReservedEqualities * numReservedCols);
+ inequalities.reserve(numReservedInequalities * numReservedCols);
+
+ for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
+ addInequality(other.getInequality(r));
+ }
+ for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
+ addEquality(other.getEquality(r));
+ }
+}
+
+// Clones this object.
+std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
+ return std::make_unique<FlatAffineConstraints>(*this);
+}
+
+// Construct from an IntegerSet.
+FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
+ : numReservedCols(set.getNumInputs() + 1),
+ numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
+ numSymbols(set.getNumSymbols()) {
+ equalities.reserve(set.getNumEqualities() * numReservedCols);
+ inequalities.reserve(set.getNumInequalities() * numReservedCols);
+ ids.resize(numIds, None);
+
+ // Flatten expressions and add them to the constraint system.
+ std::vector<SmallVector<int64_t, 8>> flatExprs;
+ FlatAffineConstraints localVarCst;
+ if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
+ assert(false && "flattening unimplemented for semi-affine integer sets");
+ return;
+ }
+ assert(flatExprs.size() == set.getNumConstraints());
+ for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) {
+ addLocalId(getNumLocalIds());
+ }
+
+ for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
+ const auto &flatExpr = flatExprs[i];
+ assert(flatExpr.size() == getNumCols());
+ if (set.getEqFlags()[i]) {
+ addEquality(flatExpr);
+ } else {
+ addInequality(flatExpr);
+ }
+ }
+ // Add the other constraints involving local id's from flattening.
+ append(localVarCst);
+}
+
+void FlatAffineConstraints::reset(unsigned numReservedInequalities,
+ unsigned numReservedEqualities,
+ unsigned newNumReservedCols,
+ unsigned newNumDims, unsigned newNumSymbols,
+ unsigned newNumLocals,
+ ArrayRef<Value> idArgs) {
+ assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
+ "minimum 1 column");
+ numReservedCols = newNumReservedCols;
+ numDims = newNumDims;
+ numSymbols = newNumSymbols;
+ numIds = numDims + numSymbols + newNumLocals;
+ assert(idArgs.empty() || idArgs.size() == numIds);
+
+ clearConstraints();
+ if (numReservedEqualities >= 1)
+ equalities.reserve(newNumReservedCols * numReservedEqualities);
+ if (numReservedInequalities >= 1)
+ inequalities.reserve(newNumReservedCols * numReservedInequalities);
+ if (idArgs.empty()) {
+ ids.resize(numIds, None);
+ } else {
+ ids.assign(idArgs.begin(), idArgs.end());
+ }
+}
+
+void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
+ unsigned newNumLocals,
+ ArrayRef<Value> idArgs) {
+ reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
+ newNumSymbols, newNumLocals, idArgs);
+}
+
+void FlatAffineConstraints::append(const FlatAffineConstraints &other) {
+ assert(other.getNumCols() == getNumCols());
+ assert(other.getNumDimIds() == getNumDimIds());
+ assert(other.getNumSymbolIds() == getNumSymbolIds());
+
+ inequalities.reserve(inequalities.size() +
+ other.getNumInequalities() * numReservedCols);
+ equalities.reserve(equalities.size() +
+ other.getNumEqualities() * numReservedCols);
+
+ for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
+ addInequality(other.getInequality(r));
+ }
+ for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
+ addEquality(other.getEquality(r));
+ }
+}
+
+void FlatAffineConstraints::addLocalId(unsigned pos) {
+ addId(IdKind::Local, pos);
+}
+
+void FlatAffineConstraints::addDimId(unsigned pos, Value id) {
+ addId(IdKind::Dimension, pos, id);
+}
+
+void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) {
+ addId(IdKind::Symbol, pos, id);
+}
+
+/// Adds a dimensional identifier. The added column is initialized to
+/// zero.
+void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) {
+ if (kind == IdKind::Dimension) {
+ assert(pos <= getNumDimIds());
+ } else if (kind == IdKind::Symbol) {
+ assert(pos <= getNumSymbolIds());
+ } else {
+ assert(pos <= getNumLocalIds());
+ }
+
+ unsigned oldNumReservedCols = numReservedCols;
+
+ // Check if a resize is necessary.
+ if (getNumCols() + 1 > numReservedCols) {
+ equalities.resize(getNumEqualities() * (getNumCols() + 1));
+ inequalities.resize(getNumInequalities() * (getNumCols() + 1));
+ numReservedCols++;
+ }
+
+ int absolutePos;
+
+ if (kind == IdKind::Dimension) {
+ absolutePos = pos;
+ numDims++;
+ } else if (kind == IdKind::Symbol) {
+ absolutePos = pos + getNumDimIds();
+ numSymbols++;
+ } else {
+ absolutePos = pos + getNumDimIds() + getNumSymbolIds();
+ }
+ numIds++;
+
+ // Note that getNumCols() now will already return the new size, which will be
+ // at least one.
+ int numInequalities = static_cast<int>(getNumInequalities());
+ int numEqualities = static_cast<int>(getNumEqualities());
+ int numCols = static_cast<int>(getNumCols());
+ for (int r = numInequalities - 1; r >= 0; r--) {
+ for (int c = numCols - 2; c >= 0; c--) {
+ if (c < absolutePos)
+ atIneq(r, c) = inequalities[r * oldNumReservedCols + c];
+ else
+ atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c];
+ }
+ atIneq(r, absolutePos) = 0;
+ }
+
+ for (int r = numEqualities - 1; r >= 0; r--) {
+ for (int c = numCols - 2; c >= 0; c--) {
+ // All values in column absolutePositions < absolutePos have the same
+ // coordinates in the 2-d view of the coefficient buffer.
+ if (c < absolutePos)
+ atEq(r, c) = equalities[r * oldNumReservedCols + c];
+ else
+ // Those at absolutePosition >= absolutePos, get a shifted
+ // absolutePosition.
+ atEq(r, c + 1) = equalities[r * oldNumReservedCols + c];
+ }
+ // Initialize added dimension to zero.
+ atEq(r, absolutePos) = 0;
+ }
+
+ // If an 'id' is provided, insert it; otherwise use None.
+ if (id) {
+ ids.insert(ids.begin() + absolutePos, id);
+ } else {
+ ids.insert(ids.begin() + absolutePos, None);
+ }
+ assert(ids.size() == getNumIds());
+}
+
+/// Checks if two constraint systems are in the same space, i.e., if they are
+/// associated with the same set of identifiers, appearing in the same order.
+static bool areIdsAligned(const FlatAffineConstraints &A,
+ const FlatAffineConstraints &B) {
+ return A.getNumDimIds() == B.getNumDimIds() &&
+ A.getNumSymbolIds() == B.getNumSymbolIds() &&
+ A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
+}
+
+/// Calls areIdsAligned to check if two constraint systems have the same set
+/// of identifiers in the same order.
+bool FlatAffineConstraints::areIdsAlignedWithOther(
+ const FlatAffineConstraints &other) {
+ return areIdsAligned(*this, other);
+}
+
+/// Checks if the SSA values associated with `cst''s identifiers are unique.
+static bool LLVM_ATTRIBUTE_UNUSED
+areIdsUnique(const FlatAffineConstraints &cst) {
+ SmallPtrSet<Value, 8> uniqueIds;
+ for (auto id : cst.getIds()) {
+ if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
+ return false;
+ }
+ return true;
+}
+
+// Swap the posA^th identifier with the posB^th identifier.
+static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
+ assert(posA < A->getNumIds() && "invalid position A");
+ assert(posB < A->getNumIds() && "invalid position B");
+
+ if (posA == posB)
+ return;
+
+ for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
+ std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
+ }
+ for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
+ std::swap(A->atEq(r, posA), A->atEq(r, posB));
+ }
+ std::swap(A->getId(posA), A->getId(posB));
+}
+
+/// Merge and align the identifiers of A and B starting at 'offset', so that
+/// both constraint systems get the union of the contained identifiers that is
+/// dimension-wise and symbol-wise unique; both constraint systems are updated
+/// so that they have the union of all identifiers, with A's original
+/// identifiers appearing first followed by any of B's identifiers that didn't
+/// appear in A. Local identifiers of each system are by design separate/local
+/// and are placed one after other (A's followed by B's).
+// Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
+// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
+//
+static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
+ FlatAffineConstraints *B) {
+ assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds());
+ // A merge/align isn't meaningful if a cst's ids aren't distinct.
+ assert(areIdsUnique(*A) && "A's id values aren't unique");
+ assert(areIdsUnique(*B) && "B's id values aren't unique");
+
+ assert(std::all_of(A->getIds().begin() + offset,
+ A->getIds().begin() + A->getNumDimAndSymbolIds(),
+ [](Optional<Value> id) { return id.hasValue(); }));
+
+ assert(std::all_of(B->getIds().begin() + offset,
+ B->getIds().begin() + B->getNumDimAndSymbolIds(),
+ [](Optional<Value> id) { return id.hasValue(); }));
+
+ // Place local id's of A after local id's of B.
+ for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
+ B->addLocalId(0);
+ }
+ for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
+ t++) {
+ A->addLocalId(A->getNumLocalIds());
+ }
+
+ SmallVector<Value, 4> aDimValues, aSymValues;
+ A->getIdValues(offset, A->getNumDimIds(), &aDimValues);
+ A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
+ {
+ // Merge dims from A into B.
+ unsigned d = offset;
+ for (auto aDimValue : aDimValues) {
+ unsigned loc;
+ if (B->findId(*aDimValue, &loc)) {
+ assert(loc >= offset && "A's dim appears in B's aligned range");
+ assert(loc < B->getNumDimIds() &&
+ "A's dim appears in B's non-dim position");
+ swapId(B, d, loc);
+ } else {
+ B->addDimId(d);
+ B->setIdValue(d, aDimValue);
+ }
+ d++;
+ }
+
+ // Dimensions that are in B, but not in A, are added at the end.
+ for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
+ A->addDimId(A->getNumDimIds());
+ A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
+ }
+ }
+ {
+ // Merge symbols: merge A's symbols into B first.
+ unsigned s = B->getNumDimIds();
+ for (auto aSymValue : aSymValues) {
+ unsigned loc;
+ if (B->findId(*aSymValue, &loc)) {
+ assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
+ "A's symbol appears in B's non-symbol position");
+ swapId(B, s, loc);
+ } else {
+ B->addSymbolId(s - B->getNumDimIds());
+ B->setIdValue(s, aSymValue);
+ }
+ s++;
+ }
+ // Symbols that are in B, but not in A, are added at the end.
+ for (unsigned t = A->getNumDimAndSymbolIds(),
+ e = B->getNumDimAndSymbolIds();
+ t < e; t++) {
+ A->addSymbolId(A->getNumSymbolIds());
+ A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
+ }
+ }
+ assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
+}
+
+// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
+void FlatAffineConstraints::mergeAndAlignIdsWithOther(
+ unsigned offset, FlatAffineConstraints *other) {
+ mergeAndAlignIds(offset, this, other);
+}
+
+// This routine may add additional local variables if the flattened expression
+// corresponding to the map has such variables due to mod's, ceildiv's, and
+// floordiv's in it.
+LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
+ std::vector<SmallVector<int64_t, 8>> flatExprs;
+ FlatAffineConstraints localCst;
+ if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs,
+ &localCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "composition unimplemented for semi-affine maps\n");
+ return failure();
+ }
+ assert(flatExprs.size() == vMap->getNumResults());
+
+ // Add localCst information.
+ if (localCst.getNumLocalIds() > 0) {
+ localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(),
+ /*values=*/vMap->getOperands());
+ // Align localCst and this.
+ mergeAndAlignIds(/*offset=*/0, &localCst, this);
+ // Finally, append localCst to this constraint set.
+ append(localCst);
+ }
+
+ // Add dimensions corresponding to the map's results.
+ for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
+ // TODO: Consider using a batched version to add a range of IDs.
+ addDimId(0);
+ }
+
+ // We add one equality for each result connecting the result dim of the map to
+ // the other identifiers.
+ // For eg: if the expression is 16*i0 + i1, and this is the r^th
+ // iteration/result of the value map, we are adding the equality:
+ // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
+ // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
+ for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
+ const auto &flatExpr = flatExprs[r];
+ assert(flatExpr.size() >= vMap->getNumOperands() + 1);
+
+ // eqToAdd is the equality corresponding to the flattened affine expression.
+ SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
+ // Set the coefficient for this result to one.
+ eqToAdd[r] = 1;
+
+ // Dims and symbols.
+ for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
+ unsigned loc;
+ bool ret = findId(*vMap->getOperand(i), &loc);
+ assert(ret && "value map's id can't be found");
+ (void)ret;
+ // Negate 'eq[r]' since the newly added dimension will be set to this one.
+ eqToAdd[loc] = -flatExpr[i];
+ }
+ // Local vars common to eq and localCst are at the beginning.
+ unsigned j = getNumDimIds() + getNumSymbolIds();
+ unsigned end = flatExpr.size() - 1;
+ for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
+ eqToAdd[j] = -flatExpr[i];
+ }
+
+ // Constant term.
+ eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
+
+ // Add the equality connecting the result of the map to this constraint set.
+ addEquality(eqToAdd);
+ }
+
+ return success();
+}
+
+// Similar to composeMap except that no Value's need be associated with the
+// constraint system nor are they looked at -- since the dimensions and
+// symbols of 'other' are expected to correspond 1:1 to 'this' system. It
+// is thus not convenient to share code with composeMap.
+LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
+ assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
+ assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
+
+ std::vector<SmallVector<int64_t, 8>> flatExprs;
+ FlatAffineConstraints localCst;
+ if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "composition unimplemented for semi-affine maps\n");
+ return failure();
+ }
+ assert(flatExprs.size() == other.getNumResults());
+
+ // Add localCst information.
+ if (localCst.getNumLocalIds() > 0) {
+ // Place local id's of A after local id's of B.
+ for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) {
+ addLocalId(0);
+ }
+ // Finally, append localCst to this constraint set.
+ append(localCst);
+ }
+
+ // Add dimensions corresponding to the map's results.
+ for (unsigned t = 0, e = other.getNumResults(); t < e; t++) {
+ addDimId(0);
+ }
+
+ // We add one equality for each result connecting the result dim of the map to
+ // the other identifiers.
+ // For eg: if the expression is 16*i0 + i1, and this is the r^th
+ // iteration/result of the value map, we are adding the equality:
+ // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we
+ // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
+ for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
+ const auto &flatExpr = flatExprs[r];
+ assert(flatExpr.size() >= other.getNumInputs() + 1);
+
+ // eqToAdd is the equality corresponding to the flattened affine expression.
+ SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
+ // Set the coefficient for this result to one.
+ eqToAdd[r] = 1;
+
+ // Dims and symbols.
+ for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
+ // Negate 'eq[r]' since the newly added dimension will be set to this one.
+ eqToAdd[e + i] = -flatExpr[i];
+ }
+ // Local vars common to eq and localCst are at the beginning.
+ unsigned j = getNumDimIds() + getNumSymbolIds();
+ unsigned end = flatExpr.size() - 1;
+ for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
+ eqToAdd[j] = -flatExpr[i];
+ }
+
+ // Constant term.
+ eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
+
+ // Add the equality connecting the result of the map to this constraint set.
+ addEquality(eqToAdd);
+ }
+
+ return success();
+}
+
+// Turn a dimension into a symbol.
+static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) {
+ unsigned pos;
+ if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) {
+ swapId(cst, pos, cst->getNumDimIds() - 1);
+ cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1);
+ }
+}
+
+// Turn a symbol into a dimension.
+static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) {
+ unsigned pos;
+ if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
+ pos < cst->getNumDimAndSymbolIds()) {
+ swapId(cst, pos, cst->getNumDimIds());
+ cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
+ }
+}
+
+// Changes all symbol identifiers which are loop IVs to dim identifiers.
+void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
+ // Gather all symbols which are loop IVs.
+ SmallVector<Value, 4> loopIVs;
+ for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
+ if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue()))
+ loopIVs.push_back(ids[i].getValue());
+ }
+ // Turn each symbol in 'loopIVs' into a dim identifier.
+ for (auto iv : loopIVs) {
+ turnSymbolIntoDim(this, *iv);
+ }
+}
+
+void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
+ if (containsId(*id))
+ return;
+
+ // Caller is expected to fully compose map/operands if necessary.
+ assert((isTopLevelValue(id) || isForInductionVar(id)) &&
+ "non-terminal symbol / loop IV expected");
+ // Outer loop IVs could be used in forOp's bounds.
+ if (auto loop = getForInductionVarOwner(id)) {
+ addDimId(getNumDimIds(), id);
+ if (failed(this->addAffineForOpDomain(loop)))
+ LLVM_DEBUG(
+ loop.emitWarning("failed to add domain info to constraint system"));
+ return;
+ }
+ // Add top level symbol.
+ addSymbolId(getNumSymbolIds(), id);
+ // Check if the symbol is a constant.
+ if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id->getDefiningOp()))
+ setIdToConstant(*id, constOp.getValue());
+}
+
+LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
+ unsigned pos;
+ // Pre-condition for this method.
+ if (!findId(*forOp.getInductionVar(), &pos)) {
+ assert(false && "Value not found");
+ return failure();
+ }
+
+ int64_t step = forOp.getStep();
+ if (step != 1) {
+ if (!forOp.hasConstantLowerBound())
+ forOp.emitWarning("domain conservatively approximated");
+ else {
+ // Add constraints for the stride.
+ // (iv - lb) % step = 0 can be written as:
+ // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
+ // Add local variable 'q' and add the above equality.
+ // The first constraint is q = (iv - lb) floordiv step
+ SmallVector<int64_t, 8> dividend(getNumCols(), 0);
+ int64_t lb = forOp.getConstantLowerBound();
+ dividend[pos] = 1;
+ dividend.back() -= lb;
+ addLocalFloorDiv(dividend, step);
+ // Second constraint: (iv - lb) - step * q = 0.
+ SmallVector<int64_t, 8> eq(getNumCols(), 0);
+ eq[pos] = 1;
+ eq.back() -= lb;
+ // For the local var just added above.
+ eq[getNumCols() - 2] = -step;
+ addEquality(eq);
+ }
+ }
+
+ if (forOp.hasConstantLowerBound()) {
+ addConstantLowerBound(pos, forOp.getConstantLowerBound());
+ } else {
+ // Non-constant lower bound case.
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands().begin(),
+ forOp.getLowerBoundOperands().end());
+ if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands,
+ /*eq=*/false, /*lower=*/true)))
+ return failure();
+ }
+
+ if (forOp.hasConstantUpperBound()) {
+ addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
+ return success();
+ }
+ // Non-constant upper bound case.
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands().begin(),
+ forOp.getUpperBoundOperands().end());
+ return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands,
+ /*eq=*/false, /*lower=*/false);
+}
+
+// Searches for a constraint with a non-zero coefficient at 'colIdx' in
+// equality (isEq=true) or inequality (isEq=false) constraints.
+// Returns true and sets row found in search in 'rowIdx'.
+// Returns false otherwise.
+static bool
+findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints,
+ unsigned colIdx, bool isEq, unsigned *rowIdx) {
+ auto at = [&](unsigned rowIdx) -> int64_t {
+ return isEq ? constraints.atEq(rowIdx, colIdx)
+ : constraints.atIneq(rowIdx, colIdx);
+ };
+ unsigned e =
+ isEq ? constraints.getNumEqualities() : constraints.getNumInequalities();
+ for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) {
+ if (at(*rowIdx) != 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Normalizes the coefficient values across all columns in 'rowIDx' by their
+// GCD in equality or inequality constraints as specified by 'isEq'.
+template <bool isEq>
+static void normalizeConstraintByGCD(FlatAffineConstraints *constraints,
+ unsigned rowIdx) {
+ auto at = [&](unsigned colIdx) -> int64_t {
+ return isEq ? constraints->atEq(rowIdx, colIdx)
+ : constraints->atIneq(rowIdx, colIdx);
+ };
+ uint64_t gcd = std::abs(at(0));
+ for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) {
+ gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j)));
+ }
+ if (gcd > 0 && gcd != 1) {
+ for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) {
+ int64_t v = at(j) / static_cast<int64_t>(gcd);
+ isEq ? constraints->atEq(rowIdx, j) = v
+ : constraints->atIneq(rowIdx, j) = v;
+ }
+ }
+}
+
+void FlatAffineConstraints::normalizeConstraintsByGCD() {
+ for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+ normalizeConstraintByGCD</*isEq=*/true>(this, i);
+ }
+ for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+ normalizeConstraintByGCD</*isEq=*/false>(this, i);
+ }
+}
+
+bool FlatAffineConstraints::hasConsistentState() const {
+ if (inequalities.size() != getNumInequalities() * numReservedCols)
+ return false;
+ if (equalities.size() != getNumEqualities() * numReservedCols)
+ return false;
+ if (ids.size() != getNumIds())
+ return false;
+
+ // Catches errors where numDims, numSymbols, numIds aren't consistent.
+ if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds)
+ return false;
+
+ return true;
+}
+
+/// Checks all rows of equality/inequality constraints for trivial
+/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced
+/// after elimination. Returns 'true' if an invalid constraint is found;
+/// 'false' otherwise.
+bool FlatAffineConstraints::hasInvalidConstraint() const {
+ assert(hasConsistentState());
+ auto check = [&](bool isEq) -> bool {
+ unsigned numCols = getNumCols();
+ unsigned numRows = isEq ? getNumEqualities() : getNumInequalities();
+ for (unsigned i = 0, e = numRows; i < e; ++i) {
+ unsigned j;
+ for (j = 0; j < numCols - 1; ++j) {
+ int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
+ // Skip rows with non-zero variable coefficients.
+ if (v != 0)
+ break;
+ }
+ if (j < numCols - 1) {
+ continue;
+ }
+ // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
+ // Example invalid constraints include: '1 == 0' or '-1 >= 0'
+ int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
+ if ((isEq && v != 0) || (!isEq && v < 0)) {
+ return true;
+ }
+ }
+ return false;
+ };
+ if (check(/*isEq=*/true))
+ return true;
+ return check(/*isEq=*/false);
+}
+
+// Eliminate identifier from constraint at 'rowIdx' based on coefficient at
+// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be
+// updated as they have already been eliminated.
+static void eliminateFromConstraint(FlatAffineConstraints *constraints,
+ unsigned rowIdx, unsigned pivotRow,
+ unsigned pivotCol, unsigned elimColStart,
+ bool isEq) {
+ // Skip if equality 'rowIdx' if same as 'pivotRow'.
+ if (isEq && rowIdx == pivotRow)
+ return;
+ auto at = [&](unsigned i, unsigned j) -> int64_t {
+ return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
+ };
+ int64_t leadCoeff = at(rowIdx, pivotCol);
+ // Skip if leading coefficient at 'rowIdx' is already zero.
+ if (leadCoeff == 0)
+ return;
+ int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
+ int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
+ int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff);
+ int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
+ int64_t rowMultiplier = lcm / std::abs(leadCoeff);
+
+ unsigned numCols = constraints->getNumCols();
+ for (unsigned j = 0; j < numCols; ++j) {
+ // Skip updating column 'j' if it was just eliminated.
+ if (j >= elimColStart && j < pivotCol)
+ continue;
+ int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
+ rowMultiplier * at(rowIdx, j);
+ isEq ? constraints->atEq(rowIdx, j) = v
+ : constraints->atIneq(rowIdx, j) = v;
+ }
+}
+
+// Remove coefficients in column range [colStart, colLimit) in place.
+// This removes in data in the specified column range, and copies any
+// remaining valid data into place.
+static void shiftColumnsToLeft(FlatAffineConstraints *constraints,
+ unsigned colStart, unsigned colLimit,
+ bool isEq) {
+ assert(colLimit <= constraints->getNumIds());
+ if (colLimit <= colStart)
+ return;
+
+ unsigned numCols = constraints->getNumCols();
+ unsigned numRows = isEq ? constraints->getNumEqualities()
+ : constraints->getNumInequalities();
+ unsigned numToEliminate = colLimit - colStart;
+ for (unsigned r = 0, e = numRows; r < e; ++r) {
+ for (unsigned c = colLimit; c < numCols; ++c) {
+ if (isEq) {
+ constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c);
+ } else {
+ constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c);
+ }
+ }
+ }
+}
+
+// Removes identifiers in column range [idStart, idLimit), and copies any
+// remaining valid data into place, and updates member variables.
+void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) {
+ assert(idLimit < getNumCols() && "invalid id limit");
+
+ if (idStart >= idLimit)
+ return;
+
+ // We are going to be removing one or more identifiers from the range.
+ assert(idStart < numIds && "invalid idStart position");
+
+ // TODO(andydavis) Make 'removeIdRange' a lambda called from here.
+ // Remove eliminated identifiers from equalities.
+ shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true);
+
+ // Remove eliminated identifiers from inequalities.
+ shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false);
+
+ // Update members numDims, numSymbols and numIds.
+ unsigned numDimsEliminated = 0;
+ unsigned numLocalsEliminated = 0;
+ unsigned numColsEliminated = idLimit - idStart;
+ if (idStart < numDims) {
+ numDimsEliminated = std::min(numDims, idLimit) - idStart;
+ }
+ // Check how many local id's were removed. Note that our identifier order is
+ // [dims, symbols, locals]. Local id start at position numDims + numSymbols.
+ if (idLimit > numDims + numSymbols) {
+ numLocalsEliminated = std::min(
+ idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds());
+ }
+ unsigned numSymbolsEliminated =
+ numColsEliminated - numDimsEliminated - numLocalsEliminated;
+
+ numDims -= numDimsEliminated;
+ numSymbols -= numSymbolsEliminated;
+ numIds = numIds - numColsEliminated;
+
+ ids.erase(ids.begin() + idStart, ids.begin() + idLimit);
+
+ // No resize necessary. numReservedCols remains the same.
+}
+
+/// Returns the position of the identifier that has the minimum <number of lower
+/// bounds> times <number of upper bounds> from the specified range of
+/// identifiers [start, end). It is often best to eliminate in the increasing
+/// order of these counts when doing Fourier-Motzkin elimination since FM adds
+/// that many new constraints.
+static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst,
+ unsigned start, unsigned end) {
+ assert(start < cst.getNumIds() && end < cst.getNumIds() + 1);
+
+ auto getProductOfNumLowerUpperBounds = [&](unsigned pos) {
+ unsigned numLb = 0;
+ unsigned numUb = 0;
+ for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+ if (cst.atIneq(r, pos) > 0) {
+ ++numLb;
+ } else if (cst.atIneq(r, pos) < 0) {
+ ++numUb;
+ }
+ }
+ return numLb * numUb;
+ };
+
+ unsigned minLoc = start;
+ unsigned min = getProductOfNumLowerUpperBounds(start);
+ for (unsigned c = start + 1; c < end; c++) {
+ unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c);
+ if (numLbUbProduct < min) {
+ min = numLbUbProduct;
+ minLoc = c;
+ }
+ }
+ return minLoc;
+}
+
+// Checks for emptiness of the set by eliminating identifiers successively and
+// using the GCD test (on all equality constraints) and checking for trivially
+// invalid constraints. Returns 'true' if the constraint system is found to be
+// empty; false otherwise.
+bool FlatAffineConstraints::isEmpty() const {
+ if (isEmptyByGCDTest() || hasInvalidConstraint())
+ return true;
+
+ // First, eliminate as many identifiers as possible using Gaussian
+ // elimination.
+ FlatAffineConstraints tmpCst(*this);
+ unsigned currentPos = 0;
+ while (currentPos < tmpCst.getNumIds()) {
+ tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds());
+ ++currentPos;
+ // We check emptiness through trivial checks after eliminating each ID to
+ // detect emptiness early. Since the checks isEmptyByGCDTest() and
+ // hasInvalidConstraint() are linear time and single sweep on the constraint
+ // buffer, this appears reasonable - but can optimize in the future.
+ if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest())
+ return true;
+ }
+
+ // Eliminate the remaining using FM.
+ for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) {
+ tmpCst.FourierMotzkinEliminate(
+ getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds()));
+ // Check for a constraint explosion. This rarely happens in practice, but
+ // this check exists as a safeguard against improperly constructed
+ // constraint systems or artificially created arbitrarily complex systems
+ // that aren't the intended use case for FlatAffineConstraints. This is
+ // needed since FM has a worst case exponential complexity in theory.
+ if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) {
+ LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n");
+ return false;
+ }
+
+ // FM wouldn't have modified the equalities in any way. So no need to again
+ // run GCD test. Check for trivial invalid constraints.
+ if (tmpCst.hasInvalidConstraint())
+ return true;
+ }
+ return false;
+}
+
+// Runs the GCD test on all equality constraints. Returns 'true' if this test
+// fails on any equality. Returns 'false' otherwise.
+// This test can be used to disprove the existence of a solution. If it returns
+// true, no integer solution to the equality constraints can exist.
+//
+// GCD test definition:
+//
+// The equality constraint:
+//
+// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0
+//
+// has an integer solution iff:
+//
+// GCD of c_1, c_2, ..., c_n divides c_0.
+//
+bool FlatAffineConstraints::isEmptyByGCDTest() const {
+ assert(hasConsistentState());
+ unsigned numCols = getNumCols();
+ for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+ uint64_t gcd = std::abs(atEq(i, 0));
+ for (unsigned j = 1; j < numCols - 1; ++j) {
+ gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j)));
+ }
+ int64_t v = std::abs(atEq(i, numCols - 1));
+ if (gcd > 0 && (v % gcd != 0)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Tightens inequalities given that we are dealing with integer spaces. This is
+/// analogous to the GCD test but applied to inequalities. The constant term can
+/// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,
+/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a
+/// fast method - linear in the number of coefficients.
+// Example on how this affects practical cases: consider the scenario:
+// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield
+// j >= 100 instead of the tighter (exact) j >= 128.
+void FlatAffineConstraints::GCDTightenInequalities() {
+ unsigned numCols = getNumCols();
+ for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+ uint64_t gcd = std::abs(atIneq(i, 0));
+ for (unsigned j = 1; j < numCols - 1; ++j) {
+ gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j)));
+ }
+ if (gcd > 0 && gcd != 1) {
+ int64_t gcdI = static_cast<int64_t>(gcd);
+ // Tighten the constant term and normalize the constraint by the GCD.
+ atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI);
+ for (unsigned j = 0, e = numCols - 1; j < e; ++j)
+ atIneq(i, j) /= gcdI;
+ }
+ }
+}
+
+// Eliminates all identifier variables in column range [posStart, posLimit).
+// Returns the number of variables eliminated.
+unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
+ unsigned posLimit) {
+ // Return if identifier positions to eliminate are out of range.
+ assert(posLimit <= numIds);
+ assert(hasConsistentState());
+
+ if (posStart >= posLimit)
+ return 0;
+
+ GCDTightenInequalities();
+
+ unsigned pivotCol = 0;
+ for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) {
+ // Find a row which has a non-zero coefficient in column 'j'.
+ unsigned pivotRow;
+ if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true,
+ &pivotRow)) {
+ // No pivot row in equalities with non-zero at 'pivotCol'.
+ if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false,
+ &pivotRow)) {
+ // If inequalities are also non-zero in 'pivotCol', it can be
+ // eliminated.
+ continue;
+ }
+ break;
+ }
+
+ // Eliminate identifier at 'pivotCol' from each equality row.
+ for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+ eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
+ /*isEq=*/true);
+ normalizeConstraintByGCD</*isEq=*/true>(this, i);
+ }
+
+ // Eliminate identifier at 'pivotCol' from each inequality row.
+ for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+ eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart,
+ /*isEq=*/false);
+ normalizeConstraintByGCD</*isEq=*/false>(this, i);
+ }
+ removeEquality(pivotRow);
+ GCDTightenInequalities();
+ }
+ // Update position limit based on number eliminated.
+ posLimit = pivotCol;
+ // Remove eliminated columns from all constraints.
+ removeIdRange(posStart, posLimit);
+ return posLimit - posStart;
+}
+
+// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
+// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
+// could be detected as the floordiv of n. For eg:
+// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
+// id_r = id_n mod 4, id_q = id_n floordiv 4.
+// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
+// pre-detected at the caller.
+static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
+ int64_t lbConst, int64_t ubConst,
+ SmallVectorImpl<AffineExpr> *memo) {
+ assert(pos < cst.getNumIds() && "invalid position");
+
+ // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
+ // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
+ // and id_q the quotient when dividing id_n by the divisor.
+
+ if (lbConst != 0 || ubConst < 1)
+ return false;
+
+ int64_t divisor = ubConst + 1;
+
+ // Now check for: id_r = id_n - divisor * id_q. As an example, we
+ // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
+ unsigned seenQuotient = 0, seenDividend = 0;
+ int quotientPos = -1, dividendPos = -1;
+ for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
+ // id_n should have coeff 1 or -1.
+ if (std::abs(cst.atEq(r, pos)) != 1)
+ continue;
+ // constant term should be 0.
+ if (cst.atEq(r, cst.getNumCols() - 1) != 0)
+ continue;
+ unsigned c, f;
+ int quotientSign = 1, dividendSign = 1;
+ for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
+ if (c == pos)
+ continue;
+ // The coefficient of the quotient should be +/-divisor.
+ // TODO(bondhugula): could be extended to detect an affine function for
+ // the quotient (i.e., the coeff could be a non-zero multiple of divisor).
+ int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
+ if (v == divisor || v == -divisor) {
+ seenQuotient++;
+ quotientPos = c;
+ quotientSign = v > 0 ? 1 : -1;
+ }
+ // The coefficient of the dividend should be +/-1.
+ // TODO(bondhugula): could be extended to detect an affine function of
+ // the other identifiers as the dividend.
+ else if (v == -1 || v == 1) {
+ seenDividend++;
+ dividendPos = c;
+ dividendSign = v < 0 ? 1 : -1;
+ } else if (cst.atEq(r, c) != 0) {
+ // Cannot be inferred as a mod since the constraint has a coefficient
+ // for an identifier that's neither a unit nor the divisor (see TODOs
+ // above).
+ break;
+ }
+ }
+ if (c < f)
+ // Cannot be inferred as a mod since the constraint has a coefficient for
+ // an identifier that's neither a unit nor the divisor (see TODOs above).
+ continue;
+
+ // We are looking for exactly one identifier as the dividend.
+ if (seenDividend == 1 && seenQuotient >= 1) {
+ if (!(*memo)[dividendPos])
+ return false;
+ // Successfully detected a mod.
+ (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
+ auto ub = cst.getConstantUpperBound(dividendPos);
+ if (ub.hasValue() && ub.getValue() < divisor)
+ // The mod can be optimized away.
+ (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
+ else
+ (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
+
+ if (seenQuotient == 1 && !(*memo)[quotientPos])
+ // Successfully detected a floordiv as well.
+ (*memo)[quotientPos] =
+ (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
+ return true;
+ }
+ }
+ return false;
+}
+
+// Gather lower and upper bounds for the pos^th identifier.
+static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
+ unsigned pos,
+ SmallVectorImpl<unsigned> *lbIndices,
+ SmallVectorImpl<unsigned> *ubIndices) {
+ assert(pos < cst.getNumIds() && "invalid position");
+
+ // Gather all lower bounds and upper bounds of the variable. Since the
+ // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+ // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+ for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+ if (cst.atIneq(r, pos) >= 1) {
+ // Lower bound.
+ lbIndices->push_back(r);
+ } else if (cst.atIneq(r, pos) <= -1) {
+ // Upper bound.
+ ubIndices->push_back(r);
+ }
+ }
+}
+
+// Check if the pos^th identifier can be expressed as a floordiv of an affine
+// function of other identifiers (where the divisor is a positive constant).
+// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4.
+bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
+ SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) {
+ assert(pos < cst.getNumIds() && "invalid position");
+
+ SmallVector<unsigned, 4> lbIndices, ubIndices;
+ getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices);
+
+ // Check if any lower bound, upper bound pair is of the form:
+ // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id'
+ // divisor * id <= expr <-- Upper bound for 'id'
+ // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1).
+ //
+ // For example, if -32*k + 16*i + j >= 0
+ // 32*k - 16*i - j + 31 >= 0 <=>
+ // k = ( 16*i + j ) floordiv 32
+ unsigned seenDividends = 0;
+ for (auto ubPos : ubIndices) {
+ for (auto lbPos : lbIndices) {
+ // Check if lower bound's constant term is 'divisor - 1'. The 'divisor'
+ // here is cst.atIneq(lbPos, pos) and we already know that it's positive
+ // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'.
+ if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1)
+ continue;
+ // Check if upper bound's constant term is 0.
+ if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0)
+ continue;
+ // For the remaining part, check if the lower bound expr's coeff's are
+ // negations of corresponding upper bound ones'.
+ unsigned c, f;
+ for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
+ if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c))
+ break;
+ if (c != pos && cst.atIneq(lbPos, c) != 0)
+ seenDividends++;
+ }
+ // Lb coeff's aren't negative of ub coeff's (for the non constant term
+ // part).
+ if (c < f)
+ continue;
+ if (seenDividends >= 1) {
+ // The divisor is the constant term of the lower bound expression.
+ // We already know that cst.atIneq(lbPos, pos) > 0.
+ int64_t divisor = cst.atIneq(lbPos, pos);
+ // Construct the dividend expression.
+ auto dividendExpr = getAffineConstantExpr(0, context);
+ unsigned c, f;
+ for (c = 0, f = cst.getNumCols() - 1; c < f; c++) {
+ if (c == pos)
+ continue;
+ int64_t ubVal = cst.atIneq(ubPos, c);
+ if (ubVal == 0)
+ continue;
+ if (!(*memo)[c])
+ break;
+ dividendExpr = dividendExpr + ubVal * (*memo)[c];
+ }
+ // Expression can't be constructed as it depends on a yet unknown
+ // identifier.
+ // TODO(mlir-team): Visit/compute the identifiers in an order so that
+ // this doesn't happen. More complex but much more efficient.
+ if (c < f)
+ continue;
+ // Successfully detected the floordiv.
+ (*memo)[pos] = dividendExpr.floorDiv(divisor);
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+// Fills an inequality row with the value 'val'.
+static inline void fillInequality(FlatAffineConstraints *cst, unsigned r,
+ int64_t val) {
+ for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
+ cst->atIneq(r, c) = val;
+ }
+}
+
+// Negates an inequality.
+static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) {
+ for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) {
+ cst->atIneq(r, c) = -cst->atIneq(r, c);
+ }
+}
+
+// A more complex check to eliminate redundant inequalities. Uses FourierMotzkin
+// to check if a constraint is redundant.
+void FlatAffineConstraints::removeRedundantInequalities() {
+ SmallVector<bool, 32> redun(getNumInequalities(), false);
+ // To check if an inequality is redundant, we replace the inequality by its
+ // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting
+ // system is empty. If it is, the inequality is redundant.
+ FlatAffineConstraints tmpCst(*this);
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ // Change the inequality to its complement.
+ negateInequality(&tmpCst, r);
+ tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--;
+ if (tmpCst.isEmpty()) {
+ redun[r] = true;
+ // Zero fill the redundant inequality.
+ fillInequality(this, r, /*val=*/0);
+ fillInequality(&tmpCst, r, /*val=*/0);
+ } else {
+ // Reverse the change (to avoid recreating tmpCst each time).
+ tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++;
+ negateInequality(&tmpCst, r);
+ }
+ }
+
+ // Scan to get rid of all rows marked redundant, in-place.
+ auto copyRow = [&](unsigned src, unsigned dest) {
+ if (src == dest)
+ return;
+ for (unsigned c = 0, e = getNumCols(); c < e; c++) {
+ atIneq(dest, c) = atIneq(src, c);
+ }
+ };
+ unsigned pos = 0;
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ if (!redun[r])
+ copyRow(r, pos++);
+ }
+ inequalities.resize(numReservedCols * pos);
+}
+
+std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
+ unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
+ ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
+ assert(pos + offset < getNumDimIds() && "invalid dim start pos");
+ assert(symStartPos >= (pos + offset) && "invalid sym start pos");
+ assert(getNumLocalIds() == localExprs.size() &&
+ "incorrect local exprs count");
+
+ SmallVector<unsigned, 4> lbIndices, ubIndices;
+ getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices);
+
+ /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
+ auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
+ b.clear();
+ for (unsigned i = 0, e = a.size(); i < e; ++i) {
+ if (i < offset || i >= offset + num)
+ b.push_back(a[i]);
+ }
+ };
+
+ SmallVector<int64_t, 8> lb, ub;
+ SmallVector<AffineExpr, 4> exprs;
+ unsigned dimCount = symStartPos - num;
+ unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
+ exprs.reserve(lbIndices.size());
+ // Lower bound expressions.
+ for (auto idx : lbIndices) {
+ auto ineq = getInequality(idx);
+ // Extract the lower bound (in terms of other coeff's + const), i.e., if
+ // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
+ // - 1.
+ addCoeffs(ineq, lb);
+ std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
+ auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context);
+ exprs.push_back(expr);
+ }
+ auto lbMap =
+ exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs);
+
+ exprs.clear();
+ exprs.reserve(ubIndices.size());
+ // Upper bound expressions.
+ for (auto idx : ubIndices) {
+ auto ineq = getInequality(idx);
+ // Extract the upper bound (in terms of other coeff's + const).
+ addCoeffs(ineq, ub);
+ auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context);
+ // Upper bound is exclusive.
+ exprs.push_back(expr + 1);
+ }
+ auto ubMap =
+ exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs);
+
+ return {lbMap, ubMap};
+}
+
+/// Computes the lower and upper bounds of the first 'num' dimensional
+/// identifiers (starting at 'offset') as affine maps of the remaining
+/// identifiers (dimensional and symbolic identifiers). Local identifiers are
+/// themselves explicitly computed as affine functions of other identifiers in
+/// this process if needed.
+void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
+ MLIRContext *context,
+ SmallVectorImpl<AffineMap> *lbMaps,
+ SmallVectorImpl<AffineMap> *ubMaps) {
+ assert(num < getNumDimIds() && "invalid range");
+
+ // Basic simplification.
+ normalizeConstraintsByGCD();
+
+ LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
+ << " identifiers\n");
+ LLVM_DEBUG(dump());
+
+ // Record computed/detected identifiers.
+ SmallVector<AffineExpr, 8> memo(getNumIds());
+ // Initialize dimensional and symbolic identifiers.
+ for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
+ if (i < offset)
+ memo[i] = getAffineDimExpr(i, context);
+ else if (i >= offset + num)
+ memo[i] = getAffineDimExpr(i - num, context);
+ }
+ for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
+ memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
+
+ bool changed;
+ do {
+ changed = false;
+ // Identify yet unknown identifiers as constants or mod's / floordiv's of
+ // other identifiers if possible.
+ for (unsigned pos = 0; pos < getNumIds(); pos++) {
+ if (memo[pos])
+ continue;
+
+ auto lbConst = getConstantLowerBound(pos);
+ auto ubConst = getConstantUpperBound(pos);
+ if (lbConst.hasValue() && ubConst.hasValue()) {
+ // Detect equality to a constant.
+ if (lbConst.getValue() == ubConst.getValue()) {
+ memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
+ changed = true;
+ continue;
+ }
+
+ // Detect an identifier as modulo of another identifier w.r.t a
+ // constant.
+ if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
+ &memo)) {
+ changed = true;
+ continue;
+ }
+ }
+
+ // Detect an identifier as floordiv of another identifier w.r.t a
+ // constant.
+ if (detectAsFloorDiv(*this, pos, &memo, context)) {
+ changed = true;
+ continue;
+ }
+
+ // Detect an identifier as an expression of other identifiers.
+ unsigned idx;
+ if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) {
+ continue;
+ }
+
+ // Build AffineExpr solving for identifier 'pos' in terms of all others.
+ auto expr = getAffineConstantExpr(0, context);
+ unsigned j, e;
+ for (j = 0, e = getNumIds(); j < e; ++j) {
+ if (j == pos)
+ continue;
+ int64_t c = atEq(idx, j);
+ if (c == 0)
+ continue;
+ // If any of the involved IDs hasn't been found yet, we can't proceed.
+ if (!memo[j])
+ break;
+ expr = expr + memo[j] * c;
+ }
+ if (j < e)
+ // Can't construct expression as it depends on a yet uncomputed
+ // identifier.
+ continue;
+
+ // Add constant term to AffineExpr.
+ expr = expr + atEq(idx, getNumIds());
+ int64_t vPos = atEq(idx, pos);
+ assert(vPos != 0 && "expected non-zero here");
+ if (vPos > 0)
+ expr = (-expr).floorDiv(vPos);
+ else
+ // vPos < 0.
+ expr = expr.floorDiv(-vPos);
+ // Successfully constructed expression.
+ memo[pos] = expr;
+ changed = true;
+ }
+ // This loop is guaranteed to reach a fixed point - since once an
+ // identifier's explicit form is computed (in memo[pos]), it's not updated
+ // again.
+ } while (changed);
+
+ // Set the lower and upper bound maps for all the identifiers that were
+ // computed as affine expressions of the rest as the "detected expr" and
+ // "detected expr + 1" respectively; set the undetected ones to null.
+ Optional<FlatAffineConstraints> tmpClone;
+ for (unsigned pos = 0; pos < num; pos++) {
+ unsigned numMapDims = getNumDimIds() - num;
+ unsigned numMapSymbols = getNumSymbolIds();
+ AffineExpr expr = memo[pos + offset];
+ if (expr)
+ expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
+
+ AffineMap &lbMap = (*lbMaps)[pos];
+ AffineMap &ubMap = (*ubMaps)[pos];
+
+ if (expr) {
+ lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
+ ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
+ } else {
+ // TODO(bondhugula): Whenever there are local identifiers in the
+ // dependence constraints, we'll conservatively over-approximate, since we
+ // don't always explicitly compute them above (in the while loop).
+ if (getNumLocalIds() == 0) {
+ // Work on a copy so that we don't update this constraint system.
+ if (!tmpClone) {
+ tmpClone.emplace(FlatAffineConstraints(*this));
+ // Removing redundant inequalities is necessary so that we don't get
+ // redundant loop bounds.
+ tmpClone->removeRedundantInequalities();
+ }
+ std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
+ pos, offset, num, getNumDimIds(), {}, context);
+ }
+
+ // If the above fails, we'll just use the constant lower bound and the
+ // constant upper bound (if they exist) as the slice bounds.
+ // TODO(b/126426796): being conservative for the moment in cases that
+ // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
+ // fixed (b/126426796).
+ if (!lbMap || lbMap.getNumResults() > 1) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "WARNING: Potentially over-approximating slice lb\n");
+ auto lbConst = getConstantLowerBound(pos + offset);
+ if (lbConst.hasValue()) {
+ lbMap = AffineMap::get(
+ numMapDims, numMapSymbols,
+ getAffineConstantExpr(lbConst.getValue(), context));
+ }
+ }
+ if (!ubMap || ubMap.getNumResults() > 1) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "WARNING: Potentially over-approximating slice ub\n");
+ auto ubConst = getConstantUpperBound(pos + offset);
+ if (ubConst.hasValue()) {
+ (ubMap) = AffineMap::get(
+ numMapDims, numMapSymbols,
+ getAffineConstantExpr(ubConst.getValue() + 1, context));
+ }
+ }
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
+ LLVM_DEBUG(lbMap.dump(););
+ LLVM_DEBUG(llvm::dbgs()
+ << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
+ LLVM_DEBUG(ubMap.dump(););
+ }
+}
+
+LogicalResult
+FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
+ ArrayRef<Value> boundOperands,
+ bool eq, bool lower) {
+ assert(pos < getNumDimAndSymbolIds() && "invalid position");
+ // Equality follows the logic of lower bound except that we add an equality
+ // instead of an inequality.
+ assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
+ if (eq)
+ lower = true;
+
+ // Fully compose map and operands; canonicalize and simplify so that we
+ // transitively get to terminal symbols or loop IVs.
+ auto map = boundMap;
+ SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
+ fullyComposeAffineMapAndOperands(&map, &operands);
+ map = simplifyAffineMap(map);
+ canonicalizeMapAndOperands(&map, &operands);
+ for (auto operand : operands)
+ addInductionVarOrTerminalSymbol(operand);
+
+ FlatAffineConstraints localVarCst;
+ std::vector<SmallVector<int64_t, 8>> flatExprs;
+ if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) {
+ LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n");
+ return failure();
+ }
+
+ // Merge and align with localVarCst.
+ if (localVarCst.getNumLocalIds() > 0) {
+ // Set values for localVarCst.
+ localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
+ for (auto operand : operands) {
+ unsigned pos;
+ if (findId(*operand, &pos)) {
+ if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
+ // If the local var cst has this as a dim, turn it into its symbol.
+ turnDimIntoSymbol(&localVarCst, *operand);
+ } else if (pos < getNumDimIds()) {
+ // Or vice versa.
+ turnSymbolIntoDim(&localVarCst, *operand);
+ }
+ }
+ }
+ mergeAndAlignIds(/*offset=*/0, this, &localVarCst);
+ append(localVarCst);
+ }
+
+ // Record positions of the operands in the constraint system. Need to do
+ // this here since the constraint system changes after a bound is added.
+ SmallVector<unsigned, 8> positions;
+ unsigned numOperands = operands.size();
+ for (auto operand : operands) {
+ unsigned pos;
+ if (!findId(*operand, &pos))
+ assert(0 && "expected to be found");
+ positions.push_back(pos);
+ }
+
+ for (const auto &flatExpr : flatExprs) {
+ SmallVector<int64_t, 4> ineq(getNumCols(), 0);
+ ineq[pos] = lower ? 1 : -1;
+ // Dims and symbols.
+ for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) {
+ ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j];
+ }
+ // Copy over the local id coefficients.
+ unsigned numLocalIds = flatExpr.size() - 1 - numOperands;
+ for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds;
+ jj++, j++) {
+ ineq[j] =
+ lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj];
+ }
+ // Constant term.
+ ineq[getNumCols() - 1] =
+ lower ? -flatExpr[flatExpr.size() - 1]
+ // Upper bound in flattenedExpr is an exclusive one.
+ : flatExpr[flatExpr.size() - 1] - 1;
+ eq ? addEquality(ineq) : addInequality(ineq);
+ }
+ return success();
+}
+
+// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
+// bounds in 'ubMaps' to each value in `values' that appears in the constraint
+// system. Note that both lower/upper bounds share the same operand list
+// 'operands'.
+// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
+// skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
+// Note that both lower/upper bounds use operands from 'operands'.
+// Returns failure for unimplemented cases such as semi-affine expressions or
+// expressions with mod/floordiv.
+LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
+ ArrayRef<AffineMap> lbMaps,
+ ArrayRef<AffineMap> ubMaps,
+ ArrayRef<Value> operands) {
+ assert(values.size() == lbMaps.size());
+ assert(lbMaps.size() == ubMaps.size());
+
+ for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
+ unsigned pos;
+ if (!findId(*values[i], &pos))
+ continue;
+
+ AffineMap lbMap = lbMaps[i];
+ AffineMap ubMap = ubMaps[i];
+ assert(!lbMap || lbMap.getNumInputs() == operands.size());
+ assert(!ubMap || ubMap.getNumInputs() == operands.size());
+
+ // Check if this slice is just an equality along this dimension.
+ if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
+ ubMap.getNumResults() == 1 &&
+ lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
+ if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
+ /*lower=*/true)))
+ return failure();
+ continue;
+ }
+
+ if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
+ /*lower=*/true)))
+ return failure();
+
+ if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
+ /*lower=*/false)))
+ return failure();
+ }
+ return success();
+}
+
+void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
+ assert(eq.size() == getNumCols());
+ unsigned offset = equalities.size();
+ equalities.resize(equalities.size() + numReservedCols);
+ std::copy(eq.begin(), eq.end(), equalities.begin() + offset);
+}
+
+void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
+ assert(inEq.size() == getNumCols());
+ unsigned offset = inequalities.size();
+ inequalities.resize(inequalities.size() + numReservedCols);
+ std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset);
+}
+
+void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
+ assert(pos < getNumCols());
+ unsigned offset = inequalities.size();
+ inequalities.resize(inequalities.size() + numReservedCols);
+ std::fill(inequalities.begin() + offset,
+ inequalities.begin() + offset + getNumCols(), 0);
+ inequalities[offset + pos] = 1;
+ inequalities[offset + getNumCols() - 1] = -lb;
+}
+
+void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
+ assert(pos < getNumCols());
+ unsigned offset = inequalities.size();
+ inequalities.resize(inequalities.size() + numReservedCols);
+ std::fill(inequalities.begin() + offset,
+ inequalities.begin() + offset + getNumCols(), 0);
+ inequalities[offset + pos] = -1;
+ inequalities[offset + getNumCols() - 1] = ub;
+}
+
+void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
+ int64_t lb) {
+ assert(expr.size() == getNumCols());
+ unsigned offset = inequalities.size();
+ inequalities.resize(inequalities.size() + numReservedCols);
+ std::fill(inequalities.begin() + offset,
+ inequalities.begin() + offset + getNumCols(), 0);
+ std::copy(expr.begin(), expr.end(), inequalities.begin() + offset);
+ inequalities[offset + getNumCols() - 1] += -lb;
+}
+
+void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
+ int64_t ub) {
+ assert(expr.size() == getNumCols());
+ unsigned offset = inequalities.size();
+ inequalities.resize(inequalities.size() + numReservedCols);
+ std::fill(inequalities.begin() + offset,
+ inequalities.begin() + offset + getNumCols(), 0);
+ for (unsigned i = 0, e = getNumCols(); i < e; i++) {
+ inequalities[offset + i] = -expr[i];
+ }
+ inequalities[offset + getNumCols() - 1] += ub;
+}
+
+/// Adds a new local identifier as the floordiv of an affine function of other
+/// identifiers, the coefficients of which are provided in 'dividend' and with
+/// respect to a positive constant 'divisor'. Two constraints are added to the
+/// system to capture equivalence with the floordiv.
+/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
+void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend,
+ int64_t divisor) {
+ assert(dividend.size() == getNumCols() && "incorrect dividend size");
+ assert(divisor > 0 && "positive divisor expected");
+
+ addLocalId(getNumLocalIds());
+
+ // Add two constraints for this new identifier 'q'.
+ SmallVector<int64_t, 8> bound(dividend.size() + 1);
+
+ // dividend - q * divisor >= 0
+ std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
+ bound.begin());
+ bound.back() = dividend.back();
+ bound[getNumIds() - 1] = -divisor;
+ addInequality(bound);
+
+ // -dividend +qdivisor * q + divisor - 1 >= 0
+ std::transform(bound.begin(), bound.end(), bound.begin(),
+ std::negate<int64_t>());
+ bound[bound.size() - 1] += divisor - 1;
+ addInequality(bound);
+}
+
+bool FlatAffineConstraints::findId(Value id, unsigned *pos) const {
+ unsigned i = 0;
+ for (const auto &mayBeId : ids) {
+ if (mayBeId.hasValue() && mayBeId.getValue() == id) {
+ *pos = i;
+ return true;
+ }
+ i++;
+ }
+ return false;
+}
+
+bool FlatAffineConstraints::containsId(Value id) const {
+ return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) {
+ return mayBeId.hasValue() && mayBeId.getValue() == id;
+ });
+}
+
+void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
+ assert(newSymbolCount <= numDims + numSymbols &&
+ "invalid separation position");
+ numDims = numDims + numSymbols - newSymbolCount;
+ numSymbols = newSymbolCount;
+}
+
+/// Sets the specified identifier to a constant value.
+void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
+ unsigned offset = equalities.size();
+ equalities.resize(equalities.size() + numReservedCols);
+ std::fill(equalities.begin() + offset,
+ equalities.begin() + offset + getNumCols(), 0);
+ equalities[offset + pos] = 1;
+ equalities[offset + getNumCols() - 1] = -val;
+}
+
+/// Sets the specified identifier to a constant value; asserts if the id is not
+/// found.
+void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) {
+ unsigned pos;
+ if (!findId(id, &pos))
+ // This is a pre-condition for this method.
+ assert(0 && "id not found");
+ setIdToConstant(pos, val);
+}
+
+void FlatAffineConstraints::removeEquality(unsigned pos) {
+ unsigned numEqualities = getNumEqualities();
+ assert(pos < numEqualities);
+ unsigned outputIndex = pos * numReservedCols;
+ unsigned inputIndex = (pos + 1) * numReservedCols;
+ unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols;
+ std::copy(equalities.begin() + inputIndex,
+ equalities.begin() + inputIndex + numElemsToCopy,
+ equalities.begin() + outputIndex);
+ equalities.resize(equalities.size() - numReservedCols);
+}
+
+/// Finds an equality that equates the specified identifier to a constant.
+/// Returns the position of the equality row. If 'symbolic' is set to true,
+/// symbols are also treated like a constant, i.e., an affine function of the
+/// symbols is also treated like a constant.
+static int findEqualityToConstant(const FlatAffineConstraints &cst,
+ unsigned pos, bool symbolic = false) {
+ assert(pos < cst.getNumIds() && "invalid position");
+ for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
+ int64_t v = cst.atEq(r, pos);
+ if (v * v != 1)
+ continue;
+ unsigned c;
+ unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds();
+ // This checks for zeros in all positions other than 'pos' in [0, f)
+ for (c = 0; c < f; c++) {
+ if (c == pos)
+ continue;
+ if (cst.atEq(r, c) != 0) {
+ // Dependent on another identifier.
+ break;
+ }
+ }
+ if (c == f)
+ // Equality is free of other identifiers.
+ return r;
+ }
+ return -1;
+}
+
+void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) {
+ assert(pos < getNumIds() && "invalid position");
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal;
+ }
+ for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+ atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal;
+ }
+ removeId(pos);
+}
+
+LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) {
+ assert(pos < getNumIds() && "invalid position");
+ int rowIdx;
+ if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
+ return failure();
+
+ // atEq(rowIdx, pos) is either -1 or 1.
+ assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
+ int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
+ setAndEliminate(pos, constVal);
+ return success();
+}
+
+void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
+ for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
+ if (failed(constantFoldId(t)))
+ t++;
+ }
+}
+
+/// Returns the extent (upper bound - lower bound) of the specified
+/// identifier if it is found to be a constant; returns None if it's not a
+/// constant. This methods treats symbolic identifiers specially, i.e.,
+/// it looks for constant differences between affine expressions involving
+/// only the symbolic identifiers. See comments at function definition for
+/// example. 'lb', if provided, is set to the lower bound associated with the
+/// constant difference. Note that 'lb' is purely symbolic and thus will contain
+/// the coefficients of the symbolic identifiers and the constant coefficient.
+// Egs: 0 <= i <= 15, return 16.
+// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
+// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
+// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
+// ceil(s0 - 7 / 8) = floor(s0 / 8)).
+Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
+ unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor,
+ SmallVectorImpl<int64_t> *ub) const {
+ assert(pos < getNumDimIds() && "Invalid identifier position");
+ assert(getNumLocalIds() == 0);
+
+ // TODO(bondhugula): eliminate all remaining dimensional identifiers (other
+ // than the one at 'pos' to make this more powerful. Not needed for
+ // hyper-rectangular spaces.
+
+ // Find an equality for 'pos'^th identifier that equates it to some function
+ // of the symbolic identifiers (+ constant).
+ int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true);
+ if (eqRow != -1) {
+ // This identifier can only take a single value.
+ if (lb) {
+ // Set lb to the symbolic value.
+ lb->resize(getNumSymbolIds() + 1);
+ if (ub)
+ ub->resize(getNumSymbolIds() + 1);
+ for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
+ int64_t v = atEq(eqRow, pos);
+ // atEq(eqRow, pos) is either -1 or 1.
+ assert(v * v == 1);
+ (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v
+ : -atEq(eqRow, getNumDimIds() + c) / v;
+ // Since this is an equality, ub = lb.
+ if (ub)
+ (*ub)[c] = (*lb)[c];
+ }
+ assert(lbFloorDivisor &&
+ "both lb and divisor or none should be provided");
+ *lbFloorDivisor = 1;
+ }
+ return 1;
+ }
+
+ // Check if the identifier appears at all in any of the inequalities.
+ unsigned r, e;
+ for (r = 0, e = getNumInequalities(); r < e; r++) {
+ if (atIneq(r, pos) != 0)
+ break;
+ }
+ if (r == e)
+ // If it doesn't, there isn't a bound on it.
+ return None;
+
+ // Positions of constraints that are lower/upper bounds on the variable.
+ SmallVector<unsigned, 4> lbIndices, ubIndices;
+
+ // Gather all symbolic lower bounds and upper bounds of the variable. Since
+ // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a
+ // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ unsigned c, f;
+ for (c = 0, f = getNumDimIds(); c < f; c++) {
+ if (c != pos && atIneq(r, c) != 0)
+ break;
+ }
+ if (c < getNumDimIds())
+ // Not a pure symbolic bound.
+ continue;
+ if (atIneq(r, pos) >= 1)
+ // Lower bound.
+ lbIndices.push_back(r);
+ else if (atIneq(r, pos) <= -1)
+ // Upper bound.
+ ubIndices.push_back(r);
+ }
+
+ // TODO(bondhugula): eliminate other dimensional identifiers to make this more
+ // powerful. Not needed for hyper-rectangular iteration spaces.
+
+ Optional<int64_t> minDiff = None;
+ unsigned minLbPosition, minUbPosition;
+ for (auto ubPos : ubIndices) {
+ for (auto lbPos : lbIndices) {
+ // Look for a lower bound and an upper bound that only differ by a
+ // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst.
+ // For example, if ii is the pos^th variable, we are looking for
+ // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The
+ // minimum among all such constant differences is kept since that's the
+ // constant bounding the extent of the pos^th variable.
+ unsigned j, e;
+ for (j = 0, e = getNumCols() - 1; j < e; j++)
+ if (atIneq(ubPos, j) != -atIneq(lbPos, j)) {
+ break;
+ }
+ if (j < getNumCols() - 1)
+ continue;
+ int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
+ atIneq(lbPos, getNumCols() - 1) + 1,
+ atIneq(lbPos, pos));
+ if (minDiff == None || diff < minDiff) {
+ minDiff = diff;
+ minLbPosition = lbPos;
+ minUbPosition = ubPos;
+ }
+ }
+ }
+ if (lb && minDiff.hasValue()) {
+ // Set lb to the symbolic lower bound.
+ lb->resize(getNumSymbolIds() + 1);
+ if (ub)
+ ub->resize(getNumSymbolIds() + 1);
+ // The lower bound is the ceildiv of the lb constraint over the coefficient
+ // of the variable at 'pos'. We express the ceildiv equivalently as a floor
+ // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
+ // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
+ *lbFloorDivisor = atIneq(minLbPosition, pos);
+ assert(*lbFloorDivisor == -atIneq(minUbPosition, pos));
+ for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
+ (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
+ }
+ if (ub) {
+ for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++)
+ (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c);
+ }
+ // The lower bound leads to a ceildiv while the upper bound is a floordiv
+ // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
+ // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
+ // the constant term for the lower bound.
+ (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1;
+ }
+ return minDiff;
+}
+
+template <bool isLower>
+Optional<int64_t>
+FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
+ assert(pos < getNumIds() && "invalid position");
+ // Project to 'pos'.
+ projectOut(0, pos);
+ projectOut(1, getNumIds() - 1);
+ // Check if there's an equality equating the '0'^th identifier to a constant.
+ int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false);
+ if (eqRowIdx != -1)
+ // atEq(rowIdx, 0) is either -1 or 1.
+ return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0);
+
+ // Check if the identifier appears at all in any of the inequalities.
+ unsigned r, e;
+ for (r = 0, e = getNumInequalities(); r < e; r++) {
+ if (atIneq(r, 0) != 0)
+ break;
+ }
+ if (r == e)
+ // If it doesn't, there isn't a bound on it.
+ return None;
+
+ Optional<int64_t> minOrMaxConst = None;
+
+ // Take the max across all const lower bounds (or min across all constant
+ // upper bounds).
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ if (isLower) {
+ if (atIneq(r, 0) <= 0)
+ // Not a lower bound.
+ continue;
+ } else if (atIneq(r, 0) >= 0) {
+ // Not an upper bound.
+ continue;
+ }
+ unsigned c, f;
+ for (c = 0, f = getNumCols() - 1; c < f; c++)
+ if (c != 0 && atIneq(r, c) != 0)
+ break;
+ if (c < getNumCols() - 1)
+ // Not a constant bound.
+ continue;
+
+ int64_t boundConst =
+ isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
+ : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
+ if (isLower) {
+ if (minOrMaxConst == None || boundConst > minOrMaxConst)
+ minOrMaxConst = boundConst;
+ } else {
+ if (minOrMaxConst == None || boundConst < minOrMaxConst)
+ minOrMaxConst = boundConst;
+ }
+ }
+ return minOrMaxConst;
+}
+
+Optional<int64_t>
+FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
+ FlatAffineConstraints tmpCst(*this);
+ return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
+}
+
+Optional<int64_t>
+FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
+ FlatAffineConstraints tmpCst(*this);
+ return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
+}
+
+// A simple (naive and conservative) check for hyper-rectangularity.
+bool FlatAffineConstraints::isHyperRectangular(unsigned pos,
+ unsigned num) const {
+ assert(pos < getNumCols() - 1);
+ // Check for two non-zero coefficients in the range [pos, pos + sum).
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ unsigned sum = 0;
+ for (unsigned c = pos; c < pos + num; c++) {
+ if (atIneq(r, c) != 0)
+ sum++;
+ }
+ if (sum > 1)
+ return false;
+ }
+ for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+ unsigned sum = 0;
+ for (unsigned c = pos; c < pos + num; c++) {
+ if (atEq(r, c) != 0)
+ sum++;
+ }
+ if (sum > 1)
+ return false;
+ }
+ return true;
+}
+
+void FlatAffineConstraints::print(raw_ostream &os) const {
+ assert(hasConsistentState());
+ os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds()
+ << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints()
+ << " constraints)\n";
+ os << "(";
+ for (unsigned i = 0, e = getNumIds(); i < e; i++) {
+ if (ids[i] == None)
+ os << "None ";
+ else
+ os << "Value ";
+ }
+ os << " const)\n";
+ for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+ for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
+ os << atEq(i, j) << " ";
+ }
+ os << "= 0\n";
+ }
+ for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+ for (unsigned j = 0, f = getNumCols(); j < f; ++j) {
+ os << atIneq(i, j) << " ";
+ }
+ os << ">= 0\n";
+ }
+ os << '\n';
+}
+
+void FlatAffineConstraints::dump() const { print(llvm::errs()); }
+
+/// Removes duplicate constraints, trivially true constraints, and constraints
+/// that can be detected as redundant as a result of differing only in their
+/// constant term part. A constraint of the form <non-negative constant> >= 0 is
+/// considered trivially true.
+// Uses a DenseSet to hash and detect duplicates followed by a linear scan to
+// remove duplicates in place.
+void FlatAffineConstraints::removeTrivialRedundancy() {
+ SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
+
+ // A map used to detect redundancy stemming from constraints that only differ
+ // in their constant term. The value stored is <row position, const term>
+ // for a given row.
+ SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
+ rowsWithoutConstTerm;
+
+ // Check if constraint is of the form <non-negative-constant> >= 0.
+ auto isTriviallyValid = [&](unsigned r) -> bool {
+ for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) {
+ if (atIneq(r, c) != 0)
+ return false;
+ }
+ return atIneq(r, getNumCols() - 1) >= 0;
+ };
+
+ // Detect and mark redundant constraints.
+ SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ int64_t *rowStart = inequalities.data() + numReservedCols * r;
+ auto row = ArrayRef<int64_t>(rowStart, getNumCols());
+ if (isTriviallyValid(r) || !rowSet.insert(row).second) {
+ redunIneq[r] = true;
+ continue;
+ }
+
+ // Among constraints that only differ in the constant term part, mark
+ // everything other than the one with the smallest constant term redundant.
+ // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
+ // former two are redundant).
+ int64_t constTerm = atIneq(r, getNumCols() - 1);
+ auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
+ const auto &ret =
+ rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
+ if (!ret.second) {
+ // Check if the other constraint has a higher constant term.
+ auto &val = ret.first->second;
+ if (val.second > constTerm) {
+ // The stored row is redundant. Mark it so, and update with this one.
+ redunIneq[val.first] = true;
+ val = {r, constTerm};
+ } else {
+ // The one stored makes this one redundant.
+ redunIneq[r] = true;
+ }
+ }
+ }
+
+ auto copyRow = [&](unsigned src, unsigned dest) {
+ if (src == dest)
+ return;
+ for (unsigned c = 0, e = getNumCols(); c < e; c++) {
+ atIneq(dest, c) = atIneq(src, c);
+ }
+ };
+
+ // Scan to get rid of all rows marked redundant, in-place.
+ unsigned pos = 0;
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ if (!redunIneq[r])
+ copyRow(r, pos++);
+ }
+ inequalities.resize(numReservedCols * pos);
+
+ // TODO(bondhugula): consider doing this for equalities as well, but probably
+ // not worth the savings.
+}
+
+void FlatAffineConstraints::clearAndCopyFrom(
+ const FlatAffineConstraints &other) {
+ FlatAffineConstraints copy(other);
+ std::swap(*this, copy);
+ assert(copy.getNumIds() == copy.getIds().size());
+}
+
+void FlatAffineConstraints::removeId(unsigned pos) {
+ removeIdRange(pos, pos + 1);
+}
+
+static std::pair<unsigned, unsigned>
+getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
+ unsigned numDims = cst.getNumDimIds();
+ unsigned numSymbols = cst.getNumSymbolIds();
+ unsigned newNumDims, newNumSymbols;
+ if (pos < numDims) {
+ newNumDims = numDims - 1;
+ newNumSymbols = numSymbols;
+ } else if (pos < numDims + numSymbols) {
+ assert(numSymbols >= 1);
+ newNumDims = numDims;
+ newNumSymbols = numSymbols - 1;
+ } else {
+ newNumDims = numDims;
+ newNumSymbols = numSymbols;
+ }
+ return {newNumDims, newNumSymbols};
+}
+
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "fm"
+
+/// Eliminates identifier at the specified position using Fourier-Motzkin
+/// variable elimination. This technique is exact for rational spaces but
+/// conservative (in "rare" cases) for integer spaces. The operation corresponds
+/// to a projection operation yielding the (convex) set of integer points
+/// contained in the rational shadow of the set. An emptiness test that relies
+/// on this method will guarantee emptiness, i.e., it disproves the existence of
+/// a solution if it says it's empty.
+/// If a non-null isResultIntegerExact is passed, it is set to true if the
+/// result is also integer exact. If it's set to false, the obtained solution
+/// *may* not be exact, i.e., it may contain integer points that do not have an
+/// integer pre-image in the original set.
+///
+/// Eg:
+/// j >= 0, j <= i + 1
+/// i >= 0, i <= N + 1
+/// Eliminating i yields,
+/// j >= 0, 0 <= N + 1, j - 1 <= N + 1
+///
+/// If darkShadow = true, this method computes the dark shadow on elimination;
+/// the dark shadow is a convex integer subset of the exact integer shadow. A
+/// non-empty dark shadow proves the existence of an integer solution. The
+/// elimination in such a case could however be an under-approximation, and thus
+/// should not be used for scanning sets or used by itself for dependence
+/// checking.
+///
+/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
+/// ^
+/// |
+/// | * * * * o o
+/// i | * * o o o o
+/// | o * * * * *
+/// --------------->
+/// j ->
+///
+/// Eliminating i from this system (projecting on the j dimension):
+/// rational shadow / integer light shadow: 1 <= j <= 6
+/// dark shadow: 3 <= j <= 6
+/// exact integer shadow: j = 1 \union 3 <= j <= 6
+/// holes/splinters: j = 2
+///
+/// darkShadow = false, isResultIntegerExact = nullptr are default values.
+// TODO(bondhugula): a slight modification to yield dark shadow version of FM
+// (tightened), which can prove the existence of a solution if there is one.
+void FlatAffineConstraints::FourierMotzkinEliminate(
+ unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
+ LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n");
+ LLVM_DEBUG(dump());
+ assert(pos < getNumIds() && "invalid position");
+ assert(hasConsistentState());
+
+ // Check if this identifier can be eliminated through a substitution.
+ for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+ if (atEq(r, pos) != 0) {
+ // Use Gaussian elimination here (since we have an equality).
+ LogicalResult ret = gaussianEliminateId(pos);
+ (void)ret;
+ assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed");
+ LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
+ LLVM_DEBUG(dump());
+ return;
+ }
+ }
+
+ // A fast linear time tightening.
+ GCDTightenInequalities();
+
+ // Check if the identifier appears at all in any of the inequalities.
+ unsigned r, e;
+ for (r = 0, e = getNumInequalities(); r < e; r++) {
+ if (atIneq(r, pos) != 0)
+ break;
+ }
+ if (r == getNumInequalities()) {
+ // If it doesn't appear, just remove the column and return.
+ // TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
+ removeId(pos);
+ LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
+ LLVM_DEBUG(dump());
+ return;
+ }
+
+ // Positions of constraints that are lower bounds on the variable.
+ SmallVector<unsigned, 4> lbIndices;
+ // Positions of constraints that are lower bounds on the variable.
+ SmallVector<unsigned, 4> ubIndices;
+ // Positions of constraints that do not involve the variable.
+ std::vector<unsigned> nbIndices;
+ nbIndices.reserve(getNumInequalities());
+
+ // Gather all lower bounds and upper bounds of the variable. Since the
+ // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+ // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+ for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+ if (atIneq(r, pos) == 0) {
+ // Id does not appear in bound.
+ nbIndices.push_back(r);
+ } else if (atIneq(r, pos) >= 1) {
+ // Lower bound.
+ lbIndices.push_back(r);
+ } else {
+ // Upper bound.
+ ubIndices.push_back(r);
+ }
+ }
+
+ // Set the number of dimensions, symbols in the resulting system.
+ const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
+ unsigned newNumDims = dimsSymbols.first;
+ unsigned newNumSymbols = dimsSymbols.second;
+
+ SmallVector<Optional<Value>, 8> newIds;
+ newIds.reserve(numIds - 1);
+ newIds.append(ids.begin(), ids.begin() + pos);
+ newIds.append(ids.begin() + pos + 1, ids.end());
+
+ /// Create the new system which has one identifier less.
+ FlatAffineConstraints newFac(
+ lbIndices.size() * ubIndices.size() + nbIndices.size(),
+ getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
+ /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds);
+
+ assert(newFac.getIds().size() == newFac.getNumIds());
+
+ // This will be used to check if the elimination was integer exact.
+ unsigned lcmProducts = 1;
+
+ // Let x be the variable we are eliminating.
+ // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
+ // that c_l, c_u >= 1) we have:
+ // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
+ // We thus generate a constraint:
+ // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
+ // Note if c_l = c_u = 1, all integer points captured by the resulting
+ // constraint correspond to integer points in the original system (i.e., they
+ // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
+ // integer exact.
+ for (auto ubPos : ubIndices) {
+ for (auto lbPos : lbIndices) {
+ SmallVector<int64_t, 4> ineq;
+ ineq.reserve(newFac.getNumCols());
+ int64_t lbCoeff = atIneq(lbPos, pos);
+ // Note that in the comments above, ubCoeff is the negation of the
+ // coefficient in the canonical form as the view taken here is that of the
+ // term being moved to the other size of '>='.
+ int64_t ubCoeff = -atIneq(ubPos, pos);
+ // TODO(bondhugula): refactor this loop to avoid all branches inside.
+ for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+ if (l == pos)
+ continue;
+ assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
+ int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
+ ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
+ atIneq(lbPos, l) * (lcm / lbCoeff));
+ lcmProducts *= lcm;
+ }
+ if (darkShadow) {
+ // The dark shadow is a convex subset of the exact integer shadow. If
+ // there is a point here, it proves the existence of a solution.
+ ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
+ }
+ // TODO: we need to have a way to add inequalities in-place in
+ // FlatAffineConstraints instead of creating and copying over.
+ newFac.addInequality(ineq);
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1)
+ << "\n");
+ if (lcmProducts == 1 && isResultIntegerExact)
+ *isResultIntegerExact = true;
+
+ // Copy over the constraints not involving this variable.
+ for (auto nbPos : nbIndices) {
+ SmallVector<int64_t, 4> ineq;
+ ineq.reserve(getNumCols() - 1);
+ for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+ if (l == pos)
+ continue;
+ ineq.push_back(atIneq(nbPos, l));
+ }
+ newFac.addInequality(ineq);
+ }
+
+ assert(newFac.getNumConstraints() ==
+ lbIndices.size() * ubIndices.size() + nbIndices.size());
+
+ // Copy over the equalities.
+ for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+ SmallVector<int64_t, 4> eq;
+ eq.reserve(newFac.getNumCols());
+ for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+ if (l == pos)
+ continue;
+ eq.push_back(atEq(r, l));
+ }
+ newFac.addEquality(eq);
+ }
+
+ // GCD tightening and normalization allows detection of more trivially
+ // redundant constraints.
+ newFac.GCDTightenInequalities();
+ newFac.normalizeConstraintsByGCD();
+ newFac.removeTrivialRedundancy();
+ clearAndCopyFrom(newFac);
+ LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
+ LLVM_DEBUG(dump());
+}
+
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "affine-structures"
+
+void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
+ if (num == 0)
+ return;
+
+ // 'pos' can be at most getNumCols() - 2 if num > 0.
+ assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position");
+ assert(pos + num < getNumCols() && "invalid range");
+
+ // Eliminate as many identifiers as possible using Gaussian elimination.
+ unsigned currentPos = pos;
+ unsigned numToEliminate = num;
+ unsigned numGaussianEliminated = 0;
+
+ while (currentPos < getNumIds()) {
+ unsigned curNumEliminated =
+ gaussianEliminateIds(currentPos, currentPos + numToEliminate);
+ ++currentPos;
+ numToEliminate -= curNumEliminated + 1;
+ numGaussianEliminated += curNumEliminated;
+ }
+
+ // Eliminate the remaining using Fourier-Motzkin.
+ for (unsigned i = 0; i < num - numGaussianEliminated; i++) {
+ unsigned numToEliminate = num - numGaussianEliminated - i;
+ FourierMotzkinEliminate(
+ getBestIdToEliminate(*this, pos, pos + numToEliminate));
+ }
+
+ // Fast/trivial simplifications.
+ GCDTightenInequalities();
+ // Normalize constraints after tightening since the latter impacts this, but
+ // not the other way round.
+ normalizeConstraintsByGCD();
+}
+
+void FlatAffineConstraints::projectOut(Value id) {
+ unsigned pos;
+ bool ret = findId(*id, &pos);
+ assert(ret);
+ (void)ret;
+ FourierMotzkinEliminate(pos);
+}
+
+void FlatAffineConstraints::clearConstraints() {
+ equalities.clear();
+ inequalities.clear();
+}
+
+namespace {
+
+enum BoundCmpResult { Greater, Less, Equal, Unknown };
+
+/// Compares two affine bounds whose coefficients are provided in 'first' and
+/// 'second'. The last coefficient is the constant term.
+static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+ assert(a.size() == b.size());
+
+ // For the bounds to be comparable, their corresponding identifier
+ // coefficients should be equal; the constant terms are then compared to
+ // determine less/greater/equal.
+
+ if (!std::equal(a.begin(), a.end() - 1, b.begin()))
+ return Unknown;
+
+ if (a.back() == b.back())
+ return Equal;
+
+ return a.back() < b.back() ? Less : Greater;
+}
+} // namespace
+
+// Computes the bounding box with respect to 'other' by finding the min of the
+// lower bounds and the max of the upper bounds along each of the dimensions.
+LogicalResult
+FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
+ assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
+ assert(otherCst.getIds()
+ .slice(0, getNumDimIds())
+ .equals(getIds().slice(0, getNumDimIds())) &&
+ "dim values mismatch");
+ assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
+ assert(getNumLocalIds() == 0 && "local ids not supported yet here");
+
+ Optional<FlatAffineConstraints> otherCopy;
+ if (!areIdsAligned(*this, otherCst)) {
+ otherCopy.emplace(FlatAffineConstraints(otherCst));
+ mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue());
+ }
+
+ const auto &other = otherCopy ? *otherCopy : otherCst;
+
+ std::vector<SmallVector<int64_t, 8>> boundingLbs;
+ std::vector<SmallVector<int64_t, 8>> boundingUbs;
+ boundingLbs.reserve(2 * getNumDimIds());
+ boundingUbs.reserve(2 * getNumDimIds());
+
+ // To hold lower and upper bounds for each dimension.
+ SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
+ // To compute min of lower bounds and max of upper bounds for each dimension.
+ SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1);
+ SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1);
+ // To compute final new lower and upper bounds for the union.
+ SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
+
+ int64_t lbFloorDivisor, otherLbFloorDivisor;
+ for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
+ auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
+ if (!extent.hasValue())
+ // TODO(bondhugula): symbolic extents when necessary.
+ // TODO(bondhugula): handle union if a dimension is unbounded.
+ return failure();
+
+ auto otherExtent = other.getConstantBoundOnDimSize(
+ d, &otherLb, &otherLbFloorDivisor, &otherUb);
+ if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor)
+ // TODO(bondhugula): symbolic extents when necessary.
+ return failure();
+
+ assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
+
+ auto res = compareBounds(lb, otherLb);
+ // Identify min.
+ if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
+ minLb = lb;
+ // Since the divisor is for a floordiv, we need to convert to ceildiv,
+ // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=>
+ // div * i >= expr - div + 1.
+ minLb.back() -= lbFloorDivisor - 1;
+ } else if (res == BoundCmpResult::Greater) {
+ minLb = otherLb;
+ minLb.back() -= otherLbFloorDivisor - 1;
+ } else {
+ // Uncomparable - check for constant lower/upper bounds.
+ auto constLb = getConstantLowerBound(d);
+ auto constOtherLb = other.getConstantLowerBound(d);
+ if (!constLb.hasValue() || !constOtherLb.hasValue())
+ return failure();
+ std::fill(minLb.begin(), minLb.end(), 0);
+ minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue());
+ }
+
+ // Do the same for ub's but max of upper bounds. Identify max.
+ auto uRes = compareBounds(ub, otherUb);
+ if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) {
+ maxUb = ub;
+ } else if (uRes == BoundCmpResult::Less) {
+ maxUb = otherUb;
+ } else {
+ // Uncomparable - check for constant lower/upper bounds.
+ auto constUb = getConstantUpperBound(d);
+ auto constOtherUb = other.getConstantUpperBound(d);
+ if (!constUb.hasValue() || !constOtherUb.hasValue())
+ return failure();
+ std::fill(maxUb.begin(), maxUb.end(), 0);
+ maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue());
+ }
+
+ std::fill(newLb.begin(), newLb.end(), 0);
+ std::fill(newUb.begin(), newUb.end(), 0);
+
+ // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor,
+ // and so it's the divisor for newLb and newUb as well.
+ newLb[d] = lbFloorDivisor;
+ newUb[d] = -lbFloorDivisor;
+ // Copy over the symbolic part + constant term.
+ std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds());
+ std::transform(newLb.begin() + getNumDimIds(), newLb.end(),
+ newLb.begin() + getNumDimIds(), std::negate<int64_t>());
+ std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds());
+
+ boundingLbs.push_back(newLb);
+ boundingUbs.push_back(newUb);
+ }
+
+ // Clear all constraints and add the lower/upper bounds for the bounding box.
+ clearConstraints();
+ for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) {
+ addInequality(boundingLbs[d]);
+ addInequality(boundingUbs[d]);
+ }
+ // TODO(mlir-team): copy over pure symbolic constraints from this and 'other'
+ // over to the union (since the above are just the union along dimensions); we
+ // shouldn't be discarding any other constraints on the symbols.
+
+ return success();
+}
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
new file mode 100644
index 00000000000..96c2928b17f
--- /dev/null
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -0,0 +1,29 @@
+add_llvm_library(MLIRAnalysis STATIC
+ AffineAnalysis.cpp
+ AffineStructures.cpp
+ CallGraph.cpp
+ Dominance.cpp
+ InferTypeOpInterface.cpp
+ Liveness.cpp
+ LoopAnalysis.cpp
+ MemRefBoundCheck.cpp
+ NestedMatcher.cpp
+ OpStats.cpp
+ SliceAnalysis.cpp
+ TestMemRefDependenceCheck.cpp
+ TestParallelismDetection.cpp
+ Utils.cpp
+ VectorAnalysis.cpp
+ Verifier.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis
+ )
+add_dependencies(MLIRAnalysis
+ MLIRAffineOps
+ MLIRCallOpInterfacesIncGen
+ MLIRTypeInferOpInterfaceIncGen
+ MLIRLoopOps
+ MLIRVectorOps
+ )
+target_link_libraries(MLIRAnalysis MLIRAffineOps MLIRLoopOps MLIRVectorOps)
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
new file mode 100644
index 00000000000..c35421d55eb
--- /dev/null
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -0,0 +1,256 @@
+//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains interfaces and analyses for defining a nested callgraph.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/CallGraph.h"
+#include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// CallInterfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/CallInterfaces.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// CallGraphNode
+//===----------------------------------------------------------------------===//
+
+/// Returns if this node refers to the indirect/external node.
+bool CallGraphNode::isExternal() const { return !callableRegion; }
+
+/// Return the callable region this node represents. This can only be called
+/// on non-external nodes.
+Region *CallGraphNode::getCallableRegion() const {
+ assert(!isExternal() && "the external node has no callable region");
+ return callableRegion;
+}
+
+/// Adds an reference edge to the given node. This is only valid on the
+/// external node.
+void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
+ assert(isExternal() && "abstract edges are only valid on external nodes");
+ addEdge(node, Edge::Kind::Abstract);
+}
+
+/// Add an outgoing call edge from this node.
+void CallGraphNode::addCallEdge(CallGraphNode *node) {
+ addEdge(node, Edge::Kind::Call);
+}
+
+/// Adds a reference edge to the given child node.
+void CallGraphNode::addChildEdge(CallGraphNode *child) {
+ addEdge(child, Edge::Kind::Child);
+}
+
+/// Returns true if this node has any child edges.
+bool CallGraphNode::hasChildren() const {
+ return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
+}
+
+/// Add an edge to 'node' with the given kind.
+void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
+ edges.insert({node, kind});
+}
+
+//===----------------------------------------------------------------------===//
+// CallGraph
+//===----------------------------------------------------------------------===//
+
+/// Recursively compute the callgraph edges for the given operation. Computed
+/// edges are placed into the given callgraph object.
+static void computeCallGraph(Operation *op, CallGraph &cg,
+ CallGraphNode *parentNode);
+
+/// Compute the set of callgraph nodes that are created by regions nested within
+/// 'op'.
+static void computeCallables(Operation *op, CallGraph &cg,
+ CallGraphNode *parentNode) {
+ if (op->getNumRegions() == 0)
+ return;
+ if (auto callableOp = dyn_cast<CallableOpInterface>(op)) {
+ SmallVector<Region *, 1> callables;
+ callableOp.getCallableRegions(callables);
+ for (auto *callableRegion : callables)
+ cg.getOrAddNode(callableRegion, parentNode);
+ }
+}
+
+/// Recursively compute the callgraph edges within the given region. Computed
+/// edges are placed into the given callgraph object.
+static void computeCallGraph(Region &region, CallGraph &cg,
+ CallGraphNode *parentNode) {
+ // Iterate over the nested operations twice:
+ /// One to fully create nodes in the for each callable region of a nested
+ /// operation;
+ for (auto &block : region)
+ for (auto &nested : block)
+ computeCallables(&nested, cg, parentNode);
+
+ /// And another to recursively compute the callgraph.
+ for (auto &block : region)
+ for (auto &nested : block)
+ computeCallGraph(&nested, cg, parentNode);
+}
+
+/// Recursively compute the callgraph edges for the given operation. Computed
+/// edges are placed into the given callgraph object.
+static void computeCallGraph(Operation *op, CallGraph &cg,
+ CallGraphNode *parentNode) {
+ // Compute the callgraph nodes and edges for each of the nested operations.
+ auto isCallable = isa<CallableOpInterface>(op);
+ for (auto &region : op->getRegions()) {
+ // Check to see if this region is a callable node, if so this is the parent
+ // node of the nested region.
+ CallGraphNode *nestedParentNode;
+ if (!isCallable || !(nestedParentNode = cg.lookupNode(&region)))
+ nestedParentNode = parentNode;
+ computeCallGraph(region, cg, nestedParentNode);
+ }
+
+ // If there is no parent node, we ignore this operation. Even if this
+ // operation was a call, there would be no callgraph node to attribute it to.
+ if (!parentNode)
+ return;
+
+ // If this is a call operation, resolve the callee.
+ if (auto call = dyn_cast<CallOpInterface>(op))
+ parentNode->addCallEdge(
+ cg.resolveCallable(call.getCallableForCallee(), op));
+}
+
+CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
+ computeCallGraph(op, *this, /*parentNode=*/nullptr);
+}
+
+/// Get or add a call graph node for the given region.
+CallGraphNode *CallGraph::getOrAddNode(Region *region,
+ CallGraphNode *parentNode) {
+ assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
+ "expected parent operation to be callable");
+ std::unique_ptr<CallGraphNode> &node = nodes[region];
+ if (!node) {
+ node.reset(new CallGraphNode(region));
+
+ // Add this node to the given parent node if necessary.
+ if (parentNode)
+ parentNode->addChildEdge(node.get());
+ else
+ // Otherwise, connect all callable nodes to the external node, this allows
+ // for conservatively including all callable nodes within the graph.
+ // FIXME(riverriddle) This isn't correct, this is only necessary for
+ // callable nodes that *could* be called from external sources. This
+ // requires extending the interface for callables to check if they may be
+ // referenced externally.
+ externalNode.addAbstractEdge(node.get());
+ }
+ return node.get();
+}
+
+/// Lookup a call graph node for the given region, or nullptr if none is
+/// registered.
+CallGraphNode *CallGraph::lookupNode(Region *region) const {
+ auto it = nodes.find(region);
+ return it == nodes.end() ? nullptr : it->second.get();
+}
+
+/// Resolve the callable for given callee to a node in the callgraph, or the
+/// external node if a valid node was not resolved.
+CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
+ Operation *from) const {
+ // Get the callee operation from the callable.
+ Operation *callee;
+ if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
+ // TODO(riverriddle) Support nested references.
+ callee = SymbolTable::lookupNearestSymbolFrom(from,
+ symbolRef.getRootReference());
+ else
+ callee = callable.get<Value>()->getDefiningOp();
+
+ // If the callee is non-null and is a valid callable object, try to get the
+ // called region from it.
+ if (callee && callee->getNumRegions()) {
+ if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
+ if (auto *node = lookupNode(callableOp.getCallableRegion(callable)))
+ return node;
+ }
+ }
+
+ // If we don't have a valid direct region, this is an external call.
+ return getExternalNode();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing
+
+/// Dump the graph in a human readable format.
+void CallGraph::dump() const { print(llvm::errs()); }
+void CallGraph::print(raw_ostream &os) const {
+ os << "// ---- CallGraph ----\n";
+
+ // Functor used to output the name for the given node.
+ auto emitNodeName = [&](const CallGraphNode *node) {
+ if (node->isExternal()) {
+ os << "<External-Node>";
+ return;
+ }
+
+ auto *callableRegion = node->getCallableRegion();
+ auto *parentOp = callableRegion->getParentOp();
+ os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
+ << callableRegion->getRegionNumber();
+ if (auto attrs = parentOp->getAttrList().getDictionary())
+ os << " : " << attrs;
+ };
+
+ for (auto &nodeIt : nodes) {
+ const CallGraphNode *node = nodeIt.second.get();
+
+ // Dump the header for this node.
+ os << "// - Node : ";
+ emitNodeName(node);
+ os << "\n";
+
+ // Emit each of the edges.
+ for (auto &edge : *node) {
+ os << "// -- ";
+ if (edge.isCall())
+ os << "Call";
+ else if (edge.isChild())
+ os << "Child";
+
+ os << "-Edge : ";
+ emitNodeName(edge.getTarget());
+ os << "\n";
+ }
+ os << "//\n";
+ }
+
+ os << "// -- SCCs --\n";
+
+ for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
+ os << "// - SCC : \n";
+ for (auto &node : scc) {
+ os << "// -- Node :";
+ emitNodeName(node);
+ os << "\n";
+ }
+ os << "\n";
+ }
+
+ os << "// -------------------\n";
+}
diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
new file mode 100644
index 00000000000..e4af4c0d69b
--- /dev/null
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -0,0 +1,171 @@
+//===- Dominance.cpp - Dominator analysis for CFGs ------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of dominance related classes and instantiations of extern
+// templates.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/GenericDomTreeConstruction.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/false>;
+template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/true>;
+template class llvm::DomTreeNodeBase<Block>;
+
+//===----------------------------------------------------------------------===//
+// DominanceInfoBase
+//===----------------------------------------------------------------------===//
+
+template <bool IsPostDom>
+void DominanceInfoBase<IsPostDom>::recalculate(Operation *op) {
+ dominanceInfos.clear();
+
+ /// Build the dominance for each of the operation regions.
+ op->walk([&](Operation *op) {
+ for (auto &region : op->getRegions()) {
+ // Don't compute dominance if the region is empty.
+ if (region.empty())
+ continue;
+ auto opDominance = std::make_unique<base>();
+ opDominance->recalculate(region);
+ dominanceInfos.try_emplace(&region, std::move(opDominance));
+ }
+ });
+}
+
+/// Return true if the specified block A properly dominates block B.
+template <bool IsPostDom>
+bool DominanceInfoBase<IsPostDom>::properlyDominates(Block *a, Block *b) {
+ // A block dominates itself but does not properly dominate itself.
+ if (a == b)
+ return false;
+
+ // If either a or b are null, then conservatively return false.
+ if (!a || !b)
+ return false;
+
+ // If both blocks are not in the same region, 'a' properly dominates 'b' if
+ // 'b' is defined in an operation region that (recursively) ends up being
+ // dominated by 'a'. Walk up the list of containers enclosing B.
+ auto *regionA = a->getParent(), *regionB = b->getParent();
+ if (regionA != regionB) {
+ Operation *bAncestor;
+ do {
+ bAncestor = regionB->getParentOp();
+ // If 'bAncestor' is the top level region, then 'a' is a block that post
+ // dominates 'b'.
+ if (!bAncestor || !bAncestor->getBlock())
+ return IsPostDom;
+
+ regionB = bAncestor->getBlock()->getParent();
+ } while (regionA != regionB);
+
+ // Check to see if the ancestor of 'b' is the same block as 'a'.
+ b = bAncestor->getBlock();
+ if (a == b)
+ return true;
+ }
+
+ // Otherwise, use the standard dominance functionality.
+
+ // If we don't have a dominance information for this region, assume that b is
+ // dominated by anything.
+ auto baseInfoIt = dominanceInfos.find(regionA);
+ if (baseInfoIt == dominanceInfos.end())
+ return true;
+ return baseInfoIt->second->properlyDominates(a, b);
+}
+
+template class mlir::detail::DominanceInfoBase</*IsPostDom=*/true>;
+template class mlir::detail::DominanceInfoBase</*IsPostDom=*/false>;
+
+//===----------------------------------------------------------------------===//
+// DominanceInfo
+//===----------------------------------------------------------------------===//
+
+/// Return true if operation A properly dominates operation B.
+bool DominanceInfo::properlyDominates(Operation *a, Operation *b) {
+ auto *aBlock = a->getBlock(), *bBlock = b->getBlock();
+
+ // If a or b are not within a block, then a does not dominate b.
+ if (!aBlock || !bBlock)
+ return false;
+
+ // If the blocks are the same, then check if b is before a in the block.
+ if (aBlock == bBlock)
+ return a->isBeforeInBlock(b);
+
+ // Traverse up b's hierarchy to check if b's block is contained in a's.
+ if (auto *bAncestor = aBlock->findAncestorOpInBlock(*b)) {
+ // Since we already know that aBlock != bBlock, here bAncestor != b.
+ // a and bAncestor are in the same block; check if 'a' dominates
+ // bAncestor.
+ return dominates(a, bAncestor);
+ }
+
+ // If the blocks are different, check if a's block dominates b's.
+ return properlyDominates(aBlock, bBlock);
+}
+
+/// Return true if value A properly dominates operation B.
+bool DominanceInfo::properlyDominates(Value a, Operation *b) {
+ if (auto *aOp = a->getDefiningOp()) {
+ // The values defined by an operation do *not* dominate any nested
+ // operations.
+ if (aOp->getParentRegion() != b->getParentRegion() && aOp->isAncestor(b))
+ return false;
+ return properlyDominates(aOp, b);
+ }
+
+ // block arguments properly dominate all operations in their own block, so
+ // we use a dominates check here, not a properlyDominates check.
+ return dominates(a.cast<BlockArgument>()->getOwner(), b->getBlock());
+}
+
+DominanceInfoNode *DominanceInfo::getNode(Block *a) {
+ auto *region = a->getParent();
+ assert(dominanceInfos.count(region) != 0);
+ return dominanceInfos[region]->getNode(a);
+}
+
+void DominanceInfo::updateDFSNumbers() {
+ for (auto &iter : dominanceInfos)
+ iter.second->updateDFSNumbers();
+}
+
+//===----------------------------------------------------------------------===//
+// PostDominanceInfo
+//===----------------------------------------------------------------------===//
+
+/// Returns true if statement 'a' properly postdominates statement b.
+bool PostDominanceInfo::properlyPostDominates(Operation *a, Operation *b) {
+ auto *aBlock = a->getBlock(), *bBlock = b->getBlock();
+
+ // If a or b are not within a block, then a does not post dominate b.
+ if (!aBlock || !bBlock)
+ return false;
+
+ // If the blocks are the same, check if b is before a in the block.
+ if (aBlock == bBlock)
+ return b->isBeforeInBlock(a);
+
+ // Traverse up b's hierarchy to check if b's block is contained in a's.
+ if (auto *bAncestor = a->getBlock()->findAncestorOpInBlock(*b))
+ // Since we already know that aBlock != bBlock, here bAncestor != b.
+ // a and bAncestor are in the same block; check if 'a' postdominates
+ // bAncestor.
+ return postDominates(a, bAncestor);
+
+ // If the blocks are different, check if a's block post dominates b's.
+ return properlyDominates(aBlock, bBlock);
+}
diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp
new file mode 100644
index 00000000000..2e52de2b3fa
--- /dev/null
+++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp
@@ -0,0 +1,22 @@
+//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the infer op interfaces defined in
+// `InferTypeOpInterface.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/InferTypeOpInterface.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+#include "mlir/Analysis/InferTypeOpInterface.cpp.inc"
+} // namespace mlir
diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp
new file mode 100644
index 00000000000..7ba31365f1a
--- /dev/null
+++ b/mlir/lib/Analysis/Liveness.cpp
@@ -0,0 +1,373 @@
+//===- Liveness.cpp - Liveness analysis for MLIR --------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implementation of the liveness analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Liveness.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+/// Builds and holds block information during the construction phase.
+struct BlockInfoBuilder {
+ using ValueSetT = Liveness::ValueSetT;
+
+ /// Constructs an empty block builder.
+ BlockInfoBuilder() : block(nullptr) {}
+
+ /// Fills the block builder with initial liveness information.
+ BlockInfoBuilder(Block *block) : block(block) {
+ // Mark all block arguments (phis) as defined.
+ for (BlockArgument argument : block->getArguments())
+ defValues.insert(argument);
+
+ // Check all result values and whether their uses
+ // are inside this block or not (see outValues).
+ for (Operation &operation : *block)
+ for (Value result : operation.getResults()) {
+ defValues.insert(result);
+
+ // Check whether this value will be in the outValues
+ // set (its uses escape this block). Due to the SSA
+ // properties of the program, the uses must occur after
+ // the definition. Therefore, we do not have to check
+ // additional conditions to detect an escaping value.
+ for (OpOperand &use : result->getUses())
+ if (use.getOwner()->getBlock() != block) {
+ outValues.insert(result);
+ break;
+ }
+ }
+
+ // Check all operations for used operands.
+ for (Operation &operation : block->getOperations())
+ for (Value operand : operation.getOperands()) {
+ // If the operand is already defined in the scope of this
+ // block, we can skip the value in the use set.
+ if (!defValues.count(operand))
+ useValues.insert(operand);
+ }
+ }
+
+ /// Updates live-in information of the current block.
+ /// To do so it uses the default liveness-computation formula:
+ /// newIn = use union out \ def.
+ /// The methods returns true, if the set has changed (newIn != in),
+ /// false otherwise.
+ bool updateLiveIn() {
+ ValueSetT newIn = useValues;
+ llvm::set_union(newIn, outValues);
+ llvm::set_subtract(newIn, defValues);
+
+ // It is sufficient to check the set sizes (instead of their contents)
+ // since the live-in set can only grow monotonically during all update
+ // operations.
+ if (newIn.size() == inValues.size())
+ return false;
+
+ inValues = newIn;
+ return true;
+ }
+
+ /// Updates live-out information of the current block.
+ /// It iterates over all successors and unifies their live-in
+ /// values with the current live-out values.
+ template <typename SourceT> void updateLiveOut(SourceT &source) {
+ for (Block *succ : block->getSuccessors()) {
+ BlockInfoBuilder &builder = source[succ];
+ llvm::set_union(outValues, builder.inValues);
+ }
+ }
+
+ /// The current block.
+ Block *block;
+
+ /// The set of all live in values.
+ ValueSetT inValues;
+
+ /// The set of all live out values.
+ ValueSetT outValues;
+
+ /// The set of all defined values.
+ ValueSetT defValues;
+
+ /// The set of all used values.
+ ValueSetT useValues;
+};
+
+/// Builds the internal liveness block mapping.
+static void buildBlockMapping(MutableArrayRef<Region> regions,
+ DenseMap<Block *, BlockInfoBuilder> &builders) {
+ llvm::SetVector<Block *> toProcess;
+
+ // Initialize all block structures
+ for (Region &region : regions)
+ for (Block &block : region) {
+ BlockInfoBuilder &builder =
+ builders.try_emplace(&block, &block).first->second;
+
+ if (builder.updateLiveIn())
+ toProcess.insert(block.pred_begin(), block.pred_end());
+ }
+
+ // Propagate the in and out-value sets (fixpoint iteration)
+ while (!toProcess.empty()) {
+ Block *current = toProcess.pop_back_val();
+ BlockInfoBuilder &builder = builders[current];
+
+ // Update the current out values.
+ builder.updateLiveOut(builders);
+
+ // Compute (potentially) updated live in values.
+ if (builder.updateLiveIn())
+ toProcess.insert(current->pred_begin(), current->pred_end());
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Liveness
+//===----------------------------------------------------------------------===//
+
+/// Creates a new Liveness analysis that computes liveness
+/// information for all associated regions.
+Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); }
+
+/// Initializes the internal mappings.
+void Liveness::build(MutableArrayRef<Region> regions) {
+
+ // Build internal block mapping.
+ DenseMap<Block *, BlockInfoBuilder> builders;
+ buildBlockMapping(regions, builders);
+
+ // Store internal block data.
+ for (auto &entry : builders) {
+ BlockInfoBuilder &builder = entry.second;
+ LivenessBlockInfo &info = blockMapping[entry.first];
+
+ info.block = builder.block;
+ info.inValues = std::move(builder.inValues);
+ info.outValues = std::move(builder.outValues);
+ }
+}
+
+/// Gets liveness info (if any) for the given value.
+Liveness::OperationListT Liveness::resolveLiveness(Value value) const {
+ OperationListT result;
+ SmallPtrSet<Block *, 32> visited;
+ SmallVector<Block *, 8> toProcess;
+
+ // Start with the defining block
+ Block *currentBlock;
+ if (Operation *defOp = value->getDefiningOp())
+ currentBlock = defOp->getBlock();
+ else
+ currentBlock = value.cast<BlockArgument>()->getOwner();
+ toProcess.push_back(currentBlock);
+ visited.insert(currentBlock);
+
+ // Start with all associated blocks
+ for (OpOperand &use : value->getUses()) {
+ Block *useBlock = use.getOwner()->getBlock();
+ if (visited.insert(useBlock).second)
+ toProcess.push_back(useBlock);
+ }
+
+ while (!toProcess.empty()) {
+ // Get block and block liveness information.
+ Block *block = toProcess.back();
+ toProcess.pop_back();
+ const LivenessBlockInfo *blockInfo = getLiveness(block);
+
+ // Note that start and end will be in the same block.
+ Operation *start = blockInfo->getStartOperation(value);
+ Operation *end = blockInfo->getEndOperation(value, start);
+
+ result.push_back(start);
+ while (start != end) {
+ start = start->getNextNode();
+ result.push_back(start);
+ }
+
+ for (Block *successor : block->getSuccessors()) {
+ if (getLiveness(successor)->isLiveIn(value) &&
+ visited.insert(successor).second)
+ toProcess.push_back(successor);
+ }
+ }
+
+ return result;
+}
+
+/// Gets liveness info (if any) for the block.
+const LivenessBlockInfo *Liveness::getLiveness(Block *block) const {
+ auto it = blockMapping.find(block);
+ return it == blockMapping.end() ? nullptr : &it->second;
+}
+
+/// Returns a reference to a set containing live-in values.
+const Liveness::ValueSetT &Liveness::getLiveIn(Block *block) const {
+ return getLiveness(block)->in();
+}
+
+/// Returns a reference to a set containing live-out values.
+const Liveness::ValueSetT &Liveness::getLiveOut(Block *block) const {
+ return getLiveness(block)->out();
+}
+
+/// Returns true if the given operation represent the last use of the
+/// given value.
+bool Liveness::isLastUse(Value value, Operation *operation) const {
+ Block *block = operation->getBlock();
+ const LivenessBlockInfo *blockInfo = getLiveness(block);
+
+ // The given value escapes the associated block.
+ if (blockInfo->isLiveOut(value))
+ return false;
+
+ Operation *endOperation = blockInfo->getEndOperation(value, operation);
+ // If the operation is a real user of `value` the first check is sufficient.
+ // If not, we will have to test whether the end operation is executed before
+ // the given operation in the block.
+ return endOperation == operation || endOperation->isBeforeInBlock(operation);
+}
+
+/// Dumps the liveness information in a human readable format.
+void Liveness::dump() const { print(llvm::errs()); }
+
+/// Dumps the liveness information to the given stream.
+void Liveness::print(raw_ostream &os) const {
+ os << "// ---- Liveness -----\n";
+
+ // Builds unique block/value mappings for testing purposes.
+ DenseMap<Block *, size_t> blockIds;
+ DenseMap<Operation *, size_t> operationIds;
+ DenseMap<Value, size_t> valueIds;
+ for (Region &region : operation->getRegions())
+ for (Block &block : region) {
+ blockIds.insert({&block, blockIds.size()});
+ for (BlockArgument argument : block.getArguments())
+ valueIds.insert({argument, valueIds.size()});
+ for (Operation &operation : block) {
+ operationIds.insert({&operation, operationIds.size()});
+ for (Value result : operation.getResults())
+ valueIds.insert({result, valueIds.size()});
+ }
+ }
+
+ // Local printing helpers
+ auto printValueRef = [&](Value value) {
+ if (Operation *defOp = value->getDefiningOp())
+ os << "val_" << defOp->getName();
+ else {
+ auto blockArg = value.cast<BlockArgument>();
+ os << "arg" << blockArg->getArgNumber() << "@"
+ << blockIds[blockArg->getOwner()];
+ }
+ os << " ";
+ };
+
+ auto printValueRefs = [&](const ValueSetT &values) {
+ std::vector<Value> orderedValues(values.begin(), values.end());
+ std::sort(orderedValues.begin(), orderedValues.end(),
+ [&](Value left, Value right) {
+ return valueIds[left] < valueIds[right];
+ });
+ for (Value value : orderedValues)
+ printValueRef(value);
+ };
+
+ // Dump information about in and out values.
+ for (Region &region : operation->getRegions())
+ for (Block &block : region) {
+ os << "// - Block: " << blockIds[&block] << "\n";
+ auto liveness = getLiveness(&block);
+ os << "// --- LiveIn: ";
+ printValueRefs(liveness->inValues);
+ os << "\n// --- LiveOut: ";
+ printValueRefs(liveness->outValues);
+ os << "\n";
+
+ // Print liveness intervals.
+ os << "// --- BeginLiveness";
+ for (Operation &op : block) {
+ if (op.getNumResults() < 1)
+ continue;
+ os << "\n";
+ for (Value result : op.getResults()) {
+ os << "// ";
+ printValueRef(result);
+ os << ":";
+ auto liveOperations = resolveLiveness(result);
+ std::sort(liveOperations.begin(), liveOperations.end(),
+ [&](Operation *left, Operation *right) {
+ return operationIds[left] < operationIds[right];
+ });
+ for (Operation *operation : liveOperations) {
+ os << "\n// ";
+ operation->print(os);
+ }
+ }
+ }
+ os << "\n// --- EndLiveness\n";
+ }
+ os << "// -------------------\n";
+}
+
+//===----------------------------------------------------------------------===//
+// LivenessBlockInfo
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the given value is in the live-in set.
+bool LivenessBlockInfo::isLiveIn(Value value) const {
+ return inValues.count(value);
+}
+
+/// Returns true if the given value is in the live-out set.
+bool LivenessBlockInfo::isLiveOut(Value value) const {
+ return outValues.count(value);
+}
+
+/// Gets the start operation for the given value
+/// (must be referenced in this block).
+Operation *LivenessBlockInfo::getStartOperation(Value value) const {
+ Operation *definingOp = value->getDefiningOp();
+ // The given value is either live-in or is defined
+ // in the scope of this block.
+ if (isLiveIn(value) || !definingOp)
+ return &block->front();
+ return definingOp;
+}
+
+/// Gets the end operation for the given value using the start operation
+/// provided (must be referenced in this block).
+Operation *LivenessBlockInfo::getEndOperation(Value value,
+ Operation *startOperation) const {
+ // The given value is either dying in this block or live-out.
+ if (isLiveOut(value))
+ return &block->back();
+
+ // Resolve the last operation (must exist by definition).
+ Operation *endOperation = startOperation;
+ for (OpOperand &use : value->getUses()) {
+ Operation *useOperation = use.getOwner();
+ // Check whether the use is in our block and after
+ // the current end operation.
+ if (useOperation->getBlock() == block &&
+ endOperation->isBeforeInBlock(useOperation))
+ endOperation = useOperation;
+ }
+ return endOperation;
+}
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
new file mode 100644
index 00000000000..18c86dc63b4
--- /dev/null
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -0,0 +1,388 @@
+//===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous loop analysis routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/LoopAnalysis.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Support/MathExtras.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallString.h"
+#include <type_traits>
+
+using namespace mlir;
+
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a
+// pure analysis method relying on FlatAffineConstraints; the latter will also
+// be more powerful (since both inequalities and equalities will be considered).
+void mlir::buildTripCountMapAndOperands(
+ AffineForOp forOp, AffineMap *tripCountMap,
+ SmallVectorImpl<Value> *tripCountOperands) {
+ int64_t loopSpan;
+
+ int64_t step = forOp.getStep();
+ OpBuilder b(forOp.getOperation());
+
+ if (forOp.hasConstantBounds()) {
+ int64_t lb = forOp.getConstantLowerBound();
+ int64_t ub = forOp.getConstantUpperBound();
+ loopSpan = ub - lb;
+ if (loopSpan < 0)
+ loopSpan = 0;
+ *tripCountMap = b.getConstantAffineMap(ceilDiv(loopSpan, step));
+ tripCountOperands->clear();
+ return;
+ }
+ auto lbMap = forOp.getLowerBoundMap();
+ auto ubMap = forOp.getUpperBoundMap();
+ if (lbMap.getNumResults() != 1) {
+ *tripCountMap = AffineMap();
+ return;
+ }
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
+
+ // Difference of each upper bound expression from the single lower bound
+ // expression (divided by the step) provides the expressions for the trip
+ // count map.
+ AffineValueMap ubValueMap(ubMap, ubOperands);
+
+ SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
+ lbMap.getResult(0));
+ auto lbMapSplat =
+ AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), lbSplatExpr);
+ AffineValueMap lbSplatValueMap(lbMapSplat, lbOperands);
+
+ AffineValueMap tripCountValueMap;
+ AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
+ for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
+ tripCountValueMap.setResult(i,
+ tripCountValueMap.getResult(i).ceilDiv(step));
+
+ *tripCountMap = tripCountValueMap.getAffineMap();
+ tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
+ tripCountValueMap.getOperands().end());
+}
+
+/// Returns the trip count of the loop if it's a constant, None otherwise. This
+/// method uses affine expression analysis (in turn using getTripCount) and is
+/// able to determine constant trip count in non-trivial cases.
+// FIXME(mlir-team): this is really relying on buildTripCountMapAndOperands;
+// being an analysis utility, it shouldn't. Replace with a version that just
+// works with analysis structures (FlatAffineConstraints) and thus doesn't
+// update the IR.
+Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) {
+ SmallVector<Value, 4> operands;
+ AffineMap map;
+ buildTripCountMapAndOperands(forOp, &map, &operands);
+
+ if (!map)
+ return None;
+
+ // Take the min if all trip counts are constant.
+ Optional<uint64_t> tripCount;
+ for (auto resultExpr : map.getResults()) {
+ if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
+ if (tripCount.hasValue())
+ tripCount = std::min(tripCount.getValue(),
+ static_cast<uint64_t>(constExpr.getValue()));
+ else
+ tripCount = constExpr.getValue();
+ } else
+ return None;
+ }
+ return tripCount;
+}
+
+/// Returns the greatest known integral divisor of the trip count. Affine
+/// expression analysis is used (indirectly through getTripCount), and
+/// this method is thus able to determine non-trivial divisors.
+uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) {
+ SmallVector<Value, 4> operands;
+ AffineMap map;
+ buildTripCountMapAndOperands(forOp, &map, &operands);
+
+ if (!map)
+ return 1;
+
+ // The largest divisor of the trip count is the GCD of the individual largest
+ // divisors.
+ assert(map.getNumResults() >= 1 && "expected one or more results");
+ Optional<uint64_t> gcd;
+ for (auto resultExpr : map.getResults()) {
+ uint64_t thisGcd;
+ if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) {
+ uint64_t tripCount = constExpr.getValue();
+ // 0 iteration loops (greatest divisor is 2^64 - 1).
+ if (tripCount == 0)
+ thisGcd = std::numeric_limits<uint64_t>::max();
+ else
+ // The greatest divisor is the trip count.
+ thisGcd = tripCount;
+ } else {
+ // Trip count is not a known constant; return its largest known divisor.
+ thisGcd = resultExpr.getLargestKnownDivisor();
+ }
+ if (gcd.hasValue())
+ gcd = llvm::GreatestCommonDivisor64(gcd.getValue(), thisGcd);
+ else
+ gcd = thisGcd;
+ }
+ assert(gcd.hasValue() && "value expected per above logic");
+ return gcd.getValue();
+}
+
+/// Given an induction variable `iv` of type AffineForOp and an access `index`
+/// of type index, returns `true` if `index` is independent of `iv` and
+/// false otherwise. The determination supports composition with at most one
+/// AffineApplyOp. The 'at most one AffineApplyOp' comes from the fact that
+/// the composition of AffineApplyOp needs to be canonicalized by construction
+/// to avoid writing code that composes arbitrary numbers of AffineApplyOps
+/// everywhere. To achieve this, at the very least, the compose-affine-apply
+/// pass must have been run.
+///
+/// Prerequisites:
+/// 1. `iv` and `index` of the proper type;
+/// 2. at most one reachable AffineApplyOp from index;
+///
+/// Returns false in cases with more than one AffineApplyOp, this is
+/// conservative.
+static bool isAccessIndexInvariant(Value iv, Value index) {
+ assert(isForInductionVar(iv) && "iv must be a AffineForOp");
+ assert(index->getType().isa<IndexType>() && "index must be of IndexType");
+ SmallVector<Operation *, 4> affineApplyOps;
+ getReachableAffineApplyOps({index}, affineApplyOps);
+
+ if (affineApplyOps.empty()) {
+ // Pointer equality test because of Value pointer semantics.
+ return index != iv;
+ }
+
+ if (affineApplyOps.size() > 1) {
+ affineApplyOps[0]->emitRemark(
+ "CompositionAffineMapsPass must have been run: there should be at most "
+ "one AffineApplyOp, returning false conservatively.");
+ return false;
+ }
+
+ auto composeOp = cast<AffineApplyOp>(affineApplyOps[0]);
+ // We need yet another level of indirection because the `dim` index of the
+ // access may not correspond to the `dim` index of composeOp.
+ return !(AffineValueMap(composeOp).isFunctionOf(0, iv));
+}
+
+DenseSet<Value> mlir::getInvariantAccesses(Value iv, ArrayRef<Value> indices) {
+ DenseSet<Value> res;
+ for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
+ auto val = indices[idx];
+ if (isAccessIndexInvariant(iv, val)) {
+ res.insert(val);
+ }
+ }
+ return res;
+}
+
+/// Given:
+/// 1. an induction variable `iv` of type AffineForOp;
+/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
+/// determines whether `memoryOp` has a contiguous access along `iv`. Contiguous
+/// is defined as either invariant or varying only along a unique MemRef dim.
+/// Upon success, the unique MemRef dim is written in `memRefDim` (or -1 to
+/// convey the memRef access is invariant along `iv`).
+///
+/// Prerequisites:
+/// 1. `memRefDim` ~= nullptr;
+/// 2. `iv` of the proper type;
+/// 3. the MemRef accessed by `memoryOp` has no layout map or at most an
+/// identity layout map.
+///
+/// Currently only supports no layoutMap or identity layoutMap in the MemRef.
+/// Returns false if the MemRef has a non-identity layoutMap or more than 1
+/// layoutMap. This is conservative.
+///
+// TODO(ntv): check strides.
+template <typename LoadOrStoreOp>
+static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
+ int *memRefDim) {
+ static_assert(std::is_same<LoadOrStoreOp, AffineLoadOp>::value ||
+ std::is_same<LoadOrStoreOp, AffineStoreOp>::value,
+ "Must be called on either const LoadOp & or const StoreOp &");
+ assert(memRefDim && "memRefDim == nullptr");
+ auto memRefType = memoryOp.getMemRefType();
+
+ auto layoutMap = memRefType.getAffineMaps();
+ // TODO(ntv): remove dependence on Builder once we support non-identity
+ // layout map.
+ Builder b(memoryOp.getContext());
+ if (layoutMap.size() >= 2 ||
+ (layoutMap.size() == 1 &&
+ !(layoutMap[0] ==
+ b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) {
+ return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
+ }
+
+ int uniqueVaryingIndexAlongIv = -1;
+ auto accessMap = memoryOp.getAffineMap();
+ SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands());
+ unsigned numDims = accessMap.getNumDims();
+ for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
+ // Gather map operands used result expr 'i' in 'exprOperands'.
+ SmallVector<Value, 4> exprOperands;
+ auto resultExpr = accessMap.getResult(i);
+ resultExpr.walk([&](AffineExpr expr) {
+ if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ exprOperands.push_back(mapOperands[dimExpr.getPosition()]);
+ else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+ exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
+ });
+ // Check access invariance of each operand in 'exprOperands'.
+ for (auto exprOperand : exprOperands) {
+ if (!isAccessIndexInvariant(iv, exprOperand)) {
+ if (uniqueVaryingIndexAlongIv != -1) {
+ // 2+ varying indices -> do not vectorize along iv.
+ return false;
+ }
+ uniqueVaryingIndexAlongIv = i;
+ }
+ }
+ }
+
+ if (uniqueVaryingIndexAlongIv == -1)
+ *memRefDim = -1;
+ else
+ *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1);
+ return true;
+}
+
+template <typename LoadOrStoreOpPointer>
+static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
+ auto memRefType = memoryOp.getMemRefType();
+ return memRefType.getElementType().template isa<VectorType>();
+}
+
+using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
+
+static bool
+isVectorizableLoopBodyWithOpCond(AffineForOp loop,
+ VectorizableOpFun isVectorizableOp,
+ NestedPattern &vectorTransferMatcher) {
+ auto *forOp = loop.getOperation();
+
+ // No vectorization across conditionals for now.
+ auto conditionals = matcher::If();
+ SmallVector<NestedMatch, 8> conditionalsMatched;
+ conditionals.match(forOp, &conditionalsMatched);
+ if (!conditionalsMatched.empty()) {
+ return false;
+ }
+
+ // No vectorization across unknown regions.
+ auto regions = matcher::Op([](Operation &op) -> bool {
+ return op.getNumRegions() != 0 &&
+ !(isa<AffineIfOp>(op) || isa<AffineForOp>(op));
+ });
+ SmallVector<NestedMatch, 8> regionsMatched;
+ regions.match(forOp, &regionsMatched);
+ if (!regionsMatched.empty()) {
+ return false;
+ }
+
+ SmallVector<NestedMatch, 8> vectorTransfersMatched;
+ vectorTransferMatcher.match(forOp, &vectorTransfersMatched);
+ if (!vectorTransfersMatched.empty()) {
+ return false;
+ }
+
+ auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
+ SmallVector<NestedMatch, 8> loadAndStoresMatched;
+ loadAndStores.match(forOp, &loadAndStoresMatched);
+ for (auto ls : loadAndStoresMatched) {
+ auto *op = ls.getMatchedOperation();
+ auto load = dyn_cast<AffineLoadOp>(op);
+ auto store = dyn_cast<AffineStoreOp>(op);
+ // Only scalar types are considered vectorizable, all load/store must be
+ // vectorizable for a loop to qualify as vectorizable.
+ // TODO(ntv): ponder whether we want to be more general here.
+ bool vector = load ? isVectorElement(load) : isVectorElement(store);
+ if (vector) {
+ return false;
+ }
+ if (isVectorizableOp && !isVectorizableOp(loop, *op)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim,
+ NestedPattern &vectorTransferMatcher) {
+ VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) {
+ auto load = dyn_cast<AffineLoadOp>(op);
+ auto store = dyn_cast<AffineStoreOp>(op);
+ return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim)
+ : isContiguousAccess(loop.getInductionVar(), store, memRefDim);
+ });
+ return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher);
+}
+
+bool mlir::isVectorizableLoopBody(AffineForOp loop,
+ NestedPattern &vectorTransferMatcher) {
+ return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher);
+}
+
+/// Checks whether SSA dominance would be violated if a for op's body
+/// operations are shifted by the specified shifts. This method checks if a
+/// 'def' and all its uses have the same shift factor.
+// TODO(mlir-team): extend this to check for memory-based dependence violation
+// when we have the support.
+bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) {
+ auto *forBody = forOp.getBody();
+ assert(shifts.size() == forBody->getOperations().size());
+
+ // Work backwards over the body of the block so that the shift of a use's
+ // ancestor operation in the block gets recorded before it's looked up.
+ DenseMap<Operation *, uint64_t> forBodyShift;
+ for (auto it : llvm::enumerate(llvm::reverse(forBody->getOperations()))) {
+ auto &op = it.value();
+
+ // Get the index of the current operation, note that we are iterating in
+ // reverse so we need to fix it up.
+ size_t index = shifts.size() - it.index() - 1;
+
+ // Remember the shift of this operation.
+ uint64_t shift = shifts[index];
+ forBodyShift.try_emplace(&op, shift);
+
+ // Validate the results of this operation if it were to be shifted.
+ for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
+ Value result = op.getResult(i);
+ for (auto *user : result->getUsers()) {
+ // If an ancestor operation doesn't lie in the block of forOp,
+ // there is no shift to check.
+ if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) {
+ assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map");
+ if (shift != forBodyShift[ancOp])
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+}
diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp
new file mode 100644
index 00000000000..1f7c1a1ae31
--- /dev/null
+++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp
@@ -0,0 +1,53 @@
+//===- MemRefBoundCheck.cpp - MLIR Affine Structures Class ----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to check memref accesses for out of bound
+// accesses.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ADT/TypeSwitch.h"
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "memref-bound-check"
+
+using namespace mlir;
+
+namespace {
+
+/// Checks for out of bound memef access subscripts..
+struct MemRefBoundCheck : public FunctionPass<MemRefBoundCheck> {
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefBoundCheckPass() {
+ return std::make_unique<MemRefBoundCheck>();
+}
+
+void MemRefBoundCheck::runOnFunction() {
+ getFunction().walk([](Operation *opInst) {
+ TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>(
+ [](auto op) { boundCheckLoadOrStoreOp(op); });
+
+ // TODO(bondhugula): do this for DMA ops as well.
+ });
+}
+
+static PassRegistration<MemRefBoundCheck>
+ memRefBoundCheck("memref-bound-check",
+ "Check memref access bounds in a Function");
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
new file mode 100644
index 00000000000..97eaafd37ce
--- /dev/null
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -0,0 +1,152 @@
+//===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/NestedMatcher.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+llvm::BumpPtrAllocator *&NestedMatch::allocator() {
+ thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
+}
+
+NestedMatch NestedMatch::build(Operation *operation,
+ ArrayRef<NestedMatch> nestedMatches) {
+ auto *result = allocator()->Allocate<NestedMatch>();
+ auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
+ std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
+ new (result) NestedMatch();
+ result->matchedOperation = operation;
+ result->matchedChildren =
+ ArrayRef<NestedMatch>(children, nestedMatches.size());
+ return *result;
+}
+
+llvm::BumpPtrAllocator *&NestedPattern::allocator() {
+ thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
+}
+
+NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
+ FilterFunctionType filter)
+ : nestedPatterns(), filter(filter), skip(nullptr) {
+ if (!nested.empty()) {
+ auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
+ std::uninitialized_copy(nested.begin(), nested.end(), newNested);
+ nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
+ }
+}
+
+unsigned NestedPattern::getDepth() const {
+ if (nestedPatterns.empty()) {
+ return 1;
+ }
+ unsigned depth = 0;
+ for (auto &c : nestedPatterns) {
+ depth = std::max(depth, c.getDepth());
+ }
+ return depth + 1;
+}
+
+/// Matches a single operation in the following way:
+/// 1. checks the kind of operation against the matcher, if different then
+/// there is no match;
+/// 2. calls the customizable filter function to refine the single operation
+/// match with extra semantic constraints;
+/// 3. if all is good, recursively matches the nested patterns;
+/// 4. if all nested match then the single operation matches too and is
+/// appended to the list of matches;
+/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
+/// want to traverse in post-order DFS to avoid invalidating iterators.
+void NestedPattern::matchOne(Operation *op,
+ SmallVectorImpl<NestedMatch> *matches) {
+ if (skip == op) {
+ return;
+ }
+ // Local custom filter function
+ if (!filter(*op)) {
+ return;
+ }
+
+ if (nestedPatterns.empty()) {
+ SmallVector<NestedMatch, 8> nestedMatches;
+ matches->push_back(NestedMatch::build(op, nestedMatches));
+ return;
+ }
+ // Take a copy of each nested pattern so we can match it.
+ for (auto nestedPattern : nestedPatterns) {
+ SmallVector<NestedMatch, 8> nestedMatches;
+ // Skip elem in the walk immediately following. Without this we would
+ // essentially need to reimplement walk here.
+ nestedPattern.skip = op;
+ nestedPattern.match(op, &nestedMatches);
+ // If we could not match even one of the specified nestedPattern, early exit
+ // as this whole branch is not a match.
+ if (nestedMatches.empty()) {
+ return;
+ }
+ matches->push_back(NestedMatch::build(op, nestedMatches));
+ }
+}
+
+static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
+
+static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
+
+namespace mlir {
+namespace matcher {
+
+NestedPattern Op(FilterFunctionType filter) {
+ return NestedPattern({}, filter);
+}
+
+NestedPattern If(NestedPattern child) {
+ return NestedPattern(child, isAffineIfOp);
+}
+NestedPattern If(FilterFunctionType filter, NestedPattern child) {
+ return NestedPattern(child, [filter](Operation &op) {
+ return isAffineIfOp(op) && filter(op);
+ });
+}
+NestedPattern If(ArrayRef<NestedPattern> nested) {
+ return NestedPattern(nested, isAffineIfOp);
+}
+NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
+ return NestedPattern(nested, [filter](Operation &op) {
+ return isAffineIfOp(op) && filter(op);
+ });
+}
+
+NestedPattern For(NestedPattern child) {
+ return NestedPattern(child, isAffineForOp);
+}
+NestedPattern For(FilterFunctionType filter, NestedPattern child) {
+ return NestedPattern(
+ child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
+}
+NestedPattern For(ArrayRef<NestedPattern> nested) {
+ return NestedPattern(nested, isAffineForOp);
+}
+NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) {
+ return NestedPattern(
+ nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
+}
+
+bool isLoadOrStore(Operation &op) {
+ return isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op);
+}
+
+} // end namespace matcher
+} // end namespace mlir
diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp
new file mode 100644
index 00000000000..dbd938710ef
--- /dev/null
+++ b/mlir/lib/Analysis/OpStats.cpp
@@ -0,0 +1,84 @@
+//===- OpStats.cpp - Prints stats of operations in module -----------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
+ explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
+
+ // Prints the resultant operation statistics post iterating over the module.
+ void runOnModule() override;
+
+ // Print summary of op stats.
+ void printSummary();
+
+private:
+ llvm::StringMap<int64_t> opCount;
+ raw_ostream &os;
+};
+} // namespace
+
+void PrintOpStatsPass::runOnModule() {
+ opCount.clear();
+
+ // Compute the operation statistics for each function in the module.
+ for (auto &op : getModule())
+ op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
+ printSummary();
+}
+
+void PrintOpStatsPass::printSummary() {
+ os << "Operations encountered:\n";
+ os << "-----------------------\n";
+ SmallVector<StringRef, 64> sorted(opCount.keys());
+ llvm::sort(sorted);
+
+ // Split an operation name from its dialect prefix.
+ auto splitOperationName = [](StringRef opName) {
+ auto splitName = opName.split('.');
+ return splitName.second.empty() ? std::make_pair("", splitName.first)
+ : splitName;
+ };
+
+ // Compute the largest dialect and operation name.
+ StringRef dialectName, opName;
+ size_t maxLenOpName = 0, maxLenDialect = 0;
+ for (const auto &key : sorted) {
+ std::tie(dialectName, opName) = splitOperationName(key);
+ maxLenDialect = std::max(maxLenDialect, dialectName.size());
+ maxLenOpName = std::max(maxLenOpName, opName.size());
+ }
+
+ for (const auto &key : sorted) {
+ std::tie(dialectName, opName) = splitOperationName(key);
+
+ // Left-align the names (aligning on the dialect) and right-align the count
+ // below. The alignment is for readability and does not affect CSV/FileCheck
+ // parsing.
+ if (dialectName.empty())
+ os.indent(maxLenDialect + 3);
+ else
+ os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.';
+
+ // Left justify the operation name.
+ os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key]
+ << '\n';
+ }
+}
+
+static PassRegistration<PrintOpStatsPass>
+ pass("print-op-stats", "Print statistics of operations");
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
new file mode 100644
index 00000000000..89ee613b370
--- /dev/null
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -0,0 +1,213 @@
+//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements Analysis functions specific to slicing in Function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/LoopOps/LoopOps.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+
+///
+/// Implements Analysis functions specific to slicing in Function.
+///
+
+using namespace mlir;
+
+using llvm::SetVector;
+
+static void getForwardSliceImpl(Operation *op,
+ SetVector<Operation *> *forwardSlice,
+ TransitiveFilter filter) {
+ if (!op) {
+ return;
+ }
+
+ // Evaluate whether we should keep this use.
+ // This is useful in particular to implement scoping; i.e. return the
+ // transitive forwardSlice in the current scope.
+ if (!filter(op)) {
+ return;
+ }
+
+ if (auto forOp = dyn_cast<AffineForOp>(op)) {
+ for (auto *ownerInst : forOp.getInductionVar()->getUsers())
+ if (forwardSlice->count(ownerInst) == 0)
+ getForwardSliceImpl(ownerInst, forwardSlice, filter);
+ } else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
+ for (auto *ownerInst : forOp.getInductionVar()->getUsers())
+ if (forwardSlice->count(ownerInst) == 0)
+ getForwardSliceImpl(ownerInst, forwardSlice, filter);
+ } else {
+ assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
+ assert(op->getNumResults() <= 1 && "unexpected multiple results");
+ if (op->getNumResults() > 0) {
+ for (auto *ownerInst : op->getResult(0)->getUsers())
+ if (forwardSlice->count(ownerInst) == 0)
+ getForwardSliceImpl(ownerInst, forwardSlice, filter);
+ }
+ }
+
+ forwardSlice->insert(op);
+}
+
+void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
+ TransitiveFilter filter) {
+ getForwardSliceImpl(op, forwardSlice, filter);
+ // Don't insert the top level operation, we just queried on it and don't
+ // want it in the results.
+ forwardSlice->remove(op);
+
+ // Reverse to get back the actual topological order.
+ // std::reverse does not work out of the box on SetVector and I want an
+ // in-place swap based thing (the real std::reverse, not the LLVM adapter).
+ std::vector<Operation *> v(forwardSlice->takeVector());
+ forwardSlice->insert(v.rbegin(), v.rend());
+}
+
+static void getBackwardSliceImpl(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ TransitiveFilter filter) {
+ if (!op)
+ return;
+
+ assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) ||
+ isa<loop::ForOp>(op)) &&
+ "unexpected generic op with regions");
+
+ // Evaluate whether we should keep this def.
+ // This is useful in particular to implement scoping; i.e. return the
+ // transitive forwardSlice in the current scope.
+ if (!filter(op)) {
+ return;
+ }
+
+ for (auto en : llvm::enumerate(op->getOperands())) {
+ auto operand = en.value();
+ if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
+ if (auto affIv = getForInductionVarOwner(operand)) {
+ auto *affOp = affIv.getOperation();
+ if (backwardSlice->count(affOp) == 0)
+ getBackwardSliceImpl(affOp, backwardSlice, filter);
+ } else if (auto loopIv = loop::getForInductionVarOwner(operand)) {
+ auto *loopOp = loopIv.getOperation();
+ if (backwardSlice->count(loopOp) == 0)
+ getBackwardSliceImpl(loopOp, backwardSlice, filter);
+ } else if (blockArg->getOwner() !=
+ &op->getParentOfType<FuncOp>().getBody().front()) {
+ op->emitError("unsupported CF for operand ") << en.index();
+ llvm_unreachable("Unsupported control flow");
+ }
+ continue;
+ }
+ auto *op = operand->getDefiningOp();
+ if (backwardSlice->count(op) == 0) {
+ getBackwardSliceImpl(op, backwardSlice, filter);
+ }
+ }
+
+ backwardSlice->insert(op);
+}
+
+void mlir::getBackwardSlice(Operation *op,
+ SetVector<Operation *> *backwardSlice,
+ TransitiveFilter filter) {
+ getBackwardSliceImpl(op, backwardSlice, filter);
+
+ // Don't insert the top level operation, we just queried on it and don't
+ // want it in the results.
+ backwardSlice->remove(op);
+}
+
+SetVector<Operation *> mlir::getSlice(Operation *op,
+ TransitiveFilter backwardFilter,
+ TransitiveFilter forwardFilter) {
+ SetVector<Operation *> slice;
+ slice.insert(op);
+
+ unsigned currentIndex = 0;
+ SetVector<Operation *> backwardSlice;
+ SetVector<Operation *> forwardSlice;
+ while (currentIndex != slice.size()) {
+ auto *currentInst = (slice)[currentIndex];
+ // Compute and insert the backwardSlice starting from currentInst.
+ backwardSlice.clear();
+ getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
+ slice.insert(backwardSlice.begin(), backwardSlice.end());
+
+ // Compute and insert the forwardSlice starting from currentInst.
+ forwardSlice.clear();
+ getForwardSlice(currentInst, &forwardSlice, forwardFilter);
+ slice.insert(forwardSlice.begin(), forwardSlice.end());
+ ++currentIndex;
+ }
+ return topologicalSort(slice);
+}
+
+namespace {
+/// DFS post-order implementation that maintains a global count to work across
+/// multiple invocations, to help implement topological sort on multi-root DAGs.
+/// We traverse all operations but only record the ones that appear in
+/// `toSort` for the final result.
+struct DFSState {
+ DFSState(const SetVector<Operation *> &set)
+ : toSort(set), topologicalCounts(), seen() {}
+ const SetVector<Operation *> &toSort;
+ SmallVector<Operation *, 16> topologicalCounts;
+ DenseSet<Operation *> seen;
+};
+} // namespace
+
+static void DFSPostorder(Operation *current, DFSState *state) {
+ assert(current->getNumResults() <= 1 && "NYI: multi-result");
+ if (current->getNumResults() > 0) {
+ for (auto &u : current->getResult(0)->getUses()) {
+ auto *op = u.getOwner();
+ DFSPostorder(op, state);
+ }
+ }
+ bool inserted;
+ using IterTy = decltype(state->seen.begin());
+ IterTy iter;
+ std::tie(iter, inserted) = state->seen.insert(current);
+ if (inserted) {
+ if (state->toSort.count(current) > 0) {
+ state->topologicalCounts.push_back(current);
+ }
+ }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+ if (toSort.empty()) {
+ return toSort;
+ }
+
+ // Run from each root with global count and `seen` set.
+ DFSState state(toSort);
+ for (auto *s : toSort) {
+ assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
+ DFSPostorder(s, &state);
+ }
+
+ // Reorder and return.
+ SetVector<Operation *> res;
+ for (auto it = state.topologicalCounts.rbegin(),
+ eit = state.topologicalCounts.rend();
+ it != eit; ++it) {
+ res.insert(*it);
+ }
+ return res;
+}
diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
new file mode 100644
index 00000000000..c6d7519740e
--- /dev/null
+++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp
@@ -0,0 +1,121 @@
+//===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to run pair-wise memref access dependence checks.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "test-memref-dependence-check"
+
+using namespace mlir;
+
+namespace {
+
+// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
+/// Checks dependences between all pairs of memref accesses in a Function.
+struct TestMemRefDependenceCheck
+ : public FunctionPass<TestMemRefDependenceCheck> {
+ SmallVector<Operation *, 4> loadsAndStores;
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>>
+mlir::createTestMemRefDependenceCheckPass() {
+ return std::make_unique<TestMemRefDependenceCheck>();
+}
+
+// Returns a result string which represents the direction vector (if there was
+// a dependence), returns the string "false" otherwise.
+static std::string
+getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
+ ArrayRef<DependenceComponent> dependenceComponents) {
+ if (!ret)
+ return "false";
+ if (dependenceComponents.empty() || loopNestDepth > numCommonLoops)
+ return "true";
+ std::string result;
+ for (unsigned i = 0, e = dependenceComponents.size(); i < e; ++i) {
+ std::string lbStr = "-inf";
+ if (dependenceComponents[i].lb.hasValue() &&
+ dependenceComponents[i].lb.getValue() !=
+ std::numeric_limits<int64_t>::min())
+ lbStr = std::to_string(dependenceComponents[i].lb.getValue());
+
+ std::string ubStr = "+inf";
+ if (dependenceComponents[i].ub.hasValue() &&
+ dependenceComponents[i].ub.getValue() !=
+ std::numeric_limits<int64_t>::max())
+ ubStr = std::to_string(dependenceComponents[i].ub.getValue());
+
+ result += "[" + lbStr + ", " + ubStr + "]";
+ }
+ return result;
+}
+
+// For each access in 'loadsAndStores', runs a dependence check between this
+// "source" access and all subsequent "destination" accesses in
+// 'loadsAndStores'. Emits the result of the dependence check as a note with
+// the source access.
+static void checkDependences(ArrayRef<Operation *> loadsAndStores) {
+ for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
+ auto *srcOpInst = loadsAndStores[i];
+ MemRefAccess srcAccess(srcOpInst);
+ for (unsigned j = 0; j < e; ++j) {
+ auto *dstOpInst = loadsAndStores[j];
+ MemRefAccess dstAccess(dstOpInst);
+
+ unsigned numCommonLoops =
+ getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
+ for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
+ FlatAffineConstraints dependenceConstraints;
+ SmallVector<DependenceComponent, 2> dependenceComponents;
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, dstAccess, d, &dependenceConstraints,
+ &dependenceComponents);
+ assert(result.value != DependenceResult::Failure);
+ bool ret = hasDependence(result);
+ // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
+ // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
+ // vectors from ([1, 1], [3, 3]) to (1, 3).
+ srcOpInst->emitRemark("dependence from ")
+ << i << " to " << j << " at depth " << d << " = "
+ << getDirectionVectorStr(ret, numCommonLoops, d,
+ dependenceComponents);
+ }
+ }
+ }
+}
+
+// Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
+// Runs pair-wise dependence checks.
+void TestMemRefDependenceCheck::runOnFunction() {
+ // Collect the loads and stores within the function.
+ loadsAndStores.clear();
+ getFunction().walk([&](Operation *op) {
+ if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+ loadsAndStores.push_back(op);
+ });
+
+ checkDependences(loadsAndStores);
+}
+
+static PassRegistration<TestMemRefDependenceCheck>
+ pass("test-memref-dependence-check",
+ "Checks dependences between all pairs of memref accesses.");
diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp
new file mode 100644
index 00000000000..6cfc5431df3
--- /dev/null
+++ b/mlir/lib/Analysis/TestParallelismDetection.cpp
@@ -0,0 +1,48 @@
+//===- ParallelismDetection.cpp - Parallelism Detection pass ------------*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to detect parallel affine 'affine.for' ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Passes.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestParallelismDetection
+ : public FunctionPass<TestParallelismDetection> {
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createParallelismDetectionTestPass() {
+ return std::make_unique<TestParallelismDetection>();
+}
+
+// Walks the function and emits a note for all 'affine.for' ops detected as
+// parallel.
+void TestParallelismDetection::runOnFunction() {
+ FuncOp f = getFunction();
+ OpBuilder b(f.getBody());
+ f.walk([&](AffineForOp forOp) {
+ if (isLoopParallel(forOp))
+ forOp.emitRemark("parallel loop");
+ else
+ forOp.emitRemark("sequential loop");
+ });
+}
+
+static PassRegistration<TestParallelismDetection>
+ pass("test-detect-parallel", "Test parallelism detection ");
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
new file mode 100644
index 00000000000..8ddf2e274eb
--- /dev/null
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -0,0 +1,1007 @@
+//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements miscellaneous analysis routines for non-loop IR
+// structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Utils.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "analysis-utils"
+
+using namespace mlir;
+
+using llvm::SmallDenseMap;
+
+/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from
+/// the outermost 'affine.for' operation to the innermost one.
+void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
+ auto *currOp = op.getParentOp();
+ AffineForOp currAffineForOp;
+ // Traverse up the hierarchy collecting all 'affine.for' operation while
+ // skipping over 'affine.if' operations.
+ while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) ||
+ isa<AffineIfOp>(currOp))) {
+ if (currAffineForOp)
+ loops->push_back(currAffineForOp);
+ currOp = currOp->getParentOp();
+ }
+ std::reverse(loops->begin(), loops->end());
+}
+
+// Populates 'cst' with FlatAffineConstraints which represent slice bounds.
+LogicalResult
+ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
+ assert(!lbOperands.empty());
+ // Adds src 'ivs' as dimension identifiers in 'cst'.
+ unsigned numDims = ivs.size();
+ // Adds operands (dst ivs and symbols) as symbols in 'cst'.
+ unsigned numSymbols = lbOperands[0].size();
+
+ SmallVector<Value, 4> values(ivs);
+ // Append 'ivs' then 'operands' to 'values'.
+ values.append(lbOperands[0].begin(), lbOperands[0].end());
+ cst->reset(numDims, numSymbols, 0, values);
+
+ // Add loop bound constraints for values which are loop IVs and equality
+ // constraints for symbols which are constants.
+ for (const auto &value : values) {
+ assert(cst->containsId(*value) && "value expected to be present");
+ if (isValidSymbol(value)) {
+ // Check if the symbol is a constant.
+
+ if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
+ cst->setIdToConstant(*value, cOp.getValue());
+ } else if (auto loop = getForInductionVarOwner(value)) {
+ if (failed(cst->addAffineForOpDomain(loop)))
+ return failure();
+ }
+ }
+
+ // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
+ LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
+ assert(succeeded(ret) &&
+ "should not fail as we never have semi-affine slice maps");
+ (void)ret;
+ return success();
+}
+
+// Clears state bounds and operand state.
+void ComputationSliceState::clearBounds() {
+ lbs.clear();
+ ubs.clear();
+ lbOperands.clear();
+ ubOperands.clear();
+}
+
+unsigned MemRefRegion::getRank() const {
+ return memref->getType().cast<MemRefType>().getRank();
+}
+
+Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
+ SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
+ SmallVectorImpl<int64_t> *lbDivisors) const {
+ auto memRefType = memref->getType().cast<MemRefType>();
+ unsigned rank = memRefType.getRank();
+ if (shape)
+ shape->reserve(rank);
+
+ assert(rank == cst.getNumDimIds() && "inconsistent memref region");
+
+ // Find a constant upper bound on the extent of this memref region along each
+ // dimension.
+ int64_t numElements = 1;
+ int64_t diffConstant;
+ int64_t lbDivisor;
+ for (unsigned d = 0; d < rank; d++) {
+ SmallVector<int64_t, 4> lb;
+ Optional<int64_t> diff = cst.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
+ if (diff.hasValue()) {
+ diffConstant = diff.getValue();
+ assert(lbDivisor > 0);
+ } else {
+ // If no constant bound is found, then it can always be bound by the
+ // memref's dim size if the latter has a constant size along this dim.
+ auto dimSize = memRefType.getDimSize(d);
+ if (dimSize == -1)
+ return None;
+ diffConstant = dimSize;
+ // Lower bound becomes 0.
+ lb.resize(cst.getNumSymbolIds() + 1, 0);
+ lbDivisor = 1;
+ }
+ numElements *= diffConstant;
+ if (lbs) {
+ lbs->push_back(lb);
+ assert(lbDivisors && "both lbs and lbDivisor or none");
+ lbDivisors->push_back(lbDivisor);
+ }
+ if (shape) {
+ shape->push_back(diffConstant);
+ }
+ }
+ return numElements;
+}
+
+LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
+ assert(memref == other.memref);
+ return cst.unionBoundingBox(*other.getConstraints());
+}
+
+/// Computes the memory region accessed by this memref with the region
+/// represented as constraints symbolic/parametric in 'loopDepth' loops
+/// surrounding opInst and any additional Function symbols.
+// For example, the memref region for this load operation at loopDepth = 1 will
+// be as below:
+//
+// affine.for %i = 0 to 32 {
+// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
+// load %A[%ii]
+// }
+// }
+//
+// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
+// The last field is a 2-d FlatAffineConstraints symbolic in %i.
+//
+// TODO(bondhugula): extend this to any other memref dereferencing ops
+// (dma_start, dma_wait).
+LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
+ ComputationSliceState *sliceState,
+ bool addMemRefDimBounds) {
+ assert((isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) &&
+ "affine load/store op expected");
+
+ MemRefAccess access(op);
+ memref = access.memref;
+ write = access.isStore();
+
+ unsigned rank = access.getRank();
+
+ LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
+ << "depth: " << loopDepth << "\n";);
+
+ if (rank == 0) {
+ SmallVector<AffineForOp, 4> ivs;
+ getLoopIVs(*op, &ivs);
+ SmallVector<Value, 8> regionSymbols;
+ extractForInductionVars(ivs, &regionSymbols);
+ // A rank 0 memref has a 0-d region.
+ cst.reset(rank, loopDepth, 0, regionSymbols);
+ return success();
+ }
+
+ // Build the constraints for this region.
+ AffineValueMap accessValueMap;
+ access.getAccessMap(&accessValueMap);
+ AffineMap accessMap = accessValueMap.getAffineMap();
+
+ unsigned numDims = accessMap.getNumDims();
+ unsigned numSymbols = accessMap.getNumSymbols();
+ unsigned numOperands = accessValueMap.getNumOperands();
+ // Merge operands with slice operands.
+ SmallVector<Value, 4> operands;
+ operands.resize(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i)
+ operands[i] = accessValueMap.getOperand(i);
+
+ if (sliceState != nullptr) {
+ operands.reserve(operands.size() + sliceState->lbOperands[0].size());
+ // Append slice operands to 'operands' as symbols.
+ for (auto extraOperand : sliceState->lbOperands[0]) {
+ if (!llvm::is_contained(operands, extraOperand)) {
+ operands.push_back(extraOperand);
+ numSymbols++;
+ }
+ }
+ }
+ // We'll first associate the dims and symbols of the access map to the dims
+ // and symbols resp. of cst. This will change below once cst is
+ // fully constructed out.
+ cst.reset(numDims, numSymbols, 0, operands);
+
+ // Add equality constraints.
+ // Add inequalities for loop lower/upper bounds.
+ for (unsigned i = 0; i < numDims + numSymbols; ++i) {
+ auto operand = operands[i];
+ if (auto loop = getForInductionVarOwner(operand)) {
+ // Note that cst can now have more dimensions than accessMap if the
+ // bounds expressions involve outer loops or other symbols.
+ // TODO(bondhugula): rewrite this to use getInstIndexSet; this way
+ // conditionals will be handled when the latter supports it.
+ if (failed(cst.addAffineForOpDomain(loop)))
+ return failure();
+ } else {
+ // Has to be a valid symbol.
+ auto symbol = operand;
+ assert(isValidSymbol(symbol));
+ // Check if the symbol is a constant.
+ if (auto *op = symbol->getDefiningOp()) {
+ if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
+ cst.setIdToConstant(*symbol, constOp.getValue());
+ }
+ }
+ }
+ }
+
+ // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
+ if (sliceState != nullptr) {
+ // Add dim and symbol slice operands.
+ for (auto operand : sliceState->lbOperands[0]) {
+ cst.addInductionVarOrTerminalSymbol(operand);
+ }
+ // Add upper/lower bounds from 'sliceState' to 'cst'.
+ LogicalResult ret =
+ cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
+ sliceState->lbOperands[0]);
+ assert(succeeded(ret) &&
+ "should not fail as we never have semi-affine slice maps");
+ (void)ret;
+ }
+
+ // Add access function equalities to connect loop IVs to data dimensions.
+ if (failed(cst.composeMap(&accessValueMap))) {
+ op->emitError("getMemRefRegion: compose affine map failed");
+ LLVM_DEBUG(accessValueMap.getAffineMap().dump());
+ return failure();
+ }
+
+ // Set all identifiers appearing after the first 'rank' identifiers as
+ // symbolic identifiers - so that the ones corresponding to the memref
+ // dimensions are the dimensional identifiers for the memref region.
+ cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank);
+
+ // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
+ // this memref region is symbolic.
+ SmallVector<AffineForOp, 4> enclosingIVs;
+ getLoopIVs(*op, &enclosingIVs);
+ assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
+ enclosingIVs.resize(loopDepth);
+ SmallVector<Value, 4> ids;
+ cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids);
+ for (auto id : ids) {
+ AffineForOp iv;
+ if ((iv = getForInductionVarOwner(id)) &&
+ llvm::is_contained(enclosingIVs, iv) == false) {
+ cst.projectOut(id);
+ }
+ }
+
+ // Project out any local variables (these would have been added for any
+ // mod/divs).
+ cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds());
+
+ // Constant fold any symbolic identifiers.
+ cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(),
+ /*num=*/cst.getNumSymbolIds());
+
+ assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format");
+
+ // Add upper/lower bounds for each memref dimension with static size
+ // to guard against potential over-approximation from projection.
+ // TODO(andydavis) Support dynamic memref dimensions.
+ if (addMemRefDimBounds) {
+ auto memRefType = memref->getType().cast<MemRefType>();
+ for (unsigned r = 0; r < rank; r++) {
+ cst.addConstantLowerBound(r, 0);
+ int64_t dimSize = memRefType.getDimSize(r);
+ if (ShapedType::isDynamic(dimSize))
+ continue;
+ cst.addConstantUpperBound(r, dimSize - 1);
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
+ LLVM_DEBUG(cst.dump());
+ return success();
+}
+
+// TODO(mlir-team): improve/complete this when we have target data.
+static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
+ auto elementType = memRefType.getElementType();
+
+ unsigned sizeInBits;
+ if (elementType.isIntOrFloat()) {
+ sizeInBits = elementType.getIntOrFloatBitWidth();
+ } else {
+ auto vectorType = elementType.cast<VectorType>();
+ sizeInBits =
+ vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+ }
+ return llvm::divideCeil(sizeInBits, 8);
+}
+
+// Returns the size of the region.
+Optional<int64_t> MemRefRegion::getRegionSize() {
+ auto memRefType = memref->getType().cast<MemRefType>();
+
+ auto layoutMaps = memRefType.getAffineMaps();
+ if (layoutMaps.size() > 1 ||
+ (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
+ LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
+ return false;
+ }
+
+ // Indices to use for the DmaStart op.
+ // Indices for the original memref being DMAed from/to.
+ SmallVector<Value, 4> memIndices;
+ // Indices for the faster buffer being DMAed into/from.
+ SmallVector<Value, 4> bufIndices;
+
+ // Compute the extents of the buffer.
+ Optional<int64_t> numElements = getConstantBoundingSizeAndShape();
+ if (!numElements.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
+ return None;
+ }
+ return getMemRefEltSizeInBytes(memRefType) * numElements.getValue();
+}
+
+/// Returns the size of memref data in bytes if it's statically shaped, None
+/// otherwise. If the element of the memref has vector type, takes into account
+/// size of the vector as well.
+// TODO(mlir-team): improve/complete this when we have target data.
+Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) {
+ if (!memRefType.hasStaticShape())
+ return None;
+ auto elementType = memRefType.getElementType();
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+ return None;
+
+ uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType);
+ for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
+ sizeInBytes = sizeInBytes * memRefType.getDimSize(i);
+ }
+ return sizeInBytes;
+}
+
+template <typename LoadOrStoreOpPointer>
+LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
+ bool emitError) {
+ static_assert(std::is_same<LoadOrStoreOpPointer, AffineLoadOp>::value ||
+ std::is_same<LoadOrStoreOpPointer, AffineStoreOp>::value,
+ "argument should be either a AffineLoadOp or a AffineStoreOp");
+
+ Operation *opInst = loadOrStoreOp.getOperation();
+ MemRefRegion region(opInst->getLoc());
+ if (failed(region.compute(opInst, /*loopDepth=*/0, /*sliceState=*/nullptr,
+ /*addMemRefDimBounds=*/false)))
+ return success();
+
+ LLVM_DEBUG(llvm::dbgs() << "Memory region");
+ LLVM_DEBUG(region.getConstraints()->dump());
+
+ bool outOfBounds = false;
+ unsigned rank = loadOrStoreOp.getMemRefType().getRank();
+
+ // For each dimension, check for out of bounds.
+ for (unsigned r = 0; r < rank; r++) {
+ FlatAffineConstraints ucst(*region.getConstraints());
+
+ // Intersect memory region with constraint capturing out of bounds (both out
+ // of upper and out of lower), and check if the constraint system is
+ // feasible. If it is, there is at least one point out of bounds.
+ SmallVector<int64_t, 4> ineq(rank + 1, 0);
+ int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
+ // TODO(bondhugula): handle dynamic dim sizes.
+ if (dimSize == -1)
+ continue;
+
+ // Check for overflow: d_i >= memref dim size.
+ ucst.addConstantLowerBound(r, dimSize);
+ outOfBounds = !ucst.isEmpty();
+ if (outOfBounds && emitError) {
+ loadOrStoreOp.emitOpError()
+ << "memref out of upper bound access along dimension #" << (r + 1);
+ }
+
+ // Check for a negative index.
+ FlatAffineConstraints lcst(*region.getConstraints());
+ std::fill(ineq.begin(), ineq.end(), 0);
+ // d_i <= -1;
+ lcst.addConstantUpperBound(r, -1);
+ outOfBounds = !lcst.isEmpty();
+ if (outOfBounds && emitError) {
+ loadOrStoreOp.emitOpError()
+ << "memref out of lower bound access along dimension #" << (r + 1);
+ }
+ }
+ return failure(outOfBounds);
+}
+
+// Explicitly instantiate the template so that the compiler knows we need them!
+template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp,
+ bool emitError);
+template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp,
+ bool emitError);
+
+// Returns in 'positions' the Block positions of 'op' in each ancestor
+// Block from the Block containing operation, stopping at 'limitBlock'.
+static void findInstPosition(Operation *op, Block *limitBlock,
+ SmallVectorImpl<unsigned> *positions) {
+ Block *block = op->getBlock();
+ while (block != limitBlock) {
+ // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
+ // rely on linear scans.
+ int instPosInBlock = std::distance(block->begin(), op->getIterator());
+ positions->push_back(instPosInBlock);
+ op = block->getParentOp();
+ block = op->getBlock();
+ }
+ std::reverse(positions->begin(), positions->end());
+}
+
+// Returns the Operation in a possibly nested set of Blocks, where the
+// position of the operation is represented by 'positions', which has a
+// Block position for each level of nesting.
+static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
+ unsigned level, Block *block) {
+ unsigned i = 0;
+ for (auto &op : *block) {
+ if (i != positions[level]) {
+ ++i;
+ continue;
+ }
+ if (level == positions.size() - 1)
+ return &op;
+ if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
+ return getInstAtPosition(positions, level + 1,
+ childAffineForOp.getBody());
+
+ for (auto &region : op.getRegions()) {
+ for (auto &b : region)
+ if (auto *ret = getInstAtPosition(positions, level + 1, &b))
+ return ret;
+ }
+ return nullptr;
+ }
+ return nullptr;
+}
+
+// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
+LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
+ FlatAffineConstraints *cst) {
+ for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) {
+ auto value = cst->getIdValue(i);
+ if (ivs.count(value) == 0) {
+ assert(isForInductionVar(value));
+ auto loop = getForInductionVarOwner(value);
+ if (failed(cst->addAffineForOpDomain(loop)))
+ return failure();
+ }
+ }
+ return success();
+}
+
+// Returns the innermost common loop depth for the set of operations in 'ops'.
+// TODO(andydavis) Move this to LoopUtils.
+static unsigned
+getInnermostCommonLoopDepth(ArrayRef<Operation *> ops,
+ SmallVectorImpl<AffineForOp> &surroundingLoops) {
+ unsigned numOps = ops.size();
+ assert(numOps > 0);
+
+ std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
+ unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
+ for (unsigned i = 0; i < numOps; ++i) {
+ getLoopIVs(*ops[i], &loops[i]);
+ loopDepthLimit =
+ std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
+ }
+
+ unsigned loopDepth = 0;
+ for (unsigned d = 0; d < loopDepthLimit; ++d) {
+ unsigned i;
+ for (i = 1; i < numOps; ++i) {
+ if (loops[i - 1][d] != loops[i][d])
+ return loopDepth;
+ }
+ surroundingLoops.push_back(loops[i - 1][d]);
+ ++loopDepth;
+ }
+ return loopDepth;
+}
+
+/// Computes in 'sliceUnion' the union of all slice bounds computed at
+/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'.
+/// Returns 'Success' if union was computed, 'failure' otherwise.
+LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA,
+ ArrayRef<Operation *> opsB,
+ unsigned loopDepth,
+ unsigned numCommonLoops,
+ bool isBackwardSlice,
+ ComputationSliceState *sliceUnion) {
+ // Compute the union of slice bounds between all pairs in 'opsA' and
+ // 'opsB' in 'sliceUnionCst'.
+ FlatAffineConstraints sliceUnionCst;
+ assert(sliceUnionCst.getNumDimAndSymbolIds() == 0);
+ std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
+ for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) {
+ MemRefAccess srcAccess(opsA[i]);
+ for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) {
+ MemRefAccess dstAccess(opsB[j]);
+ if (srcAccess.memref != dstAccess.memref)
+ continue;
+ // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
+ if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) ||
+ (isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) {
+ LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n.");
+ return failure();
+ }
+
+ bool readReadAccesses = isa<AffineLoadOp>(srcAccess.opInst) &&
+ isa<AffineLoadOp>(dstAccess.opInst);
+ FlatAffineConstraints dependenceConstraints;
+ // Check dependence between 'srcAccess' and 'dstAccess'.
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
+ &dependenceConstraints, /*dependenceComponents=*/nullptr,
+ /*allowRAR=*/readReadAccesses);
+ if (result.value == DependenceResult::Failure) {
+ LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n.");
+ return failure();
+ }
+ if (result.value == DependenceResult::NoDependence)
+ continue;
+ dependentOpPairs.push_back({opsA[i], opsB[j]});
+
+ // Compute slice bounds for 'srcAccess' and 'dstAccess'.
+ ComputationSliceState tmpSliceState;
+ mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints,
+ loopDepth, isBackwardSlice,
+ &tmpSliceState);
+
+ if (sliceUnionCst.getNumDimAndSymbolIds() == 0) {
+ // Initialize 'sliceUnionCst' with the bounds computed in previous step.
+ if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute slice bound constraints\n.");
+ return failure();
+ }
+ assert(sliceUnionCst.getNumDimAndSymbolIds() > 0);
+ continue;
+ }
+
+ // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
+ FlatAffineConstraints tmpSliceCst;
+ if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute slice bound constraints\n.");
+ return failure();
+ }
+
+ // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
+ if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) {
+
+ // Pre-constraint id alignment: record loop IVs used in each constraint
+ // system.
+ SmallPtrSet<Value, 8> sliceUnionIVs;
+ for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k)
+ sliceUnionIVs.insert(sliceUnionCst.getIdValue(k));
+ SmallPtrSet<Value, 8> tmpSliceIVs;
+ for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k)
+ tmpSliceIVs.insert(tmpSliceCst.getIdValue(k));
+
+ sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst);
+
+ // Post-constraint id alignment: add loop IV bounds missing after
+ // id alignment to constraint systems. This can occur if one constraint
+ // system uses an loop IV that is not used by the other. The call
+ // to unionBoundingBox below expects constraints for each Loop IV, even
+ // if they are the unsliced full loop bounds added here.
+ if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
+ return failure();
+ if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
+ return failure();
+ }
+ // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
+ if (sliceUnionCst.getNumLocalIds() > 0 ||
+ tmpSliceCst.getNumLocalIds() > 0 ||
+ failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Unable to compute union bounding box of slice bounds."
+ "\n.");
+ return failure();
+ }
+ }
+ }
+
+ // Empty union.
+ if (sliceUnionCst.getNumDimAndSymbolIds() == 0)
+ return failure();
+
+ // Gather loops surrounding ops from loop nest where slice will be inserted.
+ SmallVector<Operation *, 4> ops;
+ for (auto &dep : dependentOpPairs) {
+ ops.push_back(isBackwardSlice ? dep.second : dep.first);
+ }
+ SmallVector<AffineForOp, 4> surroundingLoops;
+ unsigned innermostCommonLoopDepth =
+ getInnermostCommonLoopDepth(ops, surroundingLoops);
+ if (loopDepth > innermostCommonLoopDepth) {
+ LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n.");
+ return failure();
+ }
+
+ // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
+ unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds();
+
+ // Convert any dst loop IVs which are symbol identifiers to dim identifiers.
+ sliceUnionCst.convertLoopIVSymbolsToDims();
+ sliceUnion->clearBounds();
+ sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
+ sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
+
+ // Get slice bounds from slice union constraints 'sliceUnionCst'.
+ sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
+ opsA[0]->getContext(), &sliceUnion->lbs,
+ &sliceUnion->ubs);
+
+ // Add slice bound operands of union.
+ SmallVector<Value, 4> sliceBoundOperands;
+ sliceUnionCst.getIdValues(numSliceLoopIVs,
+ sliceUnionCst.getNumDimAndSymbolIds(),
+ &sliceBoundOperands);
+
+ // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
+ sliceUnion->ivs.clear();
+ sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs);
+
+ // Set loop nest insertion point to block start at 'loopDepth'.
+ sliceUnion->insertPoint =
+ isBackwardSlice
+ ? surroundingLoops[loopDepth - 1].getBody()->begin()
+ : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
+
+ // Give each bound its own copy of 'sliceBoundOperands' for subsequent
+ // canonicalization.
+ sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+ sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+ return success();
+}
+
+const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
+// Computes slice bounds by projecting out any loop IVs from
+// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
+// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
+// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
+void mlir::getComputationSliceState(
+ Operation *depSourceOp, Operation *depSinkOp,
+ FlatAffineConstraints *dependenceConstraints, unsigned loopDepth,
+ bool isBackwardSlice, ComputationSliceState *sliceState) {
+ // Get loop nest surrounding src operation.
+ SmallVector<AffineForOp, 4> srcLoopIVs;
+ getLoopIVs(*depSourceOp, &srcLoopIVs);
+ unsigned numSrcLoopIVs = srcLoopIVs.size();
+
+ // Get loop nest surrounding dst operation.
+ SmallVector<AffineForOp, 4> dstLoopIVs;
+ getLoopIVs(*depSinkOp, &dstLoopIVs);
+ unsigned numDstLoopIVs = dstLoopIVs.size();
+
+ assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
+ (isBackwardSlice && loopDepth <= numDstLoopIVs));
+
+ // Project out dimensions other than those up to 'loopDepth'.
+ unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
+ unsigned num =
+ isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
+ dependenceConstraints->projectOut(pos, num);
+
+ // Add slice loop IV values to 'sliceState'.
+ unsigned offset = isBackwardSlice ? 0 : loopDepth;
+ unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
+ dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs,
+ &sliceState->ivs);
+
+ // Set up lower/upper bound affine maps for the slice.
+ sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
+ sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
+
+ // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
+ dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
+ depSourceOp->getContext(),
+ &sliceState->lbs, &sliceState->ubs);
+
+ // Set up bound operands for the slice's lower and upper bounds.
+ SmallVector<Value, 4> sliceBoundOperands;
+ unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds();
+ for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
+ if (i < offset || i >= offset + numSliceLoopIVs) {
+ sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i));
+ }
+ }
+
+ // Give each bound its own copy of 'sliceBoundOperands' for subsequent
+ // canonicalization.
+ sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+ sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
+
+ // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
+ sliceState->insertPoint =
+ isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
+ : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
+
+ llvm::SmallDenseSet<Value, 8> sequentialLoops;
+ if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) {
+ // For read-read access pairs, clear any slice bounds on sequential loops.
+ // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
+ getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
+ &sequentialLoops);
+ }
+ // Clear all sliced loop bounds beginning at the first sequential loop, or
+ // first loop with a slice fusion barrier attribute..
+ // TODO(andydavis, bondhugula) Use MemRef read/write regions instead of
+ // using 'kSliceFusionBarrierAttrName'.
+ auto getSliceLoop = [&](unsigned i) {
+ return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
+ };
+ for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
+ Value iv = getSliceLoop(i).getInductionVar();
+ if (sequentialLoops.count(iv) == 0 &&
+ getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr)
+ continue;
+ for (unsigned j = i; j < numSliceLoopIVs; ++j) {
+ sliceState->lbs[j] = AffineMap();
+ sliceState->ubs[j] = AffineMap();
+ }
+ break;
+ }
+}
+
+/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
+/// updates the slice loop bounds with any non-null bound maps specified in
+/// 'sliceState', and inserts this slice into the loop nest surrounding
+/// 'dstOpInst' at loop depth 'dstLoopDepth'.
+// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that
+// aren't necessarily a one-to-one relation b/w the source and destination. The
+// relation between the source and destination could be many-to-many in general.
+// TODO(andydavis,bondhugula): the slice computation is incorrect in the cases
+// where the dependence from the source to the destination does not cover the
+// entire destination index set. Subtract out the dependent destination
+// iterations from destination index set and check for emptiness --- this is one
+// solution.
+AffineForOp
+mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst,
+ unsigned dstLoopDepth,
+ ComputationSliceState *sliceState) {
+ // Get loop nest surrounding src operation.
+ SmallVector<AffineForOp, 4> srcLoopIVs;
+ getLoopIVs(*srcOpInst, &srcLoopIVs);
+ unsigned numSrcLoopIVs = srcLoopIVs.size();
+
+ // Get loop nest surrounding dst operation.
+ SmallVector<AffineForOp, 4> dstLoopIVs;
+ getLoopIVs(*dstOpInst, &dstLoopIVs);
+ unsigned dstLoopIVsSize = dstLoopIVs.size();
+ if (dstLoopDepth > dstLoopIVsSize) {
+ dstOpInst->emitError("invalid destination loop depth");
+ return AffineForOp();
+ }
+
+ // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
+ SmallVector<unsigned, 4> positions;
+ // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d.
+ findInstPosition(srcOpInst, srcLoopIVs[0].getOperation()->getBlock(),
+ &positions);
+
+ // Clone src loop nest and insert it a the beginning of the operation block
+ // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
+ auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
+ OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
+ auto sliceLoopNest =
+ cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
+
+ Operation *sliceInst =
+ getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
+ // Get loop nest surrounding 'sliceInst'.
+ SmallVector<AffineForOp, 4> sliceSurroundingLoops;
+ getLoopIVs(*sliceInst, &sliceSurroundingLoops);
+
+ // Sanity check.
+ unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
+ (void)sliceSurroundingLoopsSize;
+ assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
+ unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
+ (void)sliceLoopLimit;
+ assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
+
+ // Update loop bounds for loops in 'sliceLoopNest'.
+ for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
+ auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
+ if (AffineMap lbMap = sliceState->lbs[i])
+ forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
+ if (AffineMap ubMap = sliceState->ubs[i])
+ forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
+ }
+ return sliceLoopNest;
+}
+
+// Constructs MemRefAccess populating it with the memref, its indices and
+// opinst from 'loadOrStoreOpInst'.
+MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
+ if (auto loadOp = dyn_cast<AffineLoadOp>(loadOrStoreOpInst)) {
+ memref = loadOp.getMemRef();
+ opInst = loadOrStoreOpInst;
+ auto loadMemrefType = loadOp.getMemRefType();
+ indices.reserve(loadMemrefType.getRank());
+ for (auto index : loadOp.getMapOperands()) {
+ indices.push_back(index);
+ }
+ } else {
+ assert(isa<AffineStoreOp>(loadOrStoreOpInst) && "load/store op expected");
+ auto storeOp = dyn_cast<AffineStoreOp>(loadOrStoreOpInst);
+ opInst = loadOrStoreOpInst;
+ memref = storeOp.getMemRef();
+ auto storeMemrefType = storeOp.getMemRefType();
+ indices.reserve(storeMemrefType.getRank());
+ for (auto index : storeOp.getMapOperands()) {
+ indices.push_back(index);
+ }
+ }
+}
+
+unsigned MemRefAccess::getRank() const {
+ return memref->getType().cast<MemRefType>().getRank();
+}
+
+bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
+
+/// Returns the nesting depth of this statement, i.e., the number of loops
+/// surrounding this statement.
+unsigned mlir::getNestingDepth(Operation &op) {
+ Operation *currOp = &op;
+ unsigned depth = 0;
+ while ((currOp = currOp->getParentOp())) {
+ if (isa<AffineForOp>(currOp))
+ depth++;
+ }
+ return depth;
+}
+
+/// Equal if both affine accesses are provably equivalent (at compile
+/// time) when considering the memref, the affine maps and their respective
+/// operands. The equality of access functions + operands is checked by
+/// subtracting fully composed value maps, and then simplifying the difference
+/// using the expression flattener.
+/// TODO: this does not account for aliasing of memrefs.
+bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
+ if (memref != rhs.memref)
+ return false;
+
+ AffineValueMap diff, thisMap, rhsMap;
+ getAccessMap(&thisMap);
+ rhs.getAccessMap(&rhsMap);
+ AffineValueMap::difference(thisMap, rhsMap, &diff);
+ return llvm::all_of(diff.getAffineMap().getResults(),
+ [](AffineExpr e) { return e == 0; });
+}
+
+/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
+/// where each lists loops from outer-most to inner-most in loop nest.
+unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) {
+ SmallVector<AffineForOp, 4> loopsA, loopsB;
+ getLoopIVs(A, &loopsA);
+ getLoopIVs(B, &loopsB);
+
+ unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
+ unsigned numCommonLoops = 0;
+ for (unsigned i = 0; i < minNumLoops; ++i) {
+ if (loopsA[i].getOperation() != loopsB[i].getOperation())
+ break;
+ ++numCommonLoops;
+ }
+ return numCommonLoops;
+}
+
+static Optional<int64_t> getMemoryFootprintBytes(Block &block,
+ Block::iterator start,
+ Block::iterator end,
+ int memorySpace) {
+ SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
+
+ // Walk this 'affine.for' operation to gather all memory regions.
+ auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
+ if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) {
+ // Neither load nor a store op.
+ return WalkResult::advance();
+ }
+
+ // Compute the memref region symbolic in any IVs enclosing this block.
+ auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
+ if (failed(
+ region->compute(opInst,
+ /*loopDepth=*/getNestingDepth(*block.begin())))) {
+ return opInst->emitError("error obtaining memory region\n");
+ }
+
+ auto it = regions.find(region->memref);
+ if (it == regions.end()) {
+ regions[region->memref] = std::move(region);
+ } else if (failed(it->second->unionBoundingBox(*region))) {
+ return opInst->emitWarning(
+ "getMemoryFootprintBytes: unable to perform a union on a memory "
+ "region");
+ }
+ return WalkResult::advance();
+ });
+ if (result.wasInterrupted())
+ return None;
+
+ int64_t totalSizeInBytes = 0;
+ for (const auto &region : regions) {
+ Optional<int64_t> size = region.second->getRegionSize();
+ if (!size.hasValue())
+ return None;
+ totalSizeInBytes += size.getValue();
+ }
+ return totalSizeInBytes;
+}
+
+Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp,
+ int memorySpace) {
+ auto *forInst = forOp.getOperation();
+ return ::getMemoryFootprintBytes(
+ *forInst->getBlock(), Block::iterator(forInst),
+ std::next(Block::iterator(forInst)), memorySpace);
+}
+
+/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
+/// at 'forOp'.
+void mlir::getSequentialLoops(AffineForOp forOp,
+ llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
+ forOp.getOperation()->walk([&](Operation *op) {
+ if (auto innerFor = dyn_cast<AffineForOp>(op))
+ if (!isLoopParallel(innerFor))
+ sequentialLoops->insert(innerFor.getInductionVar());
+ });
+}
+
+/// Returns true if 'forOp' is parallel.
+bool mlir::isLoopParallel(AffineForOp forOp) {
+ // Collect all load and store ops in loop nest rooted at 'forOp'.
+ SmallVector<Operation *, 8> loadAndStoreOpInsts;
+ auto walkResult = forOp.walk([&](Operation *opInst) {
+ if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst))
+ loadAndStoreOpInsts.push_back(opInst);
+ else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) &&
+ !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect())
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ // Stop early if the loop has unknown ops with side effects.
+ if (walkResult.wasInterrupted())
+ return false;
+
+ // Dep check depth would be number of enclosing loops + 1.
+ unsigned depth = getNestingDepth(*forOp.getOperation()) + 1;
+
+ // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'.
+ for (auto *srcOpInst : loadAndStoreOpInsts) {
+ MemRefAccess srcAccess(srcOpInst);
+ for (auto *dstOpInst : loadAndStoreOpInsts) {
+ MemRefAccess dstAccess(dstOpInst);
+ FlatAffineConstraints dependenceConstraints;
+ DependenceResult result = checkMemrefAccessDependence(
+ srcAccess, dstAccess, depth, &dependenceConstraints,
+ /*dependenceComponents=*/nullptr);
+ if (result.value != DependenceResult::NoDependence)
+ return false;
+ }
+ }
+ return true;
+}
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
new file mode 100644
index 00000000000..1c7dbed5fac
--- /dev/null
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -0,0 +1,232 @@
+//===- VectorAnalysis.cpp - Analysis for Vectorization --------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/Utils.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/STLExtras.h"
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SetVector.h"
+
+///
+/// Implements Analysis functions specific to vectors which support
+/// the vectorization and vectorization materialization passes.
+///
+
+using namespace mlir;
+
+using llvm::SetVector;
+
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
+ ArrayRef<int64_t> subShape) {
+ if (superShape.size() < subShape.size()) {
+ return Optional<SmallVector<int64_t, 4>>();
+ }
+
+ // Starting from the end, compute the integer divisors.
+ // Set the boolean `divides` if integral division is not possible.
+ std::vector<int64_t> result;
+ result.reserve(superShape.size());
+ bool divides = true;
+ auto divide = [&divides, &result](int superSize, int subSize) {
+ assert(superSize > 0 && "superSize must be > 0");
+ assert(subSize > 0 && "subSize must be > 0");
+ divides &= (superSize % subSize == 0);
+ result.push_back(superSize / subSize);
+ };
+ functional::zipApply(
+ divide, SmallVector<int64_t, 8>{superShape.rbegin(), superShape.rend()},
+ SmallVector<int64_t, 8>{subShape.rbegin(), subShape.rend()});
+
+ // If integral division does not occur, return and let the caller decide.
+ if (!divides) {
+ return None;
+ }
+
+ // At this point we computed the ratio (in reverse) for the common
+ // size. Fill with the remaining entries from the super-vector shape (still in
+ // reverse).
+ int commonSize = subShape.size();
+ std::copy(superShape.rbegin() + commonSize, superShape.rend(),
+ std::back_inserter(result));
+
+ assert(result.size() == superShape.size() &&
+ "super to sub shape ratio is not of the same size as the super rank");
+
+ // Reverse again to get it back in the proper order and return.
+ return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
+}
+
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
+ VectorType subVectorType) {
+ assert(superVectorType.getElementType() == subVectorType.getElementType() &&
+ "vector types must be of the same elemental type");
+ return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
+}
+
+/// Constructs a permutation map from memref indices to vector dimension.
+///
+/// The implementation uses the knowledge of the mapping of enclosing loop to
+/// vector dimension. `enclosingLoopToVectorDim` carries this information as a
+/// map with:
+/// - keys representing "vectorized enclosing loops";
+/// - values representing the corresponding vector dimension.
+/// The algorithm traverses "vectorized enclosing loops" and extracts the
+/// at-most-one MemRef index that is invariant along said loop. This index is
+/// guaranteed to be at most one by construction: otherwise the MemRef is not
+/// vectorizable.
+/// If this invariant index is found, it is added to the permutation_map at the
+/// proper vector dimension.
+/// If no index is found to be invariant, 0 is added to the permutation_map and
+/// corresponds to a vector broadcast along that dimension.
+///
+/// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty,
+/// signalling that no permutation map can be constructed given
+/// `enclosingLoopToVectorDim`.
+///
+/// Examples can be found in the documentation of `makePermutationMap`, in the
+/// header file.
+static AffineMap makePermutationMap(
+ ArrayRef<Value> indices,
+ const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) {
+ if (enclosingLoopToVectorDim.empty())
+ return AffineMap();
+ MLIRContext *context =
+ enclosingLoopToVectorDim.begin()->getFirst()->getContext();
+ using functional::makePtrDynCaster;
+ using functional::map;
+ SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(),
+ getAffineConstantExpr(0, context));
+
+ for (auto kvp : enclosingLoopToVectorDim) {
+ assert(kvp.second < perm.size());
+ auto invariants = getInvariantAccesses(
+ cast<AffineForOp>(kvp.first).getInductionVar(), indices);
+ unsigned numIndices = indices.size();
+ unsigned countInvariantIndices = 0;
+ for (unsigned dim = 0; dim < numIndices; ++dim) {
+ if (!invariants.count(indices[dim])) {
+ assert(perm[kvp.second] == getAffineConstantExpr(0, context) &&
+ "permutationMap already has an entry along dim");
+ perm[kvp.second] = getAffineDimExpr(dim, context);
+ } else {
+ ++countInvariantIndices;
+ }
+ }
+ assert((countInvariantIndices == numIndices ||
+ countInvariantIndices == numIndices - 1) &&
+ "Vectorization prerequisite violated: at most 1 index may be "
+ "invariant wrt a vectorized loop");
+ }
+ return AffineMap::get(indices.size(), 0, perm);
+}
+
+/// Implementation detail that walks up the parents and records the ones with
+/// the specified type.
+/// TODO(ntv): could also be implemented as a collect parents followed by a
+/// filter and made available outside this file.
+template <typename T>
+static SetVector<Operation *> getParentsOfType(Operation *op) {
+ SetVector<Operation *> res;
+ auto *current = op;
+ while (auto *parent = current->getParentOp()) {
+ if (auto typedParent = dyn_cast<T>(parent)) {
+ assert(res.count(parent) == 0 && "Already inserted");
+ res.insert(parent);
+ }
+ current = parent;
+ }
+ return res;
+}
+
+/// Returns the enclosing AffineForOp, from closest to farthest.
+static SetVector<Operation *> getEnclosingforOps(Operation *op) {
+ return getParentsOfType<AffineForOp>(op);
+}
+
+AffineMap mlir::makePermutationMap(
+ Operation *op, ArrayRef<Value> indices,
+ const DenseMap<Operation *, unsigned> &loopToVectorDim) {
+ DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
+ auto enclosingLoops = getEnclosingforOps(op);
+ for (auto *forInst : enclosingLoops) {
+ auto it = loopToVectorDim.find(forInst);
+ if (it != loopToVectorDim.end()) {
+ enclosingLoopToVectorDim.insert(*it);
+ }
+ }
+ return ::makePermutationMap(indices, enclosingLoopToVectorDim);
+}
+
+bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op,
+ VectorType subVectorType) {
+ // First, extract the vector type and distinguish between:
+ // a. ops that *must* lower a super-vector (i.e. vector.transfer_read,
+ // vector.transfer_write); and
+ // b. ops that *may* lower a super-vector (all other ops).
+ // The ops that *may* lower a super-vector only do so if the super-vector to
+ // sub-vector ratio exists. The ops that *must* lower a super-vector are
+ // explicitly checked for this property.
+ /// TODO(ntv): there should be a single function for all ops to do this so we
+ /// do not have to special case. Maybe a trait, or just a method, unclear atm.
+ bool mustDivide = false;
+ (void)mustDivide;
+ VectorType superVectorType;
+ if (auto read = dyn_cast<vector::TransferReadOp>(op)) {
+ superVectorType = read.getVectorType();
+ mustDivide = true;
+ } else if (auto write = dyn_cast<vector::TransferWriteOp>(op)) {
+ superVectorType = write.getVectorType();
+ mustDivide = true;
+ } else if (op.getNumResults() == 0) {
+ if (!isa<ReturnOp>(op)) {
+ op.emitError("NYI: assuming only return operations can have 0 "
+ " results at this point");
+ }
+ return false;
+ } else if (op.getNumResults() == 1) {
+ if (auto v = op.getResult(0)->getType().dyn_cast<VectorType>()) {
+ superVectorType = v;
+ } else {
+ // Not a vector type.
+ return false;
+ }
+ } else {
+ // Not a vector.transfer and has more than 1 result, fail hard for now to
+ // wake us up when something changes.
+ op.emitError("NYI: operation has more than 1 result");
+ return false;
+ }
+
+ // Get the ratio.
+ auto ratio = shapeRatio(superVectorType, subVectorType);
+
+ // Sanity check.
+ assert((ratio.hasValue() || !mustDivide) &&
+ "vector.transfer operation in which super-vector size is not an"
+ " integer multiple of sub-vector size");
+
+ // This catches cases that are not strictly necessary to have multiplicity but
+ // still aren't divisible by the sub-vector shape.
+ // This could be useful information if we wanted to reshape at the level of
+ // the vector type (but we would have to look at the compute and distinguish
+ // between parallel, reduction and possibly other cases.
+ if (!ratio.hasValue()) {
+ return false;
+ }
+
+ return true;
+}
diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp
new file mode 100644
index 00000000000..d4861b1a2e7
--- /dev/null
+++ b/mlir/lib/Analysis/Verifier.cpp
@@ -0,0 +1,266 @@
+//===- Verifier.cpp - MLIR Verifier Implementation ------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the verify() methods on the various IR types, performing
+// (potentially expensive) checks on the holistic structure of the code. This
+// can be used for detecting bugs in compiler transformations and hand written
+// .mlir files.
+//
+// The checks in this file are only for things that can occur as part of IR
+// transformations: e.g. violation of dominance information, malformed operation
+// attributes, etc. MLIR supports transformations moving IR through locally
+// invalid states (e.g. unlinking an operation from a block before re-inserting
+// it in a new place), but each transformation must complete with the IR in a
+// valid form.
+//
+// This should not check for things that are always wrong by construction (e.g.
+// attributes or other immutable structures that are incorrect), because those
+// are not mutable and can be checked at time of construction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Verifier.h"
+#include "mlir/Analysis/Dominance.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/Regex.h"
+
+using namespace mlir;
+
+namespace {
+/// This class encapsulates all the state used to verify an operation region.
+class OperationVerifier {
+public:
+ explicit OperationVerifier(MLIRContext *ctx)
+ : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
+
+ /// Verify the given operation.
+ LogicalResult verify(Operation &op);
+
+ /// Returns the registered dialect for a dialect-specific attribute.
+ Dialect *getDialectForAttribute(const NamedAttribute &attr) {
+ assert(attr.first.strref().contains('.') && "expected dialect attribute");
+ auto dialectNamePair = attr.first.strref().split('.');
+ return ctx->getRegisteredDialect(dialectNamePair.first);
+ }
+
+ /// Returns if the given string is valid to use as an identifier name.
+ bool isValidName(StringRef name) { return identifierRegex.match(name); }
+
+private:
+ /// Verify the given potentially nested region or block.
+ LogicalResult verifyRegion(Region &region);
+ LogicalResult verifyBlock(Block &block);
+ LogicalResult verifyOperation(Operation &op);
+
+ /// Verify the dominance within the given IR unit.
+ LogicalResult verifyDominance(Region &region);
+ LogicalResult verifyDominance(Operation &op);
+
+ /// Emit an error for the given block.
+ InFlightDiagnostic emitError(Block &bb, const Twine &message) {
+ // Take the location information for the first operation in the block.
+ if (!bb.empty())
+ return bb.front().emitError(message);
+
+ // Worst case, fall back to using the parent's location.
+ return mlir::emitError(bb.getParent()->getLoc(), message);
+ }
+
+ /// The current context for the verifier.
+ MLIRContext *ctx;
+
+ /// Dominance information for this operation, when checking dominance.
+ DominanceInfo *domInfo = nullptr;
+
+ /// Regex checker for attribute names.
+ llvm::Regex identifierRegex;
+
+ /// Mapping between dialect namespace and if that dialect supports
+ /// unregistered operations.
+ llvm::StringMap<bool> dialectAllowsUnknownOps;
+};
+} // end anonymous namespace
+
+/// Verify the given operation.
+LogicalResult OperationVerifier::verify(Operation &op) {
+ // Verify the operation first.
+ if (failed(verifyOperation(op)))
+ return failure();
+
+ // Since everything looks structurally ok to this point, we do a dominance
+ // check for any nested regions. We do this as a second pass since malformed
+ // CFG's can cause dominator analysis constructure to crash and we want the
+ // verifier to be resilient to malformed code.
+ DominanceInfo theDomInfo(&op);
+ domInfo = &theDomInfo;
+ for (auto &region : op.getRegions())
+ if (failed(verifyDominance(region)))
+ return failure();
+
+ domInfo = nullptr;
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyRegion(Region &region) {
+ if (region.empty())
+ return success();
+
+ // Verify the first block has no predecessors.
+ auto *firstBB = &region.front();
+ if (!firstBB->hasNoPredecessors())
+ return mlir::emitError(region.getLoc(),
+ "entry block of region may not have predecessors");
+
+ // Verify each of the blocks within the region.
+ for (auto &block : region)
+ if (failed(verifyBlock(block)))
+ return failure();
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyBlock(Block &block) {
+ for (auto arg : block.getArguments())
+ if (arg->getOwner() != &block)
+ return emitError(block, "block argument not owned by block");
+
+ // Verify that this block has a terminator.
+ if (block.empty())
+ return emitError(block, "block with no terminator");
+
+ // Verify the non-terminator operations separately so that we can verify
+ // they has no successors.
+ for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) {
+ if (op.getNumSuccessors() != 0)
+ return op.emitError(
+ "operation with block successors must terminate its parent block");
+
+ if (failed(verifyOperation(op)))
+ return failure();
+ }
+
+ // Verify the terminator.
+ if (failed(verifyOperation(block.back())))
+ return failure();
+ if (block.back().isKnownNonTerminator())
+ return emitError(block, "block with no terminator");
+
+ // Verify that this block is not branching to a block of a different
+ // region.
+ for (Block *successor : block.getSuccessors())
+ if (successor->getParent() != block.getParent())
+ return block.back().emitOpError(
+ "branching to block of a different region");
+
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyOperation(Operation &op) {
+ // Check that operands are non-nil and structurally ok.
+ for (auto operand : op.getOperands())
+ if (!operand)
+ return op.emitError("null operand found");
+
+ /// Verify that all of the attributes are okay.
+ for (auto attr : op.getAttrs()) {
+ if (!identifierRegex.match(attr.first))
+ return op.emitError("invalid attribute name '") << attr.first << "'";
+
+ // Check for any optional dialect specific attributes.
+ if (!attr.first.strref().contains('.'))
+ continue;
+ if (auto *dialect = getDialectForAttribute(attr))
+ if (failed(dialect->verifyOperationAttribute(&op, attr)))
+ return failure();
+ }
+
+ // If we can get operation info for this, check the custom hook.
+ auto *opInfo = op.getAbstractOperation();
+ if (opInfo && failed(opInfo->verifyInvariants(&op)))
+ return failure();
+
+ // Verify that all child regions are ok.
+ for (auto &region : op.getRegions())
+ if (failed(verifyRegion(region)))
+ return failure();
+
+ // If this is a registered operation, there is nothing left to do.
+ if (opInfo)
+ return success();
+
+ // Otherwise, verify that the parent dialect allows un-registered operations.
+ auto dialectPrefix = op.getName().getDialect();
+
+ // Check for an existing answer for the operation dialect.
+ auto it = dialectAllowsUnknownOps.find(dialectPrefix);
+ if (it == dialectAllowsUnknownOps.end()) {
+ // If the operation dialect is registered, query it directly.
+ if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
+ it = dialectAllowsUnknownOps
+ .try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
+ .first;
+ // Otherwise, conservatively allow unknown operations.
+ else
+ it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
+ }
+
+ if (!it->second) {
+ return op.emitError("unregistered operation '")
+ << op.getName() << "' found in dialect ('" << dialectPrefix
+ << "') that does not allow unknown operations";
+ }
+
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyDominance(Region &region) {
+ // Verify the dominance of each of the held operations.
+ for (auto &block : region)
+ for (auto &op : block)
+ if (failed(verifyDominance(op)))
+ return failure();
+ return success();
+}
+
+LogicalResult OperationVerifier::verifyDominance(Operation &op) {
+ // Check that operands properly dominate this use.
+ for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e;
+ ++operandNo) {
+ auto operand = op.getOperand(operandNo);
+ if (domInfo->properlyDominates(operand, &op))
+ continue;
+
+ auto diag = op.emitError("operand #")
+ << operandNo << " does not dominate this use";
+ if (auto *useOp = operand->getDefiningOp())
+ diag.attachNote(useOp->getLoc()) << "operand defined here";
+ return failure();
+ }
+
+ // Verify the dominance of each of the nested blocks within this operation.
+ for (auto &region : op.getRegions())
+ if (failed(verifyDominance(region)))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Entrypoint
+//===----------------------------------------------------------------------===//
+
+/// Perform (potentially expensive) checks of invariants, used to detect
+/// compiler bugs. On error, this reports the error through the MLIRContext and
+/// returns failure.
+LogicalResult mlir::verify(Operation *op) {
+ return OperationVerifier(op->getContext()).verify(*op);
+}
OpenPOWER on IntegriCloud