CF1010F Tree

发布时间 2023-07-27 08:26:27作者: Ender_32k

题意

  • 给定一棵根为 \(1\) 的二叉树 \(T\),根上有 \(x\) 个水果。
  • 某些枝条(二叉树的边)会断掉,留下一个包含根节点的联通块 \(T'\)
  • 给剩下的 \(T'\) 中每个点 \(u\) 赋点权 \(a_u\) 表示这个点上的水果数量,满足 \(a_1=x\) 并且 \(a_u\ge \sum\limits_{v\in \text{son}(u)}a_v\)
  • 计算合法二元组 \((T',a)\) 的数量,对 \(998244353\) 取模。

题解

考虑差分,\(b_u=a_u-\sum\limits_{v\in \text{son}(u)}a_v\),特别地,\(u\) 是叶子时,\(\text{son}(u)=\varnothing\)

显然有 \(b_i\ge 0\),并且序列 \(a\)\(b\) 一一对应,考虑数合法 \(b\) 的个数。那么 \(a\) 相当于 \(b\) 的前缀和,则 \(a_1=\sum\limits_{u\in T'}b_u=x\),相当于把 \(x\) 拆分成 \(|T'|\) 个互相区分的数,方案数为 \(\dbinom{x+|T'|-1}{|T'|-1}\),于是问题转换为给定 \(k=|T'|\),求包含 \(1\) 节点的大小为 \(k\) 的连通块 \(T'\) 的个数。

不难想到树形 dp,令 \(f_{i,j}\) 表示 \(i\) 子树内包含 \(i\) 且大小为 \(j\) 的连通块个数,令 \(v_1,v_2\) 分别是 \(u\) 的左右儿子,转移是显然的背包:

\[f_{u,i}=\sum\limits_{j=0}^{i-1}f_{v_1,j}f_{v_2,i-j-1} \]

这是卷积的形式,考虑写出 \(f_{u}\) 的 OGF:

\[F_u(x)=\sum\limits_{i}f_{u,i}x^i \]

转移可以写成如下形式:

\[F_u(x)=xF_{v_1}(x)F_{v_2}(x)+1 \]

\(v\) 不存在时,视作 \(F_v(x)=1\)。于是我们就有了一个 \(O(n^2\log n)\) 的暴力 NTT 卷积做法。

考虑类似动态 dp,我们只关心 \(F_1(x)\) 的值,所以考虑只维护出所有重链顶端的点 \(u\)\(F_u(x)\)

假设一条重链从链顶到链底形如 \(u_1\to u_2\to \cdots\to u_m\)。考虑 \(u_i\) 的重儿子为 \(u_{i+1}\),轻儿子为 \(v_i\),显然轻儿子是另一条重链的链顶,它的 \(F_{v_i}(x)\) 可以递归地进行维护,所以可以假设我们现在知道了所有 \(F_{v_i}(x)\),根据刚才的转移:

  • 由于 \(u_m\) 为叶子节点,\(F_{u_m}(x)=x+1\)
  • \(F_{u_i}(x)(1\le i<m)=xF_{u_{i+1}}(x)F_{v_i}(x)+1\)

这就满足很好的序列递推性质:

\[\begin{aligned}F_{u_1}(x)&=xF_{u_2}(x)F_{v_1}(x)+1\\&=x(xF_{u_3}(x)F_{v_2}(x)+1)F_{v1}(x)+1\\&=x^2F_{u_3}(x)F_{v_2}(x)F_{v_1}(x)+xF_{v_1}(x)+1\\&=\cdots\\&=x^{m-1}F_{u_m}(x)F_{v_{m-1}}(x)F_{v_{m-2}}(x)\cdots F_{v_1}(x)+\\&\quad\ \ x^{m-2}F_{v_{m-2}}(x)F_{v_{m-3}}(x)\cdots F_{v_{1}}(x)+\\&\quad\ \ \cdots\\&\quad \ \ xF_{v_1}(x)+1\\&=\sum\limits_{i=0}^{m-1}x^i\prod\limits_{j=0}^iF_{v_j}(x)\end{aligned} \]

接下来的事情就很简单了,我们发现这个东西可以分治乘来维护:

\[G_{l,r}(x)=\prod\limits_{i=l}^rF_{v_i}(x) \]

\[\begin{aligned}H_{l,r}(x)&=1+xF_{v_1}(x)+x^2F_{v_l}(x)F_{v_{l+1}}(x)+\cdots \\&=\sum\limits_{i=0}^{r-l+1}x^i\prod\limits_{j=l}^{l+i-1}F_{v_j}(x)\end{aligned} \]

显然有:

\[G_{l,r}(x)=G_{l,mid}(x)G_{mid+1,r}(x) \]

\[H_{l,r}(x)=H_{l,mid}(x)+H_{mid+1,r}(x)(G_{mid+1,r}(x)-1) \]

维护二元组 \((H,G)\) 即可分治 NTT 乘法,答案即为:

\[\sum\limits_{i=1}^{n}[x^i]F_1(x)\dbinom{x+i-1}{i-1} \]

每个点往根走,只有经过轻边才会对复杂度贡献,所以一个点至多贡献 \(O(\log n)\) 次,加上分治乘法复杂度是 \(O(n\log^3 n)\) 的,但是众所周知,树剖卡不满。

很好写。

// Problem: F. Tree
// Contest: Codeforces - Codeforces Round 499 (Div. 1)
// URL: https://codeforces.com/problemset/problem/1010/F
// Memory Limit: 256 MB
// Time Limit: 7000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define ll long long
using namespace std;

namespace vbzIO {
    char ibuf[(1 << 20) + 1], *iS, *iT;
    #if ONLINE_JUDGE
    #define gh() (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, (1 << 20) + 1, stdin), (iS == iT ? EOF : *iS++) : *iS++)
    #else
    #define gh() getchar()
    #endif
    #define mt make_tuple
    #define mp make_pair
    #define fi first
    #define se second
    #define pc putchar
    #define pb emplace_back
    #define ins insert
    #define era erase
    typedef tuple<int, int, int> tu3;
    typedef pair<int, int> pi;
    inline int rd() {
        char ch = gh();
        int x = 0;
        bool t = 0;
        while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
        while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
        return t ? ~(x - 1) : x;
    }
    inline ll rdl() {
        char ch = gh();
        ll x = 0;
        bool t = 0;
        while (ch < '0' || ch > '9') t |= ch == '-', ch = gh();
        while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gh();
        return t ? ~(x - 1) : x;
    }
    inline void wr(int x) {
        if (x < 0) x = ~(x - 1), pc('-');
        if (x > 9) wr(x / 10);
        pc(x % 10 + '0');
    }
}
using namespace vbzIO;

typedef vector<int> poly;
const int N = 1e5 + 100;
const int M = 3e5 + 300;
const int P = 998244353;
const int G = 114514;

poly dp[N];
vector<poly> p;
vector<int> g[N], cn[N];
int n, m, inv[N], sz[N], fa[N], son[N], ch[N][2];

int Add(int x, int y) { return x += y, (x >= P) ? (x - P) : x; }
int Mul(int x, int y) { return 1ll * x * y % P; }
int Sub(int x, int y) { return Add(x, P - y); }
void Addi(int &x, int y) { x = Add(x, y); }
void Muli(int &x, int y) { x = Mul(x, y); }
void Subi(int &x, int y) { x = Sub(x, y); }

int qpow(int p, int q) {
	int res = 1;
	for (; q; q >>= 1, p = 1ll * p * p % P)
		if (q & 1) res = 1ll * res * p % P;
	return res;
}

const int iG = qpow(G, P - 2);

poly Plus(poly x, int y) {
	if (!x.size()) x.resize(1);
	return Addi(x[0], y), x; 
}

poly Plus(poly x, poly y) {
	int szx = x.size(), szy = y.size();
	if (szx < szy) swap(x, y);
	for (int i = 0; i < min(szx, szy); i++) Addi(x[i], y[i]);
	return x;
}

void NTT(int *f, int len, int lim, int op) {
	static int tr[M];
	for (int i = 1; i < len; i++)
		tr[i] = (tr[i >> 1] >> 1) | ((i & 1) << (lim - 1));
	for (int i = 0; i < len; i++)
		if (i < tr[i]) swap(f[i], f[tr[i]]);
	for (int o = 2, k = 1; k < len; o <<= 1, k <<= 1) {
		int tg = qpow(~op ? G : iG, (P - 1) / o);
		for (int i = 0; i < len; i += o) {
			for (int j = 0, w = 1; j < k; j++, Muli(w, tg)) {
				int x = f[i + j], y = Mul(f[i + j + k], w);
				f[i + j] = Add(x, y);
				f[i + j + k] = Sub(x, y);
			}
		}
	}
	if (~op) return;
	int iv = qpow(len, P - 2);
	for (int i = 0; i < len; i++) Muli(f[i], iv);
}

poly Conv(poly x, poly y) {
	static int tx[M], ty[M];
	int len = 1, lim = 0, szx = x.size(), szy = y.size();
	while (len < szx + szy) len <<= 1, lim++;
	for (int i = 0; i < len; i++) tx[i] = ty[i] = 0;
	for (int i = 0; i < szx; i++) tx[i] = x[i];
	for (int i = 0; i < szy; i++) ty[i] = y[i];
	NTT(tx, len, lim, 1), NTT(ty, len, lim, 1);
	for (int i = 0; i < len; i++) Muli(tx[i], ty[i]);
	NTT(tx, len, lim, -1);
	poly res;
	for (int i = 0; i < szx + szy - 1; i++) res.pb(tx[i]);
	return res;
}

pair<poly, poly> conq(int l, int r) {
	if (l == r) return mp(Plus(p[l], 1), p[l]);
	int mid = (l + r) >> 1;
	auto lhs = conq(l, mid), rhs = conq(mid + 1, r);
	return mp(Plus(Conv(Plus(rhs.fi, P - 1), lhs.se), lhs.fi), Conv(lhs.se, rhs.se));
}

void init(int lim) {
	inv[1] = inv[0] = 1;
	for (int i = 2; i <= lim; i++)
		inv[i] = Mul(inv[P % i], P - P / i);
}

int hvy(int x) { return son[x] == ch[x][1]; }
int lht(int x) { return son[x] != ch[x][1]; }
#define lh(x) ch[x][lht(x)]

void dfs1(int u, int fat) {
	fa[u] = fat, sz[u] = 1;
	int ct = 0;
	for (int v : g[u]) {
		if (v == fat) continue;
		dfs1(v, u), sz[u] += sz[v], ch[u][ct++] = v;
		if (sz[v] > sz[son[u]]) son[u] = v;
	}
}

void dfs2(int u, int pr) {
	cn[pr].pb(u);
	if (son[u]) dfs2(son[u], pr);
	if (lh(u)) dfs2(lh(u), lh(u));
}

void dfs3(int u) {
	if (!son[u]) return dp[u].resize(2), dp[u][0] = dp[u][1] = 1, void();
	for (int v : cn[u]) 
		if (lh(v)) dfs3(lh(v));
	p.clear();
	for (int v : cn[u]) {
		int fl = 0;
		if (lh(v)) dp[lh(v)].ins(dp[lh(v)].begin(), 0), p.pb(dp[lh(v)]), fl = 1;
		if (!fl) { poly tp; tp.resize(2), tp[1] = 1, p.pb(tp); }
	}
	dp[u] = conq(0, p.size() - 1).fi;
}

int main() {
	n = rd(), m = rdl() % P, init(n);
	for (int i = 1, u, v; i < n; i++)
		u = rd(), v = rd(), g[u].pb(v), g[v].pb(u);
	dfs1(1, 0), dfs2(1, 1), dfs3(1);
	int ans = 0;
	for (int i = 1, fc = 1; i < (int)dp[1].size(); i++) 
		Addi(ans, Mul(fc, dp[1][i])), Muli(fc, Add(m, i)), Muli(fc, inv[i]);
	wr(ans);
    return 0;
}