mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 02:29:33 +08:00
Add log1p support for CUDA and half floats
This commit is contained in:
parent
72096f3bd4
commit
aee693ac52
@ -391,6 +391,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
|
|||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
|
||||||
return half(::logf(float(a)));
|
return half(::logf(float(a)));
|
||||||
}
|
}
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
|
||||||
|
return half(::log1pf(float(a)));
|
||||||
|
}
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
|
||||||
return half(::log10f(float(a)));
|
return half(::log10f(float(a)));
|
||||||
}
|
}
|
||||||
|
@ -31,6 +31,18 @@ double2 plog<double2>(const double2& a)
|
|||||||
return make_double2(log(a.x), log(a.y));
|
return make_double2(log(a.x), log(a.y));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
float4 plog1p<float4>(const float4& a)
|
||||||
|
{
|
||||||
|
return make_float4(log1pf(a.x), log1pf(a.y), log1pf(a.z), log1pf(a.w));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
double2 plog1p<double2>(const double2& a)
|
||||||
|
{
|
||||||
|
return make_double2(log1p(a.x), log1p(a.y));
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
float4 pexp<float4>(const float4& a)
|
float4 pexp<float4>(const float4& a)
|
||||||
{
|
{
|
||||||
|
@ -34,7 +34,8 @@ template<> struct packet_traits<Eigen::half> : default_packet_traits
|
|||||||
HasSqrt = 1,
|
HasSqrt = 1,
|
||||||
HasRsqrt = 1,
|
HasRsqrt = 1,
|
||||||
HasExp = 1,
|
HasExp = 1,
|
||||||
HasLog = 1
|
HasLog = 1,
|
||||||
|
HasLog1p = 1
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -273,6 +274,11 @@ half2 plog<half2>(const half2& a) {
|
|||||||
return h2log(a);
|
return h2log(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
half2 plog1p<half2>(const half2& a) {
|
||||||
|
return h2log1p(a);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
half2 pexp<half2>(const half2& a) {
|
half2 pexp<half2>(const half2& a) {
|
||||||
return h2exp(a);
|
return h2exp(a);
|
||||||
@ -298,6 +304,14 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog<half2>(const half2&
|
|||||||
return __floats2half2_rn(r1, r2);
|
return __floats2half2_rn(r1, r2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p<half2>(const half2& a) {
|
||||||
|
float a1 = __low2float(a);
|
||||||
|
float a2 = __high2float(a);
|
||||||
|
float r1 = log1pf(a1);
|
||||||
|
float r2 = log1pf(a2);
|
||||||
|
return __floats2half2_rn(r1, r2);
|
||||||
|
}
|
||||||
|
|
||||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp<half2>(const half2& a) {
|
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp<half2>(const half2& a) {
|
||||||
float a1 = __low2float(a);
|
float a1 = __low2float(a);
|
||||||
float a2 = __high2float(a);
|
float a2 = __high2float(a);
|
||||||
|
@ -189,6 +189,11 @@ void test_basic_functions()
|
|||||||
VERIFY_IS_EQUAL(float(log(half(1.0f))), 0.0f);
|
VERIFY_IS_EQUAL(float(log(half(1.0f))), 0.0f);
|
||||||
VERIFY_IS_APPROX(float(numext::log(half(10.0f))), 2.30273f);
|
VERIFY_IS_APPROX(float(numext::log(half(10.0f))), 2.30273f);
|
||||||
VERIFY_IS_APPROX(float(log(half(10.0f))), 2.30273f);
|
VERIFY_IS_APPROX(float(log(half(10.0f))), 2.30273f);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(float(numext::log1p(half(0.0f))), 0.0f);
|
||||||
|
VERIFY_IS_EQUAL(float(log1p(half(0.0f))), 0.0f);
|
||||||
|
VERIFY_IS_APPROX(float(numext::log1p(half(10.0f))), 2.3978953f);
|
||||||
|
VERIFY_IS_APPROX(float(log1p(half(10.0f))), 2.3978953f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_trigonometric_functions()
|
void test_trigonometric_functions()
|
||||||
|
@ -402,6 +402,7 @@ template<typename Scalar> void packetmath_real()
|
|||||||
data1[internal::random<int>(0, PacketSize)] = 0;
|
data1[internal::random<int>(0, PacketSize)] = 0;
|
||||||
CHECK_CWISE1_IF(PacketTraits::HasSqrt, std::sqrt, internal::psqrt);
|
CHECK_CWISE1_IF(PacketTraits::HasSqrt, std::sqrt, internal::psqrt);
|
||||||
CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog);
|
CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog);
|
||||||
|
CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p);
|
||||||
#if EIGEN_HAS_C99_MATH && (__cplusplus > 199711L)
|
#if EIGEN_HAS_C99_MATH && (__cplusplus > 199711L)
|
||||||
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasLGamma, std::lgamma, internal::plgamma);
|
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasLGamma, std::lgamma, internal::plgamma);
|
||||||
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErf, std::erf, internal::perf);
|
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErf, std::erf, internal::perf);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user