Fix CwiseUnaryView const access (Attempt 2).

This commit is contained in:
Antonio Sánchez 2024-03-14 21:04:49 +00:00 committed by Rasmus Munk Larsen
parent 285da30ec3
commit 24f8fdeb46

View File

@ -20,9 +20,7 @@ template <typename ViewOp, typename MatrixType, typename StrideType>
struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> > : traits<MatrixType> {
typedef typename result_of<ViewOp(typename traits<MatrixType>::Scalar&)>::type1 ScalarRef;
static_assert(std::is_reference<ScalarRef>::value, "Views must return a reference type.");
typedef remove_all_t<ScalarRef> MutableScalar;
// Ensure const matrices stay const.
typedef std::conditional_t<std::is_const<MatrixType>::value, const MutableScalar, MutableScalar> Scalar;
typedef remove_all_t<ScalarRef> Scalar;
typedef typename MatrixType::Nested MatrixTypeNested;
typedef remove_all_t<MatrixTypeNested> MatrixTypeNested_;
enum {
@ -48,11 +46,13 @@ struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> > : traits<MatrixTy
: int(StrideType::OuterStrideAtCompileTime)
};
};
} // namespace internal
template <typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind>
template <typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind,
bool Mutable = !std::is_const<MatrixType>::value>
class CwiseUnaryViewImpl;
} // namespace internal
/** \class CwiseUnaryView
* \ingroup Core_Module
*
@ -67,11 +67,11 @@ class CwiseUnaryViewImpl;
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
*/
template <typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryView
: public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind> {
class CwiseUnaryView : public internal::CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType,
typename internal::traits<MatrixType>::StorageKind> {
public:
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType,
typename internal::traits<MatrixType>::StorageKind>::Base Base;
typedef typename internal::CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType,
typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef internal::remove_all_t<MatrixType> NestedExpression;
@ -98,40 +98,63 @@ class CwiseUnaryView
ViewOp m_functor;
};
namespace internal {
// Generic API dispatcher
template <typename ViewOp, typename XprType, typename StrideType, typename StorageKind>
class CwiseUnaryViewImpl : public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type {
template <typename ViewOp, typename XprType, typename StrideType, typename StorageKind, bool Mutable>
class CwiseUnaryViewImpl : public generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type {
public:
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
typedef typename generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
};
template <typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense>
: public internal::dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type {
class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, false>
: public dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type {
public:
typedef CwiseUnaryView<ViewOp, MatrixType, StrideType> Derived;
typedef typename internal::dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type Base;
typedef typename dense_xpr_base<CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC inline const Scalar* data() const { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const {
return StrideType::InnerStrideAtCompileTime != 0
? int(StrideType::InnerStrideAtCompileTime)
: derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) /
sizeof(Scalar);
return StrideType::InnerStrideAtCompileTime != 0 ? int(StrideType::InnerStrideAtCompileTime)
: derived().nestedExpression().innerStride() *
sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const {
return StrideType::OuterStrideAtCompileTime != 0
? int(StrideType::OuterStrideAtCompileTime)
: derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) /
sizeof(Scalar);
return StrideType::OuterStrideAtCompileTime != 0 ? int(StrideType::OuterStrideAtCompileTime)
: derived().nestedExpression().outerStride() *
sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar);
}
protected:
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
// Allow const access to coeffRef for the case of direct access being enabled.
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const {
return internal::evaluator<Derived>(derived()).coeffRef(index);
}
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const {
return internal::evaluator<Derived>(derived()).coeffRef(row, col);
}
};
template <typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, true>
: public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, false> {
public:
typedef CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense, false> Base;
typedef CwiseUnaryView<ViewOp, MatrixType, StrideType> Derived;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
using Base::data;
EIGEN_DEVICE_FUNC inline Scalar* data() { return &(this->coeffRef(0)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
return internal::evaluator<Derived>(derived()).coeffRef(row, col);
}
@ -142,17 +165,10 @@ class CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, Dense>
protected:
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
// Allow const access to coeffRef for the case of direct access being enabled.
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const {
return const_cast<CwiseUnaryViewImpl*>(this)->coeffRef(index);
}
EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index row, Index col) const {
return const_cast<CwiseUnaryViewImpl*>(this)->coeffRef(row, col);
}
};
} // end namespace Eigen
} // namespace internal
} // namespace Eigen
#endif // EIGEN_CWISE_UNARY_VIEW_H