mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Properly implement PartialReduxExpr on top of evaluators, and fix multiple evaluation of nested expression
This commit is contained in:
parent
5cc7251188
commit
aa6b1aebf3
@ -965,17 +965,16 @@ protected:
|
|||||||
|
|
||||||
|
|
||||||
// -------------------- PartialReduxExpr --------------------
|
// -------------------- PartialReduxExpr --------------------
|
||||||
//
|
|
||||||
// This is a wrapper around the expression object.
|
|
||||||
// TODO: Find out how to write a proper evaluator without duplicating
|
|
||||||
// the row() and col() member functions.
|
|
||||||
|
|
||||||
template< typename ArgType, typename MemberOp, int Direction>
|
template< typename ArgType, typename MemberOp, int Direction>
|
||||||
struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
||||||
: evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
: evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
||||||
{
|
{
|
||||||
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
|
typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
|
||||||
typedef typename XprType::Scalar InputScalar;
|
typedef typename internal::nested_eval<ArgType,1>::type ArgTypeNested;
|
||||||
|
typedef typename internal::remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
|
||||||
|
typedef typename ArgType::Scalar InputScalar;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
enum {
|
enum {
|
||||||
TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime)
|
TraversalSize = Direction==int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(XprType::ColsAtCompileTime)
|
||||||
};
|
};
|
||||||
@ -986,27 +985,34 @@ struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
|
|||||||
|
|
||||||
Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits),
|
Flags = (traits<XprType>::Flags&RowMajorBit) | (evaluator<ArgType>::Flags&HereditaryBits),
|
||||||
|
|
||||||
Alignment = 0 // FIXME this could be improved
|
Alignment = 0 // FIXME this will need to be improved once PartialReduxExpr is vectorized
|
||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC explicit evaluator(const XprType expr)
|
EIGEN_DEVICE_FUNC explicit evaluator(const XprType xpr)
|
||||||
: m_expr(expr)
|
: m_arg(xpr.nestedExpression()), m_functor(xpr.functor())
|
||||||
{}
|
{}
|
||||||
|
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index row, Index col) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
|
||||||
{
|
{
|
||||||
return m_expr.coeff(row, col);
|
if (Direction==Vertical)
|
||||||
|
return m_functor(m_arg.col(j));
|
||||||
|
else
|
||||||
|
return m_functor(m_arg.row(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
||||||
{
|
{
|
||||||
return m_expr.coeff(index);
|
if (Direction==Vertical)
|
||||||
|
return m_functor(m_arg.col(index));
|
||||||
|
else
|
||||||
|
return m_functor(m_arg.row(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const XprType m_expr;
|
const ArgTypeNested m_arg;
|
||||||
|
const MemberOp m_functor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,8 +41,6 @@ struct traits<PartialReduxExpr<MatrixType, MemberOp, Direction> >
|
|||||||
typedef typename traits<MatrixType>::StorageKind StorageKind;
|
typedef typename traits<MatrixType>::StorageKind StorageKind;
|
||||||
typedef typename traits<MatrixType>::XprKind XprKind;
|
typedef typename traits<MatrixType>::XprKind XprKind;
|
||||||
typedef typename MatrixType::Scalar InputScalar;
|
typedef typename MatrixType::Scalar InputScalar;
|
||||||
typedef typename ref_selector<MatrixType>::type MatrixTypeNested;
|
|
||||||
typedef typename remove_all<MatrixTypeNested>::type _MatrixTypeNested;
|
|
||||||
enum {
|
enum {
|
||||||
RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime,
|
RowsAtCompileTime = Direction==Vertical ? 1 : MatrixType::RowsAtCompileTime,
|
||||||
ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime,
|
ColsAtCompileTime = Direction==Horizontal ? 1 : MatrixType::ColsAtCompileTime,
|
||||||
@ -62,8 +60,6 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
|
|||||||
|
|
||||||
typedef typename internal::dense_xpr_base<PartialReduxExpr>::type Base;
|
typedef typename internal::dense_xpr_base<PartialReduxExpr>::type Base;
|
||||||
EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr)
|
EIGEN_DENSE_PUBLIC_INTERFACE(PartialReduxExpr)
|
||||||
typedef typename internal::traits<PartialReduxExpr>::MatrixTypeNested MatrixTypeNested;
|
|
||||||
typedef typename internal::traits<PartialReduxExpr>::_MatrixTypeNested _MatrixTypeNested;
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp())
|
explicit PartialReduxExpr(const MatrixType& mat, const MemberOp& func = MemberOp())
|
||||||
@ -74,24 +70,11 @@ class PartialReduxExpr : public internal::dense_xpr_base< PartialReduxExpr<Matri
|
|||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
|
Index cols() const { return (Direction==Horizontal ? 1 : m_matrix.cols()); }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index i, Index j) const
|
typename MatrixType::Nested nestedExpression() const { return m_matrix; }
|
||||||
{
|
const MemberOp& functor() const { return m_functor; }
|
||||||
if (Direction==Vertical)
|
|
||||||
return m_functor(m_matrix.col(j));
|
|
||||||
else
|
|
||||||
return m_functor(m_matrix.row(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
|
||||||
{
|
|
||||||
if (Direction==Vertical)
|
|
||||||
return m_functor(m_matrix.col(index));
|
|
||||||
else
|
|
||||||
return m_functor(m_matrix.row(index));
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
MatrixTypeNested m_matrix;
|
typename MatrixType::Nested m_matrix;
|
||||||
const MemberOp m_functor;
|
const MemberOp m_functor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -2,11 +2,13 @@
|
|||||||
// for linear algebra.
|
// for linear algebra.
|
||||||
//
|
//
|
||||||
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
|
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
|
||||||
|
// Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||||
//
|
//
|
||||||
// This Source Code Form is subject to the terms of the Mozilla
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#define TEST_ENABLE_TEMPORARY_TRACKING
|
||||||
#define EIGEN_NO_STATIC_ASSERT
|
#define EIGEN_NO_STATIC_ASSERT
|
||||||
|
|
||||||
#include "main.h"
|
#include "main.h"
|
||||||
@ -209,14 +211,20 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
|
|||||||
m2 = m1;
|
m2 = m1;
|
||||||
m2.rowwise().normalize();
|
m2.rowwise().normalize();
|
||||||
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
|
VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
|
||||||
|
|
||||||
|
// test with partial reduction of products
|
||||||
|
Matrix<Scalar,MatrixType::RowsAtCompileTime,MatrixType::RowsAtCompileTime> m1m1 = m1 * m1.transpose();
|
||||||
|
VERIFY_IS_APPROX( (m1 * m1.transpose()).colwise().sum(), m1m1.colwise().sum());
|
||||||
|
Matrix<Scalar,1,MatrixType::RowsAtCompileTime> tmp(rows);
|
||||||
|
VERIFY_EVALUATION_COUNT( tmp = (m1 * m1.transpose()).colwise().sum(), (MatrixType::RowsAtCompileTime==Dynamic ? 1 : 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_vectorwiseop()
|
void test_vectorwiseop()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST_1(vectorwiseop_array(Array22cd()));
|
CALL_SUBTEST_1( vectorwiseop_array(Array22cd()) );
|
||||||
CALL_SUBTEST_2(vectorwiseop_array(Array<double, 3, 2>()));
|
CALL_SUBTEST_2( vectorwiseop_array(Array<double, 3, 2>()) );
|
||||||
CALL_SUBTEST_3(vectorwiseop_array(ArrayXXf(3, 4)));
|
CALL_SUBTEST_3( vectorwiseop_array(ArrayXXf(3, 4)) );
|
||||||
CALL_SUBTEST_4(vectorwiseop_matrix(Matrix4cf()));
|
CALL_SUBTEST_4( vectorwiseop_matrix(Matrix4cf()) );
|
||||||
CALL_SUBTEST_5(vectorwiseop_matrix(Matrix<float,4,5>()));
|
CALL_SUBTEST_5( vectorwiseop_matrix(Matrix<float,4,5>()) );
|
||||||
CALL_SUBTEST_6(vectorwiseop_matrix(MatrixXd(7,2)));
|
CALL_SUBTEST_6( vectorwiseop_matrix(MatrixXd(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user