mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-04 09:44:06 +08:00
Port products with permutation matrices to evaluators.
This commit is contained in:
parent
aceae8314b
commit
59f5f155c2
@ -288,6 +288,10 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
|
|||||||
typedef internal::traits<PermutationMatrix> Traits;
|
typedef internal::traits<PermutationMatrix> Traits;
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
typedef const PermutationMatrix& Nested;
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
#ifndef EIGEN_PARSED_BY_DOXYGEN
|
||||||
typedef typename Traits::IndicesType IndicesType;
|
typedef typename Traits::IndicesType IndicesType;
|
||||||
#endif
|
#endif
|
||||||
@ -461,6 +465,22 @@ class Map<PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType>,
|
|||||||
|
|
||||||
struct PermutationStorage {};
|
struct PermutationStorage {};
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
// storage type of product of permutation wrapper with dense
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<> struct promote_storage_type<Dense, PermutationStorage>
|
||||||
|
{ typedef Dense ret; };
|
||||||
|
|
||||||
|
template<> struct promote_storage_type<PermutationStorage, Dense>
|
||||||
|
{ typedef Dense ret; };
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
template<typename _IndicesType> class TranspositionsWrapper;
|
template<typename _IndicesType> class TranspositionsWrapper;
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template<typename _IndicesType>
|
template<typename _IndicesType>
|
||||||
@ -473,8 +493,13 @@ struct traits<PermutationWrapper<_IndicesType> >
|
|||||||
enum {
|
enum {
|
||||||
RowsAtCompileTime = _IndicesType::SizeAtCompileTime,
|
RowsAtCompileTime = _IndicesType::SizeAtCompileTime,
|
||||||
ColsAtCompileTime = _IndicesType::SizeAtCompileTime,
|
ColsAtCompileTime = _IndicesType::SizeAtCompileTime,
|
||||||
MaxRowsAtCompileTime = IndicesType::MaxRowsAtCompileTime,
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
MaxRowsAtCompileTime = IndicesType::MaxSizeAtCompileTime,
|
||||||
|
MaxColsAtCompileTime = IndicesType::MaxSizeAtCompileTime,
|
||||||
|
#else
|
||||||
|
MaxRowsAtCompileTime = IndicesType::MaxRowsAtCompileTime, // is this a bug in Eigen 2.2 ?
|
||||||
MaxColsAtCompileTime = IndicesType::MaxColsAtCompileTime,
|
MaxColsAtCompileTime = IndicesType::MaxColsAtCompileTime,
|
||||||
|
#endif
|
||||||
Flags = 0
|
Flags = 0
|
||||||
#ifndef EIGEN_TEST_EVALUATORS
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
,
|
,
|
||||||
@ -508,6 +533,37 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp
|
|||||||
typename IndicesType::Nested m_indices;
|
typename IndicesType::Nested m_indices;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
// TODO: Do we need to define these operator* functions? Would it be better to have them inherited
|
||||||
|
// from MatrixBase?
|
||||||
|
|
||||||
|
/** \returns the matrix with the permutation applied to the columns.
|
||||||
|
*/
|
||||||
|
template<typename MatrixDerived, typename PermutationDerived>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
const Product<MatrixDerived, PermutationDerived, DefaultProduct>
|
||||||
|
operator*(const MatrixBase<MatrixDerived> &matrix,
|
||||||
|
const PermutationBase<PermutationDerived>& permutation)
|
||||||
|
{
|
||||||
|
return Product<MatrixDerived, PermutationDerived, DefaultProduct>
|
||||||
|
(matrix.derived(), permutation.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \returns the matrix with the permutation applied to the rows.
|
||||||
|
*/
|
||||||
|
template<typename PermutationDerived, typename MatrixDerived>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
const Product<PermutationDerived, MatrixDerived, DefaultProduct>
|
||||||
|
operator*(const PermutationBase<PermutationDerived> &permutation,
|
||||||
|
const MatrixBase<MatrixDerived>& matrix)
|
||||||
|
{
|
||||||
|
return Product<PermutationDerived, MatrixDerived, DefaultProduct>
|
||||||
|
(permutation.derived(), matrix.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
/** \returns the matrix with the permutation applied to the columns.
|
/** \returns the matrix with the permutation applied to the columns.
|
||||||
*/
|
*/
|
||||||
template<typename Derived, typename PermutationDerived>
|
template<typename Derived, typename PermutationDerived>
|
||||||
@ -533,6 +589,8 @@ operator*(const PermutationBase<PermutationDerived> &permutation,
|
|||||||
(permutation.derived(), matrix.derived());
|
(permutation.derived(), matrix.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename PermutationType, typename MatrixType, int Side, bool Transposed>
|
template<typename PermutationType, typename MatrixType, int Side, bool Transposed>
|
||||||
@ -662,6 +720,28 @@ class Transpose<PermutationBase<Derived> >
|
|||||||
|
|
||||||
DenseMatrixType toDenseMatrix() const { return *this; }
|
DenseMatrixType toDenseMatrix() const { return *this; }
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
/** \returns the matrix with the inverse permutation applied to the columns.
|
||||||
|
*/
|
||||||
|
template<typename OtherDerived> friend
|
||||||
|
const Product<OtherDerived, Transpose, DefaultProduct>
|
||||||
|
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trPerm)
|
||||||
|
{
|
||||||
|
return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trPerm.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \returns the matrix with the inverse permutation applied to the rows.
|
||||||
|
*/
|
||||||
|
template<typename OtherDerived>
|
||||||
|
const Product<Transpose, OtherDerived, DefaultProduct>
|
||||||
|
operator*(const MatrixBase<OtherDerived>& matrix) const
|
||||||
|
{
|
||||||
|
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
/** \returns the matrix with the inverse permutation applied to the columns.
|
/** \returns the matrix with the inverse permutation applied to the columns.
|
||||||
*/
|
*/
|
||||||
template<typename OtherDerived> friend
|
template<typename OtherDerived> friend
|
||||||
@ -680,6 +760,8 @@ class Transpose<PermutationBase<Derived> >
|
|||||||
return internal::permut_matrix_product_retval<PermutationType, OtherDerived, OnTheLeft, true>(m_permutation, matrix.derived());
|
return internal::permut_matrix_product_retval<PermutationType, OtherDerived, OnTheLeft, true>(m_permutation, matrix.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
const PermutationType& nestedPermutation() const { return m_permutation; }
|
const PermutationType& nestedPermutation() const { return m_permutation; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -692,6 +774,38 @@ const PermutationWrapper<const Derived> MatrixBase<Derived>::asPermutation() con
|
|||||||
return derived();
|
return derived();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO currently a permutation matrix expression has the form PermutationMatrix or PermutationWrapper
|
||||||
|
// or their transpose; in the future shape should be defined by the expression traits
|
||||||
|
template<int SizeAtCompileTime, int MaxSizeAtCompileTime, typename IndexType>
|
||||||
|
struct evaluator_traits<PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime, IndexType> >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef PermutationShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename IndicesType>
|
||||||
|
struct evaluator_traits<PermutationWrapper<IndicesType> >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef PermutationShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Derived>
|
||||||
|
struct evaluator_traits<Transpose<PermutationBase<Derived> > >
|
||||||
|
{
|
||||||
|
typedef typename storage_kind_to_evaluator_kind<Dense>::Kind Kind;
|
||||||
|
typedef PermutationShape Shape;
|
||||||
|
static const int AssumeAliasing = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_PERMUTATIONMATRIX_H
|
#endif // EIGEN_PERMUTATIONMATRIX_H
|
||||||
|
@ -29,8 +29,30 @@ template<typename Lhs, typename Rhs, int Option, typename StorageKind> class Pro
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Use ProductReturnType to get correct traits, in particular vectorization flags
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
// Determine the scalar of Product<Lhs, Rhs>. This is normally the same as Lhs::Scalar times
|
||||||
|
// Rhs::Scalar, but product with permutation matrices inherit the scalar of the other factor.
|
||||||
|
template<typename Lhs, typename Rhs, typename LhsShape = typename evaluator_traits<Lhs>::Shape,
|
||||||
|
typename RhsShape = typename evaluator_traits<Rhs>::Shape >
|
||||||
|
struct product_result_scalar
|
||||||
|
{
|
||||||
|
typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename RhsShape>
|
||||||
|
struct product_result_scalar<Lhs, Rhs, PermutationShape, RhsShape>
|
||||||
|
{
|
||||||
|
typedef typename Rhs::Scalar Scalar;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename LhsShape>
|
||||||
|
struct product_result_scalar<Lhs, Rhs, LhsShape, PermutationShape>
|
||||||
|
{
|
||||||
|
typedef typename Lhs::Scalar Scalar;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int Option>
|
template<typename Lhs, typename Rhs, int Option>
|
||||||
struct traits<Product<Lhs, Rhs, Option> >
|
struct traits<Product<Lhs, Rhs, Option> >
|
||||||
{
|
{
|
||||||
@ -39,7 +61,7 @@ struct traits<Product<Lhs, Rhs, Option> >
|
|||||||
|
|
||||||
typedef MatrixXpr XprKind;
|
typedef MatrixXpr XprKind;
|
||||||
|
|
||||||
typedef typename scalar_product_traits<typename LhsCleaned::Scalar, typename RhsCleaned::Scalar>::ReturnType Scalar;
|
typedef typename product_result_scalar<LhsCleaned,RhsCleaned>::Scalar Scalar;
|
||||||
typedef typename promote_storage_type<typename traits<LhsCleaned>::StorageKind,
|
typedef typename promote_storage_type<typename traits<LhsCleaned>::StorageKind,
|
||||||
typename traits<RhsCleaned>::StorageKind>::ret StorageKind;
|
typename traits<RhsCleaned>::StorageKind>::ret StorageKind;
|
||||||
typedef typename promote_index_type<typename traits<LhsCleaned>::Index,
|
typedef typename promote_index_type<typename traits<LhsCleaned>::Index,
|
||||||
|
@ -885,6 +885,93 @@ struct product_evaluator<Product<Lhs, Rhs, ProductKind>, ProductTag, DenseShape,
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/***************************************************************************
|
||||||
|
* Products with permutation matrices
|
||||||
|
***************************************************************************/
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductType>
|
||||||
|
struct generic_product_impl<Lhs, Rhs, PermutationShape, DenseShape, ProductType>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
permut_matrix_product_retval<Lhs, Rhs, OnTheLeft, false> pmpr(lhs, rhs);
|
||||||
|
pmpr.evalTo(dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductType>
|
||||||
|
struct generic_product_impl<Lhs, Rhs, DenseShape, PermutationShape, ProductType>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
permut_matrix_product_retval<Rhs, Lhs, OnTheRight, false> pmpr(rhs, lhs);
|
||||||
|
pmpr.evalTo(dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductType>
|
||||||
|
struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, DenseShape, ProductType>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
permut_matrix_product_retval<Lhs, Rhs, OnTheLeft, true> pmpr(lhs.nestedPermutation(), rhs);
|
||||||
|
pmpr.evalTo(dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductType>
|
||||||
|
struct generic_product_impl<Lhs, Transpose<Rhs>, DenseShape, PermutationShape, ProductType>
|
||||||
|
{
|
||||||
|
template<typename Dest>
|
||||||
|
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
|
||||||
|
{
|
||||||
|
permut_matrix_product_retval<Rhs, Lhs, OnTheRight, true> pmpr(rhs.nestedPermutation(), lhs);
|
||||||
|
pmpr.evalTo(dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: left/right and self-adj/symmetric/permutation look the same ... Too much boilerplate?
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, PermutationShape, DenseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
|
||||||
|
{
|
||||||
|
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
|
||||||
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
|
typedef typename evaluator<PlainObject>::type Base;
|
||||||
|
|
||||||
|
product_evaluator(const XprType& xpr)
|
||||||
|
: m_result(xpr.rows(), xpr.cols())
|
||||||
|
{
|
||||||
|
::new (static_cast<Base*>(this)) Base(m_result);
|
||||||
|
generic_product_impl<Lhs, Rhs, PermutationShape, DenseShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
PlainObject m_result;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, DenseShape, PermutationShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
|
||||||
|
{
|
||||||
|
typedef Product<Lhs, Rhs, DefaultProduct> XprType;
|
||||||
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
|
typedef typename evaluator<PlainObject>::type Base;
|
||||||
|
|
||||||
|
product_evaluator(const XprType& xpr)
|
||||||
|
: m_result(xpr.rows(), xpr.cols())
|
||||||
|
{
|
||||||
|
::new (static_cast<Base*>(this)) Base(m_result);
|
||||||
|
generic_product_impl<Lhs, Rhs, DenseShape, PermutationShape, ProductTag>::evalTo(m_result, xpr.lhs(), xpr.rhs());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
PlainObject m_result;
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -447,6 +447,7 @@ struct DiagonalShape { static std::string debugName() { return "DiagonalShape
|
|||||||
struct BandShape { static std::string debugName() { return "BandShape"; } };
|
struct BandShape { static std::string debugName() { return "BandShape"; } };
|
||||||
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
|
struct TriangularShape { static std::string debugName() { return "TriangularShape"; } };
|
||||||
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
|
struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } };
|
||||||
|
struct PermutationShape { static std::string debugName() { return "PermutationShape"; } };
|
||||||
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
|
struct SparseShape { static std::string debugName() { return "SparseShape"; } };
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user