:heavy_check_mark: 重み付きUnion Find (DataStructure/weighted_union_find.hpp)

重み付きUnion Find

無向グラフに対して、

  • 辺の追加
  • $2$ 頂点が連結かの判定
  • $2$ 頂点間の重み差の計算

をならし $O(\alpha(n))$ 時間で処理することが出来ます。

また、内部的に各連結成分ごとに代表となる頂点を $1$ つ持っています。辺の追加により連結成分がマージされる時、新たな代表元は元の連結成分の代表元のうちどちらかになります。

コンストラクタ

WeightedUnionFind<T> wuf(int n)
  • TTの重みをもった頂点を作ります。
  • $n$ 頂点 $0$ 辺の無向グラフを作ります。

計算量

  • $O(n)$

merge

int wuf.merge(int a, int b, T w)

頂点$a$と頂点$b$の重みの差を$w$に設定します。 $a, b$ が連結だった場合はその代表元、非連結だった場合は新たな代表元を返します。

制約

  • $0 \leq a < n$
  • $0 \leq b < n$
  • $ a $と$ b $が連結の場合、$weight(a) + w = weight(b)$

計算量

  • ならし $O(\alpha(n))$

same

bool wuf.same(int a, int b)

頂点 $a, b$ が連結かどうかを返します。

制約

  • $0 \leq a < n$
  • $0 \leq b < n$

計算量

  • ならし $O(\alpha(n))$

valid

bool wuf.valid(int a, int b, T w)

頂点$a$と頂点$b$の重みの差を$w$に設定した場合に、ここまでにmergeされた情報に矛盾が生じないかを返します。

制約

  • $0 \leq a < n$
  • $0 \leq b < n$

計算量

  • ならし $O(\alpha(n))$

leader

int wuf.leader(int a)

頂点 $a$ の属する連結成分の代表元を返します。

制約

  • $0 \leq a < n$

計算量

  • ならし $O(\alpha(n))$

size

int wuf.size(int a)

頂点 $a$ の属する連結成分のサイズを返します。

制約

  • $0 \leq a < n$

計算量

  • ならし $O(\alpha(n))$

weight

T wuf.weight(int a)

頂点 $a$ とその代表元との重みの差を返します。

制約

  • $0 \leq a < n$

計算量

  • ならし $O(\alpha(n))$

weight

T wuf.diff(int a, int b)

頂点$a$と頂点$b$の重みの差を返します。

制約

  • $0 \leq a < n$
  • $0 \leq b < n$
  • $a$ と $b$は連結

計算量

  • ならし $O(\alpha(n))$

groups

vector<vector<int>> wuf.groups()

グラフを連結成分に分け、その情報を返します。

返り値は「「一つの連結成分の頂点番号のリスト」のリスト」です。 (内側外側限らず)vector内でどの順番で頂点が格納されているかは未定義です。

計算量

  • $O(n)$

Verified with

Code

#include <bits/stdc++.h>
using namespace std;

template <class T = long long>
struct WeightedUnionFind
{
public:
    WeightedUnionFind() : _n(0) {}
    explicit WeightedUnionFind(int n, T e = 0) : _n(n), parent_or_size(n, -1), diff_weight(n, e) {}

    int merge(int a, int b, T w)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        assert(valid(a, b, w));

        int x = leader(a), y = leader(b);
        if (same(a, b))
        {
            return x;
        }

        w += weight(a), w -= weight(b);
        if (x == y)
        {
            return x;
        }
        if (-parent_or_size[x] < -parent_or_size[y])
        {
            swap(x, y), w *= -1;
        }
        parent_or_size[x] += parent_or_size[y];
        parent_or_size[y] = x;
        diff_weight[y] = w;
        return x;
    }

    bool same(int a, int b)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        return leader(a) == leader(b);
    }

    bool valid(int a, int b, T w)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        return !same(a, b) || diff(a, b) == w;
    }

    int leader(int a)
    {
        assert(0 <= a && a < _n);
        if (parent_or_size[a] < 0)
        {
            return a;
        }
        int r = leader(parent_or_size[a]);
        diff_weight[a] += diff_weight[parent_or_size[a]];
        return parent_or_size[a] = r;
    }

    int size(int a)
    {
        assert(0 <= a && a < _n);
        return -parent_or_size[leader(a)];
    }

    T weight(int a)
    {
        assert(0 <= a && a < _n);
        leader(a);
        return diff_weight[a];
    }

    T diff(int a, int b)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        assert(same(a, b));
        return weight(b) - weight(a);
    }

    vector<vector<int>> groups()
    {
        vector<int> leader_buf(_n), group_size(_n);
        for (int i = 0; i < _n; i++)
        {
            leader_buf[i] = leader(i);
            group_size[leader_buf[i]]++;
        }
        vector<vector<int>> result(_n);
        for (int i = 0; i < _n; i++)
        {
            result[i].reserve(group_size[i]);
        }
        for (int i = 0; i < _n; i++)
        {
            result[leader_buf[i]].push_back(i);
        }
        result.erase(
            remove_if(result.begin(), result.end(),
                      [&](const vector<int> &v)
                      { return v.empty(); }),
            result.end());
        return result;
    }

private:
    int _n;
    vector<int> parent_or_size;
    vector<T> diff_weight;
};
#line 1 "DataStructure/weighted_union_find.hpp"
#include <bits/stdc++.h>
using namespace std;

template <class T = long long>
struct WeightedUnionFind
{
public:
    WeightedUnionFind() : _n(0) {}
    explicit WeightedUnionFind(int n, T e = 0) : _n(n), parent_or_size(n, -1), diff_weight(n, e) {}

    int merge(int a, int b, T w)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        assert(valid(a, b, w));

        int x = leader(a), y = leader(b);
        if (same(a, b))
        {
            return x;
        }

        w += weight(a), w -= weight(b);
        if (x == y)
        {
            return x;
        }
        if (-parent_or_size[x] < -parent_or_size[y])
        {
            swap(x, y), w *= -1;
        }
        parent_or_size[x] += parent_or_size[y];
        parent_or_size[y] = x;
        diff_weight[y] = w;
        return x;
    }

    bool same(int a, int b)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        return leader(a) == leader(b);
    }

    bool valid(int a, int b, T w)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        return !same(a, b) || diff(a, b) == w;
    }

    int leader(int a)
    {
        assert(0 <= a && a < _n);
        if (parent_or_size[a] < 0)
        {
            return a;
        }
        int r = leader(parent_or_size[a]);
        diff_weight[a] += diff_weight[parent_or_size[a]];
        return parent_or_size[a] = r;
    }

    int size(int a)
    {
        assert(0 <= a && a < _n);
        return -parent_or_size[leader(a)];
    }

    T weight(int a)
    {
        assert(0 <= a && a < _n);
        leader(a);
        return diff_weight[a];
    }

    T diff(int a, int b)
    {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        assert(same(a, b));
        return weight(b) - weight(a);
    }

    vector<vector<int>> groups()
    {
        vector<int> leader_buf(_n), group_size(_n);
        for (int i = 0; i < _n; i++)
        {
            leader_buf[i] = leader(i);
            group_size[leader_buf[i]]++;
        }
        vector<vector<int>> result(_n);
        for (int i = 0; i < _n; i++)
        {
            result[i].reserve(group_size[i]);
        }
        for (int i = 0; i < _n; i++)
        {
            result[leader_buf[i]].push_back(i);
        }
        result.erase(
            remove_if(result.begin(), result.end(),
                      [&](const vector<int> &v)
                      { return v.empty(); }),
            result.end());
        return result;
    }

private:
    int _n;
    vector<int> parent_or_size;
    vector<T> diff_weight;
};
Back to top page