Commit 52a5f982 broke conjhelper functionality for HIP GPUs.

This commit addresses this.
This commit is contained in:
Rohit Santhanam 2021-06-25 19:28:00 +00:00
parent bffd267d17
commit 2d132d1736

View File

@ -45,16 +45,16 @@ template<bool Conjugate> struct conj_if;
template<> struct conj_if<true> { template<> struct conj_if<true> {
template<typename T> template<typename T>
inline T operator()(const T& x) const { return numext::conj(x); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return numext::conj(x); }
template<typename T> template<typename T>
inline T pconj(const T& x) const { return internal::pconj(x); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T pconj(const T& x) const { return internal::pconj(x); }
}; };
template<> struct conj_if<false> { template<> struct conj_if<false> {
template<typename T> template<typename T>
inline const T& operator()(const T& x) const { return x; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator()(const T& x) const { return x; }
template<typename T> template<typename T>
inline const T& pconj(const T& x) const { return x; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; }
}; };
// Generic implementation. // Generic implementation.
@ -63,10 +63,10 @@ struct conj_helper
{ {
typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType; typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType;
EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
{ return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); } { return Eigen::internal::pmadd(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y), c); }
EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
{ return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); } { return Eigen::internal::pmul(conj_if<ConjLhs>().pconj(x), conj_if<ConjRhs>().pconj(y)); }
}; };
@ -75,10 +75,10 @@ struct conj_helper<LhsType, RhsType, true, true>
{ {
typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType; typedef typename ScalarBinaryOpTraits<LhsType,RhsType>::ReturnType ResultType;
EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const
{ return Eigen::internal::pmadd(pconj(x), pconj(y), c); } { return Eigen::internal::pmadd(pconj(x), pconj(y), c); }
// We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b). // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b).
EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType pmul(const LhsType& x, const RhsType& y) const
{ return pconj(Eigen::internal::pmul(x, y)); } { return pconj(Eigen::internal::pmul(x, y)); }
}; };
@ -86,18 +86,18 @@ struct conj_helper<LhsType, RhsType, true, true>
template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false> template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
{ {
typedef std::complex<RealScalar> Scalar; typedef std::complex<RealScalar> Scalar;
EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
{ return c + conj_if<Conj>().pconj(x) * y; } { return c + conj_if<Conj>().pconj(x) * y; }
EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
{ return conj_if<Conj>().pconj(x) * y; } { return conj_if<Conj>().pconj(x) * y; }
}; };
template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj> template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
{ {
typedef std::complex<RealScalar> Scalar; typedef std::complex<RealScalar> Scalar;
EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
{ return c + pmul(x,y); } { return c + pmul(x,y); }
EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
{ return x * conj_if<Conj>().pconj(y); } { return x * conj_if<Conj>().pconj(y); }
}; };