diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index 9ab6b3565..b35b36475 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -161,6 +161,22 @@ struct TensorEvaluator, Device> } } } + + // Handle special format like NCHW, its input shape is '[1, N..., 1]' and + // broadcast shape is '[N, 1..., N]' + if (!oneByN && !nByOne) { + if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) { + nByOne = true; + oneByN = true; + for (int i = 1; i < NumDims-1; ++i) { + if (broadcast[i] != 1) { + nByOne = false; + oneByN = false; + break; + } + } + } + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -256,24 +272,70 @@ struct TensorEvaluator, Device> } if (static_cast(Layout) == static_cast(ColMajor)) { - if (oneByN) { + if (oneByN && !nByOne) { return packetNByOne(index); - } else if (nByOne) { + } else if (!oneByN && nByOne) { return packetOneByN(index); + } else if (oneByN && nByOne) { + return packetOneByNByOne(index); } else { return packetColMajor(index); } } else { - if (oneByN) { + if (oneByN && !nByOne) { return packetOneByN(index); - } else if (nByOne) { + } else if (!oneByN && nByOne) { return packetNByOne(index); + } else if (oneByN && nByOne) { + return packetOneByNByOne(index); } else { return packetRowMajor(index); } } } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne + (Index index) const + { + EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); + + EIGEN_ALIGN_MAX typename internal::remove_const::type values[PacketSize]; + Index startDim, endDim; + Index inputIndex, outputOffset, batchedIndex; + + if (static_cast(Layout) == static_cast(ColMajor)) { + startDim = NumDims - 1; + endDim = 1; + } else { + startDim = 0; + endDim = NumDims - 2; + } + + batchedIndex = index % m_outputStrides[startDim]; + inputIndex = batchedIndex / m_outputStrides[endDim]; + outputOffset = batchedIndex % m_outputStrides[endDim]; + + if (outputOffset + PacketSize <= m_outputStrides[endDim]) { + values[0] = m_impl.coeff(inputIndex); + return internal::pload1(values); + } else { + for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) { + if (outputOffset + cur < m_outputStrides[endDim]) { + values[i] = m_impl.coeff(inputIndex); + } else { + ++inputIndex; + inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex); + values[i] = m_impl.coeff(inputIndex); + outputOffset = 0; + cur = 0; + } + } + return internal::pload(values); + } + } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const { diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp index a9d268ea6..f0ff03184 100644 --- a/unsupported/test/cxx11_tensor_broadcasting.cpp +++ b/unsupported/test/cxx11_tensor_broadcasting.cpp @@ -238,6 +238,59 @@ static void test_simple_broadcasting_n_by_one() } } +template +static void test_simple_broadcasting_one_by_n_by_one_1d() +{ + Tensor tensor(1,7,1); + tensor.setRandom(); + array broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 13; + Tensor broadcasted; + broadcasted = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcasted.dimension(0), 5); + VERIFY_IS_EQUAL(broadcasted.dimension(1), 7); + VERIFY_IS_EQUAL(broadcasted.dimension(2), 13); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k)); + } + } + } +} + +template +static void test_simple_broadcasting_one_by_n_by_one_2d() +{ + Tensor tensor(1,7,13,1); + tensor.setRandom(); + array broadcasts; + broadcasts[0] = 5; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 5); + VERIFY_IS_EQUAL(broadcast.dimension(1), 7); + VERIFY_IS_EQUAL(broadcast.dimension(2), 13); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 7; ++j) { + for (int k = 0; k < 13; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l)); + } + } + } + } +} void test_cxx11_tensor_broadcasting() { @@ -253,4 +306,8 @@ void test_cxx11_tensor_broadcasting() CALL_SUBTEST(test_simple_broadcasting_n_by_one()); CALL_SUBTEST(test_simple_broadcasting_one_by_n()); CALL_SUBTEST(test_simple_broadcasting_n_by_one()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d()); }