summaryrefslogtreecommitdiffstats
path: root/mlir/lib/AffineOps/AffineOps.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-02-04 16:15:13 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 16:12:40 -0700
commitc9ad4621ce2d68cad547da360aedeee733b73f32 (patch)
tree294e4a0353f053a1015831adc3e56a2d9a4aac5f /mlir/lib/AffineOps/AffineOps.cpp
parent0f50414fa4553b1277684cb1dded84b334b35d51 (diff)
downloadbcm5719-llvm-c9ad4621ce2d68cad547da360aedeee733b73f32.tar.gz
bcm5719-llvm-c9ad4621ce2d68cad547da360aedeee733b73f32.zip
NFC: Move AffineApplyOp to the AffineOps dialect. This also moves the isValidDim/isValidSymbol methods from Value to the AffineOps dialect.
PiperOrigin-RevId: 232386632
Diffstat (limited to 'mlir/lib/AffineOps/AffineOps.cpp')
-rw-r--r--mlir/lib/AffineOps/AffineOps.cpp242
1 files changed, 239 insertions, 3 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp
index 2ef96aa3d14..39345d7fc7a 100644
--- a/mlir/lib/AffineOps/AffineOps.cpp
+++ b/mlir/lib/AffineOps/AffineOps.cpp
@@ -22,6 +22,8 @@
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -30,7 +32,241 @@ using namespace mlir;
AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
- addOperations<AffineForOp, AffineIfOp>();
+ addOperations<AffineApplyOp, AffineForOp, AffineIfOp>();
+}
+
+// Value can be used as a dimension id if it is valid as a symbol, or
+// it is an induction variable, or it is a result of affine apply operation
+// with dimension id arguments.
+bool mlir::isValidDim(const Value *value) {
+ if (auto *inst = value->getDefiningInst()) {
+ // Top level instruction or constant operation is ok.
+ if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
+ return true;
+ // Affine apply operation is ok if all of its operands are ok.
+ if (auto op = inst->dyn_cast<AffineApplyOp>())
+ return op->isValidDim();
+ return false;
+ }
+ // This value is a block argument.
+ return true;
+}
+
+// Value can be used as a symbol if it is a constant, or it is defined at
+// the top level, or it is a result of affine apply operation with symbol
+// arguments.
+bool mlir::isValidSymbol(const Value *value) {
+ if (auto *inst = value->getDefiningInst()) {
+ // Top level instruction or constant operation is ok.
+ if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
+ return true;
+ // Affine apply operation is ok if all of its operands are ok.
+ if (auto op = inst->dyn_cast<AffineApplyOp>())
+ return op->isValidSymbol();
+ return false;
+ }
+ // Otherwise, the only valid symbol is a non induction variable block
+ // argument.
+ return !isForInductionVar(value);
+}
+
+//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+void AffineApplyOp::build(Builder *builder, OperationState *result,
+ AffineMap map, ArrayRef<Value *> operands) {
+ result->addOperands(operands);
+ result->types.append(map.getNumResults(), builder->getIndexType());
+ result->addAttribute("map", builder->getAffineMapAttr(map));
+}
+
+bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
+ auto &builder = parser->getBuilder();
+ auto affineIntTy = builder.getIndexType();
+
+ AffineMapAttr mapAttr;
+ unsigned numDims;
+ if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
+ parseDimAndSymbolList(parser, result->operands, numDims) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return true;
+ auto map = mapAttr.getValue();
+
+ if (map.getNumDims() != numDims ||
+ numDims + map.getNumSymbols() != result->operands.size()) {
+ return parser->emitError(parser->getNameLoc(),
+ "dimension or symbol index mismatch");
+ }
+
+ result->types.append(map.getNumResults(), affineIntTy);
+ return false;
+}
+
+void AffineApplyOp::print(OpAsmPrinter *p) const {
+ auto map = getAffineMap();
+ *p << "affine_apply " << map;
+ printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p);
+ p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
+}
+
+bool AffineApplyOp::verify() const {
+ // Check that affine map attribute was specified.
+ auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
+ if (!affineMapAttr)
+ return emitOpError("requires an affine map");
+
+ // Check input and output dimensions match.
+ auto map = affineMapAttr.getValue();
+
+ // Verify that operand count matches affine map dimension and symbol count.
+ if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
+ return emitOpError(
+ "operand count and affine map dimension and symbol count must match");
+
+ // Verify that result count matches affine map result count.
+ if (map.getNumResults() != 1)
+ return emitOpError("mapping must produce one value");
+
+ return false;
+}
+
+// The result of the affine apply operation can be used as a dimension id if it
+// is a CFG value or if it is an Value, and all the operands are valid
+// dimension ids.
+bool AffineApplyOp::isValidDim() const {
+ return llvm::all_of(getOperands(),
+ [](const Value *op) { return mlir::isValidDim(op); });
+}
+
+// The result of the affine apply operation can be used as a symbol if it is
+// a CFG value or if it is an Value, and all the operands are symbols.
+bool AffineApplyOp::isValidSymbol() const {
+ return llvm::all_of(getOperands(),
+ [](const Value *op) { return mlir::isValidSymbol(op); });
+}
+
+Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands,
+ MLIRContext *context) const {
+ auto map = getAffineMap();
+ SmallVector<Attribute, 1> result;
+ if (map.constantFold(operands, result))
+ return Attribute();
+ return result[0];
+}
+
+namespace {
+/// SimplifyAffineApply operations.
+///
+struct SimplifyAffineApply : public RewritePattern {
+ SimplifyAffineApply(MLIRContext *context)
+ : RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult match(Instruction *op) const override;
+ void rewrite(Instruction *op, std::unique_ptr<PatternState> state,
+ PatternRewriter &rewriter) const override;
+};
+} // end anonymous namespace.
+
+namespace {
+/// FIXME: this is massive overkill for simple obviously always matching
+/// canonicalizations. Fix the pattern rewriter to make this easy.
+struct SimplifyAffineApplyState : public PatternState {
+ AffineMap map;
+ SmallVector<Value *, 8> operands;
+
+ SimplifyAffineApplyState(AffineMap map,
+ const SmallVector<Value *, 8> &operands)
+ : map(map), operands(operands) {}
+};
+
+} // end anonymous namespace.
+
+void mlir::canonicalizeMapAndOperands(
+ AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) {
+ if (!map || operands->empty())
+ return;
+
+ assert(map->getNumInputs() == operands->size() &&
+ "map inputs must match number of operands");
+
+ // Check to see what dims are used.
+ llvm::SmallBitVector usedDims(map->getNumDims());
+ llvm::SmallBitVector usedSyms(map->getNumSymbols());
+ map->walkExprs([&](AffineExpr expr) {
+ if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
+ usedDims[dimExpr.getPosition()] = true;
+ else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
+ usedSyms[symExpr.getPosition()] = true;
+ });
+
+ auto *context = map->getContext();
+
+ SmallVector<Value *, 8> resultOperands;
+ resultOperands.reserve(operands->size());
+
+ llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims;
+ SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims());
+ unsigned nextDim = 0;
+ for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) {
+ if (usedDims[i]) {
+ auto it = seenDims.find((*operands)[i]);
+ if (it == seenDims.end()) {
+ dimRemapping[i] = getAffineDimExpr(nextDim++, context);
+ resultOperands.push_back((*operands)[i]);
+ seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
+ } else {
+ dimRemapping[i] = it->second;
+ }
+ }
+ }
+ llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols;
+ SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols());
+ unsigned nextSym = 0;
+ for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) {
+ if (usedSyms[i]) {
+ auto it = seenSymbols.find((*operands)[i + map->getNumDims()]);
+ if (it == seenSymbols.end()) {
+ symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
+ resultOperands.push_back((*operands)[i + map->getNumDims()]);
+ seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()],
+ symRemapping[i]));
+ } else {
+ symRemapping[i] = it->second;
+ }
+ }
+ }
+ *map =
+ map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym);
+ *operands = resultOperands;
+}
+
+PatternMatchResult SimplifyAffineApply::match(Instruction *op) const {
+ auto apply = op->cast<AffineApplyOp>();
+ auto map = apply->getAffineMap();
+
+ AffineMap oldMap = map;
+ SmallVector<Value *, 8> resultOperands(apply->getOperands().begin(),
+ apply->getOperands().end());
+ canonicalizeMapAndOperands(&map, &resultOperands);
+ if (map != oldMap)
+ return matchSuccess(
+ std::make_unique<SimplifyAffineApplyState>(map, resultOperands));
+
+ return matchFailure();
+}
+
+void SimplifyAffineApply::rewrite(Instruction *op,
+ std::unique_ptr<PatternState> state,
+ PatternRewriter &rewriter) const {
+ auto *applyState = static_cast<SimplifyAffineApplyState *>(state.get());
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(op, applyState->map,
+ applyState->operands);
+}
+
+void AffineApplyOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.push_back(std::make_unique<SimplifyAffineApply>(context));
}
//===----------------------------------------------------------------------===//
@@ -493,9 +729,9 @@ bool AffineIfOp::verify() const {
IntegerSet condition = conditionAttr.getValue();
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
const Value *operand = getOperand(i);
- if (i < condition.getNumDims() && !operand->isValidDim())
+ if (i < condition.getNumDims() && !isValidDim(operand))
return emitOpError("operand cannot be used as a dimension id");
- if (i >= condition.getNumDims() && !operand->isValidSymbol())
+ if (i >= condition.getNumDims() && !isValidSymbol(operand))
return emitOpError("operand cannot be used as a symbol");
}
OpenPOWER on IntegriCloud