CF1442D Sum

发布时间 2024-01-02 17:18:05作者: cxqghzj

题意

给定 \(n\) 个递增数组。

\(k\) 次操作,每次你可以选择一个数组,使 \(ans\) 加上数组的第一个数,并删除。

问最大化的 \(ans\) 的值。

Sol

考虑当前选择的方案如何变得更优。

不难想到,如果当前有两个数组没有选满,则一定可以调整到其中一个变成空的方案,而使得答案不劣。

所以,不难想到结论就是最优方案只有一个数组没有选满。

枚举该数组,并对剩下的数组做 \(01\) 背包。

时间复杂度 \(O(n ^ 2k)\)

发现分块显然可做,时间复杂度 \(O(n \sqrt n k)\)

考虑更优的做法。

设计这样一种分治:\((l, r)\) 表示不选 \([l, r]\) 之间的数组的方案数。

每次分治先将 \([l, mid]\) 加入 \(dp\),然后递归 \((mid + 1, r)\)

\(dp\) 数组清空,再将 \([mid + 1, r]\) 加入 \(dp\),递归 \((l, mid)\)

做完了,时间复杂度 \(O(nk \log n)\)

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <vector>
#define int long long
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
	int p = 0, flg = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') flg = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		p = p * 10 + c - '0';
		c = getchar();
	}
	return p * flg;
}
void write(int x) {
	if (x < 0) {
		x = -x;
		putchar('-');
	}
	if (x > 9) {
		write(x / 10);
	}
	putchar(x % 10 + '0');
}
const int N = 1e6 + 5, M = 3005;

array <vector <int>, M> s;
array <int, M> f;

int ans;

array <array <int, N>, 35> isl;

void solve(int l, int r, int n, int k, int d) {
	if (l == r) {
		for (int i = 0; i <= min((int)s[l].size() - 1, k); i++)
			ans = max(ans, s[l][i] + f[k - i]);
		return;
	}
	int mid = (l + r) >> 1;
	for (int i = 0; i <= k; i++)
		isl[d][i] = f[i];
	for (int i = l; i <= mid; i++)
		for (int j = k; j >= (int)s[i].size() - 1; j--)
			f[j] = max(f[j - s[i].size() + 1] + s[i].back(), f[j]);
	solve(mid + 1, r, n, k, d + 1);
	for (int i = 0; i <= k; i++)
		f[i] = isl[d][i];
	for (int i = mid + 1; i <= r; i++)
		for (int j = k; j >= (int)s[i].size() - 1; j--)
			f[j] = max(f[j - s[i].size() + 1] + s[i].back(), f[j]);
	solve(l, mid, n, k, d + 1);
	for (int i = 0; i <= k; i++)
		f[i] = isl[d][i];
}

signed main() {
	int n = read(), k = read();
	for (int i = 1; i <= n; i++) {
		int x = read();
		s[i].push_back(0);
		while (x--) {
			int y = read();
			s[i].push_back(s[i].back() + y);
		}
	}
	solve(1, n, n, k, 0);
	write(ans), puts("");
	return 0;
}