Implement true compile-time "if" for apply_rotation_in_the_plane. This fixes a compilation issue for vectorized real type with missing vectorization for complexes, e.g. AVX512.

This commit is contained in:
Gael Guennebaud 2017-09-06 10:02:49 +02:00
parent 80142362ac
commit b35d1ce4a5

View File

@ -309,35 +309,40 @@ inline void MatrixBase<Derived>::applyOnTheRight(Index p, Index q, const JacobiR
} }
namespace internal { namespace internal {
template<typename VectorX, typename VectorY, typename OtherScalar>
void /*EIGEN_DONT_INLINE*/ apply_rotation_in_the_plane(DenseBase<VectorX>& xpr_x, DenseBase<VectorY>& xpr_y, const JacobiRotation<OtherScalar>& j) template<typename Scalar, typename OtherScalar,
int SizeAtCompileTime, int MinAlignment, bool Vectorizable>
struct apply_rotation_in_the_plane_selector
{
static inline void run(Scalar *x, Index incrx, Scalar *y, Index incry, Index size, OtherScalar c, OtherScalar s)
{
for(Index i=0; i<size; ++i)
{
Scalar xi = *x;
Scalar yi = *y;
*x = c * xi + numext::conj(s) * yi;
*y = -s * xi + numext::conj(c) * yi;
x += incrx;
y += incry;
}
}
};
template<typename Scalar, typename OtherScalar,
int SizeAtCompileTime, int MinAlignment>
struct apply_rotation_in_the_plane_selector<Scalar,OtherScalar,SizeAtCompileTime,MinAlignment,true /* vectorizable */>
{
static inline void run(Scalar *x, Index incrx, Scalar *y, Index incry, Index size, OtherScalar c, OtherScalar s)
{ {
typedef typename VectorX::Scalar Scalar;
enum { enum {
PacketSize = packet_traits<Scalar>::size, PacketSize = packet_traits<Scalar>::size,
OtherPacketSize = packet_traits<OtherScalar>::size OtherPacketSize = packet_traits<OtherScalar>::size
}; };
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<OtherScalar>::type OtherPacket; typedef typename packet_traits<OtherScalar>::type OtherPacket;
eigen_assert(xpr_x.size() == xpr_y.size());
Index size = xpr_x.size();
Index incrx = xpr_x.derived().innerStride();
Index incry = xpr_y.derived().innerStride();
Scalar* EIGEN_RESTRICT x = &xpr_x.derived().coeffRef(0);
Scalar* EIGEN_RESTRICT y = &xpr_y.derived().coeffRef(0);
OtherScalar c = j.c();
OtherScalar s = j.s();
if (c==OtherScalar(1) && s==OtherScalar(0))
return;
/*** dynamic-size vectorized paths ***/ /*** dynamic-size vectorized paths ***/
if(SizeAtCompileTime == Dynamic && ((incrx==1 && incry==1) || PacketSize == 1))
if(VectorX::SizeAtCompileTime == Dynamic &&
(VectorX::Flags & VectorY::Flags & PacketAccessBit) &&
(PacketSize == OtherPacketSize) &&
((incrx==1 && incry==1) || PacketSize == 1))
{ {
// both vectors are sequentially stored in memory => vectorization // both vectors are sequentially stored in memory => vectorization
enum { Peeling = 2 }; enum { Peeling = 2 };
@ -408,10 +413,7 @@ void /*EIGEN_DONT_INLINE*/ apply_rotation_in_the_plane(DenseBase<VectorX>& xpr_x
} }
/*** fixed-size vectorized path ***/ /*** fixed-size vectorized path ***/
else if(VectorX::SizeAtCompileTime != Dynamic && else if(SizeAtCompileTime != Dynamic && MinAlignment>0) // FIXME should be compared to the required alignment
(VectorX::Flags & VectorY::Flags & PacketAccessBit) &&
(PacketSize == OtherPacketSize) &&
(EIGEN_PLAIN_ENUM_MIN(evaluator<VectorX>::Alignment, evaluator<VectorY>::Alignment)>0)) // FIXME should be compared to the required alignment
{ {
const OtherPacket pc = pset1<OtherPacket>(c); const OtherPacket pc = pset1<OtherPacket>(c);
const OtherPacket ps = pset1<OtherPacket>(s); const OtherPacket ps = pset1<OtherPacket>(s);
@ -433,16 +435,36 @@ void /*EIGEN_DONT_INLINE*/ apply_rotation_in_the_plane(DenseBase<VectorX>& xpr_x
/*** non-vectorized path ***/ /*** non-vectorized path ***/
else else
{ {
for(Index i=0; i<size; ++i) apply_rotation_in_the_plane_selector<Scalar,OtherScalar,SizeAtCompileTime,MinAlignment,false>::run(x,incrx,y,incry,size,c,s);
}
}
};
template<typename VectorX, typename VectorY, typename OtherScalar>
void /*EIGEN_DONT_INLINE*/ apply_rotation_in_the_plane(DenseBase<VectorX>& xpr_x, DenseBase<VectorY>& xpr_y, const JacobiRotation<OtherScalar>& j)
{ {
Scalar xi = *x; typedef typename VectorX::Scalar Scalar;
Scalar yi = *y; const bool Vectorizable = (VectorX::Flags & VectorY::Flags & PacketAccessBit)
*x = c * xi + numext::conj(s) * yi; && (int(packet_traits<Scalar>::size) == int(packet_traits<OtherScalar>::size));
*y = -s * xi + numext::conj(c) * yi;
x += incrx; eigen_assert(xpr_x.size() == xpr_y.size());
y += incry; Index size = xpr_x.size();
} Index incrx = xpr_x.derived().innerStride();
} Index incry = xpr_y.derived().innerStride();
Scalar* EIGEN_RESTRICT x = &xpr_x.derived().coeffRef(0);
Scalar* EIGEN_RESTRICT y = &xpr_y.derived().coeffRef(0);
OtherScalar c = j.c();
OtherScalar s = j.s();
if (c==OtherScalar(1) && s==OtherScalar(0))
return;
apply_rotation_in_the_plane_selector<
Scalar,OtherScalar,
VectorX::SizeAtCompileTime,
EIGEN_PLAIN_ENUM_MIN(evaluator<VectorX>::Alignment, evaluator<VectorY>::Alignment),
Vectorizable>::run(x,incrx,y,incry,size,c,s);
} }
} // end namespace internal } // end namespace internal