AtCoder Grand Contest 057 E RowCol/ColRow Sort

发布时间 2023-10-07 14:33:10作者: zltzlt

洛谷传送门

AtCoder 传送门

首先考虑一个经典的套路:转 \(01\)。具体而言,我们考虑若值域是 \([0, 1]\) 怎么做。

发现可以很容易地判定一个 \(A\) 是否合法。设矩阵第 \(i\) 行的和为 \(r_i\),第 \(j\) 列的和为 \(c_j\),那么合法当且仅当 \(A\)\(\{r_i\}\)\(\{c_j\}\)(可重集)分别与 \(B\)\(\{r_i\}\)\(\{c_j\}\) 相同。并且 \(r_i, c_j\) 的每一种不同的排列方案都恰好对应一个可以被操作成 \(B\)\(A\)

那么值域为 \([0, 1]\) 时答案就是 \(\{r_i\}\)\(\{c_j\}\) 的可重集排列数相乘。

考虑值域为 \([0, 9]\) 的情况,考虑枚举 \(k \in [0, 8]\),把 \(\le k\) 的值赋成 \(0\)\(> k\) 赋成 \(1\)。那么 \(A\) 合法等价于,对于每个 \(k \in [0, 8]\),都存在两个排列 \(p_k(i), q_k(j)\),使得 \(A_{i, j} \le k \Longleftrightarrow B_{p_k(i), q_k(j)} \le k\)。那么一堆排列 \((p_0, q_0, p_1, q_1, \ldots, p_8, q_8)\) 可以唯一确定一个 \(A\),但是因为 \(\{r_i\}, \{c_j\}\) 是可重集,所以一个 \(A\) 实际上会对应 \(\prod\limits_{k = 0}^m (\sum\limits_{i = 1}^n [r_i = k])! \times \prod\limits_{k = 0}^n (\sum\limits_{i = 1}^m [c_i = k])!\) 堆排列。最后除一下即可。

于是考虑对这堆排列 \((p_0, q_0, p_1, q_1, \ldots, p_8, q_8)\) 计数。条件 \(A_{i, j} \le k \Longleftrightarrow B_{p_k(i), q_k(j)} \le k\) 里面有 \(A\),不妨把 \(A\) 扔掉,根据 \(A_{i, j} \le k \Longrightarrow A_{i, j} \le k + 1\)\(B_{p_k(i), q_k(j)} \le k \Longrightarrow B_{p_{k + 1}(i), q_{k + 1}(j)} \le k + 1\)

发现我们实际上只关心 \(p_{k + 1} \circ p_k^{-1}\)\(q_{k + 1} \circ q_k^{-1}\)。于是条件可以被改写成 \(B_{i, j} \le k \Longrightarrow B_{p_k(i), q_k(j)} \le k + 1\)

考察 \([B_{i, j} \le k]\) 的杨表结构。设 \(a_i = \sum\limits_{j = 1}^m [B_{i, j} \le k], b_j = \sum\limits_{i = 1}^n [B_{i, j} \le k + 1]\),可以发现 \(a, b\) 单调不升。若把 \(B\) 旋转 \(180°\),那么条件可以转化为 \(j \le a_i \Longrightarrow p_k(i) \le b_{q_k(j)}\)。更进一步地,因为 \(b\) 单调不升,所以 \(p_k(i) \le b_{\max\limits_{j = 1}^{a_i} q_k(j)}\)

然后可以 dp 计数了。设 \(f_{i, j}\)\(\max\limits_{o = 1}^i q_k(o) = j\) 的方案数。我们 \(p_k, q_k\) 的方案数分别统计。\(f_{i - 1} \to f_i\) 时,计算 \(q_k(i)\) 的方案数,有 \(f_{i, j} \gets (j - i + 1) f_{i - 1, j} + \sum\limits_{k = 1}^{j - 1} f_{i - 1, k}\),分别表示 \(q_k(i)\) 是或不是最大值。然后我们再维护一个指针 \(t\) 从后往前扫,当 \(a_t = j\) 时,计算 \(p_k(t)\) 的方案,相当于 \(p_k(t)\) 有一个 \(b_j\) 的上界,于是有 \(f_{i, j} \gets f_{i, j} \times (b_j - t + 1)\),因为 \(p_k(t)\) 的上界随 \(t\) 减小而减小。

每次的 \(f_{m, m}\) 相乘,然后再除一下上面的 \(\prod\limits_{k = 0}^m (\sum\limits_{i = 1}^n [r_i = k])! \times \prod\limits_{k = 0}^n (\sum\limits_{i = 1}^m [c_i = k])!\) 就是最终答案。

时间复杂度 \(O(Vnm)\)

code
// Problem: E - RowCol/ColRow Sort
// Contest: AtCoder - AtCoder Grand Contest 057
// URL: https://atcoder.jp/contests/agc057/tasks/agc057_e
// Memory Limit: 1024 MB
// Time Limit: 3000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 1510;
const ll mod = 998244353;

inline ll qpow(ll b, ll p) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = res * b % mod;
		}
		b = b * b % mod;
		p >>= 1;
	}
	return res;
}

ll n, m, a[maxn][maxn], b[19][maxn], c[19][maxn], fac[maxn], ifac[maxn], f[maxn][maxn], d[maxn];

void solve() {
	scanf("%lld%lld", &n, &m);
	fac[0] = 1;
	for (int i = 1; i <= max(n, m); ++i) {
		fac[i] = fac[i - 1] * i % mod;
	}
	ifac[max(n, m)] = qpow(fac[max(n, m)], mod - 2);
	for (int i = max(n, m) - 1; ~i; --i) {
		ifac[i] = ifac[i + 1] * (i + 1) % mod;
	}
	for (int i = 1; i <= n; ++i) {
		for (int j = 1; j <= m; ++j) {
			scanf("%lld", &a[i][j]);
			++b[a[i][j]][i];
			++c[a[i][j]][j];
		}
	}
	for (int k = 1; k <= 9; ++k) {
		for (int i = 1; i <= n; ++i) {
			b[k][i] += b[k - 1][i];
		}
		for (int i = 1; i <= m; ++i) {
			c[k][i] += c[k - 1][i];
		}
	}
	ll ans = 1;
	for (int k = 0; k <= 8; ++k) {
		mems(f, 0);
		f[0][0] = 1;
		int p = n;
		while (p && b[k][p] == 0) {
			ans = ans * (n - p + 1) % mod;
			--p;
		}
		for (int i = 1; i <= m; ++i) {
			ll s = 0;
			for (int j = 1; j <= m; ++j) {
				s = (s + f[i - 1][j - 1]) % mod;
				f[i][j] = (f[i - 1][j] * (j - i + 1) % mod + s) % mod;
			}
			while (p && b[k][p] == i) {
				for (int j = 1; j <= m; ++j) {
					f[i][j] = f[i][j] * (c[k + 1][j] - p + 1) % mod;
				}
				--p;
			}
		}
		ans = ans * f[m][m] % mod;
		mems(d, 0);
		for (int i = 1; i <= n; ++i) {
			++d[b[k][i]];
		}
		for (int i = 0; i <= m; ++i) {
			ans = ans * ifac[d[i]] % mod;
		}
		mems(d, 0);
		for (int i = 1; i <= m; ++i) {
			++d[c[k][i]];
		}
		for (int i = 0; i <= n; ++i) {
			ans = ans * ifac[d[i]] % mod;
		}
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}