CF1762E Tree Sum 题解

发布时间 2023-08-21 19:03:40作者: User-Unauthorized

题意

对于一棵 \(n\) 个节点的树 \(T\),定义 \(\operatorname{good}(T)\) 为真当且仅当边权 \(w \in \left\{-1,1\right\}\) 且对于任意节点 \(u\),均有 \(\displaystyle f(u) = \prod\limits_{\left(u, v\right) \in E} w\left(u, v\right) = -1\)

\[\sum\limits_{\operatorname{good}(T)} \operatorname{dist}(1, n) \bmod 998244353 \]

题解

分析题目性质。

性质 \(1\)

\(n\) 为奇数,那么不存在符合性质的树。

因为题目中要求对于任意节点 \(u\) 均有 \(f(u) = -1\),所以当 \(n\) 为奇数时,\(\displaystyle \prod\limits_{i = 1}^{n} f(i) = \left(-1\right)^n = -1\)

考虑计算每条边对 \(\displaystyle \prod\limits_{i = 1}^{n} f(i)\) 的贡献,可以得出对于任意一条边 \(\left(u, v\right) \in E\),无论其边权如何,均会对 \(f(u)\)\(f(v)\) 产生共两次贡献,即产生的贡献为 \(w\left(u, v\right)^2 = 1\),进而可以得出 \(\displaystyle \prod\limits_{i = 1}^{n} f(i) = 1\),与 \(\displaystyle \prod\limits_{i = 1}^{n} f(i) = \left(-1\right)^n = -1\) 冲突。

性质 \(2\)

若已经确定了树的形态,那么仅有一种边权分配方式使其符合要求。

\(\operatorname{pw}(u) = w\left(u, fa_u\right)\),即 \(u\) 向其父亲节点相连的边的边权。考虑首先钦定任意一点为根,然后从叶子节点向上归纳,对于叶子节点 \(v\),显然有 \(\displaystyle f(v) = \operatorname{pw}(v) \Rightarrow \operatorname{pw}(v) = -1\)。对于任意非叶子节点,有 \(\displaystyle f(u) = -1 \Rightarrow \operatorname{pw}(u) = \left(-1\right)\prod\limits_{v \in \operatorname{son}_u} \operatorname{pw}(v)\)。由于树的形态唯一,所以每个节点的 \(\displaystyle \operatorname{pw}(u)\) 唯一,进而边权分配方式唯一。

性质 \(3\)

若边 \(\left(u, v\right)\) 一侧有 \(k\) 个节点,那么该边边权为 \(\left(-1\right)^k\)

该性质等价于 \(pw(u) = \left(-1\right)^{size_u}\),考虑使用数学归纳法证明。该结论对于叶子节点显然,对于非叶子节点 \(u\),有

\[\operatorname{pw}(u) = \left(-1\right)\prod\limits_{v \in \operatorname{son}_u} \operatorname{pw}(v) = \left(-1\right)^{1 + \sum\limits_{v \in \operatorname{son}_u} size_v} = \left(-1\right)^{size_u} \]


接下来计算答案,考虑按边枚举贡献。

对于一条边 \(\left(u, v\right)\),如果其能产生贡献,那么该边一定联通点 \(1\) 和 点 \(n\) 所属的联通块,我们设点 \(1\) 所属的连通块有 \(k\) 个点,那么根据上文推出的性质可得这条边的边权为 \(\displaystyle \left(-1\right)^k\)。下面考虑有多少种符合要求的树 \(T\) 种会出现这类边(这里这类边指的是联通点 \(1\) 和 点 \(n\) 所属的联通块且点 \(1\) 所属的连通块有 \(k\) 个点的边)。

首先,因为节点是标号的,所以将点划分为两个连通块的方案数为 \(\displaystyle {n - 2 \choose k - 1}\)。对于每个连通块考虑子树的形态,有 \(\displaystyle k^{k - 2} \left(n - k\right)^{n - k - 2}\) 种方案。考虑边联通的为哪个点对,有 \(\displaystyle k \times \left(n - k\right)\) 种方案。所以对于这类边,会包括他的树的种类数为

\[{n - 2 \choose k - 1}k^{k - 2} \left(n - k\right)^{n - k - 2}k \times \left(n - k\right) \]

通过枚举边的种类 \(k\),可以得出总的答案

\[\sum\limits_{k = 1}^{n - 1} {n - 2 \choose k - 1}k^{k - 1} \left(n - k\right)^{n - k - 1} \]

该式可以以 \(\displaystyle \mathcal{O}(n \log n)\) 的复杂度完成计算,可以通过本题。

Code

//Codeforces - 1762E
#include <bits/stdc++.h>

typedef long long valueType;
typedef std::vector<valueType> ValueVector;

constexpr valueType MOD = 998244353;

bool ModOperSafeModOption = false;

template<typename T1, typename T2, typename T3 = valueType>
void Inc(T1 &a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    a = a + b;

    if (a >= mod)
        a -= mod;
}

template<typename T1, typename T2, typename T3 = valueType>
void Dec(T1 &a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    a = a - b;

    if (a < 0)
        a += mod;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 sum(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    return a + b >= mod ? a + b - mod : a + b;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 sub(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    return a - b < 0 ? a - b + mod : a - b;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 mul(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    return (long long) a * b % mod;
}

template<typename T1, typename T2, typename T3 = valueType>
void Mul(T1 &a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod;
    }

    a = (long long) a * b % mod;
}

template<typename T1, typename T2, typename T3 = valueType>
T1 pow(T1 a, T2 b, const T3 &mod = MOD) {
    if (ModOperSafeModOption) {
        a %= mod;
        b %= mod - 1;

        if (a < 0)
            a += mod;

        if (b < 0)
            b += mod - 1;
    }

    T1 result = 1;

    while (b > 0) {
        if (b & 1)
            Mul(result, a, mod);

        Mul(a, a, mod);
        b = b >> 1;
    }

    return result;
}

class Inverse {
public:
    typedef ValueVector container;

private:
    valueType size;
    container data;
public:
    explicit Inverse(valueType n) : size(n), data(size + 1, 0) {
        data[1] = 1;

        for (valueType i = 2; i <= size; ++i)
            data[i] = mul((MOD - MOD / i), data[MOD % i]);
    }

    valueType operator()(valueType n) const {
        return data[n];
    }
};

int main() {
    valueType N;

    std::cin >> N;

    Inverse Inv(N);

    ValueVector Fact(N + 1, 1), InvFact(N + 1, 1);

    Fact[0] = 1;
    InvFact[0] = 1;
    for (valueType i = 1; i <= N; ++i) {
        Fact[i] = mul(Fact[i - 1], i);
        InvFact[i] = mul(InvFact[i - 1], Inv(i));
    }

    typedef std::function<valueType(valueType, valueType)> CalcFunction;

    CalcFunction C = [&Fact, &InvFact](valueType n, valueType m) -> valueType {
        if (n < 0 || m < 0 || n < m)
            return 0;

        return mul(Fact[n], mul(InvFact[m], InvFact[n - m]));
    };

    valueType ans = 0;

    for(valueType i = 1; i < N; ++i) {
        valueType sum = 1;

        Mul(sum, C(N - 2, i - 1));
        Mul(sum, mul(i, N - i));
        Mul(sum, mul(pow(i, i - 2), pow(N - i, N - i - 2)));

        if (i & 1)
            Dec(ans, sum);
        else
            Inc(ans, sum);
    }

    std::cout << ans << std::endl;

    return 0;
}