diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-13 12:21:42 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-13 16:47:26 -0800 |
| commit | b030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch) | |
| tree | fa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/lib | |
| parent | 7b19d736172789ce8e5ca10ae6276302004533f0 (diff) | |
| download | bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip | |
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter.
PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/AffineOps.cpp | 8 | ||||
| -rw-r--r-- | mlir/lib/Dialect/StandardOps/Ops.cpp | 8 | ||||
| -rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 80 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 43 |
5 files changed, 128 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 96a1a68889c..22d4ec10dd0 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -99,6 +99,14 @@ AffineOpsDialect::AffineOpsDialect(MLIRContext *context) addInterfaces<AffineInlinerInterface, AffineSideEffectsInterface>(); } +/// Materialize a single constant operation from a given attribute value with +/// the desired resultant type. +Operation *AffineOpsDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create<ConstantOp>(loc, type, value); +} + /// A utility function to check if a given region is attached to a function. static bool isFunctionRegion(Region *region) { return llvm::isa<FuncOp>(region->getParentOp()); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 531be29666a..713546fc40d 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -163,6 +163,14 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) addInterfaces<StdInlinerInterface>(); } +/// Materialize a single constant operation from a given attribute value with +/// the desired resultant type. +Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create<ConstantOp>(loc, type, value); +} + void mlir::printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &p) { diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index a2345fe1c40..ae5579d9e3d 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -40,7 +40,7 @@ using namespace mlir::vector; // VectorOpsDialect //===----------------------------------------------------------------------===// -mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context) +VectorOpsDialect::VectorOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -48,6 +48,14 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context) >(); } +/// Materialize a single constant operation from a given attribute value with +/// the desired resultant type. +Operation *VectorOpsDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create<ConstantOp>(loc, type, value); +} + //===----------------------------------------------------------------------===// // ContractionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8c54df4d55b..691b2ad99c4 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -18,12 +18,13 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/Functional.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; Builder::Builder(ModuleOp module) : context(module.getContext()) {} @@ -339,27 +340,68 @@ Operation *OpBuilder::createOperation(const OperationState &state) { } /// Attempts to fold the given operation and places new results within -/// 'results'. -void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) { +/// 'results'. Returns success if the operation was folded, failure otherwise. +/// Note: This function does not erase the operation on a successful fold. +LogicalResult OpBuilder::tryFold(Operation *op, + SmallVectorImpl<Value *> &results) { results.reserve(op->getNumResults()); - SmallVector<OpFoldResult, 4> foldResults; - - // Returns if the given fold result corresponds to a valid existing value. - auto isValidValue = [](OpFoldResult result) { - return result.dyn_cast<Value *>(); + auto cleanupFailure = [&] { + results.assign(op->result_begin(), op->result_end()); + return failure(); }; - // Check if the fold failed, or did not result in only existing values. + // If this operation is already a constant, there is nothing to do. + Attribute unused; + if (matchPattern(op, m_Constant(&unused))) + return cleanupFailure(); + + // Check to see if any operands to the operation is constant and whether + // the operation knows how to constant fold itself. SmallVector<Attribute, 4> constOperands(op->getNumOperands()); - if (failed(op->fold(constOperands, foldResults)) || foldResults.empty() || - !llvm::all_of(foldResults, isValidValue)) { - // Simply return the existing operation results. - results.assign(op->result_begin(), op->result_end()); - return; + for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) + matchPattern(op->getOperand(i), m_Constant(&constOperands[i])); + + // Try to fold the operation. + SmallVector<OpFoldResult, 4> foldResults; + if (failed(op->fold(constOperands, foldResults)) || foldResults.empty()) + return cleanupFailure(); + + // A temporary builder used for creating constants during folding. + OpBuilder cstBuilder(context); + SmallVector<Operation *, 1> generatedConstants; + + // Populate the results with the folded results. + Dialect *dialect = op->getDialect(); + for (auto &it : llvm::enumerate(foldResults)) { + // Normal values get pushed back directly. + if (auto *value = it.value().dyn_cast<Value *>()) { + results.push_back(value); + continue; + } + + // Otherwise, try to materialize a constant operation. + if (!dialect) + return cleanupFailure(); + + // Ask the dialect to materialize a constant operation for this value. + Attribute attr = it.value().get<Attribute>(); + auto *constOp = dialect->materializeConstant( + cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc()); + if (!constOp) { + // Erase any generated constants. + for (Operation *cst : generatedConstants) + cst->erase(); + return cleanupFailure(); + } + assert(matchPattern(constOp, m_Constant(&attr))); + + generatedConstants.push_back(constOp); + results.push_back(constOp->getResult(0)); } - // Populate the results with the folded results and remove the original op. - llvm::transform(foldResults, std::back_inserter(results), - [](OpFoldResult result) { return result.get<Value *>(); }); - op->erase(); + // If we were successful, insert any generated constants. + for (Operation *cst : generatedConstants) + insert(cst); + + return success(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index ea4ad681693..ac13bc2ba5b 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -25,7 +25,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::detail; @@ -938,6 +937,10 @@ public: ConversionTarget &getTarget() { return target; } private: + /// Attempt to legalize the given operation by folding it. + LogicalResult legalizeWithFold(Operation *op, + ConversionPatternRewriter &rewriter); + /// Attempt to legalize the given operation by applying the provided pattern. /// Returns success if the operation was legalized, failure otherwise. LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, @@ -1003,6 +1006,14 @@ OperationLegalizer::legalize(Operation *op, return success(); } + // If the operation isn't legal, try to fold it in-place. + // TODO(riverriddle) Should we always try to do this, even if the op is + // already legal? + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n"); + return success(); + } + // Otherwise, we need to apply a legalization pattern to this operation. auto it = legalizerPatterns.find(op->getName()); if (it == legalizerPatterns.end()) { @@ -1020,6 +1031,36 @@ OperationLegalizer::legalize(Operation *op, } LogicalResult +OperationLegalizer::legalizeWithFold(Operation *op, + ConversionPatternRewriter &rewriter) { + auto &rewriterImpl = rewriter.getImpl(); + RewriterState curState = rewriterImpl.getCurrentState(); + + // Try to fold the operation. + SmallVector<Value *, 2> replacementValues; + rewriter.setInsertionPoint(op); + if (failed(rewriter.tryFold(op, replacementValues))) + return failure(); + + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + + // Recursively legalize any new constant operations. + for (unsigned i = curState.numCreatedOperations, + e = rewriterImpl.createdOps.size(); + i != e; ++i) { + Operation *cstOp = rewriterImpl.createdOps[i]; + if (failed(legalize(cstOp, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '" + << cstOp->getName() << "' was illegal.\n"); + rewriterImpl.resetState(curState); + return failure(); + } + } + return success(); +} + +LogicalResult OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ |

