diff --git a/Eigen/src/Core/DiagonalMatrix.h b/Eigen/src/Core/DiagonalMatrix.h index f27ab798a..da0264b0e 100644 --- a/Eigen/src/Core/DiagonalMatrix.h +++ b/Eigen/src/Core/DiagonalMatrix.h @@ -20,6 +20,7 @@ class DiagonalBase : public EigenBase public: typedef typename internal::traits::DiagonalVectorType DiagonalVectorType; typedef typename DiagonalVectorType::Scalar Scalar; + typedef typename DiagonalVectorType::RealScalar RealScalar; typedef typename internal::traits::StorageKind StorageKind; typedef typename internal::traits::Index Index; @@ -65,6 +66,17 @@ class DiagonalBase : public EigenBase return diagonal().cwiseInverse(); } + inline const DiagonalWrapper, const DiagonalVectorType> > + operator*(const Scalar& scalar) const + { + return diagonal() * scalar; + } + friend inline const DiagonalWrapper, const DiagonalVectorType> > + operator*(const Scalar& scalar, const DiagonalBase& other) + { + return other.diagonal() * scalar; + } + #ifdef EIGEN2_SUPPORT template bool isApprox(const DiagonalBase& other, typename NumTraits::Real precision = NumTraits::dummy_precision()) const diff --git a/test/diagonalmatrices.cpp b/test/diagonalmatrices.cpp index 3f5776dfc..7e9c80d7b 100644 --- a/test/diagonalmatrices.cpp +++ b/test/diagonalmatrices.cpp @@ -32,6 +32,8 @@ template void diagonalmatrices(const MatrixType& m) rv2 = RowVectorType::Random(cols); LeftDiagonalMatrix ldm1(v1), ldm2(v2); RightDiagonalMatrix rdm1(rv1), rdm2(rv2); + + Scalar s1 = internal::random(); SquareMatrixType sq_m1 (v1.asDiagonal()); VERIFY_IS_APPROX(sq_m1, v1.asDiagonal().toDenseMatrix()); @@ -76,6 +78,13 @@ template void diagonalmatrices(const MatrixType& m) big.block(i,j,rows,cols) = big.block(i,j,rows,cols) * rv1.asDiagonal(); VERIFY_IS_APPROX((big.block(i,j,rows,cols)) , m1 * rv1.asDiagonal() ); + + // scalar multiple + VERIFY_IS_APPROX(LeftDiagonalMatrix(ldm1*s1).diagonal(), ldm1.diagonal() * s1); + VERIFY_IS_APPROX(LeftDiagonalMatrix(s1*ldm1).diagonal(), s1 * ldm1.diagonal()); + + VERIFY_IS_APPROX(m1 * (rdm1 * s1), (m1 * rdm1) * s1); + VERIFY_IS_APPROX(m1 * (s1 * rdm1), (m1 * rdm1) * s1); } void test_diagonalmatrices()