diff options
Diffstat (limited to 'clang/lib')
| -rw-r--r-- | clang/lib/CodeGen/CGBuiltin.cpp | 150 | ||||
| -rw-r--r-- | clang/lib/Driver/ToolChains/Cuda.cpp | 22 |
2 files changed, 146 insertions, 26 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index fffc2429fb1..6a2f2b0a4a1 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -10715,7 +10715,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m16n16k16_ld_a: case NVPTX::BI__hmma_m16n16k16_ld_b: case NVPTX::BI__hmma_m16n16k16_ld_c_f16: - case NVPTX::BI__hmma_m16n16k16_ld_c_f32: { + case NVPTX::BI__hmma_m16n16k16_ld_c_f32: + case NVPTX::BI__hmma_m32n8k16_ld_a: + case NVPTX::BI__hmma_m32n8k16_ld_b: + case NVPTX::BI__hmma_m32n8k16_ld_c_f16: + case NVPTX::BI__hmma_m32n8k16_ld_c_f32: + case NVPTX::BI__hmma_m8n32k16_ld_a: + case NVPTX::BI__hmma_m8n32k16_ld_b: + case NVPTX::BI__hmma_m8n32k16_ld_c_f16: + case NVPTX::BI__hmma_m8n32k16_ld_c_f32: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Value *Src = EmitScalarExpr(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -10746,6 +10754,46 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride; NumResults = 8; break; + case NVPTX::BI__hmma_m32n8k16_ld_a: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m32n8k16_ld_b: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m32n8k16_ld_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m32n8k16_ld_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m8n32k16_ld_a: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m8n32k16_ld_b: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride; + NumResults = 8; + break; + case NVPTX::BI__hmma_m8n32k16_ld_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m8n32k16_ld_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride; + NumResults = 8; + break; default: llvm_unreachable("Unexpected builtin ID."); } @@ -10764,7 +10812,11 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, } case NVPTX::BI__hmma_m16n16k16_st_c_f16: - case NVPTX::BI__hmma_m16n16k16_st_c_f32: { + case NVPTX::BI__hmma_m16n16k16_st_c_f32: + case NVPTX::BI__hmma_m32n8k16_st_c_f16: + case NVPTX::BI__hmma_m32n8k16_st_c_f32: + case NVPTX::BI__hmma_m8n32k16_st_c_f16: + case NVPTX::BI__hmma_m8n32k16_st_c_f32: { Value *Dst = EmitScalarExpr(E->getArg(0)); Address Src = EmitPointerWithAlignment(E->getArg(1)); Value *Ldm = EmitScalarExpr(E->getArg(2)); @@ -10786,6 +10838,24 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride; break; + case NVPTX::BI__hmma_m32n8k16_st_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m32n8k16_st_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride + : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride; + break; + case NVPTX::BI__hmma_m8n32k16_st_c_f16: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride; + NumResults = 4; + break; + case NVPTX::BI__hmma_m8n32k16_st_c_f32: + IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride + : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride; + break; default: llvm_unreachable("Unexpected builtin ID."); } @@ -10808,7 +10878,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, case NVPTX::BI__hmma_m16n16k16_mma_f16f16: case NVPTX::BI__hmma_m16n16k16_mma_f32f16: case NVPTX::BI__hmma_m16n16k16_mma_f32f32: - case NVPTX::BI__hmma_m16n16k16_mma_f16f32: { + case NVPTX::BI__hmma_m16n16k16_mma_f16f32: + case NVPTX::BI__hmma_m32n8k16_mma_f16f16: + case NVPTX::BI__hmma_m32n8k16_mma_f32f16: + case NVPTX::BI__hmma_m32n8k16_mma_f32f32: + case NVPTX::BI__hmma_m32n8k16_mma_f16f32: + case NVPTX::BI__hmma_m8n32k16_mma_f16f16: + case NVPTX::BI__hmma_m8n32k16_mma_f32f16: + case NVPTX::BI__hmma_m8n32k16_mma_f32f32: + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: { Address Dst = EmitPointerWithAlignment(E->getArg(0)); Address SrcA = EmitPointerWithAlignment(E->getArg(1)); Address SrcB = EmitPointerWithAlignment(E->getArg(2)); @@ -10825,15 +10903,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, bool Satf = SatfArg.getSExtValue(); // clang-format off -#define MMA_VARIANTS(type) {{ \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type, \ - Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite \ +#define MMA_VARIANTS(geom, type) {{ \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \ + Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \ }} // clang-format on @@ -10847,22 +10925,62 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, unsigned NumEltsD; switch (BuiltinID) { case NVPTX::BI__hmma_m16n16k16_mma_f16f16: - IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16)); NumEltsC = 4; NumEltsD = 4; break; case NVPTX::BI__hmma_m16n16k16_mma_f32f16: - IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16)); NumEltsC = 4; NumEltsD = 8; break; case NVPTX::BI__hmma_m16n16k16_mma_f16f32: - IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32)); NumEltsC = 8; NumEltsD = 4; break; case NVPTX::BI__hmma_m16n16k16_mma_f32f32: - IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32)); + IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32)); + NumEltsC = 8; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f16f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16)); + NumEltsC = 4; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f32f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16)); + NumEltsC = 4; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f16f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32)); + NumEltsC = 8; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m32n8k16_mma_f32f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32)); + NumEltsC = 8; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f16f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16)); + NumEltsC = 4; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f32f16: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16)); + NumEltsC = 4; + NumEltsD = 8; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f16f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32)); + NumEltsC = 8; + NumEltsD = 4; + break; + case NVPTX::BI__hmma_m8n32k16_mma_f32f32: + IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32)); NumEltsC = 8; NumEltsD = 8; break; diff --git a/clang/lib/Driver/ToolChains/Cuda.cpp b/clang/lib/Driver/ToolChains/Cuda.cpp index fc74a8ebbf0..fdd62fec390 100644 --- a/clang/lib/Driver/ToolChains/Cuda.cpp +++ b/clang/lib/Driver/ToolChains/Cuda.cpp @@ -622,17 +622,19 @@ void CudaToolChain::addClangTargetOptions( CC1Args.push_back("-mlink-cuda-bitcode"); CC1Args.push_back(DriverArgs.MakeArgString(LibDeviceFile)); - if (CudaInstallation.version() >= CudaVersion::CUDA_90) { - // CUDA-9 uses new instructions that are only available in PTX6.0 - CC1Args.push_back("-target-feature"); - CC1Args.push_back("+ptx60"); - } else { - // Libdevice in CUDA-7.0 requires PTX version that's more recent - // than LLVM defaults to. Use PTX4.2 which is the PTX version that - // came with CUDA-7.0. - CC1Args.push_back("-target-feature"); - CC1Args.push_back("+ptx42"); + // Libdevice in CUDA-7.0 requires PTX version that's more recent than LLVM + // defaults to. Use PTX4.2 by default, which is the PTX version that came with + // CUDA-7.0. + const char *PtxFeature = "+ptx42"; + if (CudaInstallation.version() >= CudaVersion::CUDA_91) { + // CUDA-9.1 uses new instructions that are only available in PTX6.1+ + PtxFeature = "+ptx61"; + } else if (CudaInstallation.version() >= CudaVersion::CUDA_90) { + // CUDA-9.0 uses new instructions that are only available in PTX6.0+ + PtxFeature = "+ptx60"; } + CC1Args.push_back("-target-feature"); + CC1Args.push_back(PtxFeature); if (DeviceOffloadingKind == Action::OFK_OpenMP) { SmallVector<StringRef, 8> LibraryPaths; |

