diff options
author | Farhana Aleen <farhana.aleen@gmail.com> | 2018-09-18 16:59:48 +0000 |
---|---|---|
committer | Farhana Aleen <farhana.aleen@gmail.com> | 2018-09-18 16:59:48 +0000 |
commit | f5a2848376b488322e5573fef8a239befe110f3a (patch) | |
tree | 47cf7af80aee276a94711dc495aa6a46ad6a162f /llvm/lib | |
parent | b7471814cf6ba5f61c749a4b6e554808913b64ad (diff) | |
download | bcm5719-llvm-f5a2848376b488322e5573fef8a239befe110f3a.tar.gz bcm5719-llvm-f5a2848376b488322e5573fef8a239befe110f3a.zip |
[AMDGPU] Match udot8 pattern
Summary: D.u32 = S0.u4[0] * S1.u4[0] +
S0.u4[1] * S1.u4[1] +
S0.u4[2] * S1.u4[2] +
S0.u4[3] * S1.u4[3] +
S0.u4[4] * S1.u4[4] +
S0.u4[5] * S1.u4[5] +
S0.u4[6] * S1.u4[6] +
S0.u4[7] * S1.u4[7] +
S2.u32
Author: FarhanaAleen
Reviewed By: arsenm, nhaehnle
Differential Revision: https://reviews.llvm.org/D51947
llvm-svn: 342497
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AMDGPU/VOP3PInstructions.td | 69 |
1 files changed, 47 insertions, 22 deletions
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 3154dcfdd45..83e95b55380 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -168,34 +168,53 @@ defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>; class Srl<int N> : PatFrag<(ops node:$src), (srl node:$src, (i32 N))>; -foreach Bits = [8, 16, 24] in { - def srl#Bits : Srl<Bits>; -} - -def and_255 : PatFrag< - (ops node:$src0), (and node:$src0, (i32 255)) ->; +foreach Bits = 1-7 in + def srl#!shl(Bits, 2) : Srl<!shl(Bits, 2)>; -class Extract_U8<int FromBitIndex> : PatFrag<( - ops node:$src), - !if (!eq (FromBitIndex, 24), // last element +class Extract_U<int FromBitIndex, int BitMask> : PatFrag< + (ops node:$src), + !if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)), + !and (!eq (BitMask, 15), !eq (FromBitIndex, 28))), // last element (!cast<Srl>("srl"#FromBitIndex) node:$src), !if (!eq (FromBitIndex, 0), // first element - (and_255 node:$src), - (and_255 (!cast<Srl>("srl"#FromBitIndex) node:$src))))>; + (and node:$src, (i32 BitMask)), + (and (!cast<Srl>("srl"#FromBitIndex) node:$src), (i32 BitMask))))>; -// Defines patterns that extract each Index'ed 8bit from a 32bit scalar value; -foreach Index = [1, 2, 3, 4] in { - def UElt#Index : Extract_U8<!shl(!add(Index, -1), 3)>; -} +foreach Index = 0-3 in { + // Defines patterns that extract each Index'ed 8bit from an unsigned + // 32bit scalar value; + def U#Index#"_8bit" : Extract_U<!shl(Index, 3), + 255>; -// Defines multiplication patterns where the multiplication is happening on each -// Index'ed 8bit of a 32bit scalar value. -foreach Index = [1, 2, 3, 4] in { + // Defines multiplication patterns where the multiplication is happening on each + // Index'ed 8bit of a 32bit scalar value. def MulU_Elt#Index : PatFrag< (ops node:$src0, node:$src1), - (AMDGPUmul_u24_oneuse (!cast<Extract_U8>("UElt"#Index) node:$src0), - (!cast<Extract_U8>("UElt"#Index) node:$src1))>; + (AMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_8bit") node:$src0), + (!cast<Extract_U>("U"#Index#"_8bit") node:$src1))>; +} + +// Different variants of dot8 patterns cause a huge increase in the compile time. +// Define non-associative/commutative add/mul to prevent permutation in the dot8 +// pattern. +def NonACAdd : SDNode<"ISD::ADD" , SDTIntBinOp>; +def NonACAdd_oneuse : HasOneUseBinOp<NonACAdd>; + +def NonACAMDGPUmul_u24 : SDNode<"AMDGPUISD::MUL_U24" , SDTIntBinOp>; +def NonACAMDGPUmul_u24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_u24>; + +foreach Index = 0-7 in { + // Defines patterns that extract each Index'ed 4bit from an unsigned + // 32bit scalar value; + def U#Index#"_4bit" : Extract_U<!shl(Index, 2), + 15>; + + // Defines multiplication patterns where the multiplication is happening on each + // Index'ed 8bit of a 32bit scalar value. + def MulU#Index#"_4bit" : PatFrag< + (ops node:$src0, node:$src1), + (NonACAMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_4bit") node:$src0), + (!cast<Extract_U>("U"#Index#"_4bit") node:$src1))>; } class UDot2Pat<Instruction Inst> : GCNPat < @@ -246,11 +265,17 @@ def : UDot2Pat<V_DOT2_U32_U16>; def : SDot2Pat<V_DOT2_I32_I16>; def : GCNPat < - !cast<dag>(!foldl((i32 i32:$src2), [1, 2, 3, 4], lhs, y, + !cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y, (add_oneuse lhs, (!cast<PatFrag>("MulU_Elt"#y) i32:$src0, i32:$src1)))), (V_DOT4_U32_U8 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0)) >; +def : GCNPat < + !cast<dag>(!foldl((add_oneuse i32:$src2, (MulU0_4bit i32:$src0, i32:$src1)), [1, 2, 3, 4, 5, 6, 7], lhs, y, + (NonACAdd_oneuse lhs, (!cast<PatFrag>("MulU"#y#"_4bit") i32:$src0, i32:$src1)))), + (V_DOT8_U32_U4 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0)) +>; + } // End SubtargetPredicate = HasDLInsts multiclass VOP3P_Real_vi<bits<10> op> { |