bug #178: remove additional const on nested expression, and remove several const_cast.

This commit is contained in:
Gael Guennebaud 2016-01-28 21:43:20 +01:00
parent 12f8bd12a2
commit c1d900af61
12 changed files with 83 additions and 70 deletions

View File

@ -52,7 +52,7 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
const Scalar const Scalar
>::type ScalarWithConstIfNotLvalue; >::type ScalarWithConstIfNotLvalue;
typedef typename internal::ref_selector<ExpressionType>::type NestedExpressionType; typedef typename internal::ref_selector<ExpressionType>::non_const_type NestedExpressionType;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
explicit EIGEN_STRONG_INLINE ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {} explicit EIGEN_STRONG_INLINE ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {}
@ -67,7 +67,7 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
inline Index innerStride() const { return m_expression.innerStride(); } inline Index innerStride() const { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return m_expression.data(); } inline const Scalar* data() const { return m_expression.data(); }
@ -80,13 +80,13 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index rowId, Index colId) inline Scalar& coeffRef(Index rowId, Index colId)
{ {
return m_expression.const_cast_derived().coeffRef(rowId, colId); return m_expression.coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index rowId, Index colId) const inline const Scalar& coeffRef(Index rowId, Index colId) const
{ {
return m_expression.const_cast_derived().coeffRef(rowId, colId); return m_expression.coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -98,13 +98,13 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index index) inline Scalar& coeffRef(Index index)
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.coeffRef(index);
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index index) const inline const Scalar& coeffRef(Index index) const
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.coeffRef(index);
} }
template<int LoadMode> template<int LoadMode>
@ -116,7 +116,7 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
template<int LoadMode> template<int LoadMode>
inline void writePacket(Index rowId, Index colId, const PacketScalar& val) inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
{ {
m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val); m_expression.template writePacket<LoadMode>(rowId, colId, val);
} }
template<int LoadMode> template<int LoadMode>
@ -128,7 +128,7 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
template<int LoadMode> template<int LoadMode>
inline void writePacket(Index index, const PacketScalar& val) inline void writePacket(Index index, const PacketScalar& val)
{ {
m_expression.const_cast_derived().template writePacket<LoadMode>(index, val); m_expression.template writePacket<LoadMode>(index, val);
} }
template<typename Dest> template<typename Dest>
@ -145,11 +145,11 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index) */ * \sa DenseBase::resize(Index) */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } void resize(Index newSize) { m_expression.resize(newSize); }
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index,Index)*/ * \sa DenseBase::resize(Index,Index)*/
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
void resize(Index rows, Index cols) { m_expression.const_cast_derived().resize(rows,cols); } void resize(Index rows, Index cols) { m_expression.resize(rows,cols); }
protected: protected:
NestedExpressionType m_expression; NestedExpressionType m_expression;
@ -195,7 +195,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
const Scalar const Scalar
>::type ScalarWithConstIfNotLvalue; >::type ScalarWithConstIfNotLvalue;
typedef typename internal::ref_selector<ExpressionType>::type NestedExpressionType; typedef typename internal::ref_selector<ExpressionType>::non_const_type NestedExpressionType;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
explicit inline MatrixWrapper(ExpressionType& matrix) : m_expression(matrix) {} explicit inline MatrixWrapper(ExpressionType& matrix) : m_expression(matrix) {}
@ -210,7 +210,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
inline Index innerStride() const { return m_expression.innerStride(); } inline Index innerStride() const { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return m_expression.data(); } inline const Scalar* data() const { return m_expression.data(); }
@ -223,7 +223,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index rowId, Index colId) inline Scalar& coeffRef(Index rowId, Index colId)
{ {
return m_expression.const_cast_derived().coeffRef(rowId, colId); return m_expression.coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -241,13 +241,13 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index index) inline Scalar& coeffRef(Index index)
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.coeffRef(index);
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index index) const inline const Scalar& coeffRef(Index index) const
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.coeffRef(index);
} }
template<int LoadMode> template<int LoadMode>
@ -259,7 +259,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
template<int LoadMode> template<int LoadMode>
inline void writePacket(Index rowId, Index colId, const PacketScalar& val) inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
{ {
m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val); m_expression.template writePacket<LoadMode>(rowId, colId, val);
} }
template<int LoadMode> template<int LoadMode>
@ -271,7 +271,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
template<int LoadMode> template<int LoadMode>
inline void writePacket(Index index, const PacketScalar& val) inline void writePacket(Index index, const PacketScalar& val)
{ {
m_expression.const_cast_derived().template writePacket<LoadMode>(index, val); m_expression.template writePacket<LoadMode>(index, val);
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -284,11 +284,11 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index) */ * \sa DenseBase::resize(Index) */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } void resize(Index newSize) { m_expression.resize(newSize); }
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index,Index)*/ * \sa DenseBase::resize(Index,Index)*/
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
void resize(Index rows, Index cols) { m_expression.const_cast_derived().resize(rows,cols); } void resize(Index rows, Index cols) { m_expression.resize(rows,cols); }
protected: protected:
NestedExpressionType m_expression; NestedExpressionType m_expression;

View File

@ -221,15 +221,13 @@ template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel, bool H
inline Scalar& coeffRef(Index rowId, Index colId) inline Scalar& coeffRef(Index rowId, Index colId)
{ {
EIGEN_STATIC_ASSERT_LVALUE(XprType) EIGEN_STATIC_ASSERT_LVALUE(XprType)
return m_xpr.const_cast_derived() return m_xpr.coeffRef(rowId + m_startRow.value(), colId + m_startCol.value());
.coeffRef(rowId + m_startRow.value(), colId + m_startCol.value());
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index rowId, Index colId) const inline const Scalar& coeffRef(Index rowId, Index colId) const
{ {
return m_xpr.derived() return m_xpr.derived().coeffRef(rowId + m_startRow.value(), colId + m_startCol.value());
.coeffRef(rowId + m_startRow.value(), colId + m_startCol.value());
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -242,39 +240,34 @@ template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel, bool H
inline Scalar& coeffRef(Index index) inline Scalar& coeffRef(Index index)
{ {
EIGEN_STATIC_ASSERT_LVALUE(XprType) EIGEN_STATIC_ASSERT_LVALUE(XprType)
return m_xpr.const_cast_derived() return m_xpr.coeffRef(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index),
.coeffRef(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index), m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index index) const inline const Scalar& coeffRef(Index index) const
{ {
return m_xpr.const_cast_derived() return m_xpr.coeffRef(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index),
.coeffRef(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index), m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const CoeffReturnType coeff(Index index) const inline const CoeffReturnType coeff(Index index) const
{ {
return m_xpr return m_xpr.coeff(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index),
.coeff(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index), m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0));
} }
template<int LoadMode> template<int LoadMode>
inline PacketScalar packet(Index rowId, Index colId) const inline PacketScalar packet(Index rowId, Index colId) const
{ {
return m_xpr.template packet<Unaligned> return m_xpr.template packet<Unaligned>(rowId + m_startRow.value(), colId + m_startCol.value());
(rowId + m_startRow.value(), colId + m_startCol.value());
} }
template<int LoadMode> template<int LoadMode>
inline void writePacket(Index rowId, Index colId, const PacketScalar& val) inline void writePacket(Index rowId, Index colId, const PacketScalar& val)
{ {
m_xpr.const_cast_derived().template writePacket<Unaligned> m_xpr.template writePacket<Unaligned>(rowId + m_startRow.value(), colId + m_startCol.value(), val);
(rowId + m_startRow.value(), colId + m_startCol.value(), val);
} }
template<int LoadMode> template<int LoadMode>
@ -288,7 +281,7 @@ template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel, bool H
template<int LoadMode> template<int LoadMode>
inline void writePacket(Index index, const PacketScalar& val) inline void writePacket(Index index, const PacketScalar& val)
{ {
m_xpr.const_cast_derived().template writePacket<Unaligned> m_xpr.template writePacket<Unaligned>
(m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index), (m_startRow.value() + (RowsAtCompileTime == 1 ? 0 : index),
m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0), val); m_startCol.value() + (RowsAtCompileTime == 1 ? index : 0), val);
} }
@ -320,7 +313,7 @@ template<typename XprType, int BlockRows, int BlockCols, bool InnerPanel, bool H
protected: protected:
const typename XprType::Nested m_xpr; typename XprType::Nested m_xpr;
const internal::variable_if_dynamic<StorageIndex, XprType::RowsAtCompileTime == 1 ? 0 : Dynamic> m_startRow; const internal::variable_if_dynamic<StorageIndex, XprType::RowsAtCompileTime == 1 ? 0 : Dynamic> m_startRow;
const internal::variable_if_dynamic<StorageIndex, XprType::ColsAtCompileTime == 1 ? 0 : Dynamic> m_startCol; const internal::variable_if_dynamic<StorageIndex, XprType::ColsAtCompileTime == 1 ? 0 : Dynamic> m_startCol;
const internal::variable_if_dynamic<StorageIndex, RowsAtCompileTime> m_blockRows; const internal::variable_if_dynamic<StorageIndex, RowsAtCompileTime> m_blockRows;

View File

@ -58,6 +58,7 @@ class CwiseUnaryOp : public CwiseUnaryOpImpl<UnaryOp, XprType, typename internal
typedef typename CwiseUnaryOpImpl<UnaryOp, XprType,typename internal::traits<XprType>::StorageKind>::Base Base; typedef typename CwiseUnaryOpImpl<UnaryOp, XprType,typename internal::traits<XprType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryOp) EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryOp)
typedef typename internal::ref_selector<XprType>::type XprTypeNested;
typedef typename internal::remove_all<XprType>::type NestedExpression; typedef typename internal::remove_all<XprType>::type NestedExpression;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -75,16 +76,16 @@ class CwiseUnaryOp : public CwiseUnaryOpImpl<UnaryOp, XprType, typename internal
/** \returns the nested expression */ /** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename XprType::Nested>::type& const typename internal::remove_all<XprTypeNested>::type&
nestedExpression() const { return m_xpr; } nestedExpression() const { return m_xpr; }
/** \returns the nested expression */ /** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
typename internal::remove_all<typename XprType::Nested>::type& typename internal::remove_all<XprTypeNested>::type&
nestedExpression() { return m_xpr.const_cast_derived(); } nestedExpression() { return m_xpr; }
protected: protected:
typename XprType::Nested m_xpr; XprTypeNested m_xpr;
const UnaryOp m_functor; const UnaryOp m_functor;
}; };

View File

@ -61,6 +61,7 @@ class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename in
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base; typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<MatrixType>::type NestedExpression; typedef typename internal::remove_all<MatrixType>::type NestedExpression;
explicit inline CwiseUnaryView(MatrixType& mat, const ViewOp& func = ViewOp()) explicit inline CwiseUnaryView(MatrixType& mat, const ViewOp& func = ViewOp())
@ -75,15 +76,15 @@ class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename in
const ViewOp& functor() const { return m_functor; } const ViewOp& functor() const { return m_functor; }
/** \returns the nested expression */ /** \returns the nested expression */
const typename internal::remove_all<typename MatrixType::Nested>::type& const typename internal::remove_all<MatrixTypeNested>::type&
nestedExpression() const { return m_matrix; } nestedExpression() const { return m_matrix; }
/** \returns the nested expression */ /** \returns the nested expression */
typename internal::remove_all<typename MatrixType::Nested>::type& typename internal::remove_reference<MatrixTypeNested>::type&
nestedExpression() { return m_matrix.const_cast_derived(); } nestedExpression() { return m_matrix.const_cast_derived(); }
protected: protected:
typename internal::ref_selector<MatrixType>::type m_matrix; MatrixTypeNested m_matrix;
ViewOp m_functor; ViewOp m_functor;
}; };

View File

@ -103,21 +103,21 @@ template<typename MatrixType, int _DiagIndex> class Diagonal
>::type ScalarWithConstIfNotLvalue; >::type ScalarWithConstIfNotLvalue;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return &(m_matrix.const_cast_derived().coeffRef(rowOffset(), colOffset())); } inline ScalarWithConstIfNotLvalue* data() { return &(m_matrix.coeffRef(rowOffset(), colOffset())); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return &(m_matrix.const_cast_derived().coeffRef(rowOffset(), colOffset())); } inline const Scalar* data() const { return &(m_matrix.coeffRef(rowOffset(), colOffset())); }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index row, Index) inline Scalar& coeffRef(Index row, Index)
{ {
EIGEN_STATIC_ASSERT_LVALUE(MatrixType) EIGEN_STATIC_ASSERT_LVALUE(MatrixType)
return m_matrix.const_cast_derived().coeffRef(row+rowOffset(), row+colOffset()); return m_matrix.coeffRef(row+rowOffset(), row+colOffset());
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index row, Index) const inline const Scalar& coeffRef(Index row, Index) const
{ {
return m_matrix.const_cast_derived().coeffRef(row+rowOffset(), row+colOffset()); return m_matrix.coeffRef(row+rowOffset(), row+colOffset());
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -130,13 +130,13 @@ template<typename MatrixType, int _DiagIndex> class Diagonal
inline Scalar& coeffRef(Index idx) inline Scalar& coeffRef(Index idx)
{ {
EIGEN_STATIC_ASSERT_LVALUE(MatrixType) EIGEN_STATIC_ASSERT_LVALUE(MatrixType)
return m_matrix.const_cast_derived().coeffRef(idx+rowOffset(), idx+colOffset()); return m_matrix.coeffRef(idx+rowOffset(), idx+colOffset());
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index idx) const inline const Scalar& coeffRef(Index idx) const
{ {
return m_matrix.const_cast_derived().coeffRef(idx+rowOffset(), idx+colOffset()); return m_matrix.coeffRef(idx+rowOffset(), idx+colOffset());
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -159,7 +159,7 @@ template<typename MatrixType, int _DiagIndex> class Diagonal
} }
protected: protected:
typename MatrixType::Nested m_matrix; typename internal::ref_selector<MatrixType>::non_const_type m_matrix;
const internal::variable_if_dynamicindex<Index, DiagIndex> m_index; const internal::variable_if_dynamicindex<Index, DiagIndex> m_index;
private: private:

View File

@ -32,7 +32,7 @@ namespace internal {
template<typename MatrixType, unsigned int UpLo> template<typename MatrixType, unsigned int UpLo>
struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType> struct traits<SelfAdjointView<MatrixType, UpLo> > : traits<MatrixType>
{ {
typedef typename ref_selector<MatrixType>::type MatrixTypeNested; typedef typename ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned; typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
typedef MatrixType ExpressionType; typedef MatrixType ExpressionType;
typedef typename MatrixType::PlainObject FullMatrixType; typedef typename MatrixType::PlainObject FullMatrixType;
@ -97,7 +97,7 @@ template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView
{ {
EIGEN_STATIC_ASSERT_LVALUE(SelfAdjointView); EIGEN_STATIC_ASSERT_LVALUE(SelfAdjointView);
Base::check_coordinates_internal(row, col); Base::check_coordinates_internal(row, col);
return m_matrix.const_cast_derived().coeffRef(row, col); return m_matrix.coeffRef(row, col);
} }
/** \internal */ /** \internal */
@ -107,7 +107,7 @@ template<typename _MatrixType, unsigned int UpLo> class SelfAdjointView
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const MatrixTypeNestedCleaned& nestedExpression() const { return m_matrix; } const MatrixTypeNestedCleaned& nestedExpression() const { return m_matrix; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
MatrixTypeNestedCleaned& nestedExpression() { return *const_cast<MatrixTypeNestedCleaned*>(&m_matrix); } MatrixTypeNestedCleaned& nestedExpression() { return m_matrix; }
/** Efficient triangular matrix times vector/matrix product */ /** Efficient triangular matrix times vector/matrix product */
template<typename OtherDerived> template<typename OtherDerived>

View File

@ -54,6 +54,8 @@ template<typename MatrixType> class Transpose
{ {
public: public:
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename TransposeImpl<MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base; typedef typename TransposeImpl<MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(Transpose) EIGEN_GENERIC_PUBLIC_INTERFACE(Transpose)
typedef typename internal::remove_all<MatrixType>::type NestedExpression; typedef typename internal::remove_all<MatrixType>::type NestedExpression;
@ -68,16 +70,16 @@ template<typename MatrixType> class Transpose
/** \returns the nested expression */ /** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename MatrixType::Nested>::type& const typename internal::remove_all<MatrixTypeNested>::type&
nestedExpression() const { return m_matrix; } nestedExpression() const { return m_matrix; }
/** \returns the nested expression */ /** \returns the nested expression */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
typename internal::remove_all<typename MatrixType::Nested>::type& typename internal::remove_reference<MatrixTypeNested>::type&
nestedExpression() { return m_matrix.const_cast_derived(); } nestedExpression() { return m_matrix; }
protected: protected:
typename MatrixType::Nested m_matrix; typename internal::ref_selector<MatrixType>::non_const_type m_matrix;
}; };
namespace internal { namespace internal {

View File

@ -325,7 +325,7 @@ class TranspositionsWrapper
protected: protected:
const typename IndicesType::Nested m_indices; typename IndicesType::Nested m_indices;
}; };

View File

@ -168,7 +168,7 @@ namespace internal {
template<typename MatrixType, unsigned int _Mode> template<typename MatrixType, unsigned int _Mode>
struct traits<TriangularView<MatrixType, _Mode> > : traits<MatrixType> struct traits<TriangularView<MatrixType, _Mode> > : traits<MatrixType>
{ {
typedef typename ref_selector<MatrixType>::type MatrixTypeNested; typedef typename ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef; typedef typename remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned; typedef typename remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
typedef typename MatrixType::PlainObject FullMatrixType; typedef typename MatrixType::PlainObject FullMatrixType;
@ -213,7 +213,6 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
IsVectorAtCompileTime = false IsVectorAtCompileTime = false
}; };
// FIXME This, combined with const_cast_derived in transpose() leads to a const-correctness loophole
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
explicit inline TriangularView(MatrixType& matrix) : m_matrix(matrix) explicit inline TriangularView(MatrixType& matrix) : m_matrix(matrix)
{} {}
@ -235,7 +234,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
/** \returns a reference to the nested expression */ /** \returns a reference to the nested expression */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
NestedExpression& nestedExpression() { return *const_cast<NestedExpression*>(&m_matrix); } NestedExpression& nestedExpression() { return m_matrix; }
typedef TriangularView<const MatrixConjugateReturnType,Mode> ConjugateReturnType; typedef TriangularView<const MatrixConjugateReturnType,Mode> ConjugateReturnType;
/** \sa MatrixBase::conjugate() const */ /** \sa MatrixBase::conjugate() const */
@ -255,7 +254,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
inline TransposeReturnType transpose() inline TransposeReturnType transpose()
{ {
EIGEN_STATIC_ASSERT_LVALUE(MatrixType) EIGEN_STATIC_ASSERT_LVALUE(MatrixType)
typename MatrixType::TransposeReturnType tmp(m_matrix.const_cast_derived()); typename MatrixType::TransposeReturnType tmp(m_matrix);
return TransposeReturnType(tmp); return TransposeReturnType(tmp);
} }
@ -418,7 +417,7 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularViewImpl<_Mat
{ {
EIGEN_STATIC_ASSERT_LVALUE(TriangularViewType); EIGEN_STATIC_ASSERT_LVALUE(TriangularViewType);
Base::check_coordinates_internal(row, col); Base::check_coordinates_internal(row, col);
return derived().nestedExpression().const_cast_derived().coeffRef(row, col); return derived().nestedExpression().coeffRef(row, col);
} }
/** Assigns a triangular matrix to a triangular part of a dense matrix */ /** Assigns a triangular matrix to a triangular part of a dense matrix */

View File

@ -466,17 +466,17 @@ struct special_scalar_op_base : public BaseType
template<typename Derived,typename Scalar,typename OtherScalar, typename BaseType> template<typename Derived,typename Scalar,typename OtherScalar, typename BaseType>
struct special_scalar_op_base<Derived,Scalar,OtherScalar,BaseType,true> : public BaseType struct special_scalar_op_base<Derived,Scalar,OtherScalar,BaseType,true> : public BaseType
{ {
const CwiseUnaryOp<scalar_multiple2_op<Scalar,OtherScalar>, Derived> const CwiseUnaryOp<scalar_multiple2_op<Scalar,OtherScalar>, const Derived>
operator*(const OtherScalar& scalar) const operator*(const OtherScalar& scalar) const
{ {
#ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN #ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
#endif #endif
return CwiseUnaryOp<scalar_multiple2_op<Scalar,OtherScalar>, Derived> return CwiseUnaryOp<scalar_multiple2_op<Scalar,OtherScalar>, const Derived>
(*static_cast<const Derived*>(this), scalar_multiple2_op<Scalar,OtherScalar>(scalar)); (*static_cast<const Derived*>(this), scalar_multiple2_op<Scalar,OtherScalar>(scalar));
} }
inline friend const CwiseUnaryOp<scalar_multiple2_op<Scalar,OtherScalar>, Derived> inline friend const CwiseUnaryOp<scalar_multiple2_op<Scalar,OtherScalar>, const Derived>
operator*(const OtherScalar& scalar, const Derived& matrix) operator*(const OtherScalar& scalar, const Derived& matrix)
{ {
#ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN #ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
@ -485,13 +485,13 @@ struct special_scalar_op_base<Derived,Scalar,OtherScalar,BaseType,true> : publi
return static_cast<const special_scalar_op_base&>(matrix).operator*(scalar); return static_cast<const special_scalar_op_base&>(matrix).operator*(scalar);
} }
const CwiseUnaryOp<scalar_quotient2_op<Scalar,OtherScalar>, Derived> const CwiseUnaryOp<scalar_quotient2_op<Scalar,OtherScalar>, const Derived>
operator/(const OtherScalar& scalar) const operator/(const OtherScalar& scalar) const
{ {
#ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN #ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
#endif #endif
return CwiseUnaryOp<scalar_quotient2_op<Scalar,OtherScalar>, Derived> return CwiseUnaryOp<scalar_quotient2_op<Scalar,OtherScalar>, const Derived>
(*static_cast<const Derived*>(this), scalar_quotient2_op<Scalar,OtherScalar>(scalar)); (*static_cast<const Derived*>(this), scalar_quotient2_op<Scalar,OtherScalar>(scalar));
} }
}; };

View File

@ -68,6 +68,16 @@ template<typename MatrixType> void array_for_matrix(const MatrixType& m)
const Scalar& ref_a2 = m.array().matrix().coeffRef(0,0); const Scalar& ref_a2 = m.array().matrix().coeffRef(0,0);
VERIFY(&ref_a1 == &ref_m1); VERIFY(&ref_a1 == &ref_m1);
VERIFY(&ref_a2 == &ref_m2); VERIFY(&ref_a2 == &ref_m2);
// Check write accessors:
m1.array().coeffRef(0,0) = 1;
VERIFY_IS_APPROX(m1(0,0),Scalar(1));
m1.array()(0,0) = 2;
VERIFY_IS_APPROX(m1(0,0),Scalar(2));
m1.array().matrix().coeffRef(0,0) = 3;
VERIFY_IS_APPROX(m1(0,0),Scalar(3));
m1.array().matrix()(0,0) = 4;
VERIFY_IS_APPROX(m1(0,0),Scalar(4));
} }
template<typename MatrixType> void comparisons(const MatrixType& m) template<typename MatrixType> void comparisons(const MatrixType& m)

View File

@ -20,6 +20,8 @@ template<typename MatrixType> void diagonal(const MatrixType& m)
MatrixType m1 = MatrixType::Random(rows, cols), MatrixType m1 = MatrixType::Random(rows, cols),
m2 = MatrixType::Random(rows, cols); m2 = MatrixType::Random(rows, cols);
Scalar s1 = internal::random<Scalar>();
//check diagonal() //check diagonal()
VERIFY_IS_APPROX(m1.diagonal(), m1.transpose().diagonal()); VERIFY_IS_APPROX(m1.diagonal(), m1.transpose().diagonal());
m2.diagonal() = 2 * m1.diagonal(); m2.diagonal() = 2 * m1.diagonal();
@ -58,6 +60,11 @@ template<typename MatrixType> void diagonal(const MatrixType& m)
VERIFY_IS_APPROX(m2.template diagonal<N2>(), static_cast<Scalar>(2) * m1.diagonal(N2)); VERIFY_IS_APPROX(m2.template diagonal<N2>(), static_cast<Scalar>(2) * m1.diagonal(N2));
m2.diagonal(N2)[0] *= 3; m2.diagonal(N2)[0] *= 3;
VERIFY_IS_APPROX(m2.diagonal(N2)[0], static_cast<Scalar>(6) * m1.diagonal(N2)[0]); VERIFY_IS_APPROX(m2.diagonal(N2)[0], static_cast<Scalar>(6) * m1.diagonal(N2)[0]);
m2.diagonal(N2).x() = s1;
VERIFY_IS_APPROX(m2.diagonal(N2).x(), s1);
m2.diagonal(N2).coeffRef(0) = Scalar(2)*s1;
VERIFY_IS_APPROX(m2.diagonal(N2).coeff(0), Scalar(2)*s1);
} }
} }