summaryrefslogtreecommitdiffstats
path: root/clang/lib/CodeGen
diff options
context:
space:
mode:
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r--clang/lib/CodeGen/CGBuiltin.cpp543
1 files changed, 333 insertions, 210 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 02969a1908b..cfdb69c0798 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -12925,8 +12925,252 @@ Value *CodeGenFunction::EmitSystemZBuiltinExpr(unsigned BuiltinID,
}
}
-Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
- const CallExpr *E) {
+namespace {
+// Helper classes for mapping MMA builtins to particular LLVM intrinsic variant.
+struct NVPTXMmaLdstInfo {
+ unsigned NumResults; // Number of elements to load/store
+ // Intrinsic IDs for row/col variants. 0 if particular layout is unsupported.
+ unsigned IID_col;
+ unsigned IID_row;
+};
+
+#define MMA_INTR(geom_op_type, layout) \
+ Intrinsic::nvvm_wmma_##geom_op_type##_##layout##_stride
+#define MMA_LDST(n, geom_op_type) \
+ { n, MMA_INTR(geom_op_type, col), MMA_INTR(geom_op_type, row) }
+
+static NVPTXMmaLdstInfo getNVPTXMmaLdstInfo(unsigned BuiltinID) {
+ switch (BuiltinID) {
+ // FP MMA loads
+ case NVPTX::BI__hmma_m16n16k16_ld_a:
+ return MMA_LDST(8, m16n16k16_load_a_f16);
+ case NVPTX::BI__hmma_m16n16k16_ld_b:
+ return MMA_LDST(8, m16n16k16_load_b_f16);
+ case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
+ return MMA_LDST(4, m16n16k16_load_c_f16);
+ case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
+ return MMA_LDST(8, m16n16k16_load_c_f32);
+ case NVPTX::BI__hmma_m32n8k16_ld_a:
+ return MMA_LDST(8, m32n8k16_load_a_f16);
+ case NVPTX::BI__hmma_m32n8k16_ld_b:
+ return MMA_LDST(8, m32n8k16_load_b_f16);
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+ return MMA_LDST(4, m32n8k16_load_c_f16);
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+ return MMA_LDST(8, m32n8k16_load_c_f32);
+ case NVPTX::BI__hmma_m8n32k16_ld_a:
+ return MMA_LDST(8, m8n32k16_load_a_f16);
+ case NVPTX::BI__hmma_m8n32k16_ld_b:
+ return MMA_LDST(8, m8n32k16_load_b_f16);
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+ return MMA_LDST(4, m8n32k16_load_c_f16);
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f32:
+ return MMA_LDST(8, m8n32k16_load_c_f32);
+
+ // Integer MMA loads
+ case NVPTX::BI__imma_m16n16k16_ld_a_s8:
+ return MMA_LDST(2, m16n16k16_load_a_s8);
+ case NVPTX::BI__imma_m16n16k16_ld_a_u8:
+ return MMA_LDST(2, m16n16k16_load_a_u8);
+ case NVPTX::BI__imma_m16n16k16_ld_b_s8:
+ return MMA_LDST(2, m16n16k16_load_b_s8);
+ case NVPTX::BI__imma_m16n16k16_ld_b_u8:
+ return MMA_LDST(2, m16n16k16_load_b_u8);
+ case NVPTX::BI__imma_m16n16k16_ld_c:
+ return MMA_LDST(8, m16n16k16_load_c_s32);
+ case NVPTX::BI__imma_m32n8k16_ld_a_s8:
+ return MMA_LDST(4, m32n8k16_load_a_s8);
+ case NVPTX::BI__imma_m32n8k16_ld_a_u8:
+ return MMA_LDST(4, m32n8k16_load_a_u8);
+ case NVPTX::BI__imma_m32n8k16_ld_b_s8:
+ return MMA_LDST(1, m32n8k16_load_b_s8);
+ case NVPTX::BI__imma_m32n8k16_ld_b_u8:
+ return MMA_LDST(1, m32n8k16_load_b_u8);
+ case NVPTX::BI__imma_m32n8k16_ld_c:
+ return MMA_LDST(8, m32n8k16_load_c_s32);
+ case NVPTX::BI__imma_m8n32k16_ld_a_s8:
+ return MMA_LDST(1, m8n32k16_load_a_s8);
+ case NVPTX::BI__imma_m8n32k16_ld_a_u8:
+ return MMA_LDST(1, m8n32k16_load_a_u8);
+ case NVPTX::BI__imma_m8n32k16_ld_b_s8:
+ return MMA_LDST(4, m8n32k16_load_b_s8);
+ case NVPTX::BI__imma_m8n32k16_ld_b_u8:
+ return MMA_LDST(4, m8n32k16_load_b_u8);
+ case NVPTX::BI__imma_m8n32k16_ld_c:
+ return MMA_LDST(8, m8n32k16_load_c_s32);
+
+ // Sub-integer MMA loads.
+ // Only row/col layout is supported by A/B fragments.
+ case NVPTX::BI__imma_m8n8k32_ld_a_s4:
+ return {1, 0, MMA_INTR(m8n8k32_load_a_s4, row)};
+ case NVPTX::BI__imma_m8n8k32_ld_a_u4:
+ return {1, 0, MMA_INTR(m8n8k32_load_a_u4, row)};
+ case NVPTX::BI__imma_m8n8k32_ld_b_s4:
+ return {1, MMA_INTR(m8n8k32_load_b_s4, col), 0};
+ case NVPTX::BI__imma_m8n8k32_ld_b_u4:
+ return {1, MMA_INTR(m8n8k32_load_b_u4, col), 0};
+ case NVPTX::BI__imma_m8n8k32_ld_c:
+ return MMA_LDST(2, m8n8k32_load_c_s32);
+ case NVPTX::BI__bmma_m8n8k128_ld_a_b1:
+ return {1, 0, MMA_INTR(m8n8k128_load_a_b1, row)};
+ case NVPTX::BI__bmma_m8n8k128_ld_b_b1:
+ return {1, MMA_INTR(m8n8k128_load_b_b1, col), 0};
+ case NVPTX::BI__bmma_m8n8k128_ld_c:
+ return MMA_LDST(2, m8n8k128_load_c_s32);
+
+ // NOTE: We need to follow inconsitent naming scheme used by NVCC. Unlike
+ // PTX and LLVM IR where stores always use fragment D, NVCC builtins always
+ // use fragment C for both loads and stores.
+ // FP MMA stores.
+ case NVPTX::BI__hmma_m16n16k16_st_c_f16:
+ return MMA_LDST(4, m16n16k16_store_d_f16);
+ case NVPTX::BI__hmma_m16n16k16_st_c_f32:
+ return MMA_LDST(8, m16n16k16_store_d_f32);
+ case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+ return MMA_LDST(4, m32n8k16_store_d_f16);
+ case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+ return MMA_LDST(8, m32n8k16_store_d_f32);
+ case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+ return MMA_LDST(4, m8n32k16_store_d_f16);
+ case NVPTX::BI__hmma_m8n32k16_st_c_f32:
+ return MMA_LDST(8, m8n32k16_store_d_f32);
+
+ // Integer and sub-integer MMA stores.
+ // Another naming quirk. Unlike other MMA builtins that use PTX types in the
+ // name, integer loads/stores use LLVM's i32.
+ case NVPTX::BI__imma_m16n16k16_st_c_i32:
+ return MMA_LDST(8, m16n16k16_store_d_s32);
+ case NVPTX::BI__imma_m32n8k16_st_c_i32:
+ return MMA_LDST(8, m32n8k16_store_d_s32);
+ case NVPTX::BI__imma_m8n32k16_st_c_i32:
+ return MMA_LDST(8, m8n32k16_store_d_s32);
+ case NVPTX::BI__imma_m8n8k32_st_c_i32:
+ return MMA_LDST(2, m8n8k32_store_d_s32);
+ case NVPTX::BI__bmma_m8n8k128_st_c_i32:
+ return MMA_LDST(2, m8n8k128_store_d_s32);
+
+ default:
+ llvm_unreachable("Unknown MMA builtin");
+ }
+}
+#undef MMA_LDST
+#undef MMA_INTR
+
+
+struct NVPTXMmaInfo {
+ unsigned NumEltsA;
+ unsigned NumEltsB;
+ unsigned NumEltsC;
+ unsigned NumEltsD;
+ std::array<unsigned, 8> Variants;
+
+ unsigned getMMAIntrinsic(int Layout, bool Satf) {
+ unsigned Index = Layout * 2 + Satf;
+ if (Index >= Variants.size())
+ return 0;
+ return Variants[Index];
+ }
+};
+
+ // Returns an intrinsic that matches Layout and Satf for valid combinations of
+ // Layout and Satf, 0 otherwise.
+static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
+ // clang-format off
+#define MMA_VARIANTS(geom, type) {{ \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \
+ }}
+// Sub-integer MMA only supports row.col layout.
+#define MMA_VARIANTS_I4(geom, type) {{ \
+ 0, \
+ 0, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
+ 0, \
+ 0, \
+ 0, \
+ 0 \
+ }}
+// b1 MMA does not support .satfinite.
+#define MMA_VARIANTS_B1(geom, type) {{ \
+ 0, \
+ 0, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
+ 0, \
+ 0, \
+ 0, \
+ 0, \
+ 0 \
+ }}
+ // clang-format on
+ switch (BuiltinID) {
+ // FP MMA
+ // Note that 'type' argument of MMA_VARIANT uses D_C notation, while
+ // NumEltsN of return value are ordered as A,B,C,D.
+ case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
+ return {8, 8, 4, 4, MMA_VARIANTS(m16n16k16, f16_f16)};
+ case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
+ return {8, 8, 4, 8, MMA_VARIANTS(m16n16k16, f32_f16)};
+ case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+ return {8, 8, 8, 4, MMA_VARIANTS(m16n16k16, f16_f32)};
+ case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
+ return {8, 8, 8, 8, MMA_VARIANTS(m16n16k16, f32_f32)};
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+ return {8, 8, 4, 4, MMA_VARIANTS(m32n8k16, f16_f16)};
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+ return {8, 8, 4, 8, MMA_VARIANTS(m32n8k16, f32_f16)};
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+ return {8, 8, 8, 4, MMA_VARIANTS(m32n8k16, f16_f32)};
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+ return {8, 8, 8, 8, MMA_VARIANTS(m32n8k16, f32_f32)};
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+ return {8, 8, 4, 4, MMA_VARIANTS(m8n32k16, f16_f16)};
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+ return {8, 8, 4, 8, MMA_VARIANTS(m8n32k16, f32_f16)};
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+ return {8, 8, 8, 4, MMA_VARIANTS(m8n32k16, f16_f32)};
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+ return {8, 8, 8, 8, MMA_VARIANTS(m8n32k16, f32_f32)};
+
+ // Integer MMA
+ case NVPTX::BI__imma_m16n16k16_mma_s8:
+ return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, s8)};
+ case NVPTX::BI__imma_m16n16k16_mma_u8:
+ return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, u8)};
+ case NVPTX::BI__imma_m32n8k16_mma_s8:
+ return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, s8)};
+ case NVPTX::BI__imma_m32n8k16_mma_u8:
+ return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, u8)};
+ case NVPTX::BI__imma_m8n32k16_mma_s8:
+ return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, s8)};
+ case NVPTX::BI__imma_m8n32k16_mma_u8:
+ return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, u8)};
+
+ // Sub-integer MMA
+ case NVPTX::BI__imma_m8n8k32_mma_s4:
+ return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, s4)};
+ case NVPTX::BI__imma_m8n8k32_mma_u4:
+ return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, u4)};
+ case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+ return {1, 1, 2, 2, MMA_VARIANTS_B1(m8n8k128, b1)};
+ default:
+ llvm_unreachable("Unexpected builtin ID.");
+ }
+#undef MMA_VARIANTS
+#undef MMA_VARIANTS_I4
+#undef MMA_VARIANTS_B1
+}
+
+} // namespace
+
+Value *
+CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
auto MakeLdg = [&](unsigned IntrinsicID) {
Value *Ptr = EmitScalarExpr(E->getArg(0));
clang::CharUnits Align =
@@ -13189,6 +13433,8 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
Builder.CreateStore(Pred, PredOutPtr);
return Builder.CreateExtractValue(ResultPair, 0);
}
+
+ // FP MMA loads
case NVPTX::BI__hmma_m16n16k16_ld_a:
case NVPTX::BI__hmma_m16n16k16_ld_b:
case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
@@ -13200,7 +13446,33 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__hmma_m8n32k16_ld_a:
case NVPTX::BI__hmma_m8n32k16_ld_b:
case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
- case NVPTX::BI__hmma_m8n32k16_ld_c_f32: {
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f32:
+ // Integer MMA loads.
+ case NVPTX::BI__imma_m16n16k16_ld_a_s8:
+ case NVPTX::BI__imma_m16n16k16_ld_a_u8:
+ case NVPTX::BI__imma_m16n16k16_ld_b_s8:
+ case NVPTX::BI__imma_m16n16k16_ld_b_u8:
+ case NVPTX::BI__imma_m16n16k16_ld_c:
+ case NVPTX::BI__imma_m32n8k16_ld_a_s8:
+ case NVPTX::BI__imma_m32n8k16_ld_a_u8:
+ case NVPTX::BI__imma_m32n8k16_ld_b_s8:
+ case NVPTX::BI__imma_m32n8k16_ld_b_u8:
+ case NVPTX::BI__imma_m32n8k16_ld_c:
+ case NVPTX::BI__imma_m8n32k16_ld_a_s8:
+ case NVPTX::BI__imma_m8n32k16_ld_a_u8:
+ case NVPTX::BI__imma_m8n32k16_ld_b_s8:
+ case NVPTX::BI__imma_m8n32k16_ld_b_u8:
+ case NVPTX::BI__imma_m8n32k16_ld_c:
+ // Sub-integer MMA loads.
+ case NVPTX::BI__imma_m8n8k32_ld_a_s4:
+ case NVPTX::BI__imma_m8n8k32_ld_a_u4:
+ case NVPTX::BI__imma_m8n8k32_ld_b_s4:
+ case NVPTX::BI__imma_m8n8k32_ld_b_u4:
+ case NVPTX::BI__imma_m8n8k32_ld_c:
+ case NVPTX::BI__bmma_m8n8k128_ld_a_b1:
+ case NVPTX::BI__bmma_m8n8k128_ld_b_b1:
+ case NVPTX::BI__bmma_m8n8k128_ld_c:
+ {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Value *Src = EmitScalarExpr(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -13208,82 +13480,28 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext()))
return nullptr;
bool isColMajor = isColMajorArg.getSExtValue();
- unsigned IID;
- unsigned NumResults;
- switch (BuiltinID) {
- case NVPTX::BI__hmma_m16n16k16_ld_a:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride
- : Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m16n16k16_ld_b:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride
- : Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride
- : Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride;
- NumResults = 4;
- break;
- case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride
- : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m32n8k16_ld_a:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride
- : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m32n8k16_ld_b:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride
- : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride
- : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride;
- NumResults = 4;
- break;
- case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride
- : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m8n32k16_ld_a:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride
- : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m8n32k16_ld_b:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride
- : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride;
- NumResults = 8;
- break;
- case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride
- : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride;
- NumResults = 4;
- break;
- case NVPTX::BI__hmma_m8n32k16_ld_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride
- : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride;
- NumResults = 8;
- break;
- default:
- llvm_unreachable("Unexpected builtin ID.");
- }
+ NVPTXMmaLdstInfo II = getNVPTXMmaLdstInfo(BuiltinID);
+ unsigned IID = isColMajor ? II.IID_col : II.IID_row;
+ if (IID == 0)
+ return nullptr;
+
Value *Result =
Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm});
// Save returned values.
- for (unsigned i = 0; i < NumResults; ++i) {
- Builder.CreateAlignedStore(
- Builder.CreateBitCast(Builder.CreateExtractValue(Result, i),
- Dst.getElementType()),
- Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)),
- CharUnits::fromQuantity(4));
+ assert(II.NumResults);
+ if (II.NumResults == 1) {
+ Builder.CreateAlignedStore(Result, Dst.getPointer(),
+ CharUnits::fromQuantity(4));
+ } else {
+ for (unsigned i = 0; i < II.NumResults; ++i) {
+ Builder.CreateAlignedStore(
+ Builder.CreateBitCast(Builder.CreateExtractValue(Result, i),
+ Dst.getElementType()),
+ Builder.CreateGEP(Dst.getPointer(),
+ llvm::ConstantInt::get(IntTy, i)),
+ CharUnits::fromQuantity(4));
+ }
}
return Result;
}
@@ -13293,7 +13511,12 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__hmma_m32n8k16_st_c_f16:
case NVPTX::BI__hmma_m32n8k16_st_c_f32:
case NVPTX::BI__hmma_m8n32k16_st_c_f16:
- case NVPTX::BI__hmma_m8n32k16_st_c_f32: {
+ case NVPTX::BI__hmma_m8n32k16_st_c_f32:
+ case NVPTX::BI__imma_m16n16k16_st_c_i32:
+ case NVPTX::BI__imma_m32n8k16_st_c_i32:
+ case NVPTX::BI__imma_m8n32k16_st_c_i32:
+ case NVPTX::BI__imma_m8n8k32_st_c_i32:
+ case NVPTX::BI__bmma_m8n8k128_st_c_i32: {
Value *Dst = EmitScalarExpr(E->getArg(0));
Address Src = EmitPointerWithAlignment(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -13301,45 +13524,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext()))
return nullptr;
bool isColMajor = isColMajorArg.getSExtValue();
- unsigned IID;
- unsigned NumResults = 8;
- // PTX Instructions (and LLVM intrinsics) are defined for slice _d_, yet
- // for some reason nvcc builtins use _c_.
- switch (BuiltinID) {
- case NVPTX::BI__hmma_m16n16k16_st_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride
- : Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride;
- NumResults = 4;
- break;
- case NVPTX::BI__hmma_m16n16k16_st_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride
- : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride;
- break;
- case NVPTX::BI__hmma_m32n8k16_st_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride
- : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride;
- NumResults = 4;
- break;
- case NVPTX::BI__hmma_m32n8k16_st_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride
- : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride;
- break;
- case NVPTX::BI__hmma_m8n32k16_st_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride
- : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride;
- NumResults = 4;
- break;
- case NVPTX::BI__hmma_m8n32k16_st_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride
- : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride;
- break;
- default:
- llvm_unreachable("Unexpected builtin ID.");
- }
- Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType());
+ NVPTXMmaLdstInfo II = getNVPTXMmaLdstInfo(BuiltinID);
+ unsigned IID = isColMajor ? II.IID_col : II.IID_row;
+ if (IID == 0)
+ return nullptr;
+ Function *Intrinsic =
+ CGM.getIntrinsic(IID, Dst->getType());
llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
SmallVector<Value *, 10> Values = {Dst};
- for (unsigned i = 0; i < NumResults; ++i) {
+ for (unsigned i = 0; i < II.NumResults; ++i) {
Value *V = Builder.CreateAlignedLoad(
Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),
CharUnits::fromQuantity(4));
@@ -13363,7 +13556,16 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
- case NVPTX::BI__hmma_m8n32k16_mma_f16f32: {
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+ case NVPTX::BI__imma_m16n16k16_mma_s8:
+ case NVPTX::BI__imma_m16n16k16_mma_u8:
+ case NVPTX::BI__imma_m32n8k16_mma_s8:
+ case NVPTX::BI__imma_m32n8k16_mma_u8:
+ case NVPTX::BI__imma_m8n32k16_mma_s8:
+ case NVPTX::BI__imma_m8n32k16_mma_u8:
+ case NVPTX::BI__imma_m8n8k32_mma_s4:
+ case NVPTX::BI__imma_m8n8k32_mma_u4:
+ case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Address SrcA = EmitPointerWithAlignment(E->getArg(1));
Address SrcB = EmitPointerWithAlignment(E->getArg(2));
@@ -13375,119 +13577,40 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
if (Layout < 0 || Layout > 3)
return nullptr;
llvm::APSInt SatfArg;
- if (!E->getArg(5)->isIntegerConstantExpr(SatfArg, getContext()))
+ if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1)
+ SatfArg = 0; // .b1 does not have satf argument.
+ else if (!E->getArg(5)->isIntegerConstantExpr(SatfArg, getContext()))
return nullptr;
bool Satf = SatfArg.getSExtValue();
-
- // clang-format off
-#define MMA_VARIANTS(geom, type) {{ \
- Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
- Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \
- }}
- // clang-format on
-
- auto getMMAIntrinsic = [Layout, Satf](std::array<unsigned, 8> Variants) {
- unsigned Index = Layout * 2 + Satf;
- assert(Index < 8);
- return Variants[Index];
- };
- unsigned IID;
- unsigned NumEltsC;
- unsigned NumEltsD;
- switch (BuiltinID) {
- case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16));
- NumEltsC = 4;
- NumEltsD = 4;
- break;
- case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16));
- NumEltsC = 4;
- NumEltsD = 8;
- break;
- case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32));
- NumEltsC = 8;
- NumEltsD = 4;
- break;
- case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32));
- NumEltsC = 8;
- NumEltsD = 8;
- break;
- case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16));
- NumEltsC = 4;
- NumEltsD = 4;
- break;
- case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16));
- NumEltsC = 4;
- NumEltsD = 8;
- break;
- case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32));
- NumEltsC = 8;
- NumEltsD = 4;
- break;
- case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32));
- NumEltsC = 8;
- NumEltsD = 8;
- break;
- case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16));
- NumEltsC = 4;
- NumEltsD = 4;
- break;
- case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16));
- NumEltsC = 4;
- NumEltsD = 8;
- break;
- case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32));
- NumEltsC = 8;
- NumEltsD = 4;
- break;
- case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32));
- NumEltsC = 8;
- NumEltsD = 8;
- break;
- default:
- llvm_unreachable("Unexpected builtin ID.");
- }
-#undef MMA_VARIANTS
+ NVPTXMmaInfo MI = getNVPTXMmaInfo(BuiltinID);
+ unsigned IID = MI.getMMAIntrinsic(Layout, Satf);
+ if (IID == 0) // Unsupported combination of Layout/Satf.
+ return nullptr;
SmallVector<Value *, 24> Values;
Function *Intrinsic = CGM.getIntrinsic(IID);
- llvm::Type *ABType = Intrinsic->getFunctionType()->getParamType(0);
+ llvm::Type *AType = Intrinsic->getFunctionType()->getParamType(0);
// Load A
- for (unsigned i = 0; i < 8; ++i) {
+ for (unsigned i = 0; i < MI.NumEltsA; ++i) {
Value *V = Builder.CreateAlignedLoad(
Builder.CreateGEP(SrcA.getPointer(),
llvm::ConstantInt::get(IntTy, i)),
CharUnits::fromQuantity(4));
- Values.push_back(Builder.CreateBitCast(V, ABType));
+ Values.push_back(Builder.CreateBitCast(V, AType));
}
// Load B
- for (unsigned i = 0; i < 8; ++i) {
+ llvm::Type *BType = Intrinsic->getFunctionType()->getParamType(MI.NumEltsA);
+ for (unsigned i = 0; i < MI.NumEltsB; ++i) {
Value *V = Builder.CreateAlignedLoad(
Builder.CreateGEP(SrcB.getPointer(),
llvm::ConstantInt::get(IntTy, i)),
CharUnits::fromQuantity(4));
- Values.push_back(Builder.CreateBitCast(V, ABType));
+ Values.push_back(Builder.CreateBitCast(V, BType));
}
// Load C
- llvm::Type *CType = Intrinsic->getFunctionType()->getParamType(16);
- for (unsigned i = 0; i < NumEltsC; ++i) {
+ llvm::Type *CType =
+ Intrinsic->getFunctionType()->getParamType(MI.NumEltsA + MI.NumEltsB);
+ for (unsigned i = 0; i < MI.NumEltsC; ++i) {
Value *V = Builder.CreateAlignedLoad(
Builder.CreateGEP(SrcC.getPointer(),
llvm::ConstantInt::get(IntTy, i)),
@@ -13496,7 +13619,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
}
Value *Result = Builder.CreateCall(Intrinsic, Values);
llvm::Type *DType = Dst.getElementType();
- for (unsigned i = 0; i < NumEltsD; ++i)
+ for (unsigned i = 0; i < MI.NumEltsD; ++i)
Builder.CreateAlignedStore(
Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), DType),
Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)),
OpenPOWER on IntegriCloud