不难。
第一步肯定是求出直径 \(d\)。
然后能发现 \(d\bmod 2 = 0\) 时很好求。
可以先任意找到一条直径,再找到这个直径的中点,则容易知道以这个中点为根,其中的每个子树的节点与中点经过的边数最大值为 \(\frac{d}{2}\)。
所以能够得到每个子树内选两个点距离最大值也为 \(d - 2\),所以合法的方案只能为每个子树内选一个 \(\frac{d}{2}\) 距离的点或者不选,这样两段拼成 \(d\)。
设一共有 \(k\) 个子树 ,于是可以对于 \(1\le i\le k\) 的每个子树于求出子树里面与中点距离为 \(\frac{d}{2}\) 的点的个数 \(cnt_i\)。
则很容易求出总方案数,即每个子树都可以选一个或不选:\(\prod\limits_{i = 1}^k (cnt_i + 1)\);不合法的方案数也很好求,即只选了一个点或不选:\(\prod\limits_{i = 1}^k cnt_i + 1\);所以合法方案数也很好求啦:\(\prod\limits_{i = 1}^k (cnt_i + 1) - \sum\limits_{i - 1}^k cnt_i - 1\)。
考虑 \(d\bmod 2 = 1\) 怎么求,因为这时候直径的中点在边上,刚刚找点就不行了。
直径中点在边上,那直接对每个边开一个虚点,既没改变树的形态距离也满足了 \(d\bmod 2 = 0\),且虚点也不会被算入答案。
// lhzawa(https://www.cnblogs.com/lhzawa/)
#include<bits/stdc++.h>
using namespace std;
const int N = 4e5 + 10;
const long long mod = 998244353;
int n;
vector<int> ev[N];
int dep[N];
void dfsdep(int u, int fa) {
for (int v : ev[u]) {
// printf("%d -> %d\n", u, v);
if (v == fa) {
continue;
}
dep[v] = dep[u] + 1;
dfsdep(v, u);
}
return ;
}
int stk[N], top;
int fd = 0, d;
void dfsd(int u, int fa, int t) {
if (! fd) {
stk[++top] = u;
}
if (u == t) {
fd = 1;
}
for (int v : ev[u]) {
if (v == fa) {
continue;
}
dfsd(v, u, t);
}
if (! fd) {
top--;
}
return ;
}
int cnt[N];
void dfsdpu(int u, int fa, int top) {
cnt[top] += (dep[u] == d / 2);
for (int v : ev[u]) {
if (v == fa) {
continue;
}
dep[v] = dep[u] + 1;
dfsdpu(v, u, top);
}
return ;
}
int main() {
scanf("%d", &n);
function<void (int, int)> add = [](int u, int v) -> void {
ev[u].push_back(v);
return ;
};
int m = n;
for (int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
m++, add(m, x), add(x, m), add(m, y), add(y, m);
}
dfsdep(1, 0);
int s = 0;
for (int i = 1; i <= m; i++) {
// printf("%d ", dep[i]);
s = (dep[i] > dep[s] ? i : s);
}
// printf("\n");
dep[s] = 0;
dfsdep(s, 0);
int t = 0;
for (int i = 1; i <= m; i++) {
t = (dep[i] > dep[t] ? i : t);
}
dfsd(s, 0, t);
d = dep[t];
// printf("%d <-> %d = %d\n", s, t, d);
int rt = stk[(top + 1) >> 1];
// printf("rt = %d\n", rt);
for (int u : ev[rt]) {
dep[u] = 1;
dfsdpu(u, rt, u);
}
long long c = 1, h = 1;
for (int u : ev[rt]) {
c = 1ll * c * (cnt[u] + 1) % mod, h += cnt[u];
}
printf("%lld\n", (c - h + mod) % mod);
return 0;
}