[Codeforces1063F] String Journey

做法

为了方便描述,先把题目中给定串翻转一下,变成要找一个每个串是后一个串子串的序列。

下面我们认为,这个序列中的元素是一个(包含位置的)子串。同一字符串出现在不同位置被认为是不同的。

假设你有一个序列,你一定可以在不改变长度的情况下把它调整成第 \(i\) 个子串长为 \(i\)。所以我们只考虑这类序列。

称一个子串 \(s[l..r]\) 是可达的,当且仅当存在一个这类序列以 \(s[l..r]\) 结尾。题目就是要求最长可达的子串。

注意到如果 \(l < r\)\(s[l..r]\) 可达,那么 \(s[l..r-1]\)\(s[l+1..r]\) 一定可达。因此我们只需要对每个 \(l\) 求出最大的 \(r\) 使得 \(s[l..r]\) 可达,设对于 \(l=i\) 最大的 \(r\)\(f_i\)。对 \(f\) 作 dp,显然 \(f_i \ge f_{i-1}\),所以每次先令 \(f_i = f_{i-1}\),然后检查一下 \(f_i\) 能不能增大,如果能增大就一直增大到不能增大就行了。检验只需要在 SAM 上找到要检验的串在某个位置之前的最后出现位置就好,这很容易用 SAM + 线段树实现(也可以用其他方式)。由于总共只会增大 \(\mathcal O(n)\) 次,这个算法的时间复杂度为 \(\mathcal O(n \log n)\)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>

using namespace std;

const int maxn = 500010;

int n, ans, f[maxn];
char s[maxn];
int last, tot, ch[maxn*2][26], par[maxn*2], len[maxn*2], ind[maxn];
int tag[maxn*2], fa[maxn*2][20];
int T_tot;
int ls[maxn*40], rs[maxn*40], sum[maxn*40], trt[maxn*2];
vector<int> son[maxn*2];

void upd(int p, int l, int r, int &rt) {
if (!rt) rt = ++ T_tot;
++ sum[rt];
if (l == r) return;
int m = (l + r) >> 1;
if (p <= m) upd(p, l, m, ls[rt]);
else upd(p, m+1, r, rs[rt]);
}

int Merge(int x, int y) {
if (!x || !y) return x + y;
int ret = ++ T_tot;
sum[ret] = sum[x] + sum[y];
ls[ret] = Merge(ls[x], ls[y]);
rs[ret] = Merge(rs[x], rs[y]);
return ret;
}

void addchar(int c, int l) {
int np = ++ tot; len[np] = l;
while (last && !ch[last][c]) {ch[last][c] = np; last = par[last];}
if (!last) par[np] = 1;
else {
int q = ch[last][c];
if (len[q] == len[last] + 1) par[np] = q;
else {
int nq = ++ tot; par[nq] = par[q], len[nq] = len[last] + 1;
memcpy(ch[nq], ch[q], sizeof(ch[nq]));
par[q] = par[np] = nq;
while (last && ch[last][c] == q) {ch[last][c] = nq; last = par[last];}
}
}
tag[np] = l;
last = np;
}

int qrys(int p, int l, int r, int rt) {
if (!rt || p <= 0) return 0;
if (r <= p) return sum[rt];
int m = (l + r) >> 1;
int ret = 0;
ret += qrys(p, l, m, ls[rt]);
if (p > m) ret += qrys(p, m+1, r, rs[rt]);
return ret;
}

int qryk(int k, int l, int r, int rt) {
if (!rt || k <= 0 || k > sum[rt]) return 0;
if (l == r) return l;
int m = (l + r) >> 1;
if (sum[ls[rt]] >= k) return qryk(k, l, m, ls[rt]);
else return qryk(k-sum[ls[rt]], m+1, r, rs[rt]);
}

void dfs(int u) {
fa[u][0] = par[u];
for (int i = 1; i < 20; i++) fa[u][i] = fa[fa[u][i-1]][i-1];
if (tag[u]) upd(tag[u], 1, n, trt[u]);
for (int i = 0; i < son[u].size(); i++) {
int v = son[u][i];
dfs(v);
trt[u] = Merge(trt[u], trt[v]);
}
}

// 有没有右端点在 x 或之前的
int check(int l, int r, int x) {
int u = ind[r];
for (int i = 19; i >= 0; i--) {
if (len[fa[u][i]] >= r-l+1) {
u = fa[u][i];
}
}
int s = qrys(x, 1, n, trt[u]);
int p = qryk(s, 1, n, trt[u]);
if (!p) return 0;
int t = p - (r-l+1) + 1;
return f[t] >= p;
}

int main() {
scanf("%d", &n);
scanf("%s", s+1);
reverse(s + 1, s + n + 1);
last = tot = 1;
for (int i = 1; i <= n; i++) {addchar(s[i] - 'a', i); ind[i] = last;}
for (int i = 2; i <= tot; i++) son[par[i]].push_back(i);
dfs(1);
for (int l = 1; l <= n; l++) {
f[l] = max(l, f[l-1]);
while (f[l] + 1 <= n) {
if (check(l+1, f[l]+1, l-1) || check(l, f[l], l-1)) {
++ f[l];
} else break;
}
ans = max(ans, f[l] - l + 1);
}
printf("%d\n", ans);
return 0;
}