diff options
Diffstat (limited to 'mlir/include')
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 178 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 167 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 2 |
3 files changed, 217 insertions, 130 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index d8c657c7209..44735a0b672 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,38 +35,9 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; // interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let methods = [ - InterfaceMethod< - "Query the number of inputs from the current operation.", - "unsigned", "getNumInputs" - >, - InterfaceMethod< - "Query the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod< - "Query the number of inputs and outputs from the current operation.", - "unsigned", "getNumInputsAndOutputs" - >, - InterfaceMethod< - "Query the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod< - "Query the output operands from the current operation.", - "Operation::operand_range", "getOutputs" - >, - InterfaceMethod< - "Query the input and output operands from the current operation.", - "Operation::operand_range", "getInputsAndOutputs" - >, - InterfaceMethod< - "Query the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" - >, - InterfaceMethod< - "Query the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" - >, + //========================================================================// + // Loop types handling. + //========================================================================// InterfaceMethod< "Query the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" @@ -82,40 +53,98 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { InterfaceMethod< "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, + + //========================================================================// + // Input arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, InterfaceMethod<"Query the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, - InterfaceMethod<"Query the output view at the given index.", - "Value ", "getOutput", (ins "unsigned":$i) - >, InterfaceMethod<[{ Return the index of the given input value `v`, or `None` if the value is not an input. }], "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v) >, - InterfaceMethod<[{ - Query the index of the given view value, or `None` if the value is not - a view. - }], - "llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view) + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ Query the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the type of the output view at the given index. - }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ Query the subset of input operands that are of ranked tensor type. }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">, + + + //========================================================================// + // Output arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod<"Query the output buffer at the given index.", + "Value ", "getOutputBuffer", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query the subset of output operands that are of ranked tensor type. + Query the index of the given buffer value, or `None` if the value is not + part of the output buffers. + }], + "llvm::Optional<unsigned>", "getIndexOfOutputBuffer", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the output buffer at the given index. + }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the results that are of ranked tensor type. }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">, + InterfaceMethod< + "Query the output buffers (operands) from the current operation.", + "Operation::operand_range", "getOutputBuffers" + >, + //========================================================================// + // Input and Output arguments handling. + //========================================================================// + InterfaceMethod< + "Return the number of inputs and outputs, irrespective of their buffer " + "or tensor type.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Return the number of inputs, irrespective of their buffer or tensor " + "type, and output buffers", + "unsigned", "getNumInputsAndOutputBuffers" + >, + InterfaceMethod< + "Return the range over inputs (irrespective of type) and output buffers.", + "Operation::operand_range", "getInputsAndOutputBuffers" + >, + + //========================================================================// + // Other interface methods. + //========================================================================// + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + + //========================================================================// + // Other static interface methods. + //========================================================================// StaticInterfaceMethod<[{ Create an operation of the current type with the given location, operands, and attributes. @@ -128,9 +157,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { attributes); }] >, - - /// Clone an operation with the given location and operands. This is used to - /// abstract away the optional underlying region creation. InterfaceMethod<[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. @@ -536,22 +562,26 @@ def GenericOp : GenericOpBase<"generic"> { mixing input and output ranked tensor values with input and output memrefs. ```mlir - %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + %C = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor<?x?xf32>, - memref<?x?xf32, stride_specification>, - tensor<?x?xf32> + memref<?x?xf32, stride_specification> -> (tensor<?x?xf32>) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its tensor return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region arguments accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -659,22 +689,26 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { memrefs. ```mlir - %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} + %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes} : tensor<?x?xf32>, - memref<?x?xf32, stride_specification>, - tensor<?x?xf32> + memref<?x?xf32, stride_specification> -> (tensor<?x?xf32>) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.indexed_generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region argument accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 7275863c5b6..2284616e457 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -58,16 +58,47 @@ template <typename ConcreteType> class StructuredOpTraits : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> { private: - /// Return the number of inputs. For internal use only. + /// Return the number of inputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nInputs() { return cast<ConcreteType>(this->getOperation()).getNumInputs(); } - /// Return the number of outputs. For internal use only. + /// Return the number of outputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nOutputs() { return cast<ConcreteType>(this->getOperation()).getNumOutputs(); } public: + //==========================================================================// + // Loop types handling. + //==========================================================================// + unsigned getNumParallelLoops() { + return getNumIterators( + getParallelIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumReductionLoops() { + return getNumIterators( + getReductionIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumWindowLoops() { + return getNumIterators( + getWindowIteratorTypeName(), + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + unsigned getNumLoops() { + return getNumIterators( + cast<ConcreteType>(this->getOperation()).iterator_types()); + } + + //==========================================================================// + // Input arguments handling. + //==========================================================================// + // The `i^th` input argument is always the `i^th` operand regardless of + // whether we have tensors or buffers. + // /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); @@ -90,28 +121,6 @@ public: auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output. - Value getOutput(unsigned i) { - return this->getOperation()->getOperand(nInputs() + i); - } - /// Return the index of `value` in the list of output values if found, - /// llvm::None otherwise. - Optional<unsigned> getIndexOfOutput(Value value) { - auto it = llvm::find(getOutputs(), value); - if (it != getOutputs().end()) - return it - getOutputs().begin(); - return llvm::None; - } - /// Return the `i`-th output buffer type. - ShapedType getOutputShapedType(unsigned i) { - return getOutput(i).getType().template cast<ShapedType>(); - } - /// Query whether the op has only MemRef input and outputs. - bool hasBufferSemantics() { - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputsAndOutputs(), - [](Value v) { return v.getType().isa<MemRefType>(); }); - } /// Query the subset of input operands that are of ranked tensor type. SmallVector<RankedTensorType, 4> getInputTensorTypes() { SmallVector<RankedTensorType, 4> res; @@ -120,53 +129,97 @@ public: res.push_back(t); return res; } - /// Query the subset of output operands that are of ranked tensor type. + + //==========================================================================// + // Output arguments handling. + //==========================================================================// + // The `i^th` output argument is an operand (resp. a return value) iff it is + // a value of buffer type (resp. a return value of tensor type). + + /// Return the `i`-th output, asserts that this is a buffer operand and not + /// a tensor result. + Value getOutputBuffer(unsigned i) { + assert(i + this->getOperation()->getNumResults() < nOutputs() && + "overflowing output buffer index"); + return this->getOperation()->getOperand(nInputs() + i); + } + /// Return the index of `value` in the list of output buffers if found, + /// llvm::None otherwise. + Optional<unsigned> getIndexOfOutputBuffer(Value value) { + auto it = llvm::find(getOutputBuffers(), value); + if (it != getOutputBuffers().end()) + return it - getOutputBuffers().begin(); + return llvm::None; + } + /// Return the `i`-th output buffer type. + MemRefType getOutputBufferType(unsigned i) { + return getOutputBuffer(i).getType().template cast<MemRefType>(); + } + /// Return the `i`-th output shaped type, irrespective of buffer of tensor + /// type. + ShapedType getOutputShapedType(unsigned i) { + return getShapedType(i + nInputs()); + } + /// Query the subset of results that are of ranked tensor type. SmallVector<RankedTensorType, 4> getOutputTensorTypes() { SmallVector<RankedTensorType, 4> res; - for (Type type : getOutputs().getTypes()) - if (auto t = type.template dyn_cast<RankedTensorType>()) - res.push_back(t); + for (Type type : this->getOperation()->getResults().getTypes()) + res.push_back(type.template cast<RankedTensorType>()); return res; } /// Return the range over outputs. - Operation::operand_range getOutputs() { + Operation::operand_range getOutputBuffers() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), - range.begin() + getNumInputsAndOutputs()}; + range.begin() + getNumInputsAndOutputBuffers()}; } - /// Return the number of inputs and outputs. + + //==========================================================================// + // Input and Output arguments handling. + //==========================================================================// + /// Return the number of inputs and outputs, irrespective of their buffer or + /// tensor type. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th buffer type. - ShapedType getShapedType(unsigned i) { - return (i < nInputs()) ? getInputShapedType(i) - : getOutputShapedType(i - nInputs()); - } - /// Return the range over inputs and outputs. - Operation::operand_range getInputsAndOutputs() { + /// Return the number of inputs, irrespective of their buffer or tensor type, + /// and output buffers. + unsigned getNumInputsAndOutputBuffers() { + assert(this->getOperation()->getNumResults() <= nOutputs()); + return nInputs() + nOutputs() - this->getOperation()->getNumResults(); + } + /// Return the range over inputs (irrespective of type) and output buffers. + Operation::operand_range getInputsAndOutputBuffers() { auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; } - unsigned getNumParallelLoops() { - return getNumIterators( - getParallelIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); - } - unsigned getNumReductionLoops() { - return getNumIterators( - getReductionIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); - } - unsigned getNumWindowLoops() { - return getNumIterators( - getWindowIteratorTypeName(), - cast<ConcreteType>(this->getOperation()).iterator_types()); + /// Return the `i`-th shaped type, there are 3 cases: + /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise + /// 2. if `i < getNumInputsAndOutputBuffers()` then return the + /// `getOutputBufferType(i - nInputs())`; otherwise + /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. + ShapedType getShapedType(unsigned i) { + if (i < nInputs()) + return getInputShapedType(i); + if (i < getNumInputsAndOutputBuffers()) + return getOutputBufferType(i - nInputs()).template cast<ShapedType>(); + return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] + .template cast<ShapedType>(); } - unsigned getNumLoops() { - return getNumIterators( - cast<ConcreteType>(this->getOperation()).iterator_types()); + + //==========================================================================// + // Other interface methods. + //==========================================================================// + /// Query whether the op has only buffer inputs and no returns. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputs(), + [](Value v) { return v.getType().isa<MemRefType>(); }); } + + //==========================================================================// + // Other static interface methods. + //==========================================================================// static LogicalResult verifyTrait(Operation *op) { - auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs(); + auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputBuffers(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 2bd4a0d2c0b..f559ba4423d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -119,7 +119,7 @@ Optional<FusionInfo> fuseProducerOf(OpBuilder &b, LinalgOp consumer, template <typename ConcreteOp> SmallVector<Value, 8> getViewSizes(ConcreteOp linalgOp) { SmallVector<Value, 8> res; - for (auto v : linalgOp.getInputsAndOutputs()) { + for (auto v : linalgOp.getInputsAndOutputBuffers()) { MemRefType t = v.getType().template cast<MemRefType>(); for (unsigned i = 0; i < t.getRank(); ++i) res.push_back(edsc::intrinsics::dim(v, i)); |