Codeforces 1868C/1869E Travel Plan 题解 | 巧妙思路与 dp

发布时间 2023-09-13 22:24:39作者: bringlu

题目链接:Travel Plan

题目大意:\(n\) 个点的完全二叉树,每个点可以分配 \(1 \sim m\) 的点权,定义路径价值为路径中最大的点权,求所有路径的价值和。

对于任意长度(这里主要指包括几个节点)的路径 \(t\),最大点权不超过 \(k\) 的方案数有 \(k^t\) 个, 因此最大点权恰好为 \(k\) 的方案数有 \(k^t - (k-1)^t\)。所以,对于任意一条长度为 \(t\) 的路径,不考虑不在路径上其他点的影响时,其对于答案的贡献为:

\[\begin{aligned} \text{path contribution}_t &= \sum_{k=1}^m (k^t - (k-1)^t) \cdot k \\ &= \sum_{k=1}^m \left( k^{t+1} - (k-1)^{t+1} - (k-1)^t \right) \\ &= m^{t+1} - \sum_{k=1}^{m-1} k^t \end{aligned} \]

由于路径长度不会超过 \(2 \log n\),因此求出全部长度路径分别对于答案的贡献时间复杂度为 \(O(m \log \log n)\)

事实上,对于上面式子的第二项,可以用 Lagrange 插值、伯努利数、多项式等方法可以优化到 \(O(\log^2 n)\)

下一步,问题转化为求出路径长度为 \(t\) 的个数分别是多少,然后乘一下即可。

第一种方法是点分治,显然复杂度是不够的,因为有 \(O(n \log n)\)

第二种方法是题解做法。

首先,在这个完全二叉树中,不同形状的子二叉树共有 \(O(\log n)\) 个,设叶子个数为 \(leaf_i\),那么其中包括两种类型:

  1. \(leaf_i = 2^{p-1}\) 时(\(p\) 是这个子二叉树的最大深度),那么以 \(i\) 为根的子树是一个完全二叉树,显然有 \(O(\log n)\) 个。
  2. \(leaf_i \not = 2^{p-1}\) 时,节点 \(i\) 的左右儿子必有一个满足其为 \(2\) 的幂次,而另一个不满足,以这样的点为根的子树中的根可以脑补为一条链的形状,因此也有 \(O(\log n)\) 个。

不妨设 \(dp_{i,j}\) 表示以 \(i\) 为根的子树中长度为 \(j\) 的路径个数,\(f_{i,j}\) 表示以 \(i\) 为根的子树中,以 \(i\) 为结束端点长度为 \(j\) 的路径个数。满二叉树时,转移方程应该为:

\[\begin{aligned} f_{i,1} &= 1 \\ f_{i,j} &= f_{lson(i), j-1} + f_{rson(i), j-1} (j \geq 2) \\ dp_{i,1} &= size_i \\ dp_{i,j} &= dp_{lson(i),j} + dp_{rson(i), j} + \sum_{k=0}^{j-1} f_{lson(i), k} \times f_{rson(i), j - 1 - k} (j \geq 2) \\ \end{aligned} \]

具体实现的时候,事实上一共 \(O(\log n)\) 个点,因此第二部分的算法复杂度为 \(O(\log^3 n)\),这里也可以用 FFT 优化这个式子做到 \(O(\log^2 n)\)

不过官方题解说可以做到。

最后一步,由于第一步中没有考虑不在路径上的其他点的方案影响,因此需要乘上去。

\[ans = \sum_{t=1}^{\text{max path length}} dp_{1, t} \times \text{path contribution}_t \times m^{n-t} \]

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef double db;
typedef long double ld;

#define IL inline
#define fi first
#define se second
#define mk make_pair
#define pb push_back
#define SZ(x) (int)(x).size()
#define ALL(x) (x).begin(), (x).end()
#define dbg1(x) cout << #x << " = " << x << ", "
#define dbg2(x) cout << #x << " = " << x << endl

template<typename Tp> IL void read(Tp &x) {
    x=0; int f=1; char ch=getchar();
    while(!isdigit(ch)) {if(ch == '-') f=-1; ch=getchar();}
    while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar();}
    x *= f;
}
int buf[42];
template<typename Tp> IL void write(Tp x) {
    int p = 0;
    if(x < 0) { putchar('-'); x=-x;}
    if(x == 0) { putchar('0'); return;}
    while(x) {
        buf[++p] = x % 10;
        x /= 10;
    }
    for(int i=p;i;i--) putchar('0' + buf[i]);
}

const int LOGN = 65;
const int LOGNN = 150;
const int mod = 998244353;

ll n;
int m, dpid_cnt = 0;

int pathcon[LOGNN];
int f[LOGNN][LOGNN], dp[LOGNN][LOGNN];

ll ksm(ll a, ll b) {
    ll ret = 1;
    while (b) {
        if (b & 1ll) ret = ret * a % mod;
        a = a * a % mod;
        b >>= 1ll;
    }
    return ret;
}

pair<int, ll> depl(ll u) {
    if ((u << 1ll) > n) {
        return mk(1, u);
    }
    auto p = depl(u << 1ll);
    return mk(p.fi + 1, p.se);
}

int depr(ll u) {
    if ((u << 1ll | 1ll) > n) {
        return 1;
    }
    return depr(u << 1ll | 1ll) + 1;
}

bool fulltree(ll u) {
    return ((u << 1ll) > n && (u << 1ll | 1ll) > n) || (depl(u << 1ll).fi == depr(u << 1ll | 1ll));
}

ll getsz(ll u) {
    if ((u << 1ll) > n) return 1;
    if ((u << 1ll | 1ll) > n) return 2;
    auto p = depl(u);
    int dr = depr(u);
    // dbg1(u); dbg1(p.fi); dbg1(p.se); dbg1(dr); dbg1((1ll << (1ll * dr)) - 1); dbg2((1ll << (1ll * dr)) - 1 + (n - p.se + 1));
    if (p.fi == dr) return (1ll << (1ll * dr)) - 1;
    else {
        return (1ll << (1ll * dr)) - 1 + (n - p.se + 1);
    }
}

unordered_map<ll, int> dpid, szcnt;

void dfs(ll u) {
    int uid;
    ll szu = getsz(u);
    if (dpid.count(szu) == 0) dpid[szu] = uid = ++dpid_cnt;
    else return;
    
    f[uid][0] = f[uid][1] = 1; dp[uid][0] = 1;
    dp[uid][1] = szu % mod;
    if (!fulltree(u)) szcnt[u] = 1;

    if ((u << 1ll) > n) return;
    else if((u << 1ll | 1ll) > n) {
        dfs(u << 1ll);
        f[uid][2] = dp[uid][2] = 1;
        return;
    }

    dfs(u << 1ll); dfs(u << 1ll | 1ll);

    int lid = dpid[getsz(u << 1ll)], rid = dpid[getsz(u << 1ll | 1ll)];
    for (int j = 2; j <= 2 * LOGN; j++) {
        f[uid][j] = (f[lid][j-1] + f[rid][j-1]) % mod;
        dp[uid][j] = (dp[lid][j] + dp[rid][j]) % mod;
        for (int k = 0; k < j; k++) {
            dp[uid][j] = (dp[uid][j] + 1ll * f[lid][k] * f[rid][j - 1 - k]) % mod;
        }
    }
}

void solve() {
    dpid_cnt = 0; dpid.clear(); szcnt.clear();
    memset(pathcon, 0, sizeof(pathcon));
    memset(f, 0, sizeof(f));
    memset(dp, 0, sizeof(dp));
    read(n); read(m);
    for (int t = 0; t <= (LOGN << 1); t++) {
        pathcon[t] = ksm(m, t + 1);
        for (int k = 1; k < m; k++) {
            pathcon[t] = (1ll * pathcon[t] - ksm(k, t) + mod) % mod;
        }
    }

    dfs(1);

    int ans = 0;
    for (int t = 1; t <= min(n, 2ll * LOGN); t++) {
        if (dp[1][t] == 0) break;
        ans = (ans + 1ll * dp[1][t] * pathcon[t] % mod * ksm(m, n - t)) % mod;
    }
    write(ans); putchar(10);
}

int main() {
#ifdef LOCAL
    freopen("test.in", "r", stdin);
    // freopen("test.out", "w", stdout);
#endif
    int T = 1;
    read(T);
    while(T--) solve();
    return 0;
}