diff options
-rw-r--r-- | llvm/include/llvm/IR/IntrinsicsNVVM.td | 258 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 512 |
2 files changed, 295 insertions, 475 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index cf072c70eba..84499e68364 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -37,6 +37,69 @@ def llvm_anyi64ptr_ty : LLVMAnyPointerType<llvm_i64_ty>; // (space)i64* // MISC // +// Helper class for construction of n-element list<LLVMtype> [t,t,...,t] +class RepLLVMType<int N, LLVMType T> { + list<LLVMType> ret = !if(N, !listconcat(RepLLVMType<!add(N,-1), T>.ret, [T]), []); +} + +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// Geom: m<M>n<N>k<K>. E.g. m8n32k16 +// Frag: [abcd] +// PtxEltType: PTX type for the element. +class WMMA_REGS<string Geom, string Frag, string PtxEltType> { + string geom = Geom; + string frag = Frag; + string ptx_elt_type = PtxEltType; + string ft = frag#":"#ptx_elt_type; + list<LLVMType> regs = !cond( + // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16 + // All currently supported geometries use the same fragment format, + // so we only need to consider {fragment, type}. + !eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, + !eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret, + !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, + !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret, + !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret, + !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret); +} + +class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> { + string intr = "llvm.nvvm.wmma." + # Frag.geom + # "." # Op + # "." # Frag.frag + # "." # Layout + # !if(WithStride, ".stride", "") + # "." # Frag.ptx_elt_type + ; + // TODO(tra): record name should ideally use the same field order as the intrinsic. + // E.g. string record = !subst("llvm", "int", + // !subst(".", "_", llvm)); + string record = "int_nvvm_wmma_" + # Frag.geom + # "_" # Op + # "_" # Frag.frag + # "_" # Frag.ptx_elt_type + # "_" # Layout + # !if(WithStride, "_stride", ""); +} + +class WMMA_NAME_MMA<string ALayout, string BLayout, + WMMA_REGS C, WMMA_REGS D, + int Satfinite> { + string llvm = "llvm.nvvm.wmma." + # C.geom + # ".mma" + # "." # ALayout + # "." # BLayout + # "." # D.ptx_elt_type // Intrinsic encodes 'd' first. + # "." # C.ptx_elt_type + # !if(Satfinite, ".satfinite", ""); + + string record = !subst(".", "_", + !subst("llvm.", "int_", llvm)); +} + let TargetPrefix = "nvvm" in { def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">, Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], @@ -3889,166 +3952,69 @@ def int_nvvm_match_all_sync_i64p : // // WMMA instructions // - // WMMA.LOAD -class NVVM_WMMA_LD_GALSTS<string Geometry, string Abc, string Layout, - string Type, LLVMType regty, int WithStride> - : Intrinsic<!if(!eq(Abc#Type,"cf16"), - [regty, regty, regty, regty], - [regty, regty, regty, regty, - regty, regty, regty, regty]), +class NVVM_WMMA_LD<WMMA_REGS Frag, string Layout, int WithStride> + : Intrinsic<Frag.regs, !if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]), [IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>], - "llvm.nvvm.wmma." - # Geometry - # ".load" - # "." # Abc - # "." # Layout - # !if(WithStride, ".stride", "") - # "." # Type>; - -multiclass NVVM_WMMA_LD_GALT<string Geometry, string Abc, string Layout, - string Type, LLVMType regty> { - def _stride: NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 1>; - def NAME : NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 0>; -} - -multiclass NVVM_WMMA_LD_GAT<string Geometry, string Abc, - string Type, LLVMType regty> { - defm _row: NVVM_WMMA_LD_GALT<Geometry, Abc, "row", Type, regty>; - defm _col: NVVM_WMMA_LD_GALT<Geometry, Abc, "col", Type, regty>; -} - -multiclass NVVM_WMMA_LD_G<string Geometry> { - defm _a_f16: NVVM_WMMA_LD_GAT<Geometry, "a", "f16", llvm_v2f16_ty>; - defm _b_f16: NVVM_WMMA_LD_GAT<Geometry, "b", "f16", llvm_v2f16_ty>; - defm _c_f16: NVVM_WMMA_LD_GAT<Geometry, "c", "f16", llvm_v2f16_ty>; - defm _c_f32: NVVM_WMMA_LD_GAT<Geometry, "c", "f32", llvm_float_ty>; -} - -multiclass NVVM_WMMA_LD { - defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">; - defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">; - defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">; -} - -defm int_nvvm_wmma: NVVM_WMMA_LD; + WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr>; // WMMA.STORE.D -class NVVM_WMMA_STD_GLSTS<string Geometry, string Layout, - string Type, LLVMType regty, int WithStride, - // This is only used to create a typed empty array we - // need to pass to !if below. - list<LLVMType>Empty=[]> +class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride> : Intrinsic<[], !listconcat( [llvm_anyptr_ty], - !if(!eq(Type,"f16"), - [regty, regty, regty, regty], - [regty, regty, regty, regty, - regty, regty, regty, regty]), - !if(WithStride, [llvm_i32_ty], Empty)), + Frag.regs, + !if(WithStride, [llvm_i32_ty], [])), [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>], - "llvm.nvvm.wmma." - # Geometry - # ".store.d" - # "." # Layout - # !if(WithStride, ".stride", "") - # "." # Type>; - -multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout, - string Type, LLVMType regty> { - def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>; - def NAME: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>; -} - -multiclass NVVM_WMMA_STD_GT<string Geometry, string Type, LLVMType regty> { - defm _row: NVVM_WMMA_STD_GLT<Geometry, "row", Type, regty>; - defm _col: NVVM_WMMA_STD_GLT<Geometry, "col", Type, regty>; -} -multiclass NVVM_WMMA_STD_G<string Geometry> { - defm _d_f16: NVVM_WMMA_STD_GT<Geometry, "f16", llvm_v2f16_ty>; - defm _d_f32: NVVM_WMMA_STD_GT<Geometry, "f32", llvm_float_ty>; -} - -multiclass NVVM_WMMA_STD { - defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">; - defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">; - defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">; + WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>; + +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach frag = [WMMA_REGS<geom, "a", "f16">, + WMMA_REGS<geom, "b", "f16">, + WMMA_REGS<geom, "c", "f16">, + WMMA_REGS<geom, "c", "f32">] in { + def WMMA_NAME_LDST<"load", frag, layout, stride>.record + : NVVM_WMMA_LD<frag, layout, stride>; + } + foreach frag = [WMMA_REGS<geom, "d", "f16">, + WMMA_REGS<geom, "d", "f32">] in { + def WMMA_NAME_LDST<"store", frag, layout, stride>.record + : NVVM_WMMA_ST<frag, layout, stride>; + } + } + } } -defm int_nvvm_wmma: NVVM_WMMA_STD; - // WMMA.MMA -class NVVM_WMMA_MMA_GABDCS<string Geometry, - string ALayout, string BLayout, - string DType, LLVMType d_regty, - string CType, LLVMType c_regty, - string Satfinite = ""> - : Intrinsic<!if(!eq(DType,"f16"), - [d_regty, d_regty, d_regty, d_regty], - [d_regty, d_regty, d_regty, d_regty, - d_regty, d_regty, d_regty, d_regty]), +class NVVM_WMMA_MMA<string ALayout, string BLayout, + WMMA_REGS C, WMMA_REGS D, int Satfinite> + : Intrinsic<D.regs, !listconcat( - [// A - llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, - llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, - // B - llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, - llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty], - !if(!eq(CType,"f16"), - [c_regty, c_regty, c_regty, c_regty], - [c_regty, c_regty, c_regty, c_regty, - c_regty, c_regty, c_regty, c_regty])), + WMMA_REGS<C.geom, "a", "f16">.regs, + WMMA_REGS<C.geom, "b", "f16">.regs, + C.regs), [IntrNoMem], - "llvm.nvvm.wmma." - # Geometry - # ".mma" - # "." # ALayout - # "." # BLayout - # "." # DType - # "." # CType - # Satfinite> { -} - -multiclass NVVM_WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout, - string DType, LLVMType d_regty, - string CType, LLVMType c_regty> { - def NAME : NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout, - DType, d_regty, CType, c_regty>; - def _satfinite: NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout, - DType, d_regty, CType, c_regty,".satfinite">; -} - -multiclass NVVM_WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout, - string DType, LLVMType d_regty> { - defm _f16: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty, - "f16", llvm_v2f16_ty>; - defm _f32: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty, - "f32", llvm_float_ty>; -} - -multiclass NVVM_WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> { - defm _f16: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", llvm_v2f16_ty>; - defm _f32: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", llvm_float_ty>; -} - -multiclass NVVM_WMMA_MMA_GA<string Geometry, string ALayout> { - defm _col: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "col">; - defm _row: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "row">; -} - -multiclass NVVM_WMMA_MMA_G<string Geometry> { - defm _col: NVVM_WMMA_MMA_GA<Geometry, "col">; - defm _row: NVVM_WMMA_MMA_GA<Geometry, "row">; -} - -multiclass NVVM_WMMA_MMA { - defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">; - defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">; - defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">; + WMMA_NAME_MMA<ALayout, BLayout, C, D, Satfinite>.llvm>; + +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGS<geom, "c", "f16">, + WMMA_REGS<geom, "c", "f32">] in { + foreach frag_d = [WMMA_REGS<geom, "d", "f16">, + WMMA_REGS<geom, "d", "f32">] in { + foreach satf = [0, 1] in { + def WMMA_NAME_MMA<layout_a, layout_b, frag_c, frag_d, satf>.record + : NVVM_WMMA_MMA<layout_a, layout_b, frag_c, frag_d, satf>; + } + } + } + } + } } -defm int_nvvm_wmma : NVVM_WMMA_MMA; - } // let TargetPrefix = "nvvm" diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8f9251bb0df..a5f0d7908f2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -26,7 +26,17 @@ def immDouble1 : PatLeaf<(fpimm), [{ return (d==1.0); }]>; - +def AS_match { + code generic = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); + }]; + code shared = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); + }]; + code global = [{ + return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); + }]; +} //----------------------------------- // Synchronization and shuffle functions @@ -1006,17 +1016,11 @@ def INT_FNS_iii : INT_FNS_MBO<(ins i32imm:$mask, i32imm:$base, i32imm:$ //----------------------------------- class ATOMIC_GLOBAL_CHK <dag ops, dag frag> - : PatFrag<ops, frag, [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); -}]>; + : PatFrag<ops, frag, AS_match.global>; class ATOMIC_SHARED_CHK <dag ops, dag frag> - : PatFrag<ops, frag, [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); -}]>; + : PatFrag<ops, frag, AS_match.shared>; class ATOMIC_GENERIC_CHK <dag ops, dag frag> - : PatFrag<ops, frag, [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); -}]>; + : PatFrag<ops, frag, AS_match.generic>; multiclass F_ATOMIC_2_imp<NVPTXRegClass ptrclass, NVPTXRegClass regclass, string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp, @@ -7380,36 +7384,60 @@ def INT_PTX_SREG_WARPSIZE : NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;", [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>; -// -// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] -// - class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>; +// Generates list of n sequential register names. +class RegSeq<int n, string prefix> { + list<string> ret = !if(n, !listconcat(RegSeq<!add(n,-1), prefix>.ret, + [prefix # !add(n, -1)]), + []); +} -class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass, - DAGOperand SrcOp, bit WithStride> - : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic - // for this function. - PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_" - # Geometry # "_load_" - # !subst("c", "c_" # Type, Abc) - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3); - dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7); - dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47)); - - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con((ins SrcOp:$src), StrideArg); +// Helper class that represents a 'fragment' of an NVPTX *MMA instruction. +// In addition to target-independent fields provided by WMMA_REGS, it adds +// the fields commonly used to implement specific PTX instruction -- register +// types and names, constraints, parts of assembly, etc. +class WMMA_REGINFO<string Geom, string Frag, string PtxEltType> + : WMMA_REGS<Geom, Frag, PtxEltType> { + // NVPTX register types used to carry fragment data. + NVPTXRegClass regclass = !cond( + !eq(PtxEltType, "f16") : Float16x2Regs, + !eq(PtxEltType, "f32") : Float32Regs); + + // Instruction input/output arguments for the fragment. + list<NVPTXRegClass> ptx_regs = !foreach(tmp, regs, regclass); + + // List of register names for the fragment -- ["ra0", "ra1",...] + list<string> reg_names = RegSeq<!size(ptx_regs), "r"#frag>.ret; + // Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction. + string regstring = "{{$" # !head(reg_names) + # !foldl("", !tail(reg_names), a, b, + !strconcat(a, ", $", b)) + # "}}"; + + // Predicates for particular fragment variant. Technically those are + // per-instruction predicates, but currently all fragments that can be used in + // a given instruction are subject to the same constraints, so an instruction + // can use predicates from any of its fragments. If/when this is no + // longer the case, we can concat all per-fragment predicates to enforce that + // all fragments of the instruction are viable. + list<Predicate> Predicates = !cond( + // fp16 -> fp16/fp32 @ m16n16k16 + !and(!eq(Geom, "m16n16k16"), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60], + + // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16 + !and(!or(!eq(Geom, "m8n32k16"), + !eq(Geom, "m32n8k16")), + !or(!eq(PtxEltType, "f16"), + !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]); + + // template DAGs for instruction inputs/output. + dag Outs = !dag(outs, ptx_regs, reg_names); + dag Ins = !dag(ins, ptx_regs, reg_names); +} +class BuildPattern<dag Outs, PatFrag IntrMatcher, dag Ins> { // Build a dag pattern that matches the intrinsic call. // We want a dag that looks like this: // (set <output args>, (intrinsic <input arguments>)) where input and @@ -7430,277 +7458,127 @@ class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout, !subst(ins, IntrMatcher, tmp))))); // Finally, consatenate both parts together. !con() requires both dags to have // the same operator, so we wrap PatArgs in a (set ...) dag. - let Pattern = [!con(PatOuts, (set PatArgs))]; - let OutOperandList = Outs; - let InOperandList = Ins; - let AsmString = "wmma.load." - # Abc - # ".sync" - # "." # Layout - # "." # Geometry - # Space - # "." # Type # " \t" - # !if(!eq(Abc#Type, "cf16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") - # ", [$src]" - # !if(WithStride, ", $ldm", "") - # ";"; + dag ret = !con(PatOuts, (set PatArgs)); } -class WMMA_LOAD_INTR_HELPER<string Geometry, string Abc, string Layout, - string Space, string Type, bit WithStride> +// +// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] +// + +class WMMA_LOAD_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space, + bit WithStride> : PatFrag <(ops),(ops)> { // Intrinsic that matches this instruction. - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma" - # "_" # Geometry # "_load_" - # Abc # "_" # Type # "_" # Layout - # !if(WithStride,"_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - + Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"load", Frag, Layout, + WithStride>.record); let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src)); let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_LOAD_GALSTS<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass, - bit WithStride> { - def _avar: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - imem, WithStride>; - def _areg: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - Int32Regs, WithStride>; - def _areg64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - Int64Regs, WithStride>; - def _ari: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - MEMri, WithStride>; - def _ari64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass, - MEMri64, WithStride>; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); } -multiclass WMMA_LOAD_GALSTSh<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass, - bit WithStride> { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_LOAD_INTR_HELPER<Geometry, Abc, Layout, Space, Type, - WithStride>; - defm NAME: WMMA_LOAD_GALSTS<Geometry, Abc, Layout, Space, Type, regclass, - WithStride>; -} - -multiclass WMMA_LOAD_GALST<string Geometry, string Abc, string Layout, - string Space, string Type, NVPTXRegClass regclass> { - defm _stride: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 1>; - defm NAME: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 0>; -} - -multiclass WMMA_LOAD_GALT<string Geometry, string Abc, string Layout, - string Type, NVPTXRegClass regclass> { - defm _global: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".global", - Type, regclass>; - defm _shared: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".shared", - Type, regclass>; - defm NAME: WMMA_LOAD_GALST<Geometry, Abc, Layout, "", - Type, regclass>; -} - -multiclass WMMA_LOAD_GAT<string Geometry, string Abc, - string Type, NVPTXRegClass regclass> { - defm _row: WMMA_LOAD_GALT<Geometry, Abc, "row", Type, regclass>; - defm _col: WMMA_LOAD_GALT<Geometry, Abc, "col", Type, regclass>; -} +class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride, + DAGOperand SrcOp> + : EmptyNVPTXInst, + Requires<Frag.Predicates> { + // Pattern that matches the intrinsic for this instruction variant. + PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER<Frag, Layout, Space, WithStride>; + dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins))); -multiclass WMMA_LOAD_G<string Geometry> { - defm _load_a: WMMA_LOAD_GAT<Geometry, "a", "f16", Float16x2Regs>; - defm _load_b: WMMA_LOAD_GAT<Geometry, "b", "f16", Float16x2Regs>; - defm _load_c_f16: WMMA_LOAD_GAT<Geometry, "c", "f16", Float16x2Regs>; - defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>; + let Pattern = [BuildPattern<Frag.Outs, IntrMatcher, Ins>.ret]; + let OutOperandList = Frag.Outs; + let InOperandList = Ins; + let AsmString = "wmma.load." + # Frag.frag + # ".sync" + # "." # Layout + # "." # Frag.geom + # Space + # "." # Frag.ptx_elt_type # " \t" + # Frag.regstring + # ", [$src]" + # !if(WithStride, ", $ldm", "") + # ";"; } -defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">; - // // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32] // -class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass, - bit WithStride, DAGOperand DstOp> +class WMMA_STORE_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space, + bit WithStride> + : PatFrag <(ops),(ops)> { + // Intrinsic that matches this instruction. + Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"store", Frag, Layout, + WithStride>.record); + let Operands = !con((ops node:$dst), + !dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names), + !if(WithStride, (ops node:$ldm), (ops))); + let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; + let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared, + !eq(Space, ".global"): AS_match.global, + 1: AS_match.generic); +} + +class WMMA_STORE<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride, + DAGOperand DstOp> : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA" - # "_" # Geometry # "_store_d" - # "_" # Type - # "_" # Layout - # !subst(".", "_", Space) - # !if(WithStride,"_stride", "") - # "_Intr"); - dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, - regclass:$r2, regclass:$r3); - dag InsR47 = (ins regclass:$r4, regclass:$r5, - regclass:$r6, regclass:$r7); - dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47)); - dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins)); - dag Ins = !con(InsR, StrideArg); - - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. - dag PatArgs = !foreach(tmp, Ins, - !subst(imem, ADDRvar, - !subst(MEMri64, ADDRri64, - !subst(MEMri, ADDRri, - !subst(ins, IntrMatcher, tmp))))); - let Pattern = [PatArgs]; + Requires<Frag.Predicates> { + PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER<Frag, Layout, Space, WithStride>; + dag Ins = !con((ins DstOp:$src), + Frag.Ins, + !if(WithStride, (ins Int32Regs:$ldm), (ins))); + let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret]; let OutOperandList = (outs); let InOperandList = Ins; let AsmString = "wmma.store.d.sync." # Layout - # "." # Geometry + # "." # Frag.geom # Space - # "." # Type + # "." # Frag.ptx_elt_type # " \t[$src]," - # !if(!eq(Type,"f16"), - "{{$r0, $r1, $r2, $r3}}", - "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}") + # Frag.regstring # !if(WithStride, ", $ldm", "") # ";"; - } -class WMMA_STORE_INTR_HELPER<string Geometry, string Layout, string Space, - string Type, bit WithStride> - : PatFrag <(ops),(ops)> { - // Intrinsic that matches this instruction. - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_" - # Geometry - # "_store_d" - # "_" # Type - # "_" # Layout - # !if(WithStride, "_stride", "")); - code match_generic = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC); - }]; - code match_shared = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED); - }]; - code match_global = [{ - return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL); - }]; - - dag Args = !if(!eq(Type,"f16"), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3), - (ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3, - node:$r4, node:$r5, node:$r6, node:$r7)); - dag StrideArg = !if(WithStride, (ops node:$ldm), (ops)); - let Operands = !con(Args, StrideArg); - let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))]; - let PredicateCode = !if(!eq(Space, ".shared"), match_shared, - !if(!eq(Space, ".global"), match_global, match_generic)); -} - -multiclass WMMA_STORE_D_GLSTS<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass, - bit WithStride> { - def _avar: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, imem>; - def _areg: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, Int32Regs>; - def _areg64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, Int64Regs>; - def _ari: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, MEMri>; - def _ari64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass, - WithStride, MEMri64>; -} - -multiclass WMMA_STORE_D_GLSTSh<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass, - bit WithStride> { - // Define a PatFrag that matches appropriate intrinsic that loads from the - // given address space. - def _Intr: WMMA_STORE_INTR_HELPER<Geometry, Layout, Space, Type, - WithStride>; - defm NAME: WMMA_STORE_D_GLSTS<Geometry, Layout, Space, Type, regclass, - WithStride>; -} - -multiclass WMMA_STORE_D_GLST<string Geometry, string Layout, string Space, - string Type, NVPTXRegClass regclass > { - defm _stride: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 1>; - defm NAME: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 0>; -} - -multiclass WMMA_STORE_D_GLT<string Geometry, string Layout, - string Type, NVPTXRegClass regclass> { - defm _global: WMMA_STORE_D_GLST<Geometry, Layout, ".global", Type, regclass>; - defm _shared: WMMA_STORE_D_GLST<Geometry, Layout, ".shared", Type, regclass>; - defm NAME: WMMA_STORE_D_GLST<Geometry, Layout, "", Type, regclass>; -} - -multiclass WMMA_STORE_D_GT<string Geometry, string Type, - NVPTXRegClass regclass> { - defm _row: WMMA_STORE_D_GLT<Geometry, "row", Type, regclass>; - defm _col: WMMA_STORE_D_GLT<Geometry, "col", Type, regclass>; -} - -multiclass WMMA_STORE_D_G<string Geometry> { - defm _store_d_f16: WMMA_STORE_D_GT<Geometry, "f16", Float16x2Regs>; - defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>; -} - -defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">; -defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">; -defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">; +// Create all load/store variants +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout = ["row", "col"] in { + foreach stride = [0, 1] in { + foreach space = [".global", ".shared", ""] in { + foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in { + foreach frag = [WMMA_REGINFO<geom, "a", "f16">, + WMMA_REGINFO<geom, "b", "f16">, + WMMA_REGINFO<geom, "c", "f16">, + WMMA_REGINFO<geom, "c", "f32">] in { + def : WMMA_LOAD<frag, layout, space, stride, addr>; + } + foreach frag = [WMMA_REGINFO<geom, "d", "f16">, + WMMA_REGINFO<geom, "d", "f32">] in { + def : WMMA_STORE<frag, layout, space, stride, addr>; + } + } // addr + } // space + } // stride + } // layout +} // geom // WMMA.MMA -class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, - string DType, NVPTXRegClass d_reg, - string CType, NVPTXRegClass c_reg, - NVPTXRegClass ab_reg, - string Satfinite = ""> +class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB, + WMMA_REGINFO FragC, WMMA_REGINFO FragD, + string ALayout, string BLayout, int Satfinite> : EmptyNVPTXInst, - Requires<[!if(!eq(Geometry, "m16n16k16"), - hasPTX60, - hasPTX61), - hasSM70]> { - Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_" - # Geometry - # "_mma" - # "_" # ALayout - # "_" # BLayout - # "_" # DType - # "_" # CType - # !subst(".", "_", Satfinite)); - dag Outs = !if(!eq(DType,"f16"), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3), - (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3, - d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)); - dag InsExtraCArgs = !if(!eq(CType,"f16"), - (ins), - (ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)); - dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3, - ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7, - ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3, - ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7, - c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3), - InsExtraCArgs); - - // Construct the pattern to match corresponding intrinsic call. See the - // details in the comments in WMMA_LOAD_ALSTOS. + Requires<FragC.Predicates> { + //Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero; + Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_MMA<ALayout, BLayout, FragC, FragD, Satfinite>.record); + dag Outs = FragD.Outs; + dag Ins = !con(FragA.Ins, + FragB.Ins, + FragC.Ins); + + // Construct the pattern to match corresponding intrinsic call. + // mma does not load/store anything, so we don't need complex operand matching here. dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp)); dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp)); let Pattern = [!con(PatOuts, (set PatArgs))]; @@ -7709,54 +7587,30 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout, let AsmString = "wmma.mma.sync." # ALayout # "." # BLayout - # "." # Geometry - # "." # DType - # "." # CType - # Satfinite # "\n\t\t" - # !if(!eq(DType,"f16"), - "{{$d0, $d1, $d2, $d3}}, \n\t\t", - "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t") - # "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t" - # "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t" - # !if(!eq(CType,"f16"), - "{{$c0, $c1, $c2, $c3}};", - "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"); -} - -multiclass WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout, - string DType, NVPTXRegClass d_reg, - string CType, NVPTXRegClass c_reg> { - def _satfinite: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout, - DType, d_reg, CType, c_reg, - Float16x2Regs, ".satfinite">; - def NAME: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout, - DType, d_reg, CType, c_reg, - Float16x2Regs>; -} - -multiclass WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout, - string DType, NVPTXRegClass d_reg> { - defm _f16: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg, - "f16", Float16x2Regs>; - defm _f32: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg, - "f32", Float32Regs>; -} - -multiclass WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> { - defm _f16: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", Float16x2Regs>; - defm _f32: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", Float32Regs>; -} - -multiclass WMMA_MMA_GA<string Geometry, string ALayout> { - defm _col: WMMA_MMA_GAB<Geometry, ALayout, "col">; - defm _row: WMMA_MMA_GAB<Geometry, ALayout, "row">; -} - -multiclass WMMA_MMA_G<string Geometry> { - defm _col: WMMA_MMA_GA<Geometry, "col">; - defm _row: WMMA_MMA_GA<Geometry, "row">; + # "." # FragA.geom + # "." # FragD.ptx_elt_type + # "." # FragC.ptx_elt_type + # !if(Satfinite, ".satfinite", "") # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ";"; } -defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">; -defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">; -defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">; +foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in { + foreach layout_a = ["row", "col"] in { + foreach layout_b = ["row", "col"] in { + foreach frag_c = [WMMA_REGINFO<geom, "c", "f16">, + WMMA_REGINFO<geom, "c", "f32">] in { + foreach frag_d = [WMMA_REGINFO<geom, "d", "f16">, + WMMA_REGINFO<geom, "d", "f32">] in { + foreach satf = [0, 1] in { + def : WMMA_MMA<WMMA_REGINFO<geom, "a", "f16">, + WMMA_REGINFO<geom, "b", "f16">, + frag_c, frag_d, layout_a, layout_b, satf>; + } // satf + } // frag_d + } // frag_c + } // layout_b + } // layout_a +} // geom |