Codeforces 1458F - Range Diameter Sum

发布时间 2023-06-29 14:57:45作者: tzc_wk

先考虑直径的一些求法:最普遍的想法肯定是从点集中任意一个点开始 DFS 找到距其最远的点,再一遍 DFS 找到距离你找到的那个点最远的点。但是放在这个题肯定是不太行的。因此考虑一种更常用的求法:合并。更直观地说:我们定义树上一个圆 \((x,r)\) 表示距离 \(x\)\(\le r\) 的所有点组成的集合(当然,这里的 \(x\) 不一定是整点,也可能是某条边的中点,\(r\) 也不一定是整数,也可能是某个整数 \(+0.5\),但是把每条边长度乘 \(2\) 后在边上新建一个点则可以避免这些问题)。那么显然一个点集的直径就是覆盖这个点集的最小圆的直径。现在考虑怎么合并两个圆:

  • 如果一个圆完全包含了另一个圆,则返回大圆。
  • 否则,圆心为 \(\text{go}(x_1,x_2,\dfrac{1}{2}(\text{dis}(x_1,x_2)-r_1+r_2))\),半径为 \(\dfrac{1}{2}(\text{dis}(x_1,x_2)+r_1+r_2)\)

现在考虑怎么计算所有区间的答案。采用枚举右端点维护左端点的做法不可取,因为 modify 方式很复杂。因此考虑另一种求法——分治。预处理 \([l,mid],[mid+1,r]\) 对应的圆,然后枚举左端点,由于随着右端点的增加圆肯定不断变大,所以右端点可以分成三段:\([l,mid]\) 对应的圆完全包含 \([mid+1,r]\) 对应的圆,不存在包含关系和 \([l,mid]\) 对应的圆完全被 \([mid+1,r]\) 对应的圆包含,用 two pointers 维护三段分界点,左右两段显然可以 \(O(1)\) 求得,中间那段相当于要你维护一个点集,支持加入删除一个点并查询点集中所有点到给定点的距离,点分树板子硬上。

时间复杂度 2log。

const int MAXN=2e5;
const int LOG_N=18;
const int MAXD=50;
const int INF=0x3f3f3f3f;
int n,hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int fa[MAXN+5][LOG_N+2],dep[MAXN+5];
void dfs0(int x,int f){
	fa[x][0]=f;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f)continue;
		dep[y]=dep[x]+1;dfs0(y,x);
	}
}
int getlca(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	for(int i=LOG_N;~i;i--)if(dep[x]-(1<<i)>=dep[y])x=fa[x][i];
	if(x==y)return x;
	for(int i=LOG_N;~i;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
int getdis(int x,int y){return dep[x]+dep[y]-2*dep[getlca(x,y)];}
namespace Centroid_Decomposition{
	int siz[MAXN+5],cent,mx[MAXN+5],vis[MAXN+5],dis[MAXD+5][MAXN+5],dfa[MAXN+5],D[MAXN+5];
	void findcent(int x,int f,int totsz){
		siz[x]=1;mx[x]=0;
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];if(y==f||vis[y])continue;
			findcent(y,x,totsz);siz[x]+=siz[y];chkmax(mx[x],siz[y]);
		}chkmax(mx[x],totsz-siz[x]);
		if(mx[x]<mx[cent])cent=x;
	}
	void dfs_tree(int x,int f,int D){
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];if(y==f||vis[y])continue;
			dis[D][y]=dis[D][x]+1;dfs_tree(y,x,D);
		}
	}
	void divcent(int x){
		vis[x]=1;dfs_tree(x,0,D[x]);
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];if(vis[y])continue;cent=0;
			findcent(y,x,siz[y]);dfa[cent]=x;D[cent]=D[x]+1;
			divcent(cent);
		}
	}
	ll mark1[MAXN+5],mark2[MAXN+5],mark3[MAXN+5];
	void ins(int x,int y){
		for(int i=x;i;i=dfa[i]){
			mark1[i]+=y*dis[D[i]][x];mark3[i]+=y;
			if(dfa[i])mark2[i]+=y*dis[D[dfa[i]]][x];
		}
	}
	ll query(int x){
		ll res=0;
		for(int i=x,pre=0;i;pre=i,i=dfa[i])
			res+=mark1[i]-mark2[pre]+1ll*dis[D[i]][x]*(mark3[i]-mark3[pre]);
		return res;
	}
	void init_CD(){
		mx[0]=INF;findcent(1,0,n*2-1);divcent(cent);
//		for(int i=1;i<n*2;i++)printf("%d%c",dfa[i]," \n"[i==n*2-1]);
	}
}
using namespace Centroid_Decomposition;
ll res=0;
int get_kanc(int x,int k){for(int i=LOG_N;~i;i--)if(k>>i&1)x=fa[x][i];return x;}
int go(int x,int y,int k){
	int lc=getlca(x,y);
	if(k<=dep[x]-dep[lc])return get_kanc(x,k);
	else return get_kanc(y,(dep[x]+dep[y]-dep[lc]*2)-k);
}
struct circ{
	int x,r;
	circ(){x=r=0;}
	circ(int _x,int _r){x=_x;r=_r;}
	friend circ operator +(const circ &X,const circ &Y){
		int d=getdis(X.x,Y.x);
		if(X.r>=d+Y.r)return X;
		if(Y.r>=d+X.r)return Y;
		return circ(go(X.x,Y.x,(d-X.r+Y.r)/2),(d+X.r+Y.r)/2);
	}
}pre[MAXN+5],suf[MAXN+5];
void solve(int l,int r){
	if(l==r)return;int mid=l+r>>1;solve(l,mid);solve(mid+1,r);
	pre[mid+1]=circ(mid+1,0);suf[mid]=circ(mid,0);
	for(int i=mid+2;i<=r;i++)pre[i]=pre[i-1]+circ(i,0);
	for(int i=mid-1;i>=l;i--)suf[i]=suf[i+1]+circ(i,0);
	int p1=mid+1,p2=mid+1;ll c1=0,c2=0,s2=0,s3=0;
	for(int i=mid+1;i<=r;i++)s3+=pre[i].r;
	for(int i=mid;i>=l;i--){
		while(p2<=r&&pre[p2].r<getdis(suf[i].x,pre[p2].x)+suf[i].r){
			ins(pre[p2].x,1);c2++;s2+=pre[p2].r;s3-=pre[p2].r;++p2;
		}
		while(p1<p2&&suf[i].r>=getdis(suf[i].x,pre[p1].x)+pre[p1].r){
			ins(pre[p1].x,-1);c2--;s2-=pre[p1].r;c1++;++p1;
		}
		res+=2ll*c1*suf[i].r+2ll*s3+1ll*c2*suf[i].r+s2+query(suf[i].x);
	}
	for(int i=p1;i<p2;i++)ins(pre[i].x,-1);
}
int main(){
	scanf("%d",&n);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		adde(u,i+n);adde(i+n,u);adde(v,i+n);adde(i+n,v);
	}
	dfs0(1,0);
	for(int i=1;i<=LOG_N;i++)for(int j=1;j<2*n;j++)
		fa[j][i]=fa[fa[j][i-1]][i-1];
	init_CD();solve(1,n);printf("%lld\n",res/2);
	return 0;
}