Fix KDTree search for closest points for coord_t

NOTE: Store squared distance into double
This commit is contained in:
Filip Sykala - NTB T15p 2025-01-10 13:45:39 +01:00 committed by Lukas Matena
parent 92e28d93ff
commit c66df2ce99
2 changed files with 32 additions and 16 deletions

View File

@ -216,44 +216,43 @@ std::array<size_t, K> find_closest_points(
const Tree &kdtree;
const PointType &point;
const FilterFn filter;
std::array<std::pair<size_t, CoordT>, K> results;
struct Result {
size_t index;
double distance_sq;
};
std::array<Result, K> results;
Visitor(const Tree &kdtree, const PointType &point, FilterFn filter)
: kdtree(kdtree), point(point), filter(filter)
{
results.fill(std::make_pair(Tree::npos,
std::numeric_limits<CoordT>::max()));
results.fill(Result{Tree::npos, std::numeric_limits<double>::max()});
}
unsigned int operator()(size_t idx, size_t dimension)
{
if (this->filter(idx)) {
auto dist = CoordT(0);
double distance_sq = 0.;
for (size_t i = 0; i < D; ++i) {
CoordT d = point[i] - kdtree.coordinate(idx, i);
dist += d * d;
distance_sq += double(d) * d;
}
auto res = std::make_pair(idx, dist);
auto it = std::lower_bound(results.begin(), results.end(),
res, [](auto &r1, auto &r2) {
return r1.second < r2.second;
});
Result res{idx, distance_sq};
auto lower_distance = [](const Result &r1, const Result &r2) {
return r1.distance_sq < r2.distance_sq; };
auto it = std::lower_bound(results.begin(), results.end(), res, lower_distance);
if (it != results.end()) {
std::rotate(it, std::prev(results.end()), results.end());
*it = res;
}
}
return kdtree.descent_mask(point[dimension],
results.front().second, idx,
dimension);
return kdtree.descent_mask(point[dimension], results.front().distance_sq, idx, dimension);
}
} visitor(kdtree, point, filter);
kdtree.visit(visitor);
std::array<size_t, K> ret;
for (size_t i = 0; i < K; i++) ret[i] = visitor.results[i].first;
for (size_t i = 0; i < K; i++)
ret[i] = visitor.results[i].index;
return ret;
}

View File

@ -86,6 +86,23 @@ TEST_CASE("Test kdtree query for a Box", "[KDTreeIndirect]")
REQUIRE(call_count < pgrid.point_count());
}
TEST_CASE("Test kdtree closests points", "[KDTreeIndirect]") {
Points pts{
Point{-9000000, 9000000},
Point{-9000000, -9000000},
Point{ 9000000, -9000000},
Point{ 9000000, 9000000},
Point{25, 25}
};
auto point_accessor = [&pts](size_t idx, size_t dim) -> coord_t & {
return pts[idx][dim];
};
KDTreeIndirect<2, coord_t, decltype(point_accessor)> tree(point_accessor, pts.size());
std::array<size_t, 5> closest = find_closest_points<5>(tree, Point{0, 0});
CHECK(closest[0] == 4);
}
//TEST_CASE("Test kdtree query for a Sphere", "[KDTreeIndirect]") {
// auto vol = BoundingBox3Base<Vec3f>{{0.f, 0.f, 0.f}, {10.f, 10.f, 10.f}};