diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-13 12:21:42 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-13 16:47:26 -0800 |
| commit | b030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch) | |
| tree | fa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/lib/Transforms | |
| parent | 7b19d736172789ce8e5ca10ae6276302004533f0 (diff) | |
| download | bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip | |
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter.
PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index ea4ad681693..ac13bc2ba5b 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -25,7 +25,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::detail; @@ -938,6 +937,10 @@ public: ConversionTarget &getTarget() { return target; } private: + /// Attempt to legalize the given operation by folding it. + LogicalResult legalizeWithFold(Operation *op, + ConversionPatternRewriter &rewriter); + /// Attempt to legalize the given operation by applying the provided pattern. /// Returns success if the operation was legalized, failure otherwise. LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, @@ -1003,6 +1006,14 @@ OperationLegalizer::legalize(Operation *op, return success(); } + // If the operation isn't legal, try to fold it in-place. + // TODO(riverriddle) Should we always try to do this, even if the op is + // already legal? + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n"); + return success(); + } + // Otherwise, we need to apply a legalization pattern to this operation. auto it = legalizerPatterns.find(op->getName()); if (it == legalizerPatterns.end()) { @@ -1020,6 +1031,36 @@ OperationLegalizer::legalize(Operation *op, } LogicalResult +OperationLegalizer::legalizeWithFold(Operation *op, + ConversionPatternRewriter &rewriter) { + auto &rewriterImpl = rewriter.getImpl(); + RewriterState curState = rewriterImpl.getCurrentState(); + + // Try to fold the operation. + SmallVector<Value *, 2> replacementValues; + rewriter.setInsertionPoint(op); + if (failed(rewriter.tryFold(op, replacementValues))) + return failure(); + + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + + // Recursively legalize any new constant operations. + for (unsigned i = curState.numCreatedOperations, + e = rewriterImpl.createdOps.size(); + i != e; ++i) { + Operation *cstOp = rewriterImpl.createdOps[i]; + if (failed(legalize(cstOp, rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '" + << cstOp->getName() << "' was illegal.\n"); + rewriterImpl.resetState(curState); + return failure(); + } + } + return success(); +} + +LogicalResult OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ |

