diff options
author | Nicolas Vasilache <ntv@google.com> | 2020-01-11 02:22:00 -0500 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2020-01-14 17:25:28 -0500 |
commit | f52d71736b10e87b1aa1880b777dc9462a0085ce (patch) | |
tree | 3eaa824f59037e0b987abd0c39094ec999e04c3c /mlir/include | |
parent | 8d07f8d98c48ee0a9dca450aaf4e1cabc621ff68 (diff) | |
download | bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.tar.gz bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.zip |
[mlir][Linalg] Update the semantics, verifier and test for Linalg with tensors.
Summary:
This diff fixes issues with the semantics of linalg.generic on tensors that appeared when converting directly from HLO to linalg.generic.
The changes are self-contained within MLIR and can be captured and tested independently of XLA.
The linalg.generic and indexed_generic are updated to:
To allow progressive lowering from the value world (a.k.a tensor values) to
the buffer world (a.k.a memref values), a linalg.generic op accepts
mixing input and output ranked tensor values with input and output memrefs.
```
%1 = linalg.generic #trait_attribute %A, %B {other-attributes} :
tensor<?x?xf32>,
memref<?x?xf32, stride_specification>
-> (tensor<?x?xf32>)
```
In this case, the number of outputs (args_out) must match the sum of (1) the
number of output buffer operands and (2) the number of tensor return values.
The semantics is that the linalg.indexed_generic op produces (i.e.
allocates and fills) its return values.
Tensor values must be legalized by a buffer allocation pass before most
transformations can be applied. Such legalization moves tensor return values
into output buffer operands and updates the region argument accordingly.
Transformations that create control-flow around linalg.indexed_generic
operations are not expected to mix with tensors because SSA values do not
escape naturally. Still, transformations and rewrites that take advantage of
tensor SSA values are expected to be useful and will be added in the near
future.
Subscribers: bmahjour, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72555
Diffstat (limited to 'mlir/include')
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 178 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 167 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 2 |
3 files changed, 217 insertions, 130 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index d8c657c7209..44735a0b672 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,38 +35,9 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; // interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let methods = [ - InterfaceMethod< - "Query the number of inputs from the current operation.", - "unsigned", "getNumInputs" - >, - InterfaceMethod< - "Query the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod< - "Query the number of inputs and outputs from the current operation.", - "unsigned", "getNumInputsAndOutputs" - >, - InterfaceMethod< - "Query the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod< - "Query the output operands from the current operation.", - "Operation::operand_range", "getOutputs" - >, - InterfaceMethod< - "Query the input and output operands from the current operation.", - "Operation::operand_range", "getInputsAndOutputs" - >, - InterfaceMethod< - "Query the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" - >, - InterfaceMethod< - "Query the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" - >, + //========================================================================// + // Loop types handling. + //========================================================================// InterfaceMethod< "Query the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" @@ -82,40 +53,98 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { InterfaceMethod< "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, + + //========================================================================// + // Input arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, InterfaceMethod<"Query the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, - InterfaceMethod<"Query the output view at the given index.", - "Value ", "getOutput", (ins "unsigned":$i) - >, InterfaceMethod<[{ Return the index of the given input value `v`, or `None` if the value is not an input. }], "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v) >, - InterfaceMethod<[{ - Query the index of the given view value, or `None` if the value is not - a view. - }], - "llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view) + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ Query the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the type of the output view at the given index. - }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ Query the subset of input operands that are of ranked tensor type. }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">, + + + //========================================================================// + // Output arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod<"Query the output buffer at the given index.", + "Value ", "getOutputBuffer", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query the subset of output operands that are of ranked tensor type. + Query the index of the given buffer value, or `None` if the value is not + part of the output buffers. + }], + "llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the output buffer at the given index. + }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the results that are of ranked tensor type. }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">, + InterfaceMethod< + "Query the output buffers (operands) from the current operation.", + "Operation::operand_range", "getOutputBuffers" + >, + //========================================================================// + // Input and Output arguments handling. + //========================================================================// + InterfaceMethod< + "Return the number of inputs and outputs, irrespective of their buffer " + "or tensor type.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Return the number of inputs, irrespective of their buffer or tensor " + "type, and output buffers", + "unsigned", "getNumInputsAndOutputBuffers" + >, + InterfaceMethod< + "Return the range over inputs (irrespective of type) and output buffers.", + "Operation::operand_range", "getInputsAndOutputBuffers" + >, + + //========================================================================// + // Other interface methods. + //========================================================================// + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + + //========================================================================// + // Other static interface methods. + //========================================================================// StaticInterfaceMethod<[{ Create an operation of the current type with the given location, operands, and attributes. @@ -128,9 +157,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { attributes); }] >, - - /// Clone an operation with the given location and operands. This is used to - /// abstract away the optional underlying region creation. InterfaceMethod<[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. @@ -536,22 +562,26 @@ def GenericOp : GenericOpBase<"generic"> { mixing input and output ranked tensor values with input and output memrefs. ```mlir - %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + %C = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor<?x?xf32>, - memref<?x?xf32, stride_specification>, - tensor<?x?xf32> + memref<?x?xf32, stride_specification> -> (tensor<?x?xf32>) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its tensor return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region arguments accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -659,22 +689,26 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { memrefs. ```mlir - %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} + %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes} : tensor<?x?xf32>, - memref<?x?xf32, stride_specification>, - tensor<?x?xf32> + memref<?x?xf32, stride_specification> -> (tensor<?x?xf32>) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.indexed_generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region argument accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 7275863c5b6..2284616e457 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -58,16 +58,47 @@ template <typename ConcreteType> class StructuredOpTraits : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> { private: - /// Return the number of inputs. For internal use only. + /// Return the number of inputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nInputs() { return cast<ConcreteType>(this->getOperation()).getNumInputs(); } - /// Return the number of outputs. For internal use only. + /// Return the number of outputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nOutputs() { return cast<ConcreteType>(this->getOperation()).getNumOutputs(); } public: + //==========================================================================// + // Loop types handling. + //==========================================================================// + unsigned getNumParallelLoops() { + return getNumIterators( + getParallelIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumReductionLoops() { + return getNumIterators( + getReductionIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumWindowLoops() { + return getNumIterators( + getWindowIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumLoops() { + return getNumIterators( + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + + //==========================================================================// + // Input arguments handling. + //==========================================================================// + // The `i^th` input argument is always the `i^th` operand regardless of + // whether we have tensors or buffers. + // /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); @@ -90,28 +121,6 @@ public: auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output. - Value getOutput(unsigned i) { - return this->getOperation()->getOperand(nInputs() + i); - } - /// Return the index of `value` in the list of output values if found, - /// llvm::None otherwise. - Optional<unsigned> getIndexOfOutput(Value value) { - auto it = llvm::find(getOutputs(), value); - if (it != getOutputs().end()) - return it - getOutputs().begin(); - return llvm::None; - } - /// Return the `i`-th output buffer type. - ShapedType getOutputShapedType(unsigned i) { - return getOutput(i).getType().template cast<ShapedType>(); - } - /// Query whether the op has only MemRef input and outputs. - bool hasBufferSemantics() { - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputsAndOutputs(), - [](Value v) { return v.getType().isa<MemRefType>(); }); - } /// Query the subset of input operands that are of ranked tensor type. SmallVector<RankedTensorType, 4> getInputTensorTypes() { SmallVector<RankedTensorType, 4> res; @@ -120,53 +129,97 @@ public: res.push_back(t); return res; } - /// Query the subset of output operands that are of ranked tensor type. + + //==========================================================================// + // Output arguments handling. + //==========================================================================// + // The `i^th` output argument is an operand (resp. a return value) iff it is + // a value of buffer type (resp. a return value of tensor type). + + /// Return the `i`-th output, asserts that this is a buffer operand and not + /// a tensor result. + Value getOutputBuffer(unsigned i) { + assert(i + this->getOperation()->getNumResults() < nOutputs() && + "overflowing output buffer index"); + return this->getOperation()->getOperand(nInputs() + i); + } + /// Return the index of `value` in the list of output buffers if found, + /// llvm::None otherwise. + Optional<unsigned> getIndexOfOutputBuffer(Value value) { + auto it = llvm::find(getOutputBuffers(), value); + if (it != getOutputBuffers().end()) + return it - getOutputBuffers().begin(); + return llvm::None; + } + /// Return the `i`-th output buffer type. + MemRefType getOutputBufferType(unsigned i) { + return getOutputBuffer(i).getType().template cast<MemRefType>(); + } + /// Return the `i`-th output shaped type, irrespective of buffer of tensor + /// type. + ShapedType getOutputShapedType(unsigned i) { + return getShapedType(i + nInputs()); + } + /// Query the subset of results that are of ranked tensor type. SmallVector<RankedTensorType, 4> getOutputTensorTypes() { SmallVector<RankedTensorType, 4> res; - for (Type type : getOutputs().getTypes()) - if (auto t = type.template dyn_cast<RankedTensorType>()) - res.push_back(t); + for (Type type : this->getOperation()->getResults().getTypes()) + res.push_back(type.template cast<RankedTensorType>()); return res; } /// Return the range over outputs. - Operation::operand_range getOutputs() { + Operation::operand_range getOutputBuffers() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), - range.begin() + getNumInputsAndOutputs()}; + range.begin() + getNumInputsAndOutputBuffers()}; } - /// Return the number of inputs and outputs. + + //==========================================================================// + // Input and Output arguments handling. + //==========================================================================// + /// Return the number of inputs and outputs, irrespective of their buffer or + /// tensor type. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th buffer type. - ShapedType getShapedType(unsigned i) { - return (i < nInputs()) ? getInputShapedType(i) - : getOutputShapedType(i - nInputs()); - } - /// Return the range over inputs and outputs. - Operation::operand_range getInputsAndOutputs() { + /// Return the number of inputs, irrespective of their buffer or tensor type, + /// and output buffers. + unsigned getNumInputsAndOutputBuffers() { + assert(this->getOperation()->getNumResults() <= nOutputs()); + return nInputs() + nOutputs() - this->getOperation()->getNumResults(); + } + /// Return the range over inputs (irrespective of type) and output buffers. + Operation::operand_range getInputsAndOutputBuffers() { auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; } - unsigned getNumParallelLoops() { - return getNumIterators( - getParallelIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); - } - unsigned getNumReductionLoops() { - return getNumIterators( - getReductionIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); - } - unsigned getNumWindowLoops() { - return getNumIterators( - getWindowIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); + /// Return the `i`-th shaped type, there are 3 cases: + /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise + /// 2. if `i < getNumInputsAndOutputBuffers()` then return the + /// `getOutputBufferType(i - nInputs())`; otherwise + /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. + ShapedType getShapedType(unsigned i) { + if (i < nInputs()) + return getInputShapedType(i); + if (i < getNumInputsAndOutputBuffers()) + return getOutputBufferType(i - nInputs()).template cast<ShapedType>(); + return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] + .template cast<ShapedType>(); } - unsigned getNumLoops() { - return getNumIterators( - cast<ConcreteType>(this->getOperation()).iterator_types()); + + //==========================================================================// + // Other interface methods. + //==========================================================================// + /// Query whether the op has only buffer inputs and no returns. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputs(), + [](Value v) { return v.getType().isa<MemRefType>(); }); } + + //==========================================================================// + // Other static interface methods. + //==========================================================================// static LogicalResult verifyTrait(Operation *op) { - auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs(); + auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 2bd4a0d2c0b..f559ba4423d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -119,7 +119,7 @@ Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer, template <typename ConcreteOp> SmallVector<Value, 8> getViewSizes(ConcreteOp linalgOp) { SmallVector<Value, 8> res; - for (auto v : linalgOp.getInputsAndOutputs()) { + for (auto v : linalgOp.getInputsAndOutputBuffers()) { MemRefType t = v.getType().template cast<MemRefType>(); for (unsigned i = 0; i < t.getRank(); ++i) res.push_back(edsc::intrinsics::dim(v, i)); |