summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
authorArtem Belevich <tra@google.com>2018-04-18 21:51:48 +0000
committerArtem Belevich <tra@google.com>2018-04-18 21:51:48 +0000
commit0ae8590354b8688e1ec9926abc909b896ea49038 (patch)
tree3c803aae33ad4fd575d7d3672138519c7b20e0fc /llvm/lib
parentc310bfa19397e15903a8f5386b51366aade414b9 (diff)
downloadbcm5719-llvm-0ae8590354b8688e1ec9926abc909b896ea49038.tar.gz
bcm5719-llvm-0ae8590354b8688e1ec9926abc909b896ea49038.zip
[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.
The new instructions were added added for sm_70+ GPUs in CUDA-9.1. Differential Revision: https://reviews.llvm.org/D45068 llvm-svn: 330296
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTX.td5
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp58
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td38
4 files changed, 85 insertions, 17 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td
index 0cafb0ffdba..6494c46f54a 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.td
+++ b/llvm/lib/Target/NVPTX/NVPTX.td
@@ -52,6 +52,8 @@ def SM62 : SubtargetFeature<"sm_62", "SmVersion", "62",
"Target SM 6.2">;
def SM70 : SubtargetFeature<"sm_70", "SmVersion", "70",
"Target SM 7.0">;
+def SM72 : SubtargetFeature<"sm_72", "SmVersion", "72",
+ "Target SM 7.2">;
// PTX Versions
def PTX32 : SubtargetFeature<"ptx32", "PTXVersion", "32",
@@ -68,6 +70,8 @@ def PTX50 : SubtargetFeature<"ptx50", "PTXVersion", "50",
"Use PTX version 5.0">;
def PTX60 : SubtargetFeature<"ptx60", "PTXVersion", "60",
"Use PTX version 6.0">;
+def PTX61 : SubtargetFeature<"ptx61", "PTXVersion", "61",
+ "Use PTX version 6.1">;
//===----------------------------------------------------------------------===//
// NVPTX supported processors.
@@ -89,6 +93,7 @@ def : Proc<"sm_60", [SM60, PTX50]>;
def : Proc<"sm_61", [SM61, PTX50]>;
def : Proc<"sm_62", [SM62, PTX50]>;
def : Proc<"sm_70", [SM70, PTX60]>;
+def : Proc<"sm_72", [SM72, PTX61]>;
def NVPTXInstrInfo : InstrInfo {
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 79d17334655..5928bb8df66 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3329,7 +3329,23 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f16;
Info.ptrVal = I.getArgOperand(0);
@@ -3342,7 +3358,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@@ -3355,7 +3379,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@@ -3368,7 +3400,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@@ -3381,7 +3421,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index aed8266a1e4..7b2bf386d62 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def true : Predicate<"true">;
def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
+def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index b46247f21c9..66419f034f6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7378,7 +7378,11 @@ class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
DAGOperand SrcOp, bit WithStride>
- : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ : EmptyNVPTXInst,
+ Requires<[!if(!eq(Geometry, "m16n16k16"),
+ hasPTX60,
+ hasPTX61),
+ hasSM70]> {
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
// for this function.
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
@@ -7420,10 +7424,10 @@ class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
let InOperandList = Ins;
let AsmString = "wmma.load."
# Abc
- # ".sync."
- # Layout
- # ".m16n16k16"
- # Space
+ # ".sync"
+ # "." # Layout
+ # "." # Geometry
+ # Space
# "." # Type # " \t"
# !if(!eq(Abc#Type, "cf16"),
"{{$r0, $r1, $r2, $r3}}",
@@ -7512,7 +7516,9 @@ multiclass WMMA_LOAD_G<string Geometry> {
defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
}
+defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">;
defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
+defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">;
//
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
@@ -7520,7 +7526,11 @@ defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride, DAGOperand DstOp>
- : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ : EmptyNVPTXInst,
+ Requires<[!if(!eq(Geometry, "m16n16k16"),
+ hasPTX60,
+ hasPTX61),
+ hasSM70]> {
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
# "_" # Geometry # "_store_d"
# "_" # Type
@@ -7641,11 +7651,9 @@ multiclass WMMA_STORE_D_G<string Geometry> {
defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
}
-// multiclass WMMA_STORE_D {
-// defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
-// }
-
+defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">;
defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
+defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">;
// WMMA.MMA
class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
@@ -7653,7 +7661,11 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
string CType, NVPTXRegClass c_reg,
NVPTXRegClass ab_reg,
string Satfinite = "">
- : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
+ : EmptyNVPTXInst,
+ Requires<[!if(!eq(Geometry, "m16n16k16"),
+ hasPTX60,
+ hasPTX61),
+ hasSM70]> {
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
# Geometry
# "_mma"
@@ -7686,7 +7698,7 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
let AsmString = "wmma.mma.sync."
# ALayout
# "." # BLayout
- # ".m16n16k16"
+ # "." # Geometry
# "." # DType
# "." # CType
# Satfinite # "\n\t\t"
@@ -7734,4 +7746,6 @@ multiclass WMMA_MMA_G<string Geometry> {
defm _row: WMMA_MMA_GA<Geometry, "row">;
}
+defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">;
defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
+defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">;
OpenPOWER on IntegriCloud