:heavy_check_mark: test/DataStructure/kd-tree/aoj-DSL_2_C.cpp

Depends on

Code

// competitive-verifier: PROBLEM https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_2_C

#include "DataStructure/kd_tree.hpp"
#include <bits/stdc++.h>

using namespace std;

int main()
{
    int n;
    cin >> n;

    vector<array<int, 2>> p(n);
    for (int i = 0; i < n; i++)
    {
        cin >> p[i][0] >> p[i][1];
    }

    KDTree<int, 2> kd_tree(p);
    int q;
    cin >> q;

    for (int i = 0; i < q; i++)
    {
        int sx, tx, sy, ty;
        cin >> sx >> tx >> sy >> ty;

        auto result = kd_tree.range_search({sx, sy}, {tx, ty});
        sort(result.begin(), result.end(), [](const auto a, const auto b) -> bool
             { return a.id < b.id; });
        for (auto n : result)
        {
            cout << n.id << endl;
        }
        cout << endl;
    }

    return 0;
}
#line 1 "test/DataStructure/kd-tree/aoj-DSL_2_C.cpp"
// competitive-verifier: PROBLEM https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_2_C

#line 1 "DataStructure/kd_tree.hpp"
#include <bits/stdc++.h>
using namespace std;

template <typename T, size_t K>
class KDTree
{
public:
    // 座標と id をまとめた構造体
    struct PointWithID
    {
        array<T, K> point;
        size_t id;
    };

    // ノードは PointWithID の情報を持つ
    struct Node
    {
        PointWithID data;
        Node *left;
        Node *right;
        Node(const PointWithID &pwid)
            : data(pwid), left(nullptr), right(nullptr) {}
    };

    // コンストラクタ: 点群から kd 木を構築
    KDTree(const vector<array<T, K>> &points)
    {
        vector<PointWithID> pts;
        pts.reserve(points.size());
        for (size_t i = 0; i < points.size(); ++i)
        {
            pts.push_back({points[i], i});
        }
        root = build(pts.begin(), pts.end(), 0);
    }

    ~KDTree()
    {
        free_tree(root);
    }

    optional<PointWithID> nearest(const array<T, K> &target,
                                  function<bool(int)> check) const
    {
        Node *best = nullptr;
        T bestDist = numeric_limits<T>::max();
        nearest_search(root, target, 0, best, bestDist, check);
        if (best)
            return best->data;
        else
            return nullopt;
    }

    optional<PointWithID> nearest(const array<T, K> &target) const
    {
        return nearest(target, [](int)
                       { return true; });
    }

    vector<PointWithID> range_search(const array<T, K> &lower, const array<T, K> &upper,
                                     function<bool(int)> check) const
    {
        vector<PointWithID> results;
        range_search(root, 0, lower, upper, check, results);
        return results;
    }

    vector<PointWithID> range_search(const array<T, K> &lower, const array<T, K> &upper) const
    {
        return range_search(lower, upper, [](int)
                            { return true; });
    }

    /// target から近いものを最大 k 個返す(check で絞り込み可)
    vector<PointWithID> nearest_k(const array<T, K> &target, size_t k, function<bool(int)> check) const
    {
        if (k == 0 || root == nullptr)
        {
            return {};
        }

        MaxHeap pq;
        nearest_k_search(root, target, 0, k, check, pq);

        vector<PointWithID> res;
        res.reserve(pq.size());
        while (!pq.empty())
        {
            res.push_back(pq.top().second->data);
            pq.pop();
        }

        reverse(res.begin(), res.end());
        return res;
    }

    vector<PointWithID> nearest_k(const array<T, K> &target, size_t k) const
    {
        return nearest_k(target, k, [](int)
                         { return true; });
    }

private:
    using iterator = typename vector<PointWithID>::iterator;
    Node *root = nullptr;

    // kd木の構築 (中央値による分割)
    Node *build(iterator begin, iterator end, size_t depth)
    {
        if (begin >= end)
            return nullptr;
        size_t axis = depth % K;
        iterator mid = begin + distance(begin, end) / 2;
        nth_element(begin, mid, end, [axis](const PointWithID &a, const PointWithID &b)
                    { return a.point[axis] < b.point[axis]; });
        Node *node = new Node(*mid);
        node->left = build(begin, mid, depth + 1);
        node->right = build(mid + 1, end, depth + 1);
        return node;
    }

    // 再帰的な最近傍探索
    void nearest_search(Node *node, const array<T, K> &target, size_t depth,
                        Node *&best, T &bestDist, function<bool(int)> check) const
    {
        if (node == nullptr)
            return;
        size_t axis = depth % K;
        T d = distance_squared(node->data.point, target);

        if (d < bestDist && check(static_cast<int>(node->data.id)))
        {
            bestDist = d;
            best = node;
        }
        T diff = target[axis] - node->data.point[axis];
        Node *nearChild = (diff < 0) ? node->left : node->right;
        Node *farChild = (diff < 0) ? node->right : node->left;
        nearest_search(nearChild, target, depth + 1, best, bestDist, check);
        if (diff * diff < bestDist)
            nearest_search(farChild, target, depth + 1, best, bestDist, check);
    }

    // 再帰的な範囲探索
    void range_search(Node *node, size_t depth,
                      const array<T, K> &lower, const array<T, K> &upper,
                      function<bool(int)> check, vector<PointWithID> &results) const
    {
        if (node == nullptr)
            return;
        bool inside = true;
        for (size_t i = 0; i < K; i++)
        {
            if (node->data.point[i] < lower[i] || node->data.point[i] > upper[i])
            {
                inside = false;
                break;
            }
        }
        if (inside && check(static_cast<int>(node->data.id)))
            results.push_back(node->data);

        size_t axis = depth % K;
        if (lower[axis] <= node->data.point[axis])
            range_search(node->left, depth + 1, lower, upper, check, results);
        if (upper[axis] >= node->data.point[axis])
            range_search(node->right, depth + 1, lower, upper, check, results);
    }

    using QElem = std::pair<T, const Node *>;
    struct WorseFirst
    {
        bool operator()(const QElem &lhs, const QElem &rhs) const
        {
            return lhs.first < rhs.first;
        }
    };

    using MaxHeap = priority_queue<QElem, vector<QElem>, WorseFirst>;

    void nearest_k_search(
        const Node *node, const array<T, K> &target,
        size_t depth, size_t k,
        const function<bool(int)> &check, MaxHeap &pq) const
    {
        if (node == nullptr)
            return;

        size_t axis = depth % K;
        T dist2 = distance_squared(node->data.point, target);

        if (check(static_cast<int>(node->data.id)))
        {
            if (pq.size() < k)
            {
                pq.emplace(dist2, node);
            }
            else if (dist2 < pq.top().first)
            {
                pq.pop();
                pq.emplace(dist2, node);
            }
        }

        T diff = target[axis] - node->data.point[axis];
        const Node *nearChild = (diff < 0) ? node->left : node->right;
        const Node *farChild = (diff < 0) ? node->right : node->left;

        nearest_k_search(nearChild, target, depth + 1, k, check, pq);

        if (pq.size() < k || diff * diff < pq.top().first)
        {
            nearest_k_search(farChild, target, depth + 1, k, check, pq);
        }
    }

    // 2点間の二乗距離を計算
    T distance_squared(const array<T, K> &a, const array<T, K> &b) const
    {
        T dist = 0;
        for (size_t i = 0; i < K; i++)
        {
            T d = a[i] - b[i];
            dist += d * d;
        }
        return dist;
    }

    // kd木のメモリ解放
    void free_tree(Node *node)
    {
        if (node == nullptr)
            return;
        free_tree(node->left);
        free_tree(node->right);
        delete node;
    }
};
#line 5 "test/DataStructure/kd-tree/aoj-DSL_2_C.cpp"

using namespace std;

int main()
{
    int n;
    cin >> n;

    vector<array<int, 2>> p(n);
    for (int i = 0; i < n; i++)
    {
        cin >> p[i][0] >> p[i][1];
    }

    KDTree<int, 2> kd_tree(p);
    int q;
    cin >> q;

    for (int i = 0; i < q; i++)
    {
        int sx, tx, sy, ty;
        cin >> sx >> tx >> sy >> ty;

        auto result = kd_tree.range_search({sx, sy}, {tx, ty});
        sort(result.begin(), result.end(), [](const auto a, const auto b) -> bool
             { return a.id < b.id; });
        for (auto n : result)
        {
            cout << n.id << endl;
        }
        cout << endl;
    }

    return 0;
}

Test cases

Env Name Status Elapsed Memory
g++ 00_sample_00.in :heavy_check_mark: AC 5 ms 3 MB
g++ 01_small_00.in :heavy_check_mark: AC 4 ms 3 MB
g++ 02_corner_00.in :heavy_check_mark: AC 4 ms 3 MB
g++ 03_rand_00.in :heavy_check_mark: AC 4 ms 3 MB
g++ 03_rand_01.in :heavy_check_mark: AC 4 ms 3 MB
g++ 03_rand_02.in :heavy_check_mark: AC 5 ms 3 MB
g++ 03_rand_03.in :heavy_check_mark: AC 4 ms 3 MB
g++ 03_rand_04.in :heavy_check_mark: AC 5 ms 3 MB
g++ 03_rand_05.in :heavy_check_mark: AC 10 ms 4 MB
g++ 03_rand_06.in :heavy_check_mark: AC 36 ms 7 MB
g++ 03_rand_07.in :heavy_check_mark: AC 59 ms 9 MB
g++ 03_rand_08.in :heavy_check_mark: AC 95 ms 10 MB
g++ 03_rand_09.in :heavy_check_mark: AC 367 ms 31 MB
g++ 04_liner_01.in :heavy_check_mark: AC 900 ms 10 MB
g++ 04_liner_02.in :heavy_check_mark: AC 940 ms 10 MB
g++ 04_liner_03.in :heavy_check_mark: AC 70 ms 10 MB
g++ 04_liner_04.in :heavy_check_mark: AC 902 ms 14 MB
g++ 04_liner_05.in :heavy_check_mark: AC 915 ms 14 MB
g++ 04_liner_06.in :heavy_check_mark: AC 124 ms 14 MB
g++ 05_grid_00.in :heavy_check_mark: AC 543 ms 10 MB
g++ 05_grid_01.in :heavy_check_mark: AC 554 ms 13 MB
g++ 06_biased_00.in :heavy_check_mark: AC 970 ms 31 MB
g++ 06_biased_01.in :heavy_check_mark: AC 569 ms 36 MB
g++ 06_biased_02.in :heavy_check_mark: AC 389 ms 37 MB
Back to top page