题解 P3248 [HNOI2016]树

发布时间 2023-07-17 21:39:42作者: HQJ2007

有意思的题,927ms 拿下最优解。

点数最多 \(10^{10}\) 个,没法暴力拼接,考虑简化大树。

每次拼接,我们记录 \(x\)\(to\)\(to\) 所在大树的根节点 \(rt\)。然后连两条边: \((rt,to)\)\((to,x)\)。本质上相当于把每次接上来的子树缩成一个点。

这样大树的点数最多就只有 \(2\times 10^5\) 个了。

考虑询问。我们定义 \(tu,tv\) 为两点在模板树上的编号,\(u,v\) 为大树的编号。然后对于一个节点 \(u\),我们可以通过二分+主席树算出 \(tu\),方法显然不再赘述。

每次询问给定 \(u,v\)。分三种情况讨论。

  1. \(u,v\) 在同一颗子树内:直接算出 \(tu,tv\) 在模板树中计算答案。

  2. \(u,v\) 中一个为另一个的祖先:令 \(u\) 为大树中深度更深的节点,求出其在对应子树中的深度,跳到离 \(v\) 最近的节点,然后转化到模板树上求距离。

  3. 其他情况:分别求出 \(u,v\) 在对应子树中的深度,定义 \(g\)\(u,v\) 在大树上的 LCA,然后让 \(u,v\) 分别跳到离 \(g\) 最近的节点,转化到模板树上求距离即可。

复杂度 \(O(n\log n)\)。注意编号不要搞混。

code:

//fa[i][j]模板树,adj[u]模板树,fa2[i][j]新树
//siz[u]模板树
//dep[u]模板树u深度,dep2[u]新树节点u的深度,dis[u]新树节点u在大树上的深度
//dfn[u],r[u]模板树的dfs序
//a[i]dfs序为i的节点编号
//rk[i]模板树节点i在子树中的排名
//rt[i]模板树上主席树,t[u]主席树
//c[i]模板树树桩数组
//cnt为dfn序计数变量
//tot为新树编号计数变量
//num为主席树编号计数变量
//sum[i]进行i次拼接操作后大树的大小
//tx,ty模板树编号
//x,y大树编号
//dy[x]新树节点x对应的模板树编号
//head[i]第i次拼接的根的新树编号
//g[i]计算模板树节点排名的过程数组
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+5;
int n,m,q,fa[N][22],siz[N],dep[N],dep2[N],r[N],fa2[N][22],a[N],rk[N],rt[N],dfn[N],num=0,dy[N],c[N];
int tot=1,cnt=0;
ll sum[N],head[N],dis[N];
vector<int>adj[N];
struct node{int val,op;node(int val=0,int op=0):val(val),op(op){}};
vector<node>g[N];
struct tree{int ch[2],d;}t[N*30];
void dfs(int u,int lst){
  dep[u]=dep[lst]+1;fa[u][0]=lst;siz[u]=1;dfn[u]=++cnt;a[cnt]=u;
  for(int i=1;i<=20;++i)fa[u][i]=fa[fa[u][i-1]][i-1];
  for(int i=0;i<adj[u].size();++i){
    int v=adj[u][i];if(v==lst)continue;
    dfs(v,u);siz[u]+=siz[v];
  }
  r[u]=cnt;
  g[dfn[u]-1].push_back(node(u,-1));
  g[r[u]].push_back(node(u,1));
}
int lca(int u,int v){
  if(dep[u]<dep[v])swap(u,v);
  for(int i=20;i>=0;--i)if(dep[fa[u][i]]>=dep[v])u=fa[u][i];
  if(u==v)return u;
  for(int i=20;i>=0;--i)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
  return fa[u][0];
}
int lbt(int x){return x&(-x);}
void update(int i){for(;i<=n;i+=lbt(i))++c[i];}
int getsum(int i){
  int res=0;
  for(;i;i-=lbt(i))res+=c[i];
  return res;
}
void add(int v,int u,ll w){
  dis[v]=dis[u]+w;dep2[v]=dep2[u]+1;
  fa2[v][0]=u;
  for(int i=1;i<=20;++i)fa2[v][i]=fa2[fa2[v][i-1]][i-1];
}
void change(int x,int &y,int l,int r,int k){
  y=++num;t[y]=t[x];
  if(l==r){++t[y].d;return;}
  int mid=(l+r)>>1;
  if(k<=mid)change(t[x].ch[0],t[y].ch[0],l,mid,k);
  else change(t[x].ch[1],t[y].ch[1],mid+1,r,k);
  t[y].d=t[t[y].ch[0]].d+t[t[y].ch[1]].d;
}
int query(int x,int y,int l,int r,int k){
  if(l==r)return l;
  int tmp=t[t[y].ch[0]].d-t[t[x].ch[0]].d,mid=(l+r)>>1;
  if(tmp>=k)return query(t[x].ch[0],t[y].ch[0],l,mid,k);
  else return query(t[x].ch[1],t[y].ch[1],mid+1,r,k-tmp);
}
int sonv;
inline pair<int,int> lca2(int u,int v){
  bool flag=0;
  if(dep2[u]<dep2[v])swap(u,v),flag=1;
  for(int i=20;i>=0;--i)if(dep2[fa2[u][i]]>dep2[v])u=fa2[u][i];
  sonv=u;
  if(dep2[u]>dep2[v])u=fa2[u][0];
  if(u==v)return make_pair(u,v);
  for(int i=20;i>=0;--i)if(fa2[u][i]!=fa2[v][i])u=fa2[u][i],v=fa2[v][i];
  if(flag)swap(u,v);
  return make_pair(u,v);
}
int P;
inline int getreal(ll x,int i){
  P=lower_bound(sum,sum+i+1,x)-sum;
  int u=dy[head[P]];
  return query(rt[dfn[u]-1],rt[r[u]],1,n,x-(P?sum[P-1]:0));
}
inline ll read()
{
	ll x=0,f=1;char ch=getchar();
	while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
	while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
	return x*f;
}
int main(){
  n=read();m=read();q=read();
  for(int i=1;i<n;++i){
    int u,v;u=read();v=read();
    adj[u].push_back(v);adj[v].push_back(u);
  }
  dfs(1,0);
  for(int i=1;i<=n;++i){
    update(a[i]);
    for(int j=0;j<g[i].size();++j){
      node tmp=g[i][j];
      rk[tmp.val]+=tmp.op*getsum(tmp.val);
    }
  }
  for(int i=1;i<=n;++i)change(rt[i-1],rt[i],1,n,a[i]);
  sum[0]=n;dep2[1]=dis[1]=head[0]=dy[1]=1;
  for(int i=1;i<=m;++i){
    ll tx,x,ty,y;tx=read();y=read();sum[i]=sum[i-1]+siz[tx];
    x=rk[tx]+sum[i-1];
    ty=getreal(y,i);
    dy[++tot]=ty;
    add(tot,head[P],dep[ty]-dep[dy[head[P]]]);
    dy[++tot]=tx;head[i]=tot;
    add(tot,tot-1,1);
  }
  while(q--){
    ll u,v,ans=0;u=read();v=read();
    int tu=getreal(u,m),hu=head[P],gu=dy[hu];
    int tv=getreal(v,m),hv=head[P],gv=dy[hv];
    int Lca;
    if(hu==hv)ans=dep[tu]+dep[tv]-2*dep[lca(tu,tv)];
    else{
      pair<int,int>pr=lca2(hu,hv);
      if(pr.first==pr.second){
        if(pr.first==hu)swap(u,v),swap(tu,tv),swap(hu,hv),swap(gu,gv);
        ans+=dep[tu]-dep[gu]+dis[hu]-dis[sonv];
        u=dy[sonv];
        ans+=dep[u]+dep[tv]-2*dep[lca(u,tv)];
      }else{
        ans+=dep[tu]-dep[gu]+dep[tv]-dep[gv]+dis[hu]-dis[pr.first]+dis[hv]-dis[pr.second];
        u=dy[pr.first];v=dy[pr.second];
        ans+=dep[u]+dep[v]-2*dep[lca(u,v)];
      }
    }
    printf("%lld\n",ans);
  }
  return 0;
}