summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp83
1 files changed, 83 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 026b9757839..febaf28e4e9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -113,6 +113,89 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
parser.getNameLoc(), result.operands));
}
+// <operation> ::= `llvm.nvvm.mma.sync %lhs... %rhs... %acc...`
+// : signature_type
+static ParseResult parseNVVMMmaOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 12> ops;
+ Type type;
+ llvm::SMLoc typeLoc;
+ if (parser.parseOperandList(ops) ||
+ parser.parseOptionalAttributeDict(result.attributes) ||
+ parser.parseColon() || parser.getCurrentLocation(&typeLoc) ||
+ parser.parseType(type)) {
+ return failure();
+ }
+
+ auto signature = type.dyn_cast<FunctionType>();
+ if (!signature) {
+ return parser.emitError(
+ typeLoc, "expected the type to be the full list of input and output");
+ }
+
+ if (signature.getNumResults() != 1) {
+ return parser.emitError(typeLoc, "expected single result");
+ }
+
+ return failure(parser.addTypeToList(signature.getResult(0), result.types) ||
+ parser.resolveOperands(ops, signature.getInputs(),
+ parser.getNameLoc(), result.operands));
+}
+
+static void printNVVMMmaOp(OpAsmPrinter &p, MmaOp &op) {
+ p << op.getOperationName() << " ";
+ p.printOperands(op.getOperands());
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : "
+ << FunctionType::get(llvm::to_vector<12>(op.getOperandTypes()),
+ op.getType(), op.getContext());
+}
+
+static LogicalResult verify(MmaOp op) {
+ auto dialect = op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+ auto f16Ty = LLVM::LLVMType::getHalfTy(dialect);
+ auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
+ auto f32Ty = LLVM::LLVMType::getFloatTy(dialect);
+ auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
+ dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
+ auto f32x8StructTy = LLVM::LLVMType::getStructTy(
+ dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
+
+ SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
+ op.getOperandTypes().end());
+ if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
+ operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty}) {
+ return op.emitOpError(
+ "expected operands to be 4 <halfx2>s followed by either "
+ "4 <halfx2>s or 8 floats");
+ }
+ if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
+ return op.emitOpError("expected result type to be a struct of either 4 "
+ "<halfx2>s or 8 floats");
+ }
+
+ auto alayout = op.getAttrOfType<StringAttr>("alayout");
+ auto blayout = op.getAttrOfType<StringAttr>("blayout");
+
+ if (!(alayout && blayout) ||
+ !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
+ !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
+ return op.emitOpError(
+ "alayout and blayout attributes must be set to either "
+ "\"row\" or \"col\"");
+ }
+
+ if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
+ f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
+ f32Ty, f32Ty, f32Ty} &&
+ op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
+ blayout.getValue() == "row") {
+ return success();
+ }
+ return op.emitOpError("unimplemented mma.sync variant");
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud