Aho-Corasick (Tree/aho_corasick.hpp)
- View this file on GitHub
- View document part on GitHub
- Last update: 2026-05-19 21:35:42+09:00
- Include:
#include "Tree/aho_corasick.hpp"
Aho-Corasick
Aho-Corasick は、複数のパターン列をまとめて扱うためのオートマトンです。
Trie に suffix link を追加することで、現在までに読んだ列の末尾が、登録済みパターンのどの接頭辞に一致しているかを状態として管理できます。
このライブラリでは、T 型の値からなるパターン列を登録し、列を先頭から順に読んだときに、いずれかのパターンが末尾に出現しているかを判定できます。
文字列に対して使う場合は AhoCorasick<char> として利用できます。
主に以下の用途に使えます。
- 複数パターンの文字列検索
- 禁止パターンを含まない列の数え上げ
- Trie 上の suffix link の構築
- パターン集合に対する状態遷移オートマトンの構築
コンストラクタ
AhoCorasick()
空の Aho-Corasick オートマトンを作成します。
初期状態では、内部の Trie は root ノードのみを持ちます。
計算量
- $O(1)$
insert
template <typename Container>
int insert(const Container& pattern)
パターン列 pattern を追加します。
戻り値として、pattern に対応する終端ノード ID を返します。
Container::value_type は T と一致している必要があります。
insert を呼び出した後は、再度 build() を呼ぶ必要があります。
制約
-
Container::value_typeがTと一致すること -
Tはunordered_map<T, int>のキーとして利用できること
計算量
- $O(\lvert pattern \rvert)$
build
void build()
登録済みパターンから suffix link とマッチ情報を構築します。
各ノード v に対して、suffix_link[v] は v に対応する列の真の接尾辞のうち、Trie に存在する最長のものに対応するノード ID です。
また、matched[v] は、状態 v に到達した時点で、いずれかの登録済みパターンが末尾に出現しているかを表します。
move、is_matched、get_suffix_link を使う前に、この関数を呼ぶ必要があります。
計算量
- $O(M \alpha)$
ここで、M は Trie のノード数、\alpha は各ノードから出ている遷移数の合計に対する走査量です。
実装上は、Trie に存在する辺を BFS で 1 回ずつ処理します。
move
int move(int node_id, const T& value) const
現在状態 node_id から値 value を 1 つ読んだ後の状態を返します。
直接の遷移が存在しない場合は suffix link をたどり、遷移可能な最長の接尾辞状態へ移動します。
どの状態からも遷移できない場合は root に戻ります。
制約
-
build()が呼ばれていること - $0 \leq node_id < size()$
計算量
- 最悪 $O(\text{状態の深さ})$
is_matched
bool is_matched(int node_id) const
状態 node_id に到達した時点で、いずれかの登録済みパターンが末尾に出現しているかを返します。
例えば、パターンとして "bc" が登録されている場合、現在までに読んだ文字列の末尾が "bc" であれば true を返します。
また、現在状態そのものがパターン終端でなくても、suffix link をたどった先にパターン終端がある場合は true になります。
制約
-
build()が呼ばれていること - $0 \leq node_id < size()$
計算量
- $O(1)$
get_suffix_link
int get_suffix_link(int node_id) const
状態 node_id の suffix link 先のノード ID を返します。
suffix link は、現在ノードに対応する列の真の接尾辞のうち、Trie に存在する最長のものを表します。
制約
-
build()が呼ばれていること - $0 \leq node_id < size()$
計算量
- $O(1)$
parent
int parent(int node_id) const
Trie 上での親ノード ID を返します。
root の parent は -1 です。
制約
- $0 \leq node_id < size()$
計算量
- $O(1)$
transition_value
T transition_value(int node_id) const
Trie 上で、親ノードから node_id へ遷移するときに使った値を返します。
root では意味を持ちません。
制約
- $0 \leq node_id < size()$
計算量
- $O(1)$
size
int size() const
内部の Trie のノード数を返します。
計算量
- $O(1)$
Depends on
Verified with
Code
#include <bits/stdc++.h>
using namespace std;
#include "Tree/trie.hpp"
template <class T>
struct AhoCorasick
{
Trie<T> trie;
// suffix_link[v]:
// str(v) の真の接尾辞のうち、Trie に存在する最長のものに対応するノードID
vector<int> suffix_link;
// matched[v]:
// v に到達した時点で、何らかの登録済みパターンが末尾にマッチしているか
vector<bool> matched;
// build 後に true
bool built = false;
AhoCorasick() = default;
template <typename Container>
requires same_as<typename Container::value_type, T>
int insert(const Container &pattern)
{
built = false;
return trie.insert(pattern);
}
void build()
{
const int node_count = trie.nodes.size();
suffix_link.assign(node_count, 0);
matched.assign(node_count, false);
for (int node_id = 0; node_id < node_count; node_id++)
{
matched[node_id] = trie.nodes[node_id].end_count > 0;
}
queue<int> que;
// root の子の suffix link は root
for (const auto &[value, child] : trie.nodes[0].next_node_ids)
{
suffix_link[child] = 0;
que.push(child);
}
while (!que.empty())
{
int node_id = que.front();
que.pop();
// suffix link 先でマッチしているなら、このノードもマッチ状態
matched[node_id] = matched[node_id] || matched[suffix_link[node_id]];
for (const auto &[value, child] : trie.nodes[node_id].next_node_ids)
{
suffix_link[child] = find_next(suffix_link[node_id], value);
que.push(child);
}
}
built = true;
}
int move(int node_id, const T &value) const
{
assert(built);
return find_next(node_id, value);
}
bool is_matched(int node_id) const
{
assert(built);
return matched[node_id];
}
int get_suffix_link(int node_id) const
{
assert(built);
return suffix_link[node_id];
}
int parent(int node_id) const
{
return trie.nodes[node_id].parent;
}
T transition_value(int node_id) const
{
return trie.nodes[node_id].transition_value;
}
int size() const
{
return trie.nodes.size();
}
private:
int find_next(int node_id, const T &value) const
{
while (node_id != 0 && !trie.nodes[node_id].next_node_ids.contains(value))
{
node_id = suffix_link[node_id];
}
if (trie.nodes[node_id].next_node_ids.contains(value))
{
return trie.nodes[node_id].next_node_ids.at(value);
}
return 0;
}
};
#line 1 "Tree/aho_corasick.hpp"
#include <bits/stdc++.h>
using namespace std;
#line 2 "Tree/trie.hpp"
using namespace std;
template <class T>
struct Trie
{
struct Node
{
// 子ノードへの遷移
unordered_map<T, int> next_node_ids;
// 親ノードのID
// root の parent は -1
int parent = -1;
// parent からこのノードへ遷移するときに使った値
// root では意味を持たない
T transition_value{};
// このノードを通過する文字列の個数
int pass_count = 0;
// このノードで終端する文字列の個数
int end_count = 0;
Node() {};
Node(int parent, const T &transition_value)
: parent(parent), transition_value(transition_value) {};
};
vector<Node> nodes;
Trie() : nodes()
{
nodes.push_back(Node());
};
template <typename Container>
requires same_as<typename Container::value_type, T>
int insert(const Container &v)
{
int node_id = 0;
nodes[node_id].pass_count++;
for (const auto &c : v)
{
if (!nodes[node_id].next_node_ids.contains(c))
{
int next_node_id = nodes.size();
nodes[node_id].next_node_ids[c] = next_node_id;
// 新しく作るノードは、現在のノードを親として持つ
nodes.push_back(Node(node_id, c));
}
node_id = nodes[node_id].next_node_ids[c];
nodes[node_id].pass_count++;
}
nodes[node_id].end_count++;
// 挿入した列に対応する終端ノードIDを返す
return node_id;
}
template <typename Container, typename Func>
requires same_as<typename Container::value_type, T> && invocable<Func, int, bool>
void iterate_prefix(const Container &v, Func f)
{
int node_id = 0;
for (auto itr = v.begin(); itr != v.end(); itr++)
{
if (!nodes[node_id].next_node_ids.contains(*itr))
{
return;
}
node_id = nodes[node_id].next_node_ids[*itr];
// 第2引数は、現在のノードが入力列 v の終端かどうか
f(node_id, next(itr) == v.end());
}
}
template <typename Container>
requires same_as<typename Container::value_type, T>
void erase_subtrie(const Container &prefix)
{
int node_id = 0;
vector<int> ancestors;
for (const auto &c : prefix)
{
ancestors.push_back(node_id);
if (!nodes[node_id].next_node_ids.contains(c))
{
return;
}
node_id = nodes[node_id].next_node_ids[c];
}
int removed_pass_count = nodes[node_id].pass_count;
// prefix 以下の部分木を論理的に削除する
nodes[node_id].pass_count = 0;
nodes[node_id].end_count = 0;
nodes[node_id].next_node_ids.clear();
// 祖先の通過回数から、削除した部分木の通過回数を引く
for (const auto &ancestor : ancestors)
{
nodes[ancestor].pass_count -= removed_pass_count;
}
}
};
#line 5 "Tree/aho_corasick.hpp"
template <class T>
struct AhoCorasick
{
Trie<T> trie;
// suffix_link[v]:
// str(v) の真の接尾辞のうち、Trie に存在する最長のものに対応するノードID
vector<int> suffix_link;
// matched[v]:
// v に到達した時点で、何らかの登録済みパターンが末尾にマッチしているか
vector<bool> matched;
// build 後に true
bool built = false;
AhoCorasick() = default;
template <typename Container>
requires same_as<typename Container::value_type, T>
int insert(const Container &pattern)
{
built = false;
return trie.insert(pattern);
}
void build()
{
const int node_count = trie.nodes.size();
suffix_link.assign(node_count, 0);
matched.assign(node_count, false);
for (int node_id = 0; node_id < node_count; node_id++)
{
matched[node_id] = trie.nodes[node_id].end_count > 0;
}
queue<int> que;
// root の子の suffix link は root
for (const auto &[value, child] : trie.nodes[0].next_node_ids)
{
suffix_link[child] = 0;
que.push(child);
}
while (!que.empty())
{
int node_id = que.front();
que.pop();
// suffix link 先でマッチしているなら、このノードもマッチ状態
matched[node_id] = matched[node_id] || matched[suffix_link[node_id]];
for (const auto &[value, child] : trie.nodes[node_id].next_node_ids)
{
suffix_link[child] = find_next(suffix_link[node_id], value);
que.push(child);
}
}
built = true;
}
int move(int node_id, const T &value) const
{
assert(built);
return find_next(node_id, value);
}
bool is_matched(int node_id) const
{
assert(built);
return matched[node_id];
}
int get_suffix_link(int node_id) const
{
assert(built);
return suffix_link[node_id];
}
int parent(int node_id) const
{
return trie.nodes[node_id].parent;
}
T transition_value(int node_id) const
{
return trie.nodes[node_id].transition_value;
}
int size() const
{
return trie.nodes.size();
}
private:
int find_next(int node_id, const T &value) const
{
while (node_id != 0 && !trie.nodes[node_id].next_node_ids.contains(value))
{
node_id = suffix_link[node_id];
}
if (trie.nodes[node_id].next_node_ids.contains(value))
{
return trie.nodes[node_id].next_node_ids.at(value);
}
return 0;
}
};