mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-09 15:53:16 +08:00
Replace calls to numext::fma with numext:madd.
This commit is contained in:
parent
52f570a409
commit
2e8cc042a1
@ -1350,20 +1350,20 @@ struct pmadd_impl {
|
||||
template <typename Scalar>
|
||||
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, c);
|
||||
return numext::madd<Scalar>(a, b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, Scalar(-c));
|
||||
return numext::madd<Scalar>(a, b, Scalar(-c));
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(Scalar(-a), b, c);
|
||||
return numext::madd<Scalar>(Scalar(-a), b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return -Scalar(numext::fma(a, b, c));
|
||||
return -Scalar(numext::madd<Scalar>(a, b, c));
|
||||
}
|
||||
};
|
||||
|
||||
// FMA instructions.
|
||||
// Multiply-add instructions.
|
||||
/** \internal \returns a * b + c (coeff-wise) */
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
|
||||
|
@ -941,23 +941,44 @@ struct nearest_integer_impl<Scalar, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
|
||||
};
|
||||
|
||||
// Extra namespace to prevent leaking std::fma into Eigen::internal.
|
||||
namespace has_fma_detail {
|
||||
|
||||
template <typename T, typename EnableIf = void>
|
||||
struct has_fma_impl : public std::false_type {};
|
||||
|
||||
using std::fma;
|
||||
|
||||
template <typename T>
|
||||
struct has_fma_impl<
|
||||
T, std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>>
|
||||
: public std::true_type {};
|
||||
|
||||
} // namespace has_fma_detail
|
||||
|
||||
template <typename T>
|
||||
struct has_fma : public has_fma_detail::has_fma_impl<T> {};
|
||||
|
||||
// Default implementation.
|
||||
template <typename Scalar, typename Enable = void>
|
||||
template <typename T, typename Enable = void>
|
||||
struct fma_impl {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return a * b + c;
|
||||
static_assert(has_fma<T>::value, "No function fma(...) for type. Please provide an implementation.");
|
||||
};
|
||||
|
||||
// STD or ADL version if it exists.
|
||||
template <typename T>
|
||||
struct fma_impl<T, std::enable_if_t<has_fma<T>::value>> {
|
||||
static T run(const T& a, const T& b, const T& c) {
|
||||
using std::fma;
|
||||
return fma(a, b, c);
|
||||
}
|
||||
};
|
||||
|
||||
// ADL version if it exists.
|
||||
template <typename T>
|
||||
struct fma_impl<
|
||||
T,
|
||||
std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
|
||||
static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
|
||||
};
|
||||
|
||||
#if defined(EIGEN_GPUCC)
|
||||
template <>
|
||||
struct has_fma<float> : public true_type {
|
||||
}
|
||||
|
||||
template <>
|
||||
struct fma_impl<float, void> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
|
||||
@ -965,6 +986,10 @@ struct fma_impl<float, void> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct has_fma<double> : public true_type {
|
||||
}
|
||||
|
||||
template <>
|
||||
struct fma_impl<double, void> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
|
||||
@ -973,6 +998,24 @@ struct fma_impl<double, void> {
|
||||
};
|
||||
#endif
|
||||
|
||||
// Basic multiply-add.
|
||||
template <typename Scalar, typename EnableIf = void>
|
||||
struct madd_impl {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||
return x * y + z;
|
||||
}
|
||||
};
|
||||
|
||||
// Use FMA if there is a single CPU instruction.
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
template <typename Scalar>
|
||||
struct madd_impl<Scalar, std::enable_if_t<has_fma<Scalar>::value>> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||
return fma_impl<Scalar>::run(x, y, z);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/****************************************************************************
|
||||
@ -1886,15 +1929,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar
|
||||
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
|
||||
}
|
||||
|
||||
// Use std::fma if available.
|
||||
using std::fma;
|
||||
|
||||
// Otherwise, rely on template implementation.
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||
return internal::fma_impl<Scalar>::run(x, y, z);
|
||||
}
|
||||
|
||||
// Multiply-add.
|
||||
template <typename Scalar>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar madd(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||
return internal::madd_impl<Scalar>::run(x, y, z);
|
||||
}
|
||||
|
||||
} // end namespace numext
|
||||
|
||||
namespace internal {
|
||||
|
@ -2026,38 +2026,38 @@ EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d&
|
||||
}
|
||||
|
||||
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
#if defined(EIGEN_VECTORIZE_FMA)
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
|
||||
return ::fmaf(a, b, c);
|
||||
return std::fmaf(a, b, c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
|
||||
return ::fma(a, b, c);
|
||||
return std::fma(a, b, c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
|
||||
return ::fmaf(a, b, -c);
|
||||
return std::fmaf(a, b, -c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
|
||||
return ::fma(a, b, -c);
|
||||
return std::fma(a, b, -c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
|
||||
return ::fmaf(-a, b, c);
|
||||
return std::fmaf(-a, b, c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
|
||||
return ::fma(-a, b, c);
|
||||
return std::fma(-a, b, c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
|
||||
return ::fmaf(-a, b, -c);
|
||||
return std::fmaf(-a, b, -c);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
|
||||
return ::fma(-a, b, -c);
|
||||
return std::fma(-a, b, -c);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -36,10 +36,10 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
|
||||
Scalar res1(0);
|
||||
Scalar res2(0);
|
||||
for (; i; ++i) {
|
||||
res1 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res1);
|
||||
res1 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res1);
|
||||
++i;
|
||||
if (i) {
|
||||
res2 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res2);
|
||||
res2 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res2);
|
||||
}
|
||||
}
|
||||
return res1 + res2;
|
||||
@ -67,7 +67,7 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
|
||||
Scalar res(0);
|
||||
while (i && j) {
|
||||
if (i.index() == j.index()) {
|
||||
res = numext::fma(numext::conj(i.value()), j.value(), res);
|
||||
res = numext::madd<Scalar>(numext::conj(i.value()), j.value(), res);
|
||||
++i;
|
||||
++j;
|
||||
} else if (i.index() < j.index())
|
||||
|
@ -41,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
|
||||
lastVal = it.value();
|
||||
lastIndex = it.index();
|
||||
if (lastIndex == i) break;
|
||||
tmp = numext::fma(-lastVal, other.coeff(lastIndex, col), tmp);
|
||||
tmp = numext::madd<Scalar>(-lastVal, other.coeff(lastIndex, col), tmp);
|
||||
}
|
||||
if (Mode & UnitDiag)
|
||||
other.coeffRef(i, col) = tmp;
|
||||
@ -75,7 +75,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
|
||||
} else if (it && it.index() == i)
|
||||
++it;
|
||||
for (; it; ++it) {
|
||||
tmp = numext::fma<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
|
||||
tmp = numext::madd<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
|
||||
}
|
||||
|
||||
if (Mode & UnitDiag)
|
||||
@ -108,7 +108,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
|
||||
}
|
||||
if (it && it.index() == i) ++it;
|
||||
for (; it; ++it) {
|
||||
other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
||||
other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -138,7 +138,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
|
||||
}
|
||||
LhsIterator it(lhsEval, i);
|
||||
for (; it && it.index() < i; ++it) {
|
||||
other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
||||
other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -220,11 +220,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
|
||||
if (IsLower) {
|
||||
if (it.index() == i) ++it;
|
||||
for (; it; ++it) {
|
||||
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
|
||||
tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
|
||||
}
|
||||
} else {
|
||||
for (; it && it.index() < i; ++it) {
|
||||
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
|
||||
tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -44,16 +44,16 @@ template <typename Scalar>
|
||||
struct madd_impl<Scalar,
|
||||
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, c);
|
||||
return numext::madd(a, b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(a, b, Scalar(-c));
|
||||
return numext::madd(a, b, Scalar(-c));
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return numext::fma(Scalar(-a), b, c);
|
||||
return numext::madd(Scalar(-a), b, c);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||
return -Scalar(numext::fma(a, b, c));
|
||||
return -Scalar(numext::madd(a, b, c));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -42,7 +42,7 @@ template <typename Scalar, typename V1, typename V2>
|
||||
Scalar ref_dot_product(const V1& v1, const V2& v2) {
|
||||
Scalar out = Scalar(0);
|
||||
for (Index i = 0; i < v1.size(); ++i) {
|
||||
out = Eigen::numext::fma(v1[i], v2[i], out);
|
||||
out = Eigen::numext::madd(v1[i], v2[i], out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -254,8 +254,6 @@ void product(const MatrixType& m) {
|
||||
// inner product
|
||||
{
|
||||
Scalar x = square2.row(c) * square2.col(c2);
|
||||
// NOTE: FMA is necessary here in the reference to ensure accuracy for
|
||||
// large vector sizes and float16/bfloat16 types.
|
||||
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
|
||||
VERIFY_IS_APPROX(x, y);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user