summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp187
1 files changed, 91 insertions, 96 deletions
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 384d24957f4..7850dd60be1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -112,19 +112,20 @@ template <typename GenericOpType>
static LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
- auto nViews = op.getNumInputsAndOutputs();
- auto nInputViews = op.getNumInputs();
- if (block.getNumArguments() != nViews)
- return op.emitOpError(
- "expected number of block arguments to match number of views");
+ auto nOperands = op.getNumOperands();
+ if (block.getNumArguments() != nOperands)
+ return op.emitOpError("expected number of block arguments to match number "
+ "of operands");
- for (unsigned i = 0; i < nViews; ++i) {
+ // Note: the number and type of yield values are checked in the YieldOp.
+ auto nInputViews = op.getNumInputs();
+ for (unsigned i = 0; i < nOperands; ++i) {
auto viewType = op.getShapedType(i);
if (viewType.getElementType() != block.getArgument(i).getType())
return op.emitOpError("expected block argument ")
- << i << " of the same type as elemental type of "
+ << (i + 1) << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
- << "view: " << viewType;
+ << "operand: " << viewType;
}
return success();
}
@@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
- auto nViews = op.getNumInputsAndOutputs();
- if (block.getNumArguments() != nViews + nLoops)
+ auto nOperands = op.getNumOperands();
+ if (block.getNumArguments() != nOperands + nLoops)
return op.emitOpError(
- "expected number of block arguments to match number of views + "
+ "expected number of block arguments to match number of operands + "
"number of loops");
- for (unsigned i = 0; i < nLoops; ++i) {
+ // Note: the number and type of yield values are checked in the YieldOp.
+ for (unsigned i = 0; i < nLoops; ++i)
if (!block.getArgument(i).getType().isIndex())
return op.emitOpError("expected block argument ")
- << i << " to be of IndexType";
- }
+ << (i + 1) << " to be an index";
- for (unsigned i = 0; i < nViews; ++i) {
+ for (unsigned i = 0; i < nOperands; ++i) {
unsigned memrefArgIndex = i + nLoops;
auto viewType = op.getShapedType(i);
if (viewType.getElementType() !=
block.getArgument(memrefArgIndex).getType())
return op.emitOpError("expected block argument ")
- << memrefArgIndex << " of the same type as elemental type of "
+ << (memrefArgIndex + 1)
+ << " of the same type as elemental type of "
<< ((i < nInputViews) ? "input " : "output ")
- << "view: " << viewType;
+ << "operand: " << viewType;
}
return success();
}
@@ -160,70 +162,74 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
template <typename GenericOpType>
static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
+template <typename GenericOpType>
+LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) {
+ auto res = verifyFuncArgs(op, funType);
+ if (failed(res))
+ return res;
+
+ auto nInputs = op.getNumInputs();
+ auto nOutputs = op.getNumOutputs();
+ // linalg.generic output element types are exactly the function results.
+ for (unsigned idx = 0; idx < nOutputs; ++idx) {
+ ShapedType shapedType = op.getShapedType(nInputs + idx);
+ if (funType.getResult(idx) != shapedType.getElementType())
+ return op.emitOpError("expected function result ")
+ << (idx + 1) << " of the same type as elemental type "
+ << shapedType.getElementType() << " of output " << (idx + 1);
+ }
+ return success();
+}
+
template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
- auto nViews = op.getNumInputsAndOutputs();
- auto nInputViews = op.getNumInputs();
- if (funType.getNumInputs() != nViews)
- return op.emitOpError("expected fun arguments to match number of views");
- if (funType.getNumResults() != op.getNumOutputs())
+ auto nOperands = op.getNumOperands();
+ if (funType.getNumInputs() != nOperands)
return op.emitOpError(
- "expected fun results to match number of output views");
-
- for (auto en : llvm::enumerate(op.indexing_maps())) {
- auto idx = en.index();
- auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
- : op.getOutputShapedType(idx - nInputViews);
- if (funType.getInput(idx) != view.getElementType())
- return op.emitOpError("expected fun argument ")
- << idx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
-
- if (idx >= nInputViews) {
- auto resultIdx = idx - nInputViews;
- if (funType.getResult(resultIdx) != view.getElementType())
- return op.emitOpError("expected fun result ")
- << resultIdx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
- }
+ "expected function arguments to match number of operands");
+ if (funType.getNumResults() != op.getNumOutputs())
+ return op.emitOpError("expected function results(")
+ << funType.getNumResults() << ") to match number of outputs("
+ << op.getNumOutputs() << ")";
+
+ // linalg.generic operands element types are exactly the first function
+ // arguments.
+ for (unsigned idx = 0; idx < nOperands; ++idx) {
+ ShapedType shapedType = op.getShapedType(idx);
+ if (funType.getInput(idx) != shapedType.getElementType())
+ return op.emitOpError("expected function argument ")
+ << (idx + 1) << " of the same type as elemental type "
+ << shapedType.getElementType() << " of operand " << (idx + 1);
}
+
return success();
}
template <>
LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
auto nLoops = op.getNumLoops();
- auto nInputViews = op.getNumInputs();
auto nOutputs = op.getNumOutputs();
- auto nViews = op.getNumInputsAndOutputs();
- if (funType.getNumInputs() != nViews + nLoops)
- return op.emitOpError(
- "expected fun arguments to match number of views + number of loops");
+ auto nOperands = op.getNumOperands();
+ if (funType.getNumInputs() != nOperands + nLoops)
+ return op.emitOpError("expected function arguments to match number of "
+ "loops + number of operands");
if (funType.getNumResults() != nOutputs)
return op.emitOpError(
- "expected fun results to match number of output views");
- for (unsigned i = 0; i < nLoops; ++i) {
+ "expected function results to match number of outputs");
+ for (unsigned i = 0; i < nLoops; ++i)
if (!funType.getInput(i).isIndex())
- return op.emitOpError("expected fun argument ")
- << i << " to be of IndexType";
- }
- for (auto en : llvm::enumerate(op.indexing_maps())) {
- auto idx = en.index();
- auto funIdx = nLoops + idx;
- auto view = (idx < nInputViews) ? op.getInputShapedType(idx)
- : op.getOutputShapedType(idx - nInputViews);
- if (funType.getInput(funIdx) != view.getElementType())
- return op.emitOpError("expected fun argument ")
- << funIdx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
-
- if (idx >= nInputViews) {
- auto resultIdx = idx - nInputViews;
- if (funType.getResult(resultIdx) != view.getElementType())
- return op.emitOpError("expected fun result ")
- << resultIdx << " of the same type as elemental type "
- << view.getElementType() << " of view " << idx;
- }
+ return op.emitOpError("expected function argument ")
+ << (i + 1) << " to be an index";
+
+ // linalg.generic operands element types are exactly the first function
+ // arguments.
+ for (unsigned idx = 0; idx < nOperands; ++idx) {
+ ShapedType shapedType = op.getShapedType(idx);
+ if (funType.getInput(idx + nLoops) != shapedType.getElementType())
+ return op.emitOpError("expected function argument ")
+ << (idx + nLoops + 1) << " of the same type as elemental type "
+ << shapedType.getElementType() << " of input " << (idx + 1);
}
+
return success();
}
@@ -231,9 +237,11 @@ template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
- auto nViews = op.getNumInputsAndOutputs();
- if (nViews != llvm::size(op.views()))
- return op.emitOpError("expected exactly ") << nViews << " view operands";
+ auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+ if (nInputsAndOutputBuffers != llvm::size(op.views()))
+ return op.emitOpError("expected exactly ")
+ << nInputsAndOutputBuffers
+ << " inputs (tensor or buffer) and output buffer operands";
auto &region = op.region();
auto funOp = op.getFunction();
@@ -246,8 +254,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
} else {
if (!funOp || !funOp.getType())
return op.emitOpError(
- "expected fun attribute to refer to a defined symbol");
- if (failed(verifyFuncArgs(op, funType)))
+ "expected function attribute to refer to a defined symbol");
+ if (failed(verifyFuncArgsGeneric(op, funType)))
return failure();
}
@@ -287,22 +295,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
return op.emitOpError("expected the concatenation of maps in indexing_map "
"to be invertible");
- auto outputTensorTypes = op.getOutputTensorTypes();
- if (outputTensorTypes.size() != op.getNumResults())
- return op.emitOpError("expected #output tensor operands (")
- << outputTensorTypes.size() << ") to match #results ("
- << op.getNumResults() << ")";
-
- unsigned index = 0;
- for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) {
- auto resTy = std::get<0>(it);
- auto outOpTy = std::get<1>(it);
- if (resTy != outOpTy)
- return op.emitOpError("result #")
- << index << " must be " << outOpTy << ", but got " << resTy;
- ++index;
- }
-
return success();
}
@@ -731,17 +723,20 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
template <typename GenericOpType>
static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
// The operand number and types must match the view element types.
- auto nOutputViews = genericOp.getNumOutputs();
- if (op.getNumOperands() != nOutputViews)
- return op.emitOpError("expected ")
- << nOutputViews << " operand to match enclosing linalg.generic op";
+ auto nOutputs = genericOp.getNumOutputs();
+ if (op.getNumOperands() != nOutputs)
+ return op.emitOpError("expected number of yield values (")
+ << nOutputs << ") to match the number of operands of the enclosing "
+ << "linalg.generic op (" << op.getNumOperands() << ")";
- for (unsigned i = 0; i != nOutputViews; ++i) {
+ for (unsigned i = 0; i != nOutputs; ++i) {
auto elementType = genericOp.getOutputShapedType(i).getElementType();
if (op.getOperand(i).getType() != elementType)
- return op.emitOpError("type of return operand ")
- << i << " (" << op.getOperand(i).getType()
- << ") doesn't match view element type (" << elementType << ")";
+ return op.emitOpError("type of yield operand ")
+ << (i + 1) << " (" << op.getOperand(i).getType()
+ << ") doesn't match "
+ << "the element type of the enclosing linalg.generic op ("
+ << elementType << ")";
}
return success();
}
OpenPOWER on IntegriCloud