题解 ABC207F【Tree Patrolling】

发布时间 2023-06-14 20:40:28作者: rui_er

挺简单的树上背包,就是有点难写。

\({dp}_{u,i,x,y}\) 表示仅考虑 \(u\) 的子树内,有 \(i\) 个节点被控制,\(x\) 为节点 \(u\) 是否有警卫,\(y\) 为节点 \(u\) 是否被控制。(其实所有 \(x=1,y=0\) 的状态都没用,但我懒得管了。)

每个点 \(u\) 的初始值为 \({dp}_{u,0,0,0}={dp}_{u,1,1,1}=1\),其他为 \(0\)。因为初始时这个点的背包大小为 \(1\)(只有自己),分别对应没有警卫和有警卫。

\(v\)\(u\) 的儿子,考虑合并 \(v\) 的背包到 \(u\) 上。设 \(u,v\) 的背包大小分别为 \(p,q\)。下面设 \(i\in[0,p],j\in[0,q]\),考虑三种状态分别如何转移:

(一)状态 \(x=0,y=0\)

此时 \(v\) 上一定没有警卫,不然就会覆盖到 \(u\),因此:

\[{dp}_{u,i+j,0,0}\stackrel{+}{\gets}{dp}_{u,i,0,0}\times({dp}_{v,j,0,0}+{dp}_{v,j,0,1}) \]

(二)状态 \(x=0,y=1\)

有两种情况:第一种是此前 \(u\) 尚未被控制,直到 \(v\) 有警卫后 \(u\) 才被控制;第二种是此前 \(u\) 已经被控制。因此:

\[{dp}_{u,i+j,0,1}\stackrel{+}{\gets}{dp}_{u,i-1,0,0}\times{dp}_{v,j,1,1}+{dp}_{u,i,0,1}\times({dp}_{v,j,0,0}+{dp}_{v,j,0,1}+{dp}_{v,j,1,1}) \]

注意第一种情况 \(i\) 处要减一,因为 \(u\) 点的贡献需要被算进来。

(三)状态 \(x=1,y=1\)

这种情况对 \(v\) 没有限制,因此:

\[{dp}_{u,i+j,1,1}\stackrel{+}{\gets}{dp}_{u,i,1,1}\times({dp}_{v,j-1,0,0}+{dp}_{v,j,0,1}+{dp}_{v,j,1,1}) \]

注意这里 \(j\) 处又有一个减一,跟上面同理,\(v\) 点的贡献需要被算进来。

在完成 \(v\) 的背包向 \(u\) 的合并后,我们将 \(p\stackrel{+}{\gets}q\) 即可继续合并下一棵子树。

上面减一处均有数组越界的风险,这里认为所有不合法(负数下标)的状态的值均为 \(0\)

正确实现的树上背包的复杂度为 \(\mathcal O(n^2)\),原因是两个点只在 LCA 处被合并一次。正确实现树上背包的方法是精细枚举背包大小,也就是上文对 \(p,q\) 的处理。这一点在代码中是通过 \(sz\) 数组实现的。

//By: OIer rui_er
#include <bits/stdc++.h>
#define rep(x,y,z) for(ll x=(y);x<=(z);x++)
#define per(x,y,z) for(ll x=(y);x>=(z);x--)
#define debug(format...) fprintf(stderr, format)
#define fileIO(s) do{freopen(s".in","r",stdin);freopen(s".out","w",stdout);}while(false)
using namespace std;
typedef long long ll;

mt19937 rnd(std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now().time_since_epoch()).count());
ll randint(ll L, ll R) {
    uniform_int_distribution<ll> dist(L, R);
    return dist(rnd);
}

template<typename T> void chkmin(T& x, T y) {if(x > y) x = y;}
template<typename T> void chkmax(T& x, T y) {if(x < y) x = y;}

const ll N = 2e3+5, mod = 1e9+7;

ll n, dp[N][N][2][2], tmp[N][2][2], sz[N];
vector<ll> e[N];

void dfs(ll u, ll f) {
    dp[u][0][0][0] = dp[u][1][1][1] = 1;
    sz[u] = 1;
    for(ll v : e[u]) {
        if(v != f) {
            dfs(v, u);
            // debug("%lld -> %lld\n", u, v);
            rep(i, 0, sz[u]+sz[v]) {
                tmp[i][0][0] = 0;
                tmp[i][0][1] = 0;
                tmp[i][1][1] = 0;
            }
            per(i, sz[u], 0) {
                per(j, sz[v], 0) {
                    tmp[i+j][0][0] += dp[u][i][0][0] * (dp[v][j][0][0] + dp[v][j][0][1]);
                    tmp[i+j][0][0] %= mod;
                    if(i >= 1) {
                        tmp[i+j][0][1] += dp[u][i-1][0][0] * dp[v][j][1][1] + dp[u][i][0][1] * (dp[v][j][0][0] + dp[v][j][0][1] + dp[v][j][1][1]);
                        tmp[i+j][0][1] %= mod;
                        tmp[i+j][1][1] += dp[u][i][1][1] * ((j >= 1 ? dp[v][j-1][0][0] : 0) + dp[v][j][0][1] + dp[v][j][1][1]);
                        tmp[i+j][1][1] %= mod;
                    }
                }
            }
            rep(i, 0, sz[u]+sz[v]) {
                dp[u][i][0][0] = tmp[i][0][0];
                dp[u][i][0][1] = tmp[i][0][1];
                dp[u][i][1][1] = tmp[i][1][1];
            }
            sz[u] += sz[v];
            // rep(i, 0, sz[u]) debug("DP[%lld][%lld] = {%lld, %lld, %lld}\n", u, i, dp[u][i][0][0], dp[u][i][0][1], dp[u][i][1][1]);
        }
    }
    // debug("@ %lld\n", u);
    // rep(i, 0, sz[u]) debug("DP[%lld][%lld] = {%lld, %lld, %lld}\n", u, i, dp[u][i][0][0], dp[u][i][0][1], dp[u][i][1][1]);
}

int main() {
    // freopen("debug.log", "w", stderr);
    scanf("%lld", &n);
    rep(i, 1, n-1) {
        ll u, v;
        scanf("%lld%lld", &u, &v);
        e[u].push_back(v);
        e[v].push_back(u);
    }
    dfs(1, 0);
    rep(i, 0, n) printf("%lld\n", (dp[1][i][0][0]+dp[1][i][0][1]+dp[1][i][1][1])%mod);
    return 0;
}