diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp | 15 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 18 |
2 files changed, 26 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 60512232641..74000212373 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -43,6 +43,7 @@ using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using llvm::dbgs; +using llvm::SetVector; // Marker used as attribute name in generated Linalg rewriting transformations. const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = @@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); return success(); } + +LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, + Operation *op) { + LinalgOp linOp = dyn_cast<LinalgOp>(op); + SetVector<Value *> subViews; + for (auto it : linOp.getInputsAndOutputs()) + if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + subViews.insert(sv); + if (!subViews.empty()) { + auto resOp = promoteSubViewOperands(rewriter, linOp, subViews); + return success(resOp); + } + return failure(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 32b70346b97..c7fbebce383 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, return res; } -static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, - bool dynamicBuffers, - OperationFolder *folder) { +LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, + SetVector<Value *> subViews, + bool dynamicBuffers, + OperationFolder *folder) { // 1. Promote the specified views and use them in the new op. - OpBuilder b(op); ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); @@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, // extra scalars etc. auto operands = getAssumedNonViewOperands(op); opViews.append(operands.begin(), operands.end()); - op.clone(b, op.getLoc(), opViews); + LinalgOp res = op.clone(b, op.getLoc(), opViews); // 3. Emit write-back for the promoted output views: copy the partial view. for (auto viewAndPartialLocalView : writebackViews) { - // Note: use the old op to determine whether the operand view is an output. + // WARNING: MUST use the old op to determine whether the operand view is an + // output. bool isOutput = op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); if (isOutput) @@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, // 4. Dealloc local buffers. for (const auto &pi : promotedBufferAndViews) dealloc(pi.buffer); + + return res; } static void promoteSubViews(FuncOp f, bool dynamicBuffers) { @@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector<Value *> subViews; + OpBuilder b(op); for (auto it : op.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { - promoteSubViewOperands(op, subViews, dynamicBuffers, &folder); + promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); toErase.push_back(op); } }); |

