fix matrix product with EIGEN_DEFAULT_TO_ROW_MAJOR

This commit is contained in:
Gael Guennebaud 2010-01-25 21:56:01 +01:00
parent d209120180
commit 7852a48a2f
3 changed files with 69 additions and 19 deletions

View File

@ -66,11 +66,8 @@ struct ProductReturnType
template<typename Lhs, typename Rhs> template<typename Lhs, typename Rhs>
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct> struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
{ {
typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested; typedef const Lhs& LhsNested;
typedef const Rhs& RhsNested;
typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
typename ei_plain_matrix_type_column_major<Rhs>::type
>::type RhsNested;
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type; typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
}; };
@ -144,7 +141,7 @@ struct ei_traits<Product<LhsNested, RhsNested, ProductMode> >
EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)), EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)),
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit), RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit)|DirectAccessBit),
Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits) Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
| EvalBeforeAssigningBit | EvalBeforeAssigningBit
@ -571,7 +568,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
else else
{ {
_res = ei_aligned_stack_new(Scalar,res.size()); _res = ei_aligned_stack_new(Scalar,res.size());
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res; Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
} }
ei_cache_friendly_product_colmajor_times_vector(res.size(), ei_cache_friendly_product_colmajor_times_vector(res.size(),
&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), &product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
@ -579,7 +576,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
if (!EvalToRes) if (!EvalToRes)
{ {
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()); res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
ei_aligned_stack_delete(Scalar, _res, res.size()); ei_aligned_stack_delete(Scalar, _res, res.size());
} }
} }
@ -617,7 +614,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
else else
{ {
_res = ei_aligned_stack_new(Scalar, res.size()); _res = ei_aligned_stack_new(Scalar, res.size());
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res; Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
} }
ei_cache_friendly_product_colmajor_times_vector(res.size(), ei_cache_friendly_product_colmajor_times_vector(res.size(),
&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), &product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
@ -625,7 +622,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
if (!EvalToRes) if (!EvalToRes)
{ {
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()); res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
ei_aligned_stack_delete(Scalar, _res, res.size()); ei_aligned_stack_delete(Scalar, _res, res.size());
} }
} }
@ -650,7 +647,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirect
else else
{ {
_rhs = ei_aligned_stack_new(Scalar, product.rhs().size()); _rhs = ei_aligned_stack_new(Scalar, product.rhs().size());
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs(); Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1,ColMajor> >(_rhs, product.rhs().size()) = product.rhs();
} }
ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(), ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
_rhs, product.rhs().size(), res); _rhs, product.rhs().size(), res);
@ -678,7 +675,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
else else
{ {
_lhs = ei_aligned_stack_new(Scalar, product.lhs().size()); _lhs = ei_aligned_stack_new(Scalar, product.lhs().size());
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs(); Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1,ColMajor> >(_lhs, product.lhs().size()) = product.lhs();
} }
ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(), ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
_lhs, product.lhs().size(), res); _lhs, product.lhs().size(), res);
@ -709,7 +706,17 @@ MatrixBase<Derived>::operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProdu
if (other._expression()._useCacheFriendlyProduct()) if (other._expression()._useCacheFriendlyProduct())
ei_cache_friendly_product_selector<Product<Lhs,Rhs,CacheFriendlyProduct> >::run(const_cast_derived(), other._expression()); ei_cache_friendly_product_selector<Product<Lhs,Rhs,CacheFriendlyProduct> >::run(const_cast_derived(), other._expression());
else else
lazyAssign(derived() + other._expression()); {
typedef typename ei_cleantype<Lhs>::type _Lhs;
typedef typename ei_cleantype<Rhs>::type _Rhs;
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
Product<LhsNested,RhsNested,NormalProduct> prod(other._expression().lhs(),other._expression().rhs());
lazyAssign(derived() + prod);
}
return derived(); return derived();
} }
@ -724,12 +731,21 @@ inline Derived& MatrixBase<Derived>::lazyAssign(const Product<Lhs,Rhs,CacheFrien
} }
else else
{ {
lazyAssign<Product<Lhs,Rhs,CacheFriendlyProduct> >(product); typedef typename ei_cleantype<Lhs>::type _Lhs;
typedef typename ei_cleantype<Rhs>::type _Rhs;
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
typedef Product<LhsNested,RhsNested,NormalProduct> NormalProduct;
NormalProduct normal_prod(product.lhs(),product.rhs());
lazyAssign<NormalProduct>(normal_prod);
} }
return derived(); return derived();
} }
template<typename T> struct ei_product_copy_rhs template<typename T,int StorageOrder> struct ei_product_copy_rhs
{ {
typedef typename ei_meta_if< typedef typename ei_meta_if<
(ei_traits<T>::Flags & RowMajorBit) (ei_traits<T>::Flags & RowMajorBit)
@ -739,11 +755,30 @@ template<typename T> struct ei_product_copy_rhs
>::ret type; >::ret type;
}; };
template<typename T> struct ei_product_copy_lhs template<typename T> struct ei_product_copy_rhs<T,RowMajorBit>
{
typedef typename ei_meta_if<
(!(ei_traits<T>::Flags & DirectAccessBit)),
typename ei_plain_matrix_type<T>::type,
const T&
>::ret type;
};
template<typename T,int StorageOrder> struct ei_product_copy_lhs
{ {
typedef typename ei_meta_if< typedef typename ei_meta_if<
(!(int(ei_traits<T>::Flags) & DirectAccessBit)), (!(int(ei_traits<T>::Flags) & DirectAccessBit)),
typename ei_plain_matrix_type<T>::type, typename ei_plain_matrix_type_row_major<T>::type,
const T&
>::ret type;
};
template<typename T> struct ei_product_copy_lhs<T,RowMajorBit>
{
typedef typename ei_meta_if<
((ei_traits<T>::Flags & RowMajorBit)==0)
|| (!(int(ei_traits<T>::Flags) & DirectAccessBit)),
typename ei_plain_matrix_type_row_major<T>::type,
const T& const T&
>::ret type; >::ret type;
}; };
@ -752,9 +787,9 @@ template<typename Lhs, typename Rhs, int ProductMode>
template<typename DestDerived> template<typename DestDerived>
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const
{ {
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy; typedef typename ei_product_copy_lhs<_LhsNested,DestDerived::Flags&RowMajorBit>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy; typedef typename ei_unref<LhsCopy>::type _LhsCopy;
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy; typedef typename ei_product_copy_rhs<_RhsNested,DestDerived::Flags&RowMajorBit>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy; typedef typename ei_unref<RhsCopy>::type _RhsCopy;
LhsCopy lhs(m_lhs); LhsCopy lhs(m_lhs);
RhsCopy rhs(m_rhs); RhsCopy rhs(m_rhs);
@ -764,6 +799,7 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(), _RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride() DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride()
); );
} }
#endif // EIGEN_PRODUCT_H #endif // EIGEN_PRODUCT_H

View File

@ -161,6 +161,19 @@ template<typename T> struct ei_plain_matrix_type_column_major
> type; > type;
}; };
/* ei_plain_matrix_type_row_major : same as ei_plain_matrix_type but guaranteed to be row-major
*/
template<typename T> struct ei_plain_matrix_type_row_major
{
typedef Matrix<typename ei_traits<T>::Scalar,
ei_traits<T>::RowsAtCompileTime,
ei_traits<T>::ColsAtCompileTime,
AutoAlign | RowMajor,
ei_traits<T>::MaxRowsAtCompileTime,
ei_traits<T>::MaxColsAtCompileTime
> type;
};
template<typename T> struct ei_must_nest_by_value { enum { ret = false }; }; template<typename T> struct ei_must_nest_by_value { enum { ret = false }; };
template<typename T> struct ei_must_nest_by_value<NestByValue<T> > { enum { ret = true }; }; template<typename T> struct ei_must_nest_by_value<NestByValue<T> > { enum { ret = true }; };

View File

@ -77,6 +77,7 @@ template<typename MatrixType> void product(const MatrixType& m)
// begin testing Product.h: only associativity for now // begin testing Product.h: only associativity for now
// (we use Transpose.h but this doesn't count as a test for it) // (we use Transpose.h but this doesn't count as a test for it)
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2)); VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
m3 = m1; m3 = m1;
m3 *= m1.transpose() * m2; m3 *= m1.transpose() * m2;