summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td18
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td17
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td88
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h90
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td28
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h23
-rw-r--r--mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h5
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp7
-rw-r--r--mlir/lib/Dialect/Linalg/EDSC/Builders.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp59
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp85
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp6
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir45
-rw-r--r--mlir/test/Dialect/Linalg/roundtrip.mlir23
-rw-r--r--mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td34
-rw-r--r--mlir/tools/mlir-tblgen/RewriterGen.cpp14
16 files changed, 408 insertions, 135 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index c1adc8b4d05..5edca25b93a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -19,9 +19,10 @@ def Linalg_Dialect : Dialect {
let name = "linalg";
let description = [{
The `linalg` dialect groups together a set of types, operations and
- transformations that are useful to implement a structured abstraction where
- ops can lower to scalar load/store and operations or to more general library
- calls.
+ transformations that are useful to implement a structured abstraction on
+ buffers and tensors. These abstractions are useful for transformations and
+ can lower to scalar load/store and other operations or to more general
+ library calls.
The `linalg` dialect manipulates the following types and operations:
@@ -67,12 +68,13 @@ def Linalg_Dialect : Dialect {
A set of payload carrying operations that implement the [structured ops](
https://docs.google.com/presentation/d/1P-j1GrH6Q5gLBjao0afQ-GfvcAeF-QU4GXXeSy0eJ9I/edit#slide=id.p
)
- abstraction on buffers. `linalg` has `2` generic operations `linalg.generic`
- and `linalg.indexed_generic` for expressing custom operations. This is
- subject to further evolution as transformations and analyses continue to be
- developed.
+ abstraction on tensors and buffers. `linalg` has `2` generic operations
+ `linalg.generic` and `linalg.indexed_generic` for expressing custom
+ operations.
+ This is subject to further evolution as transformations and analyses
+ continue to be developed.
- Additionally, `linalg` provides some common named operations:
+ Additionally, `linalg` provides some commonly named operations:
* `linalg.copy`,
* `linalg.fill`,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 0445968ee80..d517c0a61aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -59,7 +59,8 @@ def Linalg_RangeOp :
}
def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
- Arguments<(ins AnyStridedMemRef:$view, Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,
+ Arguments<(ins AnyStridedMemRef:$view,
+ Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,
Results<(outs AnyStridedMemRef)> {
let summary = "Produce a rank-reduced `subview` of a base `view`.";
let description = [{
@@ -108,11 +109,11 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
let extraClassDeclaration = [{
enum { FirstIndexingOperand = 1 };
- unsigned getRank() { return getViewType().getRank(); }
- Type getElementType() { return getViewType().getElementType(); }
- MemRefType getViewType() { return getType().cast<MemRefType>(); }
+ unsigned getRank() { return getShapedType().getRank(); }
+ Type getElementType() { return getShapedType().getElementType(); }
+ ShapedType getShapedType() { return getType().cast<ShapedType>(); }
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
- MemRefType getBaseViewType() { return view()->getType().cast<MemRefType>(); }
+ ShapedType getBaseViewType() { return view()->getType().cast<ShapedType>();}
// Get the underlying indexing at a given rank.
Value indexing(unsigned rank) { return *(indexings().begin() + rank); }
@@ -131,7 +132,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>,
Results<(outs AnyStridedMemRef)> {
- let summary = "transpose operation produces a new strided memref (metadata-only)";
+ let summary = "`transpose` produces a new strided memref (metadata-only)";
let description = [{
The `linalg.transpose` op produces a strided memref whose sizes and strides
are a permutation of the original `view`. This is a pure metadata
@@ -151,14 +152,14 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
let verifier = [{
if (!permutation().isPermutation())
return emitOpError("expected a permutation map");
- if (permutation().getNumDims() != getViewType().getRank())
+ if (permutation().getNumDims() != getShapedType().getRank())
return emitOpError("expected a permutation map of same rank as the view");
return success();
}];
let extraClassDeclaration = [{
static StringRef getPermutationAttrName() { return "permutation"; }
- MemRefType getViewType() { return view()->getType().cast<MemRefType>(); }
+ ShapedType getShapedType() { return view()->getType().cast<ShapedType>(); }
}];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index ed09272055b..95963338491 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -89,23 +89,32 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
"Value ", "getOutput", (ins "unsigned":$i)
>,
InterfaceMethod<[{
- Query the index of the given input value, or `None` if the value is not
- an input.
+ Return the index of the given input value `v`, or `None` if the value is
+ not an input.
}],
- "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$view)
+ "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v)
>,
InterfaceMethod<[{
Query the index of the given view value, or `None` if the value is not
- an view.
+ a view.
}],
"llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
>,
InterfaceMethod<[{
- Query the type of the input view at the given index.
- }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>,
+ Query the type of the input shape at the given index.
+ }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
InterfaceMethod<[{
Query the type of the output view at the given index.
- }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>,
+ }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
+ InterfaceMethod<[{
+ Query whether the op has only MemRef input and outputs.
+ }], "bool", "hasBufferSemantics">,
+ InterfaceMethod<[{
+ Query the subset of input operands that are of ranked tensor type.
+ }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
+ InterfaceMethod<[{
+ Query the subset of output operands that are of ranked tensor type.
+ }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
StaticInterfaceMethod<[{
Create an operation of the current type with the given location,
@@ -340,7 +349,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
ArrayAttr iterator_types() {
// Outer parallel loops are always the number of output dimensions; i.e.
// [ b, xs, q] in the TF notation above.
- unsigned nPar = getOutputViewType(0).getRank();
+ unsigned nPar = getOutputShapedType(0).getRank();
unsigned nRed = getNumInputFeatureDimensions();
// Window loops are a special kind of reduction that is never tiled or
// parallelized across; i.e. [zs] in the TF notation above whose number
@@ -374,8 +383,17 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
let verifier = [{ return ::verify(*this); }];
}
+def LinalgOperand: Type<
+ Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>;
+
+class LinalgOperandOfRank<int rank>: Type<
+ And<[
+ LinalgOperand.predicate,
+ CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
+ >>;
+
class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
- let arguments = (ins Variadic<AnyStridedMemRef>:$views,
+ let arguments = (ins Variadic<LinalgOperand>:$views,
I64Attr:$args_in,
I64Attr:$args_out,
AffineMapArrayAttr:$indexing_maps,
@@ -383,6 +401,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
OptionalAttr<StrAttr>:$doc,
OptionalAttr<FlatSymbolRefAttr>:$fun,
OptionalAttr<StrAttr>:$library_call);
+ let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
@@ -511,6 +530,28 @@ def GenericOp : GenericOpBase<"generic"> {
}
}
```
+
+ To allow progressive lowering from the value world (a.k.a tensor values) to
+ the buffer world (a.k.a memref values), a `linalg.generic` op accepts
+ mixing input and output ranked tensor values with input and output memrefs.
+
+ ```mlir
+ %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
+ tensor<?x?xf32>,
+ memref<?x?xf32, stride_specification>,
+ tensor<?x?xf32>
+ -> (tensor<?x?xf32>)
+ ```
+
+ In this case, the number of return values must match the number of output
+ tensor arguments. The semantics is that the `linalg.generic` op
+ produces (i.e. allocates and fills) its return values.
+ Tensor values must be legalized by a buffer allocation pass before most
+ transformations can be applied. In particular, transformations that create
+ control flow around linalg.generic operations are not expected to mix with
+ tensors because SSA values do not escape naturally. Still, transformations
+ and rewrites that take advantage of tensor SSA values are expected to be
+ useful and will be added in the near future.
}];
let verifier = [{ return ::verify(*this); }];
}
@@ -555,9 +596,11 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
Example:
Defining a #matmul_trait attribute in MLIR can be done as follows:
```mlir
- func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
+ func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
+ %a: f32, %b: f32, %c: f32)
-> f32
{
+ "some_optional_condition"(%offset_m, %offset_n, %offset_k)
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
return %e: f32
@@ -587,7 +630,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
This may lower to either:
```mlir
- call @linalg_matmul(%A, %B, %C) :
+ call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
(memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>,
memref<?x?xf32, stride_specification>)
@@ -609,6 +652,29 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
}
}
```
+
+ To allow progressive lowering from the value world (a.k.a tensor values) to
+ the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
+ accepts mixing input and output ranked tensor values with input and output
+ memrefs.
+
+ ```mlir
+ %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes}
+ : tensor<?x?xf32>,
+ memref<?x?xf32, stride_specification>,
+ tensor<?x?xf32>
+ -> (tensor<?x?xf32>)
+ ```
+
+ In this case, the number of return values must match the number of output
+ tensor arguments. The semantics is that the `linalg.indexed_generic` op
+ produces (i.e. allocates and fills) its return values.
+ Tensor values must be legalized by a buffer allocation pass before most
+ transformations can be applied. In particular, transformations that create
+ control flow around linalg.generic operations are not expected to mix with
+ tensors because SSA values do not escape naturally. Still, transformations
+ and rewrites that take advantage of tensor SSA values are expected to be
+ useful and will be added in the near future.
}];
let verifier = [{ return ::verify(*this); }];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index 0706f1fd363..58e4726bc35 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -20,7 +20,7 @@ namespace OpTrait {
namespace linalg {
/// This class provides the API for ops that are known to have a specified
-/// number of inputs, all passed as operands. This is used as a trait like this:
+/// number of inputs, all passed as operands. Use as a trait as follows:
///
/// class DotOp : public Op<DotOp, OpTrait::NInputs<2>::Impl> {
///
@@ -34,7 +34,7 @@ public:
};
/// This class provides the API for ops that are known to have a specified
-/// number of inputs, all passed as operands. This is used as a trait like this:
+/// number of outputs, all passed as operands. Use as a trait as follows:
///
/// class DotOp : public Op<DotOp, OpTrait::NOutputs<2>::Impl> {
///
@@ -47,79 +47,101 @@ public:
};
};
-/// This class provides the API for ops that are known to operate on views. This
-/// trait must be used in conjunction with an op definition or a trait that
-/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a
-/// trait like this:
+/// This class provides the API for structured ops that are known to operate on
+/// buffers or tensors. This trait must be used in conjunction with an op
+/// definition or a trait that provides the methods `getNumInputs` and
+/// `getNumOutputs`. Use as a trait as follows:
///
-/// class DotOp : public Op<DotOp, OpTrait::ViewTrait> {
+/// class DotOp : public Op<DotOp, OpTrait::StructuredOpTraits> {
///
template <typename ConcreteType>
class StructuredOpTraits
: public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
private:
- /// Return the number of input views. For internal use only.
+ /// Return the number of inputs. For internal use only.
unsigned nInputs() {
return cast<ConcreteType>(this->getOperation()).getNumInputs();
}
- /// Return the number of input views. For internal use only.
+ /// Return the number of outputs. For internal use only.
unsigned nOutputs() {
return cast<ConcreteType>(this->getOperation()).getNumOutputs();
}
public:
- /// Return the `i`-th input view.
+ /// Return the `i`-th input value.
Value getInput(unsigned i) {
assert(i < nInputs());
return this->getOperation()->getOperand(i);
}
- /// Return the index of `view` in the list of input views if found, llvm::None
+ /// Return the index of `value` in the list of inputs if found, llvm::None
/// otherwise.
- Optional<unsigned> getIndexOfInput(Value view) {
- auto it = llvm::find(getInputs(), view);
+ Optional<unsigned> getIndexOfInput(Value value) {
+ auto it = llvm::find(getInputs(), value);
if (it != getInputs().end())
return it - getInputs().begin();
return llvm::None;
}
- /// Return the `i`-th input view type.
- MemRefType getInputViewType(unsigned i) {
- return getInput(i)->getType().template cast<MemRefType>();
+ /// Return the `i`-th input buffer type.
+ ShapedType getInputShapedType(unsigned i) {
+ return getInput(i)->getType().template cast<ShapedType>();
}
- /// Return the range over input views.
+ /// Return the range over inputs.
Operation::operand_range getInputs() {
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + nInputs()};
}
- /// Return the `i`-th output view.
+ /// Return the `i`-th output.
Value getOutput(unsigned i) {
return this->getOperation()->getOperand(nInputs() + i);
}
- /// Return the index of `view` in the list of output views if found,
+ /// Return the index of `value` in the list of output values if found,
/// llvm::None otherwise.
- Optional<unsigned> getIndexOfOutput(Value view) {
- auto it = llvm::find(getOutputs(), view);
+ Optional<unsigned> getIndexOfOutput(Value value) {
+ auto it = llvm::find(getOutputs(), value);
if (it != getOutputs().end())
return it - getOutputs().begin();
return llvm::None;
}
- /// Return the `i`-th output view type.
- MemRefType getOutputViewType(unsigned i) {
- return getOutput(i)->getType().template cast<MemRefType>();
- }
- /// Return the range over output views.
+ /// Return the `i`-th output buffer type.
+ ShapedType getOutputShapedType(unsigned i) {
+ return getOutput(i)->getType().template cast<ShapedType>();
+ }
+ /// Query whether the op has only MemRef input and outputs.
+ bool hasBufferSemantics() {
+ return this->getOperation()->getNumResults() == 0 &&
+ llvm::all_of(getInputsAndOutputs(),
+ [](Value v) { return v.getType().isa<MemRefType>(); });
+ }
+ /// Query the subset of input operands that are of ranked tensor type.
+ SmallVector<RankedTensorType, 4> getInputTensorTypes() {
+ SmallVector<RankedTensorType, 4> res;
+ for (Type type : getInputs().getTypes())
+ if (auto t = type.template dyn_cast<RankedTensorType>())
+ res.push_back(t);
+ return res;
+ }
+ /// Query the subset of output operands that are of ranked tensor type.
+ SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
+ SmallVector<RankedTensorType, 4> res;
+ for (Type type : getOutputs().getTypes())
+ if (auto t = type.template dyn_cast<RankedTensorType>())
+ res.push_back(t);
+ return res;
+ }
+ /// Return the range over outputs.
Operation::operand_range getOutputs() {
auto range = this->getOperation()->getOperands();
return {range.begin() + nInputs(),
range.begin() + getNumInputsAndOutputs()};
}
- /// Return the number of input and output views.
+ /// Return the number of inputs and outputs.
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
- /// Return the `i`-th view type.
- MemRefType getViewType(unsigned i) {
- return (i < nInputs()) ? getInputViewType(i)
- : getOutputViewType(i - nInputs());
+ /// Return the `i`-th buffer type.
+ ShapedType getShapedType(unsigned i) {
+ return (i < nInputs()) ? getInputShapedType(i)
+ : getOutputShapedType(i - nInputs());
}
- /// Return the range over input and output views.
+ /// Return the range over inputs and outputs.
Operation::operand_range getInputsAndOutputs() {
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + getNumInputsAndOutputs()};
@@ -144,8 +166,8 @@ public:
cast<ConcreteType>(this->getOperation()).iterator_types());
}
static LogicalResult verifyTrait(Operation *op) {
- auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
+ auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs();
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
return failure();
return success();
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index 08c6abedbe2..532fda2411d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -84,25 +84,29 @@ class LinalgOpToAffineLoops<string OpType> : NativeCodeCall<
" return matchFailure();">;
//===----------------------------------------------------------------------===//
-// Linalg to vector contraction patterns.
+// Linalg to vector patterns precondition and DRR.
//===----------------------------------------------------------------------===//
-class VectorizeGenericLinalgOp<string OpType> : NativeCodeCall<
- "if (failed(vectorizeGenericLinalgOp($_builder, op))) " #
- " return matchFailure();">;
+def PreconditionVectorizeGenericLinalgOp : CPred<
+ "succeeded(vectorizeGenericLinalgOpPrecondition(op))">;
+def VectorizeGenericLinalgOp : NativeCodeCall<
+ "vectorizeGenericLinalgOp($_builder, op)">;
//===----------------------------------------------------------------------===//
-// Linalg generic permutation patterns.
+// Linalg generic permutation patterns precondition and DRR.
//===----------------------------------------------------------------------===//
+class PreconditionPermuteGenericLinalgOp<list<int> permutation> : CPred<
+ "succeeded(permuteGenericLinalgOpPrecondition(op, {" #
+ StrJoinInt<permutation>.result # "}))">;
class PermuteGenericLinalgOp<list<int> permutation, string value> :
NativeCodeCall<
- "if (failed(permuteGenericLinalgOp($_builder, op, {" #
- StrJoinInt<permutation>.result # "}, \"" # value # "\"))) " #
- " return matchFailure();">;
+ "permuteGenericLinalgOp($_builder, op, {" # StrJoinInt<permutation>.result #
+ "}, \"" # value # "\")">;
//===----------------------------------------------------------------------===//
-// Linalg promote subview operands.
+// Linalg promote subview operands precondition and DRR.
//===----------------------------------------------------------------------===//
-class PromoteSubviewsLinalgOp<string OpType> : NativeCodeCall<
- "if (failed(promoteSubviewsLinalgOp($_builder, op))) " #
- " return matchFailure();">;
+def PreconditionPromoteSubviewsLinalgOp : CPred<
+ "succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
+def PromoteSubviewsLinalgOp : NativeCodeCall<
+ "promoteSubviewsLinalgOp($_builder, op)">;
#endif // LINALG_TRANSFORMS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index a70921af5d6..135756358e0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -79,17 +79,24 @@ template <typename ConcreteOp>
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op);
/// Rewrite a linalg.generic into a suitable vector.contraction op.
-LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter,
- Operation *op);
+LogicalResult vectorizeGenericLinalgOpPrecondition(Operation *op);
+SmallVector<Value, 0> vectorizeGenericLinalgOp(PatternRewriter &rewriter,
+ Operation *op);
/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps`
/// and `iterator_types` permutated according to `permutation`.
-LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
- ArrayRef<unsigned> permutation,
- StringRef linalgMarker);
-
-/// Promote std.subviews feeding linalg operations
-LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op);
+LogicalResult
+permuteGenericLinalgOpPrecondition(Operation *op,
+ ArrayRef<unsigned> permutation);
+SmallVector<Value, 0> permuteGenericLinalgOp(PatternRewriter &rewriter,
+ Operation *op,
+ ArrayRef<unsigned> permutation,
+ StringRef linalgMarker);
+
+/// Promote std.subviews feeding linalg operations.
+LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
+SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
+ Operation *op);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
index feb8bd60445..2922dfd5b82 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h
@@ -64,8 +64,9 @@ namespace vector {
//
// This will be extended in the future to support more advanced use cases than
// simple pointwise ops.
-Value unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
- ArrayRef<int64_t> targetShape);
+SmallVector<Value, 1>
+unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
+ ArrayRef<int64_t> targetShape);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 2a034fd15c5..2dd36c94d31 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -186,7 +186,8 @@ public:
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64))
.cast<LLVM::LLVMType>();
- BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType()));
+ BaseViewConversionHelper desc(
+ lowering.convertType(sliceOp.getShapedType()));
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value, 4> strides(memRefType.getRank());
@@ -215,7 +216,7 @@ public:
desc.setOffset(baseOffset);
// Corner case, no sizes or strides: early return the descriptor.
- if (sliceOp.getViewType().getRank() == 0)
+ if (sliceOp.getShapedType().getRank() == 0)
return rewriter.replaceOp(op, {desc}), matchSuccess();
Value zero =
@@ -279,7 +280,7 @@ public:
return rewriter.replaceOp(op, {baseDesc}), matchSuccess();
BaseViewConversionHelper desc(
- lowering.convertType(transposeOp.getViewType()));
+ lowering.convertType(transposeOp.getShapedType()));
// Copy the base and aligned pointers from the old descriptor to the new
// one.
diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
index de5b1d1f631..b35a8ed0fd8 100644
--- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
+++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
@@ -68,6 +68,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
edsc::ScopedContext::getBuilder()
.create<linalg::GenericOp>(
edsc::ScopedContext::getLocation(),
+ ArrayRef<Type>{}, // TODO(ntv): support tensors
values,
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 415f4181704..813bdb998aa 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -61,6 +61,10 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << ": " << op.getOperandTypes();
+
+ auto outputTensorTypes = op.getResultTypes();
+ if (!outputTensorTypes.empty())
+ p << " -> " << outputTensorTypes;
}
static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); }
@@ -92,6 +96,13 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(operandTypes))
return failure();
+ // Generic ops may specify that a subset of its outputs are tensors. Such
+ // outputs are specified in the result type.
+ SmallVector<Type, 8> tensorResultTypes;
+ if (parser.parseOptionalArrowTypeList(tensorResultTypes))
+ return failure();
+ if (!tensorResultTypes.empty())
+ result.addTypes(tensorResultTypes);
return parser.resolveOperands(operandsInfo, operandTypes,
parser.getCurrentLocation(), result.operands);
}
@@ -107,7 +118,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
"expected number of block arguments to match number of views");
for (unsigned i = 0; i < nViews; ++i) {
- auto viewType = op.getViewType(i);
+ auto viewType = op.getShapedType(i);
if (viewType.getElementType() != block.getArgument(i)->getType())
return op.emitOpError("expected block argument ")
<< i << " of the same type as elemental type of "
@@ -134,7 +145,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
for (unsigned i = 0; i < nViews; ++i) {
unsigned memrefArgIndex = i + nLoops;
- auto viewType = op.getViewType(i);
+ auto viewType = op.getShapedType(i);
if (viewType.getElementType() !=
block.getArgument(memrefArgIndex)->getType())
return op.emitOpError("expected block argument ")
@@ -159,8 +170,8 @@ template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
for (auto en : llvm::enumerate(op.indexing_maps())) {
auto idx = en.index();
- auto view = (idx < nInputViews) ? op.getInputViewType(idx)
- : op.getOutputViewType(idx - nInputViews);
+ auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
+ : op.getOutputShapedType(idx - nInputViews);
if (funType.getInput(idx) != view.getElementType())
return op.emitOpError("expected fun argument ")
<< idx << " of the same type as elemental type "
@@ -197,8 +208,8 @@ LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
for (auto en : llvm::enumerate(op.indexing_maps())) {
auto idx = en.index();
auto funIdx = nLoops + idx;
- auto view = (idx < nInputViews) ? op.getInputViewType(idx)
- : op.getOutputViewType(idx - nInputViews);
+ auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
+ : op.getOutputShapedType(idx - nInputViews);
if (funType.getInput(funIdx) != view.getElementType())
return op.emitOpError("expected fun argument ")
<< funIdx << " of the same type as elemental type "
@@ -245,8 +256,8 @@ LogicalResult verifyGenericOp(GenericOpType op) {
auto idx = en.index();
auto m = en.value().template cast<AffineMapAttr>().getValue();
indexingMaps.push_back(m); // Save reference to map for further checks.
- auto view = (idx < nInputViews) ? op.getInputViewType(idx)
- : op.getOutputViewType(idx - nInputViews);
+ auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
+ : op.getOutputShapedType(idx - nInputViews);
if (m.getNumSymbols() != 0)
return op.emitOpError("expected indexing_map #")
@@ -275,6 +286,22 @@ LogicalResult verifyGenericOp(GenericOpType op) {
return op.emitOpError("expected the concatenation of maps in indexing_map "
"to be invertible");
+ auto outputTensorTypes = op.getOutputTensorTypes();
+ if (outputTensorTypes.size() != op.getNumResults())
+ return op.emitOpError("expected #output tensor operands (")
+ << outputTensorTypes.size() << ") to match #results ("
+ << op.getNumResults() << ")";
+
+ unsigned index = 0;
+ for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) {
+ auto resTy = std::get<0>(it);
+ auto outOpTy = std::get<1>(it);
+ if (resTy != outOpTy)
+ return op.emitOpError("result #")
+ << index << " must be " << outOpTy << ", but got " << resTy;
+ ++index;
+ }
+
return success();
}
@@ -465,11 +492,11 @@ LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
// The operand number and types must match the view element types.
auto nOutputViews = genericOp.getNumOutputs();
if (op.getNumOperands() != nOutputViews)
- return op.emitOpError("op expected ")
+ return op.emitOpError("expected ")
<< nOutputViews << " operand to match enclosing linalg.generic op";
for (unsigned i = 0; i != nOutputViews; ++i) {
- auto elementType = genericOp.getOutputViewType(i).getElementType();
+ auto elementType = genericOp.getOutputShapedType(i).getElementType();
if (op.getOperand(i)->getType() != elementType)
return op.emitOpError("type of return operand ")
<< i << " (" << op.getOperand(i)->getType()
@@ -481,7 +508,7 @@ LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
static LogicalResult verify(YieldOp op) {
auto *parentOp = op.getParentOp();
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
- return op.emitOpError("op expected single non-empty parent region");
+ return op.emitOpError("expected single non-empty parent region");
auto genericOp = dyn_cast<GenericOp>(parentOp);
if (genericOp)
@@ -536,7 +563,7 @@ static ParseResult parseLinalgStructuredOp(OpAsmParser &parser,
}
static LogicalResult verify(FillOp op) {
- auto viewType = op.getOutputViewType(0);
+ auto viewType = op.getOutputShapedType(0);
auto fillType = op.value()->getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
@@ -544,8 +571,8 @@ static LogicalResult verify(FillOp op) {
}
static LogicalResult verify(CopyOp op) {
- auto outputViewType = op.getOutputViewType(0);
- auto inputViewType = op.getInputViewType(0);
+ auto outputViewType = op.getOutputShapedType(0);
+ auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
@@ -675,8 +702,8 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
// I(input_perm(ivs)) -> O(output_perm(ivs))
auto maybeInputMap = copyOp.inputPermutation();
auto maybeOutputMap = copyOp.outputPermutation();
- unsigned inputRank = copyOp.getInputViewType(0).getRank();
- unsigned outputRank = copyOp.getOutputViewType(0).getRank();
+ unsigned inputRank = copyOp.getInputShapedType(0).getRank();
+ unsigned outputRank = copyOp.getOutputShapedType(0).getRank();
return SmallVector<AffineMap, 4>{
extractOrIdentityMap(maybeInputMap, inputRank, context),
extractOrIdentityMap(maybeOutputMap, outputRank, context)};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index e9e44d7ba13..4bc452afa36 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -114,6 +114,9 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl(
return false;
}
+//============================================================================//
+// Precondition and transformation for vectorization of Linalg generic ops.
+//============================================================================//
static bool hasMultiplyAddBody(linalg::GenericOp op) {
auto &r = op.region();
if (r.empty())
@@ -153,12 +156,8 @@ static bool isMatmul(linalg::GenericOp genericOp) {
genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp);
}
-LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
- Operation *op) {
- LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
- "]: Rewrite linalg op as vector.contract: "
- << *op << ":\n");
-
+LogicalResult
+mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
// TODO(ntv): This is in fact much more general than just vectorization for
// matmul ops.
auto genericOp = dyn_cast<linalg::GenericOp>(op);
@@ -175,7 +174,20 @@ LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
if (!llvm::all_of(genericOp.getInputsAndOutputs(),
isStaticMemRefWithIdentityLayout))
return failure();
+ return success();
+}
+
+SmallVector<Value, 0>
+mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
+ Operation *op) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+ "]: Rewrite linalg op as vector.contract: "
+ << *op << ":\n");
+
+ assert(succeeded(vectorizeGenericLinalgOpPrecondition(op)) &&
+ "DRR failure case must be a precondition");
+ auto genericOp = cast<linalg::GenericOp>(op);
edsc::ScopedContext scope(rewriter, op->getLoc());
using edsc::intrinsics::std_load;
using edsc::intrinsics::std_store;
@@ -188,16 +200,35 @@ LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter,
auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(),
genericOp.iterator_types());
std_store(vRes, vectorMemRefC);
+ return {};
+}
+
+//============================================================================//
+// Precondition and transformation for permutation of Linalg generic ops.
+//============================================================================//
+LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition(
+ Operation *op, ArrayRef<unsigned> permutation) {
+ if (permutation.empty())
+ return failure();
+ // Transformation applies to generic ops only.
+ if (!isa<GenericOp>(op) && !isa<IndexedGenericOp>(op))
+ return failure();
+ LinalgOp linOp = cast<LinalgOp>(op);
+ // Transformation applies to buffers only.
+ if (!linOp.hasBufferSemantics())
+ return failure();
return success();
}
-LogicalResult
+SmallVector<Value, 0>
mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
ArrayRef<unsigned> permutation,
StringRef linalgMarker) {
- // If permutation is empty, there is nothing to be done.
- if (permutation.empty())
- return failure();
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op
+ << ":\n");
+
+ assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) &&
+ "DRR failure case must be a precondition");
auto linOp = cast<LinalgOp>(op);
auto permutationMap = inversePermutation(
@@ -220,19 +251,41 @@ mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op,
op->setAttr(LinalgTransforms::kLinalgTransformMarker,
rewriter.getStringAttr(linalgMarker));
linOp.clone(rewriter, linOp.getLoc(), op->getOperands());
- return success();
+ return {};
}
-LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
- Operation *op) {
+//============================================================================//
+// Precondition and transformation for Linalg subview promotion.
+//============================================================================//
+LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) {
LinalgOp linOp = dyn_cast<LinalgOp>(op);
+ // Transformation applies to buffers only.
+ if (!linOp || !linOp.hasBufferSemantics())
+ return failure();
+ if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) {
+ return isa_and_nonnull<SubViewOp>(v.getDefiningOp());
+ }))
+ return failure();
+ return success();
+}
+
+SmallVector<Value, 0>
+mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
+ Operation *op) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
+ << *op << ":\n");
+
+ assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
+ "DRR failure case must be a precondition");
+
+ LinalgOp linOp = cast<LinalgOp>(op);
SetVector<Value> subViews;
for (auto it : linOp.getInputsAndOutputs())
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
subViews.insert(sv);
if (!subViews.empty()) {
- auto resOp = promoteSubViewOperands(rewriter, linOp, subViews);
- return success(resOp);
+ promoteSubViewOperands(rewriter, linOp, subViews);
+ return {};
}
- return failure();
+ llvm_unreachable("DRR failure case must be a precondition");
}
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 15dffe198df..9fcbd0cb921 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -462,7 +462,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
}
// Entry point for unrolling declarative pattern rewrites.
-Value mlir::vector::unrollSingleResultOpMatchingType(
+SmallVector<Value, 1> mlir::vector::unrollSingleResultOpMatchingType(
PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) {
assert(op->getNumResults() == 1 && "Expected single result operation");
@@ -482,8 +482,8 @@ Value mlir::vector::unrollSingleResultOpMatchingType(
}
// Unroll 'op' with 'iterationBounds' to 'targetShape'.
- return unrollSingleResultStructuredOp(op, iterationBounds, vectors,
- resultIndex, targetShape, builder);
+ return SmallVector<Value, 1>{unrollSingleResultStructuredOp(
+ op, iterationBounds, vectors, resultIndex, targetShape, builder)};
}
// Generates slices of 'vectorType' according to 'sizes' and 'strides, and
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f99ee74ceea..b81315bbf53 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -423,6 +423,51 @@ func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)
// -----
+func @generic_result_tensor_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
+ // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}}
+ %0 = linalg.generic {
+ args_in = 0,
+ args_out = 1,
+ indexing_maps = [ (i) -> (i) ],
+ iterator_types = ["parallel"]
+ } %arg0 {
+ ^bb(%i: f32):
+ linalg.yield %i: f32
+ }: memref<?xf32, (i)[off]->(off + i)> -> f32
+}
+
+// -----
+
+func @generic_result_tensor_count(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
+ // expected-error @+1 {{op expected #output tensor operands (0) to match #results (1)}}
+ %0 = linalg.generic {
+ args_in = 0,
+ args_out = 1,
+ indexing_maps = [ (i) -> (i) ],
+ iterator_types = ["parallel"]
+ } %arg0 {
+ ^bb(%i: f32):
+ linalg.yield %i: f32
+ }: memref<?xf32, (i)[off]->(off + i)> -> tensor<?xf32>
+}
+
+// -----
+
+func @generic_result_tensor_type(%arg0: tensor<?xf32>) {
+ // expected-error @+1 {{op result #0 must be 'tensor<?xf32>', but got 'tensor<?x?xf32>'}}
+ %0 = linalg.generic {
+ args_in = 0,
+ args_out = 1,
+ indexing_maps = [ (i) -> (i) ],
+ iterator_types = ["parallel"]
+ } %arg0 {
+ ^bb(%i: f32):
+ linalg.yield %i: f32
+ }: tensor<?xf32> -> tensor<?x?xf32>
+}
+
+// -----
+
func @generic_fun_result_0_element_type(%arg0: memref<?xf32>) {
// expected-error @+1 {{'linalg.dot' op expected 3 or more operands}}
linalg.dot(%arg0, %arg0): memref<?xf32>, memref<?xf32>
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 75d732d540d..871fc70b451 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -139,6 +139,29 @@ func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %ar
// CHECK-LABEL: func @generic
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, memref<?x?x?xf32, #[[strided3D]]>
+func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+ linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+ return
+}
+// CHECK-LABEL: func @generic_with_tensor_input
+// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[strided3D]]>
+
+func @generic_with_tensor_output(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
+ %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @generic_with_tensor_output
+// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref<?x?xvector<3x4xi4>, #[[strided2D]]>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+// CHECK: return {{.*}} : tensor<?x?x?xf32>
+
+func @generic_with_tensor_input_and_output(%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
+ %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: func @generic_with_tensor_input_and_output
+// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
+// CHECK: return {{.*}} : tensor<?x?x?xf32>
+
#trait2 = {
args_in = 1,
args_out = 1,
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index 6bad586cafa..3025db7b688 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -100,27 +100,39 @@ def : Pattern<(DotOp:$op $_, $_, $_),
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
- [(VectorizeGenericLinalgOp<"GenericOp">)],
- [(Constraint<HasLinalgTransformMarker<"_marked_matmul_">>)]>;
+ [(VectorizeGenericLinalgOp)],
+ [(Constraint<And<[
+ HasLinalgTransformMarker<"_marked_matmul_">,
+ PreconditionVectorizeGenericLinalgOp
+ ]>>)]>;
//===----------------------------------------------------------------------===//
// Linalg generic permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
- (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
- [(Constraint<And<[HasNoLinalgTransformMarker,
- AffineMapDomainHasDim<3>]>>)]>;
+ (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
+ [(Constraint<And<[
+ HasNoLinalgTransformMarker,
+ AffineMapDomainHasDim<3>,
+ PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
+ ]>>)]>;
def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
- (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">),
- [(Constraint<And<[HasNoLinalgTransformMarker,
- AffineMapDomainHasDim<3>]>>)]>;
+ (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op),
+ [(Constraint<And<[
+ HasNoLinalgTransformMarker,
+ AffineMapDomainHasDim<3>,
+ PreconditionPermuteGenericLinalgOp<[1, 2, 0]>
+ ]>>)]>;
//===----------------------------------------------------------------------===//
// Linalg subview operands promotion.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $_, $_, $_),
- (PromoteSubviewsLinalgOp<"MatmulOp">),
- [(Constraint<HasOperandsOfType<"SubViewOp">>),
- (Constraint<HasLinalgTransformMarker<"_promote_views_">>)]>;
+ (PromoteSubviewsLinalgOp),
+ [(Constraint<And<[
+ PreconditionPromoteSubviewsLinalgOp,
+ HasOperandsOfType<"SubViewOp">,
+ HasLinalgTransformMarker<"_promote_views_">]>>
+ )]>;
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index c84b56c0c72..1562562a20e 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -329,8 +329,9 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
os.indent(indent) << "{\n";
indent += 2;
os.indent(indent) << formatv(
- "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth,
- attr.getStorageType(), namedAttr->name);
+ "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");"
+ "(void)tblgen_attr;\n",
+ depth, attr.getStorageType(), namedAttr->name);
// TODO(antiagainst): This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
@@ -573,8 +574,15 @@ void PatternEmitter::emitRewriteLogic() {
auto val = handleResultPattern(resultTree, offsets[i], 0);
os.indent(4) << "\n";
// Resolve each symbol for all range use so that we can loop over them.
+ // We need an explicit cast to `SmallVector` to capture the cases where
+ // `{0}` resolves to an `Operation::result_range` as well as cases that
+ // are not iterable (e.g. vector that gets wrapped in additional braces by
+ // RewriterGen).
+ // TODO(b/147096809): Revisit the need for materializing a vector.
os << symbolInfoMap.getAllRangeUse(
- val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }",
+ val,
+ " for (auto v : SmallVector<Value, 4>{ {0} }) {{ "
+ "tblgen_repl_values.push_back(v); }",
"\n");
}
os.indent(4) << "\n";
OpenPOWER on IntegriCloud