Made it possible to run the lgamma, erf, and erfc functors on a CUDA gpu.

This commit is contained in:
Benoit Steiner 2015-12-21 15:20:06 -08:00
parent 1c3e78319d
commit 3504ae47ca

View File

@ -415,7 +415,7 @@ template<typename Scalar> struct scalar_lgamma_op {
using numext::lgamma; return lgamma(a); using numext::lgamma; return lgamma(a);
} }
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
inline Packet packetOp(const Packet& a) const { return internal::plgamma(a); } EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::plgamma(a); }
}; };
template<typename Scalar> template<typename Scalar>
struct functor_traits<scalar_lgamma_op<Scalar> > struct functor_traits<scalar_lgamma_op<Scalar> >
@ -438,7 +438,7 @@ template<typename Scalar> struct scalar_erf_op {
using numext::erf; return erf(a); using numext::erf; return erf(a);
} }
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
inline Packet packetOp(const Packet& a) const { return internal::perf(a); } EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::perf(a); }
}; };
template<typename Scalar> template<typename Scalar>
struct functor_traits<scalar_erf_op<Scalar> > struct functor_traits<scalar_erf_op<Scalar> >
@ -461,7 +461,7 @@ template<typename Scalar> struct scalar_erfc_op {
using numext::erfc; return erfc(a); using numext::erfc; return erfc(a);
} }
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
inline Packet packetOp(const Packet& a) const { return internal::perfc(a); } EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::perfc(a); }
}; };
template<typename Scalar> template<typename Scalar>
struct functor_traits<scalar_erfc_op<Scalar> > struct functor_traits<scalar_erfc_op<Scalar> >
@ -732,10 +732,10 @@ struct functor_traits<scalar_boolean_not_op<Scalar> > {
* \sa class CwiseUnaryOp, Cwise::sign() * \sa class CwiseUnaryOp, Cwise::sign()
*/ */
template<typename Scalar,bool iscpx=(NumTraits<Scalar>::IsComplex!=0) > struct scalar_sign_op; template<typename Scalar,bool iscpx=(NumTraits<Scalar>::IsComplex!=0) > struct scalar_sign_op;
template<typename Scalar> template<typename Scalar>
struct scalar_sign_op<Scalar,false> { struct scalar_sign_op<Scalar,false> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{ {
return Scalar( (a>Scalar(0)) - (a<Scalar(0)) ); return Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
} }
@ -743,17 +743,17 @@ struct scalar_sign_op<Scalar,false> {
//template <typename Packet> //template <typename Packet>
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); } //EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
}; };
template<typename Scalar> template<typename Scalar>
struct scalar_sign_op<Scalar,true> { struct scalar_sign_op<Scalar,true> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{ {
using std::abs; using std::abs;
typedef typename NumTraits<Scalar>::Real real_type; typedef typename NumTraits<Scalar>::Real real_type;
real_type aa = abs(a); real_type aa = abs(a);
if (aa==0) if (aa==0)
return Scalar(0); return Scalar(0);
aa = 1./aa; aa = 1./aa;
return Scalar(real(a)*aa, imag(a)*aa ); return Scalar(real(a)*aa, imag(a)*aa );
} }
//TODO //TODO