diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index c61a496b747..899b639fd12 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -254,6 +254,19 @@ static Type* ToVectorTy(Type *Scalar, unsigned VF) { return VectorType::get(Scalar, VF); } +/// A helper function that returns GEP instruction and knows to skip +/// 'bitcast'. +static GetElementPtrInst *getGEPInstruction(Value *Ptr) { + + if (isa<GetElementPtrInst>(Ptr)) + return cast<GetElementPtrInst>(Ptr); + + if (isa<BitCastInst>(Ptr) && + isa<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0))) + return cast<GetElementPtrInst>(cast<BitCastInst>(Ptr)->getOperand(0)); + return nullptr; +} + /// InnerLoopVectorizer vectorizes loops which contain only one basic /// block to a specified vectorization factor (VF). /// This class performs the widening of scalars into vectors, or multiple @@ -1943,7 +1956,7 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { return II.getConsecutiveDirection(); } - GetElementPtrInst *Gep = dyn_cast_or_null<GetElementPtrInst>(Ptr); + GetElementPtrInst *Gep = getGEPInstruction(Ptr); if (!Gep) return 0; @@ -2337,7 +2350,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { VectorParts &Entry = WidenMap.get(Instr); // Handle consecutive loads/stores. - GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr); + GetElementPtrInst *Gep = getGEPInstruction(Ptr); if (Gep && Legal->isInductionVariable(Gep->getPointerOperand())) { setDebugLocFromInst(Builder, Gep); Value *PtrOperand = Gep->getPointerOperand(); @@ -2397,7 +2410,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // We don't want to update the value in the map as it might be used in // another expression. So don't use a reference type for "StoredVal". VectorParts StoredVal = getVectorValue(SI->getValueOperand()); - + for (unsigned Part = 0; Part < UF; ++Part) { // Calculate the pointer for the specific unroll-part. Value *PartPtr = |