diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 40096919c4b..f1fc80b4cc1 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1070,6 +1070,73 @@ static LogicalResult verify(spirv::BranchConditionalOp branchOp) { } //===----------------------------------------------------------------------===// +// spv.CompositeConstruct +//===----------------------------------------------------------------------===// + +static ParseResult parseCompositeConstructOp(OpAsmParser &parser, + OperationState &state) { + SmallVector<OpAsmParser::OperandType, 4> operands; + Type type; + auto loc = parser.getCurrentLocation(); + + if (parser.parseOperandList(operands) || parser.parseColonType(type)) { + return failure(); + } + auto cType = type.dyn_cast<spirv::CompositeType>(); + if (!cType) { + return parser.emitError( + loc, "result type must be a composite type, but provided ") + << type; + } + + if (operands.size() != cType.getNumElements()) { + return parser.emitError(loc, "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " << operands.size(); + } + // TODO: Add support for constructing a vector type from the vector operands. + // According to the spec: "for constructing a vector, the operands may + // also be vectors with the same component type as the Result Type component + // type". + SmallVector<Type, 4> elementTypes; + elementTypes.reserve(cType.getNumElements()); + for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) { + elementTypes.push_back(cType.getElementType(index)); + } + state.addTypes(type); + return parser.resolveOperands(operands, elementTypes, loc, state.operands); +} + +static void print(spirv::CompositeConstructOp compositeConstructOp, + OpAsmPrinter &printer) { + printer << spirv::CompositeConstructOp::getOperationName() << " "; + printer.printOperands(compositeConstructOp.constituents()); + printer << " : " << compositeConstructOp.getResult()->getType(); +} + +static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { + auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>(); + + SmallVector<Value *, 4> constituents(compositeConstructOp.constituents()); + if (constituents.size() != cType.getNumElements()) { + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); + } + + for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { + if (constituents[index]->getType() != cType.getElementType(index)) { + return compositeConstructOp.emitError( + "operand type mismatch: expected operand type ") + << cType.getElementType(index) << ", but provided " + << constituents[index]->getType(); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// |

