diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h index 82eb9c709..dd4a2c4dd 100644 --- a/Eigen/src/Core/Dot.h +++ b/Eigen/src/Core/Dot.h @@ -41,6 +41,20 @@ struct dot_nocheck { } }; +template ::Scalar> +struct squared_norm_impl { + using Real = typename NumTraits::Real; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Real run(const Derived& a) { + Scalar result = a.unaryExpr(squared_norm_functor()).sum(); + return numext::real(result) + numext::imag(result); + } +}; + +template +struct squared_norm_impl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const Derived& a) { return a.any(); } +}; + } // end namespace internal /** \fn MatrixBase::dot @@ -85,7 +99,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename NumTraits::Scalar>::Real MatrixBase::squaredNorm() const { - return numext::real((*this).cwiseAbs2().sum()); + return internal::squared_norm_impl::run(derived()); } /** \returns, for vectors, the \em l2 norm of \c *this, and for matrices the Frobenius norm. diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 5059a5408..b3b7d79d6 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -103,6 +103,26 @@ struct functor_traits> { enum { Cost = NumTraits::MulCost, PacketAccess = packet_traits::HasAbs2 }; }; +template ::IsComplex> +struct squared_norm_functor { + typedef Scalar result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { + return Scalar(numext::real(a) * numext::real(a), numext::imag(a) * numext::imag(a)); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const { + return Packet(pmul(a.v, a.v)); + } +}; +template +struct squared_norm_functor : scalar_abs2_op {}; + +template +struct functor_traits> { + using Real = typename NumTraits::Real; + enum { Cost = NumTraits::MulCost, PacketAccess = packet_traits::HasMul }; +}; + /** \internal * \brief Template functor to compute the conjugate of a complex value *