:heavy_check_mark: test/Graph/minimum_steiner_tree/yosupo-minimum_steiner_tree.cpp

Depends on

Code

// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/minimum_steiner_tree

#include "Graph/minimum_steiner_tree.hpp"
#include <bits/stdc++.h>

using namespace std;

int main()
{
    int N, M;
    cin >> N >> M;

    MinimusSteinerTree<long long> graph(N);

    map<array<int, 2>, array<int, 2>> mp;
    for (int i = 0; i < M; i++)
    {
        int U, V, W;
        cin >> U >> V >> W;
        graph.addEdge(U, V, W), graph.addEdge(V, U, W);
        if (!mp.contains({U, V}))
        {
            mp[{U, V}] = {W, i}, mp[{V, U}] = {W, i};
        }
        else if (W < mp[{U, V}][0])
        {
            mp[{U, V}] = {W, i}, mp[{V, U}] = {W, i};
        }
    }

    int K;
    cin >> K;
    vector<int> X(K);
    for (int i = 0; i < K; i++)
    {
        cin >> X[i];
    }

    auto [Y, edges] = graph.solve(X);
    cout << Y << " " << edges.size() << endl;
    long long sum = 0;
    for (auto [U, V] : edges)
    {
        cout << mp[{U, V}][1] << " ";
    }
    cout << endl;

    return 0;
}
#line 1 "test/Graph/minimum_steiner_tree/yosupo-minimum_steiner_tree.cpp"
// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/minimum_steiner_tree

#line 1 "Graph/minimum_steiner_tree.hpp"
#include <bits/stdc++.h>
using namespace std;

template <typename Cost>
struct MinimusSteinerTree
{
    const Cost INF = numeric_limits<Cost>::max();
    using P = pair<Cost, int>;
    struct Edge
    {
        int to;
        Cost cost;
        Edge() {}
        Edge(int to, Cost cost) : to(to), cost(cost) {}
    };
    int n;
    vector<vector<Edge>> g;
    MinimusSteinerTree() {}
    MinimusSteinerTree(int n) : n(n), g(n) {}
    void addEdge(int a, int b, Cost cost)
    {
        g[a].emplace_back(b, cost);
    }

    pair<Cost, vector<array<int, 2>>> solve(const vector<int> &terminals)
    {
        const int K = terminals.size();
        vector dp(1 << K, vector<pair<Cost, vector<pair<int, int>>>>(n, {INF, {}}));
        for (int t = 0; t < K; t++)
        {
            dp[1 << t][terminals[t]].first = 0;
        }

        for (int S = 1; S < (1 << K); S++)
        {
            // Sの部分集合を頂点vを経由して結合するパターンを計算
            for (int T = (S - 1) & S; T > 0; T = (T - 1) & S)
            {
                for (int v = 0; v < n; v++)
                {
                    if (dp[T][v].first == INF || dp[S ^ T][v].first == INF)
                    {
                        continue;
                    }

                    if (dp[T][v].first + dp[S ^ T][v].first < dp[S][v].first)
                    {
                        dp[S][v].first = dp[T][v].first + dp[S ^ T][v].first;
                        dp[S][v].second = {{T, v}, {S ^ T, v}};
                    }
                }
            }

            // 集合Sから到達できる状態を計算
            priority_queue<P, vector<P>, greater<P>> pq;
            for (int u = 0; u < n; u++)
            {
                if (dp[S][u].first != INF)
                {
                    pq.push({dp[S][u].first, u});
                }
            }

            while (pq.size())
            {
                auto [dist, u] = pq.top();
                pq.pop();
                if (dp[S][u].first < dist)
                {
                    continue;
                }

                for (auto [v, cost] : g[u])
                {
                    if (dist + cost < dp[S][v].first)
                    {
                        dp[S][v].first = dist + cost;
                        dp[S][v].second = {{S, u}};
                        pq.push({dp[S][v].first, v});
                    }
                }
            }
        }

        if (dp.back()[terminals[0]].first == INF)
        {
            return {};
        }

        // 復元する
        queue<pair<int, int>> qu;
        qu.push({(1 << K) - 1, terminals[0]});

        vector<array<int, 2>> edges;

        while (qu.size())
        {
            auto [S, u] = qu.front();
            qu.pop();

            for (auto [next_S, v] : dp[S][u].second)
            {
                qu.push({next_S, v});

                if (u != v)
                {
                    edges.push_back({u, v});
                }
            }
        }

        return {dp.back()[terminals[0]].first, edges};
    }
};
#line 5 "test/Graph/minimum_steiner_tree/yosupo-minimum_steiner_tree.cpp"

using namespace std;

int main()
{
    int N, M;
    cin >> N >> M;

    MinimusSteinerTree<long long> graph(N);

    map<array<int, 2>, array<int, 2>> mp;
    for (int i = 0; i < M; i++)
    {
        int U, V, W;
        cin >> U >> V >> W;
        graph.addEdge(U, V, W), graph.addEdge(V, U, W);
        if (!mp.contains({U, V}))
        {
            mp[{U, V}] = {W, i}, mp[{V, U}] = {W, i};
        }
        else if (W < mp[{U, V}][0])
        {
            mp[{U, V}] = {W, i}, mp[{V, U}] = {W, i};
        }
    }

    int K;
    cin >> K;
    vector<int> X(K);
    for (int i = 0; i < K; i++)
    {
        cin >> X[i];
    }

    auto [Y, edges] = graph.solve(X);
    cout << Y << " " << edges.size() << endl;
    long long sum = 0;
    for (auto [U, V] : edges)
    {
        cout << mp[{U, V}][1] << " ";
    }
    cout << endl;

    return 0;
}

Test cases

Env Name Status Elapsed Memory
g++ example_00 :heavy_check_mark: AC 5 ms 3 MB
g++ example_01 :heavy_check_mark: AC 4 ms 3 MB
g++ example_02 :heavy_check_mark: AC 4 ms 3 MB
g++ overflow_killer_00 :heavy_check_mark: AC 171 ms 29 MB
g++ random_00 :heavy_check_mark: AC 5 ms 4 MB
g++ random_01 :heavy_check_mark: AC 5 ms 4 MB
g++ random_02 :heavy_check_mark: AC 7 ms 4 MB
g++ random_03 :heavy_check_mark: AC 4 ms 4 MB
g++ random_04 :heavy_check_mark: AC 5 ms 4 MB
g++ random_complete_00 :heavy_check_mark: AC 80 ms 15 MB
g++ random_complete_01 :heavy_check_mark: AC 80 ms 15 MB
g++ random_complete_02 :heavy_check_mark: AC 78 ms 15 MB
g++ random_complete_03 :heavy_check_mark: AC 82 ms 15 MB
g++ random_complete_04 :heavy_check_mark: AC 79 ms 15 MB
g++ random_max_00 :heavy_check_mark: AC 174 ms 29 MB
g++ random_max_01 :heavy_check_mark: AC 174 ms 29 MB
g++ random_max_02 :heavy_check_mark: AC 176 ms 29 MB
g++ random_max_03 :heavy_check_mark: AC 178 ms 29 MB
g++ random_max_04 :heavy_check_mark: AC 180 ms 29 MB
g++ random_small_00 :heavy_check_mark: AC 5 ms 4 MB
g++ random_small_01 :heavy_check_mark: AC 4 ms 4 MB
g++ random_small_02 :heavy_check_mark: AC 11 ms 5 MB
g++ random_small_03 :heavy_check_mark: AC 4 ms 3 MB
g++ random_small_04 :heavy_check_mark: AC 5 ms 4 MB
Back to top page