Reduce the register pressure exerted by the tensor mappers whenever possible. This improves the performance of the contraction of a matrix with a vector by about 35%.

This commit is contained in:
Benoit Steiner 2016-01-20 14:51:48 -08:00
parent ebd3388ee6
commit 47076bf00e
2 changed files with 101 additions and 13 deletions

View File

@ -128,6 +128,7 @@ struct TensorContractionEvaluatorBase
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
CoordAccess = false, // to be implemented
RawAccess = true
};
// Most of the code is assuming that both input tensors are ColMajor. If the
@ -434,11 +435,11 @@ struct TensorContractionEvaluatorBase
}
template<int LoadMode>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
return internal::ploadt<Packet, LoadMode>(m_result + index);
}
EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; }
protected:
// Prevent assignment

View File

@ -22,6 +22,54 @@ enum {
/*
* Implementation of the Eigen blas_data_mapper class for tensors.
*/
template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
enum {
DirectOffsets = false
};
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
eigen_assert(false && "unsupported");
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
{
return m_tensor.template packet<LoadMode>(index);
}
private:
const Tensor m_tensor;
};
template <typename Tensor> struct CoeffLoader<Tensor, true> {
enum {
DirectOffsets = true
};
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
m_data += offset;
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
{
return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
}
private:
typedef typename Tensor::Scalar Scalar;
const Scalar* m_data;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
@ -40,6 +88,14 @@ class SimpleTensorContractionMapper {
m_contract_strides(contract_strides),
m_k_strides(k_strides) { }
enum {
DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
};
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
m_tensor.offsetBuffer(offset);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
@ -148,7 +204,7 @@ class SimpleTensorContractionMapper {
}
protected:
const Tensor m_tensor;
CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
const nocontract_t m_nocontract_strides;
const nocontract_t m_ij_strides;
const contract_t m_contract_strides;
@ -270,12 +326,6 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con
}
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper;
template<typename Scalar, typename Index, int side,
typename Tensor,
@ -287,36 +337,70 @@ class TensorContractionSubMapper {
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
typedef Self LinearMapper;
enum {
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
// TODO: we should also enable direct offsets for the Rhs case.
UseDirectOffsets = (side == Lhs) && inner_dim_contiguous && ParentMapper::DirectOffsets
};
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
// Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
// this offset every time we attempt to access a coefficient.
if (UseDirectOffsets) {
Index stride = m_base_mapper.stride();
m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
}
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
if (UseDirectOffsets) {
return m_base_mapper(i, 0);
}
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
if (UseDirectOffsets) {
return m_base_mapper(i, j);
}
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
if (UseDirectOffsets) {
return m_base_mapper.template loadPacket<Alignment>(i, 0);
}
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
if (UseDirectOffsets) {
return m_base_mapper.template loadPacket<Alignment>(i, j);
}
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
if (UseDirectOffsets) {
return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
}
return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
if (UseDirectOffsets) {
m_base_mapper.storePacket(i, 0, p);
}
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
if (UseDirectOffsets) {
return LinearMapper(m_base_mapper, i, j);
}
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
@ -324,6 +408,9 @@ class TensorContractionSubMapper {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
if (UseDirectOffsets) {
return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
}
return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
}
@ -333,7 +420,7 @@ class TensorContractionSubMapper {
}
private:
const ParentMapper& m_base_mapper;
ParentMapper m_base_mapper;
const Index m_vert_offset;
const Index m_horiz_offset;
};