:heavy_check_mark: Trie (Tree/trie.hpp)

Trie

Trie は、複数の列を木構造として管理するデータ構造です。
各辺に値を対応させ、root からあるノードまでの経路が 1 つの列の接頭辞を表します。

このライブラリでは、T 型の値からなる列を追加し、接頭辞に対応するノードをたどることができます。
文字列に対して使う場合は Trie<char> として利用できます。

主に以下の用途に使えます。

  • 複数の文字列や列の接頭辞を管理する
  • ある列の接頭辞が Trie 上に存在するかを順に調べる
  • 挿入した列に対応する終端ノード ID を取得する
  • ある接頭辞以下の部分木を削除する

Node

struct Node

Trie の各ノードを表す構造体です。

各ノードは以下の情報を持ちます。

  • 子ノードへの遷移
  • 親ノードの ID
  • 親からこのノードへ遷移するときに使った値
  • このノードを通過する列の個数
  • このノードで終端する列の個数

コンストラクタ

Trie()

空の Trie を作成します。
初期状態では、root ノードのみを持ちます。

計算量

  • $O(1)$

insert

template <typename Container>
int insert(const Container& v)

v を Trie に追加します。
戻り値として、v に対応する終端ノード ID を返します。

Container::value_typeT と一致している必要があります。

制約

  • Container::value_typeT と一致すること
  • Tunordered_map<T, int> のキーとして利用できること

計算量

  • $O(\lvert v \rvert)$

iterate_prefix

template <typename Container, typename Func>
void iterate_prefix(const Container& v, Func f)

v の接頭辞に対応するノードを順にたどり、それぞれのノードに対して f を実行します。

f は次の形式で呼び出されます。

f(node_id, is_last)
  • node_id:現在の接頭辞に対応するノード ID
  • is_last:現在の接頭辞が v 全体と一致するかどうか

途中で対応する遷移が存在しなくなった場合、その時点で処理を終了します。

制約

  • Container::value_typeT と一致すること
  • Funcf(int, bool) として呼び出せること

計算量

  • $O(\lvert v \rvert)$

erase_subtrie

template <typename Container>
void erase_subtrie(const Container& prefix)

prefix に対応するノード以下の部分木を論理的に削除します。

具体的には、prefix に対応するノードについて、

pass_count = 0
end_count = 0
next_node_ids.clear()

を行います。
また、祖先ノードの pass_count から、削除した部分木の pass_count を引きます。

prefix が Trie に存在しない場合は何もしません。

この関数は、nodes 配列からノード自体を削除するわけではありません。
そのため、削除後も過去に作成されたノード ID が再利用されることはありません。

制約

  • Container::value_typeT と一致すること

計算量

  • $O(\lvert prefix \rvert)$

Required by

Verified with

Code

#include <bits/stdc++.h>
using namespace std;

template <class T>
struct Trie
{
    struct Node
    {
        // 子ノードへの遷移
        unordered_map<T, int> next_node_ids;

        // 親ノードのID
        // root の parent は -1
        int parent = -1;

        // parent からこのノードへ遷移するときに使った値
        // root では意味を持たない
        T transition_value{};

        // このノードを通過する文字列の個数
        int pass_count = 0;

        // このノードで終端する文字列の個数
        int end_count = 0;

        Node() {};

        Node(int parent, const T &transition_value)
            : parent(parent), transition_value(transition_value) {};
    };

    vector<Node> nodes;

    Trie() : nodes()
    {
        nodes.push_back(Node());
    };

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    int insert(const Container &v)
    {
        int node_id = 0;
        nodes[node_id].pass_count++;

        for (const auto &c : v)
        {
            if (!nodes[node_id].next_node_ids.contains(c))
            {
                int next_node_id = nodes.size();
                nodes[node_id].next_node_ids[c] = next_node_id;

                // 新しく作るノードは、現在のノードを親として持つ
                nodes.push_back(Node(node_id, c));
            }

            node_id = nodes[node_id].next_node_ids[c];
            nodes[node_id].pass_count++;
        }

        nodes[node_id].end_count++;

        // 挿入した列に対応する終端ノードIDを返す
        return node_id;
    }

    template <typename Container, typename Func>
        requires same_as<typename Container::value_type, T> && invocable<Func, int, bool>
    void iterate_prefix(const Container &v, Func f)
    {
        int node_id = 0;

        for (auto itr = v.begin(); itr != v.end(); itr++)
        {
            if (!nodes[node_id].next_node_ids.contains(*itr))
            {
                return;
            }

            node_id = nodes[node_id].next_node_ids[*itr];

            // 第2引数は、現在のノードが入力列 v の終端かどうか
            f(node_id, next(itr) == v.end());
        }
    }

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    void erase_subtrie(const Container &prefix)
    {
        int node_id = 0;
        vector<int> ancestors;

        for (const auto &c : prefix)
        {
            ancestors.push_back(node_id);

            if (!nodes[node_id].next_node_ids.contains(c))
            {
                return;
            }

            node_id = nodes[node_id].next_node_ids[c];
        }

        int removed_pass_count = nodes[node_id].pass_count;

        // prefix 以下の部分木を論理的に削除する
        nodes[node_id].pass_count = 0;
        nodes[node_id].end_count = 0;
        nodes[node_id].next_node_ids.clear();

        // 祖先の通過回数から、削除した部分木の通過回数を引く
        for (const auto &ancestor : ancestors)
        {
            nodes[ancestor].pass_count -= removed_pass_count;
        }
    }
};
#line 1 "Tree/trie.hpp"
#include <bits/stdc++.h>
using namespace std;

template <class T>
struct Trie
{
    struct Node
    {
        // 子ノードへの遷移
        unordered_map<T, int> next_node_ids;

        // 親ノードのID
        // root の parent は -1
        int parent = -1;

        // parent からこのノードへ遷移するときに使った値
        // root では意味を持たない
        T transition_value{};

        // このノードを通過する文字列の個数
        int pass_count = 0;

        // このノードで終端する文字列の個数
        int end_count = 0;

        Node() {};

        Node(int parent, const T &transition_value)
            : parent(parent), transition_value(transition_value) {};
    };

    vector<Node> nodes;

    Trie() : nodes()
    {
        nodes.push_back(Node());
    };

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    int insert(const Container &v)
    {
        int node_id = 0;
        nodes[node_id].pass_count++;

        for (const auto &c : v)
        {
            if (!nodes[node_id].next_node_ids.contains(c))
            {
                int next_node_id = nodes.size();
                nodes[node_id].next_node_ids[c] = next_node_id;

                // 新しく作るノードは、現在のノードを親として持つ
                nodes.push_back(Node(node_id, c));
            }

            node_id = nodes[node_id].next_node_ids[c];
            nodes[node_id].pass_count++;
        }

        nodes[node_id].end_count++;

        // 挿入した列に対応する終端ノードIDを返す
        return node_id;
    }

    template <typename Container, typename Func>
        requires same_as<typename Container::value_type, T> && invocable<Func, int, bool>
    void iterate_prefix(const Container &v, Func f)
    {
        int node_id = 0;

        for (auto itr = v.begin(); itr != v.end(); itr++)
        {
            if (!nodes[node_id].next_node_ids.contains(*itr))
            {
                return;
            }

            node_id = nodes[node_id].next_node_ids[*itr];

            // 第2引数は、現在のノードが入力列 v の終端かどうか
            f(node_id, next(itr) == v.end());
        }
    }

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    void erase_subtrie(const Container &prefix)
    {
        int node_id = 0;
        vector<int> ancestors;

        for (const auto &c : prefix)
        {
            ancestors.push_back(node_id);

            if (!nodes[node_id].next_node_ids.contains(c))
            {
                return;
            }

            node_id = nodes[node_id].next_node_ids[c];
        }

        int removed_pass_count = nodes[node_id].pass_count;

        // prefix 以下の部分木を論理的に削除する
        nodes[node_id].pass_count = 0;
        nodes[node_id].end_count = 0;
        nodes[node_id].next_node_ids.clear();

        // 祖先の通過回数から、削除した部分木の通過回数を引く
        for (const auto &ancestor : ancestors)
        {
            nodes[ancestor].pass_count -= removed_pass_count;
        }
    }
};
Back to top page

This site uses Just the Docs, a documentation theme for Jekyll.