[模板] NTT模板

一个普通的 NTT 模板。代码非常丑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int mod = 998244353;
const int maxn = 100010;
const int g = 3;

int n, m;
int rev[maxn*4];
int w_a[maxn*4], w_b[maxn*4], w_c[maxn*4];

int qpow(int x, int y) {
int ret = 1;
while (y) {
if (y & 1) {
ret = 1LL*ret*x%mod;
}
x = 1LL*x*x%mod;
y >>= 1;
}
return ret;
}

void calrev(int l) {
rev[0] = 0;
for (int i = 1; i < (1<<l); i++) {
rev[i] = (rev[i>>1] >> 1);
if (i & 1) rev[i] |= (1<<(l-1));
}
}

void ntt(int *a, int t, int ty) {
int len = 1<<t;
for (int i = 0; i < len; i++) {
if (rev[i] > i)
swap(a[i], a[rev[i]]);
}
for (int l = 2; l <= len; l <<= 1) {
int wn = qpow(g, (mod-1)/l);
for (int s = 0; s < len; s += l) {
int w = 1;
for (int i = s; i < (s + (l >> 1)); i++) {
int v1 = a[i], v2 = 1LL*a[i+(l>>1)]*w%mod;
a[i] = (v1+v2) % mod;
a[i+(l>>1)] = (v1-v2+mod)%mod;
w = 1LL*w*wn%mod;
}
}
}
if (ty == -1) {
for (int i = 1; i < len/2; i++) {
swap(a[i], a[len-i]);
}
int r = qpow(len, mod-2);
for (int i = 0; i < len; i++) {
a[i] = 1LL*a[i]*r%mod;
}
}
}

int main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; i++)
scanf("%d", &w_a[i]);
for (int i = 0; i <= m; i++)
scanf("%d", &w_b[i]);
int l = 0;
while ((1<<l) < n+m+1) l ++;
calrev(l);
ntt(w_a, l, 1); ntt(w_b, l, 1);
for (int i = 0; i < (1<<l); i++)
w_c[i] = 1LL*w_a[i]*w_b[i]%mod;
ntt(w_c, l, -1);
for (int i = 0; i <= n+m; i++) {
printf("%d ", w_c[i]);
}
printf("\n");
return 0;
}