//===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/AffineMap.h" #include "AffineMapDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; namespace { // AffineExprConstantFolder evaluates an affine expression using constant // operands passed in 'operandConsts'. Returns an IntegerAttr attribute // representing the constant value of the affine expression evaluated on // constant 'operandConsts', or nullptr if it can't be folded. class AffineExprConstantFolder { public: AffineExprConstantFolder(unsigned numDims, ArrayRef operandConsts) : numDims(numDims), operandConsts(operandConsts) {} /// Attempt to constant fold the specified affine expr, or return null on /// failure. IntegerAttr constantFold(AffineExpr expr) { if (auto result = constantFoldImpl(expr)) return IntegerAttr::get(IndexType::get(expr.getContext()), *result); return nullptr; } private: Optional constantFoldImpl(AffineExpr expr) { switch (expr.getKind()) { case AffineExprKind::Add: return constantFoldBinExpr( expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; }); case AffineExprKind::Mul: return constantFoldBinExpr( expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); case AffineExprKind::Mod: return constantFoldBinExpr( expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); }); case AffineExprKind::FloorDiv: return constantFoldBinExpr( expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); }); case AffineExprKind::CeilDiv: return constantFoldBinExpr( expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); }); case AffineExprKind::Constant: return expr.cast().getValue(); case AffineExprKind::DimId: if (auto attr = operandConsts[expr.cast().getPosition()] .dyn_cast_or_null()) return attr.getInt(); return llvm::None; case AffineExprKind::SymbolId: if (auto attr = operandConsts[numDims + expr.cast().getPosition()] .dyn_cast_or_null()) return attr.getInt(); return llvm::None; } llvm_unreachable("Unknown AffineExpr"); } // TODO: Change these to operate on APInts too. Optional constantFoldBinExpr(AffineExpr expr, int64_t (*op)(int64_t, int64_t)) { auto binOpExpr = expr.cast(); if (auto lhs = constantFoldImpl(binOpExpr.getLHS())) if (auto rhs = constantFoldImpl(binOpExpr.getRHS())) return op(*lhs, *rhs); return llvm::None; } // The number of dimension operands in AffineMap containing this expression. unsigned numDims; // The constant valued operands used to evaluate this AffineExpr. ArrayRef operandConsts; }; } // end anonymous namespace /// Returns a single constant result affine map. AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) { return get(/*dimCount=*/0, /*symbolCount=*/0, {getAffineConstantExpr(val, context)}); } /// Returns an AffineMap representing a permutation. AffineMap AffineMap::getPermutationMap(ArrayRef permutation, MLIRContext *context) { assert(!permutation.empty() && "Cannot create permutation map from empty permutation vector"); SmallVector affExprs; for (auto index : permutation) affExprs.push_back(getAffineDimExpr(index, context)); auto m = std::max_element(permutation.begin(), permutation.end()); auto permutationMap = AffineMap::get(*m + 1, 0, affExprs); assert(permutationMap.isPermutation() && "Invalid permutation vector"); return permutationMap; } AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims, MLIRContext *context) { SmallVector dimExprs; dimExprs.reserve(numDims); for (unsigned i = 0; i < numDims; ++i) dimExprs.push_back(mlir::getAffineDimExpr(i, context)); return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs); } MLIRContext *AffineMap::getContext() const { return map->context; } bool AffineMap::isIdentity() const { if (getNumDims() != getNumResults()) return false; ArrayRef results = getResults(); for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) { auto expr = results[i].dyn_cast(); if (!expr || expr.getPosition() != i) return false; } return true; } bool AffineMap::isEmpty() const { return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0; } bool AffineMap::isSingleConstant() const { return getNumResults() == 1 && getResult(0).isa(); } int64_t AffineMap::getSingleConstantResult() const { assert(isSingleConstant() && "map must have a single constant result"); return getResult(0).cast().getValue(); } unsigned AffineMap::getNumDims() const { assert(map && "uninitialized map storage"); return map->numDims; } unsigned AffineMap::getNumSymbols() const { assert(map && "uninitialized map storage"); return map->numSymbols; } unsigned AffineMap::getNumResults() const { assert(map && "uninitialized map storage"); return map->results.size(); } unsigned AffineMap::getNumInputs() const { assert(map && "uninitialized map storage"); return map->numDims + map->numSymbols; } ArrayRef AffineMap::getResults() const { assert(map && "uninitialized map storage"); return map->results; } AffineExpr AffineMap::getResult(unsigned idx) const { assert(map && "uninitialized map storage"); return map->results[idx]; } /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise. LogicalResult AffineMap::constantFold(ArrayRef operandConstants, SmallVectorImpl &results) const { assert(getNumInputs() == operandConstants.size()); // Fold each of the result expressions. AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); // Constant fold each AffineExpr in AffineMap and add to 'results'. for (auto expr : getResults()) { auto folded = exprFolder.constantFold(expr); // If we didn't fold to a constant, then folding fails. if (!folded) return failure(); results.push_back(folded); } assert(results.size() == getNumResults() && "constant folding produced the wrong number of results"); return success(); } /// Walk all of the AffineExpr's in this mapping. Each node in an expression /// tree is visited in postorder. void AffineMap::walkExprs(std::function callback) const { for (auto expr : getResults()) expr.walk(callback); } /// This method substitutes any uses of dimensions and symbols (e.g. /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified /// expression mapping. Because this can be used to eliminate dims and /// symbols, the client needs to specify the number of dims and symbols in /// the result. The returned map always has the same number of results. AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef dimReplacements, ArrayRef symReplacements, unsigned numResultDims, unsigned numResultSyms) { SmallVector results; results.reserve(getNumResults()); for (auto expr : getResults()) results.push_back( expr.replaceDimsAndSymbols(dimReplacements, symReplacements)); return get(numResultDims, numResultSyms, results); } AffineMap AffineMap::compose(AffineMap map) { assert(getNumDims() == map.getNumResults() && "Number of results mismatch"); // Prepare `map` by concatenating the symbols and rewriting its exprs. unsigned numDims = map.getNumDims(); unsigned numSymbolsThisMap = getNumSymbols(); unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols(); SmallVector newDims(numDims); for (unsigned idx = 0; idx < numDims; ++idx) { newDims[idx] = getAffineDimExpr(idx, getContext()); } SmallVector newSymbols(numSymbols); for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) { newSymbols[idx - numSymbolsThisMap] = getAffineSymbolExpr(idx, getContext()); } auto newMap = map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols); SmallVector exprs; exprs.reserve(getResults().size()); for (auto expr : getResults()) exprs.push_back(expr.compose(newMap)); return AffineMap::get(numDims, numSymbols, exprs); } bool AffineMap::isProjectedPermutation() { if (getNumSymbols() > 0) return false; SmallVector seen(getNumInputs(), false); for (auto expr : getResults()) { if (auto dim = expr.dyn_cast()) { if (seen[dim.getPosition()]) return false; seen[dim.getPosition()] = true; continue; } return false; } return true; } bool AffineMap::isPermutation() { if (getNumDims() != getNumResults()) return false; return isProjectedPermutation(); } AffineMap AffineMap::getSubMap(ArrayRef resultPos) { SmallVector exprs; exprs.reserve(resultPos.size()); for (auto idx : resultPos) { exprs.push_back(getResult(idx)); } return AffineMap::get(getNumDims(), getNumSymbols(), exprs); } AffineMap mlir::simplifyAffineMap(AffineMap map) { SmallVector exprs; for (auto e : map.getResults()) { exprs.push_back( simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols())); } return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs); } AffineMap mlir::inversePermutation(AffineMap map) { if (!map) return map; assert(map.getNumSymbols() == 0 && "expected map without symbols"); SmallVector exprs(map.getNumDims()); for (auto en : llvm::enumerate(map.getResults())) { auto expr = en.value(); // Skip non-permutations. if (auto d = expr.dyn_cast()) { if (exprs[d.getPosition()]) continue; exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext()); } } SmallVector seenExprs; seenExprs.reserve(map.getNumDims()); for (auto expr : exprs) if (expr) seenExprs.push_back(expr); if (seenExprs.size() != map.getNumInputs()) return AffineMap(); return AffineMap::get(map.getNumResults(), 0, seenExprs); } AffineMap mlir::concatAffineMaps(ArrayRef maps) { unsigned numResults = 0; for (auto m : maps) numResults += m ? m.getNumResults() : 0; unsigned numDims = 0; SmallVector results; results.reserve(numResults); for (auto m : maps) { if (!m) continue; assert(m.getNumSymbols() == 0 && "expected map without symbols"); results.append(m.getResults().begin(), m.getResults().end()); numDims = std::max(m.getNumDims(), numDims); } return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results); }