diff options
| author | River Riddle <riverriddle@google.com> | 2019-09-01 20:06:42 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-09-01 20:07:08 -0700 |
| commit | 6563b1c4463472d5bdc83a3a62a1da1a3052ce18 (patch) | |
| tree | d18db308d1013ab66f2e59fc9806904605862556 /mlir/lib/Transforms/Utils | |
| parent | ce702fc8dafab9d8d0db4e61025c65db979cd701 (diff) | |
| download | bcm5719-llvm-6563b1c4463472d5bdc83a3a62a1da1a3052ce18.tar.gz bcm5719-llvm-6563b1c4463472d5bdc83a3a62a1da1a3052ce18.zip | |
Add a new dialect interface for the OperationFolder `OpFolderDialectInterface`.
This interface will allow for providing hooks to interrop with operation folding. The first hook, 'shouldMaterializeInto', will allow for controlling which region to insert materialized constants into. The folder will generally materialize constants into the top-level isolated region, this allows for materializing into a lower level ancestor region if it is more profitable/correct.
PiperOrigin-RevId: 266702972
Diffstat (limited to 'mlir/lib/Transforms/Utils')
| -rw-r--r-- | mlir/lib/Transforms/Utils/FoldUtils.cpp | 18 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 2 |
2 files changed, 14 insertions, 6 deletions
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index 6c313e20932..5faca1296a8 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -31,7 +31,9 @@ using namespace mlir; /// Given an operation, find the parent region that folded constants should be /// inserted into. -static Region *getInsertionRegion(Operation *op) { +static Region *getInsertionRegion( + DialectInterfaceCollection<OpFolderDialectInterface> &interfaces, + Operation *op) { while (Region *region = op->getParentRegion()) { // Insert in this region for any of the following scenarios: // * The parent is unregistered, or is known to be isolated from above. @@ -40,6 +42,12 @@ static Region *getInsertionRegion(Operation *op) { if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() || !parentOp->getBlock()) return region; + + // Otherwise, check if this region is a desired insertion region. + auto *interface = interfaces.getInterfaceFor(parentOp); + if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) + return region; + // Traverse up the parent looking for an insertion region. op = parentOp; } @@ -119,7 +127,7 @@ void OperationFolder::notifyRemoval(Operation *op) { assert(constValue); // Get the constant map that this operation was uniqued in. - auto &uniquedConstants = foldScopes[getInsertionRegion(op)]; + auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)]; // Erase all of the references to this operation. auto type = op->getResult(0)->getType(); @@ -161,12 +169,12 @@ LogicalResult OperationFolder::tryToFold( // Create a builder to insert new operations into the entry block of the // insertion region. - auto *insertionRegion = getInsertionRegion(op); - auto &entry = insertionRegion->front(); + auto *insertRegion = getInsertionRegion(interfaces, op); + auto &entry = insertRegion->front(); OpBuilder builder(&entry, entry.begin()); // Get the constant map for the insertion region of this operation. - auto &uniquedConstants = foldScopes[insertionRegion]; + auto &uniquedConstants = foldScopes[insertRegion]; // Create the result constants and replace the results. auto *dialect = op->getDialect(); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index ddb92a58113..86e8848d6ec 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -45,7 +45,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const OwningRewritePatternList &patterns) - : PatternRewriter(ctx), matcher(patterns) { + : PatternRewriter(ctx), matcher(patterns), folder(ctx) { worklist.reserve(64); } |

