[Codeforces375E] Red and Black Tree

做法

我们先不要考虑边权。

考虑直接 dp,\(dp[u][i][j]\) 表示 \(u\) 的子树中修改后有 \(i\) 个黑点,距离根最远的与子树中所有黑点距离都大于 \(x\) 的红点与根的距离为 \(j\),最小要修改几次。转移的时候需要决定距离根最近的红点,所以需要知道距离根最近的黑点的距离,但是如果我们再记一个最近黑点的话复杂度就炸了。注意到一个性质:如果根是红点,距离根最近的黑点所在的(根的儿子的)子树中一定所有红点都可以在这个子树中找到距离不超过 \(x\) 的黑点。所以只要对 \(j = 0\) 的情况额外记录下最近的黑点。这样就可以 dp 了。

现在来考虑一下边权,注意到我们只需要记一个点到根的距离,所以我们可以把每个点到根的距离先离散化一下,这样就能做带权的情况了。

内存需要卡一下。

题解竟然是对 \(500\) 级别的东西跑单纯形,不太能理解出题人的想法。

时间复杂度为 \(\mathcal O(n^3)\)。经过艰难的调试和卡常终于 A 了。(时限 1000 ms,开 Ofast 982 ms,不开 Ofast 998 ms,TAT)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int maxn = 510;
const int inf = 0x3f3f3f3f;

vector<ll> vt;
vector<int> lst[maxn];

int n, x, l[maxn], sz[maxn], col[maxn], e, cb, cr, ty;
ll dis[maxn];

struct Edge {
int v, w, x;
} E[maxn<<1];

inline void addEdge(int u, int v, int w) {
E[e].v = v, E[e].x = l[u], E[e].w = w, l[u] = e++;
}

struct dat {
vector<vector<int> > f, g;
int s;
// f : 有 i 个黑点,所有红点深度不超过 j
// g : 有 i 个黑点,没有未匹配红点,至少存在一个黑点深度不超过 j
dat(int s_) {
s = s_;
f = g = vector<vector<int> >(s + 1, vector<int>(n+2, inf));
}
};

int getIndex(ll x) {
return int (lower_bound(vt.begin(), vt.end(), x) - vt.begin() + 1);
}

int upb(ll x) {
return int (upper_bound(vt.begin(), vt.end(), x) - vt.begin() + 1);
}

inline int Min(int x, int y) {
return x < y ? x : y;
}

dat Merge(const dat &d1, const dat &d2, ll d, int u, int v) {
dat ret(d1.s + d2.s);
for (int i = 0; i < lst[v].size(); i++) lst[u].push_back(lst[v][i]);
lst[u].push_back(1);
for (int _ = 0; _ <= lst[u].size(); _++) {
int i = n+1;
if (_ < lst[u].size()) i = lst[u][_];
int rb = upb(x + 2 * d - vt[i-1]) - 1;
for (int s1 = 0; s1 <= d1.s; s1++) {
for (int s2 = 0; s2 <= d2.s; s2++) {
int s = s1 + s2;
ret.f[s][i] = Min(ret.f[s][i], d1.f[s1][i] + d2.f[s2][i]);
ret.g[s][i] = Min(ret.g[s][i], d1.g[s1][i] + d2.g[s2][n+1]);
ret.g[s][i] = Min(ret.g[s][i], d1.g[s1][n+1] + d2.g[s2][i]);
// ret.f[s][i] = Min(ret.f[s][i], d1.f[s1][i] + d2.g[s2][n+1]);
// ret.f[s][i] = Min(ret.f[s][i], d1.g[s1][n+1] + d2.f[s2][i]);
if (rb >= 0) {
ret.g[s][i] = Min(ret.g[s][i], d1.f[s1][rb] + d2.g[s2][i]);
ret.g[s][i] = Min(ret.g[s][i], d1.g[s1][i] + d2.f[s2][rb]);
}
}
}
}
for (int i = 0; i <= ret.s; i++) {
for (int j = 1; j <= n+1; j++) {
ret.g[i][j] = Min(ret.g[i][j], ret.g[i][j-1]);
}
ret.f[i][1] = Min(ret.f[i][1], ret.g[i][n+1]);
for (int j = 1; j <= n+1; j++) {
ret.f[i][j] = Min(ret.f[i][j], ret.f[i][j-1]);
}
}
lst[u].pop_back();
if (ty == 0) ret.s = min(ret.s, cb);
else ret.s = min(ret.s, cr);
return ret;
}

void dfs1(int u, int f) {
sz[u] = 1;
for (int p = l[u]; p >= 0; p = E[p].x) {
int v = E[p].v;
if (v != f) {
dis[v] = dis[u] + E[p].w;
dfs1(v, u);
sz[u] += sz[v];
}
}
}

dat dfs2(int u, int f) {
dat ret(1);
int t = getIndex(dis[u]);
lst[u].push_back(t);
if (ty == 0) {
for (int i = t; i <= n+1; i++) ret.f[0][i] = (col[u] != 0);
for (int i = t; i <= n+1; i++) ret.f[1][i] = (col[u] != 1);
for (int i = t; i <= n+1; i++) ret.g[1][i] = (col[u] != 1);
} else {
for (int i = t; i <= n+1; i++) ret.f[1][i] = (col[u] != 0);
for (int i = t; i <= n+1; i++) ret.f[0][i] = (col[u] != 1);
for (int i = t; i <= n+1; i++) ret.g[0][i] = (col[u] != 1);
}
for (int i = 0; i <= ret.s; i++) {
for (int j = 1; j <= n+1; j++) {
ret.g[i][j] = Min(ret.g[i][j], ret.g[i][j-1]);
}
ret.f[i][1] = Min(ret.f[i][1], ret.g[i][n+1]);
for (int j = 1; j <= n+1; j++) {
ret.f[i][j] = Min(ret.f[i][j], ret.f[i][j-1]);
}
}
for (int p = l[u]; p >= 0; p = E[p].x) {
int v = E[p].v;
if (v != f) {
ret = Merge(ret, dfs2(v, u), dis[u], u, v);
}
}
return ret;
}

int main() {
// freopen("data.in", "r", stdin);
memset(l, -1, sizeof(l));
scanf("%d%d", &n, &x);
for (int i = 1; i <= n; i++) {
scanf("%d", &col[i]);
if (col[i]) ++ cb;
}
for (int i = 1; i < n; i++) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
addEdge(u, v, w), addEdge(v, u, w);
}
dfs1(1, 0);
cr = n - cb;
if (cr < cb) ty = 1;
for (int i = 1; i <= n; i++) vt.push_back(dis[i]);
vt.push_back(ll(1e18));
sort(vt.begin(), vt.end());
dat res = dfs2(1, 0);
int ans = inf;
if (ty == 0) for (int i = 0; i <= n+1; i++) ans = min(ans, res.g[cb][i]);
else for (int i = 0; i <= n+1; i++) ans = min(ans, res.g[cr][i]);
if (ans < inf) printf("%d\n", ans / 2);
else puts("-1");
return 0;
}