diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 96 |
1 files changed, 92 insertions, 4 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 1ef4acfb058..bc79e4534e6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2434,6 +2434,77 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return nullptr; } +bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, + Value *&RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { + // TODO: Generalize this to work with other comparison idioms or ensure + // they get canonicalized into this form. + + // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 + // Greater), where Equal, Less and Greater are placeholders for any three + // constants. + ICmpInst::Predicate PredA, PredB; + if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && + match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && + PredA == ICmpInst::ICMP_EQ && + match(SI->getFalseValue(), + m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), + m_ConstantInt(Less), m_ConstantInt(Greater))) && + PredB == ICmpInst::ICMP_SLT) { + return true; + } + return false; +} + +Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, + Instruction *Select, + ConstantInt *C) { + + assert(C && "Cmp RHS should be a constant int!"); + // If we're testing a constant value against the result of a three way + // comparison, the result can be expressed directly in terms of the + // original values being compared. Note: We could possibly be more + // aggressive here and remove the hasOneUse test. The original select is + // really likely to simplify or sink when we remove a test of the result. + Value *OrigLHS, *OrigRHS; + ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; + if (Cmp.hasOneUse() && + matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS, + C1LessThan, C2Equal, C3GreaterThan)) { + assert(C1LessThan && C2Equal && C3GreaterThan); + + bool TrueWhenLessThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) + ->isAllOnesValue(); + bool TrueWhenEqual = + ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) + ->isAllOnesValue(); + bool TrueWhenGreaterThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) + ->isAllOnesValue(); + + // This generates the new instruction that will replace the original Cmp + // Instruction. Instead of enumerating the various combinations when + // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus + // false, we rely on chaining of ORs and future passes of InstCombine to + // simplify the OR further (i.e. a s< b || a == b becomes a s<= b). + + // When none of the three constants satisfy the predicate for the RHS (C), + // the entire original Cmp can be simplified to a false. + Value *Cond = Builder->getFalse(); + if (TrueWhenLessThan) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); + if (TrueWhenEqual) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); + if (TrueWhenGreaterThan) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); + + return replaceInstUsesWith(Cmp, Cond); + } + return nullptr; +} + /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { @@ -2493,11 +2564,28 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } + // Match against CmpInst LHS being instructions other than binary operators. Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && - LHSI->getOpcode() == Instruction::Trunc) - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) - return I; + if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { + switch (LHSI->getOpcode()) { + case Instruction::Select: + { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) + return I; + break; + } + case Instruction::Trunc: + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; + break; + default: + break; + } + } if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) return I; |