Added packet primitives to compute exp, log, sqrt and rsqrt on fp16. This improves the performance by 10 to 30%.

This commit is contained in:
Benoit Steiner 2016-05-10 11:05:33 -07:00
parent 6bf8273bc0
commit 0b9e3dcd06

View File

@ -34,7 +34,11 @@ template<> struct packet_traits<half> : default_packet_traits
AlignedOnScalar = 1,
size=2,
HasHalfPacket = 0,
HasDiv = 1
HasDiv = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasExp = 1,
HasLog = 1
};
};
@ -267,6 +271,38 @@ template<> EIGEN_DEVICE_FUNC inline half predux_mul<half2>(const half2& a) {
#endif
}
template<> EIGEN_DEVICE_FUNC inline half2 plog<half2>(const half2& a) {
float a1 = __low2float(a);
float a2 = __high2float(a);
float r1 = logf(a1);
float r2 = logf(a2);
return __floats2half2_rn(r1, r2);
}
template<> EIGEN_DEVICE_FUNC inline half2 pexp<half2>(const half2& a) {
float a1 = __low2float(a);
float a2 = __high2float(a);
float r1 = expf(a1);
float r2 = expf(a2);
return __floats2half2_rn(r1, r2);
}
template<> EIGEN_DEVICE_FUNC inline half2 psqrt<half2>(const half2& a) {
float a1 = __low2float(a);
float a2 = __high2float(a);
float r1 = sqrtf(a1);
float r2 = sqrtf(a2);
return __floats2half2_rn(r1, r2);
}
template<> EIGEN_DEVICE_FUNC inline half2 prsqrt<half2>(const half2& a) {
float a1 = __low2float(a);
float a2 = __high2float(a);
float r1 = rsqrtf(a1);
float r2 = rsqrtf(a2);
return __floats2half2_rn(r1, r2);
}
} // end namespace internal
} // end namespace Eigen