fix matrix product with EIGEN_DEFAULT_TO_ROW_MAJOR

This commit is contained in:
Gael Guennebaud 2010-01-25 21:56:01 +01:00
parent d209120180
commit 7852a48a2f
3 changed files with 69 additions and 19 deletions

View File

@ -66,11 +66,8 @@ struct ProductReturnType
template<typename Lhs, typename Rhs>
struct ProductReturnType<Lhs,Rhs,CacheFriendlyProduct>
{
typedef typename ei_nested<Lhs,Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename ei_nested<Rhs,Lhs::RowsAtCompileTime,
typename ei_plain_matrix_type_column_major<Rhs>::type
>::type RhsNested;
typedef const Lhs& LhsNested;
typedef const Rhs& RhsNested;
typedef Product<LhsNested, RhsNested, CacheFriendlyProduct> Type;
};
@ -144,7 +141,7 @@ struct ei_traits<Product<LhsNested, RhsNested, ProductMode> >
EvalToRowMajor = RhsRowMajor && (ProductMode==(int)CacheFriendlyProduct ? LhsRowMajor : (!CanVectorizeLhs)),
RemovedBits = ~(EvalToRowMajor ? 0 : RowMajorBit),
RemovedBits = ~((EvalToRowMajor ? 0 : RowMajorBit)|DirectAccessBit),
Flags = ((unsigned int)(LhsFlags | RhsFlags) & HereditaryBits & RemovedBits)
| EvalBeforeAssigningBit
@ -571,7 +568,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
else
{
_res = ei_aligned_stack_new(Scalar,res.size());
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1> >(_res, res.size()) = res;
Map<Matrix<Scalar,DestDerived::RowsAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
}
ei_cache_friendly_product_colmajor_times_vector(res.size(),
&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
@ -579,7 +576,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,ColMajor,HasDirect
if (!EvalToRes)
{
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size());
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
ei_aligned_stack_delete(Scalar, _res, res.size());
}
}
@ -617,7 +614,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
else
{
_res = ei_aligned_stack_new(Scalar, res.size());
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size()) = res;
Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size()) = res;
}
ei_cache_friendly_product_colmajor_times_vector(res.size(),
&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
@ -625,7 +622,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
if (!EvalToRes)
{
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1> >(_res, res.size());
res = Map<Matrix<Scalar,DestDerived::SizeAtCompileTime,1,ColMajor> >(_res, res.size());
ei_aligned_stack_delete(Scalar, _res, res.size());
}
}
@ -650,7 +647,7 @@ struct ei_cache_friendly_product_selector<ProductType,LhsRows,RowMajor,HasDirect
else
{
_rhs = ei_aligned_stack_new(Scalar, product.rhs().size());
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1> >(_rhs, product.rhs().size()) = product.rhs();
Map<Matrix<Scalar,Rhs::SizeAtCompileTime,1,ColMajor> >(_rhs, product.rhs().size()) = product.rhs();
}
ei_cache_friendly_product_rowmajor_times_vector(&product.lhs().const_cast_derived().coeffRef(0,0), product.lhs().stride(),
_rhs, product.rhs().size(), res);
@ -678,7 +675,7 @@ struct ei_cache_friendly_product_selector<ProductType,1,LhsOrder,LhsAccess,RhsCo
else
{
_lhs = ei_aligned_stack_new(Scalar, product.lhs().size());
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1> >(_lhs, product.lhs().size()) = product.lhs();
Map<Matrix<Scalar,Lhs::SizeAtCompileTime,1,ColMajor> >(_lhs, product.lhs().size()) = product.lhs();
}
ei_cache_friendly_product_rowmajor_times_vector(&product.rhs().const_cast_derived().coeffRef(0,0), product.rhs().stride(),
_lhs, product.lhs().size(), res);
@ -709,7 +706,17 @@ MatrixBase<Derived>::operator+=(const Flagged<Product<Lhs,Rhs,CacheFriendlyProdu
if (other._expression()._useCacheFriendlyProduct())
ei_cache_friendly_product_selector<Product<Lhs,Rhs,CacheFriendlyProduct> >::run(const_cast_derived(), other._expression());
else
lazyAssign(derived() + other._expression());
{
typedef typename ei_cleantype<Lhs>::type _Lhs;
typedef typename ei_cleantype<Rhs>::type _Rhs;
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
Product<LhsNested,RhsNested,NormalProduct> prod(other._expression().lhs(),other._expression().rhs());
lazyAssign(derived() + prod);
}
return derived();
}
@ -724,12 +731,21 @@ inline Derived& MatrixBase<Derived>::lazyAssign(const Product<Lhs,Rhs,CacheFrien
}
else
{
lazyAssign<Product<Lhs,Rhs,CacheFriendlyProduct> >(product);
typedef typename ei_cleantype<Lhs>::type _Lhs;
typedef typename ei_cleantype<Rhs>::type _Rhs;
typedef typename ei_nested<_Lhs,_Rhs::ColsAtCompileTime>::type LhsNested;
typedef typename ei_nested<_Rhs,_Lhs::RowsAtCompileTime>::type RhsNested;
typedef Product<LhsNested,RhsNested,NormalProduct> NormalProduct;
NormalProduct normal_prod(product.lhs(),product.rhs());
lazyAssign<NormalProduct>(normal_prod);
}
return derived();
}
template<typename T> struct ei_product_copy_rhs
template<typename T,int StorageOrder> struct ei_product_copy_rhs
{
typedef typename ei_meta_if<
(ei_traits<T>::Flags & RowMajorBit)
@ -739,11 +755,30 @@ template<typename T> struct ei_product_copy_rhs
>::ret type;
};
template<typename T> struct ei_product_copy_lhs
template<typename T> struct ei_product_copy_rhs<T,RowMajorBit>
{
typedef typename ei_meta_if<
(!(ei_traits<T>::Flags & DirectAccessBit)),
typename ei_plain_matrix_type<T>::type,
const T&
>::ret type;
};
template<typename T,int StorageOrder> struct ei_product_copy_lhs
{
typedef typename ei_meta_if<
(!(int(ei_traits<T>::Flags) & DirectAccessBit)),
typename ei_plain_matrix_type<T>::type,
typename ei_plain_matrix_type_row_major<T>::type,
const T&
>::ret type;
};
template<typename T> struct ei_product_copy_lhs<T,RowMajorBit>
{
typedef typename ei_meta_if<
((ei_traits<T>::Flags & RowMajorBit)==0)
|| (!(int(ei_traits<T>::Flags) & DirectAccessBit)),
typename ei_plain_matrix_type_row_major<T>::type,
const T&
>::ret type;
};
@ -752,9 +787,9 @@ template<typename Lhs, typename Rhs, int ProductMode>
template<typename DestDerived>
inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived& res) const
{
typedef typename ei_product_copy_lhs<_LhsNested>::type LhsCopy;
typedef typename ei_product_copy_lhs<_LhsNested,DestDerived::Flags&RowMajorBit>::type LhsCopy;
typedef typename ei_unref<LhsCopy>::type _LhsCopy;
typedef typename ei_product_copy_rhs<_RhsNested>::type RhsCopy;
typedef typename ei_product_copy_rhs<_RhsNested,DestDerived::Flags&RowMajorBit>::type RhsCopy;
typedef typename ei_unref<RhsCopy>::type _RhsCopy;
LhsCopy lhs(m_lhs);
RhsCopy rhs(m_rhs);
@ -764,6 +799,7 @@ inline void Product<Lhs,Rhs,ProductMode>::_cacheFriendlyEvalAndAdd(DestDerived&
_RhsCopy::Flags&RowMajorBit, (const Scalar*)&(rhs.const_cast_derived().coeffRef(0,0)), rhs.stride(),
DestDerived::Flags&RowMajorBit, (Scalar*)&(res.coeffRef(0,0)), res.stride()
);
}
#endif // EIGEN_PRODUCT_H

View File

@ -161,6 +161,19 @@ template<typename T> struct ei_plain_matrix_type_column_major
> type;
};
/* ei_plain_matrix_type_row_major : same as ei_plain_matrix_type but guaranteed to be row-major
*/
template<typename T> struct ei_plain_matrix_type_row_major
{
typedef Matrix<typename ei_traits<T>::Scalar,
ei_traits<T>::RowsAtCompileTime,
ei_traits<T>::ColsAtCompileTime,
AutoAlign | RowMajor,
ei_traits<T>::MaxRowsAtCompileTime,
ei_traits<T>::MaxColsAtCompileTime
> type;
};
template<typename T> struct ei_must_nest_by_value { enum { ret = false }; };
template<typename T> struct ei_must_nest_by_value<NestByValue<T> > { enum { ret = true }; };

View File

@ -77,6 +77,7 @@ template<typename MatrixType> void product(const MatrixType& m)
// begin testing Product.h: only associativity for now
// (we use Transpose.h but this doesn't count as a test for it)
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
m3 = m1;
m3 *= m1.transpose() * m2;