mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 16:49:38 +08:00
Allow specifying inner & outer stride for CWiseUnaryView - fixes #2398
This commit is contained in:
parent
27a78e4f96
commit
a491c7f898
@ -812,11 +812,11 @@ protected:
|
|||||||
|
|
||||||
// -------------------- CwiseUnaryView --------------------
|
// -------------------- CwiseUnaryView --------------------
|
||||||
|
|
||||||
template<typename UnaryOp, typename ArgType>
|
template<typename UnaryOp, typename ArgType, typename StrideType>
|
||||||
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType>, IndexBased>
|
struct unary_evaluator<CwiseUnaryView<UnaryOp, ArgType, StrideType>, IndexBased>
|
||||||
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType> >
|
: evaluator_base<CwiseUnaryView<UnaryOp, ArgType, StrideType> >
|
||||||
{
|
{
|
||||||
typedef CwiseUnaryView<UnaryOp, ArgType> XprType;
|
typedef CwiseUnaryView<UnaryOp, ArgType, StrideType> XprType;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),
|
CoeffReadCost = int(evaluator<ArgType>::CoeffReadCost) + int(functor_traits<UnaryOp>::Cost),
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template<typename ViewOp, typename MatrixType>
|
template<typename ViewOp, typename MatrixType, typename StrideType>
|
||||||
struct traits<CwiseUnaryView<ViewOp, MatrixType> >
|
struct traits<CwiseUnaryView<ViewOp, MatrixType, StrideType> >
|
||||||
: traits<MatrixType>
|
: traits<MatrixType>
|
||||||
{
|
{
|
||||||
typedef typename result_of<
|
typedef typename result_of<
|
||||||
@ -30,17 +30,22 @@ struct traits<CwiseUnaryView<ViewOp, MatrixType> >
|
|||||||
MatrixTypeInnerStride = inner_stride_at_compile_time<MatrixType>::ret,
|
MatrixTypeInnerStride = inner_stride_at_compile_time<MatrixType>::ret,
|
||||||
// need to cast the sizeof's from size_t to int explicitly, otherwise:
|
// need to cast the sizeof's from size_t to int explicitly, otherwise:
|
||||||
// "error: no integral type can represent all of the enumerator values
|
// "error: no integral type can represent all of the enumerator values
|
||||||
InnerStrideAtCompileTime = MatrixTypeInnerStride == Dynamic
|
InnerStrideAtCompileTime = StrideType::InnerStrideAtCompileTime == 0
|
||||||
? int(Dynamic)
|
? (MatrixTypeInnerStride == Dynamic
|
||||||
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)),
|
? int(Dynamic)
|
||||||
OuterStrideAtCompileTime = outer_stride_at_compile_time<MatrixType>::ret == Dynamic
|
: int(MatrixTypeInnerStride) * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
|
||||||
? int(Dynamic)
|
: int(StrideType::InnerStrideAtCompileTime),
|
||||||
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar))
|
|
||||||
|
OuterStrideAtCompileTime = StrideType::OuterStrideAtCompileTime == 0
|
||||||
|
? (outer_stride_at_compile_time<MatrixType>::ret == Dynamic
|
||||||
|
? int(Dynamic)
|
||||||
|
: outer_stride_at_compile_time<MatrixType>::ret * int(sizeof(typename traits<MatrixType>::Scalar) / sizeof(Scalar)))
|
||||||
|
: int(StrideType::OuterStrideAtCompileTime)
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename ViewOp, typename MatrixType, typename StorageKind>
|
template<typename ViewOp, typename MatrixType, typename StrideType, typename StorageKind>
|
||||||
class CwiseUnaryViewImpl;
|
class CwiseUnaryViewImpl;
|
||||||
|
|
||||||
/** \class CwiseUnaryView
|
/** \class CwiseUnaryView
|
||||||
@ -56,12 +61,12 @@ class CwiseUnaryViewImpl;
|
|||||||
*
|
*
|
||||||
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
|
* \sa MatrixBase::unaryViewExpr(const CustomUnaryOp &) const, class CwiseUnaryOp
|
||||||
*/
|
*/
|
||||||
template<typename ViewOp, typename MatrixType>
|
template<typename ViewOp, typename MatrixType, typename StrideType>
|
||||||
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename internal::traits<MatrixType>::StorageKind>
|
class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType,typename internal::traits<MatrixType>::StorageKind>::Base Base;
|
typedef typename CwiseUnaryViewImpl<ViewOp, MatrixType, StrideType, typename internal::traits<MatrixType>::StorageKind>::Base Base;
|
||||||
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
|
EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseUnaryView)
|
||||||
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
|
typedef typename internal::ref_selector<MatrixType>::non_const_type MatrixTypeNested;
|
||||||
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
|
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
|
||||||
@ -93,22 +98,22 @@ class CwiseUnaryView : public CwiseUnaryViewImpl<ViewOp, MatrixType, typename in
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Generic API dispatcher
|
// Generic API dispatcher
|
||||||
template<typename ViewOp, typename XprType, typename StorageKind>
|
template<typename ViewOp, typename XprType, typename StrideType, typename StorageKind>
|
||||||
class CwiseUnaryViewImpl
|
class CwiseUnaryViewImpl
|
||||||
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type
|
: public internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType> >::type Base;
|
typedef typename internal::generic_xpr_base<CwiseUnaryView<ViewOp, XprType, StrideType> >::type Base;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename ViewOp, typename MatrixType>
|
template<typename ViewOp, typename MatrixType, typename StrideType>
|
||||||
class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
|
class CwiseUnaryViewImpl<ViewOp,MatrixType,StrideType,Dense>
|
||||||
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type
|
: public internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType, StrideType> >::type
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef CwiseUnaryView<ViewOp, MatrixType> Derived;
|
typedef CwiseUnaryView<ViewOp, MatrixType,StrideType> Derived;
|
||||||
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType> >::type Base;
|
typedef typename internal::dense_xpr_base< CwiseUnaryView<ViewOp, MatrixType,StrideType> >::type Base;
|
||||||
|
|
||||||
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
||||||
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
|
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CwiseUnaryViewImpl)
|
||||||
@ -118,12 +123,16 @@ class CwiseUnaryViewImpl<ViewOp,MatrixType,Dense>
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
|
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index innerStride() const
|
||||||
{
|
{
|
||||||
return derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
return StrideType::InnerStrideAtCompileTime != 0
|
||||||
|
? int(StrideType::InnerStrideAtCompileTime)
|
||||||
|
: derived().nestedExpression().innerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
|
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR inline Index outerStride() const
|
||||||
{
|
{
|
||||||
return derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
return StrideType::OuterStrideAtCompileTime != 0
|
||||||
|
? int(StrideType::OuterStrideAtCompileTime)
|
||||||
|
: derived().nestedExpression().outerStride() * sizeof(typename internal::traits<MatrixType>::Scalar) / sizeof(Scalar);
|
||||||
}
|
}
|
||||||
protected:
|
protected:
|
||||||
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
|
EIGEN_DEFAULT_EMPTY_CONSTRUCTOR_AND_DESTRUCTOR(CwiseUnaryViewImpl)
|
||||||
|
@ -78,7 +78,6 @@ template<typename MatrixType> class Transpose;
|
|||||||
template<typename MatrixType> class Conjugate;
|
template<typename MatrixType> class Conjugate;
|
||||||
template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
|
template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
|
||||||
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
||||||
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
|
||||||
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
||||||
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
|
template<typename TernaryOp, typename Arg1, typename Arg2, typename Arg3> class CwiseTernaryOp;
|
||||||
template<typename Decomposition, typename Rhstype> class Solve;
|
template<typename Decomposition, typename Rhstype> class Solve;
|
||||||
@ -108,6 +107,7 @@ template<typename MatrixType, int MapOptions=Unaligned, typename StrideType = St
|
|||||||
template<typename Derived> class RefBase;
|
template<typename Derived> class RefBase;
|
||||||
template<typename PlainObjectType, int Options = 0,
|
template<typename PlainObjectType, int Options = 0,
|
||||||
typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref;
|
typename StrideType = typename internal::conditional<PlainObjectType::IsVectorAtCompileTime,InnerStride<1>,OuterStride<> >::type > class Ref;
|
||||||
|
template<typename ViewOp, typename MatrixType, typename StrideType = Stride<0,0>> class CwiseUnaryView;
|
||||||
|
|
||||||
template<typename Derived> class TriangularBase;
|
template<typename Derived> class TriangularBase;
|
||||||
template<typename MatrixType, unsigned int Mode> class TriangularView;
|
template<typename MatrixType, unsigned int Mode> class TriangularView;
|
||||||
|
@ -61,6 +61,9 @@ set(ei_smoke_test_list
|
|||||||
mapped_matrix_1
|
mapped_matrix_1
|
||||||
mapstaticmethods_1
|
mapstaticmethods_1
|
||||||
mapstride_1
|
mapstride_1
|
||||||
|
unaryviewstride_1
|
||||||
|
unaryviewstride_2
|
||||||
|
unaryviewstride_3
|
||||||
matrix_square_root_1
|
matrix_square_root_1
|
||||||
meta
|
meta
|
||||||
minres_2
|
minres_2
|
||||||
|
@ -194,6 +194,7 @@ ei_add_test(commainitializer)
|
|||||||
ei_add_test(smallvectors)
|
ei_add_test(smallvectors)
|
||||||
ei_add_test(mapped_matrix)
|
ei_add_test(mapped_matrix)
|
||||||
ei_add_test(mapstride)
|
ei_add_test(mapstride)
|
||||||
|
ei_add_test(unaryviewstride)
|
||||||
ei_add_test(mapstaticmethods)
|
ei_add_test(mapstaticmethods)
|
||||||
ei_add_test(array_cwise)
|
ei_add_test(array_cwise)
|
||||||
ei_add_test(array_for_matrix)
|
ei_add_test(array_for_matrix)
|
||||||
|
39
test/unaryviewstride.cpp
Normal file
39
test/unaryviewstride.cpp
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2021 Andrew Johnson <andrew.johnson@arjohnsonau.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#include "main.h"
|
||||||
|
|
||||||
|
template<int OuterStride,int InnerStride,typename VectorType> void unaryview_stride(const VectorType& m)
|
||||||
|
{
|
||||||
|
typedef typename VectorType::Scalar Scalar;
|
||||||
|
Index rows = m.rows();
|
||||||
|
Index cols = m.cols();
|
||||||
|
VectorType vec = VectorType::Random(rows, cols);
|
||||||
|
|
||||||
|
struct view_op {
|
||||||
|
EIGEN_EMPTY_STRUCT_CTOR(view_op)
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar&
|
||||||
|
operator()(const Scalar& v) const { return v; }
|
||||||
|
};
|
||||||
|
|
||||||
|
CwiseUnaryView<view_op, VectorType, Stride<OuterStride,InnerStride>> vec_view(vec);
|
||||||
|
VERIFY(vec_view.outerStride() == (OuterStride == 0 ? 0 : OuterStride));
|
||||||
|
VERIFY(vec_view.innerStride() == (InnerStride == 0 ? 1 : InnerStride));
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DECLARE_TEST(unaryviewstride)
|
||||||
|
{
|
||||||
|
CALL_SUBTEST_1(( unaryview_stride<1,2>(MatrixXf()) ));
|
||||||
|
CALL_SUBTEST_1(( unaryview_stride<0,0>(MatrixXf()) ));
|
||||||
|
CALL_SUBTEST_2(( unaryview_stride<1,2>(VectorXf()) ));
|
||||||
|
CALL_SUBTEST_2(( unaryview_stride<0,0>(VectorXf()) ));
|
||||||
|
CALL_SUBTEST_3(( unaryview_stride<1,2>(RowVectorXf()) ));
|
||||||
|
CALL_SUBTEST_3(( unaryview_stride<0,0>(RowVectorXf()) ));
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user