From 3b60d2dbc4688fa3216a0df56ecca699a9ff9ea2 Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Thu, 28 Apr 2011 15:52:15 +0100 Subject: [PATCH] Implement swap using evaluators. --- Eigen/src/Core/AssignEvaluator.h | 12 +++++- Eigen/src/Core/CoreEvaluators.h | 74 ++++++++++++++++++++++++++++++-- Eigen/src/Core/Swap.h | 2 + test/evaluators.cpp | 24 ++++++++++- 4 files changed, 107 insertions(+), 5 deletions(-) diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index 93ca2433a..c5f345a2f 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -404,7 +404,7 @@ struct copy_using_evaluator_impl::JointAlignment }; - const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dst.coeffRef(0), size); + const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dstEvaluator.coeffRef(0), size); const Index alignedEnd = alignedStart + ((size-alignedStart)/packetSize)*packetSize; unaligned_copy_using_evaluator_impl::run(dstEvaluator, srcEvaluator, 0, alignedStart); @@ -614,6 +614,16 @@ const DstXprType& copy_using_evaluator(const DstXprType& dst, const SrcXprType& return dst; } +// Based on DenseBase::swap() +// TODO: Chech whether we need to do something special for swapping two +// Arrays or Matrices. + +template +void swap_using_evaluator(const DstXprType& dst, const SrcXprType& src) +{ + copy_using_evaluator(SwapWrapper(const_cast(dst)), src); +} + } // namespace internal #endif // EIGEN_ASSIGN_EVALUATOR_H diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 187dc1c97..899aa04ea 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -65,7 +65,7 @@ struct evaluator_impl_base { Index row = rowIndexByOuterInner(outer, inner); Index col = colIndexByOuterInner(outer, inner); - derived().coeffRef(row, col) = other.coeff(row, col); + derived().copyCoeff(row, col, other); } template @@ -86,8 +86,7 @@ struct evaluator_impl_base { Index row = rowIndexByOuterInner(outer, inner); Index col = colIndexByOuterInner(outer, inner); - derived().template writePacket(row, col, - other.template packet(row, col)); + derived().template copyPacket(row, col, other); } template @@ -1017,6 +1016,75 @@ private: }; +// ---------- SwapWrapper ---------- + +template +struct evaluator_impl > + : evaluator_impl_base > +{ + typedef SwapWrapper XprType; + + evaluator_impl(const XprType& swapWrapper) + : m_argImpl(swapWrapper.expression()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::Packet Packet; + + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(row, col); + } + + inline Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template + void copyCoeff(Index row, Index col, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast(other); + Scalar tmp = m_argImpl.coeff(row, col); + m_argImpl.coeffRef(row, col) = nonconst_other.coeff(row, col); + nonconst_other.coeffRef(row, col) = tmp; + } + + template + void copyCoeff(Index index, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast(other); + Scalar tmp = m_argImpl.coeff(index); + m_argImpl.coeffRef(index) = nonconst_other.coeff(index); + nonconst_other.coeffRef(index) = tmp; + } + + template + void copyPacket(Index row, Index col, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast(other); + Packet tmp = m_argImpl.template packet(row, col); + m_argImpl.template writePacket + (row, col, nonconst_other.template packet(row, col)); + nonconst_other.template writePacket(row, col, tmp); + } + + template + void copyPacket(Index index, const OtherEvaluatorType& other) + { + OtherEvaluatorType& nonconst_other = const_cast(other); + Packet tmp = m_argImpl.template packet(index); + m_argImpl.template writePacket + (index, nonconst_other.template packet(index)); + nonconst_other.template writePacket(index, tmp); + } + +protected: + typename evaluator::type m_argImpl; +}; + + } // namespace internal #endif // EIGEN_COREEVALUATORS_H diff --git a/Eigen/src/Core/Swap.h b/Eigen/src/Core/Swap.h index 5fb032866..5fdd36e3b 100644 --- a/Eigen/src/Core/Swap.h +++ b/Eigen/src/Core/Swap.h @@ -119,6 +119,8 @@ template class SwapWrapper _other.template writePacket(index, tmp); } + ExpressionType& expression() const { return m_expression; } + protected: ExpressionType& m_expression; }; diff --git a/test/evaluators.cpp b/test/evaluators.cpp index ea957cb1e..6e81ad5ef 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -214,5 +214,27 @@ void test_evaluators() copy_using_evaluator(mat1.diagonal<-1>(), mat1.diagonal(1)); mat2.diagonal<-1>() = mat2.diagonal(1); - VERIFY_IS_APPROX(mat1, mat2); + VERIFY_IS_APPROX(mat1, mat2); + + { + // test swapping + MatrixXd mat1, mat2, mat1ref, mat2ref; + mat1ref = mat1 = MatrixXd::Random(6, 6); + mat2ref = mat2 = 2 * mat1 + MatrixXd::Identity(6, 6); + swap_using_evaluator(mat1, mat2); + mat1ref.swap(mat2ref); + VERIFY_IS_APPROX(mat1, mat1ref); + VERIFY_IS_APPROX(mat2, mat2ref); + + swap_using_evaluator(mat1.block(0, 0, 3, 3), mat2.block(3, 3, 3, 3)); + mat1ref.block(0, 0, 3, 3).swap(mat2ref.block(3, 3, 3, 3)); + VERIFY_IS_APPROX(mat1, mat1ref); + VERIFY_IS_APPROX(mat2, mat2ref); + + swap_using_evaluator(mat1.row(2), mat2.col(3).transpose()); + mat1.row(2).swap(mat2.col(3).transpose()); + VERIFY_IS_APPROX(mat1, mat1ref); + VERIFY_IS_APPROX(mat2, mat2ref); + } + }