// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2024 Charlie Schlosser // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_INNER_PRODUCT_EVAL_H #define EIGEN_INNER_PRODUCT_EVAL_H // IWYU pragma: private #include "./InternalHeaderCheck.h" namespace Eigen { namespace internal { // recursively searches for the largest simd type that does not exceed Size, or the smallest if no such type exists template ::type, bool Stop = (unpacket_traits::size <= Size) || is_same::half>::value> struct find_inner_product_packet_helper; template struct find_inner_product_packet_helper { using type = typename find_inner_product_packet_helper::half>::type; }; template struct find_inner_product_packet_helper { using type = Packet; }; template struct find_inner_product_packet : find_inner_product_packet_helper {}; template struct find_inner_product_packet { using type = typename packet_traits::type; }; template struct inner_product_assert { EIGEN_STATIC_ASSERT_VECTOR_ONLY(Lhs) EIGEN_STATIC_ASSERT_VECTOR_ONLY(Rhs) EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Lhs, Rhs) #ifndef EIGEN_NO_DEBUG static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, const Rhs& rhs) { eigen_assert((lhs.size() == rhs.size()) && "Inner product: lhs and rhs vectors must have same size"); } #else static EIGEN_DEVICE_FUNC void run(const Lhs&, const Rhs&) {} #endif }; template struct inner_product_evaluator { static constexpr int LhsFlags = evaluator::Flags, RhsFlags = evaluator::Flags, SizeAtCompileTime = min_size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime), LhsAlignment = evaluator::Alignment, RhsAlignment = evaluator::Alignment; using Scalar = typename Func::result_type; using Packet = typename find_inner_product_packet::type; static constexpr bool Vectorize = bool(LhsFlags & RhsFlags & PacketAccessBit) && Func::PacketAccess && ((SizeAtCompileTime == Dynamic) || (unpacket_traits::size <= SizeAtCompileTime)); EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit inner_product_evaluator(const Lhs& lhs, const Rhs& rhs, Func func = Func()) : m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) { inner_product_assert::run(lhs, rhs); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_size.value(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const { return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& value, Index index) const { return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const { return m_func.packet(m_lhs.template packet(index), m_rhs.template packet(index)); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(const PacketType& value, Index index) const { return m_func.packet(value, m_lhs.template packet(index), m_rhs.template packet(index)); } const Func m_func; const evaluator m_lhs; const evaluator m_rhs; const variable_if_dynamic m_size; }; template struct inner_product_impl; // scalar loop template struct inner_product_impl { using Scalar = typename Evaluator::Scalar; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) { const Index size = eval.size(); if (size == 0) return Scalar(0); Scalar result = eval.coeff(0); for (Index k = 1; k < size; k++) { result = eval.coeff(result, k); } return result; } }; // vector loop template struct inner_product_impl { using UnsignedIndex = std::make_unsigned_t; using Scalar = typename Evaluator::Scalar; using Packet = typename Evaluator::Packet; static constexpr int PacketSize = unpacket_traits::size; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) { const UnsignedIndex size = static_cast(eval.size()); if (size < PacketSize) return inner_product_impl::run(eval); const UnsignedIndex packetEnd = numext::round_down(size, PacketSize); const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize); const UnsignedIndex numPackets = size / PacketSize; const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize; Packet presult0, presult1, presult2, presult3; presult0 = eval.template packet(0 * PacketSize); if (numPackets >= 2) presult1 = eval.template packet(1 * PacketSize); if (numPackets >= 3) presult2 = eval.template packet(2 * PacketSize); if (numPackets >= 4) { presult3 = eval.template packet(3 * PacketSize); for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) { presult0 = eval.packet(presult0, k + 0 * PacketSize); presult1 = eval.packet(presult1, k + 1 * PacketSize); presult2 = eval.packet(presult2, k + 2 * PacketSize); presult3 = eval.packet(presult3, k + 3 * PacketSize); } if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize); if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize); if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize); presult2 = padd(presult2, presult3); } if (numPackets >= 3) presult1 = padd(presult1, presult2); if (numPackets >= 2) presult0 = padd(presult0, presult1); Scalar result = predux(presult0); for (UnsignedIndex k = packetEnd; k < size; k++) { result = eval.coeff(result, k); } return result; } }; template struct conditional_conj; template struct conditional_conj { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return numext::conj(a); } template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) { return pconj(a); } }; template struct conditional_conj { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return a; } template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) { return a; } }; template struct scalar_inner_product_op { using result_type = typename ScalarBinaryOpTraits::ReturnType; using conj_helper = conditional_conj; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar& a, const RhsScalar& b) const { return (conj_helper::coeff(a) * b); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const result_type& accum, const LhsScalar& a, const RhsScalar& b) const { return (conj_helper::coeff(a) * b) + accum; } static constexpr bool PacketAccess = false; }; template struct scalar_inner_product_op { using result_type = Scalar; using conj_helper = conditional_conj; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a, const Scalar& b) const { return pmul(conj_helper::coeff(a), b); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& accum, const Scalar& a, const Scalar& b) const { return pmadd(conj_helper::coeff(a), b, accum); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a, const Packet& b) const { return pmul(conj_helper::packet(a), b); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& accum, const Packet& a, const Packet& b) const { return pmadd(conj_helper::packet(a), b, accum); } static constexpr bool PacketAccess = packet_traits::HasMul && packet_traits::HasAdd; }; template struct default_inner_product_impl { using LhsScalar = typename traits::Scalar; using RhsScalar = typename traits::Scalar; using Op = scalar_inner_product_op; using Evaluator = inner_product_evaluator; using result_type = typename Evaluator::Scalar; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(const MatrixBase& a, const MatrixBase& b) { Evaluator eval(a.derived(), b.derived(), Op()); return inner_product_impl::run(eval); } }; template struct dot_impl : default_inner_product_impl {}; } // namespace internal } // namespace Eigen #endif // EIGEN_INNER_PRODUCT_EVAL_H