P8029 [COCI2021-2022#3] Akcija 题解

发布时间 2023-09-10 14:20:53作者: registerGen

:这篇题解中涉及到的所有概念均会在其第一次出现时用 斜体 标出,有些概念给出了定义,而有些概念的含义请自行意会。

定义 状态 为选了的物品数 \(a\) 与相应总价格 \(b\) 的二元组 \((a,b)\)。相应地定义状态之间的 大小关系最优状态 与状态和状态的 加法运算 \((a_1,b_1)+(a_2,b_2):=(a_1+a_2,b_1+b_2)\)

我们先来考虑 \(k=1\) 时的做法。首先我们将商品按 \(d_i\) 排序。设 \(f(i,j)\) 表示当考虑前 \(i\) 个物品,选了 \(j\) 个时,我们 从剩余物品中能获得的 最优状态。显然可以 dp 求出。(当然正常的思维路径是设 \(f(i,j)\) 表示前 \(i\) 个物品选 \(j\) 个时的最优状态,这么设的原因见下文。)转移方程为

\[f(i,j)=\begin{cases} f(i+1,j),&d_{i+1}\le j\\ \max\{f(i+1,j),f(i+1,j+1)+(1,w_{i+1})\},&d_{i+1}\ge j+1 \end{cases} \]

接下来,我们以状态 \((0,0)\)搜索树 的根,进行 Fracturing Search

\(d_i\) 从小到大的顺序枚举每个物品 \((w_i,d_i)\)\(1\le i\le n\))。设考虑第 \(i\) 个物品前的前 \(k\) 优的状态集合为 \(S\),考虑完第 \(i\) 个物品后的前 \(k\) 优的状态集合为 \(T\)。枚举 \((a,b)\in S\),其 后继状态\((a,b)\)(不选第 \(i\) 个物品)与 \((a+1,b+w_i)\)(选第 \(i\) 个物品,前提是 \(d_i\ge a+1\)),这些后继状态构成了集合 \(T\) 的一个超集 \(T'\)

对于 \(s=(a,b)\in T'\),搜索树上以 \(s\) 为根的子树中,最优状态为 \(s+f(i,a)\)。我们以它为关键字,取 \(T'\) 中前 \(k\) 大的元素即构成了 \(T\)。于是,我们令 \(S\gets T\),并枚举下一个物品。

我们需要枚举 \(n\) 个物品,枚举每个物品时我们以 \(\mathcal O(k)\) 的时间复杂度计算后继状态集合 \(T\),这样,我们就以 \(\mathcal O(nk)\) 的时间复杂度解决了本题。

#include <algorithm>
#include <cstdio>
using namespace std;

using ll = long long;

const int N = 2000;

struct Node { int w, d; };
struct State { int cnt; ll sum; };
inline bool operator<(const State &lhs, const State &rhs) {
  return lhs.cnt == rhs.cnt ? lhs.sum > rhs.sum : lhs.cnt < rhs.cnt;
}
inline State operator+(const State &lhs, const State &rhs) {
  return {lhs.cnt + rhs.cnt, lhs.sum + rhs.sum};
}

int n, k;
Node a[N + 10];
State f[N + 10][N + 10];
State ans[N + 10], nxt[N + 10];

int main() {
  scanf("%d%d", &n, &k);
  for (int i = 1; i <= n; i++)
    scanf("%d%d", &a[i].w, &a[i].d);
  sort(a + 1, a + n + 1, [](const Node &lhs, const Node &rhs) {
    return lhs.d < rhs.d;
  });
  for (int i = n - 1; i >= 0; i--)
    for (int j = 0; j <= i; j++) {
      f[i][j] = f[i + 1][j];
      if (a[i + 1].d > j) f[i][j] = max(f[i][j], f[i + 1][j + 1] + State({1, a[i + 1].w}));
    }
  int tota = 1;
  for (int i = 1; i <= n; i++) {
    int totn = 0;
    for (int j = 1; j <= tota; j++) {
      nxt[++totn] = {ans[j].cnt, ans[j].sum};
      if (a[i].d > ans[j].cnt) nxt[++totn] = {ans[j].cnt + 1, ans[j].sum + a[i].w};
    }
    tota = min(k, totn);
    nth_element(nxt + 1, nxt + tota + 1, nxt + totn + 1, [&](const State &lhs, const State &rhs) {
      return rhs + f[i][rhs.cnt] < lhs + f[i][lhs.cnt];
    });
    for (int j = 1; j <= tota; j++) ans[j] = nxt[j];
  }
  sort(ans + 1, ans + k + 1, [](const State &lhs, const State &rhs) { return rhs < lhs; });
  for (int i = 1; i <= k; i++)
    printf("%d %lld\n", ans[i].cnt, ans[i].sum);
  return 0;
}