diff options
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 187 |
1 files changed, 91 insertions, 96 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 384d24957f4..7850dd60be1 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -112,19 +112,20 @@ template <typename GenericOpType> static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (block.getNumArguments() != nViews) - return op.emitOpError( - "expected number of block arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands) + return op.emitOpError("expected number of block arguments to match number " + "of operands"); - for (unsigned i = 0; i < nViews; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + auto nInputViews = op.getNumInputs(); + for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") - << i << " of the same type as elemental type of " + << (i + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (block.getNumArguments() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( - "expected number of block arguments to match number of views + " + "expected number of block arguments to match number of operands + " "number of loops"); - for (unsigned i = 0; i < nLoops; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") - << i << " to be of IndexType"; - } + << (i + 1) << " to be an index"; - for (unsigned i = 0; i < nViews; ++i) { + for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") - << memrefArgIndex << " of the same type as elemental type of " + << (memrefArgIndex + 1) + << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -160,70 +162,74 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { template <typename GenericOpType> static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); +template <typename GenericOpType> +LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) { + auto res = verifyFuncArgs(op, funType); + if (failed(res)) + return res; + + auto nInputs = op.getNumInputs(); + auto nOutputs = op.getNumOutputs(); + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected function result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); + } + return success(); +} + template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (funType.getNumInputs() != nViews) - return op.emitOpError("expected fun arguments to match number of views"); - if (funType.getNumResults() != op.getNumOutputs()) + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands) return op.emitOpError( - "expected fun results to match number of output views"); - - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(idx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << idx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + "expected function arguments to match number of operands"); + if (funType.getNumResults() != op.getNumOutputs()) + return op.emitOpError("expected function results(") + << funType.getNumResults() << ") to match number of outputs(" + << op.getNumOutputs() << ")"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of operand " << (idx + 1); } + return success(); } template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); - auto nInputViews = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); - auto nViews = op.getNumInputsAndOutputs(); - if (funType.getNumInputs() != nViews + nLoops) - return op.emitOpError( - "expected fun arguments to match number of views + number of loops"); + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands + nLoops) + return op.emitOpError("expected function arguments to match number of " + "loops + number of operands"); if (funType.getNumResults() != nOutputs) return op.emitOpError( - "expected fun results to match number of output views"); - for (unsigned i = 0; i < nLoops; ++i) { + "expected function results to match number of outputs"); + for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) - return op.emitOpError("expected fun argument ") - << i << " to be of IndexType"; - } - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(funIdx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << funIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + return op.emitOpError("expected function argument ") + << (i + 1) << " to be an index"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx + nLoops) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + nLoops + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); } + return success(); } @@ -231,9 +237,11 @@ template <typename GenericOpType> static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (nViews != llvm::size(op.views())) - return op.emitOpError("expected exactly ") << nViews << " view operands"; + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); @@ -246,8 +254,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) { } else { if (!funOp || !funOp.getType()) return op.emitOpError( - "expected fun attribute to refer to a defined symbol"); - if (failed(verifyFuncArgs(op, funType))) + "expected function attribute to refer to a defined symbol"); + if (failed(verifyFuncArgsGeneric(op, funType))) return failure(); } @@ -287,22 +295,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) { return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); - auto outputTensorTypes = op.getOutputTensorTypes(); - if (outputTensorTypes.size() != op.getNumResults()) - return op.emitOpError("expected #output tensor operands (") - << outputTensorTypes.size() << ") to match #results (" - << op.getNumResults() << ")"; - - unsigned index = 0; - for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { - auto resTy = std::get<0>(it); - auto outOpTy = std::get<1>(it); - if (resTy != outOpTy) - return op.emitOpError("result #") - << index << " must be " << outOpTy << ", but got " << resTy; - ++index; - } - return success(); } @@ -731,17 +723,20 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { template <typename GenericOpType> static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. - auto nOutputViews = genericOp.getNumOutputs(); - if (op.getNumOperands() != nOutputViews) - return op.emitOpError("expected ") - << nOutputViews << " operand to match enclosing linalg.generic op"; + auto nOutputs = genericOp.getNumOutputs(); + if (op.getNumOperands() != nOutputs) + return op.emitOpError("expected number of yield values (") + << nOutputs << ") to match the number of operands of the enclosing " + << "linalg.generic op (" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputViews; ++i) { + for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) - return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i).getType() - << ") doesn't match view element type (" << elementType << ")"; + return op.emitOpError("type of yield operand ") + << (i + 1) << " (" << op.getOperand(i).getType() + << ") doesn't match " + << "the element type of the enclosing linalg.generic op (" + << elementType << ")"; } return success(); } |