Fully support complex types in SumReducer and MeanReducer when building for CUDA by using scalar_sum_op and scalar_product_op instead of operator+ and operator*.

This commit is contained in:
RJ Ryan 2016-10-06 10:49:48 -07:00
parent 80b5133789
commit e2e9cdd169

View File

@ -124,7 +124,8 @@ template <typename T> struct SumReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return saccum + predux(vaccum);
internal::scalar_sum_op<T> sum_op;
return sum_op(saccum, predux(vaccum));
}
};
@ -173,7 +174,8 @@ template <typename T> struct MeanReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return (saccum + predux(vaccum)) / (scalarCount_ + packetCount_ * unpacket_traits<Packet>::size);
internal::scalar_sum_op<T> sum_op;
return sum_op(saccum, predux(vaccum)) / (scalarCount_ + packetCount_ * unpacket_traits<Packet>::size);
}
protected:
@ -304,7 +306,8 @@ template <typename T> struct ProdReducer
static const bool IsStateful = false;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
(*accum) *= t;
internal::scalar_product_op<T> prod_op;
(*accum) = prod_op(*accum, t);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, Packet* accum) const {
@ -328,7 +331,8 @@ template <typename T> struct ProdReducer
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalizeBoth(const T saccum, const Packet& vaccum) const {
return saccum * predux_mul(vaccum);
internal::scalar_product_op<T> prod_op;
return prod_op(saccum, predux_mul(vaccum));
}
};