diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/VectorUtils.cpp | 130 | ||||
-rw-r--r-- | llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 191 |
2 files changed, 310 insertions, 11 deletions
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 93720857662..4153c843c40 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -11,9 +11,12 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/PatternMatch.h" @@ -434,3 +437,130 @@ llvm::Value *llvm::getSplatValue(Value *V) { return InsertEltInst->getOperand(1); } + +DenseMap<Instruction*, uint64_t> llvm::computeMinimumValueSizes( + ArrayRef<BasicBlock*> Blocks, DemandedBits &DB, + const TargetTransformInfo *TTI) { + + // DemandedBits will give us every value's live-out bits. But we want + // to ensure no extra casts would need to be inserted, so every DAG + // of connected values must have the same minimum bitwidth. + EquivalenceClasses<Value*> ECs; + SmallVector<Value*,16> Worklist; + SmallPtrSet<Value*,4> Roots; + SmallPtrSet<Value*,16> Visited; + DenseMap<Value*,uint64_t> DBits; + SmallPtrSet<Instruction*,4> InstructionSet; + DenseMap<Instruction*, uint64_t> MinBWs; + + // Determine the roots. We work bottom-up, from truncs or icmps. + bool SeenExtFromIllegalType = false; + for (auto *BB : Blocks) + for (auto &I : *BB) { + InstructionSet.insert(&I); + + if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && + !TTI->isTypeLegal(I.getOperand(0)->getType())) + SeenExtFromIllegalType = true; + + // Only deal with non-vector integers up to 64-bits wide. + if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && + !I.getType()->isVectorTy() && + I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { + // Don't make work for ourselves. If we know the loaded type is legal, + // don't add it to the worklist. + if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) + continue; + + Worklist.push_back(&I); + Roots.insert(&I); + } + } + // Early exit. + if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) + return MinBWs; + + // Now proceed breadth-first, unioning values together. + while (!Worklist.empty()) { + Value *Val = Worklist.pop_back_val(); + Value *Leader = ECs.getOrInsertLeaderValue(Val); + + if (Visited.count(Val)) + continue; + Visited.insert(Val); + + // Non-instructions terminate a chain successfully. + if (!isa<Instruction>(Val)) + continue; + Instruction *I = cast<Instruction>(Val); + + // If we encounter a type that is larger than 64 bits, we can't represent + // it so bail out. + if (DB.getDemandedBits(I).getBitWidth() > 64) + return DenseMap<Instruction*,uint64_t>(); + + uint64_t V = DB.getDemandedBits(I).getZExtValue(); + DBits[Leader] |= V; + + // Casts, loads and instructions outside of our range terminate a chain + // successfully. + if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || + !InstructionSet.count(I)) + continue; + + // Unsafe casts terminate a chain unsuccessfully. We can't do anything + // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to + // transform anything that relies on them. + if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || + !I->getType()->isIntegerTy()) { + DBits[Leader] |= ~0ULL; + continue; + } + + // We don't modify the types of PHIs. Reductions will already have been + // truncated if possible, and inductions' sizes will have been chosen by + // indvars. + if (isa<PHINode>(I)) + continue; + + if (DBits[Leader] == ~0ULL) + // All bits demanded, no point continuing. + continue; + + for (Value *O : cast<User>(I)->operands()) { + ECs.unionSets(Leader, O); + Worklist.push_back(O); + } + } + + // Now we've discovered all values, walk them to see if there are + // any users we didn't see. If there are, we can't optimize that + // chain. + for (auto &I : DBits) + for (auto *U : I.first->users()) + if (U->getType()->isIntegerTy() && DBits.count(U) == 0) + DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; + + for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { + uint64_t LeaderDemandedBits = 0; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) + LeaderDemandedBits |= DBits[*MI]; + + uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - + llvm::countLeadingZeros(LeaderDemandedBits); + // Round up to a power of 2 + if (!isPowerOf2_64((uint64_t)MinBW)) + MinBW = NextPowerOf2(MinBW); + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { + if (!isa<Instruction>(*MI)) + continue; + Type *Ty = (*MI)->getType(); + if (Roots.count(*MI)) + Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); + if (MinBW < Ty->getScalarSizeInBits()) + MinBWs[cast<Instruction>(*MI)] = MinBW; + } + } + + return MinBWs; +} diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index bd381d31de7..ec91e138e86 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -48,7 +48,6 @@ #include "llvm/Transforms/Vectorize.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" @@ -63,6 +62,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" @@ -101,6 +101,7 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <algorithm> +#include <functional> #include <map> #include <tuple> @@ -280,7 +281,12 @@ public: AddedSafetyChecks(false) {} // Perform the actual loop widening (vectorization). - void vectorize(LoopVectorizationLegality *L) { + // MinimumBitWidths maps scalar integer values to the smallest bitwidth they + // can be validly truncated to. The cost model has assumed this truncation + // will happen when vectorizing. + void vectorize(LoopVectorizationLegality *L, + DenseMap<Instruction*,uint64_t> MinimumBitWidths) { + MinBWs = MinimumBitWidths; Legal = L; // Create a new empty loop. Unlink the old loop and connect the new one. createEmptyLoop(); @@ -329,6 +335,9 @@ protected: /// See PR14725. void fixLCSSAPHIs(); + /// Shrinks vector element sizes based on information in "MinBWs". + void truncateToMinimalBitwidths(); + /// A helper function that computes the predicate of the block BB, assuming /// that the header block of the loop is set to True. It returns the *entry* /// mask for the block BB. @@ -339,7 +348,7 @@ protected: /// A helper function to vectorize a single BB within the innermost loop. void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV); - + /// Vectorize a single PHINode in a block. This method handles the induction /// variable canonicalization. It supports both VF = 1 for unrolled loops and /// arbitrary length vectors. @@ -499,6 +508,10 @@ protected: /// Trip count of the widened loop (TripCount - TripCount % (VF*UF)) Value *VectorTripCount; + /// Map of scalar integer values to the smallest bitwidth they can be legally + /// represented as. The vector equivalents of these values should be truncated + /// to this type. + DenseMap<Instruction*,uint64_t> MinBWs; LoopVectorizationLegality *Legal; // Record whether runtime check is added. @@ -1346,10 +1359,11 @@ public: LoopVectorizationCostModel(Loop *L, ScalarEvolution *SE, LoopInfo *LI, LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, - const TargetLibraryInfo *TLI, AssumptionCache *AC, + const TargetLibraryInfo *TLI, DemandedBits *DB, + AssumptionCache *AC, const Function *F, const LoopVectorizeHints *Hints, SmallPtrSetImpl<const Value *> &ValuesToIgnore) - : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), + : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {} /// Information about vectorization costs @@ -1419,6 +1433,12 @@ private: emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message); } +public: + /// Map of scalar integer values to the smallest bitwidth they can be legally + /// represented as. The vector equivalents of these values should be truncated + /// to this type. + DenseMap<Instruction*,uint64_t> MinBWs; + /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. @@ -1431,6 +1451,8 @@ private: const TargetTransformInfo &TTI; /// Target Library Info. const TargetLibraryInfo *TLI; + /// Demanded bits analysis + DemandedBits *DB; const Function *TheFunction; // Loop Vectorize Hint. const LoopVectorizeHints *Hints; @@ -1523,6 +1545,7 @@ struct LoopVectorize : public FunctionPass { DominatorTree *DT; BlockFrequencyInfo *BFI; TargetLibraryInfo *TLI; + DemandedBits *DB; AliasAnalysis *AA; AssumptionCache *AC; LoopAccessAnalysis *LAA; @@ -1542,6 +1565,7 @@ struct LoopVectorize : public FunctionPass { AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); LAA = &getAnalysis<LoopAccessAnalysis>(); + DB = &getAnalysis<DemandedBits>(); // Compute some weights outside of the loop over the loops. Compute this // using a BranchProbability to re-use its scaling math. @@ -1687,7 +1711,7 @@ struct LoopVectorize : public FunctionPass { } // Use the cost model. - LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints, + LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints, ValuesToIgnore); // Check the function attributes to find out if this function should be @@ -1800,7 +1824,7 @@ struct LoopVectorize : public FunctionPass { // If we decided that it is not legal to vectorize the loop then // interleave it. InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC); - Unroller.vectorize(&LVL); + Unroller.vectorize(&LVL, CM.MinBWs); emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), Twine("interleaved loop (interleaved count: ") + @@ -1808,7 +1832,7 @@ struct LoopVectorize : public FunctionPass { } else { // If we decided that it is *legal* to vectorize the loop then do it. InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC); - LB.vectorize(&LVL); + LB.vectorize(&LVL, CM.MinBWs); ++LoopsVectorized; // Add metadata to disable runtime unrolling scalar loop when there's no @@ -1842,6 +1866,7 @@ struct LoopVectorize : public FunctionPass { AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<LoopAccessAnalysis>(); + AU.addRequired<DemandedBits>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); @@ -2009,6 +2034,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) { // If this scalar is unknown, assume that it is a constant or that it is // loop invariant. Broadcast V and save the value for future uses. Value *B = getBroadcastInstrs(V); + return WidenMap.splat(V, B); } @@ -3102,6 +3128,117 @@ static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF, return TTI.getIntrinsicInstrCost(ID, RetTy, Tys); } +static Type *smallestIntegerVectorType(Type *T1, Type *T2) { + IntegerType *I1 = cast<IntegerType>(T1->getVectorElementType()); + IntegerType *I2 = cast<IntegerType>(T2->getVectorElementType()); + return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2; +} +static Type *largestIntegerVectorType(Type *T1, Type *T2) { + IntegerType *I1 = cast<IntegerType>(T1->getVectorElementType()); + IntegerType *I2 = cast<IntegerType>(T2->getVectorElementType()); + return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; +} + +void InnerLoopVectorizer::truncateToMinimalBitwidths() { + // For every instruction `I` in MinBWs, truncate the operands, create a + // truncated version of `I` and reextend its result. InstCombine runs + // later and will remove any ext/trunc pairs. + // + for (auto &KV : MinBWs) { + VectorParts &Parts = WidenMap.get(KV.first); + for (Value *&I : Parts) { + if (I->use_empty()) + continue; + Type *OriginalTy = I->getType(); + Type *ScalarTruncatedTy = IntegerType::get(OriginalTy->getContext(), + KV.second); + Type *TruncatedTy = VectorType::get(ScalarTruncatedTy, + OriginalTy->getVectorNumElements()); + if (TruncatedTy == OriginalTy) + continue; + + IRBuilder<> B(cast<Instruction>(I)); + auto ShrinkOperand = [&](Value *V) -> Value* { + if (auto *ZI = dyn_cast<ZExtInst>(V)) + if (ZI->getSrcTy() == TruncatedTy) + return ZI->getOperand(0); + return B.CreateZExtOrTrunc(V, TruncatedTy); + }; + + // The actual instruction modification depends on the instruction type, + // unfortunately. + Value *NewI = nullptr; + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { + NewI = B.CreateBinOp(BO->getOpcode(), + ShrinkOperand(BO->getOperand(0)), + ShrinkOperand(BO->getOperand(1))); + cast<BinaryOperator>(NewI)->copyIRFlags(I); + } else if (ICmpInst *CI = dyn_cast<ICmpInst>(I)) { + NewI = B.CreateICmp(CI->getPredicate(), + ShrinkOperand(CI->getOperand(0)), + ShrinkOperand(CI->getOperand(1))); + } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { + NewI = B.CreateSelect(SI->getCondition(), + ShrinkOperand(SI->getTrueValue()), + ShrinkOperand(SI->getFalseValue())); + } else if (CastInst *CI = dyn_cast<CastInst>(I)) { + switch (CI->getOpcode()) { + default: llvm_unreachable("Unhandled cast!"); + case Instruction::Trunc: + NewI = ShrinkOperand(CI->getOperand(0)); + break; + case Instruction::SExt: + NewI = B.CreateSExtOrTrunc(CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, + TruncatedTy)); + break; + case Instruction::ZExt: + NewI = B.CreateZExtOrTrunc(CI->getOperand(0), + smallestIntegerVectorType(OriginalTy, + TruncatedTy)); + break; + } + } else if (ShuffleVectorInst *SI = dyn_cast<ShuffleVectorInst>(I)) { + auto Elements0 = SI->getOperand(0)->getType()->getVectorNumElements(); + auto *O0 = + B.CreateZExtOrTrunc(SI->getOperand(0), + VectorType::get(ScalarTruncatedTy, Elements0)); + auto Elements1 = SI->getOperand(1)->getType()->getVectorNumElements(); + auto *O1 = + B.CreateZExtOrTrunc(SI->getOperand(1), + VectorType::get(ScalarTruncatedTy, Elements1)); + + NewI = B.CreateShuffleVector(O0, O1, SI->getMask()); + } else if (isa<LoadInst>(I)) { + // Don't do anything with the operands, just extend the result. + continue; + } else { + llvm_unreachable("Unhandled instruction type!"); + } + + // Lastly, extend the result. + NewI->takeName(cast<Instruction>(I)); + Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy); + I->replaceAllUsesWith(Res); + cast<Instruction>(I)->eraseFromParent(); + I = Res; + } + } + + // We'll have created a bunch of ZExts that are now parentless. Clean up. + for (auto &KV : MinBWs) { + VectorParts &Parts = WidenMap.get(KV.first); + for (Value *&I : Parts) { + ZExtInst *Inst = dyn_cast<ZExtInst>(I); + if (Inst && Inst->use_empty()) { + Value *NewI = Inst->getOperand(0); + Inst->eraseFromParent(); + I = NewI; + } + } + } +} + void InnerLoopVectorizer::vectorizeLoop() { //===------------------------------------------------===// // @@ -3132,6 +3269,11 @@ void InnerLoopVectorizer::vectorizeLoop() { be = DFS.endRPO(); bb != be; ++bb) vectorizeBlockInLoop(*bb, &RdxPHIsToFix); + // Insert truncates and extends for any truncated instructions as hints to + // InstCombine. + if (VF > 1) + truncateToMinimalBitwidths(); + // At this point every instruction in the original loop is widened to // a vector form. We are almost done. Now, we need to fix the PHI nodes // that we vectorized. The PHI nodes are currently empty because we did @@ -3565,6 +3707,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // For each instruction in the old loop. for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { VectorParts &Entry = WidenMap.get(it); + switch (it->getOpcode()) { case Instruction::Br: // Nothing to do for PHIs and BR, since we already took care of the @@ -3628,7 +3771,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { VectorParts &Cond = getVectorValue(it->getOperand(0)); VectorParts &Op0 = getVectorValue(it->getOperand(1)); VectorParts &Op1 = getVectorValue(it->getOperand(2)); - + Value *ScalarCond = (VF == 1) ? Cond[0] : Builder.CreateExtractElement(Cond[0], Builder.getInt32(0)); @@ -4563,6 +4706,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { unsigned TC = SE->getSmallConstantTripCount(TheLoop); DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); + MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); unsigned WidestType = getWidestType(); unsigned WidestRegister = TTI.getRegisterBitWidth(true); unsigned MaxSafeDepDist = -1U; @@ -5086,6 +5230,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { VF = 1; Type *RetTy = I->getType(); + if (VF > 1 && MinBWs.count(I)) + RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); Type *VectorTy = ToVectorTy(RetTy, VF); // TODO: We need to estimate the cost of intrinsic calls. @@ -5168,6 +5314,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { case Instruction::ICmp: case Instruction::FCmp: { Type *ValTy = I->getOperand(0)->getType(); + if (VF > 1 && MinBWs.count(dyn_cast<Instruction>(I->getOperand(0)))) + ValTy = IntegerType::get(ValTy->getContext(), MinBWs[I]); VectorTy = ToVectorTy(ValTy, VF); return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy); } @@ -5291,8 +5439,28 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { Legal->isInductionVariable(I->getOperand(0))) return TTI.getCastInstrCost(I->getOpcode(), I->getType(), I->getOperand(0)->getType()); - - Type *SrcVecTy = ToVectorTy(I->getOperand(0)->getType(), VF); + + Type *SrcScalarTy = I->getOperand(0)->getType(); + Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); + if (VF > 1 && MinBWs.count(I)) { + // This cast is going to be shrunk. This may remove the cast or it might + // turn it into slightly different cast. For example, if MinBW == 16, + // "zext i8 %1 to i32" becomes "zext i8 %1 to i16". + // + // Calculate the modified src and dest types. + Type *MinVecTy = VectorTy; + if (I->getOpcode() == Instruction::Trunc) { + SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); + VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF), + MinVecTy); + } else if (I->getOpcode() == Instruction::ZExt || + I->getOpcode() == Instruction::SExt) { + SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy); + VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF), + MinVecTy); + } + } + return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy); } case Instruction::Call: { @@ -5343,6 +5511,7 @@ INITIALIZE_PASS_DEPENDENCY(LCSSA) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) +INITIALIZE_PASS_DEPENDENCY(DemandedBits) INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) namespace llvm { |