mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Get rid of GeneralProduct<> for GemvProduct
This commit is contained in:
parent
6c7ab50811
commit
d67548f345
@ -342,16 +342,19 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
|
||||
*/
|
||||
namespace internal {
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct traits<GeneralProduct<Lhs,Rhs,GemvProduct> >
|
||||
: traits<ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> >
|
||||
{};
|
||||
#endif
|
||||
|
||||
template<int Side, int StorageOrder, bool BlasCompatible>
|
||||
struct gemv_selector;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
template<typename Lhs, typename Rhs>
|
||||
class GeneralProduct<Lhs, Rhs, GemvProduct>
|
||||
: public ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs>
|
||||
@ -378,24 +381,10 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
|
||||
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace internal {
|
||||
|
||||
// The vector is on the left => transposition
|
||||
template<int StorageOrder, bool BlasCompatible>
|
||||
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
||||
{
|
||||
template<typename ProductType, typename Dest>
|
||||
static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
|
||||
{
|
||||
Transpose<Dest> destT(dest);
|
||||
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
|
||||
gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
|
||||
::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
|
||||
(prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
|
||||
|
||||
template<typename Scalar,int Size,int MaxSize>
|
||||
@ -432,6 +421,23 @@ struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
// The vector is on the left => transposition
|
||||
template<int StorageOrder, bool BlasCompatible>
|
||||
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
||||
{
|
||||
template<typename ProductType, typename Dest>
|
||||
static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
|
||||
{
|
||||
Transpose<Dest> destT(dest);
|
||||
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
|
||||
gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
|
||||
::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
|
||||
(prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
||||
{
|
||||
template<typename ProductType, typename Dest>
|
||||
@ -582,6 +588,178 @@ template<> struct gemv_selector<OnTheRight,RowMajor,false>
|
||||
}
|
||||
};
|
||||
|
||||
#else // EIGEN_TEST_EVALUATORS
|
||||
|
||||
// The vector is on the left => transposition
|
||||
template<int StorageOrder, bool BlasCompatible>
|
||||
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
||||
{
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
Transpose<Dest> destT(dest);
|
||||
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
|
||||
gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
|
||||
::run(rhs.transpose(), lhs.transpose(), destT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
||||
{
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static inline void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Index Index;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
typedef typename Dest::RealScalar RealScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
|
||||
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
|
||||
|
||||
ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
|
||||
ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||
|
||||
enum {
|
||||
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
|
||||
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
|
||||
MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
|
||||
};
|
||||
|
||||
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
|
||||
|
||||
bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
|
||||
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
|
||||
|
||||
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
|
||||
evalToDest ? dest.data() : static_dest.data());
|
||||
|
||||
if(!evalToDest)
|
||||
{
|
||||
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
int size = dest.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
if(!alphaIsCompatible)
|
||||
{
|
||||
MappedDest(actualDestPtr, dest.size()).setZero();
|
||||
compatibleAlpha = RhsScalar(1);
|
||||
}
|
||||
else
|
||||
MappedDest(actualDestPtr, dest.size()) = dest;
|
||||
}
|
||||
|
||||
general_matrix_vector_product
|
||||
<Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
|
||||
actualLhs.rows(), actualLhs.cols(),
|
||||
actualLhs.data(), actualLhs.outerStride(),
|
||||
actualRhs.data(), actualRhs.innerStride(),
|
||||
actualDestPtr, 1,
|
||||
compatibleAlpha);
|
||||
|
||||
if (!evalToDest)
|
||||
{
|
||||
if(!alphaIsCompatible)
|
||||
dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
|
||||
else
|
||||
dest = MappedDest(actualDestPtr, dest.size());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct gemv_selector<OnTheRight,RowMajor,true>
|
||||
{
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Index Index;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
typedef typename Dest::RealScalar RealScalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||
|
||||
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
|
||||
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||
|
||||
enum {
|
||||
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||
DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
|
||||
};
|
||||
|
||||
gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
|
||||
DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
|
||||
|
||||
if(!DirectlyUseRhs)
|
||||
{
|
||||
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
int size = actualRhs.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||
}
|
||||
|
||||
general_matrix_vector_product
|
||||
<Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
|
||||
actualLhs.rows(), actualLhs.cols(),
|
||||
actualLhs.data(), actualLhs.outerStride(),
|
||||
actualRhsPtr, 1,
|
||||
dest.data(), dest.innerStride(),
|
||||
actualAlpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct gemv_selector<OnTheRight,ColMajor,false>
|
||||
{
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Index Index;
|
||||
// TODO makes sure dest is sequentially stored in memory, otherwise use a temp
|
||||
const Index size = rhs.rows();
|
||||
for(Index k=0; k<size; ++k)
|
||||
dest += (alpha*rhs.coeff(k)) * lhs.col(k);
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct gemv_selector<OnTheRight,RowMajor,false>
|
||||
{
|
||||
template<typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Index Index;
|
||||
// TODO makes sure rhs is sequentially stored in memory, otherwise use a temp
|
||||
const Index rows = dest.rows();
|
||||
for(Index i=0; i<rows; ++i)
|
||||
dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(rhs.transpose())).sum();
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_TEST_EVALUATORS
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/***************************************************************************
|
||||
|
@ -11,7 +11,7 @@
|
||||
#define EIGEN_PRODUCTBASE_H
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
|
||||
/** \class ProductBase
|
||||
* \ingroup Core_Module
|
||||
*
|
||||
@ -175,6 +175,7 @@ class ProductBase : public MatrixBase<Derived>
|
||||
};
|
||||
|
||||
#ifndef EIGEN_TEST_EVALUATORS
|
||||
|
||||
// here we need to overload the nested rule for products
|
||||
// such that the nested type is a const reference to a plain matrix
|
||||
namespace internal {
|
||||
|
@ -307,7 +307,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct>
|
||||
internal::gemv_selector<Side,
|
||||
(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)
|
||||
>::run(GeneralProduct<Lhs,Rhs,GemvProduct>(lhs,rhs), dst, alpha);
|
||||
>::run(lhs, rhs, dst, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user