Fix bug #972: allow coeff-based products of depth 0 and remove a useless statement in coeff-based product.

This commit is contained in:
Gael Guennebaud 2015-02-28 15:25:39 +01:00
parent 0e38796e1c
commit 26234720bd

View File

@ -134,7 +134,7 @@ class CoeffBasedProduct
}; };
typedef internal::product_coeff_impl<CanVectorizeInner ? InnerVectorizedTraversal : DefaultTraversal, typedef internal::product_coeff_impl<CanVectorizeInner ? InnerVectorizedTraversal : DefaultTraversal,
Unroll ? (InnerSize==0 ? 0 : InnerSize-1) : Dynamic, Unroll ? InnerSize : Dynamic,
_LhsNested, _RhsNested, Scalar> ScalarCoeffImpl; _LhsNested, _RhsNested, Scalar> ScalarCoeffImpl;
typedef CoeffBasedProduct<LhsNested,RhsNested,NestByRefBit> LazyCoeffBasedProductType; typedef CoeffBasedProduct<LhsNested,RhsNested,NestByRefBit> LazyCoeffBasedProductType;
@ -185,7 +185,7 @@ class CoeffBasedProduct
{ {
PacketScalar res; PacketScalar res;
internal::product_packet_impl<Flags&RowMajorBit ? RowMajor : ColMajor, internal::product_packet_impl<Flags&RowMajorBit ? RowMajor : ColMajor,
Unroll ? (InnerSize==0 ? 0 : InnerSize-1) : Dynamic, Unroll ? InnerSize : Dynamic,
_LhsNested, _RhsNested, PacketScalar, LoadMode> _LhsNested, _RhsNested, PacketScalar, LoadMode>
::run(row, col, m_lhs, m_rhs, res); ::run(row, col, m_lhs, m_rhs, res);
return res; return res;
@ -243,7 +243,17 @@ struct product_coeff_impl<DefaultTraversal, UnrollingIndex, Lhs, Rhs, RetScalar>
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, RetScalar &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, RetScalar &res)
{ {
product_coeff_impl<DefaultTraversal, UnrollingIndex-1, Lhs, Rhs, RetScalar>::run(row, col, lhs, rhs, res); product_coeff_impl<DefaultTraversal, UnrollingIndex-1, Lhs, Rhs, RetScalar>::run(row, col, lhs, rhs, res);
res += lhs.coeff(row, UnrollingIndex) * rhs.coeff(UnrollingIndex, col); res += lhs.coeff(row, UnrollingIndex-1) * rhs.coeff(UnrollingIndex-1, col);
}
};
template<typename Lhs, typename Rhs, typename RetScalar>
struct product_coeff_impl<DefaultTraversal, 1, Lhs, Rhs, RetScalar>
{
typedef typename Lhs::Index Index;
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, RetScalar &res)
{
res = lhs.coeff(row, 0) * rhs.coeff(0, col);
} }
}; };
@ -251,9 +261,9 @@ template<typename Lhs, typename Rhs, typename RetScalar>
struct product_coeff_impl<DefaultTraversal, 0, Lhs, Rhs, RetScalar> struct product_coeff_impl<DefaultTraversal, 0, Lhs, Rhs, RetScalar>
{ {
typedef typename Lhs::Index Index; typedef typename Lhs::Index Index;
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, RetScalar &res) static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, RetScalar &res)
{ {
res = lhs.coeff(row, 0) * rhs.coeff(0, col); res = RetScalar(0);
} }
}; };
@ -293,6 +303,16 @@ struct product_coeff_vectorized_unroller<0, Lhs, Rhs, Packet>
} }
}; };
template<typename Lhs, typename Rhs, typename RetScalar>
struct product_coeff_impl<InnerVectorizedTraversal, 0, Lhs, Rhs, RetScalar>
{
typedef typename Lhs::Index Index;
static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, RetScalar &res)
{
res = 0;
}
};
template<int UnrollingIndex, typename Lhs, typename Rhs, typename RetScalar> template<int UnrollingIndex, typename Lhs, typename Rhs, typename RetScalar>
struct product_coeff_impl<InnerVectorizedTraversal, UnrollingIndex, Lhs, Rhs, RetScalar> struct product_coeff_impl<InnerVectorizedTraversal, UnrollingIndex, Lhs, Rhs, RetScalar>
{ {
@ -302,8 +322,7 @@ struct product_coeff_impl<InnerVectorizedTraversal, UnrollingIndex, Lhs, Rhs, Re
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, RetScalar &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, RetScalar &res)
{ {
Packet pres; Packet pres;
product_coeff_vectorized_unroller<UnrollingIndex+1-PacketSize, Lhs, Rhs, Packet>::run(row, col, lhs, rhs, pres); product_coeff_vectorized_unroller<UnrollingIndex-PacketSize, Lhs, Rhs, Packet>::run(row, col, lhs, rhs, pres);
product_coeff_impl<DefaultTraversal,UnrollingIndex,Lhs,Rhs,RetScalar>::run(row, col, lhs, rhs, res);
res = predux(pres); res = predux(pres);
} }
}; };
@ -371,7 +390,7 @@ struct product_packet_impl<RowMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode>
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res)
{ {
product_packet_impl<RowMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, res); product_packet_impl<RowMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, res);
res = pmadd(pset1<Packet>(lhs.coeff(row, UnrollingIndex)), rhs.template packet<LoadMode>(UnrollingIndex, col), res); res = pmadd(pset1<Packet>(lhs.coeff(row, UnrollingIndex-1)), rhs.template packet<LoadMode>(UnrollingIndex-1, col), res);
} }
}; };
@ -382,12 +401,12 @@ struct product_packet_impl<ColMajor, UnrollingIndex, Lhs, Rhs, Packet, LoadMode>
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res)
{ {
product_packet_impl<ColMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, res); product_packet_impl<ColMajor, UnrollingIndex-1, Lhs, Rhs, Packet, LoadMode>::run(row, col, lhs, rhs, res);
res = pmadd(lhs.template packet<LoadMode>(row, UnrollingIndex), pset1<Packet>(rhs.coeff(UnrollingIndex, col)), res); res = pmadd(lhs.template packet<LoadMode>(row, UnrollingIndex-1), pset1<Packet>(rhs.coeff(UnrollingIndex-1, col)), res);
} }
}; };
template<typename Lhs, typename Rhs, typename Packet, int LoadMode> template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode> struct product_packet_impl<RowMajor, 1, Lhs, Rhs, Packet, LoadMode>
{ {
typedef typename Lhs::Index Index; typedef typename Lhs::Index Index;
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res)
@ -397,7 +416,7 @@ struct product_packet_impl<RowMajor, 0, Lhs, Rhs, Packet, LoadMode>
}; };
template<typename Lhs, typename Rhs, typename Packet, int LoadMode> template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode> struct product_packet_impl<ColMajor, 1, Lhs, Rhs, Packet, LoadMode>
{ {
typedef typename Lhs::Index Index; typedef typename Lhs::Index Index;
static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res) static EIGEN_STRONG_INLINE void run(Index row, Index col, const Lhs& lhs, const Rhs& rhs, Packet &res)
@ -406,6 +425,16 @@ struct product_packet_impl<ColMajor, 0, Lhs, Rhs, Packet, LoadMode>
} }
}; };
template<int StorageOrder, typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct product_packet_impl<StorageOrder, 0, Lhs, Rhs, Packet, LoadMode>
{
typedef typename Lhs::Index Index;
static EIGEN_STRONG_INLINE void run(Index /*row*/, Index /*col*/, const Lhs& /*lhs*/, const Rhs& /*rhs*/, Packet &res)
{
res = pset1<Packet>(0);
}
};
template<typename Lhs, typename Rhs, typename Packet, int LoadMode> template<typename Lhs, typename Rhs, typename Packet, int LoadMode>
struct product_packet_impl<RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode> struct product_packet_impl<RowMajor, Dynamic, Lhs, Rhs, Packet, LoadMode>
{ {