diff options
7 files changed, 113 insertions, 8 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 6e3ec889503..d92eb77107f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -43,6 +43,13 @@ class AffineMapDomainHasDim<int n> : CPred<[{ $0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0]. cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>; +class HasOperandsOfType<string type>: CPred<[{ + llvm::any_of($0.getOperands(), + [](Value* v) { + return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); + }) +}]>; + //===----------------------------------------------------------------------===// // Linalg fusion patterns. //===----------------------------------------------------------------------===// @@ -101,4 +108,10 @@ class PermuteGenericLinalgOp<list<int> permutation, string value> : StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " # " return matchFailure();">; +//===----------------------------------------------------------------------===// +// Linalg promote subview operands. +//===----------------------------------------------------------------------===// +class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall< + "if (failed(linalgOpPromoteSubviews($_builder, $0))) " # + " return matchFailure();">; #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index b103625a8a4..9682948dbd7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -95,6 +95,10 @@ LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op); LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, ArrayRef<unsigned> permutation, StringRef linalgMarker); + +/// Promote std.subviews feeding linalg operations +LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 994b3c9f185..9f1a8342252 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -23,6 +23,8 @@ #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/EDSC/Helpers.h" +#include "llvm/ADT/SetVector.h" + namespace mlir { class AffineExpr; class AffineMap; @@ -217,6 +219,17 @@ void applyPermutationToVector(SmallVector<T, N> &inVec, auxVec[i] = inVec[permutation[i]]; inVec = auxVec; } + +/// Prepares the SubView promotion later performed by `promoteSubViews` +/// (where most of the transformation happens). It arranges the new +/// operands for `LinalgOp op` and deallocates the new buffer(s) +/// It is the entry point for declarative transformation +/// Returns the cloned `LinalgOp` with the new operands +LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op, + llvm::SetVector<Value *> subViews, + bool dynamicBuffers = false, + OperationFolder *folder = nullptr); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 60512232641..74000212373 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -43,6 +43,7 @@ using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; using llvm::dbgs; +using llvm::SetVector; // Marker used as attribute name in generated Linalg rewriting transformations. const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = @@ -230,3 +231,17 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); return success(); } + +LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, + Operation *op) { + LinalgOp linOp = dyn_cast<LinalgOp>(op); + SetVector<Value *> subViews; + for (auto it : linOp.getInputsAndOutputs()) + if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) + subViews.insert(sv); + if (!subViews.empty()) { + auto resOp = promoteSubViewOperands(rewriter, linOp, subViews); + return success(resOp); + } + return failure(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 32b70346b97..c7fbebce383 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -160,11 +160,11 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, return res; } -static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, - bool dynamicBuffers, - OperationFolder *folder) { +LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, + SetVector<Value *> subViews, + bool dynamicBuffers, + OperationFolder *folder) { // 1. Promote the specified views and use them in the new op. - OpBuilder b(op); ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); @@ -189,11 +189,12 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, // extra scalars etc. auto operands = getAssumedNonViewOperands(op); opViews.append(operands.begin(), operands.end()); - op.clone(b, op.getLoc(), opViews); + LinalgOp res = op.clone(b, op.getLoc(), opViews); // 3. Emit write-back for the promoted output views: copy the partial view. for (auto viewAndPartialLocalView : writebackViews) { - // Note: use the old op to determine whether the operand view is an output. + // WARNING: MUST use the old op to determine whether the operand view is an + // output. bool isOutput = op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); if (isOutput) @@ -203,6 +204,8 @@ static void promoteSubViewOperands(LinalgOp op, SetVector<Value *> subViews, // 4. Dealloc local buffers. for (const auto &pi : promotedBufferAndViews) dealloc(pi.buffer); + + return res; } static void promoteSubViews(FuncOp f, bool dynamicBuffers) { @@ -212,11 +215,12 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector<Value *> subViews; + OpBuilder b(op); for (auto it : op.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { - promoteSubViewOperands(op, subViews, dynamicBuffers, &folder); + promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder); toErase.push_back(op); } }); diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 4a9d8bc9564..8a08bf850ff 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -315,3 +315,53 @@ func @matmul_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>, // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { // CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { // CHECK : linalg.matmul({{.*}}, {{.*}}, {{.*}}) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> + +func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, + %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>, + %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) { + %c2000 = constant 2000 : index + %c3000 = constant 3000 : index + %c4000 = constant 4000 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]> + %1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]> + %2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]> + loop.for %arg3 = %c0 to %0 step %c2000 { + loop.for %arg4 = %c0 to %2 step %c3000 { + loop.for %arg5 = %c0 to %1 step %c4000 { + %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : + memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]> + %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : + memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]> + %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : + memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]> + linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_views_"} : + memref<?x?xf32, offset: ?, strides: [?, ?]>, + memref<?x?xf32, offset: ?, strides: [?, ?]>, + memref<?x?xf32, offset: ?, strides: [?, ?]> + } + } + } + return +} +// CHECK-LABEL: func @promote_subview_matmul +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { +// CHECK : loop.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { +// CHECK : %[[s0:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}> +// CHECK : %[[s1:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}> +// CHECK : %[[s2:.*]] = std.subview {{%.*}}[{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}][{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}> +// CHECK : %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8> +// CHECK : %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32> +// CHECK : %[[l0:.*]] = linalg.slice %[[v0]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}> +// CHECK : %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8> +// CHECK : %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32> +// CHECK : %[[l1:.*]] = linalg.slice %[[v1]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}> +// CHECK : %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8> +// CHECK : %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}]: memref<?xi8> to memref<?x?xf32> +// CHECK : %[[l2:.*]] = linalg.slice %[[v2]][{{%.*}}, {{%.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #map{{.*}}> +// CHECK : linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}> +// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}> +// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}> +// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]> diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index 4d8c9282f2d..d2313927398 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -115,7 +115,6 @@ def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// - def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), [(Constraint<And<[HasNoLinalgTransformMarker, @@ -126,4 +125,11 @@ def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), [(Constraint<And<[HasNoLinalgTransformMarker, AffineMapDomainHasDim<3>]>> $op)]>; +//===----------------------------------------------------------------------===// +// Linalg subview operands promotion. +//===----------------------------------------------------------------------===// +def : Pat<(MatmulOp:$op $A, $B, $C), + (LinalgOpPromoteSubviews<"MatmulOp"> $op), + [(Constraint<HasOperandsOfType<"SubViewOp">> $op), + (Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>; #endif // TEST_LINALG_TRANSFORMS_PATTERNS |

