bug #1286: automatically detect the available prototypes of functors passed to CwiseNullaryExpr such that functors have only to implement the operators that matters among:

operator()()
 operator()(i)
 operator()(i,j)
Linear access is also automatically detected based on the availability of operator()(i,j).
This commit is contained in:
Gael Guennebaud 2016-08-31 15:45:25 +02:00
parent efe2c225c9
commit 218c37beb4
7 changed files with 145 additions and 70 deletions

View File

@ -337,6 +337,56 @@ protected:
// Like Matrix and Array, this is not really a unary expression, so we directly specialize evaluator. // Like Matrix and Array, this is not really a unary expression, so we directly specialize evaluator.
// Likewise, there is not need to more sophisticated dispatching here. // Likewise, there is not need to more sophisticated dispatching here.
template<typename Scalar,typename NullaryOp,
bool has_nullary = has_nullary_operator<NullaryOp>::value,
bool has_unary = has_unary_operator<NullaryOp>::value,
bool has_binary = has_binary_operator<NullaryOp>::value>
struct nullary_wrapper
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j) const { return op(i,j); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i) const { return op(i); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const { return op.template packetOp<T>(i,j); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i) const { return op.template packetOp<T>(i); }
};
template<typename Scalar,typename NullaryOp>
struct nullary_wrapper<Scalar,NullaryOp,true,false,false>
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, ...) const { return op(); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, ...) const { return op.template packetOp<T>(); }
};
template<typename Scalar,typename NullaryOp>
struct nullary_wrapper<Scalar,NullaryOp,false,false,true>
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j=0) const { return op(i,j); }
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j=0) const { return op.template packetOp<T>(i,j); }
};
// We need the following specialization for vector-only functors assigned to a runtime vector,
// for instance, using linspace and assigning a RowVectorXd to a MatrixXd or even a row of a MatrixXd.
// In this case, i==0 and j is used for the actual iteration.
template<typename Scalar,typename NullaryOp>
struct nullary_wrapper<Scalar,NullaryOp,false,true,false>
: nullary_wrapper<Scalar,NullaryOp,false,true,true> // to get the identity wrapper
{
typedef nullary_wrapper<Scalar,NullaryOp,false,true,true> base;
using base::operator();
using base::packetOp;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const NullaryOp& op, Index i, Index j) const {
eigen_assert(i==0 || j==0);
return op(i+j);
}
template <typename T> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const NullaryOp& op, Index i, Index j) const {
eigen_assert(i==0 || j==0);
return op.template packetOp<T>(i+j);
}
};
template<typename Scalar,typename NullaryOp>
struct nullary_wrapper<Scalar,NullaryOp,false,false,false> {};
template<typename NullaryOp, typename PlainObjectType> template<typename NullaryOp, typename PlainObjectType>
struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> > struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> >
: evaluator_base<CwiseNullaryOp<NullaryOp,PlainObjectType> > : evaluator_base<CwiseNullaryOp<NullaryOp,PlainObjectType> >
@ -356,7 +406,7 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> >
}; };
EIGEN_DEVICE_FUNC explicit evaluator(const XprType& n) EIGEN_DEVICE_FUNC explicit evaluator(const XprType& n)
: m_functor(n.functor()) : m_functor(n.functor()), m_wrapper()
{ {
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
} }
@ -366,31 +416,32 @@ struct evaluator<CwiseNullaryOp<NullaryOp,PlainObjectType> >
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index row, Index col) const CoeffReturnType coeff(Index row, Index col) const
{ {
return m_functor(row, col); return m_wrapper(m_functor, row, col);
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
CoeffReturnType coeff(Index index) const CoeffReturnType coeff(Index index) const
{ {
return m_functor(index); return m_wrapper(m_functor,index);
} }
template<int LoadMode, typename PacketType> template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_STRONG_INLINE
PacketType packet(Index row, Index col) const PacketType packet(Index row, Index col) const
{ {
return m_functor.template packetOp<Index,PacketType>(row, col); return m_wrapper.template packetOp<PacketType>(m_functor,row, col);
} }
template<int LoadMode, typename PacketType> template<int LoadMode, typename PacketType>
EIGEN_STRONG_INLINE EIGEN_STRONG_INLINE
PacketType packet(Index index) const PacketType packet(Index index) const
{ {
return m_functor.template packetOp<Index,PacketType>(index); return m_wrapper.template packetOp<PacketType>(m_functor,index);
} }
protected: protected:
const NullaryOp m_functor; const NullaryOp m_functor;
const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper;
}; };
// -------------------- CwiseUnaryOp -------------------- // -------------------- CwiseUnaryOp --------------------

View File

@ -20,7 +20,8 @@ struct traits<CwiseNullaryOp<NullaryOp, PlainObjectType> > : traits<PlainObjectT
Flags = traits<PlainObjectType>::Flags & RowMajorBit Flags = traits<PlainObjectType>::Flags & RowMajorBit
}; };
}; };
}
} // namespace internal
/** \class CwiseNullaryOp /** \class CwiseNullaryOp
* \ingroup Core_Module * \ingroup Core_Module
@ -70,30 +71,6 @@ class CwiseNullaryOp : public internal::dense_xpr_base< CwiseNullaryOp<NullaryOp
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index cols() const { return m_cols.value(); } EIGEN_STRONG_INLINE Index cols() const { return m_cols.value(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index rowId, Index colId) const
{
return m_functor(rowId, colId);
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index rowId, Index colId) const
{
return m_functor.packetOp(rowId, colId);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
{
return m_functor(index);
}
template<int LoadMode>
EIGEN_STRONG_INLINE PacketScalar packet(Index index) const
{
return m_functor.packetOp(index);
}
/** \returns the functor representing the nullary operation */ /** \returns the functor representing the nullary operation */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const NullaryOp& functor() const { return m_functor; } const NullaryOp& functor() const { return m_functor; }

View File

@ -16,8 +16,7 @@ namespace internal {
template<typename Scalar> struct scalar_random_op { template<typename Scalar> struct scalar_random_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_random_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_random_op)
template<typename Index> inline const Scalar operator() () const { return random<Scalar>(); }
inline const Scalar operator() (Index, Index = 0) const { return random<Scalar>(); }
}; };
template<typename Scalar> template<typename Scalar>

View File

@ -18,10 +18,9 @@ template<typename Scalar>
struct scalar_constant_op { struct scalar_constant_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_constant_op(const scalar_constant_op& other) : m_other(other.m_other) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_constant_op(const scalar_constant_op& other) : m_other(other.m_other) { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_constant_op(const Scalar& other) : m_other(other) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_constant_op(const Scalar& other) : m_other(other) { }
template<typename Index> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() () const { return m_other; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (Index, Index = 0) const { return m_other; } template<typename PacketType>
template<typename Index, typename PacketType> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp() const { return internal::pset1<PacketType>(m_other); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType packetOp(Index, Index = 0) const { return internal::pset1<PacketType>(m_other); }
const Scalar m_other; const Scalar m_other;
}; };
template<typename Scalar> template<typename Scalar>
@ -146,27 +145,9 @@ template <typename Scalar, typename PacketType, bool RandomAccess> struct linspa
template<typename Index> template<typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (Index i) const { return impl(i); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (Index i) const { return impl(i); }
// We need this function when assigning e.g. a RowVectorXd to a MatrixXd since template<typename Packet,typename Index>
// there row==0 and col is used for the actual iteration.
template<typename Index>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator() (Index row, Index col) const
{
eigen_assert(col==0 || row==0);
return impl(col + row);
}
template<typename Index, typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(Index i) const { return impl.packetOp(i); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(Index i) const { return impl.packetOp(i); }
// We need this function when assigning e.g. a RowVectorXd to a MatrixXd since
// there row==0 and col is used for the actual iteration.
template<typename Index, typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(Index row, Index col) const
{
eigen_assert(col==0 || row==0);
return impl.packetOp(col + row);
}
// This proxy object handles the actual required temporaries, the different // This proxy object handles the actual required temporaries, the different
// implementations (random vs. sequential access) as well as the // implementations (random vs. sequential access) as well as the
// correct piping to size 2/4 packet operations. // correct piping to size 2/4 packet operations.
@ -175,11 +156,11 @@ template <typename Scalar, typename PacketType, bool RandomAccess> struct linspa
const linspaced_op_impl<Scalar,PacketType,(NumTraits<Scalar>::IsInteger?true:RandomAccess),NumTraits<Scalar>::IsInteger> impl; const linspaced_op_impl<Scalar,PacketType,(NumTraits<Scalar>::IsInteger?true:RandomAccess),NumTraits<Scalar>::IsInteger> impl;
}; };
// all functors allow linear access, except scalar_identity_op. So we fix here a quick meta // Linear access is automatically determined from the operator() prototypes available for the given functor.
// to indicate whether a functor allows linear access, just always answering 'yes' except for // If it exposes an operator()(i,j), then we assume the i and j coefficients are required independently
// scalar_identity_op. // and linear access is not possible. In all other cases, linear access is enabled.
template<typename Functor> struct functor_has_linear_access { enum { ret = 1 }; }; // Users should not have to deal with this struture.
template<typename Scalar> struct functor_has_linear_access<scalar_identity_op<Scalar> > { enum { ret = 0 }; }; template<typename Functor> struct functor_has_linear_access { enum { ret = !has_binary_operator<Functor>::value }; };
} // end namespace internal } // end namespace internal

View File

@ -22,6 +22,16 @@
namespace Eigen { namespace Eigen {
typedef EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex;
/**
* \brief The Index type as used for the API.
* \details To change this, \c \#define the preprocessor symbol \c EIGEN_DEFAULT_DENSE_INDEX_TYPE.
* \sa \blank \ref TopicPreprocessorDirectives, StorageIndex.
*/
typedef EIGEN_DEFAULT_DENSE_INDEX_TYPE Index;
namespace internal { namespace internal {
/** \internal /** \internal
@ -371,6 +381,39 @@ struct has_ReturnType
enum { value = sizeof(testFunctor<T>(0)) == sizeof(yes) }; enum { value = sizeof(testFunctor<T>(0)) == sizeof(yes) };
}; };
template<int> struct any_int {};
template<typename T> const T& return_ref();
struct meta_yes { char data[1]; };
struct meta_no { char data[2]; };
template <typename T>
struct has_nullary_operator
{
template <typename C> static meta_yes testFunctor(C const *,any_int< sizeof(return_ref<C>()()) > * = 0);
static meta_no testFunctor(...);
enum { value = sizeof(testFunctor(static_cast<T*>(0))) == sizeof(meta_yes) };
};
template <typename T>
struct has_unary_operator
{
template <typename C> static meta_yes testFunctor(C const *,any_int< sizeof(return_ref<C>()(Index(0))) > * = 0);
static meta_no testFunctor(...);
enum { value = sizeof(testFunctor(static_cast<T*>(0))) == sizeof(meta_yes) };
};
template <typename T>
struct has_binary_operator
{
template <typename C> static meta_yes testFunctor(C const *,any_int< sizeof(return_ref<C>()(Index(0),Index(0))) > * = 0);
static meta_no testFunctor(...);
enum { value = sizeof(testFunctor(static_cast<T*>(0))) == sizeof(meta_yes) };
};
/** \internal In short, it computes int(sqrt(\a Y)) with \a Y an integer. /** \internal In short, it computes int(sqrt(\a Y)) with \a Y an integer.
* Usage example: \code meta_sqrt<1023>::ret \endcode * Usage example: \code meta_sqrt<1023>::ret \endcode
*/ */

View File

@ -24,16 +24,6 @@
namespace Eigen { namespace Eigen {
typedef EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex;
/**
* \brief The Index type as used for the API.
* \details To change this, \c \#define the preprocessor symbol \c EIGEN_DEFAULT_DENSE_INDEX_TYPE.
* \sa \blank \ref TopicPreprocessorDirectives, StorageIndex.
*/
typedef EIGEN_DEFAULT_DENSE_INDEX_TYPE Index;
namespace internal { namespace internal {
template<typename IndexDest, typename IndexSrc> template<typename IndexDest, typename IndexSrc>

View File

@ -104,13 +104,29 @@ void testVectorType(const VectorType& base)
template<typename MatrixType> template<typename MatrixType>
void testMatrixType(const MatrixType& m) void testMatrixType(const MatrixType& m)
{ {
using std::abs;
const Index rows = m.rows(); const Index rows = m.rows();
const Index cols = m.cols(); const Index cols = m.cols();
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
Scalar s1;
do {
s1 = internal::random<Scalar>();
} while(abs(s1)<RealScalar(1e-5) && (!NumTraits<Scalar>::IsInteger));
MatrixType A; MatrixType A;
A.setIdentity(rows, cols); A.setIdentity(rows, cols);
VERIFY(equalsIdentity(A)); VERIFY(equalsIdentity(A));
VERIFY(equalsIdentity(MatrixType::Identity(rows, cols))); VERIFY(equalsIdentity(MatrixType::Identity(rows, cols)));
A = MatrixType::Constant(rows,cols,s1);
Index i = internal::random<Index>(0,rows-1);
Index j = internal::random<Index>(0,cols-1);
VERIFY_IS_APPROX( MatrixType::Constant(rows,cols,s1)(i,j), s1 );
VERIFY_IS_APPROX( MatrixType::Constant(rows,cols,s1).coeff(i,j), s1 );
VERIFY_IS_APPROX( A(i,j), s1 );
} }
void test_nullary() void test_nullary()
@ -137,4 +153,22 @@ void test_nullary()
// Assignment of a RowVectorXd to a MatrixXd (regression test for bug #79). // Assignment of a RowVectorXd to a MatrixXd (regression test for bug #79).
VERIFY( (MatrixXd(RowVectorXd::LinSpaced(3, 0, 1)) - RowVector3d(0, 0.5, 1)).norm() < std::numeric_limits<double>::epsilon() ); VERIFY( (MatrixXd(RowVectorXd::LinSpaced(3, 0, 1)) - RowVector3d(0, 0.5, 1)).norm() < std::numeric_limits<double>::epsilon() );
#endif #endif
#ifdef EIGEN_TEST_PART_10
// check some internal logic
VERIFY(( internal::has_nullary_operator<internal::scalar_constant_op<double> >::value ));
VERIFY(( !internal::has_unary_operator<internal::scalar_constant_op<double> >::value ));
VERIFY(( !internal::has_binary_operator<internal::scalar_constant_op<double> >::value ));
VERIFY(( internal::functor_has_linear_access<internal::scalar_constant_op<double> >::ret ));
VERIFY(( !internal::has_nullary_operator<internal::scalar_identity_op<double> >::value ));
VERIFY(( !internal::has_unary_operator<internal::scalar_identity_op<double> >::value ));
VERIFY(( internal::has_binary_operator<internal::scalar_identity_op<double> >::value ));
VERIFY(( !internal::functor_has_linear_access<internal::scalar_identity_op<double> >::ret ));
VERIFY(( !internal::has_nullary_operator<internal::linspaced_op<float,float,false> >::value ));
VERIFY(( internal::has_unary_operator<internal::linspaced_op<float,float,false> >::value ));
VERIFY(( !internal::has_binary_operator<internal::linspaced_op<float,float,false> >::value ));
VERIFY(( internal::functor_has_linear_access<internal::linspaced_op<float,float,false> >::ret ));
#endif
} }