faster trsm kernel and fix a couple of issues

This commit is contained in:
Gael Guennebaud 2009-07-31 13:18:19 +02:00
parent ff20a2ba94
commit a156f5a869
2 changed files with 22 additions and 36 deletions

View File

@ -49,7 +49,7 @@ struct ei_triangular_solver_selector<Lhs,Rhs,OnTheLeft,Mode,NoUnrolling,RowMajor
{ {
static const int PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH; static const int PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
ActualLhsType actualLhs = LhsProductTraits::extract(lhs); ActualLhsType actualLhs = LhsProductTraits::extract(lhs);
const int size = lhs.cols(); const int size = lhs.cols();
for(int pi=IsLowerTriangular ? 0 : size; for(int pi=IsLowerTriangular ? 0 : size;
IsLowerTriangular ? pi<size : pi>0; IsLowerTriangular ? pi<size : pi>0;
@ -224,7 +224,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
ei_assert(!(Mode & ZeroDiagBit)); ei_assert(!(Mode & ZeroDiagBit));
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit)); ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit }; enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit && RhsDerived::IsVectorAtCompileTime };
typedef typename ei_meta_if<copy, typedef typename ei_meta_if<copy,
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy; typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
RhsCopy rhsCopy(rhs); RhsCopy rhsCopy(rhs);

View File

@ -25,7 +25,7 @@
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix // if the rhs is row major, let's transpose the product
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder> template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor> struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
{ {
@ -34,22 +34,16 @@ struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,Row
const Scalar* tri, int triStride, const Scalar* tri, int triStride,
Scalar* _other, int otherStride) Scalar* _other, int otherStride)
{ {
ei_triangular_solve_matrix< ei_triangular_solve_matrix<
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft, Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
(Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular, (Mode&UnitDiagBit) | ((Mode&UpperTriangular) ? LowerTriangular : UpperTriangular),
!Conjugate, TriStorageOrder, ColMajor> NumTraits<Scalar>::IsComplex && Conjugate,
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
::run(size, cols, tri, triStride, _other, otherStride); ::run(size, cols, tri, triStride, _other, otherStride);
// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
// other.block(0,0,size,cols) = aux;
} }
}; };
/* Optimized triangular solver with multiple right hand side (_TRSM) /* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
*/ */
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder> template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor> struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
@ -190,11 +184,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
Scalar* _other, int otherStride) Scalar* _other, int otherStride)
{ {
int rows = otherSize; int rows = otherSize;
// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride); ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride); ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
typedef ei_product_blocking_traits<Scalar> Blocking; typedef ei_product_blocking_traits<Scalar> Blocking;
enum { enum {
@ -203,8 +194,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
}; };
int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize); Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
@ -214,7 +205,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs; ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel; ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel; ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
for(int k2=IsLowerTriangular ? size : 0; for(int k2=IsLowerTriangular ? size : 0;
IsLowerTriangular ? k2>0 : k2<size; IsLowerTriangular ? k2>0 : k2<size;
@ -224,7 +214,7 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2 ; int actual_k2 = IsLowerTriangular ? k2-actual_kc : k2 ;
int startPanel = IsLowerTriangular ? 0 : k2+actual_kc; int startPanel = IsLowerTriangular ? 0 : k2+actual_kc;
int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc; int rs = IsLowerTriangular ? actual_k2 : size - actual_k2 - actual_kc;
Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize; Scalar* geb = blockB+actual_kc*actual_kc*Blocking::PacketSize;
if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs); if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, -1, actual_kc, rs);
@ -239,8 +229,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0; int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2; int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
if (panelLength>0) if (panelLength>0)
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize, pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1, &rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
@ -269,7 +257,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2; int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
// GEBP // GEBP
//if (lengthTarget>0)
if(panelLength>0) if(panelLength>0)
{ {
gebp_kernel(&lhs(i2,absolute_j2), otherStride, gebp_kernel(&lhs(i2,absolute_j2), otherStride,
@ -284,18 +271,17 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
{ {
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k; int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j)); Scalar* r = &lhs(i2,j);
for (int i=0; i<actual_mc; ++i) for (int k3=0; k3<k; ++k3)
{ {
int absolute_i = i2+i; Scalar b = conj(rhs(IsLowerTriangular ? j+1+k3 : absolute_j2+k3,j));
Scalar b = 0; Scalar* a = &lhs(i2,IsLowerTriangular ? j+1+k3 : absolute_j2+k3);
for (int k3=0; k3<k; ++k3) for (int i=0; i<actual_mc; ++i)
if(IsLowerTriangular) r[i] -= a[i] * b;
b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
else
b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
} }
Scalar b = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
for (int i=0; i<actual_mc; ++i)
r[i] *= b;
} }
// pack the just computed part of lhs to A // pack the just computed part of lhs to A
@ -304,7 +290,7 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
actual_kc, j2); actual_kc, j2);
} }
} }
if (rs>0) if (rs>0)
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
actual_mc, actual_kc, rs); actual_mc, actual_kc, rs);