[LOJ2504] 「2018 集训队互测 Day 5」小 H 爱染色

做法

把编号看成 \(1 \ldots n\) 调了巨久,我大概没救了。

显然编号最小黑球编号为 \(k\) 的方案有 \(\binom{n-k+1}m^2 - \binom {n-k}m^2\) 种。

答案为

\[ \sum_{k = 0}^{n-1} (\binom{n-k}m^2 - \binom {n-k-1}m^2) F(k) \]

注意到 \((\binom{n-k}m^2 - \binom {n-k-1}m^2) F(k)\) 是一个关于 \(k\)\(3m\) 次多项式,我们要求这个多项式的一个前缀和,这个多项式的前缀和是一个 \(3m+1\) 次多项式,设它为 \(G(k)\)。如果我们知道 \(F(0) \ldots F(3m+1)\) 就可以很容易地得到 \(G(0) \ldots G(3m+1)\),然后拉格朗日插值得到 \(G(n-1)\)

但是我们只知道 \(F(0) \ldots F(m)\)。考虑如何计算 \(F(m+1) \ldots F(3m+1)\)

\[ F(x) = \sum_{i=0}^m F(i) \prod_{j \neq i, 0 \le j\le m} \frac{x-j}{i-j} \\\\ \text{对于 $m < x \le 3m+1$} \\\\ F(x) = \sum_{i=0}^m F(i) \frac {(-1)^{m-i}x^{\underline{m+1}}} {(x-i)i!(m-i)!}=x^{\underline{m+1}} \sum_{i=0}^m F(i)\frac{(-1)^{m-i}}{i!(m-i)!} \frac 1 {x-i} \]

这可以用 NTT 优化,时间复杂度 \(\mathcal O(m \log m)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include <bits/stdc++.h>

using namespace std;

typedef unsigned long long ull;

const int mod = 998244353;
const int g = 3;
const int maxm = 1000010;

int qpow(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) {
ret = 1LL * ret * x % mod;
}
y >>= 1;
x = 1LL * x * x % mod;
}
return ret;
}

inline int mo(int x) {
if (x >= mod) {
return x - mod;
} else {
return x;
}
}

namespace NTT {
int wn[30], w[(1 << 22) + 10], rev[(1<<22) + 10], wa[(1 << 22) + 10], wb[(1 << 22) + 10], wc[(1 << 22) + 10];
ull _a[(1 << 22) + 10];
void init() {
for (int i = 0; i <= 22; i++) {
wn[i] = qpow(g, (mod - 1) / (1 << i));
}
}
void ntt(int *a, int _l, int ty) {
int l = (1 << _l);
for (int i = 1; i < l; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (_l - 1));
}
for (int i = 0; i < l; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int i = 0; i < l; i++) {
_a[i] = a[i];
}
int _ = 0;
for (int len = 2; len <= l; len <<= 1) {
w[0] = 1; ++ _;
for (int i = 1; i < len; i++) {
w[i] = 1LL * w[i-1] * wn[_] % mod;
}
for (int s = 0; s < l; s += len) {
for (int i = s; i < s + (len >> 1); i++) {
ull v1 = _a[i], v2 = w[i - s] * _a[i + (len >> 1)] % mod;
_a[i] = v1 + v2;
_a[i + (len >> 1)] = v1 + mod - v2;
}
}
if (len == (1 << 15)) {
for (int i = 0; i < l; i++) {
_a[i] %= mod;
}
}
}
for (int i = 0; i < l; i++) {
a[i] = _a[i] % mod;
}
if (ty == -1) {
int inv = qpow(l, mod - 2);
for (int i = 0; i < l; i++) {
a[i] = 1LL * a[i] * inv % mod;
}
for (int i = 1; i < l / 2; i++) {
swap(a[i], a[l-i]);
}
}
}
}

int cal(int x) {
return 1LL * x * x % mod;
}

int n, m;
int F[maxm * 3], _F[maxm * 3], G[maxm * 3], fac[maxm * 3], ifac[maxm * 3], inv[maxm * 3], rinv[maxm * 3];
int pw[maxm * 3]; // n - i 的 m 次下降幂

int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= m; i++) {
scanf("%d", &F[i]);
}
NTT::init();
fac[0] = ifac[0] = 1;
inv[1] = 1;
for (int i = 2; i <= 3 * m + 1; i++) {
inv[i] = mod - 1LL * inv[mod % i] * (mod / i) % mod;
}
for (int i = 0; i <= 3 * m + 1; i++) {
rinv[i] = qpow(n - i, mod - 2);
}
for (int i = 1; i <= 3 * m + 1; i++) {
fac[i] = 1LL * fac[i-1] * i % mod;
ifac[i] = 1LL * ifac[i-1] * inv[i] % mod;
}
{
using namespace NTT;
int l = (1 << 22);
for (int i = 0; i <= m; i++) {
wa[i] = 1LL * F[i] * ifac[i] % mod * ifac[m - i] % mod;
if ((m - i) & 1) {
wa[i] = mo(mod - wa[i]);
}
}
for (int i = 1; i <= 3 * m + 1; i++) {
wb[i] = inv[i];
}
ntt(wa, 22, 1), ntt(wb, 22, 1);
for (int i = 0; i < l; i++) {
wc[i] = 1LL * wa[i] * wb[i] % mod;
}
ntt(wc, 22, -1);
for (int i = m + 1; i <= 3 * m + 1; i++) {
F[i] = 1LL * fac[i] * ifac[i - m - 1] % mod * wc[i] % mod;
}
}
{
// P1 : n - i >= m
pw[0] = 1;
for (int i = 0; i < m; i++) {
pw[0] = 1LL * pw[0] * (n - i) % mod;
}
for (int i = 1; i <= n - m && i <= 3 * m + 2; i++) {
pw[i] = 1LL * pw[i-1] * rinv[i - 1] % mod * (n - i - m + 1) % mod;
}
// P2 : n - i < 0
for (int i = n + 1; i <= 3 * m + 2; i++) {
pw[i] = 1LL * fac[i - n + m - 1] * ifac[i - n - 1] % mod;
if (m & 1) {
pw[i] = mo(mod - pw[i]);
}
}
}
for (int i = 0; i <= 3 * m + 1; i++) {
_F[i] = 1LL * cal(ifac[m]) * (cal(pw[i]) + mod - cal(pw[i+1])) % mod * F[i] % mod;
}
G[0] = _F[0];
for (int i = 1; i <= 3 * m + 1; i++) {
G[i] = mo(G[i-1] + _F[i]);
}
if (n - 1 > 3 * m + 1) {
int ans = 0;
int t = 1;
for (int i = 0; i <= 3 * m + 1; i++) {
t = 1LL * t * (n - 1 - i) % mod;
}
for (int i = 0; i <= 3 * m + 1; i++) {
int v = 1LL * t * qpow(n - 1 - i, mod - 2) % mod * ifac[i] % mod * ifac[3 * m + 1 - i] % mod;
if ((3 * m + 1 - i) & 1) {
v = mo(mod - v);
}
ans = mo(ans + 1LL * G[i] * v % mod) % mod;
}
printf("%d\n", ans);
} else {
printf("%d\n", G[n-1]);
}
return 0;
}