mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
Reworked the tensor contraction mapper code to make it compile on Android
This commit is contained in:
parent
29c3b7513e
commit
a586fdaa91
@ -33,10 +33,10 @@ template<typename Scalar, typename Index, int side,
|
|||||||
typename Tensor,
|
typename Tensor,
|
||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
int packet_size, bool inner_dim_contiguous>
|
int packet_size, bool inner_dim_contiguous>
|
||||||
class BaseTensorContractionMapper {
|
class SimpleTensorContractionMapper {
|
||||||
public:
|
public:
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
BaseTensorContractionMapper(const Tensor& tensor,
|
SimpleTensorContractionMapper(const Tensor& tensor,
|
||||||
const nocontract_t& nocontract_strides,
|
const nocontract_t& nocontract_strides,
|
||||||
const nocontract_t& ij_strides,
|
const nocontract_t& ij_strides,
|
||||||
const contract_t& contract_strides,
|
const contract_t& contract_strides,
|
||||||
@ -160,104 +160,23 @@ class BaseTensorContractionMapper {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int side,
|
template<typename Scalar, typename Index, int side,
|
||||||
typename Tensor,
|
typename Tensor,
|
||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
int packet_size,
|
size_t packet_size, bool inner_dim_contiguous,
|
||||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
bool inner_dim_reordered, int Alignment>
|
||||||
class TensorContractionInputMapper;
|
class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
|
||||||
|
{
|
||||||
template<typename Scalar, typename Index, int side,
|
|
||||||
typename Tensor,
|
|
||||||
typename nocontract_t, typename contract_t,
|
|
||||||
int packet_size,
|
|
||||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
|
||||||
class TensorContractionSubMapper {
|
|
||||||
public:
|
public:
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
|
||||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
|
||||||
|
|
||||||
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
|
EIGEN_DEVICE_FUNC
|
||||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
|
BaseTensorContractionMapper(const Tensor& tensor,
|
||||||
typedef Self LinearMapper;
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
|
|
||||||
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
|
|
||||||
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
|
|
||||||
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
|
||||||
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
|
||||||
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
|
||||||
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
|
||||||
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
|
||||||
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PacketT, int AlignmentType>
|
|
||||||
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
|
||||||
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
|
||||||
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
|
||||||
return loadPacket(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Packet>
|
|
||||||
bool aligned(Index /*i*/) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const ParentMapper& m_base_mapper;
|
|
||||||
const Index m_vert_offset;
|
|
||||||
const Index m_horiz_offset;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int side,
|
|
||||||
typename Tensor,
|
|
||||||
typename nocontract_t, typename contract_t,
|
|
||||||
int packet_size = (Tensor::PacketAccess ? packet_traits<Scalar>::size : 1),
|
|
||||||
bool inner_dim_contiguous = false, bool inner_dim_reordered = (side != Lhs), int Alignment=Unaligned>
|
|
||||||
class TensorContractionInputMapper
|
|
||||||
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> {
|
|
||||||
|
|
||||||
public:
|
|
||||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
|
|
||||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
|
||||||
typedef SubMapper VectorMapper;
|
|
||||||
|
|
||||||
TensorContractionInputMapper(const Tensor& tensor,
|
|
||||||
const nocontract_t& nocontract_strides,
|
const nocontract_t& nocontract_strides,
|
||||||
const nocontract_t& ij_strides,
|
const nocontract_t& ij_strides,
|
||||||
const contract_t& contract_strides,
|
const contract_t& contract_strides,
|
||||||
const contract_t& k_strides)
|
const contract_t& k_strides) :
|
||||||
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
|
||||||
return SubMapper(*this, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
|
||||||
return VectorMapper(*this, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
@ -322,35 +241,23 @@ class TensorContractionInputMapper
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int side,
|
template<typename Scalar, typename Index, int side,
|
||||||
typename Tensor,
|
typename Tensor,
|
||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
bool inner_dim_contiguous,
|
||||||
class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment>
|
bool inner_dim_reordered, int Alignment>
|
||||||
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> {
|
class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
|
typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
|
||||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
|
||||||
typedef SubMapper VectorMapper;
|
|
||||||
|
|
||||||
TensorContractionInputMapper(const Tensor& tensor,
|
EIGEN_DEVICE_FUNC
|
||||||
|
BaseTensorContractionMapper(const Tensor& tensor,
|
||||||
const nocontract_t& nocontract_strides,
|
const nocontract_t& nocontract_strides,
|
||||||
const nocontract_t& ij_strides,
|
const nocontract_t& ij_strides,
|
||||||
const contract_t& contract_strides,
|
const contract_t& contract_strides,
|
||||||
const contract_t& k_strides)
|
const contract_t& k_strides) :
|
||||||
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
|
||||||
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
|
||||||
return SubMapper(*this, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
|
||||||
return VectorMapper(*this, i, j);
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -365,6 +272,106 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Scalar, typename Index, int side,
|
||||||
|
typename Tensor,
|
||||||
|
typename nocontract_t, typename contract_t,
|
||||||
|
size_t packet_size,
|
||||||
|
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||||
|
class TensorContractionInputMapper;
|
||||||
|
|
||||||
|
template<typename Scalar, typename Index, int side,
|
||||||
|
typename Tensor,
|
||||||
|
typename nocontract_t, typename contract_t,
|
||||||
|
size_t packet_size,
|
||||||
|
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||||
|
class TensorContractionSubMapper {
|
||||||
|
public:
|
||||||
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
|
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
|
||||||
|
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
|
||||||
|
typedef Self LinearMapper;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
|
||||||
|
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
|
||||||
|
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
|
||||||
|
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
|
||||||
|
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
|
||||||
|
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
|
||||||
|
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
|
||||||
|
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
|
||||||
|
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PacketT, int AlignmentType>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||||
|
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
return loadPacket(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC bool aligned(Index) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const ParentMapper& m_base_mapper;
|
||||||
|
const Index m_vert_offset;
|
||||||
|
const Index m_horiz_offset;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<typename Scalar, typename Index, int side,
|
||||||
|
typename Tensor,
|
||||||
|
typename nocontract_t, typename contract_t,
|
||||||
|
size_t packet_size,
|
||||||
|
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||||
|
class TensorContractionInputMapper
|
||||||
|
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
|
||||||
|
|
||||||
|
public:
|
||||||
|
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
|
||||||
|
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||||
|
typedef SubMapper VectorMapper;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
|
||||||
|
const nocontract_t& nocontract_strides,
|
||||||
|
const nocontract_t& ij_strides,
|
||||||
|
const contract_t& contract_strides,
|
||||||
|
const contract_t& k_strides)
|
||||||
|
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
|
||||||
|
return SubMapper(*this, i, j);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||||
|
return VectorMapper(*this, i, j);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
||||||
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
||||||
|
Loading…
x
Reference in New Issue
Block a user