mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 19:59:05 +08:00
Get rid of GeneralProduct<> for GemmProduct
This commit is contained in:
parent
728c3d2cb9
commit
6c7ab50811
@ -371,6 +371,9 @@ using std::ptrdiff_t;
|
|||||||
#include "src/Core/products/GeneralBlockPanelKernel.h"
|
#include "src/Core/products/GeneralBlockPanelKernel.h"
|
||||||
#include "src/Core/products/Parallelizer.h"
|
#include "src/Core/products/Parallelizer.h"
|
||||||
#include "src/Core/products/CoeffBasedProduct.h"
|
#include "src/Core/products/CoeffBasedProduct.h"
|
||||||
|
#ifdef EIGEN_ENABLE_EVALUATORS
|
||||||
|
#include "src/Core/ProductEvaluators.h"
|
||||||
|
#endif
|
||||||
#include "src/Core/products/GeneralMatrixVector.h"
|
#include "src/Core/products/GeneralMatrixVector.h"
|
||||||
#include "src/Core/products/GeneralMatrixMatrix.h"
|
#include "src/Core/products/GeneralMatrixMatrix.h"
|
||||||
#include "src/Core/SolveTriangular.h"
|
#include "src/Core/SolveTriangular.h"
|
||||||
@ -386,10 +389,6 @@ using std::ptrdiff_t;
|
|||||||
#include "src/Core/BandMatrix.h"
|
#include "src/Core/BandMatrix.h"
|
||||||
#include "src/Core/CoreIterators.h"
|
#include "src/Core/CoreIterators.h"
|
||||||
|
|
||||||
#ifdef EIGEN_ENABLE_EVALUATORS
|
|
||||||
#include "src/Core/ProductEvaluators.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include "src/Core/BooleanRedux.h"
|
#include "src/Core/BooleanRedux.h"
|
||||||
#include "src/Core/Select.h"
|
#include "src/Core/Select.h"
|
||||||
#include "src/Core/VectorwiseOp.h"
|
#include "src/Core/VectorwiseOp.h"
|
||||||
|
@ -311,20 +311,6 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemvProduct>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
|
||||||
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
|
||||||
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
|
|
||||||
{
|
|
||||||
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
|
||||||
|
|
||||||
template<typename Dest>
|
|
||||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha)
|
|
||||||
{
|
|
||||||
// TODO bypass GeneralProduct class
|
|
||||||
GeneralProduct<Lhs, Rhs, GemmProduct>(lhs,rhs).scaleAndAddTo(dst, alpha);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
|
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,CoeffBasedProductMode>
|
||||||
{
|
{
|
||||||
|
@ -374,6 +374,7 @@ class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, M
|
|||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
class GeneralProduct<Lhs, Rhs, GemmProduct>
|
class GeneralProduct<Lhs, Rhs, GemmProduct>
|
||||||
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
|
: public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
|
||||||
@ -421,6 +422,62 @@ class GeneralProduct<Lhs, Rhs, GemmProduct>
|
|||||||
internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
|
internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&RowMajorBit);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs>
|
||||||
|
struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||||
|
: generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct> >
|
||||||
|
{
|
||||||
|
typedef typename Product<Lhs,Rhs>::Scalar Scalar;
|
||||||
|
typedef typename Product<Lhs,Rhs>::Index Index;
|
||||||
|
typedef typename Lhs::Scalar LhsScalar;
|
||||||
|
typedef typename Rhs::Scalar RhsScalar;
|
||||||
|
|
||||||
|
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||||
|
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||||
|
typedef typename internal::remove_all<ActualLhsType>::type ActualLhsTypeCleaned;
|
||||||
|
|
||||||
|
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||||
|
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||||
|
typedef typename internal::remove_all<ActualRhsType>::type ActualRhsTypeCleaned;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime)
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Dest>
|
||||||
|
static void scaleAndAddTo(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha)
|
||||||
|
{
|
||||||
|
eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
|
||||||
|
|
||||||
|
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
|
||||||
|
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
|
||||||
|
|
||||||
|
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
|
||||||
|
* RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||||
|
|
||||||
|
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar,
|
||||||
|
Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
|
||||||
|
|
||||||
|
typedef internal::gemm_functor<
|
||||||
|
Scalar, Index,
|
||||||
|
internal::general_matrix_matrix_product<
|
||||||
|
Index,
|
||||||
|
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||||
|
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||||
|
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
||||||
|
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
|
||||||
|
|
||||||
|
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
|
||||||
|
|
||||||
|
internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)>
|
||||||
|
(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), Dest::Flags&RowMajorBit);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user