Added support for fast reciprocal square root computation.

This commit is contained in:
Benoit Steiner 2015-02-26 09:42:41 -08:00
parent 8e817b65d0
commit f41b1f1666
8 changed files with 72 additions and 6 deletions

View File

@ -58,6 +58,7 @@ struct default_packet_traits
HasDiv = 0, HasDiv = 0,
HasSqrt = 0, HasSqrt = 0,
HasRsqrt = 0,
HasExp = 0, HasExp = 0,
HasLog = 0, HasLog = 0,
HasPow = 0, HasPow = 0,
@ -352,6 +353,14 @@ Packet plog(const Packet& a) { using std::log; return log(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet psqrt(const Packet& a) { using std::sqrt; return sqrt(a); } Packet psqrt(const Packet& a) { using std::sqrt; return sqrt(a); }
/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet prsqrt(const Packet& a) {
using std::sqrt;
const Packet one(1);
return one/sqrt(a);
}
/*************************************************************************** /***************************************************************************
* The following functions might not have to be overwritten for vectorized types * The following functions might not have to be overwritten for vectorized types
***************************************************************************/ ***************************************************************************/

View File

@ -300,16 +300,29 @@ psqrt<Packet8f>(const Packet8f& _x) {
return pmul(_x, x); return pmul(_x, x);
} }
#else #else
template <> template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
EIGEN_STRONG_INLINE Packet8f psqrt<Packet8f>(const Packet8f& x) { Packet8f psqrt<Packet8f>(const Packet8f& x) {
return _mm256_sqrt_ps(x); return _mm256_sqrt_ps(x);
} }
#endif #endif
template <> template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
EIGEN_STRONG_INLINE Packet4d psqrt<Packet4d>(const Packet4d& x) { Packet4d psqrt<Packet4d>(const Packet4d& x) {
return _mm256_sqrt_pd(x); return _mm256_sqrt_pd(x);
} }
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f prsqrt<Packet8f>(const Packet8f& x) {
_EIGEN_DECLARE_CONST_Packet8f(one, 1.0f);
return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x));
}
#endif
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4d prsqrt<Packet4d>(const Packet4d& x) {
_EIGEN_DECLARE_CONST_Packet4d(one, 1.0);
return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x));
}
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen

View File

@ -65,6 +65,7 @@ template<> struct packet_traits<float> : default_packet_traits
HasLog = 1, HasLog = 1,
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1 HasBlend = 1
}; };
}; };
@ -81,6 +82,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasDiv = 1, HasDiv = 1,
HasExp = 0, HasExp = 0,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1 HasBlend = 1
}; };
}; };

View File

@ -462,11 +462,25 @@ Packet4f psqrt<Packet4f>(const Packet4f& _x)
#else #else
template<> EIGEN_STRONG_INLINE Packet4f psqrt<Packet4f>(const Packet4f& x) { return _mm_sqrt_ps(x); } template<>EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& x) { return _mm_sqrt_ps(x); }
#endif #endif
template<> EIGEN_STRONG_INLINE Packet2d psqrt<Packet2d>(const Packet2d& x) { return _mm_sqrt_pd(x); } template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d psqrt<Packet2d>(const Packet2d& x) { return _mm_sqrt_pd(x); }
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
// Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation.
return _mm_div_ps(pset1<Packet4f>(1.0f), _mm_sqrt_ps(x));
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet2d prsqrt<Packet2d>(const Packet2d& x) {
// Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation.
return _mm_div_pd(pset1<Packet2d>(1.0), _mm_sqrt_pd(x));
}
} // end namespace internal } // end namespace internal

View File

@ -108,6 +108,7 @@ template<> struct packet_traits<float> : default_packet_traits
HasLog = 1, HasLog = 1,
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1 HasBlend = 1
}; };
}; };
@ -124,6 +125,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasDiv = 1, HasDiv = 1,
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1,
HasBlend = 1 HasBlend = 1
}; };
}; };

View File

@ -224,6 +224,25 @@ struct functor_traits<scalar_sqrt_op<Scalar> >
}; };
}; };
/** \internal
* \brief Template functor to compute the reciprocal square root of a scalar
* \sa class CwiseUnaryOp, Cwise::rsqrt()
*/
template<typename Scalar> struct scalar_rsqrt_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { using std::sqrt; return Scalar(1)/sqrt(a); }
typedef typename packet_traits<Scalar>::type Packet;
inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); }
};
template<typename Scalar>
struct functor_traits<scalar_rsqrt_op<Scalar> >
{ enum {
Cost = 5 * NumTraits<Scalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasRsqrt
};
};
/** \internal /** \internal
* \brief Template functor to compute the cosine of a scalar * \brief Template functor to compute the cosine of a scalar
* \sa class CwiseUnaryOp, ArrayBase::cos() * \sa class CwiseUnaryOp, ArrayBase::cos()

View File

@ -189,6 +189,7 @@ template<typename Scalar> struct scalar_imag_op;
template<typename Scalar> struct scalar_abs_op; template<typename Scalar> struct scalar_abs_op;
template<typename Scalar> struct scalar_abs2_op; template<typename Scalar> struct scalar_abs2_op;
template<typename Scalar> struct scalar_sqrt_op; template<typename Scalar> struct scalar_sqrt_op;
template<typename Scalar> struct scalar_rsqrt_op;
template<typename Scalar> struct scalar_exp_op; template<typename Scalar> struct scalar_exp_op;
template<typename Scalar> struct scalar_log_op; template<typename Scalar> struct scalar_log_op;
template<typename Scalar> struct scalar_cos_op; template<typename Scalar> struct scalar_cos_op;

View File

@ -77,6 +77,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
return unaryExpr(internal::scalar_sqrt_op<Scalar>()); return unaryExpr(internal::scalar_sqrt_op<Scalar>());
} }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_rsqrt_op<Scalar>, const Derived>
rsqrt() const {
return unaryExpr(internal::scalar_rsqrt_op<Scalar>());
}
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_square_op<Scalar>, const Derived>
square() const { square() const {