diff options
| -rw-r--r-- | mlir/include/mlir/IR/BlockSupport.h | 1 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Operation.h | 6 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 38 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 11 | ||||
| -rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp | 13 | ||||
| -rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 25 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp | 12 | ||||
| -rw-r--r-- | mlir/lib/IR/Block.cpp | 5 | ||||
| -rw-r--r-- | mlir/lib/IR/PatternMatch.cpp | 17 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 158 |
10 files changed, 199 insertions, 87 deletions
diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h index 7cefe870c22..bc6a8245c45 100644 --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -61,6 +61,7 @@ class SuccessorRange final public: using RangeBaseT::RangeBaseT; SuccessorRange(Block *block); + SuccessorRange(Operation *term); private: /// See `detail::indexed_accessor_range_base` for details. diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 29227613468..47085f361ca 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -385,6 +385,12 @@ public: return {getTrailingObjects<BlockOperand>(), numSuccs}; } + // Successor iteration. + using succ_iterator = SuccessorRange::iterator; + succ_iterator successor_begin() { return getSuccessors().begin(); } + succ_iterator successor_end() { return getSuccessors().end(); } + SuccessorRange getSuccessors() { return SuccessorRange(this); } + /// Return the operands of this operation that are *not* successor arguments. operand_range getNonSuccessorOperands(); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index e6b5e7a5eb7..db160e3bdb2 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -361,15 +361,31 @@ public: /// block into a new block, and return it. virtual Block *splitBlock(Block *block, Block::iterator before); - /// This method is used as the final notification hook for patterns that end - /// up modifying the pattern root in place, by changing its operands. This is - /// a minor efficiency win (it avoids creating a new operation and removing - /// the old one) but also often allows simpler code in the client. - /// - /// The valuesToRemoveIfDead list is an optional list of values that the - /// rewriter should remove if they are dead at this point. - /// - void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {}); + /// This method is used to notify the rewriter that an in-place operation + /// modification is about to happen. A call to this function *must* be + /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. + /// This is a minor efficiency win (it avoids creating a new operation and + /// removing the old one) but also often allows simpler code in the client. + virtual void startRootUpdate(Operation *op) {} + + /// This method is used to signal the end of a root update on the given + /// operation. This can only be called on operations that were provided to a + /// call to `startRootUpdate`. + virtual void finalizeRootUpdate(Operation *op) {} + + /// This method cancels a pending root update. This can only be called on + /// operations that were provided to a call to `startRootUpdate`. + virtual void cancelRootUpdate(Operation *op) {} + + /// This method is a utility wrapper around a root update of an operation. It + /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given + /// callable. + template <typename CallableT> + void updateRootInPlace(Operation *root, CallableT &&callable) { + startRootUpdate(root); + callable(); + finalizeRootUpdate(root); + } protected: explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {} @@ -378,10 +394,6 @@ protected: // These are the callback methods that subclasses can choose to implement if // they would like to be notified about certain types of mutations. - /// Notify the pattern rewriter that the specified operation has been mutated - /// in place. This is called after the mutation is done. - virtual void notifyRootUpdated(Operation *op) {} - /// Notify the pattern rewriter that the specified operation is about to be /// replaced with another set of operations. This is called before the uses /// of the operation have been changed. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index dca26348689..becb95f1f4e 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -365,7 +365,16 @@ public: Operation *insert(Operation *op) override; /// PatternRewriter hook for updating the root operation in-place. - void notifyRootUpdated(Operation *op) override; + /// Note: These methods only track updates to the top-level operation itself, + /// and not nested regions. Updates to regions will still require notification + /// through other more specific hooks above. + void startRootUpdate(Operation *op) override; + + /// PatternRewriter hook for updating the root operation in-place. + void finalizeRootUpdate(Operation *op) override; + + /// PatternRewriter hook for updating the root operation in-place. + void cancelRootUpdate(Operation *op) override; /// Return a reference to the internal implementation. detail::ConversionPatternRewriterImpl &getImpl(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp index 41deec1f6ab..c3937358c47 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -54,13 +54,12 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands, signatureConverter.addInputs(argType.index(), convertedType); } } - auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - newFuncOp.setType(rewriter.getFunctionType( - signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); - rewriter.replaceOp(funcOp.getOperation(), llvm::None); + + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + }); return matchSuccess(); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 62d6a4b7ea4..422597fe90d 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { PatternMatchResult matchAndRewrite(LaunchOp launchOp, PatternRewriter &rewriter) const override { - auto origInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.startRootUpdate(launchOp); + PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&launchOp.body().front()); // Traverse operands passed to kernel and check if some of them are known @@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { // and use it instead of passing the value from the parent region. Perform // the traversal in the inverse order to simplify index arithmetics when // dropping arguments. - SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(), - launchOp.getKernelOperandValues().end()); - SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(), - launchOp.getKernelArguments().end()); + auto operands = launchOp.getKernelOperandValues(); + auto kernelArgs = launchOp.getKernelArguments(); bool found = false; for (unsigned i = operands.size(); i > 0; --i) { unsigned index = i - 1; - ValuePtr operand = operands[index]; - if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { + Value operand = operands[index]; + if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) continue; - } found = true; - ValuePtr internalConstant = + Value internalConstant = rewriter.clone(*operand->getDefiningOp())->getResult(0); - ValuePtr kernelArg = kernelArgs[index]; + Value kernelArg = *std::next(kernelArgs.begin(), index); kernelArg->replaceAllUsesWith(internalConstant); launchOp.eraseKernelArgument(index); } - rewriter.restoreInsertionPoint(origInsertionPoint); - if (!found) + if (!found) { + rewriter.cancelRootUpdate(launchOp); return matchFailure(); + } - rewriter.updatedRootInPlace(launchOp); + rewriter.finalizeRootUpdate(launchOp); return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 76e1b9b716e..0be24bf169c 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -197,13 +197,11 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands, } // Creates a new function with the update signature. - auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - newFuncOp.setType(rewriter.getFunctionType( - signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); - rewriter.eraseOp(funcOp.getOperation()); + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + }); return matchSuccess(); } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 3abbe1027ce..751ceb1bfb4 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -267,3 +267,8 @@ SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) { if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } + +SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) { + if ((count = term->getNumSuccessors())) + base = term->getBlockOperands().data(); +} diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index d5749fabc07..50e6eeec982 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -170,23 +170,6 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { cloneRegionBefore(region, *before->getParent(), before->getIterator()); } -/// This method is used as the final notification hook for patterns that end -/// up modifying the pattern root in place, by changing its operands. This is -/// a minor efficiency win (it avoids creating a new operation and removing -/// the old one) but also often allows simpler code in the client. -/// -/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter -/// should remove if they are dead at this point. -/// -void PatternRewriter::updatedRootInPlace(Operation *op, - ValueRange valuesToRemoveIfDead) { - // Notify the rewriter subclass that we're about to replace this root. - notifyRootUpdated(op); - - // TODO: Process the valuesToRemoveIfDead list, removing things and calling - // the notifyOperationRemoved hook in the process. -} - //===----------------------------------------------------------------------===// // PatternMatcher implementation //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index a19274acd1b..c9fcb670180 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -406,14 +406,16 @@ namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { - RewriterState(unsigned numCreatedOperations, unsigned numReplacements, - unsigned numBlockActions, unsigned numIgnoredOperations) - : numCreatedOperations(numCreatedOperations), - numReplacements(numReplacements), numBlockActions(numBlockActions), - numIgnoredOperations(numIgnoredOperations) {} + RewriterState(unsigned numCreatedOps, unsigned numReplacements, + unsigned numBlockActions, unsigned numIgnoredOperations, + unsigned numRootUpdates) + : numCreatedOps(numCreatedOps), numReplacements(numReplacements), + numBlockActions(numBlockActions), + numIgnoredOperations(numIgnoredOperations), + numRootUpdates(numRootUpdates) {} /// The current number of created operations. - unsigned numCreatedOperations; + unsigned numCreatedOps; /// The current number of replacements queued. unsigned numReplacements; @@ -423,6 +425,41 @@ struct RewriterState { /// The current number of ignored operations. unsigned numIgnoredOperations; + + /// The current number of operations that were updated in place. + unsigned numRootUpdates; +}; + +/// The state of an operation that was updated by a pattern in-place. This +/// contains all of the necessary information to reconstruct an operation that +/// was updated in place. +class OperationTransactionState { +public: + OperationTransactionState() = default; + OperationTransactionState(Operation *op) + : op(op), loc(op->getLoc()), attrs(op->getAttrList()), + operands(op->operand_begin(), op->operand_end()), + successors(op->successor_begin(), op->successor_end()) {} + + /// Discard the transaction state and reset the state of the original + /// operation. + void resetOperation() const { + op->setLoc(loc); + op->setAttrs(attrs); + op->setOperands(operands); + for (auto it : llvm::enumerate(successors)) + op->setSuccessor(it.value(), it.index()); + } + + /// Return the original operation of this state. + Operation *getOperation() const { return op; } + +private: + Operation *op; + LocationAttr loc; + NamedAttributeList attrs; + SmallVector<Value, 8> operands; + SmallVector<Block *, 2> successors; }; } // end anonymous namespace @@ -567,16 +604,32 @@ struct ConversionPatternRewriterImpl { /// the others. This simplifies the amount of memory needed as we can query if /// the parent operation was ignored. llvm::SetVector<Operation *> ignoredOps; + + /// A transaction state for each of operations that were updated in-place. + SmallVector<OperationTransactionState, 4> rootUpdates; + +#ifndef NDEBUG + /// A set of operations that have pending updates. This tracking isn't + /// strictly necessary, and is thus only active during debug builds for extra + /// verification. + SmallPtrSet<Operation *, 1> pendingRootUpdates; +#endif }; } // end namespace detail } // end namespace mlir RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), replacements.size(), - blockActions.size(), ignoredOps.size()); + blockActions.size(), ignoredOps.size(), + rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { + // Reset any operations that were updated in place. + for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) + rootUpdates[i].resetOperation(); + rootUpdates.resize(state.numRootUpdates); + // Undo any block actions. undoBlockActions(state.numBlockActions); @@ -587,7 +640,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { replacements.resize(state.numReplacements); // Pop all of the newly created operations. - while (createdOps.size() != state.numCreatedOperations) { + while (createdOps.size() != state.numCreatedOps) { createdOps.back()->erase(); createdOps.pop_back(); } @@ -640,6 +693,10 @@ void ConversionPatternRewriterImpl::undoBlockActions( } void ConversionPatternRewriterImpl::discardRewrites() { + // Reset any operations that were updated in place. + for (auto &state : rootUpdates) + state.resetOperation(); + undoBlockActions(); // Remove any newly created ops. @@ -867,11 +924,34 @@ Operation *ConversionPatternRewriter::insert(Operation *op) { } /// PatternRewriter hook for updating the root operation in-place. -void ConversionPatternRewriter::notifyRootUpdated(Operation *op) { - // The rewriter caches changes to the IR to allow for operating in-place and - // backtracking. The rewriter is currently not capable of backtracking - // in-place modifications. - llvm_unreachable("in-place operation updates are not supported"); +void ConversionPatternRewriter::startRootUpdate(Operation *op) { +#ifndef NDEBUG + impl->pendingRootUpdates.insert(op); +#endif + impl->rootUpdates.emplace_back(op); +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { + // There is nothing to do here, we only need to track the operation at the + // start of the update. +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { +#ifndef NDEBUG + assert(impl->pendingRootUpdates.erase(op) && + "operation did not have a pending in-place update"); +#endif + // Erase the last update for this operation. + auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; + auto &rootUpdates = impl->rootUpdates; + auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); + rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it)); } /// Return a reference to the internal implementation. @@ -1059,8 +1139,7 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. - for (unsigned i = curState.numCreatedOperations, - e = rewriterImpl.createdOps.size(); + for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); i != e; ++i) { Operation *cstOp = rewriterImpl.createdOps[i]; if (failed(legalize(cstOp, rewriter))) { @@ -1102,7 +1181,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, // Try to rewrite with the given pattern. rewriter.setInsertionPoint(op); - if (!pattern->matchAndRewrite(op, rewriter)) { + auto matchedPattern = pattern->matchAndRewrite(op, rewriter); +#ifndef NDEBUG + assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#endif + + if (!matchedPattern) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); return cleanupFailure(); } @@ -1139,12 +1223,32 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, else rewriterImpl.ignoredOps.insert(replacedOp); } - assert(replacedRoot && "expected pattern to replace the root operation"); + + // Check that the root was either updated or replace. + auto updatedRootInPlace = [&] { + return llvm::any_of( + llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates), + [op](auto &state) { return state.getOperation() == op; }); + }; (void)replacedRoot; + (void)updatedRootInPlace; + assert((replacedRoot || updatedRootInPlace()) && + "expected pattern to replace the root operation"); + + // Recursively legalize each of the operations updated in place. + for (unsigned i = curState.numRootUpdates, + e = rewriterImpl.rootUpdates.size(); + i != e; ++i) { + auto &state = rewriterImpl.rootUpdates[i]; + if (failed(legalize(state.getOperation(), rewriter))) { + LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '" + << op->getName() << "' was illegal.\n"); + return cleanupFailure(); + } + } // Recursively legalize each of the new operations. - for (unsigned i = curState.numCreatedOperations, - e = rewriterImpl.createdOps.size(); + for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); i != e; ++i) { Operation *op = rewriterImpl.createdOps[i]; if (failed(legalize(op, rewriter))) { @@ -1534,16 +1638,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> { if (failed(converter.convertTypes(type.getResults(), convertedResults))) return matchFailure(); - // Create a new function with an updated signature. - auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - newFuncOp.setType(FunctionType::get(result.getConvertedTypes(), - convertedResults, funcOp.getContext())); - - // Tell the rewriter to convert the region signature. - rewriter.applySignatureConversion(&newFuncOp.getBody(), result); - rewriter.eraseOp(funcOp); + // Update the function signature in-place. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(FunctionType::get(result.getConvertedTypes(), + convertedResults, funcOp.getContext())); + rewriter.applySignatureConversion(&funcOp.getBody(), result); + }); return matchSuccess(); } |

