Reworked the tensor contraction mapper code to make it compile on Android

This commit is contained in:
Benoit Steiner 2015-10-23 09:33:41 -07:00
parent 29c3b7513e
commit a586fdaa91

View File

@ -33,14 +33,14 @@ template<typename Scalar, typename Index, int side,
typename Tensor, typename Tensor,
typename nocontract_t, typename contract_t, typename nocontract_t, typename contract_t,
int packet_size, bool inner_dim_contiguous> int packet_size, bool inner_dim_contiguous>
class BaseTensorContractionMapper { class SimpleTensorContractionMapper {
public: public:
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
BaseTensorContractionMapper(const Tensor& tensor, SimpleTensorContractionMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides, const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides, const nocontract_t& ij_strides,
const contract_t& contract_strides, const contract_t& contract_strides,
const contract_t& k_strides) : const contract_t& k_strides) :
m_tensor(tensor), m_tensor(tensor),
m_nocontract_strides(nocontract_strides), m_nocontract_strides(nocontract_strides),
m_ij_strides(ij_strides), m_ij_strides(ij_strides),
@ -160,104 +160,23 @@ class BaseTensorContractionMapper {
}; };
template<typename Scalar, typename Index, int side, template<typename Scalar, typename Index, int side,
typename Tensor, typename Tensor,
typename nocontract_t, typename contract_t, typename nocontract_t, typename contract_t,
int packet_size, size_t packet_size, bool inner_dim_contiguous,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper; class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous>
{
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionSubMapper {
public: public:
typedef typename packet_traits<Scalar>::type Packet; typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> ParentMapper;
typedef typename packet_traits<Scalar>::half HalfPacket;
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
typedef Self LinearMapper;
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
return loadPacket(i);
}
template <typename Packet>
bool aligned(Index /*i*/) const {
return false;
}
private:
const ParentMapper& m_base_mapper;
const Index m_vert_offset;
const Index m_horiz_offset;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
int packet_size = (Tensor::PacketAccess ? packet_traits<Scalar>::size : 1),
bool inner_dim_contiguous = false, bool inner_dim_reordered = (side != Lhs), int Alignment=Unaligned>
class TensorContractionInputMapper
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> {
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides)
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { BaseTensorContractionMapper(const Tensor& tensor,
return SubMapper(*this, i, j); const nocontract_t& nocontract_strides,
} const nocontract_t& ij_strides,
const contract_t& contract_strides,
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { const contract_t& k_strides) :
return VectorMapper(*this, i, j); ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
}
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket; typedef typename packet_traits<Scalar>::half HalfPacket;
@ -322,35 +241,23 @@ class TensorContractionInputMapper
}; };
template<typename Scalar, typename Index, int side, template<typename Scalar, typename Index, int side,
typename Tensor, typename Tensor,
typename nocontract_t, typename contract_t, typename nocontract_t, typename contract_t,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> bool inner_dim_contiguous,
class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> bool inner_dim_reordered, int Alignment>
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> { class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous>
{
public: public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base; typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides)
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { BaseTensorContractionMapper(const Tensor& tensor,
return SubMapper(*this, i, j); const nocontract_t& nocontract_strides,
} const nocontract_t& ij_strides,
const contract_t& contract_strides,
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { const contract_t& k_strides) :
return VectorMapper(*this, i, j); ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
}
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -365,6 +272,106 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
} }
}; };
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
size_t packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper;
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
size_t packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionSubMapper {
public:
typedef typename packet_traits<Scalar>::type Packet;
typedef typename packet_traits<Scalar>::half HalfPacket;
typedef TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
typedef Self LinearMapper;
EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
: m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
return m_base_mapper(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
return m_base_mapper.loadPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
return m_base_mapper.loadPacket(i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
return m_base_mapper.loadHalfPacket(i + m_vert_offset, m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
return loadPacket(i);
}
template <typename Packet>
EIGEN_DEVICE_FUNC bool aligned(Index) const {
return false;
}
private:
const ParentMapper& m_base_mapper;
const Index m_vert_offset;
const Index m_horiz_offset;
};
template<typename Scalar, typename Index, int side,
typename Tensor,
typename nocontract_t, typename contract_t,
size_t packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
public:
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper;
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
const nocontract_t& ij_strides,
const contract_t& contract_strides,
const contract_t& k_strides)
: Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
return SubMapper(*this, i, j);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}
};
template<typename Dimensions, typename LhsXprType, typename RhsXprType> template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >