mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-22 04:27:36 +08:00
Ensured that contractions that can be reduced to a matrix vector product work correctly even when the input coefficients aren't aligned.
This commit is contained in:
parent
509e4ddc02
commit
9f98650d0a
@ -140,10 +140,11 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,C
|
|||||||
// find how many columns do we have to skip to be aligned with the result (if possible)
|
// find how many columns do we have to skip to be aligned with the result (if possible)
|
||||||
Index skipColumns = 0;
|
Index skipColumns = 0;
|
||||||
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
|
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
|
||||||
if( (lhsAlignmentOffset < 0) || (size_t(res)%sizeof(ResScalar)) )
|
if( (lhsAlignmentOffset < 0) || (lhsAlignmentOffset == size) || (size_t(res)%sizeof(ResScalar)) )
|
||||||
{
|
{
|
||||||
alignedSize = 0;
|
alignedSize = 0;
|
||||||
alignedStart = 0;
|
alignedStart = 0;
|
||||||
|
alignmentPattern = NoneAligned;
|
||||||
}
|
}
|
||||||
else if(LhsPacketSize > 4)
|
else if(LhsPacketSize > 4)
|
||||||
{
|
{
|
||||||
@ -412,10 +413,13 @@ EIGEN_DONT_INLINE void general_matrix_vector_product<Index,LhsScalar,LhsMapper,R
|
|||||||
// find how many rows do we have to skip to be aligned with rhs (if possible)
|
// find how many rows do we have to skip to be aligned with rhs (if possible)
|
||||||
Index skipRows = 0;
|
Index skipRows = 0;
|
||||||
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
|
// if the data cannot be aligned (TODO add some compile time tests when possible, e.g. for floats)
|
||||||
if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) || (lhsAlignmentOffset < 0) || (rhsAlignmentOffset < 0) )
|
if( (sizeof(LhsScalar)!=sizeof(RhsScalar)) ||
|
||||||
|
(lhsAlignmentOffset < 0) || (lhsAlignmentOffset == depth) ||
|
||||||
|
(rhsAlignmentOffset < 0) || (rhsAlignmentOffset == rows) )
|
||||||
{
|
{
|
||||||
alignedSize = 0;
|
alignedSize = 0;
|
||||||
alignedStart = 0;
|
alignedStart = 0;
|
||||||
|
alignmentPattern = NoneAligned;
|
||||||
}
|
}
|
||||||
else if(LhsPacketSize > 4)
|
else if(LhsPacketSize > 4)
|
||||||
{
|
{
|
||||||
|
@ -352,6 +352,52 @@ static void test_large_contraction()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_matrix_vector()
|
||||||
|
{
|
||||||
|
Tensor<float, 2> t_left(30, 50);
|
||||||
|
Tensor<float, 1> t_right(50);
|
||||||
|
Tensor<float, 1> t_result(30);
|
||||||
|
|
||||||
|
t_left.setRandom();
|
||||||
|
t_right.setRandom();
|
||||||
|
|
||||||
|
typedef Map<Eigen::Matrix<float, Dynamic, Dynamic>> MapXf;
|
||||||
|
MapXf m_left(t_left.data(), 30, 50);
|
||||||
|
MapXf m_right(t_right.data(), 50, 1);
|
||||||
|
Eigen::Matrix<float, Dynamic, Dynamic> m_result(30, 1);
|
||||||
|
|
||||||
|
// this contraction should be equivalent to a single matrix multiplication
|
||||||
|
Eigen::array<DimPair, 1> dims{{DimPair(1, 0)}};
|
||||||
|
|
||||||
|
// compute results by separate methods
|
||||||
|
t_result = t_left.contract(t_right, dims);
|
||||||
|
m_result = m_left * m_right;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
|
||||||
|
VERIFY_IS_APPROX(t_result(i), m_result(i, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_tensor_vector()
|
||||||
|
{
|
||||||
|
Tensor<float, 3> t_left(7, 13, 17);
|
||||||
|
Tensor<float, 2> t_right(1, 7);
|
||||||
|
typedef typename Tensor<float, 1>::DimensionPair DimensionPair;
|
||||||
|
Eigen::array<DimensionPair, 1> dim_pair01{{{0, 1}}};
|
||||||
|
Tensor<float, 3> t_result = t_left.contract(t_right, dim_pair01);
|
||||||
|
|
||||||
|
typedef Map<Eigen::Matrix<float, Dynamic, Dynamic>> MapXf;
|
||||||
|
MapXf m_left(t_left.data(), 7, 13*17);
|
||||||
|
MapXf m_right(t_right.data(), 1, 7);
|
||||||
|
Eigen::Matrix<float, Dynamic, Dynamic> m_result = m_left.transpose() * m_right.transpose();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
|
||||||
|
VERIFY_IS_APPROX(t_result(i), m_result(i, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void test_cxx11_tensor_contraction()
|
void test_cxx11_tensor_contraction()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_evals());
|
CALL_SUBTEST(test_evals());
|
||||||
@ -364,4 +410,6 @@ void test_cxx11_tensor_contraction()
|
|||||||
CALL_SUBTEST(test_out_of_order_contraction());
|
CALL_SUBTEST(test_out_of_order_contraction());
|
||||||
CALL_SUBTEST(test_consistency());
|
CALL_SUBTEST(test_consistency());
|
||||||
CALL_SUBTEST(test_large_contraction());
|
CALL_SUBTEST(test_large_contraction());
|
||||||
|
CALL_SUBTEST(test_matrix_vector());
|
||||||
|
CALL_SUBTEST(test_tensor_vector());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user