Fix bug #562: add vector-wise normalized and normalize functions

This commit is contained in:
Gael Guennebaud 2013-04-09 11:12:35 +02:00
parent d8f1035355
commit 3cb6e21f80
2 changed files with 63 additions and 0 deletions

View File

@ -234,6 +234,28 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
Direction==Horizontal ? 1 : m_matrix.cols()); Direction==Horizontal ? 1 : m_matrix.cols());
} }
template<typename OtherDerived> struct OppositeExtendedType {
typedef Replicate<OtherDerived,
Direction==Horizontal ? 1 : ExpressionType::RowsAtCompileTime,
Direction==Vertical ? 1 : ExpressionType::ColsAtCompileTime> Type;
};
/** \internal
* Replicates a vector in the opposite direction to match the size of \c *this */
template<typename OtherDerived>
typename OppositeExtendedType<OtherDerived>::Type
extendedToOpposite(const DenseBase<OtherDerived>& other) const
{
EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(Direction==Horizontal, OtherDerived::MaxColsAtCompileTime==1),
YOU_PASSED_A_ROW_VECTOR_BUT_A_COLUMN_VECTOR_WAS_EXPECTED)
EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(Direction==Vertical, OtherDerived::MaxRowsAtCompileTime==1),
YOU_PASSED_A_COLUMN_VECTOR_BUT_A_ROW_VECTOR_WAS_EXPECTED)
return typename OppositeExtendedType<OtherDerived>::Type
(other.derived(),
Direction==Horizontal ? 1 : m_matrix.rows(),
Direction==Vertical ? 1 : m_matrix.cols());
}
public: public:
inline VectorwiseOp(ExpressionType& matrix) : m_matrix(matrix) {} inline VectorwiseOp(ExpressionType& matrix) : m_matrix(matrix) {}
@ -505,6 +527,23 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
return m_matrix / extendedTo(other.derived()); return m_matrix / extendedTo(other.derived());
} }
/** \returns an expression where each column of row of the referenced matrix are normalized.
* The referenced matrix is \b not modified.
* \sa MatrixBase::normalized(), normalize()
*/
CwiseBinaryOp<internal::scalar_quotient_op<Scalar>,
const ExpressionTypeNestedCleaned,
const typename OppositeExtendedType<typename ReturnType<internal::member_norm,RealScalar>::Type>::Type>
normalized() const { return m_matrix.cwiseQuotient(extendedToOpposite(this->norm())); }
/** Normalize in-place each row or columns of the referenced matrix.
* \sa MatrixBase::normalize(), normalized()
*/
void normalize() {
m_matrix = this->normalized();
}
/////////// Geometry module /////////// /////////// Geometry module ///////////
#if EIGEN2_SUPPORT_STAGE > STAGE20_RESOLVE_API_CONFLICTS #if EIGEN2_SUPPORT_STAGE > STAGE20_RESOLVE_API_CONFLICTS

View File

@ -111,6 +111,8 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType; typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType;
typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType; typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType;
typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealColVectorType;
typedef Matrix<RealScalar, 1, MatrixType::ColsAtCompileTime> RealRowVectorType;
Index rows = m.rows(); Index rows = m.rows();
Index cols = m.cols(); Index cols = m.cols();
@ -123,6 +125,8 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
ColVectorType colvec = ColVectorType::Random(rows); ColVectorType colvec = ColVectorType::Random(rows);
RowVectorType rowvec = RowVectorType::Random(cols); RowVectorType rowvec = RowVectorType::Random(cols);
RealColVectorType rcres;
RealRowVectorType rrres;
// test addition // test addition
@ -159,6 +163,26 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose()); VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose()); VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
// test norm
rrres = m1.colwise().norm();
VERIFY_IS_APPROX(rrres(c), m1.col(c).norm());
rcres = m1.rowwise().norm();
VERIFY_IS_APPROX(rcres(r), m1.row(r).norm());
// test normalized
m2 = m1.colwise().normalized();
VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
m2 = m1.rowwise().normalized();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
// test normalize
m2 = m1;
m2.colwise().normalize();
VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
m2 = m1;
m2.rowwise().normalize();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
} }
void test_vectorwiseop() void test_vectorwiseop()