Remove the dot product's separate implementation and use cwiseProduct.sum instead.

Also take special care to get nicely working static  assertions.
This commit is contained in:
Benoit Jacob 2010-02-27 10:03:27 -05:00
parent b5c79e7291
commit d9f6380499
2 changed files with 18 additions and 229 deletions

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2006-2008, 2010 Benoit Jacob <jacob.benoit.1@gmail.com>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
@ -25,224 +25,28 @@
#ifndef EIGEN_DOT_H
#define EIGEN_DOT_H
/***************************************************************************
* Part 1 : the logic deciding a strategy for vectorization and unrolling
***************************************************************************/
template<typename Derived1, typename Derived2>
struct ei_dot_traits
// helper function for dot(). The problem is that if we put that in the body of dot(), then upon calling dot
// with mismatched types, the compiler emits errors about failing to instantiate cwiseProduct BEFORE
// looking at the static assertions. Thus this is a trick to get better compile errors.
template<typename T, typename U,
bool IsSameType = ei_is_same_type<typename T::Scalar, typename U::Scalar>::ret>
struct ei_dot_nocheck
{
public:
enum {
Traversal = (int(Derived1::Flags)&int(Derived2::Flags)&ActualPacketAccessBit)
&& (int(Derived1::Flags)&int(Derived2::Flags)&LinearAccessBit)
? LinearVectorizedTraversal
: DefaultTraversal
};
private:
typedef typename Derived1::Scalar Scalar;
enum {
PacketSize = ei_packet_traits<Scalar>::size,
Cost = Derived1::SizeAtCompileTime * (Derived1::CoeffReadCost + Derived2::CoeffReadCost + NumTraits<Scalar>::MulCost)
+ (Derived1::SizeAtCompileTime-1) * NumTraits<Scalar>::AddCost,
UnrollingLimit = EIGEN_UNROLLING_LIMIT * (int(Traversal) == int(DefaultTraversal) ? 1 : int(PacketSize))
};
public:
enum {
Unrolling = Cost <= UnrollingLimit
? CompleteUnrolling
: NoUnrolling
};
};
/***************************************************************************
* Part 2 : unrollers
***************************************************************************/
/*** no vectorization ***/
template<typename Derived1, typename Derived2, int Start, int Length>
struct ei_dot_novec_unroller
static inline typename ei_traits<T>::Scalar run(const MatrixBase<T>& a, const MatrixBase<U>& b)
{
enum {
HalfLength = Length/2
};
typedef typename Derived1::Scalar Scalar;
inline static Scalar run(const Derived1& v1, const Derived2& v2)
{
return ei_dot_novec_unroller<Derived1, Derived2, Start, HalfLength>::run(v1, v2)
+ ei_dot_novec_unroller<Derived1, Derived2, Start+HalfLength, Length-HalfLength>::run(v1, v2);
return a.conjugate().cwiseProduct(b).sum();
}
};
template<typename Derived1, typename Derived2, int Start>
struct ei_dot_novec_unroller<Derived1, Derived2, Start, 1>
template<typename T, typename U>
struct ei_dot_nocheck<T, U, false>
{
typedef typename Derived1::Scalar Scalar;
inline static Scalar run(const Derived1& v1, const Derived2& v2)
static inline typename ei_traits<T>::Scalar run(const MatrixBase<T>&, const MatrixBase<U>&)
{
return ei_conj(v1.coeff(Start)) * v2.coeff(Start);
return typename ei_traits<T>::Scalar(0);
}
};
/*** vectorization ***/
template<typename Derived1, typename Derived2, int Index, int Stop,
bool LastPacket = (Stop-Index == ei_packet_traits<typename Derived1::Scalar>::size)>
struct ei_dot_vec_unroller
{
typedef typename Derived1::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
enum {
row1 = Derived1::RowsAtCompileTime == 1 ? 0 : Index,
col1 = Derived1::RowsAtCompileTime == 1 ? Index : 0,
row2 = Derived2::RowsAtCompileTime == 1 ? 0 : Index,
col2 = Derived2::RowsAtCompileTime == 1 ? Index : 0
};
inline static PacketScalar run(const Derived1& v1, const Derived2& v2)
{
return ei_pmadd(
v1.template packet<Aligned>(row1, col1),
v2.template packet<Aligned>(row2, col2),
ei_dot_vec_unroller<Derived1, Derived2, Index+ei_packet_traits<Scalar>::size, Stop>::run(v1, v2)
);
}
};
template<typename Derived1, typename Derived2, int Index, int Stop>
struct ei_dot_vec_unroller<Derived1, Derived2, Index, Stop, true>
{
enum {
row1 = Derived1::RowsAtCompileTime == 1 ? 0 : Index,
col1 = Derived1::RowsAtCompileTime == 1 ? Index : 0,
row2 = Derived2::RowsAtCompileTime == 1 ? 0 : Index,
col2 = Derived2::RowsAtCompileTime == 1 ? Index : 0,
alignment1 = (Derived1::Flags & AlignedBit) ? Aligned : Unaligned,
alignment2 = (Derived2::Flags & AlignedBit) ? Aligned : Unaligned
};
typedef typename Derived1::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
inline static PacketScalar run(const Derived1& v1, const Derived2& v2)
{
return ei_pmul(v1.template packet<alignment1>(row1, col1), v2.template packet<alignment2>(row2, col2));
}
};
/***************************************************************************
* Part 3 : implementation of all cases
***************************************************************************/
template<typename Derived1, typename Derived2,
int Traversal = ei_dot_traits<Derived1, Derived2>::Traversal,
int Unrolling = ei_dot_traits<Derived1, Derived2>::Unrolling
>
struct ei_dot_impl;
template<typename Derived1, typename Derived2>
struct ei_dot_impl<Derived1, Derived2, DefaultTraversal, NoUnrolling>
{
typedef typename Derived1::Scalar Scalar;
static Scalar run(const Derived1& v1, const Derived2& v2)
{
ei_assert(v1.size()>0 && "you are using a non initialized vector");
Scalar res;
res = ei_conj(v1.coeff(0)) * v2.coeff(0);
for(int i = 1; i < v1.size(); ++i)
res += ei_conj(v1.coeff(i)) * v2.coeff(i);
return res;
}
};
template<typename Derived1, typename Derived2>
struct ei_dot_impl<Derived1, Derived2, DefaultTraversal, CompleteUnrolling>
: public ei_dot_novec_unroller<Derived1, Derived2, 0, Derived1::SizeAtCompileTime>
{};
template<typename Derived1, typename Derived2>
struct ei_dot_impl<Derived1, Derived2, LinearVectorizedTraversal, NoUnrolling>
{
typedef typename Derived1::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
static Scalar run(const Derived1& v1, const Derived2& v2)
{
const int size = v1.size();
const int packetSize = ei_packet_traits<Scalar>::size;
const int alignedSize = (size/packetSize)*packetSize;
enum {
alignment1 = (Derived1::Flags & AlignedBit) ? Aligned : Unaligned,
alignment2 = (Derived2::Flags & AlignedBit) ? Aligned : Unaligned
};
Scalar res;
// do the vectorizable part of the sum
if(size >= packetSize)
{
PacketScalar packet_res = ei_pmul(
v1.template packet<alignment1>(0),
v2.template packet<alignment2>(0)
);
for(int index = packetSize; index<alignedSize; index += packetSize)
{
packet_res = ei_pmadd(
v1.template packet<alignment1>(index),
v2.template packet<alignment2>(index),
packet_res
);
}
res = ei_predux(packet_res);
// now we must do the rest without vectorization.
if(alignedSize == size) return res;
}
else // too small to vectorize anything.
// since this is dynamic-size hence inefficient anyway for such small sizes, don't try to optimize.
{
res = Scalar(0);
}
// do the remainder of the vector
for(int index = alignedSize; index < size; ++index)
{
res += v1.coeff(index) * v2.coeff(index);
}
return res;
}
};
template<typename Derived1, typename Derived2>
struct ei_dot_impl<Derived1, Derived2, LinearVectorizedTraversal, CompleteUnrolling>
{
typedef typename Derived1::Scalar Scalar;
typedef typename ei_packet_traits<Scalar>::type PacketScalar;
enum {
PacketSize = ei_packet_traits<Scalar>::size,
Size = Derived1::SizeAtCompileTime,
VectorizedSize = (Size / PacketSize) * PacketSize
};
static Scalar run(const Derived1& v1, const Derived2& v2)
{
Scalar res = ei_predux(ei_dot_vec_unroller<Derived1, Derived2, 0, VectorizedSize>::run(v1, v2));
if (VectorizedSize != Size)
res += ei_dot_novec_unroller<Derived1, Derived2, VectorizedSize, Size-VectorizedSize>::run(v1, v2);
return res;
}
};
/***************************************************************************
* Part 4 : implementation of MatrixBase methods
***************************************************************************/
/** \returns the dot product of *this with other.
*
* \only_for_vectors
@ -266,10 +70,7 @@ MatrixBase<Derived>::dot(const MatrixBase<OtherDerived>& other) const
ei_assert(size() == other.size());
// dot() must honor EvalBeforeNestingBit (eg: v.dot(M*v) )
typedef typename ei_cleantype<typename Derived::Nested>::type ThisNested;
typedef typename ei_cleantype<typename OtherDerived::Nested>::type OtherNested;
return ei_dot_impl<ThisNested, OtherNested>::run(derived(), other.derived());
return ei_dot_nocheck<Derived,OtherDerived>::run(*this, other);
}
/** \returns the squared \em l2 norm of *this, i.e., for vectors, the dot product of *this with itself.

View File

@ -305,10 +305,7 @@ struct ei_product_coeff_vectorized_dyn_selector
{
EIGEN_STRONG_INLINE static void run(int row, int col, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
{
res = ei_dot_impl<
Block<Lhs, 1, ei_traits<Lhs>::ColsAtCompileTime>,
Block<Rhs, ei_traits<Rhs>::RowsAtCompileTime, 1>,
LinearVectorizedTraversal, NoUnrolling>::run(lhs.row(row), rhs.col(col));
res = lhs.row(row).cwiseProduct(rhs.col(col)).sum();
}
};
@ -319,10 +316,7 @@ struct ei_product_coeff_vectorized_dyn_selector<Lhs,Rhs,1,RhsCols>
{
EIGEN_STRONG_INLINE static void run(int /*row*/, int col, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
{
res = ei_dot_impl<
Lhs,
Block<Rhs, ei_traits<Rhs>::RowsAtCompileTime, 1>,
LinearVectorizedTraversal, NoUnrolling>::run(lhs, rhs.col(col));
res = lhs.cwiseProduct(rhs.col(col)).sum();
}
};
@ -331,10 +325,7 @@ struct ei_product_coeff_vectorized_dyn_selector<Lhs,Rhs,LhsRows,1>
{
EIGEN_STRONG_INLINE static void run(int row, int /*col*/, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
{
res = ei_dot_impl<
Block<Lhs, 1, ei_traits<Lhs>::ColsAtCompileTime>,
Rhs,
LinearVectorizedTraversal, NoUnrolling>::run(lhs.row(row), rhs);
res = lhs.row(row).cwiseProduct(rhs).sum();
}
};
@ -343,10 +334,7 @@ struct ei_product_coeff_vectorized_dyn_selector<Lhs,Rhs,1,1>
{
EIGEN_STRONG_INLINE static void run(int /*row*/, int /*col*/, const Lhs& lhs, const Rhs& rhs, typename Lhs::Scalar &res)
{
res = ei_dot_impl<
Lhs,
Rhs,
LinearVectorizedTraversal, NoUnrolling>::run(lhs, rhs);
res = lhs.cwiseProduct(rhs).sum();
}
};