[AGC052C] Nondivisible Prefix Sums 题解

发布时间 2023-11-30 19:19:51作者: Farmer_D

题目链接

点击打开链接

题目解法

好题!
一个序列是不合法的,必定满足某些结论,我们不妨猜测一下
首先如果和为 \(P\) 的倍数,必定不合法
然后手玩几个可以发现,最极限的情况是 \(P-1\)\(1\;+\;\) \(b_i\; + \;\) \(P-b_i\)
如果在这个情况下再加一个 \(1\),就爆了
其中 \(1\) 可以替换为 \(P-1\) 个数,因为任何序列如果众数不为 \(1\),整体乘众数的逆元即可

考虑如何证明?

  1. 必要性。如果不满足,显然构造不出
  2. 充分性。
    考虑一种优秀的构造方案,每次如果能选众数填进去就填,不能就随便填一个数,这样一定会填成上面最极限的形式

对于和为 \(P\) 的倍数的序列,可以直接小小的容斥一下
如果不考虑最后一位不可填是 \((P-1)^{n-1}\),但倒数第二位的前缀和如果是 \(0\) 的话就不可填,一直往前推,不难得到容斥之后答案为 \(\sum\limits_{i=1}^{n-1}(-1)^{n-i}(P-1)^i\)

对于第二种情况如何求解
\(f_{i,j}\) 为选了 \(i\)\(b\)\(\sum{P-b_k}\)\(j\) 的方案数
不难用前缀和优化 \(dp\) 做到 \(O(n^2)\) 的复杂度
考虑合法的 \(i,j\) 的条件,即 \(n-i\ge P+j\)\(n-i\not\equiv j(\mod P)\)
直接组合数算一算即可

时间复杂度 \(O(n^2)\)

#include <bits/stdc++.h>
#define F(i,x,y) for(int i=(x);i<=(y);i++)
#define DF(i,x,y) for(int i=(x);i>=(y);i--)
#define ms(x,y) memset(x,y,sizeof(x))
#define SZ(x) (int)x.size()-1
#define pb push_back
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
typedef pair<int,int> pii;
template<typename T> void chkmax(T &x,T y){ x=max(x,y);}
template<typename T> void chkmin(T &x,T y){ x=min(x,y);}
inline int read(){
    int FF=0,RR=1;
    char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
    for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
    return FF*RR;
}
const int N=5100,mod=998244353;
int n,P,f[N][N],s[N][N];
int fac[N],ifac[N];
inline void inc(int &x,int y){ x+=y;if(x>=mod) x-=mod;}
int C(int x,int y){ return 1ll*fac[x]*ifac[y]%mod*ifac[x-y]%mod;}
int qmi(int a,int b){
    int res=1;
    for(;b;b>>=1){ if(b&1) res=1ll*res*a%mod;a=1ll*a*a%mod;}
    return res;
}
int main(){
    n=read(),P=read();
    fac[0]=1;
    F(i,1,n) fac[i]=1ll*fac[i-1]*i%mod;
    ifac[n]=qmi(fac[n],mod-2);
    DF(i,n-1,0) ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
    //sum % P = 0
    int res=1,tot1=0;
    F(i,1,n-1){
        res=1ll*res*(P-1)%mod;
        if((n-i)&1) tot1=(tot1+res)%mod;
        else tot1=(tot1-res+mod)%mod;
    }
    //otherwise
    int bound=0;
    F(i,0,n) if(n-i>1ll*(i+1)*(P-1)) bound=i;
    f[0][0]=1;
    F(j,0,n) s[0][j]=1; 
    F(i,1,n) F(j,1,n){
        f[i][j]=s[i-1][j-1];
        if(j-(P-1)>=0) f[i][j]=(f[i][j]-s[i-1][j-(P-1)]+mod)%mod;
        s[i][j]=(s[i][j-1]+f[i][j])%mod;
    }
    int tot2=0;
    F(i,0,n) F(j,0,n) if(n-i>=P+j&&(n-i-j)%P) inc(tot2,1ll*f[i][j]*C(n,i)%mod);
    tot2=1ll*tot2*(P-1)%mod;
    //tot
    int ans=1;
    F(i,1,n) ans=1ll*ans*(P-1)%mod;
    ans=(1ll*ans-tot1-tot2+2*mod)%mod;
    printf("%d\n",ans);
    fprintf(stderr,"%d ms\n",int(1e3*clock()/CLOCKS_PER_SEC));
    return 0;
}