diff options
| author | River Riddle <riverriddle@google.com> | 2019-02-04 16:15:13 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:12:40 -0700 |
| commit | c9ad4621ce2d68cad547da360aedeee733b73f32 (patch) | |
| tree | 294e4a0353f053a1015831adc3e56a2d9a4aac5f /mlir/lib/AffineOps/AffineOps.cpp | |
| parent | 0f50414fa4553b1277684cb1dded84b334b35d51 (diff) | |
| download | bcm5719-llvm-c9ad4621ce2d68cad547da360aedeee733b73f32.tar.gz bcm5719-llvm-c9ad4621ce2d68cad547da360aedeee733b73f32.zip | |
NFC: Move AffineApplyOp to the AffineOps dialect. This also moves the isValidDim/isValidSymbol methods from Value to the AffineOps dialect.
PiperOrigin-RevId: 232386632
Diffstat (limited to 'mlir/lib/AffineOps/AffineOps.cpp')
| -rw-r--r-- | mlir/lib/AffineOps/AffineOps.cpp | 242 |
1 files changed, 239 insertions, 3 deletions
diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 2ef96aa3d14..39345d7fc7a 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -22,6 +22,8 @@ #include "mlir/IR/InstVisitor.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SmallBitVector.h" using namespace mlir; //===----------------------------------------------------------------------===// @@ -30,7 +32,241 @@ using namespace mlir; AffineOpsDialect::AffineOpsDialect(MLIRContext *context) : Dialect(/*namePrefix=*/"", context) { - addOperations<AffineForOp, AffineIfOp>(); + addOperations<AffineApplyOp, AffineForOp, AffineIfOp>(); +} + +// Value can be used as a dimension id if it is valid as a symbol, or +// it is an induction variable, or it is a result of affine apply operation +// with dimension id arguments. +bool mlir::isValidDim(const Value *value) { + if (auto *inst = value->getDefiningInst()) { + // Top level instruction or constant operation is ok. + if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = inst->dyn_cast<AffineApplyOp>()) + return op->isValidDim(); + return false; + } + // This value is a block argument. + return true; +} + +// Value can be used as a symbol if it is a constant, or it is defined at +// the top level, or it is a result of affine apply operation with symbol +// arguments. +bool mlir::isValidSymbol(const Value *value) { + if (auto *inst = value->getDefiningInst()) { + // Top level instruction or constant operation is ok. + if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>()) + return true; + // Affine apply operation is ok if all of its operands are ok. + if (auto op = inst->dyn_cast<AffineApplyOp>()) + return op->isValidSymbol(); + return false; + } + // Otherwise, the only valid symbol is a non induction variable block + // argument. + return !isForInductionVar(value); +} + +//===----------------------------------------------------------------------===// +// AffineApplyOp +//===----------------------------------------------------------------------===// + +void AffineApplyOp::build(Builder *builder, OperationState *result, + AffineMap map, ArrayRef<Value *> operands) { + result->addOperands(operands); + result->types.append(map.getNumResults(), builder->getIndexType()); + result->addAttribute("map", builder->getAffineMapAttr(map)); +} + +bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) { + auto &builder = parser->getBuilder(); + auto affineIntTy = builder.getIndexType(); + + AffineMapAttr mapAttr; + unsigned numDims; + if (parser->parseAttribute(mapAttr, "map", result->attributes) || + parseDimAndSymbolList(parser, result->operands, numDims) || + parser->parseOptionalAttributeDict(result->attributes)) + return true; + auto map = mapAttr.getValue(); + + if (map.getNumDims() != numDims || + numDims + map.getNumSymbols() != result->operands.size()) { + return parser->emitError(parser->getNameLoc(), + "dimension or symbol index mismatch"); + } + + result->types.append(map.getNumResults(), affineIntTy); + return false; +} + +void AffineApplyOp::print(OpAsmPrinter *p) const { + auto map = getAffineMap(); + *p << "affine_apply " << map; + printDimAndSymbolList(operand_begin(), operand_end(), map.getNumDims(), p); + p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map"); +} + +bool AffineApplyOp::verify() const { + // Check that affine map attribute was specified. + auto affineMapAttr = getAttrOfType<AffineMapAttr>("map"); + if (!affineMapAttr) + return emitOpError("requires an affine map"); + + // Check input and output dimensions match. + auto map = affineMapAttr.getValue(); + + // Verify that operand count matches affine map dimension and symbol count. + if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) + return emitOpError( + "operand count and affine map dimension and symbol count must match"); + + // Verify that result count matches affine map result count. + if (map.getNumResults() != 1) + return emitOpError("mapping must produce one value"); + + return false; +} + +// The result of the affine apply operation can be used as a dimension id if it +// is a CFG value or if it is an Value, and all the operands are valid +// dimension ids. +bool AffineApplyOp::isValidDim() const { + return llvm::all_of(getOperands(), + [](const Value *op) { return mlir::isValidDim(op); }); +} + +// The result of the affine apply operation can be used as a symbol if it is +// a CFG value or if it is an Value, and all the operands are symbols. +bool AffineApplyOp::isValidSymbol() const { + return llvm::all_of(getOperands(), + [](const Value *op) { return mlir::isValidSymbol(op); }); +} + +Attribute AffineApplyOp::constantFold(ArrayRef<Attribute> operands, + MLIRContext *context) const { + auto map = getAffineMap(); + SmallVector<Attribute, 1> result; + if (map.constantFold(operands, result)) + return Attribute(); + return result[0]; +} + +namespace { +/// SimplifyAffineApply operations. +/// +struct SimplifyAffineApply : public RewritePattern { + SimplifyAffineApply(MLIRContext *context) + : RewritePattern(AffineApplyOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Instruction *op) const override; + void rewrite(Instruction *op, std::unique_ptr<PatternState> state, + PatternRewriter &rewriter) const override; +}; +} // end anonymous namespace. + +namespace { +/// FIXME: this is massive overkill for simple obviously always matching +/// canonicalizations. Fix the pattern rewriter to make this easy. +struct SimplifyAffineApplyState : public PatternState { + AffineMap map; + SmallVector<Value *, 8> operands; + + SimplifyAffineApplyState(AffineMap map, + const SmallVector<Value *, 8> &operands) + : map(map), operands(operands) {} +}; + +} // end anonymous namespace. + +void mlir::canonicalizeMapAndOperands( + AffineMap *map, llvm::SmallVectorImpl<Value *> *operands) { + if (!map || operands->empty()) + return; + + assert(map->getNumInputs() == operands->size() && + "map inputs must match number of operands"); + + // Check to see what dims are used. + llvm::SmallBitVector usedDims(map->getNumDims()); + llvm::SmallBitVector usedSyms(map->getNumSymbols()); + map->walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) + usedDims[dimExpr.getPosition()] = true; + else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) + usedSyms[symExpr.getPosition()] = true; + }); + + auto *context = map->getContext(); + + SmallVector<Value *, 8> resultOperands; + resultOperands.reserve(operands->size()); + + llvm::SmallDenseMap<Value *, AffineExpr, 8> seenDims; + SmallVector<AffineExpr, 8> dimRemapping(map->getNumDims()); + unsigned nextDim = 0; + for (unsigned i = 0, e = map->getNumDims(); i != e; ++i) { + if (usedDims[i]) { + auto it = seenDims.find((*operands)[i]); + if (it == seenDims.end()) { + dimRemapping[i] = getAffineDimExpr(nextDim++, context); + resultOperands.push_back((*operands)[i]); + seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); + } else { + dimRemapping[i] = it->second; + } + } + } + llvm::SmallDenseMap<Value *, AffineExpr, 8> seenSymbols; + SmallVector<AffineExpr, 8> symRemapping(map->getNumSymbols()); + unsigned nextSym = 0; + for (unsigned i = 0, e = map->getNumSymbols(); i != e; ++i) { + if (usedSyms[i]) { + auto it = seenSymbols.find((*operands)[i + map->getNumDims()]); + if (it == seenSymbols.end()) { + symRemapping[i] = getAffineSymbolExpr(nextSym++, context); + resultOperands.push_back((*operands)[i + map->getNumDims()]); + seenSymbols.insert(std::make_pair((*operands)[i + map->getNumDims()], + symRemapping[i])); + } else { + symRemapping[i] = it->second; + } + } + } + *map = + map->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); + *operands = resultOperands; +} + +PatternMatchResult SimplifyAffineApply::match(Instruction *op) const { + auto apply = op->cast<AffineApplyOp>(); + auto map = apply->getAffineMap(); + + AffineMap oldMap = map; + SmallVector<Value *, 8> resultOperands(apply->getOperands().begin(), + apply->getOperands().end()); + canonicalizeMapAndOperands(&map, &resultOperands); + if (map != oldMap) + return matchSuccess( + std::make_unique<SimplifyAffineApplyState>(map, resultOperands)); + + return matchFailure(); +} + +void SimplifyAffineApply::rewrite(Instruction *op, + std::unique_ptr<PatternState> state, + PatternRewriter &rewriter) const { + auto *applyState = static_cast<SimplifyAffineApplyState *>(state.get()); + rewriter.replaceOpWithNewOp<AffineApplyOp>(op, applyState->map, + applyState->operands); +} + +void AffineApplyOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.push_back(std::make_unique<SimplifyAffineApply>(context)); } //===----------------------------------------------------------------------===// @@ -493,9 +729,9 @@ bool AffineIfOp::verify() const { IntegerSet condition = conditionAttr.getValue(); for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { const Value *operand = getOperand(i); - if (i < condition.getNumDims() && !operand->isValidDim()) + if (i < condition.getNumDims() && !isValidDim(operand)) return emitOpError("operand cannot be used as a dimension id"); - if (i >= condition.getNumDims() && !operand->isValidSymbol()) + if (i >= condition.getNumDims() && !isValidSymbol(operand)) return emitOpError("operand cannot be used as a symbol"); } |

