mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 19:59:05 +08:00
* make the triangular matrix * matrix product works with trapezoidal matrices
* extend the trmm unit test for unit diagonal
This commit is contained in:
parent
134ca4acb3
commit
2e792d1f42
@ -49,7 +49,7 @@
|
|||||||
// }
|
// }
|
||||||
// };
|
// };
|
||||||
|
|
||||||
/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of
|
/* Optimized triangular matrix * matrix (_TRMM++) product built on top of
|
||||||
* the general matrix matrix product.
|
* the general matrix matrix product.
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, typename Index,
|
template <typename Scalar, typename Index,
|
||||||
@ -68,7 +68,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
|
|||||||
RhsStorageOrder,ConjugateRhs,RowMajor>
|
RhsStorageOrder,ConjugateRhs,RowMajor>
|
||||||
{
|
{
|
||||||
static EIGEN_STRONG_INLINE void run(
|
static EIGEN_STRONG_INLINE void run(
|
||||||
Index size, Index otherSize,
|
Index rows, Index cols, Index depth,
|
||||||
const Scalar* lhs, Index lhsStride,
|
const Scalar* lhs, Index lhsStride,
|
||||||
const Scalar* rhs, Index rhsStride,
|
const Scalar* rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
@ -82,7 +82,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,LhsIsTriangular,
|
|||||||
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
LhsStorageOrder==RowMajor ? ColMajor : RowMajor,
|
||||||
ConjugateLhs,
|
ConjugateLhs,
|
||||||
ColMajor>
|
ColMajor>
|
||||||
::run(size, otherSize, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
|
::run(rows, cols, depth, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -96,14 +96,12 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
{
|
{
|
||||||
|
|
||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index size, Index cols,
|
Index rows, Index cols, Index depth,
|
||||||
const Scalar* _lhs, Index lhsStride,
|
const Scalar* _lhs, Index lhsStride,
|
||||||
const Scalar* _rhs, Index rhsStride,
|
const Scalar* _rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
Scalar alpha)
|
Scalar alpha)
|
||||||
{
|
{
|
||||||
Index rows = size;
|
|
||||||
|
|
||||||
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||||
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||||
|
|
||||||
@ -116,8 +114,8 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
IsLower = (Mode&Lower) == Lower
|
IsLower = (Mode&Lower) == Lower
|
||||||
};
|
};
|
||||||
|
|
||||||
Index kc = std::min<Index>(Blocking::Max_kc/4,size); // cache block size along the K direction
|
Index kc = std::min<Index>(Blocking::Max_kc/4,depth); // cache block size along the K direction
|
||||||
Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction
|
Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction
|
||||||
|
|
||||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||||
std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols;
|
std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols;
|
||||||
@ -133,20 +131,27 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs;
|
ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder> pack_lhs;
|
||||||
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs;
|
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||||
|
|
||||||
for(Index k2=IsLower ? size : 0;
|
for(Index k2=IsLower ? depth : 0;
|
||||||
IsLower ? k2>0 : k2<size;
|
IsLower ? k2>0 : k2<depth;
|
||||||
IsLower ? k2-=kc : k2+=kc)
|
IsLower ? k2-=kc : k2+=kc)
|
||||||
{
|
{
|
||||||
const Index actual_kc = std::min(IsLower ? k2 : size-k2, kc);
|
Index actual_kc = std::min(IsLower ? k2 : depth-k2, kc);
|
||||||
Index actual_k2 = IsLower ? k2-actual_kc : k2;
|
Index actual_k2 = IsLower ? k2-actual_kc : k2;
|
||||||
|
|
||||||
|
if((!IsLower)&&(k2<rows)&&(k2+actual_kc>rows))
|
||||||
|
{
|
||||||
|
actual_kc = rows-k2;
|
||||||
|
k2 = k2+actual_kc-kc;
|
||||||
|
}
|
||||||
|
|
||||||
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, alpha, actual_kc, cols);
|
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, alpha, actual_kc, cols);
|
||||||
|
|
||||||
// the selected lhs's panel has to be split in three different parts:
|
// the selected lhs's panel has to be split in three different parts:
|
||||||
// 1 - the part which is above the diagonal block => skip it
|
// 1 - the part which is above the diagonal block => skip it
|
||||||
// 2 - the diagonal block => special kernel
|
// 2 - the diagonal block => special kernel
|
||||||
// 3 - the panel below the diagonal block => GEPP
|
// 3 - the panel below the diagonal block => GEPP
|
||||||
// the block diagonal
|
// the block diagonal, if any
|
||||||
|
if(IsLower || actual_k2<rows)
|
||||||
{
|
{
|
||||||
// for each small vertical panels of lhs
|
// for each small vertical panels of lhs
|
||||||
for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
|
for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
|
||||||
@ -186,7 +191,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
|
|||||||
// the part below the diagonal => GEPP
|
// the part below the diagonal => GEPP
|
||||||
{
|
{
|
||||||
Index start = IsLower ? k2 : 0;
|
Index start = IsLower ? k2 : 0;
|
||||||
Index end = IsLower ? size : actual_k2;
|
Index end = IsLower ? rows : actual_k2;
|
||||||
for(Index i2=start; i2<end; i2+=mc)
|
for(Index i2=start; i2<end; i2+=mc)
|
||||||
{
|
{
|
||||||
const Index actual_mc = std::min(i2+mc,end)-i2;
|
const Index actual_mc = std::min(i2+mc,end)-i2;
|
||||||
@ -214,14 +219,12 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
|||||||
{
|
{
|
||||||
|
|
||||||
static EIGEN_DONT_INLINE void run(
|
static EIGEN_DONT_INLINE void run(
|
||||||
Index size, Index rows,
|
Index rows, Index cols, Index depth,
|
||||||
const Scalar* _lhs, Index lhsStride,
|
const Scalar* _lhs, Index lhsStride,
|
||||||
const Scalar* _rhs, Index rhsStride,
|
const Scalar* _rhs, Index rhsStride,
|
||||||
Scalar* res, Index resStride,
|
Scalar* res, Index resStride,
|
||||||
Scalar alpha)
|
Scalar alpha)
|
||||||
{
|
{
|
||||||
Index cols = size;
|
|
||||||
|
|
||||||
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||||
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||||
|
|
||||||
@ -234,8 +237,8 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
|||||||
IsLower = (Mode&Lower) == Lower
|
IsLower = (Mode&Lower) == Lower
|
||||||
};
|
};
|
||||||
|
|
||||||
Index kc = std::min<Index>(Blocking::Max_kc/4,size); // cache block size along the K direction
|
Index kc = std::min<Index>(Blocking::Max_kc/4,depth); // cache block size along the K direction
|
||||||
Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction
|
Index mc = std::min<Index>(Blocking::Max_mc,rows); // cache block size along the M direction
|
||||||
|
|
||||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||||
std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols;
|
std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*cols;
|
||||||
@ -251,13 +254,13 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
|
|||||||
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs;
|
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||||
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
ei_gemm_pack_rhs<Scalar, Index, Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
||||||
|
|
||||||
for(Index k2=IsLower ? 0 : size;
|
for(Index k2=IsLower ? 0 : depth;
|
||||||
IsLower ? k2<size : k2>0;
|
IsLower ? k2<depth : k2>0;
|
||||||
IsLower ? k2+=kc : k2-=kc)
|
IsLower ? k2+=kc : k2-=kc)
|
||||||
{
|
{
|
||||||
const Index actual_kc = std::min(IsLower ? size-k2 : k2, kc);
|
const Index actual_kc = std::min(IsLower ? depth-k2 : k2, kc);
|
||||||
Index actual_k2 = IsLower ? k2 : k2-actual_kc;
|
Index actual_k2 = IsLower ? k2 : k2-actual_kc;
|
||||||
Index rs = IsLower ? actual_k2 : size - k2;
|
Index rs = IsLower ? actual_k2 : depth - k2;
|
||||||
Scalar* geb = blockB+actual_kc*actual_kc;
|
Scalar* geb = blockB+actual_kc*actual_kc;
|
||||||
|
|
||||||
pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, alpha, actual_kc, rs);
|
pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, alpha, actual_kc, rs);
|
||||||
@ -355,11 +358,11 @@ struct TriangularProduct<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
|||||||
(ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
(ei_traits<_ActualRhsType>::Flags&RowMajorBit) ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||||
(ei_traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
(ei_traits<Dest >::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||||
::run(
|
::run(
|
||||||
lhs.rows(), LhsIsTriangular ? rhs.cols() : lhs.rows(), // sizes
|
lhs.rows(), rhs.cols(), lhs.cols(),// LhsIsTriangular ? rhs.cols() : lhs.rows(), // sizes
|
||||||
&lhs.coeff(0,0), lhs.outerStride(), // lhs info
|
&lhs.coeff(0,0), lhs.outerStride(), // lhs info
|
||||||
&rhs.coeff(0,0), rhs.outerStride(), // rhs info
|
&rhs.coeff(0,0), rhs.outerStride(), // rhs info
|
||||||
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
||||||
actualAlpha // alpha
|
actualAlpha // alpha
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -28,8 +28,11 @@ template<typename Scalar> void trmm(int size,int othersize)
|
|||||||
{
|
{
|
||||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||||
|
|
||||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> tri(size,size), upTri(size,size), loTri(size,size);
|
typedef Matrix<Scalar,Dynamic,Dynamic,ColMajor> MatrixType;
|
||||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> ge1(size,othersize), ge2(10,size), ge3;
|
|
||||||
|
MatrixType tri(size,size), upTri(size,size), loTri(size,size),
|
||||||
|
unitUpTri(size,size), unitLoTri(size,size);
|
||||||
|
MatrixType ge1(size,othersize), ge2(10,size), ge3;
|
||||||
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rge3;
|
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rge3;
|
||||||
|
|
||||||
Scalar s1 = ei_random<Scalar>(),
|
Scalar s1 = ei_random<Scalar>(),
|
||||||
@ -38,6 +41,8 @@ template<typename Scalar> void trmm(int size,int othersize)
|
|||||||
tri.setRandom();
|
tri.setRandom();
|
||||||
loTri = tri.template triangularView<Lower>();
|
loTri = tri.template triangularView<Lower>();
|
||||||
upTri = tri.template triangularView<Upper>();
|
upTri = tri.template triangularView<Upper>();
|
||||||
|
unitLoTri = tri.template triangularView<UnitLower>();
|
||||||
|
unitUpTri = tri.template triangularView<UnitUpper>();
|
||||||
ge1.setRandom();
|
ge1.setRandom();
|
||||||
ge2.setRandom();
|
ge2.setRandom();
|
||||||
|
|
||||||
@ -57,6 +62,10 @@ template<typename Scalar> void trmm(int size,int othersize)
|
|||||||
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<Upper>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint());
|
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<Upper>() * ge2.adjoint(), loTri.adjoint() * ge2.adjoint());
|
||||||
VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<Lower>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
|
VERIFY_IS_APPROX( ge3 = tri.adjoint().template triangularView<Lower>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
|
||||||
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<Lower>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
|
VERIFY_IS_APPROX(rge3 = tri.adjoint().template triangularView<Lower>() * ge2.adjoint(), upTri.adjoint() * ge2.adjoint());
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX( ge3 = tri.template triangularView<UnitLower>() * ge1, unitLoTri * ge1);
|
||||||
|
VERIFY_IS_APPROX(rge3 = tri.template triangularView<UnitLower>() * ge1, unitLoTri * ge1);
|
||||||
|
VERIFY_IS_APPROX( ge3 = (s1*tri).adjoint().template triangularView<UnitUpper>() * ge2.adjoint(), ei_conj(s1) * unitLoTri.adjoint() * ge2.adjoint());
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_product_trmm()
|
void test_product_trmm()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user