summaryrefslogtreecommitdiffstats
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp294
-rw-r--r--llvm/lib/Target/ARM/MVEVPTBlockPass.cpp18
-rw-r--r--llvm/lib/Target/ARM/Utils/ARMBaseInfo.h34
3 files changed, 270 insertions, 76 deletions
diff --git a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
index e3a05408343..ec62a6975f0 100644
--- a/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
+++ b/llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
@@ -22,6 +22,7 @@
#include "ARMBaseRegisterInfo.h"
#include "ARMBasicBlockInfo.h"
#include "ARMSubtarget.h"
+#include "llvm/ADT/SetOperations.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineLoopUtils.h"
@@ -37,6 +38,65 @@ using namespace llvm;
namespace {
+ struct PredicatedMI {
+ MachineInstr *MI = nullptr;
+ SetVector<MachineInstr*> Predicates;
+
+ public:
+ PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) :
+ MI(I) {
+ Predicates.insert(Preds.begin(), Preds.end());
+ }
+ };
+
+ // Represent a VPT block, a list of instructions that begins with a VPST and
+ // has a maximum of four proceeding instructions. All instructions within the
+ // block are predicated upon the vpr and we allow instructions to define the
+ // vpr within in the block too.
+ class VPTBlock {
+ std::unique_ptr<PredicatedMI> VPST;
+ PredicatedMI *Divergent = nullptr;
+ SmallVector<PredicatedMI, 4> Insts;
+
+ public:
+ VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
+ VPST = std::make_unique<PredicatedMI>(MI, Preds);
+ }
+
+ void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI);
+ if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) {
+ Divergent = &Insts.back();
+ LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI);
+ }
+ Insts.emplace_back(MI, Preds);
+ assert(Insts.size() <= 4 && "Too many instructions in VPT block!");
+ }
+
+ // Have we found an instruction within the block which defines the vpr? If
+ // so, not all the instructions in the block will have the same predicate.
+ bool HasNonUniformPredicate() const {
+ return Divergent != nullptr;
+ }
+
+ // Is the given instruction part of the predicate set controlling the entry
+ // to the block.
+ bool IsPredicatedOn(MachineInstr *MI) const {
+ return VPST->Predicates.count(MI);
+ }
+
+ // Is the given instruction the only predicate which controls the entry to
+ // the block.
+ bool IsOnlyPredicatedOn(MachineInstr *MI) const {
+ return IsPredicatedOn(MI) && VPST->Predicates.size() == 1;
+ }
+
+ unsigned size() const { return Insts.size(); }
+ SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; }
+ MachineInstr *getVPST() const { return VPST->MI; }
+ PredicatedMI *getDivergent() const { return Divergent; }
+ };
+
struct LowOverheadLoop {
MachineLoop *ML = nullptr;
@@ -46,39 +106,17 @@ namespace {
MachineInstr *Dec = nullptr;
MachineInstr *End = nullptr;
MachineInstr *VCTP = nullptr;
- SmallVector<MachineInstr*, 4> VPTUsers;
+ VPTBlock *CurrentBlock = nullptr;
+ SetVector<MachineInstr*> CurrentPredicate;
+ SmallVector<VPTBlock, 4> VPTBlocks;
bool Revert = false;
- bool FoundOneVCTP = false;
bool CannotTailPredicate = false;
LowOverheadLoop(MachineLoop *ML) : ML(ML) {
MF = ML->getHeader()->getParent();
}
- // For now, only support one vctp instruction. If we find multiple then
- // we shouldn't perform tail predication.
- void addVCTP(MachineInstr *MI) {
- if (!VCTP) {
- VCTP = MI;
- FoundOneVCTP = true;
- } else
- FoundOneVCTP = false;
- }
-
- // Check that nothing else is writing to VPR and record any insts
- // reading the VPR.
- void ScanForVPR(MachineInstr *MI) {
- for (auto &MO : MI->operands()) {
- if (!MO.isReg() || MO.getReg() != ARM::VPR)
- continue;
- if (MO.isUse())
- VPTUsers.push_back(MI);
- if (MO.isDef()) {
- CannotTailPredicate = true;
- break;
- }
- }
- }
+ bool RecordVPTBlocks(MachineInstr *MI);
// If this is an MVE instruction, check that we know how to use tail
// predication with it.
@@ -86,6 +124,11 @@ namespace {
if (CannotTailPredicate)
return;
+ if (!RecordVPTBlocks(MI)) {
+ CannotTailPredicate = true;
+ return;
+ }
+
const MCInstrDesc &MCID = MI->getDesc();
uint64_t Flags = MCID.TSFlags;
if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE)
@@ -100,7 +143,7 @@ namespace {
bool IsTailPredicationLegal() const {
// For now, let's keep things really simple and only support a single
// block for tail predication.
- return !Revert && FoundAllComponents() && FoundOneVCTP &&
+ return !Revert && FoundAllComponents() && VCTP &&
!CannotTailPredicate && ML->getNumBlocks() == 1;
}
@@ -118,6 +161,8 @@ namespace {
return Start && Dec && End;
}
+ SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; }
+
// Return the loop iteration count, or the number of elements if we're tail
// predicating.
MachineOperand &getCount() {
@@ -191,7 +236,7 @@ namespace {
void RemoveLoopUpdate(LowOverheadLoop &LoLoop);
- void RemoveVPTBlocks(LowOverheadLoop &LoLoop);
+ void ConvertVPTBlocks(LowOverheadLoop &LoLoop);
MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop);
@@ -281,13 +326,39 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
} else
LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt);
+ if (!IsTailPredicationLegal()) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n");
+ return;
+ }
+
+ // All predication within the loop should be based on vctp. If the block
+ // isn't predicated on entry, check whether the vctp is within the block
+ // and that all other instructions are then predicated on it.
+ for (auto &Block : VPTBlocks) {
+ if (Block.IsPredicatedOn(VCTP))
+ continue;
+ if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) {
+ CannotTailPredicate = true;
+ return;
+ }
+ SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
+ for (auto &PredMI : Insts) {
+ if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI))
+ continue;
+ LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI
+ << " - which is predicated on:\n";
+ for (auto *MI : PredMI.Predicates)
+ dbgs() << " - " << *MI;
+ );
+ CannotTailPredicate = true;
+ return;
+ }
+ }
+
// For tail predication, we need to provide the number of elements, instead
// of the iteration count, to the loop start instruction. The number of
// elements is provided to the vctp instruction, so we need to check that
// we can use this register at InsertPt.
- if (!IsTailPredicationLegal())
- return;
-
Register NumElements = VCTP->getOperand(1).getReg();
// If the register is defined within loop, then we can't perform TP.
@@ -338,9 +409,65 @@ void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils,
MBB = *MBB->pred_begin();
}
- LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication to convert:\n";
- for (auto *MI : VPTUsers)
- dbgs() << " - " << *MI;);
+ LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n");
+}
+
+bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) {
+ // Only support a single vctp.
+ if (isVCTP(MI) && VCTP)
+ return false;
+
+ // Start a new vpt block when we discover a vpt.
+ if (MI->getOpcode() == ARM::MVE_VPST) {
+ VPTBlocks.emplace_back(MI, CurrentPredicate);
+ CurrentBlock = &VPTBlocks.back();
+ return true;
+ }
+
+ if (isVCTP(MI))
+ VCTP = MI;
+
+ unsigned VPROpNum = MI->getNumOperands() - 1;
+ bool IsUse = false;
+ if (MI->getOperand(VPROpNum).isReg() &&
+ MI->getOperand(VPROpNum).getReg() == ARM::VPR &&
+ MI->getOperand(VPROpNum).isUse()) {
+ // If this instruction is predicated by VPR, it will be its last
+ // operand. Also check that it's only 'Then' predicated.
+ if (!MI->getOperand(VPROpNum-1).isImm() ||
+ MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: "
+ << *MI);
+ return false;
+ }
+ CurrentBlock->addInst(MI, CurrentPredicate);
+ IsUse = true;
+ }
+
+ bool IsDef = false;
+ for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) {
+ const MachineOperand &MO = MI->getOperand(i);
+ if (!MO.isReg() || MO.getReg() != ARM::VPR)
+ continue;
+
+ if (MO.isDef()) {
+ CurrentPredicate.insert(MI);
+ IsDef = true;
+ } else {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI);
+ return false;
+ }
+ }
+
+ // If we find a vpr def that is not already predicated on the vctp, we've
+ // got disjoint predicates that may not be equivalent when we do the
+ // conversion.
+ if (IsDef && !IsUse && VCTP && !isVCTP(MI)) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI);
+ return false;
+ }
+
+ return true;
}
bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) {
@@ -422,8 +549,6 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
LoLoop.End = &MI;
else if (isLoopStart(MI))
LoLoop.Start = &MI;
- else if (isVCTP(&MI))
- LoLoop.addVCTP(&MI);
else if (MI.getDesc().isCall()) {
// TODO: Though the call will require LE to execute again, does this
// mean we should revert? Always executing LE hopefully should be
@@ -431,10 +556,7 @@ bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) {
LoLoop.Revert = true;
LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n");
} else {
- // Once we've found a vctp, record the users of vpr and check there's
- // no more vpr defs.
- if (LoLoop.FoundOneVCTP)
- LoLoop.ScanForVPR(&MI);
+ // Record VPR defs and build up their corresponding vpt blocks.
// Check we know how to tail predicate any mve instructions.
LoLoop.CheckTPValidity(&MI);
}
@@ -669,27 +791,81 @@ void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) {
}
}
-void ARMLowOverheadLoops::RemoveVPTBlocks(LowOverheadLoop &LoLoop) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
- LoLoop.VCTP->eraseFromParent();
+void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) {
+ auto RemovePredicate = [](MachineInstr *MI) {
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI);
+ unsigned OpNum = MI->getNumOperands() - 1;
+ assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then &&
+ "Expected Then predicate!");
+ MI->getOperand(OpNum-1).setImm(ARMVCC::None);
+ MI->getOperand(OpNum).setReg(0);
+ };
- for (auto *MI : LoLoop.VPTUsers) {
- if (MI->getOpcode() == ARM::MVE_VPST) {
- LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *MI);
- MI->eraseFromParent();
- } else {
- unsigned OpNum = MI->getNumOperands() - 1;
- assert((MI->getOperand(OpNum).isReg() &&
- MI->getOperand(OpNum).getReg() == ARM::VPR) &&
- "Expected VPR");
- assert((MI->getOperand(OpNum-1).isImm() &&
- MI->getOperand(OpNum-1).getImm() == ARMVCC::Then) &&
- "Expected Then predicate");
- MI->getOperand(OpNum-1).setImm(ARMVCC::None);
- MI->getOperand(OpNum).setReg(0);
- LLVM_DEBUG(dbgs() << "ARM Loops: Removed predicate from: " << *MI);
+ // There are a few scenarios which we have to fix up:
+ // 1) A VPT block with is only predicated by the vctp and has no internal vpr
+ // defs.
+ // 2) A VPT block which is only predicated by the vctp but has an internal
+ // vpr def.
+ // 3) A VPT block which is predicated upon the vctp as well as another vpr
+ // def.
+ // 4) A VPT block which is not predicated upon a vctp, but contains it and
+ // all instructions within the block are predicated upon in.
+
+ for (auto &Block : LoLoop.getVPTBlocks()) {
+ SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts();
+ if (Block.HasNonUniformPredicate()) {
+ PredicatedMI *Divergent = Block.getDivergent();
+ if (isVCTP(Divergent->MI)) {
+ // The vctp will be removed, so the size of the vpt block needs to be
+ // modified.
+ uint64_t Size = getARMVPTBlockMask(Block.size() - 1);
+ Block.getVPST()->getOperand(0).setImm(Size);
+ LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n");
+ } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
+ // The VPT block has a non-uniform predicate but it's entry is guarded
+ // only by a vctp, which means we:
+ // - Need to remove the original vpst.
+ // - Then need to unpredicate any following instructions, until
+ // we come across the divergent vpr def.
+ // - Insert a new vpst to predicate the instruction(s) that following
+ // the divergent vpr def.
+ // TODO: We could be producing more VPT blocks than necessary and could
+ // fold the newly created one into a proceeding one.
+ for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()),
+ E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I)
+ RemovePredicate(&*I);
+
+ unsigned Size = 0;
+ auto E = MachineBasicBlock::reverse_iterator(Divergent->MI);
+ auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI);
+ MachineInstr *InsertAt = nullptr;
+ while (I != E) {
+ InsertAt = &*I;
+ ++Size;
+ ++I;
+ }
+ MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt,
+ InsertAt->getDebugLoc(),
+ TII->get(ARM::MVE_VPST));
+ MIB.addImm(getARMVPTBlockMask(Size));
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
+ LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB);
+ Block.getVPST()->eraseFromParent();
+ }
+ } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) {
+ // A vpt block which is only predicated upon vctp and has no internal vpr
+ // defs:
+ // - Remove vpst.
+ // - Unpredicate the remaining instructions.
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST());
+ Block.getVPST()->eraseFromParent();
+ for (auto &PredMI : Insts)
+ RemovePredicate(PredMI.MI);
}
}
+
+ LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP);
+ LoLoop.VCTP->eraseFromParent();
}
void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
@@ -743,7 +919,7 @@ void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) {
RemoveDeadBranch(LoLoop.End);
if (LoLoop.IsTailPredicationLegal()) {
RemoveLoopUpdate(LoLoop);
- RemoveVPTBlocks(LoLoop);
+ ConvertVPTBlocks(LoLoop);
}
}
}
diff --git a/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp b/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp
index 39d90d0b6db..c8b725f339e 100644
--- a/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp
+++ b/llvm/lib/Target/ARM/MVEVPTBlockPass.cpp
@@ -137,23 +137,7 @@ bool MVEVPTBlock::InsertVPTBlocks(MachineBasicBlock &Block) {
++MBIter;
};
- unsigned BlockMask = 0;
- switch (VPTInstCnt) {
- case 1:
- BlockMask = VPTMaskValue::T;
- break;
- case 2:
- BlockMask = VPTMaskValue::TT;
- break;
- case 3:
- BlockMask = VPTMaskValue::TTT;
- break;
- case 4:
- BlockMask = VPTMaskValue::TTTT;
- break;
- default:
- llvm_unreachable("Unexpected number of instruction in a VPT block");
- };
+ unsigned BlockMask = getARMVPTBlockMask(VPTInstCnt);
// Search back for a VCMP that can be folded to create a VPT, or else create
// a VPST directly
diff --git a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h
index 11cb1a162e2..27605422983 100644
--- a/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h
+++ b/llvm/lib/Target/ARM/Utils/ARMBaseInfo.h
@@ -91,6 +91,40 @@ namespace ARMVCC {
Then,
Else
};
+
+ enum VPTMaskValue {
+ T = 8, // 0b1000
+ TT = 4, // 0b0100
+ TE = 12, // 0b1100
+ TTT = 2, // 0b0010
+ TTE = 6, // 0b0110
+ TEE = 10, // 0b1010
+ TET = 14, // 0b1110
+ TTTT = 1, // 0b0001
+ TTTE = 3, // 0b0011
+ TTEE = 5, // 0b0101
+ TTET = 7, // 0b0111
+ TEEE = 9, // 0b1001
+ TEET = 11, // 0b1011
+ TETT = 13, // 0b1101
+ TETE = 15 // 0b1111
+ };
+}
+
+inline static unsigned getARMVPTBlockMask(unsigned NumInsts) {
+ switch (NumInsts) {
+ case 1:
+ return ARMVCC::T;
+ case 2:
+ return ARMVCC::TT;
+ case 3:
+ return ARMVCC::TTT;
+ case 4:
+ return ARMVCC::TTTT;
+ default:
+ break;
+ };
+ llvm_unreachable("Unexpected number of instruction in a VPT block");
}
inline static const char *ARMVPTPredToString(ARMVCC::VPTCodes CC) {
OpenPOWER on IntegriCloud