diff options
author | Sam Parker <sam.parker@arm.com> | 2019-12-20 08:42:11 +0000 |
---|---|---|
committer | Sam Parker <sam.parker@arm.com> | 2019-12-20 08:42:11 +0000 |
commit | 404251833521770732646c4348f774b94b40df72 (patch) | |
tree | 413e2f78e7da7a7efd42bd9812a1cb8025f7d9e1 /llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp | |
parent | 4f0fe6b97e4463d5c8571ac71b23c63387251444 (diff) | |
download | bcm5719-llvm-404251833521770732646c4348f774b94b40df72.tar.gz bcm5719-llvm-404251833521770732646c4348f774b94b40df72.zip |
[ARM][MVE] Tail predicate in the presence of vcmp
Record the discovered VPT blocks while checking for validity and, for
now, only handle blocks that begin with VPST and not VPT. We're now
allowing more than one instruction to define vpr, but each block must
somehow be predicated using the vctp. This leaves us with several
scenarios which need fixing up:
1) A VPT block with is only predicated by the vctp and has no
internal vpr defs.
2) A VPT block which is only predicated by the vctp but has an
internal vpr def.
3) A VPT block which is predicated upon the vctp as well as another
vpr def.
4) A VPT block which is not predicated upon a vctp, but contains it
and all instructions within the block are predicated upon in.
The changes needed are, for:
1) The easy one, just remove the vpst and unpredicate the
instructions in the block.
2) Remove the vpst and unpredicate the instructions up to the
internal vpr def. Need insert a new vpst to predicate the
remaining instructions.
3) No nothing.
4) The vctp will be inside a vpt and the instruction will be removed,
so adjust the size of the mask on the vpst.
Differential Revision: https://reviews.llvm.org/D71107
Diffstat (limited to 'llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp')
-rw-r--r-- | llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp | 294 |
1 files changed, 235 insertions, 59 deletions
diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp index e3a05408343..ec62a6975f0 100644 --- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -22,6 +22,7 @@ #include "ARMBaseRegisterInfo.h" #include "ARMBasicBlockInfo.h" #include "ARMSubtarget.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineLoopUtils.h" @@ -37,6 +38,65 @@ using namespace llvm; namespace { + struct PredicatedMI { + MachineInstr *MI = nullptr; + SetVector<MachineInstr*> Predicates; + + public: + PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) : + MI(I) { + Predicates.insert(Preds.begin(), Preds.end()); + } + }; + + // Represent a VPT block, a list of instructions that begins with a VPST and + // has a maximum of four proceeding instructions. All instructions within the + // block are predicated upon the vpr and we allow instructions to define the + // vpr within in the block too. + class VPTBlock { + std::unique_ptr<PredicatedMI> VPST; + PredicatedMI *Divergent = nullptr; + SmallVector<PredicatedMI, 4> Insts; + + public: + VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { + VPST = std::make_unique<PredicatedMI>(MI, Preds); + } + + void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { + LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI); + if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) { + Divergent = &Insts.back(); + LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI); + } + Insts.emplace_back(MI, Preds); + assert(Insts.size() <= 4 && "Too many instructions in VPT block!"); + } + + // Have we found an instruction within the block which defines the vpr? If + // so, not all the instructions in the block will have the same predicate. + bool HasNonUniformPredicate() const { + return Divergent != nullptr; + } + + // Is the given instruction part of the predicate set controlling the entry + // to the block. + bool IsPredicatedOn(MachineInstr *MI) const { + return VPST->Predicates.count(MI); + } + + // Is the given instruction the only predicate which controls the entry to + // the block. + bool IsOnlyPredicatedOn(MachineInstr *MI) const { + return IsPredicatedOn(MI) && VPST->Predicates.size() == 1; + } + + unsigned size() const { return Insts.size(); } + SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; } + MachineInstr *getVPST() const { return VPST->MI; } + PredicatedMI *getDivergent() const { return Divergent; } + }; + struct LowOverheadLoop { MachineLoop *ML = nullptr; @@ -46,39 +106,17 @@ namespace { MachineInstr *Dec = nullptr; MachineInstr *End = nullptr; MachineInstr *VCTP = nullptr; - SmallVector<MachineInstr*, 4> VPTUsers; + VPTBlock *CurrentBlock = nullptr; + SetVector<MachineInstr*> CurrentPredicate; + SmallVector<VPTBlock, 4> VPTBlocks; bool Revert = false; - bool FoundOneVCTP = false; bool CannotTailPredicate = false; LowOverheadLoop(MachineLoop *ML) : ML(ML) { MF = ML->getHeader()->getParent(); } - // For now, only support one vctp instruction. If we find multiple then - // we shouldn't perform tail predication. - void addVCTP(MachineInstr *MI) { - if (!VCTP) { - VCTP = MI; - FoundOneVCTP = true; - } else - FoundOneVCTP = false; - } - - // Check that nothing else is writing to VPR and record any insts - // reading the VPR. - void ScanForVPR(MachineInstr *MI) { - for (auto &MO : MI->operands()) { - if (!MO.isReg() || MO.getReg() != ARM::VPR) - continue; - if (MO.isUse()) - VPTUsers.push_back(MI); - if (MO.isDef()) { - CannotTailPredicate = true; - break; - } - } - } + bool RecordVPTBlocks(MachineInstr *MI); // If this is an MVE instruction, check that we know how to use tail // predication with it. @@ -86,6 +124,11 @@ namespace { if (CannotTailPredicate) return; + if (!RecordVPTBlocks(MI)) { + CannotTailPredicate = true; + return; + } + const MCInstrDesc &MCID = MI->getDesc(); uint64_t Flags = MCID.TSFlags; if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) @@ -100,7 +143,7 @@ namespace { bool IsTailPredicationLegal() const { // For now, let's keep things really simple and only support a single // block for tail predication. - return !Revert && FoundAllComponents() && FoundOneVCTP && + return !Revert && FoundAllComponents() && VCTP && !CannotTailPredicate && ML->getNumBlocks() == 1; } @@ -118,6 +161,8 @@ namespace { return Start && Dec && End; } + SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; } + // Return the loop iteration count, or the number of elements if we're tail // predicating. MachineOperand &getCount() { @@ -191,7 +236,7 @@ namespace { void RemoveLoopUpdate(LowOverheadLoop &LoLoop); - void RemoveVPTBlocks(LowOverheadLoop &LoLoop); + void ConvertVPTBlocks(LowOverheadLoop &LoLoop); MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); @@ -281,13 +326,39 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, } else LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); + if (!IsTailPredicationLegal()) { + LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n"); + return; + } + + // All predication within the loop should be based on vctp. If the block + // isn't predicated on entry, check whether the vctp is within the block + // and that all other instructions are then predicated on it. + for (auto &Block : VPTBlocks) { + if (Block.IsPredicatedOn(VCTP)) + continue; + if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) { + CannotTailPredicate = true; + return; + } + SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); + for (auto &PredMI : Insts) { + if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI)) + continue; + LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI + << " - which is predicated on:\n"; + for (auto *MI : PredMI.Predicates) + dbgs() << " - " << *MI; + ); + CannotTailPredicate = true; + return; + } + } + // For tail predication, we need to provide the number of elements, instead // of the iteration count, to the loop start instruction. The number of // elements is provided to the vctp instruction, so we need to check that // we can use this register at InsertPt. - if (!IsTailPredicationLegal()) - return; - Register NumElements = VCTP->getOperand(1).getReg(); // If the register is defined within loop, then we can't perform TP. @@ -338,9 +409,65 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, MBB = *MBB->pred_begin(); } - LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication to convert:\n"; - for (auto *MI : VPTUsers) - dbgs() << " - " << *MI;); + LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n"); +} + +bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) { + // Only support a single vctp. + if (isVCTP(MI) && VCTP) + return false; + + // Start a new vpt block when we discover a vpt. + if (MI->getOpcode() == ARM::MVE_VPST) { + VPTBlocks.emplace_back(MI, CurrentPredicate); + CurrentBlock = &VPTBlocks.back(); + return true; + } + + if (isVCTP(MI)) + VCTP = MI; + + unsigned VPROpNum = MI->getNumOperands() - 1; + bool IsUse = false; + if (MI->getOperand(VPROpNum).isReg() && + MI->getOperand(VPROpNum).getReg() == ARM::VPR && + MI->getOperand(VPROpNum).isUse()) { + // If this instruction is predicated by VPR, it will be its last + // operand. Also check that it's only 'Then' predicated. + if (!MI->getOperand(VPROpNum-1).isImm() || + MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) { + LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: " + << *MI); + return false; + } + CurrentBlock->addInst(MI, CurrentPredicate); + IsUse = true; + } + + bool IsDef = false; + for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) { + const MachineOperand &MO = MI->getOperand(i); + if (!MO.isReg() || MO.getReg() != ARM::VPR) + continue; + + if (MO.isDef()) { + CurrentPredicate.insert(MI); + IsDef = true; + } else { + LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI); + return false; + } + } + + // If we find a vpr def that is not already predicated on the vctp, we've + // got disjoint predicates that may not be equivalent when we do the + // conversion. + if (IsDef && !IsUse && VCTP && !isVCTP(MI)) { + LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI); + return false; + } + + return true; } bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { @@ -422,8 +549,6 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { LoLoop.End = &MI; else if (isLoopStart(MI)) LoLoop.Start = &MI; - else if (isVCTP(&MI)) - LoLoop.addVCTP(&MI); else if (MI.getDesc().isCall()) { // TODO: Though the call will require LE to execute again, does this // mean we should revert? Always executing LE hopefully should be @@ -431,10 +556,7 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { LoLoop.Revert = true; LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); } else { - // Once we've found a vctp, record the users of vpr and check there's - // no more vpr defs. - if (LoLoop.FoundOneVCTP) - LoLoop.ScanForVPR(&MI); + // Record VPR defs and build up their corresponding vpt blocks. // Check we know how to tail predicate any mve instructions. LoLoop.CheckTPValidity(&MI); } @@ -669,27 +791,81 @@ void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) { } } -void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) { - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); - LoLoop.VCTP->eraseFromParent(); +void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { + auto RemovePredicate = [](MachineInstr *MI) { + LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); + unsigned OpNum = MI->getNumOperands() - 1; + assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then && + "Expected Then predicate!"); + MI->getOperand(OpNum-1).setImm(ARMVCC::None); + MI->getOperand(OpNum).setReg(0); + }; - for (auto *MI : LoLoop.VPTUsers) { - if (MI->getOpcode() == ARM::MVE_VPST) { - LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI); - MI->eraseFromParent(); - } else { - unsigned OpNum = MI->getNumOperands() - 1; - assert((MI->getOperand(OpNum).isReg() && - MI->getOperand(OpNum).getReg() == ARM::VPR) && - "Expected VPR"); - assert((MI->getOperand(OpNum-1).isImm() && - MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) && - "Expected Then predicate"); - MI->getOperand(OpNum-1).setImm(ARMVCC::None); - MI->getOperand(OpNum).setReg(0); - LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI); + // There are a few scenarios which we have to fix up: + // 1) A VPT block with is only predicated by the vctp and has no internal vpr + // defs. + // 2) A VPT block which is only predicated by the vctp but has an internal + // vpr def. + // 3) A VPT block which is predicated upon the vctp as well as another vpr + // def. + // 4) A VPT block which is not predicated upon a vctp, but contains it and + // all instructions within the block are predicated upon in. + + for (auto &Block : LoLoop.getVPTBlocks()) { + SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); + if (Block.HasNonUniformPredicate()) { + PredicatedMI *Divergent = Block.getDivergent(); + if (isVCTP(Divergent->MI)) { + // The vctp will be removed, so the size of the vpt block needs to be + // modified. + uint64_t Size = getARMVPTBlockMask(Block.size() - 1); + Block.getVPST()->getOperand(0).setImm(Size); + LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n"); + } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { + // The VPT block has a non-uniform predicate but it's entry is guarded + // only by a vctp, which means we: + // - Need to remove the original vpst. + // - Then need to unpredicate any following instructions, until + // we come across the divergent vpr def. + // - Insert a new vpst to predicate the instruction(s) that following + // the divergent vpr def. + // TODO: We could be producing more VPT blocks than necessary and could + // fold the newly created one into a proceeding one. + for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()), + E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I) + RemovePredicate(&*I); + + unsigned Size = 0; + auto E = MachineBasicBlock::reverse_iterator(Divergent->MI); + auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI); + MachineInstr *InsertAt = nullptr; + while (I != E) { + InsertAt = &*I; + ++Size; + ++I; + } + MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt, + InsertAt->getDebugLoc(), + TII->get(ARM::MVE_VPST)); + MIB.addImm(getARMVPTBlockMask(Size)); + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); + LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); + Block.getVPST()->eraseFromParent(); + } + } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { + // A vpt block which is only predicated upon vctp and has no internal vpr + // defs: + // - Remove vpst. + // - Unpredicate the remaining instructions. + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); + Block.getVPST()->eraseFromParent(); + for (auto &PredMI : Insts) + RemovePredicate(PredMI.MI); } } + + LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); + LoLoop.VCTP->eraseFromParent(); } void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { @@ -743,7 +919,7 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { RemoveDeadBranch(LoLoop.End); if (LoLoop.IsTailPredicationLegal()) { RemoveLoopUpdate(LoLoop); - RemoveVPTBlocks(LoLoop); + ConvertVPTBlocks(LoLoop); } } } |