mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Add custom formatting of complex numbers for Numpy/Native.
This commit is contained in:
parent
5570a27869
commit
9f77ce4f19
@ -999,8 +999,9 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns a formatted tensor ready for printing to a stream
|
// Returns a formatted tensor ready for printing to a stream
|
||||||
inline const TensorWithFormat<Derived,DerivedTraits::Layout,DerivedTraits::NumDimensions> format(const TensorIOFormat& fmt) const {
|
template<typename Format>
|
||||||
return TensorWithFormat<Derived,DerivedTraits::Layout,DerivedTraits::NumDimensions>(derived(), fmt);
|
inline const TensorWithFormat<Derived,DerivedTraits::Layout,DerivedTraits::NumDimensions, Format> format(const Format& fmt) const {
|
||||||
|
return TensorWithFormat<Derived,DerivedTraits::Layout,DerivedTraits::NumDimensions, Format>(derived(), fmt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef EIGEN_READONLY_TENSORBASE_PLUGIN
|
#ifdef EIGEN_READONLY_TENSORBASE_PLUGIN
|
||||||
|
@ -18,33 +18,24 @@ namespace Eigen {
|
|||||||
struct TensorIOFormat;
|
struct TensorIOFormat;
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template <typename Tensor, std::size_t rank>
|
template <typename Tensor, std::size_t rank, typename Format, typename EnableIf = void>
|
||||||
struct TensorPrinter;
|
struct TensorPrinter;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TensorIOFormat {
|
template <typename Derived_>
|
||||||
TensorIOFormat(const std::vector<std::string>& _separator, const std::vector<std::string>& _prefix,
|
struct TensorIOFormatBase {
|
||||||
const std::vector<std::string>& _suffix, int _precision = StreamPrecision, int _flags = 0,
|
using Derived = Derived_;
|
||||||
const std::string& _tenPrefix = "", const std::string& _tenSuffix = "", const char _fill = ' ')
|
TensorIOFormatBase(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
|
||||||
: tenPrefix(_tenPrefix),
|
const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
|
||||||
tenSuffix(_tenSuffix),
|
const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
|
||||||
prefix(_prefix),
|
: tenPrefix(tenPrefix),
|
||||||
suffix(_suffix),
|
tenSuffix(tenSuffix),
|
||||||
separator(_separator),
|
prefix(prefix),
|
||||||
fill(_fill),
|
suffix(suffix),
|
||||||
precision(_precision),
|
separator(separator),
|
||||||
flags(_flags) {
|
fill(fill),
|
||||||
init_spacer();
|
precision(precision),
|
||||||
}
|
flags(flags) {
|
||||||
|
|
||||||
TensorIOFormat(int _precision = StreamPrecision, int _flags = 0, const std::string& _tenPrefix = "",
|
|
||||||
const std::string& _tenSuffix = "", const char _fill = ' ')
|
|
||||||
: tenPrefix(_tenPrefix), tenSuffix(_tenSuffix), fill(_fill), precision(_precision), flags(_flags) {
|
|
||||||
// default values of prefix, suffix and separator
|
|
||||||
prefix = {"", "["};
|
|
||||||
suffix = {"", "]"};
|
|
||||||
separator = {", ", "\n"};
|
|
||||||
|
|
||||||
init_spacer();
|
init_spacer();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,33 +58,6 @@ struct TensorIOFormat {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline const TensorIOFormat Numpy() {
|
|
||||||
std::vector<std::string> prefix = {"", "["};
|
|
||||||
std::vector<std::string> suffix = {"", "]"};
|
|
||||||
std::vector<std::string> separator = {" ", "\n"};
|
|
||||||
return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "[", "]");
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline const TensorIOFormat Plain() {
|
|
||||||
std::vector<std::string> separator = {" ", "\n", "\n", ""};
|
|
||||||
std::vector<std::string> prefix = {""};
|
|
||||||
std::vector<std::string> suffix = {""};
|
|
||||||
return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "", "", ' ');
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline const TensorIOFormat Native() {
|
|
||||||
std::vector<std::string> separator = {", ", ",\n", "\n"};
|
|
||||||
std::vector<std::string> prefix = {"", "{"};
|
|
||||||
std::vector<std::string> suffix = {"", "}"};
|
|
||||||
return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "{", "}", ' ');
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline const TensorIOFormat Legacy() {
|
|
||||||
TensorIOFormat LegacyFormat(StreamPrecision, 0, "", "", ' ');
|
|
||||||
LegacyFormat.legacy_bit = true;
|
|
||||||
return LegacyFormat;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string tenPrefix;
|
std::string tenPrefix;
|
||||||
std::string tenSuffix;
|
std::string tenSuffix;
|
||||||
std::vector<std::string> prefix;
|
std::vector<std::string> prefix;
|
||||||
@ -103,24 +67,67 @@ struct TensorIOFormat {
|
|||||||
int precision;
|
int precision;
|
||||||
int flags;
|
int flags;
|
||||||
std::vector<std::string> spacer{};
|
std::vector<std::string> spacer{};
|
||||||
bool legacy_bit = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int Layout, int rank>
|
struct TensorIOFormatNumpy : public TensorIOFormatBase<TensorIOFormatNumpy> {
|
||||||
|
using Base = TensorIOFormatBase<TensorIOFormatNumpy>;
|
||||||
|
TensorIOFormatNumpy()
|
||||||
|
: Base(/*separator=*/{" ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
|
||||||
|
/*flags=*/0, /*tenPrefix=*/"[", /*tenSuffix=*/"]") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TensorIOFormatNative : public TensorIOFormatBase<TensorIOFormatNative> {
|
||||||
|
using Base = TensorIOFormatBase<TensorIOFormatNative>;
|
||||||
|
TensorIOFormatNative()
|
||||||
|
: Base(/*separator=*/{", ", ",\n", "\n"}, /*prefix=*/{"", "{"}, /*suffix=*/{"", "}"},
|
||||||
|
/*precision=*/StreamPrecision, /*flags=*/0, /*tenPrefix=*/"{", /*tenSuffix=*/"}") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TensorIOFormatPlain : public TensorIOFormatBase<TensorIOFormatPlain> {
|
||||||
|
using Base = TensorIOFormatBase<TensorIOFormatPlain>;
|
||||||
|
TensorIOFormatPlain()
|
||||||
|
: Base(/*separator=*/{" ", "\n", "\n", ""}, /*prefix=*/{""}, /*suffix=*/{""}, /*precision=*/StreamPrecision,
|
||||||
|
/*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TensorIOFormatLegacy : public TensorIOFormatBase<TensorIOFormatLegacy> {
|
||||||
|
using Base = TensorIOFormatBase<TensorIOFormatLegacy>;
|
||||||
|
TensorIOFormatLegacy()
|
||||||
|
: Base(/*separator=*/{", ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
|
||||||
|
/*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TensorIOFormat : public TensorIOFormatBase<TensorIOFormat> {
|
||||||
|
using Base = TensorIOFormatBase<TensorIOFormat>;
|
||||||
|
TensorIOFormat(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
|
||||||
|
const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
|
||||||
|
const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
|
||||||
|
: Base(separator, prefix, suffix, precision, flags, tenPrefix, tenSuffix, fill) {}
|
||||||
|
|
||||||
|
static inline const TensorIOFormatNumpy Numpy() { return TensorIOFormatNumpy{}; }
|
||||||
|
|
||||||
|
static inline const TensorIOFormatPlain Plain() { return TensorIOFormatPlain{}; }
|
||||||
|
|
||||||
|
static inline const TensorIOFormatNative Native() { return TensorIOFormatNative{}; }
|
||||||
|
|
||||||
|
static inline const TensorIOFormatLegacy Legacy() { return TensorIOFormatLegacy{}; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int Layout, int rank, typename Format>
|
||||||
class TensorWithFormat;
|
class TensorWithFormat;
|
||||||
// specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
|
// specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
|
||||||
template <typename T, int rank>
|
template <typename T, int rank, typename Format>
|
||||||
class TensorWithFormat<T, RowMajor, rank> {
|
class TensorWithFormat<T, RowMajor, rank, Format> {
|
||||||
public:
|
public:
|
||||||
TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
|
TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank>& wf) {
|
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank, Format>& wf) {
|
||||||
// Evaluate the expression if needed
|
// Evaluate the expression if needed
|
||||||
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
||||||
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
|
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
|
||||||
Evaluator tensor(eval, DefaultDevice());
|
Evaluator tensor(eval, DefaultDevice());
|
||||||
tensor.evalSubExprsIfNeeded(NULL);
|
tensor.evalSubExprsIfNeeded(NULL);
|
||||||
internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
|
internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
|
||||||
// Cleanup.
|
// Cleanup.
|
||||||
tensor.cleanup();
|
tensor.cleanup();
|
||||||
return os;
|
return os;
|
||||||
@ -128,15 +135,15 @@ class TensorWithFormat<T, RowMajor, rank> {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
T t_tensor;
|
T t_tensor;
|
||||||
TensorIOFormat t_format;
|
Format t_format;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int rank>
|
template <typename T, int rank, typename Format>
|
||||||
class TensorWithFormat<T, ColMajor, rank> {
|
class TensorWithFormat<T, ColMajor, rank, Format> {
|
||||||
public:
|
public:
|
||||||
TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
|
TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank>& wf) {
|
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank, Format>& wf) {
|
||||||
// Switch to RowMajor storage and print afterwards
|
// Switch to RowMajor storage and print afterwards
|
||||||
typedef typename T::Index IndexType;
|
typedef typename T::Index IndexType;
|
||||||
std::array<IndexType, rank> shuffle;
|
std::array<IndexType, rank> shuffle;
|
||||||
@ -150,7 +157,7 @@ class TensorWithFormat<T, ColMajor, rank> {
|
|||||||
TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
|
TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
|
||||||
Evaluator tensor(eval, DefaultDevice());
|
Evaluator tensor(eval, DefaultDevice());
|
||||||
tensor.evalSubExprsIfNeeded(NULL);
|
tensor.evalSubExprsIfNeeded(NULL);
|
||||||
internal::TensorPrinter<Evaluator, rank>::run(os, tensor, wf.t_format);
|
internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
|
||||||
// Cleanup.
|
// Cleanup.
|
||||||
tensor.cleanup();
|
tensor.cleanup();
|
||||||
return os;
|
return os;
|
||||||
@ -158,21 +165,21 @@ class TensorWithFormat<T, ColMajor, rank> {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
T t_tensor;
|
T t_tensor;
|
||||||
TensorIOFormat t_format;
|
Format t_format;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename Format>
|
||||||
class TensorWithFormat<T, ColMajor, 0> {
|
class TensorWithFormat<T, ColMajor, 0, Format> {
|
||||||
public:
|
public:
|
||||||
TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {}
|
TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0>& wf) {
|
friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0, Format>& wf) {
|
||||||
// Evaluate the expression if needed
|
// Evaluate the expression if needed
|
||||||
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
|
||||||
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
|
TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
|
||||||
Evaluator tensor(eval, DefaultDevice());
|
Evaluator tensor(eval, DefaultDevice());
|
||||||
tensor.evalSubExprsIfNeeded(NULL);
|
tensor.evalSubExprsIfNeeded(NULL);
|
||||||
internal::TensorPrinter<Evaluator, 0>::run(os, tensor, wf.t_format);
|
internal::TensorPrinter<Evaluator, 0, Format>::run(os, tensor, wf.t_format);
|
||||||
// Cleanup.
|
// Cleanup.
|
||||||
tensor.cleanup();
|
tensor.cleanup();
|
||||||
return os;
|
return os;
|
||||||
@ -180,27 +187,39 @@ class TensorWithFormat<T, ColMajor, 0> {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
T t_tensor;
|
T t_tensor;
|
||||||
TensorIOFormat t_format;
|
Format t_format;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template <typename Tensor, std::size_t rank>
|
|
||||||
|
// Default scalar printer.
|
||||||
|
template <typename Scalar, typename Format, typename EnableIf = void>
|
||||||
|
struct ScalarPrinter {
|
||||||
|
static void run(std::ostream& stream, const Scalar& scalar, const Format& fmt) { stream << scalar; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct ScalarPrinter<Scalar, TensorIOFormatNumpy, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
|
||||||
|
static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNumpy& fmt) {
|
||||||
|
stream << numext::real(scalar) << "+" << numext::imag(scalar) << "j";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct ScalarPrinter<Scalar, TensorIOFormatNative, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
|
||||||
|
static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNative& fmt) {
|
||||||
|
stream << "{" << numext::real(scalar) << ", " << numext::imag(scalar) << "}";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Tensor, std::size_t rank, typename Format, typename EnableIf>
|
||||||
struct TensorPrinter {
|
struct TensorPrinter {
|
||||||
static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
|
using Scalar = std::remove_const_t<typename Tensor::Scalar>;
|
||||||
typedef std::remove_const_t<typename Tensor::Scalar> Scalar;
|
using ScalarPrinter = ScalarPrinter<Scalar, Format>;
|
||||||
|
|
||||||
|
static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
|
||||||
typedef typename Tensor::Index IndexType;
|
typedef typename Tensor::Index IndexType;
|
||||||
static const int layout = Tensor::Layout;
|
static const int layout = Tensor::Layout;
|
||||||
// backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
|
|
||||||
// (dim(1)*dim(2)*...*dim(rank-1)).
|
|
||||||
if (fmt.legacy_bit) {
|
|
||||||
const IndexType total_size = internal::array_prod(_t.dimensions());
|
|
||||||
if (total_size > 0) {
|
|
||||||
const IndexType first_dim = Eigen::internal::array_get<0>(_t.dimensions());
|
|
||||||
Map<const Array<Scalar, Dynamic, Dynamic, layout>> matrix(_t.data(), first_dim, total_size / first_dim);
|
|
||||||
s << matrix;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
eigen_assert(layout == RowMajor);
|
eigen_assert(layout == RowMajor);
|
||||||
typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
|
typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
|
||||||
@ -213,7 +232,7 @@ struct TensorPrinter {
|
|||||||
std::complex<int>, const Scalar&>>
|
std::complex<int>, const Scalar&>>
|
||||||
PrintType;
|
PrintType;
|
||||||
|
|
||||||
const IndexType total_size = array_prod(_t.dimensions());
|
const IndexType total_size = array_prod(tensor.dimensions());
|
||||||
|
|
||||||
std::streamsize explicit_precision;
|
std::streamsize explicit_precision;
|
||||||
if (fmt.precision == StreamPrecision) {
|
if (fmt.precision == StreamPrecision) {
|
||||||
@ -232,20 +251,16 @@ struct TensorPrinter {
|
|||||||
if (explicit_precision) old_precision = s.precision(explicit_precision);
|
if (explicit_precision) old_precision = s.precision(explicit_precision);
|
||||||
|
|
||||||
IndexType width = 0;
|
IndexType width = 0;
|
||||||
|
|
||||||
bool align_cols = !(fmt.flags & DontAlignCols);
|
bool align_cols = !(fmt.flags & DontAlignCols);
|
||||||
if (align_cols) {
|
if (align_cols) {
|
||||||
// compute the largest width
|
// compute the largest width
|
||||||
for (IndexType i = 0; i < total_size; i++) {
|
for (IndexType i = 0; i < total_size; i++) {
|
||||||
std::stringstream sstr;
|
std::stringstream sstr;
|
||||||
sstr.copyfmt(s);
|
sstr.copyfmt(s);
|
||||||
sstr << static_cast<PrintType>(_t.data()[i]);
|
ScalarPrinter::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
|
||||||
width = std::max<IndexType>(width, IndexType(sstr.str().length()));
|
width = std::max<IndexType>(width, IndexType(sstr.str().length()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::streamsize old_width = s.width();
|
|
||||||
char old_fill_character = s.fill();
|
|
||||||
|
|
||||||
s << fmt.tenPrefix;
|
s << fmt.tenPrefix;
|
||||||
for (IndexType i = 0; i < total_size; i++) {
|
for (IndexType i = 0; i < total_size; i++) {
|
||||||
std::array<bool, rank> is_at_end{};
|
std::array<bool, rank> is_at_end{};
|
||||||
@ -253,7 +268,7 @@ struct TensorPrinter {
|
|||||||
|
|
||||||
// is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
|
// is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
|
||||||
for (std::size_t k = 0; k < rank; k++) {
|
for (std::size_t k = 0; k < rank; k++) {
|
||||||
if ((i + 1) % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
|
if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
|
||||||
std::multiplies<IndexType>())) ==
|
std::multiplies<IndexType>())) ==
|
||||||
0) {
|
0) {
|
||||||
is_at_end[k] = true;
|
is_at_end[k] = true;
|
||||||
@ -262,7 +277,7 @@ struct TensorPrinter {
|
|||||||
|
|
||||||
// is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
|
// is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
|
||||||
for (std::size_t k = 0; k < rank; k++) {
|
for (std::size_t k = 0; k < rank; k++) {
|
||||||
if (i % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1,
|
if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
|
||||||
std::multiplies<IndexType>())) ==
|
std::multiplies<IndexType>())) ==
|
||||||
0) {
|
0) {
|
||||||
is_at_begin[k] = true;
|
is_at_begin[k] = true;
|
||||||
@ -318,12 +333,20 @@ struct TensorPrinter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s << prefix.str();
|
s << prefix.str();
|
||||||
if (width) {
|
// So we don't mess around with formatting, output scalar to a string stream, and adjust the width/fill manually.
|
||||||
s.fill(fmt.fill);
|
std::stringstream sstr;
|
||||||
s.width(width);
|
sstr.copyfmt(s);
|
||||||
s << std::right;
|
ScalarPrinter::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
|
||||||
|
std::string scalar_str = sstr.str();
|
||||||
|
IndexType scalar_width = scalar_str.length();
|
||||||
|
if (width && scalar_width < width) {
|
||||||
|
std::string filler;
|
||||||
|
for (IndexType i = scalar_width; i < width; ++i) {
|
||||||
|
filler.push_back(fmt.fill);
|
||||||
|
}
|
||||||
|
s << filler;
|
||||||
}
|
}
|
||||||
s << _t.data()[i];
|
s << scalar_str;
|
||||||
s << suffix.str();
|
s << suffix.str();
|
||||||
if (i < total_size - 1) {
|
if (i < total_size - 1) {
|
||||||
s << separator.str();
|
s << separator.str();
|
||||||
@ -331,17 +354,35 @@ struct TensorPrinter {
|
|||||||
}
|
}
|
||||||
s << fmt.tenSuffix;
|
s << fmt.tenSuffix;
|
||||||
if (explicit_precision) s.precision(old_precision);
|
if (explicit_precision) s.precision(old_precision);
|
||||||
if (width) {
|
}
|
||||||
s.fill(old_fill_character);
|
};
|
||||||
s.width(old_width);
|
|
||||||
|
template <typename Tensor, std::size_t rank>
|
||||||
|
struct TensorPrinter<Tensor, rank, TensorIOFormatLegacy, std::enable_if_t<rank != 0>> {
|
||||||
|
using Format = TensorIOFormatLegacy;
|
||||||
|
using Scalar = std::remove_const_t<typename Tensor::Scalar>;
|
||||||
|
using ScalarPrinter = ScalarPrinter<Scalar, Format>;
|
||||||
|
|
||||||
|
static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
|
||||||
|
typedef typename Tensor::Index IndexType;
|
||||||
|
static const int layout = Tensor::Layout;
|
||||||
|
// backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
|
||||||
|
// (dim(1)*dim(2)*...*dim(rank-1)).
|
||||||
|
const IndexType total_size = internal::array_prod(tensor.dimensions());
|
||||||
|
if (total_size > 0) {
|
||||||
|
const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
|
||||||
|
Map<const Array<Scalar, Dynamic, Dynamic, layout>> matrix(tensor.data(), first_dim, total_size / first_dim);
|
||||||
|
s << matrix;
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Tensor>
|
template <typename Tensor, typename Format>
|
||||||
struct TensorPrinter<Tensor, 0> {
|
struct TensorPrinter<Tensor, 0, Format> {
|
||||||
static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) {
|
static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
|
||||||
typedef typename Tensor::Scalar Scalar;
|
using Scalar = std::remove_const_t<typename Tensor::Scalar>;
|
||||||
|
using ScalarPrinter = ScalarPrinter<Scalar, Format>;
|
||||||
|
|
||||||
std::streamsize explicit_precision;
|
std::streamsize explicit_precision;
|
||||||
if (fmt.precision == StreamPrecision) {
|
if (fmt.precision == StreamPrecision) {
|
||||||
@ -358,8 +399,9 @@ struct TensorPrinter<Tensor, 0> {
|
|||||||
|
|
||||||
std::streamsize old_precision = 0;
|
std::streamsize old_precision = 0;
|
||||||
if (explicit_precision) old_precision = s.precision(explicit_precision);
|
if (explicit_precision) old_precision = s.precision(explicit_precision);
|
||||||
|
s << fmt.tenPrefix;
|
||||||
s << fmt.tenPrefix << _t.coeff(0) << fmt.tenSuffix;
|
ScalarPrinter::run(s, tensor.coeff(0), fmt);
|
||||||
|
s << fmt.tenSuffix;
|
||||||
if (explicit_precision) s.precision(old_precision);
|
if (explicit_precision) s.precision(old_precision);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -82,6 +82,16 @@ struct test_tensor_ostream_impl<std::complex<Scalar>, 2, Layout> {
|
|||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os << t.format(Eigen::TensorIOFormat::Plain());
|
os << t.format(Eigen::TensorIOFormat::Plain());
|
||||||
VERIFY(os.str() == " (1,2) (12,3)\n(-4,2) (0,5)\n(-1,4) (5,27)");
|
VERIFY(os.str() == " (1,2) (12,3)\n(-4,2) (0,5)\n(-1,4) (5,27)");
|
||||||
|
|
||||||
|
os.str("");
|
||||||
|
os.clear();
|
||||||
|
os << t.format(Eigen::TensorIOFormat::Numpy());
|
||||||
|
VERIFY(os.str() == "[[ 1+2j 12+3j]\n [-4+2j 0+5j]\n [-1+4j 5+27j]]");
|
||||||
|
|
||||||
|
os.str("");
|
||||||
|
os.clear();
|
||||||
|
os << t.format(Eigen::TensorIOFormat::Native());
|
||||||
|
VERIFY(os.str() == "{{ {1, 2}, {12, 3}},\n {{-4, 2}, {0, 5}},\n {{-1, 4}, {5, 27}}}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user