ICPC 2021 上海区域赛 F Kaiji!

標籤: icpcmathdp

题意

有一个序列 a,从小到大排序。Alice 每次在 a 中选择两个相邻的值,分别放到 Bob 的左右手。Bob 可以选择左手或者右手,然后 Alice 将会告诉 Bob 所选择的手上的值。最后 Bob 需要猜测左右手上球的大小关系。问 Bob 能获得的最大胜率 p,使得对于 Alice 的所有可能选择,Bob 获胜的概率都不小于 p。

解析

a 的具体值实际上不重要,可以先离散化。

首先,在 Bob 选择左右手这个环节,一定是左右手各一半概率,否则若 Alice 在 Bob 更大概率选到的手上放了较难的值,则获胜概率会变低。

然后就是 Bob 需要对每一个可能选到的值 i,确定 fif_{i}gig_{i}hih_{i},分别表示拿到 i 之后猜测已知值小于,等于和大于位置值的概率。满足 fi+gi+hi=1f_{i}+g_{i}+h_{i} =1

考虑 gig_{i}。为了方便,令 tit_{i} 表示值 i 是否出现了多于 1 次。假设答案为 w,一个显然的关系是:gitiwg_{i} \ge t_{i}w,即如果 i 只出现了一次,那么关系不可能是相等,否则,猜测 gig_{i} 的概率至少要为 w。不妨就让 gi=wg_{i} = w。然后假设双手的值分别是 i 和 i+1,则也要满足 fi+hi+12w\frac{f_{i}+h_{i+1}}{2} \ge w,即已知 i 猜小于和已知 i+1 猜大于的概率的平均值要不小于 w。整理一下可以得到 fi+1fi+1ti+1w2wf_{i+1} \le f_{i} +1-t_{i+1}w-2w。同时注意到 0fi1tiw0 \le f_{i} \le 1 - t_{i}w,故有:0fi+1min{1ti+1w,fi+1(ti+1+2)w}0 \le f_{i+1} \le \min\{1 - t_{i+1}w, f_{i}+1-(t_{i+1}+2)w\}

容易发现,fif_{i} 的上界由 tit_{i}fi1f_{i-1} 决定,且最大化 fi1f_{i-1} 不会让 fif_{i} 上界变小,所以我们不妨就令:

fi+1=min{1ti+1w,fi+1(ti+1+2)w}f_{i+1}=\min\{\textcolor{red}{1 - t_{i+1}w}, \textcolor{blue}{f_{i}+1-(t_{i+1}+2)w}\}

问题在于 w 并没有被确定。考虑继续简化这个转移。由于红色转移和之前的 f 都是没有关系的,所以 fif_{i} 一定是从某一个 fjf_{j} 用了红色转移,然后从 jjii 都用蓝色转移得到的最小值,可以推出式子(令 si=j=0itjs_{i}=\sum_{j=0}^{i} t_{j}):

fi=1tjw+(ij)[sisj+2(ij)w]=(ij+1)[sisj1+2(ij)w]\begin{aligned} f_{i} &= \textcolor{red}{1-t_{j}w}+\textcolor{blue}{(i-j)-[s_{i}-s_{j}+2(i-j)w]} \\ &= (i-j+1)-[s_{i}-s_{j-1}+2(i-j)w] \end{aligned}

转化一下可以得到 w 的限制:

wij+1sisj1+2(ij)w \le \frac{i - j + 1}{s_{i}-s_{j-1}+2(i-j)}

故就是要求右边那个分式的最小值。可以看作是点 (si+2i,i+1)(s_{i} + 2i, i+1) 到所有 (sj1+2j,j)(s_{j-1} + 2j, j) 的斜率最小值(iji \le j)。维护一个上凸包,由于 (sj1+2j)(s_{j-1} +2j) 单调递增,可以用单调栈维护。然后在凸包上二分即可。时间复杂度 O(nlogn)O(n\log n),但实际上这个凸包大小不会太大,所以容易通过。

#include <iostream>
#include <vector>
#include <algorithm>
#include <string>

void set_io(std::string name)
{
  std::cin.tie(nullptr);
  std::ios::sync_with_stdio(false);
}

constexpr int P = 998'244'353;

constexpr long long pow_mod(long long x, long long y)
{
  y %= P - 1;
  x %= P;
  long long r = 1;
  while (y) {
    if (y & 1) r = r * x % P;
    x = x * x % P;
    y >>= 1;
  }
  return r;
}

constexpr long long inv(long long x) { return pow_mod(x, P - 2); }

struct vec_t
{
  int x, y;
  vec_t() : x(0), y(0) {}
  vec_t(int x, int y) : x(x), y(y) {}
  vec_t &operator+=(const vec_t &v) 
  {
    x += v.x;
    y += v.y;
    return *this;
  }
  vec_t &operator-=(const vec_t &v) 
  {
    x -= v.x;
    y -= v.y;
    return *this;
  }
  vec_t operator+(const vec_t &v) const { return vec_t(*this) += v; }
  vec_t operator-(const vec_t &v) const { return vec_t(*this) -= v; }

#define slope_compare(op)                                  \
  friend bool operator op (const vec_t &a, const vec_t &b) \
  {                                                        \
    return (long long)b.x * a.y op (long long)a.x * b.y;   \
  }
  slope_compare(<)
  slope_compare(>)
  slope_compare(<=)
  slope_compare(>=)
#undef slope_compare
};

void solve()
{
  int n, a, A, B, C, M;
  std::cin >> n >> a >> A >> B >> C >> M;
  A %= M;
  B %= M;
  C %= M;
  a %= M;
  std::vector<int> cnt(M);
  for (int i = 1; i <= n; i++) {
    a = (((long long)A * a % M * a % M + (long long)B * a % M) % M + C) % M + 1;
    cnt[a - 1]++;
  }

  std::vector<vec_t> st;
  vec_t ans(1, 1);
  int s = 0;

  auto insert = [&st](const vec_t &p)
  {
    while (st.size() >= 2 && (st.end()[-1] - st.end()[-2]) <= (p - st.end()[-1])) st.pop_back();
    st.emplace_back(p);
  };
  auto find = [&st](const vec_t &p) 
  {
    int l = 0, r = st.size() - 1;
    while (l < r) {
      int mid = l + (r - l) / 2;
      if (st[mid + 1] - st[mid] < p - st[mid]) {
        r = mid;
      } else {
        l = mid + 1;
      }
    }
    return st[r];
  };

  int j = 0;
  for (int i = 0; i < M; i++) {
    if (cnt[i] == 0) continue;
    int t = cnt[i] >= 2;
    insert(vec_t(s + 2 * j, j));
    vec_t p(s + t + 2 * j, j + 1);
    auto q = find(p);
    auto d = p - q;
    if (d < ans) ans = d;
    s += t;
    j++;
  }

  std::cout << (long long)ans.y * inv(ans.x) % P << std::endl;
}

int main()
{
  set_io("game");

  int t;
  std::cin >> t;
  while (t--) solve();
}