[Codeforces438E] The Child and Binary Tree

做法

\(C(x) = \sum x^{c_i}\)\(a_i\) 是权值和为 \(i\) 的二叉树数量 (特别地,空二叉树算作一种权值和为 \(0\) 的二叉树),\(A(x) = \sum_{i \ge 0} a_ix^i\)

\(A(x) = A^2(x)C(x)+ 1\)\(C(x)A^2(x)-A(x)+1=0\)。由一元二次方程求根公式得 \(A(x) = \frac{1 \pm \sqrt {1-4C(x)}}{2C(x)}\),由于分母没有常数项,这里应该取负号。

于是多项式开根一下就好了。(因为 \([x^0] (1-4C(x)) = 1\),只存在两个的多项式 \(F(x)\) 满足 \(F^2(x) \equiv 1-4C(x) \pmod {x^n}\),一正一负)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <bits/stdc++.h>

using namespace std;

const int mod = 998244353;
const int g = 3;
const int i2 = (mod + 1) / 2;
const int maxn = 200010;

int n, m;
int wa[maxn * 8], wb[maxn * 8], wc[maxn * 8], rev[maxn * 8];

int qpow(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) ret = 1LL * ret * x % mod;
x = 1LL * x * x % mod;
y >>= 1;
}
return ret;
}

void ntt(int *a, int _l, int ty) {
int l = (1 << _l);
for (int i = 1; i < l; i++) rev[i] = (rev[i>>1]>>1)|((i&1)<<(_l-1));
for (int i = 0; i < l; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int len = 2; len <= l; len <<= 1) {
int wl = qpow(g, (mod - 1) / len);
for (int s = 0; s < l; s += len) {
int w = 1;
for (int i = 0; i < (len >> 1); i++) {
int v1 = a[s + i], v2 = 1LL * a[s + i + (len >> 1)] * w % mod;
a[s + i] = (v1 + v2) % mod;
a[s + i + (len >> 1)] = (v1 + mod - v2) % mod;
w = 1LL * w * wl % mod;
}
}
}
if (ty == -1) {
int inv = qpow(l, mod-2);
for (int i = 0; i < l; i++) {
a[i] = 1LL * a[i] * inv % mod;
}
for (int i = 1; i < l / 2; i++) {
swap(a[i], a[l-i]);
}
}
}

struct poly {
int *a, len;
poly (int len_ = 0) {
len = len_;
a = new int [len];
for (int i = 0; i < len; i++) {
a[i] = 0;
}
}
};

poly operator*(const poly &p1, const poly &p2) {
poly ret(p1.len + p2.len - 1);
int t = 0;
while ((1 << t) < ret.len) {
++ t;
}
int l = (1 << t);
for (int i = 0; i < l; i++) {
wa[i] = wb[i] = wc[i] = 0;
}
for (int i = 0; i < p1.len; i++) {
wa[i] = p1.a[i];
}
for (int i = 0; i < p2.len; i++) {
wb[i] = p2.a[i];
}
ntt(wa, t, 1), ntt(wb, t, 1);
for (int i = 0; i < l; i++) {
wc[i] = 1LL * wa[i] * wb[i] % mod;
}
ntt(wc, t, -1);
for (int i = 0; i < ret.len; i++) {
ret.a[i] = wc[i];
}
return ret;
}

poly polyInv(const poly &p) {
if (p.len == 1) {
poly ret(1);
ret.a[0] = qpow(p.a[0], mod-2);
return ret;
}
int tl = (p.len + 1) >> 1;
poly p0(tl);
for (int i = 0; i < tl; i++) {
p0.a[i] = p.a[i];
}
poly r0 = polyInv(p0);
poly v0 = r0 * p;
poly v1(p.len); v1.a[0] = 2;
for (int i = 0; i < p.len; i++) {
v1.a[i] = (v1.a[i] + mod - v0.a[i]) % mod;
}
poly r = r0 * v1;
r.len = p.len;
return r;
}

poly polySqrt(const poly &p) {
if (p.len == 1) {
poly ret(1);
ret.a[0] = 1;
return ret;
}
int tl = (p.len + 1) >> 1;
poly p0(tl);
for (int i = 0; i < tl; i++) {
p0.a[i] = p.a[i];
}
poly r0 = polySqrt(p0);
poly r1(p.len);
for (int i = 0; i < r0.len; i++) {
r1.a[i] = r0.a[i];
}
poly v0 = polyInv(r1) * p;
poly v1(p.len);
for (int i = 0; i < r0.len; i++) {
v1.a[i] = r0.a[i];
}
for (int i = 0; i < p.len; i++) {
v1.a[i] = 1LL * i2 * (v1.a[i] + v0.a[i]) % mod;
}
return v1;
}

int _c[maxn];

int main() {
int k = 0;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
int c; scanf("%d", &c);
++ _c[c];
}
for (int i = 1; i <= 100000; i++) {
if (_c[i]) {
k = i;
break;
}
}
if (k > m) {
for (int i = 1; i <= m; i++) {
puts("0");
}
return 0;
}
poly p(2); p.a[0] = p.a[1] = 1;
poly q = p * p;
poly C(m + k + 1);
for (int i = 0; i <= min(100000, m+k); i++) {
C.a[i] = _c[i];
}
poly v0(m + k + 1); v0.a[0] = 1;
for (int i = 0; i < v0.len; i++) {
v0.a[i] = (v0.a[i] + mod - 1LL * 4 * C.a[i] % mod) % mod;
}
poly v1 = polySqrt(v0); v1.a[0] = (v1.a[0] + mod - 1) % mod;
for (int i = 0; i < v1.len; i++) {
v1.a[i] = (mod - v1.a[i]) % mod;
}
poly v2(m + k + 1);
for (int i = 0; i < C.len; i++) {
v2.a[i] = 1LL * 2 * C.a[i] % mod;
}
poly X(m+1), Y(m+1);
for (int i = 0; i <= m; i++) {
X.a[i] = v1.a[i + k];
Y.a[i] = v2.a[i + k];
}
poly res = X * polyInv(Y);
for (int i = 1; i <= m; i++) {
printf("%d\n", res.a[i]);
}
return 0;
}