diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 512 |
1 files changed, 183 insertions, 329 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8f9251bb0df..a5f0d7908f2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -26,7 +26,17 @@ def immDouble1 : PatLeaf<(fpimm), [{ return (d==1.0); }]>; - +def AS_match { + code generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; +} //----------------------------------- // Synchronization and shuffle functions @@ -1006,17 +1016,11 @@ def INT_FNS_iii : INT_FNS_MBO<(ins i32imm:$mask, i32imm:$base, i32imm:$ //----------------------------------- class ATOMIC_GLOBAL_CHK <dag ops, dag frag> - : PatFrag<ops, frag, [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); -}]>; + : PatFrag<ops, frag, AS_match.global>; class ATOMIC_SHARED_CHK <dag ops, dag frag> - : PatFrag<ops, frag, [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); -}]>; + : PatFrag<ops, frag, AS_match.shared>; class ATOMIC_GENERIC_CHK <dag ops, dag frag> - : PatFrag<ops, frag, [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); -}]>; + : PatFrag<ops, frag, AS_match.generic>; multiclass F_ATOMIC_2_imp<NVPTXRegClass ptrclass, NVPTXRegClass regclass, string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp, @@ -7380,36 +7384,60 @@ def INT_PTX_SREG_WARPSIZE : NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;", [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>; -// -// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] -// - class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; +// Generates list of n sequential register names. +class RegSeq<int n, string prefix> { + list<string> ret = !if(n, !listconcat(RegSeq<!add(n,-1), prefix>.ret, + [prefix # !add(n, -1)]), + []); +} -class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass, - DAGOperand SrcOp, bit WithStride> - : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic - // for this function. - PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_" - # Geometry # "_load_" - # !subst("c", "c_" # Type, Abc) - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); - dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); - dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); - - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con((ins SrcOp:$src), StrideArg); +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// In addition to target-independent fields provided by WMMA_REGS, it adds +// the fields commonly used to implement specific PTX instruction -- register +// types and names, constraints, parts of assembly, etc. +class WMMA_REGINFO<string Geom, string Frag, string PtxEltType> + : WMMA_REGS<Geom, Frag, PtxEltType> { + // NVPTX register types used to carry fragment data. + NVPTXRegClass regclass = !cond( + !eq(PtxEltType, "f16") : Float16x2Regs, + !eq(PtxEltType, "f32") : Float32Regs); + + // Instruction input/output arguments for the fragment. + list<NVPTXRegClass> ptx_regs = !foreach(tmp, regs, regclass); + + // List of register names for the fragment -- ["ra0", "ra1",...] + list<string> reg_names = RegSeq<!size(ptx_regs), "r"#frag>.ret; + // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction. + string regstring = "{{$" # !head(reg_names) + # !foldl("", !tail(reg_names), a, b, + !strconcat(a, ", $", b)) + # "}}"; + + // Predicates for particular fragment variant. Technically those are + // per-instruction predicates, but currently all fragments that can be used in + // a given instruction are subject to the same constraints, so an instruction + // can use predicates from any of its fragments. If/when this is no + // longer the case, we can concat all per-fragment predicates to enforce that + // all fragments of the instruction are viable. + list<Predicate> Predicates = !cond( + // fp16 -> fp16/fp32 @ m16n16k16 + !and(!eq(Geom, "m16n16k16"), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60], + + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 + !and(!or(!eq(Geom, "m8n32k16"), + !eq(Geom, "m32n8k16")), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]); + + // template DAGs for instruction inputs/output. + dag Outs = !dag(outs, ptx_regs, reg_names); + dag Ins = !dag(ins, ptx_regs, reg_names); +} +class BuildPattern<dag Outs, PatFrag IntrMatcher, dag Ins> { // Build a dag pattern that matches the intrinsic call. // We want a dag that looks like this: // (set <output args>, (intrinsic <input arguments>)) where input and @@ -7430,277 +7458,127 @@ class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout, !subst(ins, IntrMatcher, tmp))))); // Finally, consatenate both parts together. !con() requires both dags to have // the same operator, so we wrap PatArgs in a (set ...) dag. - let Pattern = [!con(PatOuts, (set PatArgs))]; - let OutOperandList = Outs; - let InOperandList = Ins; - let AsmString = "wmma.load." - # Abc - # ".sync" - # "." # Layout - # "." # Geometry - # Space - # "." # Type # " \t" - # !if(!eq(Abc#Type, "cf16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") - # ", [$src]" - # !if(WithStride, ", $ldm", "") - # ";"; + dag ret = !con(PatOuts, (set PatArgs)); } -class WMMA_LOAD_INTR_HELPER<string Geometry, string Abc, string Layout, - string Space, string Type, bit WithStride> +// +// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// + +class WMMA_LOAD_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space, + bit WithStride> : PatFrag <(ops),(ops)> { // Intrinsic that matches this instruction. - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma" - # "_" # Geometry # "_load_" - # Abc # "_" # Type # "_" # Layout - # !if(WithStride,"_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - + Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"load", Frag, Layout, + WithStride>.record); let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_LOAD_GALSTS<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass, - bit WithStride> { - def _avar: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - imem, WithStride>; - def _areg: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - Int32Regs, WithStride>; - def _areg64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - Int64Regs, WithStride>; - def _ari: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - MEMri, WithStride>; - def _ari64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - MEMri64, WithStride>; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); } -multiclass WMMA_LOAD_GALSTSh<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass, - bit WithStride> { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_LOAD_INTR_HELPER<Geometry, Abc, Layout, Space, Type, - WithStride>; - defm NAME: WMMA_LOAD_GALSTS<Geometry, Abc, Layout, Space, Type, regclass, - WithStride>; -} - -multiclass WMMA_LOAD_GALST<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass> { - defm _stride: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 1>; - defm NAME: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 0>; -} - -multiclass WMMA_LOAD_GALT<string Geometry, string Abc, string Layout, - string Type, NVPTXRegClass regclass> { - defm _global: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".global", - Type, regclass>; - defm _shared: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".shared", - Type, regclass>; - defm NAME: WMMA_LOAD_GALST<Geometry, Abc, Layout, "", - Type, regclass>; -} - -multiclass WMMA_LOAD_GAT<string Geometry, string Abc, - string Type, NVPTXRegClass regclass> { - defm _row: WMMA_LOAD_GALT<Geometry, Abc, "row", Type, regclass>; - defm _col: WMMA_LOAD_GALT<Geometry, Abc, "col", Type, regclass>; -} +class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride, + DAGOperand SrcOp> + : EmptyNVPTXInst, + Requires<Frag.Predicates> { + // Pattern that matches the intrinsic for this instruction variant. + PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER<Frag, Layout, Space, WithStride>; + dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins))); -multiclass WMMA_LOAD_G<string Geometry> { - defm _load_a: WMMA_LOAD_GAT<Geometry, "a", "f16", Float16x2Regs>; - defm _load_b: WMMA_LOAD_GAT<Geometry, "b", "f16", Float16x2Regs>; - defm _load_c_f16: WMMA_LOAD_GAT<Geometry, "c", "f16", Float16x2Regs>; - defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>; + let Pattern = [BuildPattern<Frag.Outs, IntrMatcher, Ins>.ret]; + let OutOperandList = Frag.Outs; + let InOperandList = Ins; + let AsmString = "wmma.load." + # Frag.frag + # ".sync" + # "." # Layout + # "." # Frag.geom + # Space + # "." # Frag.ptx_elt_type # " \t" + # Frag.regstring + # ", [$src]" + # !if(WithStride, ", $ldm", "") + # ";"; } -defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">; - // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass, - bit WithStride, DAGOperand DstOp> +class WMMA_STORE_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space, + bit WithStride> + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"store", Frag, Layout, + WithStride>.record); + let Operands = !con((ops node:$dst), + !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names), + !if(WithStride, (ops node:$ldm), (ops))); + let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); +} + +class WMMA_STORE<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride, + DAGOperand DstOp> : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA" - # "_" # Geometry # "_store_d" - # "_" # Type - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, - regclass:$r2, regclass:$r3); - dag InsR47 = (ins regclass:$r4, regclass:$r5, - regclass:$r6, regclass:$r7); - dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47)); - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con(InsR, StrideArg); - - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. - dag PatArgs = !foreach(tmp, Ins, - !subst(imem, ADDRvar, - !subst(MEMri64, ADDRri64, - !subst(MEMri, ADDRri, - !subst(ins, IntrMatcher, tmp))))); - let Pattern = [PatArgs]; + Requires<Frag.Predicates> { + PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER<Frag, Layout, Space, WithStride>; + dag Ins = !con((ins DstOp:$src), + Frag.Ins, + !if(WithStride, (ins Int32Regs:$ldm), (ins))); + let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret]; let OutOperandList = (outs); let InOperandList = Ins; let AsmString = "wmma.store.d.sync." # Layout - # "." # Geometry + # "." # Frag.geom # Space - # "." # Type + # "." # Frag.ptx_elt_type # " \t[$src]," - # !if(!eq(Type,"f16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # Frag.regstring # !if(WithStride, ", $ldm", "") # ";"; - } -class WMMA_STORE_INTR_HELPER<string Geometry, string Layout, string Space, - string Type, bit WithStride> - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_" - # Geometry - # "_store_d" - # "_" # Type - # "_" # Layout - # !if(WithStride, "_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - - dag Args = !if(!eq(Type,"f16"), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3, - node:$r4, node:$r5, node:$r6, node:$r7)); - dag StrideArg = !if(WithStride, (ops node:$ldm), (ops)); - let Operands = !con(Args, StrideArg); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_STORE_D_GLSTS<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass, - bit WithStride> { - def _avar: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, imem>; - def _areg: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, Int32Regs>; - def _areg64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, Int64Regs>; - def _ari: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, MEMri>; - def _ari64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, MEMri64>; -} - -multiclass WMMA_STORE_D_GLSTSh<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass, - bit WithStride> { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_STORE_INTR_HELPER<Geometry, Layout, Space, Type, - WithStride>; - defm NAME: WMMA_STORE_D_GLSTS<Geometry, Layout, Space, Type, regclass, - WithStride>; -} - -multiclass WMMA_STORE_D_GLST<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass > { - defm _stride: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 1>; - defm NAME: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 0>; -} - -multiclass WMMA_STORE_D_GLT<string Geometry, string Layout, - string Type, NVPTXRegClass regclass> { - defm _global: WMMA_STORE_D_GLST<Geometry, Layout, ".global", Type, regclass>; - defm _shared: WMMA_STORE_D_GLST<Geometry, Layout, ".shared", Type, regclass>; - defm NAME: WMMA_STORE_D_GLST<Geometry, Layout, "", Type, regclass>; -} - -multiclass WMMA_STORE_D_GT<string Geometry, string Type, - NVPTXRegClass regclass> { - defm _row: WMMA_STORE_D_GLT<Geometry, "row", Type, regclass>; - defm _col: WMMA_STORE_D_GLT<Geometry, "col", Type, regclass>; -} - -multiclass WMMA_STORE_D_G<string Geometry> { - defm _store_d_f16: WMMA_STORE_D_GT<Geometry, "f16", Float16x2Regs>; - defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>; -} - -defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">; +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach space = [".global", ".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = [WMMA_REGINFO<geom, "a", "f16">, + WMMA_REGINFO<geom, "b", "f16">, + WMMA_REGINFO<geom, "c", "f16">, + WMMA_REGINFO<geom, "c", "f32">] in { + def : WMMA_LOAD<frag, layout, space, stride, addr>; + } + foreach frag = [WMMA_REGINFO<geom, "d", "f16">, + WMMA_REGINFO<geom, "d", "f32">] in { + def : WMMA_STORE<frag, layout, space, stride, addr>; + } + } // addr + } // space + } // stride + } // layout +} // geom // WMMA.MMA -class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, - string DType, NVPTXRegClass d_reg, - string CType, NVPTXRegClass c_reg, - NVPTXRegClass ab_reg, - string Satfinite = ""> +class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, + WMMA_REGINFO FragC, WMMA_REGINFO FragD, + string ALayout, string BLayout, int Satfinite> : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_" - # Geometry - # "_mma" - # "_" # ALayout - # "_" # BLayout - # "_" # DType - # "_" # CType - # !subst(".", "_", Satfinite)); - dag Outs = !if(!eq(DType,"f16"), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, - d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)); - dag InsExtraCArgs = !if(!eq(CType,"f16"), - (ins), - (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)); - dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, - ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, - ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, - ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, - c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3), - InsExtraCArgs); - - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. + Requires<FragC.Predicates> { + //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero; + Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_MMA<ALayout, BLayout, FragC, FragD, Satfinite>.record); + dag Outs = FragD.Outs; + dag Ins = !con(FragA.Ins, + FragB.Ins, + FragC.Ins); + + // Construct the pattern to match corresponding intrinsic call. + // mma does not load/store anything, so we don't need complex operand matching here. dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp)); let Pattern = [!con(PatOuts, (set PatArgs))]; @@ -7709,54 +7587,30 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, let AsmString = "wmma.mma.sync." # ALayout # "." # BLayout - # "." # Geometry - # "." # DType - # "." # CType - # Satfinite # "\n\t\t" - # !if(!eq(DType,"f16"), - "{{$d0, $d1, $d2, $d3}}, \n\t\t", - "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t") - # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t" - # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t" - # !if(!eq(CType,"f16"), - "{{$c0, $c1, $c2, $c3}};", - "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"); -} - -multiclass WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout, - string DType, NVPTXRegClass d_reg, - string CType, NVPTXRegClass c_reg> { - def _satfinite: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout, - DType, d_reg, CType, c_reg, - Float16x2Regs, ".satfinite">; - def NAME: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout, - DType, d_reg, CType, c_reg, - Float16x2Regs>; -} - -multiclass WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout, - string DType, NVPTXRegClass d_reg> { - defm _f16: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg, - "f16", Float16x2Regs>; - defm _f32: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg, - "f32", Float32Regs>; -} - -multiclass WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> { - defm _f16: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", Float16x2Regs>; - defm _f32: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", Float32Regs>; -} - -multiclass WMMA_MMA_GA<string Geometry, string ALayout> { - defm _col: WMMA_MMA_GAB<Geometry, ALayout, "col">; - defm _row: WMMA_MMA_GAB<Geometry, ALayout, "row">; -} - -multiclass WMMA_MMA_G<string Geometry> { - defm _col: WMMA_MMA_GA<Geometry, "col">; - defm _row: WMMA_MMA_GA<Geometry, "row">; + # "." # FragA.geom + # "." # FragD.ptx_elt_type + # "." # FragC.ptx_elt_type + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; } -defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">; -defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">; -defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">; +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGINFO<geom, "c", "f16">, + WMMA_REGINFO<geom, "c", "f32">] in { + foreach frag_d = [WMMA_REGINFO<geom, "d", "f16">, + WMMA_REGINFO<geom, "d", "f32">] in { + foreach satf = [0, 1] in { + def : WMMA_MMA<WMMA_REGINFO<geom, "a", "f16">, + WMMA_REGINFO<geom, "b", "f16">, + frag_c, frag_d, layout_a, layout_b, satf>; + } // satf + } // frag_d + } // frag_c + } // layout_b + } // layout_a +} // geom |