bug #1380: fix matrix exponential with Map<>

This commit is contained in:
Gael Guennebaud 2017-01-30 13:55:27 +01:00
parent c86911ac73
commit 63de19c000

View File

@ -204,7 +204,8 @@ struct matrix_exp_computeUV
template <typename MatrixType> template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, float> struct matrix_exp_computeUV<MatrixType, float>
{ {
static void run(const MatrixType& arg, MatrixType& U, MatrixType& V, int& squarings) template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings)
{ {
using std::frexp; using std::frexp;
using std::pow; using std::pow;
@ -227,7 +228,8 @@ struct matrix_exp_computeUV<MatrixType, float>
template <typename MatrixType> template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, double> struct matrix_exp_computeUV<MatrixType, double>
{ {
static void run(const MatrixType& arg, MatrixType& U, MatrixType& V, int& squarings) template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings)
{ {
using std::frexp; using std::frexp;
using std::pow; using std::pow;
@ -254,7 +256,8 @@ struct matrix_exp_computeUV<MatrixType, double>
template <typename MatrixType> template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, long double> struct matrix_exp_computeUV<MatrixType, long double>
{ {
static void run(const MatrixType& arg, MatrixType& U, MatrixType& V, int& squarings) template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings)
{ {
#if LDBL_MANT_DIG == 53 // double precision #if LDBL_MANT_DIG == 53 // double precision
matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings); matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings);
@ -351,11 +354,11 @@ void matrix_exp_compute(const MatrixType& arg, ResultType &result)
return; return;
} }
#endif #endif
MatrixType U, V; typename MatrixType::PlainObject U, V;
int squarings; int squarings;
matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings); // Pade approximant is (U+V) / (-U+V) matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings); // Pade approximant is (U+V) / (-U+V)
MatrixType numer = U + V; typename MatrixType::PlainObject numer = U + V;
MatrixType denom = -U + V; typename MatrixType::PlainObject denom = -U + V;
result = denom.partialPivLu().solve(numer); result = denom.partialPivLu().solve(numer);
for (int i=0; i<squarings; i++) for (int i=0; i<squarings; i++)
result *= result; // undo scaling by repeated squaring result *= result; // undo scaling by repeated squaring