diff options
Diffstat (limited to 'llvm')
| -rw-r--r-- | llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 58 | ||||
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 148 | ||||
| -rw-r--r-- | llvm/test/CodeGen/NVPTX/wmma.py | 52 |
4 files changed, 171 insertions, 152 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 73622ce9303..e6734ed20fe 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -3884,30 +3884,22 @@ def int_nvvm_match_all_sync_i64p : // // WMMA.LOAD -class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space, - string Type, LLVMType regty, int WithStride> +class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type, + LLVMType regty, int WithStride> : Intrinsic<!if(!eq(Abc#Type,"cf16"), [regty, regty, regty, regty], [regty, regty, regty, regty, regty, regty, regty, regty]), - !if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]), - [], // Properties must be set during instantiation. + !if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]), + [IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>], "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16" - #Space #!if(WithStride,".stride","") #"."#Type>; -multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space, - string Type, LLVMType regty> { - def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>; - def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>; -} - -multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, - string Type, LLVMType regty> { - defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>; - defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>; - defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>; +multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type, + LLVMType regty> { + def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>; + def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>; } multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> { @@ -3915,47 +3907,33 @@ multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> { defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>; } -// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are -// passed to Intrinsic<> form inside of a multiclass. Setting them globally -// outside of the multiclass works. -let IntrProperties = [IntrReadMem, IntrArgMemOnly, - ReadOnly<0>, NoCapture<0>] in { - defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>; -} +defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>; // WMMA.STORE.D -class NVVM_WMMA_STD_LSTS<string Layout, string Space, - string Type, LLVMType regty, int WithStride, +class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride, // This is only used to create a typed empty array we // need to pass to !if below. list<LLVMType>Empty=[]> : Intrinsic<[], !listconcat( - [llvm_ptr_ty], + [llvm_anyptr_ty], !if(!eq(Type,"f16"), [regty, regty, regty, regty], [regty, regty, regty, regty, regty, regty, regty, regty]), !if(WithStride, [llvm_i32_ty], Empty)), - [], // Properties must be set during instantiation. + [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], "llvm.nvvm.wmma.store.d.sync."#Layout - #".m16n16k16"#Space + #".m16n16k16" #!if(WithStride,".stride","") #"."#Type>; -multiclass NVVM_WMMA_STD_LST<string Layout, string Space, - string Type, LLVMType regty> { - def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>; - def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>; -} - multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> { - defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>; - defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>; - defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>; + def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>; + def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>; } multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> { @@ -3963,11 +3941,8 @@ multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> { defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>; } -let IntrProperties = [IntrWriteMem, IntrArgMemOnly, - WriteOnly<0>, NoCapture<0>] in { - defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>; - defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>; -} +defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>; +defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>; // WMMA.MMA class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout, diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index b03b0bbf69d..2fcfe95db63 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3327,26 +3327,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_load_a_f16_row: case Intrinsic::nvvm_wmma_load_a_f16_col_stride: case Intrinsic::nvvm_wmma_load_a_f16_row_stride: - case Intrinsic::nvvm_wmma_load_a_f16_col_shared: - case Intrinsic::nvvm_wmma_load_a_f16_row_shared: - case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride: - case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride: - case Intrinsic::nvvm_wmma_load_a_f16_col_global: - case Intrinsic::nvvm_wmma_load_a_f16_row_global: - case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride: - case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride: case Intrinsic::nvvm_wmma_load_b_f16_col: case Intrinsic::nvvm_wmma_load_b_f16_row: case Intrinsic::nvvm_wmma_load_b_f16_col_stride: - case Intrinsic::nvvm_wmma_load_b_f16_row_stride: - case Intrinsic::nvvm_wmma_load_b_f16_col_shared: - case Intrinsic::nvvm_wmma_load_b_f16_row_shared: - case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride: - case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride: - case Intrinsic::nvvm_wmma_load_b_f16_col_global: - case Intrinsic::nvvm_wmma_load_b_f16_row_global: - case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride: - case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: { + case Intrinsic::nvvm_wmma_load_b_f16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f16; Info.ptrVal = I.getArgOperand(0); @@ -3359,15 +3343,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_load_c_f16_col: case Intrinsic::nvvm_wmma_load_c_f16_row: case Intrinsic::nvvm_wmma_load_c_f16_col_stride: - case Intrinsic::nvvm_wmma_load_c_f16_row_stride: - case Intrinsic::nvvm_wmma_load_c_f16_col_shared: - case Intrinsic::nvvm_wmma_load_c_f16_row_shared: - case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride: - case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride: - case Intrinsic::nvvm_wmma_load_c_f16_col_global: - case Intrinsic::nvvm_wmma_load_c_f16_row_global: - case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride: - case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: { + case Intrinsic::nvvm_wmma_load_c_f16_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); @@ -3380,15 +3356,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_load_c_f32_col: case Intrinsic::nvvm_wmma_load_c_f32_row: case Intrinsic::nvvm_wmma_load_c_f32_col_stride: - case Intrinsic::nvvm_wmma_load_c_f32_row_stride: - case Intrinsic::nvvm_wmma_load_c_f32_col_shared: - case Intrinsic::nvvm_wmma_load_c_f32_row_shared: - case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride: - case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride: - case Intrinsic::nvvm_wmma_load_c_f32_col_global: - case Intrinsic::nvvm_wmma_load_c_f32_row_global: - case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride: - case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: { + case Intrinsic::nvvm_wmma_load_c_f32_row_stride: { Info.opc = ISD::INTRINSIC_W_CHAIN; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); @@ -3401,15 +3369,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_store_d_f16_col: case Intrinsic::nvvm_wmma_store_d_f16_row: case Intrinsic::nvvm_wmma_store_d_f16_col_stride: - case Intrinsic::nvvm_wmma_store_d_f16_row_stride: - case Intrinsic::nvvm_wmma_store_d_f16_col_shared: - case Intrinsic::nvvm_wmma_store_d_f16_row_shared: - case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride: - case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride: - case Intrinsic::nvvm_wmma_store_d_f16_col_global: - case Intrinsic::nvvm_wmma_store_d_f16_row_global: - case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride: - case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: { + case Intrinsic::nvvm_wmma_store_d_f16_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v4f16; Info.ptrVal = I.getArgOperand(0); @@ -3422,15 +3382,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_store_d_f32_col: case Intrinsic::nvvm_wmma_store_d_f32_row: case Intrinsic::nvvm_wmma_store_d_f32_col_stride: - case Intrinsic::nvvm_wmma_store_d_f32_row_stride: - case Intrinsic::nvvm_wmma_store_d_f32_col_shared: - case Intrinsic::nvvm_wmma_store_d_f32_row_shared: - case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride: - case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride: - case Intrinsic::nvvm_wmma_store_d_f32_col_global: - case Intrinsic::nvvm_wmma_store_d_f32_row_global: - case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride: - case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: { + case Intrinsic::nvvm_wmma_store_d_f32_row_stride: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v8f32; Info.ptrVal = I.getArgOperand(0); diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index b2121952887..ba3f2e3e3b1 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -7379,13 +7379,16 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space, string Type, NVPTXRegClass regclass, DAGOperand SrcOp, bit WithStride> : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_" - # Abc - # "_" # Type - # "_" # Layout - # !subst(".","_",Space) - # !if(WithStride,"_stride", "")); + // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic + // for this function. + PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_" + # !subst("a", "A", + !subst("b", "B", + !subst("c", "C_" # Type, Abc))) + # "_" # Layout + # !subst(".", "_", Space) + # !if(WithStride,"_stride", "") + # "_Intr"); dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); @@ -7410,7 +7413,7 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space, !subst(imem, ADDRvar, !subst(MEMri64, ADDRri64, !subst(MEMri, ADDRri, - !subst(ins, Intr, tmp))))); + !subst(ins, IntrMatcher, tmp))))); // Finally, consatenate both parts together. !con() requires both dags to have // the same operator, so we wrap PatArgs in a (set ...) dag. let Pattern = [!con(PatOuts, (set PatArgs))]; @@ -7425,20 +7428,52 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space, #";"; } -multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space, - string Type, NVPTXRegClass regclass, - DAGOperand SrcOp> { - def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 1>; - def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp, 0>; +class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space, + string Type, bit WithStride> + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_" + # Abc + # "_" # Type + # "_" # Layout + # !if(WithStride,"_stride", "")); + code match_generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code match_shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code match_global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; + + let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); + let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp)); + let PredicateCode = !if(!eq(Space, ".shared"), match_shared, + !if(!eq(Space, ".global"), match_global, match_generic)); +} + +multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space, + string Type, NVPTXRegClass regclass, bit WithStride> { + def _avar: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>; + def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>; + def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>; + def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>; + def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>; +} + +multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space, + string Type, NVPTXRegClass regclass, bit WithStride> { + // Define a PatFrag that matches appropriate intrinsic that loads from the + // given address space. + def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>; + defm NAME: WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>; } multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space, - string Type, NVPTXRegClass regclass> { - defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imem>; - defm _areg: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int32Regs>; - defm _areg64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, Int64Regs>; - defm _ari: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri>; - defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, MEMri64>; + string Type, NVPTXRegClass regclass> { + defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>; + defm NAME: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>; } multiclass WMMA_LOAD_ALT<string Abc, string Layout, @@ -7461,15 +7496,16 @@ defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>; // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_D_LSTOS<string Layout, string Space, +class WMMA_STORE_D_LSTSO<string Layout, string Space, string Type, NVPTXRegClass regclass, - DAGOperand DstOp, bit WithStride> + bit WithStride, DAGOperand DstOp> : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> { - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d_" - # Type - # "_" # Layout - # !subst(".","_",Space) - # !if(WithStride,"_stride", "")); + PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D" + # "_" # Type + # "_" # Layout + # !subst(".", "_", Space) + # !if(WithStride,"_stride", "") + # "_Intr"); dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); @@ -7483,7 +7519,7 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space, !subst(imem, ADDRvar, !subst(MEMri64, ADDRri64, !subst(MEMri, ADDRri, - !subst(ins, Intr, tmp))))); + !subst(ins, IntrMatcher, tmp))))); let Pattern = [PatArgs]; let OutOperandList = (outs); let InOperandList = Ins; @@ -7501,20 +7537,56 @@ class WMMA_STORE_D_LSTOS<string Layout, string Space, } -multiclass WMMA_STORE_D_LSTO<string Layout, string Space, - string Type, NVPTXRegClass regclass, - DAGOperand DstOp> { - def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 1>; - def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp, 0>; +class WMMA_STORE_INTR_HELPER<string Layout, string Space, + string Type, bit WithStride> + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d" + # "_" # Type + # "_" # Layout + # !if(WithStride, "_stride", "")); + code match_generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code match_shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code match_global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; + + dag Args = !if(!eq(Type,"f16"), + (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3), + (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3, + node:$r4, node:$r5, node:$r6, node:$r7)); + dag StrideArg = !if(WithStride, (ops node:$ldm), (ops)); + let Operands = !con(Args, StrideArg); + let Fragment = !foreach(tmp, Operands, !subst(ops, Intr, tmp)); + let PredicateCode = !if(!eq(Space, ".shared"), match_shared, + !if(!eq(Space, ".global"), match_global, match_generic)); +} + +multiclass WMMA_STORE_D_LSTS<string Layout, string Space, + string Type, NVPTXRegClass regclass, bit WithStride> { + def _avar: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>; + def _areg: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>; + def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>; + def _ari: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>; + def _ari64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>; +} + +multiclass WMMA_STORE_D_LSTSh<string Layout, string Space, + string Type, NVPTXRegClass regclass, bit WithStride> { + // Define a PatFrag that matches appropriate intrinsic that loads from the + // given address space. + def _Intr: WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>; + defm NAME: WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>; } multiclass WMMA_STORE_D_LST<string Layout, string Space, - string Type, NVPTXRegClass regclass> { - defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imem>; - defm _areg: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int32Regs>; - defm _areg64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, Int64Regs>; - defm _ari: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri>; - defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, MEMri64>; + string Type, NVPTXRegClass regclass > { + defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>; + defm NAME: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>; } multiclass WMMA_STORE_D_LT<string Layout, diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index ad62b84f417..d0fa90cdef1 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -15,6 +15,22 @@ def make_wmma_slice_ty(abcd, itype): def make_wmma_ld_ret_ty(abc, itype): return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype)) +# returns address space +def get_aspace(space): + space_map = { + ".global" : 1, + ".shared" : 3, + ".const" : 4, + ".local" : 5, + ".param" : 101, + "" : 0, + ".generic": 0 + } + return space_map[space]; + +def get_pspace(space): + return "p%di8" % get_aspace(space); + # Convenient test patterns. check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8) check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4) @@ -22,28 +38,28 @@ check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8) def gen_wmma_load_tests(): load_template = """ -declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args}); +declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args}); ; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}( -define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) { +define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) { ; CHECK wmma.load.${intrinsic_suffix} ; CHECK: {${check_result}} ; CHECK: [%rd{{[0-9]+}}]${stride_pattern} - %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args}); + %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args}); ret ${ret_ty} %v0; } ; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o( -define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) { +define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) { ; CHECK wmma.load.${intrinsic_suffix} ; CHECK: {${check_result}} ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern} - %src1 = getelementptr i8, i8* %src, i32 128; - %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args}); + %src1 = getelementptr i8, i8 ${as}* %src, i32 128; + %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args}); ret ${ret_ty} %v0; } """ - suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}" + suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}" instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}" for abc, layout, space, stride, itype in product( @@ -58,7 +74,9 @@ define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) { "layout" : layout, "space" : space, "stride" : stride, - "itype" : itype + "itype" : itype, + "pspace" : get_pspace(space), + "as" : "addrspace(%d)" % get_aspace(space) } if itype == "f32" and abc != "c": @@ -89,28 +107,28 @@ def make_wmma_slice_args(itype, abcd, prefix="v"): def gen_wmma_store_tests(): store_template = """ -declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args}); +declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args}); ; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}( -define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) { +define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) { ; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}} ; CHECK: {${check_args}} ; CHECK: ${stride_pattern} - call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args}); + call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args}); ret void } ; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o( -define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) { +define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) { ; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128] ; CHECK: ${check_args} ; CHECK: ${stride_pattern} - %src1 = getelementptr i8, i8* %src, i32 128; - call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args}); + %src1 = getelementptr i8, i8 ${as}* %src, i32 128; + call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args}); ret void } """ - suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}" + suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}" instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}" for abc, layout, space, stride, itype in product( @@ -125,7 +143,9 @@ define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args} "layout" : layout, "space" : space, "stride" : stride, - "itype" : itype + "itype" : itype, + "pspace" : get_pspace(space), + "as" : "addrspace(%d)" % get_aspace(space) } test_params = params |

