diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 9947c0254a9..9a7f3594551 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1042,10 +1042,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { //===----------------------------------------------------------------------===// static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { - auto funcOp = llvm::dyn_cast<FuncOp>(returnOp.getOperation()->getParentOp()); - if (!funcOp) - return returnOp.emitOpError("must appear in a 'func' op"); - + auto funcOp = llvm::cast<FuncOp>(returnOp.getParentOp()); auto numOutputs = funcOp.getType().getNumResults(); if (numOutputs != 0) return returnOp.emitOpError("cannot be used in functions returning value") @@ -1055,6 +1052,43 @@ static LogicalResult verifyReturn(spirv::ReturnOp returnOp) { } //===----------------------------------------------------------------------===// +// spv.ReturnValue +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnValueOp(OpAsmParser *parser, + OperationState *state) { + OpAsmParser::OperandType retValInfo; + Type retValType; + return failure( + parser->parseOperand(retValInfo) || parser->parseColonType(retValType) || + parser->resolveOperand(retValInfo, retValType, state->operands)); +} + +static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter *printer) { + *printer << spirv::ReturnValueOp::getOperationName() << ' '; + printer->printOperand(retValOp.value()); + *printer << " : " << retValOp.value()->getType(); +} + +static LogicalResult verify(spirv::ReturnValueOp retValOp) { + auto funcOp = llvm::cast<FuncOp>(retValOp.getParentOp()); + auto numFnResults = funcOp.getType().getNumResults(); + if (numFnResults != 1) + return retValOp.emitOpError( + "returns 1 value but enclosing function requires ") + << numFnResults << " results"; + + auto operandType = retValOp.value()->getType(); + auto fnResultType = funcOp.getType().getResult(0); + if (operandType != fnResultType) + return retValOp.emitOpError(" return value's type (") + << operandType << ") mismatch with function's result type (" + << fnResultType << ")"; + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.StoreOp //===----------------------------------------------------------------------===// |