From 20048319416c4677c304d116253ceb7a6eb6c5c7 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Tue, 13 Dec 2022 00:54:57 +0000 Subject: [PATCH] add EqualSpaced / setEqualSpaced --- Eigen/src/Core/CwiseNullaryOp.h | 27 +++++++++++++++++++ Eigen/src/Core/DenseBase.h | 9 +++++++ Eigen/src/Core/functors/NullaryFunctors.h | 33 +++++++++++++++++++++++ test/nullary.cpp | 15 ++++++----- 4 files changed, 78 insertions(+), 6 deletions(-) diff --git a/Eigen/src/Core/CwiseNullaryOp.h b/Eigen/src/Core/CwiseNullaryOp.h index a62f54d4c..b33c052c3 100644 --- a/Eigen/src/Core/CwiseNullaryOp.h +++ b/Eigen/src/Core/CwiseNullaryOp.h @@ -306,6 +306,20 @@ DenseBase::LinSpaced(const Scalar& low, const Scalar& high) return DenseBase::NullaryExpr(Derived::SizeAtCompileTime, internal::linspaced_op(low,high,Derived::SizeAtCompileTime)); } +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase::RandomAccessEqualSpacedReturnType +DenseBase::EqualSpaced(Index size, const Scalar& low, const Scalar& step) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return DenseBase::NullaryExpr(size, internal::equalspaced_op(low, step)); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const typename DenseBase::RandomAccessEqualSpacedReturnType +DenseBase::EqualSpaced(const Scalar& low, const Scalar& step) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return DenseBase::NullaryExpr(Derived::SizeAtCompileTime, internal::equalspaced_op(low, step)); +} + /** \returns true if all coefficients in this matrix are approximately equal to \a val, to within precision \a prec */ template EIGEN_DEVICE_FUNC bool DenseBase::isApproxToConstant @@ -455,6 +469,19 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::setLinSpaced( return setLinSpaced(size(), low, high); } +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::setEqualSpaced(Index newSize, const Scalar& low, + const Scalar& step) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return derived() = Derived::NullaryExpr(newSize, internal::equalspaced_op(low, step)); +} +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase::setEqualSpaced(const Scalar& low, + const Scalar& step) { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) + return setEqualSpaced(size(), low, step); +} + // zero: /** \returns an expression of a zero matrix. diff --git a/Eigen/src/Core/DenseBase.h b/Eigen/src/Core/DenseBase.h index 6e177793f..24dc69703 100644 --- a/Eigen/src/Core/DenseBase.h +++ b/Eigen/src/Core/DenseBase.h @@ -258,6 +258,8 @@ template class DenseBase EIGEN_DEPRECATED typedef CwiseNullaryOp,PlainObject> SequentialLinSpacedReturnType; /** \internal Represents a vector with linearly spaced coefficients that allows random access. */ typedef CwiseNullaryOp,PlainObject> RandomAccessLinSpacedReturnType; + /** \internal Represents a vector with equally spaced coefficients that allows random access. */ + typedef CwiseNullaryOp, PlainObject> RandomAccessEqualSpacedReturnType; /** \internal the return type of MatrixBase::eigenvalues() */ typedef Matrix::Scalar>::Real, internal::traits::ColsAtCompileTime, 1> EigenvaluesReturnType; @@ -336,6 +338,11 @@ template class DenseBase EIGEN_DEVICE_FUNC static const RandomAccessLinSpacedReturnType LinSpaced(const Scalar& low, const Scalar& high); + EIGEN_DEVICE_FUNC static const RandomAccessEqualSpacedReturnType + EqualSpaced(Index size, const Scalar& low, const Scalar& step); + EIGEN_DEVICE_FUNC static const RandomAccessEqualSpacedReturnType + EqualSpaced(const Scalar& low, const Scalar& step); + template EIGEN_DEVICE_FUNC static const CwiseNullaryOp NullaryExpr(Index rows, Index cols, const CustomNullaryOp& func); @@ -357,6 +364,8 @@ template class DenseBase EIGEN_DEVICE_FUNC Derived& setConstant(const Scalar& value); EIGEN_DEVICE_FUNC Derived& setLinSpaced(Index size, const Scalar& low, const Scalar& high); EIGEN_DEVICE_FUNC Derived& setLinSpaced(const Scalar& low, const Scalar& high); + EIGEN_DEVICE_FUNC Derived& setEqualSpaced(Index size, const Scalar& low, const Scalar& step); + EIGEN_DEVICE_FUNC Derived& setEqualSpaced(const Scalar& low, const Scalar& step); EIGEN_DEVICE_FUNC Derived& setZero(); EIGEN_DEVICE_FUNC Derived& setOnes(); EIGEN_DEVICE_FUNC Derived& setRandom(); diff --git a/Eigen/src/Core/functors/NullaryFunctors.h b/Eigen/src/Core/functors/NullaryFunctors.h index e099d4ac1..f18974b27 100644 --- a/Eigen/src/Core/functors/NullaryFunctors.h +++ b/Eigen/src/Core/functors/NullaryFunctors.h @@ -145,6 +145,39 @@ template struct linspaced_op const linspaced_op_impl::IsInteger> impl; }; +template +struct equalspaced_op { + typedef typename NumTraits::Real RealScalar; + + EIGEN_DEVICE_FUNC equalspaced_op(const Scalar& start, const Scalar& step) : m_start(start), m_step(step) {} + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(IndexType i) const { + return m_start + m_step * static_cast(i); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(IndexType i) const { + const Packet cst_start = pset1(m_start); + const Packet cst_step = pset1(m_step); + const Packet cst_lin0 = plset(Scalar(0)); + const Packet cst_offset = pmadd(cst_lin0, cst_step, cst_start); + + Packet istep = pset1(static_cast(i) * m_step); + return padd(cst_offset, istep); + } + const Scalar m_start; + const Scalar m_step; +}; + +template +struct functor_traits > { + enum { + Cost = NumTraits::AddCost + NumTraits::MulCost, + PacketAccess = + packet_traits::HasSetLinear && packet_traits::HasMul && packet_traits::HasAdd, + IsRepeatable = true + }; +}; + // Linear access is automatically determined from the operator() prototypes available for the given functor. // If it exposes an operator()(i,j), then we assume the i and j coefficients are required independently // and linear access is not possible. In all other cases, linear access is enabled. diff --git a/test/nullary.cpp b/test/nullary.cpp index 2c4d93806..e524837b1 100644 --- a/test/nullary.cpp +++ b/test/nullary.cpp @@ -78,8 +78,9 @@ void testVectorType(const VectorType& base) const Scalar step = ((size == 1) ? 1 : (high-low)/RealScalar(size-1)); // check whether the result yields what we expect it to do - VectorType m(base); + VectorType m(base), o(base); m.setLinSpaced(size,low,high); + o.setEqualSpaced(size, low, step); if(!NumTraits::IsInteger) { @@ -87,6 +88,7 @@ void testVectorType(const VectorType& base) for (int i=0; i