:heavy_check_mark: Interval Set (DataStructure/interval_set.hpp)

Interval Set

互いに交差しない半開区間 [L, R) の集合を管理するデータ構造です。
区間の追加時には自動的に隣接・重複する区間をマージし、
連結判定や区間長の取得、全体の長さの取得などができます。

  • テンプレート引数
    • T : 区間端点の型
    • Compare : std::set 用の比較関数(デフォルト std::less<T>
  • 制約
    • CompareT に対して std::strict_weak_order を満たす
    • T 同士の差 T - Tstd::abs が利用できる(長さ計算用)
  • 内部表現
    • 区間はすべて 半開区間 [L, R) として保存されます
    • どの 2 区間も互いに交差・隣接しないように保たれます

コンストラクタ

IntervalSet<T, Compare>()

空の区間集合を構築します。

計算量

  • $O(1)$

add_interval

std::pair<T, T> add_interval(const T& L, const T& R)
std::pair<T, T> add_interval(const std::pair<T, T>& LR)

半開区間 [L, R) を追加し、必要に応じて既存の区間とマージします。
戻り値は、追加後に [L, R) を含んでいる 1 つの大きな区間 [l, r) です。

簡単な例

  • もともと [1, 3) [5, 7) がある状態で add_interval(2, 6) を呼ぶと、
    区間は [1, 7) になります。

制約

  • L < R または LR.first < LR.second

計算量

  • 変更される区間数を $k$、全体の区間数を $M$ とすると
    $O(k \log M)$

erase_interval

void erase_interval(const T& L, const T& R)
void erase_interval(const std::pair<T, T>& LR)

区間集合から半開区間 [L, R) を 集合として差し引きます。
部分的に重なる区間は分割され、完全に含まれる区間は削除されます。

簡単な例

  • もともと [1, 3) [5, 8) [10, 13) がある状態で
    erase_interval(2, 7) を呼ぶと、
    区間は [1, 2) [7, 8) [10, 13) になります。

制約

  • L < R または LR.first < LR.second

計算量

  • 変更される区間数を $k$、全体の区間数を $M$ とすると
    $O(k \log M)$

same

bool same(T u, T v) const

uv が同じ区間(同じ連結成分)に属するかを判定します。
どちらか一方でもどの区間にも属していなければ false を返します。

計算量

  • $O(\log M)$

length

T length(const T& u) const

u が属する区間 [L, R) の長さ R - L を返します。
u がどの区間にも属していない場合は 0 を返します。
T が整数型のときは、「区間に含まれる要素数」として扱えます。

計算量

  • $O(\log M)$

get_sum_lentgth

T get_sum_lentgth() const

すべての区間長の合計を返します。

計算量

  • $O(1)$

print(デバッグ用途)

void print() const

保持している区間を [L, R) 形式で標準出力に出力します。

計算量

  • 区間数を $M$ とすると $O(M)$

Verified with

Code

#include <bits/stdc++.h>
using namespace std;

// contains
// find_intervalのpublic化
// 次の区間までの

template <typename T, typename Compare = std::less<T>>
    requires std::strict_weak_order<Compare, T, T>
class IntervalSet
{
private:
    // すべての区間長の合計
    T sum_length = 0;

    // 互いに交差しない半開区間 [first, second) の集合
    set<pair<T, T>> intervals;

    // 点 x を含む区間のイテレータ(なければ end)
    typename set<pair<T, T>>::const_iterator
    find_interval(const T &x) const
    {
        // first >= x となる最初の区間
        auto itr = intervals.lower_bound({x, x});

        // ひとつ左だけ、x を含みうる
        if (itr != intervals.begin())
        {
            auto prev_itr = prev(itr);
            if (prev_itr->first <= x && x < prev_itr->second)
                return prev_itr;
        }

        // ぴったり x から始まる区間
        if (itr != intervals.end() && itr->first <= x && x < itr->second)
            return itr;

        return intervals.end();
    }

public:
    IntervalSet() {}

    // [L, R) を追加し、必要なら隣接区間とマージする
    pair<T, T> add_interval(const T &L, const T &R)
    {
        return add_interval({L, R});
    }

    // [L, R) を追加し、必要なら隣接区間とマージする
    pair<T, T> add_interval(const pair<T, T> &LR)
    {
        assert(LR.first < LR.second);

        auto [l, r] = LR;

        // first >= l となる最初の区間
        auto itr = intervals.lower_bound(LR);

        // 左側に 1 つだけ重なりうる区間を見る
        if (itr != intervals.begin())
        {
            auto prev_itr = prev(itr);
            if (prev_itr->second >= l)
            {
                l = min(l, prev_itr->first);
                r = max(r, prev_itr->second);

                sum_length -= abs(prev_itr->second - prev_itr->first);
                itr = intervals.erase(prev_itr);
            }
        }

        // 先頭が r 以下の区間をすべてマージ
        while (itr != intervals.end() && itr->first <= r)
        {
            l = min(l, itr->first);
            r = max(r, itr->second);

            sum_length -= abs(itr->second - itr->first);
            itr = intervals.erase(itr);
        }

        intervals.insert({l, r});
        sum_length += abs(r - l);

        return {l, r};
    }

    // 区間集合から [L, R) を削除(集合としての差)
    void erase_interval(const T &L, const T &R)
    {
        erase_interval({L, R});
    }

    // 区間集合から [L, R) を削除(集合としての差)
    void erase_interval(const std::pair<T, T> &LR)
    {
        assert(LR.first < LR.second);

        const auto [l, r] = LR;
        if (intervals.empty())
            return;

        // first >= l となる最初の区間
        auto itr = intervals.lower_bound(LR);

        // 左側に 1 つだけ、[l, r) と交差しうる区間を見る
        if (itr != intervals.begin())
        {
            auto prev_itr = prev(itr);
            const auto [a, b] = *prev_itr;

            // [a, b) と [l, r) が交差
            if (b > l && a < r)
            {
                itr = intervals.erase(prev_itr);
                sum_length -= abs(b - a);

                if (a < l)
                {
                    sum_length += abs(l - a);
                    intervals.insert({a, l});
                }
                if (r < b)
                {
                    sum_length += abs(b - r);
                    intervals.insert({r, b});
                }
            }
        }

        // ここから先は first >= l が保証されている
        while (itr != intervals.end())
        {
            const auto [a, b] = *itr;

            if (a >= r)
                break; // これ以降は交差しない

            if (b <= r)
            {
                // 完全に [l, r) に含まれる区間
                itr = intervals.erase(itr);
                sum_length -= abs(b - a);
            }
            else
            {
                // a < r < b →右側 [r, b) だけ残す
                intervals.erase(itr);
                intervals.insert({r, b});
                sum_length -= abs(b - a);
                sum_length += abs(b - r);
                break;
            }
        }
    }

    // u と v が同じ区間に属するか
    bool same(T u, T v) const
    {
        if (intervals.empty())
            return false;

        if (v < u)
            swap(u, v);

        auto itr = find_interval(u);
        if (itr == intervals.end())
            return false;

        return v < itr->second;
    }

    // u が属する区間の長さ(なければ 0)
    T length(const T &u) const
    {
        auto itr = find_interval(u);
        if (itr == intervals.end())
            return T(0);
        return itr->second - itr->first;
    }

    // すべての区間長の合計
    T get_sum_lentgth() const
    {
        return sum_length;
    }

    void print() const
    {
        for (auto [l, r] : intervals)
        {
            cout << "[" << l << ", " << r << ") ";
        }
        cout << '\n';
    }
};
#line 1 "DataStructure/interval_set.hpp"
#include <bits/stdc++.h>
using namespace std;

// contains
// find_intervalのpublic化
// 次の区間までの

template <typename T, typename Compare = std::less<T>>
    requires std::strict_weak_order<Compare, T, T>
class IntervalSet
{
private:
    // すべての区間長の合計
    T sum_length = 0;

    // 互いに交差しない半開区間 [first, second) の集合
    set<pair<T, T>> intervals;

    // 点 x を含む区間のイテレータ(なければ end)
    typename set<pair<T, T>>::const_iterator
    find_interval(const T &x) const
    {
        // first >= x となる最初の区間
        auto itr = intervals.lower_bound({x, x});

        // ひとつ左だけ、x を含みうる
        if (itr != intervals.begin())
        {
            auto prev_itr = prev(itr);
            if (prev_itr->first <= x && x < prev_itr->second)
                return prev_itr;
        }

        // ぴったり x から始まる区間
        if (itr != intervals.end() && itr->first <= x && x < itr->second)
            return itr;

        return intervals.end();
    }

public:
    IntervalSet() {}

    // [L, R) を追加し、必要なら隣接区間とマージする
    pair<T, T> add_interval(const T &L, const T &R)
    {
        return add_interval({L, R});
    }

    // [L, R) を追加し、必要なら隣接区間とマージする
    pair<T, T> add_interval(const pair<T, T> &LR)
    {
        assert(LR.first < LR.second);

        auto [l, r] = LR;

        // first >= l となる最初の区間
        auto itr = intervals.lower_bound(LR);

        // 左側に 1 つだけ重なりうる区間を見る
        if (itr != intervals.begin())
        {
            auto prev_itr = prev(itr);
            if (prev_itr->second >= l)
            {
                l = min(l, prev_itr->first);
                r = max(r, prev_itr->second);

                sum_length -= abs(prev_itr->second - prev_itr->first);
                itr = intervals.erase(prev_itr);
            }
        }

        // 先頭が r 以下の区間をすべてマージ
        while (itr != intervals.end() && itr->first <= r)
        {
            l = min(l, itr->first);
            r = max(r, itr->second);

            sum_length -= abs(itr->second - itr->first);
            itr = intervals.erase(itr);
        }

        intervals.insert({l, r});
        sum_length += abs(r - l);

        return {l, r};
    }

    // 区間集合から [L, R) を削除(集合としての差)
    void erase_interval(const T &L, const T &R)
    {
        erase_interval({L, R});
    }

    // 区間集合から [L, R) を削除(集合としての差)
    void erase_interval(const std::pair<T, T> &LR)
    {
        assert(LR.first < LR.second);

        const auto [l, r] = LR;
        if (intervals.empty())
            return;

        // first >= l となる最初の区間
        auto itr = intervals.lower_bound(LR);

        // 左側に 1 つだけ、[l, r) と交差しうる区間を見る
        if (itr != intervals.begin())
        {
            auto prev_itr = prev(itr);
            const auto [a, b] = *prev_itr;

            // [a, b) と [l, r) が交差
            if (b > l && a < r)
            {
                itr = intervals.erase(prev_itr);
                sum_length -= abs(b - a);

                if (a < l)
                {
                    sum_length += abs(l - a);
                    intervals.insert({a, l});
                }
                if (r < b)
                {
                    sum_length += abs(b - r);
                    intervals.insert({r, b});
                }
            }
        }

        // ここから先は first >= l が保証されている
        while (itr != intervals.end())
        {
            const auto [a, b] = *itr;

            if (a >= r)
                break; // これ以降は交差しない

            if (b <= r)
            {
                // 完全に [l, r) に含まれる区間
                itr = intervals.erase(itr);
                sum_length -= abs(b - a);
            }
            else
            {
                // a < r < b →右側 [r, b) だけ残す
                intervals.erase(itr);
                intervals.insert({r, b});
                sum_length -= abs(b - a);
                sum_length += abs(b - r);
                break;
            }
        }
    }

    // u と v が同じ区間に属するか
    bool same(T u, T v) const
    {
        if (intervals.empty())
            return false;

        if (v < u)
            swap(u, v);

        auto itr = find_interval(u);
        if (itr == intervals.end())
            return false;

        return v < itr->second;
    }

    // u が属する区間の長さ(なければ 0)
    T length(const T &u) const
    {
        auto itr = find_interval(u);
        if (itr == intervals.end())
            return T(0);
        return itr->second - itr->first;
    }

    // すべての区間長の合計
    T get_sum_lentgth() const
    {
        return sum_length;
    }

    void print() const
    {
        for (auto [l, r] : intervals)
        {
            cout << "[" << l << ", " << r << ") ";
        }
        cout << '\n';
    }
};
Back to top page