diff options
Diffstat (limited to 'llvm/lib/Target/ARM/MVETailPredication.cpp')
-rw-r--r-- | llvm/lib/Target/ARM/MVETailPredication.cpp | 68 |
1 files changed, 59 insertions, 9 deletions
diff --git a/llvm/lib/Target/ARM/MVETailPredication.cpp b/llvm/lib/Target/ARM/MVETailPredication.cpp index 00791841566..844eafbcb38 100644 --- a/llvm/lib/Target/ARM/MVETailPredication.cpp +++ b/llvm/lib/Target/ARM/MVETailPredication.cpp @@ -84,7 +84,7 @@ private: /// Is the icmp that generates an i1 vector, based upon a loop counter /// and a limit that is defined outside the loop. - bool isTailPredicate(Value *Predicate, Value *NumElements); + bool isTailPredicate(Instruction *Predicate, Value *NumElements); }; } // end namespace @@ -178,7 +178,7 @@ bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { return Changed; } -bool MVETailPredication::isTailPredicate(Value *V, Value *NumElements) { +bool MVETailPredication::isTailPredicate(Instruction *I, Value *NumElements) { // Look for the following: // %trip.count.minus.1 = add i32 %N, -1 @@ -206,7 +206,7 @@ bool MVETailPredication::isTailPredicate(Value *V, Value *NumElements) { Instruction *Induction = nullptr; // The vector icmp - if (!match(V, m_ICmp(Pred, m_Instruction(Induction), + if (!match(I, m_ICmp(Pred, m_Instruction(Induction), m_Instruction(Shuffle))) || Pred != ICmpInst::ICMP_ULE || !L->isLoopInvariant(Shuffle)) return false; @@ -390,6 +390,55 @@ Value* MVETailPredication::ComputeElements(Value *TripCount, return Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); } +// Look through the exit block to see whether there's a duplicate predicate +// instruction. This can happen when we need to perform a select on values +// from the last and previous iteration. Instead of doing a straight +// replacement of that predicate with the vctp, clone the vctp and place it +// in the block. This means that the VPR doesn't have to be live into the +// exit block which should make it easier to convert this loop into a proper +// tail predicated loop. +static void Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, + SetVector<Instruction*> &MaybeDead, Loop *L) { + if (BasicBlock *Exit = L->getUniqueExitBlock()) { + for (auto &Pair : NewPredicates) { + Instruction *OldPred = Pair.first; + Instruction *NewPred = Pair.second; + + for (auto &I : *Exit) { + if (I.isSameOperationAs(OldPred)) { + Instruction *PredClone = NewPred->clone(); + PredClone->insertBefore(&I); + I.replaceAllUsesWith(PredClone); + MaybeDead.insert(&I); + break; + } + } + } + } + + // Drop references and add operands to check for dead. + SmallPtrSet<Instruction*, 4> Dead; + while (!MaybeDead.empty()) { + auto *I = MaybeDead.front(); + MaybeDead.remove(I); + if (I->hasNUsesOrMore(1)) + continue; + + for (auto &U : I->operands()) { + if (auto *OpI = dyn_cast<Instruction>(U)) + MaybeDead.insert(OpI); + } + I->dropAllReferences(); + Dead.insert(I); + } + + for (auto *I : Dead) + I->eraseFromParent(); + + for (auto I : L->blocks()) + DeleteDeadPHIs(I); +} + bool MVETailPredication::TryConvert(Value *TripCount) { if (!IsPredicatedVectorLoop()) return false; @@ -400,13 +449,14 @@ bool MVETailPredication::TryConvert(Value *TripCount) { // operand is generated from an induction variable. Module *M = L->getHeader()->getModule(); Type *Ty = IntegerType::get(M->getContext(), 32); - SmallPtrSet<Value*, 4> Predicates; + SetVector<Instruction*> Predicates; + DenseMap<Instruction*, Instruction*> NewPredicates; for (auto *I : MaskedInsts) { Intrinsic::ID ID = I->getIntrinsicID(); unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; - Value *Predicate = I->getArgOperand(PredOp); - if (Predicates.count(Predicate)) + auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); + if (!Predicate || Predicates.count(Predicate)) continue; VectorType *VecTy = getVectorType(I); @@ -445,6 +495,7 @@ bool MVETailPredication::TryConvert(Value *TripCount) { Value *Remaining = Builder.CreateSub(Processed, Factor); Value *TailPredicate = Builder.CreateCall(VCTP, Remaining); Predicate->replaceAllUsesWith(TailPredicate); + NewPredicates[Predicate] = cast<Instruction>(TailPredicate); // Add the incoming value to the new phi. Processed->addIncoming(Remaining, L->getLoopLatch()); @@ -453,9 +504,8 @@ bool MVETailPredication::TryConvert(Value *TripCount) { << "TP: Inserted VCTP: " << *TailPredicate << "\n"); } - for (auto I : L->blocks()) - DeleteDeadPHIs(I); - + // Now clean up. + Cleanup(NewPredicates, Predicates, L); return true; } |