diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 96b8b4ffac6..7f5fb926440 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1289,6 +1289,63 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) { return nullptr; } +/// Reduce a sequence of min/max with a common operand. +static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, + Value *RHS, + InstCombiner::BuilderTy &Builder) { + assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max"); + // TODO: Allow FP min/max with nnan/nsz. + if (!LHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + Value *A, *B, *C, *D; + SelectPatternResult L = matchSelectPattern(LHS, A, B); + SelectPatternResult R = matchSelectPattern(RHS, C, D); + if (SPF != L.Flavor || L.Flavor != R.Flavor) + return nullptr; + + // Look for a common operand. The use checks are different than usual because + // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by + // the select. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (LHS->getNumUses() <= 2 && RHS->getNumUses() > 2) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else if (RHS->getNumUses() <= 2) { + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + if (!MinMaxOp || !ThirdOp) + return nullptr; + + CmpInst::Predicate P = getCmpPredicateForMinMax(SPF); + Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp); + return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1563,6 +1620,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); return BinaryOperator::CreateNot(NewSel); } + + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } if (SPF) { |