diff options
| author | Nicolas Vasilache <ntv@google.com> | 2020-01-11 02:22:00 -0500 |
|---|---|---|
| committer | Nicolas Vasilache <ntv@google.com> | 2020-01-14 17:25:28 -0500 |
| commit | f52d71736b10e87b1aa1880b777dc9462a0085ce (patch) | |
| tree | 3eaa824f59037e0b987abd0c39094ec999e04c3c /mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | |
| parent | 8d07f8d98c48ee0a9dca450aaf4e1cabc621ff68 (diff) | |
| download | bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.tar.gz bcm5719-llvm-f52d71736b10e87b1aa1880b777dc9462a0085ce.zip | |
[mlir][Linalg] Update the semantics, verifier and test for Linalg with tensors.
Summary:
This diff fixes issues with the semantics of linalg.generic on tensors that appeared when converting directly from HLO to linalg.generic.
The changes are self-contained within MLIR and can be captured and tested independently of XLA.
The linalg.generic and indexed_generic are updated to:
To allow progressive lowering from the value world (a.k.a tensor values) to
the buffer world (a.k.a memref values), a linalg.generic op accepts
mixing input and output ranked tensor values with input and output memrefs.
```
%1 = linalg.generic #trait_attribute %A, %B {other-attributes} :
tensor<?x?xf32>,
memref<?x?xf32, stride_specification>
-> (tensor<?x?xf32>)
```
In this case, the number of outputs (args_out) must match the sum of (1) the
number of output buffer operands and (2) the number of tensor return values.
The semantics is that the linalg.indexed_generic op produces (i.e.
allocates and fills) its return values.
Tensor values must be legalized by a buffer allocation pass before most
transformations can be applied. Such legalization moves tensor return values
into output buffer operands and updates the region argument accordingly.
Transformations that create control-flow around linalg.indexed_generic
operations are not expected to mix with tensors because SSA values do not
escape naturally. Still, transformations and rewrites that take advantage of
tensor SSA values are expected to be useful and will be added in the near
future.
Subscribers: bmahjour, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72555
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 187 |
1 files changed, 91 insertions, 96 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 384d24957f4..7850dd60be1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -112,19 +112,20 @@ template <typename GenericOpType> static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (block.getNumArguments() != nViews) - return op.emitOpError( - "expected number of block arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands) + return op.emitOpError("expected number of block arguments to match number " + "of operands"); - for (unsigned i = 0; i < nViews; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + auto nInputViews = op.getNumInputs(); + for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") - << i << " of the same type as elemental type of " + << (i + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (block.getNumArguments() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( - "expected number of block arguments to match number of views + " + "expected number of block arguments to match number of operands + " "number of loops"); - for (unsigned i = 0; i < nLoops; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") - << i << " to be of IndexType"; - } + << (i + 1) << " to be an index"; - for (unsigned i = 0; i < nViews; ++i) { + for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") - << memrefArgIndex << " of the same type as elemental type of " + << (memrefArgIndex + 1) + << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -160,70 +162,74 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { template <typename GenericOpType> static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); +template <typename GenericOpType> +LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) { + auto res = verifyFuncArgs(op, funType); + if (failed(res)) + return res; + + auto nInputs = op.getNumInputs(); + auto nOutputs = op.getNumOutputs(); + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected function result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); + } + return success(); +} + template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (funType.getNumInputs() != nViews) - return op.emitOpError("expected fun arguments to match number of views"); - if (funType.getNumResults() != op.getNumOutputs()) + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands) return op.emitOpError( - "expected fun results to match number of output views"); - - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(idx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << idx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + "expected function arguments to match number of operands"); + if (funType.getNumResults() != op.getNumOutputs()) + return op.emitOpError("expected function results(") + << funType.getNumResults() << ") to match number of outputs(" + << op.getNumOutputs() << ")"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of operand " << (idx + 1); } + return success(); } template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); - auto nInputViews = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); - auto nViews = op.getNumInputsAndOutputs(); - if (funType.getNumInputs() != nViews + nLoops) - return op.emitOpError( - "expected fun arguments to match number of views + number of loops"); + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands + nLoops) + return op.emitOpError("expected function arguments to match number of " + "loops + number of operands"); if (funType.getNumResults() != nOutputs) return op.emitOpError( - "expected fun results to match number of output views"); - for (unsigned i = 0; i < nLoops; ++i) { + "expected function results to match number of outputs"); + for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) - return op.emitOpError("expected fun argument ") - << i << " to be of IndexType"; - } - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(funIdx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << funIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + return op.emitOpError("expected function argument ") + << (i + 1) << " to be an index"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx + nLoops) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + nLoops + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); } + return success(); } @@ -231,9 +237,11 @@ template <typename GenericOpType> static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (nViews != llvm::size(op.views())) - return op.emitOpError("expected exactly ") << nViews << " view operands"; + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); @@ -246,8 +254,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) { } else { if (!funOp || !funOp.getType()) return op.emitOpError( - "expected fun attribute to refer to a defined symbol"); - if (failed(verifyFuncArgs(op, funType))) + "expected function attribute to refer to a defined symbol"); + if (failed(verifyFuncArgsGeneric(op, funType))) return failure(); } @@ -287,22 +295,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) { return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); - auto outputTensorTypes = op.getOutputTensorTypes(); - if (outputTensorTypes.size() != op.getNumResults()) - return op.emitOpError("expected #output tensor operands (") - << outputTensorTypes.size() << ") to match #results (" - << op.getNumResults() << ")"; - - unsigned index = 0; - for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { - auto resTy = std::get<0>(it); - auto outOpTy = std::get<1>(it); - if (resTy != outOpTy) - return op.emitOpError("result #") - << index << " must be " << outOpTy << ", but got " << resTy; - ++index; - } - return success(); } @@ -731,17 +723,20 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { template <typename GenericOpType> static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. - auto nOutputViews = genericOp.getNumOutputs(); - if (op.getNumOperands() != nOutputViews) - return op.emitOpError("expected ") - << nOutputViews << " operand to match enclosing linalg.generic op"; + auto nOutputs = genericOp.getNumOutputs(); + if (op.getNumOperands() != nOutputs) + return op.emitOpError("expected number of yield values (") + << nOutputs << ") to match the number of operands of the enclosing " + << "linalg.generic op (" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputViews; ++i) { + for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) - return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i).getType() - << ") doesn't match view element type (" << elementType << ")"; + return op.emitOpError("type of yield operand ") + << (i + 1) << " (" << op.getOperand(i).getType() + << ") doesn't match " + << "the element type of the enclosing linalg.generic op (" + << elementType << ")"; } return success(); } |

