mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-30 15:54:13 +08:00
Added support for tensor concatenation as lvalue
This commit is contained in:
parent
00f048d44f
commit
1d3b64d32b
@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX
|
|||||||
const typename internal::remove_all<typename RhsXprType::Nested>::type&
|
const typename internal::remove_all<typename RhsXprType::Nested>::type&
|
||||||
rhsExpression() const { return m_rhs_xpr; }
|
rhsExpression() const { return m_rhs_xpr; }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; }
|
EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other)
|
||||||
|
{
|
||||||
|
typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign;
|
||||||
|
Assign assign(*this, other);
|
||||||
|
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other)
|
||||||
|
{
|
||||||
|
typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign;
|
||||||
|
Assign assign(*this, other);
|
||||||
|
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
typename LhsXprType::Nested m_lhs_xpr;
|
typename LhsXprType::Nested m_lhs_xpr;
|
||||||
@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy
|
|||||||
const Axis m_axis;
|
const Axis m_axis;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Eval as lvalue
|
||||||
|
template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
|
||||||
|
struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
|
||||||
|
: public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
|
||||||
|
{
|
||||||
|
typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
|
||||||
|
typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
|
||||||
|
typedef typename Base::Dimensions Dimensions;
|
||||||
|
enum {
|
||||||
|
IsAligned = false,
|
||||||
|
PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess,
|
||||||
|
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
|
||||||
|
: Base(op, device)
|
||||||
|
{
|
||||||
|
EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
|
||||||
|
{
|
||||||
|
// Collect dimension-wise indices (subs).
|
||||||
|
array<Index, Base::NumDims> subs;
|
||||||
|
for (int i = Base::NumDims - 1; i > 0; --i) {
|
||||||
|
subs[i] = index / this->m_outputStrides[i];
|
||||||
|
index -= subs[i] * this->m_outputStrides[i];
|
||||||
|
}
|
||||||
|
subs[0] = index;
|
||||||
|
|
||||||
|
const Dimensions& left_dims = this->m_leftImpl.dimensions();
|
||||||
|
if (subs[this->m_axis] < left_dims[this->m_axis]) {
|
||||||
|
Index left_index = subs[0];
|
||||||
|
for (int i = 1; i < Base::NumDims; ++i) {
|
||||||
|
left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
|
||||||
|
}
|
||||||
|
return this->m_leftImpl.coeffRef(left_index);
|
||||||
|
} else {
|
||||||
|
subs[this->m_axis] -= left_dims[this->m_axis];
|
||||||
|
const Dimensions& right_dims = this->m_rightImpl.dimensions();
|
||||||
|
Index right_index = subs[0];
|
||||||
|
for (int i = 1; i < Base::NumDims; ++i) {
|
||||||
|
right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
|
||||||
|
}
|
||||||
|
return this->m_rightImpl.coeffRef(right_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
void writePacket(Index index, const PacketReturnType& x)
|
||||||
|
{
|
||||||
|
static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||||
|
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
|
eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
|
||||||
|
|
||||||
|
EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
|
||||||
|
PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x);
|
||||||
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
|
coeffRef(index+i) = values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user