[CF1349D] Slime and Biscuits

做法

\(\sum a_i = s\)

显然在任何一个时刻只有至多一个人拥有全部饼干。

设从一个存在一个人拥有所有饼干的状态开始,假设游戏不会结束,\(k\) 步之后存在一个拥有所有饼干的人,且这是这个人第一次拥有所有饼干的概率为 \(q_k\) (\(q_0=1\)),设 \(Q(x) = \sum_{k \ge 0} q_k x^k\)

设第 \(k\) 步时第一次出现一个拥有全部饼干的人的概率为 \(f_k\) (即游戏 \(k\) 步结束的概率),设 \(F(x) = \sum_{k\ge 0} f_kx^k\)

\(p_{i,k}\) 表示第 \(i\) 个人 \(k\) 步后第一次拥有所有饼干的概率,设 \(P_i(x) = \sum_{k \ge 0} p_{i,k}x^k\)

那么 \(F(x) Q(x) = \sum P_i(x)\),所以 \(F(x) = \frac{\sum P_i(x)}{Q(x)}\)。我们要求的是 \(F'(1)\)。由分式求导公式问题被转化为求 \(P_i(1), P'_i(1), Q(1), Q'(1)\)

显然 \(P_i(1) = 1, Q(1) = n\),考虑 \(P'_i(x), Q'(x)\) 的实际意义,问题变成求一个初始有 \(x\) 个饼干的人第一次拥有全部饼干的期望步数,这这个值为 \(x_i\),那么有

\[ x_s = 0\\ x_i = 1 + \frac 1 {n(n-1)}[(n-i)(n-2)x_i + i(n-1)x_{i-1}+(n-i)x_{i+1}]\\ P'_i(1) = x_{a_i}, Q'(1) = (n-1) x_0 \]

解方程即可。

时间复杂度 \(\mathcal O(n + \sum a_i)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int mod = 998244353;
const int maxn = 300010;

int n, a[maxn], x[maxn], k[maxn], b[maxn], s;

int qpow(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) {
ret = 1LL * ret * x % mod;
}
y >>= 1;
x = 1LL * x * x % mod;
}
return ret;
}

int inv(ll x) {
return qpow(x % mod, mod - 2);
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
s += a[i];
}
k[0] = 1;
for (int i = 0; i < s; i++) {
int k1 = 1LL * (mod + 1 - 1LL * (s-i) * inv(s) % mod * (n-2) % mod * inv(n-1) % mod) % mod * inv(1LL * (s-i) * inv(s) % mod * inv(n-1) % mod) % mod;
int k2 = 1LL * (mod-i) * (n-1) % mod * inv(s - i) % mod;
int _b = 1LL * (n-1) * (mod - s) % mod * inv(s - i) % mod;
k[i+1] = (k[i+1] + 1LL * k1 * k[i] % mod) % mod;
b[i+1] = (b[i+1] + 1LL * k1 * b[i] % mod) % mod;
if (i > 0) {
k[i+1] = (k[i+1] + 1LL * k2 * k[i-1] % mod) % mod;
b[i+1] = (b[i+1] + 1LL * k2 * b[i-1] % mod) % mod;
}
b[i+1] = (b[i+1] + _b) % mod;
}
int t = 0;
x[0] = 1LL * (mod - b[s]) * qpow(k[s], mod - 2) % mod;
for (int i = 1; i <= s; i++) {
x[i] = (1LL * k[i] * x[0] + b[i]) % mod;
}
for (int i = 1; i <= n; i++) {
t = (t + x[a[i]]) % mod;
}
int ans = 1LL * (t - 1LL * (n-1) * x[0] % mod + mod) % mod * inv(n) % mod;
printf("%d\n", ans);
return 0;
}