diff options
| author | Artem Belevich <tra@google.com> | 2017-10-12 18:27:55 +0000 |
|---|---|---|
| committer | Artem Belevich <tra@google.com> | 2017-10-12 18:27:55 +0000 |
| commit | 3bafc2f0d9a9b392d77c3e51daeb2a83cc10f761 (patch) | |
| tree | 7178d54b7f77001f4c93a9eb120816e5d139ccc5 /llvm/lib/Target | |
| parent | 1a7e3878494efbf8dbc1c0ead042a8273e5b8229 (diff) | |
| download | bcm5719-llvm-3bafc2f0d9a9b392d77c3e51daeb2a83cc10f761.tar.gz bcm5719-llvm-3bafc2f0d9a9b392d77c3e51daeb2a83cc10f761.zip | |
[NVPTX] Implemented wmma intrinsics and instructions.
WMMA = "Warp Level Matrix Multiply-Accumulate".
These are the new instructions introduced in PTX6.0 and available
on sm_70 GPUs.
Differential Revision: https://reviews.llvm.org/D38645
llvm-svn: 315601
Diffstat (limited to 'llvm/lib/Target')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 512 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 126 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 205 |
4 files changed, 845 insertions, 0 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 2f389860d14..a7e58fa9738 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -496,8 +496,318 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { SelectCode(N); } +// Each instruction has four addressing variants. WMMA_VARIANTS() macro below +// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to +// look up the intrinsic ID of particular variant. +enum WmmaVariant { + WMMA_VARIANT_ARI64, + WMMA_VARIANT_ARI64_STRIDE, + WMMA_VARIANT_AVAR, + WMMA_VARIANT_AVAR_STRIDE, +}; + +// clang-format off +#define WMMA_VARIANTS(base) \ + {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }} +// clang-format on + +static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride, + const std::array<unsigned, 4> Variants) { + if (Stride) { + if (Variant == WMMA_VARIANT_ARI64) + Variant = WMMA_VARIANT_ARI64_STRIDE; + else if (Variant == WMMA_VARIANT_AVAR) + Variant = WMMA_VARIANT_AVAR_STRIDE; + } + return Variants[Variant]; +} + +static Optional<unsigned> +getWmmaLdStOpcode(unsigned IntrinsicID, + WmmaVariant Variant = WMMA_VARIANT_ARI64) { + switch (IntrinsicID) { + default: + return None; + // + // WMMA_LOAD_A f16 + // + case Intrinsic::nvvm_wmma_load_a_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); + case Intrinsic::nvvm_wmma_load_a_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); + case Intrinsic::nvvm_wmma_load_a_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col)); + case Intrinsic::nvvm_wmma_load_a_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row)); + case Intrinsic::nvvm_wmma_load_a_f16_col_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_row_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared)); + case Intrinsic::nvvm_wmma_load_a_f16_col_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); + case Intrinsic::nvvm_wmma_load_a_f16_row_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); + case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global)); + case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global)); + + // + // WMMA_LOAD_B f16 + // + case Intrinsic::nvvm_wmma_load_b_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); + case Intrinsic::nvvm_wmma_load_b_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); + case Intrinsic::nvvm_wmma_load_b_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col)); + case Intrinsic::nvvm_wmma_load_b_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row)); + case Intrinsic::nvvm_wmma_load_b_f16_col_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_row_shared: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared)); + case Intrinsic::nvvm_wmma_load_b_f16_col_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); + case Intrinsic::nvvm_wmma_load_b_f16_row_global: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); + case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global)); + case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global)); + + // + // WMMA_LOAD_C f16 + // + case Intrinsic::nvvm_wmma_load_c_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); + case Intrinsic::nvvm_wmma_load_c_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); + case Intrinsic::nvvm_wmma_load_c_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col)); + case Intrinsic::nvvm_wmma_load_c_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row)); + case Intrinsic::nvvm_wmma_load_c_f16_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f16_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); + case Intrinsic::nvvm_wmma_load_c_f16_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); + case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global)); + case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global)); + + // + // WMMA_LOAD_C f32 + // + case Intrinsic::nvvm_wmma_load_c_f32_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); + case Intrinsic::nvvm_wmma_load_c_f32_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); + case Intrinsic::nvvm_wmma_load_c_f32_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col)); + case Intrinsic::nvvm_wmma_load_c_f32_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row)); + case Intrinsic::nvvm_wmma_load_c_f32_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared)); + case Intrinsic::nvvm_wmma_load_c_f32_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); + case Intrinsic::nvvm_wmma_load_c_f32_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); + case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global)); + case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global)); + + // + // WMMA_STORE_D f16 + // + case Intrinsic::nvvm_wmma_store_d_f16_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); + case Intrinsic::nvvm_wmma_store_d_f16_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); + case Intrinsic::nvvm_wmma_store_d_f16_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col)); + case Intrinsic::nvvm_wmma_store_d_f16_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row)); + case Intrinsic::nvvm_wmma_store_d_f16_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f16_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); + case Intrinsic::nvvm_wmma_store_d_f16_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); + case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global)); + case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global)); + + // + // WMMA_STORE_D f32 + // + case Intrinsic::nvvm_wmma_store_d_f32_col: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); + case Intrinsic::nvvm_wmma_store_d_f32_row: + return getWmmaLdVariant(Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); + case Intrinsic::nvvm_wmma_store_d_f32_col_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col)); + case Intrinsic::nvvm_wmma_store_d_f32_row_stride: + return getWmmaLdVariant(Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row)); + case Intrinsic::nvvm_wmma_store_d_f32_col_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_row_shared: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared)); + case Intrinsic::nvvm_wmma_store_d_f32_col_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); + case Intrinsic::nvvm_wmma_store_d_f32_row_global: + return getWmmaLdVariant( + Variant, /*Stride=*/false, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); + case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global)); + case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: + return getWmmaLdVariant( + Variant, /*Stride=*/true, + WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global)); + } +} +#undef WMMA_VARIANTS + bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + if (getWmmaLdStOpcode(IID)) + return tryWMMA_LDST(N); + switch (IID) { default: return false; @@ -719,6 +1029,39 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) { case Intrinsic::nvvm_match_all_sync_i64p: SelectMatchAll(N); return true; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: + return tryWMMA_MMA(N); } } @@ -3725,3 +4068,172 @@ unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy, } } } + +bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) { + SDValue Chain = N->getOperand(0); + unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + SDValue Op1 = N->getOperand(2); + SDValue Addr, Offset, Base; + Optional<unsigned> Opcode; + SDLoc DL(N); + MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N); + WmmaVariant Variant; + SmallVector<SDValue, 12> Ops; + bool isStore = N->getNumValues() == 1; // Store ops only return a chain. + + if (SelectDirectAddr(Op1, Addr)) { + Variant = WMMA_VARIANT_AVAR; + Ops.push_back(Addr); + } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) || + SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) { + Variant = WMMA_VARIANT_ARI64; + Ops.push_back(Base); + Ops.push_back(Offset); + } else { + Variant = WMMA_VARIANT_AVAR; + Ops.push_back(Op1); + } + unsigned NumOps = N->getNumOperands(); + // Pass through the rest of the operands to the machine node. + for (unsigned i = 3; i < NumOps; ++i) + Ops.push_back(N->getOperand(i)); + Ops.push_back(Chain); + + Opcode = getWmmaLdStOpcode(IID, Variant); + if (!Opcode) { + llvm::errs() << "tryWMMALD - no Opcode.\n"; + return false; + } + + EVT MemVT = MemSD->getMemoryVT(); + assert(MemVT.isVector() && "Expected vector return type."); + + SDNode *MN; + if (isStore) { + MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops); + } else { + SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(), + MemSD->getValueType(0)); + InstVTs.push_back(MVT::Other); + MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops); + } + + ReplaceNode(N, MN); + return true; +} + +bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) { + unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue(); + SDLoc DL(N); + unsigned Opc; + + switch (IID) { + default: + return false; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32; + break; + case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite: + Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite; + break; + } + + SmallVector<SDValue, 24> Ops; + // Pass through operands and return value types to the machine node. + for (unsigned i = 1; i < N->getNumOperands(); ++i) + Ops.push_back(N->getOperand(i)); + SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0)); + SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops); + ReplaceNode(N, MN); + return true; +} diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index 3ce7843b72f..b23c27581a1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -74,6 +74,8 @@ private: bool tryConstantFP16(SDNode *N); bool SelectSETP_F16X2(SDNode *N); bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); + bool tryWMMA_LDST(SDNode *N); + bool tryWMMA_MMA(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 150e67a833f..7b9acb20b75 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3321,6 +3321,132 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( switch (Intrinsic) { default: return false; + case Intrinsic::nvvm_wmma_load_a_f16_col: + case Intrinsic::nvvm_wmma_load_a_f16_row: + case Intrinsic::nvvm_wmma_load_a_f16_col_stride: + case Intrinsic::nvvm_wmma_load_a_f16_row_stride: + case Intrinsic::nvvm_wmma_load_a_f16_col_shared: + case Intrinsic::nvvm_wmma_load_a_f16_row_shared: + case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_load_a_f16_col_global: + case Intrinsic::nvvm_wmma_load_a_f16_row_global: + case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: + case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: + case Intrinsic::nvvm_wmma_load_b_f16_col: + case Intrinsic::nvvm_wmma_load_b_f16_row: + case Intrinsic::nvvm_wmma_load_b_f16_col_stride: + case Intrinsic::nvvm_wmma_load_b_f16_row_stride: + case Intrinsic::nvvm_wmma_load_b_f16_col_shared: + case Intrinsic::nvvm_wmma_load_b_f16_row_shared: + case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_load_b_f16_col_global: + case Intrinsic::nvvm_wmma_load_b_f16_row_global: + case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: + case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8f16; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = true; + Info.writeMem = false; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_load_c_f16_col: + case Intrinsic::nvvm_wmma_load_c_f16_row: + case Intrinsic::nvvm_wmma_load_c_f16_col_stride: + case Intrinsic::nvvm_wmma_load_c_f16_row_stride: + case Intrinsic::nvvm_wmma_load_c_f16_col_shared: + case Intrinsic::nvvm_wmma_load_c_f16_row_shared: + case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f16_col_global: + case Intrinsic::nvvm_wmma_load_c_f16_row_global: + case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: + case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v4f16; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = true; + Info.writeMem = false; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_load_c_f32_col: + case Intrinsic::nvvm_wmma_load_c_f32_row: + case Intrinsic::nvvm_wmma_load_c_f32_col_stride: + case Intrinsic::nvvm_wmma_load_c_f32_row_stride: + case Intrinsic::nvvm_wmma_load_c_f32_col_shared: + case Intrinsic::nvvm_wmma_load_c_f32_row_shared: + case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: + case Intrinsic::nvvm_wmma_load_c_f32_col_global: + case Intrinsic::nvvm_wmma_load_c_f32_row_global: + case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: + case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8f32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = true; + Info.writeMem = false; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_store_d_f16_col: + case Intrinsic::nvvm_wmma_store_d_f16_row: + case Intrinsic::nvvm_wmma_store_d_f16_col_stride: + case Intrinsic::nvvm_wmma_store_d_f16_row_stride: + case Intrinsic::nvvm_wmma_store_d_f16_col_shared: + case Intrinsic::nvvm_wmma_store_d_f16_row_shared: + case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f16_col_global: + case Intrinsic::nvvm_wmma_store_d_f16_row_global: + case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: + case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v4f16; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = false; + Info.writeMem = true; + Info.align = 16; + return true; + } + + case Intrinsic::nvvm_wmma_store_d_f32_col: + case Intrinsic::nvvm_wmma_store_d_f32_row: + case Intrinsic::nvvm_wmma_store_d_f32_col_stride: + case Intrinsic::nvvm_wmma_store_d_f32_row_stride: + case Intrinsic::nvvm_wmma_store_d_f32_col_shared: + case Intrinsic::nvvm_wmma_store_d_f32_row_shared: + case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: + case Intrinsic::nvvm_wmma_store_d_f32_col_global: + case Intrinsic::nvvm_wmma_store_d_f32_row_global: + case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: + case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: { + Info.opc = ISD::INTRINSIC_W_CHAIN; + Info.memVT = MVT::v8f32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.vol = false; + Info.readMem = false; + Info.writeMem = true; + Info.align = 16; + return true; + } case Intrinsic::nvvm_atomic_load_add_f32: case Intrinsic::nvvm_atomic_load_inc_32: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 11ebaaa5407..f745b6f6635 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7368,3 +7368,208 @@ def INT_PTX_SREG_PM3 : PTX_READ_SREG_R32<"pm3", int_nvvm_read_ptx_sreg_pm3>; def INT_PTX_SREG_WARPSIZE : NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;", [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>; + +// +// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// +class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space, + string Type, NVPTXRegClass regclass, + Operand SrcOp, int WithOffset, int WithStride> + : NVPTXInst<!if(!eq(Abc#Type,"cf16"), + (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3), + (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)), + !if(WithStride, + !if(WithOffset, + (ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm), + (ins SrcOp:$src, Int32Regs:$ldm)), + !if(WithOffset, + (ins SrcOp:$src, i32imm:$offset), + (ins SrcOp:$src))), + "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t" + #!if(!eq(Abc#Type,"cf16"), + "{{$r0, $r1, $r2, $r3}}", + "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + #", " + #!if(WithOffset,"[$src+$offset]", "[$src]") + #!if(WithStride, ", $ldm", "") + #";", + []>, + Requires<[hasPTX60, hasSM70]>; + +multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space, + string Type, NVPTXRegClass regclass, + Operand SrcOp, int WithOffset = 0> { + def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, + WithOffset, 1>; + def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, + WithOffset, 0>; +} + +multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space, + string Type, NVPTXRegClass regclass> { + defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>; + defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>; +} + +multiclass WMMA_LOAD_ALT<string Abc, string Layout, + string Type, NVPTXRegClass regclass> { + defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>; + defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>; + defm NAME: WMMA_LOAD_ALST<Abc, Layout, "", Type, regclass>; +} + +multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> { + defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>; + defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>; +} + +defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>; +defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>; +defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>; +defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>; + +// +// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// +class WMMA_STORE_D_LSTOS<string Layout, string Space, + string Type, NVPTXRegClass regclass, + Operand DstOp, int WithOffset, int WithStride> + : NVPTXInst<(outs), + !if(!eq(Type,"f16"), + !if(WithStride, + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + Int32Regs:$ldm), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + Int32Regs:$ldm)), + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))), + !if(WithStride, + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7, + Int32Regs:$ldm), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7, + Int32Regs:$ldm)), + !if(WithOffset, + (ins DstOp:$src, i32imm:$offset, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7), + (ins DstOp:$src, + regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3, + regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))), + "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t" + #!if(WithOffset,"[$src+$offset], ", "[$src], ") + #!if(!eq(Type,"f16"), + "{{$r0, $r1, $r2, $r3}}", + "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + #!if(WithStride, ", $ldm", "") + #";", + []>, + Requires<[hasPTX60, hasSM70]>; + +multiclass WMMA_STORE_D_LSTO<string Layout, string Space, + string Type, NVPTXRegClass regclass, + Operand DstOp, int WithOffset = 0> { + def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, + WithOffset, 1>; + def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, + WithOffset, 0>; +} + +multiclass WMMA_STORE_D_LST<string Layout, string Space, + string Type, NVPTXRegClass regclass> { + defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>; + defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>; +} + +multiclass WMMA_STORE_D_LT<string Layout, + string Type, NVPTXRegClass regclass> { + defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>; + defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>; + defm NAME: WMMA_STORE_D_LST<Layout, "", Type, regclass>; +} + +multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> { + defm _row: WMMA_STORE_D_LT<"row", Type, regclass>; + defm _col: WMMA_STORE_D_LT<"col", Type, regclass>; +} + +defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>; +defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>; + +// WMMA.MMA +class WMMA_MMA_ABDCS<string ALayout, string BLayout, + string DType, NVPTXRegClass d_reg, + string CType, NVPTXRegClass c_reg, + NVPTXRegClass ab_reg, + string Satfinite = ""> + : NVPTXInst<!if(!eq(DType,"f16"), + (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), + (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, + d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)), + !if(!eq(CType,"f16"), + (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, + ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, + ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, + ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, + c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3), + (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, + ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, + ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, + ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, + c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3, + c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)), + "wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."# + #DType#"."#CType#Satfinite + #"\n\t\t" + #!if(!eq(DType,"f16"), + "{{$d0, $d1, $d2, $d3}}, \n\t\t", + "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t") + #"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t" + #"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t" + #!if(!eq(CType,"f16"), + "{{$c0, $c1, $c2, $c3}};", + "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"), + []>, + Requires<[hasPTX60, hasSM70]>; + +multiclass WMMA_MMA_ABDC<string ALayout, string BLayout, + string DType, NVPTXRegClass d_reg, + string CType, NVPTXRegClass c_reg> { + def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout, + DType, d_reg, CType, c_reg, + Float16x2Regs, ".satfinite">; + def NAME: WMMA_MMA_ABDCS<ALayout, BLayout, + DType, d_reg, CType, c_reg, + Float16x2Regs>; +} + +multiclass WMMA_MMA_ABD<string ALayout, string BLayout, + string DType, NVPTXRegClass d_reg> { + defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>; + defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>; +} + +multiclass WMMA_MMA_AB<string ALayout, string BLayout> { + defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>; + defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>; +} + +multiclass WMMA_MMA_A<string ALayout> { + defm _col: WMMA_MMA_AB<ALayout, "col">; + defm _row: WMMA_MMA_AB<ALayout, "row">; +} + +defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">; +defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">; + |

