mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-23 10:09:36 +08:00
bug #1064: add support for Ref<SparseVector>
This commit is contained in:
parent
fe630c9873
commit
8961265889
@ -19,7 +19,7 @@ enum {
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename Derived> class SparseRefBase;
|
template<typename Derived> class SparseRefBase;
|
||||||
|
|
||||||
template<typename MatScalar, int MatOptions, typename MatIndex, int _Options, typename _StrideType>
|
template<typename MatScalar, int MatOptions, typename MatIndex, int _Options, typename _StrideType>
|
||||||
struct traits<Ref<SparseMatrix<MatScalar,MatOptions,MatIndex>, _Options, _StrideType> >
|
struct traits<Ref<SparseMatrix<MatScalar,MatOptions,MatIndex>, _Options, _StrideType> >
|
||||||
: public traits<SparseMatrix<MatScalar,MatOptions,MatIndex> >
|
: public traits<SparseMatrix<MatScalar,MatOptions,MatIndex> >
|
||||||
@ -27,7 +27,7 @@ struct traits<Ref<SparseMatrix<MatScalar,MatOptions,MatIndex>, _Options, _Stride
|
|||||||
typedef SparseMatrix<MatScalar,MatOptions,MatIndex> PlainObjectType;
|
typedef SparseMatrix<MatScalar,MatOptions,MatIndex> PlainObjectType;
|
||||||
enum {
|
enum {
|
||||||
Options = _Options,
|
Options = _Options,
|
||||||
Flags = traits<SparseMatrix<MatScalar,MatOptions,MatIndex> >::Flags | CompressedAccessBit | NestByRefBit
|
Flags = traits<PlainObjectType>::Flags | CompressedAccessBit | NestByRefBit
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Derived> struct match {
|
template<typename Derived> struct match {
|
||||||
@ -48,7 +48,35 @@ struct traits<Ref<const SparseMatrix<MatScalar,MatOptions,MatIndex>, _Options, _
|
|||||||
Flags = (traits<SparseMatrix<MatScalar,MatOptions,MatIndex> >::Flags | CompressedAccessBit | NestByRefBit) & ~LvalueBit
|
Flags = (traits<SparseMatrix<MatScalar,MatOptions,MatIndex> >::Flags | CompressedAccessBit | NestByRefBit) & ~LvalueBit
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename MatScalar, int MatOptions, typename MatIndex, int _Options, typename _StrideType>
|
||||||
|
struct traits<Ref<SparseVector<MatScalar,MatOptions,MatIndex>, _Options, _StrideType> >
|
||||||
|
: public traits<SparseVector<MatScalar,MatOptions,MatIndex> >
|
||||||
|
{
|
||||||
|
typedef SparseVector<MatScalar,MatOptions,MatIndex> PlainObjectType;
|
||||||
|
enum {
|
||||||
|
Options = _Options,
|
||||||
|
Flags = traits<PlainObjectType>::Flags | CompressedAccessBit | NestByRefBit
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Derived> struct match {
|
||||||
|
enum {
|
||||||
|
MatchAtCompileTime = (Derived::Flags&CompressedAccessBit) && Derived::IsVectorAtCompileTime
|
||||||
|
};
|
||||||
|
typedef typename internal::conditional<MatchAtCompileTime,internal::true_type,internal::false_type>::type type;
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename MatScalar, int MatOptions, typename MatIndex, int _Options, typename _StrideType>
|
||||||
|
struct traits<Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, _Options, _StrideType> >
|
||||||
|
: public traits<Ref<SparseVector<MatScalar,MatOptions,MatIndex>, _Options, _StrideType> >
|
||||||
|
{
|
||||||
|
enum {
|
||||||
|
Flags = (traits<SparseVector<MatScalar,MatOptions,MatIndex> >::Flags | CompressedAccessBit | NestByRefBit) & ~LvalueBit
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
struct traits<SparseRefBase<Derived> > : public traits<Derived> {};
|
struct traits<SparseRefBase<Derived> > : public traits<Derived> {};
|
||||||
|
|
||||||
@ -195,6 +223,99 @@ class Ref<const SparseMatrix<MatScalar,MatOptions,MatIndex>, Options, StrideType
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \ingroup Sparse_Module
|
||||||
|
*
|
||||||
|
* \brief A sparse vector expression referencing an existing sparse vector expression
|
||||||
|
*
|
||||||
|
* \tparam PlainObjectType the equivalent sparse matrix type of the referenced data
|
||||||
|
* \tparam Options Not used for SparseVector.
|
||||||
|
* \tparam StrideType Only used for dense Ref
|
||||||
|
*
|
||||||
|
* \sa class Ref
|
||||||
|
*/
|
||||||
|
template<typename MatScalar, int MatOptions, typename MatIndex, int Options, typename StrideType>
|
||||||
|
class Ref<SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType >
|
||||||
|
: public internal::SparseRefBase<Ref<SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType > >
|
||||||
|
{
|
||||||
|
typedef SparseVector<MatScalar,MatOptions,MatIndex> PlainObjectType;
|
||||||
|
typedef internal::traits<Ref> Traits;
|
||||||
|
template<int OtherOptions>
|
||||||
|
inline Ref(const SparseVector<MatScalar,OtherOptions,MatIndex>& expr);
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef internal::SparseRefBase<Ref> Base;
|
||||||
|
EIGEN_SPARSE_PUBLIC_INTERFACE(Ref)
|
||||||
|
|
||||||
|
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||||
|
template<int OtherOptions>
|
||||||
|
inline Ref(SparseVector<MatScalar,OtherOptions,MatIndex>& expr)
|
||||||
|
{
|
||||||
|
EIGEN_STATIC_ASSERT(bool(Traits::template match<SparseVector<MatScalar,OtherOptions,MatIndex> >::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
|
||||||
|
Base::construct(expr.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Derived>
|
||||||
|
inline Ref(const SparseCompressedBase<Derived>& expr)
|
||||||
|
#else
|
||||||
|
template<typename Derived>
|
||||||
|
inline Ref(SparseCompressedBase<Derived>& expr)
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
EIGEN_STATIC_ASSERT(bool(internal::is_lvalue<Derived>::value), THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY);
|
||||||
|
EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
|
||||||
|
Base::construct(expr.const_cast_derived());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// this is the const ref version
|
||||||
|
template<typename MatScalar, int MatOptions, typename MatIndex, int Options, typename StrideType>
|
||||||
|
class Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType>
|
||||||
|
: public internal::SparseRefBase<Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> >
|
||||||
|
{
|
||||||
|
typedef SparseVector<MatScalar,MatOptions,MatIndex> TPlainObjectType;
|
||||||
|
typedef internal::traits<Ref> Traits;
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef internal::SparseRefBase<Ref> Base;
|
||||||
|
EIGEN_SPARSE_PUBLIC_INTERFACE(Ref)
|
||||||
|
|
||||||
|
template<typename Derived>
|
||||||
|
inline Ref(const SparseMatrixBase<Derived>& expr)
|
||||||
|
{
|
||||||
|
construct(expr.derived(), typename Traits::template match<Derived>::type());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Ref(const Ref& other) : Base(other) {
|
||||||
|
// copy constructor shall not copy the m_object, to avoid unnecessary malloc and copy
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherRef>
|
||||||
|
inline Ref(const RefBase<OtherRef>& other) {
|
||||||
|
construct(other.derived(), typename Traits::template match<OtherRef>::type());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
|
||||||
|
template<typename Expression>
|
||||||
|
void construct(const Expression& expr,internal::true_type)
|
||||||
|
{
|
||||||
|
Base::construct(expr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Expression>
|
||||||
|
void construct(const Expression& expr, internal::false_type)
|
||||||
|
{
|
||||||
|
TPlainObjectType* obj = reinterpret_cast<TPlainObjectType*>(m_object_bytes);
|
||||||
|
::new (obj) TPlainObjectType(expr);
|
||||||
|
Base::construct(*obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
char m_object_bytes[sizeof(TPlainObjectType)];
|
||||||
|
};
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename MatScalar, int MatOptions, typename MatIndex, int Options, typename StrideType>
|
template<typename MatScalar, int MatOptions, typename MatIndex, int Options, typename StrideType>
|
||||||
@ -217,6 +338,26 @@ struct evaluator<Ref<const SparseMatrix<MatScalar,MatOptions,MatIndex>, Options,
|
|||||||
explicit evaluator(const XprType &mat) : Base(mat) {}
|
explicit evaluator(const XprType &mat) : Base(mat) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename MatScalar, int MatOptions, typename MatIndex, int Options, typename StrideType>
|
||||||
|
struct evaluator<Ref<SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> >
|
||||||
|
: evaluator<SparseCompressedBase<Ref<SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> > >
|
||||||
|
{
|
||||||
|
typedef evaluator<SparseCompressedBase<Ref<SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> > > Base;
|
||||||
|
typedef Ref<SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> XprType;
|
||||||
|
evaluator() : Base() {}
|
||||||
|
explicit evaluator(const XprType &mat) : Base(mat) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename MatScalar, int MatOptions, typename MatIndex, int Options, typename StrideType>
|
||||||
|
struct evaluator<Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> >
|
||||||
|
: evaluator<SparseCompressedBase<Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> > >
|
||||||
|
{
|
||||||
|
typedef evaluator<SparseCompressedBase<Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> > > Base;
|
||||||
|
typedef Ref<const SparseVector<MatScalar,MatOptions,MatIndex>, Options, StrideType> XprType;
|
||||||
|
evaluator() : Base() {}
|
||||||
|
explicit evaluator(const XprType &mat) : Base(mat) {}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -235,6 +235,9 @@ class SparseVector
|
|||||||
inline SparseVector(const SparseMatrixBase<OtherDerived>& other)
|
inline SparseVector(const SparseMatrixBase<OtherDerived>& other)
|
||||||
: m_size(0)
|
: m_size(0)
|
||||||
{
|
{
|
||||||
|
#ifdef EIGEN_SPARSE_CREATE_TEMPORARY_PLUGIN
|
||||||
|
EIGEN_SPARSE_CREATE_TEMPORARY_PLUGIN
|
||||||
|
#endif
|
||||||
check_template_parameters();
|
check_template_parameters();
|
||||||
*this = other.derived();
|
*this = other.derived();
|
||||||
}
|
}
|
||||||
|
@ -53,10 +53,14 @@ EIGEN_DONT_INLINE void call_ref_3(const Ref<const SparseMatrix<float>, StandardC
|
|||||||
VERIFY_IS_EQUAL(a.toDense(),b.toDense());
|
VERIFY_IS_EQUAL(a.toDense(),b.toDense());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename B>
|
||||||
|
EIGEN_DONT_INLINE void call_ref_4(Ref<SparseVector<float> > a, const B &b) { VERIFY_IS_EQUAL(a.toDense(),b.toDense()); }
|
||||||
|
|
||||||
|
template<typename B>
|
||||||
|
EIGEN_DONT_INLINE void call_ref_5(const Ref<const SparseVector<float> >& a, const B &b) { VERIFY_IS_EQUAL(a.toDense(),b.toDense()); }
|
||||||
|
|
||||||
void call_ref()
|
void call_ref()
|
||||||
{
|
{
|
||||||
// SparseVector<std::complex<float> > ca = VectorXcf::Random(10).sparseView();
|
|
||||||
// SparseVector<float> a = VectorXf::Random(10).sparseView();
|
|
||||||
SparseMatrix<float> A = MatrixXf::Random(10,10).sparseView(0.5,1);
|
SparseMatrix<float> A = MatrixXf::Random(10,10).sparseView(0.5,1);
|
||||||
SparseMatrix<float,RowMajor> B = MatrixXf::Random(10,10).sparseView(0.5,1);
|
SparseMatrix<float,RowMajor> B = MatrixXf::Random(10,10).sparseView(0.5,1);
|
||||||
SparseMatrix<float> C = MatrixXf::Random(10,10).sparseView(0.5,1);
|
SparseMatrix<float> C = MatrixXf::Random(10,10).sparseView(0.5,1);
|
||||||
@ -111,6 +115,15 @@ void call_ref()
|
|||||||
VERIFY_EVALUATION_COUNT( call_ref_2(vr, vr.transpose()), 0);
|
VERIFY_EVALUATION_COUNT( call_ref_2(vr, vr.transpose()), 0);
|
||||||
|
|
||||||
VERIFY_EVALUATION_COUNT( call_ref_2(A.block(1,1,3,3), A.block(1,1,3,3)), 1); // should be 0 (allocate starts/nnz only)
|
VERIFY_EVALUATION_COUNT( call_ref_2(A.block(1,1,3,3), A.block(1,1,3,3)), 1); // should be 0 (allocate starts/nnz only)
|
||||||
|
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_4(vc, vc), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_4(vr, vr.transpose()), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_5(vc, vc), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_5(vr, vr.transpose()), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_4(A.col(2), A.col(2)), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_5(A.col(2), A.col(2)), 0);
|
||||||
|
// VERIFY_EVALUATION_COUNT( call_ref_4(A.row(2), A.row(2).transpose()), 1); // does not compile on purpose
|
||||||
|
VERIFY_EVALUATION_COUNT( call_ref_5(A.row(2), A.row(2).transpose()), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_sparse_ref()
|
void test_sparse_ref()
|
||||||
@ -119,5 +132,8 @@ void test_sparse_ref()
|
|||||||
CALL_SUBTEST_1( check_const_correctness(SparseMatrix<float>()) );
|
CALL_SUBTEST_1( check_const_correctness(SparseMatrix<float>()) );
|
||||||
CALL_SUBTEST_1( check_const_correctness(SparseMatrix<double,RowMajor>()) );
|
CALL_SUBTEST_1( check_const_correctness(SparseMatrix<double,RowMajor>()) );
|
||||||
CALL_SUBTEST_2( call_ref() );
|
CALL_SUBTEST_2( call_ref() );
|
||||||
|
|
||||||
|
CALL_SUBTEST_3( check_const_correctness(SparseVector<float>()) );
|
||||||
|
CALL_SUBTEST_3( check_const_correctness(SparseVector<double,RowMajor>()) );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user