diff options
-rw-r--r-- | llvm/lib/IR/Constants.cpp | 13 | ||||
-rw-r--r-- | llvm/unittests/IR/ConstantsTest.cpp | 39 |
2 files changed, 48 insertions, 4 deletions
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index cafb412b795..054375aab6c 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -280,12 +280,17 @@ bool Constant::isElementWiseEqual(Value *Y) const { // Are they fully identical? if (this == Y) return true; - // They may still be identical element-wise (if they have `undef`s). - auto *Cy = dyn_cast<Constant>(Y); - if (!Cy) + + // The input value must be a vector constant with the same type. + Type *Ty = getType(); + if (!isa<Constant>(Y) || !Ty->isVectorTy() || Ty != Y->getType()) return false; + + // They may still be identical element-wise (if they have `undef`s). + // FIXME: This crashes on FP vector constants. return match(ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_EQ, - const_cast<Constant *>(this), Cy), + const_cast<Constant *>(this), + cast<Constant>(Y)), m_One()); } diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp index 4a8fcaa48e3..8a3336210e2 100644 --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -585,5 +585,44 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) { Instruction::And, TheConstantExpr, TheConstant)->isNullValue()); } +// Check that undefined elements in vector constants are matched +// correctly for both integer and floating-point types. + +TEST(ConstantsTest, isElementWiseEqual) { + LLVMContext Context; + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + Constant *C1211 = ConstantVector::get({C1, C2, C1, C1}); + Constant *C12U1 = ConstantVector::get({C1, C2, CU, C1}); + Constant *C12U2 = ConstantVector::get({C1, C2, CU, C2}); + Constant *C12U21 = ConstantVector::get({C1, C2, CU, C2, C1}); + + EXPECT_TRUE(C1211->isElementWiseEqual(C12U1)); + EXPECT_TRUE(C12U1->isElementWiseEqual(C1211)); + EXPECT_FALSE(C12U2->isElementWiseEqual(C12U1)); + EXPECT_FALSE(C12U1->isElementWiseEqual(C12U2)); + EXPECT_FALSE(C12U21->isElementWiseEqual(C12U2)); + +/* FIXME: This will crash. + Type *FltTy = Type::getFloatTy(Context); + Constant *CFU = UndefValue::get(FltTy); + Constant *CF1 = ConstantFP::get(FltTy, 1.0); + Constant *CF2 = ConstantFP::get(FltTy, 2.0); + + Constant *CF1211 = ConstantVector::get({CF1, CF2, CF1, CF1}); + Constant *CF12U1 = ConstantVector::get({CF1, CF2, CFU, CF1}); + Constant *CF12U2 = ConstantVector::get({CF1, CF2, CFU, CF2}); + + EXPECT_TRUE(CF1211->isElementWiseEqual(CF12U1)); + EXPECT_TRUE(CF12U1->isElementWiseEqual(CF1211)); + EXPECT_FALSE(CF12U2->isElementWiseEqual(CF12U1)); + EXPECT_FALSE(CF12U1->isElementWiseEqual(CF12U2)); +*/ +} + } // end anonymous namespace } // end namespace llvm |