Description

题目链接

题目大意:求图的一棵最小生成树,且满足编号为 $s$ 的点恰好连了 $k$ 条边。$1 \le n \le 5 \times 10^4, 1 \le m \le 5 \times 10^5$。

Solution

一般这种求恰好 $k$ 个问题都是 wqs 二分的经典模型。我们考虑将 $s$ 的所有边同时加上 $\Delta$,然后求 MST 判断是否满足条件。

直接做是 $O(m (\log m + \log n) \log w)$ 的,过不了。

可以用归并排序 + 按秩合并优化,这样可以少几个 $\log$,当然我没有写按秩合并。

最后的时间复杂度是 $O(m \log m + m \log n \log w)$,其中 $w$ 是值域。这个看起来过不了,实际上最慢的点 $819$ ms,开 O2 能到 $400$ ms。

另外这道题要判断无解,具体有:

  • 图不连通
  • 跟 $s$ 相连的边不足 $k$ 条
  • 选了 $k$ 条跟 $s$ 相连的边后剩下的边不能让图连通
  • ......

Code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

template <class T>
inline void read(T &x) {
    x = 0;
    int f = 0;
    char ch = getchar();
    while (!isdigit(ch))    { f |= ch == '-'; ch = getchar(); }
    while (isdigit(ch))     { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
    x = f ? -x : x;
    return ;
}

typedef unsigned long long uLL;
typedef long long LL;

struct Edge {
    int u, v, w;
    friend bool operator < (const Edge &a, const Edge &b) {
        return a.w < b.w;
    }
} e1[500010], e2[500010], e[500010];

int f[50010];
int n, m, s, k, ans, sum, cnt, m1, m2, m3, ntmp;

int find(int x) { return f[x] == x ? x : f[x] = find(f[x]); }

bool check(int mid) {
    cnt = sum = m3 = 0;
    for (int i = 1; i <= m1; ++i)    e1[i].w += mid;
    int p = 1, q = 1;
    while (p <= m1 && q <= m2) {
        if (e1[p].w <= e2[q].w) {
            e[++m3] = e1[p++];
        } else {
            e[++m3] = e2[q++];
        }
    }
    while (p <= m1)    e[++m3] = e1[p++];
    while (q <= m2)    e[++m3] = e2[q++];
    for (int i = 1; i <= n; ++i)    f[i] = i;
    for (int i = 1; i <= m3; ++i) {
        int fu = find(e[i].u), fv = find(e[i].v);
        if (fu != fv) {
            f[fu] = fv;
            if (e[i].u == s || e[i].v == s)    ++cnt;
            sum += e[i].w;
        }
    }
    for (int i = 1; i <= m1; ++i)    e1[i].w -= mid;
    return cnt >= k;
}

void validate(int mid) {
    cnt = sum = m3 = 0;
    for (int i = 1; i <= m1; ++i)    e1[i].w += mid;
    int p = 1, q = 1;
    while (p <= m1 && q <= m2) {
        if (e1[p].w <= e2[q].w) {
            e[++m3] = e1[p++];
        } else {
            e[++m3] = e2[q++];
        }
    }
    while (p <= m1)    e[++m3] = e1[p++];
    while (q <= m2)    e[++m3] = e2[q++];
    for (int i = 1; i <= n; ++i)    f[i] = i;
    for (int i = 1; i <= m3; ++i) {
        int fu = find(e[i].u), fv = find(e[i].v);
        if (fu != fv) {
            f[fu] = fv;
            if (e[i].u == s || e[i].v == s)    ++cnt;
            sum += e[i].w;
        }
    }
    if (cnt != k) {
        puts("Impossible");
        exit(0);
    }
    return ;
}

int main() {
    read(n), read(m), read(s), read(k);
    ntmp = n;
    for (int i = 1; i <= n; ++i)    f[i] = i;
    for (int i = 1, u, v, w; i <= m; ++i) {
        read(u), read(v), read(w);
        int fu = find(u), fv = find(v);
        if (fu != fv) {
            --ntmp;
            f[fu] = fv;
        }
        if (u == s || v == s)    e1[++m1].u = u, e1[m1].v = v, e1[m1].w = w;
        else    e2[++m2].u = u, e2[m2].v = v, e2[m2].w = w;    
    }
    if (m1 < k || ntmp != 1) {
        puts("Impossible");
        return 0;
    }
    std::sort(e1 + 1, e1 + m1 + 1), std::sort(e2 + 1, e2 + m2 + 1);
    int l = -1e9, r = 1e9;
    if (!check(l)) {
        puts("Impossible");
        return 0;
    }
    while (l <= r) {
        int mid = (l + r) >> 1;
        if (check(mid))    ans = mid, l = mid + 1;
        else    r = mid - 1;
    }
    validate(ans);
    printf("%d\n", sum - ans * k);
    return 0;
}
Last modification:September 2nd, 2021 at 02:14 am