diff options
Diffstat (limited to 'clang/lib/CodeGen')
-rw-r--r-- | clang/lib/CodeGen/CGBuiltin.cpp | 543 |
1 files changed, 333 insertions, 210 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 02969a1908b..cfdb69c0798 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -12925,8 +12925,252 @@ Value *CodeGenFunction::EmitSystemZBuiltinExpr(unsigned BuiltinID, } } -Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, - const CallExpr *E) { +namespace { +// Helper classes for mapping MMA builtins to particular LLVM intrinsic variant. +struct NVPTXMmaLdstInfo { + unsigned NumResults; // Number of elements to load/store + // Intrinsic IDs for row/col variants. 0 if particular layout is unsupported. + unsigned IID_col; + unsigned IID_row; +}; + +#define MMA_INTR(geom_op_type, layout) \ + Intrinsic::nvvm_wmma_##geom_op_type##_##layout##_stride +#define MMA_LDST(n, geom_op_type) \ + { n, MMA_INTR(geom_op_type, col), MMA_INTR(geom_op_type, row) } + +static NVPTXMmaLdstInfo getNVPTXMmaLdstInfo(unsigned BuiltinID) { + switch (BuiltinID) { + // FP MMA loads + case NVPTX::BI__hmma_m16n16k16_ld_a: + return MMA_LDST(8, m16n16k16_load_a_f16); + case NVPTX::BI__hmma_m16n16k16_ld_b: + return MMA_LDST(8, m16n16k16_load_b_f16); + case NVPTX::BI__hmma_m16n16k16_ld_c_f16: + return MMA_LDST(4, m16n16k16_load_c_f16); + case NVPTX::BI__hmma_m16n16k16_ld_c_f32: + return MMA_LDST(8, m16n16k16_load_c_f32); + case NVPTX::BI__hmma_m32n8k16_ld_a: + return MMA_LDST(8, m32n8k16_load_a_f16); + case NVPTX::BI__hmma_m32n8k16_ld_b: + return MMA_LDST(8, m32n8k16_load_b_f16); + case NVPTX::BI__hmma_m32n8k16_ld_c_f16: + return MMA_LDST(4, m32n8k16_load_c_f16); + case NVPTX::BI__hmma_m32n8k16_ld_c_f32: + return MMA_LDST(8, m32n8k16_load_c_f32); + case NVPTX::BI__hmma_m8n32k16_ld_a: + return MMA_LDST(8, m8n32k16_load_a_f16); + case NVPTX::BI__hmma_m8n32k16_ld_b: + return MMA_LDST(8, m8n32k16_load_b_f16); + case NVPTX::BI__hmma_m8n32k16_ld_c_f16: + return MMA_LDST(4, m8n32k16_load_c_f16); + case NVPTX::BI__hmma_m8n32k16_ld_c_f32: + return MMA_LDST(8, m8n32k16_load_c_f32); + + // Integer MMA loads + case NVPTX::BI__imma_m16n16k16_ld_a_s8: + return MMA_LDST(2, m16n16k16_load_a_s8); + case NVPTX::BI__imma_m16n16k16_ld_a_u8: + return MMA_LDST(2, m16n16k16_load_a_u8); + case NVPTX::BI__imma_m16n16k16_ld_b_s8: + return MMA_LDST(2, m16n16k16_load_b_s8); + case NVPTX::BI__imma_m16n16k16_ld_b_u8: + return MMA_LDST(2, m16n16k16_load_b_u8); + case NVPTX::BI__imma_m16n16k16_ld_c: + return MMA_LDST(8, m16n16k16_load_c_s32); + case NVPTX::BI__imma_m32n8k16_ld_a_s8: + return MMA_LDST(4, m32n8k16_load_a_s8); + case NVPTX::BI__imma_m32n8k16_ld_a_u8: + return MMA_LDST(4, m32n8k16_load_a_u8); + case NVPTX::BI__imma_m32n8k16_ld_b_s8: + return MMA_LDST(1, m32n8k16_load_b_s8); + case NVPTX::BI__imma_m32n8k16_ld_b_u8: + return MMA_LDST(1, m32n8k16_load_b_u8); + case NVPTX::BI__imma_m32n8k16_ld_c: + return MMA_LDST(8, m32n8k16_load_c_s32); + case NVPTX::BI__imma_m8n32k16_ld_a_s8: + return MMA_LDST(1, m8n32k16_load_a_s8); + case NVPTX::BI__imma_m8n32k16_ld_a_u8: + return MMA_LDST(1, m8n32k16_load_a_u8); + case NVPTX::BI__imma_m8n32k16_ld_b_s8: + return MMA_LDST(4, m8n32k16_load_b_s8); + case NVPTX::BI__imma_m8n32k16_ld_b_u8: + return MMA_LDST(4, m8n32k16_load_b_u8); + case NVPTX::BI__imma_m8n32k16_ld_c: + return MMA_LDST(8, m8n32k16_load_c_s32); + + // Sub-integer MMA loads. + // Only row/col layout is supported by A/B fragments. + case NVPTX::BI__imma_m8n8k32_ld_a_s4: + return {1, 0, MMA_INTR(m8n8k32_load_a_s4, row)}; + case NVPTX::BI__imma_m8n8k32_ld_a_u4: + return {1, 0, MMA_INTR(m8n8k32_load_a_u4, row)}; + case NVPTX::BI__imma_m8n8k32_ld_b_s4: + return {1, MMA_INTR(m8n8k32_load_b_s4, col), 0}; + case NVPTX::BI__imma_m8n8k32_ld_b_u4: + return {1, MMA_INTR(m8n8k32_load_b_u4, col), 0}; + case NVPTX::BI__imma_m8n8k32_ld_c: + return MMA_LDST(2, m8n8k32_load_c_s32); + case NVPTX::BI__bmma_m8n8k128_ld_a_b1: + return {1, 0, MMA_INTR(m8n8k128_load_a_b1, row)}; + case NVPTX::BI__bmma_m8n8k128_ld_b_b1: + return {1, MMA_INTR(m8n8k128_load_b_b1, col), 0}; + case NVPTX::BI__bmma_m8n8k128_ld_c: + return MMA_LDST(2, m8n8k128_load_c_s32); + + // NOTE: We need to follow inconsitent naming scheme used by NVCC. Unlike + // PTX and LLVM IR where stores always use fragment D, NVCC builtins always + // use fragment C for both loads and stores. + // FP MMA stores. + case NVPTX::BI__hmma_m16n16k16_st_c_f16: + return MMA_LDST(4, m16n16k16_store_d_f16); + case NVPTX::BI__hmma_m16n16k16_st_c_f32: + return MMA_LDST(8, m16n16k16_store_d_f32); + case NVPTX::BI__hmma_m32n8k16_st_c_f16: + return MMA_LDST(4, m32n8k16_store_d_f16); + case NVPTX::BI__hmma_m32n8k16_st_c_f32: + return MMA_LDST(8, m32n8k16_store_d_f32); + case NVPTX::BI__hmma_m8n32k16_st_c_f16: + return MMA_LDST(4, m8n32k16_store_d_f16); + case NVPTX::BI__hmma_m8n32k16_st_c_f32: + return MMA_LDST(8, m8n32k16_store_d_f32); + + // Integer and sub-integer MMA stores. + // Another naming quirk. Unlike other MMA builtins that use PTX types in the + // name, integer loads/stores use LLVM's i32. + case NVPTX::BI__imma_m16n16k16_st_c_i32: + return MMA_LDST(8, m16n16k16_store_d_s32); + case NVPTX::BI__imma_m32n8k16_st_c_i32: + return MMA_LDST(8, m32n8k16_store_d_s32); + case NVPTX::BI__imma_m8n32k16_st_c_i32: + return MMA_LDST(8, m8n32k16_store_d_s32); + case NVPTX::BI__imma_m8n8k32_st_c_i32: + return MMA_LDST(2, m8n8k32_store_d_s32); + case NVPTX::BI__bmma_m8n8k128_st_c_i32: + return MMA_LDST(2, m8n8k128_store_d_s32); + + default: + llvm_unreachable("Unknown MMA builtin"); + } +} +#undef MMA_LDST +#undef MMA_INTR + + +struct NVPTXMmaInfo { + unsigned NumEltsA; + unsigned NumEltsB; + unsigned NumEltsC; + unsigned NumEltsD; + std::array<unsigned, 8> Variants; + + unsigned getMMAIntrinsic(int Layout, bool Satf) { + unsigned Index = Layout * 2 + Satf; + if (Index >= Variants.size()) + return 0; + return Variants[Index]; + } +}; + + // Returns an intrinsic that matches Layout and Satf for valid combinations of + // Layout and Satf, 0 otherwise. +static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) { + // clang-format off +#define MMA_VARIANTS(geom, type) {{ \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \ + }} +// Sub-integer MMA only supports row.col layout. +#define MMA_VARIANTS_I4(geom, type) {{ \ + 0, \ + 0, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ + 0, \ + 0, \ + 0, \ + 0 \ + }} +// b1 MMA does not support .satfinite. +#define MMA_VARIANTS_B1(geom, type) {{ \ + 0, \ + 0, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + 0, \ + 0, \ + 0, \ + 0, \ + 0 \ + }} + // clang-format on + switch (BuiltinID) { + // FP MMA + // Note that 'type' argument of MMA_VARIANT uses D_C notation, while + // NumEltsN of return value are ordered as A,B,C,D. + case NVPTX::BI__hmma_m16n16k16_mma_f16f16: + return {8, 8, 4, 4, MMA_VARIANTS(m16n16k16, f16_f16)}; + case NVPTX::BI__hmma_m16n16k16_mma_f32f16: + return {8, 8, 4, 8, MMA_VARIANTS(m16n16k16, f32_f16)}; + case NVPTX::BI__hmma_m16n16k16_mma_f16f32: + return {8, 8, 8, 4, MMA_VARIANTS(m16n16k16, f16_f32)}; + case NVPTX::BI__hmma_m16n16k16_mma_f32f32: + return {8, 8, 8, 8, MMA_VARIANTS(m16n16k16, f32_f32)}; + case NVPTX::BI__hmma_m32n8k16_mma_f16f16: + return {8, 8, 4, 4, MMA_VARIANTS(m32n8k16, f16_f16)}; + case NVPTX::BI__hmma_m32n8k16_mma_f32f16: + return {8, 8, 4, 8, MMA_VARIANTS(m32n8k16, f32_f16)}; + case NVPTX::BI__hmma_m32n8k16_mma_f16f32: + return {8, 8, 8, 4, MMA_VARIANTS(m32n8k16, f16_f32)}; + case NVPTX::BI__hmma_m32n8k16_mma_f32f32: + return {8, 8, 8, 8, MMA_VARIANTS(m32n8k16, f32_f32)}; + case NVPTX::BI__hmma_m8n32k16_mma_f16f16: + return {8, 8, 4, 4, MMA_VARIANTS(m8n32k16, f16_f16)}; + case NVPTX::BI__hmma_m8n32k16_mma_f32f16: + return {8, 8, 4, 8, MMA_VARIANTS(m8n32k16, f32_f16)}; + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: + return {8, 8, 8, 4, MMA_VARIANTS(m8n32k16, f16_f32)}; + case NVPTX::BI__hmma_m8n32k16_mma_f32f32: + return {8, 8, 8, 8, MMA_VARIANTS(m8n32k16, f32_f32)}; + + // Integer MMA + case NVPTX::BI__imma_m16n16k16_mma_s8: + return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, s8)}; + case NVPTX::BI__imma_m16n16k16_mma_u8: + return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, u8)}; + case NVPTX::BI__imma_m32n8k16_mma_s8: + return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, s8)}; + case NVPTX::BI__imma_m32n8k16_mma_u8: + return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, u8)}; + case NVPTX::BI__imma_m8n32k16_mma_s8: + return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, s8)}; + case NVPTX::BI__imma_m8n32k16_mma_u8: + return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, u8)}; + + // Sub-integer MMA + case NVPTX::BI__imma_m8n8k32_mma_s4: + return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, s4)}; + case NVPTX::BI__imma_m8n8k32_mma_u4: + return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, u4)}; + case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: + return {1, 1, 2, 2, MMA_VARIANTS_B1(m8n8k128, b1)}; + default: + llvm_unreachable("Unexpected builtin ID."); + } +#undef MMA_VARIANTS +#undef MMA_VARIANTS_I4 +#undef MMA_VARIANTS_B1 +} + +} // namespace + +Value * +CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) { auto MakeLdg = [&](unsigned IntrinsicID) { Value *Ptr = EmitScalarExpr(E->getArg(0)); clang::CharUnits Align = @@ -13189,6 +13433,8 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, Builder.CreateStore(Pred, PredOutPtr); return Builder.CreateExtractValue(ResultPair, 0); } + + // FP MMA loads case NVPTX::BI__hmma_m16n16k16_ld_a: case NVPTX::BI__hmma_m16n16k16_ld_b: case NVPTX::BI__hmma_m16n16k16_ld_c_f16: @@ -13200,7 +13446,33 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m8n32k16_ld_a: case NVPTX::BI__hmma_m8n32k16_ld_b: case NVPTX::BI__hmma_m8n32k16_ld_c_f16: - case NVPTX::BI__hmma_m8n32k16_ld_c_f32: { + case NVPTX::BI__hmma_m8n32k16_ld_c_f32: + // Integer MMA loads. + case NVPTX::BI__imma_m16n16k16_ld_a_s8: + case NVPTX::BI__imma_m16n16k16_ld_a_u8: + case NVPTX::BI__imma_m16n16k16_ld_b_s8: + case NVPTX::BI__imma_m16n16k16_ld_b_u8: + case NVPTX::BI__imma_m16n16k16_ld_c: + case NVPTX::BI__imma_m32n8k16_ld_a_s8: + case NVPTX::BI__imma_m32n8k16_ld_a_u8: + case NVPTX::BI__imma_m32n8k16_ld_b_s8: + case NVPTX::BI__imma_m32n8k16_ld_b_u8: + case NVPTX::BI__imma_m32n8k16_ld_c: + case NVPTX::BI__imma_m8n32k16_ld_a_s8: + case NVPTX::BI__imma_m8n32k16_ld_a_u8: + case NVPTX::BI__imma_m8n32k16_ld_b_s8: + case NVPTX::BI__imma_m8n32k16_ld_b_u8: + case NVPTX::BI__imma_m8n32k16_ld_c: + // Sub-integer MMA loads. + case NVPTX::BI__imma_m8n8k32_ld_a_s4: + case NVPTX::BI__imma_m8n8k32_ld_a_u4: + case NVPTX::BI__imma_m8n8k32_ld_b_s4: + case NVPTX::BI__imma_m8n8k32_ld_b_u4: + case NVPTX::BI__imma_m8n8k32_ld_c: + case NVPTX::BI__bmma_m8n8k128_ld_a_b1: + case NVPTX::BI__bmma_m8n8k128_ld_b_b1: + case NVPTX::BI__bmma_m8n8k128_ld_c: + { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Value *Src = EmitScalarExpr(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -13208,82 +13480,28 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext())) return nullptr; bool isColMajor = isColMajorArg.getSExtValue(); - unsigned IID; - unsigned NumResults; - switch (BuiltinID) { - case NVPTX::BI__hmma_m16n16k16_ld_a: - IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride - : Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m16n16k16_ld_b: - IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride - : Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m16n16k16_ld_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride - : Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride; - NumResults = 4; - break; - case NVPTX::BI__hmma_m16n16k16_ld_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride - : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m32n8k16_ld_a: - IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride - : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m32n8k16_ld_b: - IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride - : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m32n8k16_ld_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride - : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride; - NumResults = 4; - break; - case NVPTX::BI__hmma_m32n8k16_ld_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride - : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m8n32k16_ld_a: - IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride - : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m8n32k16_ld_b: - IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride - : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride; - NumResults = 8; - break; - case NVPTX::BI__hmma_m8n32k16_ld_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride - : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride; - NumResults = 4; - break; - case NVPTX::BI__hmma_m8n32k16_ld_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride - : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride; - NumResults = 8; - break; - default: - llvm_unreachable("Unexpected builtin ID."); - } + NVPTXMmaLdstInfo II = getNVPTXMmaLdstInfo(BuiltinID); + unsigned IID = isColMajor ? II.IID_col : II.IID_row; + if (IID == 0) + return nullptr; + Value *Result = Builder.CreateCall(CGM.getIntrinsic(IID, Src->getType()), {Src, Ldm}); // Save returned values. - for (unsigned i = 0; i < NumResults; ++i) { - Builder.CreateAlignedStore( - Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), - Dst.getElementType()), - Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)), - CharUnits::fromQuantity(4)); + assert(II.NumResults); + if (II.NumResults == 1) { + Builder.CreateAlignedStore(Result, Dst.getPointer(), + CharUnits::fromQuantity(4)); + } else { + for (unsigned i = 0; i < II.NumResults; ++i) { + Builder.CreateAlignedStore( + Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), + Dst.getElementType()), + Builder.CreateGEP(Dst.getPointer(), + llvm::ConstantInt::get(IntTy, i)), + CharUnits::fromQuantity(4)); + } } return Result; } @@ -13293,7 +13511,12 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m32n8k16_st_c_f16: case NVPTX::BI__hmma_m32n8k16_st_c_f32: case NVPTX::BI__hmma_m8n32k16_st_c_f16: - case NVPTX::BI__hmma_m8n32k16_st_c_f32: { + case NVPTX::BI__hmma_m8n32k16_st_c_f32: + case NVPTX::BI__imma_m16n16k16_st_c_i32: + case NVPTX::BI__imma_m32n8k16_st_c_i32: + case NVPTX::BI__imma_m8n32k16_st_c_i32: + case NVPTX::BI__imma_m8n8k32_st_c_i32: + case NVPTX::BI__bmma_m8n8k128_st_c_i32: { Value *Dst = EmitScalarExpr(E->getArg(0)); Address Src = EmitPointerWithAlignment(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -13301,45 +13524,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext())) return nullptr; bool isColMajor = isColMajorArg.getSExtValue(); - unsigned IID; - unsigned NumResults = 8; - // PTX Instructions (and LLVM intrinsics) are defined for slice _d_, yet - // for some reason nvcc builtins use _c_. - switch (BuiltinID) { - case NVPTX::BI__hmma_m16n16k16_st_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride - : Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride; - NumResults = 4; - break; - case NVPTX::BI__hmma_m16n16k16_st_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride - : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride; - break; - case NVPTX::BI__hmma_m32n8k16_st_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride - : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride; - NumResults = 4; - break; - case NVPTX::BI__hmma_m32n8k16_st_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride - : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride; - break; - case NVPTX::BI__hmma_m8n32k16_st_c_f16: - IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride - : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride; - NumResults = 4; - break; - case NVPTX::BI__hmma_m8n32k16_st_c_f32: - IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride - : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride; - break; - default: - llvm_unreachable("Unexpected builtin ID."); - } - Function *Intrinsic = CGM.getIntrinsic(IID, Dst->getType()); + NVPTXMmaLdstInfo II = getNVPTXMmaLdstInfo(BuiltinID); + unsigned IID = isColMajor ? II.IID_col : II.IID_row; + if (IID == 0) + return nullptr; + Function *Intrinsic = + CGM.getIntrinsic(IID, Dst->getType()); llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1); SmallVector<Value *, 10> Values = {Dst}; - for (unsigned i = 0; i < NumResults; ++i) { + for (unsigned i = 0; i < II.NumResults; ++i) { Value *V = Builder.CreateAlignedLoad( Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)), CharUnits::fromQuantity(4)); @@ -13363,7 +13556,16 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m8n32k16_mma_f16f16: case NVPTX::BI__hmma_m8n32k16_mma_f32f16: case NVPTX::BI__hmma_m8n32k16_mma_f32f32: - case NVPTX::BI__hmma_m8n32k16_mma_f16f32: { + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: + case NVPTX::BI__imma_m16n16k16_mma_s8: + case NVPTX::BI__imma_m16n16k16_mma_u8: + case NVPTX::BI__imma_m32n8k16_mma_s8: + case NVPTX::BI__imma_m32n8k16_mma_u8: + case NVPTX::BI__imma_m8n32k16_mma_s8: + case NVPTX::BI__imma_m8n32k16_mma_u8: + case NVPTX::BI__imma_m8n8k32_mma_s4: + case NVPTX::BI__imma_m8n8k32_mma_u4: + case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Address SrcA = EmitPointerWithAlignment(E->getArg(1)); Address SrcB = EmitPointerWithAlignment(E->getArg(2)); @@ -13375,119 +13577,40 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, if (Layout < 0 || Layout > 3) return nullptr; llvm::APSInt SatfArg; - if (!E->getArg(5)->isIntegerConstantExpr(SatfArg, getContext())) + if (BuiltinID == NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1) + SatfArg = 0; // .b1 does not have satf argument. + else if (!E->getArg(5)->isIntegerConstantExpr(SatfArg, getContext())) return nullptr; bool Satf = SatfArg.getSExtValue(); - - // clang-format off -#define MMA_VARIANTS(geom, type) {{ \ - Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ - Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \ - Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \ - }} - // clang-format on - - auto getMMAIntrinsic = [Layout, Satf](std::array<unsigned, 8> Variants) { - unsigned Index = Layout * 2 + Satf; - assert(Index < 8); - return Variants[Index]; - }; - unsigned IID; - unsigned NumEltsC; - unsigned NumEltsD; - switch (BuiltinID) { - case NVPTX::BI__hmma_m16n16k16_mma_f16f16: - IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16)); - NumEltsC = 4; - NumEltsD = 4; - break; - case NVPTX::BI__hmma_m16n16k16_mma_f32f16: - IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16)); - NumEltsC = 4; - NumEltsD = 8; - break; - case NVPTX::BI__hmma_m16n16k16_mma_f16f32: - IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32)); - NumEltsC = 8; - NumEltsD = 4; - break; - case NVPTX::BI__hmma_m16n16k16_mma_f32f32: - IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32)); - NumEltsC = 8; - NumEltsD = 8; - break; - case NVPTX::BI__hmma_m32n8k16_mma_f16f16: - IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16)); - NumEltsC = 4; - NumEltsD = 4; - break; - case NVPTX::BI__hmma_m32n8k16_mma_f32f16: - IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16)); - NumEltsC = 4; - NumEltsD = 8; - break; - case NVPTX::BI__hmma_m32n8k16_mma_f16f32: - IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32)); - NumEltsC = 8; - NumEltsD = 4; - break; - case NVPTX::BI__hmma_m32n8k16_mma_f32f32: - IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32)); - NumEltsC = 8; - NumEltsD = 8; - break; - case NVPTX::BI__hmma_m8n32k16_mma_f16f16: - IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16)); - NumEltsC = 4; - NumEltsD = 4; - break; - case NVPTX::BI__hmma_m8n32k16_mma_f32f16: - IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16)); - NumEltsC = 4; - NumEltsD = 8; - break; - case NVPTX::BI__hmma_m8n32k16_mma_f16f32: - IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32)); - NumEltsC = 8; - NumEltsD = 4; - break; - case NVPTX::BI__hmma_m8n32k16_mma_f32f32: - IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32)); - NumEltsC = 8; - NumEltsD = 8; - break; - default: - llvm_unreachable("Unexpected builtin ID."); - } -#undef MMA_VARIANTS + NVPTXMmaInfo MI = getNVPTXMmaInfo(BuiltinID); + unsigned IID = MI.getMMAIntrinsic(Layout, Satf); + if (IID == 0) // Unsupported combination of Layout/Satf. + return nullptr; SmallVector<Value *, 24> Values; Function *Intrinsic = CGM.getIntrinsic(IID); - llvm::Type *ABType = Intrinsic->getFunctionType()->getParamType(0); + llvm::Type *AType = Intrinsic->getFunctionType()->getParamType(0); // Load A - for (unsigned i = 0; i < 8; ++i) { + for (unsigned i = 0; i < MI.NumEltsA; ++i) { Value *V = Builder.CreateAlignedLoad( Builder.CreateGEP(SrcA.getPointer(), llvm::ConstantInt::get(IntTy, i)), CharUnits::fromQuantity(4)); - Values.push_back(Builder.CreateBitCast(V, ABType)); + Values.push_back(Builder.CreateBitCast(V, AType)); } // Load B - for (unsigned i = 0; i < 8; ++i) { + llvm::Type *BType = Intrinsic->getFunctionType()->getParamType(MI.NumEltsA); + for (unsigned i = 0; i < MI.NumEltsB; ++i) { Value *V = Builder.CreateAlignedLoad( Builder.CreateGEP(SrcB.getPointer(), llvm::ConstantInt::get(IntTy, i)), CharUnits::fromQuantity(4)); - Values.push_back(Builder.CreateBitCast(V, ABType)); + Values.push_back(Builder.CreateBitCast(V, BType)); } // Load C - llvm::Type *CType = Intrinsic->getFunctionType()->getParamType(16); - for (unsigned i = 0; i < NumEltsC; ++i) { + llvm::Type *CType = + Intrinsic->getFunctionType()->getParamType(MI.NumEltsA + MI.NumEltsB); + for (unsigned i = 0; i < MI.NumEltsC; ++i) { Value *V = Builder.CreateAlignedLoad( Builder.CreateGEP(SrcC.getPointer(), llvm::ConstantInt::get(IntTy, i)), @@ -13496,7 +13619,7 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, } Value *Result = Builder.CreateCall(Intrinsic, Values); llvm::Type *DType = Dst.getElementType(); - for (unsigned i = 0; i < NumEltsD; ++i) + for (unsigned i = 0; i < MI.NumEltsD; ++i) Builder.CreateAlignedStore( Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), DType), Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)), |