Provide DiagonalMatrix Product and Initializers

This commit is contained in:
Arthur 2022-06-06 21:43:22 +00:00 committed by Rasmus Munk Larsen
parent 76cf6204f3
commit 14aae29470
2 changed files with 37 additions and 0 deletions

View File

@ -65,6 +65,16 @@ class DiagonalBase : public EigenBase<Derived>
return Product<Derived, MatrixDerived, LazyProduct>(derived(),matrix.derived());
}
template <typename OtherDerived>
using DiagonalProductReturnType = DiagonalWrapper<const EIGEN_CWISE_BINARY_RETURN_TYPE(
DiagonalVectorType, typename OtherDerived::DiagonalVectorType, product)>;
template <typename OtherDerived>
EIGEN_DEVICE_FUNC const DiagonalProductReturnType<OtherDerived> operator*(
const DiagonalBase<OtherDerived>& other) const {
return (diagonal().cwiseProduct(other.diagonal())).asDiagonal();
}
typedef DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> > InverseReturnType;
EIGEN_DEVICE_FUNC
inline const InverseReturnType
@ -241,6 +251,22 @@ class DiagonalMatrix
}
#endif
typedef DiagonalWrapper<const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, DiagonalVectorType>>
InitializeReturnType;
/** Initializes a diagonal matrix of size SizeAtCompileTime with coefficients set to zero */
EIGEN_DEVICE_FUNC
static const InitializeReturnType Zero() { return DiagonalVectorType::Zero().asDiagonal(); }
/** Initializes a diagonal matrix of size dim with coefficients set to zero */
EIGEN_DEVICE_FUNC
static const InitializeReturnType Zero(Index size) { return DiagonalVectorType::Zero(size).asDiagonal(); }
/** Initializes a identity matrix of size SizeAtCompileTime */
EIGEN_DEVICE_FUNC
static const InitializeReturnType Identity() { return DiagonalVectorType::Ones().asDiagonal(); }
/** Initializes a identity matrix of size dim */
EIGEN_DEVICE_FUNC
static const InitializeReturnType Identity(Index size) { return DiagonalVectorType::Ones(size).asDiagonal(); }
/** Resizes to given size. */
EIGEN_DEVICE_FUNC
inline void resize(Index size) { m_diagonal.resize(size); }

View File

@ -72,6 +72,9 @@ template<typename MatrixType> void diagonalmatrices(const MatrixType& m)
VERIFY_IS_APPROX( (((v1+v2).asDiagonal() * (m1+m2))(i,j)) , (v1+v2)(i) * (m1+m2)(i,j) );
VERIFY_IS_APPROX( ((m1 * (rv1+rv2).asDiagonal())(i,j)) , (rv1+rv2)(j) * m1(i,j) );
VERIFY_IS_APPROX( (((m1+m2) * (rv1+rv2).asDiagonal())(i,j)) , (rv1+rv2)(j) * (m1+m2)(i,j) );
VERIFY_IS_APPROX( (ldm1 * ldm1).diagonal()(i), ldm1.diagonal()(i) * ldm1.diagonal()(i) );
VERIFY_IS_APPROX( (ldm1 * ldm1 * m1)(i, j), ldm1.diagonal()(i) * ldm1.diagonal()(i) * m1(i, j) );
VERIFY_IS_APPROX( ((v1.asDiagonal() * v1.asDiagonal()).diagonal()(i)), v1(i) * v1(i) );
internal::set_is_malloc_allowed(true);
if(rows>1)
@ -99,6 +102,7 @@ template<typename MatrixType> void diagonalmatrices(const MatrixType& m)
res.noalias() = ldm1 * m1;
res.noalias() = m1 * rdm1;
res.noalias() = ldm1 * m1 * rdm1;
res.noalias() = LeftDiagonalMatrix::Identity(rows) * m1 * RightDiagonalMatrix::Zero(cols);
internal::set_is_malloc_allowed(true);
// scalar multiple
@ -127,6 +131,13 @@ template<typename MatrixType> void diagonalmatrices(const MatrixType& m)
VERIFY_IS_APPROX( sq_m3 = v1.asDiagonal() + v2.asDiagonal(), sq_m1 + sq_m2);
VERIFY_IS_APPROX( sq_m3 = v1.asDiagonal() - v2.asDiagonal(), sq_m1 - sq_m2);
VERIFY_IS_APPROX( sq_m3 = v1.asDiagonal() - 2*v2.asDiagonal() + v1.asDiagonal(), sq_m1 - 2*sq_m2 + sq_m1);
// Zero and Identity
LeftDiagonalMatrix zero = LeftDiagonalMatrix::Zero(rows);
LeftDiagonalMatrix identity = LeftDiagonalMatrix::Identity(rows);
VERIFY_IS_APPROX(identity.diagonal().sum(), Scalar(rows));
VERIFY_IS_APPROX(zero.diagonal().sum(), Scalar(0));
VERIFY_IS_APPROX((zero + 2 * LeftDiagonalMatrix::Identity(rows)).diagonal().sum(), Scalar(2 * rows));
}
template<typename MatrixType> void as_scalar_product(const MatrixType& m)