Interval Set (DataStructure/interval_set.hpp)
- View this file on GitHub
- View document part on GitHub
- Last update: 2025-12-11 22:28:56+09:00
- Include:
#include "DataStructure/interval_set.hpp"
Interval Set
互いに交差しない半開区間 [L, R) の集合を管理するデータ構造です。
区間の追加時には自動的に隣接・重複する区間をマージし、
連結判定や区間長の取得、全体の長さの取得などができます。
- テンプレート引数
-
T: 区間端点の型 -
Compare:std::set用の比較関数(デフォルトstd::less<T>)
-
- 制約
-
CompareはTに対してstd::strict_weak_orderを満たす -
T同士の差T - Tとstd::absが利用できる(長さ計算用)
-
- 内部表現
- 区間はすべて 半開区間
[L, R)として保存されます - どの 2 区間も互いに交差・隣接しないように保たれます
- 区間はすべて 半開区間
コンストラクタ
IntervalSet<T, Compare>()
空の区間集合を構築します。
計算量
- $O(1)$
add_interval
std::pair<T, T> add_interval(const T& L, const T& R)
std::pair<T, T> add_interval(const std::pair<T, T>& LR)
半開区間 [L, R) を追加し、必要に応じて既存の区間とマージします。
戻り値は、追加後に [L, R) を含んでいる 1 つの大きな区間 [l, r) です。
簡単な例
- もともと
[1, 3) [5, 7)がある状態でadd_interval(2, 6)を呼ぶと、
区間は[1, 7)になります。
制約
-
L < RまたはLR.first < LR.second
計算量
- 変更される区間数を $k$、全体の区間数を $M$ とすると
$O(k \log M)$
erase_interval
void erase_interval(const T& L, const T& R)
void erase_interval(const std::pair<T, T>& LR)
区間集合から半開区間 [L, R) を 集合として差し引きます。
部分的に重なる区間は分割され、完全に含まれる区間は削除されます。
簡単な例
- もともと
[1, 3) [5, 8) [10, 13)がある状態で
erase_interval(2, 7)を呼ぶと、
区間は[1, 2) [7, 8) [10, 13)になります。
制約
-
L < RまたはLR.first < LR.second
計算量
- 変更される区間数を $k$、全体の区間数を $M$ とすると
$O(k \log M)$
same
bool same(T u, T v) const
u と v が同じ区間(同じ連結成分)に属するかを判定します。
どちらか一方でもどの区間にも属していなければ false を返します。
計算量
- $O(\log M)$
length
T length(const T& u) const
u が属する区間 [L, R) の長さ R - L を返します。
u がどの区間にも属していない場合は 0 を返します。
T が整数型のときは、「区間に含まれる要素数」として扱えます。
計算量
- $O(\log M)$
get_sum_lentgth
T get_sum_lentgth() const
すべての区間長の合計を返します。
計算量
- $O(1)$
print(デバッグ用途)
void print() const
保持している区間を [L, R) 形式で標準出力に出力します。
計算量
- 区間数を $M$ とすると $O(M)$
Verified with
test/DataStructure/interval-set/yukicoder-2292.cpp
test/DataStructure/interval-set/yukicoder-3017.cpp
test/DataStructure/interval-set/yukicoder-674.cpp
Code
#include <bits/stdc++.h>
using namespace std;
// contains
// find_intervalのpublic化
// 次の区間までの
template <typename T, typename Compare = std::less<T>>
requires std::strict_weak_order<Compare, T, T>
class IntervalSet
{
private:
// すべての区間長の合計
T sum_length = 0;
// 互いに交差しない半開区間 [first, second) の集合
set<pair<T, T>> intervals;
// 点 x を含む区間のイテレータ(なければ end)
typename set<pair<T, T>>::const_iterator
find_interval(const T &x) const
{
// first >= x となる最初の区間
auto itr = intervals.lower_bound({x, x});
// ひとつ左だけ、x を含みうる
if (itr != intervals.begin())
{
auto prev_itr = prev(itr);
if (prev_itr->first <= x && x < prev_itr->second)
return prev_itr;
}
// ぴったり x から始まる区間
if (itr != intervals.end() && itr->first <= x && x < itr->second)
return itr;
return intervals.end();
}
public:
IntervalSet() {}
// [L, R) を追加し、必要なら隣接区間とマージする
pair<T, T> add_interval(const T &L, const T &R)
{
return add_interval({L, R});
}
// [L, R) を追加し、必要なら隣接区間とマージする
pair<T, T> add_interval(const pair<T, T> &LR)
{
assert(LR.first < LR.second);
auto [l, r] = LR;
// first >= l となる最初の区間
auto itr = intervals.lower_bound(LR);
// 左側に 1 つだけ重なりうる区間を見る
if (itr != intervals.begin())
{
auto prev_itr = prev(itr);
if (prev_itr->second >= l)
{
l = min(l, prev_itr->first);
r = max(r, prev_itr->second);
sum_length -= abs(prev_itr->second - prev_itr->first);
itr = intervals.erase(prev_itr);
}
}
// 先頭が r 以下の区間をすべてマージ
while (itr != intervals.end() && itr->first <= r)
{
l = min(l, itr->first);
r = max(r, itr->second);
sum_length -= abs(itr->second - itr->first);
itr = intervals.erase(itr);
}
intervals.insert({l, r});
sum_length += abs(r - l);
return {l, r};
}
// 区間集合から [L, R) を削除(集合としての差)
void erase_interval(const T &L, const T &R)
{
erase_interval({L, R});
}
// 区間集合から [L, R) を削除(集合としての差)
void erase_interval(const std::pair<T, T> &LR)
{
assert(LR.first < LR.second);
const auto [l, r] = LR;
if (intervals.empty())
return;
// first >= l となる最初の区間
auto itr = intervals.lower_bound(LR);
// 左側に 1 つだけ、[l, r) と交差しうる区間を見る
if (itr != intervals.begin())
{
auto prev_itr = prev(itr);
const auto [a, b] = *prev_itr;
// [a, b) と [l, r) が交差
if (b > l && a < r)
{
itr = intervals.erase(prev_itr);
sum_length -= abs(b - a);
if (a < l)
{
sum_length += abs(l - a);
intervals.insert({a, l});
}
if (r < b)
{
sum_length += abs(b - r);
intervals.insert({r, b});
}
}
}
// ここから先は first >= l が保証されている
while (itr != intervals.end())
{
const auto [a, b] = *itr;
if (a >= r)
break; // これ以降は交差しない
if (b <= r)
{
// 完全に [l, r) に含まれる区間
itr = intervals.erase(itr);
sum_length -= abs(b - a);
}
else
{
// a < r < b →右側 [r, b) だけ残す
intervals.erase(itr);
intervals.insert({r, b});
sum_length -= abs(b - a);
sum_length += abs(b - r);
break;
}
}
}
// u と v が同じ区間に属するか
bool same(T u, T v) const
{
if (intervals.empty())
return false;
if (v < u)
swap(u, v);
auto itr = find_interval(u);
if (itr == intervals.end())
return false;
return v < itr->second;
}
// u が属する区間の長さ(なければ 0)
T length(const T &u) const
{
auto itr = find_interval(u);
if (itr == intervals.end())
return T(0);
return itr->second - itr->first;
}
// すべての区間長の合計
T get_sum_lentgth() const
{
return sum_length;
}
void print() const
{
for (auto [l, r] : intervals)
{
cout << "[" << l << ", " << r << ") ";
}
cout << '\n';
}
};
#line 1 "DataStructure/interval_set.hpp"
#include <bits/stdc++.h>
using namespace std;
// contains
// find_intervalのpublic化
// 次の区間までの
template <typename T, typename Compare = std::less<T>>
requires std::strict_weak_order<Compare, T, T>
class IntervalSet
{
private:
// すべての区間長の合計
T sum_length = 0;
// 互いに交差しない半開区間 [first, second) の集合
set<pair<T, T>> intervals;
// 点 x を含む区間のイテレータ(なければ end)
typename set<pair<T, T>>::const_iterator
find_interval(const T &x) const
{
// first >= x となる最初の区間
auto itr = intervals.lower_bound({x, x});
// ひとつ左だけ、x を含みうる
if (itr != intervals.begin())
{
auto prev_itr = prev(itr);
if (prev_itr->first <= x && x < prev_itr->second)
return prev_itr;
}
// ぴったり x から始まる区間
if (itr != intervals.end() && itr->first <= x && x < itr->second)
return itr;
return intervals.end();
}
public:
IntervalSet() {}
// [L, R) を追加し、必要なら隣接区間とマージする
pair<T, T> add_interval(const T &L, const T &R)
{
return add_interval({L, R});
}
// [L, R) を追加し、必要なら隣接区間とマージする
pair<T, T> add_interval(const pair<T, T> &LR)
{
assert(LR.first < LR.second);
auto [l, r] = LR;
// first >= l となる最初の区間
auto itr = intervals.lower_bound(LR);
// 左側に 1 つだけ重なりうる区間を見る
if (itr != intervals.begin())
{
auto prev_itr = prev(itr);
if (prev_itr->second >= l)
{
l = min(l, prev_itr->first);
r = max(r, prev_itr->second);
sum_length -= abs(prev_itr->second - prev_itr->first);
itr = intervals.erase(prev_itr);
}
}
// 先頭が r 以下の区間をすべてマージ
while (itr != intervals.end() && itr->first <= r)
{
l = min(l, itr->first);
r = max(r, itr->second);
sum_length -= abs(itr->second - itr->first);
itr = intervals.erase(itr);
}
intervals.insert({l, r});
sum_length += abs(r - l);
return {l, r};
}
// 区間集合から [L, R) を削除(集合としての差)
void erase_interval(const T &L, const T &R)
{
erase_interval({L, R});
}
// 区間集合から [L, R) を削除(集合としての差)
void erase_interval(const std::pair<T, T> &LR)
{
assert(LR.first < LR.second);
const auto [l, r] = LR;
if (intervals.empty())
return;
// first >= l となる最初の区間
auto itr = intervals.lower_bound(LR);
// 左側に 1 つだけ、[l, r) と交差しうる区間を見る
if (itr != intervals.begin())
{
auto prev_itr = prev(itr);
const auto [a, b] = *prev_itr;
// [a, b) と [l, r) が交差
if (b > l && a < r)
{
itr = intervals.erase(prev_itr);
sum_length -= abs(b - a);
if (a < l)
{
sum_length += abs(l - a);
intervals.insert({a, l});
}
if (r < b)
{
sum_length += abs(b - r);
intervals.insert({r, b});
}
}
}
// ここから先は first >= l が保証されている
while (itr != intervals.end())
{
const auto [a, b] = *itr;
if (a >= r)
break; // これ以降は交差しない
if (b <= r)
{
// 完全に [l, r) に含まれる区間
itr = intervals.erase(itr);
sum_length -= abs(b - a);
}
else
{
// a < r < b →右側 [r, b) だけ残す
intervals.erase(itr);
intervals.insert({r, b});
sum_length -= abs(b - a);
sum_length += abs(b - r);
break;
}
}
}
// u と v が同じ区間に属するか
bool same(T u, T v) const
{
if (intervals.empty())
return false;
if (v < u)
swap(u, v);
auto itr = find_interval(u);
if (itr == intervals.end())
return false;
return v < itr->second;
}
// u が属する区間の長さ(なければ 0)
T length(const T &u) const
{
auto itr = find_interval(u);
if (itr == intervals.end())
return T(0);
return itr->second - itr->first;
}
// すべての区間長の合計
T get_sum_lentgth() const
{
return sum_length;
}
void print() const
{
for (auto [l, r] : intervals)
{
cout << "[" << l << ", " << r << ") ";
}
cout << '\n';
}
};