mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
unsupported/TensorSymmetry: factor out completely from Tensor module
Remove the symCoeff() method of the the Tensor module and move the functionality into a new operator() of the symmetry classes. This makes the Tensor module now completely self-contained without symmetry support (even though previously it was only a forward declaration and a otherwise harmless trivial templated method) and also removes the inconsistency with the rest of eigen w.r.t. the method's naming scheme.
This commit is contained in:
parent
ea99433523
commit
96cb58fa3b
@ -91,9 +91,6 @@ struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
|
|||||||
return std_array_get<RowMajor ? 0 : NumIndices - 1>(indices);
|
return std_array_get<RowMajor ? 0 : NumIndices - 1>(indices);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Forward-declaration required for the symmetry support. */
|
|
||||||
template<typename Tensor_, typename Symmetry_, int Flags = 0> class tensor_symmetry_value_setter;
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
template<typename Scalar_, std::size_t NumIndices_, int Options_>
|
template<typename Scalar_, std::size_t NumIndices_, int Options_>
|
||||||
@ -285,18 +282,6 @@ class Tensor
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Symmetry_, typename... IndexTypes>
|
|
||||||
internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, Index firstIndex, IndexTypes... otherIndices)
|
|
||||||
{
|
|
||||||
return symCoeff(symmetry, std::array<Index, NumIndices>{{firstIndex, otherIndices...}});
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename Symmetry_, typename... IndexTypes>
|
|
||||||
internal::tensor_symmetry_value_setter<Self, Symmetry_> symCoeff(const Symmetry_& symmetry, std::array<Index, NumIndices> const& indices)
|
|
||||||
{
|
|
||||||
return internal::tensor_symmetry_value_setter<Self, Symmetry_>(*this, symmetry, indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool checkIndexRange(const std::array<Index, NumIndices>& indices) const
|
bool checkIndexRange(const std::array<Index, NumIndices>& indices) const
|
||||||
{
|
{
|
||||||
|
@ -50,6 +50,19 @@ class DynamicSGroup
|
|||||||
|
|
||||||
inline int globalFlags() const { return m_globalFlags; }
|
inline int globalFlags() const { return m_globalFlags; }
|
||||||
inline std::size_t size() const { return m_elements.size(); }
|
inline std::size_t size() const { return m_elements.size(); }
|
||||||
|
|
||||||
|
template<typename Tensor_, typename... IndexTypes>
|
||||||
|
inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
|
||||||
|
{
|
||||||
|
static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
||||||
|
return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Tensor_>
|
||||||
|
inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
|
||||||
|
{
|
||||||
|
return internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup>(tensor, *this, indices);
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
struct GroupElement {
|
struct GroupElement {
|
||||||
std::vector<int> representation;
|
std::vector<int> representation;
|
||||||
|
@ -212,6 +212,19 @@ class StaticSGroup
|
|||||||
return ge::count;
|
return ge::count;
|
||||||
}
|
}
|
||||||
constexpr static inline int globalFlags() { return group_elements::global_flags; }
|
constexpr static inline int globalFlags() { return group_elements::global_flags; }
|
||||||
|
|
||||||
|
template<typename Tensor_, typename... IndexTypes>
|
||||||
|
inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
|
||||||
|
{
|
||||||
|
static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
|
||||||
|
return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Tensor_>
|
||||||
|
inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
|
||||||
|
{
|
||||||
|
return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -293,7 +293,7 @@ struct tensor_symmetry_calculate_flags
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Tensor_, typename Symmetry_, int Flags>
|
template<typename Tensor_, typename Symmetry_, int Flags = 0>
|
||||||
class tensor_symmetry_value_setter
|
class tensor_symmetry_value_setter
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
@ -661,7 +661,7 @@ static void test_tensor_epsilon()
|
|||||||
Tensor<int, 3> epsilon(3,3,3);
|
Tensor<int, 3> epsilon(3,3,3);
|
||||||
|
|
||||||
epsilon.setZero();
|
epsilon.setZero();
|
||||||
epsilon.symCoeff(sym, 0, 1, 2) = 1;
|
sym(epsilon, 0, 1, 2) = 1;
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
for (int j = 0; j < 3; j++) {
|
for (int j = 0; j < 3; j++) {
|
||||||
@ -683,7 +683,7 @@ static void test_tensor_sym()
|
|||||||
for (int k = l; k < 10; k++) {
|
for (int k = l; k < 10; k++) {
|
||||||
for (int j = 0; j < 10; j++) {
|
for (int j = 0; j < 10; j++) {
|
||||||
for (int i = j; i < 10; i++) {
|
for (int i = j; i < 10; i++) {
|
||||||
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
|
sym(t, i, j, k, l) = (i + j) * (k + l);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -712,7 +712,7 @@ static void test_tensor_asym()
|
|||||||
for (int k = l + 1; k < 10; k++) {
|
for (int k = l + 1; k < 10; k++) {
|
||||||
for (int j = 0; j < 10; j++) {
|
for (int j = 0; j < 10; j++) {
|
||||||
for (int i = j + 1; i < 10; i++) {
|
for (int i = j + 1; i < 10; i++) {
|
||||||
t.symCoeff(sym, i, j, k, l) = ((i * j) + (k * l));
|
sym(t, i, j, k, l) = ((i * j) + (k * l));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -751,7 +751,7 @@ static void test_tensor_dynsym()
|
|||||||
for (int k = l; k < 10; k++) {
|
for (int k = l; k < 10; k++) {
|
||||||
for (int j = 0; j < 10; j++) {
|
for (int j = 0; j < 10; j++) {
|
||||||
for (int i = j; i < 10; i++) {
|
for (int i = j; i < 10; i++) {
|
||||||
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
|
sym(t, i, j, k, l) = (i + j) * (k + l);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -787,7 +787,7 @@ static void test_tensor_randacc()
|
|||||||
std::swap(i, j);
|
std::swap(i, j);
|
||||||
if (k < l)
|
if (k < l)
|
||||||
std::swap(k, l);
|
std::swap(k, l);
|
||||||
t.symCoeff(sym, i, j, k, l) = (i + j) * (k + l);
|
sym(t, i, j, k, l) = (i + j) * (k + l);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int l = 0; l < 10; l++) {
|
for (int l = 0; l < 10; l++) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user