AtCoder Regular Contest 119 F AtCoder Express 3

发布时间 2023-05-02 17:02:49作者: zltzlt

洛谷传送门

AtCoder 传送门

很厉害的题!

考虑所有车站已确定,如何求 \(0\)\(n+1\) 的最短路。设 \(g_{i,0}\) 为只考虑 \(0 \sim i\) 的点,到 \(i\) 和它左边第一个 \(\text{A}\) 的最短路,\(g_{i,1}\) 同理。有转移:

  • \(s_{i-1} = \text{A}, s_i = \text{A}, g_{i,0} \gets g_{i-1,0} + 1\)
  • \(s_{i-1} = \text{A}, s_i = \text{B}, g_{i,0} \gets \min(g_{i-1,0}, g_{i-1,1} + 2)\)
  • \(s_{i-1} = \text{B}, s_i = \text{A}, g_{i,0} \gets \min(g_{i-1,0} + 1, g_{i-1,1} + 1)\)
  • \(s_{i-1} = \text{B}, s_i = \text{A}, g_{i,0} \gets g_{i-1,0}\)

\(g_{i,1}\) 的转移是对称的。

\(f_{i,x,y,0/1}\) 表示当前考虑了 \(0 \sim i\) 的车站,\(g_{i,0} = x, g_{i,1} = y\)\(s_i\)\(\text{A}\)\(\text{B}\) 的方案数。这是 \(O(n^3)\) 的。

考虑压状态。显然遇到 \(\text{ABB...B}\)\(x,y\) 相差就会很大。但是要到达最后一个 \(\text{B}\),可以先跳一步 \(\text{A}\) 再往回走。这是我们的最终目标,我们不关心途中的最短路数值究竟是什么。因此可以做出如下优化:当 \(x \ge y + 2\) 时,强制让 \(x \gets y + 2\)\(y\) 同理。这样不会影响最终答案。这样是 \(O(n^2)\) 的。

实现时用 unordered_map 记录所有状态有效。

code
// Problem: F - AtCoder Express 3
// Contest: AtCoder - AtCoder Regular Contest 119
// URL: https://atcoder.jp/contests/arc119/tasks/arc119_f
// Memory Limit: 1024 MB
// Time Limit: 4000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<int, int> pii;

const int maxn = 4010;
const int mod = 1000000007;

int n, m;
char s[maxn];
unordered_map<int, int> f[2][maxn][2];

inline void upd(int o, int x, int y, int k, int val) {
	x = min(x, y + 2);
	y = min(y, x + 2);
	int &p = f[o][x][k][y];
	p += val;
	(p >= mod) && (p -= mod);
}

void solve() {
	scanf("%d%d%s", &n, &m, s + 1);
	--n;
	if (s[1] != 'B') {
		f[1][1][0][0] = 1;
	}
	if (s[1] != 'A') {
		f[1][0][1][1] = 1;
	}
	for (int i = 2, o = 0; i <= n; ++i, o ^= 1) {
		for (int x = 0; x <= n + 2; ++x) {
			for (int k = 0; k <= 1; ++k) {
				f[o][x][k].clear();
			}
		}
		for (int x = 0; x <= n + 2; ++x) {
			for (pii p : f[o ^ 1][x][0]) {
				int y = p.fst, val = p.scd;
				if (s[i] != 'B') {
					upd(o, x + 1, y, 0, val);
				}
				if (s[i] != 'A') {
					upd(o, min(x, y + 2), min(x + 1, y + 1), 1, val);
				}
			}
			for (pii p : f[o ^ 1][x][1]) {
				int y = p.fst, val = p.scd;
				if (s[i] != 'B') {
					upd(o, min(x + 1, y + 1), min(x + 2, y), 0, val);
				}
				if (s[i] != 'A') {
					upd(o, x, y + 1, 1, val);
				}
			}
		}
	}
	int ans = 0;
	for (int x = 0; x <= n; ++x) {
		for (int k = 0; k <= 1; ++k) {
			for (pii p : f[n & 1][x][k]) {
				int y = p.fst, val = p.scd;
				if (min(x, y) + 1 <= m) {
					ans = (ans + val) % mod;
				}
			}
		}
	}
	printf("%d\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}