完全图的最小生成树是不好求的,但是发现 \(\mathcal{O}(n^2)\) 级别的边中显然有很多都是没有用的,这种时候可以考虑分治。
显然如果对 \(E'(E'\in E)\) 求 MST,没有选择的边一定也不在最后的 MST 的边集中。于是就让选出的边集的并等于原图,然后再求一遍 MST 即可,考虑点分治。
点分治时,记当前的重心为 \(r'\),当前的分治树为 \(T'\)。若此时各个儿子的子树已经处理完,记为 \(T_1,T_2,\dots,T_i\),再在这个基础上添加 \(i - 1\) 条边就行了。考虑把这条 \(i\to j\) 的路径拆成 \(i\to p\) 和 \(p\to j\) 这两段考虑,其中 \(p\) 是 \(i, j\) 在点分树上的 LCA。令 \(val_i = d_i+ w_i\),\(d_i\) 表示点 \(i\) 到 \(r'\) 的距离,按照 \(val_i\) 排序即可。正确性是显然的,因为点分树上两点的 LCA 一定在原树中这两个点的路径上。
此时的边数是 \(\mathcal{O}(n\log n)\) 的,故时间复杂度为 \(\mathcal{O}(n\log^2n)\),瓶颈源于排序。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define vi vector < int >
#define eb emplace_back
#define pii pair < ll, int >
#define fi first
#define se second
#define TIME 1e3 * clock() / CLOCKS_PER_SEC
int Mbe;
mt19937_64 rng(35);
constexpr int N = 2e5 + 10;
int n, rt, cnt;
ll ans;
int a[N], mx[N], sz[N], vis[N], fa[N];
int head[N], cnt_e;
struct edge {
int to, w, nxt;
} e[N << 1];
struct graph {
int u, v;
ll w;
} g[N << 5];
void adde(int u, int v, int w) {
++cnt_e, e[cnt_e].to = v, e[cnt_e].w = w, e[cnt_e].nxt = head[u], head[u] = cnt_e;
}
int find(int x) {
if(fa[x] == x) return x;
return fa[x] = find(fa[x]);
}
void findroot(int u, int fath, int num) {
sz[u] = 1, mx[u] = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
findroot(v, u, num);
sz[u] += sz[v];
mx[u] = max(mx[u], sz[v]);
}
mx[u] = max(mx[u], num - sz[u]);
if(mx[u] < mx[rt]) rt = u;
}
int tp;
pii stk[N];
void getdep(int u, int fath, ll val) {
stk[++tp] = pii(a[u] + val, u);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v == fath || vis[v]) continue;
getdep(v, u, val + e[i].w);
}
}
void divide(int u) {
vis[u] = 1;
stk[++tp] = pii(a[u], u);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
getdep(v, u, e[i].w);
}
sort(stk + 1, stk + tp + 1);
for(int i = 2; i <= tp; ++i)
++cnt, g[cnt].u = stk[1].se, g[cnt].v = stk[i].se, g[cnt].w = stk[1].fi + stk[i].fi;
tp = 0;
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(vis[v]) continue;
rt = 0;
findroot(v, u, sz[v]);
divide(rt);
}
}
int Med;
int main() {
fprintf(stderr, "%.3lf MB\n", (&Mbe - &Med) / 1048576.0);
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
for(int i = 1; i <= n; ++i) cin >> a[i];
for(int i = 1; i < n; ++i) {
int u, v, w;
cin >> u >> v >> w;
adde(u, v, w);
adde(v, u, w);
}
mx[0] = N;
findroot(1, 0, n);
divide(rt);
for(int i = 1; i <= n; ++i) fa[i] = i;
sort(g + 1, g + cnt + 1, [](graph a, graph b){
return a.w < b.w;
});
for(int i = 1; i <= cnt; ++i) {
int u = find(g[i].u), v = find(g[i].v);
if(u != v) fa[u] = v, ans += g[i].w;
}
cout << ans << "\n";
cerr << TIME << "ms\n";
return 0;
}