mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-18 02:44:27 +08:00
Remove the dot product's separate implementation and use cwiseProduct.sum instead.
Also take special care to get nicely working static assertions.
This commit is contained in:
parent
b5c79e7291
commit
d9f6380499
@ -1,7 +1,7 @@
|
|||||||
// This file is part of Eigen, a lightweight C++ template library
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
// for linear algebra.
|
// for linear algebra.
|
||||||
//
|
//
|
||||||
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
|
// Copyright (C) 2006-2008, 2010 Benoit Jacob <jacob.benoit.1@gmail.com>
|
||||||
//
|
//
|
||||||
// Eigen is free software; you can redistribute it and/or
|
// Eigen is free software; you can redistribute it and/or
|
||||||
// modify it under the terms of the GNU Lesser General Public
|
// modify it under the terms of the GNU Lesser General Public
|
||||||
@ -25,224 +25,28 @@
|
|||||||
#ifndef EIGEN_DOT_H
|
#ifndef EIGEN_DOT_H
|
||||||
#define EIGEN_DOT_H
|
#define EIGEN_DOT_H
|
||||||
|
|
||||||
/***************************************************************************
|
// helper function for dot(). The problem is that if we put that in the body of dot(), then upon calling dot
|
||||||
* Part 1 : the logic deciding a strategy for vectorization and unrolling
|
// with mismatched types, the compiler emits errors about failing to instantiate cwiseProduct BEFORE
|
||||||
***************************************************************************/
|
// looking at the static assertions. Thus this is a trick to get better compile errors.
|
||||||
|
template<typename T, typename U,
|
||||||
template<typename Derived1, typename Derived2>
|
bool IsSameType = ei_is_same_type<typename T::Scalar, typename U::Scalar>::ret>
|
||||||
struct ei_dot_traits
|
struct ei_dot_nocheck
|
||||||
{
|
{
|
||||||
public:
|
static inline typename ei_traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
|
||||||
enum {
|
|
||||||
Traversal = (int(Derived1::Flags)&int(Derived2::Flags)&ActualPacketAccessBit)
|
|
||||||
&& (int(Derived1::Flags)&int(Derived2::Flags)&LinearAccessBit)
|
|
||||||
? LinearVectorizedTraversal
|
|
||||||
: DefaultTraversal
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
enum {
|
|
||||||
PacketSize = ei_packet_traits<Scalar>::size,
|
|
||||||
Cost = Derived1::SizeAtCompileTime * (Derived1::CoeffReadCost + Derived2::CoeffReadCost + NumTraits<Scalar>::MulCost)
|
|
||||||
+ (Derived1::SizeAtCompileTime-1) * NumTraits<Scalar>::AddCost,
|
|
||||||
UnrollingLimit = EIGEN_UNROLLING_LIMIT * (int(Traversal) == int(DefaultTraversal) ? 1 : int(PacketSize))
|
|
||||||
};
|
|
||||||
|
|
||||||
public:
|
|
||||||
enum {
|
|
||||||
Unrolling = Cost <= UnrollingLimit
|
|
||||||
? CompleteUnrolling
|
|
||||||
: NoUnrolling
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
/***************************************************************************
|
|
||||||
* Part 2 : unrollers
|
|
||||||
***************************************************************************/
|
|
||||||
|
|
||||||
/*** no vectorization ***/
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2, int Start, int Length>
|
|
||||||
struct ei_dot_novec_unroller
|
|
||||||
{
|
{
|
||||||
enum {
|
return a.conjugate().cwiseProduct(b).sum();
|
||||||
HalfLength = Length/2
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
|
|
||||||
inline static Scalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
|
||||||
return ei_dot_novec_unroller<Derived1, Derived2, Start, HalfLength>::run(v1, v2)
|
|
||||||
+ ei_dot_novec_unroller<Derived1, Derived2, Start+HalfLength, Length-HalfLength>::run(v1, v2);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2, int Start>
|
template<typename T, typename U>
|
||||||
struct ei_dot_novec_unroller<Derived1, Derived2, Start, 1>
|
struct ei_dot_nocheck<T, U, false>
|
||||||
{
|
{
|
||||||
typedef typename Derived1::Scalar Scalar;
|
static inline typename ei_traits<T>::Scalar run(const MatrixBase<T>&, const MatrixBase<U>&)
|
||||||
|
|
||||||
inline static Scalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
{
|
||||||
return ei_conj(v1.coeff(Start)) * v2.coeff(Start);
|
return typename ei_traits<T>::Scalar(0);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*** vectorization ***/
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2, int Index, int Stop,
|
|
||||||
bool LastPacket = (Stop-Index == ei_packet_traits<typename Derived1::Scalar>::size)>
|
|
||||||
struct ei_dot_vec_unroller
|
|
||||||
{
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
|
|
||||||
|
|
||||||
enum {
|
|
||||||
row1 = Derived1::RowsAtCompileTime == 1 ? 0 : Index,
|
|
||||||
col1 = Derived1::RowsAtCompileTime == 1 ? Index : 0,
|
|
||||||
row2 = Derived2::RowsAtCompileTime == 1 ? 0 : Index,
|
|
||||||
col2 = Derived2::RowsAtCompileTime == 1 ? Index : 0
|
|
||||||
};
|
|
||||||
|
|
||||||
inline static PacketScalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
|
||||||
return ei_pmadd(
|
|
||||||
v1.template packet<Aligned>(row1, col1),
|
|
||||||
v2.template packet<Aligned>(row2, col2),
|
|
||||||
ei_dot_vec_unroller<Derived1, Derived2, Index+ei_packet_traits<Scalar>::size, Stop>::run(v1, v2)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2, int Index, int Stop>
|
|
||||||
struct ei_dot_vec_unroller<Derived1, Derived2, Index, Stop, true>
|
|
||||||
{
|
|
||||||
enum {
|
|
||||||
row1 = Derived1::RowsAtCompileTime == 1 ? 0 : Index,
|
|
||||||
col1 = Derived1::RowsAtCompileTime == 1 ? Index : 0,
|
|
||||||
row2 = Derived2::RowsAtCompileTime == 1 ? 0 : Index,
|
|
||||||
col2 = Derived2::RowsAtCompileTime == 1 ? Index : 0,
|
|
||||||
alignment1 = (Derived1::Flags & AlignedBit) ? Aligned : Unaligned,
|
|
||||||
alignment2 = (Derived2::Flags & AlignedBit) ? Aligned : Unaligned
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
|
|
||||||
|
|
||||||
inline static PacketScalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
|
||||||
return ei_pmul(v1.template packet<alignment1>(row1, col1), v2.template packet<alignment2>(row2, col2));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/***************************************************************************
|
|
||||||
* Part 3 : implementation of all cases
|
|
||||||
***************************************************************************/
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2,
|
|
||||||
int Traversal = ei_dot_traits<Derived1, Derived2>::Traversal,
|
|
||||||
int Unrolling = ei_dot_traits<Derived1, Derived2>::Unrolling
|
|
||||||
>
|
|
||||||
struct ei_dot_impl;
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2>
|
|
||||||
struct ei_dot_impl<Derived1, Derived2, DefaultTraversal, NoUnrolling>
|
|
||||||
{
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
static Scalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
|
||||||
ei_assert(v1.size()>0 && "you are using a non initialized vector");
|
|
||||||
Scalar res;
|
|
||||||
res = ei_conj(v1.coeff(0)) * v2.coeff(0);
|
|
||||||
for(int i = 1; i < v1.size(); ++i)
|
|
||||||
res += ei_conj(v1.coeff(i)) * v2.coeff(i);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2>
|
|
||||||
struct ei_dot_impl<Derived1, Derived2, DefaultTraversal, CompleteUnrolling>
|
|
||||||
: public ei_dot_novec_unroller<Derived1, Derived2, 0, Derived1::SizeAtCompileTime>
|
|
||||||
{};
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2>
|
|
||||||
struct ei_dot_impl<Derived1, Derived2, LinearVectorizedTraversal, NoUnrolling>
|
|
||||||
{
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
|
|
||||||
|
|
||||||
static Scalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
|
||||||
const int size = v1.size();
|
|
||||||
const int packetSize = ei_packet_traits<Scalar>::size;
|
|
||||||
const int alignedSize = (size/packetSize)*packetSize;
|
|
||||||
enum {
|
|
||||||
alignment1 = (Derived1::Flags & AlignedBit) ? Aligned : Unaligned,
|
|
||||||
alignment2 = (Derived2::Flags & AlignedBit) ? Aligned : Unaligned
|
|
||||||
};
|
|
||||||
Scalar res;
|
|
||||||
|
|
||||||
// do the vectorizable part of the sum
|
|
||||||
if(size >= packetSize)
|
|
||||||
{
|
|
||||||
PacketScalar packet_res = ei_pmul(
|
|
||||||
v1.template packet<alignment1>(0),
|
|
||||||
v2.template packet<alignment2>(0)
|
|
||||||
);
|
|
||||||
for(int index = packetSize; index<alignedSize; index += packetSize)
|
|
||||||
{
|
|
||||||
packet_res = ei_pmadd(
|
|
||||||
v1.template packet<alignment1>(index),
|
|
||||||
v2.template packet<alignment2>(index),
|
|
||||||
packet_res
|
|
||||||
);
|
|
||||||
}
|
|
||||||
res = ei_predux(packet_res);
|
|
||||||
|
|
||||||
// now we must do the rest without vectorization.
|
|
||||||
if(alignedSize == size) return res;
|
|
||||||
}
|
|
||||||
else // too small to vectorize anything.
|
|
||||||
// since this is dynamic-size hence inefficient anyway for such small sizes, don't try to optimize.
|
|
||||||
{
|
|
||||||
res = Scalar(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// do the remainder of the vector
|
|
||||||
for(int index = alignedSize; index < size; ++index)
|
|
||||||
{
|
|
||||||
res += v1.coeff(index) * v2.coeff(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Derived1, typename Derived2>
|
|
||||||
struct ei_dot_impl<Derived1, Derived2, LinearVectorizedTraversal, CompleteUnrolling>
|
|
||||||
{
|
|
||||||
typedef typename Derived1::Scalar Scalar;
|
|
||||||
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
|
|
||||||
enum {
|
|
||||||
PacketSize = ei_packet_traits<Scalar>::size,
|
|
||||||
Size = Derived1::SizeAtCompileTime,
|
|
||||||
VectorizedSize = (Size / PacketSize) * PacketSize
|
|
||||||
};
|
|
||||||
static Scalar run(const Derived1& v1, const Derived2& v2)
|
|
||||||
{
|
|
||||||
Scalar res = ei_predux(ei_dot_vec_unroller<Derived1, Derived2, 0, VectorizedSize>::run(v1, v2));
|
|
||||||
if (VectorizedSize != Size)
|
|
||||||
res += ei_dot_novec_unroller<Derived1, Derived2, VectorizedSize, Size-VectorizedSize>::run(v1, v2);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/***************************************************************************
|
|
||||||
* Part 4 : implementation of MatrixBase methods
|
|
||||||
***************************************************************************/
|
|
||||||
|
|
||||||
/** \returns the dot product of *this with other.
|
/** \returns the dot product of *this with other.
|
||||||
*
|
*
|
||||||
* \only_for_vectors
|
* \only_for_vectors
|
||||||
@ -266,10 +70,7 @@ MatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const
|
|||||||
|
|
||||||
ei_assert(size() == other.size());
|
ei_assert(size() == other.size());
|
||||||
|
|
||||||
// dot() must honor EvalBeforeNestingBit (eg: v.dot(M*v) )
|
return ei_dot_nocheck<Derived,OtherDerived>::run(*this, other);
|
||||||
typedef typename ei_cleantype<typename Derived::Nested>::type ThisNested;
|
|
||||||
typedef typename ei_cleantype<typename OtherDerived::Nested>::type OtherNested;
|
|
||||||
return ei_dot_impl<ThisNested, OtherNested>::run(derived(), other.derived());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \returns the squared \em l2 norm of *this, i.e., for vectors, the dot product of *this with itself.
|
/** \returns the squared \em l2 norm of *this, i.e., for vectors, the dot product of *this with itself.
|
||||||
|
@ -305,10 +305,7 @@ struct ei_product_coeff_vectorized_dyn_selector
|
|||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE static void run(int row, int col, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
EIGEN_STRONG_INLINE static void run(int row, int col, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
||||||
{
|
{
|
||||||
res = ei_dot_impl<
|
res = lhs.row(row).cwiseProduct(rhs.col(col)).sum();
|
||||||
Block<Lhs, 1, ei_traits<Lhs>::ColsAtCompileTime>,
|
|
||||||
Block<Rhs, ei_traits<Rhs>::RowsAtCompileTime, 1>,
|
|
||||||
LinearVectorizedTraversal, NoUnrolling>::run(lhs.row(row), rhs.col(col));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -319,10 +316,7 @@ struct ei_product_coeff_vectorized_dyn_selector<Lhs,Rhs,1,RhsCols>
|
|||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE static void run(int /*row*/, int col, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
EIGEN_STRONG_INLINE static void run(int /*row*/, int col, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
||||||
{
|
{
|
||||||
res = ei_dot_impl<
|
res = lhs.cwiseProduct(rhs.col(col)).sum();
|
||||||
Lhs,
|
|
||||||
Block<Rhs, ei_traits<Rhs>::RowsAtCompileTime, 1>,
|
|
||||||
LinearVectorizedTraversal, NoUnrolling>::run(lhs, rhs.col(col));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -331,10 +325,7 @@ struct ei_product_coeff_vectorized_dyn_selector<Lhs,Rhs,LhsRows,1>
|
|||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE static void run(int row, int /*col*/, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
EIGEN_STRONG_INLINE static void run(int row, int /*col*/, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
||||||
{
|
{
|
||||||
res = ei_dot_impl<
|
res = lhs.row(row).cwiseProduct(rhs).sum();
|
||||||
Block<Lhs, 1, ei_traits<Lhs>::ColsAtCompileTime>,
|
|
||||||
Rhs,
|
|
||||||
LinearVectorizedTraversal, NoUnrolling>::run(lhs.row(row), rhs);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -343,10 +334,7 @@ struct ei_product_coeff_vectorized_dyn_selector<Lhs,Rhs,1,1>
|
|||||||
{
|
{
|
||||||
EIGEN_STRONG_INLINE static void run(int /*row*/, int /*col*/, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
EIGEN_STRONG_INLINE static void run(int /*row*/, int /*col*/, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
|
||||||
{
|
{
|
||||||
res = ei_dot_impl<
|
res = lhs.cwiseProduct(rhs).sum();
|
||||||
Lhs,
|
|
||||||
Rhs,
|
|
||||||
LinearVectorizedTraversal, NoUnrolling>::run(lhs, rhs);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user