diff options
Diffstat (limited to 'clang/lib/CodeGen/CGBuiltin.cpp')
| -rw-r--r-- | clang/lib/CodeGen/CGBuiltin.cpp | 150 |
1 files changed, 134 insertions, 16 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index fffc2429fb1..6a2f2b0a4a1 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -10715,7 +10715,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m16n16k16_ld_a: case NVPTX::BI__hmma_m16n16k16_ld_b: case NVPTX::BI__hmma_m16n16k16_ld_c_f16: - case NVPTX::BI__hmma_m16n16k16_ld_c_f32: { + case NVPTX::BI__hmma_m16n16k16_ld_c_f32: + case NVPTX::BI__hmma_m32n8k16_ld_a: + case NVPTX::BI__hmma_m32n8k16_ld_b: + case NVPTX::BI__hmma_m32n8k16_ld_c_f16: + case NVPTX::BI__hmma_m32n8k16_ld_c_f32: + case NVPTX::BI__hmma_m8n32k16_ld_a: + case NVPTX::BI__hmma_m8n32k16_ld_b: + case NVPTX::BI__hmma_m8n32k16_ld_c_f16: + case NVPTX::BI__hmma_m8n32k16_ld_c_f32: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Value *Src = EmitScalarExpr(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -10746,6 +10754,46 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride; NumResults = 8; break; + case NVPTX::BI__hmma_m32n8k16_ld_a: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m32n8k16_ld_b: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m32n8k16_ld_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m32n8k16_ld_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m8n32k16_ld_a: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m8n32k16_ld_b: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m8n32k16_ld_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m8n32k16_ld_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride; + NumResults = 8; + break; default: llvm_unreachable("Unexpected builtin ID."); } @@ -10764,7 +10812,11 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, } case NVPTX::BI__hmma_m16n16k16_st_c_f16: - case NVPTX::BI__hmma_m16n16k16_st_c_f32: { + case NVPTX::BI__hmma_m16n16k16_st_c_f32: + case NVPTX::BI__hmma_m32n8k16_st_c_f16: + case NVPTX::BI__hmma_m32n8k16_st_c_f32: + case NVPTX::BI__hmma_m8n32k16_st_c_f16: + case NVPTX::BI__hmma_m8n32k16_st_c_f32: { Value *Dst = EmitScalarExpr(E->getArg(0)); Address Src = EmitPointerWithAlignment(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -10786,6 +10838,24 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride; break; + case NVPTX::BI__hmma_m32n8k16_st_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m32n8k16_st_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride; + break; + case NVPTX::BI__hmma_m8n32k16_st_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m8n32k16_st_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride; + break; default: llvm_unreachable("Unexpected builtin ID."); } @@ -10808,7 +10878,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m16n16k16_mma_f16f16: case NVPTX::BI__hmma_m16n16k16_mma_f32f16: case NVPTX::BI__hmma_m16n16k16_mma_f32f32: - case NVPTX::BI__hmma_m16n16k16_mma_f16f32: { + case NVPTX::BI__hmma_m16n16k16_mma_f16f32: + case NVPTX::BI__hmma_m32n8k16_mma_f16f16: + case NVPTX::BI__hmma_m32n8k16_mma_f32f16: + case NVPTX::BI__hmma_m32n8k16_mma_f32f32: + case NVPTX::BI__hmma_m32n8k16_mma_f16f32: + case NVPTX::BI__hmma_m8n32k16_mma_f16f16: + case NVPTX::BI__hmma_m8n32k16_mma_f32f16: + case NVPTX::BI__hmma_m8n32k16_mma_f32f32: + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Address SrcA = EmitPointerWithAlignment(E->getArg(1)); Address SrcB = EmitPointerWithAlignment(E->getArg(2)); @@ -10825,15 +10903,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, bool Satf = SatfArg.getSExtValue(); // clang-format off -#define MMA_VARIANTS(type) {{ \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite \ +#define MMA_VARIANTS(geom, type) {{ \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \ }} // clang-format on @@ -10847,22 +10925,62 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, unsigned NumEltsD; switch (BuiltinID) { case NVPTX::BI__hmma_m16n16k16_mma_f16f16: - IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16)); NumEltsC = 4; NumEltsD = 4; break; case NVPTX::BI__hmma_m16n16k16_mma_f32f16: - IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16)); NumEltsC = 4; NumEltsD = 8; break; case NVPTX::BI__hmma_m16n16k16_mma_f16f32: - IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32)); NumEltsC = 8; NumEltsD = 4; break; case NVPTX::BI__hmma_m16n16k16_mma_f32f32: - IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32)); + NumEltsC = 8; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f16f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16)); + NumEltsC = 4; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f32f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16)); + NumEltsC = 4; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f16f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32)); + NumEltsC = 8; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f32f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32)); + NumEltsC = 8; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f16f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16)); + NumEltsC = 4; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f32f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16)); + NumEltsC = 4; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32)); + NumEltsC = 8; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f32f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32)); NumEltsC = 8; NumEltsD = 8; break; |

