From 5de0f2f89e5240df5eab223b9a8efc3a63322941 Mon Sep 17 00:00:00 2001 From: Kyle Macfarlan Date: Wed, 25 Oct 2023 03:06:13 +0000 Subject: [PATCH] Fixes #2735: Component-wise cbrt --- Eigen/src/Core/Assign_MKL.h | 1 + Eigen/src/Core/GenericPacketMath.h | 4 ++++ Eigen/src/Core/GlobalFunctions.h | 1 + Eigen/src/Core/MathFunctions.h | 8 ++++++++ Eigen/src/Core/functors/UnaryFunctors.h | 14 ++++++++++++++ Eigen/src/Core/util/ForwardDeclarations.h | 1 + Eigen/src/plugins/ArrayCwiseUnaryOps.inc | 19 ++++++++++++++++++- Eigen/src/plugins/MatrixCwiseUnaryOps.inc | 16 +++++++++++++++- doc/AsciiQuickReference.txt | 2 ++ doc/CoeffwiseMathFunctionsTable.dox | 13 +++++++++++++ doc/QuickReference.dox | 2 ++ doc/SparseQuickReference.dox | 1 + doc/snippets/Cwise_cbrt.cpp | 2 ++ test/array_cwise.cpp | 5 +++++ 14 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 doc/snippets/Cwise_cbrt.cpp diff --git a/Eigen/src/Core/Assign_MKL.h b/Eigen/src/Core/Assign_MKL.h index 0fda71cdf..448dae2b6 100644 --- a/Eigen/src/Core/Assign_MKL.h +++ b/Eigen/src/Core/Assign_MKL.h @@ -140,6 +140,7 @@ EIGEN_MKL_VML_DECLARE_UNARY_CALLS_CPLX(arg, Arg, _) EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(round, Round, _) EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(floor, Floor, _) EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(ceil, Ceil, _) +EIGEN_MKL_VML_DECLARE_UNARY_CALLS_REAL(cbrt, Cbrt, _) #define EIGEN_MKL_VML_DECLARE_POW_CALL(EIGENOP, VMLOP, EIGENTYPE, VMLTYPE, VMLMODE) \ template< typename DstXprType, typename SrcXprNested, typename Plain> \ diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 8e9902e1c..d04f7138a 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -1017,6 +1017,10 @@ Packet plog2(const Packet& a) { template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt(const Packet& a) { return numext::sqrt(a); } +/** \internal \returns the cube-root of \a a (coeff-wise) */ +template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet pcbrt(const Packet& a) { return numext::cbrt(a); } + /** \internal \returns the rounded value of \a a (coeff-wise) */ template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pround(const Packet& a) { using numext::round; return round(a); } diff --git a/Eigen/src/Core/GlobalFunctions.h b/Eigen/src/Core/GlobalFunctions.h index 9818b863c..2d8bb8061 100644 --- a/Eigen/src/Core/GlobalFunctions.h +++ b/Eigen/src/Core/GlobalFunctions.h @@ -89,6 +89,7 @@ namespace Eigen EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(arg,scalar_arg_op,complex argument,\sa ArrayBase::arg DOXCOMMA MatrixBase::cwiseArg) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(carg, scalar_carg_op, complex argument, \sa ArrayBase::carg DOXCOMMA MatrixBase::cwiseCArg) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sqrt,scalar_sqrt_op,square root,\sa ArrayBase::sqrt DOXCOMMA MatrixBase::cwiseSqrt) + EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cbrt,scalar_cbrt_op,cube root,\sa ArrayBase::cbrt DOXCOMMA MatrixBase::cwiseCbrt) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rsqrt,scalar_rsqrt_op,reciprocal square root,\sa ArrayBase::rsqrt) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(square,scalar_square_op,square (power 2),\sa Eigen::abs2 DOXCOMMA Eigen::pow DOXCOMMA ArrayBase::square) EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(cube,scalar_cube_op,cube (power 3),\sa Eigen::pow DOXCOMMA ArrayBase::cube) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 0f5a0fd5c..6f2fd6d05 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -1394,6 +1394,14 @@ bool sqrt(const bool &x) { return x; } SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(sqrt, sqrt) #endif +/** \returns the cube root of \a x. **/ +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +T cbrt(const T &x) { + EIGEN_USING_STD(cbrt); + return static_cast(cbrt(x)); +} + /** \returns the reciprocal square root of \a x. **/ template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 89cd6772d..d988eb1d7 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -471,6 +471,20 @@ struct functor_traits > { enum { Cost = 1, PacketAccess = packet_traits::Vectorizable }; }; +/** \internal + * \brief Template functor to compute the cube root of a scalar + * \sa class CwiseUnaryOp, Cwise::sqrt() + */ +template +struct scalar_cbrt_op { + EIGEN_DEVICE_FUNC inline const Scalar operator()(const Scalar& a) const { return numext::cbrt(a); } +}; + +template +struct functor_traits > { + enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; +}; + /** \internal * \brief Template functor to compute the reciprocal square root of a scalar * \sa class CwiseUnaryOp, Cwise::rsqrt() diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 2a4126c45..8bff87dd2 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -185,6 +185,7 @@ template struct scalar_abs_op; template struct scalar_abs2_op; template struct scalar_absolute_difference_op; template struct scalar_sqrt_op; +template struct scalar_cbrt_op; template struct scalar_rsqrt_op; template struct scalar_exp_op; template struct scalar_log_op; diff --git a/Eigen/src/plugins/ArrayCwiseUnaryOps.inc b/Eigen/src/plugins/ArrayCwiseUnaryOps.inc index 301a900a1..b0123261d 100644 --- a/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +++ b/Eigen/src/plugins/ArrayCwiseUnaryOps.inc @@ -5,6 +5,7 @@ typedef CwiseUnaryOp, const Derived> ArgReturnTy typedef CwiseUnaryOp, const Derived> CArgReturnType; typedef CwiseUnaryOp, const Derived> Abs2ReturnType; typedef CwiseUnaryOp, const Derived> SqrtReturnType; +typedef CwiseUnaryOp, const Derived> CbrtReturnType; typedef CwiseUnaryOp, const Derived> RsqrtReturnType; typedef CwiseUnaryOp, const Derived> SignReturnType; typedef CwiseUnaryOp, const Derived> InverseReturnType; @@ -184,7 +185,7 @@ log2() const * Example: \include Cwise_sqrt.cpp * Output: \verbinclude Cwise_sqrt.out * - * \sa Math functions, pow(), square() + * \sa Math functions, pow(), square(), cbrt() */ EIGEN_DEVICE_FUNC inline const SqrtReturnType @@ -193,6 +194,22 @@ sqrt() const return SqrtReturnType(derived()); } +/** \returns an expression of the coefficient-wise cube root of *this. + * + * This function computes the coefficient-wise cube root. + * + * Example: \include Cwise_cbrt.cpp + * Output: \verbinclude Cwise_cbrt.out + * + * \sa Math functions, sqrt(), pow(), square() + */ +EIGEN_DEVICE_FUNC +inline const CbrtReturnType +cbrt() const +{ + return CbrtReturnType(derived()); +} + /** \returns an expression of the coefficient-wise inverse square root of *this. * * This function computes the coefficient-wise inverse square root. diff --git a/Eigen/src/plugins/MatrixCwiseUnaryOps.inc b/Eigen/src/plugins/MatrixCwiseUnaryOps.inc index 0222137ae..cb65e171f 100644 --- a/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +++ b/Eigen/src/plugins/MatrixCwiseUnaryOps.inc @@ -17,6 +17,7 @@ typedef CwiseUnaryOp, const Derived> CwiseAbs2R typedef CwiseUnaryOp, const Derived> CwiseArgReturnType; typedef CwiseUnaryOp, const Derived> CwiseCArgReturnType; typedef CwiseUnaryOp, const Derived> CwiseSqrtReturnType; +typedef CwiseUnaryOp, const Derived> CwiseCbrtReturnType; typedef CwiseUnaryOp, const Derived> CwiseSignReturnType; typedef CwiseUnaryOp, const Derived> CwiseInverseReturnType; @@ -53,12 +54,25 @@ cwiseAbs2() const { return CwiseAbs2ReturnType(derived()); } /// EIGEN_DOC_UNARY_ADDONS(cwiseSqrt,square-root) /// -/// \sa cwisePow(), cwiseSquare() +/// \sa cwisePow(), cwiseSquare(), cwiseCbrt() /// EIGEN_DEVICE_FUNC inline const CwiseSqrtReturnType cwiseSqrt() const { return CwiseSqrtReturnType(derived()); } +/// \returns an expression of the coefficient-wise cube root of *this. +/// +/// Example: \include MatrixBase_cwiseCbrt.cpp +/// Output: \verbinclude MatrixBase_cwiseCbrt.out +/// +EIGEN_DOC_UNARY_ADDONS(cwiseCbrt,cube-root) +/// +/// \sa cwiseSqrt(), cwiseSquare(), cwisePow() +/// +EIGEN_DEVICE_FUNC +inline const CwiseCbrtReturnType +cwiseCbrt() const { return CwiseSCbrtReturnType(derived()); } + /// \returns an expression of the coefficient-wise signum of *this. /// /// Example: \include MatrixBase_cwiseSign.cpp diff --git a/doc/AsciiQuickReference.txt b/doc/AsciiQuickReference.txt index 18b4446c6..0256c813a 100644 --- a/doc/AsciiQuickReference.txt +++ b/doc/AsciiQuickReference.txt @@ -140,6 +140,8 @@ R.array().square() // P .^ 2 R.array().cube() // P .^ 3 R.cwiseSqrt() // sqrt(P) R.array().sqrt() // sqrt(P) +R.cwiseCbrt() // cbrt(P) +R.array().cbrt() // cbrt(P) R.array().exp() // exp(P) R.array().log() // log(P) R.cwiseMax(P) // max(R, P) diff --git a/doc/CoeffwiseMathFunctionsTable.dox b/doc/CoeffwiseMathFunctionsTable.dox index 3f5c56446..934a95a11 100644 --- a/doc/CoeffwiseMathFunctionsTable.dox +++ b/doc/CoeffwiseMathFunctionsTable.dox @@ -170,6 +170,19 @@ This also means that, unless specified, if the function \c std::foo is available sqrt(a[i]); SSE2, AVX (f,d) + + + \anchor cwisetable_cbrt + a.\link ArrayBase::cbrt cbrt\endlink(); \n + \link Eigen::cbrt cbrt\endlink(a);\n + m.\link MatrixBase::cwiseCbrt cwiseCbrt\endlink(); + + computes cube root (\f$ \cbrt a_i \f$) + + using std::cbrt; \n + cbrt(a[i]); + + \anchor cwisetable_rsqrt diff --git a/doc/QuickReference.dox b/doc/QuickReference.dox index e96b61767..c61d47afa 100644 --- a/doc/QuickReference.dox +++ b/doc/QuickReference.dox @@ -458,6 +458,7 @@ mat1.cwiseMax(mat2) mat1.cwiseMax(scalar) mat1.cwiseAbs2() mat1.cwiseAbs() mat1.cwiseSqrt() +mat1.cwiseCbrt() mat1.cwiseInverse() mat1.cwiseProduct(mat2) mat1.cwiseQuotient(mat2) @@ -470,6 +471,7 @@ mat1.array().max(mat2.array()) mat1.array().max(scalar) mat1.array().abs2() mat1.array().abs() mat1.array().sqrt() +mat1.array().cbrt() mat1.array().inverse() mat1.array() * mat2.array() mat1.array() / mat2.array() diff --git a/doc/SparseQuickReference.dox b/doc/SparseQuickReference.dox index 68ac2dc8a..66288cd14 100644 --- a/doc/SparseQuickReference.dox +++ b/doc/SparseQuickReference.dox @@ -172,6 +172,7 @@ sm2 = perm * sm1; // Permute the columns sm1.cwiseMax(sm2); sm1.cwiseAbs(); sm1.cwiseSqrt(); + sm1.cwiseCbrt(); \endcode sm1 and sm2 should have the same storage order diff --git a/doc/snippets/Cwise_cbrt.cpp b/doc/snippets/Cwise_cbrt.cpp new file mode 100644 index 000000000..a58c76cc4 --- /dev/null +++ b/doc/snippets/Cwise_cbrt.cpp @@ -0,0 +1,2 @@ +Array3d v(1,2,4); +cout << v.cbrt() << endl; diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 058b721fa..bfea96ace 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -169,7 +169,9 @@ void unary_op_test(std::string name, Fn fun, RefFn ref) { template void unary_ops_test() { + unary_op_test(UNARY_FUNCTOR_TEST_ARGS(sqrt)); + unary_op_test(UNARY_FUNCTOR_TEST_ARGS(cbrt)); unary_op_test(UNARY_FUNCTOR_TEST_ARGS(exp)); unary_op_test(UNARY_FUNCTOR_TEST_ARGS(log)); unary_op_test(UNARY_FUNCTOR_TEST_ARGS(sin)); @@ -821,6 +823,7 @@ template void array_real(const ArrayType& m) m3 = m4.abs(); VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m3))); + VERIFY_IS_APPROX(m3.cbrt(), cbrt(m3)); VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m3))); VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m3))); VERIFY_IS_APPROX(m3.log(), log(m3)); @@ -882,6 +885,8 @@ template void array_real(const ArrayType& m) VERIFY_IS_APPROX(m3.pow(RealScalar(0.5)), m3.sqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(0.5)), m3.sqrt()); + VERIFY_IS_APPROX(m3.pow(RealScalar(1.0/3.0)), m3.cbrt()); + VERIFY_IS_APPROX(pow(m3,RealScalar(1.0/3.0)), m3.cbrt()); VERIFY_IS_APPROX(m3.pow(RealScalar(-0.5)), m3.rsqrt()); VERIFY_IS_APPROX(pow(m3,RealScalar(-0.5)), m3.rsqrt());