Tree MST 题解

发布时间 2023-11-15 20:45:32作者: Pengzt

洛谷 AT

完全图的最小生成树是不好求的,但是发现 \(\mathcal{O}(n^2)\) 级别的边中显然有很多都是没有用的,这种时候可以考虑分治。

显然如果对 \(E'(E'\in E)\) 求 MST,没有选择的边一定也不在最后的 MST 的边集中。于是就让选出的边集的并等于原图,然后再求一遍 MST 即可,考虑点分治。

点分治时,记当前的重心为 \(r'\),当前的分治树为 \(T'\)。若此时各个儿子的子树已经处理完,记为 \(T_1,T_2,\dots,T_i\),再在这个基础上添加 \(i - 1\) 条边就行了。考虑把这条 \(i\to j\) 的路径拆成 \(i\to p\)\(p\to j\) 这两段考虑,其中 \(p\)\(i, j\) 在点分树上的 LCA。令 \(val_i = d_i+ w_i\)\(d_i\) 表示点 \(i\)\(r'\) 的距离,按照 \(val_i\) 排序即可。正确性是显然的,因为点分树上两点的 LCA 一定在原树中这两个点的路径上。

此时的边数是 \(\mathcal{O}(n\log n)\) 的,故时间复杂度为 \(\mathcal{O}(n\log^2n)\),瓶颈源于排序。

代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector < int >
#define eb emplace_back
#define pii pair < ll, int >
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
int Mbe;
mt19937_64 rng(35);
constexpr int N = 2e5 + 10;
int n, rt, cnt;
ll ans;
int a[N], mx[N], sz[N], vis[N], fa[N];
int head[N], cnt_e;
struct edge {
	int to, w, nxt;
} e[N << 1];
struct graph {
	int u, v;
	ll w;
} g[N << 5];
void adde(int u, int v, int w) {
	++cnt_e, e[cnt_e].to = v, e[cnt_e].w = w, e[cnt_e].nxt = head[u], head[u] = cnt_e;
}
int find(int x) {
	if(fa[x] == x) return x;
	return fa[x] = find(fa[x]);
}
void findroot(int u, int fath, int num) {
	sz[u] = 1, mx[u] = 0;
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == fath || vis[v]) continue;
		findroot(v, u, num);
		sz[u] += sz[v];
		mx[u] = max(mx[u], sz[v]);
	}
	mx[u] = max(mx[u], num - sz[u]);
	if(mx[u] < mx[rt]) rt = u;
}
int tp;
pii stk[N];
void getdep(int u, int fath, ll val) {
	stk[++tp] = pii(a[u] + val, u);
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(v == fath || vis[v]) continue;
		getdep(v, u, val + e[i].w);
	}
}
void divide(int u) {
	vis[u] = 1;
	stk[++tp] = pii(a[u], u);
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(vis[v]) continue;
		getdep(v, u, e[i].w);
	}
	sort(stk + 1, stk + tp + 1);
	for(int i = 2; i <= tp; ++i)
		++cnt, g[cnt].u = stk[1].se, g[cnt].v = stk[i].se, g[cnt].w = stk[1].fi + stk[i].fi;
	tp = 0;
	for(int i = head[u]; i; i = e[i].nxt) {
		int v = e[i].to;
		if(vis[v]) continue;
		rt = 0;
		findroot(v, u, sz[v]);
		divide(rt);
	}
}
int Med;
int main() {
	fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
	ios :: sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> n;
	for(int i = 1; i <= n; ++i) cin >> a[i];
	for(int i = 1; i < n; ++i) {
		int u, v, w;
		cin >> u >> v >> w;
		adde(u, v, w);
		adde(v, u, w);
	}
	mx[0] = N;
	findroot(1, 0, n);
	divide(rt);
	for(int i = 1; i <= n; ++i) fa[i] = i;
	sort(g + 1, g + cnt + 1, [](graph a, graph b){
		return a.w < b.w;
	});
	for(int i = 1; i <= cnt; ++i) {
		int u = find(g[i].u), v = find(g[i].v);
		if(u != v) fa[u] = v, ans += g[i].w;
	}
	cout << ans << "\n";
	cerr << TIME << "ms\n";
	return 0;
}