mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 09:09:36 +08:00
Added support for tensor concatenation as lvalue
This commit is contained in:
parent
159fb181c2
commit
00f048d44f
@ -524,6 +524,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
|
|||||||
swap_layout() const {
|
swap_layout() const {
|
||||||
return TensorLayoutSwapOp<Derived>(derived());
|
return TensorLayoutSwapOp<Derived>(derived());
|
||||||
}
|
}
|
||||||
|
template <typename Axis, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
TensorConcatenationOp<const Axis, Derived, OtherDerived>
|
||||||
|
concatenate(const OtherDerived& other, const Axis& axis) const {
|
||||||
|
return TensorConcatenationOp<const Axis, Derived, OtherDerived>(derived(), other.derived(), axis);
|
||||||
|
}
|
||||||
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
TensorReshapingOp<const NewDimensions, Derived>
|
TensorReshapingOp<const NewDimensions, Derived>
|
||||||
reshape(const NewDimensions& newDimensions) const {
|
reshape(const NewDimensions& newDimensions) const {
|
||||||
|
@ -103,6 +103,25 @@ static void test_simple_concatenation()
|
|||||||
// TODO(phli): Add test once we have a real vectorized implementation.
|
// TODO(phli): Add test once we have a real vectorized implementation.
|
||||||
// static void test_vectorized_concatenation() {}
|
// static void test_vectorized_concatenation() {}
|
||||||
|
|
||||||
|
static void test_concatenation_as_lvalue()
|
||||||
|
{
|
||||||
|
Tensor<int, 2> t1(2, 3);
|
||||||
|
Tensor<int, 2> t2(2, 3);
|
||||||
|
t1.setRandom();
|
||||||
|
t2.setRandom();
|
||||||
|
|
||||||
|
Tensor<int, 2> result(4, 3);
|
||||||
|
result.setRandom();
|
||||||
|
t1.concatenate(t2, 0) = result;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
VERIFY_IS_EQUAL(t1(i, j), result(i, j));
|
||||||
|
VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void test_cxx11_tensor_concatenation()
|
void test_cxx11_tensor_concatenation()
|
||||||
{
|
{
|
||||||
@ -113,4 +132,6 @@ void test_cxx11_tensor_concatenation()
|
|||||||
CALL_SUBTEST(test_simple_concatenation<ColMajor>());
|
CALL_SUBTEST(test_simple_concatenation<ColMajor>());
|
||||||
CALL_SUBTEST(test_simple_concatenation<RowMajor>());
|
CALL_SUBTEST(test_simple_concatenation<RowMajor>());
|
||||||
// CALL_SUBTEST(test_vectorized_concatenation());
|
// CALL_SUBTEST(test_vectorized_concatenation());
|
||||||
|
CALL_SUBTEST(test_concatenation_as_lvalue());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user