mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
faster trsm kernel and fix a couple of issues
This commit is contained in:
parent
ff20a2ba94
commit
a156f5a869
@ -224,7 +224,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
|||||||
ei_assert(!(Mode & ZeroDiagBit));
|
ei_assert(!(Mode & ZeroDiagBit));
|
||||||
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
|
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
|
||||||
|
|
||||||
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
|
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit && RhsDerived::IsVectorAtCompileTime };
|
||||||
typedef typename ei_meta_if<copy,
|
typedef typename ei_meta_if<copy,
|
||||||
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
|
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
|
||||||
RhsCopy rhsCopy(rhs);
|
RhsCopy rhsCopy(rhs);
|
||||||
|
@ -25,7 +25,7 @@
|
|||||||
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||||
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||||
|
|
||||||
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
|
// if the rhs is row major, let's transpose the product
|
||||||
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
|
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
|
||||||
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
|
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
|
||||||
{
|
{
|
||||||
@ -34,22 +34,16 @@ struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,Row
|
|||||||
const Scalar* tri, int triStride,
|
const Scalar* tri, int triStride,
|
||||||
Scalar* _other, int otherStride)
|
Scalar* _other, int otherStride)
|
||||||
{
|
{
|
||||||
|
|
||||||
ei_triangular_solve_matrix<
|
ei_triangular_solve_matrix<
|
||||||
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
||||||
(Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
|
(Mode&UnitDiagBit) | ((Mode&UpperTriangular) ? LowerTriangular : UpperTriangular),
|
||||||
!Conjugate, TriStorageOrder, ColMajor>
|
NumTraits<Scalar>::IsComplex && Conjugate,
|
||||||
|
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
|
||||||
::run(size, cols, tri, triStride, _other, otherStride);
|
::run(size, cols, tri, triStride, _other, otherStride);
|
||||||
|
|
||||||
// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
|
|
||||||
// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
|
|
||||||
// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
|
|
||||||
// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
|
|
||||||
// other.block(0,0,size,cols) = aux;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Optimized triangular solver with multiple right hand side (_TRSM)
|
/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
|
||||||
*/
|
*/
|
||||||
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
|
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
|
||||||
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
|
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||||
@ -190,11 +184,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
|||||||
Scalar* _other, int otherStride)
|
Scalar* _other, int otherStride)
|
||||||
{
|
{
|
||||||
int rows = otherSize;
|
int rows = otherSize;
|
||||||
// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
|
ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
|
||||||
// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
|
ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
|
||||||
|
|
||||||
Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
|
|
||||||
Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
|
|
||||||
|
|
||||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||||
enum {
|
enum {
|
||||||
@ -203,8 +194,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
|||||||
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
|
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
|
||||||
};
|
};
|
||||||
|
|
||||||
int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction
|
int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
|
||||||
int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction
|
int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
|
||||||
|
|
||||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
|
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
|
||||||
@ -214,7 +205,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
|||||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
|
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
||||||
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
|
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
|
||||||
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
|
|
||||||
|
|
||||||
for(int k2=IsLowerTriangular ? size : 0;
|
for(int k2=IsLowerTriangular ? size : 0;
|
||||||
IsLowerTriangular ? k2>0 : k2<size;
|
IsLowerTriangular ? k2>0 : k2<size;
|
||||||
@ -239,8 +229,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
|||||||
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
|
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
|
||||||
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
|
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
|
||||||
|
|
||||||
// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
|
|
||||||
|
|
||||||
if (panelLength>0)
|
if (panelLength>0)
|
||||||
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
|
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
|
||||||
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
|
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
|
||||||
@ -269,7 +257,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
|||||||
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
|
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
|
||||||
|
|
||||||
// GEBP
|
// GEBP
|
||||||
//if (lengthTarget>0)
|
|
||||||
if(panelLength>0)
|
if(panelLength>0)
|
||||||
{
|
{
|
||||||
gebp_kernel(&lhs(i2,absolute_j2), otherStride,
|
gebp_kernel(&lhs(i2,absolute_j2), otherStride,
|
||||||
@ -284,18 +271,17 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
|||||||
{
|
{
|
||||||
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
|
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
|
||||||
|
|
||||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
|
Scalar* r = &lhs(i2,j);
|
||||||
for (int i=0; i<actual_mc; ++i)
|
for (int k3=0; k3<k; ++k3)
|
||||||
{
|
{
|
||||||
int absolute_i = i2+i;
|
Scalar b = conj(rhs(IsLowerTriangular ? j+1+k3 : absolute_j2+k3,j));
|
||||||
Scalar b = 0;
|
Scalar* a = &lhs(i2,IsLowerTriangular ? j+1+k3 : absolute_j2+k3);
|
||||||
for (int k3=0; k3<k; ++k3)
|
for (int i=0; i<actual_mc; ++i)
|
||||||
if(IsLowerTriangular)
|
r[i] -= a[i] * b;
|
||||||
b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
|
|
||||||
else
|
|
||||||
b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
|
|
||||||
lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
|
|
||||||
}
|
}
|
||||||
|
Scalar b = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
|
||||||
|
for (int i=0; i<actual_mc; ++i)
|
||||||
|
r[i] *= b;
|
||||||
}
|
}
|
||||||
|
|
||||||
// pack the just computed part of lhs to A
|
// pack the just computed part of lhs to A
|
||||||
|
Loading…
x
Reference in New Issue
Block a user