summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Target
diff options
context:
space:
mode:
authorArtem Belevich <tra@google.com>2017-10-12 18:27:55 +0000
committerArtem Belevich <tra@google.com>2017-10-12 18:27:55 +0000
commit3bafc2f0d9a9b392d77c3e51daeb2a83cc10f761 (patch)
tree7178d54b7f77001f4c93a9eb120816e5d139ccc5 /llvm/lib/Target
parent1a7e3878494efbf8dbc1c0ead042a8273e5b8229 (diff)
downloadbcm5719-llvm-3bafc2f0d9a9b392d77c3e51daeb2a83cc10f761.tar.gz
bcm5719-llvm-3bafc2f0d9a9b392d77c3e51daeb2a83cc10f761.zip
[NVPTX] Implemented wmma intrinsics and instructions.
WMMA = "Warp Level Matrix Multiply-Accumulate". These are the new instructions introduced in PTX6.0 and available on sm_70 GPUs. Differential Revision: https://reviews.llvm.org/D38645 llvm-svn: 315601
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp512
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h2
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp126
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td205
4 files changed, 845 insertions, 0 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 2f389860d14..a7e58fa9738 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -496,8 +496,318 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
SelectCode(N);
}
+// Each instruction has four addressing variants. WMMA_VARIANTS() macro below
+// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to
+// look up the intrinsic ID of particular variant.
+enum WmmaVariant {
+ WMMA_VARIANT_ARI64,
+ WMMA_VARIANT_ARI64_STRIDE,
+ WMMA_VARIANT_AVAR,
+ WMMA_VARIANT_AVAR_STRIDE,
+};
+
+// clang-format off
+#define WMMA_VARIANTS(base) \
+ {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }}
+// clang-format on
+
+static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride,
+ const std::array<unsigned, 4> Variants) {
+ if (Stride) {
+ if (Variant == WMMA_VARIANT_ARI64)
+ Variant = WMMA_VARIANT_ARI64_STRIDE;
+ else if (Variant == WMMA_VARIANT_AVAR)
+ Variant = WMMA_VARIANT_AVAR_STRIDE;
+ }
+ return Variants[Variant];
+}
+
+static Optional<unsigned>
+getWmmaLdStOpcode(unsigned IntrinsicID,
+ WmmaVariant Variant = WMMA_VARIANT_ARI64) {
+ switch (IntrinsicID) {
+ default:
+ return None;
+ //
+ // WMMA_LOAD_A f16
+ //
+ case Intrinsic::nvvm_wmma_load_a_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
+ case Intrinsic::nvvm_wmma_load_a_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
+
+ //
+ // WMMA_LOAD_B f16
+ //
+ case Intrinsic::nvvm_wmma_load_b_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
+ case Intrinsic::nvvm_wmma_load_b_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
+
+ //
+ // WMMA_LOAD_C f16
+ //
+ case Intrinsic::nvvm_wmma_load_c_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
+ case Intrinsic::nvvm_wmma_load_c_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
+
+ //
+ // WMMA_LOAD_C f32
+ //
+ case Intrinsic::nvvm_wmma_load_c_f32_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
+ case Intrinsic::nvvm_wmma_load_c_f32_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
+
+ //
+ // WMMA_STORE_D f16
+ //
+ case Intrinsic::nvvm_wmma_store_d_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
+ case Intrinsic::nvvm_wmma_store_d_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
+
+ //
+ // WMMA_STORE_D f32
+ //
+ case Intrinsic::nvvm_wmma_store_d_f32_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
+ case Intrinsic::nvvm_wmma_store_d_f32_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
+ }
+}
+#undef WMMA_VARIANTS
+
bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ if (getWmmaLdStOpcode(IID))
+ return tryWMMA_LDST(N);
+
switch (IID) {
default:
return false;
@@ -719,6 +1029,39 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
case Intrinsic::nvvm_match_all_sync_i64p:
SelectMatchAll(N);
return true;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
+ return tryWMMA_MMA(N);
}
}
@@ -3725,3 +4068,172 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
}
}
}
+
+bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) {
+ SDValue Chain = N->getOperand(0);
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ SDValue Op1 = N->getOperand(2);
+ SDValue Addr, Offset, Base;
+ Optional<unsigned> Opcode;
+ SDLoc DL(N);
+ MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
+ WmmaVariant Variant;
+ SmallVector<SDValue, 12> Ops;
+ bool isStore = N->getNumValues() == 1; // Store ops only return a chain.
+
+ if (SelectDirectAddr(Op1, Addr)) {
+ Variant = WMMA_VARIANT_AVAR;
+ Ops.push_back(Addr);
+ } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) ||
+ SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) {
+ Variant = WMMA_VARIANT_ARI64;
+ Ops.push_back(Base);
+ Ops.push_back(Offset);
+ } else {
+ Variant = WMMA_VARIANT_AVAR;
+ Ops.push_back(Op1);
+ }
+ unsigned NumOps = N->getNumOperands();
+ // Pass through the rest of the operands to the machine node.
+ for (unsigned i = 3; i < NumOps; ++i)
+ Ops.push_back(N->getOperand(i));
+ Ops.push_back(Chain);
+
+ Opcode = getWmmaLdStOpcode(IID, Variant);
+ if (!Opcode) {
+ llvm::errs() << "tryWMMALD - no Opcode.\n";
+ return false;
+ }
+
+ EVT MemVT = MemSD->getMemoryVT();
+ assert(MemVT.isVector() && "Expected vector return type.");
+
+ SDNode *MN;
+ if (isStore) {
+ MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops);
+ } else {
+ SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(),
+ MemSD->getValueType(0));
+ InstVTs.push_back(MVT::Other);
+ MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops);
+ }
+
+ ReplaceNode(N, MN);
+ return true;
+}
+
+bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) {
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
+ SDLoc DL(N);
+ unsigned Opc;
+
+ switch (IID) {
+ default:
+ return false;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite;
+ break;
+ }
+
+ SmallVector<SDValue, 24> Ops;
+ // Pass through operands and return value types to the machine node.
+ for (unsigned i = 1; i < N->getNumOperands(); ++i)
+ Ops.push_back(N->getOperand(i));
+ SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0));
+ SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops);
+ ReplaceNode(N, MN);
+ return true;
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 3ce7843b72f..b23c27581a1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -74,6 +74,8 @@ private:
bool tryConstantFP16(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
+ bool tryWMMA_LDST(SDNode *N);
+ bool tryWMMA_MMA(SDNode *N);
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 150e67a833f..7b9acb20b75 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3321,6 +3321,132 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
switch (Intrinsic) {
default:
return false;
+ case Intrinsic::nvvm_wmma_load_a_f16_col:
+ case Intrinsic::nvvm_wmma_load_a_f16_row:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_col:
+ case Intrinsic::nvvm_wmma_load_b_f16_row:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8f16;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = true;
+ Info.writeMem = false;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_load_c_f16_col:
+ case Intrinsic::nvvm_wmma_load_c_f16_row:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v4f16;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = true;
+ Info.writeMem = false;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8f32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = true;
+ Info.writeMem = false;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_store_d_f16_col:
+ case Intrinsic::nvvm_wmma_store_d_f16_row:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v4f16;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = false;
+ Info.writeMem = true;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8f32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = false;
+ Info.writeMem = true;
+ Info.align = 16;
+ return true;
+ }
case Intrinsic::nvvm_atomic_load_add_f32:
case Intrinsic::nvvm_atomic_load_inc_32:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 11ebaaa5407..f745b6f6635 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7368,3 +7368,208 @@ def INT_PTX_SREG_PM3 : PTX_READ_SREG_R32<"pm3", int_nvvm_read_ptx_sreg_pm3>;
def INT_PTX_SREG_WARPSIZE :
NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
[(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
+
+//
+// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand SrcOp, int WithOffset, int WithStride>
+ : NVPTXInst<!if(!eq(Abc#Type,"cf16"),
+ (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
+ (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)),
+ !if(WithStride,
+ !if(WithOffset,
+ (ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm),
+ (ins SrcOp:$src, Int32Regs:$ldm)),
+ !if(WithOffset,
+ (ins SrcOp:$src, i32imm:$offset),
+ (ins SrcOp:$src))),
+ "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+ #!if(!eq(Abc#Type,"cf16"),
+ "{{$r0, $r1, $r2, $r3}}",
+ "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+ #", "
+ #!if(WithOffset,"[$src+$offset]", "[$src]")
+ #!if(WithStride, ", $ldm", "")
+ #";",
+ []>,
+ Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand SrcOp, int WithOffset = 0> {
+ def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
+ WithOffset, 1>;
+ def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
+ WithOffset, 0>;
+}
+
+multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
+ string Type, NVPTXRegClass regclass> {
+ defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>;
+ defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>;
+}
+
+multiclass WMMA_LOAD_ALT<string Abc, string Layout,
+ string Type, NVPTXRegClass regclass> {
+ defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
+ defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
+ defm NAME: WMMA_LOAD_ALST<Abc, Layout, "", Type, regclass>;
+}
+
+multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
+ defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
+ defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
+}
+
+defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
+
+//
+// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+class WMMA_STORE_D_LSTOS<string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand DstOp, int WithOffset, int WithStride>
+ : NVPTXInst<(outs),
+ !if(!eq(Type,"f16"),
+ !if(WithStride,
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ Int32Regs:$ldm),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ Int32Regs:$ldm)),
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))),
+ !if(WithStride,
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
+ Int32Regs:$ldm),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
+ Int32Regs:$ldm)),
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))),
+ "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+ #!if(WithOffset,"[$src+$offset], ", "[$src], ")
+ #!if(!eq(Type,"f16"),
+ "{{$r0, $r1, $r2, $r3}}",
+ "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+ #!if(WithStride, ", $ldm", "")
+ #";",
+ []>,
+ Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand DstOp, int WithOffset = 0> {
+ def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
+ WithOffset, 1>;
+ def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
+ WithOffset, 0>;
+}
+
+multiclass WMMA_STORE_D_LST<string Layout, string Space,
+ string Type, NVPTXRegClass regclass> {
+ defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>;
+ defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>;
+}
+
+multiclass WMMA_STORE_D_LT<string Layout,
+ string Type, NVPTXRegClass regclass> {
+ defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
+ defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
+ defm NAME: WMMA_STORE_D_LST<Layout, "", Type, regclass>;
+}
+
+multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
+ defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
+ defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
+}
+
+defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
+defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
+
+// WMMA.MMA
+class WMMA_MMA_ABDCS<string ALayout, string BLayout,
+ string DType, NVPTXRegClass d_reg,
+ string CType, NVPTXRegClass c_reg,
+ NVPTXRegClass ab_reg,
+ string Satfinite = "">
+ : NVPTXInst<!if(!eq(DType,"f16"),
+ (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
+ (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
+ d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)),
+ !if(!eq(CType,"f16"),
+ (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+ ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+ ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+ ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+ c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
+ (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+ ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+ ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+ ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+ c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3,
+ c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)),
+ "wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."#
+ #DType#"."#CType#Satfinite
+ #"\n\t\t"
+ #!if(!eq(DType,"f16"),
+ "{{$d0, $d1, $d2, $d3}}, \n\t\t",
+ "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
+ #"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
+ #"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
+ #!if(!eq(CType,"f16"),
+ "{{$c0, $c1, $c2, $c3}};",
+ "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"),
+ []>,
+ Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
+ string DType, NVPTXRegClass d_reg,
+ string CType, NVPTXRegClass c_reg> {
+ def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
+ DType, d_reg, CType, c_reg,
+ Float16x2Regs, ".satfinite">;
+ def NAME: WMMA_MMA_ABDCS<ALayout, BLayout,
+ DType, d_reg, CType, c_reg,
+ Float16x2Regs>;
+}
+
+multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
+ string DType, NVPTXRegClass d_reg> {
+ defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
+ defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
+}
+
+multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
+ defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
+ defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
+}
+
+multiclass WMMA_MMA_A<string ALayout> {
+ defm _col: WMMA_MMA_AB<ALayout, "col">;
+ defm _row: WMMA_MMA_AB<ALayout, "row">;
+}
+
+defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
+defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;
+
OpenPOWER on IntegriCloud