From fa2fcb4895a4ae12cb28003e646c736d013e68e8 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Thu, 7 Feb 2019 16:07:08 +0100 Subject: [PATCH] Backed out changeset 4c0fa6ce0f81ce67dd6723528ddf72f66ae92ba2 --- Eigen/src/Core/products/GeneralMatrixMatrix.h | 174 ++++-------------- 1 file changed, 37 insertions(+), 137 deletions(-) diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 4bcccd326..f49abcad5 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -404,146 +404,26 @@ class gemm_blocking_space 1 || Dest::RowsAtCompileTime > 1), - bool MultipleColsAtCompileTime = - (Rhs::ColsAtCompileTime > 1 || Dest::ColsAtCompileTime > 1)> -struct gemm_selector { - typedef typename Product::Scalar Scalar; - - typedef internal::blas_traits LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; - typedef typename internal::remove_all::type ActualLhsTypeCleaned; - - typedef internal::blas_traits RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; - typedef typename internal::remove_all::type ActualRhsTypeCleaned; - - static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) - { - if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { - gemm_selector::run(dst, a_lhs, a_rhs, alpha); - } else if (a_rhs.cols() == 1) { - // matrix * vector. - internal::gemv_dense_selector::HasUsableDirectAccess) - >::run(a_lhs, a_rhs.col(0), dst, alpha); - } else { - // vector * matrix. - internal::gemv_dense_selector::HasUsableDirectAccess) - >::run(a_lhs.row(0), a_rhs, dst, alpha); - } - } -}; - -template -struct gemm_selector { - typedef typename Product::Scalar Scalar; - - typedef internal::blas_traits LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; - typedef typename internal::remove_all::type ActualLhsTypeCleaned; - - static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) - { - if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { - gemm_selector::run(dst, a_lhs, a_rhs, alpha); - } else { - // matrix * vector. - internal::gemv_dense_selector::HasUsableDirectAccess) - >::run(a_lhs, a_rhs.col(0), dst, alpha); - } - } -}; - -template -struct gemm_selector { - typedef typename Product::Scalar Scalar; - - typedef internal::blas_traits RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; - typedef typename internal::remove_all::type ActualRhsTypeCleaned; - - static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) - { - if (a_rhs.cols() != 1 && a_lhs.rows() != 1) { - gemm_selector::run(dst, a_lhs, a_rhs, alpha); - } else { - // vector * matrix. - internal::gemv_dense_selector::HasUsableDirectAccess) - >::run(a_lhs.row(0), a_rhs, dst, alpha); - } - } -}; - -template -struct gemm_selector { - typedef typename Product::Scalar Scalar; - typedef typename Lhs::Scalar LhsScalar; - typedef typename Rhs::Scalar RhsScalar; - - typedef internal::blas_traits LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; - typedef - typename internal::remove_all::type ActualLhsTypeCleaned; - - typedef internal::blas_traits RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; - typedef - typename internal::remove_all::type ActualRhsTypeCleaned; - - enum { - MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED( - Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime) - }; - - static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, - const Scalar& alpha) { - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * - RhsBlasTraits::extractScalarFactor(a_rhs); - typename internal::add_const_on_value_type::type lhs = - LhsBlasTraits::extract(a_lhs); - typename internal::add_const_on_value_type::type rhs = - RhsBlasTraits::extract(a_rhs); - typedef internal::gemm_blocking_space< - (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, LhsScalar, RhsScalar, - Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime, - MaxDepthAtCompileTime> - BlockingType; - - typedef internal::gemm_functor< - Scalar, Index, - internal::general_matrix_matrix_product< - Index, LhsScalar, - (ActualLhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor, - bool(LhsBlasTraits::NeedToConjugate), RhsScalar, - (ActualRhsTypeCleaned::Flags & RowMajorBit) ? RowMajor : ColMajor, - bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags & RowMajorBit) ? RowMajor : ColMajor>, - ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> - GemmFunctor; - - BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); - internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 || - Dest::MaxRowsAtCompileTime == Dynamic)>( - GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), - a_rhs.cols(), a_lhs.cols(), Dest::Flags & RowMajorBit); - } -}; - template struct generic_product_impl : generic_product_impl_base > { typedef typename Product::Scalar Scalar; + typedef typename Lhs::Scalar LhsScalar; + typedef typename Rhs::Scalar RhsScalar; + + typedef internal::blas_traits LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; + typedef typename internal::remove_all::type ActualLhsTypeCleaned; + + typedef internal::blas_traits RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; + typedef typename internal::remove_all::type ActualRhsTypeCleaned; + + enum { + MaxDepthAtCompileTime = EIGEN_SIZE_MIN_PREFER_FIXED(Lhs::MaxColsAtCompileTime,Rhs::MaxRowsAtCompileTime) + }; + typedef generic_product_impl lazyproduct; template @@ -570,7 +450,7 @@ struct generic_product_impl if((rhs.rows()+dst.rows()+dst.cols())0) lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op()); else - scaleAndAddTo(dst, lhs, rhs, Scalar(1)); + scaleAndAddTo(dst,lhs, rhs, Scalar(1)); } template @@ -589,7 +469,27 @@ struct generic_product_impl if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0) return; - gemm_selector::run(dst, a_lhs, a_rhs, alpha); + typename internal::add_const_on_value_type::type lhs = LhsBlasTraits::extract(a_lhs); + typename internal::add_const_on_value_type::type rhs = RhsBlasTraits::extract(a_rhs); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) + * RhsBlasTraits::extractScalarFactor(a_rhs); + + typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,LhsScalar,RhsScalar, + Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType; + + typedef internal::gemm_functor< + Scalar, Index, + internal::general_matrix_matrix_product< + Index, + LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), + RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; + + BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); + internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==Dynamic)> + (GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(), Dest::Flags&RowMajorBit); } };