Added the ability to pad a tensor using a non-zero value

This commit is contained in:
Benoit Steiner 2016-03-07 14:45:37 -08:00
parent 7f87cc3a3b
commit 769685e74e
2 changed files with 31 additions and 21 deletions

View File

@ -643,7 +643,12 @@ class TensorBase<Derived, ReadOnlyAccessors>
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorPaddingOp<const PaddingDimensions, const Derived> const TensorPaddingOp<const PaddingDimensions, const Derived>
pad(const PaddingDimensions& padding) const { pad(const PaddingDimensions& padding) const {
return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding); return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding, internal::scalar_cast_op<int, Scalar>()(0));
}
template <typename PaddingDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorPaddingOp<const PaddingDimensions, const Derived>
pad(const PaddingDimensions& padding, const Scalar padding_value) const {
return TensorPaddingOp<const PaddingDimensions, const Derived>(derived(), padding, padding_value);
} }
template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename Shuffle> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorShufflingOp<const Shuffle, const Derived> const TensorShufflingOp<const Shuffle, const Derived>

View File

@ -16,7 +16,7 @@ namespace Eigen {
* \ingroup CXX11_Tensor_Module * \ingroup CXX11_Tensor_Module
* *
* \brief Tensor padding class. * \brief Tensor padding class.
* At the moment only 0-padding is supported. * At the moment only padding with a constant value is supported.
* *
*/ */
namespace internal { namespace internal {
@ -63,11 +63,13 @@ class TensorPaddingOp : public TensorBase<TensorPaddingOp<PaddingDimensions, Xpr
typedef typename Eigen::internal::traits<TensorPaddingOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorPaddingOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorPaddingOp>::Index Index; typedef typename Eigen::internal::traits<TensorPaddingOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPaddingOp(const XprType& expr, const PaddingDimensions& padding_dims) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPaddingOp(const XprType& expr, const PaddingDimensions& padding_dims, const Scalar padding_value)
: m_xpr(expr), m_padding_dims(padding_dims) {} : m_xpr(expr), m_padding_dims(padding_dims), m_padding_value(padding_value) {}
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const PaddingDimensions& padding() const { return m_padding_dims; } const PaddingDimensions& padding() const { return m_padding_dims; }
EIGEN_DEVICE_FUNC
Scalar padding_value() const { return m_padding_value; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename XprType::Nested>::type& const typename internal::remove_all<typename XprType::Nested>::type&
@ -76,6 +78,7 @@ class TensorPaddingOp : public TensorBase<TensorPaddingOp<PaddingDimensions, Xpr
protected: protected:
typename XprType::Nested m_xpr; typename XprType::Nested m_xpr;
const PaddingDimensions m_padding_dims; const PaddingDimensions m_padding_dims;
const Scalar m_padding_value;
}; };
@ -97,7 +100,7 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
}; };
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_padding(op.padding()) : m_impl(op.expression(), device), m_padding(op.padding()), m_paddingValue(op.padding_value())
{ {
// The padding op doesn't change the rank of the tensor. Directly padding a scalar would lead // The padding op doesn't change the rank of the tensor. Directly padding a scalar would lead
// to a vector, which doesn't make sense. Instead one should reshape the scalar into a vector // to a vector, which doesn't make sense. Instead one should reshape the scalar into a vector
@ -151,27 +154,27 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
for (int i = NumDims - 1; i > 0; --i) { for (int i = NumDims - 1; i > 0; --i) {
const Index idx = index / m_outputStrides[i]; const Index idx = index / m_outputStrides[i];
if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
index -= idx * m_outputStrides[i]; index -= idx * m_outputStrides[i];
} }
if (index < m_padding[0].first || index >= m_dimensions[0] - m_padding[0].second) { if (index < m_padding[0].first || index >= m_dimensions[0] - m_padding[0].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex += (index - m_padding[0].first); inputIndex += (index - m_padding[0].first);
} else { } else {
for (int i = 0; i < NumDims - 1; ++i) { for (int i = 0; i < NumDims - 1; ++i) {
const Index idx = index / m_outputStrides[i+1]; const Index idx = index / m_outputStrides[i+1];
if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
index -= idx * m_outputStrides[i+1]; index -= idx * m_outputStrides[i+1];
} }
if (index < m_padding[NumDims-1].first || if (index < m_padding[NumDims-1].first ||
index >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) { index >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex += (index - m_padding[NumDims-1].first); inputIndex += (index - m_padding[NumDims-1].first);
} }
@ -194,14 +197,14 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
{ {
const Index idx = coords[0]; const Index idx = coords[0];
if (idx < m_padding[0].first || idx >= m_dimensions[0] - m_padding[0].second) { if (idx < m_padding[0].first || idx >= m_dimensions[0] - m_padding[0].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex = idx - m_padding[0].first; inputIndex = idx - m_padding[0].first;
} }
for (int i = 1; i < NumDims; ++i) { for (int i = 1; i < NumDims; ++i) {
const Index idx = coords[i]; const Index idx = coords[i];
if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
} }
@ -209,14 +212,14 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
{ {
const Index idx = coords[NumDims-1]; const Index idx = coords[NumDims-1];
if (idx < m_padding[NumDims-1].first || idx >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) { if (idx < m_padding[NumDims-1].first || idx >= m_dimensions[NumDims-1] - m_padding[NumDims-1].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex = idx - m_padding[NumDims-1].first; inputIndex = idx - m_padding[NumDims-1].first;
} }
for (int i = NumDims - 2; i >= 0; --i) { for (int i = NumDims - 2; i >= 0; --i) {
const Index idx = coords[i]; const Index idx = coords[i];
if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) { if (idx < m_padding[i].first || idx >= m_dimensions[i] - m_padding[i].second) {
return internal::scalar_cast_op<int, Scalar>()(0); return m_paddingValue;
} }
inputIndex += (idx - m_padding[i].first) * m_inputStrides[i]; inputIndex += (idx - m_padding[i].first) * m_inputStrides[i];
} }
@ -245,11 +248,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
if (last < lastPaddedLeft) { if (last < lastPaddedLeft) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= firstPaddedRight && last < lastPaddedRight) { else if (first >= firstPaddedRight && last < lastPaddedRight) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= lastPaddedLeft && last < firstPaddedRight) { else if (first >= lastPaddedLeft && last < firstPaddedRight) {
// all the coefficient are between the 2 padding zones. // all the coefficient are between the 2 padding zones.
@ -271,11 +274,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
if (last < lastPaddedLeft) { if (last < lastPaddedLeft) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= firstPaddedRight && last < lastPaddedRight) { else if (first >= firstPaddedRight && last < lastPaddedRight) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= lastPaddedLeft && last < firstPaddedRight) { else if (first >= lastPaddedLeft && last < firstPaddedRight) {
// all the coefficient are between the 2 padding zones. // all the coefficient are between the 2 padding zones.
@ -304,11 +307,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
if (last < lastPaddedLeft) { if (last < lastPaddedLeft) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= firstPaddedRight && last < lastPaddedRight) { else if (first >= firstPaddedRight && last < lastPaddedRight) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= lastPaddedLeft && last < firstPaddedRight) { else if (first >= lastPaddedLeft && last < firstPaddedRight) {
// all the coefficient are between the 2 padding zones. // all the coefficient are between the 2 padding zones.
@ -330,11 +333,11 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
if (last < lastPaddedLeft) { if (last < lastPaddedLeft) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= firstPaddedRight && last < lastPaddedRight) { else if (first >= firstPaddedRight && last < lastPaddedRight) {
// all the coefficient are in the padding zone. // all the coefficient are in the padding zone.
return internal::pset1<PacketReturnType>(internal::scalar_cast_op<int, Scalar>()(0)); return internal::pset1<PacketReturnType>(m_paddingValue);
} }
else if (first >= lastPaddedLeft && last < firstPaddedRight) { else if (first >= lastPaddedLeft && last < firstPaddedRight) {
// all the coefficient are between the 2 padding zones. // all the coefficient are between the 2 padding zones.
@ -361,6 +364,8 @@ struct TensorEvaluator<const TensorPaddingOp<PaddingDimensions, ArgType>, Device
array<Index, NumDims> m_inputStrides; array<Index, NumDims> m_inputStrides;
TensorEvaluator<ArgType, Device> m_impl; TensorEvaluator<ArgType, Device> m_impl;
PaddingDimensions m_padding; PaddingDimensions m_padding;
Scalar m_paddingValue;
}; };