Allow specifying inner & outer stride for CWiseUnaryView - fixes #2398

This commit is contained in:
Andrew Johnson 2022-01-05 19:24:46 +00:00 committed by Rasmus Munk Larsen
parent 27a78e4f96
commit a491c7f898
6 changed files with 79 additions and 27 deletions

View File

@ -812,11 +812,11 @@ protected:
// -------------------- CwiseUnaryView -------------------- // -------------------- CwiseUnaryView --------------------
template<typename UnaryOp, typename ArgType> template<typename UnaryOp, typename ArgType, typename StrideType>
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType>, IndexBased> struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType, StrideType>, IndexBased>
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType> > : evaluator_base<CwiseUnaryView<UnaryOp, ArgType, StrideType> >
{ {
typedef CwiseUnaryView<UnaryOp, ArgType> XprType; typedef CwiseUnaryView<UnaryOp, ArgType, StrideType> XprType;
enum { enum {
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost), CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),

View File

@ -15,8 +15,8 @@
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
template<typename ViewOp, typename MatrixType> template<typename ViewOp, typename MatrixType, typename StrideType>
struct traits<CwiseUnaryView<ViewOp, MatrixType> > struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> >
: traits<MatrixType> : traits<MatrixType>
{ {
typedef typename result_of< typedef typename result_of<
@ -30,17 +30,22 @@ struct traits<CwiseUnaryView<ViewOp, MatrixType> >
MatrixTypeInnerStride = inner_stride_at_compile_time<MatrixType>::ret, MatrixTypeInnerStride = inner_stride_at_compile_time<MatrixType>::ret,
// need to cast the sizeof's from size_t to int explicitly, otherwise: // need to cast the sizeof's from size_t to int explicitly, otherwise:
// "error: no integral type can represent all of the enumerator values // "error: no integral type can represent all of the enumerator values
InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0
? int(Dynamic) ? (MatrixTypeInnerStride == Dynamic
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)), ? int(Dynamic)
OuterStrideAtCompileTime = outer_stride_at_compile_time<MatrixType>::ret == Dynamic : int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
? int(Dynamic) : int(StrideType::InnerStrideAtCompileTime),
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar))
OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0
? (outer_stride_at_compile_time<MatrixType>::ret == Dynamic
? int(Dynamic)
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
: int(StrideType::OuterStrideAtCompileTime)
}; };
}; };
} }
template<typename ViewOp, typename MatrixType, typename StorageKind> template<typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind>
class CwiseUnaryViewImpl; class CwiseUnaryViewImpl;
/** \class CwiseUnaryView /** \class CwiseUnaryView
@ -56,12 +61,12 @@ class CwiseUnaryViewImpl;
* *
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp * \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
*/ */
template<typename ViewOp, typename MatrixType> template<typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename internal::traits<MatrixType>::StorageKind> class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>
{ {
public: public:
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base; typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>::Base Base;
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView) EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested; typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
typedef typename internal::remove_all<MatrixType>::type NestedExpression; typedef typename internal::remove_all<MatrixType>::type NestedExpression;
@ -93,22 +98,22 @@ class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename in
}; };
// Generic API dispatcher // Generic API dispatcher
template<typename ViewOp, typename XprType, typename StorageKind> template<typename ViewOp, typename XprType, typename StrideType, typename StorageKind>
class CwiseUnaryViewImpl class CwiseUnaryViewImpl
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type : public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type
{ {
public: public:
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type Base; typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
}; };
template<typename ViewOp, typename MatrixType> template<typename ViewOp, typename MatrixType, typename StrideType>
class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense> class CwiseUnaryViewImpl<ViewOp,MatrixType,StrideType,Dense>
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type : public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type
{ {
public: public:
typedef CwiseUnaryView<ViewOp, MatrixType> Derived; typedef CwiseUnaryView<ViewOp, MatrixType,StrideType> Derived;
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type Base; typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType,StrideType> >::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Derived) EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
@ -118,12 +123,16 @@ class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
{ {
return derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar); return StrideType::InnerStrideAtCompileTime != 0
? int(StrideType::InnerStrideAtCompileTime)
: derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
} }
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
{ {
return derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar); return StrideType::OuterStrideAtCompileTime != 0
? int(StrideType::OuterStrideAtCompileTime)
: derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
} }
protected: protected:
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl) EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)

View File

@ -78,7 +78,6 @@ template<typename MatrixType> class Transpose;
template<typename MatrixType> class Conjugate; template<typename MatrixType> class Conjugate;
template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp; template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp; template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp; template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp; template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
template<typename Decomposition, typename Rhstype> class Solve; template<typename Decomposition, typename Rhstype> class Solve;
@ -108,6 +107,7 @@ template<typename MatrixType, int MapOptions=Unaligned, typename StrideType = St
template<typename Derived> class RefBase; template<typename Derived> class RefBase;
template<typename PlainObjectType, int Options = 0, template<typename PlainObjectType, int Options = 0,
typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref; typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref;
template<typename ViewOp, typename MatrixType, typename StrideType = Stride<0,0>> class CwiseUnaryView;
template<typename Derived> class TriangularBase; template<typename Derived> class TriangularBase;
template<typename MatrixType, unsigned int Mode> class TriangularView; template<typename MatrixType, unsigned int Mode> class TriangularView;

View File

@ -61,6 +61,9 @@ set(ei_smoke_test_list
mapped_matrix_1 mapped_matrix_1
mapstaticmethods_1 mapstaticmethods_1
mapstride_1 mapstride_1
unaryviewstride_1
unaryviewstride_2
unaryviewstride_3
matrix_square_root_1 matrix_square_root_1
meta meta
minres_2 minres_2

View File

@ -194,6 +194,7 @@ ei_add_test(commainitializer)
ei_add_test(smallvectors) ei_add_test(smallvectors)
ei_add_test(mapped_matrix) ei_add_test(mapped_matrix)
ei_add_test(mapstride) ei_add_test(mapstride)
ei_add_test(unaryviewstride)
ei_add_test(mapstaticmethods) ei_add_test(mapstaticmethods)
ei_add_test(array_cwise) ei_add_test(array_cwise)
ei_add_test(array_for_matrix) ei_add_test(array_for_matrix)

39
test/unaryviewstride.cpp Normal file
View File

@ -0,0 +1,39 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2021 Andrew Johnson <andrew.johnson@arjohnsonau.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
template<int OuterStride,int InnerStride,typename VectorType> void unaryview_stride(const VectorType& m)
{
typedef typename VectorType::Scalar Scalar;
Index rows = m.rows();
Index cols = m.cols();
VectorType vec = VectorType::Random(rows, cols);
struct view_op {
EIGEN_EMPTY_STRUCT_CTOR(view_op)
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar&
operator()(const Scalar& v) const { return v; }
};
CwiseUnaryView<view_op, VectorType, Stride<OuterStride,InnerStride>> vec_view(vec);
VERIFY(vec_view.outerStride() == (OuterStride == 0 ? 0 : OuterStride));
VERIFY(vec_view.innerStride() == (InnerStride == 0 ? 1 : InnerStride));
}
EIGEN_DECLARE_TEST(unaryviewstride)
{
CALL_SUBTEST_1(( unaryview_stride<1,2>(MatrixXf()) ));
CALL_SUBTEST_1(( unaryview_stride<0,0>(MatrixXf()) ));
CALL_SUBTEST_2(( unaryview_stride<1,2>(VectorXf()) ));
CALL_SUBTEST_2(( unaryview_stride<0,0>(VectorXf()) ));
CALL_SUBTEST_3(( unaryview_stride<1,2>(RowVectorXf()) ));
CALL_SUBTEST_3(( unaryview_stride<0,0>(RowVectorXf()) ));
}