diff options
Diffstat (limited to 'llvm/lib/Transforms')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index a1ac64dda35..e082cc30196 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -480,6 +480,38 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, return nullptr; } +/// Try to narrow the width of an insert element. This could be generalized for +/// any vector constant, but we limit the transform to insertion into undef to +/// avoid potential backend problems from unsupported insertion widths. This +/// could also be extended to handle the case of inserting a scalar constant +/// into a vector variable. +static Instruction *shrinkInsertElt(CastInst &Trunc, + InstCombiner::BuilderTy &Builder) { + Instruction::CastOps Opcode = Trunc.getOpcode(); + assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && + "Unexpected instruction for shrinking"); + + auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0)); + if (!InsElt || !InsElt->hasOneUse()) + return nullptr; + + Type *DestTy = Trunc.getType(); + Type *DestScalarTy = DestTy->getScalarType(); + Value *VecOp = InsElt->getOperand(0); + Value *ScalarOp = InsElt->getOperand(1); + Value *Index = InsElt->getOperand(2); + + if (isa<UndefValue>(VecOp)) { + // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index + // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index + UndefValue *NarrowUndef = UndefValue::get(DestTy); + Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy); + return InsertElementInst::Create(NarrowUndef, NarrowOp, Index); + } + + return nullptr; +} + Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; @@ -574,6 +606,9 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *I = shrinkSplatShuffle(CI, *Builder)) return I; + if (Instruction *I = shrinkInsertElt(CI, *Builder)) + return I; + if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && shouldChangeType(SrcTy, DestTy)) { // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the @@ -1426,6 +1461,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { } } + if (Instruction *I = shrinkInsertElt(CI, *Builder)) + return I; + return nullptr; } |

