mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Add a Solve expression for uniform treatment of solve() methods.
This commit is contained in:
parent
b3a07eecc5
commit
ccc41128fb
@ -361,6 +361,9 @@ using std::ptrdiff_t;
|
|||||||
#include "src/Core/Flagged.h"
|
#include "src/Core/Flagged.h"
|
||||||
#include "src/Core/ProductBase.h"
|
#include "src/Core/ProductBase.h"
|
||||||
#include "src/Core/GeneralProduct.h"
|
#include "src/Core/GeneralProduct.h"
|
||||||
|
#ifdef EIGEN_ENABLE_EVALUATORS
|
||||||
|
#include "src/Core/Solve.h"
|
||||||
|
#endif
|
||||||
#include "src/Core/TriangularMatrix.h"
|
#include "src/Core/TriangularMatrix.h"
|
||||||
#include "src/Core/SelfAdjointView.h"
|
#include "src/Core/SelfAdjointView.h"
|
||||||
#include "src/Core/products/GeneralBlockPanelKernel.h"
|
#include "src/Core/products/GeneralBlockPanelKernel.h"
|
||||||
|
@ -495,14 +495,14 @@ struct evaluator<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
|
|||||||
PacketScalar packet(Index row, Index col) const
|
PacketScalar packet(Index row, Index col) const
|
||||||
{
|
{
|
||||||
return m_functor.packetOp(m_lhsImpl.template packet<LoadMode>(row, col),
|
return m_functor.packetOp(m_lhsImpl.template packet<LoadMode>(row, col),
|
||||||
m_rhsImpl.template packet<LoadMode>(row, col));
|
m_rhsImpl.template packet<LoadMode>(row, col));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int LoadMode>
|
template<int LoadMode>
|
||||||
PacketScalar packet(Index index) const
|
PacketScalar packet(Index index) const
|
||||||
{
|
{
|
||||||
return m_functor.packetOp(m_lhsImpl.template packet<LoadMode>(index),
|
return m_functor.packetOp(m_lhsImpl.template packet<LoadMode>(index),
|
||||||
m_rhsImpl.template packet<LoadMode>(index));
|
m_rhsImpl.template packet<LoadMode>(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
145
Eigen/src/Core/Solve.h
Normal file
145
Eigen/src/Core/Solve.h
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#ifndef EIGEN_INVERSE_H
|
||||||
|
#define EIGEN_INVERSE_H
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
|
||||||
|
template<typename Decomposition, typename RhsType, typename StorageKind> class SolveImpl;
|
||||||
|
|
||||||
|
/** \class Solve
|
||||||
|
* \ingroup Core_Module
|
||||||
|
*
|
||||||
|
* \brief Pseudo expression representing a solving operation
|
||||||
|
*
|
||||||
|
* \tparam Decomposition the type of the matrix or decomposion object
|
||||||
|
* \tparam Rhstype the type of the right-hand side
|
||||||
|
*
|
||||||
|
* This class represents an expression of A.solve(B)
|
||||||
|
* and most of the time this is the only way it is used.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// this solve_traits class permits to determine the evaluation type with respect to storage kind (Dense vs Sparse)
|
||||||
|
template<typename Decomposition, typename RhsType,typename StorageKind> struct solve_traits;
|
||||||
|
|
||||||
|
template<typename Decomposition, typename RhsType>
|
||||||
|
struct solve_traits<Decomposition,RhsType,Dense>
|
||||||
|
{
|
||||||
|
typedef typename Decomposition::MatrixType MatrixType;
|
||||||
|
typedef Matrix<typename RhsType::Scalar,
|
||||||
|
MatrixType::ColsAtCompileTime,
|
||||||
|
RhsType::ColsAtCompileTime,
|
||||||
|
RhsType::PlainObject::Options,
|
||||||
|
MatrixType::MaxColsAtCompileTime,
|
||||||
|
RhsType::MaxColsAtCompileTime> PlainObject;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Decomposition, typename RhsType>
|
||||||
|
struct traits<Solve<Decomposition, RhsType> >
|
||||||
|
: traits<typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject>
|
||||||
|
{
|
||||||
|
typedef typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject PlainObject;
|
||||||
|
typedef traits<PlainObject> BaseTraits;
|
||||||
|
enum {
|
||||||
|
Flags = BaseTraits::Flags & RowMajorBit,
|
||||||
|
CoeffReadCost = Dynamic
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename Decomposition, typename RhsType>
|
||||||
|
class Solve : public SolveImpl<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef typename RhsType::Index Index;
|
||||||
|
typedef typename internal::traits<Solve>::PlainObject PlainObject;
|
||||||
|
|
||||||
|
Solve(const Decomposition &dec, const RhsType &rhs)
|
||||||
|
: m_dec(dec), m_rhs(rhs)
|
||||||
|
{}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC Index rows() const { return m_dec.rows(); }
|
||||||
|
EIGEN_DEVICE_FUNC Index cols() const { return m_rhs.cols(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC const Decomposition& dec() const { return m_dec; }
|
||||||
|
EIGEN_DEVICE_FUNC const RhsType& rhs() const { return m_rhs; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const Decomposition &m_dec;
|
||||||
|
const RhsType &m_rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Specilaization of the Solve expression for dense results
|
||||||
|
template<typename Decomposition, typename RhsType>
|
||||||
|
class SolveImpl<Decomposition,RhsType,Dense>
|
||||||
|
: public MatrixBase<Solve<Decomposition,RhsType> >
|
||||||
|
{
|
||||||
|
typedef Solve<Decomposition,RhsType> Derived;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
typedef MatrixBase<Solve<Decomposition,RhsType> > Base;
|
||||||
|
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
Scalar coeff(Index row, Index col) const;
|
||||||
|
Scalar coeff(Index i) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// Evaluator of Solve -> eval into a temporary
|
||||||
|
template<typename Decomposition, typename RhsType>
|
||||||
|
struct evaluator<Solve<Decomposition,RhsType> >
|
||||||
|
: public evaluator<typename Solve<Decomposition,RhsType>::PlainObject>::type
|
||||||
|
{
|
||||||
|
typedef Solve<Decomposition,RhsType> SolveType;
|
||||||
|
typedef typename SolveType::PlainObject PlainObject;
|
||||||
|
typedef typename evaluator<PlainObject>::type Base;
|
||||||
|
|
||||||
|
typedef evaluator type;
|
||||||
|
typedef evaluator nestedType;
|
||||||
|
|
||||||
|
evaluator(const SolveType& solve)
|
||||||
|
: m_result(solve.rows(), solve.cols())
|
||||||
|
{
|
||||||
|
::new (static_cast<Base*>(this)) Base(m_result);
|
||||||
|
solve.dec()._solve_impl(solve.rhs(), m_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
PlainObject m_result;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specialization for "dst = dec.solve(rhs)"
|
||||||
|
// NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse specialization must exist somewhere
|
||||||
|
template<typename DstXprType, typename DecType, typename RhsType, typename Scalar>
|
||||||
|
struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar>, Dense2Dense, Scalar>
|
||||||
|
{
|
||||||
|
typedef Solve<DecType,RhsType> SrcXprType;
|
||||||
|
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &)
|
||||||
|
{
|
||||||
|
// FIXME shall we resize dst here?
|
||||||
|
src.dec()._solve_impl(src.rhs(), dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namepsace internal
|
||||||
|
|
||||||
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
#endif // EIGEN_SOLVE_H
|
@ -437,11 +437,19 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
|
|||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
|
void solveInPlace(const MatrixBase<OtherDerived>& other) const;
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
template<typename Other>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
inline const Solve<TriangularView, Other>
|
||||||
|
solve(const MatrixBase<Other>& other) const
|
||||||
|
{ return Solve<TriangularView, Other>(*this, other.derived()); }
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
template<typename Other>
|
template<typename Other>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
inline const internal::triangular_solve_retval<OnTheLeft,TriangularView, Other>
|
inline const internal::triangular_solve_retval<OnTheLeft,TriangularView, Other>
|
||||||
solve(const MatrixBase<Other>& other) const
|
solve(const MatrixBase<Other>& other) const
|
||||||
{ return solve<OnTheLeft>(other); }
|
{ return solve<OnTheLeft>(other); }
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -547,6 +555,15 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
|
|||||||
#endif // EIGEN_TEST_EVALUATORS
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
#ifdef EIGEN_TEST_EVALUATORS
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
template<typename RhsType, typename DstType>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
|
||||||
|
if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
|
||||||
|
dst = rhs;
|
||||||
|
this->template solveInPlace(dst);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename ProductType>
|
template<typename ProductType>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE TriangularView& _assignProduct(const ProductType& prod, const Scalar& alpha);
|
EIGEN_STRONG_INLINE TriangularView& _assignProduct(const ProductType& prod, const Scalar& alpha);
|
||||||
|
@ -94,7 +94,8 @@ template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
|||||||
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
||||||
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
||||||
template<typename BinOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp; // TODO deprecated
|
template<typename BinOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp; // TODO deprecated
|
||||||
template<typename Derived, typename Lhs, typename Rhs> class ProductBase;
|
template<typename Derived, typename Lhs, typename Rhs> class ProductBase; // TODO deprecated
|
||||||
|
template<typename Decomposition, typename Rhstype> class Solve;
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template<typename Lhs, typename Rhs> struct product_tag;
|
template<typename Lhs, typename Rhs> struct product_tag;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user