diff options
Diffstat (limited to 'llvm/lib/Transforms')
| -rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 121 | 
1 files changed, 100 insertions, 21 deletions
| diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index dca6a0c4bcd..96a169d7ed9 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -580,9 +580,10 @@ public:    LoopVectorizationLegality(Loop *L, ScalarEvolution *SE, const DataLayout *DL,                              DominatorTree *DT, TargetLibraryInfo *TLI, -                            AliasAnalysis *AA, Function *F) +                            AliasAnalysis *AA, Function *F, +                            const TargetTransformInfo *TTI)        : NumLoads(0), NumStores(0), NumPredStores(0), TheLoop(L), SE(SE), DL(DL), -        DT(DT), TLI(TLI), AA(AA), TheFunction(F), Induction(nullptr), +        DT(DT), TLI(TLI), AA(AA), TheFunction(F), TTI(TTI), Induction(nullptr),          WidestIndTy(nullptr), HasFunNoNaNAttr(false), MaxSafeDepDistBytes(-1U) {    } @@ -768,6 +769,21 @@ public:    }    SmallPtrSet<Value *, 8>::iterator strides_end() { return StrideSet.end(); } +  /// Returns true if the target machine supports masked store operation +  /// for the given \p DataType and kind of access to \p Ptr. +  bool isLegalMaskedStore(Type *DataType, Value *Ptr) { +    return TTI->isLegalMaskedStore(DataType, isConsecutivePtr(Ptr)); +  } +  /// Returns true if the target machine supports masked load operation +  /// for the given \p DataType and kind of access to \p Ptr. +  bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { +    return TTI->isLegalMaskedLoad(DataType, isConsecutivePtr(Ptr)); +  } +  /// Returns true if vector representation of the instruction \p I +  /// requires mask. +  bool isMaskRequired(const Instruction* I) { +    return (MaskedOp.count(I) != 0); +  }  private:    /// Check if a single basic block loop is vectorizable.    /// At this point we know that this is a loop with a constant trip count @@ -840,6 +856,8 @@ private:    AliasAnalysis *AA;    /// Parent function    Function *TheFunction; +  /// Target Transform Info +  const TargetTransformInfo *TTI;    //  ---  vectorization state --- // @@ -871,6 +889,10 @@ private:    ValueToValueMap Strides;    SmallPtrSet<Value *, 8> StrideSet; +   +  /// While vectorizing these instructions we have to generate a +  /// call to the appropriate masked intrinsic +  SmallPtrSet<const Instruction*, 8> MaskedOp;  };  /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1373,7 +1395,7 @@ struct LoopVectorize : public FunctionPass {      }      // Check if it is legal to vectorize the loop. -    LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F); +    LoopVectorizationLegality LVL(L, SE, DL, DT, TLI, AA, F, TTI);      if (!LVL.canVectorize()) {        DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n");        emitMissedWarning(F, L, Hints); @@ -1761,7 +1783,8 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {    unsigned ScalarAllocatedSize = DL->getTypeAllocSize(ScalarDataTy);    unsigned VectorElementSize = DL->getTypeStoreSize(DataTy)/VF; -  if (SI && Legal->blockNeedsPredication(SI->getParent())) +  if (SI && Legal->blockNeedsPredication(SI->getParent()) && +      !Legal->isMaskRequired(SI))      return scalarizeInstruction(Instr, true);    if (ScalarAllocatedSize != VectorElementSize) @@ -1855,8 +1878,24 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {        Value *VecPtr = Builder.CreateBitCast(PartPtr,                                              DataTy->getPointerTo(AddressSpace)); -      StoreInst *NewSI = -        Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment); + +      Instruction *NewSI; +      if (Legal->isMaskRequired(SI)) { +        Type *I8PtrTy = +        Builder.getInt8PtrTy(PartPtr->getType()->getPointerAddressSpace()); + +        Value *I8Ptr = Builder.CreateBitCast(PartPtr, I8PtrTy); + +        VectorParts Cond = createBlockInMask(SI->getParent()); +        SmallVector <Value *, 8> Ops; +        Ops.push_back(I8Ptr); +        Ops.push_back(StoredVal[Part]); +        Ops.push_back(Builder.getInt32(Alignment)); +        Ops.push_back(Cond[Part]); +        NewSI = Builder.CreateMaskedStore(Ops); +      } +      else  +        NewSI = Builder.CreateAlignedStore(StoredVal[Part], VecPtr, Alignment);        propagateMetadata(NewSI, SI);      }      return; @@ -1876,9 +1915,26 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) {        PartPtr = Builder.CreateGEP(PartPtr, Builder.getInt32(1 - VF));      } -    Value *VecPtr = Builder.CreateBitCast(PartPtr, -                                          DataTy->getPointerTo(AddressSpace)); -    LoadInst *NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); +    Instruction* NewLI; +    if (Legal->isMaskRequired(LI)) { +      Type *I8PtrTy = +        Builder.getInt8PtrTy(PartPtr->getType()->getPointerAddressSpace()); + +      Value *I8Ptr = Builder.CreateBitCast(PartPtr, I8PtrTy); + +      VectorParts SrcMask = createBlockInMask(LI->getParent()); +      SmallVector <Value *, 8> Ops; +      Ops.push_back(I8Ptr); +      Ops.push_back(UndefValue::get(DataTy)); +      Ops.push_back(Builder.getInt32(Alignment)); +      Ops.push_back(SrcMask[Part]); +      NewLI = Builder.CreateMaskedLoad(Ops); +    } +    else { +      Value *VecPtr = Builder.CreateBitCast(PartPtr, +                                            DataTy->getPointerTo(AddressSpace)); +      NewLI = Builder.CreateAlignedLoad(VecPtr, Alignment, "wide.load"); +    }      propagateMetadata(NewLI, LI);      Entry[Part] = Reverse ? reverseVector(NewLI) :  NewLI;    } @@ -5305,12 +5361,27 @@ bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB)  {  bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB,                                             SmallPtrSetImpl<Value *> &SafePtrs) { +      for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { +    // Check that we don't have a constant expression that can trap as operand. +    for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); +         OI != OE; ++OI) { +      if (Constant *C = dyn_cast<Constant>(*OI)) +        if (C->canTrap()) +          return false; +    }      // We might be able to hoist the load.      if (it->mayReadFromMemory()) {        LoadInst *LI = dyn_cast<LoadInst>(it); -      if (!LI || !SafePtrs.count(LI->getPointerOperand())) +      if (!LI) +        return false; +      if (!SafePtrs.count(LI->getPointerOperand())) { +        if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand())) { +          MaskedOp.insert(LI); +          continue; +        }          return false; +      }      }      // We don't predicate stores at the moment. @@ -5318,22 +5389,30 @@ bool LoopVectorizationLegality::blockCanBePredicated(BasicBlock *BB,        StoreInst *SI = dyn_cast<StoreInst>(it);        // We only support predication of stores in basic blocks with one        // predecessor. -      if (!SI || ++NumPredStores > NumberOfStoresToPredicate || -          !SafePtrs.count(SI->getPointerOperand()) || -          !SI->getParent()->getSinglePredecessor()) +      if (!SI) +        return false; + +      bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); +      bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); +       +      if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || +          !isSinglePredecessor) { +        // Build a masked store if it is legal for the target, otherwise scalarize +        // the block. +        bool isLegalMaskedOp = +          isLegalMaskedStore(SI->getValueOperand()->getType(), +                             SI->getPointerOperand()); +        if (isLegalMaskedOp) { +          --NumPredStores; +          MaskedOp.insert(SI); +          continue; +        }          return false; +      }      }      if (it->mayThrow())        return false; -    // Check that we don't have a constant expression that can trap as operand. -    for (Instruction::op_iterator OI = it->op_begin(), OE = it->op_end(); -         OI != OE; ++OI) { -      if (Constant *C = dyn_cast<Constant>(*OI)) -        if (C->canTrap()) -          return false; -    } -      // The instructions below can trap.      switch (it->getOpcode()) {      default: continue; | 

