summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td16
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td36
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td8
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp42
-rw-r--r--mlir/test/Dialect/SPIRV/Serialization/terminator.mlir21
-rw-r--r--mlir/test/Dialect/SPIRV/ops.mlir41
-rw-r--r--mlir/test/Dialect/SPIRV/structure-ops.mlir14
7 files changed, 158 insertions, 20 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 8f07fecb9f0..cf87bfd90cd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -128,6 +128,7 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>;
def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
+def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -146,7 +147,7 @@ def SPV_OpcodeAttr :
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
- SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn
+ SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@@ -778,12 +779,15 @@ def SPV_SamplerUseAttr:
// SPIR-V OpTrait definitions
//===----------------------------------------------------------------------===//
-// Check that an op can only be used with SPIR-V ModuleOp
-def IsModuleOnlyPred :
- CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">;
+// Check that an op can only be used within the scope of a FuncOp.
+def InFunctionScope : PredOpTrait<
+ "op must appear in a 'func' block",
+ CPred<"llvm::isa_and_nonnull<FuncOp>($_op.getParentOp())">>;
-def ModuleOnly :
- PredOpTrait<"op can only be used in a 'spv.module' block", IsModuleOnlyPred>;
+// Check that an op can only be used within the scope of a SPIR-V ModuleOp.
+def InModuleScope : PredOpTrait<
+ "op must appear in a 'spv.module' block",
+ CPred<"llvm::isa_and_nonnull<spirv::ModuleOp>($_op.getParentOp())">>;
//===----------------------------------------------------------------------===//
// SPIR-V op definitions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index de496a76d26..76bffde38df 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -146,7 +146,7 @@ def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
// -----
-def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
+def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
let summary = "Declare an execution mode for an entry point.";
let description = [{
@@ -599,7 +599,7 @@ def SPV_LoadOp : SPV_Op<"Load", []> {
// -----
-def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
+def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
let summary = "Return with no value from a function with void return type.";
let description = [{
@@ -624,6 +624,38 @@ def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
// -----
+def SPV_ReturnValueOp : SPV_Op<"ReturnValue", [InFunctionScope, Terminator]> {
+ let summary = "Return a value from a function.";
+
+ let description = [{
+ Value is the value returned, by copy, and must match the Return Type
+ operand of the OpTypeFunction type of the OpFunction body this return
+ instruction is in.
+
+ This instruction must be the last instruction in a block.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ return-value-op ::= `spv.ReturnValue` ssa-use `:` spirv-type
+ ```
+
+ For example:
+
+ ```
+ spv.ReturnValue %0 : f32
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_Type:$value
+ );
+
+ let results = (outs);
+}
+
+// -----
+
def SPV_SDivOp : SPV_ArithmeticOp<"SDiv", SPV_Integer> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index d4756390742..292e148c86f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -30,7 +30,7 @@
include "mlir/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE
-def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
+def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
let summary = "Get the address of a global variable.";
let description = [{
@@ -66,7 +66,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
let hasOpcode = 0;
}
-def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
+def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
let summary = [{
Declare an entry point, its execution model, and its interface.
}];
@@ -122,7 +122,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
}
-def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> {
+def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope]> {
let summary = [{
Allocate an object in memory at module scope. The object is
referenced using a symbol name.
@@ -264,7 +264,7 @@ def SPV_ModuleOp : SPV_Op<"module",
}];
}
-def SPV_ModuleEndOp : SPV_Op<"_module_end", [Terminator, ModuleOnly]> {
+def SPV_ModuleEndOp : SPV_Op<"_module_end", [InModuleScope, Terminator]> {
let summary = "The pseudo op that ends a SPIR-V module";
let description = [{
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 9947c0254a9..9a7f3594551 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1042,10 +1042,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
//===----------------------------------------------------------------------===//
static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
- auto funcOp = llvm::dyn_cast<FuncOp>(returnOp.getOperation()->getParentOp());
- if (!funcOp)
- return returnOp.emitOpError("must appear in a 'func' op");
-
+ auto funcOp = llvm::cast<FuncOp>(returnOp.getParentOp());
auto numOutputs = funcOp.getType().getNumResults();
if (numOutputs != 0)
return returnOp.emitOpError("cannot be used in functions returning value")
@@ -1055,6 +1052,43 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
}
//===----------------------------------------------------------------------===//
+// spv.ReturnValue
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnValueOp(OpAsmParser *parser,
+ OperationState *state) {
+ OpAsmParser::OperandType retValInfo;
+ Type retValType;
+ return failure(
+ parser->parseOperand(retValInfo) || parser->parseColonType(retValType) ||
+ parser->resolveOperand(retValInfo, retValType, state->operands));
+}
+
+static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) {
+ *printer << spirv::ReturnValueOp::getOperationName() << ' ';
+ printer->printOperand(retValOp.value());
+ *printer << " : " << retValOp.value()->getType();
+}
+
+static LogicalResult verify(spirv::ReturnValueOp retValOp) {
+ auto funcOp = llvm::cast<FuncOp>(retValOp.getParentOp());
+ auto numFnResults = funcOp.getType().getNumResults();
+ if (numFnResults != 1)
+ return retValOp.emitOpError(
+ "returns 1 value but enclosing function requires ")
+ << numFnResults << " results";
+
+ auto operandType = retValOp.value()->getType();
+ auto fnResultType = funcOp.getType().getResult(0);
+ if (operandType != fnResultType)
+ return retValOp.emitOpError(" return value's type (")
+ << operandType << ") mismatch with function's result type ("
+ << fnResultType << ")";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir b/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir
new file mode 100644
index 00000000000..35d2f972b55
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Serialization/terminator.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @spirv_terminator() -> () {
+ spv.module "Logical" "GLSL450" {
+ // CHECK-LABEL: @ret
+ func @ret() -> () {
+ // CHECK: spv.Return
+ spv.Return
+ }
+
+ // CHECK-LABEL: @ret_val
+ func @ret_val() -> (i32) {
+ %0 = spv.Variable : !spv.ptr<i32, Function>
+ %1 = spv.Load "Function" %0 : i32
+ // CHECK: spv.ReturnValue {{.*}} : i32
+ spv.ReturnValue %1 : i32
+ }
+ }
+ return
+}
+
diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index 052dc687167..167b6d81343 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -327,7 +327,7 @@ spv.module "Logical" "VulkanKHR" {
spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
- // expected-error @+1 {{'spv.EntryPoint' op failed to verify that op can only be used in a 'spv.module' block}}
+ // expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.EntryPoint "GLCompute" @do_something
}
}
@@ -451,7 +451,7 @@ spv.module "Logical" "VulkanKHR" {
spv.module "Logical" "VulkanKHR" {
func @foo() {
- // expected-error @+1 {{op failed to verify that op can only be used in a 'spv.module' block}}
+ // expected-error @+1 {{op must appear in a 'spv.module' block}}
spv.globalVariable !spv.ptr<f32, Input> @var0
spv.Return
}
@@ -767,7 +767,7 @@ spv.module "Logical" "VulkanKHR" {
//===----------------------------------------------------------------------===//
"foo.function"() ({
- // expected-error @+1 {{must appear in a 'func' op}}
+ // expected-error @+1 {{op must appear in a 'func' block}}
spv.Return
}) : () -> ()
@@ -784,6 +784,41 @@ spv.module "Logical" "VulkanKHR" {
// -----
//===----------------------------------------------------------------------===//
+// spv.ReturnValue
+//===----------------------------------------------------------------------===//
+
+func @ret_val() -> (i32) {
+ %0 = spv.constant 42 : i32
+ spv.ReturnValue %0 : i32
+}
+
+// -----
+
+"foo.function"() ({
+ %0 = spv.constant true
+ // expected-error @+1 {{op must appear in a 'func' block}}
+ spv.ReturnValue %0 : i1
+}) : () -> ()
+
+// -----
+
+func @value_count_mismatch() -> () {
+ %0 = spv.constant 42 : i32
+ // expected-error @+1 {{op returns 1 value but enclosing function requires 0 results}}
+ spv.ReturnValue %0 : i32
+}
+
+// -----
+
+func @value_type_mismatch() -> (f32) {
+ %0 = spv.constant 42 : i32
+ // expected-error @+1 {{return value's type ('i32') mismatch with function's result type ('f32')}}
+ spv.ReturnValue %0 : i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spv.SDiv
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index db51b175b03..e398be6656e 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -1,6 +1,18 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
+// spv._address_of
+//===----------------------------------------------------------------------===//
+
+spv.module "Logical" "GLSL450" {
+ spv.globalVariable !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input> @var
+ // expected-error @+1 {{op must appear in a 'func' block}}
+ %1 = spv._address_of @var : !spv.ptr<!spv.struct<f32, !spv.array<4xf32>>, Input>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
@@ -171,6 +183,6 @@ spv.module "Logical" "VulkanKHR" {
//===----------------------------------------------------------------------===//
func @module_end_not_in_module() -> () {
- // expected-error @+1 {{can only be used in a 'spv.module' block}}
+ // expected-error @+1 {{op must appear in a 'spv.module' block}}
spv._module_end
}
OpenPOWER on IntegriCloud