summaryrefslogtreecommitdiffstats
path: root/clang/lib
diff options
context:
space:
mode:
authorArtem Belevich <tra@google.com>2018-04-18 21:51:48 +0000
committerArtem Belevich <tra@google.com>2018-04-18 21:51:48 +0000
commit0ae8590354b8688e1ec9926abc909b896ea49038 (patch)
tree3c803aae33ad4fd575d7d3672138519c7b20e0fc /clang/lib
parentc310bfa19397e15903a8f5386b51366aade414b9 (diff)
downloadbcm5719-llvm-0ae8590354b8688e1ec9926abc909b896ea49038.tar.gz
bcm5719-llvm-0ae8590354b8688e1ec9926abc909b896ea49038.zip
[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.
The new instructions were added added for sm_70+ GPUs in CUDA-9.1. Differential Revision: https://reviews.llvm.org/D45068 llvm-svn: 330296
Diffstat (limited to 'clang/lib')
-rw-r--r--clang/lib/CodeGen/CGBuiltin.cpp150
-rw-r--r--clang/lib/Driver/ToolChains/Cuda.cpp22
2 files changed, 146 insertions, 26 deletions
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index fffc2429fb1..6a2f2b0a4a1 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10715,7 +10715,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__hmma_m16n16k16_ld_a:
case NVPTX::BI__hmma_m16n16k16_ld_b:
case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
- case NVPTX::BI__hmma_m16n16k16_ld_c_f32: {
+ case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
+ case NVPTX::BI__hmma_m32n8k16_ld_a:
+ case NVPTX::BI__hmma_m32n8k16_ld_b:
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+ case NVPTX::BI__hmma_m8n32k16_ld_a:
+ case NVPTX::BI__hmma_m8n32k16_ld_b:
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f32: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Value *Src = EmitScalarExpr(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -10746,6 +10754,46 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
: Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride;
NumResults = 8;
break;
+ case NVPTX::BI__hmma_m32n8k16_ld_a:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_ld_b:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_a:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_b:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride;
+ NumResults = 8;
+ break;
default:
llvm_unreachable("Unexpected builtin ID.");
}
@@ -10764,7 +10812,11 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
}
case NVPTX::BI__hmma_m16n16k16_st_c_f16:
- case NVPTX::BI__hmma_m16n16k16_st_c_f32: {
+ case NVPTX::BI__hmma_m16n16k16_st_c_f32:
+ case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+ case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+ case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+ case NVPTX::BI__hmma_m8n32k16_st_c_f32: {
Value *Dst = EmitScalarExpr(E->getArg(0));
Address Src = EmitPointerWithAlignment(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -10786,6 +10838,24 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride
: Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride;
break;
+ case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_st_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride;
+ break;
default:
llvm_unreachable("Unexpected builtin ID.");
}
@@ -10808,7 +10878,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
- case NVPTX::BI__hmma_m16n16k16_mma_f16f32: {
+ case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Address SrcA = EmitPointerWithAlignment(E->getArg(1));
Address SrcB = EmitPointerWithAlignment(E->getArg(2));
@@ -10825,15 +10903,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
bool Satf = SatfArg.getSExtValue();
// clang-format off
-#define MMA_VARIANTS(type) {{ \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite \
+#define MMA_VARIANTS(geom, type) {{ \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \
}}
// clang-format on
@@ -10847,22 +10925,62 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
unsigned NumEltsD;
switch (BuiltinID) {
case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16));
NumEltsC = 4;
NumEltsD = 4;
break;
case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16));
NumEltsC = 4;
NumEltsD = 8;
break;
case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32));
NumEltsC = 8;
NumEltsD = 4;
break;
case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32));
+ NumEltsC = 8;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16));
+ NumEltsC = 4;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16));
+ NumEltsC = 4;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32));
+ NumEltsC = 8;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32));
+ NumEltsC = 8;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16));
+ NumEltsC = 4;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16));
+ NumEltsC = 4;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32));
+ NumEltsC = 8;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32));
NumEltsC = 8;
NumEltsD = 8;
break;
diff --git a/clang/lib/Driver/ToolChains/Cuda.cpp b/clang/lib/Driver/ToolChains/Cuda.cpp
index fc74a8ebbf0..fdd62fec390 100644
--- a/clang/lib/Driver/ToolChains/Cuda.cpp
+++ b/clang/lib/Driver/ToolChains/Cuda.cpp
@@ -622,17 +622,19 @@ void CudaToolChain::addClangTargetOptions(
CC1Args.push_back("-mlink-cuda-bitcode");
CC1Args.push_back(DriverArgs.MakeArgString(LibDeviceFile));
- if (CudaInstallation.version() >= CudaVersion::CUDA_90) {
- // CUDA-9 uses new instructions that are only available in PTX6.0
- CC1Args.push_back("-target-feature");
- CC1Args.push_back("+ptx60");
- } else {
- // Libdevice in CUDA-7.0 requires PTX version that's more recent
- // than LLVM defaults to. Use PTX4.2 which is the PTX version that
- // came with CUDA-7.0.
- CC1Args.push_back("-target-feature");
- CC1Args.push_back("+ptx42");
+ // Libdevice in CUDA-7.0 requires PTX version that's more recent than LLVM
+ // defaults to. Use PTX4.2 by default, which is the PTX version that came with
+ // CUDA-7.0.
+ const char *PtxFeature = "+ptx42";
+ if (CudaInstallation.version() >= CudaVersion::CUDA_91) {
+ // CUDA-9.1 uses new instructions that are only available in PTX6.1+
+ PtxFeature = "+ptx61";
+ } else if (CudaInstallation.version() >= CudaVersion::CUDA_90) {
+ // CUDA-9.0 uses new instructions that are only available in PTX6.0+
+ PtxFeature = "+ptx60";
}
+ CC1Args.push_back("-target-feature");
+ CC1Args.push_back(PtxFeature);
if (DeviceOffloadingKind == Action::OFK_OpenMP) {
SmallVector<StringRef, 8> LibraryPaths;
OpenPOWER on IntegriCloud