Improbe compatibility of Transpositions and evaluators

This commit is contained in:
Gael Guennebaud 2015-06-19 14:10:44 +02:00
parent 3af4c6c1c9
commit 4a8888dfbc
5 changed files with 50 additions and 15 deletions

View File

@ -397,12 +397,12 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename Other>
PermutationMatrix(const Transpose<PermutationBase<Other> >& other)
: m_indices(other.nestedPermutation().size())
: m_indices(other.nestedExpression().size())
{
eigen_internal_assert(m_indices.size() <= NumTraits<StorageIndex>::highest());
StorageIndex end = StorageIndex(m_indices.size());
for (StorageIndex i=0; i<end;++i)
m_indices.coeffRef(other.nestedPermutation().indices().coeff(i)) = i;
m_indices.coeffRef(other.nestedExpression().indices().coeff(i)) = i;
}
template<typename Lhs,typename Rhs>
PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs)
@ -635,7 +635,7 @@ class Transpose<PermutationBase<Derived> >
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
}
const PermutationType& nestedPermutation() const { return m_permutation; }
const PermutationType& nestedExpression() const { return m_permutation; }
protected:
const PermutationType& m_permutation;

View File

@ -53,6 +53,18 @@ template<typename Lhs, typename Rhs, typename LhsShape>
typedef typename Lhs::Scalar Scalar;
};
template<typename Lhs, typename Rhs, typename RhsShape>
struct product_result_scalar<Lhs, Rhs, TranspositionsShape, RhsShape>
{
typedef typename Rhs::Scalar Scalar;
};
template<typename Lhs, typename Rhs, typename LhsShape>
struct product_result_scalar<Lhs, Rhs, LhsShape, TranspositionsShape>
{
typedef typename Lhs::Scalar Scalar;
};
template<typename Lhs, typename Rhs, int Option>
struct traits<Product<Lhs, Rhs, Option> >
{

View File

@ -915,7 +915,7 @@ struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, MatrixShape,
template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{
permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs);
permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
}
};
@ -925,7 +925,7 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{
permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs);
permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
}
};
@ -944,7 +944,7 @@ template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
struct transposition_matrix_product
{
template<typename Dest, typename TranspositionType>
static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
static inline void run(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
{
typedef typename TranspositionType::StorageIndex StorageIndex;
const Index size = tr.size();
@ -982,13 +982,14 @@ struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductT
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
{
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs);
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
}
};
@ -998,7 +999,7 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShap
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
{
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs);
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
}
};

View File

@ -62,7 +62,7 @@ class TranspositionsBase
indices() = other.indices();
return derived();
}
#ifndef EIGEN_PARSED_BY_DOXYGEN
/** This is a special case of the templated operator=. Its purpose is to
* prevent a default operator= from hiding the templated operator=.
@ -75,7 +75,11 @@ class TranspositionsBase
#endif
/** \returns the number of transpositions */
inline Index size() const { return indices().size(); }
Index size() const { return indices().size(); }
/** \returns the number of rows of the equivalent permutation matrix */
Index rows() const { return indices().size(); }
/** \returns the number of columns of the equivalent permutation matrix */
Index cols() const { return indices().size(); }
/** Direct access to the underlying index vector */
inline const StorageIndex& coeff(Index i) const { return indices().coeff(i); }
@ -143,9 +147,10 @@ class TranspositionsBase
namespace internal {
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename _StorageIndex>
struct traits<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> >
: traits<PermutationMatrix<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> >
{
typedef Matrix<_StorageIndex, SizeAtCompileTime, 1, 0, MaxSizeAtCompileTime, 1> IndicesType;
typedef _StorageIndex StorageIndex;
typedef TranspositionsStorage StorageKind;
};
}
@ -214,9 +219,11 @@ class Transpositions : public TranspositionsBase<Transpositions<SizeAtCompileTim
namespace internal {
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename _StorageIndex, int _PacketAccess>
struct traits<Map<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex>,_PacketAccess> >
: traits<PermutationMatrix<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex> >
{
typedef Map<const Matrix<_StorageIndex,SizeAtCompileTime,1,0,MaxSizeAtCompileTime,1>, _PacketAccess> IndicesType;
typedef _StorageIndex StorageIndex;
typedef TranspositionsStorage StorageKind;
};
}
@ -271,9 +278,9 @@ class Map<Transpositions<SizeAtCompileTime,MaxSizeAtCompileTime,_StorageIndex>,P
namespace internal {
template<typename _IndicesType>
struct traits<TranspositionsWrapper<_IndicesType> >
: traits<PermutationWrapper<_IndicesType> >
{
typedef typename _IndicesType::Scalar StorageIndex;
typedef _IndicesType IndicesType;
typedef TranspositionsStorage StorageKind;
};
}
@ -347,8 +354,16 @@ operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
(transpositions.derived(), matrix.derived());
}
// Template partial specialization for transposed/inverse transpositions
/* Template partial specialization for transposed/inverse transpositions */
namespace internal {
template<typename Derived>
struct traits<Transpose<TranspositionsBase<Derived> > >
: traits<Derived>
{};
} // end namespace internal
template<typename TranspositionsDerived>
class Transpose<TranspositionsBase<TranspositionsDerived> >
@ -359,7 +374,9 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
explicit Transpose(const TranspositionType& t) : m_transpositions(t) {}
inline int size() const { return m_transpositions.size(); }
Index size() const { return m_transpositions.size(); }
Index rows() const { return m_transpositions.size(); }
Index cols() const { return m_transpositions.size(); }
/** \returns the \a matrix with the inverse transpositions applied to the columns.
*/
@ -378,6 +395,8 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
{
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
}
const TranspositionType& nestedExpression() const { return m_transpositions; }
protected:
const TranspositionType& m_transpositions;

View File

@ -468,6 +468,9 @@ struct Sparse {};
/** The type used to identify a permutation storage. */
struct PermutationStorage {};
/** The type used to identify a permutation storage. */
struct TranspositionsStorage {};
/** The type used to identify a matrix expression */
struct MatrixXpr {};