Optimized broadcasting

This commit is contained in:
Benoit Steiner 2014-11-12 22:35:44 -08:00
parent c2d1074932
commit eeabf7975e

View File

@ -24,11 +24,13 @@ template<typename Broadcast, typename XprType>
struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType> struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
{ {
typedef typename XprType::Scalar Scalar; typedef typename XprType::Scalar Scalar;
typedef typename internal::packet_traits<Scalar>::type Packet; typedef traits<XprType> XprTraits;
typedef typename traits<XprType>::StorageKind StorageKind; typedef typename packet_traits<Scalar>::type Packet;
typedef typename traits<XprType>::Index Index; typedef typename XprTraits::StorageKind StorageKind;
typedef typename XprTraits::Index Index;
typedef typename XprType::Nested Nested; typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested; typedef typename remove_reference<Nested>::type _Nested;
static const int NumDimensions = XprTraits::NumDimensions;
}; };
template<typename Broadcast, typename XprType> template<typename Broadcast, typename XprType>
@ -85,6 +87,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value; static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
typedef DSizes<Index, NumDims> Dimensions; typedef DSizes<Index, NumDims> Dimensions;
typedef typename XprType::Scalar Scalar; typedef typename XprType::Scalar Scalar;
typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
enum { enum {
IsAligned = false, IsAligned = false,
@ -129,10 +132,19 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
Index inputIndex = 0; Index inputIndex = 0;
for (int i = NumDims - 1; i > 0; --i) { for (int i = NumDims - 1; i > 0; --i) {
const Index idx = index / m_outputStrides[i]; const Index idx = index / m_outputStrides[i];
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
eigen_assert(idx % m_impl.dimensions()[i] == 0);
} else {
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
}
index -= idx * m_outputStrides[i]; index -= idx * m_outputStrides[i];
} }
inputIndex += (index % m_impl.dimensions()[0]); if (internal::index_statically_eq<Broadcast>()(0, 1)) {
eigen_assert(index < m_impl.dimensions()[0]);
inputIndex += index;
} else {
inputIndex += (index % m_impl.dimensions()[0]);
}
return m_impl.coeff(inputIndex); return m_impl.coeff(inputIndex);
} }
@ -150,10 +162,20 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
Index inputIndex = 0; Index inputIndex = 0;
for (int i = NumDims - 1; i > 0; --i) { for (int i = NumDims - 1; i > 0; --i) {
const Index idx = index / m_outputStrides[i]; const Index idx = index / m_outputStrides[i];
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
eigen_assert(idx % m_impl.dimensions()[i] == 0);
} else {
inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
}
index -= idx * m_outputStrides[i]; index -= idx * m_outputStrides[i];
} }
const Index innermostLoc = index % m_impl.dimensions()[0]; Index innermostLoc;
if (internal::index_statically_eq<Broadcast>()(0, 1)) {
eigen_assert(index < m_impl.dimensions()[0]);
innermostLoc = index;
} else {
innermostLoc = index % m_impl.dimensions()[0];
}
inputIndex += innermostLoc; inputIndex += innermostLoc;
// Todo: this could be extended to the second dimension if we're not // Todo: this could be extended to the second dimension if we're not