diff options
| author | Artem Belevich <tra@google.com> | 2018-04-18 21:51:48 +0000 |
|---|---|---|
| committer | Artem Belevich <tra@google.com> | 2018-04-18 21:51:48 +0000 |
| commit | 0ae8590354b8688e1ec9926abc909b896ea49038 (patch) | |
| tree | 3c803aae33ad4fd575d7d3672138519c7b20e0fc /llvm/lib | |
| parent | c310bfa19397e15903a8f5386b51366aade414b9 (diff) | |
| download | bcm5719-llvm-0ae8590354b8688e1ec9926abc909b896ea49038.tar.gz bcm5719-llvm-0ae8590354b8688e1ec9926abc909b896ea49038.zip | |
[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.
The new instructions were added added for sm_70+ GPUs in CUDA-9.1.
Differential Revision: https://reviews.llvm.org/D45068
llvm-svn: 330296
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTX.td | 5 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 58 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 1 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 38 |
4 files changed, 85 insertions, 17 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td index 0cafb0ffdba..6494c46f54a 100644 --- a/llvm/lib/Target/NVPTX/NVPTX.td +++ b/llvm/lib/Target/NVPTX/NVPTX.td @@ -52,6 +52,8 @@ def SM62 : SubtargetFeature<"sm_62", "SmVersion", "62", "Target SM 6.2">; def SM70 : SubtargetFeature<"sm_70", "SmVersion", "70", "Target SM 7.0">; +def SM72 : SubtargetFeature<"sm_72", "SmVersion", "72", + "Target SM 7.2">; // PTX Versions def PTX32 : SubtargetFeature<"ptx32", "PTXVersion", "32", @@ -68,6 +70,8 @@ def PTX50 : SubtargetFeature<"ptx50", "PTXVersion", "50", "Use PTX version 5.0">; def PTX60 : SubtargetFeature<"ptx60", "PTXVersion", "60", "Use PTX version 6.0">; +def PTX61 : SubtargetFeature<"ptx61", "PTXVersion", "61", + "Use PTX version 6.1">; //===----------------------------------------------------------------------===// // NVPTX supported processors. @@ -89,6 +93,7 @@ def : Proc<"sm_60", [SM60, PTX50]>; def : Proc<"sm_61", [SM61, PTX50]>; def : Proc<"sm_62", [SM62, PTX50]>; def : Proc<"sm_70", [SM70, PTX60]>; +def : Proc<"sm_72", [SM72, PTX61]>; def NVPTXInstrInfo : InstrInfo { } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 79d17334655..5928bb8df66 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3329,7 +3329,23 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col: case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row: case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f16; Info.ptrVal = I.getArgOperand(0); @@ -3342,7 +3358,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col: case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row: case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); @@ -3355,7 +3379,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col: case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row: case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); @@ -3368,7 +3400,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); @@ -3381,7 +3421,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row: case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: { + case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride: + case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride: + case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index aed8266a1e4..7b2bf386d62 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -142,6 +142,7 @@ def true : Predicate<"true">; def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">; def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">; +def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">; def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">; def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index b46247f21c9..66419f034f6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7378,7 +7378,11 @@ class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout, string Space, string Type, NVPTXRegClass regclass, DAGOperand SrcOp, bit WithStride> - : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { + : EmptyNVPTXInst, + Requires<[!if(!eq(Geometry, "m16n16k16"), + hasPTX60, + hasPTX61), + hasSM70]> { // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic // for this function. PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_" @@ -7420,10 +7424,10 @@ class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout, let InOperandList = Ins; let AsmString = "wmma.load." # Abc - # ".sync." - # Layout - # ".m16n16k16" - # Space + # ".sync" + # "." # Layout + # "." # Geometry + # Space # "." # Type # " \t" # !if(!eq(Abc#Type, "cf16"), "{{$r0, $r1, $r2, $r3}}", @@ -7512,7 +7516,9 @@ multiclass WMMA_LOAD_G<string Geometry> { defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>; } +defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">; defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; +defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">; // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] @@ -7520,7 +7526,11 @@ defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space, string Type, NVPTXRegClass regclass, bit WithStride, DAGOperand DstOp> - : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { + : EmptyNVPTXInst, + Requires<[!if(!eq(Geometry, "m16n16k16"), + hasPTX60, + hasPTX61), + hasSM70]> { PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA" # "_" # Geometry # "_store_d" # "_" # Type @@ -7641,11 +7651,9 @@ multiclass WMMA_STORE_D_G<string Geometry> { defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>; } -// multiclass WMMA_STORE_D { -// defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">; -// } - +defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">; defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">; +defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">; // WMMA.MMA class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, @@ -7653,7 +7661,11 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, string CType, NVPTXRegClass c_reg, NVPTXRegClass ab_reg, string Satfinite = ""> - : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { + : EmptyNVPTXInst, + Requires<[!if(!eq(Geometry, "m16n16k16"), + hasPTX60, + hasPTX61), + hasSM70]> { Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_" # Geometry # "_mma" @@ -7686,7 +7698,7 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, let AsmString = "wmma.mma.sync." # ALayout # "." # BLayout - # ".m16n16k16" + # "." # Geometry # "." # DType # "." # CType # Satfinite # "\n\t\t" @@ -7734,4 +7746,6 @@ multiclass WMMA_MMA_G<string Geometry> { defm _row: WMMA_MMA_GA<Geometry, "row">; } +defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">; defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">; +defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">; |

