Implement evaluators for Reverse.

This commit is contained in:
Jitse Niesen 2011-04-22 22:36:45 +01:00
parent bb2d70d211
commit f924722f3b
3 changed files with 109 additions and 1 deletions

View File

@ -666,7 +666,7 @@ struct evaluator_impl<Replicate<XprType, RowFactor, ColFactor> >
protected:
typename evaluator<XprType>::type m_argImpl;
Index m_rows;
Index m_rows; // TODO: Get rid of this if known at compile time
Index m_cols;
};
@ -787,6 +787,100 @@ struct evaluator_impl<ArrayWrapper<ArgType> >
};
// -------------------- Reverse --------------------
// defined in Reverse.h:
template<typename PacketScalar, bool ReversePacket> struct reverse_packet_cond;
template<typename ArgType, int Direction>
struct evaluator_impl<Reverse<ArgType, Direction> >
{
typedef Reverse<ArgType, Direction> ReverseType;
evaluator_impl(const ReverseType& reverse)
: m_argImpl(reverse.nestedExpression()),
m_rows(reverse.nestedExpression().rows()),
m_cols(reverse.nestedExpression().cols())
{ }
typedef typename ReverseType::Index Index;
typedef typename ReverseType::Scalar Scalar;
typedef typename ReverseType::CoeffReturnType CoeffReturnType;
typedef typename ReverseType::PacketScalar PacketScalar;
typedef typename ReverseType::PacketReturnType PacketReturnType;
enum {
PacketSize = internal::packet_traits<Scalar>::size,
IsRowMajor = ReverseType::IsRowMajor,
IsColMajor = !IsRowMajor,
ReverseRow = (Direction == Vertical) || (Direction == BothDirections),
ReverseCol = (Direction == Horizontal) || (Direction == BothDirections),
OffsetRow = ReverseRow && IsColMajor ? PacketSize : 1,
OffsetCol = ReverseCol && IsRowMajor ? PacketSize : 1,
ReversePacket = (Direction == BothDirections)
|| ((Direction == Vertical) && IsColMajor)
|| ((Direction == Horizontal) && IsRowMajor)
};
typedef internal::reverse_packet_cond<PacketScalar,ReversePacket> reverse_packet;
CoeffReturnType coeff(Index row, Index col) const
{
return m_argImpl.coeff(ReverseRow ? m_rows - row - 1 : row,
ReverseCol ? m_cols - col - 1 : col);
}
CoeffReturnType coeff(Index index) const
{
return m_argImpl.coeff(m_rows * m_cols - index - 1);
}
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(ReverseRow ? m_rows - row - 1 : row,
ReverseCol ? m_cols - col - 1 : col);
}
Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(m_rows * m_cols - index - 1);
}
template<int LoadMode>
PacketScalar packet(Index row, Index col) const
{
return reverse_packet::run(m_argImpl.template packet<LoadMode>(
ReverseRow ? m_rows - row - OffsetRow : row,
ReverseCol ? m_cols - col - OffsetCol : col));
}
template<int LoadMode>
PacketScalar packet(Index index) const
{
return preverse(m_argImpl.template packet<LoadMode>(m_rows * m_cols - index - PacketSize));
}
template<int LoadMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_argImpl.template writePacket<LoadMode>(
ReverseRow ? m_rows - row - OffsetRow : row,
ReverseCol ? m_cols - col - OffsetCol : col,
reverse_packet::run(x));
}
template<int LoadMode>
void writePacket(Index index, const PacketScalar& x)
{
m_argImpl.template writePacket<LoadMode>(m_rows * m_cols - index - PacketSize, preverse(x));
}
protected:
typename evaluator<ArgType>::type m_argImpl;
Index m_rows; // TODO: Don't use if known at compile time or not needed
Index m_cols;
};
} // namespace internal
#endif // EIGEN_COREEVALUATORS_H

View File

@ -183,6 +183,12 @@ template<typename MatrixType, int Direction> class Reverse
m_matrix.const_cast_derived().template writePacket<LoadMode>(m_matrix.size() - index - PacketSize, internal::preverse(x));
}
const typename internal::remove_all<typename MatrixType::Nested>::type&
nestedExpression() const
{
return m_matrix;
}
protected:
const typename MatrixType::Nested m_matrix;
};

View File

@ -192,4 +192,12 @@ void test_evaluators()
VERIFY_IS_APPROX(mat2, (arr1 * arr1).matrix());
arr2.matrix() = MatrixXd::Identity(6,6);
VERIFY_IS_APPROX(arr2, MatrixXd::Identity(6,6).array());
// test Reverse
VERIFY_IS_APPROX_EVALUATOR(arr2, arr1.reverse());
VERIFY_IS_APPROX_EVALUATOR(arr2, arr1.colwise().reverse());
VERIFY_IS_APPROX_EVALUATOR(arr2, arr1.rowwise().reverse());
arr2.reverse() = arr1;
VERIFY_IS_APPROX(arr2, arr1.reverse());
}