diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index c7f4d4f47..ed6234c37 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -20,8 +20,9 @@ template class level3_blocking; template< typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, - typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs> -struct general_matrix_matrix_product + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, + int ResInnerStride> +struct general_matrix_matrix_product { typedef gebp_traits Traits; @@ -30,7 +31,7 @@ struct general_matrix_matrix_product& blocking, GemmParallelInfo* info = 0) @@ -39,8 +40,8 @@ struct general_matrix_matrix_product - ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info); + ColMajor,ResInnerStride> + ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info); } }; @@ -49,8 +50,9 @@ struct general_matrix_matrix_product -struct general_matrix_matrix_product + typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, + int ResInnerStride> +struct general_matrix_matrix_product { typedef gebp_traits Traits; @@ -59,17 +61,17 @@ typedef typename ScalarBinaryOpTraits::ReturnType ResScala static void run(Index rows, Index cols, Index depth, const LhsScalar* _lhs, Index lhsStride, const RhsScalar* _rhs, Index rhsStride, - ResScalar* _res, Index resStride, + ResScalar* _res, Index resIncr, Index resStride, ResScalar alpha, level3_blocking& blocking, GemmParallelInfo* info = 0) { typedef const_blas_data_mapper LhsMapper; typedef const_blas_data_mapper RhsMapper; - typedef blas_data_mapper ResMapper; - LhsMapper lhs(_lhs,lhsStride); - RhsMapper rhs(_rhs,rhsStride); - ResMapper res(_res, resStride); + typedef blas_data_mapper ResMapper; + LhsMapper lhs(_lhs, lhsStride); + RhsMapper rhs(_rhs, rhsStride); + ResMapper res(_res, resStride, resIncr); Index kc = blocking.kc(); // cache block size along the K direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction @@ -226,7 +228,7 @@ struct gemm_functor Gemm::run(rows, cols, m_lhs.cols(), &m_lhs.coeffRef(row,0), m_lhs.outerStride(), &m_rhs.coeffRef(0,col), m_rhs.outerStride(), - (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(), + (Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(), m_actualAlpha, m_blocking, info); } @@ -476,7 +478,8 @@ struct generic_product_impl Index, LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate), RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate), - (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>, + (Dest::Flags&RowMajorBit) ? RowMajor : ColMajor, + Dest::InnerStrideAtCompileTime>, ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor; BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true); diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h index b0f6b0d5b..71abf4013 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h @@ -51,20 +51,22 @@ template< \ typename Index, \ int LhsStorageOrder, bool ConjugateLhs, \ int RhsStorageOrder, bool ConjugateRhs> \ -struct general_matrix_matrix_product \ +struct general_matrix_matrix_product \ { \ typedef gebp_traits Traits; \ \ static void run(Index rows, Index cols, Index depth, \ const EIGTYPE* _lhs, Index lhsStride, \ const EIGTYPE* _rhs, Index rhsStride, \ - EIGTYPE* res, Index resStride, \ + EIGTYPE* res, Index resIncr, Index resStride, \ EIGTYPE alpha, \ level3_blocking& /*blocking*/, \ GemmParallelInfo* /*info = 0*/) \ { \ using std::conj; \ \ + EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \ + eigen_assert(resIncr == 1); \ char transa, transb; \ BlasIndex m, n, k, lda, ldb, ldc; \ const EIGTYPE *a, *b; \ diff --git a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h index a25197ab0..a01ac0588 100644 --- a/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +++ b/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h @@ -124,8 +124,8 @@ struct product_triangular_matrix_matrix_trmm(); \ BlasIndex aStride = convert_index(aa_tmp.outerStride()); \ gemm_blocking_space gemm_blocking(_rows,_cols,_depth, 1, true); \ - general_matrix_matrix_product::run( \ - rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ + general_matrix_matrix_product::run( \ + rows, cols, depth, aa_tmp.data(), aStride, _rhs, 1, rhsStride, res, resStride, alpha, gemm_blocking, 0); \ \ /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ } \ @@ -241,8 +241,8 @@ struct product_triangular_matrix_matrix_trmm(); \ BlasIndex aStride = convert_index(aa_tmp.outerStride()); \ gemm_blocking_space gemm_blocking(_rows,_cols,_depth, 1, true); \ - general_matrix_matrix_product::run( \ - rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \ + general_matrix_matrix_product::run( \ + rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \ \ /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \ } \ diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 6e6ee119b..3dff9bc9b 100755 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -31,7 +31,7 @@ template< typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, - int ResStorageOrder> + int ResStorageOrder, int ResInnerStride> struct general_matrix_matrix_product; template +class BlasLinearMapper; + template -class BlasLinearMapper { +class BlasLinearMapper { public: typedef typename packet_traits::type Packet; typedef typename packet_traits::half HalfPacket; - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {} + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data, Index incr=1) + : m_data(data) + { + EIGEN_ONLY_USED_FOR_DEBUG(incr); + eigen_assert(incr==1); + } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { internal::prefetch(&operator()(i)); @@ -188,16 +196,25 @@ class BlasLinearMapper { }; // Lightweight helper class to access matrix coefficients. -template -class blas_data_mapper { - public: +template +class blas_data_mapper; + +template +class blas_data_mapper +{ +public: typedef typename packet_traits::type Packet; typedef typename packet_traits::half HalfPacket; typedef BlasLinearMapper LinearMapper; typedef BlasVectorMapper VectorMapper; - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {} + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1) + : m_data(data), m_stride(stride) + { + EIGEN_ONLY_USED_FOR_DEBUG(incr); + eigen_assert(incr==1); + } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper getSubMapper(Index i, Index j) const { @@ -251,6 +268,90 @@ class blas_data_mapper { const Index m_stride; }; +// Implementation of non-natural increment (i.e. inner-stride != 1) +// The exposed API is not complete yet compared to the Incr==1 case +// because some features makes less sense in this case. +template +class BlasLinearMapper +{ +public: + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,Index incr) : m_data(data), m_incr(incr) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const { + internal::prefetch(&operator()(i)); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const { + return m_data[i*m_incr.value()]; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return pgather(m_data + i*m_incr.value(), m_incr.value()); + } + + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const { + pscatter(m_data + i*m_incr.value(), p, m_incr.value()); + } + +protected: + Scalar *m_data; + const internal::variable_if_dynamic m_incr; +}; + +template +class blas_data_mapper +{ +public: + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + typedef BlasLinearMapper LinearMapper; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper + getSubMapper(Index i, Index j) const { + return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value()); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(&operator()(i, j), m_incr.value()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const { + return m_data[StorageOrder==RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride]; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { + return pgather(&operator()(i, j),m_incr.value()); + } + + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const { + return pgather(&operator()(i, j),m_incr.value()); + } + + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const { + pscatter(&operator()(i, j), p, m_stride); + } + + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const { + return pgather(&operator()(i, j), m_stride); + } + +protected: + Scalar* EIGEN_RESTRICT m_data; + const Index m_stride; + const internal::variable_if_dynamic m_incr; +}; + // lightweight helper class to access matrix coefficients (const version) template class const_blas_data_mapper : public blas_data_mapper { diff --git a/blas/level3_impl.h b/blas/level3_impl.h index 6c802cd5f..383b6dbf3 100644 --- a/blas/level3_impl.h +++ b/blas/level3_impl.h @@ -13,28 +13,28 @@ int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc) { // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n"; - typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking&, Eigen::internal::GemmParallelInfo*); + typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, Scalar, internal::level3_blocking&, Eigen::internal::GemmParallelInfo*); static const functype func[12] = { // array index: NOTR | (NOTR << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), // array index: TR | (NOTR << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), // array index: ADJ | (NOTR << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), 0, // array index: NOTR | (TR << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), // array index: TR | (TR << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), // array index: ADJ | (TR << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), 0, // array index: NOTR | (ADJ << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), // array index: TR | (ADJ << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), // array index: ADJ | (ADJ << 2) - (internal::general_matrix_matrix_product::run), + (internal::general_matrix_matrix_product::run), 0 }; @@ -71,7 +71,7 @@ int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const internal::gemm_blocking_space blocking(*m,*n,*k,1,true); int code = OP(*opa) | (OP(*opb) << 2); - func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0); + func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0); return 0; } diff --git a/test/product.h b/test/product.h index 83b9ba098..a4e0c569e 100644 --- a/test/product.h +++ b/test/product.h @@ -239,4 +239,19 @@ template void product(const MatrixType& m) VERIFY_IS_APPROX(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate()); } + // destination with a non-default inner-stride + // see bug 1741 + if(!MatrixType::IsRowMajor) + { + typedef Matrix MatrixX; + MatrixX buffer(2*rows,2*rows); + Map > map1(buffer.data(),rows,rows,Stride(2*rows,2)); + buffer.setZero(); + VERIFY_IS_APPROX(map1 = m1 * m2.transpose(), (m1 * m2.transpose()).eval()); + buffer.setZero(); + VERIFY_IS_APPROX(map1.noalias() = m1 * m2.transpose(), (m1 * m2.transpose()).eval()); + buffer.setZero(); + VERIFY_IS_APPROX(map1.noalias() += m1 * m2.transpose(), (m1 * m2.transpose()).eval()); + } + }