diff options
Diffstat (limited to 'llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp')
-rw-r--r-- | llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp | 117 |
1 files changed, 101 insertions, 16 deletions
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp index 4d030409253..0983126fa18 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp @@ -83,8 +83,10 @@ private: const DataLayout *DL = nullptr; MemoryDependenceResults *MDA = nullptr; + bool checkArgumentUses(Value &Arg) const; bool isOutArgumentCandidate(Argument &Arg) const; + bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const; public: static char ID; @@ -110,27 +112,49 @@ INITIALIZE_PASS_END(AMDGPURewriteOutArguments, DEBUG_TYPE, char AMDGPURewriteOutArguments::ID = 0; -bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const { +bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const { const int MaxUses = 10; - const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs; int UseCount = 0; - PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType()); - - // TODO: It might be useful for any out arguments, not just privates. - if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() && - !AnyAddressSpace) || - Arg.hasByValAttr() || Arg.hasStructRetAttr() || - DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) { - return false; - } - for (Use &U : Arg.uses()) { StoreInst *SI = dyn_cast<StoreInst>(U.getUser()); if (UseCount > MaxUses) return false; - if (!SI || !SI->isSimple() || + if (!SI) { + auto *BCI = dyn_cast<BitCastInst>(U.getUser()); + if (!BCI || !BCI->hasOneUse()) + return false; + + // We don't handle multiple stores currently, so stores to aggregate + // pointers aren't worth the trouble since they are canonically split up. + Type *DestEltTy = BCI->getType()->getPointerElementType(); + if (DestEltTy->isAggregateType()) + return false; + + // We could handle these if we had a convenient way to bitcast between + // them. + Type *SrcEltTy = Arg.getType()->getPointerElementType(); + if (SrcEltTy->isArrayTy()) + return false; + + // Special case handle structs with single members. It is useful to handle + // some casts between structs and non-structs, but we can't bitcast + // directly between them. directly bitcast between them. Blender uses + // some casts that look like { <3 x float> }* to <4 x float>* + if ((SrcEltTy->isStructTy() && (SrcEltTy->getNumContainedTypes() != 1))) + return false; + + // Clang emits OpenCL 3-vector type accesses with a bitcast to the + // equivalent 4-element vector and accesses that, and we're looking for + // this pointer cast. + if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy)) + return false; + + return checkArgumentUses(*BCI); + } + + if (!SI->isSimple() || U.getOperandNo() != StoreInst::getPointerOperandIndex()) return false; @@ -141,11 +165,40 @@ bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const { return UseCount > 0; } +bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const { + const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs; + PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType()); + + // TODO: It might be useful for any out arguments, not just privates. + if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() && + !AnyAddressSpace) || + Arg.hasByValAttr() || Arg.hasStructRetAttr() || + DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) { + return false; + } + + return checkArgumentUses(Arg); +} + bool AMDGPURewriteOutArguments::doInitialization(Module &M) { DL = &M.getDataLayout(); return false; } +bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const { + VectorType *VT0 = dyn_cast<VectorType>(Ty0); + VectorType *VT1 = dyn_cast<VectorType>(Ty1); + if (!VT0 || !VT1) + return false; + + if (VT0->getNumElements() != 3 || + VT1->getNumElements() != 4) + return false; + + return DL->getTypeSizeInBits(VT0->getElementType()) == + DL->getTypeSizeInBits(VT1->getElementType()); +} + bool AMDGPURewriteOutArguments::runOnFunction(Function &F) { if (skipFunction(F)) return false; @@ -316,8 +369,33 @@ bool AMDGPURewriteOutArguments::runOnFunction(Function &F) { if (RetVal) NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++); + for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) { - NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++); + Argument *Arg = ReturnPoint.first; + Value *Val = ReturnPoint.second; + Type *EltTy = Arg->getType()->getPointerElementType(); + if (Val->getType() != EltTy) { + Type *EffectiveEltTy = EltTy; + if (StructType *CT = dyn_cast<StructType>(EltTy)) { + assert(CT->getNumContainedTypes() == 1); + EffectiveEltTy = CT->getContainedType(0); + } + + if (DL->getTypeSizeInBits(EffectiveEltTy) != + DL->getTypeSizeInBits(Val->getType())) { + assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType())); + Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()), + { 0, 1, 2 }); + } + + Val = B.CreateBitCast(Val, EffectiveEltTy); + + // Re-create single element composite. + if (EltTy != EffectiveEltTy) + Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0); + } + + NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++); } if (RetVal) @@ -348,13 +426,20 @@ bool AMDGPURewriteOutArguments::runOnFunction(Function &F) { if (!OutArgIndexes.count(Arg.getArgNo())) continue; - auto *EltTy = Arg.getType()->getPointerElementType(); + PointerType *ArgType = cast<PointerType>(Arg.getType()); + + auto *EltTy = ArgType->getElementType(); unsigned Align = Arg.getParamAlignment(); if (Align == 0) Align = DL->getABITypeAlignment(EltTy); Value *Val = B.CreateExtractValue(StubCall, RetIdx++); - B.CreateAlignedStore(Val, &Arg, Align); + Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace()); + + // We can peek through bitcasts, so the type may not match. + Value *PtrVal = B.CreateBitCast(&Arg, PtrTy); + + B.CreateAlignedStore(Val, PtrVal, Align); } if (!RetTy->isVoidTy()) { |