MatrixFunctions: Clean up StemFunction.h

This commit is contained in:
Jitse Niesen 2013-07-26 13:51:10 +01:00
parent 75edc7cc8b
commit e43934d60f
3 changed files with 96 additions and 84 deletions

View File

@ -547,7 +547,7 @@ const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::sin() const
{ {
eigen_assert(rows() == cols()); eigen_assert(rows() == cols());
typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar; typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::sin); return MatrixFunctionReturnValue<Derived>(derived(), internal::stem_function_sin<ComplexScalar>);
} }
template <typename Derived> template <typename Derived>
@ -555,7 +555,7 @@ const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::cos() const
{ {
eigen_assert(rows() == cols()); eigen_assert(rows() == cols());
typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar; typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::cos); return MatrixFunctionReturnValue<Derived>(derived(), internal::stem_function_cos<ComplexScalar>);
} }
template <typename Derived> template <typename Derived>
@ -563,7 +563,7 @@ const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::sinh() const
{ {
eigen_assert(rows() == cols()); eigen_assert(rows() == cols());
typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar; typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::sinh); return MatrixFunctionReturnValue<Derived>(derived(), internal::stem_function_sinh<ComplexScalar>);
} }
template <typename Derived> template <typename Derived>
@ -571,7 +571,7 @@ const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::cosh() const
{ {
eigen_assert(rows() == cols()); eigen_assert(rows() == cols());
typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar; typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::cosh); return MatrixFunctionReturnValue<Derived>(derived(), internal::stem_function_cosh<ComplexScalar>);
} }
} // end namespace Eigen } // end namespace Eigen

View File

@ -12,93 +12,105 @@
namespace Eigen { namespace Eigen {
/** \ingroup MatrixFunctions_Module namespace internal {
* \brief Stem functions corresponding to standard mathematical functions.
*/ /** \brief The exponential function (and its derivatives). */
template <typename Scalar> template <typename Scalar>
class StdStemFunctions : internal::noncopyable Scalar stem_function_exp(Scalar x, int)
{ {
public: using std::exp;
return exp(x);
}
/** \brief The exponential function (and its derivatives). */ /** \brief Cosine (and its derivatives). */
static Scalar exp(Scalar x, int) template <typename Scalar>
{ Scalar stem_function_cos(Scalar x, int n)
return std::exp(x); {
} using std::cos;
using std::sin;
Scalar res;
/** \brief Cosine (and its derivatives). */ switch (n % 4) {
static Scalar cos(Scalar x, int n) case 0:
{ res = std::cos(x);
Scalar res; break;
switch (n % 4) { case 1:
case 0: res = -std::sin(x);
res = std::cos(x); break;
break; case 2:
case 1: res = -std::cos(x);
res = -std::sin(x); break;
break; case 3:
case 2: res = std::sin(x);
res = -std::cos(x); break;
break; }
case 3: return res;
res = std::sin(x); }
break;
}
return res;
}
/** \brief Sine (and its derivatives). */ /** \brief Sine (and its derivatives). */
static Scalar sin(Scalar x, int n) template <typename Scalar>
{ Scalar stem_function_sin(Scalar x, int n)
Scalar res; {
switch (n % 4) { using std::cos;
case 0: using std::sin;
res = std::sin(x); Scalar res;
break;
case 1:
res = std::cos(x);
break;
case 2:
res = -std::sin(x);
break;
case 3:
res = -std::cos(x);
break;
}
return res;
}
/** \brief Hyperbolic cosine (and its derivatives). */ switch (n % 4) {
static Scalar cosh(Scalar x, int n) case 0:
{ res = std::sin(x);
Scalar res; break;
switch (n % 2) { case 1:
case 0: res = std::cos(x);
res = std::cosh(x); break;
break; case 2:
case 1: res = -std::sin(x);
res = std::sinh(x); break;
break; case 3:
} res = -std::cos(x);
return res; break;
} }
return res;
}
/** \brief Hyperbolic cosine (and its derivatives). */
template <typename Scalar>
Scalar stem_function_cosh(Scalar x, int n)
{
using std::cosh;
using std::sinh;
Scalar res;
switch (n % 2) {
case 0:
res = std::cosh(x);
break;
case 1:
res = std::sinh(x);
break;
}
return res;
}
/** \brief Hyperbolic sine (and its derivatives). */ /** \brief Hyperbolic sine (and its derivatives). */
static Scalar sinh(Scalar x, int n) template <typename Scalar>
{ Scalar stem_function_sinh(Scalar x, int n)
Scalar res; {
switch (n % 2) { using std::cosh;
case 0: using std::sinh;
res = std::sinh(x); Scalar res;
break;
case 1: switch (n % 2) {
res = std::cosh(x); case 0:
break; res = std::sinh(x);
} break;
return res; case 1:
} res = std::cosh(x);
break;
}
return res;
}
}; // end of class StdStemFunctions } // end namespace internal
} // end namespace Eigen } // end namespace Eigen

View File

@ -102,7 +102,7 @@ void testMatrixExponential(const MatrixType& A)
typedef typename NumTraits<Scalar>::Real RealScalar; typedef typename NumTraits<Scalar>::Real RealScalar;
typedef std::complex<RealScalar> ComplexScalar; typedef std::complex<RealScalar> ComplexScalar;
VERIFY_IS_APPROX(A.exp(), A.matrixFunction(StdStemFunctions<ComplexScalar>::exp)); VERIFY_IS_APPROX(A.exp(), A.matrixFunction(internal::stem_function_exp<ComplexScalar>));
} }
template<typename MatrixType> template<typename MatrixType>