From ea9943352368b990d27ba22eb8670287cf96302d Mon Sep 17 00:00:00 2001 From: Christian Seiler Date: Wed, 4 Jun 2014 20:27:42 +0200 Subject: [PATCH] unsupported/TensorSymmetry: make symgroup construction autodetect number of indices When constructing a symmetry group, make the code automatically detect the number of indices required from the indices of the group's generators. Also, allow the symmetry group to be applied to lists of indices that are larger than the number of indices of the symmetry group. Before: SGroup<4, Symmetry<0, 1>, Symmetry<2,3>> group; group.apply(std::array{{0, 1, 2, 3}}, 0); After: SGroup, Symmetry<2,3>> group; group.apply(std::array{{0, 1, 2, 3}}, 0); group.apply(std::array{{0, 1, 2, 3, 4}}, 0); This should make the symmetry group easier to use - especially if one wants to reuse the same symmetry group for different tensors of maybe different rank. static/runtime asserts remain for the case where the length of the index list to which a symmetry group is to be applied is too small. --- .../src/TensorSymmetry/DynamicSymmetry.h | 42 +++++++++++----- .../CXX11/src/TensorSymmetry/StaticSymmetry.h | 50 +++++++++++-------- .../Eigen/CXX11/src/TensorSymmetry/Symmetry.h | 47 +++++++++++++---- unsupported/test/cxx11_tensor_symmetry.cpp | 26 +++++----- 4 files changed, 108 insertions(+), 57 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h index b5738b778..0329278a9 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h @@ -15,7 +15,7 @@ namespace Eigen { class DynamicSGroup { public: - inline explicit DynamicSGroup(std::size_t numIndices) : m_numIndices(numIndices), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); } + inline explicit DynamicSGroup() : m_numIndices(1), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); } inline DynamicSGroup(const DynamicSGroup& o) : m_numIndices(o.m_numIndices), m_elements(o.m_elements), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { } inline DynamicSGroup(DynamicSGroup&& o) : m_numIndices(o.m_numIndices), m_elements(), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { std::swap(m_elements, o.m_elements); } inline DynamicSGroup& operator=(const DynamicSGroup& o) { m_numIndices = o.m_numIndices; m_elements = o.m_elements; m_generators = o.m_generators; m_globalFlags = o.m_globalFlags; return *this; } @@ -33,7 +33,7 @@ class DynamicSGroup template inline RV apply(const std::array& idx, RV initial, Args&&... args) const { - eigen_assert(N == m_numIndices); + eigen_assert(N >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); for (std::size_t i = 0; i < size(); i++) initial = Op::run(h_permute(i, idx, typename internal::gen_numeric_list::type()), m_elements[i].flags, initial, std::forward(args)...); return initial; @@ -42,7 +42,7 @@ class DynamicSGroup template inline RV apply(const std::vector& idx, RV initial, Args&&... args) const { - eigen_assert(idx.size() == m_numIndices); + eigen_assert(idx.size() >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); for (std::size_t i = 0; i < size(); i++) initial = Op::run(h_permute(i, idx), m_elements[i].flags, initial, std::forward(args)...); return initial; @@ -77,7 +77,7 @@ class DynamicSGroup template inline std::array h_permute(std::size_t which, const std::array& idx, internal::numeric_list) const { - return std::array{{ idx[m_elements[which].representation[n]]... }}; + return std::array{{ idx[n >= m_numIndices ? n : m_elements[which].representation[n]]... }}; } template @@ -87,6 +87,8 @@ class DynamicSGroup result.reserve(idx.size()); for (auto k : m_elements[which].representation) result.push_back(idx[k]); + for (std::size_t i = m_numIndices; i < idx.size(); i++) + result.push_back(idx[i]); return result; } @@ -135,18 +137,18 @@ class DynamicSGroup }; // dynamic symmetry group that auto-adds the template parameters in the constructor -template +template class DynamicSGroupFromTemplateArgs : public DynamicSGroup { public: - inline DynamicSGroupFromTemplateArgs() : DynamicSGroup(NumIndices) + inline DynamicSGroupFromTemplateArgs() : DynamicSGroup() { add_all(internal::type_list()); } inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs const& other) : DynamicSGroup(other) { } inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs&& other) : DynamicSGroup(other) { } - inline DynamicSGroupFromTemplateArgs& operator=(const DynamicSGroupFromTemplateArgs& o) { DynamicSGroup::operator=(o); return *this; } - inline DynamicSGroupFromTemplateArgs& operator=(DynamicSGroupFromTemplateArgs&& o) { DynamicSGroup::operator=(o); return *this; } + inline DynamicSGroupFromTemplateArgs& operator=(const DynamicSGroupFromTemplateArgs& o) { DynamicSGroup::operator=(o); return *this; } + inline DynamicSGroupFromTemplateArgs& operator=(DynamicSGroupFromTemplateArgs&& o) { DynamicSGroup::operator=(o); return *this; } private: template @@ -168,18 +170,32 @@ inline DynamicSGroup::GroupElement DynamicSGroup::mul(GroupElement g1, GroupElem GroupElement result; result.representation.reserve(m_numIndices); - for (std::size_t i = 0; i < m_numIndices; i++) - result.representation.push_back(g2.representation[g1.representation[i]]); + for (std::size_t i = 0; i < m_numIndices; i++) { + int v = g2.representation[g1.representation[i]]; + eigen_assert(v >= 0); + result.representation.push_back(v); + } result.flags = g1.flags ^ g2.flags; return result; } inline void DynamicSGroup::add(int one, int two, int flags) { - eigen_assert(one >= 0 && (std::size_t)one < m_numIndices); - eigen_assert(two >= 0 && (std::size_t)two < m_numIndices); + eigen_assert(one >= 0); + eigen_assert(two >= 0); eigen_assert(one != two); - Generator g{one, two ,flags}; + + if ((std::size_t)one >= m_numIndices || (std::size_t)two >= m_numIndices) { + std::size_t newNumIndices = (one > two) ? one : two + 1; + for (auto& gelem : m_elements) { + gelem.representation.reserve(newNumIndices); + for (std::size_t i = m_numIndices; i < newNumIndices; i++) + gelem.representation.push_back(i); + } + m_numIndices = newNumIndices; + } + + Generator g{one, two, flags}; GroupElement e = ge(g); /* special case for first generator */ diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h index c5a630105..0eb468fc0 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h @@ -114,20 +114,24 @@ struct tensor_static_symgroup_equality template struct tensor_static_symgroup { - typedef StaticSGroup type; + typedef StaticSGroup type; constexpr static std::size_t size = type::static_size; }; -template -constexpr static inline std::array tensor_static_symgroup_index_permute(std::array idx, internal::numeric_list) +template +constexpr static inline std::array tensor_static_symgroup_index_permute(std::array idx, internal::numeric_list, internal::numeric_list) { - return {{ idx[ii]... }}; + return {{ idx[ii]..., idx[jj]... }}; } template static inline std::vector tensor_static_symgroup_index_permute(std::vector idx, internal::numeric_list) { - return {{ idx[ii]... }}; + std::vector result{{ idx[ii]... }}; + std::size_t target_size = idx.size(); + for (std::size_t i = result.size(); i < target_size; i++) + result.push_back(idx[i]); + return result; } template struct tensor_static_symgroup_do_apply; @@ -135,32 +139,35 @@ template struct tensor_static_symgroup_do_apply; template struct tensor_static_symgroup_do_apply> { - template - static inline RV run(const std::array& idx, RV initial, Args&&... args) + template + static inline RV run(const std::array& idx, RV initial, Args&&... args) { - initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward(args)...); - return tensor_static_symgroup_do_apply>::template run(idx, initial, args...); + static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices."); + typedef typename internal::gen_numeric_list::type remaining_indices; + initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward(args)...); + return tensor_static_symgroup_do_apply>::template run(idx, initial, args...); } - template + template static inline RV run(const std::vector& idx, RV initial, Args&&... args) { + eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices."); initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward(args)...); - return tensor_static_symgroup_do_apply>::template run(idx, initial, args...); + return tensor_static_symgroup_do_apply>::template run(idx, initial, args...); } }; template struct tensor_static_symgroup_do_apply> { - template - static inline RV run(const std::array&, RV initial, Args&&...) + template + static inline RV run(const std::array&, RV initial, Args&&...) { // do nothing return initial; } - template + template static inline RV run(const std::vector&, RV initial, Args&&...) { // do nothing @@ -170,9 +177,10 @@ struct tensor_static_symgroup_do_apply +template class StaticSGroup { + constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices::value; typedef internal::group_theory::enumerate_group_elements< internal::tensor_static_symgroup_multiply, internal::tensor_static_symgroup_equality, @@ -182,20 +190,20 @@ class StaticSGroup typedef typename group_elements::type ge; public: constexpr inline StaticSGroup() {} - constexpr inline StaticSGroup(const StaticSGroup&) {} - constexpr inline StaticSGroup(StaticSGroup&&) {} + constexpr inline StaticSGroup(const StaticSGroup&) {} + constexpr inline StaticSGroup(StaticSGroup&&) {} - template - static inline RV apply(const std::array& idx, RV initial, Args&&... args) + template + static inline RV apply(const std::array& idx, RV initial, Args&&... args) { - return internal::tensor_static_symgroup_do_apply::template run(idx, initial, args...); + return internal::tensor_static_symgroup_do_apply::template run(idx, initial, args...); } template static inline RV apply(const std::vector& idx, RV initial, Args&&... args) { eigen_assert(idx.size() == NumIndices); - return internal::tensor_static_symgroup_do_apply::template run(idx, initial, args...); + return internal::tensor_static_symgroup_do_apply::template run(idx, initial, args...); } constexpr static std::size_t static_size = ge::count; diff --git a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h index f0813086a..f1ccc33ef 100644 --- a/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h +++ b/unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h @@ -30,6 +30,7 @@ template struct tenso template struct tensor_static_symgroup_if; template struct tensor_symmetry_calculate_flags; template struct tensor_symmetry_assign_value; +template struct tensor_symmetry_num_indices; } // end namespace internal @@ -94,7 +95,7 @@ class DynamicSGroup; * This class is a child class of DynamicSGroup. It uses the template arguments * specified to initialize itself. */ -template +template class DynamicSGroupFromTemplateArgs; /** \class StaticSGroup @@ -116,7 +117,7 @@ class DynamicSGroupFromTemplateArgs; * group becomes too large. (In that case, unrolling may not even be * beneficial.) */ -template +template class StaticSGroup; /** \class SGroup @@ -131,24 +132,50 @@ class StaticSGroup; * \sa StaticSGroup * \sa DynamicSGroup */ -template -class SGroup : public internal::tensor_symmetry_pre_analysis::root_type +template +class SGroup : public internal::tensor_symmetry_pre_analysis::value, Gen...>::root_type { public: + constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices::value; typedef typename internal::tensor_symmetry_pre_analysis::root_type Base; // make standard constructors + assignment operators public inline SGroup() : Base() { } - inline SGroup(const SGroup& other) : Base(other) { } - inline SGroup(SGroup&& other) : Base(other) { } - inline SGroup& operator=(const SGroup& other) { Base::operator=(other); return *this; } - inline SGroup& operator=(SGroup&& other) { Base::operator=(other); return *this; } + inline SGroup(const SGroup& other) : Base(other) { } + inline SGroup(SGroup&& other) : Base(other) { } + inline SGroup& operator=(const SGroup& other) { Base::operator=(other); return *this; } + inline SGroup& operator=(SGroup&& other) { Base::operator=(other); return *this; } // all else is defined in the base class }; namespace internal { +template struct tensor_symmetry_num_indices +{ + constexpr static std::size_t value = 1; +}; + +template struct tensor_symmetry_num_indices, Sym...> +{ +private: + constexpr static std::size_t One = static_cast(One_); + constexpr static std::size_t Two = static_cast(Two_); + constexpr static std::size_t Three = tensor_symmetry_num_indices::value; + + // don't use std::max, since it's not constexpr until C++14... + constexpr static std::size_t maxOneTwoPlusOne = ((One > Two) ? One : Two) + 1; +public: + constexpr static std::size_t value = (maxOneTwoPlusOne > Three) ? maxOneTwoPlusOne : Three; +}; + +template struct tensor_symmetry_num_indices, Sym...> + : public tensor_symmetry_num_indices, Sym...> {}; +template struct tensor_symmetry_num_indices, Sym...> + : public tensor_symmetry_num_indices, Sym...> {}; +template struct tensor_symmetry_num_indices, Sym...> + : public tensor_symmetry_num_indices, Sym...> {}; + /** \internal * * \class tensor_symmetry_pre_analysis @@ -199,7 +226,7 @@ namespace internal { template struct tensor_symmetry_pre_analysis { - typedef StaticSGroup root_type; + typedef StaticSGroup<> root_type; }; template @@ -212,7 +239,7 @@ struct tensor_symmetry_pre_analysis typedef typename conditional< possible_size == 0 || possible_size >= max_static_elements, - DynamicSGroupFromTemplateArgs, + DynamicSGroupFromTemplateArgs, typename helper::type >::type root_type; }; diff --git a/unsupported/test/cxx11_tensor_symmetry.cpp b/unsupported/test/cxx11_tensor_symmetry.cpp index e8dfffd92..2a1669995 100644 --- a/unsupported/test/cxx11_tensor_symmetry.cpp +++ b/unsupported/test/cxx11_tensor_symmetry.cpp @@ -32,8 +32,8 @@ using Eigen::GlobalImagFlag; // helper function to determine if the compiler intantiated a static // or dynamic symmetry group -template -bool isDynGroup(StaticSGroup const& dummy) +template +bool isDynGroup(StaticSGroup const& dummy) { (void)dummy; return false; @@ -86,7 +86,7 @@ static void test_symgroups_static() std::array identity{{0,1,2,3,4,5,6}}; // Simple static symmetry group - StaticSGroup<7, + StaticSGroup< AntiSymmetry<0,1>, Hermiticity<0,2> > group; @@ -113,7 +113,7 @@ static void test_symgroups_dynamic() identity.push_back(i); // Simple dynamic symmetry group - DynamicSGroup group(7); + DynamicSGroup group; group.add(0,1,NegationFlag); group.add(0,2,ConjugationFlag); @@ -143,7 +143,7 @@ static void test_symgroups_selection() { // Do the same test as in test_symgroups_static but // require selection via SGroup - SGroup<7, + SGroup< AntiSymmetry<0,1>, Hermiticity<0,2> > group; @@ -168,7 +168,7 @@ static void test_symgroups_selection() // simple factorizing group: 5 generators, 2^5 = 32 elements // selection should make this dynamic, although static group // can still be reasonably generated - SGroup<10, + SGroup< Symmetry<0,1>, Symmetry<2,3>, Symmetry<4,5>, @@ -196,7 +196,7 @@ static void test_symgroups_selection() // no verify that we could also generate a static group // with these generators found.clear(); - StaticSGroup<10, + StaticSGroup< Symmetry<0,1>, Symmetry<2,3>, Symmetry<4,5>, @@ -211,7 +211,7 @@ static void test_symgroups_selection() { // try to create a HUGE group - SGroup<7, + SGroup< Symmetry<0,1>, Symmetry<1,2>, Symmetry<2,3>, @@ -657,7 +657,7 @@ static void test_symgroups_selection() static void test_tensor_epsilon() { - SGroup<3, AntiSymmetry<0,1>, AntiSymmetry<1,2>> sym; + SGroup, AntiSymmetry<1,2>> sym; Tensor epsilon(3,3,3); epsilon.setZero(); @@ -674,7 +674,7 @@ static void test_tensor_epsilon() static void test_tensor_sym() { - SGroup<4, Symmetry<0,1>, Symmetry<2,3>> sym; + SGroup, Symmetry<2,3>> sym; Tensor t(10,10,10,10); t.setZero(); @@ -703,7 +703,7 @@ static void test_tensor_sym() static void test_tensor_asym() { - SGroup<4, AntiSymmetry<0,1>, AntiSymmetry<2,3>> sym; + SGroup, AntiSymmetry<2,3>> sym; Tensor t(10,10,10,10); t.setZero(); @@ -740,7 +740,7 @@ static void test_tensor_asym() static void test_tensor_dynsym() { - DynamicSGroup sym(4); + DynamicSGroup sym; sym.addSymmetry(0,1); sym.addSymmetry(2,3); Tensor t(10,10,10,10); @@ -770,7 +770,7 @@ static void test_tensor_dynsym() static void test_tensor_randacc() { - SGroup<4, Symmetry<0,1>, Symmetry<2,3>> sym; + SGroup, Symmetry<2,3>> sym; Tensor t(10,10,10,10); t.setZero();