From b8b6566f0f2cd925e7993d1d0a7f351d9f6cf963 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 25 Nov 2021 16:11:25 +0000 Subject: [PATCH] Currently, the binding of LLT to Lapacke is done using a large macro. This factors out a large part of the functionality of the macro and implement them explicitly. --- Eigen/src/Cholesky/LLT_LAPACKE.h | 155 ++++++++++++++++++++----------- 1 file changed, 103 insertions(+), 52 deletions(-) diff --git a/Eigen/src/Cholesky/LLT_LAPACKE.h b/Eigen/src/Cholesky/LLT_LAPACKE.h index 6b2bf28c2..d7d75bddc 100644 --- a/Eigen/src/Cholesky/LLT_LAPACKE.h +++ b/Eigen/src/Cholesky/LLT_LAPACKE.h @@ -39,60 +39,111 @@ namespace Eigen { namespace internal { -template struct lapacke_llt; +namespace lapacke_llt_helpers { -#define EIGEN_LAPACKE_LLT(EIGTYPE, BLASTYPE, LAPACKE_PREFIX) \ -template<> struct lapacke_llt \ -{ \ - template \ - static inline Index potrf(MatrixType& m, char uplo) \ - { \ - lapack_int matrix_order; \ - lapack_int size, lda, info, StorageOrder; \ - EIGTYPE* a; \ - eigen_assert(m.rows()==m.cols()); \ - /* Set up parameters for ?potrf */ \ - size = convert_index(m.rows()); \ - StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor; \ - matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; \ - a = &(m.coeffRef(0,0)); \ - lda = convert_index(m.outerStride()); \ -\ - info = LAPACKE_##LAPACKE_PREFIX##potrf( matrix_order, uplo, size, (BLASTYPE*)a, lda ); \ - info = (info==0) ? -1 : info>0 ? info-1 : size; \ - return info; \ - } \ -}; \ -template<> struct llt_inplace \ -{ \ - template \ - static Index blocked(MatrixType& m) \ - { \ - return lapacke_llt::potrf(m, 'L'); \ - } \ - template \ - static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \ - { return Eigen::internal::llt_rank_update_lower(mat, vec, sigma); } \ -}; \ -template<> struct llt_inplace \ -{ \ - template \ - static Index blocked(MatrixType& m) \ - { \ - return lapacke_llt::potrf(m, 'U'); \ - } \ - template \ - static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) \ - { \ - Transpose matt(mat); \ - return llt_inplace::rankUpdate(matt, vec.conjugate(), sigma); \ - } \ -}; + // ------------------------------------------------------------------------------------------------------------------- + // Translation from Eigen to Lapacke types + // ------------------------------------------------------------------------------------------------------------------- -EIGEN_LAPACKE_LLT(double, double, d) -EIGEN_LAPACKE_LLT(float, float, s) -EIGEN_LAPACKE_LLT(dcomplex, lapack_complex_double, z) -EIGEN_LAPACKE_LLT(scomplex, lapack_complex_float, c) + // For complex numbers, the types in Eigen and Lapacke are different, but layout compatible. + template struct translate_type; + template<> struct translate_type { using type = float; }; + template<> struct translate_type { using type = double; }; + template<> struct translate_type { using type = lapack_complex_double; }; + template<> struct translate_type { using type = lapack_complex_float; }; + + // ------------------------------------------------------------------------------------------------------------------- + // Dispatch for potrf handling double, float, complex double, complex float types + // ------------------------------------------------------------------------------------------------------------------- + + inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, double* a, lapack_int lda) { + return LAPACKE_dpotrf( matrix_order, uplo, size, a, lda ); + } + + inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, float* a, lapack_int lda) { + return LAPACKE_spotrf( matrix_order, uplo, size, a, lda ); + } + + inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_double* a, lapack_int lda) { + return LAPACKE_zpotrf( matrix_order, uplo, size, a, lda ); + } + + inline lapack_int potrf(lapack_int matrix_order, char uplo, lapack_int size, lapack_complex_float* a, lapack_int lda) { + return LAPACKE_cpotrf( matrix_order, uplo, size, a, lda ); + } + + // ------------------------------------------------------------------------------------------------------------------- + // Dispatch for rank update handling upper and lower parts + // ------------------------------------------------------------------------------------------------------------------- + + template + struct rank_update {}; + + template<> + struct rank_update { + template + static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) { + return Eigen::internal::llt_rank_update_lower(mat, vec, sigma); + } + }; + + template<> + struct rank_update { + template + static Index run(MatrixType &mat, const VectorType &vec, const typename MatrixType::RealScalar &sigma) { + Transpose matt(mat); + return Eigen::internal::llt_rank_update_lower(matt, vec.conjugate(), sigma); + } + }; + + // ------------------------------------------------------------------------------------------------------------------- + // Generic lapacke llt implementation that hands of to the dispatches + // ------------------------------------------------------------------------------------------------------------------- + + template + struct lapacke_llt { + using BlasType = typename translate_type::type; + template + static Index blocked(MatrixType& m) + { + eigen_assert(m.rows()==m.cols()); + /* Set up parameters for ?potrf */ + lapack_int size = convert_index(m.rows()); + lapack_int StorageOrder = MatrixType::Flags&RowMajorBit?RowMajor:ColMajor; + lapack_int matrix_order = StorageOrder==RowMajor ? LAPACK_ROW_MAJOR : LAPACK_COL_MAJOR; + Scalar* a = &(m.coeffRef(0,0)); + lapack_int lda = convert_index(m.outerStride()); + + lapack_int info = potrf( matrix_order, Mode == Lower ? 'L' : 'U', size, (BlasType*)a, lda ); + info = (info==0) ? -1 : info>0 ? info-1 : size; + return info; + } + + template + static Index rankUpdate(MatrixType& mat, const VectorType& vec, const typename MatrixType::RealScalar& sigma) + { + return rank_update::run(mat, vec, sigma); + } + }; +} +// end namespace lapacke_llt_helpers + +/* + * Here, we just put the generic implementation from lapacke_llt into a full specialization of the llt_inplace + * type. By being a full specialization, the versions defined here thus get precedence over the generic implementation + * in LLT.h for double, float and complex double, complex float types. + */ + +#define EIGEN_LAPACKE_LLT(EIGTYPE) \ +template<> struct llt_inplace : public lapacke_llt_helpers::lapacke_llt {}; \ +template<> struct llt_inplace : public lapacke_llt_helpers::lapacke_llt {}; + +EIGEN_LAPACKE_LLT(double) +EIGEN_LAPACKE_LLT(float) +EIGEN_LAPACKE_LLT(dcomplex) +EIGEN_LAPACKE_LLT(scomplex) + +#undef EIGEN_LAPACKE_LLT } // end namespace internal