summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/BlockSupport.h1
-rw-r--r--mlir/include/mlir/IR/Operation.h6
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h38
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h11
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp13
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp25
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp12
-rw-r--r--mlir/lib/IR/Block.cpp5
-rw-r--r--mlir/lib/IR/PatternMatch.cpp17
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp158
10 files changed, 199 insertions, 87 deletions
diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h
index 7cefe870c22..bc6a8245c45 100644
--- a/mlir/include/mlir/IR/BlockSupport.h
+++ b/mlir/include/mlir/IR/BlockSupport.h
@@ -61,6 +61,7 @@ class SuccessorRange final
public:
using RangeBaseT::RangeBaseT;
SuccessorRange(Block *block);
+ SuccessorRange(Operation *term);
private:
/// See `detail::indexed_accessor_range_base` for details.
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 29227613468..47085f361ca 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -385,6 +385,12 @@ public:
return {getTrailingObjects<BlockOperand>(), numSuccs};
}
+ // Successor iteration.
+ using succ_iterator = SuccessorRange::iterator;
+ succ_iterator successor_begin() { return getSuccessors().begin(); }
+ succ_iterator successor_end() { return getSuccessors().end(); }
+ SuccessorRange getSuccessors() { return SuccessorRange(this); }
+
/// Return the operands of this operation that are *not* successor arguments.
operand_range getNonSuccessorOperands();
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index e6b5e7a5eb7..db160e3bdb2 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -361,15 +361,31 @@ public:
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
- /// This method is used as the final notification hook for patterns that end
- /// up modifying the pattern root in place, by changing its operands. This is
- /// a minor efficiency win (it avoids creating a new operation and removing
- /// the old one) but also often allows simpler code in the client.
- ///
- /// The valuesToRemoveIfDead list is an optional list of values that the
- /// rewriter should remove if they are dead at this point.
- ///
- void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {});
+ /// This method is used to notify the rewriter that an in-place operation
+ /// modification is about to happen. A call to this function *must* be
+ /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
+ /// This is a minor efficiency win (it avoids creating a new operation and
+ /// removing the old one) but also often allows simpler code in the client.
+ virtual void startRootUpdate(Operation *op) {}
+
+ /// This method is used to signal the end of a root update on the given
+ /// operation. This can only be called on operations that were provided to a
+ /// call to `startRootUpdate`.
+ virtual void finalizeRootUpdate(Operation *op) {}
+
+ /// This method cancels a pending root update. This can only be called on
+ /// operations that were provided to a call to `startRootUpdate`.
+ virtual void cancelRootUpdate(Operation *op) {}
+
+ /// This method is a utility wrapper around a root update of an operation. It
+ /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
+ /// callable.
+ template <typename CallableT>
+ void updateRootInPlace(Operation *root, CallableT &&callable) {
+ startRootUpdate(root);
+ callable();
+ finalizeRootUpdate(root);
+ }
protected:
explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
@@ -378,10 +394,6 @@ protected:
// These are the callback methods that subclasses can choose to implement if
// they would like to be notified about certain types of mutations.
- /// Notify the pattern rewriter that the specified operation has been mutated
- /// in place. This is called after the mutation is done.
- virtual void notifyRootUpdated(Operation *op) {}
-
/// Notify the pattern rewriter that the specified operation is about to be
/// replaced with another set of operations. This is called before the uses
/// of the operation have been changed.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index dca26348689..becb95f1f4e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -365,7 +365,16 @@ public:
Operation *insert(Operation *op) override;
/// PatternRewriter hook for updating the root operation in-place.
- void notifyRootUpdated(Operation *op) override;
+ /// Note: These methods only track updates to the top-level operation itself,
+ /// and not nested regions. Updates to regions will still require notification
+ /// through other more specific hooks above.
+ void startRootUpdate(Operation *op) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void finalizeRootUpdate(Operation *op) override;
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void cancelRootUpdate(Operation *op) override;
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
index 41deec1f6ab..c3937358c47 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -54,13 +54,12 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
signatureConverter.addInputs(argType.index(), convertedType);
}
}
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(rewriter.getFunctionType(
- signatureConverter.getConvertedTypes(), llvm::None));
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
+ });
return matchSuccess();
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 62d6a4b7ea4..422597fe90d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
PatternRewriter &rewriter) const override {
- auto origInsertionPoint = rewriter.saveInsertionPoint();
+ rewriter.startRootUpdate(launchOp);
+ PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&launchOp.body().front());
// Traverse operands passed to kernel and check if some of them are known
@@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
// and use it instead of passing the value from the parent region. Perform
// the traversal in the inverse order to simplify index arithmetics when
// dropping arguments.
- SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(),
- launchOp.getKernelOperandValues().end());
- SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(),
- launchOp.getKernelArguments().end());
+ auto operands = launchOp.getKernelOperandValues();
+ auto kernelArgs = launchOp.getKernelArguments();
bool found = false;
for (unsigned i = operands.size(); i > 0; --i) {
unsigned index = i - 1;
- ValuePtr operand = operands[index];
- if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
+ Value operand = operands[index];
+ if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
continue;
- }
found = true;
- ValuePtr internalConstant =
+ Value internalConstant =
rewriter.clone(*operand->getDefiningOp())->getResult(0);
- ValuePtr kernelArg = kernelArgs[index];
+ Value kernelArg = *std::next(kernelArgs.begin(), index);
kernelArg->replaceAllUsesWith(internalConstant);
launchOp.eraseKernelArgument(index);
}
- rewriter.restoreInsertionPoint(origInsertionPoint);
- if (!found)
+ if (!found) {
+ rewriter.cancelRootUpdate(launchOp);
return matchFailure();
+ }
- rewriter.updatedRootInPlace(launchOp);
+ rewriter.finalizeRootUpdate(launchOp);
return matchSuccess();
}
};
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 76e1b9b716e..0be24bf169c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -197,13 +197,11 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands,
}
// Creates a new function with the update signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(rewriter.getFunctionType(
- signatureConverter.getConvertedTypes(), llvm::None));
- rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
- rewriter.eraseOp(funcOp.getOperation());
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), llvm::None));
+ rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter);
+ });
return matchSuccess();
}
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 3abbe1027ce..751ceb1bfb4 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -267,3 +267,8 @@ SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) {
if ((count = term->getNumSuccessors()))
base = term->getBlockOperands().data();
}
+
+SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) {
+ if ((count = term->getNumSuccessors()))
+ base = term->getBlockOperands().data();
+}
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index d5749fabc07..50e6eeec982 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -170,23 +170,6 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
-/// This method is used as the final notification hook for patterns that end
-/// up modifying the pattern root in place, by changing its operands. This is
-/// a minor efficiency win (it avoids creating a new operation and removing
-/// the old one) but also often allows simpler code in the client.
-///
-/// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
-/// should remove if they are dead at this point.
-///
-void PatternRewriter::updatedRootInPlace(Operation *op,
- ValueRange valuesToRemoveIfDead) {
- // Notify the rewriter subclass that we're about to replace this root.
- notifyRootUpdated(op);
-
- // TODO: Process the valuesToRemoveIfDead list, removing things and calling
- // the notifyOperationRemoved hook in the process.
-}
-
//===----------------------------------------------------------------------===//
// PatternMatcher implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index a19274acd1b..c9fcb670180 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -406,14 +406,16 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOperations, unsigned numReplacements,
- unsigned numBlockActions, unsigned numIgnoredOperations)
- : numCreatedOperations(numCreatedOperations),
- numReplacements(numReplacements), numBlockActions(numBlockActions),
- numIgnoredOperations(numIgnoredOperations) {}
+ RewriterState(unsigned numCreatedOps, unsigned numReplacements,
+ unsigned numBlockActions, unsigned numIgnoredOperations,
+ unsigned numRootUpdates)
+ : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numBlockActions(numBlockActions),
+ numIgnoredOperations(numIgnoredOperations),
+ numRootUpdates(numRootUpdates) {}
/// The current number of created operations.
- unsigned numCreatedOperations;
+ unsigned numCreatedOps;
/// The current number of replacements queued.
unsigned numReplacements;
@@ -423,6 +425,41 @@ struct RewriterState {
/// The current number of ignored operations.
unsigned numIgnoredOperations;
+
+ /// The current number of operations that were updated in place.
+ unsigned numRootUpdates;
+};
+
+/// The state of an operation that was updated by a pattern in-place. This
+/// contains all of the necessary information to reconstruct an operation that
+/// was updated in place.
+class OperationTransactionState {
+public:
+ OperationTransactionState() = default;
+ OperationTransactionState(Operation *op)
+ : op(op), loc(op->getLoc()), attrs(op->getAttrList()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void resetOperation() const {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (auto it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+ /// Return the original operation of this state.
+ Operation *getOperation() const { return op; }
+
+private:
+ Operation *op;
+ LocationAttr loc;
+ NamedAttributeList attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
};
} // end anonymous namespace
@@ -567,16 +604,32 @@ struct ConversionPatternRewriterImpl {
/// the others. This simplifies the amount of memory needed as we can query if
/// the parent operation was ignored.
llvm::SetVector<Operation *> ignoredOps;
+
+ /// A transaction state for each of operations that were updated in-place.
+ SmallVector<OperationTransactionState, 4> rootUpdates;
+
+#ifndef NDEBUG
+ /// A set of operations that have pending updates. This tracking isn't
+ /// strictly necessary, and is thus only active during debug builds for extra
+ /// verification.
+ SmallPtrSet<Operation *, 1> pendingRootUpdates;
+#endif
};
} // end namespace detail
} // end namespace mlir
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), ignoredOps.size());
+ blockActions.size(), ignoredOps.size(),
+ rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+ // Reset any operations that were updated in place.
+ for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
+ rootUpdates[i].resetOperation();
+ rootUpdates.resize(state.numRootUpdates);
+
// Undo any block actions.
undoBlockActions(state.numBlockActions);
@@ -587,7 +640,7 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
replacements.resize(state.numReplacements);
// Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOperations) {
+ while (createdOps.size() != state.numCreatedOps) {
createdOps.back()->erase();
createdOps.pop_back();
}
@@ -640,6 +693,10 @@ void ConversionPatternRewriterImpl::undoBlockActions(
}
void ConversionPatternRewriterImpl::discardRewrites() {
+ // Reset any operations that were updated in place.
+ for (auto &state : rootUpdates)
+ state.resetOperation();
+
undoBlockActions();
// Remove any newly created ops.
@@ -867,11 +924,34 @@ Operation *ConversionPatternRewriter::insert(Operation *op) {
}
/// PatternRewriter hook for updating the root operation in-place.
-void ConversionPatternRewriter::notifyRootUpdated(Operation *op) {
- // The rewriter caches changes to the IR to allow for operating in-place and
- // backtracking. The rewriter is currently not capable of backtracking
- // in-place modifications.
- llvm_unreachable("in-place operation updates are not supported");
+void ConversionPatternRewriter::startRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ impl->pendingRootUpdates.insert(op);
+#endif
+ impl->rootUpdates.emplace_back(op);
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
+ // There is nothing to do here, we only need to track the operation at the
+ // start of the update.
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+}
+
+/// PatternRewriter hook for updating the root operation in-place.
+void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
+#ifndef NDEBUG
+ assert(impl->pendingRootUpdates.erase(op) &&
+ "operation did not have a pending in-place update");
+#endif
+ // Erase the last update for this operation.
+ auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
+ auto &rootUpdates = impl->rootUpdates;
+ auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
+ rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it));
}
/// Return a reference to the internal implementation.
@@ -1059,8 +1139,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numCreatedOperations,
- e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *cstOp = rewriterImpl.createdOps[i];
if (failed(legalize(cstOp, rewriter))) {
@@ -1102,7 +1181,12 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
// Try to rewrite with the given pattern.
rewriter.setInsertionPoint(op);
- if (!pattern->matchAndRewrite(op, rewriter)) {
+ auto matchedPattern = pattern->matchAndRewrite(op, rewriter);
+#ifndef NDEBUG
+ assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#endif
+
+ if (!matchedPattern) {
LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n");
return cleanupFailure();
}
@@ -1139,12 +1223,32 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
else
rewriterImpl.ignoredOps.insert(replacedOp);
}
- assert(replacedRoot && "expected pattern to replace the root operation");
+
+ // Check that the root was either updated or replace.
+ auto updatedRootInPlace = [&] {
+ return llvm::any_of(
+ llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates),
+ [op](auto &state) { return state.getOperation() == op; });
+ };
(void)replacedRoot;
+ (void)updatedRootInPlace;
+ assert((replacedRoot || updatedRootInPlace()) &&
+ "expected pattern to replace the root operation");
+
+ // Recursively legalize each of the operations updated in place.
+ for (unsigned i = curState.numRootUpdates,
+ e = rewriterImpl.rootUpdates.size();
+ i != e; ++i) {
+ auto &state = rewriterImpl.rootUpdates[i];
+ if (failed(legalize(state.getOperation(), rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '"
+ << op->getName() << "' was illegal.\n");
+ return cleanupFailure();
+ }
+ }
// Recursively legalize each of the new operations.
- for (unsigned i = curState.numCreatedOperations,
- e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
i != e; ++i) {
Operation *op = rewriterImpl.createdOps[i];
if (failed(legalize(op, rewriter))) {
@@ -1534,16 +1638,12 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
- // Create a new function with an updated signature.
- auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- newFuncOp.setType(FunctionType::get(result.getConvertedTypes(),
- convertedResults, funcOp.getContext()));
-
- // Tell the rewriter to convert the region signature.
- rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
- rewriter.eraseOp(funcOp);
+ // Update the function signature in-place.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(FunctionType::get(result.getConvertedTypes(),
+ convertedResults, funcOp.getContext()));
+ rewriter.applySignatureConversion(&funcOp.getBody(), result);
+ });
return matchSuccess();
}
OpenPOWER on IntegriCloud