:heavy_check_mark: test/Tree/lowest_common_ancestor/yosupo-lca.cpp

Depends on

Code

// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/lca
// competitive-verifier: TLE 5

#include "Tree/lowest_common_ancestor.hpp"
#include <bits/stdc++.h>

using namespace std;

int main()
{
    int N, Q;
    cin >> N >> Q;
    LCA lca(N);

    for (int i = 0; i < N - 1; i++)
    {
        int p;
        cin >> p;
        lca.add_edge(p, i + 1);
    }
    for (int i = 0; i < Q; i++)
    {
        int u, v;
        cin >> u >> v;
        cout << lca.find(u, v) << endl;
    }

    return 0;
}
#line 1 "test/Tree/lowest_common_ancestor/yosupo-lca.cpp"
// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/lca
// competitive-verifier: TLE 5

#line 1 "Tree/lowest_common_ancestor.hpp"
#include <bits/stdc++.h>
using namespace std;

#line 2 "Tree/doubling.hpp"
using namespace std;

template <typename T = long long>
class Doubling
{
    int N, L;

    vector<pair<int, T>> nexts;
    vector<vector<pair<int, T>>> parents;
    bool done_initialized = false;

public:
    Doubling() {};

    Doubling(int N, unsigned long long max_k) : N(N)
    {
        L = max<int>(bit_width(max_k), 1);
        nexts = vector<pair<int, T>>(N, {-1, 0});
        parents = vector<vector<pair<int, T>>>(N + 1, vector<pair<int, T>>(L, {-1, 0}));
    }

    void add_edge(int from, int to, T cost)
    {
        if (nexts[from].first != -1)
        {
            assert(false);
        }

        nexts[from] = pair(to, cost);
    }

    void init()
    {
        for (int i = 0; i < N; i++)
        {
            parents[i][0] = nexts[i];
        }

        for (int i = 0; i < L - 1; i++)
        {
            for (int v = 0; v < N; v++)
            {
                auto p1 = parents[v][i];
                auto p2 = parents[p1.first][i];
                parents[v][i + 1] = pair(p2.first, p1.second + p2.second);
            }
        }

        done_initialized = true;
    }

    // vを始点にK回移動した先のノードを返します
    pair<int, T> step_forward(int v, unsigned long long k)
    {
        if (!done_initialized)
        {
            init();
        }

        T sum_cost = 0;
        while (k != 0)
        {
            int bit_i = countr_zero(k);
            sum_cost += parents[v][bit_i].second, v = parents[v][bit_i].first;
            k ^= 1ll << bit_i;
        }

        return pair(v, sum_cost);
    }

    vector<pair<int, T>> step_forwards(const vector<int> &indeces, unsigned long long k)
    {
        if (!done_initialized)
        {
            init();
        }

        vector<pair<int, T>> result(indeces.size());
        for (int i = 0; i < indeces.size(); i++)
        {
            result[i].first = indeces[i];
        }

        T sum_cost = 0;
        while (k != 0)
        {
            int bit_i = countr_zero(k);
            for (pair<int, T> &p : result)
            {
                p.second += parents[p.first][bit_i].second, p.first = parents[p.first][bit_i].first;
            }
            k ^= 1ll << bit_i;
        }

        return result;
    }

    vector<pair<int, T>> step_forwards(unsigned long long k)
    {
        vector<int> indeces;
        for (int i = 0; i < N; i++)
        {
            indeces.push_back(i);
        }

        return step_forwards(indeces, k);
    }
};
#line 2 "Other/binary_search.hpp"
using namespace std;

template <typename T>
enable_if_t<is_integral_v<T>, T>
bin_search(T ok, T ng, function<bool(T)> check)
{
    while (max(ok, ng) - min(ok, ng) > 1)
    {
        T mid = midpoint(ok, ng);
        (check(mid) ? ok : ng) = mid;
    }
    return ok;
}

template <typename T>
enable_if_t<is_floating_point_v<T>, T>
bin_search(T ok, T ng, function<bool(T)> check)
{
    // midpointが成立するように浮動小数型を整数型にbit_castするクラス
    class OrderedBitcastFloat
    {
    public:
        using UInt = conditional_t<sizeof(T) == 4, uint32_t, conditional_t<sizeof(T) == 8, uint64_t, void>>;
        static_assert(!is_same_v<UInt, void>, "T must be float(4), double(8)");

        T real;       // 元の実数
        UInt int_key; // マッピング整数

        OrderedBitcastFloat(T x) : real(x), int_key(real_to_int_key(x)) {}
        OrderedBitcastFloat(UInt u) : real(int_key_to_real(u)), int_key(u) {}

    private:
        static constexpr UInt msb()
        {
            return UInt(1) << (8 * sizeof(UInt) - 1);
        }

        // 実数 → マッピング整数
        static UInt real_to_int_key(T x)
        {
            UInt bits = bit_cast<UInt>(x);
            if (bits & msb())
            {
                // 符号ビット=1(負数領域):反転
                return ~bits;
            }
            else
            {
                // 符号ビット=0(正数領域):MSB bit を足してオフセット
                return bits | msb();
            }
        }

        // マッピング整数 → 実数
        static T int_key_to_real(UInt u)
        {
            UInt bits;
            if (u & msb())
            {
                // MSB=1 → 正数領域:MSBを落とす
                bits = u & ~msb();
            }
            else
            {
                // MSB=0 → 負数領域:反転を戻す
                bits = ~u;
            }
            return bit_cast<T>(bits);
        }
    };

    OrderedBitcastFloat temp_ok(ok), temp_ng(ng);

    while (max(temp_ok.int_key, temp_ng.int_key) - min(temp_ok.int_key, temp_ng.int_key) > 1)
    {
        OrderedBitcastFloat mid(midpoint(temp_ok.int_key, temp_ng.int_key));
        (check(mid.real) ? temp_ok : temp_ng) = mid;
    }

    return temp_ok.real;
}
#line 6 "Tree/lowest_common_ancestor.hpp"

class LCA
{
    int N;
    int bit_width;
    vector<vector<int>> edges;
    vector<int> depth;
    Doubling<int_fast8_t> doubling;
    bool done_initialized = false;

    void dfs(int now, int parent)
    {
        if (parent != -1)
        {
            doubling.add_edge(now, parent, 0);
            depth[now] = depth[parent] + 1;
        }

        for (int next : edges[now])
        {
            if (next == parent)
            {
                continue;
            }

            dfs(next, now);
        }
    }

public:
    LCA(int N) : N(N)
    {
        bit_width = std::bit_width((unsigned int)N);
        doubling = Doubling<int_fast8_t>(N, N);
        edges.resize(N);
        depth = vector<int>(N, -1);
    };

    void add_edge(int u, int v)
    {
        edges[u].push_back(v);
    }

    void init(int root = 0)
    {
        depth[root] = 0;
        doubling.add_edge(root, root, 0);
        dfs(root, -1);
        done_initialized = true;
    }

    int find(int u, int v)
    {
        if (!done_initialized)
        {
            init();
        }

        if (depth[u] < depth[v])
        {
            swap(u, v);
        }

        // 深さを同じにする
        u = doubling.step_forward(u, depth[u] - depth[v]).first;
        if (u == v)
        {
            return u;
        }

        for (int bit = bit_width - 1; bit >= 0; bit--)
        {
            int next_u = doubling.step_forward(u, 1 << bit).first, next_v = doubling.step_forward(v, 1 << bit).first;
            if (next_u != next_v)
            {
                u = next_u, v = next_v;
            }
        }

        return doubling.step_forward(u, 1).first;
    }

    int dist(int u, int v) { return depth[u] + depth[v] - 2 * depth[find(u, v)]; }
};
#line 6 "test/Tree/lowest_common_ancestor/yosupo-lca.cpp"

using namespace std;

int main()
{
    int N, Q;
    cin >> N >> Q;
    LCA lca(N);

    for (int i = 0; i < N - 1; i++)
    {
        int p;
        cin >> p;
        lca.add_edge(p, i + 1);
    }
    for (int i = 0; i < Q; i++)
    {
        int u, v;
        cin >> u >> v;
        cout << lca.find(u, v) << endl;
    }

    return 0;
}

Test cases

Env Name Status Elapsed Memory
g++ almost_line_00 :heavy_check_mark: AC 1547 ms 128 MB
g++ almost_line_01 :heavy_check_mark: AC 1504 ms 128 MB
g++ binary_00 :heavy_check_mark: AC 1156 ms 119 MB
g++ binary_01 :heavy_check_mark: AC 1129 ms 118 MB
g++ binary_02 :heavy_check_mark: AC 1106 ms 119 MB
g++ example_00 :heavy_check_mark: AC 5 ms 3 MB
g++ line_00 :heavy_check_mark: AC 1263 ms 111 MB
g++ line_01 :heavy_check_mark: AC 1360 ms 132 MB
g++ line_02 :heavy_check_mark: AC 516 ms 17 MB
g++ line_03 :heavy_check_mark: AC 297 ms 122 MB
g++ line_04 :heavy_check_mark: AC 378 ms 80 MB
g++ max_line_00 :heavy_check_mark: AC 1627 ms 142 MB
g++ max_line_01 :heavy_check_mark: AC 1560 ms 142 MB
g++ max_line_02 :heavy_check_mark: AC 1490 ms 142 MB
g++ max_random_00 :heavy_check_mark: AC 1074 ms 119 MB
g++ max_random_01 :heavy_check_mark: AC 1059 ms 119 MB
g++ max_random_02 :heavy_check_mark: AC 1063 ms 119 MB
g++ path_graph_root_centroid_00 :heavy_check_mark: AC 1924 ms 134 MB
g++ path_graph_root_centroid_01 :heavy_check_mark: AC 1904 ms 134 MB
g++ path_graph_root_centroid_02 :heavy_check_mark: AC 1909 ms 134 MB
g++ random_00 :heavy_check_mark: AC 871 ms 93 MB
g++ random_01 :heavy_check_mark: AC 897 ms 110 MB
g++ random_02 :heavy_check_mark: AC 516 ms 15 MB
g++ random_03 :heavy_check_mark: AC 289 ms 102 MB
g++ random_04 :heavy_check_mark: AC 303 ms 67 MB
Back to top page