diff options
author | Sam Parker <sam.parker@arm.com> | 2019-11-18 17:07:56 +0000 |
---|---|---|
committer | Sam Parker <sam.parker@arm.com> | 2019-11-19 08:22:18 +0000 |
commit | 8978c12b39f90194bb35860729ddca5e819f3b92 (patch) | |
tree | dd60953f653866f6508e4a083c4758d0de834b51 /llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp | |
parent | d593292f0465c9db1f2c3cdf719009bfdf942a5c (diff) | |
download | bcm5719-llvm-8978c12b39f90194bb35860729ddca5e819f3b92.tar.gz bcm5719-llvm-8978c12b39f90194bb35860729ddca5e819f3b92.zip |
[ARM][MVE] Tail predication conversion
This patch modifies ARMLowOverheadLoops to convert a predicated
vector low-overhead loop into a tail-predicatd one. This is currently
a very basic conversion, with the following restrictions:
- Operates only on single block loops.
- The loop can only contain a single vctp instruction.
- No other instructions can write to the vpr.
- We only allow a subset of the mve instructions in the loop.
TODO: Pass the number of elements, not the number of iterations to
dlstp/wlstp.
Differential Revision: https://reviews.llvm.org/D69945
Diffstat (limited to 'llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp')
-rw-r--r-- | llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp | 428 |
1 files changed, 294 insertions, 134 deletions
diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp index e1c5a9c3e22..733a3f16606 100644 --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -25,6 +25,7 @@ #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/MC/MCInstrDesc.h" using namespace llvm; @@ -33,6 +34,97 @@ using namespace llvm; namespace { + struct LowOverheadLoop { + + MachineLoop *ML = nullptr; + MachineFunction *MF = nullptr; + MachineInstr *InsertPt = nullptr; + MachineInstr *Start = nullptr; + MachineInstr *Dec = nullptr; + MachineInstr *End = nullptr; + MachineInstr *VCTP = nullptr; + SmallVector<MachineInstr*, 4> VPTUsers; + bool Revert = false; + bool FoundOneVCTP = false; + bool CannotTailPredicate = false; + + LowOverheadLoop(MachineLoop *ML) : ML(ML) { + MF = ML->getHeader()->getParent(); + } + + // For now, only support one vctp instruction. If we find multiple then + // we shouldn't perform tail predication. + void addVCTP(MachineInstr *MI) { + if (!VCTP) { + VCTP = MI; + FoundOneVCTP = true; + } else + FoundOneVCTP = false; + } + + // Check that nothing else is writing to VPR and record any insts + // reading the VPR. + void ScanForVPR(MachineInstr *MI) { + for (auto &MO : MI->operands()) { + if (!MO.isReg() || MO.getReg() != ARM::VPR) + continue; + if (MO.isUse()) + VPTUsers.push_back(MI); + if (MO.isDef()) { + CannotTailPredicate = true; + break; + } + } + } + + // If this is an MVE instruction, check that we know how to use tail + // predication with it. + void CheckTPValidity(MachineInstr *MI) { + if (CannotTailPredicate) + return; + + const MCInstrDesc &MCID = MI->getDesc(); + uint64_t Flags = MCID.TSFlags; + if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) + return; + + if ((Flags & ARMII::ValidForTailPredication) == 0) { + LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI); + CannotTailPredicate = true; + } + } + + bool IsTailPredicationLegal() const { + // For now, let's keep things really simple and only support a single + // block for tail predication. + return !Revert && FoundAllComponents() && FoundOneVCTP && + !CannotTailPredicate && ML->getNumBlocks() == 1; + } + + // Is it safe to define LR with DLS/WLS? + // LR can be defined if it is the operand to start, because it's the same + // value, or if it's going to be equivalent to the operand to Start. + MachineInstr *IsSafeToDefineLR(); + + // Check the branch targets are within range and we satisfy our restructi + void CheckLegality(ARMBasicBlockUtils *BBUtils); + + bool FoundAllComponents() const { + return Start && Dec && End; + } + + void dump() const { + if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; + if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; + if (End) dbgs() << "ARM Loops: Found Loop End: " << *End; + if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP; + if (!FoundAllComponents()) + dbgs() << "ARM Loops: Not a low-overhead loop.\n"; + else if (!(Start && Dec && End)) + dbgs() << "ARM Loops: Failed to find all loop components.\n"; + } + }; + class ARMLowOverheadLoops : public MachineFunctionPass { MachineFunction *MF = nullptr; const ARMBaseInstrInfo *TII = nullptr; @@ -64,8 +156,6 @@ namespace { private: bool ProcessLoop(MachineLoop *ML); - MachineInstr * IsSafeToDefineLR(MachineInstr *MI); - bool RevertNonLoops(); void RevertWhile(MachineInstr *MI) const; @@ -74,9 +164,11 @@ namespace { void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const; - void Expand(MachineLoop *ML, MachineInstr *Start, - MachineInstr *InsertPt, MachineInstr *Dec, - MachineInstr *End, bool Revert); + void RemoveVPTBlocks(LowOverheadLoop &LoLoop); + + MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); + + void Expand(LowOverheadLoop &LoLoop); }; } @@ -86,31 +178,6 @@ char ARMLowOverheadLoops::ID = 0; INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME, false, false) -bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { - const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget()); - if (!ST.hasLOB()) - return false; - - MF = &mf; - LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); - - auto &MLI = getAnalysis<MachineLoopInfo>(); - MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); - MRI = &MF->getRegInfo(); - TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo()); - BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF)); - BBUtils->computeAllBlockSizes(); - BBUtils->adjustBBOffsetsAfter(&MF->front()); - - bool Changed = false; - for (auto ML : MLI) { - if (!ML->getParentLoop()) - Changed |= ProcessLoop(ML); - } - Changed |= RevertNonLoops(); - return Changed; -} - static bool IsLoopStart(MachineInstr &MI) { return MI.getOpcode() == ARM::t2DoLoopStart || MI.getOpcode() == ARM::t2WhileLoopStart; @@ -141,10 +208,20 @@ static MachineInstr* SearchForUse(MachineInstr *Begin, return nullptr; } -// Is it safe to define LR with DLS/WLS? -// LR can defined if it is the operand to start, because it's the same value, -// or if it's going to be equivalent to the operand to Start. -MachineInstr *ARMLowOverheadLoops::IsSafeToDefineLR(MachineInstr *Start) { +static bool IsVCTP(MachineInstr *MI) { + switch (MI->getOpcode()) { + default: + break; + case ARM::MVE_VCTP8: + case ARM::MVE_VCTP16: + case ARM::MVE_VCTP32: + case ARM::MVE_VCTP64: + return true; + } + return false; +} + +MachineInstr *LowOverheadLoop::IsSafeToDefineLR() { auto IsMoveLR = [](MachineInstr *MI, unsigned Reg) { return MI->getOpcode() == ARM::tMOVr && @@ -210,6 +287,78 @@ MachineInstr *ARMLowOverheadLoops::IsSafeToDefineLR(MachineInstr *Start) { return nullptr; } +void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) { + if (Revert) + return; + + if (!End->getOperand(1).isMBB()) + report_fatal_error("Expected LoopEnd to target basic block"); + + // TODO Maybe there's cases where the target doesn't have to be the header, + // but for now be safe and revert. + if (End->getOperand(1).getMBB() != ML->getHeader()) { + LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); + Revert = true; + return; + } + + // The WLS and LE instructions have 12-bits for the label offset. WLS + // requires a positive offset, while LE uses negative. + if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) || + !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) { + LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); + Revert = true; + return; + } + + if (Start->getOpcode() == ARM::t2WhileLoopStart && + (BBUtils->getOffsetOf(Start) > + BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) || + !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) { + LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n"); + Revert = true; + return; + } + + InsertPt = Revert ? nullptr : IsSafeToDefineLR(); + if (!InsertPt) { + LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); + Revert = true; + } else + LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); + + LLVM_DEBUG(if (IsTailPredicationLegal()) { + dbgs() << "ARM Loops: Will use tail predication to convert:\n"; + for (auto *MI : VPTUsers) + dbgs() << " - " << *MI; + }); +} + +bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { + const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget()); + if (!ST.hasLOB()) + return false; + + MF = &mf; + LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); + + auto &MLI = getAnalysis<MachineLoopInfo>(); + MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); + MRI = &MF->getRegInfo(); + TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo()); + BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF)); + BBUtils->computeAllBlockSizes(); + BBUtils->adjustBBOffsetsAfter(&MF->front()); + + bool Changed = false; + for (auto ML : MLI) { + if (!ML->getParentLoop()) + Changed |= ProcessLoop(ML); + } + Changed |= RevertNonLoops(); + return Changed; +} + bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { bool Changed = false; @@ -233,18 +382,14 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { return nullptr; }; - MachineInstr *Start = nullptr; - MachineInstr *Dec = nullptr; - MachineInstr *End = nullptr; - bool Revert = false; - + LowOverheadLoop LoLoop(ML); // Search the preheader for the start intrinsic, or look through the // predecessors of the header to find exactly one set.iterations intrinsic. // FIXME: I don't see why we shouldn't be supporting multiple predecessors // with potentially multiple set.loop.iterations, so we need to enable this. - if (auto *Preheader = ML->getLoopPreheader()) { - Start = SearchForStart(Preheader); - } else { + if (auto *Preheader = ML->getLoopPreheader()) + LoLoop.Start = SearchForStart(Preheader); + else { LLVM_DEBUG(dbgs() << "ARM Loops: Failed to find loop preheader!\n" << " - Performing manual predecessor search.\n"); MachineBasicBlock *Pred = nullptr; @@ -252,34 +397,46 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { if (!ML->contains(MBB)) { if (Pred) { LLVM_DEBUG(dbgs() << " - Found multiple out-of-loop preds.\n"); - Start = nullptr; + LoLoop.Start = nullptr; break; } Pred = MBB; - Start = SearchForStart(MBB); + LoLoop.Start = SearchForStart(MBB); } } } // Find the low-overhead loop components and decide whether or not to fall - // back to a normal loop. + // back to a normal loop. Also look for a vctp instructions and decide + // whether we can convert that predicate using tail predication. for (auto *MBB : reverse(ML->getBlocks())) { for (auto &MI : *MBB) { if (MI.getOpcode() == ARM::t2LoopDec) - Dec = &MI; + LoLoop.Dec = &MI; else if (MI.getOpcode() == ARM::t2LoopEnd) - End = &MI; + LoLoop.End = &MI; else if (IsLoopStart(MI)) - Start = &MI; + LoLoop.Start = &MI; + else if (IsVCTP(&MI)) + LoLoop.addVCTP(&MI); else if (MI.getDesc().isCall()) { // TODO: Though the call will require LE to execute again, does this // mean we should revert? Always executing LE hopefully should be // faster than performing a sub,cmp,br or even subs,br. - Revert = true; + LoLoop.Revert = true; LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); + } else { + // Once we've found a vctp, record the users of vpr and check there's + // no more vpr defs. + if (LoLoop.FoundOneVCTP) + LoLoop.ScanForVPR(&MI); + // Check we know how to tail predicate any mve instructions. + LoLoop.CheckTPValidity(&MI); } - if (!Dec || End) + // We need to ensure that LR is not used or defined inbetween LoopDec and + // LoopEnd. + if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert) continue; // If we find that LR has been written or read between LoopDec and @@ -294,61 +451,19 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() && MO.getReg() == ARM::LR) { LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI); - Revert = true; + LoLoop.Revert = true; break; } } } - - if (Dec && End && Revert) - break; } - LLVM_DEBUG(if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; - if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; - if (End) dbgs() << "ARM Loops: Found Loop End: " << *End;); - - if (!Start && !Dec && !End) { - LLVM_DEBUG(dbgs() << "ARM Loops: Not a low-overhead loop.\n"); - return Changed; - } else if (!(Start && Dec && End)) { - LLVM_DEBUG(dbgs() << "ARM Loops: Failed to find all loop components.\n"); + LLVM_DEBUG(LoLoop.dump()); + if (!LoLoop.FoundAllComponents()) return false; - } - - if (!End->getOperand(1).isMBB()) - report_fatal_error("Expected LoopEnd to target basic block"); - - // TODO Maybe there's cases where the target doesn't have to be the header, - // but for now be safe and revert. - if (End->getOperand(1).getMBB() != ML->getHeader()) { - LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); - Revert = true; - } - - // The WLS and LE instructions have 12-bits for the label offset. WLS - // requires a positive offset, while LE uses negative. - if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) || - !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) { - LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); - Revert = true; - } - if (Start->getOpcode() == ARM::t2WhileLoopStart && - (BBUtils->getOffsetOf(Start) > - BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) || - !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) { - LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n"); - Revert = true; - } - - MachineInstr *InsertPt = Revert ? nullptr : IsSafeToDefineLR(Start); - if (!InsertPt) { - LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); - Revert = true; - } else - LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); - Expand(ML, Start, InsertPt, Dec, End, Revert); + LoLoop.CheckLegality(BBUtils.get()); + Expand(LoLoop); return true; } @@ -438,44 +553,87 @@ void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const { MI->eraseFromParent(); } -void ARMLowOverheadLoops::Expand(MachineLoop *ML, MachineInstr *Start, - MachineInstr *InsertPt, - MachineInstr *Dec, MachineInstr *End, - bool Revert) { +MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { + MachineInstr *InsertPt = LoLoop.InsertPt; + MachineInstr *Start = LoLoop.Start; + MachineBasicBlock *MBB = InsertPt->getParent(); + bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; + unsigned Opc = 0; + + if (!LoLoop.IsTailPredicationLegal()) + Opc = IsDo ? ARM::t2DLS : ARM::t2WLS; + else { + switch (LoLoop.VCTP->getOpcode()) { + case ARM::MVE_VCTP8: + Opc = IsDo ? ARM::MVE_DLSTP_8 : ARM::MVE_WLSTP_8; + break; + case ARM::MVE_VCTP16: + Opc = IsDo ? ARM::MVE_DLSTP_16 : ARM::MVE_WLSTP_16; + break; + case ARM::MVE_VCTP32: + Opc = IsDo ? ARM::MVE_DLSTP_32 : ARM::MVE_WLSTP_32; + break; + case ARM::MVE_VCTP64: + Opc = IsDo ? ARM::MVE_DLSTP_64 : ARM::MVE_WLSTP_64; + break; + } + } - auto ExpandLoopStart = [this](MachineLoop *ML, MachineInstr *Start, - MachineInstr *InsertPt) { - MachineBasicBlock *MBB = InsertPt->getParent(); - unsigned Opc = Start->getOpcode() == ARM::t2DoLoopStart ? - ARM::t2DLS : ARM::t2WLS; - MachineInstrBuilder MIB = - BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); + MachineInstrBuilder MIB = + BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); - MIB.addDef(ARM::LR); - MIB.add(Start->getOperand(0)); - if (Opc == ARM::t2WLS) - MIB.add(Start->getOperand(1)); - - if (InsertPt != Start) - InsertPt->eraseFromParent(); - Start->eraseFromParent(); - LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); - return &*MIB; - }; + MIB.addDef(ARM::LR); + MIB.add(Start->getOperand(0)); + if (!IsDo) + MIB.add(Start->getOperand(1)); + + if (InsertPt != Start) + InsertPt->eraseFromParent(); + Start->eraseFromParent(); + LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); + return &*MIB; +} + +void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) { + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); + LoLoop.VCTP->eraseFromParent(); + + for (auto *MI : LoLoop.VPTUsers) { + if (MI->getOpcode() == ARM::MVE_VPST) { + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI); + MI->eraseFromParent(); + } else { + unsigned OpNum = MI->getNumOperands() - 1; + assert((MI->getOperand(OpNum).isReg() && + MI->getOperand(OpNum).getReg() == ARM::VPR) && + "Expected VPR"); + assert((MI->getOperand(OpNum-1).isImm() && + MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) && + "Expected Then predicate"); + MI->getOperand(OpNum-1).setImm(ARMVCC::None); + MI->getOperand(OpNum).setReg(0); + LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI); + } + } +} + +void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { // Combine the LoopDec and LoopEnd instructions into LE(TP). - auto ExpandLoopEnd = [this](MachineLoop *ML, MachineInstr *Dec, - MachineInstr *End) { + auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) { + MachineInstr *End = LoLoop.End; MachineBasicBlock *MBB = End->getParent(); + unsigned Opc = LoLoop.IsTailPredicationLegal() ? + ARM::MVE_LETP : ARM::t2LEUpdate; MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(), - TII->get(ARM::t2LEUpdate)); + TII->get(Opc)); MIB.addDef(ARM::LR); MIB.add(End->getOperand(0)); MIB.add(End->getOperand(1)); LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB); - End->eraseFromParent(); - Dec->eraseFromParent(); + LoLoop.End->eraseFromParent(); + LoLoop.Dec->eraseFromParent(); return &*MIB; }; @@ -496,18 +654,20 @@ void ARMLowOverheadLoops::Expand(MachineLoop *ML, MachineInstr *Start, } }; - if (Revert) { - if (Start->getOpcode() == ARM::t2WhileLoopStart) - RevertWhile(Start); + if (LoLoop.Revert) { + if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart) + RevertWhile(LoLoop.Start); else - Start->eraseFromParent(); - bool FlagsAlreadySet = RevertLoopDec(Dec, true); - RevertLoopEnd(End, FlagsAlreadySet); + LoLoop.Start->eraseFromParent(); + bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true); + RevertLoopEnd(LoLoop.End, FlagsAlreadySet); } else { - Start = ExpandLoopStart(ML, Start, InsertPt); - RemoveDeadBranch(Start); - End = ExpandLoopEnd(ML, Dec, End); - RemoveDeadBranch(End); + LoLoop.Start = ExpandLoopStart(LoLoop); + RemoveDeadBranch(LoLoop.Start); + LoLoop.End = ExpandLoopEnd(LoLoop); + RemoveDeadBranch(LoLoop.End); + if (LoLoop.IsTailPredicationLegal()) + RemoveVPTBlocks(LoLoop); } } |