diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86InstrInfo.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86InstrInfo.cpp | 175 |
1 files changed, 139 insertions, 36 deletions
diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp index 3035ddc9001..2f7dcfbe245 100644 --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -3318,36 +3318,31 @@ X86InstrInfo::convertToThreeAddress(MachineFunction::iterator &MFI, return NewMI; } -unsigned X86InstrInfo::getFMA3OpcodeToCommuteOperands( - const MachineInstr &MI, unsigned SrcOpIdx1, unsigned SrcOpIdx2, - const X86InstrFMA3Group &FMA3Group) const { - - unsigned Opc = MI.getOpcode(); - +/// This determines which of three possible cases of a three source commute +/// the source indexes correspond to taking into account any mask operands. +/// All prevents commuting a passthru operand. Returns -1 if the commute isn't +/// possible. +/// Case 0 - Possible to commute the first and second operands. +/// Case 1 - Possible to commute the first and third operands. +/// Case 2 - Possible to commute the second and third operands. +static int getThreeSrcCommuteCase(uint64_t TSFlags, unsigned SrcOpIdx1, + unsigned SrcOpIdx2) { // Put the lowest index to SrcOpIdx1 to simplify the checks below. if (SrcOpIdx1 > SrcOpIdx2) std::swap(SrcOpIdx1, SrcOpIdx2); - // TODO: Commuting the 1st operand of FMA*_Int requires some additional - // analysis. The commute optimization is legal only if all users of FMA*_Int - // use only the lowest element of the FMA*_Int instruction. Such analysis are - // not implemented yet. So, just return 0 in that case. - // When such analysis are available this place will be the right place for - // calling it. - if (FMA3Group.isIntrinsic() && SrcOpIdx1 == 1) - return 0; - - unsigned FMAOp1 = 1, FMAOp2 = 2, FMAOp3 = 3; - if (FMA3Group.isKMasked()) { + unsigned Op1 = 1, Op2 = 2, Op3 = 3; + if (X86II::isKMasked(TSFlags)) { // The k-mask operand cannot be commuted. if (SrcOpIdx1 == 2) - return 0; + return -1; // For k-zero-masked operations it is Ok to commute the first vector // operand. // For regular k-masked operations a conservative choice is done as the // elements of the first vector operand, for which the corresponding bit - // in the k-mask operand is set to 0, are copied to the result of FMA. + // in the k-mask operand is set to 0, are copied to the result of the + // instruction. // TODO/FIXME: The commute still may be legal if it is known that the // k-mask operand is set to either all ones or all zeroes. // It is also Ok to commute the 1st operand if all users of MI use only @@ -3356,20 +3351,43 @@ unsigned X86InstrInfo::getFMA3OpcodeToCommuteOperands( // : v1[i]; // VMOVAPSZmrk <mem_addr>, k, v4; // this is the ONLY user of v4 -> // // Ok, to commute v1 in FMADD213PSZrk. - if (FMA3Group.isKMergeMasked() && SrcOpIdx1 == FMAOp1) - return 0; - FMAOp2++; - FMAOp3++; + if (X86II::isKMergeMasked(TSFlags) && SrcOpIdx1 == Op1) + return -1; + Op2++; + Op3++; } - unsigned Case; - if (SrcOpIdx1 == FMAOp1 && SrcOpIdx2 == FMAOp2) - Case = 0; - else if (SrcOpIdx1 == FMAOp1 && SrcOpIdx2 == FMAOp3) - Case = 1; - else if (SrcOpIdx1 == FMAOp2 && SrcOpIdx2 == FMAOp3) - Case = 2; - else + if (SrcOpIdx1 == Op1 && SrcOpIdx2 == Op2) + return 0; + if (SrcOpIdx1 == Op1 && SrcOpIdx2 == Op3) + return 1; + if (SrcOpIdx1 == Op2 && SrcOpIdx2 == Op3) + return 2; + return -1; +} + +unsigned X86InstrInfo::getFMA3OpcodeToCommuteOperands( + const MachineInstr &MI, unsigned SrcOpIdx1, unsigned SrcOpIdx2, + const X86InstrFMA3Group &FMA3Group) const { + + unsigned Opc = MI.getOpcode(); + + // Put the lowest index to SrcOpIdx1 to simplify the checks below. + if (SrcOpIdx1 > SrcOpIdx2) + std::swap(SrcOpIdx1, SrcOpIdx2); + + // TODO: Commuting the 1st operand of FMA*_Int requires some additional + // analysis. The commute optimization is legal only if all users of FMA*_Int + // use only the lowest element of the FMA*_Int instruction. Such analysis are + // not implemented yet. So, just return 0 in that case. + // When such analysis are available this place will be the right place for + // calling it. + if (FMA3Group.isIntrinsic() && SrcOpIdx1 == 1) + return 0; + + // Determine which case this commute is or if it can't be done. + int Case = getThreeSrcCommuteCase(MI.getDesc().TSFlags, SrcOpIdx1, SrcOpIdx2); + if (Case < 0) return 0; // Define the FMA forms mapping array that helps to map input FMA form @@ -3416,6 +3434,36 @@ unsigned X86InstrInfo::getFMA3OpcodeToCommuteOperands( return FMAForms[FormIndex]; } +static bool commuteVPTERNLOG(MachineInstr &MI, unsigned SrcOpIdx1, + unsigned SrcOpIdx2) { + uint64_t TSFlags = MI.getDesc().TSFlags; + + // Determine which case this commute is or if it can't be done. + int Case = getThreeSrcCommuteCase(TSFlags, SrcOpIdx1, SrcOpIdx2); + if (Case < 0) + return false; + + // For each case we need to swap two pairs of bits in the final immediate. + static const uint8_t SwapMasks[3][4] = { + { 0x04, 0x10, 0x08, 0x20 }, // Swap bits 2/4 and 3/5. + { 0x02, 0x10, 0x08, 0x40 }, // Swap bits 1/4 and 3/6. + { 0x02, 0x04, 0x20, 0x40 }, // Swap bits 1/2 and 5/6. + }; + + uint8_t Imm = MI.getOperand(MI.getNumOperands()-1).getImm(); + // Clear out the bits we are swapping. + uint8_t NewImm = Imm & ~(SwapMasks[Case][0] | SwapMasks[Case][1] | + SwapMasks[Case][2] | SwapMasks[Case][3]); + // If the immediate had a bit of the pair set, then set the opposite bit. + if (Imm & SwapMasks[Case][0]) NewImm |= SwapMasks[Case][1]; + if (Imm & SwapMasks[Case][1]) NewImm |= SwapMasks[Case][0]; + if (Imm & SwapMasks[Case][2]) NewImm |= SwapMasks[Case][3]; + if (Imm & SwapMasks[Case][3]) NewImm |= SwapMasks[Case][2]; + MI.getOperand(MI.getNumOperands()-1).setImm(NewImm); + + return true; +} + MachineInstr *X86InstrInfo::commuteInstructionImpl(MachineInstr &MI, bool NewMI, unsigned OpIdx1, unsigned OpIdx2) const { @@ -3680,6 +3728,30 @@ MachineInstr *X86InstrInfo::commuteInstructionImpl(MachineInstr &MI, bool NewMI, return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false, OpIdx1, OpIdx2); } + case X86::VPTERNLOGDZrri: case X86::VPTERNLOGDZrmi: + case X86::VPTERNLOGDZ128rri: case X86::VPTERNLOGDZ128rmi: + case X86::VPTERNLOGDZ256rri: case X86::VPTERNLOGDZ256rmi: + case X86::VPTERNLOGQZrri: case X86::VPTERNLOGQZrmi: + case X86::VPTERNLOGQZ128rri: case X86::VPTERNLOGQZ128rmi: + case X86::VPTERNLOGQZ256rri: case X86::VPTERNLOGQZ256rmi: + case X86::VPTERNLOGDZrrik: case X86::VPTERNLOGDZrmik: + case X86::VPTERNLOGDZ128rrik: case X86::VPTERNLOGDZ128rmik: + case X86::VPTERNLOGDZ256rrik: case X86::VPTERNLOGDZ256rmik: + case X86::VPTERNLOGQZrrik: case X86::VPTERNLOGQZrmik: + case X86::VPTERNLOGQZ128rrik: case X86::VPTERNLOGQZ128rmik: + case X86::VPTERNLOGQZ256rrik: case X86::VPTERNLOGQZ256rmik: + case X86::VPTERNLOGDZrrikz: case X86::VPTERNLOGDZrmikz: + case X86::VPTERNLOGDZ128rrikz: case X86::VPTERNLOGDZ128rmikz: + case X86::VPTERNLOGDZ256rrikz: case X86::VPTERNLOGDZ256rmikz: + case X86::VPTERNLOGQZrrikz: case X86::VPTERNLOGQZrmikz: + case X86::VPTERNLOGQZ128rrikz: case X86::VPTERNLOGQZ128rmikz: + case X86::VPTERNLOGQZ256rrikz: case X86::VPTERNLOGQZ256rmikz: { + auto &WorkingMI = cloneIfNew(MI); + if (!commuteVPTERNLOG(WorkingMI, OpIdx1, OpIdx2)) + return nullptr; + return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false, + OpIdx1, OpIdx2); + } default: const X86InstrFMA3Group *FMA3Group = X86InstrFMA3Info::getFMA3Group(MI.getOpcode()); @@ -3701,16 +3773,30 @@ MachineInstr *X86InstrInfo::commuteInstructionImpl(MachineInstr &MI, bool NewMI, bool X86InstrInfo::findFMA3CommutedOpIndices( const MachineInstr &MI, unsigned &SrcOpIdx1, unsigned &SrcOpIdx2, const X86InstrFMA3Group &FMA3Group) const { + + if (!findThreeSrcCommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2)) + return false; + + // Check if we can adjust the opcode to preserve the semantics when + // commute the register operands. + return getFMA3OpcodeToCommuteOperands(MI, SrcOpIdx1, SrcOpIdx2, FMA3Group) != 0; +} + +bool X86InstrInfo::findThreeSrcCommutedOpIndices(const MachineInstr &MI, + unsigned &SrcOpIdx1, + unsigned &SrcOpIdx2) const { + uint64_t TSFlags = MI.getDesc().TSFlags; + unsigned FirstCommutableVecOp = 1; unsigned LastCommutableVecOp = 3; unsigned KMaskOp = 0; - if (FMA3Group.isKMasked()) { + if (X86II::isKMasked(TSFlags)) { // The k-mask operand has index = 2 for masked and zero-masked operations. KMaskOp = 2; // The operand with index = 1 is used as a source for those elements for // which the corresponding bit in the k-mask is set to 0. - if (FMA3Group.isKMergeMasked()) + if (X86II::isKMergeMasked(TSFlags)) FirstCommutableVecOp = 3; LastCommutableVecOp++; @@ -3775,9 +3861,7 @@ bool X86InstrInfo::findFMA3CommutedOpIndices( return false; } - // Check if we can adjust the opcode to preserve the semantics when - // commute the register operands. - return getFMA3OpcodeToCommuteOperands(MI, SrcOpIdx1, SrcOpIdx2, FMA3Group) != 0; + return true; } bool X86InstrInfo::findCommutedOpIndices(MachineInstr &MI, unsigned &SrcOpIdx1, @@ -3819,6 +3903,25 @@ bool X86InstrInfo::findCommutedOpIndices(MachineInstr &MI, unsigned &SrcOpIdx1, } return false; } + case X86::VPTERNLOGDZrri: case X86::VPTERNLOGDZrmi: + case X86::VPTERNLOGDZ128rri: case X86::VPTERNLOGDZ128rmi: + case X86::VPTERNLOGDZ256rri: case X86::VPTERNLOGDZ256rmi: + case X86::VPTERNLOGQZrri: case X86::VPTERNLOGQZrmi: + case X86::VPTERNLOGQZ128rri: case X86::VPTERNLOGQZ128rmi: + case X86::VPTERNLOGQZ256rri: case X86::VPTERNLOGQZ256rmi: + case X86::VPTERNLOGDZrrik: case X86::VPTERNLOGDZrmik: + case X86::VPTERNLOGDZ128rrik: case X86::VPTERNLOGDZ128rmik: + case X86::VPTERNLOGDZ256rrik: case X86::VPTERNLOGDZ256rmik: + case X86::VPTERNLOGQZrrik: case X86::VPTERNLOGQZrmik: + case X86::VPTERNLOGQZ128rrik: case X86::VPTERNLOGQZ128rmik: + case X86::VPTERNLOGQZ256rrik: case X86::VPTERNLOGQZ256rmik: + case X86::VPTERNLOGDZrrikz: case X86::VPTERNLOGDZrmikz: + case X86::VPTERNLOGDZ128rrikz: case X86::VPTERNLOGDZ128rmikz: + case X86::VPTERNLOGDZ256rrikz: case X86::VPTERNLOGDZ256rmikz: + case X86::VPTERNLOGQZrrikz: case X86::VPTERNLOGQZrmikz: + case X86::VPTERNLOGQZ128rrikz: case X86::VPTERNLOGQZ128rmikz: + case X86::VPTERNLOGQZ256rrikz: case X86::VPTERNLOGQZ256rmikz: + return findThreeSrcCommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2); default: const X86InstrFMA3Group *FMA3Group = X86InstrFMA3Info::getFMA3Group(MI.getOpcode()); |