MatrixFunctions: replace eval() by nested.

This eliminates an unnecessary copy in some situations, e.g. Map.
This commit is contained in:
Jitse Niesen 2013-07-31 14:57:20 +01:00
parent 43df1e707c
commit 68168e9eae
4 changed files with 29 additions and 24 deletions

View File

@ -392,17 +392,14 @@ template<typename Derived> struct MatrixExponentialReturnValue
template <typename ResultType> template <typename ResultType>
inline void evalTo(ResultType& result) const inline void evalTo(ResultType& result) const
{ {
const typename Derived::PlainObject srcEvaluated = m_src.eval(); internal::matrix_exp_compute(m_src, result);
internal::matrix_exp_compute(srcEvaluated, result);
} }
Index rows() const { return m_src.rows(); } Index rows() const { return m_src.rows(); }
Index cols() const { return m_src.cols(); } Index cols() const { return m_src.cols(); }
protected: protected:
const Derived& m_src; const typename internal::nested<Derived, 10>::type m_src;
private:
MatrixExponentialReturnValue& operator=(const MatrixExponentialReturnValue&);
}; };
namespace internal { namespace internal {

View File

@ -481,11 +481,15 @@ template<typename Derived> class MatrixFunctionReturnValue
: public ReturnByValue<MatrixFunctionReturnValue<Derived> > : public ReturnByValue<MatrixFunctionReturnValue<Derived> >
{ {
public: public:
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename Derived::Index Index; typedef typename Derived::Index Index;
typedef typename internal::stem_function<Scalar>::type StemFunction; typedef typename internal::stem_function<Scalar>::type StemFunction;
protected:
typedef typename internal::nested<Derived, 10>::type DerivedNested;
public:
/** \brief Constructor. /** \brief Constructor.
* *
* \param[in] A %Matrix (expression) forming the argument of the matrix function. * \param[in] A %Matrix (expression) forming the argument of the matrix function.
@ -500,26 +504,25 @@ template<typename Derived> class MatrixFunctionReturnValue
template <typename ResultType> template <typename ResultType>
inline void evalTo(ResultType& result) const inline void evalTo(ResultType& result) const
{ {
typedef typename Derived::PlainObject PlainObject; typedef typename internal::remove_all<DerivedNested>::type DerivedNestedClean;
typedef internal::traits<PlainObject> Traits; typedef internal::traits<DerivedNestedClean> Traits;
static const int RowsAtCompileTime = Traits::RowsAtCompileTime; static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
static const int ColsAtCompileTime = Traits::ColsAtCompileTime; static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
static const int Options = PlainObject::Options; static const int Options = DerivedNestedClean::Options;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType; typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
typedef internal::MatrixFunctionAtomic<DynMatrixType> AtomicType; typedef internal::MatrixFunctionAtomic<DynMatrixType> AtomicType;
AtomicType atomic(m_f); AtomicType atomic(m_f);
const PlainObject Aevaluated = m_A.eval(); internal::matrix_function_compute<DerivedNestedClean>::run(m_A, atomic, result);
internal::matrix_function_compute<PlainObject>::run(Aevaluated, atomic, result);
} }
Index rows() const { return m_A.rows(); } Index rows() const { return m_A.rows(); }
Index cols() const { return m_A.cols(); } Index cols() const { return m_A.cols(); }
private: private:
typename internal::nested<Derived>::type m_A; const DerivedNested m_A;
StemFunction *m_f; StemFunction *m_f;
}; };

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library // This file is part of Eigen, a lightweight C++ template library
// for linear algebra. // for linear algebra.
// //
// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk> // Copyright (C) 2011, 2013 Jitse Niesen <jitse@maths.leeds.ac.uk>
// Copyright (C) 2011 Chen-Pang He <jdh8@ms63.hinet.net> // Copyright (C) 2011 Chen-Pang He <jdh8@ms63.hinet.net>
// //
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
@ -306,10 +306,14 @@ template<typename Derived> class MatrixLogarithmReturnValue
: public ReturnByValue<MatrixLogarithmReturnValue<Derived> > : public ReturnByValue<MatrixLogarithmReturnValue<Derived> >
{ {
public: public:
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
typedef typename Derived::Index Index; typedef typename Derived::Index Index;
protected:
typedef typename internal::nested<Derived, 10>::type DerivedNested;
public:
/** \brief Constructor. /** \brief Constructor.
* *
* \param[in] A %Matrix (expression) forming the argument of the matrix logarithm. * \param[in] A %Matrix (expression) forming the argument of the matrix logarithm.
@ -323,25 +327,24 @@ public:
template <typename ResultType> template <typename ResultType>
inline void evalTo(ResultType& result) const inline void evalTo(ResultType& result) const
{ {
typedef typename Derived::PlainObject PlainObject; typedef typename internal::remove_all<DerivedNested>::type DerivedNestedClean;
typedef internal::traits<PlainObject> Traits; typedef internal::traits<DerivedNestedClean> Traits;
static const int RowsAtCompileTime = Traits::RowsAtCompileTime; static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
static const int ColsAtCompileTime = Traits::ColsAtCompileTime; static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
static const int Options = PlainObject::Options; static const int Options = DerivedNestedClean::Options;
typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar; typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType; typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType; typedef internal::MatrixLogarithmAtomic<DynMatrixType> AtomicType;
AtomicType atomic; AtomicType atomic;
const PlainObject Aevaluated = m_A.eval(); internal::matrix_function_compute<DerivedNestedClean>::run(m_A, atomic, result);
internal::matrix_function_compute<PlainObject>::run(Aevaluated, atomic, result);
} }
Index rows() const { return m_A.rows(); } Index rows() const { return m_A.rows(); }
Index cols() const { return m_A.cols(); } Index cols() const { return m_A.cols(); }
private: private:
typename internal::nested<Derived>::type m_A; const DerivedNested m_A;
}; };
namespace internal { namespace internal {

View File

@ -318,7 +318,10 @@ struct matrix_sqrt_compute<MatrixType, 1>
template<typename Derived> class MatrixSquareRootReturnValue template<typename Derived> class MatrixSquareRootReturnValue
: public ReturnByValue<MatrixSquareRootReturnValue<Derived> > : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
{ {
protected:
typedef typename Derived::Index Index; typedef typename Derived::Index Index;
typedef typename internal::nested<Derived, 10>::type DerivedNested;
public: public:
/** \brief Constructor. /** \brief Constructor.
* *
@ -335,16 +338,15 @@ template<typename Derived> class MatrixSquareRootReturnValue
template <typename ResultType> template <typename ResultType>
inline void evalTo(ResultType& result) const inline void evalTo(ResultType& result) const
{ {
typedef typename Derived::PlainObject PlainObject; typedef typename internal::remove_all<DerivedNested>::type DerivedNestedClean;
const PlainObject srcEvaluated = m_src.eval(); internal::matrix_sqrt_compute<DerivedNestedClean>::run(m_src, result);
internal::matrix_sqrt_compute<PlainObject>::run(srcEvaluated, result);
} }
Index rows() const { return m_src.rows(); } Index rows() const { return m_src.rows(); }
Index cols() const { return m_src.cols(); } Index cols() const { return m_src.cols(); }
protected: protected:
const Derived& m_src; const DerivedNested m_src;
}; };
namespace internal { namespace internal {