mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-21 20:34:28 +08:00
fix matrix product with EIGEN_DEFAULT_TO_ROW_MAJOR
This commit is contained in:
parent
d209120180
commit
7852a48a2f
@ -66,11 +66,8 @@ struct ProductReturnType
|
|||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
||||||
{
|
{
|
||||||
typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
|
typedef const Lhs& LhsNested;
|
||||||
|
typedef const Rhs& RhsNested;
|
||||||
typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
|
|
||||||
typename ei_plain_matrix_type_column_major<Rhs>::type
|
|
||||||
>::type RhsNested;
|
|
||||||
|
|
||||||
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
|
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
|
||||||
};
|
};
|
||||||
@ -144,7 +141,7 @@ struct ei_traits<Product<LhsNested, RhsNested, ProductMode> >
|
|||||||
|
|
||||||
EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)),
|
EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)),
|
||||||
|
|
||||||
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit)|DirectAccessBit),
|
||||||
|
|
||||||
Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
||||||
| EvalBeforeAssigningBit
|
| EvalBeforeAssigningBit
|
||||||
@ -571,7 +568,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
_res = ei_aligned_stack_new(Scalar,res.size());
|
_res = ei_aligned_stack_new(Scalar,res.size());
|
||||||
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res;
|
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
|
||||||
}
|
}
|
||||||
ei_cache_friendly_product_colmajor_times_vector(res.size(),
|
ei_cache_friendly_product_colmajor_times_vector(res.size(),
|
||||||
&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
|
&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
|
||||||
@ -579,7 +576,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
|
|||||||
|
|
||||||
if (!EvalToRes)
|
if (!EvalToRes)
|
||||||
{
|
{
|
||||||
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size());
|
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
|
||||||
ei_aligned_stack_delete(Scalar, _res, res.size());
|
ei_aligned_stack_delete(Scalar, _res, res.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -617,7 +614,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
_res = ei_aligned_stack_new(Scalar, res.size());
|
_res = ei_aligned_stack_new(Scalar, res.size());
|
||||||
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res;
|
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
|
||||||
}
|
}
|
||||||
ei_cache_friendly_product_colmajor_times_vector(res.size(),
|
ei_cache_friendly_product_colmajor_times_vector(res.size(),
|
||||||
&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
|
&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
|
||||||
@ -625,7 +622,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
|
|||||||
|
|
||||||
if (!EvalToRes)
|
if (!EvalToRes)
|
||||||
{
|
{
|
||||||
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size());
|
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
|
||||||
ei_aligned_stack_delete(Scalar, _res, res.size());
|
ei_aligned_stack_delete(Scalar, _res, res.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -650,7 +647,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirect
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
_rhs = ei_aligned_stack_new(Scalar, product.rhs().size());
|
_rhs = ei_aligned_stack_new(Scalar, product.rhs().size());
|
||||||
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs();
|
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1,ColMajor> >(_rhs, product.rhs().size()) = product.rhs();
|
||||||
}
|
}
|
||||||
ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
|
ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
|
||||||
_rhs, product.rhs().size(), res);
|
_rhs, product.rhs().size(), res);
|
||||||
@ -678,7 +675,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
|
|||||||
else
|
else
|
||||||
{
|
{
|
||||||
_lhs = ei_aligned_stack_new(Scalar, product.lhs().size());
|
_lhs = ei_aligned_stack_new(Scalar, product.lhs().size());
|
||||||
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs();
|
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1,ColMajor> >(_lhs, product.lhs().size()) = product.lhs();
|
||||||
}
|
}
|
||||||
ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
|
ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
|
||||||
_lhs, product.lhs().size(), res);
|
_lhs, product.lhs().size(), res);
|
||||||
@ -709,7 +706,17 @@ MatrixBase<Derived>::operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProdu
|
|||||||
if (other._expression()._useCacheFriendlyProduct())
|
if (other._expression()._useCacheFriendlyProduct())
|
||||||
ei_cache_friendly_product_selector<Product<Lhs,Rhs,CacheFriendlyProduct> >::run(const_cast_derived(), other._expression());
|
ei_cache_friendly_product_selector<Product<Lhs,Rhs,CacheFriendlyProduct> >::run(const_cast_derived(), other._expression());
|
||||||
else
|
else
|
||||||
lazyAssign(derived() + other._expression());
|
{
|
||||||
|
typedef typename ei_cleantype<Lhs>::type _Lhs;
|
||||||
|
typedef typename ei_cleantype<Rhs>::type _Rhs;
|
||||||
|
|
||||||
|
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
|
||||||
|
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
|
||||||
|
|
||||||
|
Product<LhsNested,RhsNested,NormalProduct> prod(other._expression().lhs(),other._expression().rhs());
|
||||||
|
|
||||||
|
lazyAssign(derived() + prod);
|
||||||
|
}
|
||||||
return derived();
|
return derived();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -724,12 +731,21 @@ inline Derived& MatrixBase<Derived>::lazyAssign(const Product<Lhs,Rhs,CacheFrien
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
lazyAssign<Product<Lhs,Rhs,CacheFriendlyProduct> >(product);
|
typedef typename ei_cleantype<Lhs>::type _Lhs;
|
||||||
|
typedef typename ei_cleantype<Rhs>::type _Rhs;
|
||||||
|
|
||||||
|
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
|
||||||
|
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
|
||||||
|
|
||||||
|
typedef Product<LhsNested,RhsNested,NormalProduct> NormalProduct;
|
||||||
|
NormalProduct normal_prod(product.lhs(),product.rhs());
|
||||||
|
|
||||||
|
lazyAssign<NormalProduct>(normal_prod);
|
||||||
}
|
}
|
||||||
return derived();
|
return derived();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T> struct ei_product_copy_rhs
|
template<typename T,int StorageOrder> struct ei_product_copy_rhs
|
||||||
{
|
{
|
||||||
typedef typename ei_meta_if<
|
typedef typename ei_meta_if<
|
||||||
(ei_traits<T>::Flags & RowMajorBit)
|
(ei_traits<T>::Flags & RowMajorBit)
|
||||||
@ -739,11 +755,30 @@ template<typename T> struct ei_product_copy_rhs
|
|||||||
>::ret type;
|
>::ret type;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T> struct ei_product_copy_lhs
|
template<typename T> struct ei_product_copy_rhs<T,RowMajorBit>
|
||||||
|
{
|
||||||
|
typedef typename ei_meta_if<
|
||||||
|
(!(ei_traits<T>::Flags & DirectAccessBit)),
|
||||||
|
typename ei_plain_matrix_type<T>::type,
|
||||||
|
const T&
|
||||||
|
>::ret type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T,int StorageOrder> struct ei_product_copy_lhs
|
||||||
{
|
{
|
||||||
typedef typename ei_meta_if<
|
typedef typename ei_meta_if<
|
||||||
(!(int(ei_traits<T>::Flags) & DirectAccessBit)),
|
(!(int(ei_traits<T>::Flags) & DirectAccessBit)),
|
||||||
typename ei_plain_matrix_type<T>::type,
|
typename ei_plain_matrix_type_row_major<T>::type,
|
||||||
|
const T&
|
||||||
|
>::ret type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T> struct ei_product_copy_lhs<T,RowMajorBit>
|
||||||
|
{
|
||||||
|
typedef typename ei_meta_if<
|
||||||
|
((ei_traits<T>::Flags & RowMajorBit)==0)
|
||||||
|
|| (!(int(ei_traits<T>::Flags) & DirectAccessBit)),
|
||||||
|
typename ei_plain_matrix_type_row_major<T>::type,
|
||||||
const T&
|
const T&
|
||||||
>::ret type;
|
>::ret type;
|
||||||
};
|
};
|
||||||
@ -752,9 +787,9 @@ template<typename Lhs, typename Rhs, int ProductMode>
|
|||||||
template<typename DestDerived>
|
template<typename DestDerived>
|
||||||
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const
|
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const
|
||||||
{
|
{
|
||||||
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
|
typedef typename ei_product_copy_lhs<_LhsNested,DestDerived::Flags&RowMajorBit>::type LhsCopy;
|
||||||
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
|
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
|
||||||
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
|
typedef typename ei_product_copy_rhs<_RhsNested,DestDerived::Flags&RowMajorBit>::type RhsCopy;
|
||||||
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
|
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
|
||||||
LhsCopy lhs(m_lhs);
|
LhsCopy lhs(m_lhs);
|
||||||
RhsCopy rhs(m_rhs);
|
RhsCopy rhs(m_rhs);
|
||||||
@ -764,6 +799,7 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
|
|||||||
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
||||||
DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride()
|
DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride()
|
||||||
);
|
);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // EIGEN_PRODUCT_H
|
#endif // EIGEN_PRODUCT_H
|
||||||
|
@ -161,6 +161,19 @@ template<typename T> struct ei_plain_matrix_type_column_major
|
|||||||
> type;
|
> type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* ei_plain_matrix_type_row_major : same as ei_plain_matrix_type but guaranteed to be row-major
|
||||||
|
*/
|
||||||
|
template<typename T> struct ei_plain_matrix_type_row_major
|
||||||
|
{
|
||||||
|
typedef Matrix<typename ei_traits<T>::Scalar,
|
||||||
|
ei_traits<T>::RowsAtCompileTime,
|
||||||
|
ei_traits<T>::ColsAtCompileTime,
|
||||||
|
AutoAlign | RowMajor,
|
||||||
|
ei_traits<T>::MaxRowsAtCompileTime,
|
||||||
|
ei_traits<T>::MaxColsAtCompileTime
|
||||||
|
> type;
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T> struct ei_must_nest_by_value { enum { ret = false }; };
|
template<typename T> struct ei_must_nest_by_value { enum { ret = false }; };
|
||||||
template<typename T> struct ei_must_nest_by_value<NestByValue<T> > { enum { ret = true }; };
|
template<typename T> struct ei_must_nest_by_value<NestByValue<T> > { enum { ret = true }; };
|
||||||
|
|
||||||
|
@ -77,6 +77,7 @@ template<typename MatrixType> void product(const MatrixType& m)
|
|||||||
|
|
||||||
// begin testing Product.h: only associativity for now
|
// begin testing Product.h: only associativity for now
|
||||||
// (we use Transpose.h but this doesn't count as a test for it)
|
// (we use Transpose.h but this doesn't count as a test for it)
|
||||||
|
|
||||||
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
|
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
|
||||||
m3 = m1;
|
m3 = m1;
|
||||||
m3 *= m1.transpose() * m2;
|
m3 *= m1.transpose() * m2;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user