Added support for fast reciprocal square root computation.

This commit is contained in:
Benoit Steiner 2015-02-26 09:42:41 -08:00
parent 8e817b65d0
commit f41b1f1666
8 changed files with 72 additions and 6 deletions

View File

@ -58,6 +58,7 @@ struct default_packet_traits
HasDiv = 0,
HasSqrt = 0,
HasRsqrt = 0,
HasExp = 0,
HasLog = 0,
HasPow = 0,
@ -352,6 +353,14 @@ Packet plog(const Packet& a) { using std::log; return log(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet psqrt(const Packet& a) { using std::sqrt; return sqrt(a); }
/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet prsqrt(const Packet& a) {
using std::sqrt;
const Packet one(1);
return one/sqrt(a);
}
/***************************************************************************
* The following functions might not have to be overwritten for vectorized types
***************************************************************************/

View File

@ -300,16 +300,29 @@ psqrt<Packet8f>(const Packet8f& _x) {
return pmul(_x, x);
}
#else
template <>
EIGEN_STRONG_INLINE Packet8f psqrt<Packet8f>(const Packet8f& x) {
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f psqrt<Packet8f>(const Packet8f& x) {
return _mm256_sqrt_ps(x);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet4d psqrt<Packet4d>(const Packet4d& x) {
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4d psqrt<Packet4d>(const Packet4d& x) {
return _mm256_sqrt_pd(x);
}
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f prsqrt<Packet8f>(const Packet8f& x) {
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4d prsqrt<Packet4d>(const Packet4d& x) {
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
}
} // end namespace internal
} // end namespace Eigen

View File

@ -65,6 +65,7 @@ template<> struct packet_traits<float> : default_packet_traits
HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1
};
};
@ -81,6 +82,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasDiv = 1,
HasExp = 0,
HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1
};
};

View File

@ -462,11 +462,25 @@ Packet4f psqrt<Packet4f>(const Packet4f& _x)
#else
template<> EIGEN_STRONG_INLINE Packet4f psqrt<Packet4f>(const Packet4f& x) { return _mm_sqrt_ps(x); }
template<>EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& x) { return _mm_sqrt_ps(x); }
#endif
template<> EIGEN_STRONG_INLINE Packet2d psqrt<Packet2d>(const Packet2d& x) { return _mm_sqrt_pd(x); }
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d psqrt<Packet2d>(const Packet2d& x) { return _mm_sqrt_pd(x); }
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
// Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation.
return _mm_div_ps(pset1<Packet4f>(1.0f), _mm_sqrt_ps(x));
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
// Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return _mm_div_pd(pset1<Packet2d>(1.0), _mm_sqrt_pd(x));
}
} // end namespace internal

View File

@ -108,6 +108,7 @@ template<> struct packet_traits<float> : default_packet_traits
HasLog = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1
};
};
@ -124,6 +125,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasDiv = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1
};
};

View File

@ -224,6 +224,25 @@ struct functor_traits<scalar_sqrt_op<Scalar> >
};
};
/** \internal
* \brief Template functor to compute the reciprocal square root of a scalar
* \sa class CwiseUnaryOp, Cwise::rsqrt()
*/
template<typename Scalar> struct scalar_rsqrt_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { using std::sqrt; return Scalar(1)/sqrt(a); }
typedef typename packet_traits<Scalar>::type Packet;
inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); }
};
template<typename Scalar>
struct functor_traits<scalar_rsqrt_op<Scalar> >
{ enum {
Cost = 5 * NumTraits<Scalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasRsqrt
};
};
/** \internal
* \brief Template functor to compute the cosine of a scalar
* \sa class CwiseUnaryOp, ArrayBase::cos()

View File

@ -189,6 +189,7 @@ template<typename Scalar> struct scalar_imag_op;
template<typename Scalar> struct scalar_abs_op;
template<typename Scalar> struct scalar_abs2_op;
template<typename Scalar> struct scalar_sqrt_op;
template<typename Scalar> struct scalar_rsqrt_op;
template<typename Scalar> struct scalar_exp_op;
template<typename Scalar> struct scalar_log_op;
template<typename Scalar> struct scalar_cos_op;

View File

@ -77,6 +77,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_sqrt_op<Scalar>());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_rsqrt_op<Scalar>, const Derived>
rsqrt() const {
return unaryExpr(internal::scalar_rsqrt_op<Scalar>());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived>
square() const {