fix a couple a issue with blas (new TRMM api, and enforece column major)

This commit is contained in:
Gael Guennebaud 2010-07-16 23:30:06 +02:00
parent f59226e901
commit cbd6fe323c
2 changed files with 10 additions and 9 deletions

View File

@ -86,15 +86,15 @@ enum
Conj = IsComplex Conj = IsComplex
}; };
typedef Map<Matrix<Scalar,Dynamic,Dynamic>, 0, OuterStride<Dynamic> > MatrixType; typedef Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > MatrixType;
typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType; typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType;
typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType; typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType;
template<typename T> template<typename T>
Map<Matrix<T,Dynamic,Dynamic>, 0, OuterStride<Dynamic> > Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >
matrix(T* data, int rows, int cols, int stride) matrix(T* data, int rows, int cols, int stride)
{ {
return Map<Matrix<T,Dynamic,Dynamic>, 0, OuterStride<Dynamic> >(data, rows, cols, OuterStride<Dynamic>(stride)); return Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
} }
template<typename T> template<typename T>

View File

@ -63,10 +63,8 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
if(beta!=Scalar(1)) if(beta!=Scalar(1))
{ {
if(beta==Scalar(0)) if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
matrix(c, *m, *n, *ldc).setZero(); else matrix(c, *m, *n, *ldc) *= beta;
else
matrix(c, *m, *n, *ldc) *= beta;
} }
ei_gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k); ei_gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k);
@ -207,14 +205,17 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
return 0; return 0;
} }
if(*m==0 || *n==0)
return 1;
// FIXME find a way to avoid this copy // FIXME find a way to avoid this copy
Matrix<Scalar,Dynamic,Dynamic> tmp = matrix(b,*m,*n,*ldb); Matrix<Scalar,Dynamic,Dynamic,ColMajor> tmp = matrix(b,*m,*n,*ldb);
matrix(b,*m,*n,*ldb).setZero(); matrix(b,*m,*n,*ldb).setZero();
if(SIDE(*side)==LEFT) if(SIDE(*side)==LEFT)
func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha); func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha);
else else
func[code](*n, *m, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha); func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha);
return 1; return 1;
} }