Fix CwiseUnaryView const access (Attempt 2).

This commit is contained in:
Antonio Sánchez 2024-03-14 21:04:49 +00:00 committed by Rasmus Munk Larsen
parent 285da30ec3
commit 24f8fdeb46

View File

@ -20,9 +20,7 @@ template <typename ViewOp, typename MatrixType, typename StrideType>
struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> > : traits<MatrixType> { struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> > : traits<MatrixType> {
typedef typename result_of<ViewOp(typename traits<MatrixType>::Scalar&)>::type1 ScalarRef; typedef typename result_of<ViewOp(typename traits<MatrixType>::Scalar&)>::type1 ScalarRef;
static_assert(std::is_reference<ScalarRef>::value, "Views must return a reference type."); static_assert(std::is_reference<ScalarRef>::value, "Views must return a reference type.");
typedef remove_all_t<ScalarRef> MutableScalar; typedef remove_all_t<ScalarRef> Scalar;
// Ensure const matrices stay const.
typedef std::conditional_t<std::is_const<MatrixType>::value, const MutableScalar, MutableScalar> Scalar;
typedef typename MatrixType::Nested MatrixTypeNested; typedef typename MatrixType::Nested MatrixTypeNested;
typedef remove_all_t<MatrixTypeNested> MatrixTypeNested_; typedef remove_all_t<MatrixTypeNested> MatrixTypeNested_;
enum { enum {
@ -48,11 +46,13 @@ struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> > : traits<MatrixTy
: int(StrideType::OuterStrideAtCompileTime) : int(StrideType::OuterStrideAtCompileTime)
}; };
}; };
} // namespace internal
template <typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind> template <typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind,
bool Mutable = !std::is_const<MatrixType>::value>
class CwiseUnaryViewImpl; class CwiseUnaryViewImpl;
} // namespace internal
/** \class CwiseUnaryView /** \class CwiseUnaryView
* \ingroup Core_Module * \ingroup Core_Module
* *
@ -67,10 +67,10 @@ class CwiseUnaryViewImpl;
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp * \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
*/ */
template <typename ViewOp, typename MatrixType, typename StrideType> template <typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryView class CwiseUnaryView : public internal::CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType,
: public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind> { typename internal::traits<MatrixType>::StorageKind> {
public: public:
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typedef typename internal::CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType,
typename internal::traits<MatrixType>::StorageKind>::Base Base; typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested; typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
@ -98,40 +98,63 @@ class CwiseUnaryView
ViewOp m_functor; ViewOp m_functor;
}; };
namespace internal {
// Generic API dispatcher // Generic API dispatcher
template <typename ViewOp, typename XprType, typename StrideType, typename StorageKind> template <typename ViewOp, typename XprType, typename StrideType, typename StorageKind, bool Mutable>
class CwiseUnaryViewImpl : public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type { class CwiseUnaryViewImpl : public generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type {
public: public:
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base; typedef typename generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
}; };
template <typename ViewOp, typename MatrixType, typename StrideType> template <typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense> class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, false>
: public internal::dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type { : public dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type {
public: public:
typedef CwiseUnaryView<ViewOp, MatrixType, StrideType> Derived; typedef CwiseUnaryView<ViewOp, MatrixType, StrideType> Derived;
typedef typename internal::dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type Base; typedef typename dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived) EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeffRef(0)); } EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const { EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const {
return StrideType::InnerStrideAtCompileTime != 0 return StrideType::InnerStrideAtCompileTime != 0 ? int(StrideType::InnerStrideAtCompileTime)
? int(StrideType::InnerStrideAtCompileTime) : derived().nestedExpression().innerStride() *
: derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar);
sizeof(Scalar);
} }
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const { EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const {
return StrideType::OuterStrideAtCompileTime != 0 return StrideType::OuterStrideAtCompileTime != 0 ? int(StrideType::OuterStrideAtCompileTime)
? int(StrideType::OuterStrideAtCompileTime) : derived().nestedExpression().outerStride() *
: derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar);
sizeof(Scalar);
} }
protected:
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
// Allow const access to coeffRef for the case of direct access being enabled.
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const {
return internal::evaluator<Derived>(derived()).coeffRef(index);
}
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const {
return internal::evaluator<Derived>(derived()).coeffRef(row, col);
}
};
template <typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, true>
: public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, false> {
public:
typedef CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, false> Base;
typedef CwiseUnaryView<ViewOp, MatrixType, StrideType> Derived;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
using Base::data;
EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
return internal::evaluator<Derived>(derived()).coeffRef(row, col); return internal::evaluator<Derived>(derived()).coeffRef(row, col);
} }
@ -142,17 +165,10 @@ class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense>
protected: protected:
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
// Allow const access to coeffRef for the case of direct access being enabled.
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const {
return const_cast<CwiseUnaryViewImpl*>(this)->coeffRef(index);
}
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const {
return const_cast<CwiseUnaryViewImpl*>(this)->coeffRef(row, col);
}
}; };
} // end namespace Eigen } // namespace internal
} // namespace Eigen
#endif // EIGEN_CWISE_UNARY_VIEW_H #endif // EIGEN_CWISE_UNARY_VIEW_H