GDCPC2023 L Classic Problem

发布时间 2023-10-13 17:15:57作者: zltzlt

洛谷传送门

CF 传送门

对于一个点 \(x\),若 \(\exists i, u_i = x \lor v_i = x\),则称 \(x\)特殊点,否则为一般点

首先发现,对于极长的一段 \([l, r]\) 满足 \(l \sim r\) 均为一般点,那么可以连边 \((l, l + 1), (l + 1, l + 2), \ldots, (r - 1, r)\),然后把 \([l, r]\) 缩成一个连续点。因为这些点通过别的点与外界连通显然不优。

对于一个特殊点 \(x\),我们把它变成区间为 \([x, x]\) 的连续点。然后把所有连续点按区间左端点排序后重编号。

然后现在相当于我们有至多 \(4m + 1\) 个点的完全图,有一些给定边,若对于之间不存在给定边的点对 \(u, v\ (u < v)\),它们之间的边权是 \(l_v - r_u\)。求这个完全图的最小生成树。

考虑 Boruvka 算法,其流程是每轮对每个连通块找到一条连向另一连通块的最短边,然后合并两端点。

考虑模拟流程。我们首先考虑给定边,然后考虑其他边。前者是容易的。至于后者,我们希望找到 \(u\) 左右侧最接近 \(u\) 且和 \(u\) 不在同一连通块且和 \(u\) 之间没有给定边的点 \(v\)。于是我们每次先处理出 \(pre_u\)\(nxt_u\) 表示 \(u\) 左侧(或右侧)最接近 \(u\) 且和 \(u\) 不在同一连通块的点,然后枚举 \(u\),暴力找 \(v\),以左侧为例,就是若 \(u, v\) 之间存在给定边就 \(v \gets v - 1\),否则 \(v \gets pre_v\)。因为给定边数量是 \(O(m)\) 的,所以这部分复杂度是对的。

若使用 set 存给定边,时间复杂度为 \(O(m \log^2 m)\)

code
// Problem: P9701 [GDCPC2023] Classic Problem
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P9701
// Memory Limit: 1 MB
// Time Limit: 8000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 500100;

ll n, m, lsh[maxn], tot, fa[maxn], pre[maxn], nxt[maxn];
pii f[maxn];
set<ll> S[maxn];

struct E {
	ll u, v, d;
	E(ll a = 0, ll b = 0, ll c = 0) : u(a), v(b), d(c) {}
} Es[maxn], G[maxn];

int find(int x) {
	return fa[x] == x ? x : fa[x] = find(fa[x]);
}

inline bool merge(int x, int y) {
	x = find(x);
	y = find(y);
	if (x != y) {
		fa[x] = y;
		return 1;
	} else {
		return 0;
	}
}

struct node {
	ll l, r, f;
	node(ll a = 0, ll b = 0, ll c = 0) : l(a), r(b), f(c) {}
} a[maxn];

inline bool operator < (const node &a, const node &b) {
	return a.l < b.l;
}

void solve() {
	scanf("%lld%lld", &n, &m);
	for (int i = 1; i <= m * 5; ++i) {
		S[i].clear();
	}
	tot = 0;
	for (int i = 1; i <= m; ++i) {
		scanf("%lld%lld%lld", &Es[i].u, &Es[i].v, &Es[i].d);
		lsh[++tot] = Es[i].u;
		lsh[++tot] = Es[i].v;
	}
	sort(lsh + 1, lsh + tot + 1);
	tot = unique(lsh + 1, lsh + tot + 1) - lsh - 1;
	lsh[0] = 0;
	lsh[tot + 1] = n + 1;
	ll K = tot, ans = 0;
	for (int i = 1; i <= tot; ++i) {
		a[i] = node(lsh[i], lsh[i], 1);
	}
	for (int i = 0; i <= tot; ++i) {
		if (lsh[i] + 1 != lsh[i + 1]) {
			a[++K] = node(lsh[i] + 1, lsh[i + 1] - 1, 0);
			ans += lsh[i + 1] - lsh[i] - 2;
		}
	}
	sort(a + 1, a + K + 1);
	int tt = 0;
	map<pii, ll> mp;
	for (int i = 1; i <= m; ++i) {
		ll u = Es[i].u, v = Es[i].v, d = Es[i].d;
		u = lower_bound(a + 1, a + K + 1, node(u, u, 0)) - a;
		v = lower_bound(a + 1, a + K + 1, node(v, v, 0)) - a;
		if (u > v) {
			swap(u, v);
		}
		G[++tt] = E(u, v, d);
		S[u].insert(v);
		S[v].insert(u);
		mp[mkp(u, v)] = mp[mkp(v, u)] = d;
	}
	for (int i = 1; i <= K; ++i) {
		fa[i] = i;
	}
	while (1) {
		bool flag = 1;
		for (int i = 1; i <= K; ++i) {
			if (find(i) != find(1)) {
				flag = 0;
				break;
			}
		}
		if (flag) {
			break;
		}
		for (int i = 1; i <= K; ++i) {
			f[i] = mkp(1e18, -1);
		}
		for (int u = 1; u <= K; ++u) {
			for (int v : S[u]) {
				if (find(u) == find(v)) {
					continue;
				}
				ll d = mp[mkp(u, v)];
				f[find(u)] = min(f[find(u)], mkp(d, 1LL * find(v)));
			}
		}
		for (int i = 1, j = 1; i <= K; i = (++j)) {
			while (j < K && find(j + 1) == find(i)) {
				++j;
			}
			for (int u = i; u <= j; ++u) {
				pre[u] = i - 1;
				nxt[u] = j + 1;
			}
		}
		for (int u = 1; u <= K; ++u) {
			int v = u;
			while (v >= 1 && (S[u].find(v) != S[u].end() || find(v) == find(u))) {
				if (find(v) == find(u)) {
					v = pre[v];
				} else {
					--v;
				}
			}
			if (v) {
				f[find(u)] = min(f[find(u)], mkp(a[u].l - a[v].r, 1LL * find(v)));
			}
			v = u;
			while (v <= K && (S[u].find(v) != S[u].end() || find(v) == find(u))) {
				if (find(v) == find(u)) {
					v = nxt[v];
				} else {
					++v;
				}
			}
			if (v <= K) {
				f[find(u)] = min(f[find(u)], mkp(a[v].l - a[u].r, 1LL * find(v)));
			}
		}
		for (int i = 1; i <= K; ++i) {
			if (fa[i] == i && merge(i, f[i].scd)) {
				ans += f[i].fst;
			}
		}
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}