Fix mixing scalar types with evaluators

This commit is contained in:
Gael Guennebaud 2014-02-19 16:30:17 +01:00
parent 8af02d19b2
commit 2eee6eaf3c
7 changed files with 35 additions and 34 deletions

View File

@ -213,7 +213,7 @@ template<typename OtherDerived>
EIGEN_STRONG_INLINE Derived & EIGEN_STRONG_INLINE Derived &
ArrayBase<Derived>::operator*=(const ArrayBase<OtherDerived>& other) ArrayBase<Derived>::operator*=(const ArrayBase<OtherDerived>& other)
{ {
call_assignment(derived(), other.derived(), internal::mul_assign_op<Scalar>()); call_assignment(derived(), other.derived(), internal::mul_assign_op<Scalar,typename OtherDerived::Scalar>());
return derived(); return derived();
} }

View File

@ -729,6 +729,11 @@ void call_assignment_no_alias(Dst& dst, const Src& src, const Func& func)
typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst&>::type ActualDstType; typedef typename internal::conditional<NeedToTranspose, Transpose<Dst>, Dst&>::type ActualDstType;
ActualDstType actualDst(dst); ActualDstType actualDst(dst);
// TODO check whether this is the right place to perform these checks:
EIGEN_STATIC_ASSERT_LVALUE(Dst)
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(ActualDstTypeCleaned,Src)
EIGEN_CHECK_BINARY_COMPATIBILIY(Func,typename ActualDstTypeCleaned::Scalar,typename Src::Scalar);
Assignment<ActualDstTypeCleaned,Src,Func>::run(actualDst, src, func); Assignment<ActualDstTypeCleaned,Src,Func>::run(actualDst, src, func);
} }
@ -739,15 +744,6 @@ struct Assignment<DstXprType, SrcXprType, Functor, Dense2Dense, Scalar>
{ {
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func) static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
{ {
// TODO check whether this is the right place to perform these checks:
enum{
SameType = internal::is_same<typename DstXprType::Scalar,typename SrcXprType::Scalar>::value
};
EIGEN_STATIC_ASSERT_LVALUE(DstXprType)
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DstXprType,SrcXprType)
EIGEN_STATIC_ASSERT(SameType,YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols()); eigen_assert(dst.rows() == src.rows() && dst.cols() == src.cols());
#ifdef EIGEN_DEBUG_ASSIGN #ifdef EIGEN_DEBUG_ASSIGN

View File

@ -86,19 +86,6 @@ struct traits<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
}; };
} // end namespace internal } // end namespace internal
// we require Lhs and Rhs to have the same scalar type. Currently there is no example of a binary functor
// that would take two operands of different types. If there were such an example, then this check should be
// moved to the BinaryOp functors, on a per-case basis. This would however require a change in the BinaryOp functors, as
// currently they take only one typename Scalar template parameter.
// It is tempting to always allow mixing different types but remember that this is often impossible in the vectorized paths.
// So allowing mixing different types gives very unexpected errors when enabling vectorization, when the user tries to
// add together a float matrix and a double matrix.
#define EIGEN_CHECK_BINARY_COMPATIBILIY(BINOP,LHS,RHS) \
EIGEN_STATIC_ASSERT((internal::functor_is_product_like<BINOP>::ret \
? int(internal::scalar_product_traits<LHS, RHS>::Defined) \
: int(internal::is_same<LHS, RHS>::value)), \
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
template<typename BinaryOp, typename Lhs, typename Rhs, typename StorageKind> template<typename BinaryOp, typename Lhs, typename Rhs, typename StorageKind>
class CwiseBinaryOpImpl; class CwiseBinaryOpImpl;

View File

@ -647,8 +647,8 @@ struct diagonal_product_evaluator_base
: evaluator_base<Derived> : evaluator_base<Derived>
{ {
typedef typename MatrixType::Index Index; typedef typename MatrixType::Index Index;
typedef typename MatrixType::Scalar Scalar; typedef typename scalar_product_traits<typename MatrixType::Scalar, typename DiagonalType::Scalar>::ReturnType Scalar;
typedef typename MatrixType::PacketScalar PacketScalar; typedef typename internal::packet_traits<Scalar>::type PacketScalar;
public: public:
diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag) diagonal_product_evaluator_base(const MatrixType &mat, const DiagonalType &diag)
: m_diagImpl(diag), m_matImpl(mat) : m_diagImpl(diag), m_matImpl(mat)

View File

@ -81,22 +81,24 @@ struct functor_traits<sub_assign_op<Scalar> > {
* \brief Template functor for scalar/packet assignment with multiplication * \brief Template functor for scalar/packet assignment with multiplication
* *
*/ */
template<typename Scalar> struct mul_assign_op { template<typename DstScalar, typename SrcScalar=DstScalar>
struct mul_assign_op {
EIGEN_EMPTY_STRUCT_CTOR(mul_assign_op) EIGEN_EMPTY_STRUCT_CTOR(mul_assign_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(Scalar& a, const Scalar& b) const { a *= b; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void assignCoeff(DstScalar& a, const SrcScalar& b) const { a *= b; }
template<int Alignment, typename Packet> template<int Alignment, typename Packet>
EIGEN_STRONG_INLINE void assignPacket(Scalar* a, const Packet& b) const EIGEN_STRONG_INLINE void assignPacket(DstScalar* a, const Packet& b) const
{ internal::pstoret<Scalar,Packet,Alignment>(a,internal::pmul(internal::ploadt<Packet,Alignment>(a),b)); } { internal::pstoret<DstScalar,Packet,Alignment>(a,internal::pmul(internal::ploadt<Packet,Alignment>(a),b)); }
}; };
template<typename Scalar> template<typename DstScalar, typename SrcScalar>
struct functor_traits<mul_assign_op<Scalar> > { struct functor_traits<mul_assign_op<DstScalar,SrcScalar> > {
enum { enum {
Cost = NumTraits<Scalar>::ReadCost + NumTraits<Scalar>::MulCost, Cost = NumTraits<DstScalar>::ReadCost + NumTraits<DstScalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasMul PacketAccess = is_same<DstScalar,SrcScalar>::value && packet_traits<DstScalar>::HasMul
}; };
}; };
template<typename DstScalar,typename SrcScalar> struct functor_is_product_like<mul_assign_op<DstScalar,SrcScalar> > { enum { ret = 1 }; };
/** \internal /** \internal
* \brief Template functor for scalar/packet assignment with diviving * \brief Template functor for scalar/packet assignment with diviving

View File

@ -516,6 +516,19 @@ template<typename T, int S> struct is_diagonal<DiagonalMatrix<T,S> >
} // end namespace internal } // end namespace internal
// we require Lhs and Rhs to have the same scalar type. Currently there is no example of a binary functor
// that would take two operands of different types. If there were such an example, then this check should be
// moved to the BinaryOp functors, on a per-case basis. This would however require a change in the BinaryOp functors, as
// currently they take only one typename Scalar template parameter.
// It is tempting to always allow mixing different types but remember that this is often impossible in the vectorized paths.
// So allowing mixing different types gives very unexpected errors when enabling vectorization, when the user tries to
// add together a float matrix and a double matrix.
#define EIGEN_CHECK_BINARY_COMPATIBILIY(BINOP,LHS,RHS) \
EIGEN_STATIC_ASSERT((internal::functor_is_product_like<BINOP>::ret \
? int(internal::scalar_product_traits<LHS, RHS>::Defined) \
: int(internal::is_same<LHS, RHS>::value)), \
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_XPRHELPER_H #endif // EIGEN_XPRHELPER_H

View File

@ -53,10 +53,13 @@ template<int SizeAtCompileType> void mixingtypes(int size = SizeAtCompileType)
mf+mf; mf+mf;
VERIFY_RAISES_ASSERT(mf+md); VERIFY_RAISES_ASSERT(mf+md);
VERIFY_RAISES_ASSERT(mf+mcf); VERIFY_RAISES_ASSERT(mf+mcf);
#ifndef EIGEN_TEST_EVALUATORS
// they do not even compile when using evaluators
VERIFY_RAISES_ASSERT(vf=vd); VERIFY_RAISES_ASSERT(vf=vd);
VERIFY_RAISES_ASSERT(vf+=vd); VERIFY_RAISES_ASSERT(vf+=vd);
VERIFY_RAISES_ASSERT(mcd=md); VERIFY_RAISES_ASSERT(mcd=md);
#endif
// check scalar products // check scalar products
VERIFY_IS_APPROX(vcf * sf , vcf * complex<float>(sf)); VERIFY_IS_APPROX(vcf * sf , vcf * complex<float>(sf));
VERIFY_IS_APPROX(sd * vcd, complex<double>(sd) * vcd); VERIFY_IS_APPROX(sd * vcd, complex<double>(sd) * vcd);