mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
faster trsm kernel and fix a couple of issues
This commit is contained in:
parent
ff20a2ba94
commit
a156f5a869
@ -224,7 +224,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
|
||||
ei_assert(!(Mode & ZeroDiagBit));
|
||||
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
|
||||
|
||||
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
|
||||
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit && RhsDerived::IsVectorAtCompileTime };
|
||||
typedef typename ei_meta_if<copy,
|
||||
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
|
||||
RhsCopy rhsCopy(rhs);
|
||||
|
@ -25,7 +25,7 @@
|
||||
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||
|
||||
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
|
||||
// if the rhs is row major, let's transpose the product
|
||||
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
|
||||
{
|
||||
@ -34,22 +34,16 @@ struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,Row
|
||||
const Scalar* tri, int triStride,
|
||||
Scalar* _other, int otherStride)
|
||||
{
|
||||
|
||||
ei_triangular_solve_matrix<
|
||||
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
|
||||
(Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
|
||||
!Conjugate, TriStorageOrder, ColMajor>
|
||||
(Mode&UnitDiagBit) | ((Mode&UpperTriangular) ? LowerTriangular : UpperTriangular),
|
||||
NumTraits<Scalar>::IsComplex && Conjugate,
|
||||
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
|
||||
::run(size, cols, tri, triStride, _other, otherStride);
|
||||
|
||||
// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
|
||||
// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
|
||||
// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
|
||||
// other.block(0,0,size,cols) = aux;
|
||||
}
|
||||
};
|
||||
|
||||
/* Optimized triangular solver with multiple right hand side (_TRSM)
|
||||
/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
|
||||
*/
|
||||
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
|
||||
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
|
||||
@ -190,11 +184,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
||||
Scalar* _other, int otherStride)
|
||||
{
|
||||
int rows = otherSize;
|
||||
// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
|
||||
// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
|
||||
|
||||
Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
|
||||
Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
|
||||
ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
|
||||
ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
|
||||
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
enum {
|
||||
@ -203,8 +194,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
||||
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
|
||||
};
|
||||
|
||||
int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction
|
||||
int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction
|
||||
int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
|
||||
int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
|
||||
|
||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
|
||||
@ -214,7 +205,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
|
||||
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
|
||||
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
|
||||
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
|
||||
|
||||
for(int k2=IsLowerTriangular ? size : 0;
|
||||
IsLowerTriangular ? k2>0 : k2<size;
|
||||
@ -239,8 +229,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
||||
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
|
||||
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
|
||||
|
||||
// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
|
||||
|
||||
if (panelLength>0)
|
||||
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
|
||||
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
|
||||
@ -269,7 +257,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
||||
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
|
||||
|
||||
// GEBP
|
||||
//if (lengthTarget>0)
|
||||
if(panelLength>0)
|
||||
{
|
||||
gebp_kernel(&lhs(i2,absolute_j2), otherStride,
|
||||
@ -284,18 +271,17 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
|
||||
{
|
||||
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
|
||||
|
||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
|
||||
for (int i=0; i<actual_mc; ++i)
|
||||
{
|
||||
int absolute_i = i2+i;
|
||||
Scalar b = 0;
|
||||
Scalar* r = &lhs(i2,j);
|
||||
for (int k3=0; k3<k; ++k3)
|
||||
if(IsLowerTriangular)
|
||||
b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
|
||||
else
|
||||
b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
|
||||
lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
|
||||
{
|
||||
Scalar b = conj(rhs(IsLowerTriangular ? j+1+k3 : absolute_j2+k3,j));
|
||||
Scalar* a = &lhs(i2,IsLowerTriangular ? j+1+k3 : absolute_j2+k3);
|
||||
for (int i=0; i<actual_mc; ++i)
|
||||
r[i] -= a[i] * b;
|
||||
}
|
||||
Scalar b = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
|
||||
for (int i=0; i<actual_mc; ++i)
|
||||
r[i] *= b;
|
||||
}
|
||||
|
||||
// pack the just computed part of lhs to A
|
||||
|
Loading…
x
Reference in New Issue
Block a user