Codeforces 1799H - Tree Cutting(树形 dp)

发布时间 2023-04-29 22:30:22作者: tzc_wk

思考的时候一直卡在不会在低于 \(O(n)\) 的时间内储存一个连通块的 \(siz\) 有关的信息,看了洛谷题解之后才发现我真是个小丑。

树形 DP。对于一条我们需要操作的边 \((i,fa_i)\),我们将其分为保留子树和删除子树两种类型,对于删除子树,我们在判定其是否合法时候改为判定删除的连通块大小是否为 \(s_{k-1}-s_k\),好处是我们只用关心子树内那一部分的 \(siz\)

这样考虑 DP,\(dp_{u,S,k}\) 表示考虑 \(u\) 子树,有且仅有 \(S\) 操作集合中的操作对应的 \(i\)\(u\) 子树内,且这些操作中最早的保留子树操作是 \(k\) 的方案数(不存在则 \(k=c+1\))。考虑转移,先考虑怎么合并两个子树 \(u,v\),假设要合并 \(dp_{u,S,k_1}\)\(dp_{v,T,k_2}\),那么可以合并的充要条件是:

  • \(S\cap T=\varnothing\)
  • \(k_1=c+1\)\(k_2=c+1\)
  • 假设 \(k_1\ne c+1\),那么 \(T\) 中所有元素必须 \(<k_1\)

然后考虑加入 \((u,fa_u)\) 边的方案数,我们枚举这条边对应的操作 \(i\),那么:

  • 如果这个操作是删除子树,那么要求 \(S\) 中所有元素都 \(<i\),并且当中没有任何一个元素是保留子树的操作。
  • 如果这个操作是保留子树,那么要求其他操作中最早的保留子树的操作的时间 \(>i\)

最后考虑怎么计算保留/删除的连通块的大小,其实是 trivial 的,因为两种情况都保证了我们在 \(i\) 之前的所有操作都是删除子树的操作,所以就直接拿 \(siz_u\) 减去 \(S\) 中所有 \(<i\) 的元素的 \(s_{j-1}-s_j\) 即可。不需要额外记录 \(O(n)\) 的一维。

const int MAXN=5000;
const int MAXK=6;
const int MAXP=64;
const int MOD=998244353;
int n,k,hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec,s[MAXK+3];
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int dp[MAXN+5][MAXK+3][MAXP+5],siz[MAXN+5],mx[MAXP+5];
void dfs(int x,int f){
	dp[x][k+1][0]=1;siz[x]=1;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f)continue;dfs(y,x);siz[x]+=siz[y];
		static int tmp[MAXK+3][MAXP+5];memset(tmp,0,sizeof(tmp));
		for(int i=1;i<=k+1;i++)for(int S=0;S<(1<<k);S++){
			int rst=(1<<k)-1-S;
			for(int T=rst;;--T&=rst){
				if(mx[T]<=i)tmp[i][S|T]=(tmp[i][S|T]+1ll*dp[x][i][S]*dp[y][k+1][T])%MOD;
				if(i!=k+1&&mx[S]<=i)tmp[i][S|T]=(tmp[i][S|T]+1ll*dp[x][k+1][S]*dp[y][i][T])%MOD;
				if(!T)break;
			}
		}
		for(int i=1;i<=k+1;i++)for(int S=0;S<(1<<k);S++)dp[x][i][S]=tmp[i][S];
	}
	if(x!=1){
		static int tmp[MAXK+3][MAXP+5];
		for(int i=1;i<=k+1;i++)for(int S=0;S<(1<<k);S++)tmp[i][S]=dp[x][i][S];
		for(int i=1;i<=k+1;i++)for(int S=0;S<(1<<k);S++)if(dp[x][i][S]){
			for(int j=1;j<i;j++)if(~S>>(j-1)&1){
				int sumsiz=siz[x];bool flg=1;
				for(int l=1;l<j;l++)if(S>>(l-1)&1)sumsiz-=(s[l-1]-s[l]);
				if(sumsiz==s[j])tmp[j][S|(1<<j-1)]=(tmp[j][S|(1<<j-1)]+dp[x][i][S])%MOD;
				if(sumsiz==s[j-1]-s[j]&&mx[S]<j)tmp[i][S|(1<<j-1)]=(tmp[i][S|(1<<j-1)]+dp[x][i][S])%MOD;
			}
		}
		for(int i=1;i<=k+1;i++)for(int S=0;S<(1<<k);S++)dp[x][i][S]=tmp[i][S];
	}
//	for(int i=1;i<=k+1;i++)for(int S=0;S<(1<<k);S++)if(dp[x][i][S])
//		printf("dp[%d][%d][%d]=%d\n",x,i,S,dp[x][i][S]);
}
int main(){
	scanf("%d",&n);for(int i=1,u,v;i<n;i++)scanf("%d%d",&u,&v),adde(u,v),adde(v,u);
	scanf("%d",&k);for(int i=1;i<=k;i++)scanf("%d",&s[i]);s[0]=n;
	for(int i=1;i<(1<<k);i++)for(int j=1;j<=k;j++)if(i>>(j-1)&1)chkmax(mx[i],j);
	dfs(1,0);int res=0;
	for(int i=1;i<=k+1;i++)res=(res+dp[1][i][(1<<k)-1])%MOD;
	printf("%d\n",res);
	return 0;
}