From 1ac9124fac72c11eab3d831e142bba8927c140d0 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sat, 20 Nov 2010 23:29:20 +0100 Subject: [PATCH] implements TRMV level 2 blas routine --- .../Core/products/TriangularMatrixVector.h | 14 ++-- blas/level2_impl.h | 64 ++++++++++++------- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index baf5fc9fb..06307f9d4 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -28,10 +28,10 @@ namespace internal { template -struct product_triangular_vector_selector; +struct product_triangular_matrix_vector; template -struct product_triangular_vector_selector +struct product_triangular_matrix_vector { typedef typename scalar_product_traits::ReturnType ResScalar; enum { @@ -39,7 +39,7 @@ struct product_triangular_vector_selector -struct product_triangular_vector_selector +struct product_triangular_matrix_vector { typedef typename scalar_product_traits::ReturnType ResScalar; enum { @@ -93,7 +93,7 @@ struct product_triangular_vector_selector Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); - internal::product_triangular_vector_selector + internal::product_triangular_matrix_vector Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) * RhsBlasTraits::extractScalarFactor(m_rhs); - internal::product_triangular_vector_selector + internal::product_triangular_matrix_vector ::run); -// func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// -// func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// -// func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// -// func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); -// func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[NOTR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (UP << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + + func[NOTR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (LO << 2) | (NUNIT << 3)] = (internal::product_triangular_matrix_vector::run); + + func[NOTR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (UP << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + + func[NOTR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[TR | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); + func[ADJ | (LO << 2) | (UNIT << 3)] = (internal::product_triangular_matrix_vector::run); init = true; } @@ -173,11 +170,32 @@ int EIGEN_BLAS_FUNC(trmv)(char *uplo, char *opa, char *diag, int *n, RealScalar Scalar* a = reinterpret_cast(pa); Scalar* b = reinterpret_cast(pb); + int info = 0; + if(UPLO(*uplo)==INVALID) info = 1; + else if(OP(*opa)==INVALID) info = 2; + else if(DIAG(*diag)==INVALID) info = 3; + else if(*n<0) info = 4; + else if(*lda res(*n); + res.setZero(); + int code = OP(*opa) | (UPLO(*uplo) << 2) | (DIAG(*diag) << 3); if(code>=16 || func[code]==0) return 0; - func[code](*n, a, *lda, b, *incb, b, *incb); + func[code](*n, *n, a, *lda, actual_b, 1, res.data(), 1, Scalar(1)); + + copy_back(res.data(),b,*n,*incb); + if(actual_b!=b) delete[] actual_b; + return 0; } @@ -194,7 +212,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, { // typedef void (*functype)(int, const Scalar *, int, Scalar *, int, Scalar); -// functype func[2]; +// static functype func[2]; // static bool init = false; // if(!init) @@ -241,7 +259,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int EIGEN_BLAS_FUNC(syr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *py, int *incy, RealScalar *pc, int *ldc) { // typedef void (*functype)(int, const Scalar *, int, const Scalar *, int, Scalar *, int, Scalar); -// functype func[2]; +// static functype func[2]; // // static bool init = false; // if(!init)