summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp42
1 files changed, 38 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 9947c0254a9..9a7f3594551 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1042,10 +1042,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
//===----------------------------------------------------------------------===//
static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
- auto funcOp = llvm::dyn_cast<FuncOp>(returnOp.getOperation()->getParentOp());
- if (!funcOp)
- return returnOp.emitOpError("must appear in a 'func' op");
-
+ auto funcOp = llvm::cast<FuncOp>(returnOp.getParentOp());
auto numOutputs = funcOp.getType().getNumResults();
if (numOutputs != 0)
return returnOp.emitOpError("cannot be used in functions returning value")
@@ -1055,6 +1052,43 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) {
}
//===----------------------------------------------------------------------===//
+// spv.ReturnValue
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseReturnValueOp(OpAsmParser *parser,
+ OperationState *state) {
+ OpAsmParser::OperandType retValInfo;
+ Type retValType;
+ return failure(
+ parser->parseOperand(retValInfo) || parser->parseColonType(retValType) ||
+ parser->resolveOperand(retValInfo, retValType, state->operands));
+}
+
+static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) {
+ *printer << spirv::ReturnValueOp::getOperationName() << ' ';
+ printer->printOperand(retValOp.value());
+ *printer << " : " << retValOp.value()->getType();
+}
+
+static LogicalResult verify(spirv::ReturnValueOp retValOp) {
+ auto funcOp = llvm::cast<FuncOp>(retValOp.getParentOp());
+ auto numFnResults = funcOp.getType().getNumResults();
+ if (numFnResults != 1)
+ return retValOp.emitOpError(
+ "returns 1 value but enclosing function requires ")
+ << numFnResults << " results";
+
+ auto operandType = retValOp.value()->getType();
+ auto fnResultType = funcOp.getType().getResult(0);
+ if (operandType != fnResultType)
+ return retValOp.emitOpError(" return value's type (")
+ << operandType << ") mismatch with function's result type ("
+ << fnResultType << ")";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.StoreOp
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud