mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-10 00:03:17 +08:00
Replace calls to numext::fma with numext:madd.
This commit is contained in:
parent
52f570a409
commit
2e8cc042a1
@ -1350,20 +1350,20 @@ struct pmadd_impl {
|
|||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
|
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return numext::fma(a, b, c);
|
return numext::madd<Scalar>(a, b, c);
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return numext::fma(a, b, Scalar(-c));
|
return numext::madd<Scalar>(a, b, Scalar(-c));
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return numext::fma(Scalar(-a), b, c);
|
return numext::madd<Scalar>(Scalar(-a), b, c);
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return -Scalar(numext::fma(a, b, c));
|
return -Scalar(numext::madd<Scalar>(a, b, c));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// FMA instructions.
|
// Multiply-add instructions.
|
||||||
/** \internal \returns a * b + c (coeff-wise) */
|
/** \internal \returns a * b + c (coeff-wise) */
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
|
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
|
||||||
|
@ -941,23 +941,44 @@ struct nearest_integer_impl<Scalar, true> {
|
|||||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Extra namespace to prevent leaking std::fma into Eigen::internal.
|
||||||
|
namespace has_fma_detail {
|
||||||
|
|
||||||
|
template <typename T, typename EnableIf = void>
|
||||||
|
struct has_fma_impl : public std::false_type {};
|
||||||
|
|
||||||
|
using std::fma;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct has_fma_impl<
|
||||||
|
T, std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>>
|
||||||
|
: public std::true_type {};
|
||||||
|
|
||||||
|
} // namespace has_fma_detail
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct has_fma : public has_fma_detail::has_fma_impl<T> {};
|
||||||
|
|
||||||
// Default implementation.
|
// Default implementation.
|
||||||
template <typename Scalar, typename Enable = void>
|
template <typename T, typename Enable = void>
|
||||||
struct fma_impl {
|
struct fma_impl {
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static_assert(has_fma<T>::value, "No function fma(...) for type. Please provide an implementation.");
|
||||||
return a * b + c;
|
};
|
||||||
|
|
||||||
|
// STD or ADL version if it exists.
|
||||||
|
template <typename T>
|
||||||
|
struct fma_impl<T, std::enable_if_t<has_fma<T>::value>> {
|
||||||
|
static T run(const T& a, const T& b, const T& c) {
|
||||||
|
using std::fma;
|
||||||
|
return fma(a, b, c);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ADL version if it exists.
|
|
||||||
template <typename T>
|
|
||||||
struct fma_impl<
|
|
||||||
T,
|
|
||||||
std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
|
|
||||||
static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(EIGEN_GPUCC)
|
#if defined(EIGEN_GPUCC)
|
||||||
|
template <>
|
||||||
|
struct has_fma<float> : public true_type {
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct fma_impl<float, void> {
|
struct fma_impl<float, void> {
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
|
||||||
@ -965,6 +986,10 @@ struct fma_impl<float, void> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct has_fma<double> : public true_type {
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct fma_impl<double, void> {
|
struct fma_impl<double, void> {
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
|
||||||
@ -973,6 +998,24 @@ struct fma_impl<double, void> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Basic multiply-add.
|
||||||
|
template <typename Scalar, typename EnableIf = void>
|
||||||
|
struct madd_impl {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||||
|
return x * y + z;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use FMA if there is a single CPU instruction.
|
||||||
|
#ifdef EIGEN_VECTORIZE_FMA
|
||||||
|
template <typename Scalar>
|
||||||
|
struct madd_impl<Scalar, std::enable_if_t<has_fma<Scalar>::value>> {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||||
|
return fma_impl<Scalar>::run(x, y, z);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
/****************************************************************************
|
/****************************************************************************
|
||||||
@ -1886,15 +1929,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar
|
|||||||
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
|
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use std::fma if available.
|
|
||||||
using std::fma;
|
|
||||||
|
|
||||||
// Otherwise, rely on template implementation.
|
// Otherwise, rely on template implementation.
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||||
return internal::fma_impl<Scalar>::run(x, y, z);
|
return internal::fma_impl<Scalar>::run(x, y, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Multiply-add.
|
||||||
|
template <typename Scalar>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar madd(const Scalar& x, const Scalar& y, const Scalar& z) {
|
||||||
|
return internal::madd_impl<Scalar>::run(x, y, z);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace numext
|
} // end namespace numext
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
@ -2026,38 +2026,38 @@ EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d&
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
|
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
|
||||||
#ifdef EIGEN_VECTORIZE_FMA
|
#if defined(EIGEN_VECTORIZE_FMA)
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
|
EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
|
||||||
return ::fmaf(a, b, c);
|
return std::fmaf(a, b, c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
|
EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
|
||||||
return ::fma(a, b, c);
|
return std::fma(a, b, c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
|
EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
|
||||||
return ::fmaf(a, b, -c);
|
return std::fmaf(a, b, -c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
|
EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
|
||||||
return ::fma(a, b, -c);
|
return std::fma(a, b, -c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
|
EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
|
||||||
return ::fmaf(-a, b, c);
|
return std::fmaf(-a, b, c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
|
EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
|
||||||
return ::fma(-a, b, c);
|
return std::fma(-a, b, c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
|
EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
|
||||||
return ::fmaf(-a, b, -c);
|
return std::fmaf(-a, b, -c);
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
|
EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
|
||||||
return ::fma(-a, b, -c);
|
return std::fma(-a, b, -c);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -36,10 +36,10 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
|
|||||||
Scalar res1(0);
|
Scalar res1(0);
|
||||||
Scalar res2(0);
|
Scalar res2(0);
|
||||||
for (; i; ++i) {
|
for (; i; ++i) {
|
||||||
res1 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res1);
|
res1 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res1);
|
||||||
++i;
|
++i;
|
||||||
if (i) {
|
if (i) {
|
||||||
res2 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res2);
|
res2 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res1 + res2;
|
return res1 + res2;
|
||||||
@ -67,7 +67,7 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
|
|||||||
Scalar res(0);
|
Scalar res(0);
|
||||||
while (i && j) {
|
while (i && j) {
|
||||||
if (i.index() == j.index()) {
|
if (i.index() == j.index()) {
|
||||||
res = numext::fma(numext::conj(i.value()), j.value(), res);
|
res = numext::madd<Scalar>(numext::conj(i.value()), j.value(), res);
|
||||||
++i;
|
++i;
|
||||||
++j;
|
++j;
|
||||||
} else if (i.index() < j.index())
|
} else if (i.index() < j.index())
|
||||||
|
@ -41,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
|
|||||||
lastVal = it.value();
|
lastVal = it.value();
|
||||||
lastIndex = it.index();
|
lastIndex = it.index();
|
||||||
if (lastIndex == i) break;
|
if (lastIndex == i) break;
|
||||||
tmp = numext::fma(-lastVal, other.coeff(lastIndex, col), tmp);
|
tmp = numext::madd<Scalar>(-lastVal, other.coeff(lastIndex, col), tmp);
|
||||||
}
|
}
|
||||||
if (Mode & UnitDiag)
|
if (Mode & UnitDiag)
|
||||||
other.coeffRef(i, col) = tmp;
|
other.coeffRef(i, col) = tmp;
|
||||||
@ -75,7 +75,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
|
|||||||
} else if (it && it.index() == i)
|
} else if (it && it.index() == i)
|
||||||
++it;
|
++it;
|
||||||
for (; it; ++it) {
|
for (; it; ++it) {
|
||||||
tmp = numext::fma<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
|
tmp = numext::madd<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Mode & UnitDiag)
|
if (Mode & UnitDiag)
|
||||||
@ -108,7 +108,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
|
|||||||
}
|
}
|
||||||
if (it && it.index() == i) ++it;
|
if (it && it.index() == i) ++it;
|
||||||
for (; it; ++it) {
|
for (; it; ++it) {
|
||||||
other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -138,7 +138,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
|
|||||||
}
|
}
|
||||||
LhsIterator it(lhsEval, i);
|
LhsIterator it(lhsEval, i);
|
||||||
for (; it && it.index() < i; ++it) {
|
for (; it && it.index() < i; ++it) {
|
||||||
other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -220,11 +220,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
|
|||||||
if (IsLower) {
|
if (IsLower) {
|
||||||
if (it.index() == i) ++it;
|
if (it.index() == i) ++it;
|
||||||
for (; it; ++it) {
|
for (; it; ++it) {
|
||||||
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
|
tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (; it && it.index() < i; ++it) {
|
for (; it && it.index() < i; ++it) {
|
||||||
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index()));
|
tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,16 +44,16 @@ template <typename Scalar>
|
|||||||
struct madd_impl<Scalar,
|
struct madd_impl<Scalar,
|
||||||
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
|
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return numext::fma(a, b, c);
|
return numext::madd(a, b, c);
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return numext::fma(a, b, Scalar(-c));
|
return numext::madd(a, b, Scalar(-c));
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return numext::fma(Scalar(-a), b, c);
|
return numext::madd(Scalar(-a), b, c);
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
|
||||||
return -Scalar(numext::fma(a, b, c));
|
return -Scalar(numext::madd(a, b, c));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ template <typename Scalar, typename V1, typename V2>
|
|||||||
Scalar ref_dot_product(const V1& v1, const V2& v2) {
|
Scalar ref_dot_product(const V1& v1, const V2& v2) {
|
||||||
Scalar out = Scalar(0);
|
Scalar out = Scalar(0);
|
||||||
for (Index i = 0; i < v1.size(); ++i) {
|
for (Index i = 0; i < v1.size(); ++i) {
|
||||||
out = Eigen::numext::fma(v1[i], v2[i], out);
|
out = Eigen::numext::madd(v1[i], v2[i], out);
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@ -254,8 +254,6 @@ void product(const MatrixType& m) {
|
|||||||
// inner product
|
// inner product
|
||||||
{
|
{
|
||||||
Scalar x = square2.row(c) * square2.col(c2);
|
Scalar x = square2.row(c) * square2.col(c2);
|
||||||
// NOTE: FMA is necessary here in the reference to ensure accuracy for
|
|
||||||
// large vector sizes and float16/bfloat16 types.
|
|
||||||
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
|
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
|
||||||
VERIFY_IS_APPROX(x, y);
|
VERIFY_IS_APPROX(x, y);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user