diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 732a786ceb7..ada5bc9627d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -593,6 +593,58 @@ static bool isShuffleEquivalentToSelect(ShuffleVectorInst &Shuf) { return true; } +// Turn a chain of inserts that splats a value into a canonical insert + shuffle +// splat. That is: +// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... -> +// shufflevector(insertelt(X, %k, 0), undef, zero) +static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { + // We are interested in the last insert in a chain. So, if this insert + // has a single user, and that user is an insert, bail. + if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back())) + return nullptr; + + VectorType *VT = cast<VectorType>(InsElt.getType()); + int NumElements = VT->getNumElements(); + + // Do not try to do this for a one-element vector, since that's a nop, + // and will cause an inf-loop. + if (NumElements == 1) + return nullptr; + + Value *SplatVal = InsElt.getOperand(1); + InsertElementInst *CurrIE = &InsElt; + SmallVector<bool, 16> ElementPresent(NumElements, false); + + // Walk the chain backwards, keeping track of which indices we inserted into, + // until we hit something that isn't an insert of the splatted value. + while (CurrIE) { + ConstantInt *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); + if (!Idx || CurrIE->getOperand(1) != SplatVal) + return nullptr; + + // Check none of the intermediate steps have any additional uses. + if ((CurrIE != &InsElt) && !CurrIE->hasOneUse()) + return nullptr; + + ElementPresent[Idx->getZExtValue()] = true; + CurrIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + } + + // Make sure we've seen an insert into every element. + if (llvm::any_of(ElementPresent, [](bool Present) { return !Present; })) + return nullptr; + + // All right, create the insert + shuffle. + Instruction *InsertFirst = InsertElementInst::Create( + UndefValue::get(VT), SplatVal, + ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), "", &InsElt); + + Constant *ZeroMask = ConstantAggregateZero::get( + VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); + + return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask); +} + /// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex /// --> shufflevector X, CVec', Mask' static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { @@ -754,6 +806,11 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) return Shuf; + // Turn a sequence of inserts that broadcasts a scalar into a single + // insert + shufflevector. + if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) + return Broadcast; + return nullptr; } |