From 2dde63499c4ef836a0d9dfd443494d863ad62b16 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 31 Oct 2014 16:33:51 -0700 Subject: [PATCH] Generalized the matrix vector product code. --- Eigen/src/Core/GeneralProduct.h | 32 ++- Eigen/src/Core/products/GeneralMatrixVector.h | 246 +++++++++--------- .../Core/products/TriangularMatrixVector.h | 46 ++-- .../Core/products/TriangularSolverVector.h | 24 +- Eigen/src/Core/util/BlasUtil.h | 47 +++- 5 files changed, 228 insertions(+), 167 deletions(-) diff --git a/Eigen/src/Core/GeneralProduct.h b/Eigen/src/Core/GeneralProduct.h index 7179eb124..9d3d5562c 100644 --- a/Eigen/src/Core/GeneralProduct.h +++ b/Eigen/src/Core/GeneralProduct.h @@ -11,7 +11,7 @@ #ifndef EIGEN_GENERAL_PRODUCT_H #define EIGEN_GENERAL_PRODUCT_H -namespace Eigen { +namespace Eigen { /** \class GeneralProduct * \ingroup Core_Module @@ -257,7 +257,7 @@ class GeneralProduct : public ProductBase, Lhs, Rhs> { template struct IsRowMajor : internal::conditional<(int(T::Flags)&RowMajorBit), internal::true_type, internal::false_type>::type {}; - + public: EIGEN_PRODUCT_PUBLIC_INTERFACE(GeneralProduct) @@ -266,7 +266,7 @@ class GeneralProduct EIGEN_STATIC_ASSERT((internal::is_same::value), YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) } - + struct set { template void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } }; struct add { template void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } }; struct sub { template void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } }; @@ -277,12 +277,12 @@ class GeneralProduct dst.const_cast_derived() += m_scale * src; } }; - + template inline void evalTo(Dest& dest) const { internal::outer_product_selector_run(*this, dest, set(), IsRowMajor()); } - + template inline void addTo(Dest& dest) const { internal::outer_product_selector_run(*this, dest, add(), IsRowMajor()); @@ -436,12 +436,12 @@ template<> struct gemv_selector bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; - + RhsScalar compatibleAlpha = get_factor::run(actualAlpha); ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), evalToDest ? dest.data() : static_dest.data()); - + if(!evalToDest) { #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN @@ -457,11 +457,13 @@ template<> struct gemv_selector MappedDest(actualDestPtr, dest.size()) = dest; } + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; general_matrix_vector_product - ::run( + ::run( actualLhs.rows(), actualLhs.cols(), - actualLhs.data(), actualLhs.outerStride(), - actualRhs.data(), actualRhs.innerStride(), + LhsMapper(actualLhs.data(), actualLhs.outerStride()), + RhsMapper(actualRhs.data(), actualRhs.innerStride()), actualDestPtr, 1, compatibleAlpha); @@ -516,11 +518,13 @@ template<> struct gemv_selector Map(actualRhsPtr, actualRhs.size()) = actualRhs; } + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; general_matrix_vector_product - ::run( + ::run( actualLhs.rows(), actualLhs.cols(), - actualLhs.data(), actualLhs.outerStride(), - actualRhsPtr, 1, + LhsMapper(actualLhs.data(), actualLhs.outerStride()), + RhsMapper(actualRhsPtr, 1), dest.data(), dest.innerStride(), actualAlpha); } @@ -594,7 +598,7 @@ MatrixBase::operator*(const MatrixBase &other) const #ifdef EIGEN_DEBUG_PRODUCT internal::product_type::debug(); #endif - + return Product(derived(), other.derived()); } #else diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 340c51394..7dfa48bfb 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -10,7 +10,7 @@ #ifndef EIGEN_GENERAL_MATRIX_VECTOR_H #define EIGEN_GENERAL_MATRIX_VECTOR_H -namespace Eigen { +namespace Eigen { namespace internal { @@ -48,17 +48,17 @@ namespace internal { * // we currently fall back to the NoneAligned case * * The same reasoning apply for the transposed case. - * + * * The last case (PacketSize>4) could probably be improved by generalizing the FirstAligned case, but since we do not support AVX yet... * One might also wonder why in the EvenAligned case we perform unaligned loads instead of using the aligned-loads plus re-alignment * strategy as in the FirstAligned case. The reason is that we observed that unaligned loads on a 8 byte boundary are not too slow * compared to unaligned loads on a 4 byte boundary. * */ -template -struct general_matrix_vector_product +template +struct general_matrix_vector_product { -typedef typename scalar_product_traits::ReturnType ResScalar; + typedef typename scalar_product_traits::ReturnType ResScalar; enum { Vectorizable = packet_traits::Vectorizable && packet_traits::Vectorizable @@ -78,17 +78,17 @@ typedef typename conditional::type ResPacket; EIGEN_DONT_INLINE static void run( Index rows, Index cols, - const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsIncr, + const LhsMapper& lhs, + const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha); }; -template -EIGEN_DONT_INLINE void general_matrix_vector_product::run( +template +EIGEN_DONT_INLINE void general_matrix_vector_product::run( Index rows, Index cols, - const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsIncr, + const LhsMapper& lhs, + const RhsMapper& rhs, ResScalar* res, Index resIncr, RhsScalar alpha) { @@ -97,14 +97,16 @@ EIGEN_DONT_INLINE void general_matrix_vector_product(&res[j]), \ padd( \ - padd(pcj.pmul(EIGEN_CAT(ploa , A0)(&lhs0[j]), ptmp0), \ - pcj.pmul(EIGEN_CAT(ploa , A13)(&lhs1[j]), ptmp1)), \ - padd(pcj.pmul(EIGEN_CAT(ploa , A2)(&lhs2[j]), ptmp2), \ - pcj.pmul(EIGEN_CAT(ploa , A13)(&lhs3[j]), ptmp3)) ))) + padd(pcj.pmul(lhs0.template load(j), ptmp0), \ + pcj.pmul(lhs1.template load(j), ptmp1)), \ + padd(pcj.pmul(lhs2.template load(j), ptmp2), \ + pcj.pmul(lhs3.template load(j), ptmp3)) ))) + + typedef typename LhsMapper::VectorMapper LhsScalars; conj_helper cj; conj_helper pcj; @@ -118,7 +120,9 @@ EIGEN_DONT_INLINE void general_matrix_vector_product1) { - eigen_internal_assert(size_t(lhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || size= cols) || LhsPacketSize > size - || (size_t(lhs+alignedStart+lhsStride*skipColumns)%sizeof(LhsPacket))==0); + || (size_t(firstLhs+alignedStart+lhsStride*skipColumns)%sizeof(LhsPacket))==0);*/ } else if(Vectorizable) { @@ -178,20 +182,20 @@ EIGEN_DONT_INLINE void general_matrix_vector_product(alpha*rhs[i*rhsIncr]), - ptmp1 = pset1(alpha*rhs[(i+offset1)*rhsIncr]), - ptmp2 = pset1(alpha*rhs[(i+2)*rhsIncr]), - ptmp3 = pset1(alpha*rhs[(i+offset3)*rhsIncr]); + RhsPacket ptmp0 = pset1(alpha*rhs(i, 0)), + ptmp1 = pset1(alpha*rhs(i+offset1, 0)), + ptmp2 = pset1(alpha*rhs(i+2, 0)), + ptmp3 = pset1(alpha*rhs(i+offset3, 0)); // this helps a lot generating better binary code - const LhsScalar *lhs0 = lhs + i*lhsStride, *lhs1 = lhs + (i+offset1)*lhsStride, - *lhs2 = lhs + (i+2)*lhsStride, *lhs3 = lhs + (i+offset3)*lhsStride; + const LhsScalars lhs0 = lhs.getVectorMapper(0, i+0), lhs1 = lhs.getVectorMapper(0, i+offset1), + lhs2 = lhs.getVectorMapper(0, i+2), lhs3 = lhs.getVectorMapper(0, i+offset3); if (Vectorizable) { @@ -199,10 +203,10 @@ EIGEN_DONT_INLINE void general_matrix_vector_productalignedStart) @@ -211,11 +215,11 @@ EIGEN_DONT_INLINE void general_matrix_vector_product(&lhs1[alignedStart-1]); - A02 = pload(&lhs2[alignedStart-2]); - A03 = pload(&lhs3[alignedStart-3]); + A01 = lhs1.template load(alignedStart-1); + A02 = lhs2.template load(alignedStart-2); + A03 = lhs3.template load(alignedStart-3); for (; j(&lhs1[j-1+LhsPacketSize]); palign<1>(A01,A11); - A12 = pload(&lhs2[j-2+LhsPacketSize]); palign<2>(A02,A12); - A13 = pload(&lhs3[j-3+LhsPacketSize]); palign<3>(A03,A13); + A11 = lhs1.template load(j-1+LhsPacketSize); palign<1>(A01,A11); + A12 = lhs2.template load(j-2+LhsPacketSize); palign<2>(A02,A12); + A13 = lhs3.template load(j-3+LhsPacketSize); palign<3>(A03,A13); - A00 = pload(&lhs0[j]); - A10 = pload(&lhs0[j+LhsPacketSize]); + A00 = lhs0.template load(j); + A10 = lhs0.template load(j+LhsPacketSize); T0 = pcj.pmadd(A00, ptmp0, pload(&res[j])); T1 = pcj.pmadd(A10, ptmp0, pload(&res[j+ResPacketSize])); T0 = pcj.pmadd(A01, ptmp1, T0); - A01 = pload(&lhs1[j-1+2*LhsPacketSize]); palign<1>(A11,A01); + A01 = lhs1.template load(j-1+2*LhsPacketSize); palign<1>(A11,A01); T0 = pcj.pmadd(A02, ptmp2, T0); - A02 = pload(&lhs2[j-2+2*LhsPacketSize]); palign<2>(A12,A02); + A02 = lhs2.template load(j-2+2*LhsPacketSize); palign<2>(A12,A02); T0 = pcj.pmadd(A03, ptmp3, T0); pstore(&res[j],T0); - A03 = pload(&lhs3[j-3+2*LhsPacketSize]); palign<3>(A13,A03); + A03 = lhs3.template load(j-3+2*LhsPacketSize); palign<3>(A13,A03); T1 = pcj.pmadd(A11, ptmp1, T1); T1 = pcj.pmadd(A12, ptmp2, T1); T1 = pcj.pmadd(A13, ptmp3, T1); @@ -254,12 +258,12 @@ EIGEN_DONT_INLINE void general_matrix_vector_product(alpha*rhs[k*rhsIncr]); - const LhsScalar* lhs0 = lhs + k*lhsStride; + RhsPacket ptmp0 = pset1(alpha*rhs(k, 0)); + const LhsScalars lhs0 = lhs.getVectorMapper(0, k); if (Vectorizable) { /* explicit vectorization */ // process first unaligned result's coeffs for (Index j=0; j(alignedStart)) for (Index i = alignedStart;i(&lhs0[i]), ptmp0, pload(&res[i]))); + pstore(&res[i], pcj.pmadd(lhs0.template load(i), ptmp0, pload(&res[i]))); else for (Index i = alignedStart;i(&lhs0[i]), ptmp0, pload(&res[i]))); + pstore(&res[i], pcj.pmadd(lhs0.template load(i), ptmp0, pload(&res[i]))); } // process remaining scalars (or all if no explicit vectorization) for (Index i=alignedSize; i -struct general_matrix_vector_product +template +struct general_matrix_vector_product { typedef typename scalar_product_traits::ReturnType ResScalar; @@ -346,67 +350,69 @@ typedef typename packet_traits::type _ResPacket; typedef typename conditional::type LhsPacket; typedef typename conditional::type RhsPacket; typedef typename conditional::type ResPacket; - + EIGEN_DONT_INLINE static void run( Index rows, Index cols, - const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsIncr, + const LhsMapper& lhs, + const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha); }; -template -EIGEN_DONT_INLINE void general_matrix_vector_product::run( +template +EIGEN_DONT_INLINE void general_matrix_vector_product::run( Index rows, Index cols, - const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsIncr, + const LhsMapper& lhs, + const RhsMapper& rhs, ResScalar* res, Index resIncr, ResScalar alpha) { - EIGEN_UNUSED_VARIABLE(rhsIncr); - eigen_internal_assert(rhsIncr==1); - + eigen_internal_assert(rhs.stride()==1); + #ifdef _EIGEN_ACCUMULATE_PACKETS #error _EIGEN_ACCUMULATE_PACKETS has already been defined #endif - #define _EIGEN_ACCUMULATE_PACKETS(A0,A13,A2) {\ - RhsPacket b = pload(&rhs[j]); \ - ptmp0 = pcj.pmadd(EIGEN_CAT(ploa,A0) (&lhs0[j]), b, ptmp0); \ - ptmp1 = pcj.pmadd(EIGEN_CAT(ploa,A13)(&lhs1[j]), b, ptmp1); \ - ptmp2 = pcj.pmadd(EIGEN_CAT(ploa,A2) (&lhs2[j]), b, ptmp2); \ - ptmp3 = pcj.pmadd(EIGEN_CAT(ploa,A13)(&lhs3[j]), b, ptmp3); } + #define _EIGEN_ACCUMULATE_PACKETS(Alignment0,Alignment13,Alignment2) {\ + RhsPacket b = rhs.getVectorMapper(j, 0).template load(0); \ + ptmp0 = pcj.pmadd(lhs0.template load(j), b, ptmp0); \ + ptmp1 = pcj.pmadd(lhs1.template load(j), b, ptmp1); \ + ptmp2 = pcj.pmadd(lhs2.template load(j), b, ptmp2); \ + ptmp3 = pcj.pmadd(lhs3.template load(j), b, ptmp3); } conj_helper cj; conj_helper pcj; + typedef typename LhsMapper::VectorMapper LhsScalars; + enum { AllAligned=0, EvenAligned=1, FirstAligned=2, NoneAligned=3 }; const Index rowsAtOnce = 4; const Index peels = 2; const Index RhsPacketAlignedMask = RhsPacketSize-1; const Index LhsPacketAlignedMask = LhsPacketSize-1; -// const Index PeelAlignedMask = RhsPacketSize*peels-1; const Index depth = cols; + const Index lhsStride = lhs.stride(); // How many coeffs of the result do we have to skip to be aligned. // Here we assume data are at least aligned on the base scalar type // if that's not the case then vectorization is discarded, see below. - Index alignedStart = internal::first_aligned(rhs, depth); + Index alignedStart = rhs.firstAligned(depth); Index alignedSize = RhsPacketSize>1 ? alignedStart + ((depth-alignedStart) & ~RhsPacketAlignedMask) : 0; const Index peeledSize = alignedSize - RhsPacketSize*peels - RhsPacketSize + 1; const Index alignmentStep = LhsPacketSize>1 ? (LhsPacketSize - lhsStride % LhsPacketSize) & LhsPacketAlignedMask : 0; Index alignmentPattern = alignmentStep==0 ? AllAligned - : alignmentStep==(LhsPacketSize/2) ? EvenAligned - : FirstAligned; + : alignmentStep==(LhsPacketSize/2) ? EvenAligned + : FirstAligned; // we cannot assume the first element is aligned because of sub-matrices - const Index lhsAlignmentOffset = internal::first_aligned(lhs,depth); + const Index lhsAlignmentOffset = lhs.firstAligned(depth); + const Index rhsAlignmentOffset = rhs.firstAligned(rows); // find how many rows do we have to skip to be aligned with rhs (if possible) Index skipRows = 0; // if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats) - if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) || (size_t(lhs)%sizeof(LhsScalar)) || (size_t(rhs)%sizeof(RhsScalar)) ) + if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) || (lhsAlignmentOffset < 0) || (rhsAlignmentOffset < 0) ) { alignedSize = 0; alignedStart = 0; @@ -418,7 +424,7 @@ EIGEN_DONT_INLINE void general_matrix_vector_product1) { - eigen_internal_assert(size_t(lhs+lhsAlignmentOffset)%sizeof(LhsPacket)==0 || depth= rows) || LhsPacketSize > depth - || (size_t(lhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0); + || (size_t(firstLhs+alignedStart+lhsStride*skipRows)%sizeof(LhsPacket))==0);*/ } else if(Vectorizable) { @@ -447,8 +453,8 @@ EIGEN_DONT_INLINE void general_matrix_vector_productalignedStart) @@ -481,11 +487,11 @@ EIGEN_DONT_INLINE void general_matrix_vector_product(&lhs1[alignedStart-1]); - A02 = pload(&lhs2[alignedStart-2]); - A03 = pload(&lhs3[alignedStart-3]); + A01 = lhs1.template load(alignedStart-1); + A02 = lhs2.template load(alignedStart-2); + A03 = lhs3.template load(alignedStart-3); for (; j(&rhs[j]); - A11 = pload(&lhs1[j-1+LhsPacketSize]); palign<1>(A01,A11); - A12 = pload(&lhs2[j-2+LhsPacketSize]); palign<2>(A02,A12); - A13 = pload(&lhs3[j-3+LhsPacketSize]); palign<3>(A03,A13); + RhsPacket b = rhs.getVectorMapper(j, 0).template load(0); + A11 = lhs1.template load(j-1+LhsPacketSize); palign<1>(A01,A11); + A12 = lhs2.template load(j-2+LhsPacketSize); palign<2>(A02,A12); + A13 = lhs3.template load(j-3+LhsPacketSize); palign<3>(A03,A13); - ptmp0 = pcj.pmadd(pload(&lhs0[j]), b, ptmp0); + ptmp0 = pcj.pmadd(lhs0.template load(j), b, ptmp0); ptmp1 = pcj.pmadd(A01, b, ptmp1); - A01 = pload(&lhs1[j-1+2*LhsPacketSize]); palign<1>(A11,A01); + A01 = lhs1.template load(j-1+2*LhsPacketSize); palign<1>(A11,A01); ptmp2 = pcj.pmadd(A02, b, ptmp2); - A02 = pload(&lhs2[j-2+2*LhsPacketSize]); palign<2>(A12,A02); + A02 = lhs2.template load(j-2+2*LhsPacketSize); palign<2>(A12,A02); ptmp3 = pcj.pmadd(A03, b, ptmp3); - A03 = pload(&lhs3[j-3+2*LhsPacketSize]); palign<3>(A13,A03); + A03 = lhs3.template load(j-3+2*LhsPacketSize); palign<3>(A13,A03); - b = pload(&rhs[j+RhsPacketSize]); - ptmp0 = pcj.pmadd(pload(&lhs0[j+LhsPacketSize]), b, ptmp0); + b = rhs.getVectorMapper(j+RhsPacketSize, 0).template load(0); + ptmp0 = pcj.pmadd(lhs0.template load(j+LhsPacketSize), b, ptmp0); ptmp1 = pcj.pmadd(A11, b, ptmp1); ptmp2 = pcj.pmadd(A12, b, ptmp2); ptmp3 = pcj.pmadd(A13, b, ptmp3); } } for (; j(tmp0); - const LhsScalar* lhs0 = lhs + i*lhsStride; + const LhsScalars lhs0 = lhs.getVectorMapper(i, 0); // process first unaligned result's coeffs // FIXME this loop get vectorized by the compiler ! for (Index j=0; jalignedStart) { // process aligned rhs coeffs - if ((size_t(lhs0+alignedStart)%sizeof(LhsPacket))==0) + if (lhs0.template aligned(alignedStart)) for (Index j = alignedStart;j(&lhs0[j]), pload(&rhs[j]), ptmp0); + ptmp0 = pcj.pmadd(lhs0.template load(j), rhs.getVectorMapper(j, 0).template load(0), ptmp0); else for (Index j = alignedStart;j(&lhs0[j]), pload(&rhs[j]), ptmp0); + ptmp0 = pcj.pmadd(lhs0.template load(j), rhs.getVectorMapper(j, 0).template load(0), ptmp0); tmp0 += predux(ptmp0); } // process remaining scalars // FIXME this loop get vectorized by the compiler ! for (Index j=alignedSize; j, 0, OuterStride<> > LhsMap; const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); typename conj_expr_if::type cjLhs(lhs); - + typedef Map, 0, InnerStride<> > RhsMap; const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); typename conj_expr_if::type cjRhs(rhs); @@ -51,6 +51,9 @@ EIGEN_DONT_INLINE void triangular_matrix_vector_product > ResMap; ResMap res(_res,rows); + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; + for (Index pi=0; pi0) { Index s = IsLower ? pi+actualPanelWidth : 0; - general_matrix_vector_product::run( + general_matrix_vector_product::run( r, actualPanelWidth, - &lhs.coeffRef(s,pi), lhsStride, - &rhs.coeffRef(pi), rhsIncr, + LhsMapper(&lhs.coeffRef(s,pi), lhsStride), + RhsMapper(&rhs.coeffRef(pi), rhsIncr), &res.coeffRef(s), resIncr, alpha); } } if((!IsLower) && cols>size) { - general_matrix_vector_product::run( + general_matrix_vector_product::run( rows, cols-size, - &lhs.coeffRef(0,size), lhsStride, - &rhs.coeffRef(size), rhsIncr, + LhsMapper(&lhs.coeffRef(0,size), lhsStride), + RhsMapper(&rhs.coeffRef(size), rhsIncr), _res, resIncr, alpha); } } @@ -118,7 +121,10 @@ EIGEN_DONT_INLINE void triangular_matrix_vector_product, 0, InnerStride<> > ResMap; ResMap res(_res,rows,InnerStride<>(resIncr)); - + + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; + for (Index pi=0; pi0) { Index s = IsLower ? 0 : pi + actualPanelWidth; - general_matrix_vector_product::run( + general_matrix_vector_product::run( actualPanelWidth, r, - &lhs.coeffRef(pi,s), lhsStride, - &rhs.coeffRef(s), rhsIncr, + LhsMapper(&lhs.coeffRef(pi,s), lhsStride), + RhsMapper(&rhs.coeffRef(s), rhsIncr), &res.coeffRef(pi), resIncr, alpha); } } if(IsLower && rows>diagSize) { - general_matrix_vector_product::run( + general_matrix_vector_product::run( rows-diagSize, cols, - &lhs.coeffRef(diagSize,0), lhsStride, - &rhs.coeffRef(0), rhsIncr, + LhsMapper(&lhs.coeffRef(diagSize,0), lhsStride), + RhsMapper(&rhs.coeffRef(0), rhsIncr), &res.coeffRef(diagSize), resIncr, alpha); } } @@ -184,7 +190,7 @@ struct TriangularProduct template void scaleAndAddTo(Dest& dst, const Scalar& alpha) const { eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); - + internal::trmv_selector<(int(internal::traits::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); } }; @@ -211,7 +217,7 @@ struct TriangularProduct namespace internal { // TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. - + template<> struct trmv_selector { template @@ -247,7 +253,7 @@ template<> struct trmv_selector bool alphaIsCompatible = (!ComplexByReal) || (numext::imag(actualAlpha)==RealScalar(0)); bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; - + RhsScalar compatibleAlpha = get_factor::run(actualAlpha); ei_declare_aligned_stack_constructed_variable(ResScalar,actualDestPtr,dest.size(), @@ -267,7 +273,7 @@ template<> struct trmv_selector else MappedDest(actualDestPtr, dest.size()) = dest; } - + internal::triangular_matrix_vector_product struct trmv_selector #endif Map(actualRhsPtr, actualRhs.size()) = actualRhs; } - + internal::triangular_matrix_vector_product ::run(size, _lhs, lhsStride, rhs); } }; - + // forward and backward substitution, row-major, rhs is a vector template struct triangular_solve_vector @@ -37,6 +37,10 @@ struct triangular_solve_vector, 0, OuterStride<> > LhsMap; const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); + + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; + typename internal::conditional< Conjugate, const CwiseUnaryOp,LhsMap>, @@ -58,10 +62,10 @@ struct triangular_solve_vector::run( + general_matrix_vector_product::run( actualPanelWidth, r, - &lhs.coeffRef(startRow,startCol), lhsStride, - rhs + startCol, 1, + LhsMapper(&lhs.coeffRef(startRow,startCol), lhsStride), + RhsMapper(rhs + startCol, 1), rhs + startRow, 1, RhsScalar(-1)); } @@ -72,7 +76,7 @@ struct triangular_solve_vector0) rhs[i] -= (cjLhs.row(i).segment(s,k).transpose().cwiseProduct(Map >(rhs+s,k))).sum(); - + if(!(Mode & UnitDiag)) rhs[i] /= cjLhs(i,i); } @@ -91,6 +95,8 @@ struct triangular_solve_vector, 0, OuterStride<> > LhsMap; const LhsMap lhs(_lhs,size,size,OuterStride<>(lhsStride)); + typedef const_blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; typename internal::conditional,LhsMap>, const LhsMap& @@ -122,10 +128,10 @@ struct triangular_solve_vector::run( + general_matrix_vector_product::run( r, actualPanelWidth, - &lhs.coeffRef(endBlock,startBlock), lhsStride, - rhs+startBlock, 1, + LhsMapper(&lhs.coeffRef(endBlock,startBlock), lhsStride), + RhsMapper(rhs+startBlock, 1), rhs+endBlock, 1, RhsScalar(-1)); } } diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 25a62d528..c4881b8da 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -34,7 +34,9 @@ template< int ResStorageOrder> struct general_matrix_matrix_product; -template +template struct general_matrix_vector_product; @@ -118,13 +120,35 @@ template struct get_factor::R }; +template +class BlasVectorMapper { + public: + EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {} + + EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_data[i]; + } + template + EIGEN_ALWAYS_INLINE Packet load(Index i) const { + return ploadt(m_data + i); + } + + template + bool aligned(Index i) const { + return (size_t(m_data+i)%sizeof(Packet))==0; + } + + protected: + Scalar* m_data; +}; + template -class MatrixLinearMapper { +class BlasLinearMapper { public: typedef typename packet_traits::type Packet; typedef typename packet_traits::half HalfPacket; - EIGEN_ALWAYS_INLINE MatrixLinearMapper(Scalar *data) : m_data(data) {} + EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {} EIGEN_ALWAYS_INLINE void prefetch(int i) const { internal::prefetch(&operator()(i)); @@ -157,7 +181,8 @@ class blas_data_mapper { typedef typename packet_traits::type Packet; typedef typename packet_traits::half HalfPacket; - typedef MatrixLinearMapper LinearMapper; + typedef BlasLinearMapper LinearMapper; + typedef BlasVectorMapper VectorMapper; EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {} @@ -170,6 +195,11 @@ class blas_data_mapper { return LinearMapper(&operator()(i, j)); } + EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { + return VectorMapper(&operator()(i, j)); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const { return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride]; @@ -193,6 +223,15 @@ class blas_data_mapper { return pgather(&operator()(i, j), m_stride); } + const Index stride() const { return m_stride; } + + Index firstAligned(Index size) const { + if (size_t(m_data)%sizeof(Scalar)) { + return -1; + } + return internal::first_aligned(m_data, size); + } + protected: Scalar* EIGEN_RESTRICT m_data; const Index m_stride;