Added support for tensor concatenation as lvalue

This commit is contained in:
Benoit Steiner 2015-02-17 09:54:40 -08:00
parent 159fb181c2
commit 00f048d44f
2 changed files with 26 additions and 0 deletions

View File

@ -524,6 +524,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
swap_layout() const { swap_layout() const {
return TensorLayoutSwapOp<Derived>(derived()); return TensorLayoutSwapOp<Derived>(derived());
} }
template <typename Axis, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorConcatenationOp<const Axis, Derived, OtherDerived>
concatenate(const OtherDerived& other, const Axis& axis) const {
return TensorConcatenationOp<const Axis, Derived, OtherDerived>(derived(), other.derived(), axis);
}
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorReshapingOp<const NewDimensions, Derived> TensorReshapingOp<const NewDimensions, Derived>
reshape(const NewDimensions& newDimensions) const { reshape(const NewDimensions& newDimensions) const {

View File

@ -103,6 +103,25 @@ static void test_simple_concatenation()
// TODO(phli): Add test once we have a real vectorized implementation. // TODO(phli): Add test once we have a real vectorized implementation.
// static void test_vectorized_concatenation() {} // static void test_vectorized_concatenation() {}
static void test_concatenation_as_lvalue()
{
Tensor<int, 2> t1(2, 3);
Tensor<int, 2> t2(2, 3);
t1.setRandom();
t2.setRandom();
Tensor<int, 2> result(4, 3);
result.setRandom();
t1.concatenate(t2, 0) = result;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
VERIFY_IS_EQUAL(t1(i, j), result(i, j));
VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
}
}
}
void test_cxx11_tensor_concatenation() void test_cxx11_tensor_concatenation()
{ {
@ -113,4 +132,6 @@ void test_cxx11_tensor_concatenation()
CALL_SUBTEST(test_simple_concatenation<ColMajor>()); CALL_SUBTEST(test_simple_concatenation<ColMajor>());
CALL_SUBTEST(test_simple_concatenation<RowMajor>()); CALL_SUBTEST(test_simple_concatenation<RowMajor>());
// CALL_SUBTEST(test_vectorized_concatenation()); // CALL_SUBTEST(test_vectorized_concatenation());
CALL_SUBTEST(test_concatenation_as_lvalue());
} }