This commit is contained in:
Artem Bishev 2025-08-07 16:58:22 +00:00 committed by Rasmus Munk Larsen
parent 8b9dbcdaaf
commit ddce1d7d12
2 changed files with 49 additions and 13 deletions

View File

@ -146,6 +146,22 @@ struct member_redux {
const BinaryOp& binaryFunc() const { return m_functor; }
const BinaryOp m_functor;
};
template <typename Scalar>
struct scalar_replace_zero_with_one_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& x) const {
return numext::is_exactly_zero(x) ? Scalar(1) : x;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
return pselect(pcmp_eq(x, pzero(x)), pset1<Packet>(Scalar(1)), x);
}
};
template <typename Scalar>
struct functor_traits<scalar_replace_zero_with_one_op<Scalar>> {
enum { Cost = 1, PacketAccess = packet_traits<Scalar>::HasCmp };
};
} // namespace internal
/** \class VectorwiseOp
@ -624,18 +640,28 @@ class VectorwiseOp {
return m_matrix / extendedTo(other.derived());
}
using Normalized_NonzeroNormType =
CwiseUnaryOp<internal::scalar_replace_zero_with_one_op<Scalar>, const NormReturnType>;
using NormalizedReturnType = CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const ExpressionTypeNestedCleaned,
const typename OppositeExtendedType<Normalized_NonzeroNormType>::Type>;
/** \returns an expression where each column (or row) of the referenced matrix are normalized.
* The referenced matrix is \b not modified.
*
* \warning If the input columns (or rows) are too small (i.e., their norm equals to 0), they remain unchanged in the
* resulting expression.
*
* \sa MatrixBase::normalized(), normalize()
*/
EIGEN_DEVICE_FUNC CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const ExpressionTypeNestedCleaned,
const typename OppositeExtendedType<NormReturnType>::Type>
normalized() const {
return m_matrix.cwiseQuotient(extendedToOpposite(this->norm()));
EIGEN_DEVICE_FUNC NormalizedReturnType normalized() const {
return m_matrix.cwiseQuotient(extendedToOpposite(Normalized_NonzeroNormType(this->norm())));
}
/** Normalize in-place each row or columns of the referenced matrix.
* \sa MatrixBase::normalize(), normalized()
*
* \warning If the input columns (or rows) are too small (i.e., their norm equals to 0), they are left unchanged.
*
* \sa MatrixBase::normalized(), normalize()
*/
EIGEN_DEVICE_FUNC void normalize() { m_matrix = this->normalized(); }

View File

@ -114,6 +114,8 @@ void vectorwiseop_matrix(const MatrixType& m) {
RealColVectorType rcres;
RealRowVectorType rrres;
Scalar small_scalar = (std::numeric_limits<RealScalar>::min)();
// test broadcast assignment
m2 = m1;
m2.colwise() = colvec;
@ -171,18 +173,26 @@ void vectorwiseop_matrix(const MatrixType& m) {
VERIFY_IS_APPROX(m1.cwiseAbs().colwise().sum().x(), m1.col(0).cwiseAbs().sum());
// test normalized
m2 = m1.colwise().normalized();
VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
m2 = m1.rowwise().normalized();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
m2 = m1;
m2.col(c).fill(small_scalar);
m3 = m2.colwise().normalized();
for (Index k = 0; k < cols; ++k) VERIFY_IS_APPROX(m3.col(k), m2.col(k).normalized());
m2 = m1;
m2.row(r).setZero();
m3 = m2.rowwise().normalized();
for (Index k = 0; k < rows; ++k) VERIFY_IS_APPROX(m3.row(k), m2.row(k).normalized());
// test normalize
m2 = m1;
m2.colwise().normalize();
VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
m2.col(c).setZero();
m3 = m2;
m3.colwise().normalize();
for (Index k = 0; k < cols; ++k) VERIFY_IS_APPROX(m3.col(k), m2.col(k).normalized());
m2 = m1;
m2.rowwise().normalize();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
m2.row(r).fill(small_scalar);
m3 = m2;
m3.rowwise().normalize();
for (Index k = 0; k < rows; ++k) VERIFY_IS_APPROX(m3.row(k), m2.row(k).normalized());
// test with partial reduction of products
Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> m1m1 = m1 * m1.transpose();