mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
implements TRMV level 2 blas routine
This commit is contained in:
parent
d72a8f1e50
commit
1ac9124fac
@ -28,10 +28,10 @@
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
|
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
|
||||||
struct product_triangular_vector_selector;
|
struct product_triangular_matrix_vector;
|
||||||
|
|
||||||
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
|
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
|
||||||
struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor>
|
struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor>
|
||||||
{
|
{
|
||||||
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||||
enum {
|
enum {
|
||||||
@ -39,7 +39,7 @@ struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar
|
|||||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
||||||
};
|
};
|
||||||
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
||||||
const RhsScalar* _rhs, Index rhsIncr, const ResScalar* _res, Index resIncr, ResScalar alpha)
|
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
|
||||||
{
|
{
|
||||||
EIGEN_UNUSED_VARIABLE(resIncr);
|
EIGEN_UNUSED_VARIABLE(resIncr);
|
||||||
eigen_assert(resIncr==1);
|
eigen_assert(resIncr==1);
|
||||||
@ -85,7 +85,7 @@ struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
|
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
|
||||||
struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor>
|
struct product_triangular_matrix_vector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor>
|
||||||
{
|
{
|
||||||
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||||
enum {
|
enum {
|
||||||
@ -93,7 +93,7 @@ struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar
|
|||||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
||||||
};
|
};
|
||||||
static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
||||||
const RhsScalar* _rhs, Index rhsIncr, const ResScalar* _res, Index resIncr, ResScalar alpha)
|
const RhsScalar* _rhs, Index rhsIncr, ResScalar* _res, Index resIncr, ResScalar alpha)
|
||||||
{
|
{
|
||||||
eigen_assert(rhsIncr==1);
|
eigen_assert(rhsIncr==1);
|
||||||
EIGEN_UNUSED_VARIABLE(rhsIncr);
|
EIGEN_UNUSED_VARIABLE(rhsIncr);
|
||||||
@ -172,7 +172,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
|||||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||||
|
|
||||||
internal::product_triangular_vector_selector
|
internal::product_triangular_matrix_vector
|
||||||
<Index,Mode,
|
<Index,Mode,
|
||||||
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
||||||
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
||||||
@ -200,7 +200,7 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
|
|||||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||||
|
|
||||||
internal::product_triangular_vector_selector
|
internal::product_triangular_matrix_vector
|
||||||
<Index,(Mode & UnitDiag) | (Mode & Lower) ? Upper : Lower,
|
<Index,(Mode & UnitDiag) | (Mode & Lower) ? Upper : Lower,
|
||||||
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
||||||
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
||||||
|
@ -139,11 +139,8 @@ int EIGEN_BLAS_FUNC(trsv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
|||||||
|
|
||||||
int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
|
int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar *pa, int *lda, RealScalar *pb, int *incb)
|
||||||
{
|
{
|
||||||
return 0;
|
typedef void (*functype)(int, int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
|
||||||
// TODO
|
static functype func[16];
|
||||||
|
|
||||||
typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int);
|
|
||||||
functype func[16];
|
|
||||||
|
|
||||||
static bool init = false;
|
static bool init = false;
|
||||||
if(!init)
|
if(!init)
|
||||||
@ -151,21 +148,21 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
|||||||
for(int k=0; k<16; ++k)
|
for(int k=0; k<16; ++k)
|
||||||
func[k] = 0;
|
func[k] = 0;
|
||||||
|
|
||||||
// func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
|
func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|0, Scalar,false,Scalar,false,ColMajor>::run);
|
||||||
// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
|
func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|0, Scalar,false,Scalar,false,RowMajor>::run);
|
||||||
// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
|
func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|0, Scalar,Conj, Scalar,false,RowMajor>::run);
|
||||||
//
|
|
||||||
// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
|
func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|0, Scalar,false,Scalar,false,ColMajor>::run);
|
||||||
// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
|
func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|0, Scalar,false,Scalar,false,RowMajor>::run);
|
||||||
// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
|
func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|0, Scalar,Conj, Scalar,false,RowMajor>::run);
|
||||||
//
|
|
||||||
// func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|UnitDiagBit,true, ColMajor,false,ColMajor,false,ColMajor>::run);
|
func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|UnitDiag,Scalar,false,Scalar,false,ColMajor>::run);
|
||||||
// func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|UnitDiagBit,true, RowMajor,false,ColMajor,false,ColMajor>::run);
|
func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|UnitDiag,Scalar,false,Scalar,false,RowMajor>::run);
|
||||||
// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,UpperTriangular|UnitDiagBit,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
|
func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|UnitDiag,Scalar,Conj, Scalar,false,RowMajor>::run);
|
||||||
//
|
|
||||||
// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|UnitDiagBit,true, ColMajor,false,ColMajor,false,ColMajor>::run);
|
func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Lower|UnitDiag,Scalar,false,Scalar,false,ColMajor>::run);
|
||||||
// func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|UnitDiagBit,true, RowMajor,false,ColMajor,false,ColMajor>::run);
|
func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|UnitDiag,Scalar,false,Scalar,false,RowMajor>::run);
|
||||||
// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<Scalar,LowerTriangular|UnitDiagBit,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
|
func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector<int,Upper|UnitDiag,Scalar,Conj, Scalar,false,RowMajor>::run);
|
||||||
|
|
||||||
init = true;
|
init = true;
|
||||||
}
|
}
|
||||||
@ -173,11 +170,32 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar
|
|||||||
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
Scalar* a = reinterpret_cast<Scalar*>(pa);
|
||||||
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
Scalar* b = reinterpret_cast<Scalar*>(pb);
|
||||||
|
|
||||||
|
int info = 0;
|
||||||
|
if(UPLO(*uplo)==INVALID) info = 1;
|
||||||
|
else if(OP(*opa)==INVALID) info = 2;
|
||||||
|
else if(DIAG(*diag)==INVALID) info = 3;
|
||||||
|
else if(*n<0) info = 4;
|
||||||
|
else if(*lda<std::max(1,*n)) info = 6;
|
||||||
|
else if(*incb==0) info = 8;
|
||||||
|
if(info)
|
||||||
|
return xerbla_(SCALAR_SUFFIX_UP"TRMV ",&info,6);
|
||||||
|
|
||||||
|
if(*n==0)
|
||||||
|
return 1;
|
||||||
|
|
||||||
|
Scalar* actual_b = get_compact_vector(b,*n,*incb);
|
||||||
|
Matrix<Scalar,Dynamic,1> res(*n);
|
||||||
|
res.setZero();
|
||||||
|
|
||||||
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
|
int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3);
|
||||||
if(code>=16 || func[code]==0)
|
if(code>=16 || func[code]==0)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
func[code](*n, a, *lda, b, *incb, b, *incb);
|
func[code](*n, *n, a, *lda, actual_b, 1, res.data(), 1, Scalar(1));
|
||||||
|
|
||||||
|
copy_back(res.data(),b,*n,*incb);
|
||||||
|
if(actual_b!=b) delete[] actual_b;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,7 +212,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
|||||||
{
|
{
|
||||||
|
|
||||||
// typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar);
|
// typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar);
|
||||||
// functype func[2];
|
// static functype func[2];
|
||||||
|
|
||||||
// static bool init = false;
|
// static bool init = false;
|
||||||
// if(!init)
|
// if(!init)
|
||||||
@ -241,7 +259,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
|||||||
int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc)
|
int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc)
|
||||||
{
|
{
|
||||||
// typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
|
// typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar);
|
||||||
// functype func[2];
|
// static functype func[2];
|
||||||
//
|
//
|
||||||
// static bool init = false;
|
// static bool init = false;
|
||||||
// if(!init)
|
// if(!init)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user