[AGC045D] Lamps and Buttons 题解

发布时间 2023-07-19 09:34:45作者: 霜木_Atomic

[AGC045D] Lamps and Buttons 题解

首先,由于排列生成随机,所以最优决策就是不决策(反正你也不知道),也就是,让 Snuke 从左往右依次按。
那么,什么情况下 Snuke 会输呢?我们可以把每个 \(p_i\)\(i\) 连边,我们发现,如果灭着的灯里面存在自环,也就是只能自己打开自己的,或者在打开所有灯之前,把亮着的灯中有自环的灭掉了,都会输。那么,我们可以考虑枚举 \([1, A]\) 中第一个自环的位置 \(t\),(还有可能没有自环,那么就设为 \(A+1\) ),获胜的条件就变为,在按到 \(t\) 之前,能够把其它的灯全打开。而这等效于,对于 $ \forall x \in [A+1, n] , \exists i \in [1, t-1] $,使得 \(x\)\(i\) 在一个环内。
现在我们就要求 \(t\) 内没有自环的情况,考虑二项式反演。我们钦定 \(t\) 内有 \(k\) 个自环,则整个序列被划分为五部分:\([1, t-1]\) 内的自环,\([1, t-1]\) 内其他的点,\(t\) 点,\([t+1, A]\)的亮灯部分,以及最初没有亮的 \([A+1, n]\) 部分。对于第一个部分是在 \(t-1\) 中取出 \(k\) 个点,第三个部分不用考虑,我们来讨论剩下的部分。
因为要成环,我们就通过插入来考虑。每次插入一个点 \(p\),都相当于是断开前面连接两个点 \(u, v\) 的一条边,然后加上 \(u\)\(p\)\(p\)\(v\) 的边。当然,也可以自己成为一个新环。那么,对于第二个部分内的点,可以随意插入,也可以自成环,所以每枚举到第 \(i\) 个点都有 \(i\) 种方案;而为了保证 “对于 $ \forall x \in [A+1, n] , \exists i \in [1, t-1] $,使得 \(x\)\(i\) 在一个环内。” 这个条件,连完第二部分后就要去考虑第五部分。这个部分中的点不能连自环,故每次枚举贡献 \(i-1\);然后第四部分也是随意插入,每次贡献 \(i\)。注意这里的 \(i\) 是连续枚举而非分别枚举。我们令这三部分的大小分别为 \(a, b, c\),整理一下,则有 \(g_k = {t-1 \choose k} \frac{a(a+b+c)!}{a+b}\)。直接套二项式反演即可。
如果按照我一开始每次单独求 \(a+b\) 逆元,总复杂度为 \(O(n+A^2 \log n)\)(机子快能跑过去)。当然这里可以直接利用阶乘和阶乘逆元求出来,这样复杂度就是 \(O(n+A^2)\)
代码:

#include<bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int N = 1e7+10, M = 5050;

int fac[N], inv[N];
inline int C(int n, int m){
	if(n<0 || m<0 || n<m) return 0;
	return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int n, m;

inline int fpow(int a, int b){
	a%=mod;
	int ret = 1;
	while(b){
		if(b & 1){
			ret = (1ll*ret*a)%mod;
		}
		b>>=1;
		a = (1ll*a*a)%mod;
	}
	return ret;
}
void prework(){
	fac[0] = 1;
	for(int i = 1; i<=n; ++i){
		fac[i] = (1ll*fac[i-1]*i)%mod;
	}
	inv[n] = fpow(fac[n], mod-2);
	for(int i = n-1; i>=0; --i){
		inv[i] = (1ll*inv[i+1]*(i+1))%mod;
	}
}

inline int calc(int a, int b, int c){
	return 1ll*a*fac[a+b+c]%mod*fac[a+b-1]%mod*inv[a+b]%mod;
}
int ans;
int main(){
	scanf("%d%d", &n, &m);
	prework();
	int a, b, c;
	for(int t = 1; t<=m+1; ++t){
		for(int i = 0; i<t; ++i){
			a = t-i-1, b = n-m, c = m-t;
			if(t == m+1) c = 0;
			int fu = (i&1)?-1:1; 
			ans = (1ll*ans+1ll*fu*C(t-1, a)*calc(a, b, c)%mod)%mod;
			ans = (ans+mod)%mod;
		}
	}	
	printf("%d\n", ans);
	return 0;
}