diff options
Diffstat (limited to 'llvm/lib/Transforms')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 73 | ||||
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineInternal.h | 1 |
2 files changed, 73 insertions, 1 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 3139cb33045..a3f56bc8d95 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -443,13 +443,81 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } +/// Rotate left/right may occur in a wider type than necessary because of type +/// promotion rules. Try to narrow all of the component instructions. +Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { + assert((isa<VectorType>(Trunc.getSrcTy()) || + shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && + "Don't narrow to an illegal scalar type"); + + // First, find an or'd pair of opposite shifts with the same shifted operand: + // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) + Value *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *ShVal, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // The shift amounts must add up to the narrow bit width. + Value *ShAmt; + bool SubIsOnLHS; + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (match(ShAmt0, + m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { + ShAmt = ShAmt1; + SubIsOnLHS = true; + } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), + m_Specific(ShAmt0))))) { + ShAmt = ShAmt0; + SubIsOnLHS = false; + } else { + return nullptr; + } + + // The shifted value must have high zeros in the wide type. Typically, this + // will be a zext, but it could also be the result of an 'and' or 'shift'. + unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); + APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); + if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + return nullptr; + + // We have an unnecessarily wide rotate! + // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // Narrow it down to eliminate the zext/trunc: + // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); + + // Mask both shift amounts to ensure there's no UB from oversized shifts. + Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); + Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); + + // Truncate the original value and use narrow ops. + Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; + Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; + Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); + Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); + return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); +} + /// Try to narrow the width of math or bitwise logic instructions by pulling a /// truncate ahead of binary operators. /// TODO: Transforms for truncated shifts should be moved into here. Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) + if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; BinaryOperator *BinOp; @@ -485,6 +553,9 @@ Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { default: break; } + if (Instruction *NarrowOr = narrowRotate(Trunc)) + return NarrowOr; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index f550dd539fd..e6d5d1c0e4d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -440,6 +440,7 @@ private: Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); /// Determine if a pair of casts can be replaced by a single cast. |

