atcorder 295 E

发布时间 2023-03-29 20:59:51作者: zuotihenkuai

题目链接:https://atcoder.jp/contests/abc295/tasks/abc295_e

题意:
给定一个长为N的数字序列,序列中每个数字都在[0, M]这个区间中。按顺序做两步操作:
第一步,对于数字序列中每个数字0,独立的并且等概率的从区间[1, M]中选择一个数, 把这个0代替成选出来的数。
第二步,把这个数字序列按照升序排列。
问第K位得到数字的期望是什么,输出答案mod998244353

Simple input

3 5 2
2 0 4

Simple output

3

Solution:
第K位的期望是:\(E_k = \displaystyle\sum^{M}_{i = 1}{i * p_i}\)
这样不是很好求,可以将其转化成\(E_k = \displaystyle\sum^{M}_{i = 1}{i * (b_i - b_{i + 1})}\)
其中\(b_i\)指的是第k位大于等于i的概率。
那么\(E_k = 1 * (b_1 - b_2) + 2 * (b_2 - b_3) + \dots + m - 1 * (b_{m - 1} - b_{m}) + b_m = b_1 + b_2 + \dots + b_m\)
接下来就要求出\(b_i\)
分类讨论:
1.:如果序列中大于等于i的数的个数是大于等于n - k + 1的那么\(b_i\)就是1,因为这种情况在不改变0的情况下已经使得在升序排序后第k位恒大于等于i。
2.:如果序列中大于等于i的个数不足n - k + 1,那么就需要把一些0变成大于等于i的数。
对于这种情况,我们事先统计好其中大于等于i的数的个数,设为cnt,并且统计出来其中0的个数,设为num。
如果num + cnt < n - k + 1, 意思就是说,把所有0都转换成大于等于i的数,仍然无法使得第k位的数大于等于i,因此\(b_i\) = 0。
而如果num + cnt >= n - k + 1, 我们就可以从num个0中选出来需要进行转化的0。
对于每一个0,它转换成大于等于i的数的概率是\(P = \frac{m - i + 1}{m}\)
那么,这种情况\(b_i = \displaystyle\sum^{cnt}_{i = n - k + 1 - cnt}\left({i\choose num} * P^i + (1 - P)^{num - i}\right)\)
就是加法原理和乘法原理。

Code:

#include <bits/stdc++.h>

using namespace std;
typedef long long LL;
#define int LL

const int mod = 998244353;
const int N = 5010;
int c[N][N];

int qmi(int a, int b, int c) {
	LL res = 1;
	while(b) {
		if(b & 1) res = res * a % c;
		a = (LL) a * a % c;
		b >>= 1;
	}
	return res;
}

signed main()
{
//	ios::sync_with_stdio(false);
//	cin.tie(0);
	for(int i = 0; i <= 5005; i ++) {
		for(int j = 0; j <= i; j ++) {
			if(!j) c[i][j] = 1;
			else c[i][j] = (c[i - 1][j - 1] + c[i - 1][j]) % mod;
		}
	}
	int n, m, k; cin >> n >> m >> k;
	vector<int> seq(n + 1);
	vector<int> stc(2010, 0);
	int num = 0;//统计序列中零的个数
	for(int i = 1; i <= n; i ++) {
		cin >> seq[i];
		if(!seq[i]) num ++;
		stc[seq[i]] ++;
	}
	int need = n - k + 1;
	int ans = 0;
	for(int i = 1; i <= m; i ++) {
		int cnt = 0;//统计序列中有多少个数字大于等于i
		for(int j = i; j <= m; j ++) cnt += stc[j];
		if(cnt >= need) ans = (ans + 1) % mod;
		else {
			if(cnt + num < need) continue;
			else {
				int p = (m - i + 1) * qmi(m, mod - 2, mod) % mod;
				for(int k = need - cnt; k <= num; k ++) {
					ans = (ans + c[num][k] * qmi(p, k, mod)  % mod * qmi((1 - p + mod) % mod, num - k, mod) % mod) % mod; 
				}
			}
		}
	}
	cout << ans;
}