diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index f9f07d41e..f88793ef2 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -999,8 +999,9 @@ class TensorBase } // Returns a formatted tensor ready for printing to a stream - inline const TensorWithFormat format(const TensorIOFormat& fmt) const { - return TensorWithFormat(derived(), fmt); + template + inline const TensorWithFormat format(const Format& fmt) const { + return TensorWithFormat(derived(), fmt); } #ifdef EIGEN_READONLY_TENSORBASE_PLUGIN diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h index 985e003e1..56c497cec 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h @@ -18,33 +18,24 @@ namespace Eigen { struct TensorIOFormat; namespace internal { -template +template struct TensorPrinter; } -struct TensorIOFormat { - TensorIOFormat(const std::vector& _separator, const std::vector& _prefix, - const std::vector& _suffix, int _precision = StreamPrecision, int _flags = 0, - const std::string& _tenPrefix = "", const std::string& _tenSuffix = "", const char _fill = ' ') - : tenPrefix(_tenPrefix), - tenSuffix(_tenSuffix), - prefix(_prefix), - suffix(_suffix), - separator(_separator), - fill(_fill), - precision(_precision), - flags(_flags) { - init_spacer(); - } - - TensorIOFormat(int _precision = StreamPrecision, int _flags = 0, const std::string& _tenPrefix = "", - const std::string& _tenSuffix = "", const char _fill = ' ') - : tenPrefix(_tenPrefix), tenSuffix(_tenSuffix), fill(_fill), precision(_precision), flags(_flags) { - // default values of prefix, suffix and separator - prefix = {"", "["}; - suffix = {"", "]"}; - separator = {", ", "\n"}; - +template +struct TensorIOFormatBase { + using Derived = Derived_; + TensorIOFormatBase(const std::vector& separator, const std::vector& prefix, + const std::vector& suffix, int precision = StreamPrecision, int flags = 0, + const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ') + : tenPrefix(tenPrefix), + tenSuffix(tenSuffix), + prefix(prefix), + suffix(suffix), + separator(separator), + fill(fill), + precision(precision), + flags(flags) { init_spacer(); } @@ -67,33 +58,6 @@ struct TensorIOFormat { } } - static inline const TensorIOFormat Numpy() { - std::vector prefix = {"", "["}; - std::vector suffix = {"", "]"}; - std::vector separator = {" ", "\n"}; - return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "[", "]"); - } - - static inline const TensorIOFormat Plain() { - std::vector separator = {" ", "\n", "\n", ""}; - std::vector prefix = {""}; - std::vector suffix = {""}; - return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "", "", ' '); - } - - static inline const TensorIOFormat Native() { - std::vector separator = {", ", ",\n", "\n"}; - std::vector prefix = {"", "{"}; - std::vector suffix = {"", "}"}; - return TensorIOFormat(separator, prefix, suffix, StreamPrecision, 0, "{", "}", ' '); - } - - static inline const TensorIOFormat Legacy() { - TensorIOFormat LegacyFormat(StreamPrecision, 0, "", "", ' '); - LegacyFormat.legacy_bit = true; - return LegacyFormat; - } - std::string tenPrefix; std::string tenSuffix; std::vector prefix; @@ -103,24 +67,67 @@ struct TensorIOFormat { int precision; int flags; std::vector spacer{}; - bool legacy_bit = false; }; -template +struct TensorIOFormatNumpy : public TensorIOFormatBase { + using Base = TensorIOFormatBase; + TensorIOFormatNumpy() + : Base(/*separator=*/{" ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision, + /*flags=*/0, /*tenPrefix=*/"[", /*tenSuffix=*/"]") {} +}; + +struct TensorIOFormatNative : public TensorIOFormatBase { + using Base = TensorIOFormatBase; + TensorIOFormatNative() + : Base(/*separator=*/{", ", ",\n", "\n"}, /*prefix=*/{"", "{"}, /*suffix=*/{"", "}"}, + /*precision=*/StreamPrecision, /*flags=*/0, /*tenPrefix=*/"{", /*tenSuffix=*/"}") {} +}; + +struct TensorIOFormatPlain : public TensorIOFormatBase { + using Base = TensorIOFormatBase; + TensorIOFormatPlain() + : Base(/*separator=*/{" ", "\n", "\n", ""}, /*prefix=*/{""}, /*suffix=*/{""}, /*precision=*/StreamPrecision, + /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {} +}; + +struct TensorIOFormatLegacy : public TensorIOFormatBase { + using Base = TensorIOFormatBase; + TensorIOFormatLegacy() + : Base(/*separator=*/{", ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision, + /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {} +}; + +struct TensorIOFormat : public TensorIOFormatBase { + using Base = TensorIOFormatBase; + TensorIOFormat(const std::vector& separator, const std::vector& prefix, + const std::vector& suffix, int precision = StreamPrecision, int flags = 0, + const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ') + : Base(separator, prefix, suffix, precision, flags, tenPrefix, tenSuffix, fill) {} + + static inline const TensorIOFormatNumpy Numpy() { return TensorIOFormatNumpy{}; } + + static inline const TensorIOFormatPlain Plain() { return TensorIOFormatPlain{}; } + + static inline const TensorIOFormatNative Native() { return TensorIOFormatNative{}; } + + static inline const TensorIOFormatLegacy Legacy() { return TensorIOFormatLegacy{}; } +}; + +template class TensorWithFormat; // specialize for Layout=ColMajor, Layout=RowMajor and rank=0. -template -class TensorWithFormat { +template +class TensorWithFormat { public: - TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {} + TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {} - friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat& wf) { + friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat& wf) { // Evaluate the expression if needed typedef TensorEvaluator, DefaultDevice> Evaluator; TensorForcedEvalOp eval = wf.t_tensor.eval(); Evaluator tensor(eval, DefaultDevice()); tensor.evalSubExprsIfNeeded(NULL); - internal::TensorPrinter::run(os, tensor, wf.t_format); + internal::TensorPrinter::run(os, tensor, wf.t_format); // Cleanup. tensor.cleanup(); return os; @@ -128,15 +135,15 @@ class TensorWithFormat { protected: T t_tensor; - TensorIOFormat t_format; + Format t_format; }; -template -class TensorWithFormat { +template +class TensorWithFormat { public: - TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {} + TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {} - friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat& wf) { + friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat& wf) { // Switch to RowMajor storage and print afterwards typedef typename T::Index IndexType; std::array shuffle; @@ -150,7 +157,7 @@ class TensorWithFormat { TensorForcedEvalOp eval = tensor_row_major.eval(); Evaluator tensor(eval, DefaultDevice()); tensor.evalSubExprsIfNeeded(NULL); - internal::TensorPrinter::run(os, tensor, wf.t_format); + internal::TensorPrinter::run(os, tensor, wf.t_format); // Cleanup. tensor.cleanup(); return os; @@ -158,21 +165,21 @@ class TensorWithFormat { protected: T t_tensor; - TensorIOFormat t_format; + Format t_format; }; -template -class TensorWithFormat { +template +class TensorWithFormat { public: - TensorWithFormat(const T& tensor, const TensorIOFormat& format) : t_tensor(tensor), t_format(format) {} + TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {} - friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat& wf) { + friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat& wf) { // Evaluate the expression if needed typedef TensorEvaluator, DefaultDevice> Evaluator; TensorForcedEvalOp eval = wf.t_tensor.eval(); Evaluator tensor(eval, DefaultDevice()); tensor.evalSubExprsIfNeeded(NULL); - internal::TensorPrinter::run(os, tensor, wf.t_format); + internal::TensorPrinter::run(os, tensor, wf.t_format); // Cleanup. tensor.cleanup(); return os; @@ -180,27 +187,39 @@ class TensorWithFormat { protected: T t_tensor; - TensorIOFormat t_format; + Format t_format; }; namespace internal { -template + +// Default scalar printer. +template +struct ScalarPrinter { + static void run(std::ostream& stream, const Scalar& scalar, const Format& fmt) { stream << scalar; } +}; + +template +struct ScalarPrinter::IsComplex>> { + static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNumpy& fmt) { + stream << numext::real(scalar) << "+" << numext::imag(scalar) << "j"; + } +}; + +template +struct ScalarPrinter::IsComplex>> { + static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNative& fmt) { + stream << "{" << numext::real(scalar) << ", " << numext::imag(scalar) << "}"; + } +}; + +template struct TensorPrinter { - static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) { - typedef std::remove_const_t Scalar; + using Scalar = std::remove_const_t; + using ScalarPrinter = ScalarPrinter; + + static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) { typedef typename Tensor::Index IndexType; static const int layout = Tensor::Layout; - // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x - // (dim(1)*dim(2)*...*dim(rank-1)). - if (fmt.legacy_bit) { - const IndexType total_size = internal::array_prod(_t.dimensions()); - if (total_size > 0) { - const IndexType first_dim = Eigen::internal::array_get<0>(_t.dimensions()); - Map> matrix(_t.data(), first_dim, total_size / first_dim); - s << matrix; - return; - } - } eigen_assert(layout == RowMajor); typedef std::conditional_t::value || is_same::value || @@ -213,7 +232,7 @@ struct TensorPrinter { std::complex, const Scalar&>> PrintType; - const IndexType total_size = array_prod(_t.dimensions()); + const IndexType total_size = array_prod(tensor.dimensions()); std::streamsize explicit_precision; if (fmt.precision == StreamPrecision) { @@ -232,20 +251,16 @@ struct TensorPrinter { if (explicit_precision) old_precision = s.precision(explicit_precision); IndexType width = 0; - bool align_cols = !(fmt.flags & DontAlignCols); if (align_cols) { // compute the largest width for (IndexType i = 0; i < total_size; i++) { std::stringstream sstr; sstr.copyfmt(s); - sstr << static_cast(_t.data()[i]); + ScalarPrinter::run(sstr, static_cast(tensor.data()[i]), fmt); width = std::max(width, IndexType(sstr.str().length())); } } - std::streamsize old_width = s.width(); - char old_fill_character = s.fill(); - s << fmt.tenPrefix; for (IndexType i = 0; i < total_size; i++) { std::array is_at_end{}; @@ -253,7 +268,7 @@ struct TensorPrinter { // is the ith element the end of an coeff (always true), of a row, of a matrix, ...? for (std::size_t k = 0; k < rank; k++) { - if ((i + 1) % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1, + if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1, std::multiplies())) == 0) { is_at_end[k] = true; @@ -262,7 +277,7 @@ struct TensorPrinter { // is the ith element the begin of an coeff (always true), of a row, of a matrix, ...? for (std::size_t k = 0; k < rank; k++) { - if (i % (std::accumulate(_t.dimensions().rbegin(), _t.dimensions().rbegin() + k, 1, + if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1, std::multiplies())) == 0) { is_at_begin[k] = true; @@ -318,12 +333,20 @@ struct TensorPrinter { } s << prefix.str(); - if (width) { - s.fill(fmt.fill); - s.width(width); - s << std::right; + // So we don't mess around with formatting, output scalar to a string stream, and adjust the width/fill manually. + std::stringstream sstr; + sstr.copyfmt(s); + ScalarPrinter::run(sstr, static_cast(tensor.data()[i]), fmt); + std::string scalar_str = sstr.str(); + IndexType scalar_width = scalar_str.length(); + if (width && scalar_width < width) { + std::string filler; + for (IndexType i = scalar_width; i < width; ++i) { + filler.push_back(fmt.fill); + } + s << filler; } - s << _t.data()[i]; + s << scalar_str; s << suffix.str(); if (i < total_size - 1) { s << separator.str(); @@ -331,17 +354,35 @@ struct TensorPrinter { } s << fmt.tenSuffix; if (explicit_precision) s.precision(old_precision); - if (width) { - s.fill(old_fill_character); - s.width(old_width); + } +}; + +template +struct TensorPrinter> { + using Format = TensorIOFormatLegacy; + using Scalar = std::remove_const_t; + using ScalarPrinter = ScalarPrinter; + + static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) { + typedef typename Tensor::Index IndexType; + static const int layout = Tensor::Layout; + // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x + // (dim(1)*dim(2)*...*dim(rank-1)). + const IndexType total_size = internal::array_prod(tensor.dimensions()); + if (total_size > 0) { + const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions()); + Map> matrix(tensor.data(), first_dim, total_size / first_dim); + s << matrix; + return; } } }; -template -struct TensorPrinter { - static void run(std::ostream& s, const Tensor& _t, const TensorIOFormat& fmt) { - typedef typename Tensor::Scalar Scalar; +template +struct TensorPrinter { + static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) { + using Scalar = std::remove_const_t; + using ScalarPrinter = ScalarPrinter; std::streamsize explicit_precision; if (fmt.precision == StreamPrecision) { @@ -358,8 +399,9 @@ struct TensorPrinter { std::streamsize old_precision = 0; if (explicit_precision) old_precision = s.precision(explicit_precision); - - s << fmt.tenPrefix << _t.coeff(0) << fmt.tenSuffix; + s << fmt.tenPrefix; + ScalarPrinter::run(s, tensor.coeff(0), fmt); + s << fmt.tenSuffix; if (explicit_precision) s.precision(old_precision); } }; diff --git a/unsupported/test/cxx11_tensor_io.cpp b/unsupported/test/cxx11_tensor_io.cpp index 16285c13a..27b32305b 100644 --- a/unsupported/test/cxx11_tensor_io.cpp +++ b/unsupported/test/cxx11_tensor_io.cpp @@ -82,6 +82,16 @@ struct test_tensor_ostream_impl, 2, Layout> { std::ostringstream os; os << t.format(Eigen::TensorIOFormat::Plain()); VERIFY(os.str() == " (1,2) (12,3)\n(-4,2) (0,5)\n(-1,4) (5,27)"); + + os.str(""); + os.clear(); + os << t.format(Eigen::TensorIOFormat::Numpy()); + VERIFY(os.str() == "[[ 1+2j 12+3j]\n [-4+2j 0+5j]\n [-1+4j 5+27j]]"); + + os.str(""); + os.clear(); + os << t.format(Eigen::TensorIOFormat::Native()); + VERIFY(os.str() == "{{ {1, 2}, {12, 3}},\n {{-4, 2}, {0, 5}},\n {{-1, 4}, {5, 27}}}"); } };