Added support for tensor concatenation as lvalue

This commit is contained in:
Benoit Steiner 2015-02-17 09:57:41 -08:00
parent 00f048d44f
commit 1d3b64d32b

View File

@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX
const typename internal::remove_all<typename RhsXprType::Nested>::type& const typename internal::remove_all<typename RhsXprType::Nested>::type&
rhsExpression() const { return m_rhs_xpr; } rhsExpression() const { return m_rhs_xpr; }
EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; } EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other)
{
typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign;
Assign assign(*this, other);
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
return *this;
}
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other)
{
typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign;
Assign assign(*this, other);
internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
return *this;
}
protected: protected:
typename LhsXprType::Nested m_lhs_xpr; typename LhsXprType::Nested m_lhs_xpr;
@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy
const Axis m_axis; const Axis m_axis;
}; };
// Eval as lvalue
template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
: public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
{
typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
typedef typename Base::Dimensions Dimensions;
enum {
IsAligned = false,
PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess,
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
: Base(op, device)
{
EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE);
}
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
{
// Collect dimension-wise indices (subs).
array<Index, Base::NumDims> subs;
for (int i = Base::NumDims - 1; i > 0; --i) {
subs[i] = index / this->m_outputStrides[i];
index -= subs[i] * this->m_outputStrides[i];
}
subs[0] = index;
const Dimensions& left_dims = this->m_leftImpl.dimensions();
if (subs[this->m_axis] < left_dims[this->m_axis]) {
Index left_index = subs[0];
for (int i = 1; i < Base::NumDims; ++i) {
left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
}
return this->m_leftImpl.coeffRef(left_index);
} else {
subs[this->m_axis] -= left_dims[this->m_axis];
const Dimensions& right_dims = this->m_rightImpl.dimensions();
Index right_index = subs[0];
for (int i = 1; i < Base::NumDims; ++i) {
right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
}
return this->m_rightImpl.coeffRef(right_index);
}
}
template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void writePacket(Index index, const PacketReturnType& x)
{
static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x);
for (int i = 0; i < packetSize; ++i) {
coeffRef(index+i) = values[i];
}
}
};
} // end namespace Eigen } // end namespace Eigen