bug #1380: fix matrix exponential with Map<>

This commit is contained in:
Gael Guennebaud 2017-01-30 13:55:27 +01:00
parent c86911ac73
commit 63de19c000

View File

@ -204,7 +204,8 @@ struct matrix_exp_computeUV
template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, float>
{
static void run(const MatrixType& arg, MatrixType& U, MatrixType& V, int& squarings)
template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings)
{
using std::frexp;
using std::pow;
@ -227,7 +228,8 @@ struct matrix_exp_computeUV<MatrixType, float>
template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, double>
{
static void run(const MatrixType& arg, MatrixType& U, MatrixType& V, int& squarings)
template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings)
{
using std::frexp;
using std::pow;
@ -254,7 +256,8 @@ struct matrix_exp_computeUV<MatrixType, double>
template <typename MatrixType>
struct matrix_exp_computeUV<MatrixType, long double>
{
static void run(const MatrixType& arg, MatrixType& U, MatrixType& V, int& squarings)
template <typename ArgType>
static void run(const ArgType& arg, MatrixType& U, MatrixType& V, int& squarings)
{
#if LDBL_MANT_DIG == 53 // double precision
matrix_exp_computeUV<MatrixType, double>::run(arg, U, V, squarings);
@ -351,11 +354,11 @@ void matrix_exp_compute(const MatrixType& arg, ResultType &result)
return;
}
#endif
MatrixType U, V;
typename MatrixType::PlainObject U, V;
int squarings;
matrix_exp_computeUV<MatrixType>::run(arg, U, V, squarings); // Pade approximant is (U+V) / (-U+V)
MatrixType numer = U + V;
MatrixType denom = -U + V;
typename MatrixType::PlainObject numer = U + V;
typename MatrixType::PlainObject denom = -U + V;
result = denom.partialPivLu().solve(numer);
for (int i=0; i<squarings; i++)
result *= result; // undo scaling by repeated squaring