diff options
Diffstat (limited to 'mlir/lib/Dialect')
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/AffineOps.cpp | 1764 | ||||
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/CMakeLists.txt | 10 | ||||
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/DialectRegistration.cpp | 22 | ||||
| -rw-r--r-- | mlir/lib/Dialect/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp | 12 |
5 files changed, 1803 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp new file mode 100644 index 00000000000..7db3fa07c52 --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -0,0 +1,1764 @@ +//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/Debug.h" +using namespace mlir; +using llvm::dbgs; + +#define DEBUG_TYPE "affine-analysis" + +//===----------------------------------------------------------------------===// +// AffineOpsDialect +//===----------------------------------------------------------------------===// + +AffineOpsDialect::AffineOpsDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp, + AffineStoreOp, +#define GET_OP_LIST +#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc" + >(); +} + +/// A utility function to check if a given region is attached to a function. +static bool isFunctionRegion(Region *region) { + return llvm::isa<FuncOp>(region->getParentOp()); +} + +/// A utility function to check if a value is defined at the top level of a +/// function. A value defined at the top level is always a valid symbol. +bool mlir::isTopLevelSymbol(Value *value) { + if (auto *arg = dyn_cast<BlockArgument>(value)) + return isFunctionRegion(arg->getOwner()->getParent()); + return isFunctionRegion(value->getDefiningOp()->getParentRegion()); +} + +// Value can be used as a dimension id if it is valid as a symbol, or +// it is an induction variable, or it is a result of affine apply operation +// with dimension id arguments. +bool mlir::isValidDim(Value *value) { + // The value must be an index type. + if (!value->getType().isIndex()) + return false; + + if (auto *op = value->getDefiningOp()) { + // Top level operation or constant operation is ok. + if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op)) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast<AffineApplyOp>(op)) + return applyOp.isValidDim(); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast<DimOp>(op)) + return isTopLevelSymbol(dimOp.getOperand()); + return false; + } + // This value is a block argument (which also includes 'affine.for' loop IVs). + return true; +} + +// Value can be used as a symbol if it is a constant, or it is defined at +// the top level, or it is a result of affine apply operation with symbol +// arguments. +bool mlir::isValidSymbol(Value *value) { + // The value must be an index type. + if (!value->getType().isIndex()) + return false; + + if (auto *op = value->getDefiningOp()) { + // Top level operation or constant operation is ok. + if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op)) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto applyOp = dyn_cast<AffineApplyOp>(op)) + return applyOp.isValidSymbol(); + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (auto dimOp = dyn_cast<DimOp>(op)) + return isTopLevelSymbol(dimOp.getOperand()); + return false; + } + // Otherwise, check that the value is a top level symbol. + return isTopLevelSymbol(value); +} + +// Returns true if 'value' is a valid index to an affine operation (e.g. +// affine.load, affine.store, affine.dma_start, affine.dma_wait). +// Returns false otherwise. +static bool isValidAffineIndexOperand(Value *value) { + return isValidDim(value) || isValidSymbol(value); +} + +/// Utility function to verify that a set of operands are valid dimension and +/// symbol identifiers. The operands should be layed out such that the dimension +/// operands are before the symbol operands. This function returns failure if +/// there was an invalid operand. An operation is provided to emit any necessary +/// errors. +template <typename OpTy> +static LogicalResult +verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, + unsigned numDims) { + unsigned opIt = 0; + for (auto *operand : operands) { + if (opIt++ < numDims) { + if (!isValidDim(operand)) + return op.emitOpError("operand cannot be used as a dimension id"); + } else if (!isValidSymbol(operand)) { + return op.emitOpError("operand cannot be used as a symbol"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AffineApplyOp +//===----------------------------------------------------------------------===// + +void AffineApplyOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef<Value *> operands) { + result->addOperands(operands); + result->types.append(map.getNumResults(), builder->getIndexType()); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto affineIntTy = builder.getIndexType(); + + AffineMapAttr mapAttr; + unsigned numDims; + if (parser->parseAttribute(mapAttr, "map", result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims) || + parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + auto map = mapAttr.getValue(); + + if (map.getNumDims() != numDims || + numDims + map.getNumSymbols() != result->operands.size()) { + return parser->emitError(parser->getNameLoc(), + "dimension or symbol index mismatch"); + } + + result->types.append(map.getNumResults(), affineIntTy); + return success(); +} + +void AffineApplyOp::print(OpAsmPrinter *p) { + *p << "affine.apply " << getAttr("map"); + printDimAndSymbolList(operand_begin(), operand_end(), + getAffineMap().getNumDims(), p); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"}); +} + +LogicalResult AffineApplyOp::verify() { + // Check that affine map attribute was specified. + auto affineMapAttr = getAttrOfType<AffineMapAttr>("map"); + if (!affineMapAttr) + return emitOpError("requires an affine map"); + + // Check input and output dimensions match. + auto map = affineMapAttr.getValue(); + + // Verify that operand count matches affine map dimension and symbol count. + if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) + return emitOpError( + "operand count and affine map dimension and symbol count must match"); + + // Verify that all operands are of `index` type. + for (Type t : getOperandTypes()) { + if (!t.isIndex()) + return emitOpError("operands must be of type 'index'"); + } + + if (!getResult()->getType().isIndex()) + return emitOpError("result must be of type 'index'"); + + // Verify that the operands are valid dimension and symbol identifiers. + if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), + map.getNumDims()))) + return failure(); + + // Verify that the map only produces one result. + if (map.getNumResults() != 1) + return emitOpError("mapping must produce one value"); + + return success(); +} + +// The result of the affine apply operation can be used as a dimension id if it +// is a CFG value or if it is an Value, and all the operands are valid +// dimension ids. +bool AffineApplyOp::isValidDim() { + return llvm::all_of(getOperands(), + [](Value *op) { return mlir::isValidDim(op); }); +} + +// The result of the affine apply operation can be used as a symbol if it is +// a CFG value or if it is an Value, and all the operands are symbols. +bool AffineApplyOp::isValidSymbol() { + return llvm::all_of(getOperands(), + [](Value *op) { return mlir::isValidSymbol(op); }); +} + +OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) { + auto map = getAffineMap(); + + // Fold dims and symbols to existing values. + auto expr = map.getResult(0); + if (auto dim = expr.dyn_cast<AffineDimExpr>()) + return getOperand(dim.getPosition()); + if (auto sym = expr.dyn_cast<AffineSymbolExpr>()) + return getOperand(map.getNumDims() + sym.getPosition()); + + // Otherwise, default to folding the map. + SmallVector<Attribute, 1> result; + if (failed(map.constantFold(operands, result))) + return {}; + return result[0]; +} + +namespace { +/// An `AffineApplyNormalizer` is a helper class that is not visible to the user +/// and supports renumbering operands of AffineApplyOp. This acts as a +/// reindexing map of Value* to positional dims or symbols and allows +/// simplifications such as: +/// +/// ```mlir +/// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0) +/// ``` +/// +/// into: +/// +/// ```mlir +/// %1 = affine.apply () -> (0) +/// ``` +struct AffineApplyNormalizer { + AffineApplyNormalizer(AffineMap map, ArrayRef<Value *> operands); + + /// Returns the AffineMap resulting from normalization. + AffineMap getAffineMap() { return affineMap; } + + SmallVector<Value *, 8> getOperands() { + SmallVector<Value *, 8> res(reorderedDims); + res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); + return res; + } + +private: + /// Helper function to insert `v` into the coordinate system of the current + /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding + /// renumbered position. + AffineDimExpr renumberOneDim(Value *v); + + /// Given an `other` normalizer, this rewrites `other.affineMap` in the + /// coordinate system of the current AffineApplyNormalizer. + /// Returns the rewritten AffineMap and updates the dims and symbols of + /// `this`. + AffineMap renumber(const AffineApplyNormalizer &other); + + /// Maps of Value* to position in `affineMap`. + DenseMap<Value *, unsigned> dimValueToPosition; + + /// Ordered dims and symbols matching positional dims and symbols in + /// `affineMap`. + SmallVector<Value *, 8> reorderedDims; + SmallVector<Value *, 8> concatenatedSymbols; + + AffineMap affineMap; + + /// Used with RAII to control the depth at which AffineApply are composed + /// recursively. Only accepts depth 1 for now to allow a behavior where a + /// newly composed AffineApplyOp does not increase the length of the chain of + /// AffineApplyOps. Full composition is implemented iteratively on top of + /// this behavior. + static unsigned &affineApplyDepth() { + static thread_local unsigned depth = 0; + return depth; + } + static constexpr unsigned kMaxAffineApplyDepth = 1; + + AffineApplyNormalizer() { affineApplyDepth()++; } + +public: + ~AffineApplyNormalizer() { affineApplyDepth()--; } +}; +} // end anonymous namespace. + +AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) { + DenseMap<Value *, unsigned>::iterator iterPos; + bool inserted = false; + std::tie(iterPos, inserted) = + dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); + if (inserted) { + reorderedDims.push_back(v); + } + return getAffineDimExpr(iterPos->second, v->getContext()) + .cast<AffineDimExpr>(); +} + +AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { + SmallVector<AffineExpr, 8> dimRemapping; + for (auto *v : other.reorderedDims) { + auto kvp = other.dimValueToPosition.find(v); + if (dimRemapping.size() <= kvp->second) + dimRemapping.resize(kvp->second + 1); + dimRemapping[kvp->second] = renumberOneDim(kvp->first); + } + unsigned numSymbols = concatenatedSymbols.size(); + unsigned numOtherSymbols = other.concatenatedSymbols.size(); + SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols); + for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { + symRemapping[idx] = + getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); + } + concatenatedSymbols.insert(concatenatedSymbols.end(), + other.concatenatedSymbols.begin(), + other.concatenatedSymbols.end()); + auto map = other.affineMap; + return map.replaceDimsAndSymbols(dimRemapping, symRemapping, + dimRemapping.size(), symRemapping.size()); +} + +// Gather the positions of the operands that are produced by an AffineApplyOp. +static llvm::SetVector<unsigned> +indicesFromAffineApplyOp(ArrayRef<Value *> operands) { + llvm::SetVector<unsigned> res; + for (auto en : llvm::enumerate(operands)) + if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp())) + res.insert(en.index()); + return res; +} + +// Support the special case of a symbol coming from an AffineApplyOp that needs +// to be composed into the current AffineApplyOp. +// This case is handled by rewriting all such symbols into dims for the purpose +// of allowing mathematical AffineMap composition. +// Returns an AffineMap where symbols that come from an AffineApplyOp have been +// rewritten as dims and are ordered after the original dims. +// TODO(andydavis,ntv): This promotion makes AffineMap lose track of which +// symbols are represented as dims. This loss is static but can still be +// recovered dynamically (with `isValidSymbol`). Still this is annoying for the +// semi-affine map case. A dynamic canonicalization of all dims that are valid +// symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even +// results in better simplifications and foldings. But we should evaluate +// whether this behavior is what we really want after using more. +static AffineMap promoteComposedSymbolsAsDims(AffineMap map, + ArrayRef<Value *> symbols) { + if (symbols.empty()) { + return map; + } + + // Sanity check on symbols. + for (auto *sym : symbols) { + assert(isValidSymbol(sym) && "Expected only valid symbols"); + (void)sym; + } + + // Extract the symbol positions that come from an AffineApplyOp and + // needs to be rewritten as dims. + auto symPositions = indicesFromAffineApplyOp(symbols); + if (symPositions.empty()) { + return map; + } + + // Create the new map by replacing each symbol at pos by the next new dim. + unsigned numDims = map.getNumDims(); + unsigned numSymbols = map.getNumSymbols(); + unsigned numNewDims = 0; + unsigned numNewSymbols = 0; + SmallVector<AffineExpr, 8> symReplacements(numSymbols); + for (unsigned i = 0; i < numSymbols; ++i) { + symReplacements[i] = + symPositions.count(i) > 0 + ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) + : getAffineSymbolExpr(numNewSymbols++, map.getContext()); + } + assert(numSymbols >= numNewDims); + AffineMap newMap = map.replaceDimsAndSymbols( + {}, symReplacements, numDims + numNewDims, numNewSymbols); + + return newMap; +} + +/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to +/// keep a correspondence between the mathematical `map` and the `operands` of +/// a given AffineApplyOp. This correspondence is maintained by iterating over +/// the operands and forming an `auxiliaryMap` that can be composed +/// mathematically with `map`. To keep this correspondence in cases where +/// symbols are produced by affine.apply operations, we perform a local rewrite +/// of symbols as dims. +/// +/// Rationale for locally rewriting symbols as dims: +/// ================================================ +/// The mathematical composition of AffineMap must always concatenate symbols +/// because it does not have enough information to do otherwise. For example, +/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce +/// `(d0)[s0, s1] -> (d0 + s0 + s1)`. +/// +/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when +/// applied to the same mlir::Value* for both s0 and s1. +/// As a consequence mathematical composition of AffineMap always concatenates +/// symbols. +/// +/// When AffineMaps are used in AffineApplyOp however, they may specify +/// composition via symbols, which is ambiguous mathematically. This corner case +/// is handled by locally rewriting such symbols that come from AffineApplyOp +/// into dims and composing through dims. +/// TODO(andydavis, ntv): Composition via symbols comes at a significant code +/// complexity. Alternatively we should investigate whether we want to +/// explicitly disallow symbols coming from affine.apply and instead force the +/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 +/// extra API calls for such uses, which haven't popped up until now) and the +/// benefit potentially big: simpler and more maintainable code for a +/// non-trivial, recursive, procedure. +AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, + ArrayRef<Value *> operands) + : AffineApplyNormalizer() { + static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); + assert(map.getNumInputs() == operands.size() && + "number of operands does not match the number of map inputs"); + + LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); + + // Promote symbols that come from an AffineApplyOp to dims by rewriting the + // map to always refer to: + // (dims, symbols coming from AffineApplyOp, other symbols). + // The order of operands can remain unchanged. + // This is a simplification that relies on 2 ordering properties: + // 1. rewritten symbols always appear after the original dims in the map; + // 2. operands are traversed in order and either dispatched to: + // a. auxiliaryExprs (dims and symbols rewritten as dims); + // b. concatenatedSymbols (all other symbols) + // This allows operand order to remain unchanged. + unsigned numDimsBeforeRewrite = map.getNumDims(); + map = promoteComposedSymbolsAsDims(map, + operands.take_back(map.getNumSymbols())); + + LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); + + SmallVector<AffineExpr, 8> auxiliaryExprs; + bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); + // We fully spell out the 2 cases below. In this particular instance a little + // code duplication greatly improves readability. + // Note that the first branch would disappear if we only supported full + // composition (i.e. infinite kMaxAffineApplyDepth). + if (!furtherCompose) { + // 1. Only dispatch dims or symbols. + for (auto en : llvm::enumerate(operands)) { + auto *t = en.value(); + assert(t->getType().isIndex()); + bool isDim = (en.index() < map.getNumDims()); + if (isDim) { + // a. The mathematical composition of AffineMap composes dims. + auxiliaryExprs.push_back(renumberOneDim(t)); + } else { + // b. The mathematical composition of AffineMap concatenates symbols. + // We do the same for symbol operands. + concatenatedSymbols.push_back(t); + } + } + } else { + assert(numDimsBeforeRewrite <= operands.size()); + // 2. Compose AffineApplyOps and dispatch dims or symbols. + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + auto *t = operands[i]; + auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp()); + if (affineApply) { + // a. Compose affine.apply operations. + LLVM_DEBUG(affineApply.getOperation()->print( + dbgs() << "\nCompose AffineApplyOp recursively: ")); + AffineMap affineApplyMap = affineApply.getAffineMap(); + SmallVector<Value *, 8> affineApplyOperands( + affineApply.getOperands().begin(), affineApply.getOperands().end()); + AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); + + LLVM_DEBUG(normalizer.affineMap.print( + dbgs() << "\nRenumber into current normalizer: ")); + + auto renumberedMap = renumber(normalizer); + + LLVM_DEBUG( + renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); + + auxiliaryExprs.push_back(renumberedMap.getResult(0)); + } else { + if (i < numDimsBeforeRewrite) { + // b. The mathematical composition of AffineMap composes dims. + auxiliaryExprs.push_back(renumberOneDim(t)); + } else { + // c. The mathematical composition of AffineMap concatenates symbols. + // We do the same for symbol operands. + concatenatedSymbols.push_back(t); + } + } + } + } + + // Early exit if `map` is already composed. + if (auxiliaryExprs.empty()) { + affineMap = map; + return; + } + + assert(concatenatedSymbols.size() >= map.getNumSymbols() && + "Unexpected number of concatenated symbols"); + auto numDims = dimValueToPosition.size(); + auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); + auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs); + + LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); + LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); + LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); + + // TODO(andydavis,ntv): Disabling simplification results in major speed gains. + // Another option is to cache the results as it is expected a lot of redundant + // work is performed in practice. + affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); + + LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); + LLVM_DEBUG(dbgs() << "\n"); +} + +/// Implements `map` and `operands` composition and simplification to support +/// `makeComposedAffineApply`. This can be called to achieve the same effects +/// on `map` and `operands` without creating an AffineApplyOp that needs to be +/// immediately deleted. +static void composeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl<Value *> *operands) { + AffineApplyNormalizer normalizer(*map, *operands); + auto normalizedMap = normalizer.getAffineMap(); + auto normalizedOperands = normalizer.getOperands(); + canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); + *map = normalizedMap; + *operands = normalizedOperands; + assert(*map); +} + +void mlir::fullyComposeAffineMapAndOperands( + AffineMap *map, SmallVectorImpl<Value *> *operands) { + while (llvm::any_of(*operands, [](Value *v) { + return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp()); + })) { + composeAffineMapAndOperands(map, operands); + } +} + +AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, + AffineMap map, + ArrayRef<Value *> operands) { + AffineMap normalizedMap = map; + SmallVector<Value *, 8> normalizedOperands(operands.begin(), operands.end()); + composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); + assert(normalizedMap); + return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands); +} + +// A symbol may appear as a dim in affine.apply operations. This function +// canonicalizes dims that are valid symbols into actual symbols. +static void +canonicalizePromotedSymbols(AffineMap *map, + llvm::SmallVectorImpl<Value *> *operands) { + if (!map || operands->empty()) + return; + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); + + auto *context = map->getContext(); + SmallVector<Value *, 8> resultOperands; + resultOperands.reserve(operands->size()); + SmallVector<Value *, 8> remappedSymbols; + remappedSymbols.reserve(operands->size()); + unsigned nextDim = 0; + unsigned nextSym = 0; + unsigned oldNumSyms = map->getNumSymbols(); + SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims()); + for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) { + if (i < map->getNumDims()) { + if (isValidSymbol((*operands)[i])) { + // This is a valid symbols that appears as a dim, canonicalize it. + dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); + remappedSymbols.push_back((*operands)[i]); + } else { + dimRemapping[i] = getAffineDimExpr(nextDim++, context); + resultOperands.push_back((*operands)[i]); + } + } else { + resultOperands.push_back((*operands)[i]); + } + } + + resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); + *operands = resultOperands; + *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim, + oldNumSyms + nextSym); + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); +} + +void mlir::canonicalizeMapAndOperands( + AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) { + if (!map || operands->empty()) + return; + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); + + canonicalizePromotedSymbols(map, operands); + + // Check to see what dims are used. + llvm::SmallBitVector usedDims(map->getNumDims()); + llvm::SmallBitVector usedSyms(map->getNumSymbols()); + map->walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) + usedDims[dimExpr.getPosition()] = true; + else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) + usedSyms[symExpr.getPosition()] = true; + }); + + auto *context = map->getContext(); + + SmallVector<Value *, 8> resultOperands; + resultOperands.reserve(operands->size()); + + llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims; + SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims()); + unsigned nextDim = 0; + for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { + if (usedDims[i]) { + auto it = seenDims.find((*operands)[i]); + if (it == seenDims.end()) { + dimRemapping[i] = getAffineDimExpr(nextDim++, context); + resultOperands.push_back((*operands)[i]); + seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); + } else { + dimRemapping[i] = it->second; + } + } + } + llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols; + SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols()); + unsigned nextSym = 0; + for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { + if (usedSyms[i]) { + auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); + if (it == seenSymbols.end()) { + symRemapping[i] = getAffineSymbolExpr(nextSym++, context); + resultOperands.push_back((*operands)[i + map->getNumDims()]); + seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()], + symRemapping[i])); + } else { + symRemapping[i] = it->second; + } + } + } + *map = + map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); + *operands = resultOperands; +} + +namespace { +/// Simplify AffineApply operations. +/// +struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> { + using OpRewritePattern<AffineApplyOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineApplyOp apply, + PatternRewriter &rewriter) const override { + auto map = apply.getAffineMap(); + + AffineMap oldMap = map; + SmallVector<Value *, 8> resultOperands(apply.getOperands()); + composeAffineMapAndOperands(&map, &resultOperands); + if (map == oldMap) + return matchFailure(); + + rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands); + return matchSuccess(); + } +}; +} // end anonymous namespace. + +void AffineApplyOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<SimplifyAffineApply>(context); +} + +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +namespace { +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +struct MemRefCastFolder : public RewritePattern { + /// The rootOpName is the name of the root operation to match against. + MemRefCastFolder(StringRef rootOpName, MLIRContext *context) + : RewritePattern(rootOpName, 1, context) {} + + PatternMatchResult match(Operation *op) const override { + for (auto *operand : op->getOperands()) + if (matchPattern(operand, m_Op<MemRefCastOp>())) + return matchSuccess(); + + return matchFailure(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + if (auto *memref = op->getOperand(i)->getDefiningOp()) + if (auto cast = dyn_cast<MemRefCastOp>(memref)) + op->setOperand(i, cast.getOperand()); + rewriter.updatedRootInPlace(op); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// AffineDmaStartOp +//===----------------------------------------------------------------------===// + +// TODO(b/133776335) Check that map operands are loop IVs or symbols. +void AffineDmaStartOp::build(Builder *builder, OperationState *result, + Value *srcMemRef, AffineMap srcMap, + ArrayRef<Value *> srcIndices, Value *destMemRef, + AffineMap dstMap, ArrayRef<Value *> destIndices, + Value *tagMemRef, AffineMap tagMap, + ArrayRef<Value *> tagIndices, Value *numElements, + Value *stride, Value *elementsPerStride) { + result->addOperands(srcMemRef); + result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap)); + result->addOperands(srcIndices); + result->addOperands(destMemRef); + result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap)); + result->addOperands(destIndices); + result->addOperands(tagMemRef); + result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result->addOperands(tagIndices); + result->addOperands(numElements); + if (stride) { + result->addOperands({stride, elementsPerStride}); + } +} + +void AffineDmaStartOp::print(OpAsmPrinter *p) { + *p << "affine.dma_start " << *getSrcMemRef() << '['; + SmallVector<Value *, 8> operands(getSrcIndices()); + p->printAffineMapOfSSAIds(getSrcMapAttr(), operands); + *p << "], " << *getDstMemRef() << '['; + operands.assign(getDstIndices().begin(), getDstIndices().end()); + p->printAffineMapOfSSAIds(getDstMapAttr(), operands); + *p << "], " << *getTagMemRef() << '['; + operands.assign(getTagIndices().begin(), getTagIndices().end()); + p->printAffineMapOfSSAIds(getTagMapAttr(), operands); + *p << "], " << *getNumElements(); + if (isStrided()) { + *p << ", " << *getStride(); + *p << ", " << *getNumElementsPerStride(); + } + *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " + << getTagMemRefType(); +} + +// Parse AffineDmaStartOp. +// Ex: +// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, +// %stride, %num_elt_per_stride +// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> +// +ParseResult AffineDmaStartOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType srcMemRefInfo; + AffineMapAttr srcMapAttr; + SmallVector<OpAsmParser::OperandType, 4> srcMapOperands; + OpAsmParser::OperandType dstMemRefInfo; + AffineMapAttr dstMapAttr; + SmallVector<OpAsmParser::OperandType, 4> dstMapOperands; + OpAsmParser::OperandType tagMemRefInfo; + AffineMapAttr tagMapAttr; + SmallVector<OpAsmParser::OperandType, 4> tagMapOperands; + OpAsmParser::OperandType numElementsInfo; + SmallVector<OpAsmParser::OperandType, 2> strideInfo; + + SmallVector<Type, 3> types; + auto indexType = parser->getBuilder().getIndexType(); + + // Parse and resolve the following list of operands: + // *) dst memref followed by its affine maps operands (in square brackets). + // *) src memref followed by its affine map operands (in square brackets). + // *) tag memref followed by its affine map operands (in square brackets). + // *) number of elements transferred by DMA operation. + if (parser->parseOperand(srcMemRefInfo) || + parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, + getSrcMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(dstMemRefInfo) || + parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, + getDstMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(tagMemRefInfo) || + parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, + getTagMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(numElementsInfo)) + return failure(); + + // Parse optional stride and elements per stride. + if (parser->parseTrailingOperandList(strideInfo)) { + return failure(); + } + if (!strideInfo.empty() && strideInfo.size() != 2) { + return parser->emitError(parser->getNameLoc(), + "expected two stride related operands"); + } + bool isStrided = strideInfo.size() == 2; + + if (parser->parseColonTypeList(types)) + return failure(); + + if (types.size() != 3) + return parser->emitError(parser->getNameLoc(), "expected three types"); + + if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) || + parser->resolveOperands(srcMapOperands, indexType, result->operands) || + parser->resolveOperand(dstMemRefInfo, types[1], result->operands) || + parser->resolveOperands(dstMapOperands, indexType, result->operands) || + parser->resolveOperand(tagMemRefInfo, types[2], result->operands) || + parser->resolveOperands(tagMapOperands, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + if (isStrided) { + if (parser->resolveOperands(strideInfo, indexType, result->operands)) + return failure(); + } + + // Check that src/dst/tag operand counts match their map.numInputs. + if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || + dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || + tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) + return parser->emitError(parser->getNameLoc(), + "memref operand count not equal to map.numInputs"); + return success(); +} + +LogicalResult AffineDmaStartOp::verify() { + if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>()) + return emitOpError("expected DMA source to be of memref type"); + if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>()) + return emitOpError("expected DMA destination to be of memref type"); + if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>()) + return emitOpError("expected DMA tag to be of memref type"); + + // DMAs from different memory spaces supported. + if (getSrcMemorySpace() == getDstMemorySpace()) { + return emitOpError("DMA should be between different memory spaces"); + } + unsigned numInputsAllMaps = getSrcMap().getNumInputs() + + getDstMap().getNumInputs() + + getTagMap().getNumInputs(); + if (getNumOperands() != numInputsAllMaps + 3 + 1 && + getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { + return emitOpError("incorrect number of operands"); + } + + for (auto *idx : getSrcIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("src index to dma_start must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("src index must be a dimension or symbol identifier"); + } + for (auto *idx : getDstIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("dst index to dma_start must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("dst index must be a dimension or symbol identifier"); + } + for (auto *idx : getTagIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("tag index to dma_start must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("tag index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineDmaStartOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// dma_start(memrefcast) -> dma_start + results.insert<MemRefCastFolder>(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// AffineDmaWaitOp +//===----------------------------------------------------------------------===// + +// TODO(b/133776335) Check that map operands are loop IVs or symbols. +void AffineDmaWaitOp::build(Builder *builder, OperationState *result, + Value *tagMemRef, AffineMap tagMap, + ArrayRef<Value *> tagIndices, Value *numElements) { + result->addOperands(tagMemRef); + result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap)); + result->addOperands(tagIndices); + result->addOperands(numElements); +} + +void AffineDmaWaitOp::print(OpAsmPrinter *p) { + *p << "affine.dma_wait " << *getTagMemRef() << '['; + SmallVector<Value *, 2> operands(getTagIndices()); + p->printAffineMapOfSSAIds(getTagMapAttr(), operands); + *p << "], "; + p->printOperand(getNumElements()); + *p << " : " << getTagMemRef()->getType(); +} + +// Parse AffineDmaWaitOp. +// Eg: +// affine.dma_wait %tag[%index], %num_elements +// : memref<1 x i32, (d0) -> (d0), 4> +// +ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType tagMemRefInfo; + AffineMapAttr tagMapAttr; + SmallVector<OpAsmParser::OperandType, 2> tagMapOperands; + Type type; + auto indexType = parser->getBuilder().getIndexType(); + OpAsmParser::OperandType numElementsInfo; + + // Parse tag memref, its map operands, and dma size. + if (parser->parseOperand(tagMemRefInfo) || + parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, + getTagMapAttrName(), result->attributes) || + parser->parseComma() || parser->parseOperand(numElementsInfo) || + parser->parseColonType(type) || + parser->resolveOperand(tagMemRefInfo, type, result->operands) || + parser->resolveOperands(tagMapOperands, indexType, result->operands) || + parser->resolveOperand(numElementsInfo, indexType, result->operands)) + return failure(); + + if (!type.isa<MemRefType>()) + return parser->emitError(parser->getNameLoc(), + "expected tag to be of memref type"); + + if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) + return parser->emitError(parser->getNameLoc(), + "tag memref operand count != to map.numInputs"); + return success(); +} + +LogicalResult AffineDmaWaitOp::verify() { + if (!getOperand(0)->getType().isa<MemRefType>()) + return emitOpError("expected DMA tag to be of memref type"); + for (auto *idx : getTagIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("index to dma_wait must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineDmaWaitOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// dma_wait(memrefcast) -> dma_wait + results.insert<MemRefCastFolder>(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// AffineForOp +//===----------------------------------------------------------------------===// + +void AffineForOp::build(Builder *builder, OperationState *result, + ArrayRef<Value *> lbOperands, AffineMap lbMap, + ArrayRef<Value *> ubOperands, AffineMap ubMap, + int64_t step) { + assert(((!lbMap && lbOperands.empty()) || + lbOperands.size() == lbMap.getNumInputs()) && + "lower bound operand count does not match the affine map"); + assert(((!ubMap && ubOperands.empty()) || + ubOperands.size() == ubMap.getNumInputs()) && + "upper bound operand count does not match the affine map"); + assert(step > 0 && "step has to be a positive integer constant"); + + // Add an attribute for the step. + result->addAttribute(getStepAttrName(), + builder->getIntegerAttr(builder->getIndexType(), step)); + + // Add the lower bound. + result->addAttribute(getLowerBoundAttrName(), + builder->getAffineMapAttr(lbMap)); + result->addOperands(lbOperands); + + // Add the upper bound. + result->addAttribute(getUpperBoundAttrName(), + builder->getAffineMapAttr(ubMap)); + result->addOperands(ubOperands); + + // Create a region and a block for the body. The argument of the region is + // the loop induction variable. + Region *bodyRegion = result->addRegion(); + Block *body = new Block(); + body->addArgument(IndexType::get(builder->getContext())); + bodyRegion->push_back(body); + ensureTerminator(*bodyRegion, *builder, result->location); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); +} + +void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb, + int64_t ub, int64_t step) { + auto lbMap = AffineMap::getConstantMap(lb, builder->getContext()); + auto ubMap = AffineMap::getConstantMap(ub, builder->getContext()); + return build(builder, result, {}, lbMap, {}, ubMap, step); +} + +static LogicalResult verify(AffineForOp op) { + // Check that the body defines as single block argument for the induction + // variable. + auto *body = op.getBody(); + if (body->getNumArguments() != 1 || + !body->getArgument(0)->getType().isIndex()) + return op.emitOpError( + "expected body to have a single index argument for the " + "induction variable"); + + // Verify that there are enough operands for the bounds. + AffineMap lowerBoundMap = op.getLowerBoundMap(), + upperBoundMap = op.getUpperBoundMap(); + if (op.getNumOperands() != + (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs())) + return op.emitOpError( + "operand count must match with affine map dimension and symbol count"); + + // Verify that the bound operands are valid dimension/symbols. + /// Lower bound. + if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), + op.getLowerBoundMap().getNumDims()))) + return failure(); + /// Upper bound. + if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), + op.getUpperBoundMap().getNumDims()))) + return failure(); + return success(); +} + +/// Parse a for operation loop bounds. +static ParseResult parseBound(bool isLower, OperationState *result, + OpAsmParser *p) { + // 'min' / 'max' prefixes are generally syntactic sugar, but are required if + // the map has multiple results. + bool failedToParsedMinMax = + failed(p->parseOptionalKeyword(isLower ? "max" : "min")); + + auto &builder = p->getBuilder(); + auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() + : AffineForOp::getUpperBoundAttrName(); + + // Parse ssa-id as identity map. + SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; + if (p->parseOperandList(boundOpInfos)) + return failure(); + + if (!boundOpInfos.empty()) { + // Check that only one operand was parsed. + if (boundOpInfos.size() > 1) + return p->emitError(p->getNameLoc(), + "expected only one loop bound operand"); + + // TODO: improve error message when SSA value is not an affine integer. + // Currently it is 'use of value ... expects different type than prior uses' + if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(), + result->operands)) + return failure(); + + // Create an identity map using symbol id. This representation is optimized + // for storage. Analysis passes may expand it into a multi-dimensional map + // if desired. + AffineMap map = builder.getSymbolIdentityMap(); + result->addAttribute(boundAttrName, builder.getAffineMapAttr(map)); + return success(); + } + + // Get the attribute location. + llvm::SMLoc attrLoc = p->getCurrentLocation(); + + Attribute boundAttr; + if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, + result->attributes)) + return failure(); + + // Parse full form - affine map followed by dim and symbol list. + if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { + unsigned currentNumOperands = result->operands.size(); + unsigned numDims; + if (parseDimAndSymbolList(p, result->operands, numDims)) + return failure(); + + auto map = affineMapAttr.getValue(); + if (map.getNumDims() != numDims) + return p->emitError( + p->getNameLoc(), + "dim operand count and integer set dim count must match"); + + unsigned numDimAndSymbolOperands = + result->operands.size() - currentNumOperands; + if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) + return p->emitError( + p->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // If the map has multiple results, make sure that we parsed the min/max + // prefix. + if (map.getNumResults() > 1 && failedToParsedMinMax) { + if (isLower) { + return p->emitError(attrLoc, "lower loop bound affine map with " + "multiple results requires 'max' prefix"); + } + return p->emitError(attrLoc, "upper loop bound affine map with multiple " + "results requires 'min' prefix"); + } + return success(); + } + + // Parse custom assembly form. + if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { + result->attributes.pop_back(); + result->addAttribute( + boundAttrName, builder.getAffineMapAttr( + builder.getConstantAffineMap(integerAttr.getInt()))); + return success(); + } + + return p->emitError( + p->getNameLoc(), + "expected valid affine map representation for loop bounds"); +} + +ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + OpAsmParser::OperandType inductionVariable; + // Parse the induction variable followed by '='. + if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual()) + return failure(); + + // Parse loop bounds. + if (parseBound(/*isLower=*/true, result, parser) || + parser->parseKeyword("to", " between bounds") || + parseBound(/*isLower=*/false, result, parser)) + return failure(); + + // Parse the optional loop step, we default to 1 if one is not present. + if (parser->parseOptionalKeyword("step")) { + result->addAttribute( + AffineForOp::getStepAttrName(), + builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); + } else { + llvm::SMLoc stepLoc = parser->getCurrentLocation(); + IntegerAttr stepAttr; + if (parser->parseAttribute(stepAttr, builder.getIndexType(), + AffineForOp::getStepAttrName().data(), + result->attributes)) + return failure(); + + if (stepAttr.getValue().getSExtValue() < 0) + return parser->emitError( + stepLoc, + "expected step to be representable as a positive signed integer"); + } + + // Parse the body region. + Region *body = result->addRegion(); + if (parser->parseRegion(*body, inductionVariable, builder.getIndexType())) + return failure(); + + AffineForOp::ensureTerminator(*body, builder, result->location); + + // Parse the optional attribute list. + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + // Set the operands list as resizable so that we can freely modify the bounds. + result->setOperandListToResizable(); + return success(); +} + +static void printBound(AffineMapAttr boundMap, + Operation::operand_range boundOperands, + const char *prefix, OpAsmPrinter *p) { + AffineMap map = boundMap.getValue(); + + // Check if this bound should be printed using custom assembly form. + // The decision to restrict printing custom assembly form to trivial cases + // comes from the will to roundtrip MLIR binary -> text -> binary in a + // lossless way. + // Therefore, custom assembly form parsing and printing is only supported for + // zero-operand constant maps and single symbol operand identity maps. + if (map.getNumResults() == 1) { + AffineExpr expr = map.getResult(0); + + // Print constant bound. + if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { + if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { + *p << constExpr.getValue(); + return; + } + } + + // Print bound that consists of a single SSA symbol if the map is over a + // single symbol. + if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { + if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { + p->printOperand(*boundOperands.begin()); + return; + } + } + } else { + // Map has multiple results. Print 'min' or 'max' prefix. + *p << prefix << ' '; + } + + // Print the map and its operands. + *p << boundMap; + printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), + map.getNumDims(), p); +} + +void print(OpAsmPrinter *p, AffineForOp op) { + *p << "affine.for "; + p->printOperand(op.getBody()->getArgument(0)); + *p << " = "; + printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); + *p << " to "; + printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); + + if (op.getStep() != 1) + *p << " step " << op.getStep(); + p->printRegion(op.region(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{op.getLowerBoundAttrName(), + op.getUpperBoundAttrName(), + op.getStepAttrName()}); +} + +namespace { +/// This is a pattern to fold constant loop bounds. +struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> { + using OpRewritePattern<AffineForOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineForOp forOp, + PatternRewriter &rewriter) const override { + auto foldLowerOrUpperBound = [&forOp](bool lower) { + // Check to see if each of the operands is the result of a constant. If + // so, get the value. If not, ignore it. + SmallVector<Attribute, 8> operandConstants; + auto boundOperands = + lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); + for (auto *operand : boundOperands) { + Attribute operandCst; + matchPattern(operand, m_Constant(&operandCst)); + operandConstants.push_back(operandCst); + } + + AffineMap boundMap = + lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); + assert(boundMap.getNumResults() >= 1 && + "bound maps should have at least one result"); + SmallVector<Attribute, 4> foldedResults; + if (failed(boundMap.constantFold(operandConstants, foldedResults))) + return failure(); + + // Compute the max or min as applicable over the results. + assert(!foldedResults.empty() && + "bounds should have at least one result"); + auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue(); + for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { + auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue(); + maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) + : llvm::APIntOps::smin(maxOrMin, foldedResult); + } + lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) + : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); + return success(); + }; + + // Try to fold the lower bound. + bool folded = false; + if (!forOp.hasConstantLowerBound()) + folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); + + // Try to fold the upper bound. + if (!forOp.hasConstantUpperBound()) + folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); + + // If any of the bounds were folded we return success. + if (!folded) + return matchFailure(); + rewriter.updatedRootInPlace(forOp); + return matchSuccess(); + } +}; +} // end anonymous namespace + +void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<AffineForLoopBoundFolder>(context); +} + +AffineBound AffineForOp::getLowerBound() { + auto lbMap = getLowerBoundMap(); + return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); +} + +AffineBound AffineForOp::getUpperBound() { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(), + ubMap); +} + +void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) { + assert(lbOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end()); + + auto ubOperands = getUpperBoundOperands(); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getOperation()->setOperands(newOperands); + + setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) { + assert(ubOperands.size() == map.getNumInputs()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + + SmallVector<Value *, 4> newOperands(getLowerBoundOperands()); + newOperands.append(ubOperands.begin(), ubOperands.end()); + getOperation()->setOperands(newOperands); + + setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); +} + +void AffineForOp::setLowerBoundMap(AffineMap map) { + auto lbMap = getLowerBoundMap(); + assert(lbMap.getNumDims() == map.getNumDims() && + lbMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)lbMap; + setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); +} + +void AffineForOp::setUpperBoundMap(AffineMap map) { + auto ubMap = getUpperBoundMap(); + assert(ubMap.getNumDims() == map.getNumDims() && + ubMap.getNumSymbols() == map.getNumSymbols()); + assert(map.getNumResults() >= 1 && "bound map has at least one result"); + (void)ubMap; + setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); +} + +bool AffineForOp::hasConstantLowerBound() { + return getLowerBoundMap().isSingleConstant(); +} + +bool AffineForOp::hasConstantUpperBound() { + return getUpperBoundMap().isSingleConstant(); +} + +int64_t AffineForOp::getConstantLowerBound() { + return getLowerBoundMap().getSingleConstantResult(); +} + +int64_t AffineForOp::getConstantUpperBound() { + return getUpperBoundMap().getSingleConstantResult(); +} + +void AffineForOp::setConstantLowerBound(int64_t value) { + setLowerBound({}, AffineMap::getConstantMap(value, getContext())); +} + +void AffineForOp::setConstantUpperBound(int64_t value) { + setUpperBound({}, AffineMap::getConstantMap(value, getContext())); +} + +AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { + return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; +} + +AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { + return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()}; +} + +bool AffineForOp::matchingBoundOperandList() { + auto lbMap = getLowerBoundMap(); + auto ubMap = getUpperBoundMap(); + if (lbMap.getNumDims() != ubMap.getNumDims() || + lbMap.getNumSymbols() != ubMap.getNumSymbols()) + return false; + + unsigned numOperands = lbMap.getNumInputs(); + for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { + // Compare Value *'s. + if (getOperand(i) != getOperand(numOperands + i)) + return false; + } + return true; +} + +/// Returns if the provided value is the induction variable of a AffineForOp. +bool mlir::isForInductionVar(Value *val) { + return getForInductionVarOwner(val) != AffineForOp(); +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +AffineForOp mlir::getForInductionVarOwner(Value *val) { + auto *ivArg = dyn_cast<BlockArgument>(val); + if (!ivArg || !ivArg->getOwner()) + return AffineForOp(); + auto *containingInst = ivArg->getOwner()->getParent()->getParentOp(); + return dyn_cast<AffineForOp>(containingInst); +} + +/// Extracts the induction variables from a list of AffineForOps and returns +/// them. +void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts, + SmallVectorImpl<Value *> *ivs) { + ivs->reserve(forInsts.size()); + for (auto forInst : forInsts) + ivs->push_back(forInst.getInductionVar()); +} + +//===----------------------------------------------------------------------===// +// AffineIfOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AffineIfOp op) { + // Verify that we have a condition attribute. + auto conditionAttr = + op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); + if (!conditionAttr) + return op.emitOpError( + "requires an integer set attribute named 'condition'"); + + // Verify that there are enough operands for the condition. + IntegerSet condition = conditionAttr.getValue(); + if (op.getNumOperands() != condition.getNumOperands()) + return op.emitOpError( + "operand count and condition integer set dimension and " + "symbol count must match"); + + // Verify that the operands are valid dimension/symbols. + if (failed(verifyDimAndSymbolIdentifiers( + op, op.getOperation()->getNonSuccessorOperands(), + condition.getNumDims()))) + return failure(); + + // Verify that the entry of each child region does not have arguments. + for (auto ®ion : op.getOperation()->getRegions()) { + for (auto &b : region) + if (b.getNumArguments() != 0) + return op.emitOpError( + "requires that child entry blocks have no arguments"); + } + return success(); +} + +ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) { + // Parse the condition attribute set. + IntegerSetAttr conditionAttr; + unsigned numDims; + if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), + result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims)) + return failure(); + + // Verify the condition operands. + auto set = conditionAttr.getValue(); + if (set.getNumDims() != numDims) + return parser->emitError( + parser->getNameLoc(), + "dim operand count and integer set dim count must match"); + if (numDims + set.getNumSymbols() != result->operands.size()) + return parser->emitError( + parser->getNameLoc(), + "symbol operand count and integer set symbol count must match"); + + // Create the regions for 'then' and 'else'. The latter must be created even + // if it remains empty for the validity of the operation. + result->regions.reserve(2); + Region *thenRegion = result->addRegion(); + Region *elseRegion = result->addRegion(); + + // Parse the 'then' region. + if (parser->parseRegion(*thenRegion, {}, {})) + return failure(); + AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(), + result->location); + + // If we find an 'else' keyword then parse the 'else' region. + if (!parser->parseOptionalKeyword("else")) { + if (parser->parseRegion(*elseRegion, {}, {})) + return failure(); + AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(), + result->location); + } + + // Parse the optional attribute list. + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + return success(); +} + +void print(OpAsmPrinter *p, AffineIfOp op) { + auto conditionAttr = + op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); + *p << "affine.if " << conditionAttr; + printDimAndSymbolList(op.operand_begin(), op.operand_end(), + conditionAttr.getValue().getNumDims(), p); + p->printRegion(op.thenRegion(), + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + + // Print the 'else' regions if it has any blocks. + auto &elseRegion = op.elseRegion(); + if (!elseRegion.empty()) { + *p << " else"; + p->printRegion(elseRegion, + /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + } + + // Print the attribute list. + p->printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/op.getConditionAttrName()); +} + +IntegerSet AffineIfOp::getIntegerSet() { + return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue(); +} +void AffineIfOp::setIntegerSet(IntegerSet newSet) { + setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); +} + +//===----------------------------------------------------------------------===// +// AffineLoadOp +//===----------------------------------------------------------------------===// + +void AffineLoadOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef<Value *> operands) { + result->addOperands(operands); + if (map) + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + auto memrefType = operands[0]->getType().cast<MemRefType>(); + result->types.push_back(memrefType.getElementType()); +} + +void AffineLoadOp::build(Builder *builder, OperationState *result, + Value *memref, ArrayRef<Value *> indices) { + result->addOperands(memref); + result->addOperands(indices); + auto memrefType = memref->getType().cast<MemRefType>(); + auto rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + auto map = rank ? builder->getMultiDimIdentityMap(rank) + : builder->getEmptyAffineMap(); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); + result->types.push_back(memrefType.getElementType()); +} + +ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto affineIntTy = builder.getIndexType(); + + MemRefType type; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector<OpAsmParser::OperandType, 1> mapOperands; + return failure( + parser->parseOperand(memrefInfo) || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(mapOperands, affineIntTy, result->operands) || + parser->addTypeToList(type.getElementType(), result->types)); +} + +void AffineLoadOp::print(OpAsmPrinter *p) { + *p << "affine.load " << *getMemRef() << '['; + AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); + if (mapAttr) { + SmallVector<Value *, 2> operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + } + *p << ']'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); + *p << " : " << getMemRefType(); +} + +LogicalResult AffineLoadOp::verify() { + if (getType() != getMemRefType().getElementType()) + return emitOpError("result type must match element type of memref"); + + auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); + if (mapAttr) { + AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.load affine map num results must equal" + " memref rank"); + if (map.getNumInputs() != getNumOperands() - 1) + return emitOpError("expects as many subscripts as affine map inputs"); + } else { + if (getMemRefType().getRank() != getNumOperands() - 1) + return emitOpError( + "expects the number of subscripts to be equal to memref rank"); + } + + for (auto *idx : getIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("index to load must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineLoadOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// load(memrefcast) -> load + results.insert<MemRefCastFolder>(getOperationName(), context); +} + +//===----------------------------------------------------------------------===// +// AffineStoreOp +//===----------------------------------------------------------------------===// + +void AffineStoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, AffineMap map, + ArrayRef<Value *> operands) { + result->addOperands(valueToStore); + result->addOperands(operands); + if (map) + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); +} + +void AffineStoreOp::build(Builder *builder, OperationState *result, + Value *valueToStore, Value *memref, + ArrayRef<Value *> operands) { + result->addOperands(valueToStore); + result->addOperands(memref); + result->addOperands(operands); + auto memrefType = memref->getType().cast<MemRefType>(); + auto rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + auto map = rank ? builder->getMultiDimIdentityMap(rank) + : builder->getEmptyAffineMap(); + result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map)); +} + +ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) { + auto affineIntTy = parser->getBuilder().getIndexType(); + + MemRefType type; + OpAsmParser::OperandType storeValueInfo; + OpAsmParser::OperandType memrefInfo; + AffineMapAttr mapAttr; + SmallVector<OpAsmParser::OperandType, 1> mapOperands; + return failure( + parser->parseOperand(storeValueInfo) || parser->parseComma() || + parser->parseOperand(memrefInfo) || + parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(), + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(storeValueInfo, type.getElementType(), + result->operands) || + parser->resolveOperand(memrefInfo, type, result->operands) || + parser->resolveOperands(mapOperands, affineIntTy, result->operands)); +} + +void AffineStoreOp::print(OpAsmPrinter *p) { + *p << "affine.store " << *getValueToStore(); + *p << ", " << *getMemRef() << '['; + AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); + if (mapAttr) { + SmallVector<Value *, 2> operands(getIndices()); + p->printAffineMapOfSSAIds(mapAttr, operands); + } + *p << ']'; + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); + *p << " : " << getMemRefType(); +} + +LogicalResult AffineStoreOp::verify() { + // First operand must have same type as memref element type. + if (getValueToStore()->getType() != getMemRefType().getElementType()) + return emitOpError("first operand must have same type memref element type"); + + auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()); + if (mapAttr) { + AffineMap map = mapAttr.getValue(); + if (map.getNumResults() != getMemRefType().getRank()) + return emitOpError("affine.store affine map num results must equal" + " memref rank"); + if (map.getNumInputs() != getNumOperands() - 2) + return emitOpError("expects as many subscripts as affine map inputs"); + } else { + if (getMemRefType().getRank() != getNumOperands() - 2) + return emitOpError( + "expects the number of subscripts to be equal to memref rank"); + } + + for (auto *idx : getIndices()) { + if (!idx->getType().isIndex()) + return emitOpError("index to store must have 'index' type"); + if (!isValidAffineIndexOperand(idx)) + return emitOpError("index must be a dimension or symbol identifier"); + } + return success(); +} + +void AffineStoreOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + /// load(memrefcast) -> load + results.insert<MemRefCastFolder>(getOperationName(), context); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc" diff --git a/mlir/lib/Dialect/AffineOps/CMakeLists.txt b/mlir/lib/Dialect/AffineOps/CMakeLists.txt new file mode 100644 index 00000000000..dbe469369a3 --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/CMakeLists.txt @@ -0,0 +1,10 @@ +add_llvm_library(MLIRAffineOps + AffineOps.cpp + DialectRegistration.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AffineOps + ) +add_dependencies(MLIRAffineOps MLIRAffineOpsIncGen MLIRIR MLIRStandardOps) +target_link_libraries(MLIRAffineOps MLIRIR MLIRStandardOps) + diff --git a/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp new file mode 100644 index 00000000000..9197e3c619f --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp @@ -0,0 +1,22 @@ +//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Dialect/AffineOps/AffineOps.h" +using namespace mlir; + +// Static initialization for Affine op dialect registration. +static DialectRegistration<AffineOpsDialect> StandardOps; diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 294041df4a5..b0641a9611f 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(AffineOps) add_subdirectory(FxpMathOps) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp index 1c5bb6e70c8..c48437f60db 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp @@ -15,7 +15,12 @@ // limitations under the License. // ============================================================================= -#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Helpers.h" @@ -23,11 +28,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Utils/Intrinsics.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" |

