From fb477b8be14558ba997c4cadf3667e6efef30646 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Tue, 10 Sep 2024 21:02:31 +0000 Subject: [PATCH] Better dot products --- Eigen/Core | 1 + Eigen/src/Core/Dot.h | 37 +---- Eigen/src/Core/InnerProduct.h | 253 +++++++++++++++++++++++++++++ Eigen/src/Core/ProductEvaluators.h | 7 +- 4 files changed, 259 insertions(+), 39 deletions(-) create mode 100644 Eigen/src/Core/InnerProduct.h diff --git a/Eigen/Core b/Eigen/Core index 29dda3932..e6dbe3a8e 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -324,6 +324,7 @@ using std::ptrdiff_t; #include "src/Core/CwiseNullaryOp.h" #include "src/Core/CwiseUnaryView.h" #include "src/Core/SelfCwiseBinaryOp.h" +#include "src/Core/InnerProduct.h" #include "src/Core/Dot.h" #include "src/Core/StableNorm.h" #include "src/Core/Stride.h" diff --git a/Eigen/src/Core/Dot.h b/Eigen/src/Core/Dot.h index dd4a2c4dd..059527c85 100644 --- a/Eigen/src/Core/Dot.h +++ b/Eigen/src/Core/Dot.h @@ -17,30 +17,6 @@ namespace Eigen { namespace internal { -// helper function for dot(). The problem is that if we put that in the body of dot(), then upon calling dot -// with mismatched types, the compiler emits errors about failing to instantiate cwiseProduct BEFORE -// looking at the static assertions. Thus this is a trick to get better compile errors. -template -struct dot_nocheck { - typedef scalar_conj_product_op::Scalar, typename traits::Scalar> conj_prod; - typedef typename conj_prod::result_type ResScalar; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static ResScalar run(const MatrixBase& a, const MatrixBase& b) { - return a.template binaryExpr(b).sum(); - } -}; - -template -struct dot_nocheck { - typedef scalar_conj_product_op::Scalar, typename traits::Scalar> conj_prod; - typedef typename conj_prod::result_type ResScalar; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static ResScalar run(const MatrixBase& a, const MatrixBase& b) { - return a.transpose().template binaryExpr(b).sum(); - } -}; - template ::Scalar> struct squared_norm_impl { using Real = typename NumTraits::Real; @@ -74,18 +50,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ScalarBinaryOpTraits::Scalar, typename internal::traits::Scalar>::ReturnType MatrixBase::dot(const MatrixBase& other) const { - EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived) - EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived) - EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived, OtherDerived) -#if !(defined(EIGEN_NO_STATIC_ASSERT) && defined(EIGEN_NO_DEBUG)) - EIGEN_CHECK_BINARY_COMPATIBILIY( - Eigen::internal::scalar_conj_product_op, Scalar, - typename OtherDerived::Scalar); -#endif - - eigen_assert(size() == other.size()); - - return internal::dot_nocheck::run(*this, other); + return internal::dot_impl::run(derived(), other.derived()); } //---------- implementation of L2 norm and related functions ---------- diff --git a/Eigen/src/Core/InnerProduct.h b/Eigen/src/Core/InnerProduct.h new file mode 100644 index 000000000..c8b1c1d0d --- /dev/null +++ b/Eigen/src/Core/InnerProduct.h @@ -0,0 +1,253 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2024 Charlie Schlosser +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_INNER_PRODUCT_EVAL_H +#define EIGEN_INNER_PRODUCT_EVAL_H + +// IWYU pragma: private +#include "./InternalHeaderCheck.h" + +namespace Eigen { + +namespace internal { + +// recursively searches for the largest simd type that does not exceed Size, or the smallest if no such type exists +template ::type, + bool Stop = + (unpacket_traits::size <= Size) || is_same::half>::value> +struct find_inner_product_packet_helper; + +template +struct find_inner_product_packet_helper { + using type = typename find_inner_product_packet_helper::half>::type; +}; + +template +struct find_inner_product_packet_helper { + using type = Packet; +}; + +template +struct find_inner_product_packet : find_inner_product_packet_helper {}; + +template +struct find_inner_product_packet { + using type = typename packet_traits::type; +}; + +template +struct inner_product_assert { + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Lhs) + EIGEN_STATIC_ASSERT_VECTOR_ONLY(Rhs) + EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Lhs, Rhs) +#ifndef EIGEN_NO_DEBUG + static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, const Rhs& rhs) { + eigen_assert((lhs.size() == rhs.size()) && "Inner product: lhs and rhs vectors must have same size"); + } +#else + static EIGEN_DEVICE_FUNC void run(const Lhs&, const Rhs&) {} +#endif +}; + +template +struct inner_product_evaluator { + static constexpr int LhsFlags = evaluator::Flags, RhsFlags = evaluator::Flags, + SizeAtCompileTime = min_size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime), + LhsAlignment = evaluator::Alignment, RhsAlignment = evaluator::Alignment; + + using Scalar = typename Func::result_type; + using Packet = typename find_inner_product_packet::type; + + static constexpr bool Vectorize = + bool(LhsFlags & RhsFlags & PacketAccessBit) && Func::PacketAccess && + ((SizeAtCompileTime == Dynamic) || (unpacket_traits::size <= SizeAtCompileTime)); + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit inner_product_evaluator(const Lhs& lhs, const Rhs& rhs, + Func func = Func()) + : m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) { + inner_product_assert::run(lhs, rhs); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_size.value(); } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const { + return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index)); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& value, Index index) const { + return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index)); + } + + template + EIGEN_STRONG_INLINE PacketType packet(Index index) const { + return m_func.packet(m_lhs.template packet(index), + m_rhs.template packet(index)); + } + + template + EIGEN_STRONG_INLINE PacketType packet(const PacketType& value, Index index) const { + return m_func.packet(value, m_lhs.template packet(index), + m_rhs.template packet(index)); + } + + const Func m_func; + const evaluator m_lhs; + const evaluator m_rhs; + const variable_if_dynamic m_size; +}; + +template +struct inner_product_impl; + +// scalar loop +template +struct inner_product_impl { + using Scalar = typename Evaluator::Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) { + const Index size = eval.size(); + if (size == 0) return Scalar(0); + + Scalar result = eval.coeff(0); + for (Index k = 1; k < size; k++) { + result = eval.coeff(result, k); + } + + return result; + } +}; + +// vector loop +template +struct inner_product_impl { + using UnsignedIndex = std::make_unsigned_t; + using Scalar = typename Evaluator::Scalar; + using Packet = typename Evaluator::Packet; + static constexpr int PacketSize = unpacket_traits::size; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) { + const UnsignedIndex size = static_cast(eval.size()); + if (size < PacketSize) return inner_product_impl::run(eval); + const UnsignedIndex packetEnd = numext::round_down(size, PacketSize); + const UnsignedIndex numPackets = size / PacketSize; + + Packet presult0 = eval.template packet(0 * PacketSize); + Packet presult1 = pzero(Packet()); + Packet presult2 = pzero(Packet()); + Packet presult3 = pzero(Packet()); + + if (numPackets >= 2) presult1 = eval.template packet(1 * PacketSize); + if (numPackets >= 3) presult2 = eval.template packet(2 * PacketSize); + if (numPackets >= 4) { + presult3 = eval.template packet(3 * PacketSize); + + const UnsignedIndex numRemPackets = (numPackets - 4) % 4; + const UnsignedIndex quadStart = 4 * PacketSize; + const UnsignedIndex quadEnd = (numPackets - numRemPackets) * PacketSize; + + for (UnsignedIndex k = quadStart; k < quadEnd; k += 4 * PacketSize) { + presult0 = eval.packet(presult0, k + 0 * PacketSize); + presult1 = eval.packet(presult1, k + 1 * PacketSize); + presult2 = eval.packet(presult2, k + 2 * PacketSize); + presult3 = eval.packet(presult3, k + 3 * PacketSize); + } + + if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize); + if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize); + if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize); + } + + Scalar result = predux(padd(padd(presult0, presult1), padd(presult2, presult3))); + + if (size > packetEnd) { + Scalar scalarAccum = eval.coeff(packetEnd); + for (UnsignedIndex k = packetEnd + 1; k < size; k++) { + scalarAccum = eval.coeff(scalarAccum, k); + } + result += scalarAccum; + } + + return result; + } +}; + +template +struct conditional_conj; + +template +struct conditional_conj { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return numext::conj(a); } + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) { + return pconj(a); + } +}; + +template +struct conditional_conj { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return a; } + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) { + return a; + } +}; + +template +struct scalar_inner_product_op { + using result_type = typename ScalarBinaryOpTraits::ReturnType; + using conj_helper = conditional_conj; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar& a, const RhsScalar& b) const { + return (conj_helper::coeff(a) * b); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const result_type& accum, const LhsScalar& a, + const RhsScalar& b) const { + return (conj_helper::coeff(a) * b) + accum; + } + static constexpr bool PacketAccess = false; +}; + +template +struct scalar_inner_product_op { + using result_type = Scalar; + using conj_helper = conditional_conj; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a, const Scalar& b) const { + return pmul(conj_helper::coeff(a), b); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& accum, const Scalar& a, const Scalar& b) const { + return pmadd(conj_helper::coeff(a), b, accum); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a, const Packet& b) const { + return pmul(conj_helper::packet(a), b); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& accum, const Packet& a, const Packet& b) const { + return pmadd(conj_helper::packet(a), b, accum); + } + static constexpr bool PacketAccess = packet_traits::HasMul && packet_traits::HasAdd; +}; + +template +struct default_inner_product_impl { + using LhsScalar = typename traits::Scalar; + using RhsScalar = typename traits::Scalar; + using Op = scalar_inner_product_op; + using Evaluator = inner_product_evaluator; + using result_type = typename Evaluator::Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(const MatrixBase& a, const MatrixBase& b) { + Evaluator eval(a.derived(), b.derived(), Op()); + return inner_product_impl::run(eval); + } +}; + +template +struct dot_impl : default_inner_product_impl {}; + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_INNER_PRODUCT_EVAL_H diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index fa4d0384b..77a658a8e 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -235,19 +235,20 @@ EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op, scalar_difference_op, add_assig template struct generic_product_impl { + using impl = default_inner_product_impl; template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { - dst.coeffRef(0, 0) = (lhs.transpose().cwiseProduct(rhs)).sum(); + dst.coeffRef(0, 0) = impl::run(lhs, rhs); } template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { - dst.coeffRef(0, 0) += (lhs.transpose().cwiseProduct(rhs)).sum(); + dst.coeffRef(0, 0) += impl::run(lhs, rhs); } template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { - dst.coeffRef(0, 0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); + dst.coeffRef(0, 0) -= impl::run(lhs, rhs); } };