mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-01 00:04:14 +08:00
Refactor dense product evaluators
This commit is contained in:
parent
fc6ecebc69
commit
cc6dd878ee
@ -316,7 +316,6 @@ using std::ptrdiff_t;
|
|||||||
#include "src/Core/Product.h"
|
#include "src/Core/Product.h"
|
||||||
#include "src/Core/CoreEvaluators.h"
|
#include "src/Core/CoreEvaluators.h"
|
||||||
#include "src/Core/AssignEvaluator.h"
|
#include "src/Core/AssignEvaluator.h"
|
||||||
#include "src/Core/ProductEvaluators.h"
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef EIGEN_PARSED_BY_DOXYGEN // work around Doxygen bug triggered by Assign.h r814874
|
#ifndef EIGEN_PARSED_BY_DOXYGEN // work around Doxygen bug triggered by Assign.h r814874
|
||||||
@ -382,6 +381,10 @@ using std::ptrdiff_t;
|
|||||||
#include "src/Core/BandMatrix.h"
|
#include "src/Core/BandMatrix.h"
|
||||||
#include "src/Core/CoreIterators.h"
|
#include "src/Core/CoreIterators.h"
|
||||||
|
|
||||||
|
#ifdef EIGEN_ENABLE_EVALUATORS
|
||||||
|
#include "src/Core/ProductEvaluators.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "src/Core/BooleanRedux.h"
|
#include "src/Core/BooleanRedux.h"
|
||||||
#include "src/Core/Select.h"
|
#include "src/Core/Select.h"
|
||||||
#include "src/Core/VectorwiseOp.h"
|
#include "src/Core/VectorwiseOp.h"
|
||||||
|
@ -12,8 +12,7 @@
|
|||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs> class Product;
|
template<typename Lhs, typename Rhs, int Option, int ProductTag, typename StorageKind> class ProductImpl;
|
||||||
template<typename Lhs, typename Rhs, typename StorageKind> class ProductImpl;
|
|
||||||
|
|
||||||
/** \class Product
|
/** \class Product
|
||||||
* \ingroup Core_Module
|
* \ingroup Core_Module
|
||||||
@ -24,13 +23,17 @@ template<typename Lhs, typename Rhs, typename StorageKind> class ProductImpl;
|
|||||||
* \param Rhs the type of the right-hand side expression
|
* \param Rhs the type of the right-hand side expression
|
||||||
*
|
*
|
||||||
* This class represents an expression of the product of two arbitrary matrices.
|
* This class represents an expression of the product of two arbitrary matrices.
|
||||||
|
*
|
||||||
|
* The other template parameters are:
|
||||||
|
* \tparam Option can be DefaultProduct or LazyProduct
|
||||||
|
* \tparam ProductTag can be InnerProduct, OuterProduct, GemvProduct, GemmProduct. It is used to ease expression manipulations.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Use ProductReturnType to get correct traits, in particular vectorization flags
|
// Use ProductReturnType to get correct traits, in particular vectorization flags
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs, int Option, int ProductTag>
|
||||||
struct traits<Product<Lhs, Rhs> >
|
struct traits<Product<Lhs, Rhs, Option, ProductTag> >
|
||||||
: traits<typename ProductReturnType<Lhs, Rhs>::Type>
|
: traits<typename ProductReturnType<Lhs, Rhs>::Type>
|
||||||
{
|
{
|
||||||
// We want A+B*C to be of type Product<Matrix, Sum> and not Product<Matrix, Matrix>
|
// We want A+B*C to be of type Product<Matrix, Sum> and not Product<Matrix, Matrix>
|
||||||
@ -42,14 +45,15 @@ struct traits<Product<Lhs, Rhs> >
|
|||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs, int Option, int ProductTag>
|
||||||
class Product : public ProductImpl<Lhs,Rhs,typename internal::promote_storage_type<typename internal::traits<Lhs>::StorageKind,
|
class Product : public ProductImpl<Lhs,Rhs,Option,ProductTag,
|
||||||
typename internal::traits<Rhs>::StorageKind>::ret>
|
typename internal::promote_storage_type<typename internal::traits<Lhs>::StorageKind,
|
||||||
|
typename internal::traits<Rhs>::StorageKind>::ret>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef typename ProductImpl<
|
typedef typename ProductImpl<
|
||||||
Lhs, Rhs,
|
Lhs, Rhs, Option, ProductTag,
|
||||||
typename internal::promote_storage_type<typename Lhs::StorageKind,
|
typename internal::promote_storage_type<typename Lhs::StorageKind,
|
||||||
typename Rhs::StorageKind>::ret>::Base Base;
|
typename Rhs::StorageKind>::ret>::Base Base;
|
||||||
EIGEN_GENERIC_PUBLIC_INTERFACE(Product)
|
EIGEN_GENERIC_PUBLIC_INTERFACE(Product)
|
||||||
@ -78,13 +82,13 @@ class Product : public ProductImpl<Lhs,Rhs,typename internal::promote_storage_ty
|
|||||||
RhsNested m_rhs;
|
RhsNested m_rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs, int Option, int ProductTag>
|
||||||
class ProductImpl<Lhs,Rhs,Dense> : public internal::dense_xpr_base<Product<Lhs,Rhs> >::type
|
class ProductImpl<Lhs,Rhs,Option,ProductTag,Dense> : public internal::dense_xpr_base<Product<Lhs,Rhs,Option,ProductTag> >::type
|
||||||
{
|
{
|
||||||
typedef Product<Lhs, Rhs> Derived;
|
typedef Product<Lhs, Rhs> Derived;
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef typename internal::dense_xpr_base<Product<Lhs, Rhs> >::type Base;
|
typedef typename internal::dense_xpr_base<Product<Lhs, Rhs, Option, ProductTag> >::type Base;
|
||||||
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -102,6 +106,15 @@ prod(const Lhs& lhs, const Rhs& rhs)
|
|||||||
return Product<Lhs,Rhs>(lhs,rhs);
|
return Product<Lhs,Rhs>(lhs,rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \internal used to test the evaluator only
|
||||||
|
*/
|
||||||
|
template<typename Lhs,typename Rhs>
|
||||||
|
const Product<Lhs,Rhs,LazyProduct>
|
||||||
|
lazyprod(const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
return Product<Lhs,Rhs,LazyProduct>(lhs,rhs);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
#endif // EIGEN_PRODUCT_H
|
#endif // EIGEN_PRODUCT_H
|
||||||
|
@ -17,94 +17,172 @@ namespace Eigen {
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
// We can evaluate the product either all at once, like GeneralProduct and its evalTo() function, or
|
|
||||||
// traverse the matrix coefficient by coefficient, like CoeffBasedProduct. Use the existing logic
|
// Helper class to perform a dense product with the destination at hand.
|
||||||
// in ProductReturnType to decide.
|
// Depending on the sizes of the factors, there are different evaluation strategies
|
||||||
|
// as controlled by internal::product_type.
|
||||||
|
template<typename Lhs, typename Rhs, int ProductType = internal::product_type<Lhs,Rhs>::value>
|
||||||
|
struct dense_product_impl;
|
||||||
|
|
||||||
template<typename XprType, typename ProductType>
|
|
||||||
struct product_evaluator_dispatcher;
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
// The evaluator for default dense products creates a temporary and call dense_product_impl
|
||||||
struct evaluator_impl<Product<Lhs, Rhs> >
|
template<typename Lhs, typename Rhs, int ProductTag>
|
||||||
: product_evaluator_dispatcher<Product<Lhs, Rhs>, typename ProductReturnType<Lhs, Rhs>::Type>
|
struct evaluator_impl<Product<Lhs, Rhs, DefaultProduct, ProductTag> >
|
||||||
|
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct, ProductTag>::PlainObject>::type
|
||||||
{
|
{
|
||||||
typedef Product<Lhs, Rhs> XprType;
|
typedef Product<Lhs, Rhs, DefaultProduct, ProductTag> XprType;
|
||||||
typedef product_evaluator_dispatcher<XprType, typename ProductReturnType<Lhs, Rhs>::Type> Base;
|
|
||||||
|
|
||||||
evaluator_impl(const XprType& xpr) : Base(xpr)
|
|
||||||
{ }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename XprType, typename ProductType>
|
|
||||||
struct product_evaluator_traits_dispatcher;
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
struct evaluator_traits<Product<Lhs, Rhs> >
|
|
||||||
: product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, typename ProductReturnType<Lhs, Rhs>::Type>
|
|
||||||
{
|
|
||||||
static const int AssumeAliasing = 1;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Case 1: Evaluate all at once
|
|
||||||
//
|
|
||||||
// We can view the GeneralProduct class as a part of the product evaluator.
|
|
||||||
// Four sub-cases: InnerProduct, OuterProduct, GemmProduct and GemvProduct.
|
|
||||||
// InnerProduct is special because GeneralProduct does not have an evalTo() method in this case.
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
|
|
||||||
{
|
|
||||||
static const int HasEvalTo = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
|
|
||||||
: public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
|
|
||||||
{
|
|
||||||
typedef Product<Lhs, Rhs> XprType;
|
|
||||||
typedef typename XprType::PlainObject PlainObject;
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
typedef typename evaluator<PlainObject>::type evaluator_base;
|
typedef typename evaluator<PlainObject>::type Base;
|
||||||
|
|
||||||
// TODO: Computation is too early (?)
|
evaluator_impl(const XprType& xpr)
|
||||||
product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
|
: m_result(xpr.rows(), xpr.cols())
|
||||||
{
|
{
|
||||||
m_result.coeffRef(0,0) = (xpr.lhs().transpose().cwiseProduct(xpr.rhs())).sum();
|
::new (static_cast<Base*>(this)) Base(m_result);
|
||||||
|
dense_product_impl<Lhs, Rhs>::evalTo(m_result, xpr.lhs(), xpr.rhs());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
PlainObject m_result;
|
PlainObject m_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
// For the other three subcases, simply call the evalTo() method of GeneralProduct
|
|
||||||
// TODO: GeneralProduct should take evaluators, not expression objects.
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int ProductType>
|
template<typename Lhs, typename Rhs>
|
||||||
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
|
struct dense_product_impl<Lhs,Rhs,InnerProduct>
|
||||||
{
|
{
|
||||||
static const int HasEvalTo = 1;
|
template<typename Dst>
|
||||||
};
|
static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int ProductType>
|
|
||||||
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
|
|
||||||
{
|
|
||||||
typedef Product<Lhs, Rhs> XprType;
|
|
||||||
typedef typename XprType::PlainObject PlainObject;
|
|
||||||
typedef typename evaluator<PlainObject>::type evaluator_base;
|
|
||||||
|
|
||||||
product_evaluator_dispatcher(const XprType& xpr) : m_xpr(xpr)
|
|
||||||
{ }
|
|
||||||
|
|
||||||
template<typename DstEvaluatorType, typename DstXprType>
|
|
||||||
void evalTo(DstEvaluatorType /* not used */, DstXprType& dst) const
|
|
||||||
{
|
{
|
||||||
dst.resize(m_xpr.rows(), m_xpr.cols());
|
dst.coeffRef(0,0) = (lhs.transpose().cwiseProduct(rhs)).sum();
|
||||||
GeneralProduct<Lhs, Rhs, ProductType>(m_xpr.lhs(), m_xpr.rhs()).evalTo(dst);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
template<typename Dst>
|
||||||
const XprType& m_xpr;
|
static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
dst.coeffRef(0,0) += (lhs.transpose().cwiseProduct(rhs)).sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ dst.coeffRef(0,0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
struct dense_product_impl<Lhs,Rhs,OuterProduct>
|
||||||
|
{
|
||||||
|
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
// TODO bypass GeneralProduct class
|
||||||
|
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).evalTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
// TODO bypass GeneralProduct class
|
||||||
|
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).addTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{
|
||||||
|
// TODO bypass GeneralProduct class
|
||||||
|
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).subTo(dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||||
|
{
|
||||||
|
// TODO bypass GeneralProduct class
|
||||||
|
GeneralProduct<Lhs, Rhs, OuterProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// This base class provides default implementations for evalTo, addTo, subTo, in terms of scaleAndAddTo
|
||||||
|
template<typename Lhs, typename Rhs, typename Derived>
|
||||||
|
struct dense_product_impl_base
|
||||||
|
{
|
||||||
|
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ dst.setZero(); scaleAndAddTo(dst, lhs, rhs, Scalar(1)); }
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ scaleAndAddTo(dst,lhs, rhs, Scalar(1)); }
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ scaleAndAddTo(dst, lhs, rhs, Scalar(-1)); }
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||||
|
{ Derived::scaleAndAddTo(dst,lhs,rhs,alpha); }
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
struct dense_product_impl<Lhs,Rhs,GemvProduct> : dense_product_impl_base<Lhs,Rhs,dense_product_impl<Lhs,Rhs,GemvProduct> >
|
||||||
|
{
|
||||||
|
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||||
|
enum { Side = Lhs::IsVectorAtCompileTime ? OnTheLeft : OnTheRight };
|
||||||
|
typedef typename internal::conditional<int(Side)==OnTheRight,Lhs,Rhs>::type MatrixType;
|
||||||
|
|
||||||
|
template<typename Dest>
|
||||||
|
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||||
|
{
|
||||||
|
internal::gemv_selector<Side,
|
||||||
|
(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||||
|
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)
|
||||||
|
>::run(GeneralProduct<Lhs,Rhs,GemvProduct>(lhs,rhs), dst, alpha);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
struct dense_product_impl<Lhs,Rhs,GemmProduct> : dense_product_impl_base<Lhs,Rhs,dense_product_impl<Lhs,Rhs,GemmProduct> >
|
||||||
|
{
|
||||||
|
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||||
|
|
||||||
|
template<typename Dest>
|
||||||
|
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||||
|
{
|
||||||
|
// TODO bypass GeneralProduct class
|
||||||
|
GeneralProduct<Lhs, Rhs, GemmProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
struct dense_product_impl<Lhs,Rhs,CoeffBasedProductMode>
|
||||||
|
{
|
||||||
|
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ dst = lazyprod(lhs,rhs); }
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ dst += lazyprod(lhs,rhs); }
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
|
||||||
|
{ dst -= lazyprod(lhs,rhs); }
|
||||||
|
|
||||||
|
template<typename Dst>
|
||||||
|
static inline void scaleAndAddTo(Dst& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
||||||
|
{ dst += alpha * lazyprod(lhs,rhs); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
struct dense_product_impl<Lhs,Rhs,LazyCoeffBasedProductMode> : dense_product_impl<Lhs,Rhs,CoeffBasedProductMode> {};
|
||||||
|
|
||||||
// Case 2: Evaluate coeff by coeff
|
// Case 2: Evaluate coeff by coeff
|
||||||
//
|
//
|
||||||
// This is mostly taken from CoeffBasedProduct.h
|
// This is mostly taken from CoeffBasedProduct.h
|
||||||
@ -117,20 +195,14 @@ struct etor_product_coeff_impl;
|
|||||||
template<int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
|
template<int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typename Packet, int LoadMode>
|
||||||
struct etor_product_packet_impl;
|
struct etor_product_packet_impl;
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
|
template<typename Lhs, typename Rhs, int ProductTag>
|
||||||
struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
|
struct evaluator_impl<Product<Lhs, Rhs, LazyProduct, ProductTag> >
|
||||||
|
: evaluator_impl_base<Product<Lhs, Rhs, LazyProduct, ProductTag> >
|
||||||
{
|
{
|
||||||
static const int HasEvalTo = 0;
|
typedef Product<Lhs, Rhs, LazyProduct, ProductTag> XprType;
|
||||||
};
|
typedef CoeffBasedProduct<Lhs, Rhs, 0> CoeffBasedProductType;
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
|
evaluator_impl(const XprType& xpr)
|
||||||
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
|
|
||||||
: evaluator_impl_base<Product<Lhs, Rhs> >
|
|
||||||
{
|
|
||||||
typedef Product<Lhs, Rhs> XprType;
|
|
||||||
typedef CoeffBasedProduct<LhsNested, RhsNested, Flags> CoeffBasedProductType;
|
|
||||||
|
|
||||||
product_evaluator_dispatcher(const XprType& xpr)
|
|
||||||
: m_lhsImpl(xpr.lhs()),
|
: m_lhsImpl(xpr.lhs()),
|
||||||
m_rhsImpl(xpr.rhs()),
|
m_rhsImpl(xpr.rhs()),
|
||||||
m_innerDim(xpr.lhs().cols())
|
m_innerDim(xpr.lhs().cols())
|
||||||
@ -150,11 +222,13 @@ struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNest
|
|||||||
InnerSize = traits<CoeffBasedProductType>::InnerSize,
|
InnerSize = traits<CoeffBasedProductType>::InnerSize,
|
||||||
CoeffReadCost = traits<CoeffBasedProductType>::CoeffReadCost,
|
CoeffReadCost = traits<CoeffBasedProductType>::CoeffReadCost,
|
||||||
Unroll = CoeffReadCost != Dynamic && CoeffReadCost <= EIGEN_UNROLLING_LIMIT,
|
Unroll = CoeffReadCost != Dynamic && CoeffReadCost <= EIGEN_UNROLLING_LIMIT,
|
||||||
CanVectorizeInner = traits<CoeffBasedProductType>::CanVectorizeInner
|
CanVectorizeInner = traits<CoeffBasedProductType>::CanVectorizeInner,
|
||||||
|
Flags = CoeffBasedProductType::Flags
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef typename evaluator<Lhs>::type LhsEtorType;
|
typedef typename evaluator<Lhs>::type LhsEtorType;
|
||||||
typedef typename evaluator<Rhs>::type RhsEtorType;
|
typedef typename evaluator<Rhs>::type RhsEtorType;
|
||||||
|
|
||||||
typedef etor_product_coeff_impl<CanVectorizeInner ? InnerVectorizedTraversal : DefaultTraversal,
|
typedef etor_product_coeff_impl<CanVectorizeInner ? InnerVectorizedTraversal : DefaultTraversal,
|
||||||
Unroll ? InnerSize-1 : Dynamic,
|
Unroll ? InnerSize-1 : Dynamic,
|
||||||
LhsEtorType, RhsEtorType, Scalar> CoeffImpl;
|
LhsEtorType, RhsEtorType, Scalar> CoeffImpl;
|
||||||
@ -183,8 +257,8 @@ struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNest
|
|||||||
{
|
{
|
||||||
PacketScalar res;
|
PacketScalar res;
|
||||||
typedef etor_product_packet_impl<Flags&RowMajorBit ? RowMajor : ColMajor,
|
typedef etor_product_packet_impl<Flags&RowMajorBit ? RowMajor : ColMajor,
|
||||||
Unroll ? InnerSize-1 : Dynamic,
|
Unroll ? InnerSize-1 : Dynamic,
|
||||||
LhsEtorType, RhsEtorType, PacketScalar, LoadMode> PacketImpl;
|
LhsEtorType, RhsEtorType, PacketScalar, LoadMode> PacketImpl;
|
||||||
PacketImpl::run(row, col, m_lhsImpl, m_rhsImpl, m_innerDim, res);
|
PacketImpl::run(row, col, m_lhsImpl, m_rhsImpl, m_innerDim, res);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -197,6 +271,7 @@ protected:
|
|||||||
Index m_innerDim;
|
Index m_innerDim;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/***************************************************************************
|
/***************************************************************************
|
||||||
* Normal product .coeff() implementation (with meta-unrolling)
|
* Normal product .coeff() implementation (with meta-unrolling)
|
||||||
***************************************************************************/
|
***************************************************************************/
|
||||||
@ -275,7 +350,6 @@ struct etor_product_coeff_impl<InnerVectorizedTraversal, UnrollingIndex, Lhs, Rh
|
|||||||
{
|
{
|
||||||
Packet pres;
|
Packet pres;
|
||||||
etor_product_coeff_vectorized_unroller<UnrollingIndex+1-PacketSize, Lhs, Rhs, Packet>::run(row, col, lhs, rhs, innerDim, pres);
|
etor_product_coeff_vectorized_unroller<UnrollingIndex+1-PacketSize, Lhs, Rhs, Packet>::run(row, col, lhs, rhs, innerDim, pres);
|
||||||
etor_product_coeff_impl<DefaultTraversal,UnrollingIndex,Lhs,Rhs,RetScalar>::run(row, col, lhs, rhs, innerDim, res);
|
|
||||||
res = predux(pres);
|
res = predux(pres);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -425,7 +425,7 @@ namespace Architecture
|
|||||||
|
|
||||||
/** \internal \ingroup enums
|
/** \internal \ingroup enums
|
||||||
* Enum used as template parameter in GeneralProduct. */
|
* Enum used as template parameter in GeneralProduct. */
|
||||||
enum { CoeffBasedProductMode, LazyCoeffBasedProductMode, OuterProduct, InnerProduct, GemvProduct, GemmProduct };
|
enum { DefaultProduct=0, CoeffBasedProductMode, LazyCoeffBasedProductMode, LazyProduct, OuterProduct, InnerProduct, GemvProduct, GemmProduct };
|
||||||
|
|
||||||
/** \internal \ingroup enums
|
/** \internal \ingroup enums
|
||||||
* Enum used in experimental parallel implementation. */
|
* Enum used in experimental parallel implementation. */
|
||||||
|
@ -87,11 +87,20 @@ template<typename NullaryOp, typename MatrixType> class CwiseNullaryOp;
|
|||||||
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
template<typename UnaryOp, typename MatrixType> class CwiseUnaryOp;
|
||||||
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
template<typename ViewOp, typename MatrixType> class CwiseUnaryView;
|
||||||
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
template<typename BinaryOp, typename Lhs, typename Rhs> class CwiseBinaryOp;
|
||||||
template<typename BinOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp;
|
template<typename BinOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp; // TODO deprecated
|
||||||
template<typename Derived, typename Lhs, typename Rhs> class ProductBase;
|
template<typename Derived, typename Lhs, typename Rhs> class ProductBase;
|
||||||
template<typename Lhs, typename Rhs> class Product;
|
|
||||||
template<typename Lhs, typename Rhs, int Mode> class GeneralProduct;
|
namespace internal {
|
||||||
template<typename Lhs, typename Rhs, int NestingFlags> class CoeffBasedProduct;
|
template<typename Lhs, typename Rhs> struct product_tag;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs,
|
||||||
|
int Option = DefaultProduct,
|
||||||
|
int ProductTag = internal::product_tag<Lhs,Rhs>::ret
|
||||||
|
> class Product;
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Mode> class GeneralProduct; // TODO deprecated
|
||||||
|
template<typename Lhs, typename Rhs, int NestingFlags> class CoeffBasedProduct; // TODO deprecated
|
||||||
|
|
||||||
template<typename Derived> class DiagonalBase;
|
template<typename Derived> class DiagonalBase;
|
||||||
template<typename _DiagonalVectorType> class DiagonalWrapper;
|
template<typename _DiagonalVectorType> class DiagonalWrapper;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user