Optimize check_rows_cols_for_overflow

This commit is contained in:
Charles Schlosser 2023-07-10 17:40:17 +00:00
parent 9297aae66f
commit 21cd3fe209
3 changed files with 35 additions and 14 deletions

View File

@ -28,21 +28,40 @@ namespace Eigen {
namespace internal { namespace internal {
template<int MaxSizeAtCompileTime> struct check_rows_cols_for_overflow { template <int MaxSizeAtCompileTime, int MaxRowsAtCompileTime, int MaxColsAtCompileTime>
struct check_rows_cols_for_overflow {
EIGEN_STATIC_ASSERT(MaxRowsAtCompileTime * MaxColsAtCompileTime == MaxSizeAtCompileTime,YOU MADE A PROGRAMMING MISTAKE)
template <typename Index> template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index, Index) {} EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index, Index) {}
}; };
template<> struct check_rows_cols_for_overflow<Dynamic> { template <int MaxRowsAtCompileTime>
struct check_rows_cols_for_overflow<Dynamic, MaxRowsAtCompileTime, Dynamic> {
template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index, Index cols) {
constexpr Index MaxIndex = NumTraits<Index>::highest();
bool error = cols > MaxIndex / MaxRowsAtCompileTime;
if (error) throw_std_bad_alloc();
}
};
template <int MaxColsAtCompileTime>
struct check_rows_cols_for_overflow<Dynamic, Dynamic, MaxColsAtCompileTime> {
template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index rows, Index) {
constexpr Index MaxIndex = NumTraits<Index>::highest();
bool error = rows > MaxIndex / MaxColsAtCompileTime;
if (error) throw_std_bad_alloc();
}
};
template <>
struct check_rows_cols_for_overflow<Dynamic, Dynamic, Dynamic> {
template <typename Index> template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index rows, Index cols) { EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index rows, Index cols) {
// http://hg.mozilla.org/mozilla-central/file/6c8a909977d3/xpcom/ds/CheckedInt.h#l242 constexpr Index MaxIndex = NumTraits<Index>::highest();
// we assume Index is signed bool error = cols == 0 ? false : (rows > MaxIndex / cols);
Index max_index = (std::size_t(1) << (8 * sizeof(Index) - 1)) - 1; // assume Index is signed if (error) throw_std_bad_alloc();
bool error = (rows == 0 || cols == 0) ? false
: (rows > max_index / cols);
if (error)
throw_std_bad_alloc();
} }
}; };
@ -268,7 +287,7 @@ class PlainObjectBase : public internal::dense_xpr_base<Derived>::type
&& internal::check_implication(RowsAtCompileTime==Dynamic && MaxRowsAtCompileTime!=Dynamic, rows<=MaxRowsAtCompileTime) && internal::check_implication(RowsAtCompileTime==Dynamic && MaxRowsAtCompileTime!=Dynamic, rows<=MaxRowsAtCompileTime)
&& internal::check_implication(ColsAtCompileTime==Dynamic && MaxColsAtCompileTime!=Dynamic, cols<=MaxColsAtCompileTime) && internal::check_implication(ColsAtCompileTime==Dynamic && MaxColsAtCompileTime!=Dynamic, cols<=MaxColsAtCompileTime)
&& rows>=0 && cols>=0 && "Invalid sizes when resizing a matrix or array."); && rows>=0 && cols>=0 && "Invalid sizes when resizing a matrix or array.");
internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime>::run(rows, cols); internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime, MaxRowsAtCompileTime, MaxColsAtCompileTime>::run(rows, cols);
#ifdef EIGEN_INITIALIZE_COEFFS #ifdef EIGEN_INITIALIZE_COEFFS
Index size = rows*cols; Index size = rows*cols;
bool size_changed = size != this->size(); bool size_changed = size != this->size();
@ -340,7 +359,7 @@ class PlainObjectBase : public internal::dense_xpr_base<Derived>::type
EIGEN_STRONG_INLINE void resizeLike(const EigenBase<OtherDerived>& _other) EIGEN_STRONG_INLINE void resizeLike(const EigenBase<OtherDerived>& _other)
{ {
const OtherDerived& other = _other.derived(); const OtherDerived& other = _other.derived();
internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime>::run(other.rows(), other.cols()); internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime, MaxRowsAtCompileTime, MaxColsAtCompileTime>::run(other.rows(), other.cols());
const Index othersize = other.rows()*other.cols(); const Index othersize = other.rows()*other.cols();
if(RowsAtCompileTime == 1) if(RowsAtCompileTime == 1)
{ {
@ -965,7 +984,7 @@ struct conservative_resize_like_impl
&& (( Derived::IsRowMajor && _this.cols() == cols) || // row-major and we change only the number of rows && (( Derived::IsRowMajor && _this.cols() == cols) || // row-major and we change only the number of rows
(!Derived::IsRowMajor && _this.rows() == rows) )) // column-major and we change only the number of columns (!Derived::IsRowMajor && _this.rows() == rows) )) // column-major and we change only the number of columns
{ {
internal::check_rows_cols_for_overflow<Derived::MaxSizeAtCompileTime>::run(rows, cols); internal::check_rows_cols_for_overflow<Derived::MaxSizeAtCompileTime, Derived::MaxRowsAtCompileTime, Derived::MaxColsAtCompileTime>::run(rows, cols);
_this.derived().m_storage.conservativeResize(rows*cols,rows,cols); _this.derived().m_storage.conservativeResize(rows*cols,rows,cols);
} }
else else

View File

@ -311,7 +311,9 @@ constexpr inline unsigned compute_matrix_flags(int Options) {
} }
constexpr inline int size_at_compile_time(int rows, int cols) { constexpr inline int size_at_compile_time(int rows, int cols) {
return (rows==Dynamic || cols==Dynamic) ? Dynamic : rows * cols; if (rows == 0 || cols == 0) return 0;
if (rows == Dynamic || cols == Dynamic) return Dynamic;
return rows * cols;
} }
template<typename XprType> struct size_of_xpr_at_compile_time template<typename XprType> struct size_of_xpr_at_compile_time

View File

@ -360,7 +360,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
int i; int i;
Index size = Index(1); Index size = Index(1);
for (i = 0; i < NumIndices; i++) { for (i = 0; i < NumIndices; i++) {
internal::check_rows_cols_for_overflow<Dynamic>::run(size, dimensions[i]); internal::check_rows_cols_for_overflow<Dynamic, Dynamic, Dynamic>::run(size, dimensions[i]);
size *= dimensions[i]; size *= dimensions[i];
} }
#ifdef EIGEN_INITIALIZE_COEFFS #ifdef EIGEN_INITIALIZE_COEFFS