mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Reduce the register pressure exerted by the tensor mappers whenever possible. This improves the performance of the contraction of a matrix with a vector by about 35%.
This commit is contained in:
parent
ebd3388ee6
commit
47076bf00e
@ -128,6 +128,7 @@ struct TensorContractionEvaluatorBase
|
||||
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
RawAccess = true
|
||||
};
|
||||
|
||||
// Most of the code is assuming that both input tensors are ColMajor. If the
|
||||
@ -434,11 +435,11 @@ struct TensorContractionEvaluatorBase
|
||||
}
|
||||
|
||||
template<int LoadMode>
|
||||
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
|
||||
return internal::ploadt<Packet, LoadMode>(m_result + index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data() const { return m_result; }
|
||||
|
||||
protected:
|
||||
// Prevent assignment
|
||||
|
@ -22,6 +22,54 @@ enum {
|
||||
/*
|
||||
* Implementation of the Eigen blas_data_mapper class for tensors.
|
||||
*/
|
||||
|
||||
template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
|
||||
enum {
|
||||
DirectOffsets = false
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
|
||||
eigen_assert(false && "unsupported");
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
|
||||
|
||||
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
|
||||
{
|
||||
return m_tensor.template packet<LoadMode>(index);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
const Tensor m_tensor;
|
||||
};
|
||||
|
||||
template <typename Tensor> struct CoeffLoader<Tensor, true> {
|
||||
enum {
|
||||
DirectOffsets = true
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
|
||||
m_data += offset;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
|
||||
|
||||
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
|
||||
{
|
||||
return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
|
||||
}
|
||||
private:
|
||||
typedef typename Tensor::Scalar Scalar;
|
||||
const Scalar* m_data;
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
@ -40,6 +88,14 @@ class SimpleTensorContractionMapper {
|
||||
m_contract_strides(contract_strides),
|
||||
m_k_strides(k_strides) { }
|
||||
|
||||
enum {
|
||||
DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
|
||||
m_tensor.offsetBuffer(offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
|
||||
|
||||
@ -148,7 +204,7 @@ class SimpleTensorContractionMapper {
|
||||
}
|
||||
|
||||
protected:
|
||||
const Tensor m_tensor;
|
||||
CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
|
||||
const nocontract_t m_nocontract_strides;
|
||||
const nocontract_t m_ij_strides;
|
||||
const contract_t m_contract_strides;
|
||||
@ -270,12 +326,6 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
typename nocontract_t, typename contract_t,
|
||||
int packet_size,
|
||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||
class TensorContractionInputMapper;
|
||||
|
||||
template<typename Scalar, typename Index, int side,
|
||||
typename Tensor,
|
||||
@ -287,36 +337,70 @@ class TensorContractionSubMapper {
|
||||
typedef typename packet_traits<Scalar>::type Packet;
|
||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||
|
||||
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
|
||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
|
||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
|
||||
typedef Self LinearMapper;
|
||||
|
||||
enum {
|
||||
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
|
||||
// TODO: we should also enable direct offsets for the Rhs case.
|
||||
UseDirectOffsets = (side == Lhs) && inner_dim_contiguous && ParentMapper::DirectOffsets
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
|
||||
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
|
||||
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
|
||||
// Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
|
||||
// this offset every time we attempt to access a coefficient.
|
||||
if (UseDirectOffsets) {
|
||||
Index stride = m_base_mapper.stride();
|
||||
m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper(i, 0);
|
||||
}
|
||||
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper(i, j);
|
||||
}
|
||||
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper.template loadPacket<Alignment>(i, 0);
|
||||
}
|
||||
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
||||
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper.template loadPacket<Alignment>(i, j);
|
||||
}
|
||||
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
|
||||
}
|
||||
return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
||||
if (UseDirectOffsets) {
|
||||
m_base_mapper.storePacket(i, 0, p);
|
||||
}
|
||||
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||
if (UseDirectOffsets) {
|
||||
return LinearMapper(m_base_mapper, i, j);
|
||||
}
|
||||
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||
}
|
||||
|
||||
@ -324,6 +408,9 @@ class TensorContractionSubMapper {
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
|
||||
if (UseDirectOffsets) {
|
||||
return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
|
||||
}
|
||||
return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
|
||||
}
|
||||
|
||||
@ -333,7 +420,7 @@ class TensorContractionSubMapper {
|
||||
}
|
||||
|
||||
private:
|
||||
const ParentMapper& m_base_mapper;
|
||||
ParentMapper m_base_mapper;
|
||||
const Index m_vert_offset;
|
||||
const Index m_horiz_offset;
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user