mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-01 17:50:40 +08:00
Optimized broadcasting
This commit is contained in:
parent
c2d1074932
commit
eeabf7975e
@ -24,11 +24,13 @@ template<typename Broadcast, typename XprType>
|
|||||||
struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
|
struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
|
||||||
{
|
{
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
typedef typename internal::packet_traits<Scalar>::type Packet;
|
typedef traits<XprType> XprTraits;
|
||||||
typedef typename traits<XprType>::StorageKind StorageKind;
|
typedef typename packet_traits<Scalar>::type Packet;
|
||||||
typedef typename traits<XprType>::Index Index;
|
typedef typename XprTraits::StorageKind StorageKind;
|
||||||
|
typedef typename XprTraits::Index Index;
|
||||||
typedef typename XprType::Nested Nested;
|
typedef typename XprType::Nested Nested;
|
||||||
typedef typename remove_reference<Nested>::type _Nested;
|
typedef typename remove_reference<Nested>::type _Nested;
|
||||||
|
static const int NumDimensions = XprTraits::NumDimensions;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Broadcast, typename XprType>
|
template<typename Broadcast, typename XprType>
|
||||||
@ -85,6 +87,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
|||||||
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||||
typedef DSizes<Index, NumDims> Dimensions;
|
typedef DSizes<Index, NumDims> Dimensions;
|
||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
IsAligned = false,
|
IsAligned = false,
|
||||||
@ -129,10 +132,19 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
|||||||
Index inputIndex = 0;
|
Index inputIndex = 0;
|
||||||
for (int i = NumDims - 1; i > 0; --i) {
|
for (int i = NumDims - 1; i > 0; --i) {
|
||||||
const Index idx = index / m_outputStrides[i];
|
const Index idx = index / m_outputStrides[i];
|
||||||
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
|
if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
|
||||||
|
eigen_assert(idx % m_impl.dimensions()[i] == 0);
|
||||||
|
} else {
|
||||||
|
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
|
||||||
|
}
|
||||||
index -= idx * m_outputStrides[i];
|
index -= idx * m_outputStrides[i];
|
||||||
}
|
}
|
||||||
inputIndex += (index % m_impl.dimensions()[0]);
|
if (internal::index_statically_eq<Broadcast>()(0, 1)) {
|
||||||
|
eigen_assert(index < m_impl.dimensions()[0]);
|
||||||
|
inputIndex += index;
|
||||||
|
} else {
|
||||||
|
inputIndex += (index % m_impl.dimensions()[0]);
|
||||||
|
}
|
||||||
return m_impl.coeff(inputIndex);
|
return m_impl.coeff(inputIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,10 +162,20 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
|||||||
Index inputIndex = 0;
|
Index inputIndex = 0;
|
||||||
for (int i = NumDims - 1; i > 0; --i) {
|
for (int i = NumDims - 1; i > 0; --i) {
|
||||||
const Index idx = index / m_outputStrides[i];
|
const Index idx = index / m_outputStrides[i];
|
||||||
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
|
if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
|
||||||
|
eigen_assert(idx % m_impl.dimensions()[i] == 0);
|
||||||
|
} else {
|
||||||
|
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
|
||||||
|
}
|
||||||
index -= idx * m_outputStrides[i];
|
index -= idx * m_outputStrides[i];
|
||||||
}
|
}
|
||||||
const Index innermostLoc = index % m_impl.dimensions()[0];
|
Index innermostLoc;
|
||||||
|
if (internal::index_statically_eq<Broadcast>()(0, 1)) {
|
||||||
|
eigen_assert(index < m_impl.dimensions()[0]);
|
||||||
|
innermostLoc = index;
|
||||||
|
} else {
|
||||||
|
innermostLoc = index % m_impl.dimensions()[0];
|
||||||
|
}
|
||||||
inputIndex += innermostLoc;
|
inputIndex += innermostLoc;
|
||||||
|
|
||||||
// Todo: this could be extended to the second dimension if we're not
|
// Todo: this could be extended to the second dimension if we're not
|
||||||
|
Loading…
x
Reference in New Issue
Block a user