mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-30 07:44:10 +08:00
Added support for tensor concatenation as lvalue
This commit is contained in:
parent
00f048d44f
commit
1d3b64d32b
@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX
|
||||
const typename internal::remove_all<typename RhsXprType::Nested>::type&
|
||||
rhsExpression() const { return m_rhs_xpr; }
|
||||
|
||||
EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; }
|
||||
EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other)
|
||||
{
|
||||
typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign;
|
||||
Assign assign(*this, other);
|
||||
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<typename OtherDerived>
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other)
|
||||
{
|
||||
typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign;
|
||||
Assign assign(*this, other);
|
||||
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
|
||||
return *this;
|
||||
}
|
||||
|
||||
protected:
|
||||
typename LhsXprType::Nested m_lhs_xpr;
|
||||
@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy
|
||||
const Axis m_axis;
|
||||
};
|
||||
|
||||
// Eval as lvalue
|
||||
template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
|
||||
struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
|
||||
: public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
|
||||
{
|
||||
typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
|
||||
typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
|
||||
typedef typename Base::Dimensions Dimensions;
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess,
|
||||
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
|
||||
: Base(op, device)
|
||||
{
|
||||
EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
}
|
||||
|
||||
typedef typename XprType::Index Index;
|
||||
typedef typename XprType::Scalar Scalar;
|
||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||
typedef typename XprType::PacketReturnType PacketReturnType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
|
||||
{
|
||||
// Collect dimension-wise indices (subs).
|
||||
array<Index, Base::NumDims> subs;
|
||||
for (int i = Base::NumDims - 1; i > 0; --i) {
|
||||
subs[i] = index / this->m_outputStrides[i];
|
||||
index -= subs[i] * this->m_outputStrides[i];
|
||||
}
|
||||
subs[0] = index;
|
||||
|
||||
const Dimensions& left_dims = this->m_leftImpl.dimensions();
|
||||
if (subs[this->m_axis] < left_dims[this->m_axis]) {
|
||||
Index left_index = subs[0];
|
||||
for (int i = 1; i < Base::NumDims; ++i) {
|
||||
left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
|
||||
}
|
||||
return this->m_leftImpl.coeffRef(left_index);
|
||||
} else {
|
||||
subs[this->m_axis] -= left_dims[this->m_axis];
|
||||
const Dimensions& right_dims = this->m_rightImpl.dimensions();
|
||||
Index right_index = subs[0];
|
||||
for (int i = 1; i < Base::NumDims; ++i) {
|
||||
right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
|
||||
}
|
||||
return this->m_rightImpl.coeffRef(right_index);
|
||||
}
|
||||
}
|
||||
|
||||
template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
void writePacket(Index index, const PacketReturnType& x)
|
||||
{
|
||||
static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||
eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
|
||||
|
||||
EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
|
||||
PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x);
|
||||
for (int i = 0; i < packetSize; ++i) {
|
||||
coeffRef(index+i) = values[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user