Add special path for matrix<complex>/real.

This also fixes underflow issues when scaling complex matrices through complex/complex operator.
This commit is contained in:
Gael Guennebaud 2015-06-26 16:08:15 +02:00
parent e102ddbf1f
commit 98ff17eb9e
6 changed files with 33 additions and 1 deletions

View File

@ -49,6 +49,8 @@ template<typename Derived> class DenseBase
public: public:
using internal::special_scalar_op_base<Derived,typename internal::traits<Derived>::Scalar, using internal::special_scalar_op_base<Derived,typename internal::traits<Derived>::Scalar,
typename NumTraits<typename internal::traits<Derived>::Scalar>::Real>::operator*; typename NumTraits<typename internal::traits<Derived>::Scalar>::Real>::operator*;
using internal::special_scalar_op_base<Derived,typename internal::traits<Derived>::Scalar,
typename NumTraits<typename internal::traits<Derived>::Scalar>::Real>::operator/;
/** Inner iterator type to iterate over the coefficients of a row or column. /** Inner iterator type to iterate over the coefficients of a row or column.

View File

@ -81,6 +81,7 @@ template<typename Derived> class MatrixBase
using Base::operator*=; using Base::operator*=;
using Base::operator/=; using Base::operator/=;
using Base::operator*; using Base::operator*;
using Base::operator/;
typedef typename Base::CoeffReturnType CoeffReturnType; typedef typename Base::CoeffReturnType CoeffReturnType;
typedef typename Base::ConstTransposeReturnType ConstTransposeReturnType; typedef typename Base::ConstTransposeReturnType ConstTransposeReturnType;

View File

@ -392,6 +392,18 @@ template<typename Scalar>
struct functor_traits<scalar_quotient1_op<Scalar> > struct functor_traits<scalar_quotient1_op<Scalar> >
{ enum { Cost = 2 * NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasDiv }; }; { enum { Cost = 2 * NumTraits<Scalar>::MulCost, PacketAccess = packet_traits<Scalar>::HasDiv }; };
template<typename Scalar1, typename Scalar2>
struct scalar_quotient2_op {
typedef typename scalar_product_traits<Scalar1,Scalar2>::ReturnType result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_quotient2_op(const scalar_quotient2_op& other) : m_other(other.m_other) { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_quotient2_op(const Scalar2& other) : m_other(other) { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const Scalar1& a) const { return a / m_other; }
typename add_const_on_value_type<typename NumTraits<Scalar2>::Nested>::type m_other;
};
template<typename Scalar1,typename Scalar2>
struct functor_traits<scalar_quotient2_op<Scalar1,Scalar2> >
{ enum { Cost = 2 * NumTraits<Scalar1>::MulCost, PacketAccess = false }; };
// In Eigen, any binary op (Product, CwiseBinaryOp) require the Lhs and Rhs to have the same scalar type, except for multiplication // In Eigen, any binary op (Product, CwiseBinaryOp) require the Lhs and Rhs to have the same scalar type, except for multiplication
// where the mixing of different types is handled by scalar_product_traits // where the mixing of different types is handled by scalar_product_traits
// In particular, real * complex<real> is allowed. // In particular, real * complex<real> is allowed.

View File

@ -213,6 +213,7 @@ template<typename Scalar> struct scalar_identity_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
template<typename LhsScalar,typename RhsScalar> struct scalar_multiple2_op; template<typename LhsScalar,typename RhsScalar> struct scalar_multiple2_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op; template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_quotient_op;
template<typename LhsScalar,typename RhsScalar> struct scalar_quotient2_op;
} // end namespace internal } // end namespace internal

View File

@ -427,7 +427,9 @@ struct special_scalar_op_base : public DenseCoeffsBase<Derived>
{ {
// dummy operator* so that the // dummy operator* so that the
// "using special_scalar_op_base::operator*" compiles // "using special_scalar_op_base::operator*" compiles
void operator*() const; struct dummy {};
void operator*(dummy) const;
void operator/(dummy) const;
}; };
template<typename Derived,typename Scalar,typename OtherScalar> template<typename Derived,typename Scalar,typename OtherScalar>
@ -451,6 +453,16 @@ struct special_scalar_op_base<Derived,Scalar,OtherScalar,true> : public DenseCo
#endif #endif
return static_cast<const special_scalar_op_base&>(matrix).operator*(scalar); return static_cast<const special_scalar_op_base&>(matrix).operator*(scalar);
} }
const CwiseUnaryOp<scalar_quotient2_op<Scalar,OtherScalar>, Derived>
operator/(const OtherScalar& scalar) const
{
#ifdef EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
EIGEN_SPECIAL_SCALAR_MULTIPLE_PLUGIN
#endif
return CwiseUnaryOp<scalar_quotient2_op<Scalar,OtherScalar>, Derived>
(*static_cast<const Derived*>(this), scalar_quotient2_op<Scalar,OtherScalar>(scalar));
}
}; };
template<typename XprType, typename CastType> struct cast_return_type template<typename XprType, typename CastType> struct cast_return_type

View File

@ -88,6 +88,10 @@ template<typename MatrixType> void real_complex(DenseIndex rows = MatrixType::Ro
g_called = false; g_called = false;
VERIFY_IS_APPROX(m1*s, m1*Scalar(s)); VERIFY_IS_APPROX(m1*s, m1*Scalar(s));
VERIFY(g_called && "matrix<complex> * real not properly optimized"); VERIFY(g_called && "matrix<complex> * real not properly optimized");
g_called = false;
VERIFY_IS_APPROX(m1/s, m1/Scalar(s));
VERIFY(g_called && "matrix<complex> / real not properly optimized");
} }
void test_linearstructure() void test_linearstructure()