Add support for dense.cwiseProduct(sparse)

This also fixes a regression regarding (dense*sparse).diagonal()
This commit is contained in:
Gael Guennebaud 2015-11-04 17:42:07 +01:00
parent f6b1deebab
commit 902750826b
6 changed files with 26 additions and 16 deletions

View File

@ -438,6 +438,15 @@ template<typename Derived> class MatrixBase
template<typename OtherScalar>
void applyOnTheRight(Index p, Index q, const JacobiRotation<OtherScalar>& j);
///////// SparseCore module /////////
template<typename OtherDerived>
EIGEN_STRONG_INLINE const typename SparseMatrixBase<OtherDerived>::template CwiseProductDenseReturnType<Derived>::Type
cwiseProduct(const SparseMatrixBase<OtherDerived> &other) const
{
return other.cwiseProduct(derived());
}
///////// MatrixFunctions module /////////
typedef typename internal::stem_function<Scalar>::type StemFunction;

View File

@ -265,7 +265,6 @@ template<typename Scalar> class Rotation2D;
template<typename Scalar> class AngleAxis;
template<typename Scalar,int Dim> class Translation;
template<typename Scalar,int Dim> class AlignedBox;
template<typename Scalar, int Options = AutoAlign> class Quaternion;
template<typename Scalar,int Dim,int Mode,int _Options=AutoAlign> class Transform;
template <typename _Scalar, int _AmbientDim, int Options=AutoAlign> class ParametrizedLine;
@ -273,6 +272,9 @@ template <typename _Scalar, int _AmbientDim, int Options=AutoAlign> class Hyperp
template<typename Scalar> class UniformScaling;
template<typename MatrixType,int Direction> class Homogeneous;
// Sparse module:
template<typename Derived> class SparseMatrixBase;
// MatrixFunctions module
template<typename Derived> struct MatrixExponentialReturnValue;
template<typename Derived> class MatrixFunctionReturnValue;

View File

@ -423,10 +423,10 @@ Derived& SparseMatrixBase<Derived>::operator-=(const DiagonalBase<OtherDerived>&
template<typename Derived>
template<typename OtherDerived>
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
EIGEN_STRONG_INLINE const typename SparseMatrixBase<Derived>::template CwiseProductDenseReturnType<OtherDerived>::Type
SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const
{
return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived());
return typename CwiseProductDenseReturnType<OtherDerived>::Type(derived(), other.derived());
}
} // end namespace Eigen

View File

@ -262,20 +262,18 @@ template<typename Derived> class SparseMatrixBase
Derived& operator*=(const Scalar& other);
Derived& operator/=(const Scalar& other);
#define EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE \
CwiseBinaryOp< \
internal::scalar_product_op< \
typename internal::scalar_product_traits< \
typename internal::traits<Derived>::Scalar, \
typename internal::traits<OtherDerived>::Scalar \
>::ReturnType \
>, \
const Derived, \
const OtherDerived \
>
template<typename OtherDerived> struct CwiseProductDenseReturnType {
typedef CwiseBinaryOp<internal::scalar_product_op<typename internal::scalar_product_traits<
typename internal::traits<Derived>::Scalar,
typename internal::traits<OtherDerived>::Scalar
>::ReturnType>,
const Derived,
const OtherDerived
> Type;
};
template<typename OtherDerived>
EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE
EIGEN_STRONG_INLINE const typename CwiseProductDenseReturnType<OtherDerived>::Type
cwiseProduct(const MatrixBase<OtherDerived> &other) const;
// sparse * diagonal

View File

@ -49,7 +49,6 @@ const int InnerRandomAccessPattern = 0x2 | CoherentAccessPattern;
const int OuterRandomAccessPattern = 0x4 | CoherentAccessPattern;
const int RandomAccessPattern = 0x8 | OuterRandomAccessPattern | InnerRandomAccessPattern;
template<typename Derived> class SparseMatrixBase;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class SparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class DynamicSparseMatrix;
template<typename _Scalar, int _Flags = 0, typename _StorageIndex = int> class SparseVector;

View File

@ -188,6 +188,8 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
refM4.setRandom();
// sparse cwise* dense
VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4));
// dense cwise* sparse
VERIFY_IS_APPROX(refM4.cwiseProduct(m3), refM4.cwiseProduct(refM3));
// VERIFY_IS_APPROX(m3.cwise()/refM4, refM3.cwise()/refM4);
// test aliasing