specialize for Size==0 in order to catch user bugs and not clutter

the compiler output with an infinite recursion. Also add a #define switch
for loop unrolling.
This commit is contained in:
Benoit Jacob 2007-12-11 10:04:39 +00:00
parent 9d51572cbe
commit fc924bc7d4
4 changed files with 62 additions and 13 deletions

View File

@ -40,6 +40,17 @@ template<int UnrollCount, int Rows> struct CopyHelperUnroller
} }
}; };
// prevent buggy user code from causing an infinite recursion
template<int UnrollCount> struct CopyHelperUnroller<UnrollCount, 0>
{
template <typename Derived1, typename Derived2>
static void run(Derived1 &dst, const Derived2 &src)
{
EIGEN_UNUSED(dst);
EIGEN_UNUSED(src);
}
};
template<int Rows> struct CopyHelperUnroller<1, Rows> template<int Rows> struct CopyHelperUnroller<1, Rows>
{ {
template <typename Derived1, typename Derived2> template <typename Derived1, typename Derived2>
@ -63,7 +74,7 @@ template<typename Scalar, typename Derived>
template<typename OtherDerived> template<typename OtherDerived>
void MatrixBase<Scalar, Derived>::_copy_helper(const MatrixBase<Scalar, OtherDerived>& other) void MatrixBase<Scalar, Derived>::_copy_helper(const MatrixBase<Scalar, OtherDerived>& other)
{ {
if(SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25) if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 25)
CopyHelperUnroller<SizeAtCompileTime, RowsAtCompileTime>::run(*this, other); CopyHelperUnroller<SizeAtCompileTime, RowsAtCompileTime>::run(*this, other);
else else
for(int i = 0; i < rows(); i++) for(int i = 0; i < rows(); i++)

View File

@ -32,7 +32,7 @@ struct DotUnroller
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot) static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
{ {
DotUnroller<Index-1, Size, Derived1, Derived2>::run(v1, v2, dot); DotUnroller<Index-1, Size, Derived1, Derived2>::run(v1, v2, dot);
dot += v1[Index] * conj(v2[Index]); dot += v1.read(Index) * conj(v2.read(Index));
} }
}; };
@ -41,7 +41,7 @@ struct DotUnroller<0, Size, Derived1, Derived2>
{ {
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot) static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
{ {
dot = v1[0] * conj(v2[0]); dot = v1.read(0) * conj(v2.read(0));
} }
}; };
@ -56,20 +56,32 @@ struct DotUnroller<Index, Dynamic, Derived1, Derived2>
} }
}; };
// prevent buggy user code from causing an infinite recursion
template<int Index, typename Derived1, typename Derived2>
struct DotUnroller<Index, 0, Derived1, Derived2>
{
static void run(const Derived1 &v1, const Derived2& v2, typename Derived1::Scalar &dot)
{
EIGEN_UNUSED(v1);
EIGEN_UNUSED(v2);
EIGEN_UNUSED(dot);
}
};
template<typename Scalar, typename Derived> template<typename Scalar, typename Derived>
template<typename OtherDerived> template<typename OtherDerived>
Scalar MatrixBase<Scalar, Derived>::dot(const OtherDerived& other) const Scalar MatrixBase<Scalar, Derived>::dot(const OtherDerived& other) const
{ {
assert(IsVector && OtherDerived::IsVector && size() == other.size()); assert(IsVector && OtherDerived::IsVector && size() == other.size());
Scalar res; Scalar res;
if(SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 16) if(EIGEN_UNROLLED_LOOPS && SizeAtCompileTime != Dynamic && SizeAtCompileTime <= 16)
DotUnroller<SizeAtCompileTime-1, SizeAtCompileTime, Derived, OtherDerived> DotUnroller<SizeAtCompileTime-1, SizeAtCompileTime, Derived, OtherDerived>
::run(*static_cast<const Derived*>(this), other, res); ::run(*static_cast<const Derived*>(this), other, res);
else else
{ {
res = (*this)[0] * conj(other[0]); res = (*this).read(0) * conj(other.read(0));
for(int i = 1; i < size(); i++) for(int i = 1; i < size(); i++)
res += (*this)[i]* conj(other[i]); res += (*this).read(i)* conj(other.read(i));
} }
return res; return res;
} }

View File

@ -61,6 +61,21 @@ struct ProductUnroller<Index, Dynamic, Lhs, Rhs>
} }
}; };
// prevent buggy user code from causing an infinite recursion
template<int Index, typename Lhs, typename Rhs>
struct ProductUnroller<Index, 0, Lhs, Rhs>
{
static void run(int row, int col, const Lhs& lhs, const Rhs& rhs,
typename Lhs::Scalar &res)
{
EIGEN_UNUSED(row);
EIGEN_UNUSED(col);
EIGEN_UNUSED(lhs);
EIGEN_UNUSED(rhs);
EIGEN_UNUSED(res);
}
};
template<typename Lhs, typename Rhs> class Product template<typename Lhs, typename Rhs> class Product
: public MatrixBase<typename Lhs::Scalar, Product<Lhs, Rhs> > : public MatrixBase<typename Lhs::Scalar, Product<Lhs, Rhs> >
{ {
@ -93,14 +108,15 @@ template<typename Lhs, typename Rhs> class Product
Scalar _read(int row, int col) const Scalar _read(int row, int col) const
{ {
Scalar res; Scalar res;
if(Lhs::ColsAtCompileTime != Dynamic && Lhs::ColsAtCompileTime <= 16) if(EIGEN_UNROLLED_LOOPS
&& Lhs::ColsAtCompileTime != Dynamic && Lhs::ColsAtCompileTime <= 16)
ProductUnroller<Lhs::ColsAtCompileTime-1, Lhs::ColsAtCompileTime, LhsRef, RhsRef> ProductUnroller<Lhs::ColsAtCompileTime-1, Lhs::ColsAtCompileTime, LhsRef, RhsRef>
::run(row, col, m_lhs, m_rhs, res); ::run(row, col, m_lhs, m_rhs, res);
else else
{ {
res = m_lhs(row, 0) * m_rhs(0, col); res = m_lhs.read(row, 0) * m_rhs.read(0, col);
for(int i = 1; i < m_lhs.cols(); i++) for(int i = 1; i < m_lhs.cols(); i++)
res += m_lhs(row, i) * m_rhs(i, col); res += m_lhs.read(row, i) * m_rhs.read(i, col);
} }
return res; return res;
} }
@ -112,7 +128,7 @@ template<typename Lhs, typename Rhs> class Product
template<typename Scalar, typename Derived> template<typename Scalar, typename Derived>
template<typename OtherDerived> template<typename OtherDerived>
Product<Derived, OtherDerived> const Product<Derived, OtherDerived>
MatrixBase<Scalar, Derived>::lazyProduct(const MatrixBase<Scalar, OtherDerived> &other) const MatrixBase<Scalar, Derived>::lazyProduct(const MatrixBase<Scalar, OtherDerived> &other) const
{ {
return Product<Derived, OtherDerived>(ref(), other.ref()); return Product<Derived, OtherDerived>(ref(), other.ref());

View File

@ -31,7 +31,7 @@ template<int Index, int Rows, typename Derived> struct TraceUnroller
static void run(const Derived &mat, typename Derived::Scalar &trace) static void run(const Derived &mat, typename Derived::Scalar &trace)
{ {
TraceUnroller<Index-1, Rows, Derived>::run(mat, trace); TraceUnroller<Index-1, Rows, Derived>::run(mat, trace);
trace += mat(Index, Index); trace += mat.read(Index, Index);
} }
}; };
@ -39,7 +39,7 @@ template<int Rows, typename Derived> struct TraceUnroller<0, Rows, Derived>
{ {
static void run(const Derived &mat, typename Derived::Scalar &trace) static void run(const Derived &mat, typename Derived::Scalar &trace)
{ {
trace = mat(0, 0); trace = mat.read(0, 0);
} }
}; };
@ -52,12 +52,22 @@ template<int Index, typename Derived> struct TraceUnroller<Index, Dynamic, Deriv
} }
}; };
// prevent buggy user code from causing an infinite recursion
template<int Index, typename Derived> struct TraceUnroller<Index, 0, Derived>
{
static void run(const Derived &mat, typename Derived::Scalar &trace)
{
EIGEN_UNUSED(mat);
EIGEN_UNUSED(trace);
}
};
template<typename Scalar, typename Derived> template<typename Scalar, typename Derived>
Scalar MatrixBase<Scalar, Derived>::trace() const Scalar MatrixBase<Scalar, Derived>::trace() const
{ {
assert(rows() == cols()); assert(rows() == cols());
Scalar res; Scalar res;
if(RowsAtCompileTime != Dynamic && RowsAtCompileTime <= 16) if(EIGEN_UNROLLED_LOOPS && RowsAtCompileTime != Dynamic && RowsAtCompileTime <= 16)
TraceUnroller<RowsAtCompileTime-1, RowsAtCompileTime, Derived> TraceUnroller<RowsAtCompileTime-1, RowsAtCompileTime, Derived>
::run(*static_cast<const Derived*>(this), res); ::run(*static_cast<const Derived*>(this), res);
else else