diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index eea890c815c..f35a6ce2d93 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -296,6 +296,49 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } +/// If we have a shift-by-constant of a bitwise logic op that itself has a +/// shift-by-constant operand with identical opcode, we may be able to convert +/// that into 2 independent shifts followed by the logic op. This eliminates a +/// a use of an intermediate value (reduces dependency chain). +static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.isShift() && "Expected a shift as input"); + auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); + if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) + return nullptr; + + const APInt *C0, *C1; + if (!match(I.getOperand(1), m_APInt(C1))) + return nullptr; + + Instruction::BinaryOps ShiftOpcode = I.getOpcode(); + Type *Ty = I.getType(); + + // Find a matching one-use shift by constant. The fold is not valid if the sum + // of the shift values equals or exceeds bitwidth. + // TODO: Remove the one-use check if the other logic operand (Y) is constant. + Value *X, *Y; + auto matchFirstShift = [&](Value *V) { + return match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) && + cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode && + (*C0 + *C1).ult(Ty->getScalarSizeInBits()); + }; + + // Logic ops are commutative, so check each operand for a match. + if (matchFirstShift(LogicInst->getOperand(0))) + Y = LogicInst->getOperand(1); + else if (matchFirstShift(LogicInst->getOperand(1))) + Y = LogicInst->getOperand(0); + else + return nullptr; + + // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) + Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1); + Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); +} + Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -348,6 +391,9 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return &I; } + if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) + return Logic; + return nullptr; } |