mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Improved the performance of the contraction of a 2d tensor with a 1d tensor by a factor of 3 or more. This helps speedup LSTM neural networks.
This commit is contained in:
parent
bd7d901da9
commit
d920d57f38
@ -32,7 +32,7 @@ enum {
|
|||||||
template<typename Scalar, typename Index, int side,
|
template<typename Scalar, typename Index, int side,
|
||||||
typename Tensor,
|
typename Tensor,
|
||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
int packet_size, bool inner_dim_contiguous>
|
int packet_size, bool inner_dim_contiguous, int Alignment>
|
||||||
class SimpleTensorContractionMapper {
|
class SimpleTensorContractionMapper {
|
||||||
public:
|
public:
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -144,11 +144,11 @@ class SimpleTensorContractionMapper {
|
|||||||
return IndexPair<Index>(linidx[0], linidx[1]);
|
return IndexPair<Index>(linidx[0], linidx[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Index firstAligned(Index size) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
|
||||||
return size;
|
return (Alignment == Aligned) ? 0 : size;
|
||||||
}
|
}
|
||||||
Index stride() const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
|
||||||
return 1;
|
return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -165,10 +165,10 @@ template<typename Scalar, typename Index, int side,
|
|||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
int packet_size, bool inner_dim_contiguous,
|
int packet_size, bool inner_dim_contiguous,
|
||||||
bool inner_dim_reordered, int Alignment>
|
bool inner_dim_reordered, int Alignment>
|
||||||
class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
|
class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
|
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
BaseTensorContractionMapper(const Tensor& tensor,
|
BaseTensorContractionMapper(const Tensor& tensor,
|
||||||
@ -181,6 +181,7 @@ template<typename Scalar, typename Index, int side,
|
|||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
|
template <int AlignmentType = Alignment>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
||||||
// whole method makes column major assumption
|
// whole method makes column major assumption
|
||||||
@ -192,7 +193,7 @@ template<typename Scalar, typename Index, int side,
|
|||||||
if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
|
if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
|
||||||
const Index index = this->computeIndex(i, j);
|
const Index index = this->computeIndex(i, j);
|
||||||
eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
|
eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
|
||||||
return this->m_tensor.template packet<Alignment>(index);
|
return this->m_tensor.template packet<AlignmentType>(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
|
const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
|
||||||
@ -207,7 +208,7 @@ template<typename Scalar, typename Index, int side,
|
|||||||
(side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
|
(side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
|
||||||
(last - first) == (packet_size - 1)) {
|
(last - first) == (packet_size - 1)) {
|
||||||
|
|
||||||
return this->m_tensor.template packet<Alignment>(first);
|
return this->m_tensor.template packet<AlignmentType>(first);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_ALIGN_MAX Scalar data[packet_size];
|
EIGEN_ALIGN_MAX Scalar data[packet_size];
|
||||||
@ -223,6 +224,7 @@ template<typename Scalar, typename Index, int side,
|
|||||||
return pload<Packet>(data);
|
return pload<Packet>(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int AlignmentType = Alignment>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
|
EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
|
||||||
// whole method makes column major assumption
|
// whole method makes column major assumption
|
||||||
@ -230,7 +232,7 @@ template<typename Scalar, typename Index, int side,
|
|||||||
// don't need to add offsets for now (because operator handles that)
|
// don't need to add offsets for now (because operator handles that)
|
||||||
const Index half_packet_size = unpacket_traits<HalfPacket>::size;
|
const Index half_packet_size = unpacket_traits<HalfPacket>::size;
|
||||||
if (half_packet_size == packet_size) {
|
if (half_packet_size == packet_size) {
|
||||||
return loadPacket(i, j);
|
return loadPacket<AlignmentType>(i, j);
|
||||||
}
|
}
|
||||||
EIGEN_ALIGN_MAX Scalar data[half_packet_size];
|
EIGEN_ALIGN_MAX Scalar data[half_packet_size];
|
||||||
for (Index k = 0; k < half_packet_size; k++) {
|
for (Index k = 0; k < half_packet_size; k++) {
|
||||||
@ -246,10 +248,10 @@ template<typename Scalar, typename Index, int side,
|
|||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
bool inner_dim_contiguous,
|
bool inner_dim_contiguous,
|
||||||
bool inner_dim_reordered, int Alignment>
|
bool inner_dim_reordered, int Alignment>
|
||||||
class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
|
class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
|
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
BaseTensorContractionMapper(const Tensor& tensor,
|
BaseTensorContractionMapper(const Tensor& tensor,
|
||||||
@ -260,13 +262,13 @@ class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, con
|
|||||||
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||||
|
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
EIGEN_DEVICE_FUNC
|
template <int> EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
||||||
EIGEN_ALIGN_MAX Scalar data[1];
|
EIGEN_ALIGN_MAX Scalar data[1];
|
||||||
data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
|
data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
|
||||||
return pload<typename packet_traits<Scalar>::type>(data);
|
return pload<typename packet_traits<Scalar>::type>(data);
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC
|
template <int> EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
|
EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
|
||||||
return loadPacket(i, j);
|
return loadPacket(i, j);
|
||||||
}
|
}
|
||||||
@ -304,14 +306,14 @@ class TensorContractionSubMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
||||||
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
|
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
||||||
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
|
return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
||||||
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
|
return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
||||||
@ -325,12 +327,12 @@ class TensorContractionSubMapper {
|
|||||||
template <typename PacketT, int AlignmentType>
|
template <typename PacketT, int AlignmentType>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
|
||||||
return loadPacket(i);
|
return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEVICE_FUNC bool aligned(Index) const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -741,17 +743,19 @@ struct TensorContractionEvaluatorBase
|
|||||||
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
||||||
const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
|
const Index lhs_packet_size = internal::packet_traits<LhsScalar>::size;
|
||||||
const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
|
const Index rhs_packet_size = internal::packet_traits<RhsScalar>::size;
|
||||||
|
const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned;
|
||||||
|
const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned;
|
||||||
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
||||||
LeftEvaluator, left_nocontract_t,
|
LeftEvaluator, left_nocontract_t,
|
||||||
contract_t, lhs_packet_size,
|
contract_t, lhs_packet_size,
|
||||||
lhs_inner_dim_contiguous,
|
lhs_inner_dim_contiguous,
|
||||||
false, Unaligned> LhsMapper;
|
false, lhs_alignment> LhsMapper;
|
||||||
|
|
||||||
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
||||||
RightEvaluator, right_nocontract_t,
|
RightEvaluator, right_nocontract_t,
|
||||||
contract_t, rhs_packet_size,
|
contract_t, rhs_packet_size,
|
||||||
rhs_inner_dim_contiguous,
|
rhs_inner_dim_contiguous,
|
||||||
rhs_inner_dim_reordered, Unaligned> RhsMapper;
|
rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
|
||||||
|
|
||||||
LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
|
LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
|
||||||
m_left_contracting_strides, m_k_strides);
|
m_left_contracting_strides, m_k_strides);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user