From 2d11041e244685340c523351f841fe0b4f3b2a62 Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Fri, 22 Apr 2011 22:36:45 +0100 Subject: [PATCH] Use copyCoeff/copyPacket in copy_using_evaluator. --- Eigen/src/Core/AssignEvaluator.h | 268 ++++++++++++--------------- Eigen/src/Core/CoreEvaluators.h | 302 ++++++++++++++++++++----------- test/evaluators.cpp | 2 +- 3 files changed, 314 insertions(+), 258 deletions(-) diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index c49c2a50f..93ca2433a 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -147,165 +147,132 @@ public: * Part 2 : meta-unrollers ***************************************************************************/ -// TODO:`Ideally, we want to use only the evaluator objects here, not the expression objects -// However, we need to access .rowIndexByOuterInner() which is in the expression object - /************************ *** Default traversal *** ************************/ -template +template struct copy_using_evaluator_DefaultTraversal_CompleteUnrolling { + typedef typename DstEvaluatorType::XprType DstXprType; + enum { outer = Index / DstXprType::InnerSizeAtCompileTime, inner = Index % DstXprType::InnerSizeAtCompileTime }; - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, - SrcEvaluatorType &srcEvaluator, - const DstXprType &dst) + SrcEvaluatorType &srcEvaluator) { - // TODO: Use copyCoeffByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, inner); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); - copy_using_evaluator_DefaultTraversal_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); + dstEvaluator.copyCoeffByOuterInner(outer, inner, srcEvaluator); + copy_using_evaluator_DefaultTraversal_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; -template -struct copy_using_evaluator_DefaultTraversal_CompleteUnrolling +template +struct copy_using_evaluator_DefaultTraversal_CompleteUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&) { } }; -template +template struct copy_using_evaluator_DefaultTraversal_InnerUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, SrcEvaluatorType &srcEvaluator, - const DstXprType &dst, int outer) { - // TODO: Use copyCoeffByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, Index); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, Index); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); - copy_using_evaluator_DefaultTraversal_InnerUnrolling - ::run(dstEvaluator, srcEvaluator, dst, outer); + dstEvaluator.copyCoeffByOuterInner(outer, Index, srcEvaluator); + copy_using_evaluator_DefaultTraversal_InnerUnrolling + + ::run(dstEvaluator, srcEvaluator, outer); } }; -template -struct copy_using_evaluator_DefaultTraversal_InnerUnrolling +template +struct copy_using_evaluator_DefaultTraversal_InnerUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&, int) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, int) { } }; /*********************** *** Linear traversal *** ***********************/ -template +template struct copy_using_evaluator_LinearTraversal_CompleteUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, - SrcEvaluatorType &srcEvaluator, - const DstXprType &dst) + SrcEvaluatorType &srcEvaluator) { - // use copyCoeff ? - dstEvaluator.coeffRef(Index) = srcEvaluator.coeff(Index); - copy_using_evaluator_LinearTraversal_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); + dstEvaluator.copyCoeff(Index, srcEvaluator); + copy_using_evaluator_LinearTraversal_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; -template -struct copy_using_evaluator_LinearTraversal_CompleteUnrolling +template +struct copy_using_evaluator_LinearTraversal_CompleteUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&) { } }; /************************** *** Inner vectorization *** **************************/ -template +template struct copy_using_evaluator_innervec_CompleteUnrolling { + typedef typename DstEvaluatorType::XprType DstXprType; + typedef typename SrcEvaluatorType::XprType SrcXprType; + enum { outer = Index / DstXprType::InnerSizeAtCompileTime, inner = Index % DstXprType::InnerSizeAtCompileTime, JointAlignment = copy_using_evaluator_traits::JointAlignment }; - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, - SrcEvaluatorType &srcEvaluator, - const DstXprType &dst) + SrcEvaluatorType &srcEvaluator) { - // TODO: Use copyPacketByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, inner); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.template writePacket(row, col, srcEvaluator.template packet(row, col)); - copy_using_evaluator_innervec_CompleteUnrolling::size, Stop>::run(dstEvaluator, srcEvaluator, dst); + dstEvaluator.template copyPacketByOuterInner(outer, inner, srcEvaluator); + enum { NextIndex = Index + packet_traits::size }; + copy_using_evaluator_innervec_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; -template -struct copy_using_evaluator_innervec_CompleteUnrolling +template +struct copy_using_evaluator_innervec_CompleteUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&) { } }; -template +template struct copy_using_evaluator_innervec_InnerUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType &dstEvaluator, SrcEvaluatorType &srcEvaluator, - const DstXprType &dst, int outer) { - // TODO: Use copyPacketByOuterInner ? - typename DstXprType::Index row = dst.rowIndexByOuterInner(outer, Index); - typename DstXprType::Index col = dst.colIndexByOuterInner(outer, Index); - dstEvaluator.template writePacket(row, col, srcEvaluator.template packet(row, col)); - copy_using_evaluator_innervec_InnerUnrolling::size, Stop>::run(dstEvaluator, srcEvaluator, dst, outer); + dstEvaluator.template copyPacketByOuterInner(outer, Index, srcEvaluator); + typedef typename DstEvaluatorType::XprType DstXprType; + enum { NextIndex = Index + packet_traits::size }; + copy_using_evaluator_innervec_InnerUnrolling + + ::run(dstEvaluator, srcEvaluator, outer); } }; -template -struct copy_using_evaluator_innervec_InnerUnrolling +template +struct copy_using_evaluator_innervec_InnerUnrolling { - typedef typename evaluator::type DstEvaluatorType; - typedef typename evaluator::type SrcEvaluatorType; - EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, const DstXprType&, int) { } + EIGEN_STRONG_INLINE static void run(DstEvaluatorType&, SrcEvaluatorType&, int) { } }; /*************************************************************************** @@ -326,20 +293,18 @@ struct copy_using_evaluator_impl; template struct copy_using_evaluator_impl { - static void run(const DstXprType& dst, const SrcXprType& src) + static void run(DstXprType& dst, const SrcXprType& src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); for(Index outer = 0; outer < dst.outerSize(); ++outer) { for(Index inner = 0; inner < dst.innerSize(); ++inner) { - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.coeffRef(row, col) = srcEvaluator.coeff(row, col); // TODO: use copyCoeff ? + dstEvaluator.copyCoeffByOuterInner(outer, inner, srcEvaluator); } } } @@ -348,16 +313,17 @@ struct copy_using_evaluator_impl struct copy_using_evaluator_impl { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - copy_using_evaluator_DefaultTraversal_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_DefaultTraversal_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; @@ -365,18 +331,19 @@ template struct copy_using_evaluator_impl { typedef typename DstXprType::Index Index; - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index outerSize = dst.outerSize(); for(Index outer = 0; outer < outerSize; ++outer) - copy_using_evaluator_DefaultTraversal_InnerUnrolling - ::run(dstEvaluator, srcEvaluator, dst, outer); + copy_using_evaluator_DefaultTraversal_InnerUnrolling + + ::run(dstEvaluator, srcEvaluator, outer); } }; @@ -387,43 +354,46 @@ struct copy_using_evaluator_impl struct unaligned_copy_using_evaluator_impl { + // if IsAligned = true, then do nothing template static EIGEN_STRONG_INLINE void run(const SrcEvaluatorType&, DstEvaluatorType&, typename SrcEvaluatorType::Index, typename SrcEvaluatorType::Index) {} }; -// TODO: check why no ... ???? - template <> struct unaligned_copy_using_evaluator_impl { // MSVC must not inline this functions. If it does, it fails to optimize the // packet access path. #ifdef _MSC_VER - template - static EIGEN_DONT_INLINE void run(const SrcEvaluatorType& src, DstEvaluatorType& dst, - typename SrcEvaluatorType::Index start, typename SrcEvaluatorType::Index end) + template + static EIGEN_DONT_INLINE void run(DstEvaluatorType &dstEvaluator, + const SrcEvaluatorType &srcEvaluator, + typename DstEvaluatorType::Index start, + typename DstEvaluatorType::Index end) #else - template - static EIGEN_STRONG_INLINE void run(const SrcEvaluatorType& src, DstEvaluatorType& dst, - typename SrcEvaluatorType::Index start, typename SrcEvaluatorType::Index end) + template + static EIGEN_STRONG_INLINE void run(DstEvaluatorType &dstEvaluator, + const SrcEvaluatorType &srcEvaluator, + typename DstEvaluatorType::Index start, + typename DstEvaluatorType::Index end) #endif { - for (typename SrcEvaluatorType::Index index = start; index < end; ++index) - dst.copyCoeff(index, src); + for (typename DstEvaluatorType::Index index = start; index < end; ++index) + dstEvaluator.copyCoeff(index, srcEvaluator); } }; template struct copy_using_evaluator_impl { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index size = dst.size(); @@ -437,14 +407,14 @@ struct copy_using_evaluator_impl::run(src,dst.const_cast_derived(),0,alignedStart); + unaligned_copy_using_evaluator_impl::run(dstEvaluator, srcEvaluator, 0, alignedStart); for(Index index = alignedStart; index < alignedEnd; index += packetSize) { - dstEvaluator.template writePacket(index, srcEvaluator.template packet(index)); + dstEvaluator.template copyPacket(index, srcEvaluator); } - unaligned_copy_using_evaluator_impl<>::run(src,dst.const_cast_derived(),alignedEnd,size); + unaligned_copy_using_evaluator_impl<>::run(dstEvaluator, srcEvaluator, alignedEnd, size); } }; @@ -452,22 +422,24 @@ template struct copy_using_evaluator_impl { typedef typename DstXprType::Index Index; - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); enum { size = DstXprType::SizeAtCompileTime, packetSize = packet_traits::size, alignedSize = (size/packetSize)*packetSize }; - copy_using_evaluator_innervec_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); - copy_using_evaluator_DefaultTraversal_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_innervec_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); + copy_using_evaluator_DefaultTraversal_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; @@ -478,13 +450,13 @@ struct copy_using_evaluator_impl struct copy_using_evaluator_impl { - inline static void run(const DstXprType &dst, const SrcXprType &src) + inline static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index innerSize = dst.innerSize(); @@ -492,10 +464,7 @@ struct copy_using_evaluator_impl::size; for(Index outer = 0; outer < outerSize; ++outer) for(Index inner = 0; inner < innerSize; inner+=packetSize) { - // TODO: Use copyPacketByOuterInner ? - Index row = dst.rowIndexByOuterInner(outer, inner); - Index col = dst.colIndexByOuterInner(outer, inner); - dstEvaluator.template writePacket(row, col, srcEvaluator.template packet(row, col)); + dstEvaluator.template copyPacketByOuterInner(outer, inner, srcEvaluator); } } }; @@ -503,16 +472,17 @@ struct copy_using_evaluator_impl struct copy_using_evaluator_impl { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - copy_using_evaluator_innervec_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_innervec_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; @@ -520,18 +490,19 @@ template struct copy_using_evaluator_impl { typedef typename DstXprType::Index Index; - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index outerSize = dst.outerSize(); for(Index outer = 0; outer < outerSize; ++outer) - copy_using_evaluator_innervec_InnerUnrolling - ::run(dstEvaluator, srcEvaluator, dst, outer); + copy_using_evaluator_innervec_InnerUnrolling + + ::run(dstEvaluator, srcEvaluator, outer); } }; @@ -542,34 +513,35 @@ struct copy_using_evaluator_impl struct copy_using_evaluator_impl { - inline static void run(const DstXprType &dst, const SrcXprType &src) + inline static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); const Index size = dst.size(); for(Index i = 0; i < size; ++i) - dstEvaluator.coeffRef(i) = srcEvaluator.coeff(i); // TODO: use copyCoeff ? + dstEvaluator.copyCoeff(i, srcEvaluator); } }; template struct copy_using_evaluator_impl { - EIGEN_STRONG_INLINE static void run(const DstXprType &dst, const SrcXprType &src) + EIGEN_STRONG_INLINE static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); - copy_using_evaluator_LinearTraversal_CompleteUnrolling - ::run(dstEvaluator, srcEvaluator, dst); + copy_using_evaluator_LinearTraversal_CompleteUnrolling + + ::run(dstEvaluator, srcEvaluator); } }; @@ -580,13 +552,13 @@ struct copy_using_evaluator_impl struct copy_using_evaluator_impl { - inline static void run(const DstXprType &dst, const SrcXprType &src) + inline static void run(DstXprType &dst, const SrcXprType &src) { typedef typename evaluator::type DstEvaluatorType; typedef typename evaluator::type SrcEvaluatorType; typedef typename DstXprType::Index Index; - DstEvaluatorType dstEvaluator(dst.const_cast_derived()); + DstEvaluatorType dstEvaluator(dst); SrcEvaluatorType srcEvaluator(src); typedef packet_traits PacketTraits; @@ -608,23 +580,17 @@ struct copy_using_evaluator_impl(row, col, srcEvaluator.template packet(row, col)); + dstEvaluator.template copyPacketByOuterInner(outer, inner, srcEvaluator); } // do the non-vectorizable part of the assignment for(Index inner = alignedEnd; inner((alignedStart+alignedStep)%packetSize, innerSize); @@ -644,7 +610,7 @@ const DstXprType& copy_using_evaluator(const DstXprType& dst, const SrcXprType& #ifdef EIGEN_DEBUG_ASSIGN internal::copy_using_evaluator_traits::debug(); #endif - copy_using_evaluator_impl::run(dst, src); + copy_using_evaluator_impl::run(const_cast(dst), src); return dst; } diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 1ef82e4be..187dc1c97 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -47,63 +47,140 @@ struct evaluator typedef evaluator_impl type; }; -// -------------------- Transpose -------------------- +// ---------- base class for all writable evaluators ---------- template -struct evaluator_impl > +struct evaluator_impl_base { - typedef Transpose TransposeType; - evaluator_impl(const TransposeType& t) : m_argImpl(t.nestedExpression()) {} + typedef typename ExpressionType::Index Index; - typedef typename TransposeType::Index Index; - - typename TransposeType::CoeffReturnType coeff(Index i, Index j) const + template + void copyCoeff(Index row, Index col, const OtherEvaluatorType& other) { - return m_argImpl.coeff(j, i); + derived().coeffRef(row, col) = other.coeff(row, col); } - typename TransposeType::CoeffReturnType coeff(Index index) const + template + void copyCoeffByOuterInner(Index outer, Index inner, const OtherEvaluatorType& other) + { + Index row = rowIndexByOuterInner(outer, inner); + Index col = colIndexByOuterInner(outer, inner); + derived().coeffRef(row, col) = other.coeff(row, col); + } + + template + void copyCoeff(Index index, const OtherEvaluatorType& other) + { + derived().coeffRef(index) = other.coeff(index); + } + + template + void copyPacket(Index row, Index col, const OtherEvaluatorType& other) + { + derived().template writePacket(row, col, + other.template packet(row, col)); + } + + template + void copyPacketByOuterInner(Index outer, Index inner, const OtherEvaluatorType& other) + { + Index row = rowIndexByOuterInner(outer, inner); + Index col = colIndexByOuterInner(outer, inner); + derived().template writePacket(row, col, + other.template packet(row, col)); + } + + template + void copyPacket(Index index, const OtherEvaluatorType& other) + { + derived().template writePacket(index, + other.template packet(index)); + } + + Index rowIndexByOuterInner(Index outer, Index inner) const + { + return int(ExpressionType::RowsAtCompileTime) == 1 ? 0 + : int(ExpressionType::ColsAtCompileTime) == 1 ? inner + : int(ExpressionType::Flags)&RowMajorBit ? outer + : inner; + } + + Index colIndexByOuterInner(Index outer, Index inner) const + { + return int(ExpressionType::ColsAtCompileTime) == 1 ? 0 + : int(ExpressionType::RowsAtCompileTime) == 1 ? inner + : int(ExpressionType::Flags)&RowMajorBit ? inner + : outer; + } + + evaluator_impl& derived() + { + return *static_cast*>(this); + } +}; + +// -------------------- Transpose -------------------- + +template +struct evaluator_impl > + : evaluator_impl_base > +{ + typedef Transpose XprType; + + evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {} + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(col, row); + } + + CoeffReturnType coeff(Index index) const { return m_argImpl.coeff(index); } - typename TransposeType::Scalar& coeffRef(Index i, Index j) + Scalar& coeffRef(Index row, Index col) { - return m_argImpl.coeffRef(j, i); + return m_argImpl.coeffRef(col, row); } - typename TransposeType::Scalar& coeffRef(Index index) + typename XprType::Scalar& coeffRef(Index index) { return m_argImpl.coeffRef(index); } - // TODO: Difference between PacketScalar and PacketReturnType? template - const typename ExpressionType::PacketScalar packet(Index row, Index col) const + PacketReturnType packet(Index row, Index col) const { return m_argImpl.template packet(col, row); } template - const typename ExpressionType::PacketScalar packet(Index index) const + PacketReturnType packet(Index index) const { return m_argImpl.template packet(index); } template - void writePacket(Index row, Index col, const typename ExpressionType::PacketScalar& x) + void writePacket(Index row, Index col, const PacketScalar& x) { m_argImpl.template writePacket(col, row, x); } template - void writePacket(Index index, const typename ExpressionType::PacketScalar& x) + void writePacket(Index index, const PacketScalar& x) { m_argImpl.template writePacket(index, x); } protected: - typename evaluator::type m_argImpl; + typename evaluator::type m_argImpl; }; // -------------------- Matrix and Array -------------------- @@ -113,6 +190,7 @@ protected: template struct evaluator_impl > + : evaluator_impl_base { typedef PlainObjectBase PlainObjectType; @@ -176,10 +254,10 @@ template > : evaluator_impl > > { - typedef Matrix MatrixType; + typedef Matrix XprType; - evaluator_impl(const MatrixType& m) - : evaluator_impl >(m) + evaluator_impl(const XprType& m) + : evaluator_impl >(m) { } }; @@ -187,10 +265,10 @@ template > : evaluator_impl > > { - typedef Array ArrayType; + typedef Array XprType; - evaluator_impl(const ArrayType& m) - : evaluator_impl >(m) + evaluator_impl(const XprType& m) + : evaluator_impl >(m) { } }; @@ -199,15 +277,15 @@ struct evaluator_impl > template struct evaluator_impl > { - typedef CwiseNullaryOp NullaryOpType; + typedef CwiseNullaryOp XprType; - evaluator_impl(const NullaryOpType& n) + evaluator_impl(const XprType& n) : m_functor(n.functor()) { } - typedef typename NullaryOpType::Index Index; - typedef typename NullaryOpType::CoeffReturnType CoeffReturnType; - typedef typename NullaryOpType::PacketScalar PacketScalar; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; CoeffReturnType coeff(Index row, Index col) const { @@ -240,16 +318,16 @@ protected: template struct evaluator_impl > { - typedef CwiseUnaryOp UnaryOpType; + typedef CwiseUnaryOp XprType; - evaluator_impl(const UnaryOpType& op) + evaluator_impl(const XprType& op) : m_functor(op.functor()), m_argImpl(op.nestedExpression()) { } - typedef typename UnaryOpType::Index Index; - typedef typename UnaryOpType::CoeffReturnType CoeffReturnType; - typedef typename UnaryOpType::PacketScalar PacketScalar; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; CoeffReturnType coeff(Index row, Index col) const { @@ -283,17 +361,17 @@ protected: template struct evaluator_impl > { - typedef CwiseBinaryOp BinaryOpType; + typedef CwiseBinaryOp XprType; - evaluator_impl(const BinaryOpType& xpr) + evaluator_impl(const XprType& xpr) : m_functor(xpr.functor()), m_lhsImpl(xpr.lhs()), m_rhsImpl(xpr.rhs()) { } - typedef typename BinaryOpType::Index Index; - typedef typename BinaryOpType::CoeffReturnType CoeffReturnType; - typedef typename BinaryOpType::PacketScalar PacketScalar; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; CoeffReturnType coeff(Index row, Index col) const { @@ -329,17 +407,18 @@ protected: template struct evaluator_impl > + : evaluator_impl_base > { - typedef CwiseUnaryView CwiseUnaryViewType; + typedef CwiseUnaryView XprType; - evaluator_impl(const CwiseUnaryViewType& op) + evaluator_impl(const XprType& op) : m_unaryOp(op.functor()), m_argImpl(op.nestedExpression()) { } - typedef typename CwiseUnaryViewType::Index Index; - typedef typename CwiseUnaryViewType::Scalar Scalar; - typedef typename CwiseUnaryViewType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -400,23 +479,26 @@ protected: template struct evaluator_impl > + : evaluator_impl_base { typedef MapBase MapType; - typedef typename MapType::PointerType PointerType; - typedef typename MapType::Index Index; - typedef typename MapType::Scalar Scalar; - typedef typename MapType::CoeffReturnType CoeffReturnType; - typedef typename MapType::PacketScalar PacketScalar; - typedef typename MapType::PacketReturnType PacketReturnType; + typedef Derived XprType; + + typedef typename XprType::PointerType PointerType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; - evaluator_impl(const MapType& map) + evaluator_impl(const XprType& map) : m_data(const_cast(map.data())), m_rowStride(map.rowStride()), m_colStride(map.colStride()) { } enum { - RowsAtCompileTime = MapType::RowsAtCompileTime + RowsAtCompileTime = XprType::RowsAtCompileTime }; CoeffReturnType coeff(Index row, Index col) const @@ -480,34 +562,35 @@ template struct evaluator_impl > : public evaluator_impl > > { - typedef Map MapType; + typedef Map XprType; - evaluator_impl(const MapType& map) - : evaluator_impl >(map) + evaluator_impl(const XprType& map) + : evaluator_impl >(map) { } }; // -------------------- Block -------------------- -template -struct evaluator_impl > +template +struct evaluator_impl > + : evaluator_impl_base > { - typedef Block BlockType; + typedef Block XprType; - evaluator_impl(const BlockType& block) + evaluator_impl(const XprType& block) : m_argImpl(block.nestedExpression()), m_startRow(block.startRow()), m_startCol(block.startCol()) { } - typedef typename BlockType::Index Index; - typedef typename BlockType::Scalar Scalar; - typedef typename BlockType::CoeffReturnType CoeffReturnType; - typedef typename BlockType::PacketScalar PacketScalar; - typedef typename BlockType::PacketReturnType PacketReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; enum { - RowsAtCompileTime = BlockType::RowsAtCompileTime + RowsAtCompileTime = XprType::RowsAtCompileTime }; CoeffReturnType coeff(Index row, Index col) const @@ -560,7 +643,7 @@ struct evaluator_impl::type m_argImpl; + typename evaluator::type m_argImpl; // TODO: Get rid of m_startRow, m_startCol if known at compile time Index m_startRow; @@ -570,14 +653,14 @@ protected: // TODO: This evaluator does not actually use the child evaluator; // all action is via the data() as returned by the Block expression. -template -struct evaluator_impl > - : evaluator_impl > > +template +struct evaluator_impl > + : evaluator_impl > > { - typedef Block BlockType; + typedef Block XprType; - evaluator_impl(const BlockType& block) - : evaluator_impl >(block) + evaluator_impl(const XprType& block) + : evaluator_impl >(block) { } }; @@ -587,16 +670,16 @@ struct evaluator_impl struct evaluator_impl > { - typedef Select SelectType; + typedef Select XprType; - evaluator_impl(const SelectType& select) + evaluator_impl(const XprType& select) : m_conditionImpl(select.conditionMatrix()), m_thenImpl(select.thenMatrix()), m_elseImpl(select.elseMatrix()) { } - typedef typename SelectType::Index Index; - typedef typename SelectType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -623,20 +706,20 @@ protected: // -------------------- Replicate -------------------- -template -struct evaluator_impl > +template +struct evaluator_impl > { - typedef Replicate ReplicateType; + typedef Replicate XprType; - evaluator_impl(const ReplicateType& replicate) + evaluator_impl(const XprType& replicate) : m_argImpl(replicate.nestedExpression()), m_rows(replicate.nestedExpression().rows()), m_cols(replicate.nestedExpression().cols()) { } - typedef typename ReplicateType::Index Index; - typedef typename ReplicateType::CoeffReturnType CoeffReturnType; - typedef typename ReplicateType::PacketReturnType PacketReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -665,7 +748,7 @@ struct evaluator_impl > } protected: - typename evaluator::type m_argImpl; + typename evaluator::type m_argImpl; Index m_rows; // TODO: Get rid of this if known at compile time Index m_cols; }; @@ -677,17 +760,17 @@ protected: // TODO: Find out how to write a proper evaluator without duplicating // the row() and col() member functions. -template< typename XprType, typename MemberOp, int Direction> -struct evaluator_impl > +template< typename ArgType, typename MemberOp, int Direction> +struct evaluator_impl > { - typedef PartialReduxExpr PartialReduxExprType; + typedef PartialReduxExpr XprType; - evaluator_impl(const PartialReduxExprType expr) + evaluator_impl(const XprType expr) : m_expr(expr) { } - typedef typename PartialReduxExprType::Index Index; - typedef typename PartialReduxExprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index col) const { @@ -700,7 +783,7 @@ struct evaluator_impl > } protected: - const PartialReduxExprType& m_expr; + const XprType& m_expr; }; @@ -711,6 +794,7 @@ protected: template struct evaluator_impl_wrapper_base + : evaluator_impl_base { evaluator_impl_wrapper_base(const ArgType& arg) : m_argImpl(arg) {} @@ -772,7 +856,9 @@ template struct evaluator_impl > : evaluator_impl_wrapper_base { - evaluator_impl(const MatrixWrapper& wrapper) + typedef MatrixWrapper XprType; + + evaluator_impl(const XprType& wrapper) : evaluator_impl_wrapper_base(wrapper.nestedExpression()) { } }; @@ -781,7 +867,9 @@ template struct evaluator_impl > : evaluator_impl_wrapper_base { - evaluator_impl(const ArrayWrapper& wrapper) + typedef ArrayWrapper XprType; + + evaluator_impl(const XprType& wrapper) : evaluator_impl_wrapper_base(wrapper.nestedExpression()) { } }; @@ -794,24 +882,25 @@ template struct reverse_packet_cond; template struct evaluator_impl > + : evaluator_impl_base > { - typedef Reverse ReverseType; + typedef Reverse XprType; - evaluator_impl(const ReverseType& reverse) + evaluator_impl(const XprType& reverse) : m_argImpl(reverse.nestedExpression()), m_rows(reverse.nestedExpression().rows()), m_cols(reverse.nestedExpression().cols()) { } - typedef typename ReverseType::Index Index; - typedef typename ReverseType::Scalar Scalar; - typedef typename ReverseType::CoeffReturnType CoeffReturnType; - typedef typename ReverseType::PacketScalar PacketScalar; - typedef typename ReverseType::PacketReturnType PacketReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; enum { PacketSize = internal::packet_traits::size, - IsRowMajor = ReverseType::IsRowMajor, + IsRowMajor = XprType::IsRowMajor, IsColMajor = !IsRowMajor, ReverseRow = (Direction == Vertical) || (Direction == BothDirections), ReverseCol = (Direction == Horizontal) || (Direction == BothDirections), @@ -885,17 +974,18 @@ protected: template struct evaluator_impl > + : evaluator_impl_base > { - typedef Diagonal DiagonalType; + typedef Diagonal XprType; - evaluator_impl(const DiagonalType& diagonal) + evaluator_impl(const XprType& diagonal) : m_argImpl(diagonal.nestedExpression()), m_index(diagonal.index()) { } - typedef typename DiagonalType::Index Index; - typedef typename DiagonalType::Scalar Scalar; - typedef typename DiagonalType::CoeffReturnType CoeffReturnType; + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; CoeffReturnType coeff(Index row, Index) const { diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 5a123f0ad..ea957cb1e 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -214,5 +214,5 @@ void test_evaluators() copy_using_evaluator(mat1.diagonal<-1>(), mat1.diagonal(1)); mat2.diagonal<-1>() = mat2.diagonal(1); - VERIFY_IS_APPROX(mat1, mat2); + VERIFY_IS_APPROX(mat1, mat2); }