diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp | 294 | ||||
| -rw-r--r-- | llvm/lib/Target/ARM/MVEVPTBlockPass.cpp | 18 | ||||
| -rw-r--r-- | llvm/lib/Target/ARM/Utils/ARMBaseInfo.h | 34 |
3 files changed, 270 insertions, 76 deletions
diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp index e3a05408343..ec62a6975f0 100644 --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -22,6 +22,7 @@ #include "ARMBaseRegisterInfo.h" #include "ARMBasicBlockInfo.h" #include "ARMSubtarget.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineLoopUtils.h" @@ -37,6 +38,65 @@ using namespace llvm; namespace { + struct PredicatedMI { + MachineInstr *MI = nullptr; + SetVector<MachineInstr*> Predicates; + + public: + PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) : + MI(I) { + Predicates.insert(Preds.begin(), Preds.end()); + } + }; + + // Represent a VPT block, a list of instructions that begins with a VPST and + // has a maximum of four proceeding instructions. All instructions within the + // block are predicated upon the vpr and we allow instructions to define the + // vpr within in the block too. + class VPTBlock { + std::unique_ptr<PredicatedMI> VPST; + PredicatedMI *Divergent = nullptr; + SmallVector<PredicatedMI, 4> Insts; + + public: + VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { + VPST = std::make_unique<PredicatedMI>(MI, Preds); + } + + void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { + LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI); + if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) { + Divergent = &Insts.back(); + LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI); + } + Insts.emplace_back(MI, Preds); + assert(Insts.size() <= 4 && "Too many instructions in VPT block!"); + } + + // Have we found an instruction within the block which defines the vpr? If + // so, not all the instructions in the block will have the same predicate. + bool HasNonUniformPredicate() const { + return Divergent != nullptr; + } + + // Is the given instruction part of the predicate set controlling the entry + // to the block. + bool IsPredicatedOn(MachineInstr *MI) const { + return VPST->Predicates.count(MI); + } + + // Is the given instruction the only predicate which controls the entry to + // the block. + bool IsOnlyPredicatedOn(MachineInstr *MI) const { + return IsPredicatedOn(MI) && VPST->Predicates.size() == 1; + } + + unsigned size() const { return Insts.size(); } + SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; } + MachineInstr *getVPST() const { return VPST->MI; } + PredicatedMI *getDivergent() const { return Divergent; } + }; + struct LowOverheadLoop { MachineLoop *ML = nullptr; @@ -46,39 +106,17 @@ namespace { MachineInstr *Dec = nullptr; MachineInstr *End = nullptr; MachineInstr *VCTP = nullptr; - SmallVector<MachineInstr*, 4> VPTUsers; + VPTBlock *CurrentBlock = nullptr; + SetVector<MachineInstr*> CurrentPredicate; + SmallVector<VPTBlock, 4> VPTBlocks; bool Revert = false; - bool FoundOneVCTP = false; bool CannotTailPredicate = false; LowOverheadLoop(MachineLoop *ML) : ML(ML) { MF = ML->getHeader()->getParent(); } - // For now, only support one vctp instruction. If we find multiple then - // we shouldn't perform tail predication. - void addVCTP(MachineInstr *MI) { - if (!VCTP) { - VCTP = MI; - FoundOneVCTP = true; - } else - FoundOneVCTP = false; - } - - // Check that nothing else is writing to VPR and record any insts - // reading the VPR. - void ScanForVPR(MachineInstr *MI) { - for (auto &MO : MI->operands()) { - if (!MO.isReg() || MO.getReg() != ARM::VPR) - continue; - if (MO.isUse()) - VPTUsers.push_back(MI); - if (MO.isDef()) { - CannotTailPredicate = true; - break; - } - } - } + bool RecordVPTBlocks(MachineInstr *MI); // If this is an MVE instruction, check that we know how to use tail // predication with it. @@ -86,6 +124,11 @@ namespace { if (CannotTailPredicate) return; + if (!RecordVPTBlocks(MI)) { + CannotTailPredicate = true; + return; + } + const MCInstrDesc &MCID = MI->getDesc(); uint64_t Flags = MCID.TSFlags; if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) @@ -100,7 +143,7 @@ namespace { bool IsTailPredicationLegal() const { // For now, let's keep things really simple and only support a single // block for tail predication. - return !Revert && FoundAllComponents() && FoundOneVCTP && + return !Revert && FoundAllComponents() && VCTP && !CannotTailPredicate && ML->getNumBlocks() == 1; } @@ -118,6 +161,8 @@ namespace { return Start && Dec && End; } + SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; } + // Return the loop iteration count, or the number of elements if we're tail // predicating. MachineOperand &getCount() { @@ -191,7 +236,7 @@ namespace { void RemoveLoopUpdate(LowOverheadLoop &LoLoop); - void RemoveVPTBlocks(LowOverheadLoop &LoLoop); + void ConvertVPTBlocks(LowOverheadLoop &LoLoop); MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); @@ -281,13 +326,39 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, } else LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); + if (!IsTailPredicationLegal()) { + LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n"); + return; + } + + // All predication within the loop should be based on vctp. If the block + // isn't predicated on entry, check whether the vctp is within the block + // and that all other instructions are then predicated on it. + for (auto &Block : VPTBlocks) { + if (Block.IsPredicatedOn(VCTP)) + continue; + if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) { + CannotTailPredicate = true; + return; + } + SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); + for (auto &PredMI : Insts) { + if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI)) + continue; + LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI + << " - which is predicated on:\n"; + for (auto *MI : PredMI.Predicates) + dbgs() << " - " << *MI; + ); + CannotTailPredicate = true; + return; + } + } + // For tail predication, we need to provide the number of elements, instead // of the iteration count, to the loop start instruction. The number of // elements is provided to the vctp instruction, so we need to check that // we can use this register at InsertPt. - if (!IsTailPredicationLegal()) - return; - Register NumElements = VCTP->getOperand(1).getReg(); // If the register is defined within loop, then we can't perform TP. @@ -338,9 +409,65 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, MBB = *MBB->pred_begin(); } - LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication to convert:\n"; - for (auto *MI : VPTUsers) - dbgs() << " - " << *MI;); + LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n"); +} + +bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) { + // Only support a single vctp. + if (isVCTP(MI) && VCTP) + return false; + + // Start a new vpt block when we discover a vpt. + if (MI->getOpcode() == ARM::MVE_VPST) { + VPTBlocks.emplace_back(MI, CurrentPredicate); + CurrentBlock = &VPTBlocks.back(); + return true; + } + + if (isVCTP(MI)) + VCTP = MI; + + unsigned VPROpNum = MI->getNumOperands() - 1; + bool IsUse = false; + if (MI->getOperand(VPROpNum).isReg() && + MI->getOperand(VPROpNum).getReg() == ARM::VPR && + MI->getOperand(VPROpNum).isUse()) { + // If this instruction is predicated by VPR, it will be its last + // operand. Also check that it's only 'Then' predicated. + if (!MI->getOperand(VPROpNum-1).isImm() || + MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) { + LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: " + << *MI); + return false; + } + CurrentBlock->addInst(MI, CurrentPredicate); + IsUse = true; + } + + bool IsDef = false; + for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) { + const MachineOperand &MO = MI->getOperand(i); + if (!MO.isReg() || MO.getReg() != ARM::VPR) + continue; + + if (MO.isDef()) { + CurrentPredicate.insert(MI); + IsDef = true; + } else { + LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI); + return false; + } + } + + // If we find a vpr def that is not already predicated on the vctp, we've + // got disjoint predicates that may not be equivalent when we do the + // conversion. + if (IsDef && !IsUse && VCTP && !isVCTP(MI)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI); + return false; + } + + return true; } bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { @@ -422,8 +549,6 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { LoLoop.End = &MI; else if (isLoopStart(MI)) LoLoop.Start = &MI; - else if (isVCTP(&MI)) - LoLoop.addVCTP(&MI); else if (MI.getDesc().isCall()) { // TODO: Though the call will require LE to execute again, does this // mean we should revert? Always executing LE hopefully should be @@ -431,10 +556,7 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { LoLoop.Revert = true; LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); } else { - // Once we've found a vctp, record the users of vpr and check there's - // no more vpr defs. - if (LoLoop.FoundOneVCTP) - LoLoop.ScanForVPR(&MI); + // Record VPR defs and build up their corresponding vpt blocks. // Check we know how to tail predicate any mve instructions. LoLoop.CheckTPValidity(&MI); } @@ -669,27 +791,81 @@ void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) { } } -void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) { - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); - LoLoop.VCTP->eraseFromParent(); +void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { + auto RemovePredicate = [](MachineInstr *MI) { + LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); + unsigned OpNum = MI->getNumOperands() - 1; + assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then && + "Expected Then predicate!"); + MI->getOperand(OpNum-1).setImm(ARMVCC::None); + MI->getOperand(OpNum).setReg(0); + }; - for (auto *MI : LoLoop.VPTUsers) { - if (MI->getOpcode() == ARM::MVE_VPST) { - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI); - MI->eraseFromParent(); - } else { - unsigned OpNum = MI->getNumOperands() - 1; - assert((MI->getOperand(OpNum).isReg() && - MI->getOperand(OpNum).getReg() == ARM::VPR) && - "Expected VPR"); - assert((MI->getOperand(OpNum-1).isImm() && - MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) && - "Expected Then predicate"); - MI->getOperand(OpNum-1).setImm(ARMVCC::None); - MI->getOperand(OpNum).setReg(0); - LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI); + // There are a few scenarios which we have to fix up: + // 1) A VPT block with is only predicated by the vctp and has no internal vpr + // defs. + // 2) A VPT block which is only predicated by the vctp but has an internal + // vpr def. + // 3) A VPT block which is predicated upon the vctp as well as another vpr + // def. + // 4) A VPT block which is not predicated upon a vctp, but contains it and + // all instructions within the block are predicated upon in. + + for (auto &Block : LoLoop.getVPTBlocks()) { + SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); + if (Block.HasNonUniformPredicate()) { + PredicatedMI *Divergent = Block.getDivergent(); + if (isVCTP(Divergent->MI)) { + // The vctp will be removed, so the size of the vpt block needs to be + // modified. + uint64_t Size = getARMVPTBlockMask(Block.size() - 1); + Block.getVPST()->getOperand(0).setImm(Size); + LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n"); + } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { + // The VPT block has a non-uniform predicate but it's entry is guarded + // only by a vctp, which means we: + // - Need to remove the original vpst. + // - Then need to unpredicate any following instructions, until + // we come across the divergent vpr def. + // - Insert a new vpst to predicate the instruction(s) that following + // the divergent vpr def. + // TODO: We could be producing more VPT blocks than necessary and could + // fold the newly created one into a proceeding one. + for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()), + E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I) + RemovePredicate(&*I); + + unsigned Size = 0; + auto E = MachineBasicBlock::reverse_iterator(Divergent->MI); + auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI); + MachineInstr *InsertAt = nullptr; + while (I != E) { + InsertAt = &*I; + ++Size; + ++I; + } + MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt, + InsertAt->getDebugLoc(), + TII->get(ARM::MVE_VPST)); + MIB.addImm(getARMVPTBlockMask(Size)); + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); + LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); + Block.getVPST()->eraseFromParent(); + } + } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { + // A vpt block which is only predicated upon vctp and has no internal vpr + // defs: + // - Remove vpst. + // - Unpredicate the remaining instructions. + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); + Block.getVPST()->eraseFromParent(); + for (auto &PredMI : Insts) + RemovePredicate(PredMI.MI); } } + + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); + LoLoop.VCTP->eraseFromParent(); } void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { @@ -743,7 +919,7 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { RemoveDeadBranch(LoLoop.End); if (LoLoop.IsTailPredicationLegal()) { RemoveLoopUpdate(LoLoop); - RemoveVPTBlocks(LoLoop); + ConvertVPTBlocks(LoLoop); } } } diff --git a/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp b/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp index 39d90d0b6db..c8b725f339e 100644 --- a/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp +++ b/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp @@ -137,23 +137,7 @@ bool MVEVPTBlock::InsertVPTBlocks(MachineBasicBlock &Block) { ++MBIter; }; - unsigned BlockMask = 0; - switch (VPTInstCnt) { - case 1: - BlockMask = VPTMaskValue::T; - break; - case 2: - BlockMask = VPTMaskValue::TT; - break; - case 3: - BlockMask = VPTMaskValue::TTT; - break; - case 4: - BlockMask = VPTMaskValue::TTTT; - break; - default: - llvm_unreachable("Unexpected number of instruction in a VPT block"); - }; + unsigned BlockMask = getARMVPTBlockMask(VPTInstCnt); // Search back for a VCMP that can be folded to create a VPT, or else create // a VPST directly diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h index 11cb1a162e2..27605422983 100644 --- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h +++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h @@ -91,6 +91,40 @@ namespace ARMVCC { Then, Else }; + + enum VPTMaskValue { + T = 8, // 0b1000 + TT = 4, // 0b0100 + TE = 12, // 0b1100 + TTT = 2, // 0b0010 + TTE = 6, // 0b0110 + TEE = 10, // 0b1010 + TET = 14, // 0b1110 + TTTT = 1, // 0b0001 + TTTE = 3, // 0b0011 + TTEE = 5, // 0b0101 + TTET = 7, // 0b0111 + TEEE = 9, // 0b1001 + TEET = 11, // 0b1011 + TETT = 13, // 0b1101 + TETE = 15 // 0b1111 + }; +} + +inline static unsigned getARMVPTBlockMask(unsigned NumInsts) { + switch (NumInsts) { + case 1: + return ARMVCC::T; + case 2: + return ARMVCC::TT; + case 3: + return ARMVCC::TTT; + case 4: + return ARMVCC::TTTT; + default: + break; + }; + llvm_unreachable("Unexpected number of instruction in a VPT block"); } inline static const char *ARMVPTPredToString(ARMVCC::VPTCodes CC) { |

