mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
Make Transpositions use evaluators
This commit is contained in:
parent
82b6ac0864
commit
3af4c6c1c9
@ -537,9 +537,6 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// TODO: Do we need to define these operator* functions? Would it be better to have them inherited
|
|
||||||
// from MatrixBase?
|
|
||||||
|
|
||||||
/** \returns the matrix with the permutation applied to the columns.
|
/** \returns the matrix with the permutation applied to the columns.
|
||||||
*/
|
*/
|
||||||
template<typename MatrixDerived, typename PermutationDerived>
|
template<typename MatrixDerived, typename PermutationDerived>
|
||||||
|
@ -929,6 +929,79 @@ struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/***************************************************************************
|
||||||
|
* Products with transpositions matrices
|
||||||
|
***************************************************************************/
|
||||||
|
|
||||||
|
// FIXME could we unify Transpositions and Permutation into a single "shape"??
|
||||||
|
|
||||||
|
/** \internal
|
||||||
|
* \class transposition_matrix_product
|
||||||
|
* Internal helper class implementing the product between a permutation matrix and a matrix.
|
||||||
|
*/
|
||||||
|
template<typename MatrixType, int Side, bool Transposed, typename MatrixShape>
|
||||||
|
struct transposition_matrix_product
|
||||||
|
{
|
||||||
|
template<typename Dest, typename TranspositionType>
|
||||||
|
static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat)
|
||||||
|
{
|
||||||
|
typedef typename TranspositionType::StorageIndex StorageIndex;
|
||||||
|
const Index size = tr.size();
|
||||||
|
StorageIndex j = 0;
|
||||||
|
|
||||||
|
if(!(is_same<MatrixType,Dest>::value && extract_data(dst) == extract_data(mat)))
|
||||||
|
dst = mat;
|
||||||
|
|
||||||
|
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
|
||||||
|
if(Index(j=tr.coeff(k))!=k)
|
||||||
|
{
|
||||||
|
if(Side==OnTheLeft) dst.row(k).swap(dst.row(j));
|
||||||
|
else if(Side==OnTheRight) dst.col(k).swap(dst.col(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||||
|
struct generic_product_impl<Lhs, Rhs, TranspositionsShape, MatrixShape, ProductTag>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
transposition_matrix_product<Rhs, OnTheLeft, false, MatrixShape>::run(dst, lhs, rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||||
|
struct generic_product_impl<Lhs, Rhs, MatrixShape, TranspositionsShape, ProductTag>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
transposition_matrix_product<Lhs, OnTheRight, false, MatrixShape>::run(dst, rhs, lhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||||
|
struct generic_product_impl<Transpose<Lhs>, Rhs, TranspositionsShape, MatrixShape, ProductTag>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
transposition_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedPermutation(), rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
|
||||||
|
struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, TranspositionsShape, ProductTag>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
|
||||||
|
{
|
||||||
|
transposition_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedPermutation(), lhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -41,10 +41,6 @@ namespace Eigen {
|
|||||||
* \sa class PermutationMatrix
|
* \sa class PermutationMatrix
|
||||||
*/
|
*/
|
||||||
|
|
||||||
namespace internal {
|
|
||||||
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed=false> struct transposition_matrix_product_retval;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
class TranspositionsBase
|
class TranspositionsBase
|
||||||
{
|
{
|
||||||
@ -325,77 +321,32 @@ class TranspositionsWrapper
|
|||||||
const typename IndicesType::Nested m_indices;
|
const typename IndicesType::Nested m_indices;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/** \returns the \a matrix with the \a transpositions applied to the columns.
|
/** \returns the \a matrix with the \a transpositions applied to the columns.
|
||||||
*/
|
*/
|
||||||
template<typename Derived, typename TranspositionsDerived>
|
template<typename MatrixDerived, typename TranspositionsDerived>
|
||||||
inline const internal::transposition_matrix_product_retval<TranspositionsDerived, Derived, OnTheRight>
|
EIGEN_DEVICE_FUNC
|
||||||
operator*(const MatrixBase<Derived>& matrix,
|
const Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
|
||||||
const TranspositionsBase<TranspositionsDerived> &transpositions)
|
operator*(const MatrixBase<MatrixDerived> &matrix,
|
||||||
|
const TranspositionsBase<TranspositionsDerived>& transpositions)
|
||||||
{
|
{
|
||||||
return internal::transposition_matrix_product_retval
|
return Product<MatrixDerived, TranspositionsDerived, DefaultProduct>
|
||||||
<TranspositionsDerived, Derived, OnTheRight>
|
(matrix.derived(), transpositions.derived());
|
||||||
(transpositions.derived(), matrix.derived());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \returns the \a matrix with the \a transpositions applied to the rows.
|
/** \returns the \a matrix with the \a transpositions applied to the rows.
|
||||||
*/
|
*/
|
||||||
template<typename Derived, typename TranspositionDerived>
|
template<typename TranspositionsDerived, typename MatrixDerived>
|
||||||
inline const internal::transposition_matrix_product_retval
|
EIGEN_DEVICE_FUNC
|
||||||
<TranspositionDerived, Derived, OnTheLeft>
|
const Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
|
||||||
operator*(const TranspositionsBase<TranspositionDerived> &transpositions,
|
operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
|
||||||
const MatrixBase<Derived>& matrix)
|
const MatrixBase<MatrixDerived>& matrix)
|
||||||
{
|
{
|
||||||
return internal::transposition_matrix_product_retval
|
return Product<TranspositionsDerived, MatrixDerived, DefaultProduct>
|
||||||
<TranspositionDerived, Derived, OnTheLeft>
|
|
||||||
(transpositions.derived(), matrix.derived());
|
(transpositions.derived(), matrix.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace internal {
|
|
||||||
|
|
||||||
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
|
|
||||||
struct traits<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
|
|
||||||
{
|
|
||||||
typedef typename MatrixType::PlainObject ReturnType;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename TranspositionType, typename MatrixType, int Side, bool Transposed>
|
|
||||||
struct transposition_matrix_product_retval
|
|
||||||
: public ReturnByValue<transposition_matrix_product_retval<TranspositionType, MatrixType, Side, Transposed> >
|
|
||||||
{
|
|
||||||
typedef typename remove_all<typename MatrixType::Nested>::type MatrixTypeNestedCleaned;
|
|
||||||
typedef typename TranspositionType::StorageIndex StorageIndex;
|
|
||||||
|
|
||||||
transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix)
|
|
||||||
: m_transpositions(tr), m_matrix(matrix)
|
|
||||||
{}
|
|
||||||
|
|
||||||
inline Index rows() const { return m_matrix.rows(); }
|
|
||||||
inline Index cols() const { return m_matrix.cols(); }
|
|
||||||
|
|
||||||
template<typename Dest> inline void evalTo(Dest& dst) const
|
|
||||||
{
|
|
||||||
const Index size = m_transpositions.size();
|
|
||||||
StorageIndex j = 0;
|
|
||||||
|
|
||||||
if(!(is_same<MatrixTypeNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_matrix)))
|
|
||||||
dst = m_matrix;
|
|
||||||
|
|
||||||
for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k<size ; Transposed?--k:++k)
|
|
||||||
if(Index(j=m_transpositions.coeff(k))!=k)
|
|
||||||
{
|
|
||||||
if(Side==OnTheLeft)
|
|
||||||
dst.row(k).swap(dst.row(j));
|
|
||||||
else if(Side==OnTheRight)
|
|
||||||
dst.col(k).swap(dst.col(j));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
const TranspositionType& m_transpositions;
|
|
||||||
typename MatrixType::Nested m_matrix;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // end namespace internal
|
|
||||||
|
|
||||||
/* Template partial specialization for transposed/inverse transpositions */
|
/* Template partial specialization for transposed/inverse transpositions */
|
||||||
|
|
||||||
@ -412,26 +363,56 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
|
|||||||
|
|
||||||
/** \returns the \a matrix with the inverse transpositions applied to the columns.
|
/** \returns the \a matrix with the inverse transpositions applied to the columns.
|
||||||
*/
|
*/
|
||||||
template<typename Derived> friend
|
template<typename OtherDerived> friend
|
||||||
inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>
|
const Product<OtherDerived, Transpose, DefaultProduct>
|
||||||
operator*(const MatrixBase<Derived>& matrix, const Transpose& trt)
|
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt)
|
||||||
{
|
{
|
||||||
return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheRight, true>(trt.m_transpositions, matrix.derived());
|
return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \returns the \a matrix with the inverse transpositions applied to the rows.
|
/** \returns the \a matrix with the inverse transpositions applied to the rows.
|
||||||
*/
|
*/
|
||||||
template<typename Derived>
|
template<typename OtherDerived>
|
||||||
inline const internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>
|
const Product<Transpose, OtherDerived, DefaultProduct>
|
||||||
operator*(const MatrixBase<Derived>& matrix) const
|
operator*(const MatrixBase<OtherDerived>& matrix) const
|
||||||
{
|
{
|
||||||
return internal::transposition_matrix_product_retval<TranspositionType, Derived, OnTheLeft, true>(m_transpositions, matrix.derived());
|
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const TranspositionType& m_transpositions;
|
const TranspositionType& m_transpositions;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper
|
||||||
|
// or their transpose; in the future shape should be defined by the expression traits
|
||||||
|
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType>
|
||||||
|
struct evaluator_traits<Transpositions<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef TranspositionsShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename IndicesType>
|
||||||
|
struct evaluator_traits<TranspositionsWrapper<IndicesType> >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef TranspositionsShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Derived>
|
||||||
|
struct evaluator_traits<Transpose<TranspositionsBase<Derived> > >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef TranspositionsShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_TRANSPOSITIONS_H
|
#endif // EIGEN_TRANSPOSITIONS_H
|
||||||
|
@ -482,6 +482,7 @@ struct BandShape { static std::string debugName() { return "BandSha
|
|||||||
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
|
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
|
||||||
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
|
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
|
||||||
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
|
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
|
||||||
|
struct TranspositionsShape { static std::string debugName() { return "TranspositionsShape"; } };
|
||||||
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
|
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user