finalize trsm: works in all situations, and it is now used by solve() and solveInPlace()

This commit is contained in:
Gael Guennebaud 2009-07-26 13:01:37 +02:00
parent 282e18da49
commit f3fde74695
4 changed files with 185 additions and 53 deletions

View File

@ -29,13 +29,14 @@ template<typename Lhs, typename Rhs,
int Mode, // can be Upper/Lower | UnitDiag int Mode, // can be Upper/Lower | UnitDiag
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
? CompleteUnrolling : NoUnrolling, ? CompleteUnrolling : NoUnrolling,
int StorageOrder = int(Lhs::Flags) & RowMajorBit int StorageOrder = int(Lhs::Flags) & RowMajorBit,
int RhsCols = Rhs::ColsAtCompileTime
> >
struct ei_triangular_solver_selector; struct ei_triangular_solver_selector;
// forward and backward substitution, row-major // forward and backward substitution, row-major, rhs is a vector
template<typename Lhs, typename Rhs, int Mode> template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor> struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor,1>
{ {
typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Scalar Scalar;
typedef ei_blas_traits<Lhs> LhsProductTraits; typedef ei_blas_traits<Lhs> LhsProductTraits;
@ -89,9 +90,9 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
} }
}; };
// forward and backward substitution, column-major // forward and backward substitution, column-major, rhs is a vector
template<typename Lhs, typename Rhs, int Mode> template<typename Lhs, typename Rhs, int Mode>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor> struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor,1>
{ {
typedef typename Rhs::Scalar Scalar; typedef typename Rhs::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type Packet; typedef typename ei_packet_traits<Scalar>::type Packet;
@ -150,6 +151,24 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
} }
}; };
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, int Mode>
struct ei_triangular_solve_matrix;
// the rhs is a matrix
template<typename Lhs, typename Rhs, int Mode, int StorageOrder, int RhsCols>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,StorageOrder,RhsCols>
{
typedef typename Rhs::Scalar Scalar;
typedef ei_blas_traits<Lhs> LhsProductTraits;
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
static void run(const Lhs& lhs, Rhs& rhs)
{
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
ei_triangular_solve_matrix<Scalar,StorageOrder,LhsProductTraits::NeedToConjugate,Rhs::Flags&RowMajorBit,Mode>
::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride());
}
};
/*************************************************************************** /***************************************************************************
* meta-unrolling implementation * meta-unrolling implementation
***************************************************************************/ ***************************************************************************/
@ -184,7 +203,7 @@ struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
}; };
template<typename Lhs, typename Rhs, int Mode, int StorageOrder> template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> { struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder,1> {
static void run(const Lhs& lhs, Rhs& rhs) static void run(const Lhs& lhs, Rhs& rhs)
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); } { ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
}; };

View File

@ -26,63 +26,37 @@
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
template<typename Scalar, int nr> template<typename Scalar, int nr>
struct ei_gemm_pack_rhs_panel struct ei_gemm_pack_rhs_panel;
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,RowMajor,Mode>
{ {
enum { PacketSize = ei_packet_traits<Scalar>::size }; static EIGEN_DONT_INLINE void run(
void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset) int size, int cols,
const Scalar* lhs, int lhsStride,
Scalar* _rhs, int rhsStride)
{ {
int packet_cols = (cols/nr) * nr; Map<Matrix<Scalar,Dynamic,Dynamic> > rhs(_rhs, rhsStride, cols);
int count = 0; Matrix<Scalar,Dynamic,Dynamic> aux = rhs.block(0,0,size,cols);
for(int j2=0; j2<packet_cols; j2+=nr) ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
{ ::run(size, cols, lhs, lhsStride, aux.data(), aux.stride());
// skip what we have before rhs.block(0,0,size,cols) = aux;
count += PacketSize * nr * offset;
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
const Scalar* b1 = &rhs[(j2+1)*rhsStride];
const Scalar* b2 = &rhs[(j2+2)*rhsStride];
const Scalar* b3 = &rhs[(j2+3)*rhsStride];
for(int k=0; k<depth; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k]));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k]));
if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k]));
if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k]));
count += nr*PacketSize;
}
// skip what we have after
count += PacketSize * nr * (stride-offset-depth);
}
// copy the remaining columns one at a time (nr==1)
for(int j2=packet_cols; j2<cols; ++j2)
{
count += PacketSize * offset;
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
for(int k=0; k<depth; k++)
{
ei_pstore(&blockB[count], ei_pset1(alpha*b0[k]));
count += PacketSize;
}
count += PacketSize * (stride-offset-depth);
}
} }
}; };
/* Optimized triangular solver with multiple right hand side (_TRSM) /* Optimized triangular solver with multiple right hand side (_TRSM)
*/ */
template <typename Scalar, template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
int LhsStorageOrder, struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
int RhsStorageOrder,
int Mode>
struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
{ {
static EIGEN_DONT_INLINE void run( static EIGEN_DONT_INLINE void run(
int size, int cols, int size, int cols,
const Scalar* _lhs, int lhsStride, const Scalar* _lhs, int lhsStride,
Scalar* _rhs, int rhsStride) Scalar* _rhs, int rhsStride)
{ {
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_blas_data_mapper <Scalar, RhsStorageOrder> rhs(_rhs,rhsStride); ei_blas_data_mapper <Scalar, ColMajor> rhs(_rhs,rhsStride);
typedef ei_product_blocking_traits<Scalar> Blocking; typedef ei_product_blocking_traits<Scalar> Blocking;
enum { enum {
@ -96,7 +70,8 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<false,false> > gebp_kernel; ei_conj_if<ConjugateLhs> conj;
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,false> > gebp_kernel;
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs; ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
for(int k2=IsLowerTriangular ? 0 : size; for(int k2=IsLowerTriangular ? 0 : size;
@ -131,7 +106,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
int s = IsLowerTriangular ? k2+k1 : i+1; int s = IsLowerTriangular ? k2+k1 : i+1;
int rs = actualPanelWidth - k - 1; // remaining size int rs = actualPanelWidth - k - 1; // remaining size
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/lhs(i,i); Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i));
for (int j=0; j<cols; ++j) for (int j=0; j<cols; ++j)
{ {
if (LhsStorageOrder==RowMajor) if (LhsStorageOrder==RowMajor)
@ -140,7 +115,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
const Scalar* l = &lhs(i,s); const Scalar* l = &lhs(i,s);
Scalar* r = &rhs(s,j); Scalar* r = &rhs(s,j);
for (int i3=0; i3<k; ++i3) for (int i3=0; i3<k; ++i3)
b += l[i3] * r[i3]; b += conj(l[i3]) * r[i3];
rhs(i,j) = (rhs(i,j) - b)*a; rhs(i,j) = (rhs(i,j) - b)*a;
} }
@ -151,7 +126,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
Scalar* r = &rhs(s,j); Scalar* r = &rhs(s,j);
const Scalar* l = &lhs(s,i); const Scalar* l = &lhs(s,i);
for (int i3=0;i3<rs;++i3) for (int i3=0;i3<rs;++i3)
r[i3] -= b * l[i3]; r[i3] -= b * conj(l[i3]);
} }
} }
} }
@ -199,4 +174,46 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
} }
}; };
template<typename Scalar, int nr>
struct ei_gemm_pack_rhs_panel
{
enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset)
{
int packet_cols = (cols/nr) * nr;
int count = 0;
for(int j2=0; j2<packet_cols; j2+=nr)
{
// skip what we have before
count += PacketSize * nr * offset;
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
const Scalar* b1 = &rhs[(j2+1)*rhsStride];
const Scalar* b2 = &rhs[(j2+2)*rhsStride];
const Scalar* b3 = &rhs[(j2+3)*rhsStride];
for(int k=0; k<depth; k++)
{
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k]));
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k]));
if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k]));
if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k]));
count += nr*PacketSize;
}
// skip what we have after
count += PacketSize * nr * (stride-offset-depth);
}
// copy the remaining columns one at a time (nr==1)
for(int j2=packet_cols; j2<cols; ++j2)
{
count += PacketSize * offset;
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
for(int k=0; k<depth; k++)
{
ei_pstore(&blockB[count], ei_pset1(alpha*b0[k]));
count += PacketSize;
}
count += PacketSize * (stride-offset-depth);
}
}
};
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H

View File

@ -101,6 +101,7 @@ ei_add_test(product_extra ${EI_OFLAG})
ei_add_test(product_selfadjoint ${EI_OFLAG}) ei_add_test(product_selfadjoint ${EI_OFLAG})
ei_add_test(product_symm ${EI_OFLAG}) ei_add_test(product_symm ${EI_OFLAG})
ei_add_test(product_syrk ${EI_OFLAG}) ei_add_test(product_syrk ${EI_OFLAG})
ei_add_test(product_trsm ${EI_OFLAG})
ei_add_test(diagonalmatrices) ei_add_test(diagonalmatrices)
ei_add_test(adjoint) ei_add_test(adjoint)
ei_add_test(submatrices) ei_add_test(submatrices)

95
test/product_trsm.cpp Normal file
View File

@ -0,0 +1,95 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@gmail.com>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation; either
// version 3 of the License, or (at your option) any later version.
//
// Alternatively, you can redistribute it and/or
// modify it under the terms of the GNU General Public License as
// published by the Free Software Foundation; either version 2 of
// the License, or (at your option) any later version.
//
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public
// License and a copy of the GNU General Public License along with
// Eigen. If not, see <http://www.gnu.org/licenses/>.
#include "main.h"
template<typename Lhs, typename Rhs>
void solve_ref(const Lhs& lhs, Rhs& rhs)
{
for (int j=0; j<rhs.cols(); ++j)
lhs.solveInPlace(rhs.col(j));
}
template<typename Scalar> void trsm(int size,int cols)
{
typedef typename NumTraits<Scalar>::Real RealScalar;
Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size);
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size);
Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRef(size,cols), cmRhs(size,cols);
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRef(size,cols), rmRhs(size,cols);
cmLhs.setRandom(); cmLhs.diagonal().cwise() += 10;
rmLhs.setRandom(); rmLhs.diagonal().cwise() += 10;
cmRhs.setRandom(); cmRef = cmRhs;
cmLhs.conjugate().template triangularView<LowerTriangular>().solveInPlace(cmRhs);
solve_ref(cmLhs.conjugate().template triangularView<LowerTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
cmRhs.setRandom(); cmRef = cmRhs;
cmLhs.conjugate().template triangularView<UpperTriangular>().solveInPlace(cmRhs);
solve_ref(cmLhs.conjugate().template triangularView<UpperTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
rmRhs.setRandom(); rmRef = rmRhs;
cmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
solve_ref(cmLhs.template triangularView<LowerTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
rmRhs.setRandom(); rmRef = rmRhs;
cmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
solve_ref(cmLhs.template triangularView<UpperTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
cmRhs.setRandom(); cmRef = cmRhs;
rmLhs.template triangularView<UnitLowerTriangular>().solveInPlace(cmRhs);
solve_ref(rmLhs.template triangularView<UnitLowerTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
cmRhs.setRandom(); cmRef = cmRhs;
rmLhs.template triangularView<UnitUpperTriangular>().solveInPlace(cmRhs);
solve_ref(rmLhs.template triangularView<UnitUpperTriangular>(),cmRef);
VERIFY_IS_APPROX(cmRhs, cmRef);
rmRhs.setRandom(); rmRef = rmRhs;
rmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
solve_ref(rmLhs.template triangularView<LowerTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
rmRhs.setRandom(); rmRef = rmRhs;
rmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef);
VERIFY_IS_APPROX(rmRhs, rmRef);
}
void test_product_trsm()
{
for(int i = 0; i < g_repeat ; i++)
{
trsm<float>(ei_random<int>(1,320),ei_random<int>(1,320));
trsm<std::complex<double> >(ei_random<int>(1,320),ei_random<int>(1,320));
}
}