mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Improved handling of 1d tensors
This commit is contained in:
parent
2dde63499c
commit
b1789c112b
@ -48,7 +48,7 @@ class BaseTensorContractionMapper {
|
|||||||
m_k_strides(k_strides) { }
|
m_k_strides(k_strides) { }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE void prefetch(int /*i*/) { }
|
EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
|
EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
|
||||||
@ -142,6 +142,13 @@ class BaseTensorContractionMapper {
|
|||||||
return IndexPair<Index>(linidx[0], linidx[1]);
|
return IndexPair<Index>(linidx[0], linidx[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Index firstAligned(Index size) const {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
Index stride() const {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
const Tensor m_tensor;
|
const Tensor m_tensor;
|
||||||
const nocontract_t m_nocontract_strides;
|
const nocontract_t m_nocontract_strides;
|
||||||
@ -202,6 +209,18 @@ class TensorContractionSubMapper {
|
|||||||
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename PacketT, int AlignmentType>
|
||||||
|
EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
|
||||||
|
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
return loadPacket(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Packet>
|
||||||
|
bool aligned(Index /*i*/) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const ParentMapper& m_base_mapper;
|
const ParentMapper& m_base_mapper;
|
||||||
const Index m_vert_offset;
|
const Index m_vert_offset;
|
||||||
@ -220,6 +239,7 @@ class TensorContractionInputMapper
|
|||||||
public:
|
public:
|
||||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
|
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous> Base;
|
||||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||||
|
typedef SubMapper VectorMapper;
|
||||||
|
|
||||||
TensorContractionInputMapper(const Tensor& tensor,
|
TensorContractionInputMapper(const Tensor& tensor,
|
||||||
const nocontract_t& nocontract_strides,
|
const nocontract_t& nocontract_strides,
|
||||||
@ -233,6 +253,10 @@ class TensorContractionInputMapper
|
|||||||
return SubMapper(*this, i, j);
|
return SubMapper(*this, i, j);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||||
|
return VectorMapper(*this, i, j);
|
||||||
|
}
|
||||||
|
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
typedef typename packet_traits<Scalar>::half HalfPacket;
|
typedef typename packet_traits<Scalar>::half HalfPacket;
|
||||||
|
|
||||||
@ -306,6 +330,7 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
|
|||||||
public:
|
public:
|
||||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
|
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous> Base;
|
||||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||||
|
typedef SubMapper VectorMapper;
|
||||||
|
|
||||||
TensorContractionInputMapper(const Tensor& tensor,
|
TensorContractionInputMapper(const Tensor& tensor,
|
||||||
const nocontract_t& nocontract_strides,
|
const nocontract_t& nocontract_strides,
|
||||||
@ -319,6 +344,10 @@ class TensorContractionInputMapper<Scalar, Index, side, Tensor, nocontract_t, co
|
|||||||
return SubMapper(*this, i, j);
|
return SubMapper(*this, i, j);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
|
||||||
|
return VectorMapper(*this, i, j);
|
||||||
|
}
|
||||||
|
|
||||||
typedef typename packet_traits<Scalar>::type Packet;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
|
||||||
@ -592,41 +621,80 @@ struct TensorContractionEvaluatorBase
|
|||||||
if (this->m_lhs_inner_dim_contiguous) {
|
if (this->m_lhs_inner_dim_contiguous) {
|
||||||
if (this->m_rhs_inner_dim_contiguous) {
|
if (this->m_rhs_inner_dim_contiguous) {
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
if (this->m_rhs_inner_dim_reordered) {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<true, true, true, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<true, true, false, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
if (this->m_rhs_inner_dim_reordered) {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<true, false, true, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<true, false, false, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (this->m_rhs_inner_dim_contiguous) {
|
if (this->m_rhs_inner_dim_contiguous) {
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
if (this->m_rhs_inner_dim_reordered) {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<false, true, true, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<false, true, false, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
if (this->m_rhs_inner_dim_reordered) {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<false, false, true, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
static_cast<const Derived*>(this)->template evalTyped<false, false, false, Unaligned>(buffer);
|
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
|
void evalGemv(Scalar* buffer) const {
|
||||||
|
const Index rows = m_i_size;
|
||||||
|
const Index cols = m_k_size;
|
||||||
|
|
||||||
|
typedef typename internal::remove_const<typename LeftArgType::Scalar>::type LhsScalar;
|
||||||
|
typedef typename internal::remove_const<typename RightArgType::Scalar>::type RhsScalar;
|
||||||
|
typedef TensorEvaluator<LeftArgType, Device> LeftEvaluator;
|
||||||
|
typedef TensorEvaluator<RightArgType, Device> RightEvaluator;
|
||||||
|
const int lhs_packet_size = internal::packet_traits<LhsScalar>::size;
|
||||||
|
const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
|
||||||
|
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
||||||
|
LeftEvaluator, left_nocontract_t,
|
||||||
|
contract_t, lhs_packet_size,
|
||||||
|
lhs_inner_dim_contiguous,
|
||||||
|
false, Unaligned> LhsMapper;
|
||||||
|
|
||||||
|
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
||||||
|
RightEvaluator, right_nocontract_t,
|
||||||
|
contract_t, rhs_packet_size,
|
||||||
|
rhs_inner_dim_contiguous,
|
||||||
|
rhs_inner_dim_reordered, Unaligned> RhsMapper;
|
||||||
|
|
||||||
|
LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
|
||||||
|
m_left_contracting_strides, m_k_strides);
|
||||||
|
RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
|
||||||
|
m_right_contracting_strides, m_k_strides);
|
||||||
|
|
||||||
|
const Scalar alpha(1);
|
||||||
|
const Index resIncr(1);
|
||||||
|
|
||||||
|
// zero out the result buffer (which must be of size at least rows * sizeof(Scalar)
|
||||||
|
m_device.memset(buffer, 0, rows * sizeof(Scalar));
|
||||||
|
|
||||||
|
internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
|
||||||
|
rows, cols, lhs, rhs,
|
||||||
|
buffer, resIncr, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||||
m_leftImpl.cleanup();
|
m_leftImpl.cleanup();
|
||||||
m_rightImpl.cleanup();
|
m_rightImpl.cleanup();
|
||||||
@ -707,7 +775,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
Base(op, device) { }
|
Base(op, device) { }
|
||||||
|
|
||||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
EIGEN_DEVICE_FUNC void evalTyped(Scalar* buffer) const {
|
void evalProduct(Scalar* buffer) const {
|
||||||
|
if (this->m_j_size == 1) {
|
||||||
|
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
|
EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const {
|
||||||
// columns in left side, rows in right side
|
// columns in left side, rows in right side
|
||||||
const Index k = this->m_k_size;
|
const Index k = this->m_k_size;
|
||||||
|
|
||||||
|
@ -93,7 +93,17 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
Base(op, device) {}
|
Base(op, device) {}
|
||||||
|
|
||||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
void evalTyped(Scalar* buffer) const {
|
void evalProduct(Scalar* buffer) const {
|
||||||
|
if (this->m_j_size == 1) {
|
||||||
|
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
|
void evalGemm(Scalar* buffer) const {
|
||||||
// columns in left side, rows in right side
|
// columns in left side, rows in right side
|
||||||
const Index k = this->m_k_size;
|
const Index k = this->m_k_size;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user