summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td13
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h4
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Utils/Utils.h13
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp18
-rw-r--r--mlir/test/Dialect/Linalg/transform-patterns.mlir50
-rw-r--r--mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td8
7 files changed, 113 insertions, 8 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index 6e3ec889503..d92eb77107f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -43,6 +43,13 @@ class AffineMapDomainHasDim<int n> : CPred<[{
$0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
+class HasOperandsOfType<string type>: CPred<[{
+ llvm::any_of($0.getOperands(),
+ [](Value* v) {
+ return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
+ })
+}]>;
+
//===----------------------------------------------------------------------===//
// Linalg fusion patterns.
//===----------------------------------------------------------------------===//
@@ -101,4 +108,10 @@ class PermuteGenericLinalgOp<list<int> permutation, string value> :
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">;
+//===----------------------------------------------------------------------===//
+// Linalg promote subview operands.
+//===----------------------------------------------------------------------===//
+class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
+ "if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
+ " return matchFailure();">;
#endif // LINALG_TRANSFORMS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index b103625a8a4..9682948dbd7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -95,6 +95,10 @@ LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
ArrayRef<unsigned> permutation,
StringRef linalgMarker);
+
+/// Promote std.subviews feeding linalg operations
+LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 994b3c9f185..9f1a8342252 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -23,6 +23,8 @@
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/EDSC/Helpers.h"
+#include "llvm/ADT/SetVector.h"
+
namespace mlir {
class AffineExpr;
class AffineMap;
@@ -217,6 +219,17 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
auxVec[i] = inVec[permutation[i]];
inVec = auxVec;
}
+
+/// Prepares the SubView promotion later performed by `promoteSubViews`
+/// (where most of the transformation happens). It arranges the new
+/// operands for `LinalgOp op` and deallocates the new buffer(s)
+/// It is the entry point for declarative transformation
+/// Returns the cloned `LinalgOp` with the new operands
+LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
+ llvm::SetVector<Value *> subViews,
+ bool dynamicBuffers = false,
+ OperationFolder *folder = nullptr);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 60512232641..74000212373 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -43,6 +43,7 @@ using namespace mlir::linalg;
using namespace mlir::linalg::intrinsics;
using llvm::dbgs;
+using llvm::SetVector;
// Marker used as attribute name in generated Linalg rewriting transformations.
const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
@@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
return success();
}
+
+LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
+ Operation *op) {
+ LinalgOp linOp = dyn_cast<LinalgOp>(op);
+ SetVector<Value *> subViews;
+ for (auto it : linOp.getInputsAndOutputs())
+ if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
+ subViews.insert(sv);
+ if (!subViews.empty()) {
+ auto resOp = promoteSubViewOperands(rewriter, linOp, subViews);
+ return success(resOp);
+ }
+ return failure();
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 32b70346b97..c7fbebce383 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
return res;
}
-static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
- bool dynamicBuffers,
- OperationFolder *folder) {
+LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
+ SetVector<Value *> subViews,
+ bool dynamicBuffers,
+ OperationFolder *folder) {
// 1. Promote the specified views and use them in the new op.
- OpBuilder b(op);
ScopedContext scope(b, op.getLoc());
auto promotedBufferAndViews = promoteSubViews(
b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder);
@@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// extra scalars etc.
auto operands = getAssumedNonViewOperands(op);
opViews.append(operands.begin(), operands.end());
- op.clone(b, op.getLoc(), opViews);
+ LinalgOp res = op.clone(b, op.getLoc(), opViews);
// 3. Emit write-back for the promoted output views: copy the partial view.
for (auto viewAndPartialLocalView : writebackViews) {
- // Note: use the old op to determine whether the operand view is an output.
+ // WARNING: MUST use the old op to determine whether the operand view is an
+ // output.
bool isOutput =
op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue();
if (isOutput)
@@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews,
// 4. Dealloc local buffers.
for (const auto &pi : promotedBufferAndViews)
dealloc(pi.buffer);
+
+ return res;
}
static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
@@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
// TODO(ntv) some heuristic here to decide what to promote. Atm it is all or
// nothing.
SetVector<Value *> subViews;
+ OpBuilder b(op);
for (auto it : op.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
- promoteSubViewOperands(op, subViews, dynamicBuffers, &folder);
+ promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
toErase.push_back(op);
}
});
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 4a9d8bc9564..8a08bf850ff 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -315,3 +315,53 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
// CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+
+func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+ %c2000 = constant 2000 : index
+ %c3000 = constant 3000 : index
+ %c4000 = constant 4000 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+ loop.for %arg3 = %c0 to %0 step %c2000 {
+ loop.for %arg4 = %c0 to %2 step %c3000 {
+ loop.for %arg5 = %c0 to %1 step %c4000 {
+ %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
+ memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
+ memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
+ memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} :
+ memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>
+ }
+ }
+ }
+ return
+}
+// CHECK-LABEL: func @promote_subview_matmul
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] {
+// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] {
+// CHECK : %[[s0:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK : %[[s1:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK : %[[s2:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK : %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK : %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
+// CHECK : %[[l0:.*]] = linalg.slice %[[v0]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
+// CHECK : %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK : %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
+// CHECK : %[[l1:.*]] = linalg.slice %[[v1]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
+// CHECK : %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK : %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32>
+// CHECK : %[[l2:.*]] = linalg.slice %[[v2]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}>
+// CHECK : linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index 4d8c9282f2d..d2313927398 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -115,7 +115,6 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
-
def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
(PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
[(Constraint<And<[HasNoLinalgTransformMarker,
@@ -126,4 +125,11 @@ def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
[(Constraint<And<[HasNoLinalgTransformMarker,
AffineMapDomainHasDim<3>]>> $op)]>;
+//===----------------------------------------------------------------------===//
+// Linalg subview operands promotion.
+//===----------------------------------------------------------------------===//
+def : Pat<(MatmulOp:$op $A, $B, $C),
+ (LinalgOpPromoteSubviews<"MatmulOp"> $op),
+ [(Constraint<HasOperandsOfType<"SubViewOp">> $op),
+ (Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
OpenPOWER on IntegriCloud