mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 19:59:05 +08:00
Get rid of GeneralProduct<> for GemvProduct
This commit is contained in:
parent
6c7ab50811
commit
d67548f345
@ -342,16 +342,19 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
|
|||||||
*/
|
*/
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
struct traits<GeneralProduct<Lhs,Rhs,GemvProduct> >
|
struct traits<GeneralProduct<Lhs,Rhs,GemvProduct> >
|
||||||
: traits<ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> >
|
: traits<ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs> >
|
||||||
{};
|
{};
|
||||||
|
#endif
|
||||||
|
|
||||||
template<int Side, int StorageOrder, bool BlasCompatible>
|
template<int Side, int StorageOrder, bool BlasCompatible>
|
||||||
struct gemv_selector;
|
struct gemv_selector;
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
class GeneralProduct<Lhs, Rhs, GemvProduct>
|
class GeneralProduct<Lhs, Rhs, GemvProduct>
|
||||||
: public ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs>
|
: public ProductBase<GeneralProduct<Lhs,Rhs,GemvProduct>, Lhs, Rhs>
|
||||||
@ -378,24 +381,10 @@ class GeneralProduct<Lhs, Rhs, GemvProduct>
|
|||||||
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha);
|
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)>::run(*this, dst, alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
// The vector is on the left => transposition
|
|
||||||
template<int StorageOrder, bool BlasCompatible>
|
|
||||||
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
|
||||||
{
|
|
||||||
template<typename ProductType, typename Dest>
|
|
||||||
static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
|
|
||||||
{
|
|
||||||
Transpose<Dest> destT(dest);
|
|
||||||
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
|
|
||||||
gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
|
|
||||||
::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
|
|
||||||
(prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
|
template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
|
||||||
|
|
||||||
template<typename Scalar,int Size,int MaxSize>
|
template<typename Scalar,int Size,int MaxSize>
|
||||||
@ -432,6 +421,23 @@ struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
|
|||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
// The vector is on the left => transposition
|
||||||
|
template<int StorageOrder, bool BlasCompatible>
|
||||||
|
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
||||||
|
{
|
||||||
|
template<typename ProductType, typename Dest>
|
||||||
|
static void run(const ProductType& prod, Dest& dest, const typename ProductType::Scalar& alpha)
|
||||||
|
{
|
||||||
|
Transpose<Dest> destT(dest);
|
||||||
|
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
|
||||||
|
gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
|
||||||
|
::run(GeneralProduct<Transpose<const typename ProductType::_RhsNested>,Transpose<const typename ProductType::_LhsNested>, GemvProduct>
|
||||||
|
(prod.rhs().transpose(), prod.lhs().transpose()), destT, alpha);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
||||||
{
|
{
|
||||||
template<typename ProductType, typename Dest>
|
template<typename ProductType, typename Dest>
|
||||||
@ -582,6 +588,178 @@ template<> struct gemv_selector<OnTheRight,RowMajor,false>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
// The vector is on the left => transposition
|
||||||
|
template<int StorageOrder, bool BlasCompatible>
|
||||||
|
struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
||||||
|
{
|
||||||
|
template<typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||||
|
{
|
||||||
|
Transpose<Dest> destT(dest);
|
||||||
|
enum { OtherStorageOrder = StorageOrder == RowMajor ? ColMajor : RowMajor };
|
||||||
|
gemv_selector<OnTheRight,OtherStorageOrder,BlasCompatible>
|
||||||
|
::run(rhs.transpose(), lhs.transpose(), destT, alpha);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
||||||
|
{
|
||||||
|
template<typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static inline void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||||
|
{
|
||||||
|
typedef typename Dest::Index Index;
|
||||||
|
typedef typename Lhs::Scalar LhsScalar;
|
||||||
|
typedef typename Rhs::Scalar RhsScalar;
|
||||||
|
typedef typename Dest::Scalar ResScalar;
|
||||||
|
typedef typename Dest::RealScalar RealScalar;
|
||||||
|
|
||||||
|
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||||
|
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||||
|
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||||
|
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||||
|
|
||||||
|
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
|
||||||
|
|
||||||
|
ActualLhsType actualLhs = LhsBlasTraits::extract(lhs);
|
||||||
|
ActualRhsType actualRhs = RhsBlasTraits::extract(rhs);
|
||||||
|
|
||||||
|
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||||
|
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||||
|
|
||||||
|
enum {
|
||||||
|
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||||
|
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||||
|
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
|
||||||
|
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
|
||||||
|
MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
|
||||||
|
};
|
||||||
|
|
||||||
|
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
|
||||||
|
|
||||||
|
bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0));
|
||||||
|
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
|
||||||
|
|
||||||
|
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
|
||||||
|
|
||||||
|
ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(),
|
||||||
|
evalToDest ? dest.data() : static_dest.data());
|
||||||
|
|
||||||
|
if(!evalToDest)
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
int size = dest.size();
|
||||||
|
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
#endif
|
||||||
|
if(!alphaIsCompatible)
|
||||||
|
{
|
||||||
|
MappedDest(actualDestPtr, dest.size()).setZero();
|
||||||
|
compatibleAlpha = RhsScalar(1);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
MappedDest(actualDestPtr, dest.size()) = dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
general_matrix_vector_product
|
||||||
|
<Index,LhsScalar,ColMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
|
||||||
|
actualLhs.rows(), actualLhs.cols(),
|
||||||
|
actualLhs.data(), actualLhs.outerStride(),
|
||||||
|
actualRhs.data(), actualRhs.innerStride(),
|
||||||
|
actualDestPtr, 1,
|
||||||
|
compatibleAlpha);
|
||||||
|
|
||||||
|
if (!evalToDest)
|
||||||
|
{
|
||||||
|
if(!alphaIsCompatible)
|
||||||
|
dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
|
||||||
|
else
|
||||||
|
dest = MappedDest(actualDestPtr, dest.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct gemv_selector<OnTheRight,RowMajor,true>
|
||||||
|
{
|
||||||
|
template<typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||||
|
{
|
||||||
|
typedef typename Dest::Index Index;
|
||||||
|
typedef typename Lhs::Scalar LhsScalar;
|
||||||
|
typedef typename Rhs::Scalar RhsScalar;
|
||||||
|
typedef typename Dest::Scalar ResScalar;
|
||||||
|
typedef typename Dest::RealScalar RealScalar;
|
||||||
|
|
||||||
|
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||||
|
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||||
|
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||||
|
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||||
|
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||||
|
|
||||||
|
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
|
||||||
|
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
|
||||||
|
|
||||||
|
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||||
|
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||||
|
|
||||||
|
enum {
|
||||||
|
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||||
|
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||||
|
DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
|
||||||
|
};
|
||||||
|
|
||||||
|
gemv_static_vector_if<RhsScalar,ActualRhsTypeCleaned::SizeAtCompileTime,ActualRhsTypeCleaned::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||||
|
|
||||||
|
ei_declare_aligned_stack_constructed_variable(RhsScalar,actualRhsPtr,actualRhs.size(),
|
||||||
|
DirectlyUseRhs ? const_cast<RhsScalar*>(actualRhs.data()) : static_rhs.data());
|
||||||
|
|
||||||
|
if(!DirectlyUseRhs)
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
int size = actualRhs.size();
|
||||||
|
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
#endif
|
||||||
|
Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
general_matrix_vector_product
|
||||||
|
<Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
|
||||||
|
actualLhs.rows(), actualLhs.cols(),
|
||||||
|
actualLhs.data(), actualLhs.outerStride(),
|
||||||
|
actualRhsPtr, 1,
|
||||||
|
dest.data(), dest.innerStride(),
|
||||||
|
actualAlpha);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct gemv_selector<OnTheRight,ColMajor,false>
|
||||||
|
{
|
||||||
|
template<typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||||
|
{
|
||||||
|
typedef typename Dest::Index Index;
|
||||||
|
// TODO makes sure dest is sequentially stored in memory, otherwise use a temp
|
||||||
|
const Index size = rhs.rows();
|
||||||
|
for(Index k=0; k<size; ++k)
|
||||||
|
dest += (alpha*rhs.coeff(k)) * lhs.col(k);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct gemv_selector<OnTheRight,RowMajor,false>
|
||||||
|
{
|
||||||
|
template<typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static void run(const Lhs &lhs, const Rhs &rhs, Dest& dest, const typename Dest::Scalar& alpha)
|
||||||
|
{
|
||||||
|
typedef typename Dest::Index Index;
|
||||||
|
// TODO makes sure rhs is sequentially stored in memory, otherwise use a temp
|
||||||
|
const Index rows = dest.rows();
|
||||||
|
for(Index i=0; i<rows; ++i)
|
||||||
|
dest.coeffRef(i) += alpha * (lhs.row(i).cwiseProduct(rhs.transpose())).sum();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
/***************************************************************************
|
/***************************************************************************
|
||||||
|
@ -175,6 +175,7 @@ class ProductBase : public MatrixBase<Derived>
|
|||||||
};
|
};
|
||||||
|
|
||||||
#ifndef EIGEN_TEST_EVALUATORS
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
// here we need to overload the nested rule for products
|
// here we need to overload the nested rule for products
|
||||||
// such that the nested type is a const reference to a plain matrix
|
// such that the nested type is a const reference to a plain matrix
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
@ -307,7 +307,7 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct>
|
|||||||
internal::gemv_selector<Side,
|
internal::gemv_selector<Side,
|
||||||
(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
(int(MatrixType::Flags)&RowMajorBit) ? RowMajor : ColMajor,
|
||||||
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)
|
bool(internal::blas_traits<MatrixType>::HasUsableDirectAccess)
|
||||||
>::run(GeneralProduct<Lhs,Rhs,GemvProduct>(lhs,rhs), dst, alpha);
|
>::run(lhs, rhs, dst, alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user