mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 18:59:01 +08:00
Better dot products
This commit is contained in:
parent
134b526d61
commit
fb477b8be1
@ -324,6 +324,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/CwiseNullaryOp.h"
|
||||
#include "src/Core/CwiseUnaryView.h"
|
||||
#include "src/Core/SelfCwiseBinaryOp.h"
|
||||
#include "src/Core/InnerProduct.h"
|
||||
#include "src/Core/Dot.h"
|
||||
#include "src/Core/StableNorm.h"
|
||||
#include "src/Core/Stride.h"
|
||||
|
@ -17,30 +17,6 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// helper function for dot(). The problem is that if we put that in the body of dot(), then upon calling dot
|
||||
// with mismatched types, the compiler emits errors about failing to instantiate cwiseProduct BEFORE
|
||||
// looking at the static assertions. Thus this is a trick to get better compile errors.
|
||||
template <typename T, typename U,
|
||||
bool NeedToTranspose = T::IsVectorAtCompileTime && U::IsVectorAtCompileTime &&
|
||||
((int(T::RowsAtCompileTime) == 1 && int(U::ColsAtCompileTime) == 1) ||
|
||||
(int(T::ColsAtCompileTime) == 1 && int(U::RowsAtCompileTime) == 1))>
|
||||
struct dot_nocheck {
|
||||
typedef scalar_conj_product_op<typename traits<T>::Scalar, typename traits<U>::Scalar> conj_prod;
|
||||
typedef typename conj_prod::result_type ResScalar;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b) {
|
||||
return a.template binaryExpr<conj_prod>(b).sum();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct dot_nocheck<T, U, true> {
|
||||
typedef scalar_conj_product_op<typename traits<T>::Scalar, typename traits<U>::Scalar> conj_prod;
|
||||
typedef typename conj_prod::result_type ResScalar;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static ResScalar run(const MatrixBase<T>& a, const MatrixBase<U>& b) {
|
||||
return a.transpose().template binaryExpr<conj_prod>(b).sum();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Derived, typename Scalar = typename traits<Derived>::Scalar>
|
||||
struct squared_norm_impl {
|
||||
using Real = typename NumTraits<Scalar>::Real;
|
||||
@ -74,18 +50,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
typename ScalarBinaryOpTraits<typename internal::traits<Derived>::Scalar,
|
||||
typename internal::traits<OtherDerived>::Scalar>::ReturnType
|
||||
MatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const {
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
|
||||
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Derived, OtherDerived)
|
||||
#if !(defined(EIGEN_NO_STATIC_ASSERT) && defined(EIGEN_NO_DEBUG))
|
||||
EIGEN_CHECK_BINARY_COMPATIBILIY(
|
||||
Eigen::internal::scalar_conj_product_op<Scalar EIGEN_COMMA typename OtherDerived::Scalar>, Scalar,
|
||||
typename OtherDerived::Scalar);
|
||||
#endif
|
||||
|
||||
eigen_assert(size() == other.size());
|
||||
|
||||
return internal::dot_nocheck<Derived, OtherDerived>::run(*this, other);
|
||||
return internal::dot_impl<Derived, OtherDerived>::run(derived(), other.derived());
|
||||
}
|
||||
|
||||
//---------- implementation of L2 norm and related functions ----------
|
||||
|
253
Eigen/src/Core/InnerProduct.h
Normal file
253
Eigen/src/Core/InnerProduct.h
Normal file
@ -0,0 +1,253 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2024 Charlie Schlosser <cs.schlosser@gmail.com>
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_INNER_PRODUCT_EVAL_H
|
||||
#define EIGEN_INNER_PRODUCT_EVAL_H
|
||||
|
||||
// IWYU pragma: private
|
||||
#include "./InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// recursively searches for the largest simd type that does not exceed Size, or the smallest if no such type exists
|
||||
template <typename Scalar, int Size, typename Packet = typename packet_traits<Scalar>::type,
|
||||
bool Stop =
|
||||
(unpacket_traits<Packet>::size <= Size) || is_same<Packet, typename unpacket_traits<Packet>::half>::value>
|
||||
struct find_inner_product_packet_helper;
|
||||
|
||||
template <typename Scalar, int Size, typename Packet>
|
||||
struct find_inner_product_packet_helper<Scalar, Size, Packet, false> {
|
||||
using type = typename find_inner_product_packet_helper<Scalar, Size, typename unpacket_traits<Packet>::half>::type;
|
||||
};
|
||||
|
||||
template <typename Scalar, int Size, typename Packet>
|
||||
struct find_inner_product_packet_helper<Scalar, Size, Packet, true> {
|
||||
using type = Packet;
|
||||
};
|
||||
|
||||
template <typename Scalar, int Size>
|
||||
struct find_inner_product_packet : find_inner_product_packet_helper<Scalar, Size> {};
|
||||
|
||||
template <typename Scalar>
|
||||
struct find_inner_product_packet<Scalar, Dynamic> {
|
||||
using type = typename packet_traits<Scalar>::type;
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs>
|
||||
struct inner_product_assert {
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Lhs)
|
||||
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Rhs)
|
||||
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Lhs, Rhs)
|
||||
#ifndef EIGEN_NO_DEBUG
|
||||
static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, const Rhs& rhs) {
|
||||
eigen_assert((lhs.size() == rhs.size()) && "Inner product: lhs and rhs vectors must have same size");
|
||||
}
|
||||
#else
|
||||
static EIGEN_DEVICE_FUNC void run(const Lhs&, const Rhs&) {}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename Func, typename Lhs, typename Rhs>
|
||||
struct inner_product_evaluator {
|
||||
static constexpr int LhsFlags = evaluator<Lhs>::Flags, RhsFlags = evaluator<Rhs>::Flags,
|
||||
SizeAtCompileTime = min_size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime),
|
||||
LhsAlignment = evaluator<Lhs>::Alignment, RhsAlignment = evaluator<Rhs>::Alignment;
|
||||
|
||||
using Scalar = typename Func::result_type;
|
||||
using Packet = typename find_inner_product_packet<Scalar, SizeAtCompileTime>::type;
|
||||
|
||||
static constexpr bool Vectorize =
|
||||
bool(LhsFlags & RhsFlags & PacketAccessBit) && Func::PacketAccess &&
|
||||
((SizeAtCompileTime == Dynamic) || (unpacket_traits<Packet>::size <= SizeAtCompileTime));
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit inner_product_evaluator(const Lhs& lhs, const Rhs& rhs,
|
||||
Func func = Func())
|
||||
: m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) {
|
||||
inner_product_assert<Lhs, Rhs>::run(lhs, rhs);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_size.value(); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const {
|
||||
return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& value, Index index) const {
|
||||
return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index));
|
||||
}
|
||||
|
||||
template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
|
||||
EIGEN_STRONG_INLINE PacketType packet(Index index) const {
|
||||
return m_func.packet(m_lhs.template packet<LhsMode, PacketType>(index),
|
||||
m_rhs.template packet<RhsMode, PacketType>(index));
|
||||
}
|
||||
|
||||
template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
|
||||
EIGEN_STRONG_INLINE PacketType packet(const PacketType& value, Index index) const {
|
||||
return m_func.packet(value, m_lhs.template packet<LhsMode, PacketType>(index),
|
||||
m_rhs.template packet<RhsMode, PacketType>(index));
|
||||
}
|
||||
|
||||
const Func m_func;
|
||||
const evaluator<Lhs> m_lhs;
|
||||
const evaluator<Rhs> m_rhs;
|
||||
const variable_if_dynamic<Index, SizeAtCompileTime> m_size;
|
||||
};
|
||||
|
||||
template <typename Evaluator, bool Vectorize = Evaluator::Vectorize>
|
||||
struct inner_product_impl;
|
||||
|
||||
// scalar loop
|
||||
template <typename Evaluator>
|
||||
struct inner_product_impl<Evaluator, false> {
|
||||
using Scalar = typename Evaluator::Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
|
||||
const Index size = eval.size();
|
||||
if (size == 0) return Scalar(0);
|
||||
|
||||
Scalar result = eval.coeff(0);
|
||||
for (Index k = 1; k < size; k++) {
|
||||
result = eval.coeff(result, k);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// vector loop
|
||||
template <typename Evaluator>
|
||||
struct inner_product_impl<Evaluator, true> {
|
||||
using UnsignedIndex = std::make_unsigned_t<Index>;
|
||||
using Scalar = typename Evaluator::Scalar;
|
||||
using Packet = typename Evaluator::Packet;
|
||||
static constexpr int PacketSize = unpacket_traits<Packet>::size;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
|
||||
const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
|
||||
if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
|
||||
const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
|
||||
const UnsignedIndex numPackets = size / PacketSize;
|
||||
|
||||
Packet presult0 = eval.template packet<Packet>(0 * PacketSize);
|
||||
Packet presult1 = pzero(Packet());
|
||||
Packet presult2 = pzero(Packet());
|
||||
Packet presult3 = pzero(Packet());
|
||||
|
||||
if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
|
||||
if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
|
||||
if (numPackets >= 4) {
|
||||
presult3 = eval.template packet<Packet>(3 * PacketSize);
|
||||
|
||||
const UnsignedIndex numRemPackets = (numPackets - 4) % 4;
|
||||
const UnsignedIndex quadStart = 4 * PacketSize;
|
||||
const UnsignedIndex quadEnd = (numPackets - numRemPackets) * PacketSize;
|
||||
|
||||
for (UnsignedIndex k = quadStart; k < quadEnd; k += 4 * PacketSize) {
|
||||
presult0 = eval.packet(presult0, k + 0 * PacketSize);
|
||||
presult1 = eval.packet(presult1, k + 1 * PacketSize);
|
||||
presult2 = eval.packet(presult2, k + 2 * PacketSize);
|
||||
presult3 = eval.packet(presult3, k + 3 * PacketSize);
|
||||
}
|
||||
|
||||
if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
|
||||
if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
|
||||
if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
|
||||
}
|
||||
|
||||
Scalar result = predux(padd(padd(presult0, presult1), padd(presult2, presult3)));
|
||||
|
||||
if (size > packetEnd) {
|
||||
Scalar scalarAccum = eval.coeff(packetEnd);
|
||||
for (UnsignedIndex k = packetEnd + 1; k < size; k++) {
|
||||
scalarAccum = eval.coeff(scalarAccum, k);
|
||||
}
|
||||
result += scalarAccum;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar, bool Conj>
|
||||
struct conditional_conj;
|
||||
|
||||
template <typename Scalar>
|
||||
struct conditional_conj<Scalar, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return numext::conj(a); }
|
||||
template <typename Packet>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) {
|
||||
return pconj(a);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct conditional_conj<Scalar, false> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return a; }
|
||||
template <typename Packet>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) {
|
||||
return a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, bool Conj>
|
||||
struct scalar_inner_product_op {
|
||||
using result_type = typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType;
|
||||
using conj_helper = conditional_conj<LhsScalar, Conj>;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar& a, const RhsScalar& b) const {
|
||||
return (conj_helper::coeff(a) * b);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const result_type& accum, const LhsScalar& a,
|
||||
const RhsScalar& b) const {
|
||||
return (conj_helper::coeff(a) * b) + accum;
|
||||
}
|
||||
static constexpr bool PacketAccess = false;
|
||||
};
|
||||
|
||||
template <typename Scalar, bool Conj>
|
||||
struct scalar_inner_product_op<Scalar, Scalar, Conj> {
|
||||
using result_type = Scalar;
|
||||
using conj_helper = conditional_conj<Scalar, Conj>;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a, const Scalar& b) const {
|
||||
return pmul(conj_helper::coeff(a), b);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& accum, const Scalar& a, const Scalar& b) const {
|
||||
return pmadd(conj_helper::coeff(a), b, accum);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a, const Packet& b) const {
|
||||
return pmul(conj_helper::packet(a), b);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& accum, const Packet& a, const Packet& b) const {
|
||||
return pmadd(conj_helper::packet(a), b, accum);
|
||||
}
|
||||
static constexpr bool PacketAccess = packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasAdd;
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs, bool Conj>
|
||||
struct default_inner_product_impl {
|
||||
using LhsScalar = typename traits<Lhs>::Scalar;
|
||||
using RhsScalar = typename traits<Rhs>::Scalar;
|
||||
using Op = scalar_inner_product_op<LhsScalar, RhsScalar, Conj>;
|
||||
using Evaluator = inner_product_evaluator<Op, Lhs, Rhs>;
|
||||
using result_type = typename Evaluator::Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(const MatrixBase<Lhs>& a, const MatrixBase<Rhs>& b) {
|
||||
Evaluator eval(a.derived(), b.derived(), Op());
|
||||
return inner_product_impl<Evaluator>::run(eval);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs>
|
||||
struct dot_impl : default_inner_product_impl<Lhs, Rhs, true> {};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_INNER_PRODUCT_EVAL_H
|
@ -235,19 +235,20 @@ EIGEN_CATCH_ASSIGN_XPR_OP_PRODUCT(sub_assign_op, scalar_difference_op, add_assig
|
||||
|
||||
template <typename Lhs, typename Rhs>
|
||||
struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, InnerProduct> {
|
||||
using impl = default_inner_product_impl<Lhs, Rhs, false>;
|
||||
template <typename Dst>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
|
||||
dst.coeffRef(0, 0) = (lhs.transpose().cwiseProduct(rhs)).sum();
|
||||
dst.coeffRef(0, 0) = impl::run(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename Dst>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
|
||||
dst.coeffRef(0, 0) += (lhs.transpose().cwiseProduct(rhs)).sum();
|
||||
dst.coeffRef(0, 0) += impl::run(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename Dst>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) {
|
||||
dst.coeffRef(0, 0) -= (lhs.transpose().cwiseProduct(rhs)).sum();
|
||||
dst.coeffRef(0, 0) -= impl::run(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user