:heavy_check_mark: Aho-Corasick (Tree/aho_corasick.hpp)

Aho-Corasick

Aho-Corasick は、複数のパターン列をまとめて扱うためのオートマトンです。
Trie に suffix link を追加することで、現在までに読んだ列の末尾が、登録済みパターンのどの接頭辞に一致しているかを状態として管理できます。

このライブラリでは、T 型の値からなるパターン列を登録し、列を先頭から順に読んだときに、いずれかのパターンが末尾に出現しているかを判定できます。
文字列に対して使う場合は AhoCorasick<char> として利用できます。

主に以下の用途に使えます。

  • 複数パターンの文字列検索
  • 禁止パターンを含まない列の数え上げ
  • Trie 上の suffix link の構築
  • パターン集合に対する状態遷移オートマトンの構築

コンストラクタ

AhoCorasick()

空の Aho-Corasick オートマトンを作成します。
初期状態では、内部の Trie は root ノードのみを持ちます。

計算量

  • $O(1)$

insert

template <typename Container>
int insert(const Container& pattern)

パターン列 pattern を追加します。
戻り値として、pattern に対応する終端ノード ID を返します。

Container::value_typeT と一致している必要があります。

insert を呼び出した後は、再度 build() を呼ぶ必要があります。

制約

  • Container::value_typeT と一致すること
  • Tunordered_map<T, int> のキーとして利用できること

計算量

  • $O(\lvert pattern \rvert)$

build

void build()

登録済みパターンから suffix link とマッチ情報を構築します。

各ノード v に対して、suffix_link[v]v に対応する列の真の接尾辞のうち、Trie に存在する最長のものに対応するノード ID です。
また、matched[v] は、状態 v に到達した時点で、いずれかの登録済みパターンが末尾に出現しているかを表します。

moveis_matchedget_suffix_link を使う前に、この関数を呼ぶ必要があります。

計算量

  • $O(M \alpha)$

ここで、M は Trie のノード数、\alpha は各ノードから出ている遷移数の合計に対する走査量です。
実装上は、Trie に存在する辺を BFS で 1 回ずつ処理します。


move

int move(int node_id, const T& value) const

現在状態 node_id から値 value を 1 つ読んだ後の状態を返します。

直接の遷移が存在しない場合は suffix link をたどり、遷移可能な最長の接尾辞状態へ移動します。
どの状態からも遷移できない場合は root に戻ります。

制約

  • build() が呼ばれていること
  • $0 \leq node_id < size()$

計算量

  • 最悪 $O(\text{状態の深さ})$

is_matched

bool is_matched(int node_id) const

状態 node_id に到達した時点で、いずれかの登録済みパターンが末尾に出現しているかを返します。

例えば、パターンとして "bc" が登録されている場合、現在までに読んだ文字列の末尾が "bc" であれば true を返します。
また、現在状態そのものがパターン終端でなくても、suffix link をたどった先にパターン終端がある場合は true になります。

制約

  • build() が呼ばれていること
  • $0 \leq node_id < size()$

計算量

  • $O(1)$

int get_suffix_link(int node_id) const

状態 node_id の suffix link 先のノード ID を返します。

suffix link は、現在ノードに対応する列の真の接尾辞のうち、Trie に存在する最長のものを表します。

制約

  • build() が呼ばれていること
  • $0 \leq node_id < size()$

計算量

  • $O(1)$

parent

int parent(int node_id) const

Trie 上での親ノード ID を返します。
root の parent-1 です。

制約

  • $0 \leq node_id < size()$

計算量

  • $O(1)$

transition_value

T transition_value(int node_id) const

Trie 上で、親ノードから node_id へ遷移するときに使った値を返します。
root では意味を持ちません。

制約

  • $0 \leq node_id < size()$

計算量

  • $O(1)$

size

int size() const

内部の Trie のノード数を返します。

計算量

  • $O(1)$

Depends on

Verified with

Code

#include <bits/stdc++.h>
using namespace std;

#include "Tree/trie.hpp"

template <class T>
struct AhoCorasick
{
    Trie<T> trie;

    // suffix_link[v]:
    // str(v) の真の接尾辞のうち、Trie に存在する最長のものに対応するノードID
    vector<int> suffix_link;

    // matched[v]:
    // v に到達した時点で、何らかの登録済みパターンが末尾にマッチしているか
    vector<bool> matched;

    // build 後に true
    bool built = false;

    AhoCorasick() = default;

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    int insert(const Container &pattern)
    {
        built = false;
        return trie.insert(pattern);
    }

    void build()
    {
        const int node_count = trie.nodes.size();

        suffix_link.assign(node_count, 0);
        matched.assign(node_count, false);

        for (int node_id = 0; node_id < node_count; node_id++)
        {
            matched[node_id] = trie.nodes[node_id].end_count > 0;
        }

        queue<int> que;

        // root の子の suffix link は root
        for (const auto &[value, child] : trie.nodes[0].next_node_ids)
        {
            suffix_link[child] = 0;
            que.push(child);
        }

        while (!que.empty())
        {
            int node_id = que.front();
            que.pop();

            // suffix link 先でマッチしているなら、このノードもマッチ状態
            matched[node_id] = matched[node_id] || matched[suffix_link[node_id]];

            for (const auto &[value, child] : trie.nodes[node_id].next_node_ids)
            {
                suffix_link[child] = find_next(suffix_link[node_id], value);
                que.push(child);
            }
        }

        built = true;
    }

    int move(int node_id, const T &value) const
    {
        assert(built);

        return find_next(node_id, value);
    }

    bool is_matched(int node_id) const
    {
        assert(built);

        return matched[node_id];
    }

    int get_suffix_link(int node_id) const
    {
        assert(built);

        return suffix_link[node_id];
    }

    int parent(int node_id) const
    {
        return trie.nodes[node_id].parent;
    }

    T transition_value(int node_id) const
    {
        return trie.nodes[node_id].transition_value;
    }

    int size() const
    {
        return trie.nodes.size();
    }

private:
    int find_next(int node_id, const T &value) const
    {
        while (node_id != 0 && !trie.nodes[node_id].next_node_ids.contains(value))
        {
            node_id = suffix_link[node_id];
        }

        if (trie.nodes[node_id].next_node_ids.contains(value))
        {
            return trie.nodes[node_id].next_node_ids.at(value);
        }

        return 0;
    }
};
#line 1 "Tree/aho_corasick.hpp"
#include <bits/stdc++.h>
using namespace std;

#line 2 "Tree/trie.hpp"
using namespace std;

template <class T>
struct Trie
{
    struct Node
    {
        // 子ノードへの遷移
        unordered_map<T, int> next_node_ids;

        // 親ノードのID
        // root の parent は -1
        int parent = -1;

        // parent からこのノードへ遷移するときに使った値
        // root では意味を持たない
        T transition_value{};

        // このノードを通過する文字列の個数
        int pass_count = 0;

        // このノードで終端する文字列の個数
        int end_count = 0;

        Node() {};

        Node(int parent, const T &transition_value)
            : parent(parent), transition_value(transition_value) {};
    };

    vector<Node> nodes;

    Trie() : nodes()
    {
        nodes.push_back(Node());
    };

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    int insert(const Container &v)
    {
        int node_id = 0;
        nodes[node_id].pass_count++;

        for (const auto &c : v)
        {
            if (!nodes[node_id].next_node_ids.contains(c))
            {
                int next_node_id = nodes.size();
                nodes[node_id].next_node_ids[c] = next_node_id;

                // 新しく作るノードは、現在のノードを親として持つ
                nodes.push_back(Node(node_id, c));
            }

            node_id = nodes[node_id].next_node_ids[c];
            nodes[node_id].pass_count++;
        }

        nodes[node_id].end_count++;

        // 挿入した列に対応する終端ノードIDを返す
        return node_id;
    }

    template <typename Container, typename Func>
        requires same_as<typename Container::value_type, T> && invocable<Func, int, bool>
    void iterate_prefix(const Container &v, Func f)
    {
        int node_id = 0;

        for (auto itr = v.begin(); itr != v.end(); itr++)
        {
            if (!nodes[node_id].next_node_ids.contains(*itr))
            {
                return;
            }

            node_id = nodes[node_id].next_node_ids[*itr];

            // 第2引数は、現在のノードが入力列 v の終端かどうか
            f(node_id, next(itr) == v.end());
        }
    }

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    void erase_subtrie(const Container &prefix)
    {
        int node_id = 0;
        vector<int> ancestors;

        for (const auto &c : prefix)
        {
            ancestors.push_back(node_id);

            if (!nodes[node_id].next_node_ids.contains(c))
            {
                return;
            }

            node_id = nodes[node_id].next_node_ids[c];
        }

        int removed_pass_count = nodes[node_id].pass_count;

        // prefix 以下の部分木を論理的に削除する
        nodes[node_id].pass_count = 0;
        nodes[node_id].end_count = 0;
        nodes[node_id].next_node_ids.clear();

        // 祖先の通過回数から、削除した部分木の通過回数を引く
        for (const auto &ancestor : ancestors)
        {
            nodes[ancestor].pass_count -= removed_pass_count;
        }
    }
};
#line 5 "Tree/aho_corasick.hpp"

template <class T>
struct AhoCorasick
{
    Trie<T> trie;

    // suffix_link[v]:
    // str(v) の真の接尾辞のうち、Trie に存在する最長のものに対応するノードID
    vector<int> suffix_link;

    // matched[v]:
    // v に到達した時点で、何らかの登録済みパターンが末尾にマッチしているか
    vector<bool> matched;

    // build 後に true
    bool built = false;

    AhoCorasick() = default;

    template <typename Container>
        requires same_as<typename Container::value_type, T>
    int insert(const Container &pattern)
    {
        built = false;
        return trie.insert(pattern);
    }

    void build()
    {
        const int node_count = trie.nodes.size();

        suffix_link.assign(node_count, 0);
        matched.assign(node_count, false);

        for (int node_id = 0; node_id < node_count; node_id++)
        {
            matched[node_id] = trie.nodes[node_id].end_count > 0;
        }

        queue<int> que;

        // root の子の suffix link は root
        for (const auto &[value, child] : trie.nodes[0].next_node_ids)
        {
            suffix_link[child] = 0;
            que.push(child);
        }

        while (!que.empty())
        {
            int node_id = que.front();
            que.pop();

            // suffix link 先でマッチしているなら、このノードもマッチ状態
            matched[node_id] = matched[node_id] || matched[suffix_link[node_id]];

            for (const auto &[value, child] : trie.nodes[node_id].next_node_ids)
            {
                suffix_link[child] = find_next(suffix_link[node_id], value);
                que.push(child);
            }
        }

        built = true;
    }

    int move(int node_id, const T &value) const
    {
        assert(built);

        return find_next(node_id, value);
    }

    bool is_matched(int node_id) const
    {
        assert(built);

        return matched[node_id];
    }

    int get_suffix_link(int node_id) const
    {
        assert(built);

        return suffix_link[node_id];
    }

    int parent(int node_id) const
    {
        return trie.nodes[node_id].parent;
    }

    T transition_value(int node_id) const
    {
        return trie.nodes[node_id].transition_value;
    }

    int size() const
    {
        return trie.nodes.size();
    }

private:
    int find_next(int node_id, const T &value) const
    {
        while (node_id != 0 && !trie.nodes[node_id].next_node_ids.contains(value))
        {
            node_id = suffix_link[node_id];
        }

        if (trie.nodes[node_id].next_node_ids.contains(value))
        {
            return trie.nodes[node_id].next_node_ids.at(value);
        }

        return 0;
    }
};
Back to top page

This site uses Just the Docs, a documentation theme for Jekyll.