Introduce a AliasFreeProduct option for Permutations and Transpositions

This commit is contained in:
Gael Guennebaud 2015-06-19 15:38:19 +02:00
parent 3f6aa4cd5d
commit 0c8b0e007b
6 changed files with 46 additions and 35 deletions

View File

@ -541,11 +541,11 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<_IndicesTyp
*/ */
template<typename MatrixDerived, typename PermutationDerived> template<typename MatrixDerived, typename PermutationDerived>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const Product<MatrixDerived, PermutationDerived, DefaultProduct> const Product<MatrixDerived, PermutationDerived, AliasFreeProduct>
operator*(const MatrixBase<MatrixDerived> &matrix, operator*(const MatrixBase<MatrixDerived> &matrix,
const PermutationBase<PermutationDerived>& permutation) const PermutationBase<PermutationDerived>& permutation)
{ {
return Product<MatrixDerived, PermutationDerived, DefaultProduct> return Product<MatrixDerived, PermutationDerived, AliasFreeProduct>
(matrix.derived(), permutation.derived()); (matrix.derived(), permutation.derived());
} }
@ -553,11 +553,11 @@ operator*(const MatrixBase<MatrixDerived> &matrix,
*/ */
template<typename PermutationDerived, typename MatrixDerived> template<typename PermutationDerived, typename MatrixDerived>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const Product<PermutationDerived, MatrixDerived, DefaultProduct> const Product<PermutationDerived, MatrixDerived, AliasFreeProduct>
operator*(const PermutationBase<PermutationDerived> &permutation, operator*(const PermutationBase<PermutationDerived> &permutation,
const MatrixBase<MatrixDerived>& matrix) const MatrixBase<MatrixDerived>& matrix)
{ {
return Product<PermutationDerived, MatrixDerived, DefaultProduct> return Product<PermutationDerived, MatrixDerived, AliasFreeProduct>
(permutation.derived(), matrix.derived()); (permutation.derived(), matrix.derived());
} }
@ -620,19 +620,19 @@ class Transpose<PermutationBase<Derived> >
/** \returns the matrix with the inverse permutation applied to the columns. /** \returns the matrix with the inverse permutation applied to the columns.
*/ */
template<typename OtherDerived> friend template<typename OtherDerived> friend
const Product<OtherDerived, Transpose, DefaultProduct> const Product<OtherDerived, Transpose, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trPerm) operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trPerm)
{ {
return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trPerm.derived()); return Product<OtherDerived, Transpose, AliasFreeProduct>(matrix.derived(), trPerm.derived());
} }
/** \returns the matrix with the inverse permutation applied to the rows. /** \returns the matrix with the inverse permutation applied to the rows.
*/ */
template<typename OtherDerived> template<typename OtherDerived>
const Product<Transpose, OtherDerived, DefaultProduct> const Product<Transpose, OtherDerived, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix) const operator*(const MatrixBase<OtherDerived>& matrix) const
{ {
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived()); return Product<Transpose, OtherDerived, AliasFreeProduct>(*this, matrix.derived());
} }
const PermutationType& nestedExpression() const { return m_permutation; } const PermutationType& nestedExpression() const { return m_permutation; }

View File

@ -25,7 +25,7 @@ template<typename Lhs, typename Rhs, int Option, typename StorageKind> class Pro
* This class represents an expression of the product of two arbitrary matrices. * This class represents an expression of the product of two arbitrary matrices.
* *
* The other template parameters are: * The other template parameters are:
* \tparam Option can be DefaultProduct or LazyProduct * \tparam Option can be DefaultProduct, AliasFreeProduct, or LazyProduct
* *
*/ */

View File

@ -90,13 +90,21 @@ struct evaluator_traits<Product<Lhs, Rhs, DefaultProduct> >
enum { AssumeAliasing = 1 }; enum { AssumeAliasing = 1 };
}; };
template<typename Lhs, typename Rhs>
struct evaluator_traits<Product<Lhs, Rhs, AliasFreeProduct> >
: evaluator_traits_base<Product<Lhs, Rhs, AliasFreeProduct> >
{
enum { AssumeAliasing = 0 };
};
// This is the default evaluator implementation for products: // This is the default evaluator implementation for products:
// It creates a temporary and call generic_product_impl // It creates a temporary and call generic_product_impl
template<typename Lhs, typename Rhs, int ProductTag, typename LhsShape, typename RhsShape> template<typename Lhs, typename Rhs, int Options, int ProductTag, typename LhsShape, typename RhsShape>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, LhsShape, RhsShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, LhsShape, RhsShape, typename traits<Lhs>::Scalar,
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),typename traits<Rhs>::Scalar>::type>
: public evaluator<typename Product<Lhs, Rhs, Options>::PlainObject>::type
{ {
typedef Product<Lhs, Rhs, DefaultProduct> XprType; typedef Product<Lhs, Rhs, Options> XprType;
typedef typename XprType::PlainObject PlainObject; typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type Base; typedef typename evaluator<PlainObject>::type Base;
enum { enum {
@ -128,10 +136,11 @@ protected:
}; };
// Dense = Product // Dense = Product
template< typename DstXprType, typename Lhs, typename Rhs, typename Scalar> template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::assign_op<Scalar>, Dense2Dense, Scalar> struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<Scalar>, Dense2Dense,
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type>
{ {
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType; typedef Product<Lhs,Rhs,Options> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &) static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &)
{ {
// FIXME shall we handle nested_eval here? // FIXME shall we handle nested_eval here?
@ -140,10 +149,11 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::assign_
}; };
// Dense += Product // Dense += Product
template< typename DstXprType, typename Lhs, typename Rhs, typename Scalar> template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::add_assign_op<Scalar>, Dense2Dense, Scalar> struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<Scalar>, Dense2Dense,
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type>
{ {
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType; typedef Product<Lhs,Rhs,Options> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar> &) static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<Scalar> &)
{ {
// FIXME shall we handle nested_eval here? // FIXME shall we handle nested_eval here?
@ -152,10 +162,11 @@ struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::add_ass
}; };
// Dense -= Product // Dense -= Product
template< typename DstXprType, typename Lhs, typename Rhs, typename Scalar> template< typename DstXprType, typename Lhs, typename Rhs, int Options, typename Scalar>
struct Assignment<DstXprType, Product<Lhs,Rhs,DefaultProduct>, internal::sub_assign_op<Scalar>, Dense2Dense, Scalar> struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::sub_assign_op<Scalar>, Dense2Dense,
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type>
{ {
typedef Product<Lhs,Rhs,DefaultProduct> SrcXprType; typedef Product<Lhs,Rhs,Options> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar> &) static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<Scalar> &)
{ {
// FIXME shall we handle nested_eval here? // FIXME shall we handle nested_eval here?

View File

@ -334,11 +334,11 @@ class TranspositionsWrapper
*/ */
template<typename MatrixDerived, typename TranspositionsDerived> template<typename MatrixDerived, typename TranspositionsDerived>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const Product<MatrixDerived, TranspositionsDerived, DefaultProduct> const Product<MatrixDerived, TranspositionsDerived, AliasFreeProduct>
operator*(const MatrixBase<MatrixDerived> &matrix, operator*(const MatrixBase<MatrixDerived> &matrix,
const TranspositionsBase<TranspositionsDerived>& transpositions) const TranspositionsBase<TranspositionsDerived>& transpositions)
{ {
return Product<MatrixDerived, TranspositionsDerived, DefaultProduct> return Product<MatrixDerived, TranspositionsDerived, AliasFreeProduct>
(matrix.derived(), transpositions.derived()); (matrix.derived(), transpositions.derived());
} }
@ -346,11 +346,11 @@ operator*(const MatrixBase<MatrixDerived> &matrix,
*/ */
template<typename TranspositionsDerived, typename MatrixDerived> template<typename TranspositionsDerived, typename MatrixDerived>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const Product<TranspositionsDerived, MatrixDerived, DefaultProduct> const Product<TranspositionsDerived, MatrixDerived, AliasFreeProduct>
operator*(const TranspositionsBase<TranspositionsDerived> &transpositions, operator*(const TranspositionsBase<TranspositionsDerived> &transpositions,
const MatrixBase<MatrixDerived>& matrix) const MatrixBase<MatrixDerived>& matrix)
{ {
return Product<TranspositionsDerived, MatrixDerived, DefaultProduct> return Product<TranspositionsDerived, MatrixDerived, AliasFreeProduct>
(transpositions.derived(), matrix.derived()); (transpositions.derived(), matrix.derived());
} }
@ -381,19 +381,19 @@ class Transpose<TranspositionsBase<TranspositionsDerived> >
/** \returns the \a matrix with the inverse transpositions applied to the columns. /** \returns the \a matrix with the inverse transpositions applied to the columns.
*/ */
template<typename OtherDerived> friend template<typename OtherDerived> friend
const Product<OtherDerived, Transpose, DefaultProduct> const Product<OtherDerived, Transpose, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt) operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trt)
{ {
return Product<OtherDerived, Transpose, DefaultProduct>(matrix.derived(), trt.derived()); return Product<OtherDerived, Transpose, AliasFreeProduct>(matrix.derived(), trt.derived());
} }
/** \returns the \a matrix with the inverse transpositions applied to the rows. /** \returns the \a matrix with the inverse transpositions applied to the rows.
*/ */
template<typename OtherDerived> template<typename OtherDerived>
const Product<Transpose, OtherDerived, DefaultProduct> const Product<Transpose, OtherDerived, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix) const operator*(const MatrixBase<OtherDerived>& matrix) const
{ {
return Product<Transpose, OtherDerived, DefaultProduct>(*this, matrix.derived()); return Product<Transpose, OtherDerived, AliasFreeProduct>(*this, matrix.derived());
} }
const TranspositionType& nestedExpression() const { return m_transpositions; } const TranspositionType& nestedExpression() const { return m_transpositions; }

View File

@ -453,7 +453,7 @@ namespace Architecture
/** \internal \ingroup enums /** \internal \ingroup enums
* Enum used as template parameter in GeneralProduct. */ * Enum used as template parameter in GeneralProduct. */
enum { DefaultProduct=0, CoeffBasedProductMode, LazyCoeffBasedProductMode, LazyProduct, OuterProduct, InnerProduct, GemvProduct, GemmProduct }; enum { DefaultProduct=0, LazyProduct, AliasFreeProduct, CoeffBasedProductMode, LazyCoeffBasedProductMode, OuterProduct, InnerProduct, GemvProduct, GemmProduct };
/** \internal \ingroup enums /** \internal \ingroup enums
* Enum used in experimental parallel implementation. */ * Enum used in experimental parallel implementation. */

View File

@ -90,10 +90,10 @@ template <int ProductTag> struct product_promote_storage_type<PermutationStorage
// whereas it should be correctly handled by traits<Product<> >::PlainObject // whereas it should be correctly handled by traits<Product<> >::PlainObject
template<typename Lhs, typename Rhs, int ProductTag> template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, PermutationShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, PermutationShape, SparseShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar>
: public evaluator<typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType>::type : public evaluator<typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType>::type
{ {
typedef Product<Lhs, Rhs, DefaultProduct> XprType; typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
typedef typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType PlainObject; typedef typename permutation_matrix_product<Rhs,OnTheRight,false,SparseShape>::ReturnType PlainObject;
typedef typename evaluator<PlainObject>::type Base; typedef typename evaluator<PlainObject>::type Base;
@ -109,10 +109,10 @@ protected:
}; };
template<typename Lhs, typename Rhs, int ProductTag> template<typename Lhs, typename Rhs, int ProductTag>
struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseShape, PermutationShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar> struct product_evaluator<Product<Lhs, Rhs, AliasFreeProduct>, ProductTag, SparseShape, PermutationShape, typename traits<Lhs>::Scalar, typename traits<Rhs>::Scalar>
: public evaluator<typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType>::type : public evaluator<typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType>::type
{ {
typedef Product<Lhs, Rhs, DefaultProduct> XprType; typedef Product<Lhs, Rhs, AliasFreeProduct> XprType;
typedef typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType PlainObject; typedef typename permutation_matrix_product<Lhs,OnTheRight,false,SparseShape>::ReturnType PlainObject;
typedef typename evaluator<PlainObject>::type Base; typedef typename evaluator<PlainObject>::type Base;