题意
给定一颗树,每个点有点权。求有多少对点对 \((x,y)\) 满足 \(x<y\) 且以 \(x\) 到 \(y\) 的简单路径上的所有点的点权作为边长,能围成一个凸多边形。
\(1 \leq n \leq 10^5\),\(1 \leq a_i \leq 10^9\)。
思路
遇到这种求合法路径条数的题一般都可以往点分治去向。
首先思考如何判断一堆边能否围成一个凸多边形。如果最长边已经大于了其他所有边之和,那么就算其他边连成一条直线都不可能与最长边收尾相连。极限情况便是取等,此时最长边和其他边围成的直线刚好重合。只要其他边再长一点,就可以让他们“凸”起来,从而满足条件。因此,一对 \((x,y)\) 是合法的,必须要满足: \(sum-mx>mx\),\(sum\) 即为 \((x,y)\) 的点权和,\(mx\) 为最大点权。移项可得:\(sum >2mx\)。
这时再来想如何统计答案。路径条数可以分为三类:
- 不同子树内两个点(路径经过重心)。
- 其中一个点为重心。
- 同一子树内两个点。
第三种交给递归处理。第二种可以看成一条过重心的路径和重心单独作为一个点的路径拼起来,也就变为了第一种路径。我们可以把从重心出发的路径记录下来,每条路径记录他的 \(sum,mx\)。然后任选两条拼起来,统计合法个数即可。然而这样可能会选到在同一子树内的两点,容斥一下,用 \(ans\) 减去每个子树中单独选的答案即可。
那么,如何判断两路径拼起来是否合法呢?根据上面推的式子可知,必须满足:\(sum_1+sum_2-a_u>2\max(mx1,mx2)\)。\(-a_u\) 是因为两路径端点的重心被多算了一次,需要减去。然而有 \(\max\) 存在,不太好做,考虑如何让 \(\max\) 去掉:可以按每条路径的 \(mx\) 从小到大排序,每次算当前路径时只考虑在他前面的路径,这样 \(\max(mx1,mx2)\) 就一定等于当前路径的 \(mx\) 了,\(\max\) 自然脱去。
设选的另一条路径点权和为 \(res\),当前为 \(now\),那么有 \(res>2mx-now+a_u\)。右边的一坨我们是知道的,所以所有满足上式的路径都成立,相当于算有多少个 \(res>2mx-now+a_u\)。可以用树状数组维护值域,单点修改,区间查询。注意 \(a_i\) 达到了 \(10^9\),离散化即可。
此题就愉快的结束了。
code
#include<bits/stdc++.h>
#define ll long long
#define mkp make_pair
using namespace std;
int n,m,k,cnt,head[200005];
int vis[200005];
ll a[200005],ans;
ll ls[200006],pl;
ll c[200005];
int t;
typedef pair<ll,ll>pii;
vector<pii>p;
struct edge{
int u,v,nxt;
}e[400005];
ll lowbit(ll x){
return x&-x;
}
void modify(int x,int lim,ll d){
for(int i=x;i<=lim;i+=lowbit(i))c[i]+=d;
}
ll query(int x){
ll res=0;
for(int i=x;i;i-=lowbit(i))res+=c[i];
return res;
}
void add(int u,int v){
e[++k].v=v;
e[k].nxt=head[u];
head[u]=k;
}
int get_size(int u,int fa){
int res=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa||vis[v])continue;
res+=get_size(v,u);
}
return res;
}
int get_wc(int u,int fa,int sz,int& wc){//求重心
int sum=1,mx=0;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa||vis[v])continue;
int rtt=get_wc(v,u,sz,wc);
sum+=rtt;
mx=max(mx,rtt);
}
mx=max(mx,sz-sum);
if(mx<=sz/2)wc=u;
return sum;
}
void get(int u,int fa,ll sum,ll mx){//记录每一条到重心的路径
mx=max(mx,a[u]);sum+=a[u];
p.push_back(mkp(mx,sum));
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa||vis[v])continue;
get(v,u,sum,mx);
}
}
ll calc(int u,int fa,ll ku){
while(!p.empty())p.pop_back();
get(u,fa,ku,ku);
ll res=0;pl=0;
for(auto tmp:p)ls[++pl]=tmp.second;
sort(ls+1,ls+1+pl);//离散化
pl=unique(ls+1,ls+1+pl)-ls-1;
sort(p.begin(),p.end());
ll nm=(ku?ku:a[u]);
for(auto now:p){
ll mx=now.first,sum=now.second;
int pos=upper_bound(ls+1,ls+1+pl,nm-sum+2*mx)-ls;
if(pos<=pl)res+=query(pl)-query(pos-1);//树状数组求答案
pos=lower_bound(ls+1,ls+1+pl,sum)-ls;
modify(pos,pl,1);
}
for(auto now:p){
ll sum=now.second;
int pos=lower_bound(ls+1,ls+1+pl,sum)-ls;
modify(pos,pl,-1);//树状数组清空
}
return res;
}
void dfs(int u){
get_wc(u,0,get_size(u,0),u);
vis[u]=1;
ans+=calc(u,0,0ll);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].v;
if(vis[v])continue;
ans-=calc(v,u,a[u]);//容斥
dfs(v);
}
}
int main(){
cin>>t;
while(t--){
cin>>n;
for(int i=1;i<=n;++i)scanf("%lld",&a[i]);
k=0;ans=0;
for(int i=1;i<=n;++i)vis[i]=head[i]=0;
for(int i=1;i<n;++i){
int a,b;
scanf("%d%d",&a,&b);
add(a,b);add(b,a);
}
dfs(1);
cout<<ans<<'\n';
}
return 0;
}