重み付きUnion Find (DataStructure/weighted_union_find.hpp)
- View this file on GitHub
- View document part on GitHub
- Last update: 2024-10-01 20:26:27+09:00
- Include:
#include "DataStructure/weighted_union_find.hpp"
重み付きUnion Find
無向グラフに対して、
- 辺の追加
- $2$ 頂点が連結かの判定
- $2$ 頂点間の重み差の計算
をならし $O(\alpha(n))$ 時間で処理することが出来ます。
また、内部的に各連結成分ごとに代表となる頂点を $1$ つ持っています。辺の追加により連結成分がマージされる時、新たな代表元は元の連結成分の代表元のうちどちらかになります。
コンストラクタ
WeightedUnionFind<T> wuf(int n)
- 型
T
型T
の重みをもった頂点を作ります。 - $n$ 頂点 $0$ 辺の無向グラフを作ります。
計算量
- $O(n)$
merge
int wuf.merge(int a, int b, T w)
頂点$a$と頂点$b$の重みの差を$w$に設定します。 $a, b$ が連結だった場合はその代表元、非連結だった場合は新たな代表元を返します。
制約
- $0 \leq a < n$
- $0 \leq b < n$
- $ a $と$ b $が連結の場合、$weight(a) + w = weight(b)$
計算量
- ならし $O(\alpha(n))$
same
bool wuf.same(int a, int b)
頂点 $a, b$ が連結かどうかを返します。
制約
- $0 \leq a < n$
- $0 \leq b < n$
計算量
- ならし $O(\alpha(n))$
valid
bool wuf.valid(int a, int b, T w)
頂点$a$と頂点$b$の重みの差を$w$に設定した場合に、ここまでにmergeされた情報に矛盾が生じないかを返します。
制約
- $0 \leq a < n$
- $0 \leq b < n$
計算量
- ならし $O(\alpha(n))$
leader
int wuf.leader(int a)
頂点 $a$ の属する連結成分の代表元を返します。
制約
- $0 \leq a < n$
計算量
- ならし $O(\alpha(n))$
size
int wuf.size(int a)
頂点 $a$ の属する連結成分のサイズを返します。
制約
- $0 \leq a < n$
計算量
- ならし $O(\alpha(n))$
weight
T wuf.weight(int a)
頂点 $a$ とその代表元との重みの差を返します。
制約
- $0 \leq a < n$
計算量
- ならし $O(\alpha(n))$
weight
T wuf.diff(int a, int b)
頂点$a$と頂点$b$の重みの差を返します。
制約
- $0 \leq a < n$
- $0 \leq b < n$
- $a$ と $b$は連結
計算量
- ならし $O(\alpha(n))$
groups
vector<vector<int>> wuf.groups()
グラフを連結成分に分け、その情報を返します。
返り値は「「一つの連結成分の頂点番号のリスト」のリスト」です。 (内側外側限らず)vector内でどの順番で頂点が格納されているかは未定義です。
計算量
- $O(n)$
Verified with
test/DataStructure/weighted_union_find/aoj-DSL_1_B.cpp
test/DataStructure/weighted_union_find/yosupo-unionfind_with_potential.cpp
Code
#include <bits/stdc++.h>
using namespace std;
template <class T = long long>
struct WeightedUnionFind
{
public:
WeightedUnionFind() : _n(0) {}
explicit WeightedUnionFind(int n, T e = 0) : _n(n), parent_or_size(n, -1), diff_weight(n, e) {}
int merge(int a, int b, T w)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
assert(valid(a, b, w));
int x = leader(a), y = leader(b);
if (same(a, b))
{
return x;
}
w += weight(a), w -= weight(b);
if (x == y)
{
return x;
}
if (-parent_or_size[x] < -parent_or_size[y])
{
swap(x, y), w *= -1;
}
parent_or_size[x] += parent_or_size[y];
parent_or_size[y] = x;
diff_weight[y] = w;
return x;
}
bool same(int a, int b)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
return leader(a) == leader(b);
}
bool valid(int a, int b, T w)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
return !same(a, b) || diff(a, b) == w;
}
int leader(int a)
{
assert(0 <= a && a < _n);
if (parent_or_size[a] < 0)
{
return a;
}
int r = leader(parent_or_size[a]);
diff_weight[a] += diff_weight[parent_or_size[a]];
return parent_or_size[a] = r;
}
int size(int a)
{
assert(0 <= a && a < _n);
return -parent_or_size[leader(a)];
}
T weight(int a)
{
assert(0 <= a && a < _n);
leader(a);
return diff_weight[a];
}
T diff(int a, int b)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
assert(same(a, b));
return weight(b) - weight(a);
}
vector<vector<int>> groups()
{
vector<int> leader_buf(_n), group_size(_n);
for (int i = 0; i < _n; i++)
{
leader_buf[i] = leader(i);
group_size[leader_buf[i]]++;
}
vector<vector<int>> result(_n);
for (int i = 0; i < _n; i++)
{
result[i].reserve(group_size[i]);
}
for (int i = 0; i < _n; i++)
{
result[leader_buf[i]].push_back(i);
}
result.erase(
remove_if(result.begin(), result.end(),
[&](const vector<int> &v)
{ return v.empty(); }),
result.end());
return result;
}
private:
int _n;
vector<int> parent_or_size;
vector<T> diff_weight;
};
#line 1 "DataStructure/weighted_union_find.hpp"
#include <bits/stdc++.h>
using namespace std;
template <class T = long long>
struct WeightedUnionFind
{
public:
WeightedUnionFind() : _n(0) {}
explicit WeightedUnionFind(int n, T e = 0) : _n(n), parent_or_size(n, -1), diff_weight(n, e) {}
int merge(int a, int b, T w)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
assert(valid(a, b, w));
int x = leader(a), y = leader(b);
if (same(a, b))
{
return x;
}
w += weight(a), w -= weight(b);
if (x == y)
{
return x;
}
if (-parent_or_size[x] < -parent_or_size[y])
{
swap(x, y), w *= -1;
}
parent_or_size[x] += parent_or_size[y];
parent_or_size[y] = x;
diff_weight[y] = w;
return x;
}
bool same(int a, int b)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
return leader(a) == leader(b);
}
bool valid(int a, int b, T w)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
return !same(a, b) || diff(a, b) == w;
}
int leader(int a)
{
assert(0 <= a && a < _n);
if (parent_or_size[a] < 0)
{
return a;
}
int r = leader(parent_or_size[a]);
diff_weight[a] += diff_weight[parent_or_size[a]];
return parent_or_size[a] = r;
}
int size(int a)
{
assert(0 <= a && a < _n);
return -parent_or_size[leader(a)];
}
T weight(int a)
{
assert(0 <= a && a < _n);
leader(a);
return diff_weight[a];
}
T diff(int a, int b)
{
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
assert(same(a, b));
return weight(b) - weight(a);
}
vector<vector<int>> groups()
{
vector<int> leader_buf(_n), group_size(_n);
for (int i = 0; i < _n; i++)
{
leader_buf[i] = leader(i);
group_size[leader_buf[i]]++;
}
vector<vector<int>> result(_n);
for (int i = 0; i < _n; i++)
{
result[i].reserve(group_size[i]);
}
for (int i = 0; i < _n; i++)
{
result[leader_buf[i]].push_back(i);
}
result.erase(
remove_if(result.begin(), result.end(),
[&](const vector<int> &v)
{ return v.empty(); }),
result.end());
return result;
}
private:
int _n;
vector<int> parent_or_size;
vector<T> diff_weight;
};