summaryrefslogtreecommitdiffstats
path: root/mlir/include
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td178
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h167
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Utils/Utils.h2
3 files changed, 217 insertions, 130 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index d8c657c7209..44735a0b672 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -35,38 +35,9 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
// interface.
def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
let methods = [
- InterfaceMethod<
- "Query the number of inputs from the current operation.",
- "unsigned", "getNumInputs"
- >,
- InterfaceMethod<
- "Query the number of outputs from the current operation.",
- "unsigned", "getNumOutputs"
- >,
- InterfaceMethod<
- "Query the number of inputs and outputs from the current operation.",
- "unsigned", "getNumInputsAndOutputs"
- >,
- InterfaceMethod<
- "Query the input operands from the current operation.",
- "Operation::operand_range", "getInputs"
- >,
- InterfaceMethod<
- "Query the output operands from the current operation.",
- "Operation::operand_range", "getOutputs"
- >,
- InterfaceMethod<
- "Query the input and output operands from the current operation.",
- "Operation::operand_range", "getInputsAndOutputs"
- >,
- InterfaceMethod<
- "Query the iterator types attribute within the current operation.",
- "ArrayAttr", "iterator_types"
- >,
- InterfaceMethod<
- "Query the indexing maps attribute within the current operation.",
- "ArrayAttr", "indexing_maps"
- >,
+ //========================================================================//
+ // Loop types handling.
+ //========================================================================//
InterfaceMethod<
"Query the number of parallel loops within the current operation.",
"unsigned", "getNumParallelLoops"
@@ -82,40 +53,98 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
InterfaceMethod<
"Query the number of loops within the current operation.",
"unsigned", "getNumLoops">,
+
+ //========================================================================//
+ // Input arguments handling.
+ //========================================================================//
+ InterfaceMethod<
+ "Query the number of inputs from the current operation.",
+ "unsigned", "getNumInputs"
+ >,
InterfaceMethod<"Query the input view at the given index.",
"Value ", "getInput", (ins "unsigned":$i)
>,
- InterfaceMethod<"Query the output view at the given index.",
- "Value ", "getOutput", (ins "unsigned":$i)
- >,
InterfaceMethod<[{
Return the index of the given input value `v`, or `None` if the value is
not an input.
}],
"llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v)
>,
- InterfaceMethod<[{
- Query the index of the given view value, or `None` if the value is not
- a view.
- }],
- "llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
+ InterfaceMethod<
+ "Query the input operands from the current operation.",
+ "Operation::operand_range", "getInputs"
>,
InterfaceMethod<[{
Query the type of the input shape at the given index.
}], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
InterfaceMethod<[{
- Query the type of the output view at the given index.
- }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
- InterfaceMethod<[{
- Query whether the op has only MemRef input and outputs.
- }], "bool", "hasBufferSemantics">,
- InterfaceMethod<[{
Query the subset of input operands that are of ranked tensor type.
}], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
+
+
+ //========================================================================//
+ // Output arguments handling.
+ //========================================================================//
+ InterfaceMethod<
+ "Query the number of outputs from the current operation.",
+ "unsigned", "getNumOutputs"
+ >,
+ InterfaceMethod<"Query the output buffer at the given index.",
+ "Value ", "getOutputBuffer", (ins "unsigned":$i)
+ >,
InterfaceMethod<[{
- Query the subset of output operands that are of ranked tensor type.
+ Query the index of the given buffer value, or `None` if the value is not
+ part of the output buffers.
+ }],
+ "llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value ":$view)
+ >,
+ InterfaceMethod<[{
+ Query the type of the output buffer at the given index.
+ }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>,
+ InterfaceMethod<[{
+ Query the results that are of ranked tensor type.
}], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
+ InterfaceMethod<
+ "Query the output buffers (operands) from the current operation.",
+ "Operation::operand_range", "getOutputBuffers"
+ >,
+ //========================================================================//
+ // Input and Output arguments handling.
+ //========================================================================//
+ InterfaceMethod<
+ "Return the number of inputs and outputs, irrespective of their buffer "
+ "or tensor type.",
+ "unsigned", "getNumInputsAndOutputs"
+ >,
+ InterfaceMethod<
+ "Return the number of inputs, irrespective of their buffer or tensor "
+ "type, and output buffers",
+ "unsigned", "getNumInputsAndOutputBuffers"
+ >,
+ InterfaceMethod<
+ "Return the range over inputs (irrespective of type) and output buffers.",
+ "Operation::operand_range", "getInputsAndOutputBuffers"
+ >,
+
+ //========================================================================//
+ // Other interface methods.
+ //========================================================================//
+ InterfaceMethod<
+ "Query the iterator types attribute within the current operation.",
+ "ArrayAttr", "iterator_types"
+ >,
+ InterfaceMethod<
+ "Query the indexing maps attribute within the current operation.",
+ "ArrayAttr", "indexing_maps"
+ >,
+ InterfaceMethod<[{
+ Query whether the op has only MemRef input and outputs.
+ }], "bool", "hasBufferSemantics">,
+
+ //========================================================================//
+ // Other static interface methods.
+ //========================================================================//
StaticInterfaceMethod<[{
Create an operation of the current type with the given location,
operands, and attributes.
@@ -128,9 +157,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
attributes);
}]
>,
-
- /// Clone an operation with the given location and operands. This is used to
- /// abstract away the optional underlying region creation.
InterfaceMethod<[{
Clone the current operation with the given location and operands. This
is used to abstract away the optional underlying region creation.
@@ -536,22 +562,26 @@ def GenericOp : GenericOpBase<"generic"> {
mixing input and output ranked tensor values with input and output memrefs.
```mlir
- %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
+ %C = linalg.generic #trait_attribute %A, %B {other-attributes} :
tensor<?x?xf32>,
- memref<?x?xf32, stride_specification>,
- tensor<?x?xf32>
+ memref<?x?xf32, stride_specification>
-> (tensor<?x?xf32>)
```
- In this case, the number of return values must match the number of output
- tensor arguments. The semantics is that the `linalg.generic` op
- produces (i.e. allocates and fills) its return values.
+ In this case, the number of outputs (args_out) must match the sum of (1) the
+ number of output buffer operands and (2) the number of tensor return values.
+ The semantics is that the `linalg.indexed_generic` op produces (i.e.
+ allocates and fills) its tensor return values.
+
Tensor values must be legalized by a buffer allocation pass before most
- transformations can be applied. In particular, transformations that create
- control flow around linalg.generic operations are not expected to mix with
- tensors because SSA values do not escape naturally. Still, transformations
- and rewrites that take advantage of tensor SSA values are expected to be
- useful and will be added in the near future.
+ transformations can be applied. Such legalization moves tensor return values
+ into output buffer operands and updates the region arguments accordingly.
+
+ Transformations that create control-flow around linalg.indexed_generic
+ operations are not expected to work with tensors because SSA values do not
+ escape naturally. Still, transformations and rewrites that take advantage of
+ tensor SSA values are expected to be useful and will be added in the near
+ future.
}];
let verifier = [{ return ::verify(*this); }];
}
@@ -659,22 +689,26 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
memrefs.
```mlir
- %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes}
+ %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes}
: tensor<?x?xf32>,
- memref<?x?xf32, stride_specification>,
- tensor<?x?xf32>
+ memref<?x?xf32, stride_specification>
-> (tensor<?x?xf32>)
```
- In this case, the number of return values must match the number of output
- tensor arguments. The semantics is that the `linalg.indexed_generic` op
- produces (i.e. allocates and fills) its return values.
+ In this case, the number of outputs (args_out) must match the sum of (1) the
+ number of output buffer operands and (2) the number of tensor return values.
+ The semantics is that the `linalg.indexed_generic` op produces (i.e.
+ allocates and fills) its return values.
+
Tensor values must be legalized by a buffer allocation pass before most
- transformations can be applied. In particular, transformations that create
- control flow around linalg.generic operations are not expected to mix with
- tensors because SSA values do not escape naturally. Still, transformations
- and rewrites that take advantage of tensor SSA values are expected to be
- useful and will be added in the near future.
+ transformations can be applied. Such legalization moves tensor return values
+ into output buffer operands and updates the region argument accordingly.
+
+ Transformations that create control-flow around linalg.indexed_generic
+ operations are not expected to work with tensors because SSA values do not
+ escape naturally. Still, transformations and rewrites that take advantage of
+ tensor SSA values are expected to be useful and will be added in the near
+ future.
}];
let verifier = [{ return ::verify(*this); }];
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
index 7275863c5b6..2284616e457 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -58,16 +58,47 @@ template <typename ConcreteType>
class StructuredOpTraits
: public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
private:
- /// Return the number of inputs. For internal use only.
+ /// Return the number of inputs, irrespective of their buffer or tensor type.
+ /// For internal use only.
unsigned nInputs() {
return cast<ConcreteType>(this->getOperation()).getNumInputs();
}
- /// Return the number of outputs. For internal use only.
+ /// Return the number of outputs, irrespective of their buffer or tensor type.
+ /// For internal use only.
unsigned nOutputs() {
return cast<ConcreteType>(this->getOperation()).getNumOutputs();
}
public:
+ //==========================================================================//
+ // Loop types handling.
+ //==========================================================================//
+ unsigned getNumParallelLoops() {
+ return getNumIterators(
+ getParallelIteratorTypeName(),
+ cast<ConcreteType>(this->getOperation()).iterator_types());
+ }
+ unsigned getNumReductionLoops() {
+ return getNumIterators(
+ getReductionIteratorTypeName(),
+ cast<ConcreteType>(this->getOperation()).iterator_types());
+ }
+ unsigned getNumWindowLoops() {
+ return getNumIterators(
+ getWindowIteratorTypeName(),
+ cast<ConcreteType>(this->getOperation()).iterator_types());
+ }
+ unsigned getNumLoops() {
+ return getNumIterators(
+ cast<ConcreteType>(this->getOperation()).iterator_types());
+ }
+
+ //==========================================================================//
+ // Input arguments handling.
+ //==========================================================================//
+ // The `i^th` input argument is always the `i^th` operand regardless of
+ // whether we have tensors or buffers.
+ //
/// Return the `i`-th input value.
Value getInput(unsigned i) {
assert(i < nInputs());
@@ -90,28 +121,6 @@ public:
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + nInputs()};
}
- /// Return the `i`-th output.
- Value getOutput(unsigned i) {
- return this->getOperation()->getOperand(nInputs() + i);
- }
- /// Return the index of `value` in the list of output values if found,
- /// llvm::None otherwise.
- Optional<unsigned> getIndexOfOutput(Value value) {
- auto it = llvm::find(getOutputs(), value);
- if (it != getOutputs().end())
- return it - getOutputs().begin();
- return llvm::None;
- }
- /// Return the `i`-th output buffer type.
- ShapedType getOutputShapedType(unsigned i) {
- return getOutput(i).getType().template cast<ShapedType>();
- }
- /// Query whether the op has only MemRef input and outputs.
- bool hasBufferSemantics() {
- return this->getOperation()->getNumResults() == 0 &&
- llvm::all_of(getInputsAndOutputs(),
- [](Value v) { return v.getType().isa<MemRefType>(); });
- }
/// Query the subset of input operands that are of ranked tensor type.
SmallVector<RankedTensorType, 4> getInputTensorTypes() {
SmallVector<RankedTensorType, 4> res;
@@ -120,53 +129,97 @@ public:
res.push_back(t);
return res;
}
- /// Query the subset of output operands that are of ranked tensor type.
+
+ //==========================================================================//
+ // Output arguments handling.
+ //==========================================================================//
+ // The `i^th` output argument is an operand (resp. a return value) iff it is
+ // a value of buffer type (resp. a return value of tensor type).
+
+ /// Return the `i`-th output, asserts that this is a buffer operand and not
+ /// a tensor result.
+ Value getOutputBuffer(unsigned i) {
+ assert(i + this->getOperation()->getNumResults() < nOutputs() &&
+ "overflowing output buffer index");
+ return this->getOperation()->getOperand(nInputs() + i);
+ }
+ /// Return the index of `value` in the list of output buffers if found,
+ /// llvm::None otherwise.
+ Optional<unsigned> getIndexOfOutputBuffer(Value value) {
+ auto it = llvm::find(getOutputBuffers(), value);
+ if (it != getOutputBuffers().end())
+ return it - getOutputBuffers().begin();
+ return llvm::None;
+ }
+ /// Return the `i`-th output buffer type.
+ MemRefType getOutputBufferType(unsigned i) {
+ return getOutputBuffer(i).getType().template cast<MemRefType>();
+ }
+ /// Return the `i`-th output shaped type, irrespective of buffer of tensor
+ /// type.
+ ShapedType getOutputShapedType(unsigned i) {
+ return getShapedType(i + nInputs());
+ }
+ /// Query the subset of results that are of ranked tensor type.
SmallVector<RankedTensorType, 4> getOutputTensorTypes() {
SmallVector<RankedTensorType, 4> res;
- for (Type type : getOutputs().getTypes())
- if (auto t = type.template dyn_cast<RankedTensorType>())
- res.push_back(t);
+ for (Type type : this->getOperation()->getResults().getTypes())
+ res.push_back(type.template cast<RankedTensorType>());
return res;
}
/// Return the range over outputs.
- Operation::operand_range getOutputs() {
+ Operation::operand_range getOutputBuffers() {
auto range = this->getOperation()->getOperands();
return {range.begin() + nInputs(),
- range.begin() + getNumInputsAndOutputs()};
+ range.begin() + getNumInputsAndOutputBuffers()};
}
- /// Return the number of inputs and outputs.
+
+ //==========================================================================//
+ // Input and Output arguments handling.
+ //==========================================================================//
+ /// Return the number of inputs and outputs, irrespective of their buffer or
+ /// tensor type.
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
- /// Return the `i`-th buffer type.
- ShapedType getShapedType(unsigned i) {
- return (i < nInputs()) ? getInputShapedType(i)
- : getOutputShapedType(i - nInputs());
- }
- /// Return the range over inputs and outputs.
- Operation::operand_range getInputsAndOutputs() {
+ /// Return the number of inputs, irrespective of their buffer or tensor type,
+ /// and output buffers.
+ unsigned getNumInputsAndOutputBuffers() {
+ assert(this->getOperation()->getNumResults() <= nOutputs());
+ return nInputs() + nOutputs() - this->getOperation()->getNumResults();
+ }
+ /// Return the range over inputs (irrespective of type) and output buffers.
+ Operation::operand_range getInputsAndOutputBuffers() {
auto range = this->getOperation()->getOperands();
- return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+ return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()};
}
- unsigned getNumParallelLoops() {
- return getNumIterators(
- getParallelIteratorTypeName(),
- cast<ConcreteType>(this->getOperation()).iterator_types());
- }
- unsigned getNumReductionLoops() {
- return getNumIterators(
- getReductionIteratorTypeName(),
- cast<ConcreteType>(this->getOperation()).iterator_types());
- }
- unsigned getNumWindowLoops() {
- return getNumIterators(
- getWindowIteratorTypeName(),
- cast<ConcreteType>(this->getOperation()).iterator_types());
+ /// Return the `i`-th shaped type, there are 3 cases:
+ /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise
+ /// 2. if `i < getNumInputsAndOutputBuffers()` then return the
+ /// `getOutputBufferType(i - nInputs())`; otherwise
+ /// 3. return the `i - getNumInputsAndOutputBuffers()` result type.
+ ShapedType getShapedType(unsigned i) {
+ if (i < nInputs())
+ return getInputShapedType(i);
+ if (i < getNumInputsAndOutputBuffers())
+ return getOutputBufferType(i - nInputs()).template cast<ShapedType>();
+ return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]
+ .template cast<ShapedType>();
}
- unsigned getNumLoops() {
- return getNumIterators(
- cast<ConcreteType>(this->getOperation()).iterator_types());
+
+ //==========================================================================//
+ // Other interface methods.
+ //==========================================================================//
+ /// Query whether the op has only buffer inputs and no returns.
+ bool hasBufferSemantics() {
+ return this->getOperation()->getNumResults() == 0 &&
+ llvm::all_of(getInputs(),
+ [](Value v) { return v.getType().isa<MemRefType>(); });
}
+
+ //==========================================================================//
+ // Other static interface methods.
+ //==========================================================================//
static LogicalResult verifyTrait(Operation *op) {
- auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs();
+ auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers();
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
return failure();
return success();
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 2bd4a0d2c0b..f559ba4423d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -119,7 +119,7 @@ Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer,
template <typename ConcreteOp>
SmallVector<Value, 8> getViewSizes(ConcreteOp linalgOp) {
SmallVector<Value, 8> res;
- for (auto v : linalgOp.getInputsAndOutputs()) {
+ for (auto v : linalgOp.getInputsAndOutputBuffers()) {
MemRefType t = v.getType().template cast<MemRefType>();
for (unsigned i = 0; i < t.getRank(); ++i)
res.push_back(edsc::intrinsics::dim(v, i));
OpenPOWER on IntegriCloud