diff options
| -rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 30 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td | 52 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 67 | ||||
| -rw-r--r-- | mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir | 7 | ||||
| -rw-r--r-- | mlir/test/Dialect/SPIRV/composite-ops.mlir | 52 |
5 files changed, 193 insertions, 15 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 62095a518e9..8368a626ffc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -1075,6 +1075,7 @@ def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; +def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; @@ -1171,20 +1172,21 @@ def SPV_OpcodeAttr : SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, - SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, - SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, - SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, - SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, - SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, - SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU, + SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td index 71650504741..6392a1b52e8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td @@ -25,6 +25,58 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td" +// ----- + +def SPV_CompositeConstructOp : SPV_Op<"CompositeConstruct", [NoSideEffect]> { + let summary = [{ + Construct a new composite object from a set of constituent objects that + will fully form it. + }]; + + let description = [{ + Result Type must be a composite type, whose top-level + members/elements/components/columns have the same type as the types of + the operands, with one exception. The exception is that for constructing + a vector, the operands may also be vectors with the same component type + as the Result Type component type. When constructing a vector, the total + number of components in all the operands must equal the number of + components in Result Type. + + Constituents will become members of a structure, or elements of an + array, or components of a vector, or columns of a matrix. There must be + exactly one Constituent for each top-level + member/element/component/column of the result, with one exception. The + exception is that for constructing a vector, a contiguous subset of the + scalars consumed can be represented by a vector operand instead. The + Constituents must appear in the order needed by the definition of the + type of the result. When constructing a vector, there must be at least + two Constituent operands. + + ### Custom assembly form + + ``` {.ebnf} + composite-construct-op ::= ssa-id `=` `spv.CompositeConstruct` + (ssa-use (`,` ssa-use)* )? `:` composite-type + ``` + + For example: + + ``` + %0 = spv.CompositeConstruct %1, %2, %3 : vector<3xf32> + ``` + }]; + + let arguments = (ins + Variadic<SPV_Type>:$constituents + ); + + let results = (outs + SPV_Composite:$result + ); +} + +// ----- + def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { let summary = "Extract a part of a composite object."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 40096919c4b..f1fc80b4cc1 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1070,6 +1070,73 @@ static LogicalResult verify(spirv::BranchConditionalOp branchOp) { } //===----------------------------------------------------------------------===// +// spv.CompositeConstruct +//===----------------------------------------------------------------------===// + +static ParseResult parseCompositeConstructOp(OpAsmParser &parser, + OperationState &state) { + SmallVector<OpAsmParser::OperandType, 4> operands; + Type type; + auto loc = parser.getCurrentLocation(); + + if (parser.parseOperandList(operands) || parser.parseColonType(type)) { + return failure(); + } + auto cType = type.dyn_cast<spirv::CompositeType>(); + if (!cType) { + return parser.emitError( + loc, "result type must be a composite type, but provided ") + << type; + } + + if (operands.size() != cType.getNumElements()) { + return parser.emitError(loc, "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " << operands.size(); + } + // TODO: Add support for constructing a vector type from the vector operands. + // According to the spec: "for constructing a vector, the operands may + // also be vectors with the same component type as the Result Type component + // type". + SmallVector<Type, 4> elementTypes; + elementTypes.reserve(cType.getNumElements()); + for (auto index : llvm::seq<uint32_t>(0, cType.getNumElements())) { + elementTypes.push_back(cType.getElementType(index)); + } + state.addTypes(type); + return parser.resolveOperands(operands, elementTypes, loc, state.operands); +} + +static void print(spirv::CompositeConstructOp compositeConstructOp, + OpAsmPrinter &printer) { + printer << spirv::CompositeConstructOp::getOperationName() << " "; + printer.printOperands(compositeConstructOp.constituents()); + printer << " : " << compositeConstructOp.getResult()->getType(); +} + +static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { + auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>(); + + SmallVector<Value *, 4> constituents(compositeConstructOp.constituents()); + if (constituents.size() != cType.getNumElements()) { + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); + } + + for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { + if (constituents[index]->getType() != cType.getElementType(index)) { + return compositeConstructOp.emitError( + "operand type mismatch: expected operand type ") + << cType.getElementType(index) << ", but provided " + << constituents[index]->getType(); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir index a3f74ca02cd..ba01cc84f0c 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir @@ -2,8 +2,13 @@ spv.module "Logical" "GLSL450" { func @composite_insert(%arg0 : !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>>, %arg1: !spv.array<4xf32>) -> !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> { - // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>> + // CHECK: spv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32, 0 : i32] : !spv.array<4 x f32> into !spv.struct<f32, !spv.struct<!spv.array<4 x f32>, f32>> %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> spv.ReturnValue %0: !spv.struct<f32, !spv.struct<!spv.array<4xf32>, f32>> } + func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32> + spv.ReturnValue %0: vector<3xf32> + } } diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir index 353080c8cc2..4ce89748a09 100644 --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -1,6 +1,58 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// +// spv.CompositeConstruct +//===----------------------------------------------------------------------===// + +func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + +func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<f32>) -> !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>> { + // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<vector<3xf32>, !spv.array<4 x f32>, !spv.struct<f32>> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>> + return %0: !spv.struct<vector<3xf32>, !spv.array<4xf32>, !spv.struct<f32>> +} + +// ----- + +func @composite_construct_empty_struct() -> !spv.struct<> { + // CHECK: spv.CompositeConstruct : !spv.struct<> + %0 = spv.CompositeConstruct : !spv.struct<> + return %0: !spv.struct<> +} + +// ----- + +func @composite_construct_invalid_num_of_elements(%arg0: f32) -> f32 { + // expected-error @+1 {{result type must be a composite type, but provided 'f32'}} + %0 = spv.CompositeConstruct %arg0 : f32 + return %0: f32 +} + +// ----- + +func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { + // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}} + %0 = spv.CompositeConstruct %arg0, %arg2 : vector<3xf32> + return %0: vector<3xf32> +} + +// ----- + +func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> { + // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}} + %0 = "spv.CompositeConstruct" (%arg0, %arg1, %arg2) : (f32, f32, f32) -> vector<3xi32> + return %0: vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// |

