summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp1764
-rw-r--r--mlir/lib/Dialect/AffineOps/CMakeLists.txt10
-rw-r--r--mlir/lib/Dialect/AffineOps/DialectRegistration.cpp22
-rw-r--r--mlir/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp12
5 files changed, 1803 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
new file mode 100644
index 00000000000..7db3fa07c52
--- /dev/null
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -0,0 +1,1764 @@
+//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Support/Debug.h"
+using namespace mlir;
+using llvm::dbgs;
+
+#define DEBUG_TYPE "affine-analysis"
+
+//===----------------------------------------------------------------------===//
+// AffineOpsDialect
+//===----------------------------------------------------------------------===//
+
+AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp,
+ AffineStoreOp,
+#define GET_OP_LIST
+#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
+ >();
+}
+
+/// A utility function to check if a given region is attached to a function.
+static bool isFunctionRegion(Region *region) {
+ return llvm::isa<FuncOp>(region->getParentOp());
+}
+
+/// A utility function to check if a value is defined at the top level of a
+/// function. A value defined at the top level is always a valid symbol.
+bool mlir::isTopLevelSymbol(Value *value) {
+ if (auto *arg = dyn_cast<BlockArgument>(value))
+ return isFunctionRegion(arg->getOwner()->getParent());
+ return isFunctionRegion(value->getDefiningOp()->getParentRegion());
+}
+
+// Value can be used as a dimension id if it is valid as a symbol, or
+// it is an induction variable, or it is a result of affine apply operation
+// with dimension id arguments.
+bool mlir::isValidDim(Value *value) {
+ // The value must be an index type.
+ if (!value->getType().isIndex())
+ return false;
+
+ if (auto *op = value->getDefiningOp()) {
+ // Top level operation or constant operation is ok.
+ if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
+ return true;
+ // Affine apply operation is ok if all of its operands are ok.
+ if (auto applyOp = dyn_cast<AffineApplyOp>(op))
+ return applyOp.isValidDim();
+ // The dim op is okay if its operand memref/tensor is defined at the top
+ // level.
+ if (auto dimOp = dyn_cast<DimOp>(op))
+ return isTopLevelSymbol(dimOp.getOperand());
+ return false;
+ }
+ // This value is a block argument (which also includes 'affine.for' loop IVs).
+ return true;
+}
+
+// Value can be used as a symbol if it is a constant, or it is defined at
+// the top level, or it is a result of affine apply operation with symbol
+// arguments.
+bool mlir::isValidSymbol(Value *value) {
+ // The value must be an index type.
+ if (!value->getType().isIndex())
+ return false;
+
+ if (auto *op = value->getDefiningOp()) {
+ // Top level operation or constant operation is ok.
+ if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
+ return true;
+ // Affine apply operation is ok if all of its operands are ok.
+ if (auto applyOp = dyn_cast<AffineApplyOp>(op))
+ return applyOp.isValidSymbol();
+ // The dim op is okay if its operand memref/tensor is defined at the top
+ // level.
+ if (auto dimOp = dyn_cast<DimOp>(op))
+ return isTopLevelSymbol(dimOp.getOperand());
+ return false;
+ }
+ // Otherwise, check that the value is a top level symbol.
+ return isTopLevelSymbol(value);
+}
+
+// Returns true if 'value' is a valid index to an affine operation (e.g.
+// affine.load, affine.store, affine.dma_start, affine.dma_wait).
+// Returns false otherwise.
+static bool isValidAffineIndexOperand(Value *value) {
+ return isValidDim(value) || isValidSymbol(value);
+}
+
+/// Utility function to verify that a set of operands are valid dimension and
+/// symbol identifiers. The operands should be layed out such that the dimension
+/// operands are before the symbol operands. This function returns failure if
+/// there was an invalid operand. An operation is provided to emit any necessary
+/// errors.
+template <typename OpTy>
+static LogicalResult
+verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
+ unsigned numDims) {
+ unsigned opIt = 0;
+ for (auto *operand : operands) {
+ if (opIt++ < numDims) {
+ if (!isValidDim(operand))
+ return op.emitOpError("operand cannot be used as a dimension id");
+ } else if (!isValidSymbol(operand)) {
+ return op.emitOpError("operand cannot be used as a symbol");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+void AffineApplyOp::build(Builder *builder, OperationState *result,
+ AffineMap map, ArrayRef<Value *> operands) {
+ result->addOperands(operands);
+ result->types.append(map.getNumResults(), builder->getIndexType());
+ result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ auto affineIntTy = builder.getIndexType();
+
+ AffineMapAttr mapAttr;
+ unsigned numDims;
+ if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+ auto map = mapAttr.getValue();
+
+ if (map.getNumDims() != numDims ||
+ numDims + map.getNumSymbols() != result->operands.size()) {
+ return parser->emitError(parser->getNameLoc(),
+ "dimension or symbol index mismatch");
+ }
+
+ result->types.append(map.getNumResults(), affineIntTy);
+ return success();
+}
+
+void AffineApplyOp::print(OpAsmPrinter *p) {
+ *p << "affine.apply " << getAttr("map");
+ printDimAndSymbolList(operand_begin(), operand_end(),
+ getAffineMap().getNumDims(), p);
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
+}
+
+LogicalResult AffineApplyOp::verify() {
+ // Check that affine map attribute was specified.
+ auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
+ if (!affineMapAttr)
+ return emitOpError("requires an affine map");
+
+ // Check input and output dimensions match.
+ auto map = affineMapAttr.getValue();
+
+ // Verify that operand count matches affine map dimension and symbol count.
+ if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
+ return emitOpError(
+ "operand count and affine map dimension and symbol count must match");
+
+ // Verify that all operands are of `index` type.
+ for (Type t : getOperandTypes()) {
+ if (!t.isIndex())
+ return emitOpError("operands must be of type 'index'");
+ }
+
+ if (!getResult()->getType().isIndex())
+ return emitOpError("result must be of type 'index'");
+
+ // Verify that the operands are valid dimension and symbol identifiers.
+ if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
+ map.getNumDims())))
+ return failure();
+
+ // Verify that the map only produces one result.
+ if (map.getNumResults() != 1)
+ return emitOpError("mapping must produce one value");
+
+ return success();
+}
+
+// The result of the affine apply operation can be used as a dimension id if it
+// is a CFG value or if it is an Value, and all the operands are valid
+// dimension ids.
+bool AffineApplyOp::isValidDim() {
+ return llvm::all_of(getOperands(),
+ [](Value *op) { return mlir::isValidDim(op); });
+}
+
+// The result of the affine apply operation can be used as a symbol if it is
+// a CFG value or if it is an Value, and all the operands are symbols.
+bool AffineApplyOp::isValidSymbol() {
+ return llvm::all_of(getOperands(),
+ [](Value *op) { return mlir::isValidSymbol(op); });
+}
+
+OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
+ auto map = getAffineMap();
+
+ // Fold dims and symbols to existing values.
+ auto expr = map.getResult(0);
+ if (auto dim = expr.dyn_cast<AffineDimExpr>())
+ return getOperand(dim.getPosition());
+ if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
+ return getOperand(map.getNumDims() + sym.getPosition());
+
+ // Otherwise, default to folding the map.
+ SmallVector<Attribute, 1> result;
+ if (failed(map.constantFold(operands, result)))
+ return {};
+ return result[0];
+}
+
+namespace {
+/// An `AffineApplyNormalizer` is a helper class that is not visible to the user
+/// and supports renumbering operands of AffineApplyOp. This acts as a
+/// reindexing map of Value* to positional dims or symbols and allows
+/// simplifications such as:
+///
+/// ```mlir
+/// %1 = affine.apply (d0, d1) -> (d0 - d1) (%0, %0)
+/// ```
+///
+/// into:
+///
+/// ```mlir
+/// %1 = affine.apply () -> (0)
+/// ```
+struct AffineApplyNormalizer {
+ AffineApplyNormalizer(AffineMap map, ArrayRef<Value *> operands);
+
+ /// Returns the AffineMap resulting from normalization.
+ AffineMap getAffineMap() { return affineMap; }
+
+ SmallVector<Value *, 8> getOperands() {
+ SmallVector<Value *, 8> res(reorderedDims);
+ res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
+ return res;
+ }
+
+private:
+ /// Helper function to insert `v` into the coordinate system of the current
+ /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
+ /// renumbered position.
+ AffineDimExpr renumberOneDim(Value *v);
+
+ /// Given an `other` normalizer, this rewrites `other.affineMap` in the
+ /// coordinate system of the current AffineApplyNormalizer.
+ /// Returns the rewritten AffineMap and updates the dims and symbols of
+ /// `this`.
+ AffineMap renumber(const AffineApplyNormalizer &other);
+
+ /// Maps of Value* to position in `affineMap`.
+ DenseMap<Value *, unsigned> dimValueToPosition;
+
+ /// Ordered dims and symbols matching positional dims and symbols in
+ /// `affineMap`.
+ SmallVector<Value *, 8> reorderedDims;
+ SmallVector<Value *, 8> concatenatedSymbols;
+
+ AffineMap affineMap;
+
+ /// Used with RAII to control the depth at which AffineApply are composed
+ /// recursively. Only accepts depth 1 for now to allow a behavior where a
+ /// newly composed AffineApplyOp does not increase the length of the chain of
+ /// AffineApplyOps. Full composition is implemented iteratively on top of
+ /// this behavior.
+ static unsigned &affineApplyDepth() {
+ static thread_local unsigned depth = 0;
+ return depth;
+ }
+ static constexpr unsigned kMaxAffineApplyDepth = 1;
+
+ AffineApplyNormalizer() { affineApplyDepth()++; }
+
+public:
+ ~AffineApplyNormalizer() { affineApplyDepth()--; }
+};
+} // end anonymous namespace.
+
+AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value *v) {
+ DenseMap<Value *, unsigned>::iterator iterPos;
+ bool inserted = false;
+ std::tie(iterPos, inserted) =
+ dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
+ if (inserted) {
+ reorderedDims.push_back(v);
+ }
+ return getAffineDimExpr(iterPos->second, v->getContext())
+ .cast<AffineDimExpr>();
+}
+
+AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
+ SmallVector<AffineExpr, 8> dimRemapping;
+ for (auto *v : other.reorderedDims) {
+ auto kvp = other.dimValueToPosition.find(v);
+ if (dimRemapping.size() <= kvp->second)
+ dimRemapping.resize(kvp->second + 1);
+ dimRemapping[kvp->second] = renumberOneDim(kvp->first);
+ }
+ unsigned numSymbols = concatenatedSymbols.size();
+ unsigned numOtherSymbols = other.concatenatedSymbols.size();
+ SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
+ for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
+ symRemapping[idx] =
+ getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
+ }
+ concatenatedSymbols.insert(concatenatedSymbols.end(),
+ other.concatenatedSymbols.begin(),
+ other.concatenatedSymbols.end());
+ auto map = other.affineMap;
+ return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
+ dimRemapping.size(), symRemapping.size());
+}
+
+// Gather the positions of the operands that are produced by an AffineApplyOp.
+static llvm::SetVector<unsigned>
+indicesFromAffineApplyOp(ArrayRef<Value *> operands) {
+ llvm::SetVector<unsigned> res;
+ for (auto en : llvm::enumerate(operands))
+ if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp()))
+ res.insert(en.index());
+ return res;
+}
+
+// Support the special case of a symbol coming from an AffineApplyOp that needs
+// to be composed into the current AffineApplyOp.
+// This case is handled by rewriting all such symbols into dims for the purpose
+// of allowing mathematical AffineMap composition.
+// Returns an AffineMap where symbols that come from an AffineApplyOp have been
+// rewritten as dims and are ordered after the original dims.
+// TODO(andydavis,ntv): This promotion makes AffineMap lose track of which
+// symbols are represented as dims. This loss is static but can still be
+// recovered dynamically (with `isValidSymbol`). Still this is annoying for the
+// semi-affine map case. A dynamic canonicalization of all dims that are valid
+// symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
+// results in better simplifications and foldings. But we should evaluate
+// whether this behavior is what we really want after using more.
+static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
+ ArrayRef<Value *> symbols) {
+ if (symbols.empty()) {
+ return map;
+ }
+
+ // Sanity check on symbols.
+ for (auto *sym : symbols) {
+ assert(isValidSymbol(sym) && "Expected only valid symbols");
+ (void)sym;
+ }
+
+ // Extract the symbol positions that come from an AffineApplyOp and
+ // needs to be rewritten as dims.
+ auto symPositions = indicesFromAffineApplyOp(symbols);
+ if (symPositions.empty()) {
+ return map;
+ }
+
+ // Create the new map by replacing each symbol at pos by the next new dim.
+ unsigned numDims = map.getNumDims();
+ unsigned numSymbols = map.getNumSymbols();
+ unsigned numNewDims = 0;
+ unsigned numNewSymbols = 0;
+ SmallVector<AffineExpr, 8> symReplacements(numSymbols);
+ for (unsigned i = 0; i < numSymbols; ++i) {
+ symReplacements[i] =
+ symPositions.count(i) > 0
+ ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
+ : getAffineSymbolExpr(numNewSymbols++, map.getContext());
+ }
+ assert(numSymbols >= numNewDims);
+ AffineMap newMap = map.replaceDimsAndSymbols(
+ {}, symReplacements, numDims + numNewDims, numNewSymbols);
+
+ return newMap;
+}
+
+/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
+/// keep a correspondence between the mathematical `map` and the `operands` of
+/// a given AffineApplyOp. This correspondence is maintained by iterating over
+/// the operands and forming an `auxiliaryMap` that can be composed
+/// mathematically with `map`. To keep this correspondence in cases where
+/// symbols are produced by affine.apply operations, we perform a local rewrite
+/// of symbols as dims.
+///
+/// Rationale for locally rewriting symbols as dims:
+/// ================================================
+/// The mathematical composition of AffineMap must always concatenate symbols
+/// because it does not have enough information to do otherwise. For example,
+/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
+/// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
+///
+/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
+/// applied to the same mlir::Value* for both s0 and s1.
+/// As a consequence mathematical composition of AffineMap always concatenates
+/// symbols.
+///
+/// When AffineMaps are used in AffineApplyOp however, they may specify
+/// composition via symbols, which is ambiguous mathematically. This corner case
+/// is handled by locally rewriting such symbols that come from AffineApplyOp
+/// into dims and composing through dims.
+/// TODO(andydavis, ntv): Composition via symbols comes at a significant code
+/// complexity. Alternatively we should investigate whether we want to
+/// explicitly disallow symbols coming from affine.apply and instead force the
+/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
+/// extra API calls for such uses, which haven't popped up until now) and the
+/// benefit potentially big: simpler and more maintainable code for a
+/// non-trivial, recursive, procedure.
+AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
+ ArrayRef<Value *> operands)
+ : AffineApplyNormalizer() {
+ static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
+ assert(map.getNumInputs() == operands.size() &&
+ "number of operands does not match the number of map inputs");
+
+ LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
+
+ // Promote symbols that come from an AffineApplyOp to dims by rewriting the
+ // map to always refer to:
+ // (dims, symbols coming from AffineApplyOp, other symbols).
+ // The order of operands can remain unchanged.
+ // This is a simplification that relies on 2 ordering properties:
+ // 1. rewritten symbols always appear after the original dims in the map;
+ // 2. operands are traversed in order and either dispatched to:
+ // a. auxiliaryExprs (dims and symbols rewritten as dims);
+ // b. concatenatedSymbols (all other symbols)
+ // This allows operand order to remain unchanged.
+ unsigned numDimsBeforeRewrite = map.getNumDims();
+ map = promoteComposedSymbolsAsDims(map,
+ operands.take_back(map.getNumSymbols()));
+
+ LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
+
+ SmallVector<AffineExpr, 8> auxiliaryExprs;
+ bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
+ // We fully spell out the 2 cases below. In this particular instance a little
+ // code duplication greatly improves readability.
+ // Note that the first branch would disappear if we only supported full
+ // composition (i.e. infinite kMaxAffineApplyDepth).
+ if (!furtherCompose) {
+ // 1. Only dispatch dims or symbols.
+ for (auto en : llvm::enumerate(operands)) {
+ auto *t = en.value();
+ assert(t->getType().isIndex());
+ bool isDim = (en.index() < map.getNumDims());
+ if (isDim) {
+ // a. The mathematical composition of AffineMap composes dims.
+ auxiliaryExprs.push_back(renumberOneDim(t));
+ } else {
+ // b. The mathematical composition of AffineMap concatenates symbols.
+ // We do the same for symbol operands.
+ concatenatedSymbols.push_back(t);
+ }
+ }
+ } else {
+ assert(numDimsBeforeRewrite <= operands.size());
+ // 2. Compose AffineApplyOps and dispatch dims or symbols.
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ auto *t = operands[i];
+ auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp());
+ if (affineApply) {
+ // a. Compose affine.apply operations.
+ LLVM_DEBUG(affineApply.getOperation()->print(
+ dbgs() << "\nCompose AffineApplyOp recursively: "));
+ AffineMap affineApplyMap = affineApply.getAffineMap();
+ SmallVector<Value *, 8> affineApplyOperands(
+ affineApply.getOperands().begin(), affineApply.getOperands().end());
+ AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
+
+ LLVM_DEBUG(normalizer.affineMap.print(
+ dbgs() << "\nRenumber into current normalizer: "));
+
+ auto renumberedMap = renumber(normalizer);
+
+ LLVM_DEBUG(
+ renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
+
+ auxiliaryExprs.push_back(renumberedMap.getResult(0));
+ } else {
+ if (i < numDimsBeforeRewrite) {
+ // b. The mathematical composition of AffineMap composes dims.
+ auxiliaryExprs.push_back(renumberOneDim(t));
+ } else {
+ // c. The mathematical composition of AffineMap concatenates symbols.
+ // We do the same for symbol operands.
+ concatenatedSymbols.push_back(t);
+ }
+ }
+ }
+ }
+
+ // Early exit if `map` is already composed.
+ if (auxiliaryExprs.empty()) {
+ affineMap = map;
+ return;
+ }
+
+ assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
+ "Unexpected number of concatenated symbols");
+ auto numDims = dimValueToPosition.size();
+ auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
+ auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs);
+
+ LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
+ LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
+ LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
+
+ // TODO(andydavis,ntv): Disabling simplification results in major speed gains.
+ // Another option is to cache the results as it is expected a lot of redundant
+ // work is performed in practice.
+ affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
+
+ LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
+ LLVM_DEBUG(dbgs() << "\n");
+}
+
+/// Implements `map` and `operands` composition and simplification to support
+/// `makeComposedAffineApply`. This can be called to achieve the same effects
+/// on `map` and `operands` without creating an AffineApplyOp that needs to be
+/// immediately deleted.
+static void composeAffineMapAndOperands(AffineMap *map,
+ SmallVectorImpl<Value *> *operands) {
+ AffineApplyNormalizer normalizer(*map, *operands);
+ auto normalizedMap = normalizer.getAffineMap();
+ auto normalizedOperands = normalizer.getOperands();
+ canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
+ *map = normalizedMap;
+ *operands = normalizedOperands;
+ assert(*map);
+}
+
+void mlir::fullyComposeAffineMapAndOperands(
+ AffineMap *map, SmallVectorImpl<Value *> *operands) {
+ while (llvm::any_of(*operands, [](Value *v) {
+ return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp());
+ })) {
+ composeAffineMapAndOperands(map, operands);
+ }
+}
+
+AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
+ AffineMap map,
+ ArrayRef<Value *> operands) {
+ AffineMap normalizedMap = map;
+ SmallVector<Value *, 8> normalizedOperands(operands.begin(), operands.end());
+ composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
+ assert(normalizedMap);
+ return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
+}
+
+// A symbol may appear as a dim in affine.apply operations. This function
+// canonicalizes dims that are valid symbols into actual symbols.
+static void
+canonicalizePromotedSymbols(AffineMap *map,
+ llvm::SmallVectorImpl<Value *> *operands) {
+ if (!map || operands->empty())
+ return;
+
+ assert(map->getNumInputs() == operands->size() &&
+ "map inputs must match number of operands");
+
+ auto *context = map->getContext();
+ SmallVector<Value *, 8> resultOperands;
+ resultOperands.reserve(operands->size());
+ SmallVector<Value *, 8> remappedSymbols;
+ remappedSymbols.reserve(operands->size());
+ unsigned nextDim = 0;
+ unsigned nextSym = 0;
+ unsigned oldNumSyms = map->getNumSymbols();
+ SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+ for (unsigned i = 0, e = map->getNumInputs(); i != e; ++i) {
+ if (i < map->getNumDims()) {
+ if (isValidSymbol((*operands)[i])) {
+ // This is a valid symbols that appears as a dim, canonicalize it.
+ dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
+ remappedSymbols.push_back((*operands)[i]);
+ } else {
+ dimRemapping[i] = getAffineDimExpr(nextDim++, context);
+ resultOperands.push_back((*operands)[i]);
+ }
+ } else {
+ resultOperands.push_back((*operands)[i]);
+ }
+ }
+
+ resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
+ *operands = resultOperands;
+ *map = map->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
+ oldNumSyms + nextSym);
+
+ assert(map->getNumInputs() == operands->size() &&
+ "map inputs must match number of operands");
+}
+
+void mlir::canonicalizeMapAndOperands(
+ AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
+ if (!map || operands->empty())
+ return;
+
+ assert(map->getNumInputs() == operands->size() &&
+ "map inputs must match number of operands");
+
+ canonicalizePromotedSymbols(map, operands);
+
+ // Check to see what dims are used.
+ llvm::SmallBitVector usedDims(map->getNumDims());
+ llvm::SmallBitVector usedSyms(map->getNumSymbols());
+ map->walkExprs([&](AffineExpr expr) {
+ if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ usedDims[dimExpr.getPosition()] = true;
+ else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+ usedSyms[symExpr.getPosition()] = true;
+ });
+
+ auto *context = map->getContext();
+
+ SmallVector<Value *, 8> resultOperands;
+ resultOperands.reserve(operands->size());
+
+ llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
+ SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+ unsigned nextDim = 0;
+ for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
+ if (usedDims[i]) {
+ auto it = seenDims.find((*operands)[i]);
+ if (it == seenDims.end()) {
+ dimRemapping[i] = getAffineDimExpr(nextDim++, context);
+ resultOperands.push_back((*operands)[i]);
+ seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
+ } else {
+ dimRemapping[i] = it->second;
+ }
+ }
+ }
+ llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
+ SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
+ unsigned nextSym = 0;
+ for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
+ if (usedSyms[i]) {
+ auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
+ if (it == seenSymbols.end()) {
+ symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
+ resultOperands.push_back((*operands)[i + map->getNumDims()]);
+ seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()],
+ symRemapping[i]));
+ } else {
+ symRemapping[i] = it->second;
+ }
+ }
+ }
+ *map =
+ map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
+ *operands = resultOperands;
+}
+
+namespace {
+/// Simplify AffineApply operations.
+///
+struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
+ using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineApplyOp apply,
+ PatternRewriter &rewriter) const override {
+ auto map = apply.getAffineMap();
+
+ AffineMap oldMap = map;
+ SmallVector<Value *, 8> resultOperands(apply.getOperands());
+ composeAffineMapAndOperands(&map, &resultOperands);
+ if (map == oldMap)
+ return matchFailure();
+
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace.
+
+void AffineApplyOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<SimplifyAffineApply>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// Common canonicalization pattern support logic
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This is a common class used for patterns of the form
+/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
+/// into the root operation directly.
+struct MemRefCastFolder : public RewritePattern {
+ /// The rootOpName is the name of the root operation to match against.
+ MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
+ : RewritePattern(rootOpName, 1, context) {}
+
+ PatternMatchResult match(Operation *op) const override {
+ for (auto *operand : op->getOperands())
+ if (matchPattern(operand, m_Op<MemRefCastOp>()))
+ return matchSuccess();
+
+ return matchFailure();
+ }
+
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
+ if (auto *memref = op->getOperand(i)->getDefiningOp())
+ if (auto cast = dyn_cast<MemRefCastOp>(memref))
+ op->setOperand(i, cast.getOperand());
+ rewriter.updatedRootInPlace(op);
+ }
+};
+
+} // end anonymous namespace.
+
+//===----------------------------------------------------------------------===//
+// AffineDmaStartOp
+//===----------------------------------------------------------------------===//
+
+// TODO(b/133776335) Check that map operands are loop IVs or symbols.
+void AffineDmaStartOp::build(Builder *builder, OperationState *result,
+ Value *srcMemRef, AffineMap srcMap,
+ ArrayRef<Value *> srcIndices, Value *destMemRef,
+ AffineMap dstMap, ArrayRef<Value *> destIndices,
+ Value *tagMemRef, AffineMap tagMap,
+ ArrayRef<Value *> tagIndices, Value *numElements,
+ Value *stride, Value *elementsPerStride) {
+ result->addOperands(srcMemRef);
+ result->addAttribute(getSrcMapAttrName(), builder->getAffineMapAttr(srcMap));
+ result->addOperands(srcIndices);
+ result->addOperands(destMemRef);
+ result->addAttribute(getDstMapAttrName(), builder->getAffineMapAttr(dstMap));
+ result->addOperands(destIndices);
+ result->addOperands(tagMemRef);
+ result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
+ result->addOperands(tagIndices);
+ result->addOperands(numElements);
+ if (stride) {
+ result->addOperands({stride, elementsPerStride});
+ }
+}
+
+void AffineDmaStartOp::print(OpAsmPrinter *p) {
+ *p << "affine.dma_start " << *getSrcMemRef() << '[';
+ SmallVector<Value *, 8> operands(getSrcIndices());
+ p->printAffineMapOfSSAIds(getSrcMapAttr(), operands);
+ *p << "], " << *getDstMemRef() << '[';
+ operands.assign(getDstIndices().begin(), getDstIndices().end());
+ p->printAffineMapOfSSAIds(getDstMapAttr(), operands);
+ *p << "], " << *getTagMemRef() << '[';
+ operands.assign(getTagIndices().begin(), getTagIndices().end());
+ p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
+ *p << "], " << *getNumElements();
+ if (isStrided()) {
+ *p << ", " << *getStride();
+ *p << ", " << *getNumElementsPerStride();
+ }
+ *p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
+ << getTagMemRefType();
+}
+
+// Parse AffineDmaStartOp.
+// Ex:
+// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
+// %stride, %num_elt_per_stride
+// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
+//
+ParseResult AffineDmaStartOp::parse(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType srcMemRefInfo;
+ AffineMapAttr srcMapAttr;
+ SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
+ OpAsmParser::OperandType dstMemRefInfo;
+ AffineMapAttr dstMapAttr;
+ SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
+ OpAsmParser::OperandType tagMemRefInfo;
+ AffineMapAttr tagMapAttr;
+ SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
+ OpAsmParser::OperandType numElementsInfo;
+ SmallVector<OpAsmParser::OperandType, 2> strideInfo;
+
+ SmallVector<Type, 3> types;
+ auto indexType = parser->getBuilder().getIndexType();
+
+ // Parse and resolve the following list of operands:
+ // *) dst memref followed by its affine maps operands (in square brackets).
+ // *) src memref followed by its affine map operands (in square brackets).
+ // *) tag memref followed by its affine map operands (in square brackets).
+ // *) number of elements transferred by DMA operation.
+ if (parser->parseOperand(srcMemRefInfo) ||
+ parser->parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
+ getSrcMapAttrName(), result->attributes) ||
+ parser->parseComma() || parser->parseOperand(dstMemRefInfo) ||
+ parser->parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
+ getDstMapAttrName(), result->attributes) ||
+ parser->parseComma() || parser->parseOperand(tagMemRefInfo) ||
+ parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
+ getTagMapAttrName(), result->attributes) ||
+ parser->parseComma() || parser->parseOperand(numElementsInfo))
+ return failure();
+
+ // Parse optional stride and elements per stride.
+ if (parser->parseTrailingOperandList(strideInfo)) {
+ return failure();
+ }
+ if (!strideInfo.empty() && strideInfo.size() != 2) {
+ return parser->emitError(parser->getNameLoc(),
+ "expected two stride related operands");
+ }
+ bool isStrided = strideInfo.size() == 2;
+
+ if (parser->parseColonTypeList(types))
+ return failure();
+
+ if (types.size() != 3)
+ return parser->emitError(parser->getNameLoc(), "expected three types");
+
+ if (parser->resolveOperand(srcMemRefInfo, types[0], result->operands) ||
+ parser->resolveOperands(srcMapOperands, indexType, result->operands) ||
+ parser->resolveOperand(dstMemRefInfo, types[1], result->operands) ||
+ parser->resolveOperands(dstMapOperands, indexType, result->operands) ||
+ parser->resolveOperand(tagMemRefInfo, types[2], result->operands) ||
+ parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
+ parser->resolveOperand(numElementsInfo, indexType, result->operands))
+ return failure();
+
+ if (isStrided) {
+ if (parser->resolveOperands(strideInfo, indexType, result->operands))
+ return failure();
+ }
+
+ // Check that src/dst/tag operand counts match their map.numInputs.
+ if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
+ dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
+ tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
+ return parser->emitError(parser->getNameLoc(),
+ "memref operand count not equal to map.numInputs");
+ return success();
+}
+
+LogicalResult AffineDmaStartOp::verify() {
+ if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
+ return emitOpError("expected DMA source to be of memref type");
+ if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
+ return emitOpError("expected DMA destination to be of memref type");
+ if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
+ return emitOpError("expected DMA tag to be of memref type");
+
+ // DMAs from different memory spaces supported.
+ if (getSrcMemorySpace() == getDstMemorySpace()) {
+ return emitOpError("DMA should be between different memory spaces");
+ }
+ unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
+ getDstMap().getNumInputs() +
+ getTagMap().getNumInputs();
+ if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
+ getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
+ return emitOpError("incorrect number of operands");
+ }
+
+ for (auto *idx : getSrcIndices()) {
+ if (!idx->getType().isIndex())
+ return emitOpError("src index to dma_start must have 'index' type");
+ if (!isValidAffineIndexOperand(idx))
+ return emitOpError("src index must be a dimension or symbol identifier");
+ }
+ for (auto *idx : getDstIndices()) {
+ if (!idx->getType().isIndex())
+ return emitOpError("dst index to dma_start must have 'index' type");
+ if (!isValidAffineIndexOperand(idx))
+ return emitOpError("dst index must be a dimension or symbol identifier");
+ }
+ for (auto *idx : getTagIndices()) {
+ if (!idx->getType().isIndex())
+ return emitOpError("tag index to dma_start must have 'index' type");
+ if (!isValidAffineIndexOperand(idx))
+ return emitOpError("tag index must be a dimension or symbol identifier");
+ }
+ return success();
+}
+
+void AffineDmaStartOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ /// dma_start(memrefcast) -> dma_start
+ results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineDmaWaitOp
+//===----------------------------------------------------------------------===//
+
+// TODO(b/133776335) Check that map operands are loop IVs or symbols.
+void AffineDmaWaitOp::build(Builder *builder, OperationState *result,
+ Value *tagMemRef, AffineMap tagMap,
+ ArrayRef<Value *> tagIndices, Value *numElements) {
+ result->addOperands(tagMemRef);
+ result->addAttribute(getTagMapAttrName(), builder->getAffineMapAttr(tagMap));
+ result->addOperands(tagIndices);
+ result->addOperands(numElements);
+}
+
+void AffineDmaWaitOp::print(OpAsmPrinter *p) {
+ *p << "affine.dma_wait " << *getTagMemRef() << '[';
+ SmallVector<Value *, 2> operands(getTagIndices());
+ p->printAffineMapOfSSAIds(getTagMapAttr(), operands);
+ *p << "], ";
+ p->printOperand(getNumElements());
+ *p << " : " << getTagMemRef()->getType();
+}
+
+// Parse AffineDmaWaitOp.
+// Eg:
+// affine.dma_wait %tag[%index], %num_elements
+// : memref<1 x i32, (d0) -> (d0), 4>
+//
+ParseResult AffineDmaWaitOp::parse(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType tagMemRefInfo;
+ AffineMapAttr tagMapAttr;
+ SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
+ Type type;
+ auto indexType = parser->getBuilder().getIndexType();
+ OpAsmParser::OperandType numElementsInfo;
+
+ // Parse tag memref, its map operands, and dma size.
+ if (parser->parseOperand(tagMemRefInfo) ||
+ parser->parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
+ getTagMapAttrName(), result->attributes) ||
+ parser->parseComma() || parser->parseOperand(numElementsInfo) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(tagMemRefInfo, type, result->operands) ||
+ parser->resolveOperands(tagMapOperands, indexType, result->operands) ||
+ parser->resolveOperand(numElementsInfo, indexType, result->operands))
+ return failure();
+
+ if (!type.isa<MemRefType>())
+ return parser->emitError(parser->getNameLoc(),
+ "expected tag to be of memref type");
+
+ if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
+ return parser->emitError(parser->getNameLoc(),
+ "tag memref operand count != to map.numInputs");
+ return success();
+}
+
+LogicalResult AffineDmaWaitOp::verify() {
+ if (!getOperand(0)->getType().isa<MemRefType>())
+ return emitOpError("expected DMA tag to be of memref type");
+ for (auto *idx : getTagIndices()) {
+ if (!idx->getType().isIndex())
+ return emitOpError("index to dma_wait must have 'index' type");
+ if (!isValidAffineIndexOperand(idx))
+ return emitOpError("index must be a dimension or symbol identifier");
+ }
+ return success();
+}
+
+void AffineDmaWaitOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ /// dma_wait(memrefcast) -> dma_wait
+ results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineForOp
+//===----------------------------------------------------------------------===//
+
+void AffineForOp::build(Builder *builder, OperationState *result,
+ ArrayRef<Value *> lbOperands, AffineMap lbMap,
+ ArrayRef<Value *> ubOperands, AffineMap ubMap,
+ int64_t step) {
+ assert(((!lbMap && lbOperands.empty()) ||
+ lbOperands.size() == lbMap.getNumInputs()) &&
+ "lower bound operand count does not match the affine map");
+ assert(((!ubMap && ubOperands.empty()) ||
+ ubOperands.size() == ubMap.getNumInputs()) &&
+ "upper bound operand count does not match the affine map");
+ assert(step > 0 && "step has to be a positive integer constant");
+
+ // Add an attribute for the step.
+ result->addAttribute(getStepAttrName(),
+ builder->getIntegerAttr(builder->getIndexType(), step));
+
+ // Add the lower bound.
+ result->addAttribute(getLowerBoundAttrName(),
+ builder->getAffineMapAttr(lbMap));
+ result->addOperands(lbOperands);
+
+ // Add the upper bound.
+ result->addAttribute(getUpperBoundAttrName(),
+ builder->getAffineMapAttr(ubMap));
+ result->addOperands(ubOperands);
+
+ // Create a region and a block for the body. The argument of the region is
+ // the loop induction variable.
+ Region *bodyRegion = result->addRegion();
+ Block *body = new Block();
+ body->addArgument(IndexType::get(builder->getContext()));
+ bodyRegion->push_back(body);
+ ensureTerminator(*bodyRegion, *builder, result->location);
+
+ // Set the operands list as resizable so that we can freely modify the bounds.
+ result->setOperandListToResizable();
+}
+
+void AffineForOp::build(Builder *builder, OperationState *result, int64_t lb,
+ int64_t ub, int64_t step) {
+ auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
+ auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
+ return build(builder, result, {}, lbMap, {}, ubMap, step);
+}
+
+static LogicalResult verify(AffineForOp op) {
+ // Check that the body defines as single block argument for the induction
+ // variable.
+ auto *body = op.getBody();
+ if (body->getNumArguments() != 1 ||
+ !body->getArgument(0)->getType().isIndex())
+ return op.emitOpError(
+ "expected body to have a single index argument for the "
+ "induction variable");
+
+ // Verify that there are enough operands for the bounds.
+ AffineMap lowerBoundMap = op.getLowerBoundMap(),
+ upperBoundMap = op.getUpperBoundMap();
+ if (op.getNumOperands() !=
+ (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
+ return op.emitOpError(
+ "operand count must match with affine map dimension and symbol count");
+
+ // Verify that the bound operands are valid dimension/symbols.
+ /// Lower bound.
+ if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
+ op.getLowerBoundMap().getNumDims())))
+ return failure();
+ /// Upper bound.
+ if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
+ op.getUpperBoundMap().getNumDims())))
+ return failure();
+ return success();
+}
+
+/// Parse a for operation loop bounds.
+static ParseResult parseBound(bool isLower, OperationState *result,
+ OpAsmParser *p) {
+ // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
+ // the map has multiple results.
+ bool failedToParsedMinMax =
+ failed(p->parseOptionalKeyword(isLower ? "max" : "min"));
+
+ auto &builder = p->getBuilder();
+ auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
+ : AffineForOp::getUpperBoundAttrName();
+
+ // Parse ssa-id as identity map.
+ SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
+ if (p->parseOperandList(boundOpInfos))
+ return failure();
+
+ if (!boundOpInfos.empty()) {
+ // Check that only one operand was parsed.
+ if (boundOpInfos.size() > 1)
+ return p->emitError(p->getNameLoc(),
+ "expected only one loop bound operand");
+
+ // TODO: improve error message when SSA value is not an affine integer.
+ // Currently it is 'use of value ... expects different type than prior uses'
+ if (p->resolveOperand(boundOpInfos.front(), builder.getIndexType(),
+ result->operands))
+ return failure();
+
+ // Create an identity map using symbol id. This representation is optimized
+ // for storage. Analysis passes may expand it into a multi-dimensional map
+ // if desired.
+ AffineMap map = builder.getSymbolIdentityMap();
+ result->addAttribute(boundAttrName, builder.getAffineMapAttr(map));
+ return success();
+ }
+
+ // Get the attribute location.
+ llvm::SMLoc attrLoc = p->getCurrentLocation();
+
+ Attribute boundAttr;
+ if (p->parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
+ result->attributes))
+ return failure();
+
+ // Parse full form - affine map followed by dim and symbol list.
+ if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
+ unsigned currentNumOperands = result->operands.size();
+ unsigned numDims;
+ if (parseDimAndSymbolList(p, result->operands, numDims))
+ return failure();
+
+ auto map = affineMapAttr.getValue();
+ if (map.getNumDims() != numDims)
+ return p->emitError(
+ p->getNameLoc(),
+ "dim operand count and integer set dim count must match");
+
+ unsigned numDimAndSymbolOperands =
+ result->operands.size() - currentNumOperands;
+ if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
+ return p->emitError(
+ p->getNameLoc(),
+ "symbol operand count and integer set symbol count must match");
+
+ // If the map has multiple results, make sure that we parsed the min/max
+ // prefix.
+ if (map.getNumResults() > 1 && failedToParsedMinMax) {
+ if (isLower) {
+ return p->emitError(attrLoc, "lower loop bound affine map with "
+ "multiple results requires 'max' prefix");
+ }
+ return p->emitError(attrLoc, "upper loop bound affine map with multiple "
+ "results requires 'min' prefix");
+ }
+ return success();
+ }
+
+ // Parse custom assembly form.
+ if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
+ result->attributes.pop_back();
+ result->addAttribute(
+ boundAttrName, builder.getAffineMapAttr(
+ builder.getConstantAffineMap(integerAttr.getInt())));
+ return success();
+ }
+
+ return p->emitError(
+ p->getNameLoc(),
+ "expected valid affine map representation for loop bounds");
+}
+
+ParseResult parseAffineForOp(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ OpAsmParser::OperandType inductionVariable;
+ // Parse the induction variable followed by '='.
+ if (parser->parseRegionArgument(inductionVariable) || parser->parseEqual())
+ return failure();
+
+ // Parse loop bounds.
+ if (parseBound(/*isLower=*/true, result, parser) ||
+ parser->parseKeyword("to", " between bounds") ||
+ parseBound(/*isLower=*/false, result, parser))
+ return failure();
+
+ // Parse the optional loop step, we default to 1 if one is not present.
+ if (parser->parseOptionalKeyword("step")) {
+ result->addAttribute(
+ AffineForOp::getStepAttrName(),
+ builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
+ } else {
+ llvm::SMLoc stepLoc = parser->getCurrentLocation();
+ IntegerAttr stepAttr;
+ if (parser->parseAttribute(stepAttr, builder.getIndexType(),
+ AffineForOp::getStepAttrName().data(),
+ result->attributes))
+ return failure();
+
+ if (stepAttr.getValue().getSExtValue() < 0)
+ return parser->emitError(
+ stepLoc,
+ "expected step to be representable as a positive signed integer");
+ }
+
+ // Parse the body region.
+ Region *body = result->addRegion();
+ if (parser->parseRegion(*body, inductionVariable, builder.getIndexType()))
+ return failure();
+
+ AffineForOp::ensureTerminator(*body, builder, result->location);
+
+ // Parse the optional attribute list.
+ if (parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+
+ // Set the operands list as resizable so that we can freely modify the bounds.
+ result->setOperandListToResizable();
+ return success();
+}
+
+static void printBound(AffineMapAttr boundMap,
+ Operation::operand_range boundOperands,
+ const char *prefix, OpAsmPrinter *p) {
+ AffineMap map = boundMap.getValue();
+
+ // Check if this bound should be printed using custom assembly form.
+ // The decision to restrict printing custom assembly form to trivial cases
+ // comes from the will to roundtrip MLIR binary -> text -> binary in a
+ // lossless way.
+ // Therefore, custom assembly form parsing and printing is only supported for
+ // zero-operand constant maps and single symbol operand identity maps.
+ if (map.getNumResults() == 1) {
+ AffineExpr expr = map.getResult(0);
+
+ // Print constant bound.
+ if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
+ if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
+ *p << constExpr.getValue();
+ return;
+ }
+ }
+
+ // Print bound that consists of a single SSA symbol if the map is over a
+ // single symbol.
+ if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
+ if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
+ p->printOperand(*boundOperands.begin());
+ return;
+ }
+ }
+ } else {
+ // Map has multiple results. Print 'min' or 'max' prefix.
+ *p << prefix << ' ';
+ }
+
+ // Print the map and its operands.
+ *p << boundMap;
+ printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
+ map.getNumDims(), p);
+}
+
+void print(OpAsmPrinter *p, AffineForOp op) {
+ *p << "affine.for ";
+ p->printOperand(op.getBody()->getArgument(0));
+ *p << " = ";
+ printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
+ *p << " to ";
+ printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
+
+ if (op.getStep() != 1)
+ *p << " step " << op.getStep();
+ p->printRegion(op.region(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ p->printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/{op.getLowerBoundAttrName(),
+ op.getUpperBoundAttrName(),
+ op.getStepAttrName()});
+}
+
+namespace {
+/// This is a pattern to fold constant loop bounds.
+struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
+ using OpRewritePattern<AffineForOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(AffineForOp forOp,
+ PatternRewriter &rewriter) const override {
+ auto foldLowerOrUpperBound = [&forOp](bool lower) {
+ // Check to see if each of the operands is the result of a constant. If
+ // so, get the value. If not, ignore it.
+ SmallVector<Attribute, 8> operandConstants;
+ auto boundOperands =
+ lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
+ for (auto *operand : boundOperands) {
+ Attribute operandCst;
+ matchPattern(operand, m_Constant(&operandCst));
+ operandConstants.push_back(operandCst);
+ }
+
+ AffineMap boundMap =
+ lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
+ assert(boundMap.getNumResults() >= 1 &&
+ "bound maps should have at least one result");
+ SmallVector<Attribute, 4> foldedResults;
+ if (failed(boundMap.constantFold(operandConstants, foldedResults)))
+ return failure();
+
+ // Compute the max or min as applicable over the results.
+ assert(!foldedResults.empty() &&
+ "bounds should have at least one result");
+ auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
+ for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
+ auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
+ maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
+ : llvm::APIntOps::smin(maxOrMin, foldedResult);
+ }
+ lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
+ : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
+ return success();
+ };
+
+ // Try to fold the lower bound.
+ bool folded = false;
+ if (!forOp.hasConstantLowerBound())
+ folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
+
+ // Try to fold the upper bound.
+ if (!forOp.hasConstantUpperBound())
+ folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
+
+ // If any of the bounds were folded we return success.
+ if (!folded)
+ return matchFailure();
+ rewriter.updatedRootInPlace(forOp);
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace
+
+void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<AffineForLoopBoundFolder>(context);
+}
+
+AffineBound AffineForOp::getLowerBound() {
+ auto lbMap = getLowerBoundMap();
+ return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
+}
+
+AffineBound AffineForOp::getUpperBound() {
+ auto lbMap = getLowerBoundMap();
+ auto ubMap = getUpperBoundMap();
+ return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
+ ubMap);
+}
+
+void AffineForOp::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
+ assert(lbOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+ SmallVector<Value *, 4> newOperands(lbOperands.begin(), lbOperands.end());
+
+ auto ubOperands = getUpperBoundOperands();
+ newOperands.append(ubOperands.begin(), ubOperands.end());
+ getOperation()->setOperands(newOperands);
+
+ setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
+ assert(ubOperands.size() == map.getNumInputs());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+
+ SmallVector<Value *, 4> newOperands(getLowerBoundOperands());
+ newOperands.append(ubOperands.begin(), ubOperands.end());
+ getOperation()->setOperands(newOperands);
+
+ setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
+}
+
+void AffineForOp::setLowerBoundMap(AffineMap map) {
+ auto lbMap = getLowerBoundMap();
+ assert(lbMap.getNumDims() == map.getNumDims() &&
+ lbMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+ (void)lbMap;
+ setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
+}
+
+void AffineForOp::setUpperBoundMap(AffineMap map) {
+ auto ubMap = getUpperBoundMap();
+ assert(ubMap.getNumDims() == map.getNumDims() &&
+ ubMap.getNumSymbols() == map.getNumSymbols());
+ assert(map.getNumResults() >= 1 && "bound map has at least one result");
+ (void)ubMap;
+ setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
+}
+
+bool AffineForOp::hasConstantLowerBound() {
+ return getLowerBoundMap().isSingleConstant();
+}
+
+bool AffineForOp::hasConstantUpperBound() {
+ return getUpperBoundMap().isSingleConstant();
+}
+
+int64_t AffineForOp::getConstantLowerBound() {
+ return getLowerBoundMap().getSingleConstantResult();
+}
+
+int64_t AffineForOp::getConstantUpperBound() {
+ return getUpperBoundMap().getSingleConstantResult();
+}
+
+void AffineForOp::setConstantLowerBound(int64_t value) {
+ setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
+}
+
+void AffineForOp::setConstantUpperBound(int64_t value) {
+ setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
+}
+
+AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
+ return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
+}
+
+AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
+ return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
+}
+
+bool AffineForOp::matchingBoundOperandList() {
+ auto lbMap = getLowerBoundMap();
+ auto ubMap = getUpperBoundMap();
+ if (lbMap.getNumDims() != ubMap.getNumDims() ||
+ lbMap.getNumSymbols() != ubMap.getNumSymbols())
+ return false;
+
+ unsigned numOperands = lbMap.getNumInputs();
+ for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
+ // Compare Value *'s.
+ if (getOperand(i) != getOperand(numOperands + i))
+ return false;
+ }
+ return true;
+}
+
+/// Returns if the provided value is the induction variable of a AffineForOp.
+bool mlir::isForInductionVar(Value *val) {
+ return getForInductionVarOwner(val) != AffineForOp();
+}
+
+/// Returns the loop parent of an induction variable. If the provided value is
+/// not an induction variable, then return nullptr.
+AffineForOp mlir::getForInductionVarOwner(Value *val) {
+ auto *ivArg = dyn_cast<BlockArgument>(val);
+ if (!ivArg || !ivArg->getOwner())
+ return AffineForOp();
+ auto *containingInst = ivArg->getOwner()->getParent()->getParentOp();
+ return dyn_cast<AffineForOp>(containingInst);
+}
+
+/// Extracts the induction variables from a list of AffineForOps and returns
+/// them.
+void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
+ SmallVectorImpl<Value *> *ivs) {
+ ivs->reserve(forInsts.size());
+ for (auto forInst : forInsts)
+ ivs->push_back(forInst.getInductionVar());
+}
+
+//===----------------------------------------------------------------------===//
+// AffineIfOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(AffineIfOp op) {
+ // Verify that we have a condition attribute.
+ auto conditionAttr =
+ op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
+ if (!conditionAttr)
+ return op.emitOpError(
+ "requires an integer set attribute named 'condition'");
+
+ // Verify that there are enough operands for the condition.
+ IntegerSet condition = conditionAttr.getValue();
+ if (op.getNumOperands() != condition.getNumOperands())
+ return op.emitOpError(
+ "operand count and condition integer set dimension and "
+ "symbol count must match");
+
+ // Verify that the operands are valid dimension/symbols.
+ if (failed(verifyDimAndSymbolIdentifiers(
+ op, op.getOperation()->getNonSuccessorOperands(),
+ condition.getNumDims())))
+ return failure();
+
+ // Verify that the entry of each child region does not have arguments.
+ for (auto &region : op.getOperation()->getRegions()) {
+ for (auto &b : region)
+ if (b.getNumArguments() != 0)
+ return op.emitOpError(
+ "requires that child entry blocks have no arguments");
+ }
+ return success();
+}
+
+ParseResult parseAffineIfOp(OpAsmParser *parser, OperationState *result) {
+ // Parse the condition attribute set.
+ IntegerSetAttr conditionAttr;
+ unsigned numDims;
+ if (parser->parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
+ result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims))
+ return failure();
+
+ // Verify the condition operands.
+ auto set = conditionAttr.getValue();
+ if (set.getNumDims() != numDims)
+ return parser->emitError(
+ parser->getNameLoc(),
+ "dim operand count and integer set dim count must match");
+ if (numDims + set.getNumSymbols() != result->operands.size())
+ return parser->emitError(
+ parser->getNameLoc(),
+ "symbol operand count and integer set symbol count must match");
+
+ // Create the regions for 'then' and 'else'. The latter must be created even
+ // if it remains empty for the validity of the operation.
+ result->regions.reserve(2);
+ Region *thenRegion = result->addRegion();
+ Region *elseRegion = result->addRegion();
+
+ // Parse the 'then' region.
+ if (parser->parseRegion(*thenRegion, {}, {}))
+ return failure();
+ AffineIfOp::ensureTerminator(*thenRegion, parser->getBuilder(),
+ result->location);
+
+ // If we find an 'else' keyword then parse the 'else' region.
+ if (!parser->parseOptionalKeyword("else")) {
+ if (parser->parseRegion(*elseRegion, {}, {}))
+ return failure();
+ AffineIfOp::ensureTerminator(*elseRegion, parser->getBuilder(),
+ result->location);
+ }
+
+ // Parse the optional attribute list.
+ if (parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+
+ return success();
+}
+
+void print(OpAsmPrinter *p, AffineIfOp op) {
+ auto conditionAttr =
+ op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
+ *p << "affine.if " << conditionAttr;
+ printDimAndSymbolList(op.operand_begin(), op.operand_end(),
+ conditionAttr.getValue().getNumDims(), p);
+ p->printRegion(op.thenRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+
+ // Print the 'else' regions if it has any blocks.
+ auto &elseRegion = op.elseRegion();
+ if (!elseRegion.empty()) {
+ *p << " else";
+ p->printRegion(elseRegion,
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ }
+
+ // Print the attribute list.
+ p->printOptionalAttrDict(op.getAttrs(),
+ /*elidedAttrs=*/op.getConditionAttrName());
+}
+
+IntegerSet AffineIfOp::getIntegerSet() {
+ return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
+}
+void AffineIfOp::setIntegerSet(IntegerSet newSet) {
+ setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
+}
+
+//===----------------------------------------------------------------------===//
+// AffineLoadOp
+//===----------------------------------------------------------------------===//
+
+void AffineLoadOp::build(Builder *builder, OperationState *result,
+ AffineMap map, ArrayRef<Value *> operands) {
+ result->addOperands(operands);
+ if (map)
+ result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+ auto memrefType = operands[0]->getType().cast<MemRefType>();
+ result->types.push_back(memrefType.getElementType());
+}
+
+void AffineLoadOp::build(Builder *builder, OperationState *result,
+ Value *memref, ArrayRef<Value *> indices) {
+ result->addOperands(memref);
+ result->addOperands(indices);
+ auto memrefType = memref->getType().cast<MemRefType>();
+ auto rank = memrefType.getRank();
+ // Create identity map for memrefs with at least one dimension or () -> ()
+ // for zero-dimensional memrefs.
+ auto map = rank ? builder->getMultiDimIdentityMap(rank)
+ : builder->getEmptyAffineMap();
+ result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+ result->types.push_back(memrefType.getElementType());
+}
+
+ParseResult AffineLoadOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ auto affineIntTy = builder.getIndexType();
+
+ MemRefType type;
+ OpAsmParser::OperandType memrefInfo;
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ return failure(
+ parser->parseOperand(memrefInfo) ||
+ parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
+ result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(mapOperands, affineIntTy, result->operands) ||
+ parser->addTypeToList(type.getElementType(), result->types));
+}
+
+void AffineLoadOp::print(OpAsmPrinter *p) {
+ *p << "affine.load " << *getMemRef() << '[';
+ AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+ if (mapAttr) {
+ SmallVector<Value *, 2> operands(getIndices());
+ p->printAffineMapOfSSAIds(mapAttr, operands);
+ }
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
+ *p << " : " << getMemRefType();
+}
+
+LogicalResult AffineLoadOp::verify() {
+ if (getType() != getMemRefType().getElementType())
+ return emitOpError("result type must match element type of memref");
+
+ auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+ if (mapAttr) {
+ AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue();
+ if (map.getNumResults() != getMemRefType().getRank())
+ return emitOpError("affine.load affine map num results must equal"
+ " memref rank");
+ if (map.getNumInputs() != getNumOperands() - 1)
+ return emitOpError("expects as many subscripts as affine map inputs");
+ } else {
+ if (getMemRefType().getRank() != getNumOperands() - 1)
+ return emitOpError(
+ "expects the number of subscripts to be equal to memref rank");
+ }
+
+ for (auto *idx : getIndices()) {
+ if (!idx->getType().isIndex())
+ return emitOpError("index to load must have 'index' type");
+ if (!isValidAffineIndexOperand(idx))
+ return emitOpError("index must be a dimension or symbol identifier");
+ }
+ return success();
+}
+
+void AffineLoadOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ /// load(memrefcast) -> load
+ results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineStoreOp
+//===----------------------------------------------------------------------===//
+
+void AffineStoreOp::build(Builder *builder, OperationState *result,
+ Value *valueToStore, AffineMap map,
+ ArrayRef<Value *> operands) {
+ result->addOperands(valueToStore);
+ result->addOperands(operands);
+ if (map)
+ result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+}
+
+void AffineStoreOp::build(Builder *builder, OperationState *result,
+ Value *valueToStore, Value *memref,
+ ArrayRef<Value *> operands) {
+ result->addOperands(valueToStore);
+ result->addOperands(memref);
+ result->addOperands(operands);
+ auto memrefType = memref->getType().cast<MemRefType>();
+ auto rank = memrefType.getRank();
+ // Create identity map for memrefs with at least one dimension or () -> ()
+ // for zero-dimensional memrefs.
+ auto map = rank ? builder->getMultiDimIdentityMap(rank)
+ : builder->getEmptyAffineMap();
+ result->addAttribute(getMapAttrName(), builder->getAffineMapAttr(map));
+}
+
+ParseResult AffineStoreOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto affineIntTy = parser->getBuilder().getIndexType();
+
+ MemRefType type;
+ OpAsmParser::OperandType storeValueInfo;
+ OpAsmParser::OperandType memrefInfo;
+ AffineMapAttr mapAttr;
+ SmallVector<OpAsmParser::OperandType, 1> mapOperands;
+ return failure(
+ parser->parseOperand(storeValueInfo) || parser->parseComma() ||
+ parser->parseOperand(memrefInfo) ||
+ parser->parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
+ result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(storeValueInfo, type.getElementType(),
+ result->operands) ||
+ parser->resolveOperand(memrefInfo, type, result->operands) ||
+ parser->resolveOperands(mapOperands, affineIntTy, result->operands));
+}
+
+void AffineStoreOp::print(OpAsmPrinter *p) {
+ *p << "affine.store " << *getValueToStore();
+ *p << ", " << *getMemRef() << '[';
+ AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+ if (mapAttr) {
+ SmallVector<Value *, 2> operands(getIndices());
+ p->printAffineMapOfSSAIds(mapAttr, operands);
+ }
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
+ *p << " : " << getMemRefType();
+}
+
+LogicalResult AffineStoreOp::verify() {
+ // First operand must have same type as memref element type.
+ if (getValueToStore()->getType() != getMemRefType().getElementType())
+ return emitOpError("first operand must have same type memref element type");
+
+ auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
+ if (mapAttr) {
+ AffineMap map = mapAttr.getValue();
+ if (map.getNumResults() != getMemRefType().getRank())
+ return emitOpError("affine.store affine map num results must equal"
+ " memref rank");
+ if (map.getNumInputs() != getNumOperands() - 2)
+ return emitOpError("expects as many subscripts as affine map inputs");
+ } else {
+ if (getMemRefType().getRank() != getNumOperands() - 2)
+ return emitOpError(
+ "expects the number of subscripts to be equal to memref rank");
+ }
+
+ for (auto *idx : getIndices()) {
+ if (!idx->getType().isIndex())
+ return emitOpError("index to store must have 'index' type");
+ if (!isValidAffineIndexOperand(idx))
+ return emitOpError("index must be a dimension or symbol identifier");
+ }
+ return success();
+}
+
+void AffineStoreOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ /// load(memrefcast) -> load
+ results.insert<MemRefCastFolder>(getOperationName(), context);
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
diff --git a/mlir/lib/Dialect/AffineOps/CMakeLists.txt b/mlir/lib/Dialect/AffineOps/CMakeLists.txt
new file mode 100644
index 00000000000..dbe469369a3
--- /dev/null
+++ b/mlir/lib/Dialect/AffineOps/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_llvm_library(MLIRAffineOps
+ AffineOps.cpp
+ DialectRegistration.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AffineOps
+ )
+add_dependencies(MLIRAffineOps MLIRAffineOpsIncGen MLIRIR MLIRStandardOps)
+target_link_libraries(MLIRAffineOps MLIRIR MLIRStandardOps)
+
diff --git a/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp
new file mode 100644
index 00000000000..9197e3c619f
--- /dev/null
+++ b/mlir/lib/Dialect/AffineOps/DialectRegistration.cpp
@@ -0,0 +1,22 @@
+//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+using namespace mlir;
+
+// Static initialization for Affine op dialect registration.
+static DialectRegistration<AffineOpsDialect> StandardOps;
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 294041df4a5..b0641a9611f 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(AffineOps)
add_subdirectory(FxpMathOps)
add_subdirectory(GPU)
add_subdirectory(Linalg)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
index 1c5bb6e70c8..c48437f60db 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
@@ -15,7 +15,12 @@
// limitations under the License.
// =============================================================================
-#include "mlir/AffineOps/AffineOps.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/EDSC/Helpers.h"
@@ -23,11 +28,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/OpImplementation.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/STLExtras.h"
OpenPOWER on IntegriCloud