Made the cost model cwiseMax and cwiseMin methods consts to help the PowerPC cuda compiler compile this code.

This commit is contained in:
Benoit Steiner 2016-08-18 13:46:36 -07:00
parent 647a51b426
commit 7944d4431f

View File

@ -91,21 +91,21 @@ class TensorOpCost {
} }
// TODO(rmlarsen): Define min in terms of total cost, not elementwise. // TODO(rmlarsen): Define min in terms of total cost, not elementwise.
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost& cwiseMin( EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost cwiseMin(
const TensorOpCost& rhs) { const TensorOpCost& rhs) const {
bytes_loaded_ = numext::mini(bytes_loaded_, rhs.bytes_loaded()); double bytes_loaded = numext::mini(bytes_loaded_, rhs.bytes_loaded());
bytes_stored_ = numext::mini(bytes_stored_, rhs.bytes_stored()); double bytes_stored = numext::mini(bytes_stored_, rhs.bytes_stored());
compute_cycles_ = numext::mini(compute_cycles_, rhs.compute_cycles()); double compute_cycles = numext::mini(compute_cycles_, rhs.compute_cycles());
return *this; return TensorOpCost(bytes_loaded, bytes_stored, compute_cycles);
} }
// TODO(rmlarsen): Define max in terms of total cost, not elementwise. // TODO(rmlarsen): Define max in terms of total cost, not elementwise.
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost& cwiseMax( EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost cwiseMax(
const TensorOpCost& rhs) { const TensorOpCost& rhs) const {
bytes_loaded_ = numext::maxi(bytes_loaded_, rhs.bytes_loaded()); double bytes_loaded = numext::maxi(bytes_loaded_, rhs.bytes_loaded());
bytes_stored_ = numext::maxi(bytes_stored_, rhs.bytes_stored()); double bytes_stored = numext::maxi(bytes_stored_, rhs.bytes_stored());
compute_cycles_ = numext::maxi(compute_cycles_, rhs.compute_cycles()); double compute_cycles = numext::maxi(compute_cycles_, rhs.compute_cycles());
return *this; return TensorOpCost(bytes_loaded, bytes_stored, compute_cycles);
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost& operator+=( EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost& operator+=(