diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-23 13:05:38 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-23 16:26:15 -0800 |
| commit | 5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4 (patch) | |
| tree | 7df1c8e31616dc8e59025def2de12c4327637428 /mlir/lib/Dialect | |
| parent | a5d5d2912506322b224eff0428de796a5ef7c1a4 (diff) | |
| download | bcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.tar.gz bcm5719-llvm-5d5bd2e1da29d976cb125dbb3cd097a5e42b2be4.zip | |
Change the `notifyRootUpdated` API to be transaction based.
This means that in-place, or root, updates need to use explicit calls to `startRootUpdate`, `finalizeRootUpdate`, and `cancelRootUpdate`. The major benefit of this change is that it enables in-place updates in DialectConversion, which simplifies the FuncOp pattern for example. The major downside to this is that the cases that *may* modify an operation in-place will need an explicit cancel on the failure branches(assuming that they started an update before attempting the transformation).
PiperOrigin-RevId: 286933674
Diffstat (limited to 'mlir/lib/Dialect')
| -rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 25 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp | 12 |
2 files changed, 17 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 62d6a4b7ea4..422597fe90d 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -472,7 +472,8 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { PatternMatchResult matchAndRewrite(LaunchOp launchOp, PatternRewriter &rewriter) const override { - auto origInsertionPoint = rewriter.saveInsertionPoint(); + rewriter.startRootUpdate(launchOp); + PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&launchOp.body().front()); // Traverse operands passed to kernel and check if some of them are known @@ -480,31 +481,29 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> { // and use it instead of passing the value from the parent region. Perform // the traversal in the inverse order to simplify index arithmetics when // dropping arguments. - SmallVector<ValuePtr, 8> operands(launchOp.getKernelOperandValues().begin(), - launchOp.getKernelOperandValues().end()); - SmallVector<ValuePtr, 8> kernelArgs(launchOp.getKernelArguments().begin(), - launchOp.getKernelArguments().end()); + auto operands = launchOp.getKernelOperandValues(); + auto kernelArgs = launchOp.getKernelArguments(); bool found = false; for (unsigned i = operands.size(); i > 0; --i) { unsigned index = i - 1; - ValuePtr operand = operands[index]; - if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) { + Value operand = operands[index]; + if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp())) continue; - } found = true; - ValuePtr internalConstant = + Value internalConstant = rewriter.clone(*operand->getDefiningOp())->getResult(0); - ValuePtr kernelArg = kernelArgs[index]; + Value kernelArg = *std::next(kernelArgs.begin(), index); kernelArg->replaceAllUsesWith(internalConstant); launchOp.eraseKernelArgument(index); } - rewriter.restoreInsertionPoint(origInsertionPoint); - if (!found) + if (!found) { + rewriter.cancelRootUpdate(launchOp); return matchFailure(); + } - rewriter.updatedRootInPlace(launchOp); + rewriter.finalizeRootUpdate(launchOp); return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 76e1b9b716e..0be24bf169c 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -197,13 +197,11 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<ValuePtr> operands, } // Creates a new function with the update signature. - auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); - rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), - newFuncOp.end()); - newFuncOp.setType(rewriter.getFunctionType( - signatureConverter.getConvertedTypes(), llvm::None)); - rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); - rewriter.eraseOp(funcOp.getOperation()); + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType( + signatureConverter.getConvertedTypes(), llvm::None)); + rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); + }); return matchSuccess(); } |

