Improved the contraction mapper to properly support tensor products

This commit is contained in:
Benoit Steiner 2016-07-11 13:43:41 -07:00
parent 0bc020be9d
commit 3a2dd352ae

View File

@ -130,13 +130,13 @@ class SimpleTensorContractionMapper {
} }
Index contract_val = left ? col : row; Index contract_val = left ? col : row;
if(array_size<contract_t>::value > 0) {
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
const Index idx = contract_val / m_k_strides[i]; const Index idx = contract_val / m_k_strides[i];
linidx += idx * m_contract_strides[i]; linidx += idx * m_contract_strides[i];
contract_val -= idx * m_k_strides[i]; contract_val -= idx * m_k_strides[i];
} }
if(array_size<contract_t>::value > 0) {
if (side == Rhs && inner_dim_contiguous) { if (side == Rhs && inner_dim_contiguous) {
eigen_assert(m_contract_strides[0] == 1); eigen_assert(m_contract_strides[0] == 1);
linidx += contract_val; linidx += contract_val;
@ -153,6 +153,7 @@ class SimpleTensorContractionMapper {
const bool left = (side == Lhs); const bool left = (side == Lhs);
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
Index linidx[2] = {0, 0}; Index linidx[2] = {0, 0};
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
const Index idx0 = nocontract_val[0] / m_ij_strides[i]; const Index idx0 = nocontract_val[0] / m_ij_strides[i];
const Index idx1 = nocontract_val[1] / m_ij_strides[i]; const Index idx1 = nocontract_val[1] / m_ij_strides[i];
@ -161,7 +162,6 @@ class SimpleTensorContractionMapper {
nocontract_val[0] -= idx0 * m_ij_strides[i]; nocontract_val[0] -= idx0 * m_ij_strides[i];
nocontract_val[1] -= idx1 * m_ij_strides[i]; nocontract_val[1] -= idx1 * m_ij_strides[i];
} }
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
if (side == Lhs && inner_dim_contiguous) { if (side == Lhs && inner_dim_contiguous) {
eigen_assert(m_nocontract_strides[0] == 1); eigen_assert(m_nocontract_strides[0] == 1);
linidx[0] += nocontract_val[0]; linidx[0] += nocontract_val[0];
@ -173,6 +173,7 @@ class SimpleTensorContractionMapper {
} }
Index contract_val[2] = {left ? col : row, left ? col : row + distance}; Index contract_val[2] = {left ? col : row, left ? col : row + distance};
if (array_size<contract_t>::value> 0) {
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
const Index idx0 = contract_val[0] / m_k_strides[i]; const Index idx0 = contract_val[0] / m_k_strides[i];
const Index idx1 = contract_val[1] / m_k_strides[i]; const Index idx1 = contract_val[1] / m_k_strides[i];
@ -190,6 +191,7 @@ class SimpleTensorContractionMapper {
linidx[0] += contract_val[0] * m_contract_strides[0]; linidx[0] += contract_val[0] * m_contract_strides[0];
linidx[1] += contract_val[1] * m_contract_strides[0]; linidx[1] += contract_val[1] * m_contract_strides[0];
} }
}
return IndexPair<Index>(linidx[0], linidx[1]); return IndexPair<Index>(linidx[0], linidx[1]);
} }
@ -200,7 +202,7 @@ class SimpleTensorContractionMapper {
return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
} }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1; return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
} }
protected: protected: