题解 [NOI2020] 命运

发布时间 2023-08-01 12:09:34作者: Meatherm

Link

题意

给定一棵 \(n\) 个节点的有根树和 \(m\) 条祖先到后代的链。问有多少种把边权设置为 \(0\)\(1\) 的方案使得每条链上至少有一条边是 \(1\)

答案对 \(998244353\) 取模。

\(1 \leq n,m \leq 5 \times 10^5\)

题解

我们将链的下端称为限制的起点。容易发现,对于同一个起点,终点越深,限制越强,于是不妨只考虑起点。

\(f(u,x)\) 表示限制的起点在 \(u\) 的子树之内,不满足条件的限制最深深度为 \(x\) 的方案数。最终显然有 \(f(1,0)\) 为答案。

考虑树形 DP,不断合并子树来求解 \(f\)。首先遍历 \(u\) 为起点的限制,求出最深的深度 \(d_{max}\),则初始子树中只有 \(u\) 一个点,\(f(u,d_{max})=1\)

考虑合并子树 \(v\),根据 \((u,v)\) 这条边是 \(0\) 还是 \(1\) 转移:

\[f(u,d) \leftarrow \sum \limits_{i=0}^{d} f(v,i) f(u,d) +\sum \limits_{i=0}^{d-1}f(u,i)f(v,d) + \sum \limits_{i=0}^{dep_u} f(v,i)f(u,d) \]

考虑第二个和式的上界是 \(d-1\),因为 \(f(u,d)f(v,d)\) 会出现两次。

记前缀和 \(g(u,d) = \sum \limits_{i=0}^d f(u,d)\)

\[f(u,d) \leftarrow g(v,d)f(u,d)+g(u,d-1)f(v,d)+g(v,dep_u)f(u,d) \\ f(u,d) \leftarrow f(u,d)(g(v,d)+g(v,dep_u))+g(u,d-1)f(v,d) \]

考虑线段树合并转移。这里讲一下转移过程:\(g(v,dep_u)\) 在整个合并过程中都是常量,可以提前查询得到。合并到 \([l,r]\) 时维护 \(g(u,l-1)\)\(g(v,l-1)\),因为我们是从左到右合并,于是这是容易维护的。

如果 \(u,v\) 的树上都有 \([l,r]\) 的节点,直接递归合并。到叶子了就按照上面的式子直接转移。重点讲一下 \(u,v\) 有一个为空的情况:

  • \(u\) 为空,那么 \(f(u,l\cdots r)=0\),于是上面式子里面 \(f(u,d)\) 的项全是 \(0\),且区间内的 \(g(u,d-1) = g(u,l-1)\)。直接给 \(f(v,l\cdots r)\) 乘上 \(g(u,l-1)\) 即可。
  • \(v\) 为空,那么 \(f(v,l\cdots r) =0\),于是上面式子里面 \(f(v,d)\) 的项全是 \(0\),且区间内的 \(g(v,d) = g(v,l-1)\)。直接给 \(f(u,l\cdots r)\) 乘上 \(g(v,l-1)+g(v,dep_u)\) 即可。
# include <bits/stdc++.h>

const int N=500010,mod=998244353;

int n,m;

int dep[N],rt[N];

std::vector <int> G[N];
std::vector <int> lim[N]; 

struct Node{
	int sum,lc,rc,tag;
	Node(){
		tag=1;
		return;
	}
}tr[N*30];
int cnt;



inline int read(void){
	int res,f=1;
	char c;
	while((c=getchar())<'0'||c>'9')
		if(c=='-') f=-1;
	res=c-48;
	while((c=getchar())>='0'&&c<='9')
		res=res*10+c-48;
	return res*f;
}

inline void add(int &x,int v){
	x+=v;
	if(x>=mod) x-=mod;
	return;
}
inline int adc(int a,int b){
	return (a+b<mod)?(a+b):(a+b-mod); 
}
inline int mul(int a,int b){
	return 1ll*a*b%mod;
}
inline int& lc(int x){
	return tr[x].lc;
}
inline int& rc(int x){
	return tr[x].rc;
}

inline void pushup(int x){
	tr[x].sum=adc(tr[lc(x)].sum,tr[rc(x)].sum);
	return;
}
inline void mule(int x,int v){
	if(!x) return;
	tr[x].sum=mul(tr[x].sum,v);
	tr[x].tag=mul(tr[x].tag,v);
	return;
}
inline void pushdown(int x){
	if(tr[x].tag!=1)
		mule(lc(x),tr[x].tag),mule(rc(x),tr[x].tag),tr[x].tag=1;
	return;
}
void change(int &k,int l,int r,int x,int v){
	if(!k) k=++cnt;
//	printf("qwq = %d\n",tr[x].tag);
	if(l==r) return tr[k].sum=v,void();
	int mid=(l+r)>>1;
	if(x<=mid) change(lc(k),l,mid,x,v);
	else change(rc(k),mid+1,r,x,v);
	pushup(k);
	return;
}
void merge(int &k,int x,int y,int l,int r,int &gu,int &gv){ // gu[i-1] gv[i-1]
//	printf("exe\n");
	if(!x&&!y) return k=0,void();
	if(!x){
		add(gv,tr[y].sum),mule(y,gu),k=y;
		return;
	}
	if(!y){
		add(gu,tr[x].sum),mule(x,gv),k=x;
		return;
	}
	if(l==r){
		int fu=tr[x].sum,fv=tr[y].sum;
		add(gv,fv),
		tr[x].sum=adc(1ll*tr[x].sum*gv%mod,1ll*gu*fv%mod),add(gu,fu);
		k=x;
		return;
	}
	int mid=(l+r)>>1;
	pushdown(x),pushdown(y);
	merge(lc(k),lc(x),lc(y),l,mid,gu,gv);
	merge(rc(k),rc(x),rc(y),mid+1,r,gu,gv);
	pushup(k);
	return;
}
int query(int k,int l,int r,int L,int R){
	if(!k) return 0;
	if(L<=l&&r<=R) return tr[k].sum;
	pushdown(k);
	int mid=(l+r)>>1,res=0;
	if(L<=mid) add(res,query(lc(k),l,mid,L,R));
	if(mid<R) add(res,query(rc(k),mid+1,r,L,R));
	return res;
}

void dfs(int i,int fa){
	int md=0,su,sv;
	dep[i]=dep[fa]+1;
	for(auto v:lim[i]) md=std::max(md,dep[v]);
	change(rt[i],0,n,md,1);
	for(auto v:G[i]){
		if(v==fa) continue;
		dfs(v,i),su=0,sv=query(rt[v],0,n,0,dep[i]);
		merge(rt[i],rt[i],rt[v],0,n,su,sv);
	}
	return;
}

int main(void){
	n=read();
	for(int i=1;i<n;++i){
		int u=read(),v=read();
		G[u].push_back(v),G[v].push_back(u); 
	}
	m=read();
	while(m--){
		int u=read(),v=read();
		lim[v].push_back(u);
	}
	dfs(1,0);
	
	printf("%d",query(rt[1],0,n,0,0));
	
	return 0;
}