Eliminate use of _res.

This commit is contained in:
Antonio Sánchez 2023-10-16 19:56:53 +00:00
parent a96545777b
commit 5bdf58b8df
7 changed files with 99 additions and 95 deletions

View File

@ -62,9 +62,9 @@ typedef gebp_traits<LhsScalar,RhsScalar> Traits;
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static void run(Index rows, Index cols, Index depth, static void run(Index rows, Index cols, Index depth,
const LhsScalar* _lhs, Index lhsStride, const LhsScalar* lhs_, Index lhsStride,
const RhsScalar* _rhs, Index rhsStride, const RhsScalar* rhs_, Index rhsStride,
ResScalar* _res, Index resIncr, Index resStride, ResScalar* res_, Index resIncr, Index resStride,
ResScalar alpha, ResScalar alpha,
level3_blocking<LhsScalar,RhsScalar>& blocking, level3_blocking<LhsScalar,RhsScalar>& blocking,
GemmParallelInfo<Index>* info = 0) GemmParallelInfo<Index>* info = 0)
@ -72,9 +72,9 @@ static void run(Index rows, Index cols, Index depth,
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor,Unaligned,ResInnerStride> ResMapper;
LhsMapper lhs(_lhs, lhsStride); LhsMapper lhs(lhs_, lhsStride);
RhsMapper rhs(_rhs, rhsStride); RhsMapper rhs(rhs_, rhsStride);
ResMapper res(_res, resStride, resIncr); ResMapper res(res_, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction

View File

@ -63,9 +63,9 @@ template <typename Index, typename LhsScalar, int LhsStorageOrder, bool Conjugat
struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version> struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,ColMajor,ResInnerStride,UpLo,Version>
{ {
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride, static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs_, Index lhsStride,
const RhsScalar* _rhs, Index rhsStride, const RhsScalar* rhs_, Index rhsStride,
ResScalar* _res, Index resIncr, Index resStride, ResScalar* res_, Index resIncr, Index resStride,
const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking) const ResScalar& alpha, level3_blocking<LhsScalar,RhsScalar>& blocking)
{ {
typedef gebp_traits<LhsScalar,RhsScalar> Traits; typedef gebp_traits<LhsScalar,RhsScalar> Traits;
@ -73,9 +73,9 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper; typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride); LhsMapper lhs(lhs_,lhsStride);
RhsMapper rhs(_rhs,rhsStride); RhsMapper rhs(rhs_,rhsStride);
ResMapper res(_res, resStride, resIncr); ResMapper res(res_, resStride, resIncr);
Index kc = blocking.kc(); Index kc = blocking.kc();
Index mc = (std::min)(size,blocking.mc()); Index mc = (std::min)(size,blocking.mc());
@ -116,7 +116,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder,
gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc,
(std::min)(size,i2), alpha, -1, -1, 0, 0); (std::min)(size,i2), alpha, -1, -1, 0, 0);
sybb(_res+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha); sybb(res_+resStride*i2 + resIncr*i2, resIncr, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha);
if (UpLo==Upper) if (UpLo==Upper)
{ {
@ -147,11 +147,11 @@ struct tribb_kernel
enum { enum {
BlockSize = meta_least_common_multiple<plain_enum_max(mr, nr), plain_enum_min(mr,nr)>::ret BlockSize = meta_least_common_multiple<plain_enum_max(mr, nr), plain_enum_min(mr,nr)>::ret
}; };
void operator()(ResScalar* _res, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha) void operator()(ResScalar* res_, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha)
{ {
typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper; typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
ResMapper res(_res, resStride, resIncr); ResMapper res(res_, resStride, resIncr);
gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1; gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2; gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;

View File

@ -46,7 +46,7 @@ struct symm_pack_lhs
for(Index w=0; w<BlockRows; w++) for(Index w=0; w<BlockRows; w++)
blockA[count++] = numext::conj(lhs(k, i+w)); // transposed blockA[count++] = numext::conj(lhs(k, i+w)); // transposed
} }
void operator()(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) void operator()(Scalar* blockA, const Scalar* lhs_, Index lhsStride, Index cols, Index rows)
{ {
typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket; typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket;
typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half QuarterPacket; typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half QuarterPacket;
@ -56,7 +56,7 @@ struct symm_pack_lhs
HasHalf = (int)HalfPacketSize < (int)PacketSize, HasHalf = (int)HalfPacketSize < (int)PacketSize,
HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize}; HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize};
const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(_lhs,lhsStride); const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(lhs_,lhsStride);
Index count = 0; Index count = 0;
//Index peeled_mc3 = (rows/Pack1)*Pack1; //Index peeled_mc3 = (rows/Pack1)*Pack1;
@ -104,11 +104,11 @@ template<typename Scalar, typename Index, int nr, int StorageOrder>
struct symm_pack_rhs struct symm_pack_rhs
{ {
enum { PacketSize = packet_traits<Scalar>::size }; enum { PacketSize = packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2) void operator()(Scalar* blockB, const Scalar* rhs_, Index rhsStride, Index rows, Index cols, Index k2)
{ {
Index end_k = k2 + rows; Index end_k = k2 + rows;
Index count = 0; Index count = 0;
const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(_rhs,rhsStride); const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(rhs_,rhsStride);
Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
@ -333,8 +333,8 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* _lhs, Index lhsStride, const Scalar* lhs_, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* rhs_, Index rhsStride,
Scalar* res, Index resIncr, Index resStride, Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking); const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
}; };
@ -345,9 +345,9 @@ template <typename Scalar, typename Index,
int ResInnerStride> int ResInnerStride>
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run( EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* _lhs, Index lhsStride, const Scalar* lhs_, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* rhs_, Index rhsStride,
Scalar* _res, Index resIncr, Index resStride, Scalar* res_, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
Index size = rows; Index size = rows;
@ -358,10 +358,10 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,t
typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper; typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper; typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride); LhsMapper lhs(lhs_,lhsStride);
LhsTransposeMapper lhs_transpose(_lhs,lhsStride); LhsTransposeMapper lhs_transpose(lhs_,lhsStride);
RhsMapper rhs(_rhs,rhsStride); RhsMapper rhs(rhs_,rhsStride);
ResMapper res(_res, resStride, resIncr); ResMapper res(res_, resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@ -428,8 +428,8 @@ struct product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLh
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* _lhs, Index lhsStride, const Scalar* lhs_, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* rhs_, Index rhsStride,
Scalar* res, Index resIncr, Index resStride, Scalar* res, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking); const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
}; };
@ -440,9 +440,9 @@ template <typename Scalar, typename Index,
int ResInnerStride> int ResInnerStride>
EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run( EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
Index rows, Index cols, Index rows, Index cols,
const Scalar* _lhs, Index lhsStride, const Scalar* lhs_, Index lhsStride,
const Scalar* _rhs, Index rhsStride, const Scalar* rhs_, Index rhsStride,
Scalar* _res, Index resIncr, Index resStride, Scalar* res_, Index resIncr, Index resStride,
const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking) const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
{ {
Index size = cols; Index size = cols;
@ -451,8 +451,8 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f
typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper; typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper; typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
LhsMapper lhs(_lhs,lhsStride); LhsMapper lhs(lhs_,lhsStride);
ResMapper res(_res,resStride, resIncr); ResMapper res(res_,resStride, resIncr);
Index kc = blocking.kc(); // cache block size along the K direction Index kc = blocking.kc(); // cache block size along the K direction
Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction Index mc = (std::min)(rows,blocking.mc()); // cache block size along the M direction
@ -469,7 +469,7 @@ EIGEN_DONT_INLINE void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,f
{ {
const Index actual_kc = (std::min)(k2+kc,size)-k2; const Index actual_kc = (std::min)(k2+kc,size)-k2;
pack_rhs(blockB, _rhs, rhsStride, actual_kc, cols, k2); pack_rhs(blockB, rhs_, rhsStride, actual_kc, cols, k2);
// => GEPP // => GEPP
for(Index i2=0; i2<rows; i2+=mc) for(Index i2=0; i2<rows; i2+=mc)

View File

@ -27,15 +27,15 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
static constexpr bool IsLower = ((Mode & Lower) == Lower); static constexpr bool IsLower = ((Mode & Lower) == Lower);
static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag; static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag;
static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag; static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag;
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* lhs_, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar* rhs_, Index rhsIncr, ResScalar* res_, Index resIncr,
const RhsScalar& alpha); const RhsScalar& alpha);
}; };
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version> template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version> EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor,Version>
::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, ::run(Index _rows, Index _cols, const LhsScalar* lhs_, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar& alpha) const RhsScalar* rhs_, Index rhsIncr, ResScalar* res_, Index resIncr, const RhsScalar& alpha)
{ {
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
Index size = (std::min)(_rows,_cols); Index size = (std::min)(_rows,_cols);
@ -43,15 +43,15 @@ EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,Con
Index cols = IsLower ? (std::min)(_rows,_cols) : _cols; Index cols = IsLower ? (std::min)(_rows,_cols) : _cols;
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap; typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); const LhsMap lhs(lhs_,rows,cols,OuterStride<>(lhsStride));
typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap; typedef Map<const Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr)); const RhsMap rhs(rhs_,cols,InnerStride<>(rhsIncr));
typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap; typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
ResMap res(_res,rows); ResMap res(res_,rows);
typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper; typedef const_blas_data_mapper<LhsScalar,Index,ColMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;
@ -86,7 +86,7 @@ EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,Con
rows, cols-size, rows, cols-size,
LhsMapper(&lhs.coeffRef(0,size), lhsStride), LhsMapper(&lhs.coeffRef(0,size), lhsStride),
RhsMapper(&rhs.coeffRef(size), rhsIncr), RhsMapper(&rhs.coeffRef(size), rhsIncr),
_res, resIncr, alpha); res_, resIncr, alpha);
} }
} }
@ -97,15 +97,15 @@ struct triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,C
static constexpr bool IsLower = ((Mode & Lower) == Lower); static constexpr bool IsLower = ((Mode & Lower) == Lower);
static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag; static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag;
static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag; static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag;
static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* lhs_, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const RhsScalar* rhs_, Index rhsIncr, ResScalar* res_, Index resIncr,
const ResScalar& alpha); const ResScalar& alpha);
}; };
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version> template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,int Version>
EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version> EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor,Version>
::run(Index _rows, Index _cols, const LhsScalar* _lhs, Index lhsStride, ::run(Index _rows, Index _cols, const LhsScalar* lhs_, Index lhsStride,
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, const ResScalar& alpha) const RhsScalar* rhs_, Index rhsIncr, ResScalar* res_, Index resIncr, const ResScalar& alpha)
{ {
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
Index diagSize = (std::min)(_rows,_cols); Index diagSize = (std::min)(_rows,_cols);
@ -113,15 +113,15 @@ EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index,Mode,LhsScalar,Con
Index cols = IsLower ? diagSize : _cols; Index cols = IsLower ? diagSize : _cols;
typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap; typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride)); const LhsMap lhs(lhs_,rows,cols,OuterStride<>(lhsStride));
typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs); typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap; typedef Map<const Matrix<RhsScalar,Dynamic,1> > RhsMap;
const RhsMap rhs(_rhs,cols); const RhsMap rhs(rhs_,cols);
typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs); typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap; typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
ResMap res(_res,rows,InnerStride<>(resIncr)); ResMap res(res_,rows,InnerStride<>(resIncr));
typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper; typedef const_blas_data_mapper<LhsScalar,Index,RowMajor> LhsMapper;
typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper; typedef const_blas_data_mapper<RhsScalar,Index,RowMajor> RhsMapper;

View File

@ -53,18 +53,18 @@ struct triangular_matrix_vector_product_trmv :
#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \ #define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \ struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor,Specialized> { \
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \ static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, \
const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \ const Scalar* rhs_, Index rhsIncr, Scalar* res_, Index resIncr, Scalar alpha) { \
triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \ triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,ColMajor>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
} \ } \
}; \ }; \
template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \ template<typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \ struct triangular_matrix_vector_product<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor,Specialized> { \
static void run(Index _rows, Index _cols, const Scalar* _lhs, Index lhsStride, \ static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, \
const Scalar* _rhs, Index rhsIncr, Scalar* _res, Index resIncr, Scalar alpha) { \ const Scalar* rhs_, Index rhsIncr, Scalar* res_, Index resIncr, Scalar alpha) { \
triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \ triangular_matrix_vector_product_trmv<Index,Mode,Scalar,ConjLhs,Scalar,ConjRhs,RowMajor>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
} \ } \
}; };
@ -84,23 +84,23 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
LowUp = IsLower ? Lower : Upper \ LowUp = IsLower ? Lower : Upper \
}; \ }; \
static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \ static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \ const EIGTYPE* rhs_, Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) \
{ \ { \
if (ConjLhs || IsZeroDiag) { \ if (ConjLhs || IsZeroDiag) { \
triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \ triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,ColMajor,BuiltIn>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
return; \ return; \
}\ }\
Index size = (std::min)(_rows,_cols); \ Index size = (std::min)(rows_,cols_); \
Index rows = IsLower ? _rows : size; \ Index rows = IsLower ? rows_ : size; \
Index cols = IsLower ? size : _cols; \ Index cols = IsLower ? size : cols_; \
\ \
typedef VectorX##EIGPREFIX VectorRhs; \ typedef VectorX##EIGPREFIX VectorRhs; \
EIGTYPE *x, *y;\ EIGTYPE *x, *y;\
\ \
/* Set x*/ \ /* Set x*/ \
Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \ Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_,cols,InnerStride<>(rhsIncr)); \
VectorRhs x_tmp; \ VectorRhs x_tmp; \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \ x = x_tmp.data(); \
@ -124,24 +124,24 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \ diag = IsUnitDiag ? 'U' : 'N'; \
\ \
/* call ?TRMV*/ \ /* call ?TRMV*/ \
BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
\ \
/* Add op(a_tr)rhs into res*/ \ /* Add op(a_tr)rhs into res*/ \
BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)res_, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \ /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \ if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \ x = x_tmp.data(); \
if (size<rows) { \ if (size<rows) { \
y = _res + size*resIncr; \ y = res_ + size*resIncr; \
a = _lhs + size; \ a = lhs_ + size; \
m = convert_index<BlasIndex>(rows-size); \ m = convert_index<BlasIndex>(rows-size); \
n = convert_index<BlasIndex>(size); \ n = convert_index<BlasIndex>(size); \
} \ } \
else { \ else { \
x += size; \ x += size; \
y = _res; \ y = res_; \
a = _lhs + size*lda; \ a = lhs_ + size*lda; \
m = convert_index<BlasIndex>(size); \ m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \ n = convert_index<BlasIndex>(cols-size); \
} \ } \
@ -173,23 +173,23 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \ IsZeroDiag = (Mode&ZeroDiag) ? 1 : 0, \
LowUp = IsLower ? Lower : Upper \ LowUp = IsLower ? Lower : Upper \
}; \ }; \
static void run(Index _rows, Index _cols, const EIGTYPE* _lhs, Index lhsStride, \ static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, \
const EIGTYPE* _rhs, Index rhsIncr, EIGTYPE* _res, Index resIncr, EIGTYPE alpha) \ const EIGTYPE* rhs_, Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) \
{ \ { \
if (IsZeroDiag) { \ if (IsZeroDiag) { \
triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \ triangular_matrix_vector_product<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,ConjRhs,RowMajor,BuiltIn>::run( \
_rows, _cols, _lhs, lhsStride, _rhs, rhsIncr, _res, resIncr, alpha); \ rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
return; \ return; \
}\ }\
Index size = (std::min)(_rows,_cols); \ Index size = (std::min)(rows_,cols_); \
Index rows = IsLower ? _rows : size; \ Index rows = IsLower ? rows_ : size; \
Index cols = IsLower ? size : _cols; \ Index cols = IsLower ? size : cols_; \
\ \
typedef VectorX##EIGPREFIX VectorRhs; \ typedef VectorX##EIGPREFIX VectorRhs; \
EIGTYPE *x, *y;\ EIGTYPE *x, *y;\
\ \
/* Set x*/ \ /* Set x*/ \
Map<const VectorRhs, 0, InnerStride<> > rhs(_rhs,cols,InnerStride<>(rhsIncr)); \ Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_,cols,InnerStride<>(rhsIncr)); \
VectorRhs x_tmp; \ VectorRhs x_tmp; \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \ x = x_tmp.data(); \
@ -213,24 +213,24 @@ struct triangular_matrix_vector_product_trmv<Index,Mode,EIGTYPE,ConjLhs,EIGTYPE,
diag = IsUnitDiag ? 'U' : 'N'; \ diag = IsUnitDiag ? 'U' : 'N'; \
\ \
/* call ?TRMV*/ \ /* call ?TRMV*/ \
BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)_lhs, &lda, (BLASTYPE*)x, &incx); \ BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
\ \
/* Add op(a_tr)rhs into res*/ \ /* Add op(a_tr)rhs into res*/ \
BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)_res, &incy); \ BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha),(const BLASTYPE*)x, &incx, (BLASTYPE*)res_, &incy); \
/* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \ /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
if (size<(std::max)(rows,cols)) { \ if (size<(std::max)(rows,cols)) { \
if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \ if (ConjRhs) x_tmp = rhs.conjugate(); else x_tmp = rhs; \
x = x_tmp.data(); \ x = x_tmp.data(); \
if (size<rows) { \ if (size<rows) { \
y = _res + size*resIncr; \ y = res_ + size*resIncr; \
a = _lhs + size*lda; \ a = lhs_ + size*lda; \
m = convert_index<BlasIndex>(rows-size); \ m = convert_index<BlasIndex>(rows-size); \
n = convert_index<BlasIndex>(size); \ n = convert_index<BlasIndex>(size); \
} \ } \
else { \ else { \
x += size; \ x += size; \
y = _res; \ y = res_; \
a = _lhs + size; \ a = lhs_ + size; \
m = convert_index<BlasIndex>(size); \ m = convert_index<BlasIndex>(size); \
n = convert_index<BlasIndex>(cols-size); \ n = convert_index<BlasIndex>(cols-size); \
} \ } \

View File

@ -93,9 +93,9 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,C
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
{ {
remove_all_t<ResultType> _res(res.rows(), res.cols()); remove_all_t<ResultType> res_(res.rows(), res.cols());
internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, res_, tolerance);
res.swap(_res); res.swap(res_);
} }
}; };
@ -107,9 +107,9 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,C
{ {
// we need a col-major matrix to hold the result // we need a col-major matrix to hold the result
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType; typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
SparseTemporaryType _res(res.rows(), res.cols()); SparseTemporaryType res_(res.rows(), res.cols());
internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, res_, tolerance);
res = _res; res = res_;
} }
}; };
@ -120,9 +120,9 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
{ {
// let's transpose the product to get a column x column product // let's transpose the product to get a column x column product
remove_all_t<ResultType> _res(res.rows(), res.cols()); remove_all_t<ResultType> res_(res.rows(), res.cols());
internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, res_, tolerance);
res.swap(_res); res.swap(res_);
} }
}; };
@ -140,9 +140,9 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
// let's transpose the product to get a column x column product // let's transpose the product to get a column x column product
// typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
// SparseTemporaryType _res(res.cols(), res.rows()); // SparseTemporaryType res_(res.cols(), res.rows());
// sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, res_);
// res = _res.transpose(); // res = res_.transpose();
} }
}; };

View File

@ -125,6 +125,10 @@ struct imag {};
// B0 is defined in POSIX header termios.h // B0 is defined in POSIX header termios.h
#define B0 FORBIDDEN_IDENTIFIER #define B0 FORBIDDEN_IDENTIFIER
#define I FORBIDDEN_IDENTIFIER #define I FORBIDDEN_IDENTIFIER
// _res is defined by resolv.h
#define _res FORBIDDEN_IDENTIFIER
// Unit tests calling Eigen's blas library must preserve the default blocking size // Unit tests calling Eigen's blas library must preserve the default blocking size
// to avoid troubles. // to avoid troubles.
#ifndef EIGEN_NO_DEBUG_SMALL_PRODUCT_BLOCKS #ifndef EIGEN_NO_DEBUG_SMALL_PRODUCT_BLOCKS