bug #178: get rid of some const_cast in SparseCore

This commit is contained in:
Gael Guennebaud 2016-01-28 22:11:18 +01:00
parent c1d900af61
commit b908e071a8
2 changed files with 23 additions and 26 deletions

View File

@ -100,11 +100,11 @@ protected:
enum { OuterSize = IsRowMajor ? BlockRows : BlockCols }; enum { OuterSize = IsRowMajor ? BlockRows : BlockCols };
public: public:
inline sparse_matrix_block_impl(const SparseMatrixType& xpr, Index i) inline sparse_matrix_block_impl(SparseMatrixType& xpr, Index i)
: m_matrix(xpr), m_outerStart(convert_index(i)), m_outerSize(OuterSize) : m_matrix(xpr), m_outerStart(convert_index(i)), m_outerSize(OuterSize)
{} {}
inline sparse_matrix_block_impl(const SparseMatrixType& xpr, Index startRow, Index startCol, Index blockRows, Index blockCols) inline sparse_matrix_block_impl(SparseMatrixType& xpr, Index startRow, Index startCol, Index blockRows, Index blockCols)
: m_matrix(xpr), m_outerStart(convert_index(IsRowMajor ? startRow : startCol)), m_outerSize(convert_index(IsRowMajor ? blockRows : blockCols)) : m_matrix(xpr), m_outerStart(convert_index(IsRowMajor ? startRow : startCol)), m_outerSize(convert_index(IsRowMajor ? blockRows : blockCols))
{} {}
@ -112,7 +112,7 @@ public:
inline BlockType& operator=(const SparseMatrixBase<OtherDerived>& other) inline BlockType& operator=(const SparseMatrixBase<OtherDerived>& other)
{ {
typedef typename internal::remove_all<typename SparseMatrixType::Nested>::type _NestedMatrixType; typedef typename internal::remove_all<typename SparseMatrixType::Nested>::type _NestedMatrixType;
_NestedMatrixType& matrix = const_cast<_NestedMatrixType&>(m_matrix);; _NestedMatrixType& matrix = m_matrix;
// This assignment is slow if this vector set is not empty // This assignment is slow if this vector set is not empty
// and/or it is not at the end of the nonzeros of the underlying matrix. // and/or it is not at the end of the nonzeros of the underlying matrix.
@ -209,28 +209,28 @@ public:
inline const Scalar* valuePtr() const inline const Scalar* valuePtr() const
{ return m_matrix.valuePtr(); } { return m_matrix.valuePtr(); }
inline Scalar* valuePtr() inline Scalar* valuePtr()
{ return m_matrix.const_cast_derived().valuePtr(); } { return m_matrix.valuePtr(); }
inline const StorageIndex* innerIndexPtr() const inline const StorageIndex* innerIndexPtr() const
{ return m_matrix.innerIndexPtr(); } { return m_matrix.innerIndexPtr(); }
inline StorageIndex* innerIndexPtr() inline StorageIndex* innerIndexPtr()
{ return m_matrix.const_cast_derived().innerIndexPtr(); } { return m_matrix.innerIndexPtr(); }
inline const StorageIndex* outerIndexPtr() const inline const StorageIndex* outerIndexPtr() const
{ return m_matrix.outerIndexPtr() + m_outerStart; } { return m_matrix.outerIndexPtr() + m_outerStart; }
inline StorageIndex* outerIndexPtr() inline StorageIndex* outerIndexPtr()
{ return m_matrix.const_cast_derived().outerIndexPtr() + m_outerStart; } { return m_matrix.outerIndexPtr() + m_outerStart; }
inline const StorageIndex* innerNonZeroPtr() const inline const StorageIndex* innerNonZeroPtr() const
{ return isCompressed() ? 0 : (m_matrix.innerNonZeroPtr()+m_outerStart); } { return isCompressed() ? 0 : (m_matrix.innerNonZeroPtr()+m_outerStart); }
inline StorageIndex* innerNonZeroPtr() inline StorageIndex* innerNonZeroPtr()
{ return isCompressed() ? 0 : (m_matrix.const_cast_derived().innerNonZeroPtr()+m_outerStart); } { return isCompressed() ? 0 : (m_matrix.innerNonZeroPtr()+m_outerStart); }
bool isCompressed() const { return m_matrix.innerNonZeroPtr()==0; } bool isCompressed() const { return m_matrix.innerNonZeroPtr()==0; }
inline Scalar& coeffRef(Index row, Index col) inline Scalar& coeffRef(Index row, Index col)
{ {
return m_matrix.const_cast_derived().coeffRef(row + (IsRowMajor ? m_outerStart : 0), col + (IsRowMajor ? 0 : m_outerStart)); return m_matrix.coeffRef(row + (IsRowMajor ? m_outerStart : 0), col + (IsRowMajor ? 0 : m_outerStart));
} }
inline const Scalar coeff(Index row, Index col) const inline const Scalar coeff(Index row, Index col) const
@ -264,7 +264,7 @@ public:
protected: protected:
typename SparseMatrixType::Nested m_matrix; typename internal::ref_selector<SparseMatrixType>::non_const_type m_matrix;
Index m_outerStart; Index m_outerStart;
const internal::variable_if_dynamic<Index, OuterSize> m_outerSize; const internal::variable_if_dynamic<Index, OuterSize> m_outerSize;
@ -373,7 +373,7 @@ public:
/** Column or Row constructor /** Column or Row constructor
*/ */
inline BlockImpl(const XprType& xpr, Index i) inline BlockImpl(XprType& xpr, Index i)
: m_matrix(xpr), : m_matrix(xpr),
m_startRow( (BlockRows==1) && (BlockCols==XprType::ColsAtCompileTime) ? convert_index(i) : 0), m_startRow( (BlockRows==1) && (BlockCols==XprType::ColsAtCompileTime) ? convert_index(i) : 0),
m_startCol( (BlockRows==XprType::RowsAtCompileTime) && (BlockCols==1) ? convert_index(i) : 0), m_startCol( (BlockRows==XprType::RowsAtCompileTime) && (BlockCols==1) ? convert_index(i) : 0),
@ -383,7 +383,7 @@ public:
/** Dynamic-size constructor /** Dynamic-size constructor
*/ */
inline BlockImpl(const XprType& xpr, Index startRow, Index startCol, Index blockRows, Index blockCols) inline BlockImpl(XprType& xpr, Index startRow, Index startCol, Index blockRows, Index blockCols)
: m_matrix(xpr), m_startRow(convert_index(startRow)), m_startCol(convert_index(startCol)), m_blockRows(convert_index(blockRows)), m_blockCols(convert_index(blockCols)) : m_matrix(xpr), m_startRow(convert_index(startRow)), m_startCol(convert_index(startCol)), m_blockRows(convert_index(blockRows)), m_blockCols(convert_index(blockCols))
{} {}
@ -392,8 +392,7 @@ public:
inline Scalar& coeffRef(Index row, Index col) inline Scalar& coeffRef(Index row, Index col)
{ {
return m_matrix.const_cast_derived() return m_matrix.coeffRef(row + m_startRow.value(), col + m_startCol.value());
.coeffRef(row + m_startRow.value(), col + m_startCol.value());
} }
inline const Scalar coeff(Index row, Index col) const inline const Scalar coeff(Index row, Index col) const
@ -403,16 +402,14 @@ public:
inline Scalar& coeffRef(Index index) inline Scalar& coeffRef(Index index)
{ {
return m_matrix.const_cast_derived() return m_matrix.coeffRef(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index),
.coeffRef(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index), m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
} }
inline const Scalar coeff(Index index) const inline const Scalar coeff(Index index) const
{ {
return m_matrix return m_matrix.coeff(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index),
.coeff(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index), m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
} }
inline const _MatrixTypeNested& nestedExpression() const { return m_matrix; } inline const _MatrixTypeNested& nestedExpression() const { return m_matrix; }
@ -430,7 +427,7 @@ public:
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(BlockImpl) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(BlockImpl)
typename XprType::Nested m_matrix; typename internal::ref_selector<XprType>::non_const_type m_matrix;
const internal::variable_if_dynamic<Index, XprType::RowsAtCompileTime == 1 ? 0 : Dynamic> m_startRow; const internal::variable_if_dynamic<Index, XprType::RowsAtCompileTime == 1 ? 0 : Dynamic> m_startRow;
const internal::variable_if_dynamic<Index, XprType::ColsAtCompileTime == 1 ? 0 : Dynamic> m_startCol; const internal::variable_if_dynamic<Index, XprType::ColsAtCompileTime == 1 ? 0 : Dynamic> m_startCol;
const internal::variable_if_dynamic<Index, RowsAtCompileTime> m_blockRows; const internal::variable_if_dynamic<Index, RowsAtCompileTime> m_blockRows;

View File

@ -55,10 +55,10 @@ template<typename MatrixType, unsigned int _Mode> class SparseSelfAdjointView
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::StorageIndex StorageIndex; typedef typename MatrixType::StorageIndex StorageIndex;
typedef Matrix<StorageIndex,Dynamic,1> VectorI; typedef Matrix<StorageIndex,Dynamic,1> VectorI;
typedef typename MatrixType::Nested MatrixTypeNested; typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<MatrixTypeNested>::type _MatrixTypeNested; typedef typename internal::remove_all<MatrixTypeNested>::type _MatrixTypeNested;
explicit inline SparseSelfAdjointView(const MatrixType& matrix) : m_matrix(matrix) explicit inline SparseSelfAdjointView(MatrixType& matrix) : m_matrix(matrix)
{ {
eigen_assert(rows()==cols() && "SelfAdjointView is only for squared matrices"); eigen_assert(rows()==cols() && "SelfAdjointView is only for squared matrices");
} }
@ -68,7 +68,7 @@ template<typename MatrixType, unsigned int _Mode> class SparseSelfAdjointView
/** \internal \returns a reference to the nested matrix */ /** \internal \returns a reference to the nested matrix */
const _MatrixTypeNested& matrix() const { return m_matrix; } const _MatrixTypeNested& matrix() const { return m_matrix; }
_MatrixTypeNested& matrix() { return m_matrix.const_cast_derived(); } typename internal::remove_reference<MatrixTypeNested>::type& matrix() { return m_matrix; }
/** \returns an expression of the matrix product between a sparse self-adjoint matrix \c *this and a sparse matrix \a rhs. /** \returns an expression of the matrix product between a sparse self-adjoint matrix \c *this and a sparse matrix \a rhs.
* *
@ -158,7 +158,7 @@ template<typename MatrixType, unsigned int _Mode> class SparseSelfAdjointView
protected: protected:
typename MatrixType::Nested m_matrix; MatrixTypeNested m_matrix;
//mutable VectorI m_countPerRow; //mutable VectorI m_countPerRow;
//mutable VectorI m_countPerCol; //mutable VectorI m_countPerCol;
private: private:
@ -194,9 +194,9 @@ SparseSelfAdjointView<MatrixType,Mode>::rankUpdate(const SparseMatrixBase<Derive
{ {
SparseMatrix<Scalar,(MatrixType::Flags&RowMajorBit)?RowMajor:ColMajor> tmp = u * u.adjoint(); SparseMatrix<Scalar,(MatrixType::Flags&RowMajorBit)?RowMajor:ColMajor> tmp = u * u.adjoint();
if(alpha==Scalar(0)) if(alpha==Scalar(0))
m_matrix.const_cast_derived() = tmp.template triangularView<Mode>(); m_matrix = tmp.template triangularView<Mode>();
else else
m_matrix.const_cast_derived() += alpha * tmp.template triangularView<Mode>(); m_matrix += alpha * tmp.template triangularView<Mode>();
return *this; return *this;
} }