mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-03 02:30:38 +08:00
Merged eigen/eigen into default
This commit is contained in:
commit
b052ec6992
@ -163,6 +163,7 @@ using std::ptrdiff_t;
|
||||
// Generic half float support
|
||||
#include "src/Core/arch/Default/Half.h"
|
||||
#include "src/Core/arch/Default/TypeCasting.h"
|
||||
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
||||
|
||||
#if defined EIGEN_VECTORIZE_AVX512
|
||||
#include "src/Core/arch/SSE/PacketMath.h"
|
||||
@ -226,7 +227,10 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/SYCL/TypeCasting.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include "src/Core/arch/Default/Settings.h"
|
||||
// This file provides generic implementations valid for scalar as well
|
||||
#include "src/Core/arch/Default/GenericPacketMathFunctions.h"
|
||||
|
||||
#include "src/Core/functors/TernaryFunctors.h"
|
||||
#include "src/Core/functors/BinaryFunctors.h"
|
||||
|
@ -1118,11 +1118,8 @@ struct unary_evaluator<Block<ArgType, BlockRows, BlockCols, InnerPanel>, IndexBa
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
CoeffReturnType coeff(Index index) const
|
||||
{
|
||||
if (ForwardLinearAccess)
|
||||
return m_argImpl.coeff(m_linear_offset.value() + index);
|
||||
else
|
||||
return coeff(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
|
||||
{
|
||||
return linear_coeff_impl(index, bool_constant<ForwardLinearAccess>());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
@ -1133,11 +1130,8 @@ struct unary_evaluator<Block<ArgType, BlockRows, BlockCols, InnerPanel>, IndexBa
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Scalar& coeffRef(Index index)
|
||||
{
|
||||
if (ForwardLinearAccess)
|
||||
return m_argImpl.coeffRef(m_linear_offset.value() + index);
|
||||
else
|
||||
return coeffRef(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
|
||||
{
|
||||
return linear_coeffRef_impl(index, bool_constant<ForwardLinearAccess>());
|
||||
}
|
||||
|
||||
template<int LoadMode, typename PacketType>
|
||||
@ -1178,6 +1172,28 @@ struct unary_evaluator<Block<ArgType, BlockRows, BlockCols, InnerPanel>, IndexBa
|
||||
}
|
||||
|
||||
protected:
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
CoeffReturnType linear_coeff_impl(Index index, internal::true_type /* ForwardLinearAccess */) const
|
||||
{
|
||||
return m_argImpl.coeff(m_linear_offset.value() + index);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
CoeffReturnType linear_coeff_impl(Index index, internal::false_type /* not ForwardLinearAccess */) const
|
||||
{
|
||||
return coeff(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Scalar& linear_coeffRef_impl(Index index, internal::true_type /* ForwardLinearAccess */)
|
||||
{
|
||||
return m_argImpl.coeffRef(m_linear_offset.value() + index);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Scalar& linear_coeffRef_impl(Index index, internal::false_type /* not ForwardLinearAccess */)
|
||||
{
|
||||
return coeffRef(RowsAtCompileTime == 1 ? 0 : index, RowsAtCompileTime == 1 ? index : 0);
|
||||
}
|
||||
|
||||
evaluator<ArgType> m_argImpl;
|
||||
const variable_if_dynamic<Index, (ArgType::RowsAtCompileTime == 1 && BlockRows==1) ? 0 : Dynamic> m_startRow;
|
||||
const variable_if_dynamic<Index, (ArgType::ColsAtCompileTime == 1 && BlockCols==1) ? 0 : Dynamic> m_startCol;
|
||||
|
@ -542,7 +542,7 @@ Packet pexpm1(const Packet& a) { return numext::expm1(a); }
|
||||
|
||||
/** \internal \returns the log of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet plog(const Packet& a) { using std::log; return log(a); }
|
||||
Packet plog(const Packet& a) { EIGEN_USING_STD(log); return log(a); }
|
||||
|
||||
/** \internal \returns the log1p of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
@ -554,7 +554,7 @@ Packet plog10(const Packet& a) { using std::log10; return log10(a); }
|
||||
|
||||
/** \internal \returns the square-root of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet psqrt(const Packet& a) { using std::sqrt; return sqrt(a); }
|
||||
Packet psqrt(const Packet& a) { EIGEN_USING_STD(sqrt); return sqrt(a); }
|
||||
|
||||
/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
|
@ -19,7 +19,7 @@ namespace internal {
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int Side, int Mode, bool Conjugate, int StorageOrder>
|
||||
struct triangular_solve_vector;
|
||||
|
||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder>
|
||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherStorageOrder, int OtherInnerStride>
|
||||
struct triangular_solve_matrix;
|
||||
|
||||
// small helper struct extracting some traits on the underlying solver operation
|
||||
@ -98,8 +98,8 @@ struct triangular_solver_selector<Lhs,Rhs,Side,Mode,NoUnrolling,Dynamic>
|
||||
BlockingType blocking(rhs.rows(), rhs.cols(), size, 1, false);
|
||||
|
||||
triangular_solve_matrix<Scalar,Index,Side,Mode,LhsProductTraits::NeedToConjugate,(int(Lhs::Flags) & RowMajorBit) ? RowMajor : ColMajor,
|
||||
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.outerStride(), blocking);
|
||||
(Rhs::Flags&RowMajorBit) ? RowMajor : ColMajor, Rhs::InnerStrideAtCompileTime>
|
||||
::run(size, othersize, &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &rhs.coeffRef(0,0), rhs.innerStride(), rhs.outerStride(), blocking);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -12,8 +12,6 @@
|
||||
#ifndef EIGEN_MATH_FUNCTIONS_ALTIVEC_H
|
||||
#define EIGEN_MATH_FUNCTIONS_ALTIVEC_H
|
||||
|
||||
#include "../Default/GenericPacketMathFunctions.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
@ -13,6 +13,9 @@
|
||||
* Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/
|
||||
*/
|
||||
|
||||
#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
|
||||
#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
@ -553,7 +556,7 @@ Packet pcos_float(const Packet& x)
|
||||
*/
|
||||
template <typename Packet, int N>
|
||||
struct ppolevl {
|
||||
static EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
|
||||
EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
return pmadd(ppolevl<Packet, N-1>::run(x, coeff), x, pset1<Packet>(coeff[N]));
|
||||
}
|
||||
@ -561,7 +564,7 @@ struct ppolevl {
|
||||
|
||||
template <typename Packet>
|
||||
struct ppolevl<Packet, 0> {
|
||||
static EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits<Packet>::type coeff[]) {
|
||||
EIGEN_UNUSED_VARIABLE(x);
|
||||
return pset1<Packet>(coeff[0]);
|
||||
}
|
||||
@ -569,3 +572,5 @@ struct ppolevl<Packet, 0> {
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H
|
||||
|
69
Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
Normal file
69
Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h
Normal file
@ -0,0 +1,69 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2019 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
|
||||
#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
// Forward declarations of the generic math functions
|
||||
// implemented in GenericPacketMathFunctions.h
|
||||
// This is needed to workaround a circular dependency.
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_float(const Packet& a, Packet& exponent);
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_float(Packet a, Packet exponent);
|
||||
|
||||
/** \internal \returns log(x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet plog_float(const Packet _x);
|
||||
|
||||
/** \internal \returns log(1 + x) */
|
||||
template<typename Packet>
|
||||
Packet generic_plog1p(const Packet& x);
|
||||
|
||||
/** \internal \returns exp(x)-1 */
|
||||
template<typename Packet>
|
||||
Packet generic_expm1(const Packet& x);
|
||||
|
||||
/** \internal \returns exp(x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet pexp_float(const Packet _x);
|
||||
|
||||
/** \internal \returns exp(x) for double precision real numbers */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet pexp_double(const Packet _x);
|
||||
|
||||
/** \internal \returns sin(x) for single precision float */
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet psin_float(const Packet& x);
|
||||
|
||||
/** \internal \returns cos(x) for single precision float */
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet pcos_float(const Packet& x);
|
||||
|
||||
template <typename Packet, int N> struct ppolevl;
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H
|
@ -8,8 +8,6 @@
|
||||
#ifndef EIGEN_MATH_FUNCTIONS_NEON_H
|
||||
#define EIGEN_MATH_FUNCTIONS_NEON_H
|
||||
|
||||
#include "../Default/GenericPacketMathFunctions.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
@ -15,8 +15,6 @@
|
||||
#ifndef EIGEN_MATH_FUNCTIONS_SSE_H
|
||||
#define EIGEN_MATH_FUNCTIONS_SSE_H
|
||||
|
||||
#include "../Default/GenericPacketMathFunctions.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
@ -20,8 +20,9 @@ template<typename _LhsScalar, typename _RhsScalar> class level3_blocking;
|
||||
template<
|
||||
typename Index,
|
||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor>
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride>
|
||||
{
|
||||
typedef gebp_traits<RhsScalar,LhsScalar> Traits;
|
||||
|
||||
@ -30,7 +31,7 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
|
||||
Index rows, Index cols, Index depth,
|
||||
const LhsScalar* lhs, Index lhsStride,
|
||||
const RhsScalar* rhs, Index rhsStride,
|
||||
ResScalar* res, Index resStride,
|
||||
ResScalar* res, Index resIncr, Index resStride,
|
||||
ResScalar alpha,
|
||||
level3_blocking<RhsScalar,LhsScalar>& blocking,
|
||||
GemmParallelInfo<Index>* info = 0)
|
||||
@ -39,8 +40,8 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
|
||||
general_matrix_matrix_product<Index,
|
||||
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
|
||||
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
|
||||
ColMajor>
|
||||
::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
|
||||
ColMajor,ResInnerStride>
|
||||
::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking,info);
|
||||
}
|
||||
};
|
||||
|
||||
@ -49,8 +50,9 @@ struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLh
|
||||
template<
|
||||
typename Index,
|
||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor>
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride>
|
||||
{
|
||||
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
@ -59,17 +61,17 @@ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScala
|
||||
static void run(Index rows, Index cols, Index depth,
|
||||
const LhsScalar* _lhs, Index lhsStride,
|
||||
const RhsScalar* _rhs, Index rhsStride,
|
||||
ResScalar* _res, Index resStride,
|
||||
ResScalar* _res, Index resIncr, Index resStride,
|
||||
ResScalar alpha,
|
||||
level3_blocking<LhsScalar,RhsScalar>& blocking,
|
||||
GemmParallelInfo<Index>* info = 0)
|
||||
{
|
||||
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
||||
LhsMapper lhs(_lhs,lhsStride);
|
||||
RhsMapper rhs(_rhs,rhsStride);
|
||||
ResMapper res(_res, resStride);
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
|
||||
LhsMapper lhs(_lhs, lhsStride);
|
||||
RhsMapper rhs(_rhs, rhsStride);
|
||||
ResMapper res(_res, resStride, resIncr);
|
||||
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
@ -228,7 +230,7 @@ struct gemm_functor
|
||||
Gemm::run(rows, cols, m_lhs.cols(),
|
||||
&m_lhs.coeffRef(row,0), m_lhs.outerStride(),
|
||||
&m_rhs.coeffRef(0,col), m_rhs.outerStride(),
|
||||
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.outerStride(),
|
||||
(Scalar*)&(m_dest.coeffRef(row,col)), m_dest.innerStride(), m_dest.outerStride(),
|
||||
m_actualAlpha, m_blocking, info);
|
||||
}
|
||||
|
||||
@ -498,7 +500,8 @@ struct generic_product_impl<Lhs,Rhs,DenseShape,DenseShape,GemmProduct>
|
||||
Index,
|
||||
LhsScalar, (ActualLhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(LhsBlasTraits::NeedToConjugate),
|
||||
RhsScalar, (ActualRhsTypeCleaned::Flags&RowMajorBit) ? RowMajor : ColMajor, bool(RhsBlasTraits::NeedToConjugate),
|
||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor>,
|
||||
(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||
Dest::InnerStrideAtCompileTime>,
|
||||
ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType> GemmFunctor;
|
||||
|
||||
BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1, true);
|
||||
|
@ -25,51 +25,54 @@ namespace internal {
|
||||
**********************************************************************/
|
||||
|
||||
// forward declarations (defined at the end of this file)
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
|
||||
struct tribb_kernel;
|
||||
|
||||
/* Optimized matrix-matrix product evaluating only one triangular half */
|
||||
template <typename Index,
|
||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResStorageOrder, int UpLo, int Version = Specialized>
|
||||
int ResStorageOrder, int ResInnerStride, int UpLo, int Version = Specialized>
|
||||
struct general_matrix_matrix_triangular_product;
|
||||
|
||||
// as usual if the result is row major => we transpose the product
|
||||
template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version>
|
||||
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,UpLo,Version>
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int UpLo, int Version>
|
||||
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,UpLo,Version>
|
||||
{
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride,
|
||||
const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride,
|
||||
const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resIncr, Index resStride,
|
||||
const ResScalar& alpha, level3_blocking<RhsScalar,LhsScalar>& blocking)
|
||||
{
|
||||
general_matrix_matrix_triangular_product<Index,
|
||||
RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs,
|
||||
LhsScalar, LhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateLhs,
|
||||
ColMajor, UpLo==Lower?Upper:Lower>
|
||||
::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking);
|
||||
ColMajor, ResInnerStride, UpLo==Lower?Upper:Lower>
|
||||
::run(size,depth,rhs,rhsStride,lhs,lhsStride,res,resIncr,resStride,alpha,blocking);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs, int UpLo, int Version>
|
||||
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Version>
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int UpLo, int Version>
|
||||
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version>
|
||||
{
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride,
|
||||
const RhsScalar* _rhs, Index rhsStride, ResScalar* _res, Index resStride,
|
||||
const RhsScalar* _rhs, Index rhsStride,
|
||||
ResScalar* _res, Index resIncr, Index resStride,
|
||||
const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar> Traits;
|
||||
|
||||
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
LhsMapper lhs(_lhs,lhsStride);
|
||||
RhsMapper rhs(_rhs,rhsStride);
|
||||
ResMapper res(_res, resStride);
|
||||
ResMapper res(_res, resStride, resIncr);
|
||||
|
||||
Index kc = blocking.kc();
|
||||
Index mc = (std::min)(size,blocking.mc());
|
||||
@ -87,7 +90,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
|
||||
gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
|
||||
gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
|
||||
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
|
||||
tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, UpLo> sybb;
|
||||
tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo> sybb;
|
||||
|
||||
for(Index k2=0; k2<depth; k2+=kc)
|
||||
{
|
||||
@ -110,7 +113,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
|
||||
gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
|
||||
(std::min)(size,i2), alpha, -1, -1, 0, 0);
|
||||
|
||||
sybb(_res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
|
||||
sybb(_res+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
|
||||
|
||||
if (UpLo==Upper)
|
||||
{
|
||||
@ -132,7 +135,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
|
||||
// while the triangular block overlapping the diagonal is evaluated into a
|
||||
// small temporary buffer which is then accumulated into the result using a
|
||||
// triangular traversal.
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int UpLo>
|
||||
template<typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs, int ResInnerStride, int UpLo>
|
||||
struct tribb_kernel
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjLhs,ConjRhs> Traits;
|
||||
@ -141,11 +144,13 @@ struct tribb_kernel
|
||||
enum {
|
||||
BlockSize = meta_least_common_multiple<EIGEN_PLAIN_ENUM_MAX(mr,nr),EIGEN_PLAIN_ENUM_MIN(mr,nr)>::ret
|
||||
};
|
||||
void operator()(ResScalar* _res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
|
||||
void operator()(ResScalar* _res, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
|
||||
{
|
||||
typedef blas_data_mapper<ResScalar, Index, ColMajor> ResMapper;
|
||||
ResMapper res(_res, resStride);
|
||||
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel;
|
||||
typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
|
||||
ResMapper res(_res, resStride, resIncr);
|
||||
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
|
||||
gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
|
||||
|
||||
Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer((internal::constructor_without_unaligned_array_assert()));
|
||||
|
||||
@ -157,32 +162,32 @@ struct tribb_kernel
|
||||
const RhsScalar* actual_b = blockB+j*depth;
|
||||
|
||||
if(UpLo==Upper)
|
||||
gebp_kernel(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha,
|
||||
-1, -1, 0, 0);
|
||||
gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha,
|
||||
-1, -1, 0, 0);
|
||||
|
||||
// selfadjoint micro block
|
||||
{
|
||||
Index i = j;
|
||||
buffer.setZero();
|
||||
// 1 - apply the kernel on the temporary buffer
|
||||
gebp_kernel(ResMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
|
||||
-1, -1, 0, 0);
|
||||
gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
|
||||
-1, -1, 0, 0);
|
||||
|
||||
// 2 - triangular accumulation
|
||||
for(Index j1=0; j1<actualBlockSize; ++j1)
|
||||
{
|
||||
ResScalar* r = &res(i, j + j1);
|
||||
typename ResMapper::LinearMapper r = res.getLinearMapper(i,j+j1);
|
||||
for(Index i1=UpLo==Lower ? j1 : 0;
|
||||
UpLo==Lower ? i1<actualBlockSize : i1<=j1; ++i1)
|
||||
r[i1] += buffer(i1,j1);
|
||||
r(i1) += buffer(i1,j1);
|
||||
}
|
||||
}
|
||||
|
||||
if(UpLo==Lower)
|
||||
{
|
||||
Index i = j+actualBlockSize;
|
||||
gebp_kernel(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i,
|
||||
depth, actualBlockSize, alpha, -1, -1, 0, 0);
|
||||
gebp_kernel1(res.getSubMapper(i, j), blockA+depth*i, actual_b, size-i,
|
||||
depth, actualBlockSize, alpha, -1, -1, 0, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -286,11 +291,12 @@ struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
|
||||
internal::general_matrix_matrix_triangular_product<Index,
|
||||
typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
||||
typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||
IsRowMajor ? RowMajor : ColMajor, UpLo&(Lower|Upper)>
|
||||
IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo&(Lower|Upper)>
|
||||
::run(size, depth,
|
||||
&actualLhs.coeffRef(SkipDiag&&(UpLo&Lower)==Lower ? 1 : 0,0), actualLhs.outerStride(),
|
||||
&actualRhs.coeffRef(0,SkipDiag&&(UpLo&Upper)==Upper ? 1 : 0), actualRhs.outerStride(),
|
||||
mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? 1 : mat.outerStride() ) : 0), mat.outerStride(), actualAlpha, blocking);
|
||||
mat.data() + (SkipDiag ? (bool(IsRowMajor) != ((UpLo&Lower)==Lower) ? mat.innerStride() : mat.outerStride() ) : 0),
|
||||
mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -40,7 +40,7 @@ namespace internal {
|
||||
template <typename Index, typename Scalar, int AStorageOrder, bool ConjugateA, int ResStorageOrder, int UpLo>
|
||||
struct general_matrix_matrix_rankupdate :
|
||||
general_matrix_matrix_triangular_product<
|
||||
Index,Scalar,AStorageOrder,ConjugateA,Scalar,AStorageOrder,ConjugateA,ResStorageOrder,UpLo,BuiltIn> {};
|
||||
Index,Scalar,AStorageOrder,ConjugateA,Scalar,AStorageOrder,ConjugateA,ResStorageOrder,1,UpLo,BuiltIn> {};
|
||||
|
||||
|
||||
// try to go to BLAS specialization
|
||||
@ -48,9 +48,9 @@ struct general_matrix_matrix_rankupdate :
|
||||
template <typename Index, int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs, int UpLo> \
|
||||
struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,ConjugateLhs, \
|
||||
Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,UpLo,Specialized> { \
|
||||
Scalar,RhsStorageOrder,ConjugateRhs,ColMajor,1,UpLo,Specialized> { \
|
||||
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const Scalar* lhs, Index lhsStride, \
|
||||
const Scalar* rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
|
||||
const Scalar* rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) \
|
||||
{ \
|
||||
if ( lhs==rhs && ((UpLo&(Lower|Upper))==UpLo) ) { \
|
||||
general_matrix_matrix_rankupdate<Index,Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,UpLo> \
|
||||
@ -59,8 +59,8 @@ struct general_matrix_matrix_triangular_product<Index,Scalar,LhsStorageOrder,Con
|
||||
general_matrix_matrix_triangular_product<Index, \
|
||||
Scalar, LhsStorageOrder, ConjugateLhs, \
|
||||
Scalar, RhsStorageOrder, ConjugateRhs, \
|
||||
ColMajor, UpLo, BuiltIn> \
|
||||
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resStride,alpha,blocking); \
|
||||
ColMajor, 1, UpLo, BuiltIn> \
|
||||
::run(size,depth,lhs,lhsStride,rhs,rhsStride,res,resIncr,resStride,alpha,blocking); \
|
||||
} \
|
||||
} \
|
||||
};
|
||||
|
@ -51,20 +51,22 @@ template< \
|
||||
typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor> \
|
||||
struct general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1> \
|
||||
{ \
|
||||
typedef gebp_traits<EIGTYPE,EIGTYPE> Traits; \
|
||||
\
|
||||
static void run(Index rows, Index cols, Index depth, \
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE* res, Index resIncr, Index resStride, \
|
||||
EIGTYPE alpha, \
|
||||
level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/, \
|
||||
GemmParallelInfo<Index>* /*info = 0*/) \
|
||||
{ \
|
||||
using std::conj; \
|
||||
\
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||
eigen_assert(resIncr == 1); \
|
||||
char transa, transb; \
|
||||
BlasIndex m, n, k, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
|
@ -294,20 +294,21 @@ struct symm_pack_rhs
|
||||
template <typename Scalar, typename Index,
|
||||
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
|
||||
int ResStorageOrder>
|
||||
int ResStorageOrder, int ResInnerStride>
|
||||
struct product_selfadjoint_matrix;
|
||||
|
||||
template <typename Scalar, typename Index,
|
||||
int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs>
|
||||
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor>
|
||||
int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,RowMajor,ResInnerStride>
|
||||
{
|
||||
|
||||
static EIGEN_STRONG_INLINE void run(
|
||||
Index rows, Index cols,
|
||||
const Scalar* lhs, Index lhsStride,
|
||||
const Scalar* rhs, Index rhsStride,
|
||||
Scalar* res, Index resStride,
|
||||
Scalar* res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
product_selfadjoint_matrix<Scalar, Index,
|
||||
@ -315,33 +316,35 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,LhsSelfAdjoint,Co
|
||||
RhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsSelfAdjoint,ConjugateRhs),
|
||||
EIGEN_LOGICAL_XOR(LhsSelfAdjoint,LhsStorageOrder==RowMajor) ? ColMajor : RowMajor,
|
||||
LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs),
|
||||
ColMajor>
|
||||
::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking);
|
||||
ColMajor,ResInnerStride>
|
||||
::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>
|
||||
{
|
||||
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index rows, Index cols,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* res, Index resStride,
|
||||
Scalar* res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor>::run(
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run(
|
||||
Index rows, Index cols,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* _res, Index resStride,
|
||||
Scalar* _res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index size = rows;
|
||||
@ -351,11 +354,11 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t
|
||||
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
LhsMapper lhs(_lhs,lhsStride);
|
||||
LhsTransposeMapper lhs_transpose(_lhs,lhsStride);
|
||||
RhsMapper rhs(_rhs,rhsStride);
|
||||
ResMapper res(_res, resStride);
|
||||
ResMapper res(_res, resStride, resIncr);
|
||||
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
@ -415,26 +418,28 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t
|
||||
// matrix * selfadjoint product
|
||||
template <typename Scalar, typename Index,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>
|
||||
{
|
||||
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index rows, Index cols,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* res, Index resStride,
|
||||
Scalar* res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs>
|
||||
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor>::run(
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride>
|
||||
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
|
||||
Index rows, Index cols,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* _res, Index resStride,
|
||||
Scalar* _res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index size = cols;
|
||||
@ -442,9 +447,9 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
LhsMapper lhs(_lhs,lhsStride);
|
||||
ResMapper res(_res,resStride);
|
||||
ResMapper res(_res,resStride, resIncr);
|
||||
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
@ -520,12 +525,13 @@ struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false>
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,bool(LhsBlasTraits::NeedToConjugate)),
|
||||
EIGEN_LOGICAL_XOR(RhsIsUpper,internal::traits<Rhs>::Flags &RowMajorBit) ? RowMajor : ColMajor, RhsIsSelfAdjoint,
|
||||
NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,bool(RhsBlasTraits::NeedToConjugate)),
|
||||
internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor>
|
||||
internal::traits<Dest>::Flags&RowMajorBit ? RowMajor : ColMajor,
|
||||
Dest::InnerStrideAtCompileTime>
|
||||
::run(
|
||||
lhs.rows(), rhs.cols(), // sizes
|
||||
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
|
||||
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
|
||||
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
||||
&dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(), // result info
|
||||
actualAlpha, blocking // alpha
|
||||
);
|
||||
}
|
||||
|
@ -44,16 +44,18 @@ namespace internal {
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor,1> \
|
||||
{\
|
||||
\
|
||||
static void run( \
|
||||
Index rows, Index cols, \
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE* res, Index resIncr, Index resStride, \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||
eigen_assert(resIncr == 1); \
|
||||
char side='L', uplo='L'; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
@ -91,15 +93,17 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLh
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,true,ConjugateLhs,RhsStorageOrder,false,ConjugateRhs,ColMajor,1> \
|
||||
{\
|
||||
static void run( \
|
||||
Index rows, Index cols, \
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE* res, Index resIncr, Index resStride, \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||
eigen_assert(resIncr == 1); \
|
||||
char side='L', uplo='L'; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
@ -167,16 +171,18 @@ EIGEN_BLAS_HEMM_L(scomplex, float, cf, chemm_)
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor,1> \
|
||||
{\
|
||||
\
|
||||
static void run( \
|
||||
Index rows, Index cols, \
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE* res, Index resIncr, Index resStride, \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||
eigen_assert(resIncr == 1); \
|
||||
char side='R', uplo='L'; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
@ -213,15 +219,17 @@ struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateL
|
||||
template <typename Index, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor> \
|
||||
struct product_selfadjoint_matrix<EIGTYPE,Index,LhsStorageOrder,false,ConjugateLhs,RhsStorageOrder,true,ConjugateRhs,ColMajor,1> \
|
||||
{\
|
||||
static void run( \
|
||||
Index rows, Index cols, \
|
||||
const EIGTYPE* _lhs, Index lhsStride, \
|
||||
const EIGTYPE* _rhs, Index rhsStride, \
|
||||
EIGTYPE* res, Index resStride, \
|
||||
EIGTYPE* res, Index resIncr, Index resStride, \
|
||||
EIGTYPE alpha, level3_blocking<EIGTYPE, EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||
eigen_assert(resIncr == 1); \
|
||||
char side='R', uplo='L'; \
|
||||
BlasIndex m, n, lda, ldb, ldc; \
|
||||
const EIGTYPE *a, *b; \
|
||||
|
@ -109,10 +109,10 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,false>
|
||||
internal::general_matrix_matrix_triangular_product<Index,
|
||||
Scalar, OtherIsRowMajor ? RowMajor : ColMajor, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
|
||||
Scalar, OtherIsRowMajor ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex,
|
||||
IsRowMajor ? RowMajor : ColMajor, UpLo>
|
||||
IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime, UpLo>
|
||||
::run(size, depth,
|
||||
&actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(),
|
||||
mat.data(), mat.outerStride(), actualAlpha, blocking);
|
||||
mat.data(), mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -45,22 +45,24 @@ template <typename Scalar, typename Index,
|
||||
int Mode, bool LhsIsTriangular,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResStorageOrder, int Version = Specialized>
|
||||
int ResStorageOrder, int ResInnerStride,
|
||||
int Version = Specialized>
|
||||
struct product_triangular_matrix_matrix;
|
||||
|
||||
template <typename Scalar, typename Index,
|
||||
int Mode, bool LhsIsTriangular,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs, int Version>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int Version>
|
||||
struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,RowMajor,Version>
|
||||
RhsStorageOrder,ConjugateRhs,RowMajor,ResInnerStride,Version>
|
||||
{
|
||||
static EIGEN_STRONG_INLINE void run(
|
||||
Index rows, Index cols, Index depth,
|
||||
const Scalar* lhs, Index lhsStride,
|
||||
const Scalar* rhs, Index rhsStride,
|
||||
Scalar* res, Index resStride,
|
||||
Scalar* res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
product_triangular_matrix_matrix<Scalar, Index,
|
||||
@ -70,18 +72,19 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
|
||||
ConjugateRhs,
|
||||
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
||||
ConjugateLhs,
|
||||
ColMajor>
|
||||
::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha, blocking);
|
||||
ColMajor, ResInnerStride>
|
||||
::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
|
||||
}
|
||||
};
|
||||
|
||||
// implements col-major += alpha * op(triangular) * op(general)
|
||||
template <typename Scalar, typename Index, int Mode,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs, int Version>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int Version>
|
||||
struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,Version>
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
|
||||
{
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
@ -95,20 +98,21 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
||||
Index _rows, Index _cols, Index _depth,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* res, Index resStride,
|
||||
Scalar* res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs, int Version>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int Version>
|
||||
EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run(
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>::run(
|
||||
Index _rows, Index _cols, Index _depth,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* _res, Index resStride,
|
||||
Scalar* _res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
// strip zeros
|
||||
@ -119,10 +123,10 @@ EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
LhsMapper lhs(_lhs,lhsStride);
|
||||
RhsMapper rhs(_rhs,rhsStride);
|
||||
ResMapper res(_res, resStride);
|
||||
ResMapper res(_res, resStride, resIncr);
|
||||
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
@ -235,10 +239,11 @@ EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
||||
// implements col-major += alpha * op(general) * op(triangular)
|
||||
template <typename Scalar, typename Index, int Mode,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs, int Version>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int Version>
|
||||
struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,Version>
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>
|
||||
{
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
enum {
|
||||
@ -251,20 +256,21 @@ struct product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
||||
Index _rows, Index _cols, Index _depth,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* res, Index resStride,
|
||||
Scalar* res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
|
||||
template <typename Scalar, typename Index, int Mode,
|
||||
int LhsStorageOrder, bool ConjugateLhs,
|
||||
int RhsStorageOrder, bool ConjugateRhs, int Version>
|
||||
int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResInnerStride, int Version>
|
||||
EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
||||
LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,Version>::run(
|
||||
RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,Version>::run(
|
||||
Index _rows, Index _cols, Index _depth,
|
||||
const Scalar* _lhs, Index lhsStride,
|
||||
const Scalar* _rhs, Index rhsStride,
|
||||
Scalar* _res, Index resStride,
|
||||
Scalar* _res, Index resIncr, Index resStride,
|
||||
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
const Index PacketBytes = packet_traits<Scalar>::size*sizeof(Scalar);
|
||||
@ -276,10 +282,10 @@ EIGEN_DONT_INLINE void product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor> ResMapper;
|
||||
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
|
||||
LhsMapper lhs(_lhs,lhsStride);
|
||||
RhsMapper rhs(_rhs,rhsStride);
|
||||
ResMapper res(_res, resStride);
|
||||
ResMapper res(_res, resStride, resIncr);
|
||||
|
||||
Index kc = blocking.kc(); // cache block size along the K direction
|
||||
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
|
||||
@ -433,12 +439,12 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
Mode, LhsIsTriangular,
|
||||
(internal::traits<ActualLhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
||||
(internal::traits<ActualRhsTypeCleaned>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||
(internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||
(internal::traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor, Dest::InnerStrideAtCompileTime>
|
||||
::run(
|
||||
stripedRows, stripedCols, stripedDepth, // sizes
|
||||
&lhs.coeffRef(0,0), lhs.outerStride(), // lhs info
|
||||
&rhs.coeffRef(0,0), rhs.outerStride(), // rhs info
|
||||
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
||||
&dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(), // result info
|
||||
actualAlpha, blocking
|
||||
);
|
||||
|
||||
|
@ -46,7 +46,7 @@ template <typename Scalar, typename Index,
|
||||
struct product_triangular_matrix_matrix_trmm :
|
||||
product_triangular_matrix_matrix<Scalar,Index,Mode,
|
||||
LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
|
||||
RhsStorageOrder, ConjugateRhs, ResStorageOrder, BuiltIn> {};
|
||||
RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
|
||||
|
||||
|
||||
// try to go to BLAS specialization
|
||||
@ -55,13 +55,15 @@ template <typename Index, int Mode, \
|
||||
int LhsStorageOrder, bool ConjugateLhs, \
|
||||
int RhsStorageOrder, bool ConjugateRhs> \
|
||||
struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,Specialized> { \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \
|
||||
static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
|
||||
const Scalar* _rhs, Index rhsStride, Scalar* res, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
|
||||
const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
|
||||
eigen_assert(resIncr == 1); \
|
||||
product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
|
||||
LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
|
||||
RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
||||
} \
|
||||
};
|
||||
|
||||
@ -115,8 +117,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
|
||||
/* Most likely no benefit to call TRMM or GEMM from BLAS */ \
|
||||
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
|
||||
/*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
|
||||
} else { \
|
||||
/* Make sense to call GEMM */ \
|
||||
@ -124,8 +126,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
|
||||
MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
|
||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
||||
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
|
||||
rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \
|
||||
\
|
||||
/*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
||||
} \
|
||||
@ -232,8 +234,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
|
||||
/* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
|
||||
product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, BuiltIn>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
|
||||
LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
|
||||
_rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
|
||||
/*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
|
||||
} else { \
|
||||
/* Make sense to call GEMM */ \
|
||||
@ -241,8 +243,8 @@ struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
|
||||
MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
|
||||
BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
|
||||
gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor>::run( \
|
||||
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, resStride, alpha, gemm_blocking, 0); \
|
||||
general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
|
||||
rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
|
||||
\
|
||||
/*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
|
||||
} \
|
||||
|
@ -15,48 +15,48 @@ namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
// if the rhs is row major, let's transpose the product
|
||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
|
||||
template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor,OtherInnerStride>
|
||||
{
|
||||
static void run(
|
||||
Index size, Index cols,
|
||||
const Scalar* tri, Index triStride,
|
||||
Scalar* _other, Index otherStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
triangular_solve_matrix<
|
||||
Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
||||
(Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
|
||||
NumTraits<Scalar>::IsComplex && Conjugate,
|
||||
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
|
||||
::run(size, cols, tri, triStride, _other, otherStride, blocking);
|
||||
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor, OtherInnerStride>
|
||||
::run(size, cols, tri, triStride, _other, otherIncr, otherStride, blocking);
|
||||
}
|
||||
};
|
||||
|
||||
/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
|
||||
*/
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder,int OtherInnerStride>
|
||||
struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
|
||||
{
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>::run(
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index cols = otherSize;
|
||||
|
||||
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> TriMapper;
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor> OtherMapper;
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> OtherMapper;
|
||||
TriMapper tri(_tri, triStride);
|
||||
OtherMapper other(_other, otherStride);
|
||||
OtherMapper other(_other, otherStride, otherIncr);
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
|
||||
@ -128,19 +128,19 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
{
|
||||
Scalar b(0);
|
||||
const Scalar* l = &tri(i,s);
|
||||
Scalar* r = &other(s,j);
|
||||
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
|
||||
for (Index i3=0; i3<k; ++i3)
|
||||
b += conj(l[i3]) * r[i3];
|
||||
b += conj(l[i3]) * r(i3);
|
||||
|
||||
other(i,j) = (other(i,j) - b)*a;
|
||||
}
|
||||
else
|
||||
{
|
||||
Scalar b = (other(i,j) *= a);
|
||||
Scalar* r = &other(s,j);
|
||||
const Scalar* l = &tri(s,i);
|
||||
typename OtherMapper::LinearMapper r = other.getLinearMapper(s,j);
|
||||
typename TriMapper::LinearMapper l = tri.getLinearMapper(s,i);
|
||||
for (Index i3=0;i3<rs;++i3)
|
||||
r[i3] -= b * conj(l[i3]);
|
||||
r(i3) -= b * conj(l(i3));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -185,28 +185,28 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conju
|
||||
|
||||
/* Optimized triangular solver with multiple left hand sides and the triangular matrix on the right
|
||||
*/
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>
|
||||
{
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking);
|
||||
};
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>::run(
|
||||
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride>
|
||||
EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,OtherInnerStride>::run(
|
||||
Index size, Index otherSize,
|
||||
const Scalar* _tri, Index triStride,
|
||||
Scalar* _other, Index otherStride,
|
||||
Scalar* _other, Index otherIncr, Index otherStride,
|
||||
level3_blocking<Scalar,Scalar>& blocking)
|
||||
{
|
||||
Index rows = otherSize;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor> LhsMapper;
|
||||
typedef blas_data_mapper<Scalar, Index, ColMajor, Unaligned, OtherInnerStride> LhsMapper;
|
||||
typedef const_blas_data_mapper<Scalar, Index, TriStorageOrder> RhsMapper;
|
||||
LhsMapper lhs(_other, otherStride);
|
||||
LhsMapper lhs(_other, otherStride, otherIncr);
|
||||
RhsMapper rhs(_tri, triStride);
|
||||
|
||||
typedef gebp_traits<Scalar,Scalar> Traits;
|
||||
@ -297,24 +297,24 @@ EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conj
|
||||
{
|
||||
Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
|
||||
|
||||
Scalar* r = &lhs(i2,j);
|
||||
typename LhsMapper::LinearMapper r = lhs.getLinearMapper(i2,j);
|
||||
for (Index k3=0; k3<k; ++k3)
|
||||
{
|
||||
Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
|
||||
Scalar* a = &lhs(i2,IsLower ? j+1+k3 : absolute_j2+k3);
|
||||
typename LhsMapper::LinearMapper a = lhs.getLinearMapper(i2,IsLower ? j+1+k3 : absolute_j2+k3);
|
||||
for (Index i=0; i<actual_mc; ++i)
|
||||
r[i] -= a[i] * b;
|
||||
r(i) -= a(i) * b;
|
||||
}
|
||||
if((Mode & UnitDiag)==0)
|
||||
{
|
||||
Scalar inv_rjj = RealScalar(1)/conj(rhs(j,j));
|
||||
for (Index i=0; i<actual_mc; ++i)
|
||||
r[i] *= inv_rjj;
|
||||
r(i) *= inv_rjj;
|
||||
}
|
||||
}
|
||||
|
||||
// pack the just computed part of lhs to A
|
||||
pack_lhs_panel(blockA, LhsMapper(_other+absolute_j2*otherStride+i2, otherStride),
|
||||
pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2),
|
||||
actualPanelWidth, actual_mc,
|
||||
actual_kc, j2);
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ namespace internal {
|
||||
// implements LeftSide op(triangular)^-1 * general
|
||||
#define EIGEN_BLAS_TRSM_L(EIGTYPE, BLASTYPE, BLASFUNC) \
|
||||
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor,1> \
|
||||
{ \
|
||||
enum { \
|
||||
IsLower = (Mode&Lower) == Lower, \
|
||||
@ -51,8 +51,10 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheLeft,Mode,Conjugate,TriStorage
|
||||
static void run( \
|
||||
Index size, Index otherSize, \
|
||||
const EIGTYPE* _tri, Index triStride, \
|
||||
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
EIGTYPE* _other, Index otherIncr, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
|
||||
eigen_assert(otherIncr == 1); \
|
||||
BlasIndex m = convert_index<BlasIndex>(size), n = convert_index<BlasIndex>(otherSize), lda, ldb; \
|
||||
char side = 'L', uplo, diag='N', transa; \
|
||||
/* Set alpha_ */ \
|
||||
@ -99,7 +101,7 @@ EIGEN_BLAS_TRSM_L(scomplex, float, ctrsm_)
|
||||
// implements RightSide general * op(triangular)^-1
|
||||
#define EIGEN_BLAS_TRSM_R(EIGTYPE, BLASTYPE, BLASFUNC) \
|
||||
template <typename Index, int Mode, bool Conjugate, int TriStorageOrder> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor> \
|
||||
struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor,1> \
|
||||
{ \
|
||||
enum { \
|
||||
IsLower = (Mode&Lower) == Lower, \
|
||||
@ -110,8 +112,10 @@ struct triangular_solve_matrix<EIGTYPE,Index,OnTheRight,Mode,Conjugate,TriStorag
|
||||
static void run( \
|
||||
Index size, Index otherSize, \
|
||||
const EIGTYPE* _tri, Index triStride, \
|
||||
EIGTYPE* _other, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
EIGTYPE* _other, Index otherIncr, Index otherStride, level3_blocking<EIGTYPE,EIGTYPE>& /*blocking*/) \
|
||||
{ \
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(otherIncr); \
|
||||
eigen_assert(otherIncr == 1); \
|
||||
BlasIndex m = convert_index<BlasIndex>(otherSize), n = convert_index<BlasIndex>(size), lda, ldb; \
|
||||
char side = 'R', uplo, diag='N', transa; \
|
||||
/* Set alpha_ */ \
|
||||
|
@ -31,7 +31,7 @@ template<
|
||||
typename Index,
|
||||
typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
|
||||
typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
|
||||
int ResStorageOrder>
|
||||
int ResStorageOrder, int ResInnerStride>
|
||||
struct general_matrix_matrix_product;
|
||||
|
||||
template<typename Index,
|
||||
@ -155,11 +155,19 @@ class BlasVectorMapper {
|
||||
Scalar* m_data;
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, int AlignmentType, int Incr=1>
|
||||
class BlasLinearMapper;
|
||||
|
||||
template<typename Scalar, typename Index, int AlignmentType>
|
||||
class BlasLinearMapper
|
||||
class BlasLinearMapper<Scalar,Index,AlignmentType>
|
||||
{
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data, Index incr=1)
|
||||
: m_data(data)
|
||||
{
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(incr);
|
||||
eigen_assert(incr==1);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
|
||||
internal::prefetch(&operator()(i));
|
||||
@ -184,14 +192,22 @@ protected:
|
||||
};
|
||||
|
||||
// Lightweight helper class to access matrix coefficients.
|
||||
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
|
||||
class blas_data_mapper
|
||||
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned, int Incr = 1>
|
||||
class blas_data_mapper;
|
||||
|
||||
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType>
|
||||
class blas_data_mapper<Scalar,Index,StorageOrder,AlignmentType,1>
|
||||
{
|
||||
public:
|
||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
|
||||
typedef BlasVectorMapper<Scalar, Index> VectorMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr=1)
|
||||
: m_data(data), m_stride(stride)
|
||||
{
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(incr);
|
||||
eigen_assert(incr==1);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
|
||||
getSubMapper(Index i, Index j) const {
|
||||
@ -247,6 +263,86 @@ protected:
|
||||
const Index m_stride;
|
||||
};
|
||||
|
||||
// Implementation of non-natural increment (i.e. inner-stride != 1)
|
||||
// The exposed API is not complete yet compared to the Incr==1 case
|
||||
// because some features makes less sense in this case.
|
||||
template<typename Scalar, typename Index, int AlignmentType, int Incr>
|
||||
class BlasLinearMapper
|
||||
{
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data,Index incr) : m_data(data), m_incr(incr) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
|
||||
internal::prefetch(&operator()(i));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
|
||||
return m_data[i*m_incr.value()];
|
||||
}
|
||||
|
||||
template<typename PacketType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i) const {
|
||||
return pgather<Scalar,PacketType>(m_data + i*m_incr.value(), m_incr.value());
|
||||
}
|
||||
|
||||
template<typename PacketType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketType &p) const {
|
||||
pscatter<Scalar, PacketType>(m_data + i*m_incr.value(), p, m_incr.value());
|
||||
}
|
||||
|
||||
protected:
|
||||
Scalar *m_data;
|
||||
const internal::variable_if_dynamic<Index,Incr> m_incr;
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType,int Incr>
|
||||
class blas_data_mapper
|
||||
{
|
||||
public:
|
||||
typedef BlasLinearMapper<Scalar, Index, AlignmentType,Incr> LinearMapper;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride, Index incr) : m_data(data), m_stride(stride), m_incr(incr) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper
|
||||
getSubMapper(Index i, Index j) const {
|
||||
return blas_data_mapper(&operator()(i, j), m_stride, m_incr.value());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
return LinearMapper(&operator()(i, j), m_incr.value());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
|
||||
return m_data[StorageOrder==RowMajor ? j*m_incr.value() + i*m_stride : i*m_incr.value() + j*m_stride];
|
||||
}
|
||||
|
||||
template<typename PacketType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketType loadPacket(Index i, Index j) const {
|
||||
return pgather<Scalar,PacketType>(&operator()(i, j),m_incr.value());
|
||||
}
|
||||
|
||||
template <typename PacketT, int AlignmentT>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i, Index j) const {
|
||||
return pgather<Scalar,PacketT>(&operator()(i, j),m_incr.value());
|
||||
}
|
||||
|
||||
template<typename SubPacket>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
|
||||
pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
|
||||
}
|
||||
|
||||
template<typename SubPacket>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
|
||||
return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
|
||||
}
|
||||
|
||||
protected:
|
||||
Scalar* EIGEN_RESTRICT m_data;
|
||||
const Index m_stride;
|
||||
const internal::variable_if_dynamic<Index,Incr> m_incr;
|
||||
};
|
||||
|
||||
// lightweight helper class to access matrix coefficients (const version)
|
||||
template<typename Scalar, typename Index, int StorageOrder>
|
||||
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
|
||||
|
@ -180,7 +180,7 @@
|
||||
#define EIGEN_COMP_ARM 0
|
||||
#endif
|
||||
|
||||
/// \internal EIGEN_COMP_ARM set to 1 if the compiler is ARM Compiler
|
||||
/// \internal EIGEN_COMP_EMSCRIPTEN set to 1 if the compiler is Emscripten Compiler
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
#define EIGEN_COMP_EMSCRIPTEN 1
|
||||
#else
|
||||
|
@ -63,6 +63,15 @@ typedef std::size_t UIntPtr;
|
||||
struct true_type { enum { value = 1 }; };
|
||||
struct false_type { enum { value = 0 }; };
|
||||
|
||||
template<bool Condition>
|
||||
struct bool_constant;
|
||||
|
||||
template<>
|
||||
struct bool_constant<true> : true_type {};
|
||||
|
||||
template<>
|
||||
struct bool_constant<false> : false_type {};
|
||||
|
||||
template<bool Condition, typename Then, typename Else>
|
||||
struct conditional { typedef Then type; };
|
||||
|
||||
|
@ -101,7 +101,7 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
m_value = 0; // this is to avoid a compilation warning
|
||||
m_value = Scalar(0); // this is to avoid a compilation warning
|
||||
m_id = -1;
|
||||
}
|
||||
return *this;
|
||||
|
@ -1341,7 +1341,7 @@ typename SparseMatrix<_Scalar,_Options,_StorageIndex>::Scalar& SparseMatrix<_Sca
|
||||
}
|
||||
|
||||
m_data.index(p) = convert_index(inner);
|
||||
return (m_data.value(p) = 0);
|
||||
return (m_data.value(p) = Scalar(0));
|
||||
}
|
||||
|
||||
if(m_data.size() != m_data.allocatedSize())
|
||||
|
@ -13,28 +13,28 @@ int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const
|
||||
const RealScalar *pa, const int *lda, const RealScalar *pb, const int *ldb, const RealScalar *pbeta, RealScalar *pc, const int *ldc)
|
||||
{
|
||||
// std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
|
||||
static const functype func[12] = {
|
||||
// array index: NOTR | (NOTR << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (NOTR << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (NOTR << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (TR << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (TR << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (TR << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (ADJ << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,1>::run),
|
||||
// array index: TR | (ADJ << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor,1>::run),
|
||||
// array index: ADJ | (ADJ << 2)
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor,1>::run),
|
||||
0
|
||||
};
|
||||
|
||||
@ -71,7 +71,7 @@ int EIGEN_BLAS_FUNC(gemm)(const char *opa, const char *opb, const int *m, const
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k,1,true);
|
||||
|
||||
int code = OP(*opa) | (OP(*opb) << 2);
|
||||
func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
|
||||
func[code](*m, *n, *k, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking, 0);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -79,63 +79,63 @@ int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, c
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
|
||||
{
|
||||
// std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
|
||||
static const functype func[32] = {
|
||||
// array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor>::run),\
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor,1>::run),\
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, Conj, RowMajor,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, Conj, RowMajor,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor>::run),
|
||||
(internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor,1>::run),
|
||||
0
|
||||
};
|
||||
|
||||
@ -163,12 +163,12 @@ int EIGEN_BLAS_FUNC(trsm)(const char *side, const char *uplo, const char *opa, c
|
||||
if(SIDE(*side)==LEFT)
|
||||
{
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
|
||||
func[code](*m, *n, a, *lda, b, *ldb, blocking);
|
||||
func[code](*m, *n, a, *lda, b, 1, *ldb, blocking);
|
||||
}
|
||||
else
|
||||
{
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
|
||||
func[code](*n, *m, a, *lda, b, *ldb, blocking);
|
||||
func[code](*n, *m, a, *lda, b, 1, *ldb, blocking);
|
||||
}
|
||||
|
||||
if(alpha!=Scalar(1))
|
||||
@ -184,63 +184,63 @@ int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, c
|
||||
const RealScalar *palpha, const RealScalar *pa, const int *lda, RealScalar *pb, const int *ldb)
|
||||
{
|
||||
// std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
static const functype func[32] = {
|
||||
// array index: NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor,1>::run),
|
||||
0,
|
||||
// array index: NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor,1>::run),
|
||||
// array index: TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor,1>::run),
|
||||
// array index: ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run),
|
||||
(internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor,1>::run),
|
||||
0
|
||||
};
|
||||
|
||||
@ -272,12 +272,12 @@ int EIGEN_BLAS_FUNC(trmm)(const char *side, const char *uplo, const char *opa, c
|
||||
if(SIDE(*side)==LEFT)
|
||||
{
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m,1,false);
|
||||
func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking);
|
||||
func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, 1, *ldb, alpha, blocking);
|
||||
}
|
||||
else
|
||||
{
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n,1,false);
|
||||
func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking);
|
||||
func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, 1, *ldb, alpha, blocking);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
@ -338,12 +338,12 @@ int EIGEN_BLAS_FUNC(symm)(const char *side, const char *uplo, const int *m, cons
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,size,1,false);
|
||||
|
||||
if(SIDE(*side)==LEFT)
|
||||
if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
|
||||
if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, RowMajor,true,false, ColMajor,false,false, ColMajor,1>::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking);
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,true,false, ColMajor,false,false, ColMajor,1>::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking);
|
||||
else return 0;
|
||||
else if(SIDE(*side)==RIGHT)
|
||||
if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
|
||||
if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, RowMajor,true,false, ColMajor,1>::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha, blocking);
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, ColMajor,true,false, ColMajor,1>::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha, blocking);
|
||||
else return 0;
|
||||
else
|
||||
return 0;
|
||||
@ -359,21 +359,21 @@ int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const
|
||||
{
|
||||
// std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
|
||||
#if !ISCOMPLEX
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
static const functype func[8] = {
|
||||
// array index: NOTR | (UP << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Upper>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, 1, Upper>::run),
|
||||
// array index: TR | (UP << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Upper>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, 1, Upper>::run),
|
||||
// array index: ADJ | (UP << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Upper>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,1, Upper>::run),
|
||||
0,
|
||||
// array index: NOTR | (LO << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Lower>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, 1, Lower>::run),
|
||||
// array index: TR | (LO << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Lower>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, 1, Lower>::run),
|
||||
// array index: ADJ | (LO << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Lower>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,1, Lower>::run),
|
||||
0
|
||||
};
|
||||
#endif
|
||||
@ -426,7 +426,7 @@ int EIGEN_BLAS_FUNC(syrk)(const char *uplo, const char *op, const int *n, const
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false);
|
||||
|
||||
int code = OP(*op) | (UPLO(*uplo) << 2);
|
||||
func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking);
|
||||
func[code](*n, *k, a, *lda, a, *lda, c, 1, *ldc, alpha, blocking);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
@ -537,18 +537,18 @@ int EIGEN_BLAS_FUNC(hemm)(const char *side, const char *uplo, const int *m, cons
|
||||
|
||||
if(SIDE(*side)==LEFT)
|
||||
{
|
||||
if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar,DenseIndex,RowMajor,true,Conj, ColMajor,false,false, ColMajor>
|
||||
::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,true,false, ColMajor,false,false, ColMajor>
|
||||
::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha, blocking);
|
||||
if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar,DenseIndex,RowMajor,true,Conj, ColMajor,false,false, ColMajor, 1>
|
||||
::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking);
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,true,false, ColMajor,false,false, ColMajor,1>
|
||||
::run(*m, *n, a, *lda, b, *ldb, c, 1, *ldc, alpha, blocking);
|
||||
else return 0;
|
||||
}
|
||||
else if(SIDE(*side)==RIGHT)
|
||||
{
|
||||
if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor>
|
||||
::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);*/
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor>
|
||||
::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha, blocking);
|
||||
if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor, 1>
|
||||
::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha, blocking);*/
|
||||
else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor,1>
|
||||
::run(*m, *n, b, *ldb, a, *lda, c, 1, *ldc, alpha, blocking);
|
||||
else return 0;
|
||||
}
|
||||
else
|
||||
@ -566,19 +566,19 @@ int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const
|
||||
{
|
||||
// std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
|
||||
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, DenseIndex, const Scalar&, internal::level3_blocking<Scalar,Scalar>&);
|
||||
static const functype func[8] = {
|
||||
// array index: NOTR | (UP << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Upper>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,1,Upper>::run),
|
||||
0,
|
||||
// array index: ADJ | (UP << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Upper>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,1,Upper>::run),
|
||||
0,
|
||||
// array index: NOTR | (LO << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Lower>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,1,Lower>::run),
|
||||
0,
|
||||
// array index: ADJ | (LO << 2)
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Lower>::run),
|
||||
(internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,1,Lower>::run),
|
||||
0
|
||||
};
|
||||
|
||||
@ -620,7 +620,7 @@ int EIGEN_BLAS_FUNC(herk)(const char *uplo, const char *op, const int *n, const
|
||||
if(*k>0 && alpha!=RealScalar(0))
|
||||
{
|
||||
internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*n,*n,*k,1,false);
|
||||
func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha, blocking);
|
||||
func[code](*n, *k, a, *lda, a, *lda, c, 1, *ldc, alpha, blocking);
|
||||
matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
|
||||
}
|
||||
return 0;
|
||||
|
@ -419,6 +419,12 @@ void check_indexed_view()
|
||||
VERIFY_IS_EQUAL( A3(ind,ind).eval(), MatrixXi::Constant(5,5,A3(1,1)) );
|
||||
}
|
||||
|
||||
// Regression for bug 1736
|
||||
{
|
||||
VERIFY_IS_APPROX(A(all, eii).col(0).eval(), A.col(eii(0)));
|
||||
A(all, eii).col(0) = A.col(eii(0));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(indexed_view)
|
||||
|
@ -241,4 +241,19 @@ template<typename MatrixType> void product(const MatrixType& m)
|
||||
VERIFY_IS_APPROX(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate());
|
||||
}
|
||||
|
||||
// destination with a non-default inner-stride
|
||||
// see bug 1741
|
||||
if(!MatrixType::IsRowMajor)
|
||||
{
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||
MatrixX buffer(2*rows,2*rows);
|
||||
Map<RowSquareMatrixType,0,Stride<Dynamic,2> > map1(buffer.data(),rows,rows,Stride<Dynamic,2>(2*rows,2));
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX(map1 = m1 * m2.transpose(), (m1 * m2.transpose()).eval());
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX(map1.noalias() = m1 * m2.transpose(), (m1 * m2.transpose()).eval());
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX(map1.noalias() += m1 * m2.transpose(), (m1 * m2.transpose()).eval());
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -82,6 +82,16 @@ template<typename Scalar> void mmtr(int size)
|
||||
ref2.template triangularView<Lower>() = ref1.template triangularView<Lower>();
|
||||
matc.template triangularView<Lower>() = sqc * matc * sqc.adjoint();
|
||||
VERIFY_IS_APPROX(matc, ref2);
|
||||
|
||||
// destination with a non-default inner-stride
|
||||
// see bug 1741
|
||||
{
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||
MatrixX buffer(2*size,2*size);
|
||||
Map<MatrixColMaj,0,Stride<Dynamic,Dynamic> > map1(buffer.data(),size,size,Stride<Dynamic,Dynamic>(2*size,2));
|
||||
buffer.setZero();
|
||||
CHECK_MMTR(map1, Lower, = s*soc*sor.adjoint());
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(product_mmtr)
|
||||
|
@ -75,12 +75,12 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in
|
||||
rhs13 = (s1*m1.adjoint()) * (s2*rhs2.adjoint()));
|
||||
|
||||
// test row major = <...>
|
||||
m2 = m1.template triangularView<Lower>(); rhs12.setRandom(); rhs13 = rhs12;
|
||||
VERIFY_IS_APPROX(rhs12 -= (s1*m2).template selfadjointView<Lower>() * (s2*rhs3),
|
||||
m2 = m1.template triangularView<Lower>(); rhs32.setRandom(); rhs13 = rhs32;
|
||||
VERIFY_IS_APPROX(rhs32.noalias() -= (s1*m2).template selfadjointView<Lower>() * (s2*rhs3),
|
||||
rhs13 -= (s1*m1) * (s2 * rhs3));
|
||||
|
||||
m2 = m1.template triangularView<Upper>();
|
||||
VERIFY_IS_APPROX(rhs12 = (s1*m2.adjoint()).template selfadjointView<Lower>() * (s2*rhs3).conjugate(),
|
||||
VERIFY_IS_APPROX(rhs32.noalias() = (s1*m2.adjoint()).template selfadjointView<Lower>() * (s2*rhs3).conjugate(),
|
||||
rhs13 = (s1*m1.adjoint()) * (s2*rhs3).conjugate());
|
||||
|
||||
|
||||
@ -92,6 +92,20 @@ template<typename Scalar, int Size, int OtherSize> void symm(int size = Size, in
|
||||
VERIFY_IS_APPROX(rhs22 = (rhs2) * (m2).template selfadjointView<Lower>(), rhs23 = (rhs2) * (m1));
|
||||
VERIFY_IS_APPROX(rhs22 = (s2*rhs2) * (s1*m2).template selfadjointView<Lower>(), rhs23 = (s2*rhs2) * (s1*m1));
|
||||
|
||||
// destination with a non-default inner-stride
|
||||
// see bug 1741
|
||||
{
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||
MatrixX buffer(2*cols,2*othersize);
|
||||
Map<Rhs1,0,Stride<Dynamic,2> > map1(buffer.data(),cols,othersize,Stride<Dynamic,2>(2*rows,2));
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX( map1.noalias() = (s1*m2).template selfadjointView<Lower>() * (s2*rhs1),
|
||||
rhs13 = (s1*m1) * (s2*rhs1));
|
||||
|
||||
Map<Rhs2,0,Stride<Dynamic,2> > map2(buffer.data(),rhs22.rows(),rhs22.cols(),Stride<Dynamic,2>(2*rhs22.outerStride(),2));
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX(map2 = (rhs2) * (m2).template selfadjointView<Lower>(), rhs23 = (rhs2) * (m1));
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(product_symm)
|
||||
|
@ -115,6 +115,17 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.row(c).adjoint(),s1)._expression()),
|
||||
((s1 * m1.row(c).adjoint() * m1.row(c).adjoint().adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
|
||||
// destination with a non-default inner-stride
|
||||
// see bug 1741
|
||||
{
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||
MatrixX buffer(2*rows,2*cols);
|
||||
Map<MatrixType,0,Stride<Dynamic,2> > map1(buffer.data(),rows,cols,Stride<Dynamic,2>(2*rows,2));
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX((map1.template selfadjointView<Lower>().rankUpdate(rhs2,s1)._expression()),
|
||||
((s1 * rhs2 * rhs2.adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(product_syrk)
|
||||
|
@ -76,8 +76,18 @@ void trmm(int rows=get_random_size<Scalar>(),
|
||||
VERIFY_IS_APPROX( ge_xs = (s1*mat).adjoint().template triangularView<Mode>() * ge_left.adjoint(), numext::conj(s1) * triTr.conjugate() * ge_left.adjoint());
|
||||
VERIFY_IS_APPROX( ge_xs = (s1*mat).transpose().template triangularView<Mode>() * ge_left.adjoint(), s1triTr * ge_left.adjoint());
|
||||
|
||||
|
||||
// TODO check with sub-matrix expressions ?
|
||||
|
||||
// destination with a non-default inner-stride
|
||||
// see bug 1741
|
||||
{
|
||||
VERIFY_IS_APPROX( ge_xs.noalias() = mat.template triangularView<Mode>() * ge_right, tri * ge_right);
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||
MatrixX buffer(2*ge_xs.rows(),2*ge_xs.cols());
|
||||
Map<ResXS,0,Stride<Dynamic,2> > map1(buffer.data(),ge_xs.rows(),ge_xs.cols(),Stride<Dynamic,2>(2*ge_xs.outerStride(),2));
|
||||
buffer.setZero();
|
||||
VERIFY_IS_APPROX( map1.noalias() = mat.template triangularView<Mode>() * ge_right, tri * ge_right);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Scalar, int Mode, int TriOrder>
|
||||
|
@ -72,6 +72,19 @@ template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols
|
||||
VERIFY_TRSM(rmLhs.template triangularView<Lower>(), rmRhs.col(c));
|
||||
VERIFY_TRSM(cmLhs.template triangularView<Lower>(), rmRhs.col(c));
|
||||
|
||||
// destination with a non-default inner-stride
|
||||
// see bug 1741
|
||||
{
|
||||
typedef Matrix<Scalar,Dynamic,Dynamic> MatrixX;
|
||||
MatrixX buffer(2*cmRhs.rows(),2*cmRhs.cols());
|
||||
Map<Matrix<Scalar,Size,Cols,colmajor>,0,Stride<Dynamic,2> > map1(buffer.data(),cmRhs.rows(),cmRhs.cols(),Stride<Dynamic,2>(2*cmRhs.outerStride(),2));
|
||||
Map<Matrix<Scalar,Size,Cols,rowmajor>,0,Stride<Dynamic,2> > map2(buffer.data(),rmRhs.rows(),rmRhs.cols(),Stride<Dynamic,2>(2*rmRhs.outerStride(),2));
|
||||
buffer.setZero();
|
||||
VERIFY_TRSM(cmLhs.conjugate().template triangularView<Lower>(), map1);
|
||||
buffer.setZero();
|
||||
VERIFY_TRSM(cmLhs .template triangularView<Lower>(), map2);
|
||||
}
|
||||
|
||||
if(Size==Dynamic)
|
||||
{
|
||||
cmLhs.resize(0,0);
|
||||
|
@ -45,11 +45,7 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "src/util/CXX11Meta.h"
|
||||
#include "src/util/MaxSizeVector.h"
|
||||
|
||||
#include "src/ThreadPool/ThreadLocal.h"
|
||||
#ifndef EIGEN_THREAD_LOCAL
|
||||
// There are non-parenthesized calls to "max" in the <unordered_map> header,
|
||||
// which trigger a check in test/main.h causing compilation to fail.
|
||||
// We work around the check here by removing the check for max in
|
||||
@ -58,7 +54,11 @@
|
||||
#undef max
|
||||
#endif
|
||||
#include <unordered_map>
|
||||
#endif
|
||||
|
||||
#include "src/util/CXX11Meta.h"
|
||||
#include "src/util/MaxSizeVector.h"
|
||||
|
||||
#include "src/ThreadPool/ThreadLocal.h"
|
||||
#include "src/ThreadPool/ThreadYield.h"
|
||||
#include "src/ThreadPool/ThreadCancel.h"
|
||||
#include "src/ThreadPool/EventCount.h"
|
||||
|
@ -60,6 +60,242 @@
|
||||
#endif
|
||||
#endif // defined(__ANDROID__) && defined(__clang__)
|
||||
|
||||
#endif // EIGEN_AVOID_THREAD_LOCAL
|
||||
#endif // EIGEN_AVOID_THREAD_LOCAL
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
template <typename T>
|
||||
struct ThreadLocalNoOpInitialize {
|
||||
void operator()(T&) const {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ThreadLocalNoOpRelease {
|
||||
void operator()(T&) const {}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Thread local container for elements of type T, that does not use thread local
|
||||
// storage. As long as the number of unique threads accessing this storage
|
||||
// is smaller than `capacity_`, it is lock-free and wait-free. Otherwise it will
|
||||
// use a mutex for synchronization.
|
||||
//
|
||||
// Type `T` has to be default constructible, and by default each thread will get
|
||||
// a default constructed value. It is possible to specify custom `initialize`
|
||||
// callable, that will be called lazily from each thread accessing this object,
|
||||
// and will be passed a default initialized object of type `T`. Also it's
|
||||
// possible to pass a custom `release` callable, that will be invoked before
|
||||
// calling ~T().
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// struct Counter {
|
||||
// int value = 0;
|
||||
// }
|
||||
//
|
||||
// Eigen::ThreadLocal<Counter> counter(10);
|
||||
//
|
||||
// // Each thread will have access to it's own counter object.
|
||||
// Counter& cnt = counter.local();
|
||||
// cnt++;
|
||||
//
|
||||
// WARNING: Eigen::ThreadLocal uses the OS-specific value returned by
|
||||
// std::this_thread::get_id() to identify threads. This value is not guaranteed
|
||||
// to be unique except for the life of the thread. A newly created thread may
|
||||
// get an OS-specific ID equal to that of an already destroyed thread.
|
||||
//
|
||||
// Somewhat similar to TBB thread local storage, with similar restrictions:
|
||||
// https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html
|
||||
//
|
||||
template <typename T,
|
||||
typename Initialize = internal::ThreadLocalNoOpInitialize<T>,
|
||||
typename Release = internal::ThreadLocalNoOpRelease<T>>
|
||||
class ThreadLocal {
|
||||
// We preallocate default constructed elements in MaxSizedVector.
|
||||
static_assert(std::is_default_constructible<T>::value,
|
||||
"ThreadLocal data type must be default constructible");
|
||||
|
||||
public:
|
||||
explicit ThreadLocal(int capacity)
|
||||
: ThreadLocal(capacity, internal::ThreadLocalNoOpInitialize<T>(),
|
||||
internal::ThreadLocalNoOpRelease<T>()) {}
|
||||
|
||||
ThreadLocal(int capacity, Initialize initialize)
|
||||
: ThreadLocal(capacity, std::move(initialize),
|
||||
internal::ThreadLocalNoOpRelease<T>()) {}
|
||||
|
||||
ThreadLocal(int capacity, Initialize initialize, Release release)
|
||||
: initialize_(std::move(initialize)),
|
||||
release_(std::move(release)),
|
||||
capacity_(capacity),
|
||||
data_(capacity_),
|
||||
ptr_(capacity_),
|
||||
filled_records_(0) {
|
||||
eigen_assert(capacity_ >= 0);
|
||||
data_.resize(capacity_);
|
||||
for (int i = 0; i < capacity_; ++i) {
|
||||
ptr_.emplace_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
T& local() {
|
||||
std::thread::id this_thread = std::this_thread::get_id();
|
||||
if (capacity_ == 0) return SpilledLocal(this_thread);
|
||||
|
||||
std::size_t h = std::hash<std::thread::id>()(this_thread);
|
||||
const int start_idx = h % capacity_;
|
||||
|
||||
// NOTE: From the definition of `std::this_thread::get_id()` it is
|
||||
// guaranteed that we never can have concurrent insertions with the same key
|
||||
// to our hash-map like data structure. If we didn't find an element during
|
||||
// the initial traversal, it's guaranteed that no one else could have
|
||||
// inserted it while we are in this function. This allows to massively
|
||||
// simplify out lock-free insert-only hash map.
|
||||
|
||||
// Check if we already have an element for `this_thread`.
|
||||
int idx = start_idx;
|
||||
while (ptr_[idx].load() != nullptr) {
|
||||
ThreadIdAndValue& record = *(ptr_[idx].load());
|
||||
if (record.thread_id == this_thread) return record.value;
|
||||
|
||||
idx += 1;
|
||||
if (idx >= capacity_) idx -= capacity_;
|
||||
if (idx == start_idx) break;
|
||||
}
|
||||
|
||||
// If we are here, it means that we found an insertion point in lookup
|
||||
// table at `idx`, or we did a full traversal and table is full.
|
||||
|
||||
// If lock-free storage is full, fallback on mutex.
|
||||
if (filled_records_.load() >= capacity_) return SpilledLocal(this_thread);
|
||||
|
||||
// We double check that we still have space to insert an element into a lock
|
||||
// free storage. If old value in `filled_records_` is larger than the
|
||||
// records capacity, it means that some other thread added an element while
|
||||
// we were traversing lookup table.
|
||||
int insertion_index =
|
||||
filled_records_.fetch_add(1, std::memory_order_relaxed);
|
||||
if (insertion_index >= capacity_) return SpilledLocal(this_thread);
|
||||
|
||||
// At this point it's guaranteed that we can access to
|
||||
// data_[insertion_index_] without a data race.
|
||||
data_[insertion_index].thread_id = this_thread;
|
||||
initialize_(data_[insertion_index].value);
|
||||
|
||||
// That's the pointer we'll put into the lookup table.
|
||||
ThreadIdAndValue* inserted = &data_[insertion_index];
|
||||
|
||||
// We'll use nullptr pointer to ThreadIdAndValue in a compare-and-swap loop.
|
||||
ThreadIdAndValue* empty = nullptr;
|
||||
|
||||
// Now we have to find an insertion point into the lookup table. We start
|
||||
// from the `idx` that was identified as an insertion point above, it's
|
||||
// guaranteed that we will have an empty record somewhere in a lookup table
|
||||
// (because we created a record in the `data_`).
|
||||
const int insertion_idx = idx;
|
||||
|
||||
do {
|
||||
// Always start search from the original insertion candidate.
|
||||
idx = insertion_idx;
|
||||
while (ptr_[idx].load() != nullptr) {
|
||||
idx += 1;
|
||||
if (idx >= capacity_) idx -= capacity_;
|
||||
// If we did a full loop, it means that we don't have any free entries
|
||||
// in the lookup table, and this means that something is terribly wrong.
|
||||
eigen_assert(idx != insertion_idx);
|
||||
}
|
||||
// Atomic CAS of the pointer guarantees that any other thread, that will
|
||||
// follow this pointer will see all the mutations in the `data_`.
|
||||
} while (!ptr_[idx].compare_exchange_weak(empty, inserted));
|
||||
|
||||
return inserted->value;
|
||||
}
|
||||
|
||||
// WARN: It's not thread safe to call it concurrently with `local()`.
|
||||
void ForEach(std::function<void(std::thread::id, T&)> f) {
|
||||
// Reading directly from `data_` is unsafe, because only CAS to the
|
||||
// record in `ptr_` makes all changes visible to other threads.
|
||||
for (auto& ptr : ptr_) {
|
||||
ThreadIdAndValue* record = ptr.load();
|
||||
if (record == nullptr) continue;
|
||||
f(record->thread_id, record->value);
|
||||
}
|
||||
|
||||
// We did not spill into the map based storage.
|
||||
if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
|
||||
|
||||
// Adds a happens before edge from the last call to SpilledLocal().
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
for (auto& kv : per_thread_map_) {
|
||||
f(kv.first, kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
// WARN: It's not thread safe to call it concurrently with `local()`.
|
||||
~ThreadLocal() {
|
||||
// Reading directly from `data_` is unsafe, because only CAS to the record
|
||||
// in `ptr_` makes all changes visible to other threads.
|
||||
for (auto& ptr : ptr_) {
|
||||
ThreadIdAndValue* record = ptr.load();
|
||||
if (record == nullptr) continue;
|
||||
release_(record->value);
|
||||
}
|
||||
|
||||
// We did not spill into the map based storage.
|
||||
if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
|
||||
|
||||
// Adds a happens before edge from the last call to SpilledLocal().
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
for (auto& kv : per_thread_map_) {
|
||||
release_(kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct ThreadIdAndValue {
|
||||
std::thread::id thread_id;
|
||||
T value;
|
||||
};
|
||||
|
||||
// Use unordered map guarded by a mutex when lock free storage is full.
|
||||
T& SpilledLocal(std::thread::id this_thread) {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
|
||||
auto it = per_thread_map_.find(this_thread);
|
||||
if (it == per_thread_map_.end()) {
|
||||
auto result = per_thread_map_.emplace(this_thread, T());
|
||||
eigen_assert(result.second);
|
||||
initialize_((*result.first).second);
|
||||
return (*result.first).second;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
Initialize initialize_;
|
||||
Release release_;
|
||||
const int capacity_;
|
||||
|
||||
// Storage that backs lock-free lookup table `ptr_`. Records stored in this
|
||||
// storage contiguously starting from index 0.
|
||||
MaxSizeVector<ThreadIdAndValue> data_;
|
||||
|
||||
// Atomic pointers to the data stored in `data_`. Used as a lookup table for
|
||||
// linear probing hash map (https://en.wikipedia.org/wiki/Linear_probing).
|
||||
MaxSizeVector<std::atomic<ThreadIdAndValue*>> ptr_;
|
||||
|
||||
// Number of records stored in the `data_`.
|
||||
std::atomic<int> filled_records_;
|
||||
|
||||
// We fallback on per thread map if lock-free storage is full. In practice
|
||||
// this should never happen, if `capacity_` is a reasonable estimate of the
|
||||
// number of threads running in a system.
|
||||
std::mutex mu_; // Protects per_thread_map_.
|
||||
std::unordered_map<std::thread::id, T> per_thread_map_;
|
||||
};
|
||||
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_THREADPOOL_THREAD_LOCAL_H
|
||||
|
@ -624,6 +624,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(
|
||||
}
|
||||
|
||||
template <typename T, typename ScalarType>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
T generic_ndtri(const T& a) {
|
||||
const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
|
||||
const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
|
||||
|
@ -201,6 +201,7 @@ if(EIGEN_TEST_CXX11)
|
||||
ei_add_test(cxx11_tensor_shuffling)
|
||||
ei_add_test(cxx11_tensor_striding)
|
||||
ei_add_test(cxx11_tensor_notification "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_thread_local "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_thread_pool "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_executor "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_ref)
|
||||
|
@ -1091,7 +1091,7 @@ void test_gpu_ndtri()
|
||||
expected_out(1) = -std::numeric_limits<Scalar>::infinity();
|
||||
expected_out(2) = Scalar(0.0);
|
||||
expected_out(3) = Scalar(-0.8416212335729142);
|
||||
expected_out(4) = Scalar(0.8416212335729142);j
|
||||
expected_out(4) = Scalar(0.8416212335729142);
|
||||
expected_out(5) = Scalar(1.2815515655446004);
|
||||
expected_out(6) = Scalar(-1.2815515655446004);
|
||||
expected_out(7) = Scalar(2.3263478740408408);
|
||||
|
149
unsupported/test/cxx11_tensor_thread_local.cpp
Normal file
149
unsupported/test/cxx11_tensor_thread_local.cpp
Normal file
@ -0,0 +1,149 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <iostream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "main.h"
|
||||
#include <Eigen/CXX11/ThreadPool>
|
||||
|
||||
struct Counter {
|
||||
Counter() = default;
|
||||
|
||||
void inc() {
|
||||
// Check that mutation happens only in a thread that created this counter.
|
||||
VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by);
|
||||
counter_value++;
|
||||
}
|
||||
int value() { return counter_value; }
|
||||
|
||||
std::thread::id created_by;
|
||||
int counter_value = 0;
|
||||
};
|
||||
|
||||
struct InitCounter {
|
||||
void operator()(Counter& counter) {
|
||||
counter.created_by = std::this_thread::get_id();
|
||||
}
|
||||
};
|
||||
|
||||
void test_simple_thread_local() {
|
||||
int num_threads = internal::random<int>(4, 32);
|
||||
Eigen::ThreadPool thread_pool(num_threads);
|
||||
Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
|
||||
|
||||
int num_tasks = 3 * num_threads;
|
||||
Eigen::Barrier barrier(num_tasks);
|
||||
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
thread_pool.Schedule([&counter, &barrier]() {
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
|
||||
counter.ForEach(
|
||||
[](std::thread::id, Counter& cnt) { VERIFY_IS_EQUAL(cnt.value(), 3); });
|
||||
}
|
||||
|
||||
void test_zero_sized_thread_local() {
|
||||
Eigen::ThreadLocal<Counter, InitCounter> counter(0, InitCounter());
|
||||
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
|
||||
int total = 0;
|
||||
counter.ForEach([&total](std::thread::id, Counter& cnt) {
|
||||
total += cnt.value();
|
||||
VERIFY_IS_EQUAL(cnt.value(), 1);
|
||||
});
|
||||
|
||||
VERIFY_IS_EQUAL(total, 1);
|
||||
}
|
||||
|
||||
// All thread local values fits into the lock-free storage.
|
||||
void test_large_number_of_tasks_no_spill() {
|
||||
int num_threads = internal::random<int>(4, 32);
|
||||
Eigen::ThreadPool thread_pool(num_threads);
|
||||
Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
|
||||
|
||||
int num_tasks = 10000;
|
||||
Eigen::Barrier barrier(num_tasks);
|
||||
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
thread_pool.Schedule([&counter, &barrier]() {
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
|
||||
int total = 0;
|
||||
std::unordered_set<std::thread::id> unique_threads;
|
||||
|
||||
counter.ForEach([&](std::thread::id id, Counter& cnt) {
|
||||
total += cnt.value();
|
||||
unique_threads.insert(id);
|
||||
});
|
||||
|
||||
VERIFY_IS_EQUAL(total, num_tasks);
|
||||
// Not all threads in a pool might be woken up to execute submitted tasks.
|
||||
// Also thread_pool.Schedule() might use current thread if queue is full.
|
||||
VERIFY_IS_EQUAL(
|
||||
unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true);
|
||||
}
|
||||
|
||||
// Lock free thread local storage is too small to fit all the unique threads,
|
||||
// and it spills to a map guarded by a mutex.
|
||||
void test_large_number_of_tasks_with_spill() {
|
||||
int num_threads = internal::random<int>(4, 32);
|
||||
Eigen::ThreadPool thread_pool(num_threads);
|
||||
Eigen::ThreadLocal<Counter, InitCounter> counter(1, InitCounter());
|
||||
|
||||
int num_tasks = 10000;
|
||||
Eigen::Barrier barrier(num_tasks);
|
||||
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
thread_pool.Schedule([&counter, &barrier]() {
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
|
||||
int total = 0;
|
||||
std::unordered_set<std::thread::id> unique_threads;
|
||||
|
||||
counter.ForEach([&](std::thread::id id, Counter& cnt) {
|
||||
total += cnt.value();
|
||||
unique_threads.insert(id);
|
||||
});
|
||||
|
||||
VERIFY_IS_EQUAL(total, num_tasks);
|
||||
// Not all threads in a pool might be woken up to execute submitted tasks.
|
||||
// Also thread_pool.Schedule() might use current thread if queue is full.
|
||||
VERIFY_IS_EQUAL(
|
||||
unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true);
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_thread_local) {
|
||||
CALL_SUBTEST(test_simple_thread_local());
|
||||
CALL_SUBTEST(test_zero_sized_thread_local());
|
||||
CALL_SUBTEST(test_large_number_of_tasks_no_spill());
|
||||
CALL_SUBTEST(test_large_number_of_tasks_with_spill());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user