diff options
-rw-r--r-- | llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp | 179 | ||||
-rw-r--r-- | llvm/test/Transforms/RewriteStatepointsForGC/live-vector.ll | 44 |
2 files changed, 166 insertions, 57 deletions
diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 21ff55c697e..ae2ae3af0c7 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -294,12 +294,17 @@ static void analyzeParsePointLiveness( static Value *findBaseDefiningValue(Value *I); -/// If we can trivially determine that the index specified in the given vector -/// is a base pointer, return it. In cases where the entire vector is known to -/// consist of base pointers, the entire vector will be returned. This -/// indicates that the relevant extractelement is a valid base pointer and -/// should be used directly. -static Value *findBaseOfVector(Value *I, Value *Index) { +/// Return a base defining value for the 'Index' element of the given vector +/// instruction 'I'. If Index is null, returns a BDV for the entire vector +/// 'I'. As an optimization, this method will try to determine when the +/// element is known to already be a base pointer. If this can be established, +/// the second value in the returned pair will be true. Note that either a +/// vector or a pointer typed value can be returned. For the former, the +/// vector returned is a BDV (and possibly a base) of the entire vector 'I'. +/// If the later, the return pointer is a BDV (or possibly a base) for the +/// particular element in 'I'. +static std::pair<Value *, bool> +findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) { assert(I->getType()->isVectorTy() && cast<VectorType>(I->getType())->getElementType()->isPointerTy() && "Illegal to ask for the base pointer of a non-pointer type"); @@ -309,7 +314,7 @@ static Value *findBaseOfVector(Value *I, Value *Index) { if (isa<Argument>(I)) // An incoming argument to the function is a base pointer - return I; + return std::make_pair(I, true); // We shouldn't see the address of a global as a vector value? assert(!isa<GlobalVariable>(I) && @@ -320,7 +325,7 @@ static Value *findBaseOfVector(Value *I, Value *Index) { if (isa<UndefValue>(I)) // utterly meaningless, but useful for dealing with partially optimized // code. - return I; + return std::make_pair(I, true); // Due to inheritance, this must be _after_ the global variable and undef // checks @@ -328,38 +333,56 @@ static Value *findBaseOfVector(Value *I, Value *Index) { assert(!isa<GlobalVariable>(I) && !isa<UndefValue>(I) && "order of checks wrong!"); assert(Con->isNullValue() && "null is the only case which makes sense"); - return Con; + return std::make_pair(Con, true); } - + if (isa<LoadInst>(I)) - return I; - + return std::make_pair(I, true); + // For an insert element, we might be able to look through it if we know - // something about the indexes, but if the indices are arbitrary values, we - // can't without much more extensive scalarization. + // something about the indexes. if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(I)) { - Value *InsertIndex = IEI->getOperand(2); - // This index is inserting the value, look for it's base - if (InsertIndex == Index) - return findBaseDefiningValue(IEI->getOperand(1)); - // Both constant, and can't be equal per above. This insert is definitely - // not relevant, look back at the rest of the vector and keep trying. - if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex)) - return findBaseOfVector(IEI->getOperand(0), Index); - } - - // Note: This code is currently rather incomplete. We are essentially only - // handling cases where the vector element is trivially a base pointer. We - // need to update the entire base pointer construction algorithm to know how - // to track vector elements and potentially scalarize, but the case which - // would motivate the work hasn't shown up in real workloads yet. - llvm_unreachable("no base found for vector element"); + if (Index) { + Value *InsertIndex = IEI->getOperand(2); + // This index is inserting the value, look for its BDV + if (InsertIndex == Index) + return std::make_pair(findBaseDefiningValue(IEI->getOperand(1)), false); + // Both constant, and can't be equal per above. This insert is definitely + // not relevant, look back at the rest of the vector and keep trying. + if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex)) + return findBaseDefiningValueOfVector(IEI->getOperand(0), Index); + } + + // We don't know whether this vector contains entirely base pointers or + // not. To be conservatively correct, we treat it as a BDV and will + // duplicate code as needed to construct a parallel vector of bases. + return std::make_pair(IEI, false); + } + + if (isa<ShuffleVectorInst>(I)) + // We don't know whether this vector contains entirely base pointers or + // not. To be conservatively correct, we treat it as a BDV and will + // duplicate code as needed to construct a parallel vector of bases. + // TODO: There a number of local optimizations which could be applied here + // for particular sufflevector patterns. + return std::make_pair(I, false); + + // A PHI or Select is a base defining value. The outer findBasePointer + // algorithm is responsible for constructing a base value for this BDV. + assert((isa<SelectInst>(I) || isa<PHINode>(I)) && + "unknown vector instruction - no base found for vector element"); + return std::make_pair(I, false); } +static bool isKnownBaseResult(Value *V); + /// Helper function for findBasePointer - Will return a value which either a) /// defines the base pointer for the input or b) blocks the simple search /// (i.e. a PHI or Select of two derived pointers) static Value *findBaseDefiningValue(Value *I) { + if (I->getType()->isVectorTy()) + return findBaseDefiningValueOfVector(I).first; + assert(I->getType()->isPointerTy() && "Illegal to ask for the base pointer of a non-pointer type"); @@ -370,16 +393,39 @@ static Value *findBaseDefiningValue(Value *I) { if (auto *EEI = dyn_cast<ExtractElementInst>(I)) { Value *VectorOperand = EEI->getVectorOperand(); Value *Index = EEI->getIndexOperand(); - Value *VectorBase = findBaseOfVector(VectorOperand, Index); - // If the result returned is a vector, we know the entire vector must - // contain base pointers. In that case, the extractelement is a valid base - // for this value. - if (VectorBase->getType()->isVectorTy()) - return EEI; - // Otherwise, we needed to look through the vector to find the base for - // this particular element. - assert(VectorBase->getType()->isPointerTy()); - return VectorBase; + std::pair<Value *, bool> pair = + findBaseDefiningValueOfVector(VectorOperand, Index); + Value *VectorBase = pair.first; + if (VectorBase->getType()->isPointerTy()) + // We found a BDV for this specific element with the vector. This is an + // optimization, but in practice it covers most of the useful cases + // created via scalarization. + return VectorBase; + else { + assert(VectorBase->getType()->isVectorTy()); + if (pair.second) + // If the entire vector returned is known to be entirely base pointers, + // then the extractelement is valid base for this value. + return EEI; + else { + // Otherwise, we have an instruction which potentially produces a + // derived pointer and we need findBasePointers to clone code for us + // such that we can create an instruction which produces the + // accompanying base pointer. + // Note: This code is currently rather incomplete. We don't currently + // support the general form of shufflevector of insertelement. + // Conceptually, these are just 'base defining values' of the same + // variety as phi or select instructions. We need to update the + // findBasePointers algorithm to insert new 'base-only' versions of the + // original instructions. This is relative straight forward to do, but + // the case which would motivate the work hasn't shown up in real + // workloads yet. + assert((isa<PHINode>(VectorBase) || isa<SelectInst>(VectorBase)) && + "need to extend findBasePointers for generic vector" + "instruction cases"); + return VectorBase; + } + } } if (isa<Argument>(I)) @@ -1712,7 +1758,9 @@ static void findLiveReferences( /// slightly non-trivial since it requires a format change. Given how rare /// such cases are (for the moment?) scalarizing is an acceptable comprimise. static void splitVectorValues(Instruction *StatepointInst, - StatepointLiveSetTy &LiveSet, DominatorTree &DT) { + StatepointLiveSetTy &LiveSet, + DenseMap<Value *, Value *>& PointerToBase, + DominatorTree &DT) { SmallVector<Value *, 16> ToSplit; for (Value *V : LiveSet) if (isa<VectorType>(V->getType())) @@ -1721,14 +1769,14 @@ static void splitVectorValues(Instruction *StatepointInst, if (ToSplit.empty()) return; + DenseMap<Value *, SmallVector<Value *, 16>> ElementMapping; + Function &F = *(StatepointInst->getParent()->getParent()); DenseMap<Value *, AllocaInst *> AllocaMap; // First is normal return, second is exceptional return (invoke only) DenseMap<Value *, std::pair<Value *, Value *>> Replacements; for (Value *V : ToSplit) { - LiveSet.erase(V); - AllocaInst *Alloca = new AllocaInst(V->getType(), "", F.getEntryBlock().getFirstNonPHI()); AllocaMap[V] = Alloca; @@ -1738,7 +1786,7 @@ static void splitVectorValues(Instruction *StatepointInst, SmallVector<Value *, 16> Elements; for (unsigned i = 0; i < VT->getNumElements(); i++) Elements.push_back(Builder.CreateExtractElement(V, Builder.getInt32(i))); - LiveSet.insert(Elements.begin(), Elements.end()); + ElementMapping[V] = Elements; auto InsertVectorReform = [&](Instruction *IP) { Builder.SetInsertPoint(IP); @@ -1771,6 +1819,7 @@ static void splitVectorValues(Instruction *StatepointInst, Replacements[V].second = InsertVectorReform(IP); } } + for (Value *V : ToSplit) { AllocaInst *Alloca = AllocaMap[V]; @@ -1814,6 +1863,25 @@ static void splitVectorValues(Instruction *StatepointInst, for (Value *V : ToSplit) Allocas.push_back(AllocaMap[V]); PromoteMemToReg(Allocas, DT); + + // Update our tracking of live pointers and base mappings to account for the + // changes we just made. + for (Value *V : ToSplit) { + auto &Elements = ElementMapping[V]; + + LiveSet.erase(V); + LiveSet.insert(Elements.begin(), Elements.end()); + // We need to update the base mapping as well. + assert(PointerToBase.count(V)); + Value *OldBase = PointerToBase[V]; + auto &BaseElements = ElementMapping[OldBase]; + PointerToBase.erase(V); + assert(Elements.size() == BaseElements.size()); + for (unsigned i = 0; i < Elements.size(); i++) { + Value *Elem = Elements[i]; + PointerToBase[Elem] = BaseElements[i]; + } + } } // Helper function for the "rematerializeLiveValues". It walks use chain @@ -2075,17 +2143,6 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P, // site. findLiveReferences(F, DT, P, toUpdate, records); - // Do a limited scalarization of any live at safepoint vector values which - // contain pointers. This enables this pass to run after vectorization at - // the cost of some possible performance loss. TODO: it would be nice to - // natively support vectors all the way through the backend so we don't need - // to scalarize here. - for (size_t i = 0; i < records.size(); i++) { - struct PartiallyConstructedSafepointRecord &info = records[i]; - Instruction *statepoint = toUpdate[i].getInstruction(); - splitVectorValues(cast<Instruction>(statepoint), info.liveset, DT); - } - // B) Find the base pointers for each live pointer /* scope for caching */ { // Cache the 'defining value' relation used in the computation and @@ -2146,6 +2203,18 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P, } holders.clear(); + // Do a limited scalarization of any live at safepoint vector values which + // contain pointers. This enables this pass to run after vectorization at + // the cost of some possible performance loss. TODO: it would be nice to + // natively support vectors all the way through the backend so we don't need + // to scalarize here. + for (size_t i = 0; i < records.size(); i++) { + struct PartiallyConstructedSafepointRecord &info = records[i]; + Instruction *statepoint = toUpdate[i].getInstruction(); + splitVectorValues(cast<Instruction>(statepoint), info.liveset, + info.PointerToBase, DT); + } + // In order to reduce live set of statepoint we might choose to rematerialize // some values instead of relocating them. This is purelly an optimization and // does not influence correctness. diff --git a/llvm/test/Transforms/RewriteStatepointsForGC/live-vector.ll b/llvm/test/Transforms/RewriteStatepointsForGC/live-vector.ll index 0a4456a6835..26ad73737ad 100644 --- a/llvm/test/Transforms/RewriteStatepointsForGC/live-vector.ll +++ b/llvm/test/Transforms/RewriteStatepointsForGC/live-vector.ll @@ -105,8 +105,6 @@ define <2 x i64 addrspace(1)*> @test5(i64 addrspace(1)* %p) ; CHECK-NEXT: bitcast ; CHECK-NEXT: gc.relocate ; CHECK-NEXT: bitcast -; CHECK-NEXT: gc.relocate -; CHECK-NEXT: bitcast ; CHECK-NEXT: insertelement ; CHECK-NEXT: insertelement ; CHECK-NEXT: ret <2 x i64 addrspace(1)*> %7 @@ -116,6 +114,48 @@ entry: ret <2 x i64 addrspace(1)*> %vec } + +; A base vector from a load +define <2 x i64 addrspace(1)*> @test6(i1 %cnd, <2 x i64 addrspace(1)*>* %ptr) + gc "statepoint-example" { +; CHECK-LABEL: test6 +; CHECK-LABEL: merge: +; CHECK-NEXT: = phi +; CHECK-NEXT: = phi +; CHECK-NEXT: extractelement +; CHECK-NEXT: extractelement +; CHECK-NEXT: extractelement +; CHECK-NEXT: extractelement +; CHECK-NEXT: gc.statepoint +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: insertelement +; CHECK-NEXT: insertelement +; CHECK-NEXT: insertelement +; CHECK-NEXT: insertelement +; CHECK-NEXT: ret <2 x i64 addrspace(1)*> +entry: + br i1 %cnd, label %taken, label %untaken +taken: + %obja = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr + br label %merge +untaken: + %objb = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr + br label %merge + +merge: + %obj = phi <2 x i64 addrspace(1)*> [%obja, %taken], [%objb, %untaken] + %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0) + ret <2 x i64 addrspace(1)*> %obj +} + + declare void @do_safepoint() declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...) |