diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index b318d0b2de9..5e0a3c37979 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4026,7 +4026,8 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; } -static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp) { +static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); Value *X; @@ -4038,15 +4039,25 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp) { if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) { // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. + // TODO: This is too strict. We can handle some predicates (equality?). if (CastOp0->getOpcode() != CastOp1->getOpcode()) return nullptr; // Not an extension from the same type? - // TODO: Handle this by extending the narrower operand to the type of - // the wider operand. Value *Y = CastOp1->getOperand(0); - if (X->getType() != Y->getType()) - return nullptr; + Type *XTy = X->getType(), *YTy = Y->getType(); + if (XTy != YTy) { + // One of the casts must have one use because we are creating a new cast. + if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse()) + return nullptr; + // Extend the narrower operand to the type of the wider operand. + if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) + X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy); + else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) + Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy); + else + return nullptr; + } // (zext X) == (zext Y) --> X == Y // (sext X) == (sext Y) --> X == Y @@ -4148,7 +4159,7 @@ Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); } - return foldICmpWithZextOrSext(ICmp); + return foldICmpWithZextOrSext(ICmp, Builder); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { |