diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index 08a2c696a..9be00067d 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -616,7 +616,13 @@ struct copy_using_evaluator_impl::type tmpEvaluator(tmp); + srcEvaluator.evalTo(tmpEvaluator, tmp); + copy_using_evaluator(dst, tmp); } }; diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 768fa8950..808546ec1 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -3,7 +3,7 @@ // // Copyright (C) 2011 Benoit Jacob // Copyright (C) 2011 Gael Guennebaud -// Copyright (C) 2011 Jitse Niesen +// Copyright (C) 2011-2012 Jitse Niesen // // Eigen is free software; you can redistribute it and/or // modify it under the terms of the GNU Lesser General Public @@ -42,24 +42,46 @@ struct evaluator_traits static const int HasEvalTo = 0; }; +// expression class for evaluating nested expression to a temporary + +template +class EvalToTemp; + // evaluator::type is type of evaluator for T +// evaluator::nestedType is type of evaluator if T is nested inside another evaluator + +template +struct evaluator_impl +{ }; + +template::HasEvalTo> +struct evaluator_nested_type; template -struct evaluator_impl {}; +struct evaluator_nested_type +{ + typedef evaluator_impl type; +}; + +template +struct evaluator_nested_type +{ + typedef evaluator_impl > type; +}; template struct evaluator { typedef evaluator_impl type; + typedef typename evaluator_nested_type::type nestedType; }; // TODO: Think about const-correctness template struct evaluator -{ - typedef evaluator_impl type; -}; + : evaluator +{ }; // ---------- base class for all writable evaluators ---------- @@ -132,70 +154,6 @@ struct evaluator_impl_base } }; -// -------------------- Transpose -------------------- - -template -struct evaluator_impl > - : evaluator_impl_base > -{ - typedef Transpose XprType; - - evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {} - - typedef typename XprType::Index Index; - typedef typename XprType::Scalar Scalar; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketScalar PacketScalar; - typedef typename XprType::PacketReturnType PacketReturnType; - - CoeffReturnType coeff(Index row, Index col) const - { - return m_argImpl.coeff(col, row); - } - - CoeffReturnType coeff(Index index) const - { - return m_argImpl.coeff(index); - } - - Scalar& coeffRef(Index row, Index col) - { - return m_argImpl.coeffRef(col, row); - } - - typename XprType::Scalar& coeffRef(Index index) - { - return m_argImpl.coeffRef(index); - } - - template - PacketReturnType packet(Index row, Index col) const - { - return m_argImpl.template packet(col, row); - } - - template - PacketReturnType packet(Index index) const - { - return m_argImpl.template packet(index); - } - - template - void writePacket(Index row, Index col, const PacketScalar& x) - { - m_argImpl.template writePacket(col, row, x); - } - - template - void writePacket(Index index, const PacketScalar& x) - { - m_argImpl.template writePacket(index, x); - } - -protected: - typename evaluator::type m_argImpl; -}; - // -------------------- Matrix and Array -------------------- // // evaluator_impl is a common base class for the @@ -285,6 +243,89 @@ struct evaluator_impl > { } }; +// -------------------- EvalToTemp -------------------- + +template +struct evaluator_impl > + : evaluator_impl +{ + typedef typename ArgType::PlainObject PlainObject; + typedef evaluator_impl BaseType; + + evaluator_impl(const ArgType& arg) + : BaseType(m_result) + { + copy_using_evaluator(m_result, arg); + }; + +protected: + PlainObject m_result; +}; + +// -------------------- Transpose -------------------- + +template +struct evaluator_impl > + : evaluator_impl_base > +{ + typedef Transpose XprType; + + evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {} + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketScalar PacketScalar; + typedef typename XprType::PacketReturnType PacketReturnType; + + CoeffReturnType coeff(Index row, Index col) const + { + return m_argImpl.coeff(col, row); + } + + CoeffReturnType coeff(Index index) const + { + return m_argImpl.coeff(index); + } + + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(col, row); + } + + typename XprType::Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template + PacketReturnType packet(Index row, Index col) const + { + return m_argImpl.template packet(col, row); + } + + template + PacketReturnType packet(Index index) const + { + return m_argImpl.template packet(index); + } + + template + void writePacket(Index row, Index col, const PacketScalar& x) + { + m_argImpl.template writePacket(col, row, x); + } + + template + void writePacket(Index index, const PacketScalar& x) + { + m_argImpl.template writePacket(index, x); + } + +protected: + typename evaluator::nestedType m_argImpl; +}; + // -------------------- CwiseNullaryOp -------------------- template @@ -366,7 +407,7 @@ struct evaluator_impl > protected: const UnaryOp m_functor; - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; }; // -------------------- CwiseBinaryOp -------------------- @@ -412,8 +453,8 @@ struct evaluator_impl > protected: const BinaryOp m_functor; - typename evaluator::type m_lhsImpl; - typename evaluator::type m_rhsImpl; + typename evaluator::nestedType m_lhsImpl; + typename evaluator::nestedType m_rhsImpl; }; // -------------------- CwiseUnaryView -------------------- @@ -455,7 +496,7 @@ struct evaluator_impl > protected: const UnaryOp m_unaryOp; - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; }; // -------------------- Map -------------------- @@ -626,7 +667,7 @@ struct evaluator_impl::type m_argImpl; + typename evaluator::nestedType m_argImpl; // TODO: Get rid of m_startRow, m_startCol if known at compile time Index m_startRow; @@ -681,9 +722,9 @@ struct evaluator_impl::type m_conditionImpl; - typename evaluator::type m_thenImpl; - typename evaluator::type m_elseImpl; + typename evaluator::nestedType m_conditionImpl; + typename evaluator::nestedType m_thenImpl; + typename evaluator::nestedType m_elseImpl; }; @@ -731,7 +772,7 @@ struct evaluator_impl > } protected: - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; Index m_rows; // TODO: Get rid of this if known at compile time Index m_cols; }; @@ -834,7 +875,7 @@ struct evaluator_impl_wrapper_base } protected: - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; }; template @@ -949,7 +990,7 @@ struct evaluator_impl > } protected: - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; Index m_rows; // TODO: Don't use if known at compile time or not needed Index m_cols; }; @@ -993,7 +1034,7 @@ struct evaluator_impl > } protected: - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; Index m_index; // TODO: Don't use if known at compile time private: @@ -1069,7 +1110,7 @@ struct evaluator_impl > } protected: - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; }; @@ -1133,7 +1174,7 @@ struct evaluator_impl > } protected: - typename evaluator::type m_argImpl; + typename evaluator::nestedType m_argImpl; const BinaryOp m_functor; }; diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index aadaa9303..e814a4710 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -50,12 +50,26 @@ struct evaluator_impl > { } }; +template +struct product_evaluator_traits_dispatcher; + +template +struct evaluator_traits > + : product_evaluator_traits_dispatcher, typename ProductReturnType::Type> +{ }; + // Case 1: Evaluate all at once // // We can view the GeneralProduct class as a part of the product evaluator. // Four sub-cases: InnerProduct, OuterProduct, GemmProduct and GemvProduct. // InnerProduct is special because GeneralProduct does not have an evalTo() method in this case. +template +struct product_evaluator_traits_dispatcher, GeneralProduct > +{ + static const int HasEvalTo = 0; +}; + template struct product_evaluator_dispatcher, GeneralProduct > : public evaluator::PlainObject>::type @@ -63,7 +77,8 @@ struct product_evaluator_dispatcher, GeneralProduct XprType; typedef typename XprType::PlainObject PlainObject; typedef typename evaluator::type evaluator_base; - + + // TODO: Computation is too early (?) product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result) { m_result.coeffRef(0,0) = (xpr.lhs().transpose().cwiseProduct(xpr.rhs())).sum(); @@ -76,22 +91,31 @@ protected: // For the other three subcases, simply call the evalTo() method of GeneralProduct // TODO: GeneralProduct should take evaluators, not expression objects. +template +struct product_evaluator_traits_dispatcher, GeneralProduct > +{ + static const int HasEvalTo = 1; +}; + template struct product_evaluator_dispatcher, GeneralProduct > - : public evaluator::PlainObject>::type { typedef Product XprType; typedef typename XprType::PlainObject PlainObject; typedef typename evaluator::type evaluator_base; - product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result) + product_evaluator_dispatcher(const XprType& xpr) : m_xpr(xpr) + { } + + template + void evalTo(DstEvaluatorType /* not used */, DstXprType& dst) { - m_result.resize(xpr.rows(), xpr.cols()); - GeneralProduct(xpr.lhs(), xpr.rhs()).evalTo(m_result); + dst.resize(m_xpr.rows(), m_xpr.cols()); + GeneralProduct(m_xpr.lhs(), m_xpr.rhs()).evalTo(dst); } -protected: - PlainObject m_result; +protected: + const XprType& m_xpr; }; // Case 2: Evaluate coeff by coeff @@ -106,6 +130,12 @@ struct etor_product_coeff_impl; template struct etor_product_packet_impl; +template +struct product_evaluator_traits_dispatcher, CoeffBasedProduct > +{ + static const int HasEvalTo = 0; +}; + template struct product_evaluator_dispatcher, CoeffBasedProduct > : evaluator_impl_base > diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 62ba5b126..3081d7858 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -65,6 +65,11 @@ void test_evaluators() VERIFY_IS_APPROX_EVALUATOR2(d, s * prod(a,b), s * a*b); VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b).transpose(), (a*b).transpose()); VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b) + prod(b,c), a*b + b*c); + + // check that prod works even with aliasing present + c = a*a; + copy_using_evaluator(a, prod(a,a)); + VERIFY_IS_APPROX(a,c); } {