summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-13 12:21:42 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-13 16:47:26 -0800
commitb030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch)
treefa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/lib/Transforms
parent7b19d736172789ce8e5ca10ae6276302004533f0 (diff)
downloadbcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz
bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter. PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp43
1 files changed, 42 insertions, 1 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ea4ad681693..ac13bc2ba5b 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -25,7 +25,6 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
@@ -938,6 +937,10 @@ public:
ConversionTarget &getTarget() { return target; }
private:
+ /// Attempt to legalize the given operation by folding it.
+ LogicalResult legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter);
+
/// Attempt to legalize the given operation by applying the provided pattern.
/// Returns success if the operation was legalized, failure otherwise.
LogicalResult legalizePattern(Operation *op, RewritePattern *pattern,
@@ -1003,6 +1006,14 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation isn't legal, try to fold it in-place.
+ // TODO(riverriddle) Should we always try to do this, even if the op is
+ // already legal?
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n");
+ return success();
+ }
+
// Otherwise, we need to apply a legalization pattern to this operation.
auto it = legalizerPatterns.find(op->getName());
if (it == legalizerPatterns.end()) {
@@ -1020,6 +1031,36 @@ OperationLegalizer::legalize(Operation *op,
}
LogicalResult
+OperationLegalizer::legalizeWithFold(Operation *op,
+ ConversionPatternRewriter &rewriter) {
+ auto &rewriterImpl = rewriter.getImpl();
+ RewriterState curState = rewriterImpl.getCurrentState();
+
+ // Try to fold the operation.
+ SmallVector<Value *, 2> replacementValues;
+ rewriter.setInsertionPoint(op);
+ if (failed(rewriter.tryFold(op, replacementValues)))
+ return failure();
+
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
+ // Recursively legalize any new constant operations.
+ for (unsigned i = curState.numCreatedOperations,
+ e = rewriterImpl.createdOps.size();
+ i != e; ++i) {
+ Operation *cstOp = rewriterImpl.createdOps[i];
+ if (failed(legalize(cstOp, rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '"
+ << cstOp->getName() << "' was illegal.\n");
+ rewriterImpl.resetState(curState);
+ return failure();
+ }
+ }
+ return success();
+}
+
+LogicalResult
OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
ConversionPatternRewriter &rewriter) {
LLVM_DEBUG({
OpenPOWER on IntegriCloud