summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td30
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp6
-rw-r--r--mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td118
4 files changed, 80 insertions, 79 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index 8f6762f0048..08c6abedbe2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -18,24 +18,24 @@ include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td"
include "mlir/Dialect/AffineOps/AffineOps.td"
def HasNoLinalgTransformMarker : CPred<[{
- !$0.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
+ !op.getAttrOfType<StringAttr>(LinalgTransforms::kLinalgTransformMarker)
}]>;
class HasLinalgTransformMarker<string str> : CPred<[{
- $0.getAttrOfType<StringAttr>(
+ op.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker) &&
- $0.getAttrOfType<StringAttr>(
+ op.getAttrOfType<StringAttr>(
LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>;
class IsProducedByOpOfType<string str> :
- CPred<"isProducedByOpOfType<" # str # ">($0, $1)">;
+ CPred<"isProducedByOpOfType<" # str # ">(op, $0)">;
class AffineMapDomainHasDim<int n> : CPred<[{
- $0.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
+ op.getAttrOfType<ArrayAttr>(getIndexingMapsAttrName()).getValue()[0].
cast<AffineMapAttr>().getValue().getNumDims() ==}] # n # [{}]>;
class HasOperandsOfType<string type>: CPred<[{
- llvm::any_of($0.getOperands(),
+ llvm::any_of(op.getOperands(),
[](Value v) {
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
})
@@ -50,7 +50,7 @@ class HasOperandsOfType<string type>: CPred<[{
// patterns.
class TileAndFuseLinalgOp<
list<int> sizes, list<int> operandIndices, string value> : NativeCodeCall<
- "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" #
+ "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, {" # StrJoinInt<operandIndices>.result # "}," #
" \"" # value # "\")))" #
" return matchFailure();">;
@@ -67,7 +67,7 @@ class TileAndFuseLinalgOp<
// of elements as `sizes`.
class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
NativeCodeCall<
- "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" #
+ "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" #
StrJoinInt<sizes>.result # "}, \"" # value # "\", {" #
StrJoinInt<permutation>.result # "})))" #
" return matchFailure();">;
@@ -76,18 +76,18 @@ class TileLinalgOp<list<int> sizes, string value, list<int> permutation=[]> :
// Linalg to loop patterns.
//===----------------------------------------------------------------------===//
class LinalgOpToLoops<string OpType> : NativeCodeCall<
- "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " #
+ "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
- "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " #
+ "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
-class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
- "if (failed(vectorizeGenericOp($_builder, $0))) " #
+class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall<
+ "if (failed(vectorizeGenericLinalgOp($_builder, op))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
@@ -95,14 +95,14 @@ class LinalgOpToVectorContraction<string OpType> : NativeCodeCall<
//===----------------------------------------------------------------------===//
class PermuteGenericLinalgOp<list<int> permutation, string value> :
NativeCodeCall<
- "if (failed(permuteGenericLinalgOp($_builder, $0, {" #
+ "if (failed(permuteGenericLinalgOp($_builder, op, {" #
StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
" return matchFailure();">;
//===----------------------------------------------------------------------===//
// Linalg promote subview operands.
//===----------------------------------------------------------------------===//
-class LinalgOpPromoteSubviews<string OpType> : NativeCodeCall<
- "if (failed(linalgOpPromoteSubviews($_builder, $0))) " #
+class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall<
+ "if (failed(promoteSubviewsLinalgOp($_builder, op))) " #
" return matchFailure();">;
#endif // LINALG_TRANSFORMS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index 1bba2953273..a70921af5d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -79,7 +79,8 @@ template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
/// Rewrite a linalg.generic into a suitable vector.contraction op.
-LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op);
+LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter,
+ Operation *op);
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
@@ -88,7 +89,7 @@ LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
StringRef linalgMarker);
/// Promote std.subviews feeding linalg operations
-LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op);
+LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index eb23a8ceb1a..e9e44d7ba13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -153,8 +153,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}
-LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter,
- Operation *op) {
+LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
+ Operation *op) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
"]: Rewrite linalg op as vector.contract: "
<< *op << ":\n");
@@ -223,7 +223,7 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
return success();
}
-LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter,
+LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
Operation *op) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
SetVector<Value> subViews;
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index d07f6060c3b..6bad586cafa 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -19,11 +19,11 @@ include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
//===----------------------------------------------------------------------===//
// Test Linalg fusion patterns.
//===----------------------------------------------------------------------===//
-def : Pat<(MatmulOp:$consumer $A, $B, $C),
- (TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer),
+def : Pat<(MatmulOp:$op $A, $_, $_),
+ (TileAndFuseLinalgOp<[100, 150], [0], "L1">),
[
- (Constraint<HasNoLinalgTransformMarker> $consumer),
- (Constraint<IsProducedByOpOfType<"MatmulOp">> $consumer, $A),
+ (Constraint<HasNoLinalgTransformMarker>),
+ (Constraint<IsProducedByOpOfType<"MatmulOp">> $A),
],
// In the buffer world there is no use-def chains or dags so benefits
// cannot be computed automatically from the length of the matched
@@ -36,91 +36,91 @@ def : Pat<(MatmulOp:$consumer $A, $B, $C),
//===----------------------------------------------------------------------===//
// Linalg tiling patterns.
//===----------------------------------------------------------------------===//
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[2000, 3000, 4000], "L3"> $op),
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[2000, 3000, 4000], "L3">),
[(Constraint<Or<[HasNoLinalgTransformMarker,
- HasLinalgTransformMarker<"MEM">]>> $op)]>;
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[200, 300, 400], "L2"> $op),
- [(Constraint<HasLinalgTransformMarker<"L3">> $op)]>;
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[20, 30, 40], "L1"> $op),
- [(Constraint<HasLinalgTransformMarker<"L2">> $op)]>;
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[2, 3, 4], "REG"> $op),
- [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+ HasLinalgTransformMarker<"MEM">]>>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[200, 300, 400], "L2">),
+ [(Constraint<HasLinalgTransformMarker<"L3">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[20, 30, 40], "L1">),
+ [(Constraint<HasLinalgTransformMarker<"L2">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[2, 3, 4], "REG">),
+ [(Constraint<HasLinalgTransformMarker<"L1">>)]>;
-def : Pattern<(MatvecOp:$op $A, $b, $c),
- [(TileLinalgOp<[5, 6], "L1"> $op)],
- [(Constraint<HasNoLinalgTransformMarker> $op)]>;
+def : Pattern<(MatvecOp:$op $_, $_, $_),
+ [(TileLinalgOp<[5, 6], "L1">)],
+ [(Constraint<HasNoLinalgTransformMarker>)]>;
-def : Pattern<(DotOp:$op $a, $b, $c),
- [(TileLinalgOp<[8000], "L1"> $op)],
+def : Pattern<(DotOp:$op $_, $_, $_),
+ [(TileLinalgOp<[8000], "L1">)],
[(Constraint<Or<[HasNoLinalgTransformMarker,
HasLinalgTransformMarker<"MEM">,
HasLinalgTransformMarker<"L3">,
- HasLinalgTransformMarker<"L2">]>> $op)]>;
-def : Pattern<(DotOp:$op $a, $b, $c),
- [(TileLinalgOp<[8], "REG"> $op)],
- [(Constraint<HasLinalgTransformMarker<"L1">> $op)]>;
+ HasLinalgTransformMarker<"L2">]>>)]>;
+def : Pattern<(DotOp:$op $_, $_, $_),
+ [(TileLinalgOp<[8], "REG">)],
+ [(Constraint<HasLinalgTransformMarker<"L1">>)]>;
//===----------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===----------------------------------------------------------------------===//
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op),
- [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op),
- [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">> $op)]>;
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op),
- [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>),
+ [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>),
+ [(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (TileLinalgOp<[20, 30, 40], "REG__with_perm__">),
+ [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
-def : Pattern<(MatvecOp:$op $A, $b, $c),
- [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)],
- [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
+def : Pattern<(MatvecOp:$op $_, $_, $_),
+ [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)],
+ [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
-def : Pattern<(DotOp:$op $a, $b, $c),
- [(TileLinalgOp<[8000], "L1__with_perm__"> $op)],
- [(Constraint<HasLinalgTransformMarker<"__with_perm__">> $op)]>;
-def : Pattern<(DotOp:$op $a, $b, $c),
- [(TileLinalgOp<[8], "REG__with_perm__"> $op)],
- [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">> $op)]>;
+def : Pattern<(DotOp:$op $_, $_, $_),
+ [(TileLinalgOp<[8000], "L1__with_perm__">)],
+ [(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
+def : Pattern<(DotOp:$op $_, $_, $_),
+ [(TileLinalgOp<[8], "REG__with_perm__">)],
+ [(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to loops patterns.
//===----------------------------------------------------------------------===//
-def : Pattern<(DotOp:$op $a, $b, $c),
- [(LinalgOpToLoops<"DotOp"> $op)],
- [(Constraint<HasLinalgTransformMarker<"REG">> $op)]>;
+def : Pattern<(DotOp:$op $_, $_, $_),
+ [(LinalgOpToLoops<"DotOp">)],
+ [(Constraint<HasLinalgTransformMarker<"REG">>)]>;
//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
-def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
- [(LinalgOpToVectorContraction<"GenericOp"> $op)],
- [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">> $op)]>;
+def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
+ [(VectorizeGenericLinalgOp<"GenericOp">)],
+ [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">>)]>;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
-def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
- (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
+def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
+ (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
[(Constraint<And<[HasNoLinalgTransformMarker,
- AffineMapDomainHasDim<3>]>> $op)]>;
+ AffineMapDomainHasDim<3>]>>)]>;
-def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8),
- (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op),
+def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
+ (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
[(Constraint<And<[HasNoLinalgTransformMarker,
- AffineMapDomainHasDim<3>]>> $op)]>;
+ AffineMapDomainHasDim<3>]>>)]>;
//===----------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===----------------------------------------------------------------------===//
-def : Pat<(MatmulOp:$op $A, $B, $C),
- (LinalgOpPromoteSubviews<"MatmulOp"> $op),
- [(Constraint<HasOperandsOfType<"SubViewOp">> $op),
- (Constraint<HasLinalgTransformMarker<"_promote_views_">> $op)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (PromoteSubviewsLinalgOp<"MatmulOp">),
+ [(Constraint<HasOperandsOfType<"SubViewOp">>),
+ (Constraint<HasLinalgTransformMarker<"_promote_views_">>)]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
OpenPOWER on IntegriCloud