mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-21 20:34:28 +08:00
fix matrix product with EIGEN_DEFAULT_TO_ROW_MAJOR
This commit is contained in:
parent
d209120180
commit
7852a48a2f
@ -66,11 +66,8 @@ struct ProductReturnType
|
||||
template<typename Lhs, typename Rhs>
|
||||
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
|
||||
{
|
||||
typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
|
||||
|
||||
typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
|
||||
typename ei_plain_matrix_type_column_major<Rhs>::type
|
||||
>::type RhsNested;
|
||||
typedef const Lhs& LhsNested;
|
||||
typedef const Rhs& RhsNested;
|
||||
|
||||
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
|
||||
};
|
||||
@ -144,7 +141,7 @@ struct ei_traits<Product<LhsNested, RhsNested, ProductMode> >
|
||||
|
||||
EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)),
|
||||
|
||||
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
|
||||
RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit)|DirectAccessBit),
|
||||
|
||||
Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
|
||||
| EvalBeforeAssigningBit
|
||||
@ -571,7 +568,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
|
||||
else
|
||||
{
|
||||
_res = ei_aligned_stack_new(Scalar,res.size());
|
||||
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res;
|
||||
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
|
||||
}
|
||||
ei_cache_friendly_product_colmajor_times_vector(res.size(),
|
||||
&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
|
||||
@ -579,7 +576,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
|
||||
|
||||
if (!EvalToRes)
|
||||
{
|
||||
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size());
|
||||
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
|
||||
ei_aligned_stack_delete(Scalar, _res, res.size());
|
||||
}
|
||||
}
|
||||
@ -617,7 +614,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
|
||||
else
|
||||
{
|
||||
_res = ei_aligned_stack_new(Scalar, res.size());
|
||||
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res;
|
||||
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
|
||||
}
|
||||
ei_cache_friendly_product_colmajor_times_vector(res.size(),
|
||||
&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
|
||||
@ -625,7 +622,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
|
||||
|
||||
if (!EvalToRes)
|
||||
{
|
||||
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size());
|
||||
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
|
||||
ei_aligned_stack_delete(Scalar, _res, res.size());
|
||||
}
|
||||
}
|
||||
@ -650,7 +647,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirect
|
||||
else
|
||||
{
|
||||
_rhs = ei_aligned_stack_new(Scalar, product.rhs().size());
|
||||
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs();
|
||||
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1,ColMajor> >(_rhs, product.rhs().size()) = product.rhs();
|
||||
}
|
||||
ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
|
||||
_rhs, product.rhs().size(), res);
|
||||
@ -678,7 +675,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
|
||||
else
|
||||
{
|
||||
_lhs = ei_aligned_stack_new(Scalar, product.lhs().size());
|
||||
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs();
|
||||
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1,ColMajor> >(_lhs, product.lhs().size()) = product.lhs();
|
||||
}
|
||||
ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
|
||||
_lhs, product.lhs().size(), res);
|
||||
@ -709,7 +706,17 @@ MatrixBase<Derived>::operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProdu
|
||||
if (other._expression()._useCacheFriendlyProduct())
|
||||
ei_cache_friendly_product_selector<Product<Lhs,Rhs,CacheFriendlyProduct> >::run(const_cast_derived(), other._expression());
|
||||
else
|
||||
lazyAssign(derived() + other._expression());
|
||||
{
|
||||
typedef typename ei_cleantype<Lhs>::type _Lhs;
|
||||
typedef typename ei_cleantype<Rhs>::type _Rhs;
|
||||
|
||||
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
|
||||
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
|
||||
|
||||
Product<LhsNested,RhsNested,NormalProduct> prod(other._expression().lhs(),other._expression().rhs());
|
||||
|
||||
lazyAssign(derived() + prod);
|
||||
}
|
||||
return derived();
|
||||
}
|
||||
|
||||
@ -724,12 +731,21 @@ inline Derived& MatrixBase<Derived>::lazyAssign(const Product<Lhs,Rhs,CacheFrien
|
||||
}
|
||||
else
|
||||
{
|
||||
lazyAssign<Product<Lhs,Rhs,CacheFriendlyProduct> >(product);
|
||||
typedef typename ei_cleantype<Lhs>::type _Lhs;
|
||||
typedef typename ei_cleantype<Rhs>::type _Rhs;
|
||||
|
||||
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
|
||||
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
|
||||
|
||||
typedef Product<LhsNested,RhsNested,NormalProduct> NormalProduct;
|
||||
NormalProduct normal_prod(product.lhs(),product.rhs());
|
||||
|
||||
lazyAssign<NormalProduct>(normal_prod);
|
||||
}
|
||||
return derived();
|
||||
}
|
||||
|
||||
template<typename T> struct ei_product_copy_rhs
|
||||
template<typename T,int StorageOrder> struct ei_product_copy_rhs
|
||||
{
|
||||
typedef typename ei_meta_if<
|
||||
(ei_traits<T>::Flags & RowMajorBit)
|
||||
@ -739,11 +755,30 @@ template<typename T> struct ei_product_copy_rhs
|
||||
>::ret type;
|
||||
};
|
||||
|
||||
template<typename T> struct ei_product_copy_lhs
|
||||
template<typename T> struct ei_product_copy_rhs<T,RowMajorBit>
|
||||
{
|
||||
typedef typename ei_meta_if<
|
||||
(!(ei_traits<T>::Flags & DirectAccessBit)),
|
||||
typename ei_plain_matrix_type<T>::type,
|
||||
const T&
|
||||
>::ret type;
|
||||
};
|
||||
|
||||
template<typename T,int StorageOrder> struct ei_product_copy_lhs
|
||||
{
|
||||
typedef typename ei_meta_if<
|
||||
(!(int(ei_traits<T>::Flags) & DirectAccessBit)),
|
||||
typename ei_plain_matrix_type<T>::type,
|
||||
typename ei_plain_matrix_type_row_major<T>::type,
|
||||
const T&
|
||||
>::ret type;
|
||||
};
|
||||
|
||||
template<typename T> struct ei_product_copy_lhs<T,RowMajorBit>
|
||||
{
|
||||
typedef typename ei_meta_if<
|
||||
((ei_traits<T>::Flags & RowMajorBit)==0)
|
||||
|| (!(int(ei_traits<T>::Flags) & DirectAccessBit)),
|
||||
typename ei_plain_matrix_type_row_major<T>::type,
|
||||
const T&
|
||||
>::ret type;
|
||||
};
|
||||
@ -752,9 +787,9 @@ template<typename Lhs, typename Rhs, int ProductMode>
|
||||
template<typename DestDerived>
|
||||
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const
|
||||
{
|
||||
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
|
||||
typedef typename ei_product_copy_lhs<_LhsNested,DestDerived::Flags&RowMajorBit>::type LhsCopy;
|
||||
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
|
||||
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
|
||||
typedef typename ei_product_copy_rhs<_RhsNested,DestDerived::Flags&RowMajorBit>::type RhsCopy;
|
||||
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
|
||||
LhsCopy lhs(m_lhs);
|
||||
RhsCopy rhs(m_rhs);
|
||||
@ -764,6 +799,7 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
|
||||
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
|
||||
DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride()
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
#endif // EIGEN_PRODUCT_H
|
||||
|
@ -161,6 +161,19 @@ template<typename T> struct ei_plain_matrix_type_column_major
|
||||
> type;
|
||||
};
|
||||
|
||||
/* ei_plain_matrix_type_row_major : same as ei_plain_matrix_type but guaranteed to be row-major
|
||||
*/
|
||||
template<typename T> struct ei_plain_matrix_type_row_major
|
||||
{
|
||||
typedef Matrix<typename ei_traits<T>::Scalar,
|
||||
ei_traits<T>::RowsAtCompileTime,
|
||||
ei_traits<T>::ColsAtCompileTime,
|
||||
AutoAlign | RowMajor,
|
||||
ei_traits<T>::MaxRowsAtCompileTime,
|
||||
ei_traits<T>::MaxColsAtCompileTime
|
||||
> type;
|
||||
};
|
||||
|
||||
template<typename T> struct ei_must_nest_by_value { enum { ret = false }; };
|
||||
template<typename T> struct ei_must_nest_by_value<NestByValue<T> > { enum { ret = true }; };
|
||||
|
||||
|
@ -77,6 +77,7 @@ template<typename MatrixType> void product(const MatrixType& m)
|
||||
|
||||
// begin testing Product.h: only associativity for now
|
||||
// (we use Transpose.h but this doesn't count as a test for it)
|
||||
|
||||
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
|
||||
m3 = m1;
|
||||
m3 *= m1.transpose() * m2;
|
||||
|
Loading…
x
Reference in New Issue
Block a user