diff options
Diffstat (limited to 'mlir/lib/Dialect')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 026b9757839..febaf28e4e9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -113,6 +113,89 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, parser.getNameLoc(), result.operands)); } +// <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...` +// : signature_type +static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 12> ops; + Type type; + llvm::SMLoc typeLoc; + if (parser.parseOperandList(ops) || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColon() || parser.getCurrentLocation(&typeLoc) || + parser.parseType(type)) { + return failure(); + } + + auto signature = type.dyn_cast<FunctionType>(); + if (!signature) { + return parser.emitError( + typeLoc, "expected the type to be the full list of input and output"); + } + + if (signature.getNumResults() != 1) { + return parser.emitError(typeLoc, "expected single result"); + } + + return failure(parser.addTypeToList(signature.getResult(0), result.types) || + parser.resolveOperands(ops, signature.getInputs(), + parser.getNameLoc(), result.operands)); +} + +static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) { + p << op.getOperationName() << " "; + p.printOperands(op.getOperands()); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " + << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()), + op.getType(), op.getContext()); +} + +static LogicalResult verify(MmaOp op) { + auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); + auto f16Ty = LLVM::LLVMType::getHalfTy(dialect); + auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2); + auto f32Ty = LLVM::LLVMType::getFloatTy(dialect); + auto f16x2x4StructTy = LLVM::LLVMType::getStructTy( + dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + auto f32x8StructTy = LLVM::LLVMType::getStructTy( + dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); + + SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) && + operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty}) { + return op.emitOpError( + "expected operands to be 4 <halfx2>s followed by either " + "4 <halfx2>s or 8 floats"); + } + if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) { + return op.emitOpError("expected result type to be a struct of either 4 " + "<halfx2>s or 8 floats"); + } + + auto alayout = op.getAttrOfType<StringAttr>("alayout"); + auto blayout = op.getAttrOfType<StringAttr>("blayout"); + + if (!(alayout && blayout) || + !(alayout.getValue() == "row" || alayout.getValue() == "col") || + !(blayout.getValue() == "row" || blayout.getValue() == "col")) { + return op.emitOpError( + "alayout and blayout attributes must be set to either " + "\"row\" or \"col\""); + } + + if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty} && + op.getType() == f32x8StructTy && alayout.getValue() == "row" && + blayout.getValue() == "row") { + return success(); + } + return op.emitOpError("unimplemented mma.sync variant"); +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// |

