mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 14:53:13 +08:00
finalize trsm: works in all situations, and it is now used by solve() and solveInPlace()
This commit is contained in:
parent
282e18da49
commit
f3fde74695
@ -29,13 +29,14 @@ template<typename Lhs, typename Rhs,
|
||||
int Mode, // can be Upper/Lower | UnitDiag
|
||||
int Unrolling = Rhs::IsVectorAtCompileTime && Rhs::SizeAtCompileTime <= 8 // FIXME
|
||||
? CompleteUnrolling : NoUnrolling,
|
||||
int StorageOrder = int(Lhs::Flags) & RowMajorBit
|
||||
int StorageOrder = int(Lhs::Flags) & RowMajorBit,
|
||||
int RhsCols = Rhs::ColsAtCompileTime
|
||||
>
|
||||
struct ei_triangular_solver_selector;
|
||||
|
||||
// forward and backward substitution, row-major
|
||||
// forward and backward substitution, row-major, rhs is a vector
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor,1>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef ei_blas_traits<Lhs> LhsProductTraits;
|
||||
@ -89,9 +90,9 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,RowMajor>
|
||||
}
|
||||
};
|
||||
|
||||
// forward and backward substitution, column-major
|
||||
// forward and backward substitution, column-major, rhs is a vector
|
||||
template<typename Lhs, typename Rhs, int Mode>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor,1>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef typename ei_packet_traits<Scalar>::type Packet;
|
||||
@ -150,6 +151,24 @@ struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,ColMajor>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, int Mode>
|
||||
struct ei_triangular_solve_matrix;
|
||||
|
||||
// the rhs is a matrix
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder, int RhsCols>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,NoUnrolling,StorageOrder,RhsCols>
|
||||
{
|
||||
typedef typename Rhs::Scalar Scalar;
|
||||
typedef ei_blas_traits<Lhs> LhsProductTraits;
|
||||
typedef typename LhsProductTraits::ActualXprType ActualLhsType;
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
const ActualLhsType& actualLhs = LhsProductTraits::extract(lhs);
|
||||
ei_triangular_solve_matrix<Scalar,StorageOrder,LhsProductTraits::NeedToConjugate,Rhs::Flags&RowMajorBit,Mode>
|
||||
::run(lhs.rows(), rhs.cols(), &actualLhs.coeff(0,0), actualLhs.stride(), &rhs.coeffRef(0,0), rhs.stride());
|
||||
}
|
||||
};
|
||||
|
||||
/***************************************************************************
|
||||
* meta-unrolling implementation
|
||||
***************************************************************************/
|
||||
@ -184,7 +203,7 @@ struct ei_triangular_solver_unroller<Lhs,Rhs,Mode,Index,Size,true> {
|
||||
};
|
||||
|
||||
template<typename Lhs, typename Rhs, int Mode, int StorageOrder>
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder> {
|
||||
struct ei_triangular_solver_selector<Lhs,Rhs,Mode,CompleteUnrolling,StorageOrder,1> {
|
||||
static void run(const Lhs& lhs, Rhs& rhs)
|
||||
{ ei_triangular_solver_unroller<Lhs,Rhs,Mode,0,Rhs::SizeAtCompileTime>::run(lhs,rhs); }
|
||||
};
|
||||
|
@ -26,63 +26,37 @@
|
||||
#define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||
|
||||
template<typename Scalar, int nr>
|
||||
struct ei_gemm_pack_rhs_panel
|
||||
struct ei_gemm_pack_rhs_panel;
|
||||
|
||||
// if the rhs is row major, we have to evaluate it in a temporary colmajor matrix
|
||||
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
|
||||
struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,RowMajor,Mode>
|
||||
{
|
||||
enum { PacketSize = ei_packet_traits<Scalar>::size };
|
||||
void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset)
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int cols,
|
||||
const Scalar* lhs, int lhsStride,
|
||||
Scalar* _rhs, int rhsStride)
|
||||
{
|
||||
int packet_cols = (cols/nr) * nr;
|
||||
int count = 0;
|
||||
for(int j2=0; j2<packet_cols; j2+=nr)
|
||||
{
|
||||
// skip what we have before
|
||||
count += PacketSize * nr * offset;
|
||||
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
|
||||
const Scalar* b1 = &rhs[(j2+1)*rhsStride];
|
||||
const Scalar* b2 = &rhs[(j2+2)*rhsStride];
|
||||
const Scalar* b3 = &rhs[(j2+3)*rhsStride];
|
||||
for(int k=0; k<depth; k++)
|
||||
{
|
||||
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k]));
|
||||
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k]));
|
||||
if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k]));
|
||||
if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k]));
|
||||
count += nr*PacketSize;
|
||||
}
|
||||
// skip what we have after
|
||||
count += PacketSize * nr * (stride-offset-depth);
|
||||
}
|
||||
// copy the remaining columns one at a time (nr==1)
|
||||
for(int j2=packet_cols; j2<cols; ++j2)
|
||||
{
|
||||
count += PacketSize * offset;
|
||||
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
|
||||
for(int k=0; k<depth; k++)
|
||||
{
|
||||
ei_pstore(&blockB[count], ei_pset1(alpha*b0[k]));
|
||||
count += PacketSize;
|
||||
}
|
||||
count += PacketSize * (stride-offset-depth);
|
||||
}
|
||||
Map<Matrix<Scalar,Dynamic,Dynamic> > rhs(_rhs, rhsStride, cols);
|
||||
Matrix<Scalar,Dynamic,Dynamic> aux = rhs.block(0,0,size,cols);
|
||||
ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
|
||||
::run(size, cols, lhs, lhsStride, aux.data(), aux.stride());
|
||||
rhs.block(0,0,size,cols) = aux;
|
||||
}
|
||||
};
|
||||
|
||||
/* Optimized triangular solver with multiple right hand side (_TRSM)
|
||||
*/
|
||||
template <typename Scalar,
|
||||
int LhsStorageOrder,
|
||||
int RhsStorageOrder,
|
||||
int Mode>
|
||||
struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
|
||||
template <typename Scalar, int LhsStorageOrder, bool ConjugateLhs, int Mode>
|
||||
struct ei_triangular_solve_matrix<Scalar,LhsStorageOrder,ConjugateLhs,ColMajor,Mode>
|
||||
{
|
||||
|
||||
static EIGEN_DONT_INLINE void run(
|
||||
int size, int cols,
|
||||
const Scalar* _lhs, int lhsStride,
|
||||
Scalar* _rhs, int rhsStride)
|
||||
{
|
||||
ei_const_blas_data_mapper<Scalar, LhsStorageOrder> lhs(_lhs,lhsStride);
|
||||
ei_blas_data_mapper <Scalar, RhsStorageOrder> rhs(_rhs,rhsStride);
|
||||
ei_blas_data_mapper <Scalar, ColMajor> rhs(_rhs,rhsStride);
|
||||
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
enum {
|
||||
@ -96,7 +70,8 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
|
||||
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
|
||||
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize);
|
||||
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<false,false> > gebp_kernel;
|
||||
ei_conj_if<ConjugateLhs> conj;
|
||||
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, ei_conj_helper<ConjugateLhs,false> > gebp_kernel;
|
||||
ei_gemm_pack_lhs<Scalar,Blocking::mr,LhsStorageOrder> pack_lhs;
|
||||
|
||||
for(int k2=IsLowerTriangular ? 0 : size;
|
||||
@ -131,7 +106,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
|
||||
int s = IsLowerTriangular ? k2+k1 : i+1;
|
||||
int rs = actualPanelWidth - k - 1; // remaining size
|
||||
|
||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/lhs(i,i);
|
||||
Scalar a = (Mode & UnitDiagBit) ? Scalar(1) : Scalar(1)/conj(lhs(i,i));
|
||||
for (int j=0; j<cols; ++j)
|
||||
{
|
||||
if (LhsStorageOrder==RowMajor)
|
||||
@ -140,7 +115,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
|
||||
const Scalar* l = &lhs(i,s);
|
||||
Scalar* r = &rhs(s,j);
|
||||
for (int i3=0; i3<k; ++i3)
|
||||
b += l[i3] * r[i3];
|
||||
b += conj(l[i3]) * r[i3];
|
||||
|
||||
rhs(i,j) = (rhs(i,j) - b)*a;
|
||||
}
|
||||
@ -151,7 +126,7 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
|
||||
Scalar* r = &rhs(s,j);
|
||||
const Scalar* l = &lhs(s,i);
|
||||
for (int i3=0;i3<rs;++i3)
|
||||
r[i3] -= b * l[i3];
|
||||
r[i3] -= b * conj(l[i3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -199,4 +174,46 @@ struct ei_triangular_solve_matrix//<Scalar,LhsStorageOrder,RhsStorageOrder>
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, int nr>
|
||||
struct ei_gemm_pack_rhs_panel
|
||||
{
|
||||
enum { PacketSize = ei_packet_traits<Scalar>::size };
|
||||
void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int depth, int cols, int stride, int offset)
|
||||
{
|
||||
int packet_cols = (cols/nr) * nr;
|
||||
int count = 0;
|
||||
for(int j2=0; j2<packet_cols; j2+=nr)
|
||||
{
|
||||
// skip what we have before
|
||||
count += PacketSize * nr * offset;
|
||||
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
|
||||
const Scalar* b1 = &rhs[(j2+1)*rhsStride];
|
||||
const Scalar* b2 = &rhs[(j2+2)*rhsStride];
|
||||
const Scalar* b3 = &rhs[(j2+3)*rhsStride];
|
||||
for(int k=0; k<depth; k++)
|
||||
{
|
||||
ei_pstore(&blockB[count+0*PacketSize], ei_pset1(alpha*b0[k]));
|
||||
ei_pstore(&blockB[count+1*PacketSize], ei_pset1(alpha*b1[k]));
|
||||
if(nr==4) ei_pstore(&blockB[count+2*PacketSize], ei_pset1(alpha*b2[k]));
|
||||
if(nr==4) ei_pstore(&blockB[count+3*PacketSize], ei_pset1(alpha*b3[k]));
|
||||
count += nr*PacketSize;
|
||||
}
|
||||
// skip what we have after
|
||||
count += PacketSize * nr * (stride-offset-depth);
|
||||
}
|
||||
// copy the remaining columns one at a time (nr==1)
|
||||
for(int j2=packet_cols; j2<cols; ++j2)
|
||||
{
|
||||
count += PacketSize * offset;
|
||||
const Scalar* b0 = &rhs[(j2+0)*rhsStride];
|
||||
for(int k=0; k<depth; k++)
|
||||
{
|
||||
ei_pstore(&blockB[count], ei_pset1(alpha*b0[k]));
|
||||
count += PacketSize;
|
||||
}
|
||||
count += PacketSize * (stride-offset-depth);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H
|
||||
|
@ -101,6 +101,7 @@ ei_add_test(product_extra ${EI_OFLAG})
|
||||
ei_add_test(product_selfadjoint ${EI_OFLAG})
|
||||
ei_add_test(product_symm ${EI_OFLAG})
|
||||
ei_add_test(product_syrk ${EI_OFLAG})
|
||||
ei_add_test(product_trsm ${EI_OFLAG})
|
||||
ei_add_test(diagonalmatrices)
|
||||
ei_add_test(adjoint)
|
||||
ei_add_test(submatrices)
|
||||
|
95
test/product_trsm.cpp
Normal file
95
test/product_trsm.cpp
Normal file
@ -0,0 +1,95 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@gmail.com>
|
||||
//
|
||||
// Eigen is free software; you can redistribute it and/or
|
||||
// modify it under the terms of the GNU Lesser General Public
|
||||
// License as published by the Free Software Foundation; either
|
||||
// version 3 of the License, or (at your option) any later version.
|
||||
//
|
||||
// Alternatively, you can redistribute it and/or
|
||||
// modify it under the terms of the GNU General Public License as
|
||||
// published by the Free Software Foundation; either version 2 of
|
||||
// the License, or (at your option) any later version.
|
||||
//
|
||||
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
|
||||
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public
|
||||
// License and a copy of the GNU General Public License along with
|
||||
// Eigen. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
#include "main.h"
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
void solve_ref(const Lhs& lhs, Rhs& rhs)
|
||||
{
|
||||
for (int j=0; j<rhs.cols(); ++j)
|
||||
lhs.solveInPlace(rhs.col(j));
|
||||
}
|
||||
|
||||
template<typename Scalar> void trsm(int size,int cols)
|
||||
{
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmLhs(size,size);
|
||||
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmLhs(size,size);
|
||||
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> cmRef(size,cols), cmRhs(size,cols);
|
||||
Matrix<Scalar,Dynamic,Dynamic,RowMajor> rmRef(size,cols), rmRhs(size,cols);
|
||||
|
||||
cmLhs.setRandom(); cmLhs.diagonal().cwise() += 10;
|
||||
rmLhs.setRandom(); rmLhs.diagonal().cwise() += 10;
|
||||
|
||||
cmRhs.setRandom(); cmRef = cmRhs;
|
||||
cmLhs.conjugate().template triangularView<LowerTriangular>().solveInPlace(cmRhs);
|
||||
solve_ref(cmLhs.conjugate().template triangularView<LowerTriangular>(),cmRef);
|
||||
VERIFY_IS_APPROX(cmRhs, cmRef);
|
||||
|
||||
cmRhs.setRandom(); cmRef = cmRhs;
|
||||
cmLhs.conjugate().template triangularView<UpperTriangular>().solveInPlace(cmRhs);
|
||||
solve_ref(cmLhs.conjugate().template triangularView<UpperTriangular>(),cmRef);
|
||||
VERIFY_IS_APPROX(cmRhs, cmRef);
|
||||
|
||||
rmRhs.setRandom(); rmRef = rmRhs;
|
||||
cmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
|
||||
solve_ref(cmLhs.template triangularView<LowerTriangular>(),rmRef);
|
||||
VERIFY_IS_APPROX(rmRhs, rmRef);
|
||||
|
||||
rmRhs.setRandom(); rmRef = rmRhs;
|
||||
cmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
|
||||
solve_ref(cmLhs.template triangularView<UpperTriangular>(),rmRef);
|
||||
VERIFY_IS_APPROX(rmRhs, rmRef);
|
||||
|
||||
|
||||
cmRhs.setRandom(); cmRef = cmRhs;
|
||||
rmLhs.template triangularView<UnitLowerTriangular>().solveInPlace(cmRhs);
|
||||
solve_ref(rmLhs.template triangularView<UnitLowerTriangular>(),cmRef);
|
||||
VERIFY_IS_APPROX(cmRhs, cmRef);
|
||||
|
||||
cmRhs.setRandom(); cmRef = cmRhs;
|
||||
rmLhs.template triangularView<UnitUpperTriangular>().solveInPlace(cmRhs);
|
||||
solve_ref(rmLhs.template triangularView<UnitUpperTriangular>(),cmRef);
|
||||
VERIFY_IS_APPROX(cmRhs, cmRef);
|
||||
|
||||
rmRhs.setRandom(); rmRef = rmRhs;
|
||||
rmLhs.template triangularView<LowerTriangular>().solveInPlace(rmRhs);
|
||||
solve_ref(rmLhs.template triangularView<LowerTriangular>(),rmRef);
|
||||
VERIFY_IS_APPROX(rmRhs, rmRef);
|
||||
|
||||
rmRhs.setRandom(); rmRef = rmRhs;
|
||||
rmLhs.template triangularView<UpperTriangular>().solveInPlace(rmRhs);
|
||||
solve_ref(rmLhs.template triangularView<UpperTriangular>(),rmRef);
|
||||
VERIFY_IS_APPROX(rmRhs, rmRef);
|
||||
}
|
||||
void test_product_trsm()
|
||||
{
|
||||
for(int i = 0; i < g_repeat ; i++)
|
||||
{
|
||||
trsm<float>(ei_random<int>(1,320),ei_random<int>(1,320));
|
||||
trsm<std::complex<double> >(ei_random<int>(1,320),ei_random<int>(1,320));
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user