题解 P8338 [AHOI2022] 排列

发布时间 2023-07-17 21:44:46作者: HQJ2007

恶心题。

每次操作,相当与把第 \(i\) 个数置换到 \(p_i\),于是可以连边。

因为 \(i\)\(p_i\) 互不相同,所以对于每一个点,有且仅有一条出边和一条入边,即若干个简单环。

那么最少操作 \(\operatorname{lcm}(a_1,a_2,a_3...a_{x-2},a_{x-1},a_x)\) 次点会都回到原位。其中 \(a_i\) 表示第 \(i\) 个环的大小。

交换两个点 \(i,j\) 相当于把 \(i\) 所在的环和 \(j\) 所在的环合并。

我们只关注环的数量和大小,所以可以将相同大小的环合并。

因为点数之和为n,所以不同大小的环最多有 \(\sqrt{n}\) 个。

每次计算 \(lcm\) 时,直接暴力删数加数即可。

实现时只需要记录每个质因子所出现的最大、次大和次次大的指数,因为每次只合并两个环。

复杂度 \(O(n\log n)\)

code:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e5+5,mod=1e9+7;
int T,n,x[N],cnt[N],cir[N],fa[N],b[N][3],vis[N][3],vis2[N],p[N];
struct node{
  int val,num;
  node(int val=0,int num=0):val(val),num(num){}
};
vector<node>y[N];
vector<int>type,tmp;
int ff(int u){return fa[u]==u?u:fa[u]=ff(fa[u]);}
ll ksm(ll xx,ll yy){
  ll res=1;
  while(yy){
    if(yy&1)res=res*xx%mod;
    yy>>=1;
    xx=xx*xx%mod;
  }
  return res;
}
int main(){
  ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
  cin>>T;
  int tot=0;
  for(ll i=2;i<=5e5;++i){
    if(!x[i])p[++tot]=i,x[i]=i;
    for(int j=1;j<=tot&&p[j]*i<=5e5;++j){
      x[p[j]*i]=p[j];
      if(i%p[j]==0)break;
    }
  }
  for(int i=2;i<=5e5;++i){
    int t=i;
    while(x[t]){
      if(y[i].size()&&y[i].back().val==x[t])++y[i][y[i].size()-1].num;
      else y[i].push_back(node(x[t],1));
      t/=x[t];
    }
  }
  while(T--){
    cin>>n;
    type.clear();
    for(int i=1;i<=n;++i)fa[i]=i,cnt[i]=cir[i]=b[i][0]=b[i][1]=b[i][2]=0;
    for(int i=1;i<=n;++i){
      int a;cin>>a;
      fa[ff(i)]=ff(a);
    }
    tmp.clear();
    for(int i=1;i<=n;++i)++cnt[ff(i)];
    for(int i=1;i<=n;++i)if(cnt[i])++cir[cnt[i]],tmp.push_back(cnt[i]);
    for(int i=1;i<=n;++i)if(cir[i])type.push_back(i);
    for(int i=0;i<tmp.size();++i){
      int v=tmp[i];
      for(int j=0;j<y[v].size();++j){
        int w=y[v][j].val,t=y[v][j].num;
        if(t>=b[w][0])b[w][2]=b[w][1],b[w][1]=b[w][0],b[w][0]=t;
        else if(t>=b[w][1])b[w][2]=b[w][1],b[w][1]=t;
        else b[w][2]=max(b[w][2],t);
      }
    }
    ll z=1,res=0;
    for(int i=2;i<=n;++i)z=z*ksm(i,b[i][0])%mod;
    for(int i=0;i<type.size();++i){
      for(int j=i;j<type.size();++j){
        int t=type[i],t2=type[j];ll ans=z;
        tmp.clear();
        for(int l=0;l<y[t].size();++l){
          int w=y[t][l].val,num=y[t][l].num;
          if(!vis2[w])tmp.push_back(w),vis2[w]=1;
          if(b[w][0]==num&&!vis[w][0])vis[w][0]=1;
          else if(b[w][1]==num)vis[w][1]=1;
        }
        for(int l=0;l<y[t2].size();++l){
          int w=y[t2][l].val,num=y[t2][l].num;
          if(!vis2[w])tmp.push_back(w),vis2[w]=1;
          if(b[w][0]==num&&!vis[w][0])vis[w][0]=1;
          else if(b[w][1]==num)vis[w][1]=1;
        }
        for(int l=0;l<tmp.size();++l){
          int w=tmp[l];
          for(int k=0;k<2&&vis[w][k]==1;++k){
            ans=ans*ksm(ksm(w,b[w][k]-b[w][k+1]),mod-2)%mod;
          }
        }
        ll tt=t+t2;
        if(tt<=n){
          for(int l=0;l<y[tt].size();++l){
            int w=y[tt][l].val,num=y[tt][l].num,k=0;
            while(vis[w][k]==1)++k;
            if(num>b[w][k])ans=ans*ksm(w,num-b[w][k])%mod;
          }
        }
        if(t==t2)res=(res+(ll)cir[t]*t%mod*(cir[t]-1)%mod*t%mod*ans%mod)%mod;
        else res=(res+(ll)2*cir[t]*t%mod*cir[t2]%mod*t2%mod*ans%mod)%mod;
        for(int l=0;l<tmp.size();++l){
          int w=tmp[l];
          vis2[w]=vis[w][0]=vis[w][1]=0;
        }
      }
    }
    cout<<res<<endl;
  }
  return 0;
}