Optimize check_rows_cols_for_overflow

This commit is contained in:
Charles Schlosser 2023-07-10 17:40:17 +00:00
parent 9297aae66f
commit 21cd3fe209
3 changed files with 35 additions and 14 deletions

View File

@ -28,21 +28,40 @@ namespace Eigen {
namespace internal {
template<int MaxSizeAtCompileTime> struct check_rows_cols_for_overflow {
template <int MaxSizeAtCompileTime, int MaxRowsAtCompileTime, int MaxColsAtCompileTime>
struct check_rows_cols_for_overflow {
EIGEN_STATIC_ASSERT(MaxRowsAtCompileTime * MaxColsAtCompileTime == MaxSizeAtCompileTime,YOU MADE A PROGRAMMING MISTAKE)
template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index, Index) {}
};
template<> struct check_rows_cols_for_overflow<Dynamic> {
template <int MaxRowsAtCompileTime>
struct check_rows_cols_for_overflow<Dynamic, MaxRowsAtCompileTime, Dynamic> {
template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index, Index cols) {
constexpr Index MaxIndex = NumTraits<Index>::highest();
bool error = cols > MaxIndex / MaxRowsAtCompileTime;
if (error) throw_std_bad_alloc();
}
};
template <int MaxColsAtCompileTime>
struct check_rows_cols_for_overflow<Dynamic, Dynamic, MaxColsAtCompileTime> {
template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index rows, Index) {
constexpr Index MaxIndex = NumTraits<Index>::highest();
bool error = rows > MaxIndex / MaxColsAtCompileTime;
if (error) throw_std_bad_alloc();
}
};
template <>
struct check_rows_cols_for_overflow<Dynamic, Dynamic, Dynamic> {
template <typename Index>
EIGEN_DEVICE_FUNC static EIGEN_ALWAYS_INLINE constexpr void run(Index rows, Index cols) {
// http://hg.mozilla.org/mozilla-central/file/6c8a909977d3/xpcom/ds/CheckedInt.h#l242
// we assume Index is signed
Index max_index = (std::size_t(1) << (8 * sizeof(Index) - 1)) - 1; // assume Index is signed
bool error = (rows == 0 || cols == 0) ? false
: (rows > max_index / cols);
if (error)
throw_std_bad_alloc();
constexpr Index MaxIndex = NumTraits<Index>::highest();
bool error = cols == 0 ? false : (rows > MaxIndex / cols);
if (error) throw_std_bad_alloc();
}
};
@ -268,7 +287,7 @@ class PlainObjectBase : public internal::dense_xpr_base<Derived>::type
&& internal::check_implication(RowsAtCompileTime==Dynamic && MaxRowsAtCompileTime!=Dynamic, rows<=MaxRowsAtCompileTime)
&& internal::check_implication(ColsAtCompileTime==Dynamic && MaxColsAtCompileTime!=Dynamic, cols<=MaxColsAtCompileTime)
&& rows>=0 && cols>=0 && "Invalid sizes when resizing a matrix or array.");
internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime>::run(rows, cols);
internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime, MaxRowsAtCompileTime, MaxColsAtCompileTime>::run(rows, cols);
#ifdef EIGEN_INITIALIZE_COEFFS
Index size = rows*cols;
bool size_changed = size != this->size();
@ -340,7 +359,7 @@ class PlainObjectBase : public internal::dense_xpr_base<Derived>::type
EIGEN_STRONG_INLINE void resizeLike(const EigenBase<OtherDerived>& _other)
{
const OtherDerived& other = _other.derived();
internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime>::run(other.rows(), other.cols());
internal::check_rows_cols_for_overflow<MaxSizeAtCompileTime, MaxRowsAtCompileTime, MaxColsAtCompileTime>::run(other.rows(), other.cols());
const Index othersize = other.rows()*other.cols();
if(RowsAtCompileTime == 1)
{
@ -965,7 +984,7 @@ struct conservative_resize_like_impl
&& (( Derived::IsRowMajor && _this.cols() == cols) || // row-major and we change only the number of rows
(!Derived::IsRowMajor && _this.rows() == rows) )) // column-major and we change only the number of columns
{
internal::check_rows_cols_for_overflow<Derived::MaxSizeAtCompileTime>::run(rows, cols);
internal::check_rows_cols_for_overflow<Derived::MaxSizeAtCompileTime, Derived::MaxRowsAtCompileTime, Derived::MaxColsAtCompileTime>::run(rows, cols);
_this.derived().m_storage.conservativeResize(rows*cols,rows,cols);
}
else

View File

@ -311,7 +311,9 @@ constexpr inline unsigned compute_matrix_flags(int Options) {
}
constexpr inline int size_at_compile_time(int rows, int cols) {
return (rows==Dynamic || cols==Dynamic) ? Dynamic : rows * cols;
if (rows == 0 || cols == 0) return 0;
if (rows == Dynamic || cols == Dynamic) return Dynamic;
return rows * cols;
}
template<typename XprType> struct size_of_xpr_at_compile_time

View File

@ -360,7 +360,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
int i;
Index size = Index(1);
for (i = 0; i < NumIndices; i++) {
internal::check_rows_cols_for_overflow<Dynamic>::run(size, dimensions[i]);
internal::check_rows_cols_for_overflow<Dynamic, Dynamic, Dynamic>::run(size, dimensions[i]);
size *= dimensions[i];
}
#ifdef EIGEN_INITIALIZE_COEFFS