diff options
4 files changed, 80 insertions, 79 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td index 8f6762f0048..08c6abedbe2 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -18,24 +18,24 @@ include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" include "mlir/Dialect/AffineOps/AffineOps.td" def HasNoLinalgTransformMarker : CPred<[{ - !$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker) + !op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker) }]>; class HasLinalgTransformMarker<string str> : CPred<[{ - $0.getAttrOfType<StringAttr>( + op.getAttrOfType<StringAttr>( LinalgTransforms::kLinalgTransformMarker) && - $0.getAttrOfType<StringAttr>( + op.getAttrOfType<StringAttr>( LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>; class IsProducedByOpOfType<string str> : - CPred<"isProducedByOpOfType<" # str # ">($0, $1)">; + CPred<"isProducedByOpOfType<" # str # ">(op, $0)">; class AffineMapDomainHasDim<int n> : CPred<[{ - $0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0]. + op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0]. cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>; class HasOperandsOfType<string type>: CPred<[{ - llvm::any_of($0.getOperands(), + llvm::any_of(op.getOperands(), [](Value v) { return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); }) @@ -50,7 +50,7 @@ class HasOperandsOfType<string type>: CPred<[{ // patterns. class TileAndFuseLinalgOp< list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall< - "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" # + "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," # " \"" # value # "\")))" # " return matchFailure();">; @@ -67,7 +67,7 @@ class TileAndFuseLinalgOp< // of elements as `sizes`. class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : NativeCodeCall< - "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # + "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt<sizes>.result # "}, \"" # value # "\", {" # StrJoinInt<permutation>.result # "})))" # " return matchFailure();">; @@ -76,18 +76,18 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> : // Linalg to loop patterns. //===----------------------------------------------------------------------===// class LinalgOpToLoops<string OpType> : NativeCodeCall< - "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; class LinalgOpToAffineLoops<string OpType> : NativeCodeCall< - "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// -class LinalgOpToVectorContraction<string OpType> : NativeCodeCall< - "if (failed(vectorizeGenericOp($_builder, $0))) " # +class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall< + "if (failed(vectorizeGenericLinalgOp($_builder, op))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// @@ -95,14 +95,14 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall< //===----------------------------------------------------------------------===// class PermuteGenericLinalgOp<list<int> permutation, string value> : NativeCodeCall< - "if (failed(permuteGenericLinalgOp($_builder, $0, {" # + "if (failed(permuteGenericLinalgOp($_builder, op, {" # StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// // Linalg promote subview operands. //===----------------------------------------------------------------------===// -class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall< - "if (failed(linalgOpPromoteSubviews($_builder, $0))) " # +class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall< + "if (failed(promoteSubviewsLinalgOp($_builder, op))) " # " return matchFailure();">; #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h index 1bba2953273..a70921af5d6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -79,7 +79,8 @@ template <typename ConcreteOp> LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); /// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op); +LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op); /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` /// and `iterator_types` permutated according to `permutation`. @@ -88,7 +89,7 @@ LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, StringRef linalgMarker); /// Promote std.subviews feeding linalg operations -LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op); +LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index eb23a8ceb1a..e9e44d7ba13 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -153,8 +153,8 @@ static bool isMatmul(linalg::GenericOp genericOp) { genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter, - Operation *op) { +LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op) { LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Rewrite linalg op as vector.contract: " << *op << ":\n"); @@ -223,7 +223,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, return success(); } -LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, +LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op) { LinalgOp linOp = dyn_cast<LinalgOp>(op); SetVector<Value> subViews; diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td index d07f6060c3b..6bad586cafa 100644 --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -19,11 +19,11 @@ include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td" //===----------------------------------------------------------------------===// // Test Linalg fusion patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$consumer $A, $B, $C), - (TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer), +def : Pat<(MatmulOp:$op $A, $_, $_), + (TileAndFuseLinalgOp<[100, 150], [0], "L1">), [ - (Constraint<HasNoLinalgTransformMarker> $consumer), - (Constraint<IsProducedByOpOfType<"MatmulOp">> $consumer, $A), + (Constraint<HasNoLinalgTransformMarker>), + (Constraint<IsProducedByOpOfType<"MatmulOp">> $A), ], // In the buffer world there is no use-def chains or dags so benefits // cannot be computed automatically from the length of the matched @@ -36,91 +36,91 @@ def : Pat<(MatmulOp:$consumer $A, $B, $C), //===----------------------------------------------------------------------===// // Linalg tiling patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2000, 3000, 4000], "L3">), [(Constraint<Or<[HasNoLinalgTransformMarker, - HasLinalgTransformMarker<"MEM">]>> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[200, 300, 400], "L2"> $op), - [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[20, 30, 40], "L1"> $op), - [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2, 3, 4], "REG"> $op), - [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>; + HasLinalgTransformMarker<"MEM">]>>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[200, 300, 400], "L2">), + [(Constraint<HasLinalgTransformMarker<"L3">>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[20, 30, 40], "L1">), + [(Constraint<HasLinalgTransformMarker<"L2">>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2, 3, 4], "REG">), + [(Constraint<HasLinalgTransformMarker<"L1">>)]>; -def : Pattern<(MatvecOp:$op $A, $b, $c), - [(TileLinalgOp<[5, 6], "L1"> $op)], - [(Constraint<HasNoLinalgTransformMarker> $op)]>; +def : Pattern<(MatvecOp:$op $_, $_, $_), + [(TileLinalgOp<[5, 6], "L1">)], + [(Constraint<HasNoLinalgTransformMarker>)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8000], "L1"> $op)], +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8000], "L1">)], [(Constraint<Or<[HasNoLinalgTransformMarker, HasLinalgTransformMarker<"MEM">, HasLinalgTransformMarker<"L3">, - HasLinalgTransformMarker<"L2">]>> $op)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8], "REG"> $op)], - [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>; + HasLinalgTransformMarker<"L2">]>>)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8], "REG">)], + [(Constraint<HasLinalgTransformMarker<"L1">>)]>; //===----------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op), - [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op), - [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op), - [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>), + [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>), + [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[20, 30, 40], "REG__with_perm__">), + [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>; -def : Pattern<(MatvecOp:$op $A, $b, $c), - [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)], - [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>; +def : Pattern<(MatvecOp:$op $_, $_, $_), + [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)], + [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8000], "L1__with_perm__"> $op)], - [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8], "REG__with_perm__"> $op)], - [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8000], "L1__with_perm__">)], + [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8], "REG__with_perm__">)], + [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>; //===----------------------------------------------------------------------===// // Linalg to loops patterns. //===----------------------------------------------------------------------===// -def : Pattern<(DotOp:$op $a, $b, $c), - [(LinalgOpToLoops<"DotOp"> $op)], - [(Constraint<HasLinalgTransformMarker<"REG">> $op)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(LinalgOpToLoops<"DotOp">)], + [(Constraint<HasLinalgTransformMarker<"REG">>)]>; //===----------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// -def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - [(LinalgOpToVectorContraction<"GenericOp"> $op)], - [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>; +def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + [(VectorizeGenericLinalgOp<"GenericOp">)], + [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">>)]>; //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), +def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">), [(Constraint<And<[HasNoLinalgTransformMarker, - AffineMapDomainHasDim<3>]>> $op)]>; + AffineMapDomainHasDim<3>]>>)]>; -def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), +def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">), [(Constraint<And<[HasNoLinalgTransformMarker, - AffineMapDomainHasDim<3>]>> $op)]>; + AffineMapDomainHasDim<3>]>>)]>; //===----------------------------------------------------------------------===// // Linalg subview operands promotion. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (LinalgOpPromoteSubviews<"MatmulOp"> $op), - [(Constraint<HasOperandsOfType<"SubViewOp">> $op), - (Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp<"MatmulOp">), + [(Constraint<HasOperandsOfType<"SubViewOp">>), + (Constraint<HasLinalgTransformMarker<"_promote_views_">>)]>; #endif // TEST_LINALG_TRANSFORMS_PATTERNS |