mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 19:59:05 +08:00
bug #1567: add optimized path for tensor broadcasting and 'Channel First' shape
This commit is contained in:
parent
ec323b7e66
commit
6190aa5632
@ -161,6 +161,22 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle special format like NCHW, its input shape is '[1, N..., 1]' and
|
||||||
|
// broadcast shape is '[N, 1..., N]'
|
||||||
|
if (!oneByN && !nByOne) {
|
||||||
|
if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
|
||||||
|
nByOne = true;
|
||||||
|
oneByN = true;
|
||||||
|
for (int i = 1; i < NumDims-1; ++i) {
|
||||||
|
if (broadcast[i] != 1) {
|
||||||
|
nByOne = false;
|
||||||
|
oneByN = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||||
@ -256,24 +272,70 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
if (oneByN) {
|
if (oneByN && !nByOne) {
|
||||||
return packetNByOne<LoadMode>(index);
|
return packetNByOne<LoadMode>(index);
|
||||||
} else if (nByOne) {
|
} else if (!oneByN && nByOne) {
|
||||||
return packetOneByN<LoadMode>(index);
|
return packetOneByN<LoadMode>(index);
|
||||||
|
} else if (oneByN && nByOne) {
|
||||||
|
return packetOneByNByOne<LoadMode>(index);
|
||||||
} else {
|
} else {
|
||||||
return packetColMajor<LoadMode>(index);
|
return packetColMajor<LoadMode>(index);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (oneByN) {
|
if (oneByN && !nByOne) {
|
||||||
return packetOneByN<LoadMode>(index);
|
return packetOneByN<LoadMode>(index);
|
||||||
} else if (nByOne) {
|
} else if (!oneByN && nByOne) {
|
||||||
return packetNByOne<LoadMode>(index);
|
return packetNByOne<LoadMode>(index);
|
||||||
|
} else if (oneByN && nByOne) {
|
||||||
|
return packetOneByNByOne<LoadMode>(index);
|
||||||
} else {
|
} else {
|
||||||
return packetRowMajor<LoadMode>(index);
|
return packetRowMajor<LoadMode>(index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
|
||||||
|
(Index index) const
|
||||||
|
{
|
||||||
|
EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
|
eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
|
||||||
|
|
||||||
|
EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
|
||||||
|
Index startDim, endDim;
|
||||||
|
Index inputIndex, outputOffset, batchedIndex;
|
||||||
|
|
||||||
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
|
startDim = NumDims - 1;
|
||||||
|
endDim = 1;
|
||||||
|
} else {
|
||||||
|
startDim = 0;
|
||||||
|
endDim = NumDims - 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
batchedIndex = index % m_outputStrides[startDim];
|
||||||
|
inputIndex = batchedIndex / m_outputStrides[endDim];
|
||||||
|
outputOffset = batchedIndex % m_outputStrides[endDim];
|
||||||
|
|
||||||
|
if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
|
||||||
|
values[0] = m_impl.coeff(inputIndex);
|
||||||
|
return internal::pload1<PacketReturnType>(values);
|
||||||
|
} else {
|
||||||
|
for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
|
||||||
|
if (outputOffset + cur < m_outputStrides[endDim]) {
|
||||||
|
values[i] = m_impl.coeff(inputIndex);
|
||||||
|
} else {
|
||||||
|
++inputIndex;
|
||||||
|
inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
|
||||||
|
values[i] = m_impl.coeff(inputIndex);
|
||||||
|
outputOffset = 0;
|
||||||
|
cur = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return internal::pload<PacketReturnType>(values);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<int LoadMode>
|
template<int LoadMode>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const
|
||||||
{
|
{
|
||||||
|
@ -238,6 +238,59 @@ static void test_simple_broadcasting_n_by_one()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_simple_broadcasting_one_by_n_by_one_1d()
|
||||||
|
{
|
||||||
|
Tensor<float, 3, DataLayout> tensor(1,7,1);
|
||||||
|
tensor.setRandom();
|
||||||
|
array<ptrdiff_t, 3> broadcasts;
|
||||||
|
broadcasts[0] = 5;
|
||||||
|
broadcasts[1] = 1;
|
||||||
|
broadcasts[2] = 13;
|
||||||
|
Tensor<float, 3, DataLayout> broadcasted;
|
||||||
|
broadcasted = tensor.broadcast(broadcasts);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
|
||||||
|
VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
|
||||||
|
VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (int j = 0; j < 7; ++j) {
|
||||||
|
for (int k = 0; k < 13; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_simple_broadcasting_one_by_n_by_one_2d()
|
||||||
|
{
|
||||||
|
Tensor<float, 4, DataLayout> tensor(1,7,13,1);
|
||||||
|
tensor.setRandom();
|
||||||
|
array<ptrdiff_t, 4> broadcasts;
|
||||||
|
broadcasts[0] = 5;
|
||||||
|
broadcasts[1] = 1;
|
||||||
|
broadcasts[2] = 1;
|
||||||
|
broadcasts[3] = 19;
|
||||||
|
Tensor<float, 4, DataLayout> broadcast;
|
||||||
|
broadcast = tensor.broadcast(broadcasts);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
|
||||||
|
VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
|
||||||
|
VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
|
||||||
|
VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (int j = 0; j < 7; ++j) {
|
||||||
|
for (int k = 0; k < 13; ++k) {
|
||||||
|
for (int l = 0; l < 19; ++l) {
|
||||||
|
VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void test_cxx11_tensor_broadcasting()
|
void test_cxx11_tensor_broadcasting()
|
||||||
{
|
{
|
||||||
@ -253,4 +306,8 @@ void test_cxx11_tensor_broadcasting()
|
|||||||
CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
|
CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
|
||||||
CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
|
CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
|
||||||
CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
|
CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user