Fix bug #562: add vector-wise normalized and normalize functions

This commit is contained in:
Gael Guennebaud 2013-04-09 11:12:35 +02:00
parent d8f1035355
commit 3cb6e21f80
2 changed files with 63 additions and 0 deletions

View File

@ -233,6 +233,28 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
Direction==Vertical ? 1 : m_matrix.rows(),
Direction==Horizontal ? 1 : m_matrix.cols());
}
template<typename OtherDerived> struct OppositeExtendedType {
typedef Replicate<OtherDerived,
Direction==Horizontal ? 1 : ExpressionType::RowsAtCompileTime,
Direction==Vertical ? 1 : ExpressionType::ColsAtCompileTime> Type;
};
/** \internal
* Replicates a vector in the opposite direction to match the size of \c *this */
template<typename OtherDerived>
typename OppositeExtendedType<OtherDerived>::Type
extendedToOpposite(const DenseBase<OtherDerived>& other) const
{
EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(Direction==Horizontal, OtherDerived::MaxColsAtCompileTime==1),
YOU_PASSED_A_ROW_VECTOR_BUT_A_COLUMN_VECTOR_WAS_EXPECTED)
EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(Direction==Vertical, OtherDerived::MaxRowsAtCompileTime==1),
YOU_PASSED_A_COLUMN_VECTOR_BUT_A_ROW_VECTOR_WAS_EXPECTED)
return typename OppositeExtendedType<OtherDerived>::Type
(other.derived(),
Direction==Horizontal ? 1 : m_matrix.rows(),
Direction==Vertical ? 1 : m_matrix.cols());
}
public:
@ -504,6 +526,23 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
return m_matrix / extendedTo(other.derived());
}
/** \returns an expression where each column of row of the referenced matrix are normalized.
* The referenced matrix is \b not modified.
* \sa MatrixBase::normalized(), normalize()
*/
CwiseBinaryOp<internal::scalar_quotient_op<Scalar>,
const ExpressionTypeNestedCleaned,
const typename OppositeExtendedType<typename ReturnType<internal::member_norm,RealScalar>::Type>::Type>
normalized() const { return m_matrix.cwiseQuotient(extendedToOpposite(this->norm())); }
/** Normalize in-place each row or columns of the referenced matrix.
* \sa MatrixBase::normalize(), normalized()
*/
void normalize() {
m_matrix = this->normalized();
}
/////////// Geometry module ///////////

View File

@ -111,6 +111,8 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType;
typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType;
typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealColVectorType;
typedef Matrix<RealScalar, 1, MatrixType::ColsAtCompileTime> RealRowVectorType;
Index rows = m.rows();
Index cols = m.cols();
@ -123,6 +125,8 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
ColVectorType colvec = ColVectorType::Random(rows);
RowVectorType rowvec = RowVectorType::Random(cols);
RealColVectorType rcres;
RealRowVectorType rrres;
// test addition
@ -159,6 +163,26 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
// test norm
rrres = m1.colwise().norm();
VERIFY_IS_APPROX(rrres(c), m1.col(c).norm());
rcres = m1.rowwise().norm();
VERIFY_IS_APPROX(rcres(r), m1.row(r).norm());
// test normalized
m2 = m1.colwise().normalized();
VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
m2 = m1.rowwise().normalized();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
// test normalize
m2 = m1;
m2.colwise().normalize();
VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
m2 = m1;
m2.rowwise().normalize();
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
}
void test_vectorwiseop()