faster trsm kernel and fix a couple of issues

This commit is contained in:
Gael Guennebaud 2009-07-31 13:18:19 +02:00
parent ff20a2ba94
commit a156f5a869
2 changed files with 22 additions and 36 deletions

View File

@ -224,7 +224,7 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<RhsDerived>&
ei_assert(!(Mode & ZeroDiagBit));
ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit));
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit };
enum { copy = ei_traits<RhsDerived>::Flags & RowMajorBit && RhsDerived::IsVectorAtCompileTime };
typedef typename ei_meta_if<copy,
typename ei_plain_matrix_type_column_major<RhsDerived>::type, RhsDerived&>::ret RhsCopy;
RhsCopy rhsCopy(rhs);

View File

@ -25,7 +25,7 @@
#ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
// if the rhs is row major, let's transpose the product
template <typename Scalar, int Side, int Mode, bool Conjugate, int TriStorageOrder>
struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
{
@ -34,22 +34,16 @@ struct ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,Row
const Scalar* tri, int triStride,
Scalar* _other, int otherStride)
{
ei_triangular_solve_matrix<
Scalar, Side==OnTheLeft?OnTheRight:OnTheLeft,
(Mode&UnitDiagBit) | (Mode&UpperTriangular) ? LowerTriangular : UpperTriangular,
!Conjugate, TriStorageOrder, ColMajor>
(Mode&UnitDiagBit) | ((Mode&UpperTriangular) ? LowerTriangular : UpperTriangular),
NumTraits<Scalar>::IsComplex && Conjugate,
TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
::run(size, cols, tri, triStride, _other, otherStride);
// Map<Matrix<Scalar,Dynamic,Dynamic> > other(_other, otherStride, cols);
// Matrix<Scalar,Dynamic,Dynamic> aux = other.block(0,0,size,cols);
// ei_triangular_solve_matrix<Scalar,Side,Mode,Conjugate,TriStorageOrder,ColMajor>
// ::run(size, cols, tri, triStride, aux.data(), aux.stride());
// other.block(0,0,size,cols) = aux;
}
};
/* Optimized triangular solver with multiple right hand side (_TRSM)
/* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
*/
template <typename Scalar, int Mode, bool Conjugate, int TriStorageOrder>
struct ei_triangular_solve_matrix<Scalar,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
@ -190,11 +184,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
Scalar* _other, int otherStride)
{
int rows = otherSize;
// ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
// ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
Map<Matrix<Scalar,Dynamic,Dynamic,TriStorageOrder> > rhs(_tri,size,size);
Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor> > lhs(_other,rows,size);
ei_const_blas_data_mapper<Scalar, TriStorageOrder> rhs(_tri,triStride);
ei_blas_data_mapper<Scalar, ColMajor> lhs(_other,otherStride);
typedef ei_product_blocking_traits<Scalar> Blocking;
enum {
@ -203,8 +194,8 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
IsLowerTriangular = (Mode&LowerTriangular) == LowerTriangular
};
int kc = std::min<int>(/*Blocking::Max_kc/4*/32,size); // cache block size along the K direction
int mc = std::min<int>(/*Blocking::Max_mc*/32,size); // cache block size along the M direction
int kc = std::min<int>(Blocking::Max_kc/4,size); // cache block size along the K direction
int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
@ -214,7 +205,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder> pack_rhs;
ei_gemm_pack_rhs<Scalar,Blocking::nr,RhsStorageOrder,true> pack_rhs_panel;
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false, true> pack_lhs_panel;
ei_gemm_pack_lhs<Scalar, Blocking::mr, ColMajor, false> pack_lhs;
for(int k2=IsLowerTriangular ? size : 0;
IsLowerTriangular ? k2>0 : k2<size;
@ -239,8 +229,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int panelOffset = IsLowerTriangular ? j2+actualPanelWidth : 0;
int panelLength = IsLowerTriangular ? actual_kc-j2-actualPanelWidth : j2;
// std::cerr << "$ " << k2 << " " << j2 << " " << actual_j2 << " " << panelOffset << " " << panelLength << "\n";
if (panelLength>0)
pack_rhs_panel(blockB+j2*actual_kc*Blocking::PacketSize,
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1,
@ -269,7 +257,6 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
int panelLength = IsLowerTriangular ? actual_kc - j2 - actualPanelWidth : j2;
// GEBP
//if (lengthTarget>0)
if(panelLength>0)
{
gebp_kernel(&lhs(i2,absolute_j2), otherStride,
@ -284,18 +271,17 @@ struct ei_triangular_solve_matrix<Scalar,OnTheRight,Mode,Conjugate,TriStorageOrd
{
int j = IsLowerTriangular ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
for (int i=0; i<actual_mc; ++i)
{
int absolute_i = i2+i;
Scalar b = 0;
Scalar* r = &lhs(i2,j);
for (int k3=0; k3<k; ++k3)
if(IsLowerTriangular)
b += lhs(absolute_i,j+1+k3) * conj(rhs(j+1+k3,j));
else
b += lhs(absolute_i,absolute_j2+k3) * conj(rhs(absolute_j2+k3,j));
lhs(absolute_i,j) = (lhs(absolute_i,j) - b)*a;
{
Scalar b = conj(rhs(IsLowerTriangular ? j+1+k3 : absolute_j2+k3,j));
Scalar* a = &lhs(i2,IsLowerTriangular ? j+1+k3 : absolute_j2+k3);
for (int i=0; i<actual_mc; ++i)
r[i] -= a[i] * b;
}
Scalar b = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
for (int i=0; i<actual_mc; ++i)
r[i] *= b;
}
// pack the just computed part of lhs to A