Improbe compatibility of Transpositions and evaluators

This commit is contained in:
Gael Guennebaud 2015-06-19 14:10:44 +02:00
parent 3af4c6c1c9
commit 4a8888dfbc
5 changed files with 50 additions and 15 deletions

View File

@ -397,12 +397,12 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename Other> template<typename Other>
PermutationMatrix(const Transpose<PermutationBase<Other> >& other) PermutationMatrix(const Transpose<PermutationBase<Other> >& other)
: m_indices(other.nestedPermutation().size()) : m_indices(other.nestedExpression().size())
{ {
eigen_internal_assert(m_indices.size() <= NumTraits<StorageIndex>::highest()); eigen_internal_assert(m_indices.size() <= NumTraits<StorageIndex>::highest());
StorageIndex end = StorageIndex(m_indices.size()); StorageIndex end = StorageIndex(m_indices.size());
for (StorageIndex i=0; i<end;++i) for (StorageIndex i=0; i<end;++i)
m_indices.coeffRef(other.nestedPermutation().indices().coeff(i)) = i; m_indices.coeffRef(other.nestedExpression().indices().coeff(i)) = i;
} }
template<typename Lhs,typename Rhs> template<typename Lhs,typename Rhs>
PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs) PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs)
@ -635,7 +635,7 @@ class Transpose<PermutationBase<Derived> >
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived()); return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
} }
const PermutationType& nestedPermutation() const { return m_permutation; } const PermutationType& nestedExpression() const { return m_permutation; }
protected: protected:
const PermutationType& m_permutation; const PermutationType& m_permutation;

View File

@ -53,6 +53,18 @@ template<typename Lhs, typename Rhs, typename LhsShape>
typedef typename Lhs::Scalar Scalar; typedef typename Lhs::Scalar Scalar;
}; };
template<typename Lhs, typename Rhs, typename RhsShape>
struct product_result_scalar<Lhs, Rhs, TranspositionsShape, RhsShape>
{
typedef typename Rhs::Scalar Scalar;
};
template<typename Lhs, typename Rhs, typename LhsShape>
struct product_result_scalar<Lhs, Rhs, LhsShape, TranspositionsShape>
{
typedef typename Lhs::Scalar Scalar;
};
template<typename Lhs, typename Rhs, int Option> template<typename Lhs, typename Rhs, int Option>
struct traits<Product<Lhs, Rhs, Option> > struct traits<Product<Lhs, Rhs, Option> >
{ {

View File

@ -915,7 +915,7 @@ struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, MatrixShape,
template<typename Dest> template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{ {
permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs); permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
} }
}; };
@ -925,7 +925,7 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
template<typename Dest> template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{ {
permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs); permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
} }
}; };
@ -944,7 +944,7 @@ template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
struct transposition_matrix_product struct transposition_matrix_product
{ {
template<typename Dest, typename TranspositionType> template<typename Dest, typename TranspositionType>
static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat) static inline void run(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
{ {
typedef typename TranspositionType::StorageIndex StorageIndex; typedef typename TranspositionType::StorageIndex StorageIndex;
const Index size = tr.size(); const Index size = tr.size();
@ -982,13 +982,14 @@ struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductT
} }
}; };
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag> struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
{ {
template<typename Dest> template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{ {
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs); transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
} }
}; };
@ -998,7 +999,7 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShap
template<typename Dest> template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{ {
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs); transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
} }
}; };

View File

@ -75,7 +75,11 @@ class TranspositionsBase
#endif #endif
/** \returns the number of transpositions */ /** \returns the number of transpositions */
inline Index size() const { return indices().size(); } Index size() const { return indices().size(); }
/** \returns the number of rows of the equivalent permutation matrix */
Index rows() const { return indices().size(); }
/** \returns the number of columns of the equivalent permutation matrix */
Index cols() const { return indices().size(); }
/** Direct access to the underlying index vector */ /** Direct access to the underlying index vector */
inline const StorageIndex& coeff(Index i) const { return indices().coeff(i); } inline const StorageIndex& coeff(Index i) const { return indices().coeff(i); }
@ -143,9 +147,10 @@ class TranspositionsBase
namespace internal { namespace internal {
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename _StorageIndex> template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename _StorageIndex>
struct traits<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> > struct traits<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> >
: traits<PermutationMatrix<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> >
{ {
typedef Matrix<_StorageIndex, SizeAtCompileTime, 1, 0, MaxSizeAtCompileTime, 1> IndicesType; typedef Matrix<_StorageIndex, SizeAtCompileTime, 1, 0, MaxSizeAtCompileTime, 1> IndicesType;
typedef _StorageIndex StorageIndex; typedef TranspositionsStorage StorageKind;
}; };
} }
@ -214,9 +219,11 @@ class Transpositions : public TranspositionsBase<Transpositions<SizeAtCompileTim
namespace internal { namespace internal {
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename _StorageIndex, int _PacketAccess> template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename _StorageIndex, int _PacketAccess>
struct traits<Map<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex>,_PacketAccess> > struct traits<Map<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex>,_PacketAccess> >
: traits<PermutationMatrix<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> >
{ {
typedef Map<const Matrix<_StorageIndex,SizeAtCompileTime,1,0,MaxSizeAtCompileTime,1>, _PacketAccess> IndicesType; typedef Map<const Matrix<_StorageIndex,SizeAtCompileTime,1,0,MaxSizeAtCompileTime,1>, _PacketAccess> IndicesType;
typedef _StorageIndex StorageIndex; typedef _StorageIndex StorageIndex;
typedef TranspositionsStorage StorageKind;
}; };
} }
@ -271,9 +278,9 @@ class Map<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex>,P
namespace internal { namespace internal {
template<typename _IndicesType> template<typename _IndicesType>
struct traits<TranspositionsWrapper<_IndicesType> > struct traits<TranspositionsWrapper<_IndicesType> >
: traits<PermutationWrapper<_IndicesType> >
{ {
typedef typename _IndicesType::Scalar StorageIndex; typedef TranspositionsStorage StorageKind;
typedef _IndicesType IndicesType;
}; };
} }
@ -347,8 +354,16 @@ operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
(transpositions.derived(), matrix.derived()); (transpositions.derived(), matrix.derived());
} }
// Template partial specialization for transposed/inverse transpositions
/* Template partial specialization for transposed/inverse transpositions */ namespace internal {
template<typename Derived>
struct traits<Transpose<TranspositionsBase<Derived> > >
: traits<Derived>
{};
} // end namespace internal
template<typename TranspositionsDerived> template<typename TranspositionsDerived>
class Transpose<TranspositionsBase<TranspositionsDerived> > class Transpose<TranspositionsBase<TranspositionsDerived> >
@ -359,7 +374,9 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
explicit Transpose(const TranspositionType& t) : m_transpositions(t) {} explicit Transpose(const TranspositionType& t) : m_transpositions(t) {}
inline int size() const { return m_transpositions.size(); } Index size() const { return m_transpositions.size(); }
Index rows() const { return m_transpositions.size(); }
Index cols() const { return m_transpositions.size(); }
/** \returns the \a matrix with the inverse transpositions applied to the columns. /** \returns the \a matrix with the inverse transpositions applied to the columns.
*/ */
@ -379,6 +396,8 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived()); return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
} }
const TranspositionType& nestedExpression() const { return m_transpositions; }
protected: protected:
const TranspositionType& m_transpositions; const TranspositionType& m_transpositions;
}; };

View File

@ -468,6 +468,9 @@ struct Sparse {};
/** The type used to identify a permutation storage. */ /** The type used to identify a permutation storage. */
struct PermutationStorage {}; struct PermutationStorage {};
/** The type used to identify a permutation storage. */
struct TranspositionsStorage {};
/** The type used to identify a matrix expression */ /** The type used to identify a matrix expression */
struct MatrixXpr {}; struct MatrixXpr {};