diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-23 13:05:38 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-23 16:26:15 -0800 |
| commit | 5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4 (patch) | |
| tree | 7df1c8e31616dc8e59025def2de12c4327637428 /mlir/lib/Transforms | |
| parent | a5d5d2912506322b224eff0428de796a5ef7c1a4 (diff) | |
| download | bcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.tar.gz bcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.zip | |
Change the `notifyRootUpdated` API to be transaction based.
This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation).
PiperOrigin-RevId: 286933674
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 158 |
1 files changed, 129 insertions, 29 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index a19274acd1b..c9fcb670180 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -406,14 +406,16 @@ namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { - RewriterState(unsigned numCreatedOperations, unsigned numReplacements, - unsigned numBlockActions, unsigned numIgnoredOperations) - : numCreatedOperations(numCreatedOperations), - numReplacements(numReplacements), numBlockActions(numBlockActions), - numIgnoredOperations(numIgnoredOperations) {} + RewriterState(unsigned numCreatedOps, unsigned numReplacements, + unsigned numBlockActions, unsigned numIgnoredOperations, + unsigned numRootUpdates) + : numCreatedOps(numCreatedOps), numReplacements(numReplacements), + numBlockActions(numBlockActions), + numIgnoredOperations(numIgnoredOperations), + numRootUpdates(numRootUpdates) {} /// The current number of created operations. - unsigned numCreatedOperations; + unsigned numCreatedOps; /// The current number of replacements queued. unsigned numReplacements; @@ -423,6 +425,41 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; + + /// The current number of operations that were updated in place. + unsigned numRootUpdates; +}; + +/// The state of an operation that was updated by a pattern in-place. This +/// contains all of the necessary information to reconstruct an operation that +/// was updated in place. +class OperationTransactionState { +public: + OperationTransactionState() = default; + OperationTransactionState(Operation *op) + : op(op), loc(op->getLoc()), attrs(op->getAttrList()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) {} + + /// Discard the transaction state and reset the state of the original + /// operation. + void resetOperation() const { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (auto it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + } + + /// Return the original operation of this state. + Operation *getOperation() const { return op; } + +private: + Operation *op; + LocationAttr loc; + NamedAttributeList attrs; + SmallVector<Value, 8> operands; + SmallVector<Block *, 2> successors; }; } // end anonymous namespace @@ -567,16 +604,32 @@ struct ConversionPatternRewriterImpl { /// the others. This simplifies the amount of memory needed as we can query if /// the parent operation was ignored. llvm::SetVector<Operation *> ignoredOps; + + /// A transaction state for each of operations that were updated in-place. + SmallVector<OperationTransactionState, 4> rootUpdates; + +#ifndef NDEBUG + /// A set of operations that have pending updates. This tracking isn't + /// strictly necessary, and is thus only active during debug builds for extra + /// verification. + SmallPtrSet<Operation *, 1> pendingRootUpdates; +#endif }; } // end namespace detail } // end namespace mlir RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), replacements.size(), - blockActions.size(), ignoredOps.size()); + blockActions.size(), ignoredOps.size(), + rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { + // Reset any operations that were updated in place. + for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) + rootUpdates[i].resetOperation(); + rootUpdates.resize(state.numRootUpdates); + // Undo any block actions. undoBlockActions(state.numBlockActions); @@ -587,7 +640,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { replacements.resize(state.numReplacements); // Pop all of the newly created operations. - while (createdOps.size() != state.numCreatedOperations) { + while (createdOps.size() != state.numCreatedOps) { createdOps.back()->erase(); createdOps.pop_back(); } @@ -640,6 +693,10 @@ void ConversionPatternRewriterImpl::undoBlockActions( } void ConversionPatternRewriterImpl::discardRewrites() { + // Reset any operations that were updated in place. + for (auto &state : rootUpdates) + state.resetOperation(); + undoBlockActions(); // Remove any newly created ops. @@ -867,11 +924,34 @@ Operation *ConversionPatternRewriter::insert(Operation *op) { } /// PatternRewriter hook for updating the root operation in-place. -void ConversionPatternRewriter::notifyRootUpdated(Operation *op) { - // The rewriter caches changes to the IR to allow for operating in-place and - // backtracking. The rewriter is currently not capable of backtracking - // in-place modifications. - llvm_unreachable("in-place operation updates are not supported"); +void ConversionPatternRewriter::startRootUpdate(Operation *op) { +#ifndef NDEBUG + impl->pendingRootUpdates.insert(op); +#endif + impl->rootUpdates.emplace_back(op); +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { + // There is nothing to do here, we only need to track the operation at the + // start of the update. +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif + // Erase the last update for this operation. + auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; + auto &rootUpdates = impl->rootUpdates; + auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); + rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it)); } /// Return a reference to the internal implementation. @@ -1059,8 +1139,7 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. - for (unsigned i = curState.numCreatedOperations, - e = rewriterImpl.createdOps.size(); + for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); i != e; ++i) { Operation *cstOp = rewriterImpl.createdOps[i]; if (failed(legalize(cstOp, rewriter))) { @@ -1102,7 +1181,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, // Try to rewrite with the given pattern. rewriter.setInsertionPoint(op); - if (!pattern->matchAndRewrite(op, rewriter)) { + auto matchedPattern = pattern->matchAndRewrite(op, rewriter); +#ifndef NDEBUG + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#endif + + if (!matchedPattern) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); return cleanupFailure(); } @@ -1139,12 +1223,32 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, else rewriterImpl.ignoredOps.insert(replacedOp); } - assert(replacedRoot && "expected pattern to replace the root operation"); + + // Check that the root was either updated or replace. + auto updatedRootInPlace = [&] { + return llvm::any_of( + llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates), + [op](auto &state) { return state.getOperation() == op; }); + }; (void)replacedRoot; + (void)updatedRootInPlace; + assert((replacedRoot || updatedRootInPlace()) && + "expected pattern to replace the root operation"); + + // Recursively legalize each of the operations updated in place. + for (unsigned i = curState.numRootUpdates, + e = rewriterImpl.rootUpdates.size(); + i != e; ++i) { + auto &state = rewriterImpl.rootUpdates[i]; + if (failed(legalize(state.getOperation(), rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '" + << op->getName() << "' was illegal.\n"); + return cleanupFailure(); + } + } // Recursively legalize each of the new operations. - for (unsigned i = curState.numCreatedOperations, - e = rewriterImpl.createdOps.size(); + for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); i != e; ++i) { Operation *op = rewriterImpl.createdOps[i]; if (failed(legalize(op, rewriter))) { @@ -1534,16 +1638,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> { if (failed(converter.convertTypes(type.getResults(), convertedResults))) return matchFailure(); - // Create a new function with an updated signature. - auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - newFuncOp.setType(FunctionType::get(result.getConvertedTypes(), - convertedResults, funcOp.getContext())); - - // Tell the rewriter to convert the region signature. - rewriter.applySignatureConversion(&newFuncOp.getBody(), result); - rewriter.eraseOp(funcOp); + // Update the function signature in-place. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(FunctionType::get(result.getConvertedTypes(), + convertedResults, funcOp.getContext())); + rewriter.applySignatureConversion(&funcOp.getBody(), result); + }); return matchSuccess(); } |

