summaryrefslogtreecommitdiffstats
path: root/mlir/include
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-23 13:05:38 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-23 16:26:15 -0800
commit5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4 (patch)
tree7df1c8e31616dc8e59025def2de12c4327637428 /mlir/include
parenta5d5d2912506322b224eff0428de796a5ef7c1a4 (diff)
downloadbcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.tar.gz
bcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.zip
Change the `notifyRootUpdated` API to be transaction based.
This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation). PiperOrigin-RevId: 286933674
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir/IR/BlockSupport.h1
-rw-r--r--mlir/include/mlir/IR/Operation.h6
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h38
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h11
4 files changed, 42 insertions, 14 deletions
diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h
index 7cefe870c22..bc6a8245c45 100644
--- a/mlir/include/mlir/IR/BlockSupport.h
+++ b/mlir/include/mlir/IR/BlockSupport.h
@@ -61,6 +61,7 @@ class SuccessorRange final
public:
using RangeBaseT::RangeBaseT;
SuccessorRange(Block *block);
+ SuccessorRange(Operation *term);
private:
/// See `detail::indexed_accessor_range_base` for details.
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 29227613468..47085f361ca 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -385,6 +385,12 @@ public:
return {getTrailingObjects<BlockOperand>(), numSuccs};
}
+ // Successor iteration.
+ using succ_iterator = SuccessorRange::iterator;
+ succ_iterator successor_begin() { return getSuccessors().begin(); }
+ succ_iterator successor_end() { return getSuccessors().end(); }
+ SuccessorRange getSuccessors() { return SuccessorRange(this); }
+
/// Return the operands of this operation that are *not* successor arguments.
operand_range getNonSuccessorOperands();
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index e6b5e7a5eb7..db160e3bdb2 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -361,15 +361,31 @@ public:
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
- /// This method is used as the final notification hook for patterns that end
- /// up modifying the pattern root in place, by changing its operands. This is
- /// a minor efficiency win (it avoids creating a new operation and removing
- /// the old one) but also often allows simpler code in the client.
- ///
- /// The valuesToRemoveIfDead list is an optional list of values that the
- /// rewriter should remove if they are dead at this point.
- ///
- void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {});
+ /// This method is used to notify the rewriter that an in-place operation
+ /// modification is about to happen. A call to this function *must* be
+ /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
+ /// This is a minor efficiency win (it avoids creating a new operation and
+ /// removing the old one) but also often allows simpler code in the client.
+ virtual void startRootUpdate(Operation *op) {}
+
+ /// This method is used to signal the end of a root update on the given
+ /// operation. This can only be called on operations that were provided to a
+ /// call to `startRootUpdate`.
+ virtual void finalizeRootUpdate(Operation *op) {}
+
+ /// This method cancels a pending root update. This can only be called on
+ /// operations that were provided to a call to `startRootUpdate`.
+ virtual void cancelRootUpdate(Operation *op) {}
+
+ /// This method is a utility wrapper around a root update of an operation. It
+ /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
+ /// callable.
+ template <typename CallableT>
+ void updateRootInPlace(Operation *root, CallableT &&callable) {
+ startRootUpdate(root);
+ callable();
+ finalizeRootUpdate(root);
+ }
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
@@ -378,10 +394,6 @@ protected:
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
- /// Notify the pattern rewriter that the specified operation has been mutated
- /// in place. This is called after the mutation is done.
- virtual void notifyRootUpdated(Operation *op) {}
-
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index dca26348689..becb95f1f4e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -365,7 +365,16 @@ public:
Operation *insert(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
- void notifyRootUpdated(Operation *op) override;
+ /// Note: These methods only track updates to the top-level operation itself,
+ /// and not nested regions. Updates to regions will still require notification
+ /// through other more specific hooks above.
+ void startRootUpdate(Operation *op) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void finalizeRootUpdate(Operation *op) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void cancelRootUpdate(Operation *op) override;
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
OpenPOWER on IntegriCloud