summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/GPUDialect.cpp')
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp25
1 files changed, 12 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 62d6a4b7ea4..422597fe90d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
PatternMatchResult matchAndRewrite(LaunchOp launchOp,
PatternRewriter &rewriter) const override {
- auto origInsertionPoint = rewriter.saveInsertionPoint();
+ rewriter.startRootUpdate(launchOp);
+ PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&launchOp.body().front());
// Traverse operands passed to kernel and check if some of them are known
@@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
// and use it instead of passing the value from the parent region. Perform
// the traversal in the inverse order to simplify index arithmetics when
// dropping arguments.
- SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(),
- launchOp.getKernelOperandValues().end());
- SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(),
- launchOp.getKernelArguments().end());
+ auto operands = launchOp.getKernelOperandValues();
+ auto kernelArgs = launchOp.getKernelArguments();
bool found = false;
for (unsigned i = operands.size(); i > 0; --i) {
unsigned index = i - 1;
- ValuePtr operand = operands[index];
- if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) {
+ Value operand = operands[index];
+ if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
continue;
- }
found = true;
- ValuePtr internalConstant =
+ Value internalConstant =
rewriter.clone(*operand->getDefiningOp())->getResult(0);
- ValuePtr kernelArg = kernelArgs[index];
+ Value kernelArg = *std::next(kernelArgs.begin(), index);
kernelArg->replaceAllUsesWith(internalConstant);
launchOp.eraseKernelArgument(index);
}
- rewriter.restoreInsertionPoint(origInsertionPoint);
- if (!found)
+ if (!found) {
+ rewriter.cancelRootUpdate(launchOp);
return matchFailure();
+ }
- rewriter.updatedRootInPlace(launchOp);
+ rewriter.finalizeRootUpdate(launchOp);
return matchSuccess();
}
};
OpenPOWER on IntegriCloud