diff options
-rw-r--r-- | llvm/include/llvm/IR/PatternMatch.h | 78 | ||||
-rw-r--r-- | llvm/unittests/IR/PatternMatch.cpp | 108 |
2 files changed, 186 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 5ac5aab4a0b..304b84bc85d 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -979,6 +979,84 @@ m_SelectCst(const Cond &C) { } //===----------------------------------------------------------------------===// +// Matchers for InsertElementInst classes +// + +template <typename Val_t, typename Elt_t, typename Idx_t> +struct InsertElementClass_match { + Val_t V; + Elt_t E; + Idx_t I; + + InsertElementClass_match(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx) + : V(Val), E(Elt), I(Idx) {} + + template <typename OpTy> bool match(OpTy *VV) { + if (auto *II = dyn_cast<InsertElementInst>(VV)) + return V.match(II->getOperand(0)) && E.match(II->getOperand(1)) && + I.match(II->getOperand(2)); + return false; + } +}; + +template <typename Val_t, typename Elt_t, typename Idx_t> +inline InsertElementClass_match<Val_t, Elt_t, Idx_t> +m_InsertElement(const Val_t &Val, const Elt_t &Elt, const Idx_t &Idx) { + return InsertElementClass_match<Val_t, Elt_t, Idx_t>(Val, Elt, Idx); +} + +//===----------------------------------------------------------------------===// +// Matchers for ExtractElementInst classes +// + +template <typename Val_t, typename Idx_t> struct ExtractElementClass_match { + Val_t V; + Idx_t I; + + ExtractElementClass_match(const Val_t &Val, const Idx_t &Idx) + : V(Val), I(Idx) {} + + template <typename OpTy> bool match(OpTy *VV) { + if (auto *II = dyn_cast<ExtractElementInst>(VV)) + return V.match(II->getOperand(0)) && I.match(II->getOperand(1)); + return false; + } +}; + +template <typename Val_t, typename Idx_t> +inline ExtractElementClass_match<Val_t, Idx_t> +m_ExtractElement(const Val_t &Val, const Idx_t &Idx) { + return ExtractElementClass_match<Val_t, Idx_t>(Val, Idx); +} + +//===----------------------------------------------------------------------===// +// Matchers for ShuffleVectorInst classes +// + +template <typename V1_t, typename V2_t, typename Mask_t> +struct ShuffleVectorClass_match { + V1_t V1; + V2_t V2; + Mask_t M; + + ShuffleVectorClass_match(const V1_t &v1, const V2_t &v2, const Mask_t &m) + : V1(v1), V2(v2), M(m) {} + + template <typename OpTy> bool match(OpTy *V) { + if (auto *SI = dyn_cast<ShuffleVectorInst>(V)) + return V1.match(SI->getOperand(0)) && V2.match(SI->getOperand(1)) && + M.match(SI->getOperand(2)); + return false; + } +}; + +template <typename V1_t, typename V2_t, typename Mask_t> +inline ShuffleVectorClass_match<V1_t, V2_t, Mask_t> +m_ShuffleVector(const V1_t &v1, const V2_t &v2, const Mask_t &m) { + return ShuffleVectorClass_match<V1_t, V2_t, Mask_t>(v1, v2, m); +} + +//===----------------------------------------------------------------------===// // Matchers for CastInst classes // diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index 5c13ba6ecd9..6bdcf4de7e2 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -340,6 +340,114 @@ TEST_F(PatternMatchTest, OverflowingBinOps) { EXPECT_FALSE(m_NUWShl(m_Value(), m_Value()).match(IRB.CreateNUWAdd(L, R))); } +TEST_F(PatternMatchTest, VectorOps) { + // Build up small tree of vector operations + // + // Val = 0 + 1 + // Val2 = Val + 3 + // VI1 = insertelement <2 x i8> undef, i8 1, i32 0 = <1, undef> + // VI2 = insertelement <2 x i8> %VI1, i8 %Val2, i8 %Val = <1, 4> + // VI3 = insertelement <2 x i8> %VI1, i8 %Val2, i32 1 = <1, 4> + // VI4 = insertelement <2 x i8> %VI1, i8 2, i8 %Val = <1, 2> + // + // SI1 = shufflevector <2 x i8> %VI1, <2 x i8> undef, zeroinitializer + // SI2 = shufflevector <2 x i8> %VI3, <2 x i8> %VI4, <2 x i8> <i8 0, i8 2> + // SI3 = shufflevector <2 x i8> %VI3, <2 x i8> undef, zeroinitializer + // SI4 = shufflevector <2 x i8> %VI4, <2 x i8> undef, zeroinitializer + // + // SP1 = VectorSplat(2, i8 2) + // SP2 = VectorSplat(2, i8 %Val) + Type *VecTy = VectorType::get(IRB.getInt8Ty(), 2); + Type *i32 = IRB.getInt32Ty(); + Type *i32VecTy = VectorType::get(i32, 2); + + Value *Val = IRB.CreateAdd(IRB.getInt8(0), IRB.getInt8(1)); + Value *Val2 = IRB.CreateAdd(Val, IRB.getInt8(3)); + + SmallVector<Constant *, 2> VecElemIdxs; + VecElemIdxs.push_back(ConstantInt::get(i32, 0)); + VecElemIdxs.push_back(ConstantInt::get(i32, 2)); + auto *IdxVec = ConstantVector::get(VecElemIdxs); + + Value *UndefVec = UndefValue::get(VecTy); + Value *VI1 = IRB.CreateInsertElement(UndefVec, IRB.getInt8(1), (uint64_t)0); + Value *VI2 = IRB.CreateInsertElement(VI1, Val2, Val); + Value *VI3 = IRB.CreateInsertElement(VI1, Val2, (uint64_t)1); + Value *VI4 = IRB.CreateInsertElement(VI1, IRB.getInt8(2), Val); + + Value *EX1 = IRB.CreateExtractElement(VI4, Val); + Value *EX2 = IRB.CreateExtractElement(VI4, (uint64_t)0); + Value *EX3 = IRB.CreateExtractElement(IdxVec, (uint64_t)1); + + Value *Zero = ConstantAggregateZero::get(i32VecTy); + Value *SI1 = IRB.CreateShuffleVector(VI1, UndefVec, Zero); + Value *SI2 = IRB.CreateShuffleVector(VI3, VI4, IdxVec); + Value *SI3 = IRB.CreateShuffleVector(VI3, UndefVec, Zero); + Value *SI4 = IRB.CreateShuffleVector(VI4, UndefVec, Zero); + + Value *SP1 = IRB.CreateVectorSplat(2, IRB.getInt8(2)); + Value *SP2 = IRB.CreateVectorSplat(2, Val); + + Value *A = nullptr, *B = nullptr, *C = nullptr; + + // Test matching insertelement + EXPECT_TRUE(match(VI1, m_InsertElement(m_Value(), m_Value(), m_Value()))); + EXPECT_TRUE( + match(VI1, m_InsertElement(m_Undef(), m_ConstantInt(), m_ConstantInt()))); + EXPECT_TRUE( + match(VI1, m_InsertElement(m_Undef(), m_ConstantInt(), m_Zero()))); + EXPECT_TRUE( + match(VI1, m_InsertElement(m_Undef(), m_SpecificInt(1), m_Zero()))); + EXPECT_TRUE(match(VI2, m_InsertElement(m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE( + match(VI2, m_InsertElement(m_Value(), m_Value(), m_ConstantInt()))); + EXPECT_FALSE( + match(VI2, m_InsertElement(m_Value(), m_ConstantInt(), m_Value()))); + EXPECT_FALSE(match(VI2, m_InsertElement(m_Constant(), m_Value(), m_Value()))); + EXPECT_TRUE(match(VI3, m_InsertElement(m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_TRUE(A == VI1); + EXPECT_TRUE(B == Val2); + EXPECT_TRUE(isa<ConstantInt>(C)); + A = B = C = nullptr; // reset + + // Test matching extractelement + EXPECT_TRUE(match(EX1, m_ExtractElement(m_Value(A), m_Value(B)))); + EXPECT_TRUE(A == VI4); + EXPECT_TRUE(B == Val); + A = B = C = nullptr; // reset + EXPECT_FALSE(match(EX1, m_ExtractElement(m_Value(), m_ConstantInt()))); + EXPECT_TRUE(match(EX2, m_ExtractElement(m_Value(), m_ConstantInt()))); + EXPECT_TRUE(match(EX3, m_ExtractElement(m_Constant(), m_ConstantInt()))); + + // Test matching shufflevector + EXPECT_TRUE(match(SI1, m_ShuffleVector(m_Value(), m_Undef(), m_Zero()))); + EXPECT_TRUE(match(SI2, m_ShuffleVector(m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_TRUE(A == VI3); + EXPECT_TRUE(B == VI4); + EXPECT_TRUE(C == IdxVec); + A = B = C = nullptr; // reset + + // Test matching the vector splat pattern + EXPECT_TRUE(match( + SI1, + m_ShuffleVector(m_InsertElement(m_Undef(), m_SpecificInt(1), m_Zero()), + m_Undef(), m_Zero()))); + EXPECT_FALSE(match( + SI3, m_ShuffleVector(m_InsertElement(m_Undef(), m_Value(), m_Zero()), + m_Undef(), m_Zero()))); + EXPECT_FALSE(match( + SI4, m_ShuffleVector(m_InsertElement(m_Undef(), m_Value(), m_Zero()), + m_Undef(), m_Zero()))); + EXPECT_TRUE(match( + SP1, + m_ShuffleVector(m_InsertElement(m_Undef(), m_SpecificInt(2), m_Zero()), + m_Undef(), m_Zero()))); + EXPECT_TRUE(match( + SP2, m_ShuffleVector(m_InsertElement(m_Undef(), m_Value(A), m_Zero()), + m_Undef(), m_Zero()))); + EXPECT_TRUE(A == Val); +} + template <typename T> struct MutableConstTest : PatternMatchTest { }; typedef ::testing::Types<std::tuple<Value*, Instruction*>, |