summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp18
2 files changed, 26 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 60512232641..74000212373 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -43,6 +43,7 @@ using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using llvm::dbgs;
+using llvm::SetVector;
// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
@@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
return success();
}
+
+LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
+ Operation *op) {
+ LinalgOp linOp = dyn_cast<LinalgOp>(op);
+ SetVector<Value *> subViews;
+ for (auto it : linOp.getInputsAndOutputs())
+ if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
+ subViews.insert(sv);
+ if (!subViews.empty()) {
+ auto resOp = promoteSubViewOperands(rewriter, linOp, subViews);
+ return success(resOp);
+ }
+ return failure();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 32b70346b97..c7fbebce383 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
return res;
}
-static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
- bool dynamicBuffers,
- OperationFolder *folder) {
+LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
+ SetVector<Value *> subViews,
+ bool dynamicBuffers,
+ OperationFolder *folder) {
// 1. Promote the specified views and use them in the new op.
- OpBuilder b(op);
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews(
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
@@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// extra scalars etc.
auto operands = getAssumedNonViewOperands(op);
opViews.append(operands.begin(), operands.end());
- op.clone(b, op.getLoc(), opViews);
+ LinalgOp res = op.clone(b, op.getLoc(), opViews);
// 3. Emit write-back for the promoted output views: copy the partial view.
for (auto viewAndPartialLocalView : writebackViews) {
- // Note: use the old op to determine whether the operand view is an output.
+ // WARNING: MUST use the old op to determine whether the operand view is an
+ // output.
bool isOutput =
op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
if (isOutput)
@@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// 4. Dealloc local buffers.
for (const auto &pi : promotedBufferAndViews)
dealloc(pi.buffer);
+
+ return res;
}
static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
@@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
SetVector<Value *> subViews;
+ OpBuilder b(op);
for (auto it : op.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
- promoteSubViewOperands(op, subViews, dynamicBuffers, &folder);
+ promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
toErase.push_back(op);
}
});
OpenPOWER on IntegriCloud