题解 Gym 102978F【Find the LCA】

发布时间 2023-08-10 23:08:57作者: caijianhong

problem

You are given an integer sequence \(A_1,A_2,\ldots,A_N\). You'll make a rooted tree with \(N\) vertices numbered from \(1\) through \(N\). The vertex \(1\) is the root, and for each vertex \(i\) (\(2 \leq i \leq N\)), its parent \(p_i\) must satisfy \(p_i>i\).

You define the score of a rooted tree as follows:

  • Let \(x\) be the lowest common ancestor of the vertex \(N-1\) and the vertex \(N\). Then, the score is

    \[\prod_{v \in (\text{subtree rooted at $x$})} A_v \]

  • Note that we consider \(x\) itself is in the subtree rooted at \(x\).

There are \((N-1)!\) ways to make a tree. Find the sum of scores of all possible trees, modulo \(998244353\).

solution

结论:满足 \(lca(n,n-1)=1\) 的满足 \(p_i<i\) 的树恰好有 \((n-1)!/2\) 个,且与不是这个的形成双射。证明不会。

\(lca(n,n-1)=1\) 的树和不是的分开统计,前者的答案是 \((n-1)!/2\prod_u a_u\)

对于后者,枚举 \(lca\),枚举子树中的点是哪些,记不计算 \(n,n-1\)\(i\) 的子树有 \(j\) 个点的乘积和(只算乘积)为

\[F_{i,j}=\sum_{i=i_1<i_2<\cdots<i_j<n-1}\prod_{k=1}^jA_{i_k}=A_i[x^{j-1}]\prod_{k=i+1}^{n-2}(1+A_k\cdot x). \]

那么答案是枚举 \(i,j\) 得到:

\[A_nA_{n-1}\sum_j \frac{(j+1)!}{2}(n-j-3)!\sum_i (i-1)f_{i,j} \]

前面一项枚举了子树形态,共 \(j+2\) 个点;后一项是其它的不在子树的点的方案数,然后枚举 lca,还有 lca 的父亲,然后是具体选法。

所以要算后面的,这样就能枚举 \(j\) 计算答案。

\[\sum_i (i-1)f_{i,j}=\sum_i(i-1)A_i[x^{j-1}]\prod_{k=i+1}^{n-2}(1+A_k\cdot x)=[x^{j-1}]\sum_i(i-1)A_i\prod_{k=i+1}^{n-2}(1+A_k\cdot x). \]

称后面的东西为 \(F\),我们要算出 \(F\) 的各项系数。

我们可以用这么这一东西,就是你考虑维护二元组 \((f,g)\),其中 \(f,g\) 是多项式,\(f\) 维护了 \(\prod_k(1+A_kx)\)\(g\) 维护了 \(\sum_i (i-1)A_i\cdot f\),就是考虑暴力,从右往右维护 \((f,g)\),先是 \(g:=g+(i-1)A_if\),然后 \(f:=f(1+A_ix)\)。这就相当于一个函数复合,我们可以分治维护,已经知道左边的 \((f_1,g_1)\) 和右边的 \((f_2,g_2)\),可以求出新的为 \((f_1f_2,g_1f_1+g_2)\),这样就能求出答案了。

code

有点细节

C++ 20 要开

点击查看代码

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr,##__VA_ARGS__)
#else
#define debug(...) void(0)
#endif
typedef long long LL;
typedef long long LL;
LL qpow(LL a,LL b,int p){LL r=1;for(a%=p;b;b>>=1,a=a*a%p) if(b&1) r=r*a%p; return r;}
const int P=998244353,G=3,G0=qpow(G,P-2,P),inv2=(P+1)/2;
LL mod(LL x){return (x%P+P)%P;}
void red(LL&x){x%=P;}
void ntt(vector<LL> &a,int op){
	int n=a.size(); vector<LL> w(n);
	for(int i=1;i<n;i++) w[i]=w[i>>1]>>1|(i&1?n>>1:0);
	for(int i=0;i<n;i++) if(i<w[i]) swap(a[i],a[w[i]]);
	for(int k=1,len=2;len<=n;k<<=1,len<<=1){
		LL wn=qpow(op==1?G:G0,(P-1)/len,P);
		for(int i=w[0]=1;i<k;i++) red(w[i]=w[i-1]*wn);
		for(int i=0;i<n;i+=len){
			for(int j=0;j<k;j++){
				LL x=a[i+j],y=a[i+j+k]*w[j]%P;
				a[i+j]=x+y,a[i+j+k]=x-y;
				if(a[i+j]>=P) a[i+j]-=P;
				if(a[i+j+k]<0) a[i+j+k]+=P;
			}
		}
	}
	for(LL&x:a) x=mod(x);
	if(op==-1){LL inv=qpow(n,P-2,P); for(LL&x:a) red(x*=inv);}
}
vector<LL> times(vector<LL> a,vector<LL> b){
	int len=1,n=(int)a.size()-1,m=(int)b.size()-1;
	for(;len<=n+m;len<<=1);
	a.resize(len),b.resize(len);
	//return ntt(multiple(ntt(a,1),ntt(b,1)),-1);
	ntt(a,1),ntt(b,1);
	for(int i=0;i<len;i++) red(a[i]*=b[i]);
	ntt(a,-1);
	return a;
}
vector<LL> pls(vector<LL> a,vector<LL> b){
	if(a.size()<b.size()) swap(a,b);
	for(int i=0;i<b.size();i++) red(a[i]+=b[i]);
	return a;
}	
struct node{
	vector<LL> f,g;
	friend node operator&(node a,node b){
		return {times(a.f,b.f),pls(times(a.g,b.f),b.g)};
	}
};
int n;
LL a[1<<18];
node solve(int l,int r){
	if(l>r) return {{},{}};
	if(l==r){
		return {
			{1ll,a[l]},
			{(l-1)*a[l]%P}
		};
	}
	int mid=(l+r)>>1;
	return solve(l,mid)&solve(mid+1,r);
}
LL fac[1<<18];
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	scanf("%d",&n);
	LL sum=1;
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]),red(sum*=a[i]);
	vector<LL> F=solve(2,n-2).g;
	for(int i=fac[0]=1;i<=n;i++) fac[i]=fac[i-1]*i%P;
	LL res=fac[n-2];
	for(int j=1;j<=n-3;j++){
		red(res+=fac[j+1]*inv2%P*fac[n-j-3]%P*F[j-1]%P);
	}
	printf("%lld\n",mod(res*a[n]%P*a[n-1]+sum*fac[n-1]%P*inv2));
	return 0;
}