diff options
-rw-r--r-- | llvm/include/llvm/ADT/APInt.h | 2 | ||||
-rw-r--r-- | llvm/lib/Support/APInt.cpp | 21 | ||||
-rw-r--r-- | llvm/unittests/ADT/APIntTest.cpp | 18 |
3 files changed, 41 insertions, 0 deletions
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 8dce5a621bb..ddf0e19ffde 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -1109,6 +1109,8 @@ public: APInt uadd_sat(const APInt &RHS) const; APInt ssub_sat(const APInt &RHS) const; APInt usub_sat(const APInt &RHS) const; + APInt smul_sat(const APInt &RHS) const; + APInt umul_sat(const APInt &RHS) const; /// Array-indexing support. /// diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 758fe8b4f86..df2e6197b2a 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -2048,6 +2048,27 @@ APInt APInt::usub_sat(const APInt &RHS) const { return APInt(BitWidth, 0); } +APInt APInt::smul_sat(const APInt &RHS) const { + bool Overflow; + APInt Res = smul_ov(RHS, Overflow); + if (!Overflow) + return Res; + + // The result is negative if one and only one of inputs is negative. + bool ResIsNegative = isNegative() ^ RHS.isNegative(); + + return ResIsNegative ? APInt::getSignedMinValue(BitWidth) + : APInt::getSignedMaxValue(BitWidth); +} + +APInt APInt::umul_sat(const APInt &RHS) const { + bool Overflow; + APInt Res = umul_ov(RHS, Overflow); + if (!Overflow) + return Res; + + return APInt::getMaxValue(BitWidth); +} void APInt::fromString(unsigned numbits, StringRef str, uint8_t radix) { // Check our assumptions here diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index a58d31439e7..f45a4e6b52c 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -1197,6 +1197,24 @@ TEST(APIntTest, SaturatingMath) { EXPECT_EQ(APInt(8, 127), AP_100.ssub_sat(-AP_100)); EXPECT_EQ(APInt(8, -128), (-AP_100).ssub_sat(AP_100)); EXPECT_EQ(APInt(8, -128), APInt(8, -128).ssub_sat(APInt(8, 127))); + + EXPECT_EQ(APInt(8, 250), APInt(8, 50).umul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, 255), APInt(8, 50).umul_sat(APInt(8, 6))); + EXPECT_EQ(APInt(8, 255), APInt(8, -128).umul_sat(APInt(8, 3))); + EXPECT_EQ(APInt(8, 255), APInt(8, 3).umul_sat(APInt(8, -128))); + EXPECT_EQ(APInt(8, 255), APInt(8, -128).umul_sat(APInt(8, -128))); + + EXPECT_EQ(APInt(8, 125), APInt(8, 25).smul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, 127), APInt(8, 25).smul_sat(APInt(8, 6))); + EXPECT_EQ(APInt(8, 127), APInt(8, 127).smul_sat(APInt(8, 127))); + EXPECT_EQ(APInt(8, -125), APInt(8, -25).smul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, -125), APInt(8, 25).smul_sat(APInt(8, -5))); + EXPECT_EQ(APInt(8, 125), APInt(8, -25).smul_sat(APInt(8, -5))); + EXPECT_EQ(APInt(8, 125), APInt(8, 25).smul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, -128), APInt(8, -25).smul_sat(APInt(8, 6))); + EXPECT_EQ(APInt(8, -128), APInt(8, 25).smul_sat(APInt(8, -6))); + EXPECT_EQ(APInt(8, 127), APInt(8, -25).smul_sat(APInt(8, -6))); + EXPECT_EQ(APInt(8, 127), APInt(8, 25).smul_sat(APInt(8, 6))); } TEST(APIntTest, FromArray) { |