diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/CodeGen/MachineCombiner.cpp | 12 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 11 | ||||
-rw-r--r-- | llvm/lib/CodeGen/TargetInstrInfo.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 580 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.h | 5 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h | 1 |
7 files changed, 583 insertions, 38 deletions
diff --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp index 44601d5e462..6b5c6ba8250 100644 --- a/llvm/lib/CodeGen/MachineCombiner.cpp +++ b/llvm/lib/CodeGen/MachineCombiner.cpp @@ -40,6 +40,7 @@ class MachineCombiner : public MachineFunctionPass { const TargetRegisterInfo *TRI; MCSchedModel SchedModel; MachineRegisterInfo *MRI; + MachineLoopInfo *MLI; // Current MachineLoopInfo MachineTraceMetrics *Traces; MachineTraceMetrics::Ensemble *MinInstr; @@ -86,6 +87,7 @@ char &llvm::MachineCombinerID = MachineCombiner::ID; INITIALIZE_PASS_BEGIN(MachineCombiner, "machine-combiner", "Machine InstCombiner", false, false) +INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) INITIALIZE_PASS_DEPENDENCY(MachineTraceMetrics) INITIALIZE_PASS_END(MachineCombiner, "machine-combiner", "Machine InstCombiner", false, false) @@ -93,6 +95,7 @@ INITIALIZE_PASS_END(MachineCombiner, "machine-combiner", "Machine InstCombiner", void MachineCombiner::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); AU.addPreserved<MachineDominatorTree>(); + AU.addRequired<MachineLoopInfo>(); AU.addPreserved<MachineLoopInfo>(); AU.addRequired<MachineTraceMetrics>(); AU.addPreserved<MachineTraceMetrics>(); @@ -354,6 +357,8 @@ bool MachineCombiner::combineInstructions(MachineBasicBlock *MBB) { DEBUG(dbgs() << "Combining MBB " << MBB->getName() << "\n"); auto BlockIter = MBB->begin(); + // Check if the block is in a loop. + const MachineLoop *ML = MLI->getLoopFor(MBB); while (BlockIter != MBB->end()) { auto &MI = *BlockIter++; @@ -406,11 +411,15 @@ bool MachineCombiner::combineInstructions(MachineBasicBlock *MBB) { if (!NewInstCount) continue; + bool SubstituteAlways = false; + if (ML && TII->isThroughputPattern(P)) + SubstituteAlways = true; + // Substitute when we optimize for codesize and the new sequence has // fewer instructions OR // the new sequence neither lengthens the critical path nor increases // resource pressure. - if (doSubstitute(NewInstCount, OldInstCount) || + if (SubstituteAlways || doSubstitute(NewInstCount, OldInstCount) || (improvesCriticalPathLen(MBB, &MI, BlockTrace, InsInstrs, InstrIdxForVirtReg, P) && preservesResourceLen(MBB, BlockTrace, InsInstrs, DelInstrs))) { @@ -447,6 +456,7 @@ bool MachineCombiner::runOnMachineFunction(MachineFunction &MF) { SchedModel = STI.getSchedModel(); TSchedModel.init(SchedModel, &STI, TII); MRI = &MF.getRegInfo(); + MLI = &getAnalysis<MachineLoopInfo>(); Traces = &getAnalysis<MachineTraceMetrics>(); MinInstr = nullptr; OptSize = MF.getFunction()->optForSize(); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 703d33bff17..f740e59af96 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/SelectionDAGTargetInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -7716,6 +7717,11 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { if (!HasFMAD && !HasFMA) return SDValue(); + const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); + ; + if (AllowFusion && STI && STI->GenerateFMAsInMachineCombiner(OptLevel)) + return SDValue(); + // Always prefer FMAD to FMA for precision. unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); @@ -7899,6 +7905,10 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { if (!HasFMAD && !HasFMA) return SDValue(); + const SelectionDAGTargetInfo *STI = DAG.getSubtarget().getSelectionDAGInfo(); + if (AllowFusion && STI && STI->GenerateFMAsInMachineCombiner(OptLevel)) + return SDValue(); + // Always prefer FMAD to FMA for precision. unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); @@ -8367,7 +8377,6 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { AddToWorklist(Fused.getNode()); return Fused; } - return SDValue(); } diff --git a/llvm/lib/CodeGen/TargetInstrInfo.cpp b/llvm/lib/CodeGen/TargetInstrInfo.cpp index 86517d9afbc..800ad6d1bb4 100644 --- a/llvm/lib/CodeGen/TargetInstrInfo.cpp +++ b/llvm/lib/CodeGen/TargetInstrInfo.cpp @@ -655,7 +655,11 @@ bool TargetInstrInfo::getMachineCombinerPatterns( return false; } - +/// Return true when a code sequence can improve loop throughput. +bool +TargetInstrInfo::isThroughputPattern(MachineCombinerPattern Pattern) const { + return false; +} /// Attempt the reassociation transformation to reduce critical path length. /// See the above comments before getMachineCombinerPatterns(). void TargetInstrInfo::reassociateOps( diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index eb0c5785d5d..5a189f40ab1 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2787,37 +2787,75 @@ static bool isCombineInstrCandidate64(unsigned Opc) { return false; } // +// FP Opcodes that can be combined with a FMUL +static bool isCombineInstrCandidateFP(const MachineInstr &Inst) { + switch (Inst.getOpcode()) { + case AArch64::FADDSrr: + case AArch64::FADDDrr: + case AArch64::FADDv2f32: + case AArch64::FADDv2f64: + case AArch64::FADDv4f32: + case AArch64::FSUBSrr: + case AArch64::FSUBDrr: + case AArch64::FSUBv2f32: + case AArch64::FSUBv2f64: + case AArch64::FSUBv4f32: + return Inst.getParent()->getParent()->getTarget().Options.UnsafeFPMath; + default: + break; + } + return false; +} +// // Opcodes that can be combined with a MUL static bool isCombineInstrCandidate(unsigned Opc) { return (isCombineInstrCandidate32(Opc) || isCombineInstrCandidate64(Opc)); } -static bool canCombineWithMUL(MachineBasicBlock &MBB, MachineOperand &MO, - unsigned MulOpc, unsigned ZeroReg) { +// +// Utility routine that checks if \param MO is defined by an +// \param CombineOpc instruction in the basic block \param MBB +static bool canCombine(MachineBasicBlock &MBB, MachineOperand &MO, + unsigned CombineOpc, unsigned ZeroReg = 0, + bool CheckZeroReg = false) { MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); MachineInstr *MI = nullptr; - // We need a virtual register definition. + if (MO.isReg() && TargetRegisterInfo::isVirtualRegister(MO.getReg())) MI = MRI.getUniqueVRegDef(MO.getReg()); // And it needs to be in the trace (otherwise, it won't have a depth). - if (!MI || MI->getParent() != &MBB || (unsigned)MI->getOpcode() != MulOpc) - return false; - - assert(MI->getNumOperands() >= 4 && MI->getOperand(0).isReg() && - MI->getOperand(1).isReg() && MI->getOperand(2).isReg() && - MI->getOperand(3).isReg() && "MAdd/MSub must have a least 4 regs"); - - // The third input reg must be zero. - if (MI->getOperand(3).getReg() != ZeroReg) + if (!MI || MI->getParent() != &MBB || (unsigned)MI->getOpcode() != CombineOpc) return false; - // Must only used by the user we combine with. if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) return false; + if (CheckZeroReg) { + assert(MI->getNumOperands() >= 4 && MI->getOperand(0).isReg() && + MI->getOperand(1).isReg() && MI->getOperand(2).isReg() && + MI->getOperand(3).isReg() && "MAdd/MSub must have a least 4 regs"); + // The third input reg must be zero. + if (MI->getOperand(3).getReg() != ZeroReg) + return false; + } + return true; } +// +// Is \param MO defined by an integer multiply and can be combined? +static bool canCombineWithMUL(MachineBasicBlock &MBB, MachineOperand &MO, + unsigned MulOpc, unsigned ZeroReg) { + return canCombine(MBB, MO, MulOpc, ZeroReg, true); +} + +// +// Is \param MO defined by a floating-point multiply and can be combined? +static bool canCombineWithFMUL(MachineBasicBlock &MBB, MachineOperand &MO, + unsigned MulOpc) { + return canCombine(MBB, MO, MulOpc); +} + // TODO: There are many more machine instruction opcodes to match: // 1. Other data types (integer, vectors) // 2. Other math / logic operations (xor, or) @@ -2951,7 +2989,230 @@ static bool getMaddPatterns(MachineInstr &Root, } return Found; } +/// Floating-Point Support +/// Find instructions that can be turned into madd. +static bool getFMAPatterns(MachineInstr &Root, + SmallVectorImpl<MachineCombinerPattern> &Patterns) { + + if (!isCombineInstrCandidateFP(Root)) + return 0; + + MachineBasicBlock &MBB = *Root.getParent(); + bool Found = false; + + switch (Root.getOpcode()) { + default: + assert(false && "Unsupported FP instruction in combiner\n"); + break; + case AArch64::FADDSrr: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg() && + "FADDWrr does not have register operands"); + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) { + Patterns.push_back(MachineCombinerPattern::FMULADDS_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv1i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv1i32_indexed_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULSrr)) { + Patterns.push_back(MachineCombinerPattern::FMULADDS_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv1i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv1i32_indexed_OP2); + Found = true; + } + break; + case AArch64::FADDDrr: + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULDrr)) { + Patterns.push_back(MachineCombinerPattern::FMULADDD_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv1i64_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv1i64_indexed_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULDrr)) { + Patterns.push_back(MachineCombinerPattern::FMULADDD_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv1i64_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv1i64_indexed_OP2); + Found = true; + } + break; + case AArch64::FADDv2f32: + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv2i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2i32_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv2f32)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2f32_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2i32_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2f32)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2f32_OP2); + Found = true; + } + break; + case AArch64::FADDv2f64: + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv2i64_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2i64_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv2f64)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2f64_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2i64_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2i64_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2f64)) { + Patterns.push_back(MachineCombinerPattern::FMLAv2f64_OP2); + Found = true; + } + break; + case AArch64::FADDv4f32: + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv4i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4i32_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv4f32)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4f32_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4i32_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4f32)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4f32_OP2); + Found = true; + } + break; + + case AArch64::FSUBSrr: + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) { + Patterns.push_back(MachineCombinerPattern::FMULSUBS_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULSrr)) { + Patterns.push_back(MachineCombinerPattern::FMULSUBS_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv1i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv1i32_indexed_OP2); + Found = true; + } + break; + case AArch64::FSUBDrr: + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULDrr)) { + Patterns.push_back(MachineCombinerPattern::FMULSUBD_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULDrr)) { + Patterns.push_back(MachineCombinerPattern::FMULSUBD_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv1i64_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv1i64_indexed_OP2); + Found = true; + } + break; + case AArch64::FSUBv2f32: + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv2i32_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2f32)) { + Patterns.push_back(MachineCombinerPattern::FMLSv2f32_OP2); + Found = true; + } + break; + case AArch64::FSUBv2f64: + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2i64_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv2i64_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv2f64)) { + Patterns.push_back(MachineCombinerPattern::FMLSv2f64_OP2); + Found = true; + } + break; + case AArch64::FSUBv4f32: + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4i32_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv4i32_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4f32)) { + Patterns.push_back(MachineCombinerPattern::FMLSv4f32_OP2); + Found = true; + } + break; + } + return Found; +} + +/// Return true when a code sequence can improve throughput. It +/// should be called only for instructions in loops. +/// \param Pattern - combiner pattern +bool +AArch64InstrInfo::isThroughputPattern(MachineCombinerPattern Pattern) const { + switch (Pattern) { + default: + break; + case MachineCombinerPattern::FMULADDS_OP1: + case MachineCombinerPattern::FMULADDS_OP2: + case MachineCombinerPattern::FMULSUBS_OP1: + case MachineCombinerPattern::FMULSUBS_OP2: + case MachineCombinerPattern::FMULADDD_OP1: + case MachineCombinerPattern::FMULADDD_OP2: + case MachineCombinerPattern::FMULSUBD_OP1: + case MachineCombinerPattern::FMULSUBD_OP2: + case MachineCombinerPattern::FMLAv1i32_indexed_OP1: + case MachineCombinerPattern::FMLAv1i32_indexed_OP2: + case MachineCombinerPattern::FMLAv1i64_indexed_OP1: + case MachineCombinerPattern::FMLAv1i64_indexed_OP2: + case MachineCombinerPattern::FMLAv2f32_OP2: + case MachineCombinerPattern::FMLAv2f32_OP1: + case MachineCombinerPattern::FMLAv2f64_OP1: + case MachineCombinerPattern::FMLAv2f64_OP2: + case MachineCombinerPattern::FMLAv2i32_indexed_OP1: + case MachineCombinerPattern::FMLAv2i32_indexed_OP2: + case MachineCombinerPattern::FMLAv2i64_indexed_OP1: + case MachineCombinerPattern::FMLAv2i64_indexed_OP2: + case MachineCombinerPattern::FMLAv4f32_OP1: + case MachineCombinerPattern::FMLAv4f32_OP2: + case MachineCombinerPattern::FMLAv4i32_indexed_OP1: + case MachineCombinerPattern::FMLAv4i32_indexed_OP2: + case MachineCombinerPattern::FMLSv1i32_indexed_OP2: + case MachineCombinerPattern::FMLSv1i64_indexed_OP2: + case MachineCombinerPattern::FMLSv2i32_indexed_OP2: + case MachineCombinerPattern::FMLSv2i64_indexed_OP2: + case MachineCombinerPattern::FMLSv2f32_OP2: + case MachineCombinerPattern::FMLSv2f64_OP2: + case MachineCombinerPattern::FMLSv4i32_indexed_OP2: + case MachineCombinerPattern::FMLSv4f32_OP2: + return true; + } // end switch (Pattern) + return false; +} /// Return true when there is potentially a faster code sequence for an /// instruction chain ending in \p Root. All potential patterns are listed in /// the \p Pattern vector. Pattern should be sorted in priority order since the @@ -2960,28 +3221,35 @@ static bool getMaddPatterns(MachineInstr &Root, bool AArch64InstrInfo::getMachineCombinerPatterns( MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns) const { + // Integer patterns if (getMaddPatterns(Root, Patterns)) return true; + // Floating point patterns + if (getFMAPatterns(Root, Patterns)) + return true; return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns); } -/// genMadd - Generate madd instruction and combine mul and add. -/// Example: -/// MUL I=A,B,0 -/// ADD R,I,C -/// ==> MADD R,A,B,C -/// \param Root is the ADD instruction +enum class FMAInstKind { Default, Indexed, Accumulator }; +/// genFusedMultiply - Generate fused multiply instructions. +/// This function supports both integer and floating point instructions. +/// A typical example: +/// F|MUL I=A,B,0 +/// F|ADD R,I,C +/// ==> F|MADD R,A,B,C +/// \param Root is the F|ADD instruction /// \param [out] InsInstrs is a vector of machine instructions and will /// contain the generated madd instruction /// \param IdxMulOpd is index of operand in Root that is the result of -/// the MUL. In the example above IdxMulOpd is 1. -/// \param MaddOpc the opcode fo the madd instruction -static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI, - const TargetInstrInfo *TII, MachineInstr &Root, - SmallVectorImpl<MachineInstr *> &InsInstrs, - unsigned IdxMulOpd, unsigned MaddOpc, - const TargetRegisterClass *RC) { +/// the F|MUL. In the example above IdxMulOpd is 1. +/// \param MaddOpc the opcode fo the f|madd instruction +static MachineInstr * +genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI, + const TargetInstrInfo *TII, MachineInstr &Root, + SmallVectorImpl<MachineInstr *> &InsInstrs, unsigned IdxMulOpd, + unsigned MaddOpc, const TargetRegisterClass *RC, + FMAInstKind kind = FMAInstKind::Default) { assert(IdxMulOpd == 1 || IdxMulOpd == 2); unsigned IdxOtherOpd = IdxMulOpd == 1 ? 2 : 1; @@ -3003,12 +3271,26 @@ static MachineInstr *genMadd(MachineFunction &MF, MachineRegisterInfo &MRI, if (TargetRegisterInfo::isVirtualRegister(SrcReg2)) MRI.constrainRegClass(SrcReg2, RC); - MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), - ResultReg) - .addReg(SrcReg0, getKillRegState(Src0IsKill)) - .addReg(SrcReg1, getKillRegState(Src1IsKill)) - .addReg(SrcReg2, getKillRegState(Src2IsKill)); - // Insert the MADD + MachineInstrBuilder MIB; + if (kind == FMAInstKind::Default) + MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), ResultReg) + .addReg(SrcReg0, getKillRegState(Src0IsKill)) + .addReg(SrcReg1, getKillRegState(Src1IsKill)) + .addReg(SrcReg2, getKillRegState(Src2IsKill)); + else if (kind == FMAInstKind::Indexed) + MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), ResultReg) + .addReg(SrcReg2, getKillRegState(Src2IsKill)) + .addReg(SrcReg0, getKillRegState(Src0IsKill)) + .addReg(SrcReg1, getKillRegState(Src1IsKill)) + .addImm(MUL->getOperand(3).getImm()); + else if (kind == FMAInstKind::Accumulator) + MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MaddOpc), ResultReg) + .addReg(SrcReg2, getKillRegState(Src2IsKill)) + .addReg(SrcReg0, getKillRegState(Src0IsKill)) + .addReg(SrcReg1, getKillRegState(Src1IsKill)); + else + assert(false && "Invalid FMA instruction kind \n"); + // Insert the MADD (MADD, FMA, FMS, FMLA, FMSL) InsInstrs.push_back(MIB); return MUL; } @@ -3096,7 +3378,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence( Opc = AArch64::MADDXrrr; RC = &AArch64::GPR64RegClass; } - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); break; case MachineCombinerPattern::MULADDW_OP2: case MachineCombinerPattern::MULADDX_OP2: @@ -3111,7 +3393,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence( Opc = AArch64::MADDXrrr; RC = &AArch64::GPR64RegClass; } - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULADDWI_OP1: case MachineCombinerPattern::MULADDXI_OP1: { @@ -3203,7 +3485,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence( Opc = AArch64::MSUBXrrr; RC = &AArch64::GPR64RegClass; } - MUL = genMadd(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBWI_OP1: case MachineCombinerPattern::MULSUBXI_OP1: { @@ -3248,6 +3530,234 @@ void AArch64InstrInfo::genAlternativeCodeSequence( } break; } + // Floating Point Support + case MachineCombinerPattern::FMULADDS_OP1: + case MachineCombinerPattern::FMULADDD_OP1: + // MUL I=A,B,0 + // ADD R,I,C + // ==> MADD R,A,B,C + // --- Create(MADD); + if (Pattern == MachineCombinerPattern::FMULADDS_OP1) { + Opc = AArch64::FMADDSrrr; + RC = &AArch64::FPR32RegClass; + } else { + Opc = AArch64::FMADDDrrr; + RC = &AArch64::FPR64RegClass; + } + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::FMULADDS_OP2: + case MachineCombinerPattern::FMULADDD_OP2: + // FMUL I=A,B,0 + // FADD R,C,I + // ==> FMADD R,A,B,C + // --- Create(FMADD); + if (Pattern == MachineCombinerPattern::FMULADDS_OP2) { + Opc = AArch64::FMADDSrrr; + RC = &AArch64::FPR32RegClass; + } else { + Opc = AArch64::FMADDDrrr; + RC = &AArch64::FPR64RegClass; + } + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::FMLAv1i32_indexed_OP1: + Opc = AArch64::FMLAv1i32_indexed; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLAv1i32_indexed_OP2: + Opc = AArch64::FMLAv1i32_indexed; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + + case MachineCombinerPattern::FMLAv1i64_indexed_OP1: + Opc = AArch64::FMLAv1i64_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLAv1i64_indexed_OP2: + Opc = AArch64::FMLAv1i64_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + + case MachineCombinerPattern::FMLAv2i32_indexed_OP1: + case MachineCombinerPattern::FMLAv2f32_OP1: + RC = &AArch64::FPR64RegClass; + if (Pattern == MachineCombinerPattern::FMLAv2i32_indexed_OP1) { + Opc = AArch64::FMLAv2i32_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLAv2f32; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator); + } + break; + case MachineCombinerPattern::FMLAv2i32_indexed_OP2: + case MachineCombinerPattern::FMLAv2f32_OP2: + RC = &AArch64::FPR64RegClass; + if (Pattern == MachineCombinerPattern::FMLAv2i32_indexed_OP2) { + Opc = AArch64::FMLAv2i32_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLAv2f32; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + } + break; + + case MachineCombinerPattern::FMLAv2i64_indexed_OP1: + case MachineCombinerPattern::FMLAv2f64_OP1: + RC = &AArch64::FPR128RegClass; + if (Pattern == MachineCombinerPattern::FMLAv2i64_indexed_OP1) { + Opc = AArch64::FMLAv2i64_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLAv2f64; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator); + } + break; + case MachineCombinerPattern::FMLAv2i64_indexed_OP2: + case MachineCombinerPattern::FMLAv2f64_OP2: + RC = &AArch64::FPR128RegClass; + if (Pattern == MachineCombinerPattern::FMLAv2i64_indexed_OP2) { + Opc = AArch64::FMLAv2i64_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLAv2f64; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + } + break; + + case MachineCombinerPattern::FMLAv4i32_indexed_OP1: + case MachineCombinerPattern::FMLAv4f32_OP1: + RC = &AArch64::FPR128RegClass; + if (Pattern == MachineCombinerPattern::FMLAv4i32_indexed_OP1) { + Opc = AArch64::FMLAv4i32_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLAv4f32; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator); + } + break; + + case MachineCombinerPattern::FMLAv4i32_indexed_OP2: + case MachineCombinerPattern::FMLAv4f32_OP2: + RC = &AArch64::FPR128RegClass; + if (Pattern == MachineCombinerPattern::FMLAv4i32_indexed_OP2) { + Opc = AArch64::FMLAv4i32_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLAv4f32; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + } + break; + + case MachineCombinerPattern::FMULSUBS_OP1: + case MachineCombinerPattern::FMULSUBD_OP1: { + // FMUL I=A,B,0 + // FSUB R,I,C + // ==> FNMSUB R,A,B,C // = -C + A*B + // --- Create(FNMSUB); + if (Pattern == MachineCombinerPattern::FMULSUBS_OP1) { + Opc = AArch64::FNMSUBSrrr; + RC = &AArch64::FPR32RegClass; + } else { + Opc = AArch64::FNMSUBDrrr; + RC = &AArch64::FPR64RegClass; + } + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + } + case MachineCombinerPattern::FMULSUBS_OP2: + case MachineCombinerPattern::FMULSUBD_OP2: { + // FMUL I=A,B,0 + // FSUB R,C,I + // ==> FMSUB R,A,B,C (computes C - A*B) + // --- Create(FMSUB); + if (Pattern == MachineCombinerPattern::FMULSUBS_OP2) { + Opc = AArch64::FMSUBSrrr; + RC = &AArch64::FPR32RegClass; + } else { + Opc = AArch64::FMSUBDrrr; + RC = &AArch64::FPR64RegClass; + } + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::FMLSv1i32_indexed_OP2: + Opc = AArch64::FMLSv1i32_indexed; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + + case MachineCombinerPattern::FMLSv1i64_indexed_OP2: + Opc = AArch64::FMLSv1i64_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + + case MachineCombinerPattern::FMLSv2f32_OP2: + case MachineCombinerPattern::FMLSv2i32_indexed_OP2: + RC = &AArch64::FPR64RegClass; + if (Pattern == MachineCombinerPattern::FMLSv2i32_indexed_OP2) { + Opc = AArch64::FMLSv2i32_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLSv2f32; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + } + break; + + case MachineCombinerPattern::FMLSv2f64_OP2: + case MachineCombinerPattern::FMLSv2i64_indexed_OP2: + RC = &AArch64::FPR128RegClass; + if (Pattern == MachineCombinerPattern::FMLSv2i64_indexed_OP2) { + Opc = AArch64::FMLSv2i64_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLSv2f64; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + } + break; + + case MachineCombinerPattern::FMLSv4f32_OP2: + case MachineCombinerPattern::FMLSv4i32_indexed_OP2: + RC = &AArch64::FPR128RegClass; + if (Pattern == MachineCombinerPattern::FMLSv4i32_indexed_OP2) { + Opc = AArch64::FMLSv4i32_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + } else { + Opc = AArch64::FMLSv4f32; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + } + break; + } } // end switch (Pattern) // Record MUL and ADD/SUB for deletion DelInstrs.push_back(MUL); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h index a592f91dd4e..353ef735dac 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -174,6 +174,11 @@ public: unsigned SrcReg2, int CmpMask, int CmpValue, const MachineRegisterInfo *MRI) const override; bool optimizeCondBranch(MachineInstr *MI) const override; + + /// Return true when a code sequence can improve throughput. It + /// should be called only for instructions in loops. + /// \param Pattern - combiner pattern + bool isThroughputPattern(MachineCombinerPattern Pattern) const override; /// Return true when there is potentially a faster code sequence /// for an instruction chain ending in <Root>. All potential patterns are /// listed in the <Patterns> array. diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp index f40293021d7..4e4aaf8e553 100644 --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp @@ -51,3 +51,9 @@ SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset( } return SDValue(); } +bool AArch64SelectionDAGInfo::GenerateFMAsInMachineCombiner( + CodeGenOpt::Level OptLevel) const { + if (OptLevel >= CodeGenOpt::Aggressive) + return true; + return false; +} diff --git a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h index 8adb030555a..e61f177f2ef 100644 --- a/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h +++ b/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.h @@ -25,6 +25,7 @@ public: SDValue Dst, SDValue Src, SDValue Size, unsigned Align, bool isVolatile, MachinePointerInfo DstPtrInfo) const override; + bool GenerateFMAsInMachineCombiner(CodeGenOpt::Level OptLevel) const override; }; } |