summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-12-23 13:05:38 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-23 16:26:15 -0800
commit5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4 (patch)
tree7df1c8e31616dc8e59025def2de12c4327637428 /mlir/lib/Transforms
parenta5d5d2912506322b224eff0428de796a5ef7c1a4 (diff)
downloadbcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.tar.gz
bcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.zip
Change the `notifyRootUpdated` API to be transaction based.
This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation). PiperOrigin-RevId: 286933674
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp158
1 files changed, 129 insertions, 29 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index a19274acd1b..c9fcb670180 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -406,14 +406,16 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
- unsigned numBlockActions, unsigned numIgnoredOperations)
- : numCreatedOperations(numCreatedOperations),
- numReplacements(numReplacements), numBlockActions(numBlockActions),
- numIgnoredOperations(numIgnoredOperations) {}
+ RewriterState(unsigned numCreatedOps, unsigned numReplacements,
+ unsigned numBlockActions, unsigned numIgnoredOperations,
+ unsigned numRootUpdates)
+ : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numBlockActions(numBlockActions),
+ numIgnoredOperations(numIgnoredOperations),
+ numRootUpdates(numRootUpdates) {}
/// The current number of created operations.
- unsigned numCreatedOperations;
+ unsigned numCreatedOps;
/// The current number of replacements queued.
unsigned numReplacements;
@@ -423,6 +425,41 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of operations that were updated in place.
+ unsigned numRootUpdates;
+};
+
+/// The state of an operation that was updated by a pattern in-place. This
+/// contains all of the necessary information to reconstruct an operation that
+/// was updated in place.
+class OperationTransactionState {
+public:
+ OperationTransactionState() = default;
+ OperationTransactionState(Operation *op)
+ : op(op), loc(op->getLoc()), attrs(op->getAttrList()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void resetOperation() const {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (auto it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+ /// Return the original operation of this state.
+ Operation *getOperation() const { return op; }
+
+private:
+ Operation *op;
+ LocationAttr loc;
+ NamedAttributeList attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
};
} // end anonymous namespace
@@ -567,16 +604,32 @@ struct ConversionPatternRewriterImpl {
/// the others. This simplifies the amount of memory needed as we can query if
/// the parent operation was ignored.
llvm::SetVector<Operation *> ignoredOps;
+
+ /// A transaction state for each of operations that were updated in-place.
+ SmallVector<OperationTransactionState, 4> rootUpdates;
+
+#ifndef NDEBUG
+ /// A set of operations that have pending updates. This tracking isn't
+ /// strictly necessary, and is thus only active during debug builds for extra
+ /// verification.
+ SmallPtrSet<Operation *, 1> pendingRootUpdates;
+#endif
};
} // end namespace detail
} // end namespace mlir
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), ignoredOps.size());
+ blockActions.size(), ignoredOps.size(),
+ rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+ // Reset any operations that were updated in place.
+ for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
+ rootUpdates[i].resetOperation();
+ rootUpdates.resize(state.numRootUpdates);
+
// Undo any block actions.
undoBlockActions(state.numBlockActions);
@@ -587,7 +640,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
replacements.resize(state.numReplacements);
// Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOperations) {
+ while (createdOps.size() != state.numCreatedOps) {
createdOps.back()->erase();
createdOps.pop_back();
}
@@ -640,6 +693,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
}
void ConversionPatternRewriterImpl::discardRewrites() {
+ // Reset any operations that were updated in place.
+ for (auto &state : rootUpdates)
+ state.resetOperation();
+
undoBlockActions();
// Remove any newly created ops.
@@ -867,11 +924,34 @@ Operation *ConversionPatternRewriter::insert(Operation *op) {
}
/// PatternRewriter hook for updating the root operation in-place.
-void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
- // The rewriter caches changes to the IR to allow for operating in-place and
- // backtracking. The rewriter is currently not capable of backtracking
- // in-place modifications.
- llvm_unreachable("in-place operation updates are not supported");
+void ConversionPatternRewriter::startRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ impl->pendingRootUpdates.insert(op);
+#endif
+ impl->rootUpdates.emplace_back(op);
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
+ // There is nothing to do here, we only need to track the operation at the
+ // start of the update.
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+ // Erase the last update for this operation.
+ auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
+ auto &rootUpdates = impl->rootUpdates;
+ auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
+ rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
}
/// Return a reference to the internal implementation.
@@ -1059,8 +1139,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numCreatedOperations,
- e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *cstOp = rewriterImpl.createdOps[i];
if (failed(legalize(cstOp, rewriter))) {
@@ -1102,7 +1181,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
// Try to rewrite with the given pattern.
rewriter.setInsertionPoint(op);
- if (!pattern->matchAndRewrite(op, rewriter)) {
+ auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
+#ifndef NDEBUG
+ assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#endif
+
+ if (!matchedPattern) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
return cleanupFailure();
}
@@ -1139,12 +1223,32 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
else
rewriterImpl.ignoredOps.insert(replacedOp);
}
- assert(replacedRoot && "expected pattern to replace the root operation");
+
+ // Check that the root was either updated or replace.
+ auto updatedRootInPlace = [&] {
+ return llvm::any_of(
+ llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
+ [op](auto &state) { return state.getOperation() == op; });
+ };
(void)replacedRoot;
+ (void)updatedRootInPlace;
+ assert((replacedRoot || updatedRootInPlace()) &&
+ "expected pattern to replace the root operation");
+
+ // Recursively legalize each of the operations updated in place.
+ for (unsigned i = curState.numRootUpdates,
+ e = rewriterImpl.rootUpdates.size();
+ i != e; ++i) {
+ auto &state = rewriterImpl.rootUpdates[i];
+ if (failed(legalize(state.getOperation(), rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '"
+ << op->getName() << "' was illegal.\n");
+ return cleanupFailure();
+ }
+ }
// Recursively legalize each of the new operations.
- for (unsigned i = curState.numCreatedOperations,
- e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *op = rewriterImpl.createdOps[i];
if (failed(legalize(op, rewriter))) {
@@ -1534,16 +1638,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
- // Create a new function with an updated signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
- convertedResults, funcOp.getContext()));
-
- // Tell the rewriter to convert the region signature.
- rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
- rewriter.eraseOp(funcOp);
+ // Update the function signature in-place.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(FunctionType::get(result.getConvertedTypes(),
+ convertedResults, funcOp.getContext()));
+ rewriter.applySignatureConversion(&funcOp.getBody(), result);
+ });
return matchSuccess();
}
OpenPOWER on IntegriCloud