[NOI2021] 路径交点 题解

发布时间 2023-08-04 21:31:38作者: 霜木_Atomic

[NOI2021] 路径交点 题解

题意

给定一张 \(k\) 层的有向图,第 \(i\) 层有 \(n_i\)​ 个顶点,第 ​\(1\) 层与第 \(k\)​ 层顶点数相同。对于第 ​ ​\(j\) \((1 \leq j <k)\) 层的顶点,只会连向第 \(j+1\) 层的顶点。没有边连向第 \(1\) 层的顶点,第 \(k\) 层的顶点不会向其他顶点连边。

现在要选出 \(n_1\) 条路径,每条路径以第 \(1\) 层顶点为起点,第 \(k\) 层顶点为终点,并要求图中的每个顶点至多出现在一条路径中。

我们规定,第 \(i\) 层与第 \(i+1\) 层之间的两条路径 \(P\)\(Q\) 有交点,当且仅当 \(P_i-Q_i\)\(P_{i+1}-Q_{i+1}\) 异号。其中 \(P_i\)\(Q_i\) 表示第 \(i\) 层的顶点,\(P_{i+1}\)\(Q_{i+1}\) 表示第 \(i+1\) 层的顶点。

现让你求出有偶数个交点的路径方案数比有奇数个交点的路径方案数多多少个。

思路

首先,上面的“异号”这一条件完全可以转化为逆序对。奇偶,做差,逆序对……好像行列式欸。事实上,如果只有两层,答案就是行为左部点,列为右部点的行列式的值。我们从求行列式值的层面来考虑。一个 \(n\) 阶行列式 \(A\) 的值可以表示为

\[\sum_{p \in P} (-1)^{k} \prod_{i = 1}^{n} A_{i, p_i} \]

其中 \(p\) 为一个 \(1\)\(n\) 的排列,\(P\)\(1\)\(n\) 的全排列集合,\(k\) 为排列 \(p\) 中逆序对的个数。

我们来考虑这里连乘的含义,其实就是在用左部点去匹配右部点;而这里的逆序对,也就是题目中说的交点数量(这里可以自己画一个行列式理解一下)。那么,最后行列式的值,就是方案数之差。

现在来考虑 \(k\) 层的情况。我们发现,对于从左往右连续的三层 \(1, 2, 3\),有两条路径 \(P\)\(Q\),我们按总交点数分为两种情况。

  • 当总交点数为偶数时,则应有 \(P_1-Q_1\)\(P_2-Q_2\) 异号且 \(P_2-Q_2\)\(P_3-Q_3\) 异号;或 \(P_1-Q_1\)\(P_2-Q_2\) 同号且 \(P_2-Q_2\)\(P_3-Q_3\) 同号。而这两种情况下,\(P_1-Q_1\)\(P_3-Q_3\) 均为同号。
  • 当总交点数为奇数时,类似的,总有 \(P_1-Q_1\)\(P_3-Q_3\) 异号。
    也就是说交点的奇偶性并不会改变,这一性质同样可以推广到连续 \(k\) 层。

对于任意相邻的两层,我们都可以建立一个行为左部点,列为右部点的邻接矩阵,我们考虑将这 \(k-1\) 个邻接矩阵相乘,得到的新矩阵中的每个元素变为从 \(1\)\(k\) 层的路径方案数,而这时的逆序对数与交点数量奇偶性的关系仍然成立。我们对矩阵进行行列式求值,得到的就是最后答案。

代码:

#include<bits/stdc++.h>
using namespace std;
const int N = 205;
const int mod = 998244353;

inline int read(){
    int x = 0; char ch = getchar();
    while(ch<'0' || ch>'9') ch = getchar();
    while(ch>='0'&&ch<='9') x = x*10+ch-48, ch = getchar();
    return x;
}
inline int fpow(int a, int b){
    int ret = 1;
    a%=mod;
    while(b){
        if(b & 1){
            ret = (1ll*ret*a)%mod;
        }
        a = (1ll*a*a)%mod;
        b>>=1;
    }
    return ret;
}
struct mat{
    int f[N][N];
    int h, w;
    mat(){memset(f, 0, sizeof(f));}
    int *operator [](int x){return f[x];}
    void init(int th, int tw){
        h = th, w = tw;
    }
    void frs(){
        memset(f, 0, sizeof(f));
    }
    mat operator *(mat B){
        mat t, BT;
        int wb = B.w, hb = B.h;
        for(int i = 1; i<=hb; ++i){
            for(int j = 1; j<=wb; ++j){
                BT[j][i] = B[i][j];
            }
        }
        for(int i = 1; i<=h; ++i){
            for(int j = 1; j<=wb; ++j){
                for(int k = 1; k<=w; ++k){
                    t[i][j] = (1ll*t[i][j]+1ll*f[i][k]*BT[j][k]%mod)%mod;
                }
            }
        }
        t.init(h, wb);
        return t;
    }
    int calc_det(){
        int ret = 1;
        for(int i = 1; i<=h; ++i){
            if(!f[i][i]){
                for(int j = i+1; j<=h; ++j){
                    if(f[j][i]){
                        swap(f[i], f[j]);
                        ret = -ret;
                        break;
                    }
                }
            }
            int inv = fpow(f[i][i], mod-2);
            for(int j = i+1; j<=h; ++j){
                int tmp = 1ll*f[j][i]*inv%mod;
                for(int k = i; k<=h; ++k){
                    f[j][k] = (1ll*f[j][k]-1ll*f[i][k]*tmp%mod)%mod;
                    f[j][k] = (1ll*f[j][k]+mod)%mod;
                }
            }
        }
        for(int i = 1; i<=h; ++i){
            ret = (1ll*ret*f[i][i]%mod);
        }
        ret = (1ll*ret+mod)%mod;
        return ret;
    }
    void print(){
        for(int i = 1; i<=h; ++i, puts("")){
            for(int j = 1; j<=w; ++j){
                printf("%d ", f[i][j]);
            }
        }
    }
}a[N];

int K;
int m[N], n[N];
int T;
int main(){
    T = read();
    while(T--){
        K = read();
        for(int i = 1; i<=K; ++i){
            n[i] = read();
        }
        for(int i = 1; i<K; ++i){
            m[i] = read();
        }
        for(int i = 1; i<K; ++i){
            a[i].init(n[i], n[i+1]);
            a[i].frs();
            for(int j = 1; j<=m[i]; ++j){
                int u = read(), v = read();
                ++a[i][u][v];
            }
        }
        mat ans = a[1];
        for(int i = 2; i<K; ++i){
            ans = ans*a[i];
        }
        printf("%d\n", ans.calc_det());

    }
    return 0;
}