in ScalarMultiple, make the factor type independent from the matrix scalar type.

This is an optimization for complex matrices, allowing to do only a real multiplication
when a complex multiplication is not needed, e.g. in normalized().
This commit is contained in:
Benoit Jacob 2007-12-26 08:30:21 +00:00
parent dad245af56
commit 05a49547e1
5 changed files with 40 additions and 22 deletions

View File

@ -89,7 +89,8 @@ typename NumTraits<Scalar>::Real MatrixBase<Scalar, Derived>::norm() const
} }
template<typename Scalar, typename Derived> template<typename Scalar, typename Derived>
ScalarMultiple<Derived> MatrixBase<Scalar, Derived>::normalized() const const ScalarMultiple<typename NumTraits<Scalar>::Real, Derived>
MatrixBase<Scalar, Derived>::normalized() const
{ {
return (*this) / norm(); return (*this) / norm();
} }

View File

@ -172,4 +172,21 @@ inline bool isApprox(const std::complex<double>& a, const std::complex<double>&
} }
// isApproxOrLessThan wouldn't make sense for complex numbers // isApproxOrLessThan wouldn't make sense for complex numbers
#define EIGEN_MAKE_MORE_OVERLOADED_COMPLEX_OPERATOR_STAR(T,U) \
inline std::complex<T> operator*(U a, const std::complex<T>& b) \
{ \
return std::complex<T>(static_cast<T>(a)*b.real(), \
static_cast<T>(a)*b.imag()); \
} \
inline std::complex<T> operator*(const std::complex<T>& b, U a) \
{ \
return std::complex<T>(static_cast<T>(a)*b.real(), \
static_cast<T>(a)*b.imag()); \
}
EIGEN_MAKE_MORE_OVERLOADED_COMPLEX_OPERATOR_STAR(int, float)
EIGEN_MAKE_MORE_OVERLOADED_COMPLEX_OPERATOR_STAR(int, double)
EIGEN_MAKE_MORE_OVERLOADED_COMPLEX_OPERATOR_STAR(float, double)
EIGEN_MAKE_MORE_OVERLOADED_COMPLEX_OPERATOR_STAR(double, float)
#endif // EIGEN_MATHFUNCTIONS_H #endif // EIGEN_MATHFUNCTIONS_H

View File

@ -149,7 +149,7 @@ template<typename Scalar, typename Derived> class MatrixBase
Scalar dot(const OtherDerived& other) const; Scalar dot(const OtherDerived& other) const;
RealScalar norm2() const; RealScalar norm2() const;
RealScalar norm() const; RealScalar norm() const;
ScalarMultiple<Derived> normalized() const; const ScalarMultiple<RealScalar, Derived> normalized() const;
static Eval<Random<Derived> > random(int rows, int cols); static Eval<Random<Derived> > random(int rows, int cols);
static Eval<Random<Derived> > random(int size); static Eval<Random<Derived> > random(int size);

View File

@ -26,19 +26,19 @@
#ifndef EIGEN_SCALARMULTIPLE_H #ifndef EIGEN_SCALARMULTIPLE_H
#define EIGEN_SCALARMULTIPLE_H #define EIGEN_SCALARMULTIPLE_H
template<typename MatrixType> class ScalarMultiple : NoOperatorEquals, template<typename FactorType, typename MatrixType> class ScalarMultiple : NoOperatorEquals,
public MatrixBase<typename MatrixType::Scalar, ScalarMultiple<MatrixType> > public MatrixBase<typename MatrixType::Scalar, ScalarMultiple<FactorType, MatrixType> >
{ {
public: public:
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Ref MatRef; typedef typename MatrixType::Ref MatRef;
friend class MatrixBase<typename MatrixType::Scalar, ScalarMultiple<MatrixType> >; friend class MatrixBase<Scalar, ScalarMultiple<FactorType, MatrixType> >;
ScalarMultiple(const MatRef& matrix, Scalar scalar) ScalarMultiple(const MatRef& matrix, FactorType factor)
: m_matrix(matrix), m_scalar(scalar) {} : m_matrix(matrix), m_factor(factor) {}
ScalarMultiple(const ScalarMultiple& other) ScalarMultiple(const ScalarMultiple& other)
: m_matrix(other.m_matrix), m_scalar(other.m_scalar) {} : m_matrix(other.m_matrix), m_factor(other.m_factor) {}
private: private:
static const int _RowsAtCompileTime = MatrixType::RowsAtCompileTime, static const int _RowsAtCompileTime = MatrixType::RowsAtCompileTime,
@ -50,35 +50,35 @@ template<typename MatrixType> class ScalarMultiple : NoOperatorEquals,
Scalar _coeff(int row, int col) const Scalar _coeff(int row, int col) const
{ {
return m_matrix.coeff(row, col) * m_scalar; return m_factor * m_matrix.coeff(row, col);
} }
protected: protected:
const MatRef m_matrix; const MatRef m_matrix;
const Scalar m_scalar; const FactorType m_factor;
}; };
#define EIGEN_MAKE_SCALAR_OPS(OtherScalar) \ #define EIGEN_MAKE_SCALAR_OPS(FactorType) \
template<typename Scalar, typename Derived> \ template<typename Scalar, typename Derived> \
const ScalarMultiple<Derived> \ const ScalarMultiple<FactorType, Derived> \
operator*(const MatrixBase<Scalar, Derived>& matrix, \ operator*(const MatrixBase<Scalar, Derived>& matrix, \
OtherScalar scalar) \ FactorType scalar) \
{ \ { \
return ScalarMultiple<Derived>(matrix.ref(), scalar); \ return ScalarMultiple<FactorType, Derived>(matrix.ref(), scalar); \
} \ } \
\ \
template<typename Scalar, typename Derived> \ template<typename Scalar, typename Derived> \
const ScalarMultiple<Derived> \ const ScalarMultiple<FactorType, Derived> \
operator*(OtherScalar scalar, \ operator*(FactorType scalar, \
const MatrixBase<Scalar, Derived>& matrix) \ const MatrixBase<Scalar, Derived>& matrix) \
{ \ { \
return ScalarMultiple<Derived>(matrix.ref(), scalar); \ return ScalarMultiple<FactorType, Derived>(matrix.ref(), scalar); \
} \ } \
\ \
template<typename Scalar, typename Derived> \ template<typename Scalar, typename Derived> \
const ScalarMultiple<Derived> \ const ScalarMultiple<FactorType, Derived> \
operator/(const MatrixBase<Scalar, Derived>& matrix, \ operator/(const MatrixBase<Scalar, Derived>& matrix, \
OtherScalar scalar) \ FactorType scalar) \
{ \ { \
assert(NumTraits<Scalar>::HasFloatingPoint); \ assert(NumTraits<Scalar>::HasFloatingPoint); \
return matrix * (static_cast<Scalar>(1) / scalar); \ return matrix * (static_cast<Scalar>(1) / scalar); \
@ -86,14 +86,14 @@ operator/(const MatrixBase<Scalar, Derived>& matrix, \
\ \
template<typename Scalar, typename Derived> \ template<typename Scalar, typename Derived> \
Derived & \ Derived & \
MatrixBase<Scalar, Derived>::operator*=(const OtherScalar &other) \ MatrixBase<Scalar, Derived>::operator*=(const FactorType &other) \
{ \ { \
return *this = *this * other; \ return *this = *this * other; \
} \ } \
\ \
template<typename Scalar, typename Derived> \ template<typename Scalar, typename Derived> \
Derived & \ Derived & \
MatrixBase<Scalar, Derived>::operator/=(const OtherScalar &other) \ MatrixBase<Scalar, Derived>::operator/=(const FactorType &other) \
{ \ { \
return *this = *this / other; \ return *this = *this / other; \
} }

View File

@ -97,7 +97,7 @@ template<typename MatrixType> class Opposite;
template<typename Lhs, typename Rhs> class Sum; template<typename Lhs, typename Rhs> class Sum;
template<typename Lhs, typename Rhs> class Difference; template<typename Lhs, typename Rhs> class Difference;
template<typename Lhs, typename Rhs> class Product; template<typename Lhs, typename Rhs> class Product;
template<typename MatrixType> class ScalarMultiple; template<typename FactorType, typename MatrixType> class ScalarMultiple;
template<typename MatrixType> class Random; template<typename MatrixType> class Random;
template<typename MatrixType> class Zero; template<typename MatrixType> class Zero;
template<typename MatrixType> class Ones; template<typename MatrixType> class Ones;