diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/GPUDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 62d6a4b7ea4..422597fe90d 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { PatternMatchResult matchAndRewrite(LaunchOp launchOp, PatternRewriter &rewriter) const override { - auto origInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.startRootUpdate(launchOp); + PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&launchOp.body().front()); // Traverse operands passed to kernel and check if some of them are known @@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { // and use it instead of passing the value from the parent region. Perform // the traversal in the inverse order to simplify index arithmetics when // dropping arguments. - SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(), - launchOp.getKernelOperandValues().end()); - SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(), - launchOp.getKernelArguments().end()); + auto operands = launchOp.getKernelOperandValues(); + auto kernelArgs = launchOp.getKernelArguments(); bool found = false; for (unsigned i = operands.size(); i > 0; --i) { unsigned index = i - 1; - ValuePtr operand = operands[index]; - if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { + Value operand = operands[index]; + if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) continue; - } found = true; - ValuePtr internalConstant = + Value internalConstant = rewriter.clone(*operand->getDefiningOp())->getResult(0); - ValuePtr kernelArg = kernelArgs[index]; + Value kernelArg = *std::next(kernelArgs.begin(), index); kernelArg->replaceAllUsesWith(internalConstant); launchOp.eraseKernelArgument(index); } - rewriter.restoreInsertionPoint(origInsertionPoint); - if (!found) + if (!found) { + rewriter.cancelRootUpdate(launchOp); return matchFailure(); + } - rewriter.updatedRootInPlace(launchOp); + rewriter.finalizeRootUpdate(launchOp); return matchSuccess(); } }; |

