diff options
Diffstat (limited to 'mlir/lib/Analysis')
-rw-r--r-- | mlir/lib/Analysis/AffineAnalysis.cpp | 886 | ||||
-rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 2854 | ||||
-rw-r--r-- | mlir/lib/Analysis/CMakeLists.txt | 29 | ||||
-rw-r--r-- | mlir/lib/Analysis/CallGraph.cpp | 256 | ||||
-rw-r--r-- | mlir/lib/Analysis/Dominance.cpp | 171 | ||||
-rw-r--r-- | mlir/lib/Analysis/InferTypeOpInterface.cpp | 22 | ||||
-rw-r--r-- | mlir/lib/Analysis/Liveness.cpp | 373 | ||||
-rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 388 | ||||
-rw-r--r-- | mlir/lib/Analysis/MemRefBoundCheck.cpp | 53 | ||||
-rw-r--r-- | mlir/lib/Analysis/NestedMatcher.cpp | 152 | ||||
-rw-r--r-- | mlir/lib/Analysis/OpStats.cpp | 84 | ||||
-rw-r--r-- | mlir/lib/Analysis/SliceAnalysis.cpp | 213 | ||||
-rw-r--r-- | mlir/lib/Analysis/TestMemRefDependenceCheck.cpp | 121 | ||||
-rw-r--r-- | mlir/lib/Analysis/TestParallelismDetection.cpp | 48 | ||||
-rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 1007 | ||||
-rw-r--r-- | mlir/lib/Analysis/VectorAnalysis.cpp | 232 | ||||
-rw-r--r-- | mlir/lib/Analysis/Verifier.cpp | 266 |
17 files changed, 7155 insertions, 0 deletions
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp new file mode 100644 index 00000000000..3358bb437ff --- /dev/null +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -0,0 +1,886 @@ +//===- AffineAnalysis.cpp - Affine structures analysis routines -----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous analysis routines for affine structures +// (expressions, maps, sets), and other utilities relying on such analysis. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "affine-analysis" + +using namespace mlir; + +using llvm::dbgs; + +/// Returns the sequence of AffineApplyOp Operations operation in +/// 'affineApplyOps', which are reachable via a search starting from 'operands', +/// and ending at operands which are not defined by AffineApplyOps. +// TODO(andydavis) Add a method to AffineApplyOp which forward substitutes +// the AffineApplyOp into any user AffineApplyOps. +void mlir::getReachableAffineApplyOps( + ArrayRef<Value> operands, SmallVectorImpl<Operation *> &affineApplyOps) { + struct State { + // The ssa value for this node in the DFS traversal. + Value value; + // The operand index of 'value' to explore next during DFS traversal. + unsigned operandIndex; + }; + SmallVector<State, 4> worklist; + for (auto operand : operands) { + worklist.push_back({operand, 0}); + } + + while (!worklist.empty()) { + State &state = worklist.back(); + auto *opInst = state.value->getDefiningOp(); + // Note: getDefiningOp will return nullptr if the operand is not an + // Operation (i.e. block argument), which is a terminator for the search. + if (!isa_and_nonnull<AffineApplyOp>(opInst)) { + worklist.pop_back(); + continue; + } + + if (state.operandIndex == 0) { + // Pre-Visit: Add 'opInst' to reachable sequence. + affineApplyOps.push_back(opInst); + } + if (state.operandIndex < opInst->getNumOperands()) { + // Visit: Add next 'affineApplyOp' operand to worklist. + // Get next operand to visit at 'operandIndex'. + auto nextOperand = opInst->getOperand(state.operandIndex); + // Increment 'operandIndex' in 'state'. + ++state.operandIndex; + // Add 'nextOperand' to worklist. + worklist.push_back({nextOperand, 0}); + } else { + // Post-visit: done visiting operands AffineApplyOp, pop off stack. + worklist.pop_back(); + } + } +} + +// Builds a system of constraints with dimensional identifiers corresponding to +// the loop IVs of the forOps appearing in that order. Any symbols founds in +// the bound operands are added as symbols in the system. Returns failure for +// the yet unimplemented cases. +// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or +// stride information in FlatAffineConstraints. (For eg., by using iv - lb % +// step = 0 and/or by introducing a method in FlatAffineConstraints +// setExprStride(ArrayRef<int64_t> expr, int64_t stride) +LogicalResult mlir::getIndexSet(MutableArrayRef<AffineForOp> forOps, + FlatAffineConstraints *domain) { + SmallVector<Value, 4> indices; + extractForInductionVars(forOps, &indices); + // Reset while associated Values in 'indices' to the domain. + domain->reset(forOps.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); + for (auto forOp : forOps) { + // Add constraints from forOp's bounds. + if (failed(domain->addAffineForOpDomain(forOp))) + return failure(); + } + return success(); +} + +// Computes the iteration domain for 'opInst' and populates 'indexSet', which +// encapsulates the constraints involving loops surrounding 'opInst' and +// potentially involving any Function symbols. The dimensional identifiers in +// 'indexSet' correspond to the loops surrounding 'op' from outermost to +// innermost. +// TODO(andydavis) Add support to handle IfInsts surrounding 'op'. +static LogicalResult getInstIndexSet(Operation *op, + FlatAffineConstraints *indexSet) { + // TODO(andydavis) Extend this to gather enclosing IfInsts and consider + // factoring it out into a utility function. + SmallVector<AffineForOp, 4> loops; + getLoopIVs(*op, &loops); + return getIndexSet(loops, indexSet); +} + +// ValuePositionMap manages the mapping from Values which represent dimension +// and symbol identifiers from 'src' and 'dst' access functions to positions +// in new space where some Values are kept separate (using addSrc/DstValue) +// and some Values are merged (addSymbolValue). +// Position lookups return the absolute position in the new space which +// has the following format: +// +// [src-dim-identifiers] [dst-dim-identifiers] [symbol-identifiers] +// +// Note: access function non-IV dimension identifiers (that have 'dimension' +// positions in the access function position space) are assigned as symbols +// in the output position space. Convenience access functions which lookup +// an Value in multiple maps are provided (i.e. getSrcDimOrSymPos) to handle +// the common case of resolving positions for all access function operands. +// +// TODO(andydavis) Generalize this: could take a template parameter for +// the number of maps (3 in the current case), and lookups could take indices +// of maps to check. So getSrcDimOrSymPos would be "getPos(value, {0, 2})". +class ValuePositionMap { +public: + void addSrcValue(Value value) { + if (addValueAt(value, &srcDimPosMap, numSrcDims)) + ++numSrcDims; + } + void addDstValue(Value value) { + if (addValueAt(value, &dstDimPosMap, numDstDims)) + ++numDstDims; + } + void addSymbolValue(Value value) { + if (addValueAt(value, &symbolPosMap, numSymbols)) + ++numSymbols; + } + unsigned getSrcDimOrSymPos(Value value) const { + return getDimOrSymPos(value, srcDimPosMap, 0); + } + unsigned getDstDimOrSymPos(Value value) const { + return getDimOrSymPos(value, dstDimPosMap, numSrcDims); + } + unsigned getSymPos(Value value) const { + auto it = symbolPosMap.find(value); + assert(it != symbolPosMap.end()); + return numSrcDims + numDstDims + it->second; + } + + unsigned getNumSrcDims() const { return numSrcDims; } + unsigned getNumDstDims() const { return numDstDims; } + unsigned getNumDims() const { return numSrcDims + numDstDims; } + unsigned getNumSymbols() const { return numSymbols; } + +private: + bool addValueAt(Value value, DenseMap<Value, unsigned> *posMap, + unsigned position) { + auto it = posMap->find(value); + if (it == posMap->end()) { + (*posMap)[value] = position; + return true; + } + return false; + } + unsigned getDimOrSymPos(Value value, + const DenseMap<Value, unsigned> &dimPosMap, + unsigned dimPosOffset) const { + auto it = dimPosMap.find(value); + if (it != dimPosMap.end()) { + return dimPosOffset + it->second; + } + it = symbolPosMap.find(value); + assert(it != symbolPosMap.end()); + return numSrcDims + numDstDims + it->second; + } + + unsigned numSrcDims = 0; + unsigned numDstDims = 0; + unsigned numSymbols = 0; + DenseMap<Value, unsigned> srcDimPosMap; + DenseMap<Value, unsigned> dstDimPosMap; + DenseMap<Value, unsigned> symbolPosMap; +}; + +// Builds a map from Value to identifier position in a new merged identifier +// list, which is the result of merging dim/symbol lists from src/dst +// iteration domains, the format of which is as follows: +// +// [src-dim-identifiers, dst-dim-identifiers, symbol-identifiers, const_term] +// +// This method populates 'valuePosMap' with mappings from operand Values in +// 'srcAccessMap'/'dstAccessMap' (as well as those in 'srcDomain'/'dstDomain') +// to the position of these values in the merged list. +static void buildDimAndSymbolPositionMaps( + const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, const AffineValueMap &srcAccessMap, + const AffineValueMap &dstAccessMap, ValuePositionMap *valuePosMap, + FlatAffineConstraints *dependenceConstraints) { + auto updateValuePosMap = [&](ArrayRef<Value> values, bool isSrc) { + for (unsigned i = 0, e = values.size(); i < e; ++i) { + auto value = values[i]; + if (!isForInductionVar(values[i])) { + assert(isValidSymbol(values[i]) && + "access operand has to be either a loop IV or a symbol"); + valuePosMap->addSymbolValue(value); + } else if (isSrc) { + valuePosMap->addSrcValue(value); + } else { + valuePosMap->addDstValue(value); + } + } + }; + + SmallVector<Value, 4> srcValues, destValues; + srcDomain.getIdValues(0, srcDomain.getNumDimAndSymbolIds(), &srcValues); + dstDomain.getIdValues(0, dstDomain.getNumDimAndSymbolIds(), &destValues); + // Update value position map with identifiers from src iteration domain. + updateValuePosMap(srcValues, /*isSrc=*/true); + // Update value position map with identifiers from dst iteration domain. + updateValuePosMap(destValues, /*isSrc=*/false); + // Update value position map with identifiers from src access function. + updateValuePosMap(srcAccessMap.getOperands(), /*isSrc=*/true); + // Update value position map with identifiers from dst access function. + updateValuePosMap(dstAccessMap.getOperands(), /*isSrc=*/false); +} + +// Sets up dependence constraints columns appropriately, in the format: +// [src-dim-ids, dst-dim-ids, symbol-ids, local-ids, const_term] +void initDependenceConstraints(const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, + const AffineValueMap &srcAccessMap, + const AffineValueMap &dstAccessMap, + const ValuePositionMap &valuePosMap, + FlatAffineConstraints *dependenceConstraints) { + // Calculate number of equalities/inequalities and columns required to + // initialize FlatAffineConstraints for 'dependenceDomain'. + unsigned numIneq = + srcDomain.getNumInequalities() + dstDomain.getNumInequalities(); + AffineMap srcMap = srcAccessMap.getAffineMap(); + assert(srcMap.getNumResults() == dstAccessMap.getAffineMap().getNumResults()); + unsigned numEq = srcMap.getNumResults(); + unsigned numDims = srcDomain.getNumDimIds() + dstDomain.getNumDimIds(); + unsigned numSymbols = valuePosMap.getNumSymbols(); + unsigned numLocals = srcDomain.getNumLocalIds() + dstDomain.getNumLocalIds(); + unsigned numIds = numDims + numSymbols + numLocals; + unsigned numCols = numIds + 1; + + // Set flat affine constraints sizes and reserving space for constraints. + dependenceConstraints->reset(numIneq, numEq, numCols, numDims, numSymbols, + numLocals); + + // Set values corresponding to dependence constraint identifiers. + SmallVector<Value, 4> srcLoopIVs, dstLoopIVs; + srcDomain.getIdValues(0, srcDomain.getNumDimIds(), &srcLoopIVs); + dstDomain.getIdValues(0, dstDomain.getNumDimIds(), &dstLoopIVs); + + dependenceConstraints->setIdValues(0, srcLoopIVs.size(), srcLoopIVs); + dependenceConstraints->setIdValues( + srcLoopIVs.size(), srcLoopIVs.size() + dstLoopIVs.size(), dstLoopIVs); + + // Set values for the symbolic identifier dimensions. + auto setSymbolIds = [&](ArrayRef<Value> values) { + for (auto value : values) { + if (!isForInductionVar(value)) { + assert(isValidSymbol(value) && "expected symbol"); + dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value); + } + } + }; + + setSymbolIds(srcAccessMap.getOperands()); + setSymbolIds(dstAccessMap.getOperands()); + + SmallVector<Value, 8> srcSymbolValues, dstSymbolValues; + srcDomain.getIdValues(srcDomain.getNumDimIds(), + srcDomain.getNumDimAndSymbolIds(), &srcSymbolValues); + dstDomain.getIdValues(dstDomain.getNumDimIds(), + dstDomain.getNumDimAndSymbolIds(), &dstSymbolValues); + setSymbolIds(srcSymbolValues); + setSymbolIds(dstSymbolValues); + + for (unsigned i = 0, e = dependenceConstraints->getNumDimAndSymbolIds(); + i < e; i++) + assert(dependenceConstraints->getIds()[i].hasValue()); +} + +// Adds iteration domain constraints from 'srcDomain' and 'dstDomain' into +// 'dependenceDomain'. +// Uses 'valuePosMap' to determine the position in 'dependenceDomain' to which a +// srcDomain/dstDomain Value maps. +static void addDomainConstraints(const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, + const ValuePositionMap &valuePosMap, + FlatAffineConstraints *dependenceDomain) { + unsigned depNumDimsAndSymbolIds = dependenceDomain->getNumDimAndSymbolIds(); + + SmallVector<int64_t, 4> cst(dependenceDomain->getNumCols()); + + auto addDomain = [&](bool isSrc, bool isEq, unsigned localOffset) { + const FlatAffineConstraints &domain = isSrc ? srcDomain : dstDomain; + unsigned numCsts = + isEq ? domain.getNumEqualities() : domain.getNumInequalities(); + unsigned numDimAndSymbolIds = domain.getNumDimAndSymbolIds(); + auto at = [&](unsigned i, unsigned j) -> int64_t { + return isEq ? domain.atEq(i, j) : domain.atIneq(i, j); + }; + auto map = [&](unsigned i) -> int64_t { + return isSrc ? valuePosMap.getSrcDimOrSymPos(domain.getIdValue(i)) + : valuePosMap.getDstDimOrSymPos(domain.getIdValue(i)); + }; + + for (unsigned i = 0; i < numCsts; ++i) { + // Zero fill. + std::fill(cst.begin(), cst.end(), 0); + // Set coefficients for identifiers corresponding to domain. + for (unsigned j = 0; j < numDimAndSymbolIds; ++j) + cst[map(j)] = at(i, j); + // Local terms. + for (unsigned j = 0, e = domain.getNumLocalIds(); j < e; j++) + cst[depNumDimsAndSymbolIds + localOffset + j] = + at(i, numDimAndSymbolIds + j); + // Set constant term. + cst[cst.size() - 1] = at(i, domain.getNumCols() - 1); + // Add constraint. + if (isEq) + dependenceDomain->addEquality(cst); + else + dependenceDomain->addInequality(cst); + } + }; + + // Add equalities from src domain. + addDomain(/*isSrc=*/true, /*isEq=*/true, /*localOffset=*/0); + // Add inequalities from src domain. + addDomain(/*isSrc=*/true, /*isEq=*/false, /*localOffset=*/0); + // Add equalities from dst domain. + addDomain(/*isSrc=*/false, /*isEq=*/true, + /*localOffset=*/srcDomain.getNumLocalIds()); + // Add inequalities from dst domain. + addDomain(/*isSrc=*/false, /*isEq=*/false, + /*localOffset=*/srcDomain.getNumLocalIds()); +} + +// Adds equality constraints that equate src and dst access functions +// represented by 'srcAccessMap' and 'dstAccessMap' for each result. +// Requires that 'srcAccessMap' and 'dstAccessMap' have the same results count. +// For example, given the following two accesses functions to a 2D memref: +// +// Source access function: +// (a0 * d0 + a1 * s0 + a2, b0 * d0 + b1 * s0 + b2) +// +// Destination access function: +// (c0 * d0 + c1 * s0 + c2, f0 * d0 + f1 * s0 + f2) +// +// This method constructs the following equality constraints in +// 'dependenceDomain', by equating the access functions for each result +// (i.e. each memref dim). Notice that 'd0' for the destination access function +// is mapped into 'd0' in the equality constraint: +// +// d0 d1 s0 c +// -- -- -- -- +// a0 -c0 (a1 - c1) (a1 - c2) = 0 +// b0 -f0 (b1 - f1) (b1 - f2) = 0 +// +// Returns failure if any AffineExpr cannot be flattened (due to it being +// semi-affine). Returns success otherwise. +static LogicalResult +addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, + const AffineValueMap &dstAccessMap, + const ValuePositionMap &valuePosMap, + FlatAffineConstraints *dependenceDomain) { + AffineMap srcMap = srcAccessMap.getAffineMap(); + AffineMap dstMap = dstAccessMap.getAffineMap(); + assert(srcMap.getNumResults() == dstMap.getNumResults()); + unsigned numResults = srcMap.getNumResults(); + + unsigned srcNumIds = srcMap.getNumDims() + srcMap.getNumSymbols(); + ArrayRef<Value> srcOperands = srcAccessMap.getOperands(); + + unsigned dstNumIds = dstMap.getNumDims() + dstMap.getNumSymbols(); + ArrayRef<Value> dstOperands = dstAccessMap.getOperands(); + + std::vector<SmallVector<int64_t, 8>> srcFlatExprs; + std::vector<SmallVector<int64_t, 8>> destFlatExprs; + FlatAffineConstraints srcLocalVarCst, destLocalVarCst; + // Get flattened expressions for the source destination maps. + if (failed(getFlattenedAffineExprs(srcMap, &srcFlatExprs, &srcLocalVarCst)) || + failed(getFlattenedAffineExprs(dstMap, &destFlatExprs, &destLocalVarCst))) + return failure(); + + unsigned domNumLocalIds = dependenceDomain->getNumLocalIds(); + unsigned srcNumLocalIds = srcLocalVarCst.getNumLocalIds(); + unsigned dstNumLocalIds = destLocalVarCst.getNumLocalIds(); + unsigned numLocalIdsToAdd = srcNumLocalIds + dstNumLocalIds; + for (unsigned i = 0; i < numLocalIdsToAdd; i++) { + dependenceDomain->addLocalId(dependenceDomain->getNumLocalIds()); + } + + unsigned numDims = dependenceDomain->getNumDimIds(); + unsigned numSymbols = dependenceDomain->getNumSymbolIds(); + unsigned numSrcLocalIds = srcLocalVarCst.getNumLocalIds(); + unsigned newLocalIdOffset = numDims + numSymbols + domNumLocalIds; + + // Equality to add. + SmallVector<int64_t, 8> eq(dependenceDomain->getNumCols()); + for (unsigned i = 0; i < numResults; ++i) { + // Zero fill. + std::fill(eq.begin(), eq.end(), 0); + + // Flattened AffineExpr for src result 'i'. + const auto &srcFlatExpr = srcFlatExprs[i]; + // Set identifier coefficients from src access function. + for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) + eq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = srcFlatExpr[j]; + // Local terms. + for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) + eq[newLocalIdOffset + j] = srcFlatExpr[srcNumIds + j]; + // Set constant term. + eq[eq.size() - 1] = srcFlatExpr[srcFlatExpr.size() - 1]; + + // Flattened AffineExpr for dest result 'i'. + const auto &destFlatExpr = destFlatExprs[i]; + // Set identifier coefficients from dst access function. + for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) + eq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] -= destFlatExpr[j]; + // Local terms. + for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) + eq[newLocalIdOffset + numSrcLocalIds + j] = -destFlatExpr[dstNumIds + j]; + // Set constant term. + eq[eq.size() - 1] -= destFlatExpr[destFlatExpr.size() - 1]; + + // Add equality constraint. + dependenceDomain->addEquality(eq); + } + + // Add equality constraints for any operands that are defined by constant ops. + auto addEqForConstOperands = [&](ArrayRef<Value> operands) { + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (isForInductionVar(operands[i])) + continue; + auto symbol = operands[i]; + assert(isValidSymbol(symbol)); + // Check if the symbol is a constant. + if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp())) + dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol), + cOp.getValue()); + } + }; + + // Add equality constraints for any src symbols defined by constant ops. + addEqForConstOperands(srcOperands); + // Add equality constraints for any dst symbols defined by constant ops. + addEqForConstOperands(dstOperands); + + // By construction (see flattener), local var constraints will not have any + // equalities. + assert(srcLocalVarCst.getNumEqualities() == 0 && + destLocalVarCst.getNumEqualities() == 0); + // Add inequalities from srcLocalVarCst and destLocalVarCst into the + // dependence domain. + SmallVector<int64_t, 8> ineq(dependenceDomain->getNumCols()); + for (unsigned r = 0, e = srcLocalVarCst.getNumInequalities(); r < e; r++) { + std::fill(ineq.begin(), ineq.end(), 0); + + // Set identifier coefficients from src local var constraints. + for (unsigned j = 0, e = srcOperands.size(); j < e; ++j) + ineq[valuePosMap.getSrcDimOrSymPos(srcOperands[j])] = + srcLocalVarCst.atIneq(r, j); + // Local terms. + for (unsigned j = 0, e = srcNumLocalIds; j < e; j++) + ineq[newLocalIdOffset + j] = srcLocalVarCst.atIneq(r, srcNumIds + j); + // Set constant term. + ineq[ineq.size() - 1] = + srcLocalVarCst.atIneq(r, srcLocalVarCst.getNumCols() - 1); + dependenceDomain->addInequality(ineq); + } + + for (unsigned r = 0, e = destLocalVarCst.getNumInequalities(); r < e; r++) { + std::fill(ineq.begin(), ineq.end(), 0); + // Set identifier coefficients from dest local var constraints. + for (unsigned j = 0, e = dstOperands.size(); j < e; ++j) + ineq[valuePosMap.getDstDimOrSymPos(dstOperands[j])] = + destLocalVarCst.atIneq(r, j); + // Local terms. + for (unsigned j = 0, e = dstNumLocalIds; j < e; j++) + ineq[newLocalIdOffset + numSrcLocalIds + j] = + destLocalVarCst.atIneq(r, dstNumIds + j); + // Set constant term. + ineq[ineq.size() - 1] = + destLocalVarCst.atIneq(r, destLocalVarCst.getNumCols() - 1); + + dependenceDomain->addInequality(ineq); + } + return success(); +} + +// Returns the number of outer loop common to 'src/dstDomain'. +// Loops common to 'src/dst' domains are added to 'commonLoops' if non-null. +static unsigned +getNumCommonLoops(const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, + SmallVectorImpl<AffineForOp> *commonLoops = nullptr) { + // Find the number of common loops shared by src and dst accesses. + unsigned minNumLoops = + std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); + unsigned numCommonLoops = 0; + for (unsigned i = 0; i < minNumLoops; ++i) { + if (!isForInductionVar(srcDomain.getIdValue(i)) || + !isForInductionVar(dstDomain.getIdValue(i)) || + srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) + break; + if (commonLoops != nullptr) + commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i))); + ++numCommonLoops; + } + if (commonLoops != nullptr) + assert(commonLoops->size() == numCommonLoops); + return numCommonLoops; +} + +// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. +static Block *getCommonBlock(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess, + const FlatAffineConstraints &srcDomain, + unsigned numCommonLoops) { + if (numCommonLoops == 0) { + auto *block = srcAccess.opInst->getBlock(); + while (!llvm::isa<FuncOp>(block->getParentOp())) { + block = block->getParentOp()->getBlock(); + } + return block; + } + auto commonForValue = srcDomain.getIdValue(numCommonLoops - 1); + auto forOp = getForInductionVarOwner(commonForValue); + assert(forOp && "commonForValue was not an induction variable"); + return forOp.getBody(); +} + +// Returns true if the ancestor operation of 'srcAccess' appears before the +// ancestor operation of 'dstAccess' in the common ancestral block. Returns +// false otherwise. +// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, +// the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that +// 'numCommonLoops' is the number of contiguous surrounding outer loops. +static bool srcAppearsBeforeDstInAncestralBlock( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { + // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. + auto *commonBlock = + getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); + // Check the dominance relationship between the respective ancestors of the + // src and dst in the Block of the innermost among the common loops. + auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst); + assert(srcInst != nullptr); + auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst); + assert(dstInst != nullptr); + + // Determine whether dstInst comes after srcInst. + return srcInst->isBeforeInBlock(dstInst); +} + +// Adds ordering constraints to 'dependenceDomain' based on number of loops +// common to 'src/dstDomain' and requested 'loopDepth'. +// Note that 'loopDepth' cannot exceed the number of common loops plus one. +// EX: Given a loop nest of depth 2 with IVs 'i' and 'j': +// *) If 'loopDepth == 1' then one constraint is added: i' >= i + 1 +// *) If 'loopDepth == 2' then two constraints are added: i == i' and j' > j + 1 +// *) If 'loopDepth == 3' then two constraints are added: i == i' and j == j' +static void addOrderingConstraints(const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, + unsigned loopDepth, + FlatAffineConstraints *dependenceDomain) { + unsigned numCols = dependenceDomain->getNumCols(); + SmallVector<int64_t, 4> eq(numCols); + unsigned numSrcDims = srcDomain.getNumDimIds(); + unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); + unsigned numCommonLoopConstraints = std::min(numCommonLoops, loopDepth); + for (unsigned i = 0; i < numCommonLoopConstraints; ++i) { + std::fill(eq.begin(), eq.end(), 0); + eq[i] = -1; + eq[i + numSrcDims] = 1; + if (i == loopDepth - 1) { + eq[numCols - 1] = -1; + dependenceDomain->addInequality(eq); + } else { + dependenceDomain->addEquality(eq); + } + } +} + +// Computes distance and direction vectors in 'dependences', by adding +// variables to 'dependenceDomain' which represent the difference of the IVs, +// eliminating all other variables, and reading off distance vectors from +// equality constraints (if possible), and direction vectors from inequalities. +static void computeDirectionVector( + const FlatAffineConstraints &srcDomain, + const FlatAffineConstraints &dstDomain, unsigned loopDepth, + FlatAffineConstraints *dependenceDomain, + SmallVector<DependenceComponent, 2> *dependenceComponents) { + // Find the number of common loops shared by src and dst accesses. + SmallVector<AffineForOp, 4> commonLoops; + unsigned numCommonLoops = + getNumCommonLoops(srcDomain, dstDomain, &commonLoops); + if (numCommonLoops == 0) + return; + // Compute direction vectors for requested loop depth. + unsigned numIdsToEliminate = dependenceDomain->getNumIds(); + // Add new variables to 'dependenceDomain' to represent the direction + // constraints for each shared loop. + for (unsigned j = 0; j < numCommonLoops; ++j) { + dependenceDomain->addDimId(j); + } + + // Add equality constraints for each common loop, setting newly introduced + // variable at column 'j' to the 'dst' IV minus the 'src IV. + SmallVector<int64_t, 4> eq; + eq.resize(dependenceDomain->getNumCols()); + unsigned numSrcDims = srcDomain.getNumDimIds(); + // Constraint variables format: + // [num-common-loops][num-src-dim-ids][num-dst-dim-ids][num-symbols][constant] + for (unsigned j = 0; j < numCommonLoops; ++j) { + std::fill(eq.begin(), eq.end(), 0); + eq[j] = 1; + eq[j + numCommonLoops] = 1; + eq[j + numCommonLoops + numSrcDims] = -1; + dependenceDomain->addEquality(eq); + } + + // Eliminate all variables other than the direction variables just added. + dependenceDomain->projectOut(numCommonLoops, numIdsToEliminate); + + // Scan each common loop variable column and set direction vectors based + // on eliminated constraint system. + dependenceComponents->resize(numCommonLoops); + for (unsigned j = 0; j < numCommonLoops; ++j) { + (*dependenceComponents)[j].op = commonLoops[j].getOperation(); + auto lbConst = dependenceDomain->getConstantLowerBound(j); + (*dependenceComponents)[j].lb = + lbConst.getValueOr(std::numeric_limits<int64_t>::min()); + auto ubConst = dependenceDomain->getConstantUpperBound(j); + (*dependenceComponents)[j].ub = + ubConst.getValueOr(std::numeric_limits<int64_t>::max()); + } +} + +// Populates 'accessMap' with composition of AffineApplyOps reachable from +// indices of MemRefAccess. +void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const { + // Get affine map from AffineLoad/Store. + AffineMap map; + if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) + map = loadOp.getAffineMap(); + else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) + map = storeOp.getAffineMap(); + SmallVector<Value, 8> operands(indices.begin(), indices.end()); + fullyComposeAffineMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + canonicalizeMapAndOperands(&map, &operands); + accessMap->reset(map, operands); +} + +// Builds a flat affine constraint system to check if there exists a dependence +// between memref accesses 'srcAccess' and 'dstAccess'. +// Returns 'NoDependence' if the accesses can be definitively shown not to +// access the same element. +// Returns 'HasDependence' if the accesses do access the same element. +// Returns 'Failure' if an error or unsupported case was encountered. +// If a dependence exists, returns in 'dependenceComponents' a direction +// vector for the dependence, with a component for each loop IV in loops +// common to both accesses (see Dependence in AffineAnalysis.h for details). +// +// The memref access dependence check is comprised of the following steps: +// *) Compute access functions for each access. Access functions are computed +// using AffineValueMaps initialized with the indices from an access, then +// composed with AffineApplyOps reachable from operands of that access, +// until operands of the AffineValueMap are loop IVs or symbols. +// *) Build iteration domain constraints for each access. Iteration domain +// constraints are pairs of inequality constraints representing the +// upper/lower loop bounds for each AffineForOp in the loop nest associated +// with each access. +// *) Build dimension and symbol position maps for each access, which map +// Values from access functions and iteration domains to their position +// in the merged constraint system built by this method. +// +// This method builds a constraint system with the following column format: +// +// [src-dim-identifiers, dst-dim-identifiers, symbols, constant] +// +// For example, given the following MLIR code with "source" and "destination" +// accesses to the same memref label, and symbols %M, %N, %K: +// +// affine.for %i0 = 0 to 100 { +// affine.for %i1 = 0 to 50 { +// %a0 = affine.apply +// (d0, d1) -> (d0 * 2 - d1 * 4 + s1, d1 * 3 - s0) (%i0, %i1)[%M, %N] +// // Source memref access. +// store %v0, %m[%a0#0, %a0#1] : memref<4x4xf32> +// } +// } +// +// affine.for %i2 = 0 to 100 { +// affine.for %i3 = 0 to 50 { +// %a1 = affine.apply +// (d0, d1) -> (d0 * 7 + d1 * 9 - s1, d1 * 11 + s0) (%i2, %i3)[%K, %M] +// // Destination memref access. +// %v1 = load %m[%a1#0, %a1#1] : memref<4x4xf32> +// } +// } +// +// The access functions would be the following: +// +// src: (%i0 * 2 - %i1 * 4 + %N, %i1 * 3 - %M) +// dst: (%i2 * 7 + %i3 * 9 - %M, %i3 * 11 - %K) +// +// The iteration domains for the src/dst accesses would be the following: +// +// src: 0 <= %i0 <= 100, 0 <= %i1 <= 50 +// dst: 0 <= %i2 <= 100, 0 <= %i3 <= 50 +// +// The symbols by both accesses would be assigned to a canonical position order +// which will be used in the dependence constraint system: +// +// symbol name: %M %N %K +// symbol pos: 0 1 2 +// +// Equality constraints are built by equating each result of src/destination +// access functions. For this example, the following two equality constraints +// will be added to the dependence constraint system: +// +// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] +// 2 -4 -7 -9 1 1 0 0 = 0 +// 0 3 0 -11 -1 0 1 0 = 0 +// +// Inequality constraints from the iteration domain will be meged into +// the dependence constraint system +// +// [src_dim0, src_dim1, dst_dim0, dst_dim1, sym0, sym1, sym2, const] +// 1 0 0 0 0 0 0 0 >= 0 +// -1 0 0 0 0 0 0 100 >= 0 +// 0 1 0 0 0 0 0 0 >= 0 +// 0 -1 0 0 0 0 0 50 >= 0 +// 0 0 1 0 0 0 0 0 >= 0 +// 0 0 -1 0 0 0 0 100 >= 0 +// 0 0 0 1 0 0 0 0 >= 0 +// 0 0 0 -1 0 0 0 50 >= 0 +// +// +// TODO(andydavis) Support AffineExprs mod/floordiv/ceildiv. +DependenceResult mlir::checkMemrefAccessDependence( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, + SmallVector<DependenceComponent, 2> *dependenceComponents, bool allowRAR) { + LLVM_DEBUG(llvm::dbgs() << "Checking for dependence at depth: " + << Twine(loopDepth) << " between:\n";); + LLVM_DEBUG(srcAccess.opInst->dump();); + LLVM_DEBUG(dstAccess.opInst->dump();); + + // Return 'NoDependence' if these accesses do not access the same memref. + if (srcAccess.memref != dstAccess.memref) + return DependenceResult::NoDependence; + + // Return 'NoDependence' if one of these accesses is not an AffineStoreOp. + if (!allowRAR && !isa<AffineStoreOp>(srcAccess.opInst) && + !isa<AffineStoreOp>(dstAccess.opInst)) + return DependenceResult::NoDependence; + + // Get composed access function for 'srcAccess'. + AffineValueMap srcAccessMap; + srcAccess.getAccessMap(&srcAccessMap); + + // Get composed access function for 'dstAccess'. + AffineValueMap dstAccessMap; + dstAccess.getAccessMap(&dstAccessMap); + + // Get iteration domain for the 'srcAccess' operation. + FlatAffineConstraints srcDomain; + if (failed(getInstIndexSet(srcAccess.opInst, &srcDomain))) + return DependenceResult::Failure; + + // Get iteration domain for 'dstAccess' operation. + FlatAffineConstraints dstDomain; + if (failed(getInstIndexSet(dstAccess.opInst, &dstDomain))) + return DependenceResult::Failure; + + // Return 'NoDependence' if loopDepth > numCommonLoops and if the ancestor + // operation of 'srcAccess' does not properly dominate the ancestor + // operation of 'dstAccess' in the same common operation block. + // Note: this check is skipped if 'allowRAR' is true, because because RAR + // deps can exist irrespective of lexicographic ordering b/w src and dst. + unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); + assert(loopDepth <= numCommonLoops + 1); + if (!allowRAR && loopDepth > numCommonLoops && + !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain, + numCommonLoops)) { + return DependenceResult::NoDependence; + } + // Build dim and symbol position maps for each access from access operand + // Value to position in merged constraint system. + ValuePositionMap valuePosMap; + buildDimAndSymbolPositionMaps(srcDomain, dstDomain, srcAccessMap, + dstAccessMap, &valuePosMap, + dependenceConstraints); + + initDependenceConstraints(srcDomain, dstDomain, srcAccessMap, dstAccessMap, + valuePosMap, dependenceConstraints); + + assert(valuePosMap.getNumDims() == + srcDomain.getNumDimIds() + dstDomain.getNumDimIds()); + + // Create memref access constraint by equating src/dst access functions. + // Note that this check is conservative, and will fail in the future when + // local variables for mod/div exprs are supported. + if (failed(addMemRefAccessConstraints(srcAccessMap, dstAccessMap, valuePosMap, + dependenceConstraints))) + return DependenceResult::Failure; + + // Add 'src' happens before 'dst' ordering constraints. + addOrderingConstraints(srcDomain, dstDomain, loopDepth, + dependenceConstraints); + // Add src and dst domain constraints. + addDomainConstraints(srcDomain, dstDomain, valuePosMap, + dependenceConstraints); + + // Return 'NoDependence' if the solution space is empty: no dependence. + if (dependenceConstraints->isEmpty()) { + return DependenceResult::NoDependence; + } + + // Compute dependence direction vector and return true. + if (dependenceComponents != nullptr) { + computeDirectionVector(srcDomain, dstDomain, loopDepth, + dependenceConstraints, dependenceComponents); + } + + LLVM_DEBUG(llvm::dbgs() << "Dependence polyhedron:\n"); + LLVM_DEBUG(dependenceConstraints->dump()); + return DependenceResult::HasDependence; +} + +/// Gathers dependence components for dependences between all ops in loop nest +/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth]. +void mlir::getDependenceComponents( + AffineForOp forOp, unsigned maxLoopDepth, + std::vector<SmallVector<DependenceComponent, 2>> *depCompsVec) { + // Collect all load and store ops in loop nest rooted at 'forOp'. + SmallVector<Operation *, 8> loadAndStoreOpInsts; + forOp.getOperation()->walk([&](Operation *opInst) { + if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst)) + loadAndStoreOpInsts.push_back(opInst); + }); + + unsigned numOps = loadAndStoreOpInsts.size(); + for (unsigned d = 1; d <= maxLoopDepth; ++d) { + for (unsigned i = 0; i < numOps; ++i) { + auto *srcOpInst = loadAndStoreOpInsts[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < numOps; ++j) { + auto *dstOpInst = loadAndStoreOpInsts[j]; + MemRefAccess dstAccess(dstOpInst); + + FlatAffineConstraints dependenceConstraints; + SmallVector<DependenceComponent, 2> depComps; + // TODO(andydavis,bondhugula) Explore whether it would be profitable + // to pre-compute and store deps instead of repeatedly checking. + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, &depComps); + if (hasDependence(result)) + depCompsVec->push_back(depComps); + } + } + } +} diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp new file mode 100644 index 00000000000..78a869884ee --- /dev/null +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -0,0 +1,2854 @@ +//===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Structures for affine/polyhedral analysis of MLIR functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/MathExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "affine-structures" + +using namespace mlir; +using llvm::SmallDenseMap; +using llvm::SmallDenseSet; + +namespace { + +// See comments for SimpleAffineExprFlattener. +// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording +// constraint information associated with mod's, floordiv's, and ceildiv's +// in FlatAffineConstraints 'localVarCst'. +struct AffineExprFlattener : public SimpleAffineExprFlattener { +public: + // Constraints connecting newly introduced local variables (for mod's and + // div's) to existing (dimensional and symbolic) ones. These are always + // inequalities. + FlatAffineConstraints localVarCst; + + AffineExprFlattener(unsigned nDims, unsigned nSymbols, MLIRContext *ctx) + : SimpleAffineExprFlattener(nDims, nSymbols) { + localVarCst.reset(nDims, nSymbols, /*numLocals=*/0); + } + +private: + // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). + // The local identifier added is always a floordiv of a pure add/mul affine + // function of other identifiers, coefficients of which are specified in + // `dividend' and with respect to the positive constant `divisor'. localExpr + // is the simplified tree expression (AffineExpr) corresponding to the + // quantifier. + void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor, + AffineExpr localExpr) override { + SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); + // Update localVarCst. + localVarCst.addLocalFloorDiv(dividend, divisor); + } +}; + +} // end anonymous namespace + +// Flattens the expressions in map. Returns failure if 'expr' was unable to be +// flattened (i.e., semi-affine expressions not handled yet). +static LogicalResult +getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, + unsigned numSymbols, + std::vector<SmallVector<int64_t, 8>> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (exprs.empty()) { + localVarCst->reset(numDims, numSymbols); + return success(); + } + + AffineExprFlattener flattener(numDims, numSymbols, exprs[0].getContext()); + // Use the same flattener to simplify each expression successively. This way + // local identifiers / expressions are shared. + for (auto expr : exprs) { + if (!expr.isPureAffine()) + return failure(); + + flattener.walkPostOrder(expr); + } + + assert(flattener.operandExprStack.size() == exprs.size()); + flattenedExprs->clear(); + flattenedExprs->assign(flattener.operandExprStack.begin(), + flattener.operandExprStack.end()); + + if (localVarCst) { + localVarCst->clearAndCopyFrom(flattener.localVarCst); + } + + return success(); +} + +// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to +// be flattened (semi-affine expressions not handled yet). +LogicalResult +mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, + unsigned numSymbols, + SmallVectorImpl<int64_t> *flattenedExpr, + FlatAffineConstraints *localVarCst) { + std::vector<SmallVector<int64_t, 8>> flattenedExprs; + LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols, + &flattenedExprs, localVarCst); + *flattenedExpr = flattenedExprs[0]; + return ret; +} + +/// Flattens the expressions in map. Returns failure if 'expr' was unable to be +/// flattened (i.e., semi-affine expressions not handled yet). +LogicalResult mlir::getFlattenedAffineExprs( + AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (map.getNumResults() == 0) { + localVarCst->reset(map.getNumDims(), map.getNumSymbols()); + return success(); + } + return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(), + map.getNumSymbols(), flattenedExprs, + localVarCst); +} + +LogicalResult mlir::getFlattenedAffineExprs( + IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs, + FlatAffineConstraints *localVarCst) { + if (set.getNumConstraints() == 0) { + localVarCst->reset(set.getNumDims(), set.getNumSymbols()); + return success(); + } + return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(), + set.getNumSymbols(), flattenedExprs, + localVarCst); +} + +//===----------------------------------------------------------------------===// +// MutableAffineMap. +//===----------------------------------------------------------------------===// + +MutableAffineMap::MutableAffineMap(AffineMap map) + : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()), + // A map always has at least 1 result by construction + context(map.getResult(0).getContext()) { + for (auto result : map.getResults()) + results.push_back(result); +} + +void MutableAffineMap::reset(AffineMap map) { + results.clear(); + numDims = map.getNumDims(); + numSymbols = map.getNumSymbols(); + // A map always has at least 1 result by construction + context = map.getResult(0).getContext(); + for (auto result : map.getResults()) + results.push_back(result); +} + +bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { + if (results[idx].isMultipleOf(factor)) + return true; + + // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to + // complete this (for a more powerful analysis). + return false; +} + +// Simplifies the result affine expressions of this map. The expressions have to +// be pure for the simplification implemented. +void MutableAffineMap::simplify() { + // Simplify each of the results if possible. + // TODO(ntv): functional-style map + for (unsigned i = 0, e = getNumResults(); i < e; i++) { + results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols); + } +} + +AffineMap MutableAffineMap::getAffineMap() const { + return AffineMap::get(numDims, numSymbols, results); +} + +MutableIntegerSet::MutableIntegerSet(IntegerSet set, MLIRContext *context) + : numDims(set.getNumDims()), numSymbols(set.getNumSymbols()) { + // TODO(bondhugula) +} + +// Universal set. +MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols, + MLIRContext *context) + : numDims(numDims), numSymbols(numSymbols) {} + +//===----------------------------------------------------------------------===// +// AffineValueMap. +//===----------------------------------------------------------------------===// + +AffineValueMap::AffineValueMap(AffineMap map, ArrayRef<Value> operands, + ArrayRef<Value> results) + : map(map), operands(operands.begin(), operands.end()), + results(results.begin(), results.end()) {} + +AffineValueMap::AffineValueMap(AffineApplyOp applyOp) + : map(applyOp.getAffineMap()), + operands(applyOp.operand_begin(), applyOp.operand_end()) { + results.push_back(applyOp.getResult()); +} + +AffineValueMap::AffineValueMap(AffineBound bound) + : map(bound.getMap()), + operands(bound.operand_begin(), bound.operand_end()) {} + +void AffineValueMap::reset(AffineMap map, ArrayRef<Value> operands, + ArrayRef<Value> results) { + this->map.reset(map); + this->operands.assign(operands.begin(), operands.end()); + this->results.assign(results.begin(), results.end()); +} + +void AffineValueMap::difference(const AffineValueMap &a, + const AffineValueMap &b, AffineValueMap *res) { + assert(a.getNumResults() == b.getNumResults() && "invalid inputs"); + + // Fully compose A's map + operands. + auto aMap = a.getAffineMap(); + SmallVector<Value, 4> aOperands(a.getOperands().begin(), + a.getOperands().end()); + fullyComposeAffineMapAndOperands(&aMap, &aOperands); + + // Use the affine apply normalizer to get B's map into A's coordinate space. + AffineApplyNormalizer normalizer(aMap, aOperands); + SmallVector<Value, 4> bOperands(b.getOperands().begin(), + b.getOperands().end()); + auto bMap = b.getAffineMap(); + normalizer.normalize(&bMap, &bOperands); + + assert(std::equal(bOperands.begin(), bOperands.end(), + normalizer.getOperands().begin()) && + "operands are expected to be the same after normalization"); + + // Construct the difference expressions. + SmallVector<AffineExpr, 4> diffExprs; + diffExprs.reserve(a.getNumResults()); + for (unsigned i = 0, e = bMap.getNumResults(); i < e; ++i) + diffExprs.push_back(normalizer.getAffineMap().getResult(i) - + bMap.getResult(i)); + + auto diffMap = AffineMap::get(normalizer.getNumDims(), + normalizer.getNumSymbols(), diffExprs); + canonicalizeMapAndOperands(&diffMap, &bOperands); + diffMap = simplifyAffineMap(diffMap); + res->reset(diffMap, bOperands); +} + +// Returns true and sets 'indexOfMatch' if 'valueToMatch' is found in +// 'valuesToSearch' beginning at 'indexStart'. Returns false otherwise. +static bool findIndex(Value valueToMatch, ArrayRef<Value> valuesToSearch, + unsigned indexStart, unsigned *indexOfMatch) { + unsigned size = valuesToSearch.size(); + for (unsigned i = indexStart; i < size; ++i) { + if (valueToMatch == valuesToSearch[i]) { + *indexOfMatch = i; + return true; + } + } + return false; +} + +inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const { + return map.isMultipleOf(idx, factor); +} + +/// This method uses the invariant that operands are always positionally aligned +/// with the AffineDimExpr in the underlying AffineMap. +bool AffineValueMap::isFunctionOf(unsigned idx, Value value) const { + unsigned index; + if (!findIndex(value, operands, /*indexStart=*/0, &index)) { + return false; + } + auto expr = const_cast<AffineValueMap *>(this)->getAffineMap().getResult(idx); + // TODO(ntv): this is better implemented on a flattened representation. + // At least for now it is conservative. + return expr.isFunctionOfDim(index); +} + +Value AffineValueMap::getOperand(unsigned i) const { + return static_cast<Value>(operands[i]); +} + +ArrayRef<Value> AffineValueMap::getOperands() const { + return ArrayRef<Value>(operands); +} + +AffineMap AffineValueMap::getAffineMap() const { return map.getAffineMap(); } + +AffineValueMap::~AffineValueMap() {} + +//===----------------------------------------------------------------------===// +// FlatAffineConstraints. +//===----------------------------------------------------------------------===// + +// Copy constructor. +FlatAffineConstraints::FlatAffineConstraints( + const FlatAffineConstraints &other) { + numReservedCols = other.numReservedCols; + numDims = other.getNumDimIds(); + numSymbols = other.getNumSymbolIds(); + numIds = other.getNumIds(); + + auto otherIds = other.getIds(); + ids.reserve(numReservedCols); + ids.append(otherIds.begin(), otherIds.end()); + + unsigned numReservedEqualities = other.getNumReservedEqualities(); + unsigned numReservedInequalities = other.getNumReservedInequalities(); + + equalities.reserve(numReservedEqualities * numReservedCols); + inequalities.reserve(numReservedInequalities * numReservedCols); + + for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { + addInequality(other.getInequality(r)); + } + for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { + addEquality(other.getEquality(r)); + } +} + +// Clones this object. +std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const { + return std::make_unique<FlatAffineConstraints>(*this); +} + +// Construct from an IntegerSet. +FlatAffineConstraints::FlatAffineConstraints(IntegerSet set) + : numReservedCols(set.getNumInputs() + 1), + numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()), + numSymbols(set.getNumSymbols()) { + equalities.reserve(set.getNumEqualities() * numReservedCols); + inequalities.reserve(set.getNumInequalities() * numReservedCols); + ids.resize(numIds, None); + + // Flatten expressions and add them to the constraint system. + std::vector<SmallVector<int64_t, 8>> flatExprs; + FlatAffineConstraints localVarCst; + if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) { + assert(false && "flattening unimplemented for semi-affine integer sets"); + return; + } + assert(flatExprs.size() == set.getNumConstraints()); + for (unsigned l = 0, e = localVarCst.getNumLocalIds(); l < e; l++) { + addLocalId(getNumLocalIds()); + } + + for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { + const auto &flatExpr = flatExprs[i]; + assert(flatExpr.size() == getNumCols()); + if (set.getEqFlags()[i]) { + addEquality(flatExpr); + } else { + addInequality(flatExpr); + } + } + // Add the other constraints involving local id's from flattening. + append(localVarCst); +} + +void FlatAffineConstraints::reset(unsigned numReservedInequalities, + unsigned numReservedEqualities, + unsigned newNumReservedCols, + unsigned newNumDims, unsigned newNumSymbols, + unsigned newNumLocals, + ArrayRef<Value> idArgs) { + assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 && + "minimum 1 column"); + numReservedCols = newNumReservedCols; + numDims = newNumDims; + numSymbols = newNumSymbols; + numIds = numDims + numSymbols + newNumLocals; + assert(idArgs.empty() || idArgs.size() == numIds); + + clearConstraints(); + if (numReservedEqualities >= 1) + equalities.reserve(newNumReservedCols * numReservedEqualities); + if (numReservedInequalities >= 1) + inequalities.reserve(newNumReservedCols * numReservedInequalities); + if (idArgs.empty()) { + ids.resize(numIds, None); + } else { + ids.assign(idArgs.begin(), idArgs.end()); + } +} + +void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols, + unsigned newNumLocals, + ArrayRef<Value> idArgs) { + reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims, + newNumSymbols, newNumLocals, idArgs); +} + +void FlatAffineConstraints::append(const FlatAffineConstraints &other) { + assert(other.getNumCols() == getNumCols()); + assert(other.getNumDimIds() == getNumDimIds()); + assert(other.getNumSymbolIds() == getNumSymbolIds()); + + inequalities.reserve(inequalities.size() + + other.getNumInequalities() * numReservedCols); + equalities.reserve(equalities.size() + + other.getNumEqualities() * numReservedCols); + + for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) { + addInequality(other.getInequality(r)); + } + for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) { + addEquality(other.getEquality(r)); + } +} + +void FlatAffineConstraints::addLocalId(unsigned pos) { + addId(IdKind::Local, pos); +} + +void FlatAffineConstraints::addDimId(unsigned pos, Value id) { + addId(IdKind::Dimension, pos, id); +} + +void FlatAffineConstraints::addSymbolId(unsigned pos, Value id) { + addId(IdKind::Symbol, pos, id); +} + +/// Adds a dimensional identifier. The added column is initialized to +/// zero. +void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value id) { + if (kind == IdKind::Dimension) { + assert(pos <= getNumDimIds()); + } else if (kind == IdKind::Symbol) { + assert(pos <= getNumSymbolIds()); + } else { + assert(pos <= getNumLocalIds()); + } + + unsigned oldNumReservedCols = numReservedCols; + + // Check if a resize is necessary. + if (getNumCols() + 1 > numReservedCols) { + equalities.resize(getNumEqualities() * (getNumCols() + 1)); + inequalities.resize(getNumInequalities() * (getNumCols() + 1)); + numReservedCols++; + } + + int absolutePos; + + if (kind == IdKind::Dimension) { + absolutePos = pos; + numDims++; + } else if (kind == IdKind::Symbol) { + absolutePos = pos + getNumDimIds(); + numSymbols++; + } else { + absolutePos = pos + getNumDimIds() + getNumSymbolIds(); + } + numIds++; + + // Note that getNumCols() now will already return the new size, which will be + // at least one. + int numInequalities = static_cast<int>(getNumInequalities()); + int numEqualities = static_cast<int>(getNumEqualities()); + int numCols = static_cast<int>(getNumCols()); + for (int r = numInequalities - 1; r >= 0; r--) { + for (int c = numCols - 2; c >= 0; c--) { + if (c < absolutePos) + atIneq(r, c) = inequalities[r * oldNumReservedCols + c]; + else + atIneq(r, c + 1) = inequalities[r * oldNumReservedCols + c]; + } + atIneq(r, absolutePos) = 0; + } + + for (int r = numEqualities - 1; r >= 0; r--) { + for (int c = numCols - 2; c >= 0; c--) { + // All values in column absolutePositions < absolutePos have the same + // coordinates in the 2-d view of the coefficient buffer. + if (c < absolutePos) + atEq(r, c) = equalities[r * oldNumReservedCols + c]; + else + // Those at absolutePosition >= absolutePos, get a shifted + // absolutePosition. + atEq(r, c + 1) = equalities[r * oldNumReservedCols + c]; + } + // Initialize added dimension to zero. + atEq(r, absolutePos) = 0; + } + + // If an 'id' is provided, insert it; otherwise use None. + if (id) { + ids.insert(ids.begin() + absolutePos, id); + } else { + ids.insert(ids.begin() + absolutePos, None); + } + assert(ids.size() == getNumIds()); +} + +/// Checks if two constraint systems are in the same space, i.e., if they are +/// associated with the same set of identifiers, appearing in the same order. +static bool areIdsAligned(const FlatAffineConstraints &A, + const FlatAffineConstraints &B) { + return A.getNumDimIds() == B.getNumDimIds() && + A.getNumSymbolIds() == B.getNumSymbolIds() && + A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds()); +} + +/// Calls areIdsAligned to check if two constraint systems have the same set +/// of identifiers in the same order. +bool FlatAffineConstraints::areIdsAlignedWithOther( + const FlatAffineConstraints &other) { + return areIdsAligned(*this, other); +} + +/// Checks if the SSA values associated with `cst''s identifiers are unique. +static bool LLVM_ATTRIBUTE_UNUSED +areIdsUnique(const FlatAffineConstraints &cst) { + SmallPtrSet<Value, 8> uniqueIds; + for (auto id : cst.getIds()) { + if (id.hasValue() && !uniqueIds.insert(id.getValue()).second) + return false; + } + return true; +} + +// Swap the posA^th identifier with the posB^th identifier. +static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) { + assert(posA < A->getNumIds() && "invalid position A"); + assert(posB < A->getNumIds() && "invalid position B"); + + if (posA == posB) + return; + + for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) { + std::swap(A->atIneq(r, posA), A->atIneq(r, posB)); + } + for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) { + std::swap(A->atEq(r, posA), A->atEq(r, posB)); + } + std::swap(A->getId(posA), A->getId(posB)); +} + +/// Merge and align the identifiers of A and B starting at 'offset', so that +/// both constraint systems get the union of the contained identifiers that is +/// dimension-wise and symbol-wise unique; both constraint systems are updated +/// so that they have the union of all identifiers, with A's original +/// identifiers appearing first followed by any of B's identifiers that didn't +/// appear in A. Local identifiers of each system are by design separate/local +/// and are placed one after other (A's followed by B's). +// Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M]) +// Output: both A, B have (%i, %j, %k) [%M, %N, %P] +// +static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A, + FlatAffineConstraints *B) { + assert(offset <= A->getNumDimIds() && offset <= B->getNumDimIds()); + // A merge/align isn't meaningful if a cst's ids aren't distinct. + assert(areIdsUnique(*A) && "A's id values aren't unique"); + assert(areIdsUnique(*B) && "B's id values aren't unique"); + + assert(std::all_of(A->getIds().begin() + offset, + A->getIds().begin() + A->getNumDimAndSymbolIds(), + [](Optional<Value> id) { return id.hasValue(); })); + + assert(std::all_of(B->getIds().begin() + offset, + B->getIds().begin() + B->getNumDimAndSymbolIds(), + [](Optional<Value> id) { return id.hasValue(); })); + + // Place local id's of A after local id's of B. + for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) { + B->addLocalId(0); + } + for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e; + t++) { + A->addLocalId(A->getNumLocalIds()); + } + + SmallVector<Value, 4> aDimValues, aSymValues; + A->getIdValues(offset, A->getNumDimIds(), &aDimValues); + A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues); + { + // Merge dims from A into B. + unsigned d = offset; + for (auto aDimValue : aDimValues) { + unsigned loc; + if (B->findId(*aDimValue, &loc)) { + assert(loc >= offset && "A's dim appears in B's aligned range"); + assert(loc < B->getNumDimIds() && + "A's dim appears in B's non-dim position"); + swapId(B, d, loc); + } else { + B->addDimId(d); + B->setIdValue(d, aDimValue); + } + d++; + } + + // Dimensions that are in B, but not in A, are added at the end. + for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) { + A->addDimId(A->getNumDimIds()); + A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t)); + } + } + { + // Merge symbols: merge A's symbols into B first. + unsigned s = B->getNumDimIds(); + for (auto aSymValue : aSymValues) { + unsigned loc; + if (B->findId(*aSymValue, &loc)) { + assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() && + "A's symbol appears in B's non-symbol position"); + swapId(B, s, loc); + } else { + B->addSymbolId(s - B->getNumDimIds()); + B->setIdValue(s, aSymValue); + } + s++; + } + // Symbols that are in B, but not in A, are added at the end. + for (unsigned t = A->getNumDimAndSymbolIds(), + e = B->getNumDimAndSymbolIds(); + t < e; t++) { + A->addSymbolId(A->getNumSymbolIds()); + A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t)); + } + } + assert(areIdsAligned(*A, *B) && "IDs expected to be aligned"); +} + +// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'. +void FlatAffineConstraints::mergeAndAlignIdsWithOther( + unsigned offset, FlatAffineConstraints *other) { + mergeAndAlignIds(offset, this, other); +} + +// This routine may add additional local variables if the flattened expression +// corresponding to the map has such variables due to mod's, ceildiv's, and +// floordiv's in it. +LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) { + std::vector<SmallVector<int64_t, 8>> flatExprs; + FlatAffineConstraints localCst; + if (failed(getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, + &localCst))) { + LLVM_DEBUG(llvm::dbgs() + << "composition unimplemented for semi-affine maps\n"); + return failure(); + } + assert(flatExprs.size() == vMap->getNumResults()); + + // Add localCst information. + if (localCst.getNumLocalIds() > 0) { + localCst.setIdValues(0, /*end=*/localCst.getNumDimAndSymbolIds(), + /*values=*/vMap->getOperands()); + // Align localCst and this. + mergeAndAlignIds(/*offset=*/0, &localCst, this); + // Finally, append localCst to this constraint set. + append(localCst); + } + + // Add dimensions corresponding to the map's results. + for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { + // TODO: Consider using a batched version to add a range of IDs. + addDimId(0); + } + + // We add one equality for each result connecting the result dim of the map to + // the other identifiers. + // For eg: if the expression is 16*i0 + i1, and this is the r^th + // iteration/result of the value map, we are adding the equality: + // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we + // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. + for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { + const auto &flatExpr = flatExprs[r]; + assert(flatExpr.size() >= vMap->getNumOperands() + 1); + + // eqToAdd is the equality corresponding to the flattened affine expression. + SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); + // Set the coefficient for this result to one. + eqToAdd[r] = 1; + + // Dims and symbols. + for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { + unsigned loc; + bool ret = findId(*vMap->getOperand(i), &loc); + assert(ret && "value map's id can't be found"); + (void)ret; + // Negate 'eq[r]' since the newly added dimension will be set to this one. + eqToAdd[loc] = -flatExpr[i]; + } + // Local vars common to eq and localCst are at the beginning. + unsigned j = getNumDimIds() + getNumSymbolIds(); + unsigned end = flatExpr.size() - 1; + for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) { + eqToAdd[j] = -flatExpr[i]; + } + + // Constant term. + eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; + + // Add the equality connecting the result of the map to this constraint set. + addEquality(eqToAdd); + } + + return success(); +} + +// Similar to composeMap except that no Value's need be associated with the +// constraint system nor are they looked at -- since the dimensions and +// symbols of 'other' are expected to correspond 1:1 to 'this' system. It +// is thus not convenient to share code with composeMap. +LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) { + assert(other.getNumDims() == getNumDimIds() && "dim mismatch"); + assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch"); + + std::vector<SmallVector<int64_t, 8>> flatExprs; + FlatAffineConstraints localCst; + if (failed(getFlattenedAffineExprs(other, &flatExprs, &localCst))) { + LLVM_DEBUG(llvm::dbgs() + << "composition unimplemented for semi-affine maps\n"); + return failure(); + } + assert(flatExprs.size() == other.getNumResults()); + + // Add localCst information. + if (localCst.getNumLocalIds() > 0) { + // Place local id's of A after local id's of B. + for (unsigned l = 0, e = localCst.getNumLocalIds(); l < e; l++) { + addLocalId(0); + } + // Finally, append localCst to this constraint set. + append(localCst); + } + + // Add dimensions corresponding to the map's results. + for (unsigned t = 0, e = other.getNumResults(); t < e; t++) { + addDimId(0); + } + + // We add one equality for each result connecting the result dim of the map to + // the other identifiers. + // For eg: if the expression is 16*i0 + i1, and this is the r^th + // iteration/result of the value map, we are adding the equality: + // d_r - 16*i0 - i1 = 0. Hence, when flattening say (i0 + 1, i0 + 8*i2), we + // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. + for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { + const auto &flatExpr = flatExprs[r]; + assert(flatExpr.size() >= other.getNumInputs() + 1); + + // eqToAdd is the equality corresponding to the flattened affine expression. + SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0); + // Set the coefficient for this result to one. + eqToAdd[r] = 1; + + // Dims and symbols. + for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) { + // Negate 'eq[r]' since the newly added dimension will be set to this one. + eqToAdd[e + i] = -flatExpr[i]; + } + // Local vars common to eq and localCst are at the beginning. + unsigned j = getNumDimIds() + getNumSymbolIds(); + unsigned end = flatExpr.size() - 1; + for (unsigned i = other.getNumInputs(); i < end; i++, j++) { + eqToAdd[j] = -flatExpr[i]; + } + + // Constant term. + eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1]; + + // Add the equality connecting the result of the map to this constraint set. + addEquality(eqToAdd); + } + + return success(); +} + +// Turn a dimension into a symbol. +static void turnDimIntoSymbol(FlatAffineConstraints *cst, Value id) { + unsigned pos; + if (cst->findId(id, &pos) && pos < cst->getNumDimIds()) { + swapId(cst, pos, cst->getNumDimIds() - 1); + cst->setDimSymbolSeparation(cst->getNumSymbolIds() + 1); + } +} + +// Turn a symbol into a dimension. +static void turnSymbolIntoDim(FlatAffineConstraints *cst, Value id) { + unsigned pos; + if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() && + pos < cst->getNumDimAndSymbolIds()) { + swapId(cst, pos, cst->getNumDimIds()); + cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1); + } +} + +// Changes all symbol identifiers which are loop IVs to dim identifiers. +void FlatAffineConstraints::convertLoopIVSymbolsToDims() { + // Gather all symbols which are loop IVs. + SmallVector<Value, 4> loopIVs; + for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) { + if (ids[i].hasValue() && getForInductionVarOwner(ids[i].getValue())) + loopIVs.push_back(ids[i].getValue()); + } + // Turn each symbol in 'loopIVs' into a dim identifier. + for (auto iv : loopIVs) { + turnSymbolIntoDim(this, *iv); + } +} + +void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) { + if (containsId(*id)) + return; + + // Caller is expected to fully compose map/operands if necessary. + assert((isTopLevelValue(id) || isForInductionVar(id)) && + "non-terminal symbol / loop IV expected"); + // Outer loop IVs could be used in forOp's bounds. + if (auto loop = getForInductionVarOwner(id)) { + addDimId(getNumDimIds(), id); + if (failed(this->addAffineForOpDomain(loop))) + LLVM_DEBUG( + loop.emitWarning("failed to add domain info to constraint system")); + return; + } + // Add top level symbol. + addSymbolId(getNumSymbolIds(), id); + // Check if the symbol is a constant. + if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id->getDefiningOp())) + setIdToConstant(*id, constOp.getValue()); +} + +LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) { + unsigned pos; + // Pre-condition for this method. + if (!findId(*forOp.getInductionVar(), &pos)) { + assert(false && "Value not found"); + return failure(); + } + + int64_t step = forOp.getStep(); + if (step != 1) { + if (!forOp.hasConstantLowerBound()) + forOp.emitWarning("domain conservatively approximated"); + else { + // Add constraints for the stride. + // (iv - lb) % step = 0 can be written as: + // (iv - lb) - step * q = 0 where q = (iv - lb) / step. + // Add local variable 'q' and add the above equality. + // The first constraint is q = (iv - lb) floordiv step + SmallVector<int64_t, 8> dividend(getNumCols(), 0); + int64_t lb = forOp.getConstantLowerBound(); + dividend[pos] = 1; + dividend.back() -= lb; + addLocalFloorDiv(dividend, step); + // Second constraint: (iv - lb) - step * q = 0. + SmallVector<int64_t, 8> eq(getNumCols(), 0); + eq[pos] = 1; + eq.back() -= lb; + // For the local var just added above. + eq[getNumCols() - 2] = -step; + addEquality(eq); + } + } + + if (forOp.hasConstantLowerBound()) { + addConstantLowerBound(pos, forOp.getConstantLowerBound()); + } else { + // Non-constant lower bound case. + SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands().begin(), + forOp.getLowerBoundOperands().end()); + if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(), lbOperands, + /*eq=*/false, /*lower=*/true))) + return failure(); + } + + if (forOp.hasConstantUpperBound()) { + addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1); + return success(); + } + // Non-constant upper bound case. + SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands().begin(), + forOp.getUpperBoundOperands().end()); + return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(), ubOperands, + /*eq=*/false, /*lower=*/false); +} + +// Searches for a constraint with a non-zero coefficient at 'colIdx' in +// equality (isEq=true) or inequality (isEq=false) constraints. +// Returns true and sets row found in search in 'rowIdx'. +// Returns false otherwise. +static bool +findConstraintWithNonZeroAt(const FlatAffineConstraints &constraints, + unsigned colIdx, bool isEq, unsigned *rowIdx) { + auto at = [&](unsigned rowIdx) -> int64_t { + return isEq ? constraints.atEq(rowIdx, colIdx) + : constraints.atIneq(rowIdx, colIdx); + }; + unsigned e = + isEq ? constraints.getNumEqualities() : constraints.getNumInequalities(); + for (*rowIdx = 0; *rowIdx < e; ++(*rowIdx)) { + if (at(*rowIdx) != 0) { + return true; + } + } + return false; +} + +// Normalizes the coefficient values across all columns in 'rowIDx' by their +// GCD in equality or inequality constraints as specified by 'isEq'. +template <bool isEq> +static void normalizeConstraintByGCD(FlatAffineConstraints *constraints, + unsigned rowIdx) { + auto at = [&](unsigned colIdx) -> int64_t { + return isEq ? constraints->atEq(rowIdx, colIdx) + : constraints->atIneq(rowIdx, colIdx); + }; + uint64_t gcd = std::abs(at(0)); + for (unsigned j = 1, e = constraints->getNumCols(); j < e; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(at(j))); + } + if (gcd > 0 && gcd != 1) { + for (unsigned j = 0, e = constraints->getNumCols(); j < e; ++j) { + int64_t v = at(j) / static_cast<int64_t>(gcd); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } + } +} + +void FlatAffineConstraints::normalizeConstraintsByGCD() { + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + normalizeConstraintByGCD</*isEq=*/true>(this, i); + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + normalizeConstraintByGCD</*isEq=*/false>(this, i); + } +} + +bool FlatAffineConstraints::hasConsistentState() const { + if (inequalities.size() != getNumInequalities() * numReservedCols) + return false; + if (equalities.size() != getNumEqualities() * numReservedCols) + return false; + if (ids.size() != getNumIds()) + return false; + + // Catches errors where numDims, numSymbols, numIds aren't consistent. + if (numDims > numIds || numSymbols > numIds || numDims + numSymbols > numIds) + return false; + + return true; +} + +/// Checks all rows of equality/inequality constraints for trivial +/// contradictions (for example: 1 == 0, 0 >= 1), which may have surfaced +/// after elimination. Returns 'true' if an invalid constraint is found; +/// 'false' otherwise. +bool FlatAffineConstraints::hasInvalidConstraint() const { + assert(hasConsistentState()); + auto check = [&](bool isEq) -> bool { + unsigned numCols = getNumCols(); + unsigned numRows = isEq ? getNumEqualities() : getNumInequalities(); + for (unsigned i = 0, e = numRows; i < e; ++i) { + unsigned j; + for (j = 0; j < numCols - 1; ++j) { + int64_t v = isEq ? atEq(i, j) : atIneq(i, j); + // Skip rows with non-zero variable coefficients. + if (v != 0) + break; + } + if (j < numCols - 1) { + continue; + } + // Check validity of constant term at 'numCols - 1' w.r.t 'isEq'. + // Example invalid constraints include: '1 == 0' or '-1 >= 0' + int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1); + if ((isEq && v != 0) || (!isEq && v < 0)) { + return true; + } + } + return false; + }; + if (check(/*isEq=*/true)) + return true; + return check(/*isEq=*/false); +} + +// Eliminate identifier from constraint at 'rowIdx' based on coefficient at +// pivotRow, pivotCol. Columns in range [elimColStart, pivotCol) will not be +// updated as they have already been eliminated. +static void eliminateFromConstraint(FlatAffineConstraints *constraints, + unsigned rowIdx, unsigned pivotRow, + unsigned pivotCol, unsigned elimColStart, + bool isEq) { + // Skip if equality 'rowIdx' if same as 'pivotRow'. + if (isEq && rowIdx == pivotRow) + return; + auto at = [&](unsigned i, unsigned j) -> int64_t { + return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j); + }; + int64_t leadCoeff = at(rowIdx, pivotCol); + // Skip if leading coefficient at 'rowIdx' is already zero. + if (leadCoeff == 0) + return; + int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol); + int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1; + int64_t lcm = mlir::lcm(pivotCoeff, leadCoeff); + int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff)); + int64_t rowMultiplier = lcm / std::abs(leadCoeff); + + unsigned numCols = constraints->getNumCols(); + for (unsigned j = 0; j < numCols; ++j) { + // Skip updating column 'j' if it was just eliminated. + if (j >= elimColStart && j < pivotCol) + continue; + int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) + + rowMultiplier * at(rowIdx, j); + isEq ? constraints->atEq(rowIdx, j) = v + : constraints->atIneq(rowIdx, j) = v; + } +} + +// Remove coefficients in column range [colStart, colLimit) in place. +// This removes in data in the specified column range, and copies any +// remaining valid data into place. +static void shiftColumnsToLeft(FlatAffineConstraints *constraints, + unsigned colStart, unsigned colLimit, + bool isEq) { + assert(colLimit <= constraints->getNumIds()); + if (colLimit <= colStart) + return; + + unsigned numCols = constraints->getNumCols(); + unsigned numRows = isEq ? constraints->getNumEqualities() + : constraints->getNumInequalities(); + unsigned numToEliminate = colLimit - colStart; + for (unsigned r = 0, e = numRows; r < e; ++r) { + for (unsigned c = colLimit; c < numCols; ++c) { + if (isEq) { + constraints->atEq(r, c - numToEliminate) = constraints->atEq(r, c); + } else { + constraints->atIneq(r, c - numToEliminate) = constraints->atIneq(r, c); + } + } + } +} + +// Removes identifiers in column range [idStart, idLimit), and copies any +// remaining valid data into place, and updates member variables. +void FlatAffineConstraints::removeIdRange(unsigned idStart, unsigned idLimit) { + assert(idLimit < getNumCols() && "invalid id limit"); + + if (idStart >= idLimit) + return; + + // We are going to be removing one or more identifiers from the range. + assert(idStart < numIds && "invalid idStart position"); + + // TODO(andydavis) Make 'removeIdRange' a lambda called from here. + // Remove eliminated identifiers from equalities. + shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/true); + + // Remove eliminated identifiers from inequalities. + shiftColumnsToLeft(this, idStart, idLimit, /*isEq=*/false); + + // Update members numDims, numSymbols and numIds. + unsigned numDimsEliminated = 0; + unsigned numLocalsEliminated = 0; + unsigned numColsEliminated = idLimit - idStart; + if (idStart < numDims) { + numDimsEliminated = std::min(numDims, idLimit) - idStart; + } + // Check how many local id's were removed. Note that our identifier order is + // [dims, symbols, locals]. Local id start at position numDims + numSymbols. + if (idLimit > numDims + numSymbols) { + numLocalsEliminated = std::min( + idLimit - std::max(idStart, numDims + numSymbols), getNumLocalIds()); + } + unsigned numSymbolsEliminated = + numColsEliminated - numDimsEliminated - numLocalsEliminated; + + numDims -= numDimsEliminated; + numSymbols -= numSymbolsEliminated; + numIds = numIds - numColsEliminated; + + ids.erase(ids.begin() + idStart, ids.begin() + idLimit); + + // No resize necessary. numReservedCols remains the same. +} + +/// Returns the position of the identifier that has the minimum <number of lower +/// bounds> times <number of upper bounds> from the specified range of +/// identifiers [start, end). It is often best to eliminate in the increasing +/// order of these counts when doing Fourier-Motzkin elimination since FM adds +/// that many new constraints. +static unsigned getBestIdToEliminate(const FlatAffineConstraints &cst, + unsigned start, unsigned end) { + assert(start < cst.getNumIds() && end < cst.getNumIds() + 1); + + auto getProductOfNumLowerUpperBounds = [&](unsigned pos) { + unsigned numLb = 0; + unsigned numUb = 0; + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) > 0) { + ++numLb; + } else if (cst.atIneq(r, pos) < 0) { + ++numUb; + } + } + return numLb * numUb; + }; + + unsigned minLoc = start; + unsigned min = getProductOfNumLowerUpperBounds(start); + for (unsigned c = start + 1; c < end; c++) { + unsigned numLbUbProduct = getProductOfNumLowerUpperBounds(c); + if (numLbUbProduct < min) { + min = numLbUbProduct; + minLoc = c; + } + } + return minLoc; +} + +// Checks for emptiness of the set by eliminating identifiers successively and +// using the GCD test (on all equality constraints) and checking for trivially +// invalid constraints. Returns 'true' if the constraint system is found to be +// empty; false otherwise. +bool FlatAffineConstraints::isEmpty() const { + if (isEmptyByGCDTest() || hasInvalidConstraint()) + return true; + + // First, eliminate as many identifiers as possible using Gaussian + // elimination. + FlatAffineConstraints tmpCst(*this); + unsigned currentPos = 0; + while (currentPos < tmpCst.getNumIds()) { + tmpCst.gaussianEliminateIds(currentPos, tmpCst.getNumIds()); + ++currentPos; + // We check emptiness through trivial checks after eliminating each ID to + // detect emptiness early. Since the checks isEmptyByGCDTest() and + // hasInvalidConstraint() are linear time and single sweep on the constraint + // buffer, this appears reasonable - but can optimize in the future. + if (tmpCst.hasInvalidConstraint() || tmpCst.isEmptyByGCDTest()) + return true; + } + + // Eliminate the remaining using FM. + for (unsigned i = 0, e = tmpCst.getNumIds(); i < e; i++) { + tmpCst.FourierMotzkinEliminate( + getBestIdToEliminate(tmpCst, 0, tmpCst.getNumIds())); + // Check for a constraint explosion. This rarely happens in practice, but + // this check exists as a safeguard against improperly constructed + // constraint systems or artificially created arbitrarily complex systems + // that aren't the intended use case for FlatAffineConstraints. This is + // needed since FM has a worst case exponential complexity in theory. + if (tmpCst.getNumConstraints() >= kExplosionFactor * getNumIds()) { + LLVM_DEBUG(llvm::dbgs() << "FM constraint explosion detected\n"); + return false; + } + + // FM wouldn't have modified the equalities in any way. So no need to again + // run GCD test. Check for trivial invalid constraints. + if (tmpCst.hasInvalidConstraint()) + return true; + } + return false; +} + +// Runs the GCD test on all equality constraints. Returns 'true' if this test +// fails on any equality. Returns 'false' otherwise. +// This test can be used to disprove the existence of a solution. If it returns +// true, no integer solution to the equality constraints can exist. +// +// GCD test definition: +// +// The equality constraint: +// +// c_1*x_1 + c_2*x_2 + ... + c_n*x_n = c_0 +// +// has an integer solution iff: +// +// GCD of c_1, c_2, ..., c_n divides c_0. +// +bool FlatAffineConstraints::isEmptyByGCDTest() const { + assert(hasConsistentState()); + unsigned numCols = getNumCols(); + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + uint64_t gcd = std::abs(atEq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atEq(i, j))); + } + int64_t v = std::abs(atEq(i, numCols - 1)); + if (gcd > 0 && (v % gcd != 0)) { + return true; + } + } + return false; +} + +/// Tightens inequalities given that we are dealing with integer spaces. This is +/// analogous to the GCD test but applied to inequalities. The constant term can +/// be reduced to the preceding multiple of the GCD of the coefficients, i.e., +/// 64*i - 100 >= 0 => 64*i - 128 >= 0 (since 'i' is an integer). This is a +/// fast method - linear in the number of coefficients. +// Example on how this affects practical cases: consider the scenario: +// 64*i >= 100, j = 64*i; without a tightening, elimination of i would yield +// j >= 100 instead of the tighter (exact) j >= 128. +void FlatAffineConstraints::GCDTightenInequalities() { + unsigned numCols = getNumCols(); + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + uint64_t gcd = std::abs(atIneq(i, 0)); + for (unsigned j = 1; j < numCols - 1; ++j) { + gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(atIneq(i, j))); + } + if (gcd > 0 && gcd != 1) { + int64_t gcdI = static_cast<int64_t>(gcd); + // Tighten the constant term and normalize the constraint by the GCD. + atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcdI); + for (unsigned j = 0, e = numCols - 1; j < e; ++j) + atIneq(i, j) /= gcdI; + } + } +} + +// Eliminates all identifier variables in column range [posStart, posLimit). +// Returns the number of variables eliminated. +unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart, + unsigned posLimit) { + // Return if identifier positions to eliminate are out of range. + assert(posLimit <= numIds); + assert(hasConsistentState()); + + if (posStart >= posLimit) + return 0; + + GCDTightenInequalities(); + + unsigned pivotCol = 0; + for (pivotCol = posStart; pivotCol < posLimit; ++pivotCol) { + // Find a row which has a non-zero coefficient in column 'j'. + unsigned pivotRow; + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/true, + &pivotRow)) { + // No pivot row in equalities with non-zero at 'pivotCol'. + if (!findConstraintWithNonZeroAt(*this, pivotCol, /*isEq=*/false, + &pivotRow)) { + // If inequalities are also non-zero in 'pivotCol', it can be + // eliminated. + continue; + } + break; + } + + // Eliminate identifier at 'pivotCol' from each equality row. + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/true); + normalizeConstraintByGCD</*isEq=*/true>(this, i); + } + + // Eliminate identifier at 'pivotCol' from each inequality row. + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + eliminateFromConstraint(this, i, pivotRow, pivotCol, posStart, + /*isEq=*/false); + normalizeConstraintByGCD</*isEq=*/false>(this, i); + } + removeEquality(pivotRow); + GCDTightenInequalities(); + } + // Update position limit based on number eliminated. + posLimit = pivotCol; + // Remove eliminated columns from all constraints. + removeIdRange(posStart, posLimit); + return posLimit - posStart; +} + +// Detect the identifier at 'pos' (say id_r) as modulo of another identifier +// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q) +// could be detected as the floordiv of n. For eg: +// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=> +// id_r = id_n mod 4, id_q = id_n floordiv 4. +// lbConst and ubConst are the constant lower and upper bounds for 'pos' - +// pre-detected at the caller. +static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos, + int64_t lbConst, int64_t ubConst, + SmallVectorImpl<AffineExpr> *memo) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to + // id_n - divisor * id_q. If these are true, then id_n becomes the dividend + // and id_q the quotient when dividing id_n by the divisor. + + if (lbConst != 0 || ubConst < 1) + return false; + + int64_t divisor = ubConst + 1; + + // Now check for: id_r = id_n - divisor * id_q. As an example, we + // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0. + unsigned seenQuotient = 0, seenDividend = 0; + int quotientPos = -1, dividendPos = -1; + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + // id_n should have coeff 1 or -1. + if (std::abs(cst.atEq(r, pos)) != 1) + continue; + // constant term should be 0. + if (cst.atEq(r, cst.getNumCols() - 1) != 0) + continue; + unsigned c, f; + int quotientSign = 1, dividendSign = 1; + for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) { + if (c == pos) + continue; + // The coefficient of the quotient should be +/-divisor. + // TODO(bondhugula): could be extended to detect an affine function for + // the quotient (i.e., the coeff could be a non-zero multiple of divisor). + int64_t v = cst.atEq(r, c) * cst.atEq(r, pos); + if (v == divisor || v == -divisor) { + seenQuotient++; + quotientPos = c; + quotientSign = v > 0 ? 1 : -1; + } + // The coefficient of the dividend should be +/-1. + // TODO(bondhugula): could be extended to detect an affine function of + // the other identifiers as the dividend. + else if (v == -1 || v == 1) { + seenDividend++; + dividendPos = c; + dividendSign = v < 0 ? 1 : -1; + } else if (cst.atEq(r, c) != 0) { + // Cannot be inferred as a mod since the constraint has a coefficient + // for an identifier that's neither a unit nor the divisor (see TODOs + // above). + break; + } + } + if (c < f) + // Cannot be inferred as a mod since the constraint has a coefficient for + // an identifier that's neither a unit nor the divisor (see TODOs above). + continue; + + // We are looking for exactly one identifier as the dividend. + if (seenDividend == 1 && seenQuotient >= 1) { + if (!(*memo)[dividendPos]) + return false; + // Successfully detected a mod. + (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign; + auto ub = cst.getConstantUpperBound(dividendPos); + if (ub.hasValue() && ub.getValue() < divisor) + // The mod can be optimized away. + (*memo)[pos] = (*memo)[dividendPos] * dividendSign; + else + (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign; + + if (seenQuotient == 1 && !(*memo)[quotientPos]) + // Successfully detected a floordiv as well. + (*memo)[quotientPos] = + (*memo)[dividendPos].floorDiv(divisor) * quotientSign; + return true; + } + } + return false; +} + +// Gather lower and upper bounds for the pos^th identifier. +static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst, + unsigned pos, + SmallVectorImpl<unsigned> *lbIndices, + SmallVectorImpl<unsigned> *ubIndices) { + assert(pos < cst.getNumIds() && "invalid position"); + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) { + if (cst.atIneq(r, pos) >= 1) { + // Lower bound. + lbIndices->push_back(r); + } else if (cst.atIneq(r, pos) <= -1) { + // Upper bound. + ubIndices->push_back(r); + } + } +} + +// Check if the pos^th identifier can be expressed as a floordiv of an affine +// function of other identifiers (where the divisor is a positive constant). +// For eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. +bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos, + SmallVectorImpl<AffineExpr> *memo, MLIRContext *context) { + assert(pos < cst.getNumIds() && "invalid position"); + + SmallVector<unsigned, 4> lbIndices, ubIndices; + getLowerAndUpperBoundIndices(cst, pos, &lbIndices, &ubIndices); + + // Check if any lower bound, upper bound pair is of the form: + // divisor * id >= expr - (divisor - 1) <-- Lower bound for 'id' + // divisor * id <= expr <-- Upper bound for 'id' + // Then, 'id' is equivalent to 'expr floordiv divisor'. (where divisor > 1). + // + // For example, if -32*k + 16*i + j >= 0 + // 32*k - 16*i - j + 31 >= 0 <=> + // k = ( 16*i + j ) floordiv 32 + unsigned seenDividends = 0; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Check if lower bound's constant term is 'divisor - 1'. The 'divisor' + // here is cst.atIneq(lbPos, pos) and we already know that it's positive + // (since cst.Ineq(lbPos, ...) is a lower bound expression for 'pos'. + if (cst.atIneq(lbPos, cst.getNumCols() - 1) != cst.atIneq(lbPos, pos) - 1) + continue; + // Check if upper bound's constant term is 0. + if (cst.atIneq(ubPos, cst.getNumCols() - 1) != 0) + continue; + // For the remaining part, check if the lower bound expr's coeff's are + // negations of corresponding upper bound ones'. + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (cst.atIneq(lbPos, c) != -cst.atIneq(ubPos, c)) + break; + if (c != pos && cst.atIneq(lbPos, c) != 0) + seenDividends++; + } + // Lb coeff's aren't negative of ub coeff's (for the non constant term + // part). + if (c < f) + continue; + if (seenDividends >= 1) { + // The divisor is the constant term of the lower bound expression. + // We already know that cst.atIneq(lbPos, pos) > 0. + int64_t divisor = cst.atIneq(lbPos, pos); + // Construct the dividend expression. + auto dividendExpr = getAffineConstantExpr(0, context); + unsigned c, f; + for (c = 0, f = cst.getNumCols() - 1; c < f; c++) { + if (c == pos) + continue; + int64_t ubVal = cst.atIneq(ubPos, c); + if (ubVal == 0) + continue; + if (!(*memo)[c]) + break; + dividendExpr = dividendExpr + ubVal * (*memo)[c]; + } + // Expression can't be constructed as it depends on a yet unknown + // identifier. + // TODO(mlir-team): Visit/compute the identifiers in an order so that + // this doesn't happen. More complex but much more efficient. + if (c < f) + continue; + // Successfully detected the floordiv. + (*memo)[pos] = dividendExpr.floorDiv(divisor); + return true; + } + } + } + return false; +} + +// Fills an inequality row with the value 'val'. +static inline void fillInequality(FlatAffineConstraints *cst, unsigned r, + int64_t val) { + for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { + cst->atIneq(r, c) = val; + } +} + +// Negates an inequality. +static inline void negateInequality(FlatAffineConstraints *cst, unsigned r) { + for (unsigned c = 0, f = cst->getNumCols(); c < f; c++) { + cst->atIneq(r, c) = -cst->atIneq(r, c); + } +} + +// A more complex check to eliminate redundant inequalities. Uses FourierMotzkin +// to check if a constraint is redundant. +void FlatAffineConstraints::removeRedundantInequalities() { + SmallVector<bool, 32> redun(getNumInequalities(), false); + // To check if an inequality is redundant, we replace the inequality by its + // complement (for eg., i - 1 >= 0 by i <= 0), and check if the resulting + // system is empty. If it is, the inequality is redundant. + FlatAffineConstraints tmpCst(*this); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + // Change the inequality to its complement. + negateInequality(&tmpCst, r); + tmpCst.atIneq(r, tmpCst.getNumCols() - 1)--; + if (tmpCst.isEmpty()) { + redun[r] = true; + // Zero fill the redundant inequality. + fillInequality(this, r, /*val=*/0); + fillInequality(&tmpCst, r, /*val=*/0); + } else { + // Reverse the change (to avoid recreating tmpCst each time). + tmpCst.atIneq(r, tmpCst.getNumCols() - 1)++; + negateInequality(&tmpCst, r); + } + } + + // Scan to get rid of all rows marked redundant, in-place. + auto copyRow = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) { + atIneq(dest, c) = atIneq(src, c); + } + }; + unsigned pos = 0; + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (!redun[r]) + copyRow(r, pos++); + } + inequalities.resize(numReservedCols * pos); +} + +std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound( + unsigned pos, unsigned offset, unsigned num, unsigned symStartPos, + ArrayRef<AffineExpr> localExprs, MLIRContext *context) const { + assert(pos + offset < getNumDimIds() && "invalid dim start pos"); + assert(symStartPos >= (pos + offset) && "invalid sym start pos"); + assert(getNumLocalIds() == localExprs.size() && + "incorrect local exprs count"); + + SmallVector<unsigned, 4> lbIndices, ubIndices; + getLowerAndUpperBoundIndices(*this, pos + offset, &lbIndices, &ubIndices); + + /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos). + auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) { + b.clear(); + for (unsigned i = 0, e = a.size(); i < e; ++i) { + if (i < offset || i >= offset + num) + b.push_back(a[i]); + } + }; + + SmallVector<int64_t, 8> lb, ub; + SmallVector<AffineExpr, 4> exprs; + unsigned dimCount = symStartPos - num; + unsigned symCount = getNumDimAndSymbolIds() - symStartPos; + exprs.reserve(lbIndices.size()); + // Lower bound expressions. + for (auto idx : lbIndices) { + auto ineq = getInequality(idx); + // Extract the lower bound (in terms of other coeff's + const), i.e., if + // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j + // - 1. + addCoeffs(ineq, lb); + std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>()); + auto expr = mlir::toAffineExpr(lb, dimCount, symCount, localExprs, context); + exprs.push_back(expr); + } + auto lbMap = + exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs); + + exprs.clear(); + exprs.reserve(ubIndices.size()); + // Upper bound expressions. + for (auto idx : ubIndices) { + auto ineq = getInequality(idx); + // Extract the upper bound (in terms of other coeff's + const). + addCoeffs(ineq, ub); + auto expr = mlir::toAffineExpr(ub, dimCount, symCount, localExprs, context); + // Upper bound is exclusive. + exprs.push_back(expr + 1); + } + auto ubMap = + exprs.empty() ? AffineMap() : AffineMap::get(dimCount, symCount, exprs); + + return {lbMap, ubMap}; +} + +/// Computes the lower and upper bounds of the first 'num' dimensional +/// identifiers (starting at 'offset') as affine maps of the remaining +/// identifiers (dimensional and symbolic identifiers). Local identifiers are +/// themselves explicitly computed as affine functions of other identifiers in +/// this process if needed. +void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num, + MLIRContext *context, + SmallVectorImpl<AffineMap> *lbMaps, + SmallVectorImpl<AffineMap> *ubMaps) { + assert(num < getNumDimIds() && "invalid range"); + + // Basic simplification. + normalizeConstraintsByGCD(); + + LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num + << " identifiers\n"); + LLVM_DEBUG(dump()); + + // Record computed/detected identifiers. + SmallVector<AffineExpr, 8> memo(getNumIds()); + // Initialize dimensional and symbolic identifiers. + for (unsigned i = 0, e = getNumDimIds(); i < e; i++) { + if (i < offset) + memo[i] = getAffineDimExpr(i, context); + else if (i >= offset + num) + memo[i] = getAffineDimExpr(i - num, context); + } + for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) + memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context); + + bool changed; + do { + changed = false; + // Identify yet unknown identifiers as constants or mod's / floordiv's of + // other identifiers if possible. + for (unsigned pos = 0; pos < getNumIds(); pos++) { + if (memo[pos]) + continue; + + auto lbConst = getConstantLowerBound(pos); + auto ubConst = getConstantUpperBound(pos); + if (lbConst.hasValue() && ubConst.hasValue()) { + // Detect equality to a constant. + if (lbConst.getValue() == ubConst.getValue()) { + memo[pos] = getAffineConstantExpr(lbConst.getValue(), context); + changed = true; + continue; + } + + // Detect an identifier as modulo of another identifier w.r.t a + // constant. + if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(), + &memo)) { + changed = true; + continue; + } + } + + // Detect an identifier as floordiv of another identifier w.r.t a + // constant. + if (detectAsFloorDiv(*this, pos, &memo, context)) { + changed = true; + continue; + } + + // Detect an identifier as an expression of other identifiers. + unsigned idx; + if (!findConstraintWithNonZeroAt(*this, pos, /*isEq=*/true, &idx)) { + continue; + } + + // Build AffineExpr solving for identifier 'pos' in terms of all others. + auto expr = getAffineConstantExpr(0, context); + unsigned j, e; + for (j = 0, e = getNumIds(); j < e; ++j) { + if (j == pos) + continue; + int64_t c = atEq(idx, j); + if (c == 0) + continue; + // If any of the involved IDs hasn't been found yet, we can't proceed. + if (!memo[j]) + break; + expr = expr + memo[j] * c; + } + if (j < e) + // Can't construct expression as it depends on a yet uncomputed + // identifier. + continue; + + // Add constant term to AffineExpr. + expr = expr + atEq(idx, getNumIds()); + int64_t vPos = atEq(idx, pos); + assert(vPos != 0 && "expected non-zero here"); + if (vPos > 0) + expr = (-expr).floorDiv(vPos); + else + // vPos < 0. + expr = expr.floorDiv(-vPos); + // Successfully constructed expression. + memo[pos] = expr; + changed = true; + } + // This loop is guaranteed to reach a fixed point - since once an + // identifier's explicit form is computed (in memo[pos]), it's not updated + // again. + } while (changed); + + // Set the lower and upper bound maps for all the identifiers that were + // computed as affine expressions of the rest as the "detected expr" and + // "detected expr + 1" respectively; set the undetected ones to null. + Optional<FlatAffineConstraints> tmpClone; + for (unsigned pos = 0; pos < num; pos++) { + unsigned numMapDims = getNumDimIds() - num; + unsigned numMapSymbols = getNumSymbolIds(); + AffineExpr expr = memo[pos + offset]; + if (expr) + expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols); + + AffineMap &lbMap = (*lbMaps)[pos]; + AffineMap &ubMap = (*ubMaps)[pos]; + + if (expr) { + lbMap = AffineMap::get(numMapDims, numMapSymbols, expr); + ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1); + } else { + // TODO(bondhugula): Whenever there are local identifiers in the + // dependence constraints, we'll conservatively over-approximate, since we + // don't always explicitly compute them above (in the while loop). + if (getNumLocalIds() == 0) { + // Work on a copy so that we don't update this constraint system. + if (!tmpClone) { + tmpClone.emplace(FlatAffineConstraints(*this)); + // Removing redundant inequalities is necessary so that we don't get + // redundant loop bounds. + tmpClone->removeRedundantInequalities(); + } + std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound( + pos, offset, num, getNumDimIds(), {}, context); + } + + // If the above fails, we'll just use the constant lower bound and the + // constant upper bound (if they exist) as the slice bounds. + // TODO(b/126426796): being conservative for the moment in cases that + // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is + // fixed (b/126426796). + if (!lbMap || lbMap.getNumResults() > 1) { + LLVM_DEBUG(llvm::dbgs() + << "WARNING: Potentially over-approximating slice lb\n"); + auto lbConst = getConstantLowerBound(pos + offset); + if (lbConst.hasValue()) { + lbMap = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(lbConst.getValue(), context)); + } + } + if (!ubMap || ubMap.getNumResults() > 1) { + LLVM_DEBUG(llvm::dbgs() + << "WARNING: Potentially over-approximating slice ub\n"); + auto ubConst = getConstantUpperBound(pos + offset); + if (ubConst.hasValue()) { + (ubMap) = AffineMap::get( + numMapDims, numMapSymbols, + getAffineConstantExpr(ubConst.getValue() + 1, context)); + } + } + } + LLVM_DEBUG(llvm::dbgs() + << "lb map for pos = " << Twine(pos + offset) << ", expr: "); + LLVM_DEBUG(lbMap.dump();); + LLVM_DEBUG(llvm::dbgs() + << "ub map for pos = " << Twine(pos + offset) << ", expr: "); + LLVM_DEBUG(ubMap.dump();); + } +} + +LogicalResult +FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap, + ArrayRef<Value> boundOperands, + bool eq, bool lower) { + assert(pos < getNumDimAndSymbolIds() && "invalid position"); + // Equality follows the logic of lower bound except that we add an equality + // instead of an inequality. + assert((!eq || boundMap.getNumResults() == 1) && "single result expected"); + if (eq) + lower = true; + + // Fully compose map and operands; canonicalize and simplify so that we + // transitively get to terminal symbols or loop IVs. + auto map = boundMap; + SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end()); + fullyComposeAffineMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + canonicalizeMapAndOperands(&map, &operands); + for (auto operand : operands) + addInductionVarOrTerminalSymbol(operand); + + FlatAffineConstraints localVarCst; + std::vector<SmallVector<int64_t, 8>> flatExprs; + if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst))) { + LLVM_DEBUG(llvm::dbgs() << "semi-affine expressions not yet supported\n"); + return failure(); + } + + // Merge and align with localVarCst. + if (localVarCst.getNumLocalIds() > 0) { + // Set values for localVarCst. + localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands); + for (auto operand : operands) { + unsigned pos; + if (findId(*operand, &pos)) { + if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) { + // If the local var cst has this as a dim, turn it into its symbol. + turnDimIntoSymbol(&localVarCst, *operand); + } else if (pos < getNumDimIds()) { + // Or vice versa. + turnSymbolIntoDim(&localVarCst, *operand); + } + } + } + mergeAndAlignIds(/*offset=*/0, this, &localVarCst); + append(localVarCst); + } + + // Record positions of the operands in the constraint system. Need to do + // this here since the constraint system changes after a bound is added. + SmallVector<unsigned, 8> positions; + unsigned numOperands = operands.size(); + for (auto operand : operands) { + unsigned pos; + if (!findId(*operand, &pos)) + assert(0 && "expected to be found"); + positions.push_back(pos); + } + + for (const auto &flatExpr : flatExprs) { + SmallVector<int64_t, 4> ineq(getNumCols(), 0); + ineq[pos] = lower ? 1 : -1; + // Dims and symbols. + for (unsigned j = 0, e = map.getNumInputs(); j < e; j++) { + ineq[positions[j]] = lower ? -flatExpr[j] : flatExpr[j]; + } + // Copy over the local id coefficients. + unsigned numLocalIds = flatExpr.size() - 1 - numOperands; + for (unsigned jj = 0, j = getNumIds() - numLocalIds; jj < numLocalIds; + jj++, j++) { + ineq[j] = + lower ? -flatExpr[numOperands + jj] : flatExpr[numOperands + jj]; + } + // Constant term. + ineq[getNumCols() - 1] = + lower ? -flatExpr[flatExpr.size() - 1] + // Upper bound in flattenedExpr is an exclusive one. + : flatExpr[flatExpr.size() - 1] - 1; + eq ? addEquality(ineq) : addInequality(ineq); + } + return success(); +} + +// Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper +// bounds in 'ubMaps' to each value in `values' that appears in the constraint +// system. Note that both lower/upper bounds share the same operand list +// 'operands'. +// This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and +// skips any null AffineMaps in 'lbMaps' or 'ubMaps'. +// Note that both lower/upper bounds use operands from 'operands'. +// Returns failure for unimplemented cases such as semi-affine expressions or +// expressions with mod/floordiv. +LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values, + ArrayRef<AffineMap> lbMaps, + ArrayRef<AffineMap> ubMaps, + ArrayRef<Value> operands) { + assert(values.size() == lbMaps.size()); + assert(lbMaps.size() == ubMaps.size()); + + for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) { + unsigned pos; + if (!findId(*values[i], &pos)) + continue; + + AffineMap lbMap = lbMaps[i]; + AffineMap ubMap = ubMaps[i]; + assert(!lbMap || lbMap.getNumInputs() == operands.size()); + assert(!ubMap || ubMap.getNumInputs() == operands.size()); + + // Check if this slice is just an equality along this dimension. + if (lbMap && ubMap && lbMap.getNumResults() == 1 && + ubMap.getNumResults() == 1 && + lbMap.getResult(0) + 1 == ubMap.getResult(0)) { + if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true, + /*lower=*/true))) + return failure(); + continue; + } + + if (lbMap && failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false, + /*lower=*/true))) + return failure(); + + if (ubMap && failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false, + /*lower=*/false))) + return failure(); + } + return success(); +} + +void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) { + assert(eq.size() == getNumCols()); + unsigned offset = equalities.size(); + equalities.resize(equalities.size() + numReservedCols); + std::copy(eq.begin(), eq.end(), equalities.begin() + offset); +} + +void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) { + assert(inEq.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::copy(inEq.begin(), inEq.end(), inequalities.begin() + offset); +} + +void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) { + assert(pos < getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + inequalities[offset + pos] = 1; + inequalities[offset + getNumCols() - 1] = -lb; +} + +void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) { + assert(pos < getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + inequalities[offset + pos] = -1; + inequalities[offset + getNumCols() - 1] = ub; +} + +void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr, + int64_t lb) { + assert(expr.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + std::copy(expr.begin(), expr.end(), inequalities.begin() + offset); + inequalities[offset + getNumCols() - 1] += -lb; +} + +void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr, + int64_t ub) { + assert(expr.size() == getNumCols()); + unsigned offset = inequalities.size(); + inequalities.resize(inequalities.size() + numReservedCols); + std::fill(inequalities.begin() + offset, + inequalities.begin() + offset + getNumCols(), 0); + for (unsigned i = 0, e = getNumCols(); i < e; i++) { + inequalities[offset + i] = -expr[i]; + } + inequalities[offset + getNumCols() - 1] += ub; +} + +/// Adds a new local identifier as the floordiv of an affine function of other +/// identifiers, the coefficients of which are provided in 'dividend' and with +/// respect to a positive constant 'divisor'. Two constraints are added to the +/// system to capture equivalence with the floordiv. +/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. +void FlatAffineConstraints::addLocalFloorDiv(ArrayRef<int64_t> dividend, + int64_t divisor) { + assert(dividend.size() == getNumCols() && "incorrect dividend size"); + assert(divisor > 0 && "positive divisor expected"); + + addLocalId(getNumLocalIds()); + + // Add two constraints for this new identifier 'q'. + SmallVector<int64_t, 8> bound(dividend.size() + 1); + + // dividend - q * divisor >= 0 + std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1, + bound.begin()); + bound.back() = dividend.back(); + bound[getNumIds() - 1] = -divisor; + addInequality(bound); + + // -dividend +qdivisor * q + divisor - 1 >= 0 + std::transform(bound.begin(), bound.end(), bound.begin(), + std::negate<int64_t>()); + bound[bound.size() - 1] += divisor - 1; + addInequality(bound); +} + +bool FlatAffineConstraints::findId(Value id, unsigned *pos) const { + unsigned i = 0; + for (const auto &mayBeId : ids) { + if (mayBeId.hasValue() && mayBeId.getValue() == id) { + *pos = i; + return true; + } + i++; + } + return false; +} + +bool FlatAffineConstraints::containsId(Value id) const { + return llvm::any_of(ids, [&](const Optional<Value> &mayBeId) { + return mayBeId.hasValue() && mayBeId.getValue() == id; + }); +} + +void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { + assert(newSymbolCount <= numDims + numSymbols && + "invalid separation position"); + numDims = numDims + numSymbols - newSymbolCount; + numSymbols = newSymbolCount; +} + +/// Sets the specified identifier to a constant value. +void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) { + unsigned offset = equalities.size(); + equalities.resize(equalities.size() + numReservedCols); + std::fill(equalities.begin() + offset, + equalities.begin() + offset + getNumCols(), 0); + equalities[offset + pos] = 1; + equalities[offset + getNumCols() - 1] = -val; +} + +/// Sets the specified identifier to a constant value; asserts if the id is not +/// found. +void FlatAffineConstraints::setIdToConstant(Value id, int64_t val) { + unsigned pos; + if (!findId(id, &pos)) + // This is a pre-condition for this method. + assert(0 && "id not found"); + setIdToConstant(pos, val); +} + +void FlatAffineConstraints::removeEquality(unsigned pos) { + unsigned numEqualities = getNumEqualities(); + assert(pos < numEqualities); + unsigned outputIndex = pos * numReservedCols; + unsigned inputIndex = (pos + 1) * numReservedCols; + unsigned numElemsToCopy = (numEqualities - pos - 1) * numReservedCols; + std::copy(equalities.begin() + inputIndex, + equalities.begin() + inputIndex + numElemsToCopy, + equalities.begin() + outputIndex); + equalities.resize(equalities.size() - numReservedCols); +} + +/// Finds an equality that equates the specified identifier to a constant. +/// Returns the position of the equality row. If 'symbolic' is set to true, +/// symbols are also treated like a constant, i.e., an affine function of the +/// symbols is also treated like a constant. +static int findEqualityToConstant(const FlatAffineConstraints &cst, + unsigned pos, bool symbolic = false) { + assert(pos < cst.getNumIds() && "invalid position"); + for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) { + int64_t v = cst.atEq(r, pos); + if (v * v != 1) + continue; + unsigned c; + unsigned f = symbolic ? cst.getNumDimIds() : cst.getNumIds(); + // This checks for zeros in all positions other than 'pos' in [0, f) + for (c = 0; c < f; c++) { + if (c == pos) + continue; + if (cst.atEq(r, c) != 0) { + // Dependent on another identifier. + break; + } + } + if (c == f) + // Equality is free of other identifiers. + return r; + } + return -1; +} + +void FlatAffineConstraints::setAndEliminate(unsigned pos, int64_t constVal) { + assert(pos < getNumIds() && "invalid position"); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + atIneq(r, getNumCols() - 1) += atIneq(r, pos) * constVal; + } + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + atEq(r, getNumCols() - 1) += atEq(r, pos) * constVal; + } + removeId(pos); +} + +LogicalResult FlatAffineConstraints::constantFoldId(unsigned pos) { + assert(pos < getNumIds() && "invalid position"); + int rowIdx; + if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) + return failure(); + + // atEq(rowIdx, pos) is either -1 or 1. + assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); + int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); + setAndEliminate(pos, constVal); + return success(); +} + +void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) { + for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { + if (failed(constantFoldId(t))) + t++; + } +} + +/// Returns the extent (upper bound - lower bound) of the specified +/// identifier if it is found to be a constant; returns None if it's not a +/// constant. This methods treats symbolic identifiers specially, i.e., +/// it looks for constant differences between affine expressions involving +/// only the symbolic identifiers. See comments at function definition for +/// example. 'lb', if provided, is set to the lower bound associated with the +/// constant difference. Note that 'lb' is purely symbolic and thus will contain +/// the coefficients of the symbolic identifiers and the constant coefficient. +// Egs: 0 <= i <= 15, return 16. +// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) +// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. +// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb = +// ceil(s0 - 7 / 8) = floor(s0 / 8)). +Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize( + unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor, + SmallVectorImpl<int64_t> *ub) const { + assert(pos < getNumDimIds() && "Invalid identifier position"); + assert(getNumLocalIds() == 0); + + // TODO(bondhugula): eliminate all remaining dimensional identifiers (other + // than the one at 'pos' to make this more powerful. Not needed for + // hyper-rectangular spaces. + + // Find an equality for 'pos'^th identifier that equates it to some function + // of the symbolic identifiers (+ constant). + int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true); + if (eqRow != -1) { + // This identifier can only take a single value. + if (lb) { + // Set lb to the symbolic value. + lb->resize(getNumSymbolIds() + 1); + if (ub) + ub->resize(getNumSymbolIds() + 1); + for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) { + int64_t v = atEq(eqRow, pos); + // atEq(eqRow, pos) is either -1 or 1. + assert(v * v == 1); + (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v + : -atEq(eqRow, getNumDimIds() + c) / v; + // Since this is an equality, ub = lb. + if (ub) + (*ub)[c] = (*lb)[c]; + } + assert(lbFloorDivisor && + "both lb and divisor or none should be provided"); + *lbFloorDivisor = 1; + } + return 1; + } + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + // Positions of constraints that are lower/upper bounds on the variable. + SmallVector<unsigned, 4> lbIndices, ubIndices; + + // Gather all symbolic lower bounds and upper bounds of the variable. Since + // the canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a + // lower bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned c, f; + for (c = 0, f = getNumDimIds(); c < f; c++) { + if (c != pos && atIneq(r, c) != 0) + break; + } + if (c < getNumDimIds()) + // Not a pure symbolic bound. + continue; + if (atIneq(r, pos) >= 1) + // Lower bound. + lbIndices.push_back(r); + else if (atIneq(r, pos) <= -1) + // Upper bound. + ubIndices.push_back(r); + } + + // TODO(bondhugula): eliminate other dimensional identifiers to make this more + // powerful. Not needed for hyper-rectangular iteration spaces. + + Optional<int64_t> minDiff = None; + unsigned minLbPosition, minUbPosition; + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + // Look for a lower bound and an upper bound that only differ by a + // constant, i.e., pairs of the form 0 <= c_pos - f(c_i's) <= diffConst. + // For example, if ii is the pos^th variable, we are looking for + // constraints like ii >= i, ii <= ii + 50, 50 being the difference. The + // minimum among all such constant differences is kept since that's the + // constant bounding the extent of the pos^th variable. + unsigned j, e; + for (j = 0, e = getNumCols() - 1; j < e; j++) + if (atIneq(ubPos, j) != -atIneq(lbPos, j)) { + break; + } + if (j < getNumCols() - 1) + continue; + int64_t diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) + + atIneq(lbPos, getNumCols() - 1) + 1, + atIneq(lbPos, pos)); + if (minDiff == None || diff < minDiff) { + minDiff = diff; + minLbPosition = lbPos; + minUbPosition = ubPos; + } + } + } + if (lb && minDiff.hasValue()) { + // Set lb to the symbolic lower bound. + lb->resize(getNumSymbolIds() + 1); + if (ub) + ub->resize(getNumSymbolIds() + 1); + // The lower bound is the ceildiv of the lb constraint over the coefficient + // of the variable at 'pos'. We express the ceildiv equivalently as a floor + // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + + // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). + *lbFloorDivisor = atIneq(minLbPosition, pos); + assert(*lbFloorDivisor == -atIneq(minUbPosition, pos)); + for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) { + (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c); + } + if (ub) { + for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) + (*ub)[c] = atIneq(minUbPosition, getNumDimIds() + c); + } + // The lower bound leads to a ceildiv while the upper bound is a floordiv + // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val + + // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to + // the constant term for the lower bound. + (*lb)[getNumSymbolIds()] += atIneq(minLbPosition, pos) - 1; + } + return minDiff; +} + +template <bool isLower> +Optional<int64_t> +FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) { + assert(pos < getNumIds() && "invalid position"); + // Project to 'pos'. + projectOut(0, pos); + projectOut(1, getNumIds() - 1); + // Check if there's an equality equating the '0'^th identifier to a constant. + int eqRowIdx = findEqualityToConstant(*this, 0, /*symbolic=*/false); + if (eqRowIdx != -1) + // atEq(rowIdx, 0) is either -1 or 1. + return -atEq(eqRowIdx, getNumCols() - 1) / atEq(eqRowIdx, 0); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, 0) != 0) + break; + } + if (r == e) + // If it doesn't, there isn't a bound on it. + return None; + + Optional<int64_t> minOrMaxConst = None; + + // Take the max across all const lower bounds (or min across all constant + // upper bounds). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (isLower) { + if (atIneq(r, 0) <= 0) + // Not a lower bound. + continue; + } else if (atIneq(r, 0) >= 0) { + // Not an upper bound. + continue; + } + unsigned c, f; + for (c = 0, f = getNumCols() - 1; c < f; c++) + if (c != 0 && atIneq(r, c) != 0) + break; + if (c < getNumCols() - 1) + // Not a constant bound. + continue; + + int64_t boundConst = + isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0)) + : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0)); + if (isLower) { + if (minOrMaxConst == None || boundConst > minOrMaxConst) + minOrMaxConst = boundConst; + } else { + if (minOrMaxConst == None || boundConst < minOrMaxConst) + minOrMaxConst = boundConst; + } + } + return minOrMaxConst; +} + +Optional<int64_t> +FlatAffineConstraints::getConstantLowerBound(unsigned pos) const { + FlatAffineConstraints tmpCst(*this); + return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos); +} + +Optional<int64_t> +FlatAffineConstraints::getConstantUpperBound(unsigned pos) const { + FlatAffineConstraints tmpCst(*this); + return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos); +} + +// A simple (naive and conservative) check for hyper-rectangularity. +bool FlatAffineConstraints::isHyperRectangular(unsigned pos, + unsigned num) const { + assert(pos < getNumCols() - 1); + // Check for two non-zero coefficients in the range [pos, pos + sum). + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + unsigned sum = 0; + for (unsigned c = pos; c < pos + num; c++) { + if (atIneq(r, c) != 0) + sum++; + } + if (sum > 1) + return false; + } + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + unsigned sum = 0; + for (unsigned c = pos; c < pos + num; c++) { + if (atEq(r, c) != 0) + sum++; + } + if (sum > 1) + return false; + } + return true; +} + +void FlatAffineConstraints::print(raw_ostream &os) const { + assert(hasConsistentState()); + os << "\nConstraints (" << getNumDimIds() << " dims, " << getNumSymbolIds() + << " symbols, " << getNumLocalIds() << " locals), (" << getNumConstraints() + << " constraints)\n"; + os << "("; + for (unsigned i = 0, e = getNumIds(); i < e; i++) { + if (ids[i] == None) + os << "None "; + else + os << "Value "; + } + os << " const)\n"; + for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) { + for (unsigned j = 0, f = getNumCols(); j < f; ++j) { + os << atEq(i, j) << " "; + } + os << "= 0\n"; + } + for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) { + for (unsigned j = 0, f = getNumCols(); j < f; ++j) { + os << atIneq(i, j) << " "; + } + os << ">= 0\n"; + } + os << '\n'; +} + +void FlatAffineConstraints::dump() const { print(llvm::errs()); } + +/// Removes duplicate constraints, trivially true constraints, and constraints +/// that can be detected as redundant as a result of differing only in their +/// constant term part. A constraint of the form <non-negative constant> >= 0 is +/// considered trivially true. +// Uses a DenseSet to hash and detect duplicates followed by a linear scan to +// remove duplicates in place. +void FlatAffineConstraints::removeTrivialRedundancy() { + SmallDenseSet<ArrayRef<int64_t>, 8> rowSet; + + // A map used to detect redundancy stemming from constraints that only differ + // in their constant term. The value stored is <row position, const term> + // for a given row. + SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>> + rowsWithoutConstTerm; + + // Check if constraint is of the form <non-negative-constant> >= 0. + auto isTriviallyValid = [&](unsigned r) -> bool { + for (unsigned c = 0, e = getNumCols() - 1; c < e; c++) { + if (atIneq(r, c) != 0) + return false; + } + return atIneq(r, getNumCols() - 1) >= 0; + }; + + // Detect and mark redundant constraints. + SmallVector<bool, 256> redunIneq(getNumInequalities(), false); + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + int64_t *rowStart = inequalities.data() + numReservedCols * r; + auto row = ArrayRef<int64_t>(rowStart, getNumCols()); + if (isTriviallyValid(r) || !rowSet.insert(row).second) { + redunIneq[r] = true; + continue; + } + + // Among constraints that only differ in the constant term part, mark + // everything other than the one with the smallest constant term redundant. + // (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the + // former two are redundant). + int64_t constTerm = atIneq(r, getNumCols() - 1); + auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1); + const auto &ret = + rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}}); + if (!ret.second) { + // Check if the other constraint has a higher constant term. + auto &val = ret.first->second; + if (val.second > constTerm) { + // The stored row is redundant. Mark it so, and update with this one. + redunIneq[val.first] = true; + val = {r, constTerm}; + } else { + // The one stored makes this one redundant. + redunIneq[r] = true; + } + } + } + + auto copyRow = [&](unsigned src, unsigned dest) { + if (src == dest) + return; + for (unsigned c = 0, e = getNumCols(); c < e; c++) { + atIneq(dest, c) = atIneq(src, c); + } + }; + + // Scan to get rid of all rows marked redundant, in-place. + unsigned pos = 0; + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (!redunIneq[r]) + copyRow(r, pos++); + } + inequalities.resize(numReservedCols * pos); + + // TODO(bondhugula): consider doing this for equalities as well, but probably + // not worth the savings. +} + +void FlatAffineConstraints::clearAndCopyFrom( + const FlatAffineConstraints &other) { + FlatAffineConstraints copy(other); + std::swap(*this, copy); + assert(copy.getNumIds() == copy.getIds().size()); +} + +void FlatAffineConstraints::removeId(unsigned pos) { + removeIdRange(pos, pos + 1); +} + +static std::pair<unsigned, unsigned> +getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) { + unsigned numDims = cst.getNumDimIds(); + unsigned numSymbols = cst.getNumSymbolIds(); + unsigned newNumDims, newNumSymbols; + if (pos < numDims) { + newNumDims = numDims - 1; + newNumSymbols = numSymbols; + } else if (pos < numDims + numSymbols) { + assert(numSymbols >= 1); + newNumDims = numDims; + newNumSymbols = numSymbols - 1; + } else { + newNumDims = numDims; + newNumSymbols = numSymbols; + } + return {newNumDims, newNumSymbols}; +} + +#undef DEBUG_TYPE +#define DEBUG_TYPE "fm" + +/// Eliminates identifier at the specified position using Fourier-Motzkin +/// variable elimination. This technique is exact for rational spaces but +/// conservative (in "rare" cases) for integer spaces. The operation corresponds +/// to a projection operation yielding the (convex) set of integer points +/// contained in the rational shadow of the set. An emptiness test that relies +/// on this method will guarantee emptiness, i.e., it disproves the existence of +/// a solution if it says it's empty. +/// If a non-null isResultIntegerExact is passed, it is set to true if the +/// result is also integer exact. If it's set to false, the obtained solution +/// *may* not be exact, i.e., it may contain integer points that do not have an +/// integer pre-image in the original set. +/// +/// Eg: +/// j >= 0, j <= i + 1 +/// i >= 0, i <= N + 1 +/// Eliminating i yields, +/// j >= 0, 0 <= N + 1, j - 1 <= N + 1 +/// +/// If darkShadow = true, this method computes the dark shadow on elimination; +/// the dark shadow is a convex integer subset of the exact integer shadow. A +/// non-empty dark shadow proves the existence of an integer solution. The +/// elimination in such a case could however be an under-approximation, and thus +/// should not be used for scanning sets or used by itself for dependence +/// checking. +/// +/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set. +/// ^ +/// | +/// | * * * * o o +/// i | * * o o o o +/// | o * * * * * +/// ---------------> +/// j -> +/// +/// Eliminating i from this system (projecting on the j dimension): +/// rational shadow / integer light shadow: 1 <= j <= 6 +/// dark shadow: 3 <= j <= 6 +/// exact integer shadow: j = 1 \union 3 <= j <= 6 +/// holes/splinters: j = 2 +/// +/// darkShadow = false, isResultIntegerExact = nullptr are default values. +// TODO(bondhugula): a slight modification to yield dark shadow version of FM +// (tightened), which can prove the existence of a solution if there is one. +void FlatAffineConstraints::FourierMotzkinEliminate( + unsigned pos, bool darkShadow, bool *isResultIntegerExact) { + LLVM_DEBUG(llvm::dbgs() << "FM input (eliminate pos " << pos << "):\n"); + LLVM_DEBUG(dump()); + assert(pos < getNumIds() && "invalid position"); + assert(hasConsistentState()); + + // Check if this identifier can be eliminated through a substitution. + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + if (atEq(r, pos) != 0) { + // Use Gaussian elimination here (since we have an equality). + LogicalResult ret = gaussianEliminateId(pos); + (void)ret; + assert(succeeded(ret) && "Gaussian elimination guaranteed to succeed"); + LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); + LLVM_DEBUG(dump()); + return; + } + } + + // A fast linear time tightening. + GCDTightenInequalities(); + + // Check if the identifier appears at all in any of the inequalities. + unsigned r, e; + for (r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) != 0) + break; + } + if (r == getNumInequalities()) { + // If it doesn't appear, just remove the column and return. + // TODO(andydavis,bondhugula): refactor removeColumns to use it from here. + removeId(pos); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); + return; + } + + // Positions of constraints that are lower bounds on the variable. + SmallVector<unsigned, 4> lbIndices; + // Positions of constraints that are lower bounds on the variable. + SmallVector<unsigned, 4> ubIndices; + // Positions of constraints that do not involve the variable. + std::vector<unsigned> nbIndices; + nbIndices.reserve(getNumInequalities()); + + // Gather all lower bounds and upper bounds of the variable. Since the + // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower + // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1. + for (unsigned r = 0, e = getNumInequalities(); r < e; r++) { + if (atIneq(r, pos) == 0) { + // Id does not appear in bound. + nbIndices.push_back(r); + } else if (atIneq(r, pos) >= 1) { + // Lower bound. + lbIndices.push_back(r); + } else { + // Upper bound. + ubIndices.push_back(r); + } + } + + // Set the number of dimensions, symbols in the resulting system. + const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this); + unsigned newNumDims = dimsSymbols.first; + unsigned newNumSymbols = dimsSymbols.second; + + SmallVector<Optional<Value>, 8> newIds; + newIds.reserve(numIds - 1); + newIds.append(ids.begin(), ids.begin() + pos); + newIds.append(ids.begin() + pos + 1, ids.end()); + + /// Create the new system which has one identifier less. + FlatAffineConstraints newFac( + lbIndices.size() * ubIndices.size() + nbIndices.size(), + getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols, + /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols, newIds); + + assert(newFac.getIds().size() == newFac.getNumIds()); + + // This will be used to check if the elimination was integer exact. + unsigned lcmProducts = 1; + + // Let x be the variable we are eliminating. + // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note + // that c_l, c_u >= 1) we have: + // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u + // We thus generate a constraint: + // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub. + // Note if c_l = c_u = 1, all integer points captured by the resulting + // constraint correspond to integer points in the original system (i.e., they + // have integer pre-images). Hence, if the lcm's are all 1, the elimination is + // integer exact. + for (auto ubPos : ubIndices) { + for (auto lbPos : lbIndices) { + SmallVector<int64_t, 4> ineq; + ineq.reserve(newFac.getNumCols()); + int64_t lbCoeff = atIneq(lbPos, pos); + // Note that in the comments above, ubCoeff is the negation of the + // coefficient in the canonical form as the view taken here is that of the + // term being moved to the other size of '>='. + int64_t ubCoeff = -atIneq(ubPos, pos); + // TODO(bondhugula): refactor this loop to avoid all branches inside. + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified"); + int64_t lcm = mlir::lcm(lbCoeff, ubCoeff); + ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) + + atIneq(lbPos, l) * (lcm / lbCoeff)); + lcmProducts *= lcm; + } + if (darkShadow) { + // The dark shadow is a convex subset of the exact integer shadow. If + // there is a point here, it proves the existence of a solution. + ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1; + } + // TODO: we need to have a way to add inequalities in-place in + // FlatAffineConstraints instead of creating and copying over. + newFac.addInequality(ineq); + } + } + + LLVM_DEBUG(llvm::dbgs() << "FM isResultIntegerExact: " << (lcmProducts == 1) + << "\n"); + if (lcmProducts == 1 && isResultIntegerExact) + *isResultIntegerExact = true; + + // Copy over the constraints not involving this variable. + for (auto nbPos : nbIndices) { + SmallVector<int64_t, 4> ineq; + ineq.reserve(getNumCols() - 1); + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + ineq.push_back(atIneq(nbPos, l)); + } + newFac.addInequality(ineq); + } + + assert(newFac.getNumConstraints() == + lbIndices.size() * ubIndices.size() + nbIndices.size()); + + // Copy over the equalities. + for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { + SmallVector<int64_t, 4> eq; + eq.reserve(newFac.getNumCols()); + for (unsigned l = 0, e = getNumCols(); l < e; l++) { + if (l == pos) + continue; + eq.push_back(atEq(r, l)); + } + newFac.addEquality(eq); + } + + // GCD tightening and normalization allows detection of more trivially + // redundant constraints. + newFac.GCDTightenInequalities(); + newFac.normalizeConstraintsByGCD(); + newFac.removeTrivialRedundancy(); + clearAndCopyFrom(newFac); + LLVM_DEBUG(llvm::dbgs() << "FM output:\n"); + LLVM_DEBUG(dump()); +} + +#undef DEBUG_TYPE +#define DEBUG_TYPE "affine-structures" + +void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) { + if (num == 0) + return; + + // 'pos' can be at most getNumCols() - 2 if num > 0. + assert((getNumCols() < 2 || pos <= getNumCols() - 2) && "invalid position"); + assert(pos + num < getNumCols() && "invalid range"); + + // Eliminate as many identifiers as possible using Gaussian elimination. + unsigned currentPos = pos; + unsigned numToEliminate = num; + unsigned numGaussianEliminated = 0; + + while (currentPos < getNumIds()) { + unsigned curNumEliminated = + gaussianEliminateIds(currentPos, currentPos + numToEliminate); + ++currentPos; + numToEliminate -= curNumEliminated + 1; + numGaussianEliminated += curNumEliminated; + } + + // Eliminate the remaining using Fourier-Motzkin. + for (unsigned i = 0; i < num - numGaussianEliminated; i++) { + unsigned numToEliminate = num - numGaussianEliminated - i; + FourierMotzkinEliminate( + getBestIdToEliminate(*this, pos, pos + numToEliminate)); + } + + // Fast/trivial simplifications. + GCDTightenInequalities(); + // Normalize constraints after tightening since the latter impacts this, but + // not the other way round. + normalizeConstraintsByGCD(); +} + +void FlatAffineConstraints::projectOut(Value id) { + unsigned pos; + bool ret = findId(*id, &pos); + assert(ret); + (void)ret; + FourierMotzkinEliminate(pos); +} + +void FlatAffineConstraints::clearConstraints() { + equalities.clear(); + inequalities.clear(); +} + +namespace { + +enum BoundCmpResult { Greater, Less, Equal, Unknown }; + +/// Compares two affine bounds whose coefficients are provided in 'first' and +/// 'second'. The last coefficient is the constant term. +static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { + assert(a.size() == b.size()); + + // For the bounds to be comparable, their corresponding identifier + // coefficients should be equal; the constant terms are then compared to + // determine less/greater/equal. + + if (!std::equal(a.begin(), a.end() - 1, b.begin())) + return Unknown; + + if (a.back() == b.back()) + return Equal; + + return a.back() < b.back() ? Less : Greater; +} +} // namespace + +// Computes the bounding box with respect to 'other' by finding the min of the +// lower bounds and the max of the upper bounds along each of the dimensions. +LogicalResult +FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) { + assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); + assert(otherCst.getIds() + .slice(0, getNumDimIds()) + .equals(getIds().slice(0, getNumDimIds())) && + "dim values mismatch"); + assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); + assert(getNumLocalIds() == 0 && "local ids not supported yet here"); + + Optional<FlatAffineConstraints> otherCopy; + if (!areIdsAligned(*this, otherCst)) { + otherCopy.emplace(FlatAffineConstraints(otherCst)); + mergeAndAlignIds(/*offset=*/numDims, this, &otherCopy.getValue()); + } + + const auto &other = otherCopy ? *otherCopy : otherCst; + + std::vector<SmallVector<int64_t, 8>> boundingLbs; + std::vector<SmallVector<int64_t, 8>> boundingUbs; + boundingLbs.reserve(2 * getNumDimIds()); + boundingUbs.reserve(2 * getNumDimIds()); + + // To hold lower and upper bounds for each dimension. + SmallVector<int64_t, 4> lb, otherLb, ub, otherUb; + // To compute min of lower bounds and max of upper bounds for each dimension. + SmallVector<int64_t, 4> minLb(getNumSymbolIds() + 1); + SmallVector<int64_t, 4> maxUb(getNumSymbolIds() + 1); + // To compute final new lower and upper bounds for the union. + SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols()); + + int64_t lbFloorDivisor, otherLbFloorDivisor; + for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { + auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub); + if (!extent.hasValue()) + // TODO(bondhugula): symbolic extents when necessary. + // TODO(bondhugula): handle union if a dimension is unbounded. + return failure(); + + auto otherExtent = other.getConstantBoundOnDimSize( + d, &otherLb, &otherLbFloorDivisor, &otherUb); + if (!otherExtent.hasValue() || lbFloorDivisor != otherLbFloorDivisor) + // TODO(bondhugula): symbolic extents when necessary. + return failure(); + + assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); + + auto res = compareBounds(lb, otherLb); + // Identify min. + if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { + minLb = lb; + // Since the divisor is for a floordiv, we need to convert to ceildiv, + // i.e., i >= expr floordiv div <=> i >= (expr - div + 1) ceildiv div <=> + // div * i >= expr - div + 1. + minLb.back() -= lbFloorDivisor - 1; + } else if (res == BoundCmpResult::Greater) { + minLb = otherLb; + minLb.back() -= otherLbFloorDivisor - 1; + } else { + // Uncomparable - check for constant lower/upper bounds. + auto constLb = getConstantLowerBound(d); + auto constOtherLb = other.getConstantLowerBound(d); + if (!constLb.hasValue() || !constOtherLb.hasValue()) + return failure(); + std::fill(minLb.begin(), minLb.end(), 0); + minLb.back() = std::min(constLb.getValue(), constOtherLb.getValue()); + } + + // Do the same for ub's but max of upper bounds. Identify max. + auto uRes = compareBounds(ub, otherUb); + if (uRes == BoundCmpResult::Greater || uRes == BoundCmpResult::Equal) { + maxUb = ub; + } else if (uRes == BoundCmpResult::Less) { + maxUb = otherUb; + } else { + // Uncomparable - check for constant lower/upper bounds. + auto constUb = getConstantUpperBound(d); + auto constOtherUb = other.getConstantUpperBound(d); + if (!constUb.hasValue() || !constOtherUb.hasValue()) + return failure(); + std::fill(maxUb.begin(), maxUb.end(), 0); + maxUb.back() = std::max(constUb.getValue(), constOtherUb.getValue()); + } + + std::fill(newLb.begin(), newLb.end(), 0); + std::fill(newUb.begin(), newUb.end(), 0); + + // The divisor for lb, ub, otherLb, otherUb at this point is lbDivisor, + // and so it's the divisor for newLb and newUb as well. + newLb[d] = lbFloorDivisor; + newUb[d] = -lbFloorDivisor; + // Copy over the symbolic part + constant term. + std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimIds()); + std::transform(newLb.begin() + getNumDimIds(), newLb.end(), + newLb.begin() + getNumDimIds(), std::negate<int64_t>()); + std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimIds()); + + boundingLbs.push_back(newLb); + boundingUbs.push_back(newUb); + } + + // Clear all constraints and add the lower/upper bounds for the bounding box. + clearConstraints(); + for (unsigned d = 0, e = getNumDimIds(); d < e; ++d) { + addInequality(boundingLbs[d]); + addInequality(boundingUbs[d]); + } + // TODO(mlir-team): copy over pure symbolic constraints from this and 'other' + // over to the union (since the above are just the union along dimensions); we + // shouldn't be discarding any other constraints on the symbols. + + return success(); +} diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt new file mode 100644 index 00000000000..96c2928b17f --- /dev/null +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -0,0 +1,29 @@ +add_llvm_library(MLIRAnalysis STATIC + AffineAnalysis.cpp + AffineStructures.cpp + CallGraph.cpp + Dominance.cpp + InferTypeOpInterface.cpp + Liveness.cpp + LoopAnalysis.cpp + MemRefBoundCheck.cpp + NestedMatcher.cpp + OpStats.cpp + SliceAnalysis.cpp + TestMemRefDependenceCheck.cpp + TestParallelismDetection.cpp + Utils.cpp + VectorAnalysis.cpp + Verifier.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Analysis + ) +add_dependencies(MLIRAnalysis + MLIRAffineOps + MLIRCallOpInterfacesIncGen + MLIRTypeInferOpInterfaceIncGen + MLIRLoopOps + MLIRVectorOps + ) +target_link_libraries(MLIRAnalysis MLIRAffineOps MLIRLoopOps MLIRVectorOps) diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp new file mode 100644 index 00000000000..c35421d55eb --- /dev/null +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -0,0 +1,256 @@ +//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains interfaces and analyses for defining a nested callgraph. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/CallGraph.h" +#include "mlir/Analysis/CallInterfaces.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// CallInterfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/CallInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// CallGraphNode +//===----------------------------------------------------------------------===// + +/// Returns if this node refers to the indirect/external node. +bool CallGraphNode::isExternal() const { return !callableRegion; } + +/// Return the callable region this node represents. This can only be called +/// on non-external nodes. +Region *CallGraphNode::getCallableRegion() const { + assert(!isExternal() && "the external node has no callable region"); + return callableRegion; +} + +/// Adds an reference edge to the given node. This is only valid on the +/// external node. +void CallGraphNode::addAbstractEdge(CallGraphNode *node) { + assert(isExternal() && "abstract edges are only valid on external nodes"); + addEdge(node, Edge::Kind::Abstract); +} + +/// Add an outgoing call edge from this node. +void CallGraphNode::addCallEdge(CallGraphNode *node) { + addEdge(node, Edge::Kind::Call); +} + +/// Adds a reference edge to the given child node. +void CallGraphNode::addChildEdge(CallGraphNode *child) { + addEdge(child, Edge::Kind::Child); +} + +/// Returns true if this node has any child edges. +bool CallGraphNode::hasChildren() const { + return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); }); +} + +/// Add an edge to 'node' with the given kind. +void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) { + edges.insert({node, kind}); +} + +//===----------------------------------------------------------------------===// +// CallGraph +//===----------------------------------------------------------------------===// + +/// Recursively compute the callgraph edges for the given operation. Computed +/// edges are placed into the given callgraph object. +static void computeCallGraph(Operation *op, CallGraph &cg, + CallGraphNode *parentNode); + +/// Compute the set of callgraph nodes that are created by regions nested within +/// 'op'. +static void computeCallables(Operation *op, CallGraph &cg, + CallGraphNode *parentNode) { + if (op->getNumRegions() == 0) + return; + if (auto callableOp = dyn_cast<CallableOpInterface>(op)) { + SmallVector<Region *, 1> callables; + callableOp.getCallableRegions(callables); + for (auto *callableRegion : callables) + cg.getOrAddNode(callableRegion, parentNode); + } +} + +/// Recursively compute the callgraph edges within the given region. Computed +/// edges are placed into the given callgraph object. +static void computeCallGraph(Region ®ion, CallGraph &cg, + CallGraphNode *parentNode) { + // Iterate over the nested operations twice: + /// One to fully create nodes in the for each callable region of a nested + /// operation; + for (auto &block : region) + for (auto &nested : block) + computeCallables(&nested, cg, parentNode); + + /// And another to recursively compute the callgraph. + for (auto &block : region) + for (auto &nested : block) + computeCallGraph(&nested, cg, parentNode); +} + +/// Recursively compute the callgraph edges for the given operation. Computed +/// edges are placed into the given callgraph object. +static void computeCallGraph(Operation *op, CallGraph &cg, + CallGraphNode *parentNode) { + // Compute the callgraph nodes and edges for each of the nested operations. + auto isCallable = isa<CallableOpInterface>(op); + for (auto ®ion : op->getRegions()) { + // Check to see if this region is a callable node, if so this is the parent + // node of the nested region. + CallGraphNode *nestedParentNode; + if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) + nestedParentNode = parentNode; + computeCallGraph(region, cg, nestedParentNode); + } + + // If there is no parent node, we ignore this operation. Even if this + // operation was a call, there would be no callgraph node to attribute it to. + if (!parentNode) + return; + + // If this is a call operation, resolve the callee. + if (auto call = dyn_cast<CallOpInterface>(op)) + parentNode->addCallEdge( + cg.resolveCallable(call.getCallableForCallee(), op)); +} + +CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { + computeCallGraph(op, *this, /*parentNode=*/nullptr); +} + +/// Get or add a call graph node for the given region. +CallGraphNode *CallGraph::getOrAddNode(Region *region, + CallGraphNode *parentNode) { + assert(region && isa<CallableOpInterface>(region->getParentOp()) && + "expected parent operation to be callable"); + std::unique_ptr<CallGraphNode> &node = nodes[region]; + if (!node) { + node.reset(new CallGraphNode(region)); + + // Add this node to the given parent node if necessary. + if (parentNode) + parentNode->addChildEdge(node.get()); + else + // Otherwise, connect all callable nodes to the external node, this allows + // for conservatively including all callable nodes within the graph. + // FIXME(riverriddle) This isn't correct, this is only necessary for + // callable nodes that *could* be called from external sources. This + // requires extending the interface for callables to check if they may be + // referenced externally. + externalNode.addAbstractEdge(node.get()); + } + return node.get(); +} + +/// Lookup a call graph node for the given region, or nullptr if none is +/// registered. +CallGraphNode *CallGraph::lookupNode(Region *region) const { + auto it = nodes.find(region); + return it == nodes.end() ? nullptr : it->second.get(); +} + +/// Resolve the callable for given callee to a node in the callgraph, or the +/// external node if a valid node was not resolved. +CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, + Operation *from) const { + // Get the callee operation from the callable. + Operation *callee; + if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>()) + // TODO(riverriddle) Support nested references. + callee = SymbolTable::lookupNearestSymbolFrom(from, + symbolRef.getRootReference()); + else + callee = callable.get<Value>()->getDefiningOp(); + + // If the callee is non-null and is a valid callable object, try to get the + // called region from it. + if (callee && callee->getNumRegions()) { + if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) { + if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) + return node; + } + } + + // If we don't have a valid direct region, this is an external call. + return getExternalNode(); +} + +//===----------------------------------------------------------------------===// +// Printing + +/// Dump the graph in a human readable format. +void CallGraph::dump() const { print(llvm::errs()); } +void CallGraph::print(raw_ostream &os) const { + os << "// ---- CallGraph ----\n"; + + // Functor used to output the name for the given node. + auto emitNodeName = [&](const CallGraphNode *node) { + if (node->isExternal()) { + os << "<External-Node>"; + return; + } + + auto *callableRegion = node->getCallableRegion(); + auto *parentOp = callableRegion->getParentOp(); + os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" + << callableRegion->getRegionNumber(); + if (auto attrs = parentOp->getAttrList().getDictionary()) + os << " : " << attrs; + }; + + for (auto &nodeIt : nodes) { + const CallGraphNode *node = nodeIt.second.get(); + + // Dump the header for this node. + os << "// - Node : "; + emitNodeName(node); + os << "\n"; + + // Emit each of the edges. + for (auto &edge : *node) { + os << "// -- "; + if (edge.isCall()) + os << "Call"; + else if (edge.isChild()) + os << "Child"; + + os << "-Edge : "; + emitNodeName(edge.getTarget()); + os << "\n"; + } + os << "//\n"; + } + + os << "// -- SCCs --\n"; + + for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) { + os << "// - SCC : \n"; + for (auto &node : scc) { + os << "// -- Node :"; + emitNodeName(node); + os << "\n"; + } + os << "\n"; + } + + os << "// -------------------\n"; +} diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp new file mode 100644 index 00000000000..e4af4c0d69b --- /dev/null +++ b/mlir/lib/Analysis/Dominance.cpp @@ -0,0 +1,171 @@ +//===- Dominance.cpp - Dominator analysis for CFGs ------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of dominance related classes and instantiations of extern +// templates. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Dominance.h" +#include "mlir/IR/Operation.h" +#include "llvm/Support/GenericDomTreeConstruction.h" + +using namespace mlir; +using namespace mlir::detail; + +template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/false>; +template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/true>; +template class llvm::DomTreeNodeBase<Block>; + +//===----------------------------------------------------------------------===// +// DominanceInfoBase +//===----------------------------------------------------------------------===// + +template <bool IsPostDom> +void DominanceInfoBase<IsPostDom>::recalculate(Operation *op) { + dominanceInfos.clear(); + + /// Build the dominance for each of the operation regions. + op->walk([&](Operation *op) { + for (auto ®ion : op->getRegions()) { + // Don't compute dominance if the region is empty. + if (region.empty()) + continue; + auto opDominance = std::make_unique<base>(); + opDominance->recalculate(region); + dominanceInfos.try_emplace(®ion, std::move(opDominance)); + } + }); +} + +/// Return true if the specified block A properly dominates block B. +template <bool IsPostDom> +bool DominanceInfoBase<IsPostDom>::properlyDominates(Block *a, Block *b) { + // A block dominates itself but does not properly dominate itself. + if (a == b) + return false; + + // If either a or b are null, then conservatively return false. + if (!a || !b) + return false; + + // If both blocks are not in the same region, 'a' properly dominates 'b' if + // 'b' is defined in an operation region that (recursively) ends up being + // dominated by 'a'. Walk up the list of containers enclosing B. + auto *regionA = a->getParent(), *regionB = b->getParent(); + if (regionA != regionB) { + Operation *bAncestor; + do { + bAncestor = regionB->getParentOp(); + // If 'bAncestor' is the top level region, then 'a' is a block that post + // dominates 'b'. + if (!bAncestor || !bAncestor->getBlock()) + return IsPostDom; + + regionB = bAncestor->getBlock()->getParent(); + } while (regionA != regionB); + + // Check to see if the ancestor of 'b' is the same block as 'a'. + b = bAncestor->getBlock(); + if (a == b) + return true; + } + + // Otherwise, use the standard dominance functionality. + + // If we don't have a dominance information for this region, assume that b is + // dominated by anything. + auto baseInfoIt = dominanceInfos.find(regionA); + if (baseInfoIt == dominanceInfos.end()) + return true; + return baseInfoIt->second->properlyDominates(a, b); +} + +template class mlir::detail::DominanceInfoBase</*IsPostDom=*/true>; +template class mlir::detail::DominanceInfoBase</*IsPostDom=*/false>; + +//===----------------------------------------------------------------------===// +// DominanceInfo +//===----------------------------------------------------------------------===// + +/// Return true if operation A properly dominates operation B. +bool DominanceInfo::properlyDominates(Operation *a, Operation *b) { + auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); + + // If a or b are not within a block, then a does not dominate b. + if (!aBlock || !bBlock) + return false; + + // If the blocks are the same, then check if b is before a in the block. + if (aBlock == bBlock) + return a->isBeforeInBlock(b); + + // Traverse up b's hierarchy to check if b's block is contained in a's. + if (auto *bAncestor = aBlock->findAncestorOpInBlock(*b)) { + // Since we already know that aBlock != bBlock, here bAncestor != b. + // a and bAncestor are in the same block; check if 'a' dominates + // bAncestor. + return dominates(a, bAncestor); + } + + // If the blocks are different, check if a's block dominates b's. + return properlyDominates(aBlock, bBlock); +} + +/// Return true if value A properly dominates operation B. +bool DominanceInfo::properlyDominates(Value a, Operation *b) { + if (auto *aOp = a->getDefiningOp()) { + // The values defined by an operation do *not* dominate any nested + // operations. + if (aOp->getParentRegion() != b->getParentRegion() && aOp->isAncestor(b)) + return false; + return properlyDominates(aOp, b); + } + + // block arguments properly dominate all operations in their own block, so + // we use a dominates check here, not a properlyDominates check. + return dominates(a.cast<BlockArgument>()->getOwner(), b->getBlock()); +} + +DominanceInfoNode *DominanceInfo::getNode(Block *a) { + auto *region = a->getParent(); + assert(dominanceInfos.count(region) != 0); + return dominanceInfos[region]->getNode(a); +} + +void DominanceInfo::updateDFSNumbers() { + for (auto &iter : dominanceInfos) + iter.second->updateDFSNumbers(); +} + +//===----------------------------------------------------------------------===// +// PostDominanceInfo +//===----------------------------------------------------------------------===// + +/// Returns true if statement 'a' properly postdominates statement b. +bool PostDominanceInfo::properlyPostDominates(Operation *a, Operation *b) { + auto *aBlock = a->getBlock(), *bBlock = b->getBlock(); + + // If a or b are not within a block, then a does not post dominate b. + if (!aBlock || !bBlock) + return false; + + // If the blocks are the same, check if b is before a in the block. + if (aBlock == bBlock) + return b->isBeforeInBlock(a); + + // Traverse up b's hierarchy to check if b's block is contained in a's. + if (auto *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) + // Since we already know that aBlock != bBlock, here bAncestor != b. + // a and bAncestor are in the same block; check if 'a' postdominates + // bAncestor. + return postDominates(a, bAncestor); + + // If the blocks are different, check if a's block post dominates b's. + return properlyDominates(aBlock, bBlock); +} diff --git a/mlir/lib/Analysis/InferTypeOpInterface.cpp b/mlir/lib/Analysis/InferTypeOpInterface.cpp new file mode 100644 index 00000000000..2e52de2b3fa --- /dev/null +++ b/mlir/lib/Analysis/InferTypeOpInterface.cpp @@ -0,0 +1,22 @@ +//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the infer op interfaces defined in +// `InferTypeOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/InferTypeOpInterface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +#include "mlir/Analysis/InferTypeOpInterface.cpp.inc" +} // namespace mlir diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp new file mode 100644 index 00000000000..7ba31365f1a --- /dev/null +++ b/mlir/lib/Analysis/Liveness.cpp @@ -0,0 +1,373 @@ +//===- Liveness.cpp - Liveness analysis for MLIR --------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of the liveness analysis. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Liveness.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +/// Builds and holds block information during the construction phase. +struct BlockInfoBuilder { + using ValueSetT = Liveness::ValueSetT; + + /// Constructs an empty block builder. + BlockInfoBuilder() : block(nullptr) {} + + /// Fills the block builder with initial liveness information. + BlockInfoBuilder(Block *block) : block(block) { + // Mark all block arguments (phis) as defined. + for (BlockArgument argument : block->getArguments()) + defValues.insert(argument); + + // Check all result values and whether their uses + // are inside this block or not (see outValues). + for (Operation &operation : *block) + for (Value result : operation.getResults()) { + defValues.insert(result); + + // Check whether this value will be in the outValues + // set (its uses escape this block). Due to the SSA + // properties of the program, the uses must occur after + // the definition. Therefore, we do not have to check + // additional conditions to detect an escaping value. + for (OpOperand &use : result->getUses()) + if (use.getOwner()->getBlock() != block) { + outValues.insert(result); + break; + } + } + + // Check all operations for used operands. + for (Operation &operation : block->getOperations()) + for (Value operand : operation.getOperands()) { + // If the operand is already defined in the scope of this + // block, we can skip the value in the use set. + if (!defValues.count(operand)) + useValues.insert(operand); + } + } + + /// Updates live-in information of the current block. + /// To do so it uses the default liveness-computation formula: + /// newIn = use union out \ def. + /// The methods returns true, if the set has changed (newIn != in), + /// false otherwise. + bool updateLiveIn() { + ValueSetT newIn = useValues; + llvm::set_union(newIn, outValues); + llvm::set_subtract(newIn, defValues); + + // It is sufficient to check the set sizes (instead of their contents) + // since the live-in set can only grow monotonically during all update + // operations. + if (newIn.size() == inValues.size()) + return false; + + inValues = newIn; + return true; + } + + /// Updates live-out information of the current block. + /// It iterates over all successors and unifies their live-in + /// values with the current live-out values. + template <typename SourceT> void updateLiveOut(SourceT &source) { + for (Block *succ : block->getSuccessors()) { + BlockInfoBuilder &builder = source[succ]; + llvm::set_union(outValues, builder.inValues); + } + } + + /// The current block. + Block *block; + + /// The set of all live in values. + ValueSetT inValues; + + /// The set of all live out values. + ValueSetT outValues; + + /// The set of all defined values. + ValueSetT defValues; + + /// The set of all used values. + ValueSetT useValues; +}; + +/// Builds the internal liveness block mapping. +static void buildBlockMapping(MutableArrayRef<Region> regions, + DenseMap<Block *, BlockInfoBuilder> &builders) { + llvm::SetVector<Block *> toProcess; + + // Initialize all block structures + for (Region ®ion : regions) + for (Block &block : region) { + BlockInfoBuilder &builder = + builders.try_emplace(&block, &block).first->second; + + if (builder.updateLiveIn()) + toProcess.insert(block.pred_begin(), block.pred_end()); + } + + // Propagate the in and out-value sets (fixpoint iteration) + while (!toProcess.empty()) { + Block *current = toProcess.pop_back_val(); + BlockInfoBuilder &builder = builders[current]; + + // Update the current out values. + builder.updateLiveOut(builders); + + // Compute (potentially) updated live in values. + if (builder.updateLiveIn()) + toProcess.insert(current->pred_begin(), current->pred_end()); + } +} + +//===----------------------------------------------------------------------===// +// Liveness +//===----------------------------------------------------------------------===// + +/// Creates a new Liveness analysis that computes liveness +/// information for all associated regions. +Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); } + +/// Initializes the internal mappings. +void Liveness::build(MutableArrayRef<Region> regions) { + + // Build internal block mapping. + DenseMap<Block *, BlockInfoBuilder> builders; + buildBlockMapping(regions, builders); + + // Store internal block data. + for (auto &entry : builders) { + BlockInfoBuilder &builder = entry.second; + LivenessBlockInfo &info = blockMapping[entry.first]; + + info.block = builder.block; + info.inValues = std::move(builder.inValues); + info.outValues = std::move(builder.outValues); + } +} + +/// Gets liveness info (if any) for the given value. +Liveness::OperationListT Liveness::resolveLiveness(Value value) const { + OperationListT result; + SmallPtrSet<Block *, 32> visited; + SmallVector<Block *, 8> toProcess; + + // Start with the defining block + Block *currentBlock; + if (Operation *defOp = value->getDefiningOp()) + currentBlock = defOp->getBlock(); + else + currentBlock = value.cast<BlockArgument>()->getOwner(); + toProcess.push_back(currentBlock); + visited.insert(currentBlock); + + // Start with all associated blocks + for (OpOperand &use : value->getUses()) { + Block *useBlock = use.getOwner()->getBlock(); + if (visited.insert(useBlock).second) + toProcess.push_back(useBlock); + } + + while (!toProcess.empty()) { + // Get block and block liveness information. + Block *block = toProcess.back(); + toProcess.pop_back(); + const LivenessBlockInfo *blockInfo = getLiveness(block); + + // Note that start and end will be in the same block. + Operation *start = blockInfo->getStartOperation(value); + Operation *end = blockInfo->getEndOperation(value, start); + + result.push_back(start); + while (start != end) { + start = start->getNextNode(); + result.push_back(start); + } + + for (Block *successor : block->getSuccessors()) { + if (getLiveness(successor)->isLiveIn(value) && + visited.insert(successor).second) + toProcess.push_back(successor); + } + } + + return result; +} + +/// Gets liveness info (if any) for the block. +const LivenessBlockInfo *Liveness::getLiveness(Block *block) const { + auto it = blockMapping.find(block); + return it == blockMapping.end() ? nullptr : &it->second; +} + +/// Returns a reference to a set containing live-in values. +const Liveness::ValueSetT &Liveness::getLiveIn(Block *block) const { + return getLiveness(block)->in(); +} + +/// Returns a reference to a set containing live-out values. +const Liveness::ValueSetT &Liveness::getLiveOut(Block *block) const { + return getLiveness(block)->out(); +} + +/// Returns true if the given operation represent the last use of the +/// given value. +bool Liveness::isLastUse(Value value, Operation *operation) const { + Block *block = operation->getBlock(); + const LivenessBlockInfo *blockInfo = getLiveness(block); + + // The given value escapes the associated block. + if (blockInfo->isLiveOut(value)) + return false; + + Operation *endOperation = blockInfo->getEndOperation(value, operation); + // If the operation is a real user of `value` the first check is sufficient. + // If not, we will have to test whether the end operation is executed before + // the given operation in the block. + return endOperation == operation || endOperation->isBeforeInBlock(operation); +} + +/// Dumps the liveness information in a human readable format. +void Liveness::dump() const { print(llvm::errs()); } + +/// Dumps the liveness information to the given stream. +void Liveness::print(raw_ostream &os) const { + os << "// ---- Liveness -----\n"; + + // Builds unique block/value mappings for testing purposes. + DenseMap<Block *, size_t> blockIds; + DenseMap<Operation *, size_t> operationIds; + DenseMap<Value, size_t> valueIds; + for (Region ®ion : operation->getRegions()) + for (Block &block : region) { + blockIds.insert({&block, blockIds.size()}); + for (BlockArgument argument : block.getArguments()) + valueIds.insert({argument, valueIds.size()}); + for (Operation &operation : block) { + operationIds.insert({&operation, operationIds.size()}); + for (Value result : operation.getResults()) + valueIds.insert({result, valueIds.size()}); + } + } + + // Local printing helpers + auto printValueRef = [&](Value value) { + if (Operation *defOp = value->getDefiningOp()) + os << "val_" << defOp->getName(); + else { + auto blockArg = value.cast<BlockArgument>(); + os << "arg" << blockArg->getArgNumber() << "@" + << blockIds[blockArg->getOwner()]; + } + os << " "; + }; + + auto printValueRefs = [&](const ValueSetT &values) { + std::vector<Value> orderedValues(values.begin(), values.end()); + std::sort(orderedValues.begin(), orderedValues.end(), + [&](Value left, Value right) { + return valueIds[left] < valueIds[right]; + }); + for (Value value : orderedValues) + printValueRef(value); + }; + + // Dump information about in and out values. + for (Region ®ion : operation->getRegions()) + for (Block &block : region) { + os << "// - Block: " << blockIds[&block] << "\n"; + auto liveness = getLiveness(&block); + os << "// --- LiveIn: "; + printValueRefs(liveness->inValues); + os << "\n// --- LiveOut: "; + printValueRefs(liveness->outValues); + os << "\n"; + + // Print liveness intervals. + os << "// --- BeginLiveness"; + for (Operation &op : block) { + if (op.getNumResults() < 1) + continue; + os << "\n"; + for (Value result : op.getResults()) { + os << "// "; + printValueRef(result); + os << ":"; + auto liveOperations = resolveLiveness(result); + std::sort(liveOperations.begin(), liveOperations.end(), + [&](Operation *left, Operation *right) { + return operationIds[left] < operationIds[right]; + }); + for (Operation *operation : liveOperations) { + os << "\n// "; + operation->print(os); + } + } + } + os << "\n// --- EndLiveness\n"; + } + os << "// -------------------\n"; +} + +//===----------------------------------------------------------------------===// +// LivenessBlockInfo +//===----------------------------------------------------------------------===// + +/// Returns true if the given value is in the live-in set. +bool LivenessBlockInfo::isLiveIn(Value value) const { + return inValues.count(value); +} + +/// Returns true if the given value is in the live-out set. +bool LivenessBlockInfo::isLiveOut(Value value) const { + return outValues.count(value); +} + +/// Gets the start operation for the given value +/// (must be referenced in this block). +Operation *LivenessBlockInfo::getStartOperation(Value value) const { + Operation *definingOp = value->getDefiningOp(); + // The given value is either live-in or is defined + // in the scope of this block. + if (isLiveIn(value) || !definingOp) + return &block->front(); + return definingOp; +} + +/// Gets the end operation for the given value using the start operation +/// provided (must be referenced in this block). +Operation *LivenessBlockInfo::getEndOperation(Value value, + Operation *startOperation) const { + // The given value is either dying in this block or live-out. + if (isLiveOut(value)) + return &block->back(); + + // Resolve the last operation (must exist by definition). + Operation *endOperation = startOperation; + for (OpOperand &use : value->getUses()) { + Operation *useOperation = use.getOwner(); + // Check whether the use is in our block and after + // the current end operation. + if (useOperation->getBlock() == block && + endOperation->isBeforeInBlock(useOperation)) + endOperation = useOperation; + } + return endOperation; +} diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp new file mode 100644 index 00000000000..18c86dc63b4 --- /dev/null +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -0,0 +1,388 @@ +//===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous loop analysis routines. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/LoopAnalysis.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/NestedMatcher.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Support/MathExtras.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallString.h" +#include <type_traits> + +using namespace mlir; + +/// Returns the trip count of the loop as an affine expression if the latter is +/// expressible as an affine expression, and nullptr otherwise. The trip count +/// expression is simplified before returning. This method only utilizes map +/// composition to construct lower and upper bounds before computing the trip +/// count expressions. +// TODO(mlir-team): this should be moved into 'Transforms/' and be replaced by a +// pure analysis method relying on FlatAffineConstraints; the latter will also +// be more powerful (since both inequalities and equalities will be considered). +void mlir::buildTripCountMapAndOperands( + AffineForOp forOp, AffineMap *tripCountMap, + SmallVectorImpl<Value> *tripCountOperands) { + int64_t loopSpan; + + int64_t step = forOp.getStep(); + OpBuilder b(forOp.getOperation()); + + if (forOp.hasConstantBounds()) { + int64_t lb = forOp.getConstantLowerBound(); + int64_t ub = forOp.getConstantUpperBound(); + loopSpan = ub - lb; + if (loopSpan < 0) + loopSpan = 0; + *tripCountMap = b.getConstantAffineMap(ceilDiv(loopSpan, step)); + tripCountOperands->clear(); + return; + } + auto lbMap = forOp.getLowerBoundMap(); + auto ubMap = forOp.getUpperBoundMap(); + if (lbMap.getNumResults() != 1) { + *tripCountMap = AffineMap(); + return; + } + SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); + SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands()); + + // Difference of each upper bound expression from the single lower bound + // expression (divided by the step) provides the expressions for the trip + // count map. + AffineValueMap ubValueMap(ubMap, ubOperands); + + SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(), + lbMap.getResult(0)); + auto lbMapSplat = + AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), lbSplatExpr); + AffineValueMap lbSplatValueMap(lbMapSplat, lbOperands); + + AffineValueMap tripCountValueMap; + AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap); + for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i) + tripCountValueMap.setResult(i, + tripCountValueMap.getResult(i).ceilDiv(step)); + + *tripCountMap = tripCountValueMap.getAffineMap(); + tripCountOperands->assign(tripCountValueMap.getOperands().begin(), + tripCountValueMap.getOperands().end()); +} + +/// Returns the trip count of the loop if it's a constant, None otherwise. This +/// method uses affine expression analysis (in turn using getTripCount) and is +/// able to determine constant trip count in non-trivial cases. +// FIXME(mlir-team): this is really relying on buildTripCountMapAndOperands; +// being an analysis utility, it shouldn't. Replace with a version that just +// works with analysis structures (FlatAffineConstraints) and thus doesn't +// update the IR. +Optional<uint64_t> mlir::getConstantTripCount(AffineForOp forOp) { + SmallVector<Value, 4> operands; + AffineMap map; + buildTripCountMapAndOperands(forOp, &map, &operands); + + if (!map) + return None; + + // Take the min if all trip counts are constant. + Optional<uint64_t> tripCount; + for (auto resultExpr : map.getResults()) { + if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) { + if (tripCount.hasValue()) + tripCount = std::min(tripCount.getValue(), + static_cast<uint64_t>(constExpr.getValue())); + else + tripCount = constExpr.getValue(); + } else + return None; + } + return tripCount; +} + +/// Returns the greatest known integral divisor of the trip count. Affine +/// expression analysis is used (indirectly through getTripCount), and +/// this method is thus able to determine non-trivial divisors. +uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) { + SmallVector<Value, 4> operands; + AffineMap map; + buildTripCountMapAndOperands(forOp, &map, &operands); + + if (!map) + return 1; + + // The largest divisor of the trip count is the GCD of the individual largest + // divisors. + assert(map.getNumResults() >= 1 && "expected one or more results"); + Optional<uint64_t> gcd; + for (auto resultExpr : map.getResults()) { + uint64_t thisGcd; + if (auto constExpr = resultExpr.dyn_cast<AffineConstantExpr>()) { + uint64_t tripCount = constExpr.getValue(); + // 0 iteration loops (greatest divisor is 2^64 - 1). + if (tripCount == 0) + thisGcd = std::numeric_limits<uint64_t>::max(); + else + // The greatest divisor is the trip count. + thisGcd = tripCount; + } else { + // Trip count is not a known constant; return its largest known divisor. + thisGcd = resultExpr.getLargestKnownDivisor(); + } + if (gcd.hasValue()) + gcd = llvm::GreatestCommonDivisor64(gcd.getValue(), thisGcd); + else + gcd = thisGcd; + } + assert(gcd.hasValue() && "value expected per above logic"); + return gcd.getValue(); +} + +/// Given an induction variable `iv` of type AffineForOp and an access `index` +/// of type index, returns `true` if `index` is independent of `iv` and +/// false otherwise. The determination supports composition with at most one +/// AffineApplyOp. The 'at most one AffineApplyOp' comes from the fact that +/// the composition of AffineApplyOp needs to be canonicalized by construction +/// to avoid writing code that composes arbitrary numbers of AffineApplyOps +/// everywhere. To achieve this, at the very least, the compose-affine-apply +/// pass must have been run. +/// +/// Prerequisites: +/// 1. `iv` and `index` of the proper type; +/// 2. at most one reachable AffineApplyOp from index; +/// +/// Returns false in cases with more than one AffineApplyOp, this is +/// conservative. +static bool isAccessIndexInvariant(Value iv, Value index) { + assert(isForInductionVar(iv) && "iv must be a AffineForOp"); + assert(index->getType().isa<IndexType>() && "index must be of IndexType"); + SmallVector<Operation *, 4> affineApplyOps; + getReachableAffineApplyOps({index}, affineApplyOps); + + if (affineApplyOps.empty()) { + // Pointer equality test because of Value pointer semantics. + return index != iv; + } + + if (affineApplyOps.size() > 1) { + affineApplyOps[0]->emitRemark( + "CompositionAffineMapsPass must have been run: there should be at most " + "one AffineApplyOp, returning false conservatively."); + return false; + } + + auto composeOp = cast<AffineApplyOp>(affineApplyOps[0]); + // We need yet another level of indirection because the `dim` index of the + // access may not correspond to the `dim` index of composeOp. + return !(AffineValueMap(composeOp).isFunctionOf(0, iv)); +} + +DenseSet<Value> mlir::getInvariantAccesses(Value iv, ArrayRef<Value> indices) { + DenseSet<Value> res; + for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) { + auto val = indices[idx]; + if (isAccessIndexInvariant(iv, val)) { + res.insert(val); + } + } + return res; +} + +/// Given: +/// 1. an induction variable `iv` of type AffineForOp; +/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&; +/// determines whether `memoryOp` has a contiguous access along `iv`. Contiguous +/// is defined as either invariant or varying only along a unique MemRef dim. +/// Upon success, the unique MemRef dim is written in `memRefDim` (or -1 to +/// convey the memRef access is invariant along `iv`). +/// +/// Prerequisites: +/// 1. `memRefDim` ~= nullptr; +/// 2. `iv` of the proper type; +/// 3. the MemRef accessed by `memoryOp` has no layout map or at most an +/// identity layout map. +/// +/// Currently only supports no layoutMap or identity layoutMap in the MemRef. +/// Returns false if the MemRef has a non-identity layoutMap or more than 1 +/// layoutMap. This is conservative. +/// +// TODO(ntv): check strides. +template <typename LoadOrStoreOp> +static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp, + int *memRefDim) { + static_assert(std::is_same<LoadOrStoreOp, AffineLoadOp>::value || + std::is_same<LoadOrStoreOp, AffineStoreOp>::value, + "Must be called on either const LoadOp & or const StoreOp &"); + assert(memRefDim && "memRefDim == nullptr"); + auto memRefType = memoryOp.getMemRefType(); + + auto layoutMap = memRefType.getAffineMaps(); + // TODO(ntv): remove dependence on Builder once we support non-identity + // layout map. + Builder b(memoryOp.getContext()); + if (layoutMap.size() >= 2 || + (layoutMap.size() == 1 && + !(layoutMap[0] == + b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) { + return memoryOp.emitError("NYI: non-trivial layoutMap"), false; + } + + int uniqueVaryingIndexAlongIv = -1; + auto accessMap = memoryOp.getAffineMap(); + SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands()); + unsigned numDims = accessMap.getNumDims(); + for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { + // Gather map operands used result expr 'i' in 'exprOperands'. + SmallVector<Value, 4> exprOperands; + auto resultExpr = accessMap.getResult(i); + resultExpr.walk([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) + exprOperands.push_back(mapOperands[dimExpr.getPosition()]); + else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) + exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]); + }); + // Check access invariance of each operand in 'exprOperands'. + for (auto exprOperand : exprOperands) { + if (!isAccessIndexInvariant(iv, exprOperand)) { + if (uniqueVaryingIndexAlongIv != -1) { + // 2+ varying indices -> do not vectorize along iv. + return false; + } + uniqueVaryingIndexAlongIv = i; + } + } + } + + if (uniqueVaryingIndexAlongIv == -1) + *memRefDim = -1; + else + *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1); + return true; +} + +template <typename LoadOrStoreOpPointer> +static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { + auto memRefType = memoryOp.getMemRefType(); + return memRefType.getElementType().template isa<VectorType>(); +} + +using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>; + +static bool +isVectorizableLoopBodyWithOpCond(AffineForOp loop, + VectorizableOpFun isVectorizableOp, + NestedPattern &vectorTransferMatcher) { + auto *forOp = loop.getOperation(); + + // No vectorization across conditionals for now. + auto conditionals = matcher::If(); + SmallVector<NestedMatch, 8> conditionalsMatched; + conditionals.match(forOp, &conditionalsMatched); + if (!conditionalsMatched.empty()) { + return false; + } + + // No vectorization across unknown regions. + auto regions = matcher::Op([](Operation &op) -> bool { + return op.getNumRegions() != 0 && + !(isa<AffineIfOp>(op) || isa<AffineForOp>(op)); + }); + SmallVector<NestedMatch, 8> regionsMatched; + regions.match(forOp, ®ionsMatched); + if (!regionsMatched.empty()) { + return false; + } + + SmallVector<NestedMatch, 8> vectorTransfersMatched; + vectorTransferMatcher.match(forOp, &vectorTransfersMatched); + if (!vectorTransfersMatched.empty()) { + return false; + } + + auto loadAndStores = matcher::Op(matcher::isLoadOrStore); + SmallVector<NestedMatch, 8> loadAndStoresMatched; + loadAndStores.match(forOp, &loadAndStoresMatched); + for (auto ls : loadAndStoresMatched) { + auto *op = ls.getMatchedOperation(); + auto load = dyn_cast<AffineLoadOp>(op); + auto store = dyn_cast<AffineStoreOp>(op); + // Only scalar types are considered vectorizable, all load/store must be + // vectorizable for a loop to qualify as vectorizable. + // TODO(ntv): ponder whether we want to be more general here. + bool vector = load ? isVectorElement(load) : isVectorElement(store); + if (vector) { + return false; + } + if (isVectorizableOp && !isVectorizableOp(loop, *op)) { + return false; + } + } + return true; +} + +bool mlir::isVectorizableLoopBody(AffineForOp loop, int *memRefDim, + NestedPattern &vectorTransferMatcher) { + VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { + auto load = dyn_cast<AffineLoadOp>(op); + auto store = dyn_cast<AffineStoreOp>(op); + return load ? isContiguousAccess(loop.getInductionVar(), load, memRefDim) + : isContiguousAccess(loop.getInductionVar(), store, memRefDim); + }); + return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher); +} + +bool mlir::isVectorizableLoopBody(AffineForOp loop, + NestedPattern &vectorTransferMatcher) { + return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher); +} + +/// Checks whether SSA dominance would be violated if a for op's body +/// operations are shifted by the specified shifts. This method checks if a +/// 'def' and all its uses have the same shift factor. +// TODO(mlir-team): extend this to check for memory-based dependence violation +// when we have the support. +bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) { + auto *forBody = forOp.getBody(); + assert(shifts.size() == forBody->getOperations().size()); + + // Work backwards over the body of the block so that the shift of a use's + // ancestor operation in the block gets recorded before it's looked up. + DenseMap<Operation *, uint64_t> forBodyShift; + for (auto it : llvm::enumerate(llvm::reverse(forBody->getOperations()))) { + auto &op = it.value(); + + // Get the index of the current operation, note that we are iterating in + // reverse so we need to fix it up. + size_t index = shifts.size() - it.index() - 1; + + // Remember the shift of this operation. + uint64_t shift = shifts[index]; + forBodyShift.try_emplace(&op, shift); + + // Validate the results of this operation if it were to be shifted. + for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { + Value result = op.getResult(i); + for (auto *user : result->getUsers()) { + // If an ancestor operation doesn't lie in the block of forOp, + // there is no shift to check. + if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) { + assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map"); + if (shift != forBodyShift[ancOp]) + return false; + } + } + } + } + return true; +} diff --git a/mlir/lib/Analysis/MemRefBoundCheck.cpp b/mlir/lib/Analysis/MemRefBoundCheck.cpp new file mode 100644 index 00000000000..1f7c1a1ae31 --- /dev/null +++ b/mlir/lib/Analysis/MemRefBoundCheck.cpp @@ -0,0 +1,53 @@ +//===- MemRefBoundCheck.cpp - MLIR Affine Structures Class ----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to check memref accesses for out of bound +// accesses. +// +//===----------------------------------------------------------------------===// + +#include "mlir/ADT/TypeSwitch.h" +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Passes.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "memref-bound-check" + +using namespace mlir; + +namespace { + +/// Checks for out of bound memef access subscripts.. +struct MemRefBoundCheck : public FunctionPass<MemRefBoundCheck> { + void runOnFunction() override; +}; + +} // end anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> mlir::createMemRefBoundCheckPass() { + return std::make_unique<MemRefBoundCheck>(); +} + +void MemRefBoundCheck::runOnFunction() { + getFunction().walk([](Operation *opInst) { + TypeSwitch<Operation *>(opInst).Case<AffineLoadOp, AffineStoreOp>( + [](auto op) { boundCheckLoadOrStoreOp(op); }); + + // TODO(bondhugula): do this for DMA ops as well. + }); +} + +static PassRegistration<MemRefBoundCheck> + memRefBoundCheck("memref-bound-check", + "Check memref access bounds in a Function"); diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp new file mode 100644 index 00000000000..97eaafd37ce --- /dev/null +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -0,0 +1,152 @@ +//===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/NestedMatcher.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +llvm::BumpPtrAllocator *&NestedMatch::allocator() { + thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; +} + +NestedMatch NestedMatch::build(Operation *operation, + ArrayRef<NestedMatch> nestedMatches) { + auto *result = allocator()->Allocate<NestedMatch>(); + auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size()); + std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children); + new (result) NestedMatch(); + result->matchedOperation = operation; + result->matchedChildren = + ArrayRef<NestedMatch>(children, nestedMatches.size()); + return *result; +} + +llvm::BumpPtrAllocator *&NestedPattern::allocator() { + thread_local llvm::BumpPtrAllocator *allocator = nullptr; + return allocator; +} + +NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested, + FilterFunctionType filter) + : nestedPatterns(), filter(filter), skip(nullptr) { + if (!nested.empty()) { + auto *newNested = allocator()->Allocate<NestedPattern>(nested.size()); + std::uninitialized_copy(nested.begin(), nested.end(), newNested); + nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); + } +} + +unsigned NestedPattern::getDepth() const { + if (nestedPatterns.empty()) { + return 1; + } + unsigned depth = 0; + for (auto &c : nestedPatterns) { + depth = std::max(depth, c.getDepth()); + } + return depth + 1; +} + +/// Matches a single operation in the following way: +/// 1. checks the kind of operation against the matcher, if different then +/// there is no match; +/// 2. calls the customizable filter function to refine the single operation +/// match with extra semantic constraints; +/// 3. if all is good, recursively matches the nested patterns; +/// 4. if all nested match then the single operation matches too and is +/// appended to the list of matches; +/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will +/// want to traverse in post-order DFS to avoid invalidating iterators. +void NestedPattern::matchOne(Operation *op, + SmallVectorImpl<NestedMatch> *matches) { + if (skip == op) { + return; + } + // Local custom filter function + if (!filter(*op)) { + return; + } + + if (nestedPatterns.empty()) { + SmallVector<NestedMatch, 8> nestedMatches; + matches->push_back(NestedMatch::build(op, nestedMatches)); + return; + } + // Take a copy of each nested pattern so we can match it. + for (auto nestedPattern : nestedPatterns) { + SmallVector<NestedMatch, 8> nestedMatches; + // Skip elem in the walk immediately following. Without this we would + // essentially need to reimplement walk here. + nestedPattern.skip = op; + nestedPattern.match(op, &nestedMatches); + // If we could not match even one of the specified nestedPattern, early exit + // as this whole branch is not a match. + if (nestedMatches.empty()) { + return; + } + matches->push_back(NestedMatch::build(op, nestedMatches)); + } +} + +static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); } + +static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); } + +namespace mlir { +namespace matcher { + +NestedPattern Op(FilterFunctionType filter) { + return NestedPattern({}, filter); +} + +NestedPattern If(NestedPattern child) { + return NestedPattern(child, isAffineIfOp); +} +NestedPattern If(FilterFunctionType filter, NestedPattern child) { + return NestedPattern(child, [filter](Operation &op) { + return isAffineIfOp(op) && filter(op); + }); +} +NestedPattern If(ArrayRef<NestedPattern> nested) { + return NestedPattern(nested, isAffineIfOp); +} +NestedPattern If(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { + return NestedPattern(nested, [filter](Operation &op) { + return isAffineIfOp(op) && filter(op); + }); +} + +NestedPattern For(NestedPattern child) { + return NestedPattern(child, isAffineForOp); +} +NestedPattern For(FilterFunctionType filter, NestedPattern child) { + return NestedPattern( + child, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); +} +NestedPattern For(ArrayRef<NestedPattern> nested) { + return NestedPattern(nested, isAffineForOp); +} +NestedPattern For(FilterFunctionType filter, ArrayRef<NestedPattern> nested) { + return NestedPattern( + nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); +} + +bool isLoadOrStore(Operation &op) { + return isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op); +} + +} // end namespace matcher +} // end namespace mlir diff --git a/mlir/lib/Analysis/OpStats.cpp b/mlir/lib/Analysis/OpStats.cpp new file mode 100644 index 00000000000..dbd938710ef --- /dev/null +++ b/mlir/lib/Analysis/OpStats.cpp @@ -0,0 +1,84 @@ +//===- OpStats.cpp - Prints stats of operations in module -----------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { +struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> { + explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} + + // Prints the resultant operation statistics post iterating over the module. + void runOnModule() override; + + // Print summary of op stats. + void printSummary(); + +private: + llvm::StringMap<int64_t> opCount; + raw_ostream &os; +}; +} // namespace + +void PrintOpStatsPass::runOnModule() { + opCount.clear(); + + // Compute the operation statistics for each function in the module. + for (auto &op : getModule()) + op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); + printSummary(); +} + +void PrintOpStatsPass::printSummary() { + os << "Operations encountered:\n"; + os << "-----------------------\n"; + SmallVector<StringRef, 64> sorted(opCount.keys()); + llvm::sort(sorted); + + // Split an operation name from its dialect prefix. + auto splitOperationName = [](StringRef opName) { + auto splitName = opName.split('.'); + return splitName.second.empty() ? std::make_pair("", splitName.first) + : splitName; + }; + + // Compute the largest dialect and operation name. + StringRef dialectName, opName; + size_t maxLenOpName = 0, maxLenDialect = 0; + for (const auto &key : sorted) { + std::tie(dialectName, opName) = splitOperationName(key); + maxLenDialect = std::max(maxLenDialect, dialectName.size()); + maxLenOpName = std::max(maxLenOpName, opName.size()); + } + + for (const auto &key : sorted) { + std::tie(dialectName, opName) = splitOperationName(key); + + // Left-align the names (aligning on the dialect) and right-align the count + // below. The alignment is for readability and does not affect CSV/FileCheck + // parsing. + if (dialectName.empty()) + os.indent(maxLenDialect + 3); + else + os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.'; + + // Left justify the operation name. + os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key] + << '\n'; + } +} + +static PassRegistration<PrintOpStatsPass> + pass("print-op-stats", "Print statistics of operations"); diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp new file mode 100644 index 00000000000..89ee613b370 --- /dev/null +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -0,0 +1,213 @@ +//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements Analysis functions specific to slicing in Function. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/Functional.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/SetVector.h" + +/// +/// Implements Analysis functions specific to slicing in Function. +/// + +using namespace mlir; + +using llvm::SetVector; + +static void getForwardSliceImpl(Operation *op, + SetVector<Operation *> *forwardSlice, + TransitiveFilter filter) { + if (!op) { + return; + } + + // Evaluate whether we should keep this use. + // This is useful in particular to implement scoping; i.e. return the + // transitive forwardSlice in the current scope. + if (!filter(op)) { + return; + } + + if (auto forOp = dyn_cast<AffineForOp>(op)) { + for (auto *ownerInst : forOp.getInductionVar()->getUsers()) + if (forwardSlice->count(ownerInst) == 0) + getForwardSliceImpl(ownerInst, forwardSlice, filter); + } else if (auto forOp = dyn_cast<loop::ForOp>(op)) { + for (auto *ownerInst : forOp.getInductionVar()->getUsers()) + if (forwardSlice->count(ownerInst) == 0) + getForwardSliceImpl(ownerInst, forwardSlice, filter); + } else { + assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); + assert(op->getNumResults() <= 1 && "unexpected multiple results"); + if (op->getNumResults() > 0) { + for (auto *ownerInst : op->getResult(0)->getUsers()) + if (forwardSlice->count(ownerInst) == 0) + getForwardSliceImpl(ownerInst, forwardSlice, filter); + } + } + + forwardSlice->insert(op); +} + +void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice, + TransitiveFilter filter) { + getForwardSliceImpl(op, forwardSlice, filter); + // Don't insert the top level operation, we just queried on it and don't + // want it in the results. + forwardSlice->remove(op); + + // Reverse to get back the actual topological order. + // std::reverse does not work out of the box on SetVector and I want an + // in-place swap based thing (the real std::reverse, not the LLVM adapter). + std::vector<Operation *> v(forwardSlice->takeVector()); + forwardSlice->insert(v.rbegin(), v.rend()); +} + +static void getBackwardSliceImpl(Operation *op, + SetVector<Operation *> *backwardSlice, + TransitiveFilter filter) { + if (!op) + return; + + assert((op->getNumRegions() == 0 || isa<AffineForOp>(op) || + isa<loop::ForOp>(op)) && + "unexpected generic op with regions"); + + // Evaluate whether we should keep this def. + // This is useful in particular to implement scoping; i.e. return the + // transitive forwardSlice in the current scope. + if (!filter(op)) { + return; + } + + for (auto en : llvm::enumerate(op->getOperands())) { + auto operand = en.value(); + if (auto blockArg = operand.dyn_cast<BlockArgument>()) { + if (auto affIv = getForInductionVarOwner(operand)) { + auto *affOp = affIv.getOperation(); + if (backwardSlice->count(affOp) == 0) + getBackwardSliceImpl(affOp, backwardSlice, filter); + } else if (auto loopIv = loop::getForInductionVarOwner(operand)) { + auto *loopOp = loopIv.getOperation(); + if (backwardSlice->count(loopOp) == 0) + getBackwardSliceImpl(loopOp, backwardSlice, filter); + } else if (blockArg->getOwner() != + &op->getParentOfType<FuncOp>().getBody().front()) { + op->emitError("unsupported CF for operand ") << en.index(); + llvm_unreachable("Unsupported control flow"); + } + continue; + } + auto *op = operand->getDefiningOp(); + if (backwardSlice->count(op) == 0) { + getBackwardSliceImpl(op, backwardSlice, filter); + } + } + + backwardSlice->insert(op); +} + +void mlir::getBackwardSlice(Operation *op, + SetVector<Operation *> *backwardSlice, + TransitiveFilter filter) { + getBackwardSliceImpl(op, backwardSlice, filter); + + // Don't insert the top level operation, we just queried on it and don't + // want it in the results. + backwardSlice->remove(op); +} + +SetVector<Operation *> mlir::getSlice(Operation *op, + TransitiveFilter backwardFilter, + TransitiveFilter forwardFilter) { + SetVector<Operation *> slice; + slice.insert(op); + + unsigned currentIndex = 0; + SetVector<Operation *> backwardSlice; + SetVector<Operation *> forwardSlice; + while (currentIndex != slice.size()) { + auto *currentInst = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentInst. + backwardSlice.clear(); + getBackwardSlice(currentInst, &backwardSlice, backwardFilter); + slice.insert(backwardSlice.begin(), backwardSlice.end()); + + // Compute and insert the forwardSlice starting from currentInst. + forwardSlice.clear(); + getForwardSlice(currentInst, &forwardSlice, forwardFilter); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + ++currentIndex; + } + return topologicalSort(slice); +} + +namespace { +/// DFS post-order implementation that maintains a global count to work across +/// multiple invocations, to help implement topological sort on multi-root DAGs. +/// We traverse all operations but only record the ones that appear in +/// `toSort` for the final result. +struct DFSState { + DFSState(const SetVector<Operation *> &set) + : toSort(set), topologicalCounts(), seen() {} + const SetVector<Operation *> &toSort; + SmallVector<Operation *, 16> topologicalCounts; + DenseSet<Operation *> seen; +}; +} // namespace + +static void DFSPostorder(Operation *current, DFSState *state) { + assert(current->getNumResults() <= 1 && "NYI: multi-result"); + if (current->getNumResults() > 0) { + for (auto &u : current->getResult(0)->getUses()) { + auto *op = u.getOwner(); + DFSPostorder(op, state); + } + } + bool inserted; + using IterTy = decltype(state->seen.begin()); + IterTy iter; + std::tie(iter, inserted) = state->seen.insert(current); + if (inserted) { + if (state->toSort.count(current) > 0) { + state->topologicalCounts.push_back(current); + } + } +} + +SetVector<Operation *> +mlir::topologicalSort(const SetVector<Operation *> &toSort) { + if (toSort.empty()) { + return toSort; + } + + // Run from each root with global count and `seen` set. + DFSState state(toSort); + for (auto *s : toSort) { + assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); + DFSPostorder(s, &state); + } + + // Reorder and return. + SetVector<Operation *> res; + for (auto it = state.topologicalCounts.rbegin(), + eit = state.topologicalCounts.rend(); + it != eit; ++it) { + res.insert(*it); + } + return res; +} diff --git a/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp new file mode 100644 index 00000000000..c6d7519740e --- /dev/null +++ b/mlir/lib/Analysis/TestMemRefDependenceCheck.cpp @@ -0,0 +1,121 @@ +//===- TestMemRefDependenceCheck.cpp - Test dep analysis ------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to run pair-wise memref access dependence checks. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Passes.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "test-memref-dependence-check" + +using namespace mlir; + +namespace { + +// TODO(andydavis) Add common surrounding loop depth-wise dependence checks. +/// Checks dependences between all pairs of memref accesses in a Function. +struct TestMemRefDependenceCheck + : public FunctionPass<TestMemRefDependenceCheck> { + SmallVector<Operation *, 4> loadsAndStores; + void runOnFunction() override; +}; + +} // end anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> +mlir::createTestMemRefDependenceCheckPass() { + return std::make_unique<TestMemRefDependenceCheck>(); +} + +// Returns a result string which represents the direction vector (if there was +// a dependence), returns the string "false" otherwise. +static std::string +getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth, + ArrayRef<DependenceComponent> dependenceComponents) { + if (!ret) + return "false"; + if (dependenceComponents.empty() || loopNestDepth > numCommonLoops) + return "true"; + std::string result; + for (unsigned i = 0, e = dependenceComponents.size(); i < e; ++i) { + std::string lbStr = "-inf"; + if (dependenceComponents[i].lb.hasValue() && + dependenceComponents[i].lb.getValue() != + std::numeric_limits<int64_t>::min()) + lbStr = std::to_string(dependenceComponents[i].lb.getValue()); + + std::string ubStr = "+inf"; + if (dependenceComponents[i].ub.hasValue() && + dependenceComponents[i].ub.getValue() != + std::numeric_limits<int64_t>::max()) + ubStr = std::to_string(dependenceComponents[i].ub.getValue()); + + result += "[" + lbStr + ", " + ubStr + "]"; + } + return result; +} + +// For each access in 'loadsAndStores', runs a dependence check between this +// "source" access and all subsequent "destination" accesses in +// 'loadsAndStores'. Emits the result of the dependence check as a note with +// the source access. +static void checkDependences(ArrayRef<Operation *> loadsAndStores) { + for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) { + auto *srcOpInst = loadsAndStores[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < e; ++j) { + auto *dstOpInst = loadsAndStores[j]; + MemRefAccess dstAccess(dstOpInst); + + unsigned numCommonLoops = + getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + FlatAffineConstraints dependenceConstraints; + SmallVector<DependenceComponent, 2> dependenceComponents; + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, + &dependenceComponents); + assert(result.value != DependenceResult::Failure); + bool ret = hasDependence(result); + // TODO(andydavis) Print dependence type (i.e. RAW, etc) and print + // distance vectors as: ([2, 3], [0, 10]). Also, shorten distance + // vectors from ([1, 1], [3, 3]) to (1, 3). + srcOpInst->emitRemark("dependence from ") + << i << " to " << j << " at depth " << d << " = " + << getDirectionVectorStr(ret, numCommonLoops, d, + dependenceComponents); + } + } + } +} + +// Walks the Function 'f' adding load and store ops to 'loadsAndStores'. +// Runs pair-wise dependence checks. +void TestMemRefDependenceCheck::runOnFunction() { + // Collect the loads and stores within the function. + loadsAndStores.clear(); + getFunction().walk([&](Operation *op) { + if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) + loadsAndStores.push_back(op); + }); + + checkDependences(loadsAndStores); +} + +static PassRegistration<TestMemRefDependenceCheck> + pass("test-memref-dependence-check", + "Checks dependences between all pairs of memref accesses."); diff --git a/mlir/lib/Analysis/TestParallelismDetection.cpp b/mlir/lib/Analysis/TestParallelismDetection.cpp new file mode 100644 index 00000000000..6cfc5431df3 --- /dev/null +++ b/mlir/lib/Analysis/TestParallelismDetection.cpp @@ -0,0 +1,48 @@ +//===- ParallelismDetection.cpp - Parallelism Detection pass ------------*-===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to detect parallel affine 'affine.for' ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Passes.h" +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct TestParallelismDetection + : public FunctionPass<TestParallelismDetection> { + void runOnFunction() override; +}; + +} // end anonymous namespace + +std::unique_ptr<OpPassBase<FuncOp>> mlir::createParallelismDetectionTestPass() { + return std::make_unique<TestParallelismDetection>(); +} + +// Walks the function and emits a note for all 'affine.for' ops detected as +// parallel. +void TestParallelismDetection::runOnFunction() { + FuncOp f = getFunction(); + OpBuilder b(f.getBody()); + f.walk([&](AffineForOp forOp) { + if (isLoopParallel(forOp)) + forOp.emitRemark("parallel loop"); + else + forOp.emitRemark("sequential loop"); + }); +} + +static PassRegistration<TestParallelismDetection> + pass("test-detect-parallel", "Test parallelism detection "); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp new file mode 100644 index 00000000000..8ddf2e274eb --- /dev/null +++ b/mlir/lib/Analysis/Utils.cpp @@ -0,0 +1,1007 @@ +//===- Utils.cpp ---- Misc utilities for analysis -------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous analysis routines for non-loop IR +// structures. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Utils.h" + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "analysis-utils" + +using namespace mlir; + +using llvm::SmallDenseMap; + +/// Populates 'loops' with IVs of the loops surrounding 'op' ordered from +/// the outermost 'affine.for' operation to the innermost one. +void mlir::getLoopIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) { + auto *currOp = op.getParentOp(); + AffineForOp currAffineForOp; + // Traverse up the hierarchy collecting all 'affine.for' operation while + // skipping over 'affine.if' operations. + while (currOp && ((currAffineForOp = dyn_cast<AffineForOp>(currOp)) || + isa<AffineIfOp>(currOp))) { + if (currAffineForOp) + loops->push_back(currAffineForOp); + currOp = currOp->getParentOp(); + } + std::reverse(loops->begin(), loops->end()); +} + +// Populates 'cst' with FlatAffineConstraints which represent slice bounds. +LogicalResult +ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) { + assert(!lbOperands.empty()); + // Adds src 'ivs' as dimension identifiers in 'cst'. + unsigned numDims = ivs.size(); + // Adds operands (dst ivs and symbols) as symbols in 'cst'. + unsigned numSymbols = lbOperands[0].size(); + + SmallVector<Value, 4> values(ivs); + // Append 'ivs' then 'operands' to 'values'. + values.append(lbOperands[0].begin(), lbOperands[0].end()); + cst->reset(numDims, numSymbols, 0, values); + + // Add loop bound constraints for values which are loop IVs and equality + // constraints for symbols which are constants. + for (const auto &value : values) { + assert(cst->containsId(*value) && "value expected to be present"); + if (isValidSymbol(value)) { + // Check if the symbol is a constant. + + if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp())) + cst->setIdToConstant(*value, cOp.getValue()); + } else if (auto loop = getForInductionVarOwner(value)) { + if (failed(cst->addAffineForOpDomain(loop))) + return failure(); + } + } + + // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]' + LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]); + assert(succeeded(ret) && + "should not fail as we never have semi-affine slice maps"); + (void)ret; + return success(); +} + +// Clears state bounds and operand state. +void ComputationSliceState::clearBounds() { + lbs.clear(); + ubs.clear(); + lbOperands.clear(); + ubOperands.clear(); +} + +unsigned MemRefRegion::getRank() const { + return memref->getType().cast<MemRefType>().getRank(); +} + +Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape( + SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs, + SmallVectorImpl<int64_t> *lbDivisors) const { + auto memRefType = memref->getType().cast<MemRefType>(); + unsigned rank = memRefType.getRank(); + if (shape) + shape->reserve(rank); + + assert(rank == cst.getNumDimIds() && "inconsistent memref region"); + + // Find a constant upper bound on the extent of this memref region along each + // dimension. + int64_t numElements = 1; + int64_t diffConstant; + int64_t lbDivisor; + for (unsigned d = 0; d < rank; d++) { + SmallVector<int64_t, 4> lb; + Optional<int64_t> diff = cst.getConstantBoundOnDimSize(d, &lb, &lbDivisor); + if (diff.hasValue()) { + diffConstant = diff.getValue(); + assert(lbDivisor > 0); + } else { + // If no constant bound is found, then it can always be bound by the + // memref's dim size if the latter has a constant size along this dim. + auto dimSize = memRefType.getDimSize(d); + if (dimSize == -1) + return None; + diffConstant = dimSize; + // Lower bound becomes 0. + lb.resize(cst.getNumSymbolIds() + 1, 0); + lbDivisor = 1; + } + numElements *= diffConstant; + if (lbs) { + lbs->push_back(lb); + assert(lbDivisors && "both lbs and lbDivisor or none"); + lbDivisors->push_back(lbDivisor); + } + if (shape) { + shape->push_back(diffConstant); + } + } + return numElements; +} + +LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { + assert(memref == other.memref); + return cst.unionBoundingBox(*other.getConstraints()); +} + +/// Computes the memory region accessed by this memref with the region +/// represented as constraints symbolic/parametric in 'loopDepth' loops +/// surrounding opInst and any additional Function symbols. +// For example, the memref region for this load operation at loopDepth = 1 will +// be as below: +// +// affine.for %i = 0 to 32 { +// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { +// load %A[%ii] +// } +// } +// +// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} } +// The last field is a 2-d FlatAffineConstraints symbolic in %i. +// +// TODO(bondhugula): extend this to any other memref dereferencing ops +// (dma_start, dma_wait). +LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, + ComputationSliceState *sliceState, + bool addMemRefDimBounds) { + assert((isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op)) && + "affine load/store op expected"); + + MemRefAccess access(op); + memref = access.memref; + write = access.isStore(); + + unsigned rank = access.getRank(); + + LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op + << "depth: " << loopDepth << "\n";); + + if (rank == 0) { + SmallVector<AffineForOp, 4> ivs; + getLoopIVs(*op, &ivs); + SmallVector<Value, 8> regionSymbols; + extractForInductionVars(ivs, ®ionSymbols); + // A rank 0 memref has a 0-d region. + cst.reset(rank, loopDepth, 0, regionSymbols); + return success(); + } + + // Build the constraints for this region. + AffineValueMap accessValueMap; + access.getAccessMap(&accessValueMap); + AffineMap accessMap = accessValueMap.getAffineMap(); + + unsigned numDims = accessMap.getNumDims(); + unsigned numSymbols = accessMap.getNumSymbols(); + unsigned numOperands = accessValueMap.getNumOperands(); + // Merge operands with slice operands. + SmallVector<Value, 4> operands; + operands.resize(numOperands); + for (unsigned i = 0; i < numOperands; ++i) + operands[i] = accessValueMap.getOperand(i); + + if (sliceState != nullptr) { + operands.reserve(operands.size() + sliceState->lbOperands[0].size()); + // Append slice operands to 'operands' as symbols. + for (auto extraOperand : sliceState->lbOperands[0]) { + if (!llvm::is_contained(operands, extraOperand)) { + operands.push_back(extraOperand); + numSymbols++; + } + } + } + // We'll first associate the dims and symbols of the access map to the dims + // and symbols resp. of cst. This will change below once cst is + // fully constructed out. + cst.reset(numDims, numSymbols, 0, operands); + + // Add equality constraints. + // Add inequalities for loop lower/upper bounds. + for (unsigned i = 0; i < numDims + numSymbols; ++i) { + auto operand = operands[i]; + if (auto loop = getForInductionVarOwner(operand)) { + // Note that cst can now have more dimensions than accessMap if the + // bounds expressions involve outer loops or other symbols. + // TODO(bondhugula): rewrite this to use getInstIndexSet; this way + // conditionals will be handled when the latter supports it. + if (failed(cst.addAffineForOpDomain(loop))) + return failure(); + } else { + // Has to be a valid symbol. + auto symbol = operand; + assert(isValidSymbol(symbol)); + // Check if the symbol is a constant. + if (auto *op = symbol->getDefiningOp()) { + if (auto constOp = dyn_cast<ConstantIndexOp>(op)) { + cst.setIdToConstant(*symbol, constOp.getValue()); + } + } + } + } + + // Add lower/upper bounds on loop IVs using bounds from 'sliceState'. + if (sliceState != nullptr) { + // Add dim and symbol slice operands. + for (auto operand : sliceState->lbOperands[0]) { + cst.addInductionVarOrTerminalSymbol(operand); + } + // Add upper/lower bounds from 'sliceState' to 'cst'. + LogicalResult ret = + cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs, + sliceState->lbOperands[0]); + assert(succeeded(ret) && + "should not fail as we never have semi-affine slice maps"); + (void)ret; + } + + // Add access function equalities to connect loop IVs to data dimensions. + if (failed(cst.composeMap(&accessValueMap))) { + op->emitError("getMemRefRegion: compose affine map failed"); + LLVM_DEBUG(accessValueMap.getAffineMap().dump()); + return failure(); + } + + // Set all identifiers appearing after the first 'rank' identifiers as + // symbolic identifiers - so that the ones corresponding to the memref + // dimensions are the dimensional identifiers for the memref region. + cst.setDimSymbolSeparation(cst.getNumDimAndSymbolIds() - rank); + + // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which + // this memref region is symbolic. + SmallVector<AffineForOp, 4> enclosingIVs; + getLoopIVs(*op, &enclosingIVs); + assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); + enclosingIVs.resize(loopDepth); + SmallVector<Value, 4> ids; + cst.getIdValues(cst.getNumDimIds(), cst.getNumDimAndSymbolIds(), &ids); + for (auto id : ids) { + AffineForOp iv; + if ((iv = getForInductionVarOwner(id)) && + llvm::is_contained(enclosingIVs, iv) == false) { + cst.projectOut(id); + } + } + + // Project out any local variables (these would have been added for any + // mod/divs). + cst.projectOut(cst.getNumDimAndSymbolIds(), cst.getNumLocalIds()); + + // Constant fold any symbolic identifiers. + cst.constantFoldIdRange(/*pos=*/cst.getNumDimIds(), + /*num=*/cst.getNumSymbolIds()); + + assert(cst.getNumDimIds() == rank && "unexpected MemRefRegion format"); + + // Add upper/lower bounds for each memref dimension with static size + // to guard against potential over-approximation from projection. + // TODO(andydavis) Support dynamic memref dimensions. + if (addMemRefDimBounds) { + auto memRefType = memref->getType().cast<MemRefType>(); + for (unsigned r = 0; r < rank; r++) { + cst.addConstantLowerBound(r, 0); + int64_t dimSize = memRefType.getDimSize(r); + if (ShapedType::isDynamic(dimSize)) + continue; + cst.addConstantUpperBound(r, dimSize - 1); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); + LLVM_DEBUG(cst.dump()); + return success(); +} + +// TODO(mlir-team): improve/complete this when we have target data. +static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast<VectorType>(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + +// Returns the size of the region. +Optional<int64_t> MemRefRegion::getRegionSize() { + auto memRefType = memref->getType().cast<MemRefType>(); + + auto layoutMaps = memRefType.getAffineMaps(); + if (layoutMaps.size() > 1 || + (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) { + LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + return false; + } + + // Indices to use for the DmaStart op. + // Indices for the original memref being DMAed from/to. + SmallVector<Value, 4> memIndices; + // Indices for the faster buffer being DMAed into/from. + SmallVector<Value, 4> bufIndices; + + // Compute the extents of the buffer. + Optional<int64_t> numElements = getConstantBoundingSizeAndShape(); + if (!numElements.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + return None; + } + return getMemRefEltSizeInBytes(memRefType) * numElements.getValue(); +} + +/// Returns the size of memref data in bytes if it's statically shaped, None +/// otherwise. If the element of the memref has vector type, takes into account +/// size of the vector as well. +// TODO(mlir-team): improve/complete this when we have target data. +Optional<uint64_t> mlir::getMemRefSizeInBytes(MemRefType memRefType) { + if (!memRefType.hasStaticShape()) + return None; + auto elementType = memRefType.getElementType(); + if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) + return None; + + uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType); + for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) { + sizeInBytes = sizeInBytes * memRefType.getDimSize(i); + } + return sizeInBytes; +} + +template <typename LoadOrStoreOpPointer> +LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp, + bool emitError) { + static_assert(std::is_same<LoadOrStoreOpPointer, AffineLoadOp>::value || + std::is_same<LoadOrStoreOpPointer, AffineStoreOp>::value, + "argument should be either a AffineLoadOp or a AffineStoreOp"); + + Operation *opInst = loadOrStoreOp.getOperation(); + MemRefRegion region(opInst->getLoc()); + if (failed(region.compute(opInst, /*loopDepth=*/0, /*sliceState=*/nullptr, + /*addMemRefDimBounds=*/false))) + return success(); + + LLVM_DEBUG(llvm::dbgs() << "Memory region"); + LLVM_DEBUG(region.getConstraints()->dump()); + + bool outOfBounds = false; + unsigned rank = loadOrStoreOp.getMemRefType().getRank(); + + // For each dimension, check for out of bounds. + for (unsigned r = 0; r < rank; r++) { + FlatAffineConstraints ucst(*region.getConstraints()); + + // Intersect memory region with constraint capturing out of bounds (both out + // of upper and out of lower), and check if the constraint system is + // feasible. If it is, there is at least one point out of bounds. + SmallVector<int64_t, 4> ineq(rank + 1, 0); + int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r); + // TODO(bondhugula): handle dynamic dim sizes. + if (dimSize == -1) + continue; + + // Check for overflow: d_i >= memref dim size. + ucst.addConstantLowerBound(r, dimSize); + outOfBounds = !ucst.isEmpty(); + if (outOfBounds && emitError) { + loadOrStoreOp.emitOpError() + << "memref out of upper bound access along dimension #" << (r + 1); + } + + // Check for a negative index. + FlatAffineConstraints lcst(*region.getConstraints()); + std::fill(ineq.begin(), ineq.end(), 0); + // d_i <= -1; + lcst.addConstantUpperBound(r, -1); + outOfBounds = !lcst.isEmpty(); + if (outOfBounds && emitError) { + loadOrStoreOp.emitOpError() + << "memref out of lower bound access along dimension #" << (r + 1); + } + } + return failure(outOfBounds); +} + +// Explicitly instantiate the template so that the compiler knows we need them! +template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineLoadOp loadOp, + bool emitError); +template LogicalResult mlir::boundCheckLoadOrStoreOp(AffineStoreOp storeOp, + bool emitError); + +// Returns in 'positions' the Block positions of 'op' in each ancestor +// Block from the Block containing operation, stopping at 'limitBlock'. +static void findInstPosition(Operation *op, Block *limitBlock, + SmallVectorImpl<unsigned> *positions) { + Block *block = op->getBlock(); + while (block != limitBlock) { + // FIXME: This algorithm is unnecessarily O(n) and should be improved to not + // rely on linear scans. + int instPosInBlock = std::distance(block->begin(), op->getIterator()); + positions->push_back(instPosInBlock); + op = block->getParentOp(); + block = op->getBlock(); + } + std::reverse(positions->begin(), positions->end()); +} + +// Returns the Operation in a possibly nested set of Blocks, where the +// position of the operation is represented by 'positions', which has a +// Block position for each level of nesting. +static Operation *getInstAtPosition(ArrayRef<unsigned> positions, + unsigned level, Block *block) { + unsigned i = 0; + for (auto &op : *block) { + if (i != positions[level]) { + ++i; + continue; + } + if (level == positions.size() - 1) + return &op; + if (auto childAffineForOp = dyn_cast<AffineForOp>(op)) + return getInstAtPosition(positions, level + 1, + childAffineForOp.getBody()); + + for (auto ®ion : op.getRegions()) { + for (auto &b : region) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; + } + return nullptr; + } + return nullptr; +} + +// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. +LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs, + FlatAffineConstraints *cst) { + for (unsigned i = 0, e = cst->getNumDimIds(); i < e; ++i) { + auto value = cst->getIdValue(i); + if (ivs.count(value) == 0) { + assert(isForInductionVar(value)); + auto loop = getForInductionVarOwner(value); + if (failed(cst->addAffineForOpDomain(loop))) + return failure(); + } + } + return success(); +} + +// Returns the innermost common loop depth for the set of operations in 'ops'. +// TODO(andydavis) Move this to LoopUtils. +static unsigned +getInnermostCommonLoopDepth(ArrayRef<Operation *> ops, + SmallVectorImpl<AffineForOp> &surroundingLoops) { + unsigned numOps = ops.size(); + assert(numOps > 0); + + std::vector<SmallVector<AffineForOp, 4>> loops(numOps); + unsigned loopDepthLimit = std::numeric_limits<unsigned>::max(); + for (unsigned i = 0; i < numOps; ++i) { + getLoopIVs(*ops[i], &loops[i]); + loopDepthLimit = + std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size())); + } + + unsigned loopDepth = 0; + for (unsigned d = 0; d < loopDepthLimit; ++d) { + unsigned i; + for (i = 1; i < numOps; ++i) { + if (loops[i - 1][d] != loops[i][d]) + return loopDepth; + } + surroundingLoops.push_back(loops[i - 1][d]); + ++loopDepth; + } + return loopDepth; +} + +/// Computes in 'sliceUnion' the union of all slice bounds computed at +/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB'. +/// Returns 'Success' if union was computed, 'failure' otherwise. +LogicalResult mlir::computeSliceUnion(ArrayRef<Operation *> opsA, + ArrayRef<Operation *> opsB, + unsigned loopDepth, + unsigned numCommonLoops, + bool isBackwardSlice, + ComputationSliceState *sliceUnion) { + // Compute the union of slice bounds between all pairs in 'opsA' and + // 'opsB' in 'sliceUnionCst'. + FlatAffineConstraints sliceUnionCst; + assert(sliceUnionCst.getNumDimAndSymbolIds() == 0); + std::vector<std::pair<Operation *, Operation *>> dependentOpPairs; + for (unsigned i = 0, numOpsA = opsA.size(); i < numOpsA; ++i) { + MemRefAccess srcAccess(opsA[i]); + for (unsigned j = 0, numOpsB = opsB.size(); j < numOpsB; ++j) { + MemRefAccess dstAccess(opsB[j]); + if (srcAccess.memref != dstAccess.memref) + continue; + // Check if 'loopDepth' exceeds nesting depth of src/dst ops. + if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) || + (isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) { + LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n."); + return failure(); + } + + bool readReadAccesses = isa<AffineLoadOp>(srcAccess.opInst) && + isa<AffineLoadOp>(dstAccess.opInst); + FlatAffineConstraints dependenceConstraints; + // Check dependence between 'srcAccess' and 'dstAccess'. + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1, + &dependenceConstraints, /*dependenceComponents=*/nullptr, + /*allowRAR=*/readReadAccesses); + if (result.value == DependenceResult::Failure) { + LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n."); + return failure(); + } + if (result.value == DependenceResult::NoDependence) + continue; + dependentOpPairs.push_back({opsA[i], opsB[j]}); + + // Compute slice bounds for 'srcAccess' and 'dstAccess'. + ComputationSliceState tmpSliceState; + mlir::getComputationSliceState(opsA[i], opsB[j], &dependenceConstraints, + loopDepth, isBackwardSlice, + &tmpSliceState); + + if (sliceUnionCst.getNumDimAndSymbolIds() == 0) { + // Initialize 'sliceUnionCst' with the bounds computed in previous step. + if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return failure(); + } + assert(sliceUnionCst.getNumDimAndSymbolIds() > 0); + continue; + } + + // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. + FlatAffineConstraints tmpSliceCst; + if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute slice bound constraints\n."); + return failure(); + } + + // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. + if (!sliceUnionCst.areIdsAlignedWithOther(tmpSliceCst)) { + + // Pre-constraint id alignment: record loop IVs used in each constraint + // system. + SmallPtrSet<Value, 8> sliceUnionIVs; + for (unsigned k = 0, l = sliceUnionCst.getNumDimIds(); k < l; ++k) + sliceUnionIVs.insert(sliceUnionCst.getIdValue(k)); + SmallPtrSet<Value, 8> tmpSliceIVs; + for (unsigned k = 0, l = tmpSliceCst.getNumDimIds(); k < l; ++k) + tmpSliceIVs.insert(tmpSliceCst.getIdValue(k)); + + sliceUnionCst.mergeAndAlignIdsWithOther(/*offset=*/0, &tmpSliceCst); + + // Post-constraint id alignment: add loop IV bounds missing after + // id alignment to constraint systems. This can occur if one constraint + // system uses an loop IV that is not used by the other. The call + // to unionBoundingBox below expects constraints for each Loop IV, even + // if they are the unsliced full loop bounds added here. + if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) + return failure(); + if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) + return failure(); + } + // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. + if (sliceUnionCst.getNumLocalIds() > 0 || + tmpSliceCst.getNumLocalIds() > 0 || + failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { + LLVM_DEBUG(llvm::dbgs() + << "Unable to compute union bounding box of slice bounds." + "\n."); + return failure(); + } + } + } + + // Empty union. + if (sliceUnionCst.getNumDimAndSymbolIds() == 0) + return failure(); + + // Gather loops surrounding ops from loop nest where slice will be inserted. + SmallVector<Operation *, 4> ops; + for (auto &dep : dependentOpPairs) { + ops.push_back(isBackwardSlice ? dep.second : dep.first); + } + SmallVector<AffineForOp, 4> surroundingLoops; + unsigned innermostCommonLoopDepth = + getInnermostCommonLoopDepth(ops, surroundingLoops); + if (loopDepth > innermostCommonLoopDepth) { + LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n."); + return failure(); + } + + // Store 'numSliceLoopIVs' before converting dst loop IVs to dims. + unsigned numSliceLoopIVs = sliceUnionCst.getNumDimIds(); + + // Convert any dst loop IVs which are symbol identifiers to dim identifiers. + sliceUnionCst.convertLoopIVSymbolsToDims(); + sliceUnion->clearBounds(); + sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap()); + sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap()); + + // Get slice bounds from slice union constraints 'sliceUnionCst'. + sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs, + opsA[0]->getContext(), &sliceUnion->lbs, + &sliceUnion->ubs); + + // Add slice bound operands of union. + SmallVector<Value, 4> sliceBoundOperands; + sliceUnionCst.getIdValues(numSliceLoopIVs, + sliceUnionCst.getNumDimAndSymbolIds(), + &sliceBoundOperands); + + // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'. + sliceUnion->ivs.clear(); + sliceUnionCst.getIdValues(0, numSliceLoopIVs, &sliceUnion->ivs); + + // Set loop nest insertion point to block start at 'loopDepth'. + sliceUnion->insertPoint = + isBackwardSlice + ? surroundingLoops[loopDepth - 1].getBody()->begin() + : std::prev(surroundingLoops[loopDepth - 1].getBody()->end()); + + // Give each bound its own copy of 'sliceBoundOperands' for subsequent + // canonicalization. + sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); + sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); + return success(); +} + +const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; +// Computes slice bounds by projecting out any loop IVs from +// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice +// bounds in 'sliceState' which represent the one loop nest's IVs in terms of +// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice'). +void mlir::getComputationSliceState( + Operation *depSourceOp, Operation *depSinkOp, + FlatAffineConstraints *dependenceConstraints, unsigned loopDepth, + bool isBackwardSlice, ComputationSliceState *sliceState) { + // Get loop nest surrounding src operation. + SmallVector<AffineForOp, 4> srcLoopIVs; + getLoopIVs(*depSourceOp, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + + // Get loop nest surrounding dst operation. + SmallVector<AffineForOp, 4> dstLoopIVs; + getLoopIVs(*depSinkOp, &dstLoopIVs); + unsigned numDstLoopIVs = dstLoopIVs.size(); + + assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) || + (isBackwardSlice && loopDepth <= numDstLoopIVs)); + + // Project out dimensions other than those up to 'loopDepth'. + unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth; + unsigned num = + isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth; + dependenceConstraints->projectOut(pos, num); + + // Add slice loop IV values to 'sliceState'. + unsigned offset = isBackwardSlice ? 0 : loopDepth; + unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs; + dependenceConstraints->getIdValues(offset, offset + numSliceLoopIVs, + &sliceState->ivs); + + // Set up lower/upper bound affine maps for the slice. + sliceState->lbs.resize(numSliceLoopIVs, AffineMap()); + sliceState->ubs.resize(numSliceLoopIVs, AffineMap()); + + // Get bounds for slice IVs in terms of other IVs, symbols, and constants. + dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs, + depSourceOp->getContext(), + &sliceState->lbs, &sliceState->ubs); + + // Set up bound operands for the slice's lower and upper bounds. + SmallVector<Value, 4> sliceBoundOperands; + unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolIds(); + for (unsigned i = 0; i < numDimsAndSymbols; ++i) { + if (i < offset || i >= offset + numSliceLoopIVs) { + sliceBoundOperands.push_back(dependenceConstraints->getIdValue(i)); + } + } + + // Give each bound its own copy of 'sliceBoundOperands' for subsequent + // canonicalization. + sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); + sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); + + // Set destination loop nest insertion point to block start at 'dstLoopDepth'. + sliceState->insertPoint = + isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin() + : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); + + llvm::SmallDenseSet<Value, 8> sequentialLoops; + if (isa<AffineLoadOp>(depSourceOp) && isa<AffineLoadOp>(depSinkOp)) { + // For read-read access pairs, clear any slice bounds on sequential loops. + // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. + getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0], + &sequentialLoops); + } + // Clear all sliced loop bounds beginning at the first sequential loop, or + // first loop with a slice fusion barrier attribute.. + // TODO(andydavis, bondhugula) Use MemRef read/write regions instead of + // using 'kSliceFusionBarrierAttrName'. + auto getSliceLoop = [&](unsigned i) { + return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; + }; + for (unsigned i = 0; i < numSliceLoopIVs; ++i) { + Value iv = getSliceLoop(i).getInductionVar(); + if (sequentialLoops.count(iv) == 0 && + getSliceLoop(i).getAttr(kSliceFusionBarrierAttrName) == nullptr) + continue; + for (unsigned j = i; j < numSliceLoopIVs; ++j) { + sliceState->lbs[j] = AffineMap(); + sliceState->ubs[j] = AffineMap(); + } + break; + } +} + +/// Creates a computation slice of the loop nest surrounding 'srcOpInst', +/// updates the slice loop bounds with any non-null bound maps specified in +/// 'sliceState', and inserts this slice into the loop nest surrounding +/// 'dstOpInst' at loop depth 'dstLoopDepth'. +// TODO(andydavis,bondhugula): extend the slicing utility to compute slices that +// aren't necessarily a one-to-one relation b/w the source and destination. The +// relation between the source and destination could be many-to-many in general. +// TODO(andydavis,bondhugula): the slice computation is incorrect in the cases +// where the dependence from the source to the destination does not cover the +// entire destination index set. Subtract out the dependent destination +// iterations from destination index set and check for emptiness --- this is one +// solution. +AffineForOp +mlir::insertBackwardComputationSlice(Operation *srcOpInst, Operation *dstOpInst, + unsigned dstLoopDepth, + ComputationSliceState *sliceState) { + // Get loop nest surrounding src operation. + SmallVector<AffineForOp, 4> srcLoopIVs; + getLoopIVs(*srcOpInst, &srcLoopIVs); + unsigned numSrcLoopIVs = srcLoopIVs.size(); + + // Get loop nest surrounding dst operation. + SmallVector<AffineForOp, 4> dstLoopIVs; + getLoopIVs(*dstOpInst, &dstLoopIVs); + unsigned dstLoopIVsSize = dstLoopIVs.size(); + if (dstLoopDepth > dstLoopIVsSize) { + dstOpInst->emitError("invalid destination loop depth"); + return AffineForOp(); + } + + // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'. + SmallVector<unsigned, 4> positions; + // TODO(andydavis): This code is incorrect since srcLoopIVs can be 0-d. + findInstPosition(srcOpInst, srcLoopIVs[0].getOperation()->getBlock(), + &positions); + + // Clone src loop nest and insert it a the beginning of the operation block + // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. + auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; + OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); + auto sliceLoopNest = + cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation())); + + Operation *sliceInst = + getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); + // Get loop nest surrounding 'sliceInst'. + SmallVector<AffineForOp, 4> sliceSurroundingLoops; + getLoopIVs(*sliceInst, &sliceSurroundingLoops); + + // Sanity check. + unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); + (void)sliceSurroundingLoopsSize; + assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize); + unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs; + (void)sliceLoopLimit; + assert(sliceLoopLimit >= sliceSurroundingLoopsSize); + + // Update loop bounds for loops in 'sliceLoopNest'. + for (unsigned i = 0; i < numSrcLoopIVs; ++i) { + auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; + if (AffineMap lbMap = sliceState->lbs[i]) + forOp.setLowerBound(sliceState->lbOperands[i], lbMap); + if (AffineMap ubMap = sliceState->ubs[i]) + forOp.setUpperBound(sliceState->ubOperands[i], ubMap); + } + return sliceLoopNest; +} + +// Constructs MemRefAccess populating it with the memref, its indices and +// opinst from 'loadOrStoreOpInst'. +MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { + if (auto loadOp = dyn_cast<AffineLoadOp>(loadOrStoreOpInst)) { + memref = loadOp.getMemRef(); + opInst = loadOrStoreOpInst; + auto loadMemrefType = loadOp.getMemRefType(); + indices.reserve(loadMemrefType.getRank()); + for (auto index : loadOp.getMapOperands()) { + indices.push_back(index); + } + } else { + assert(isa<AffineStoreOp>(loadOrStoreOpInst) && "load/store op expected"); + auto storeOp = dyn_cast<AffineStoreOp>(loadOrStoreOpInst); + opInst = loadOrStoreOpInst; + memref = storeOp.getMemRef(); + auto storeMemrefType = storeOp.getMemRefType(); + indices.reserve(storeMemrefType.getRank()); + for (auto index : storeOp.getMapOperands()) { + indices.push_back(index); + } + } +} + +unsigned MemRefAccess::getRank() const { + return memref->getType().cast<MemRefType>().getRank(); +} + +bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); } + +/// Returns the nesting depth of this statement, i.e., the number of loops +/// surrounding this statement. +unsigned mlir::getNestingDepth(Operation &op) { + Operation *currOp = &op; + unsigned depth = 0; + while ((currOp = currOp->getParentOp())) { + if (isa<AffineForOp>(currOp)) + depth++; + } + return depth; +} + +/// Equal if both affine accesses are provably equivalent (at compile +/// time) when considering the memref, the affine maps and their respective +/// operands. The equality of access functions + operands is checked by +/// subtracting fully composed value maps, and then simplifying the difference +/// using the expression flattener. +/// TODO: this does not account for aliasing of memrefs. +bool MemRefAccess::operator==(const MemRefAccess &rhs) const { + if (memref != rhs.memref) + return false; + + AffineValueMap diff, thisMap, rhsMap; + getAccessMap(&thisMap); + rhs.getAccessMap(&rhsMap); + AffineValueMap::difference(thisMap, rhsMap, &diff); + return llvm::all_of(diff.getAffineMap().getResults(), + [](AffineExpr e) { return e == 0; }); +} + +/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', +/// where each lists loops from outer-most to inner-most in loop nest. +unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) { + SmallVector<AffineForOp, 4> loopsA, loopsB; + getLoopIVs(A, &loopsA); + getLoopIVs(B, &loopsB); + + unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); + unsigned numCommonLoops = 0; + for (unsigned i = 0; i < minNumLoops; ++i) { + if (loopsA[i].getOperation() != loopsB[i].getOperation()) + break; + ++numCommonLoops; + } + return numCommonLoops; +} + +static Optional<int64_t> getMemoryFootprintBytes(Block &block, + Block::iterator start, + Block::iterator end, + int memorySpace) { + SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions; + + // Walk this 'affine.for' operation to gather all memory regions. + auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { + if (!isa<AffineLoadOp>(opInst) && !isa<AffineStoreOp>(opInst)) { + // Neither load nor a store op. + return WalkResult::advance(); + } + + // Compute the memref region symbolic in any IVs enclosing this block. + auto region = std::make_unique<MemRefRegion>(opInst->getLoc()); + if (failed( + region->compute(opInst, + /*loopDepth=*/getNestingDepth(*block.begin())))) { + return opInst->emitError("error obtaining memory region\n"); + } + + auto it = regions.find(region->memref); + if (it == regions.end()) { + regions[region->memref] = std::move(region); + } else if (failed(it->second->unionBoundingBox(*region))) { + return opInst->emitWarning( + "getMemoryFootprintBytes: unable to perform a union on a memory " + "region"); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return None; + + int64_t totalSizeInBytes = 0; + for (const auto ®ion : regions) { + Optional<int64_t> size = region.second->getRegionSize(); + if (!size.hasValue()) + return None; + totalSizeInBytes += size.getValue(); + } + return totalSizeInBytes; +} + +Optional<int64_t> mlir::getMemoryFootprintBytes(AffineForOp forOp, + int memorySpace) { + auto *forInst = forOp.getOperation(); + return ::getMemoryFootprintBytes( + *forInst->getBlock(), Block::iterator(forInst), + std::next(Block::iterator(forInst)), memorySpace); +} + +/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted +/// at 'forOp'. +void mlir::getSequentialLoops(AffineForOp forOp, + llvm::SmallDenseSet<Value, 8> *sequentialLoops) { + forOp.getOperation()->walk([&](Operation *op) { + if (auto innerFor = dyn_cast<AffineForOp>(op)) + if (!isLoopParallel(innerFor)) + sequentialLoops->insert(innerFor.getInductionVar()); + }); +} + +/// Returns true if 'forOp' is parallel. +bool mlir::isLoopParallel(AffineForOp forOp) { + // Collect all load and store ops in loop nest rooted at 'forOp'. + SmallVector<Operation *, 8> loadAndStoreOpInsts; + auto walkResult = forOp.walk([&](Operation *opInst) { + if (isa<AffineLoadOp>(opInst) || isa<AffineStoreOp>(opInst)) + loadAndStoreOpInsts.push_back(opInst); + else if (!isa<AffineForOp>(opInst) && !isa<AffineTerminatorOp>(opInst) && + !isa<AffineIfOp>(opInst) && !opInst->hasNoSideEffect()) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + // Stop early if the loop has unknown ops with side effects. + if (walkResult.wasInterrupted()) + return false; + + // Dep check depth would be number of enclosing loops + 1. + unsigned depth = getNestingDepth(*forOp.getOperation()) + 1; + + // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'. + for (auto *srcOpInst : loadAndStoreOpInsts) { + MemRefAccess srcAccess(srcOpInst); + for (auto *dstOpInst : loadAndStoreOpInsts) { + MemRefAccess dstAccess(dstOpInst); + FlatAffineConstraints dependenceConstraints; + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, depth, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (result.value != DependenceResult::NoDependence) + return false; + } + } + return true; +} diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp new file mode 100644 index 00000000000..1c7dbed5fac --- /dev/null +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -0,0 +1,232 @@ +//===- VectorAnalysis.cpp - Analysis for Vectorization --------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/VectorOps/Utils.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/Functional.h" +#include "mlir/Support/STLExtras.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" + +/// +/// Implements Analysis functions specific to vectors which support +/// the vectorization and vectorization materialization passes. +/// + +using namespace mlir; + +using llvm::SetVector; + +Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape, + ArrayRef<int64_t> subShape) { + if (superShape.size() < subShape.size()) { + return Optional<SmallVector<int64_t, 4>>(); + } + + // Starting from the end, compute the integer divisors. + // Set the boolean `divides` if integral division is not possible. + std::vector<int64_t> result; + result.reserve(superShape.size()); + bool divides = true; + auto divide = [÷s, &result](int superSize, int subSize) { + assert(superSize > 0 && "superSize must be > 0"); + assert(subSize > 0 && "subSize must be > 0"); + divides &= (superSize % subSize == 0); + result.push_back(superSize / subSize); + }; + functional::zipApply( + divide, SmallVector<int64_t, 8>{superShape.rbegin(), superShape.rend()}, + SmallVector<int64_t, 8>{subShape.rbegin(), subShape.rend()}); + + // If integral division does not occur, return and let the caller decide. + if (!divides) { + return None; + } + + // At this point we computed the ratio (in reverse) for the common + // size. Fill with the remaining entries from the super-vector shape (still in + // reverse). + int commonSize = subShape.size(); + std::copy(superShape.rbegin() + commonSize, superShape.rend(), + std::back_inserter(result)); + + assert(result.size() == superShape.size() && + "super to sub shape ratio is not of the same size as the super rank"); + + // Reverse again to get it back in the proper order and return. + return SmallVector<int64_t, 4>{result.rbegin(), result.rend()}; +} + +Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType, + VectorType subVectorType) { + assert(superVectorType.getElementType() == subVectorType.getElementType() && + "vector types must be of the same elemental type"); + return shapeRatio(superVectorType.getShape(), subVectorType.getShape()); +} + +/// Constructs a permutation map from memref indices to vector dimension. +/// +/// The implementation uses the knowledge of the mapping of enclosing loop to +/// vector dimension. `enclosingLoopToVectorDim` carries this information as a +/// map with: +/// - keys representing "vectorized enclosing loops"; +/// - values representing the corresponding vector dimension. +/// The algorithm traverses "vectorized enclosing loops" and extracts the +/// at-most-one MemRef index that is invariant along said loop. This index is +/// guaranteed to be at most one by construction: otherwise the MemRef is not +/// vectorizable. +/// If this invariant index is found, it is added to the permutation_map at the +/// proper vector dimension. +/// If no index is found to be invariant, 0 is added to the permutation_map and +/// corresponds to a vector broadcast along that dimension. +/// +/// Returns an empty AffineMap if `enclosingLoopToVectorDim` is empty, +/// signalling that no permutation map can be constructed given +/// `enclosingLoopToVectorDim`. +/// +/// Examples can be found in the documentation of `makePermutationMap`, in the +/// header file. +static AffineMap makePermutationMap( + ArrayRef<Value> indices, + const DenseMap<Operation *, unsigned> &enclosingLoopToVectorDim) { + if (enclosingLoopToVectorDim.empty()) + return AffineMap(); + MLIRContext *context = + enclosingLoopToVectorDim.begin()->getFirst()->getContext(); + using functional::makePtrDynCaster; + using functional::map; + SmallVector<AffineExpr, 4> perm(enclosingLoopToVectorDim.size(), + getAffineConstantExpr(0, context)); + + for (auto kvp : enclosingLoopToVectorDim) { + assert(kvp.second < perm.size()); + auto invariants = getInvariantAccesses( + cast<AffineForOp>(kvp.first).getInductionVar(), indices); + unsigned numIndices = indices.size(); + unsigned countInvariantIndices = 0; + for (unsigned dim = 0; dim < numIndices; ++dim) { + if (!invariants.count(indices[dim])) { + assert(perm[kvp.second] == getAffineConstantExpr(0, context) && + "permutationMap already has an entry along dim"); + perm[kvp.second] = getAffineDimExpr(dim, context); + } else { + ++countInvariantIndices; + } + } + assert((countInvariantIndices == numIndices || + countInvariantIndices == numIndices - 1) && + "Vectorization prerequisite violated: at most 1 index may be " + "invariant wrt a vectorized loop"); + } + return AffineMap::get(indices.size(), 0, perm); +} + +/// Implementation detail that walks up the parents and records the ones with +/// the specified type. +/// TODO(ntv): could also be implemented as a collect parents followed by a +/// filter and made available outside this file. +template <typename T> +static SetVector<Operation *> getParentsOfType(Operation *op) { + SetVector<Operation *> res; + auto *current = op; + while (auto *parent = current->getParentOp()) { + if (auto typedParent = dyn_cast<T>(parent)) { + assert(res.count(parent) == 0 && "Already inserted"); + res.insert(parent); + } + current = parent; + } + return res; +} + +/// Returns the enclosing AffineForOp, from closest to farthest. +static SetVector<Operation *> getEnclosingforOps(Operation *op) { + return getParentsOfType<AffineForOp>(op); +} + +AffineMap mlir::makePermutationMap( + Operation *op, ArrayRef<Value> indices, + const DenseMap<Operation *, unsigned> &loopToVectorDim) { + DenseMap<Operation *, unsigned> enclosingLoopToVectorDim; + auto enclosingLoops = getEnclosingforOps(op); + for (auto *forInst : enclosingLoops) { + auto it = loopToVectorDim.find(forInst); + if (it != loopToVectorDim.end()) { + enclosingLoopToVectorDim.insert(*it); + } + } + return ::makePermutationMap(indices, enclosingLoopToVectorDim); +} + +bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, + VectorType subVectorType) { + // First, extract the vector type and distinguish between: + // a. ops that *must* lower a super-vector (i.e. vector.transfer_read, + // vector.transfer_write); and + // b. ops that *may* lower a super-vector (all other ops). + // The ops that *may* lower a super-vector only do so if the super-vector to + // sub-vector ratio exists. The ops that *must* lower a super-vector are + // explicitly checked for this property. + /// TODO(ntv): there should be a single function for all ops to do this so we + /// do not have to special case. Maybe a trait, or just a method, unclear atm. + bool mustDivide = false; + (void)mustDivide; + VectorType superVectorType; + if (auto read = dyn_cast<vector::TransferReadOp>(op)) { + superVectorType = read.getVectorType(); + mustDivide = true; + } else if (auto write = dyn_cast<vector::TransferWriteOp>(op)) { + superVectorType = write.getVectorType(); + mustDivide = true; + } else if (op.getNumResults() == 0) { + if (!isa<ReturnOp>(op)) { + op.emitError("NYI: assuming only return operations can have 0 " + " results at this point"); + } + return false; + } else if (op.getNumResults() == 1) { + if (auto v = op.getResult(0)->getType().dyn_cast<VectorType>()) { + superVectorType = v; + } else { + // Not a vector type. + return false; + } + } else { + // Not a vector.transfer and has more than 1 result, fail hard for now to + // wake us up when something changes. + op.emitError("NYI: operation has more than 1 result"); + return false; + } + + // Get the ratio. + auto ratio = shapeRatio(superVectorType, subVectorType); + + // Sanity check. + assert((ratio.hasValue() || !mustDivide) && + "vector.transfer operation in which super-vector size is not an" + " integer multiple of sub-vector size"); + + // This catches cases that are not strictly necessary to have multiplicity but + // still aren't divisible by the sub-vector shape. + // This could be useful information if we wanted to reshape at the level of + // the vector type (but we would have to look at the compute and distinguish + // between parallel, reduction and possibly other cases. + if (!ratio.hasValue()) { + return false; + } + + return true; +} diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp new file mode 100644 index 00000000000..d4861b1a2e7 --- /dev/null +++ b/mlir/lib/Analysis/Verifier.cpp @@ -0,0 +1,266 @@ +//===- Verifier.cpp - MLIR Verifier Implementation ------------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the verify() methods on the various IR types, performing +// (potentially expensive) checks on the holistic structure of the code. This +// can be used for detecting bugs in compiler transformations and hand written +// .mlir files. +// +// The checks in this file are only for things that can occur as part of IR +// transformations: e.g. violation of dominance information, malformed operation +// attributes, etc. MLIR supports transformations moving IR through locally +// invalid states (e.g. unlinking an operation from a block before re-inserting +// it in a new place), but each transformation must complete with the IR in a +// valid form. +// +// This should not check for things that are always wrong by construction (e.g. +// attributes or other immutable structures that are incorrect), because those +// are not mutable and can be checked at time of construction. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/Analysis/Dominance.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Regex.h" + +using namespace mlir; + +namespace { +/// This class encapsulates all the state used to verify an operation region. +class OperationVerifier { +public: + explicit OperationVerifier(MLIRContext *ctx) + : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {} + + /// Verify the given operation. + LogicalResult verify(Operation &op); + + /// Returns the registered dialect for a dialect-specific attribute. + Dialect *getDialectForAttribute(const NamedAttribute &attr) { + assert(attr.first.strref().contains('.') && "expected dialect attribute"); + auto dialectNamePair = attr.first.strref().split('.'); + return ctx->getRegisteredDialect(dialectNamePair.first); + } + + /// Returns if the given string is valid to use as an identifier name. + bool isValidName(StringRef name) { return identifierRegex.match(name); } + +private: + /// Verify the given potentially nested region or block. + LogicalResult verifyRegion(Region ®ion); + LogicalResult verifyBlock(Block &block); + LogicalResult verifyOperation(Operation &op); + + /// Verify the dominance within the given IR unit. + LogicalResult verifyDominance(Region ®ion); + LogicalResult verifyDominance(Operation &op); + + /// Emit an error for the given block. + InFlightDiagnostic emitError(Block &bb, const Twine &message) { + // Take the location information for the first operation in the block. + if (!bb.empty()) + return bb.front().emitError(message); + + // Worst case, fall back to using the parent's location. + return mlir::emitError(bb.getParent()->getLoc(), message); + } + + /// The current context for the verifier. + MLIRContext *ctx; + + /// Dominance information for this operation, when checking dominance. + DominanceInfo *domInfo = nullptr; + + /// Regex checker for attribute names. + llvm::Regex identifierRegex; + + /// Mapping between dialect namespace and if that dialect supports + /// unregistered operations. + llvm::StringMap<bool> dialectAllowsUnknownOps; +}; +} // end anonymous namespace + +/// Verify the given operation. +LogicalResult OperationVerifier::verify(Operation &op) { + // Verify the operation first. + if (failed(verifyOperation(op))) + return failure(); + + // Since everything looks structurally ok to this point, we do a dominance + // check for any nested regions. We do this as a second pass since malformed + // CFG's can cause dominator analysis constructure to crash and we want the + // verifier to be resilient to malformed code. + DominanceInfo theDomInfo(&op); + domInfo = &theDomInfo; + for (auto ®ion : op.getRegions()) + if (failed(verifyDominance(region))) + return failure(); + + domInfo = nullptr; + return success(); +} + +LogicalResult OperationVerifier::verifyRegion(Region ®ion) { + if (region.empty()) + return success(); + + // Verify the first block has no predecessors. + auto *firstBB = ®ion.front(); + if (!firstBB->hasNoPredecessors()) + return mlir::emitError(region.getLoc(), + "entry block of region may not have predecessors"); + + // Verify each of the blocks within the region. + for (auto &block : region) + if (failed(verifyBlock(block))) + return failure(); + return success(); +} + +LogicalResult OperationVerifier::verifyBlock(Block &block) { + for (auto arg : block.getArguments()) + if (arg->getOwner() != &block) + return emitError(block, "block argument not owned by block"); + + // Verify that this block has a terminator. + if (block.empty()) + return emitError(block, "block with no terminator"); + + // Verify the non-terminator operations separately so that we can verify + // they has no successors. + for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) { + if (op.getNumSuccessors() != 0) + return op.emitError( + "operation with block successors must terminate its parent block"); + + if (failed(verifyOperation(op))) + return failure(); + } + + // Verify the terminator. + if (failed(verifyOperation(block.back()))) + return failure(); + if (block.back().isKnownNonTerminator()) + return emitError(block, "block with no terminator"); + + // Verify that this block is not branching to a block of a different + // region. + for (Block *successor : block.getSuccessors()) + if (successor->getParent() != block.getParent()) + return block.back().emitOpError( + "branching to block of a different region"); + + return success(); +} + +LogicalResult OperationVerifier::verifyOperation(Operation &op) { + // Check that operands are non-nil and structurally ok. + for (auto operand : op.getOperands()) + if (!operand) + return op.emitError("null operand found"); + + /// Verify that all of the attributes are okay. + for (auto attr : op.getAttrs()) { + if (!identifierRegex.match(attr.first)) + return op.emitError("invalid attribute name '") << attr.first << "'"; + + // Check for any optional dialect specific attributes. + if (!attr.first.strref().contains('.')) + continue; + if (auto *dialect = getDialectForAttribute(attr)) + if (failed(dialect->verifyOperationAttribute(&op, attr))) + return failure(); + } + + // If we can get operation info for this, check the custom hook. + auto *opInfo = op.getAbstractOperation(); + if (opInfo && failed(opInfo->verifyInvariants(&op))) + return failure(); + + // Verify that all child regions are ok. + for (auto ®ion : op.getRegions()) + if (failed(verifyRegion(region))) + return failure(); + + // If this is a registered operation, there is nothing left to do. + if (opInfo) + return success(); + + // Otherwise, verify that the parent dialect allows un-registered operations. + auto dialectPrefix = op.getName().getDialect(); + + // Check for an existing answer for the operation dialect. + auto it = dialectAllowsUnknownOps.find(dialectPrefix); + if (it == dialectAllowsUnknownOps.end()) { + // If the operation dialect is registered, query it directly. + if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix)) + it = dialectAllowsUnknownOps + .try_emplace(dialectPrefix, dialect->allowsUnknownOperations()) + .first; + // Otherwise, conservatively allow unknown operations. + else + it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first; + } + + if (!it->second) { + return op.emitError("unregistered operation '") + << op.getName() << "' found in dialect ('" << dialectPrefix + << "') that does not allow unknown operations"; + } + + return success(); +} + +LogicalResult OperationVerifier::verifyDominance(Region ®ion) { + // Verify the dominance of each of the held operations. + for (auto &block : region) + for (auto &op : block) + if (failed(verifyDominance(op))) + return failure(); + return success(); +} + +LogicalResult OperationVerifier::verifyDominance(Operation &op) { + // Check that operands properly dominate this use. + for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e; + ++operandNo) { + auto operand = op.getOperand(operandNo); + if (domInfo->properlyDominates(operand, &op)) + continue; + + auto diag = op.emitError("operand #") + << operandNo << " does not dominate this use"; + if (auto *useOp = operand->getDefiningOp()) + diag.attachNote(useOp->getLoc()) << "operand defined here"; + return failure(); + } + + // Verify the dominance of each of the nested blocks within this operation. + for (auto ®ion : op.getRegions()) + if (failed(verifyDominance(region))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Entrypoint +//===----------------------------------------------------------------------===// + +/// Perform (potentially expensive) checks of invariants, used to detect +/// compiler bugs. On error, this reports the error through the MLIRContext and +/// returns failure. +LogicalResult mlir::verify(Operation *op) { + return OperationVerifier(op->getContext()).verify(*op); +} |