:heavy_check_mark: KDTree (DataStructure/kd_tree.hpp)

KDTree

K 次元空間の点群から k-d tree を構築し、高速に最近傍検索や範囲検索を行うデータ構造です。
構築には平均で $O(N \log N)$ の時間がかかるため、点群の動的な追加・削除にはあまり向いていません。
一方、一度構築してしまえば、最近傍検索や範囲検索を平均 $O(\log N)$ で実行できます。

テンプレートパラメータ

template <typename T, size_t K>
  • T: 座標の型(数値型)
  • K: 次元数

検索結果

PointWithID

struct PointWithID {
    array<T, K> point;
    size_t id;
};

kd木による検索結果は座標 point と一意な識別子 id をまとめた構造体によって返されます。 id は追加された順に割り振られます。

コンストラクタ

KDTree(const vector<array<T, K>> &points)

点群 points から k-d tree を構築します。各点には入力順に 0 から points.size()-1 までの id が自動付与されます。

計算量

  • 平均: $O(N \log N)$

最近傍検索 (nearest)

optional<PointWithID> nearest(
    const array<T, K> &target,
    function<bool(int)> check) const

フィルタ関数 check(id)true の点の中から、target に最も近い点を返します。該当点がない場合は nulloptが返されます。

計算量

  • 平均: $O(\log N)$
optional<PointWithID> nearest(
    const array<T, K> &target) const

check なし版。全点を対象に探索します。

計算量

  • 平均: $O(\log N)$
vector<PointWithID> range_search(
    const array<T, K> &lower,
    const array<T, K> &upper,
    function<bool(int)> check) const

矩形領域 [lower, upper] 内に含まれる点を返します。check(id) で追加フィルタが可能です。

計算量

  • 平均: $O(R + \log N)$
    • $R$: 結果点数
vector<PointWithID> range_search(
    const array<T, K> &lower,
    const array<T, K> &upper) const

check なし版。

k 個の最近傍検索 (nearest_k)

vector<PointWithID> nearest_k(
    const array<T, K> &target,
    size_t k,
    function<bool(int)> check) const

target に近い点を最大 k 個返します。check(id) でフィルタが可能です。

計算量

  • 平均: $O(k \log N)$
vector<PointWithID> nearest_k(
    const array<T, K> &target,
    size_t k) const

check なし版。

Verified with

Code

#include <bits/stdc++.h>
using namespace std;

template <typename T, size_t K>
class KDTree
{
public:
    // 座標と id をまとめた構造体
    struct PointWithID
    {
        array<T, K> point;
        size_t id;
    };

    // ノードは PointWithID の情報を持つ
    struct Node
    {
        PointWithID data;
        Node *left;
        Node *right;
        Node(const PointWithID &pwid)
            : data(pwid), left(nullptr), right(nullptr) {}
    };

    // コンストラクタ: 点群から kd 木を構築
    KDTree(const vector<array<T, K>> &points)
    {
        vector<PointWithID> pts;
        pts.reserve(points.size());
        for (size_t i = 0; i < points.size(); ++i)
        {
            pts.push_back({points[i], i});
        }
        root = build(pts.begin(), pts.end(), 0);
    }

    ~KDTree()
    {
        free_tree(root);
    }

    optional<PointWithID> nearest(const array<T, K> &target,
                                  function<bool(int)> check) const
    {
        Node *best = nullptr;
        T bestDist = numeric_limits<T>::max();
        nearest_search(root, target, 0, best, bestDist, check);
        if (best)
            return best->data;
        else
            return nullopt;
    }

    optional<PointWithID> nearest(const array<T, K> &target) const
    {
        return nearest(target, [](int)
                       { return true; });
    }

    vector<PointWithID> range_search(const array<T, K> &lower, const array<T, K> &upper,
                                     function<bool(int)> check) const
    {
        vector<PointWithID> results;
        range_search(root, 0, lower, upper, check, results);
        return results;
    }

    vector<PointWithID> range_search(const array<T, K> &lower, const array<T, K> &upper) const
    {
        return range_search(lower, upper, [](int)
                            { return true; });
    }

    /// target から近いものを最大 k 個返す(check で絞り込み可)
    vector<PointWithID> nearest_k(const array<T, K> &target, size_t k, function<bool(int)> check) const
    {
        if (k == 0 || root == nullptr)
        {
            return {};
        }

        MaxHeap pq;
        nearest_k_search(root, target, 0, k, check, pq);

        vector<PointWithID> res;
        res.reserve(pq.size());
        while (!pq.empty())
        {
            res.push_back(pq.top().second->data);
            pq.pop();
        }

        reverse(res.begin(), res.end());
        return res;
    }

    vector<PointWithID> nearest_k(const array<T, K> &target, size_t k) const
    {
        return nearest_k(target, k, [](int)
                         { return true; });
    }

private:
    using iterator = typename vector<PointWithID>::iterator;
    Node *root = nullptr;

    // kd木の構築 (中央値による分割)
    Node *build(iterator begin, iterator end, size_t depth)
    {
        if (begin >= end)
            return nullptr;
        size_t axis = depth % K;
        iterator mid = begin + distance(begin, end) / 2;
        nth_element(begin, mid, end, [axis](const PointWithID &a, const PointWithID &b)
                    { return a.point[axis] < b.point[axis]; });
        Node *node = new Node(*mid);
        node->left = build(begin, mid, depth + 1);
        node->right = build(mid + 1, end, depth + 1);
        return node;
    }

    // 再帰的な最近傍探索
    void nearest_search(Node *node, const array<T, K> &target, size_t depth,
                        Node *&best, T &bestDist, function<bool(int)> check) const
    {
        if (node == nullptr)
            return;
        size_t axis = depth % K;
        T d = distance_squared(node->data.point, target);

        if (d < bestDist && check(static_cast<int>(node->data.id)))
        {
            bestDist = d;
            best = node;
        }
        T diff = target[axis] - node->data.point[axis];
        Node *nearChild = (diff < 0) ? node->left : node->right;
        Node *farChild = (diff < 0) ? node->right : node->left;
        nearest_search(nearChild, target, depth + 1, best, bestDist, check);
        if (diff * diff < bestDist)
            nearest_search(farChild, target, depth + 1, best, bestDist, check);
    }

    // 再帰的な範囲探索
    void range_search(Node *node, size_t depth,
                      const array<T, K> &lower, const array<T, K> &upper,
                      function<bool(int)> check, vector<PointWithID> &results) const
    {
        if (node == nullptr)
            return;
        bool inside = true;
        for (size_t i = 0; i < K; i++)
        {
            if (node->data.point[i] < lower[i] || node->data.point[i] > upper[i])
            {
                inside = false;
                break;
            }
        }
        if (inside && check(static_cast<int>(node->data.id)))
            results.push_back(node->data);

        size_t axis = depth % K;
        if (lower[axis] <= node->data.point[axis])
            range_search(node->left, depth + 1, lower, upper, check, results);
        if (upper[axis] >= node->data.point[axis])
            range_search(node->right, depth + 1, lower, upper, check, results);
    }

    using QElem = std::pair<T, const Node *>;
    struct WorseFirst
    {
        bool operator()(const QElem &lhs, const QElem &rhs) const
        {
            return lhs.first < rhs.first;
        }
    };

    using MaxHeap = priority_queue<QElem, vector<QElem>, WorseFirst>;

    void nearest_k_search(
        const Node *node, const array<T, K> &target,
        size_t depth, size_t k,
        const function<bool(int)> &check, MaxHeap &pq) const
    {
        if (node == nullptr)
            return;

        size_t axis = depth % K;
        T dist2 = distance_squared(node->data.point, target);

        if (check(static_cast<int>(node->data.id)))
        {
            if (pq.size() < k)
            {
                pq.emplace(dist2, node);
            }
            else if (dist2 < pq.top().first)
            {
                pq.pop();
                pq.emplace(dist2, node);
            }
        }

        T diff = target[axis] - node->data.point[axis];
        const Node *nearChild = (diff < 0) ? node->left : node->right;
        const Node *farChild = (diff < 0) ? node->right : node->left;

        nearest_k_search(nearChild, target, depth + 1, k, check, pq);

        if (pq.size() < k || diff * diff < pq.top().first)
        {
            nearest_k_search(farChild, target, depth + 1, k, check, pq);
        }
    }

    // 2点間の二乗距離を計算
    T distance_squared(const array<T, K> &a, const array<T, K> &b) const
    {
        T dist = 0;
        for (size_t i = 0; i < K; i++)
        {
            T d = a[i] - b[i];
            dist += d * d;
        }
        return dist;
    }

    // kd木のメモリ解放
    void free_tree(Node *node)
    {
        if (node == nullptr)
            return;
        free_tree(node->left);
        free_tree(node->right);
        delete node;
    }
};
#line 1 "DataStructure/kd_tree.hpp"
#include <bits/stdc++.h>
using namespace std;

template <typename T, size_t K>
class KDTree
{
public:
    // 座標と id をまとめた構造体
    struct PointWithID
    {
        array<T, K> point;
        size_t id;
    };

    // ノードは PointWithID の情報を持つ
    struct Node
    {
        PointWithID data;
        Node *left;
        Node *right;
        Node(const PointWithID &pwid)
            : data(pwid), left(nullptr), right(nullptr) {}
    };

    // コンストラクタ: 点群から kd 木を構築
    KDTree(const vector<array<T, K>> &points)
    {
        vector<PointWithID> pts;
        pts.reserve(points.size());
        for (size_t i = 0; i < points.size(); ++i)
        {
            pts.push_back({points[i], i});
        }
        root = build(pts.begin(), pts.end(), 0);
    }

    ~KDTree()
    {
        free_tree(root);
    }

    optional<PointWithID> nearest(const array<T, K> &target,
                                  function<bool(int)> check) const
    {
        Node *best = nullptr;
        T bestDist = numeric_limits<T>::max();
        nearest_search(root, target, 0, best, bestDist, check);
        if (best)
            return best->data;
        else
            return nullopt;
    }

    optional<PointWithID> nearest(const array<T, K> &target) const
    {
        return nearest(target, [](int)
                       { return true; });
    }

    vector<PointWithID> range_search(const array<T, K> &lower, const array<T, K> &upper,
                                     function<bool(int)> check) const
    {
        vector<PointWithID> results;
        range_search(root, 0, lower, upper, check, results);
        return results;
    }

    vector<PointWithID> range_search(const array<T, K> &lower, const array<T, K> &upper) const
    {
        return range_search(lower, upper, [](int)
                            { return true; });
    }

    /// target から近いものを最大 k 個返す(check で絞り込み可)
    vector<PointWithID> nearest_k(const array<T, K> &target, size_t k, function<bool(int)> check) const
    {
        if (k == 0 || root == nullptr)
        {
            return {};
        }

        MaxHeap pq;
        nearest_k_search(root, target, 0, k, check, pq);

        vector<PointWithID> res;
        res.reserve(pq.size());
        while (!pq.empty())
        {
            res.push_back(pq.top().second->data);
            pq.pop();
        }

        reverse(res.begin(), res.end());
        return res;
    }

    vector<PointWithID> nearest_k(const array<T, K> &target, size_t k) const
    {
        return nearest_k(target, k, [](int)
                         { return true; });
    }

private:
    using iterator = typename vector<PointWithID>::iterator;
    Node *root = nullptr;

    // kd木の構築 (中央値による分割)
    Node *build(iterator begin, iterator end, size_t depth)
    {
        if (begin >= end)
            return nullptr;
        size_t axis = depth % K;
        iterator mid = begin + distance(begin, end) / 2;
        nth_element(begin, mid, end, [axis](const PointWithID &a, const PointWithID &b)
                    { return a.point[axis] < b.point[axis]; });
        Node *node = new Node(*mid);
        node->left = build(begin, mid, depth + 1);
        node->right = build(mid + 1, end, depth + 1);
        return node;
    }

    // 再帰的な最近傍探索
    void nearest_search(Node *node, const array<T, K> &target, size_t depth,
                        Node *&best, T &bestDist, function<bool(int)> check) const
    {
        if (node == nullptr)
            return;
        size_t axis = depth % K;
        T d = distance_squared(node->data.point, target);

        if (d < bestDist && check(static_cast<int>(node->data.id)))
        {
            bestDist = d;
            best = node;
        }
        T diff = target[axis] - node->data.point[axis];
        Node *nearChild = (diff < 0) ? node->left : node->right;
        Node *farChild = (diff < 0) ? node->right : node->left;
        nearest_search(nearChild, target, depth + 1, best, bestDist, check);
        if (diff * diff < bestDist)
            nearest_search(farChild, target, depth + 1, best, bestDist, check);
    }

    // 再帰的な範囲探索
    void range_search(Node *node, size_t depth,
                      const array<T, K> &lower, const array<T, K> &upper,
                      function<bool(int)> check, vector<PointWithID> &results) const
    {
        if (node == nullptr)
            return;
        bool inside = true;
        for (size_t i = 0; i < K; i++)
        {
            if (node->data.point[i] < lower[i] || node->data.point[i] > upper[i])
            {
                inside = false;
                break;
            }
        }
        if (inside && check(static_cast<int>(node->data.id)))
            results.push_back(node->data);

        size_t axis = depth % K;
        if (lower[axis] <= node->data.point[axis])
            range_search(node->left, depth + 1, lower, upper, check, results);
        if (upper[axis] >= node->data.point[axis])
            range_search(node->right, depth + 1, lower, upper, check, results);
    }

    using QElem = std::pair<T, const Node *>;
    struct WorseFirst
    {
        bool operator()(const QElem &lhs, const QElem &rhs) const
        {
            return lhs.first < rhs.first;
        }
    };

    using MaxHeap = priority_queue<QElem, vector<QElem>, WorseFirst>;

    void nearest_k_search(
        const Node *node, const array<T, K> &target,
        size_t depth, size_t k,
        const function<bool(int)> &check, MaxHeap &pq) const
    {
        if (node == nullptr)
            return;

        size_t axis = depth % K;
        T dist2 = distance_squared(node->data.point, target);

        if (check(static_cast<int>(node->data.id)))
        {
            if (pq.size() < k)
            {
                pq.emplace(dist2, node);
            }
            else if (dist2 < pq.top().first)
            {
                pq.pop();
                pq.emplace(dist2, node);
            }
        }

        T diff = target[axis] - node->data.point[axis];
        const Node *nearChild = (diff < 0) ? node->left : node->right;
        const Node *farChild = (diff < 0) ? node->right : node->left;

        nearest_k_search(nearChild, target, depth + 1, k, check, pq);

        if (pq.size() < k || diff * diff < pq.top().first)
        {
            nearest_k_search(farChild, target, depth + 1, k, check, pq);
        }
    }

    // 2点間の二乗距離を計算
    T distance_squared(const array<T, K> &a, const array<T, K> &b) const
    {
        T dist = 0;
        for (size_t i = 0; i < K; i++)
        {
            T d = a[i] - b[i];
            dist += d * d;
        }
        return dist;
    }

    // kd木のメモリ解放
    void free_tree(Node *node)
    {
        if (node == nullptr)
            return;
        free_tree(node->left);
        free_tree(node->right);
        delete node;
    }
};
Back to top page