pass eigen2's triangular test

This commit is contained in:
Benoit Jacob 2011-01-23 21:53:28 -05:00
parent 5c82fd7f40
commit 1dabd133cc
4 changed files with 54 additions and 5 deletions

View File

@ -76,6 +76,19 @@ class DiagonalBase : public EigenBase<Derived>
{
return diagonal().cwiseInverse();
}
#ifdef EIGEN2_SUPPORT
template<typename OtherDerived>
bool isApprox(const DiagonalBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const
{
return diagonal().isApprox(other.diagonal(), precision);
}
template<typename OtherDerived>
bool isApprox(const MatrixBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const
{
return toDenseMatrix().isApprox(other, precision);
}
#endif
};
template<typename Derived>
@ -256,7 +269,7 @@ class DiagonalWrapper
* \sa class DiagonalWrapper, class DiagonalMatrix, diagonal(), isDiagonal()
**/
template<typename Derived>
inline const DiagonalWrapper<Derived>
inline const DiagonalWrapper<const Derived>
MatrixBase<Derived>::asDiagonal() const
{
return derived();

View File

@ -245,7 +245,13 @@ template<typename Derived> class MatrixBase
#ifdef EIGEN2_SUPPORT
template<unsigned int Mode> TriangularView<Derived, Mode> part();
template<unsigned int Mode> const TriangularView<Derived, Mode> part() const;
#endif
// huuuge hack. make Eigen2's matrix.part<Diagonal>() work in eigen3. Problem: Diagonal is now a class template instead
// of an integer constant. Solution: overload the part() method template wrt template parameters list.
template<template<typename T, int n> class U>
const DiagonalWrapper<ConstDiagonalReturnType> part() const
{ return diagonal().asDiagonal(); }
#endif // EIGEN2_SUPPORT
template<unsigned int Mode> struct TriangularViewReturnType { typedef TriangularView<Derived, Mode> Type; };
template<unsigned int Mode> struct ConstTriangularViewReturnType { typedef const TriangularView<const Derived, Mode> Type; };
@ -270,7 +276,7 @@ template<typename Derived> class MatrixBase
static const BasisReturnType UnitZ();
static const BasisReturnType UnitW();
const DiagonalWrapper<Derived> asDiagonal() const;
const DiagonalWrapper<const Derived> asDiagonal() const;
Derived& setIdentity();
Derived& setIdentity(Index rows, Index cols);

View File

@ -296,6 +296,36 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
(lhs.derived(),rhs.m_matrix);
}
#ifdef EIGEN2_SUPPORT
template<typename OtherMatrixType>
struct eigen2_product_return_type
{
typedef typename TriangularView<MatrixType,Mode>::DenseMatrixType DenseMatrixType;
typedef typename TriangularView<OtherMatrixType,Mode>::DenseMatrixType OtherDenseMatrixType;
typedef typename ProductReturnType<DenseMatrixType, OtherDenseMatrixType>::Type ProdRetType;
typedef typename ProdRetType::PlainObject type;
};
template<typename OtherMatrixType>
const typename eigen2_product_return_type<OtherMatrixType>::type
operator*(const TriangularView<OtherMatrixType, Mode>& rhs) const
{
return toDenseMatrix() * rhs.toDenseMatrix();
}
template<typename OtherMatrixType>
bool isApprox(const TriangularView<OtherMatrixType, Mode>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const
{
return toDenseMatrix().isApprox(other.toDenseMatrix(), precision);
}
template<typename OtherDerived>
bool isApprox(const MatrixBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const
{
return toDenseMatrix().isApprox(other, precision);
}
#endif // EIGEN2_SUPPORT
template<int Side, typename OtherDerived>
typename internal::plain_matrix_type_column_major<OtherDerived>::type
solve(const MatrixBase<OtherDerived>& other) const;

View File

@ -40,14 +40,14 @@ template<typename OtherDerived>
typename ExpressionType::PlainObject
Flagged<ExpressionType,Added,Removed>::solveTriangular(const MatrixBase<OtherDerived>& other) const
{
return m_matrix.template triangularView<Added>.solve(other.derived());
return m_matrix.template triangularView<Added>().solve(other.derived());
}
template<typename ExpressionType, unsigned int Added, unsigned int Removed>
template<typename OtherDerived>
void Flagged<ExpressionType,Added,Removed>::solveTriangularInPlace(const MatrixBase<OtherDerived>& other) const
{
m_matrix.template triangularView<Added>.solveInPlace(other.derived());
m_matrix.template triangularView<Added>().solveInPlace(other.derived());
}
#endif // EIGEN_TRIANGULAR_SOLVER2_H