diff options
Diffstat (limited to 'libcxx')
-rw-r--r-- | libcxx/include/numeric | 29 | ||||
-rw-r--r-- | libcxx/test/std/numerics/numeric.ops/numeric.ops.midpoint/midpoint.float.pass.cpp | 21 |
2 files changed, 34 insertions, 16 deletions
diff --git a/libcxx/include/numeric b/libcxx/include/numeric index ba2fe2696a9..2118704d57f 100644 --- a/libcxx/include/numeric +++ b/libcxx/include/numeric @@ -460,10 +460,10 @@ iota(_ForwardIterator __first, _ForwardIterator __last, _Tp __value_) #if _LIBCPP_STD_VER > 14 -template <typename _Result, typename _Source, bool _IsSigned = is_signed<_Source>::value> struct __abs; +template <typename _Result, typename _Source, bool _IsSigned = is_signed<_Source>::value> struct __ct_abs; template <typename _Result, typename _Source> -struct __abs<_Result, _Source, true> { +struct __ct_abs<_Result, _Source, true> { _LIBCPP_CONSTEXPR _LIBCPP_INLINE_VISIBILITY _Result operator()(_Source __t) const noexcept { @@ -474,7 +474,7 @@ struct __abs<_Result, _Source, true> { }; template <typename _Result, typename _Source> -struct __abs<_Result, _Source, false> { +struct __ct_abs<_Result, _Source, false> { _LIBCPP_CONSTEXPR _LIBCPP_INLINE_VISIBILITY _Result operator()(_Source __t) const noexcept { return __t; } }; @@ -500,8 +500,8 @@ gcd(_Tp __m, _Up __n) using _Rp = common_type_t<_Tp,_Up>; using _Wp = make_unsigned_t<_Rp>; return static_cast<_Rp>(_VSTD::__gcd( - static_cast<_Wp>(__abs<_Rp, _Tp>()(__m)), - static_cast<_Wp>(__abs<_Rp, _Up>()(__n)))); + static_cast<_Wp>(__ct_abs<_Rp, _Tp>()(__m)), + static_cast<_Wp>(__ct_abs<_Rp, _Up>()(__n)))); } template<class _Tp, class _Up> @@ -516,8 +516,8 @@ lcm(_Tp __m, _Up __n) return 0; using _Rp = common_type_t<_Tp,_Up>; - _Rp __val1 = __abs<_Rp, _Tp>()(__m) / _VSTD::gcd(__m, __n); - _Rp __val2 = __abs<_Rp, _Up>()(__n); + _Rp __val1 = __ct_abs<_Rp, _Tp>()(__m) / _VSTD::gcd(__m, __n); + _Rp __val2 = __ct_abs<_Rp, _Up>()(__n); _LIBCPP_ASSERT((numeric_limits<_Rp>::max() / __val1 > __val2), "Overflow in lcm"); return __val1 * __val2; } @@ -563,16 +563,23 @@ constexpr int __sign(_Tp __val) { return (_Tp(0) < __val) - (__val < _Tp(0)); } +template <typename _Fp> +constexpr _Fp __fp_abs(_Fp __f) { return __f >= 0 ? __f : -__f; } + template <class _Fp> _LIBCPP_INLINE_VISIBILITY constexpr enable_if_t<is_floating_point_v<_Fp>, _Fp> midpoint(_Fp __a, _Fp __b) noexcept { - return isnormal(__a) && isnormal(__b) - && ((__sign(__a) != __sign(__b)) || ((numeric_limits<_Fp>::max() - abs(__a)) < abs(__b))) - ? __a / 2 + __b / 2 - : (__a + __b) / 2; + constexpr _Fp __lo = numeric_limits<_Fp>::min()*2; + constexpr _Fp __hi = numeric_limits<_Fp>::max()/2; + return __fp_abs(__a) <= __hi && __fp_abs(__b) <= __hi ? // typical case: overflow is impossible + (__a + __b)/2 : // always correctly rounded + __fp_abs(__a) < __lo ? __a + __b/2 : // not safe to halve a + __fp_abs(__a) < __lo ? __a/2 + __b : // not safe to halve b + __a/2 + __b/2; // otherwise correctly rounded } + #endif // _LIBCPP_STD_VER > 17 _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/test/std/numerics/numeric.ops/numeric.ops.midpoint/midpoint.float.pass.cpp b/libcxx/test/std/numerics/numeric.ops/numeric.ops.midpoint/midpoint.float.pass.cpp index 1a967f7a26e..7d98348881b 100644 --- a/libcxx/test/std/numerics/numeric.ops/numeric.ops.midpoint/midpoint.float.pass.cpp +++ b/libcxx/test/std/numerics/numeric.ops/numeric.ops.midpoint/midpoint.float.pass.cpp @@ -43,11 +43,11 @@ void fp_test() constexpr T minV = std::numeric_limits<T>::min(); // Things that can be compared exactly - assert((std::midpoint(T(0), T(0)) == T(0))); - assert((std::midpoint(T(2), T(4)) == T(3))); - assert((std::midpoint(T(4), T(2)) == T(3))); - assert((std::midpoint(T(3), T(4)) == T(3.5))); - assert((std::midpoint(T(0), T(0.4)) == T(0.2))); + static_assert((std::midpoint(T(0), T(0)) == T(0)), ""); + static_assert((std::midpoint(T(2), T(4)) == T(3)), ""); + static_assert((std::midpoint(T(4), T(2)) == T(3)), ""); + static_assert((std::midpoint(T(3), T(4)) == T(3.5)), ""); + static_assert((std::midpoint(T(0), T(0.4)) == T(0.2)), ""); // Things that can't be compared exactly constexpr T pct = fp_error_pct<T>(); @@ -70,6 +70,17 @@ void fp_test() assert((fptest_close_pct(std::midpoint(T(0), minV), minV/2, pct))); assert((fptest_close_pct(std::midpoint(maxV, maxV), maxV, pct))); assert((fptest_close_pct(std::midpoint(minV, minV), minV, pct))); + assert((fptest_close_pct(std::midpoint(maxV, minV), maxV/2, pct))); + assert((fptest_close_pct(std::midpoint(minV, maxV), maxV/2, pct))); + +// Near the min and the max + assert((fptest_close_pct(std::midpoint(maxV*T(0.75), maxV*T(0.50)), maxV*T(0.625), pct))); + assert((fptest_close_pct(std::midpoint(maxV*T(0.50), maxV*T(0.75)), maxV*T(0.625), pct))); + assert((fptest_close_pct(std::midpoint(minV*T(2), minV*T(8)), minV*T(5), pct))); + +// Big numbers of different signs + assert((fptest_close_pct(std::midpoint(maxV*T( 0.75), maxV*T(-0.5)), maxV*T( 0.125), pct))); + assert((fptest_close_pct(std::midpoint(maxV*T(-0.75), maxV*T( 0.5)), maxV*T(-0.125), pct))); // Denormalized values // TODO |