[Codeforces1120D] Power Tree

题目链接: https://codeforces.com/contest/1120/problem/D

题目大意

给定一棵 \(n\) 个点,以 \(1\) 为根的树,第 \(i\) 个点有一个价格 \(c_i\)
你要选一些点,然后另一个人会给每个叶节点设置一个数字。你只能对已经选择的点做子树加任意数字的操作。
你要通过若干次操作把所有叶子节点上的数变为 \(0\)
问:要确保另一个人无论怎么操作你都能把叶子上的数变成 \(0\),你选择的点价格总和最小是什么?有哪些点被至少一种最小价格的方案包含?
\(n \le 200000\)\(0\le c_i \le 10^9\)

解法

这个问题可以看成,你要选一些点作为未知数,每个叶子的限制构成了一个方程,你要让这个线性方程组满足,无论常数项取值是什么,这个线性方程组都有解。
因为价格是非负整数,如果系数矩阵的列向量组线性相关,必可以删除某一列,即可以少选一个点,这样一定不会变的更差,所以系数矩阵的列向量组线性无关。因为无论常数项的取值是什么,方程组都要有解,所以行数等于列数。
结论 设叶节点个数为 \(x\),一个包含恰好 \(x\) 个点的选取方案能保证可以将所有叶子上的数字变成 \(0\) 的充要条件是,任意两个叶子到根的路径上,存在被选取的点,且深度最大的被选取点不同。
证明
必要性显然。
充分性:把方程组的未知数按任意一种 \(dfs\) 序排列,系数矩阵中,每一行的第一个非 \(0\) 元素所在的列都不同,所以系数矩阵的列向量组线性无关,又因为行数等于列数,所以无论常数项取值是什么,方程组都有解。
状态 \(dp[i][j][k]\)
\(i\) 表示只考虑以 \(i\) 为根的子树。
\(j\) 表示 \(i\) 是否被选择。
\(k\) 表示是否存在一个叶子,他到 \(i\) 的路径中没有点被选择。
\(dp\) 数组中存的是,满足任意两个叶子到 \(i\) 的路径上深度最大的祖先不同(如果不存在则看作 \(0\))的最小价格和。
\(dp\) 一下,然后记录哪些转移可以取到最优值,最后从根开始 \(dfs\) 一遍求方案即可。

代码

(巨丑)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>

using namespace std;

typedef long long ll;

const int maxn = 200010;
const ll inf = 1e18;

int c[maxn], vis[4*maxn], n;
ll dp[maxn][2][2];
vector<int> tree[maxn], tran[maxn*4];

void dfs1(int u, int f) {
dp[u][0][0] = dp[u][0][1] = dp[u][1][0] = dp[u][1][1] = inf;
for (int i = 0; i < tree[u].size(); i++) {
int v = tree[u][i];
if (v != f) {
dfs1(v, u);
}
}
if (u == 1 || tree[u].size() > 1) {
ll s = 0, mn = inf, cm = inf;
for (int i = 0; i < tree[u].size(); i++) {
int v = tree[u][i];
if (v != f) {
s += min(dp[v][0][0], dp[v][1][0]);
ll t = dp[v][0][1] - min(dp[v][0][0], dp[v][1][0]);
if (t <= mn) {
cm = mn;
mn = t;
} else if (t < cm)
cm = t;
}
}
dp[u][0][0] = s;
dp[u][0][1] = s+mn;
dp[u][1][0] = min(s+c[u]+mn, s+c[u]);
for (int i = 0; i < tree[u].size(); i++) {
int v = tree[u][i];
if (v != f) {
if (dp[v][0][0] < dp[v][1][0]) {
tran[u*4+2*0+0].push_back(v*4+2*0+0);
if (mn >= 0) tran[u*4+2*1+0].push_back(v*4+2*0+0);
} else if (dp[v][0][0] > dp[v][1][0]) {
tran[u*4+2*0+0].push_back(v*4+2*1+0);
if (mn >= 0) tran[u*4+2*1+0].push_back(v*4+2*1+0);
} else {
tran[u*4+2*0+0].push_back(v*4+2*0+0);
tran[u*4+2*0+0].push_back(v*4+2*1+0);
if (mn >= 0) tran[u*4+2*1+0].push_back(v*4+2*0+0);
if (mn >= 0) tran[u*4+2*1+0].push_back(v*4+2*1+0);
}
ll t = dp[v][0][1] - min(dp[v][0][0], dp[v][1][0]);
if (t == mn) {
if (mn == cm) {
if (dp[v][0][0] < dp[v][1][0]) {
tran[u*4+2*0+1].push_back(v*4+2*0+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*0+0);
} else if (dp[v][0][0] > dp[v][1][0]) {
tran[u*4+2*0+1].push_back(v*4+2*1+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*1+0);
} else {
tran[u*4+2*0+1].push_back(v*4+2*0+0);
tran[u*4+2*0+1].push_back(v*4+2*1+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*0+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*1+0);
}
}
tran[u*4+2*0+1].push_back(v*4+2*0+1);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*0+1);
} else {
if (dp[v][0][0] < dp[v][1][0]) {
tran[u*4+2*0+1].push_back(v*4+2*0+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*0+0);
} else if (dp[v][0][0] > dp[v][1][0]) {
tran[u*4+2*0+1].push_back(v*4+2*1+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*1+0);
} else {
tran[u*4+2*0+1].push_back(v*4+2*0+0);
tran[u*4+2*0+1].push_back(v*4+2*1+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*0+0);
if (mn <= 0) tran[u*4+2*1+0].push_back(v*4+2*1+0);
}
}
}
}
} else {
dp[u][1][0] = c[u];
dp[u][0][1] = 0;
}
}

void dfs2(int u) {
vis[u] = 1;
for (int i = 0; i < tran[u].size(); i++) {
int v = tran[u][i];
if (!vis[v])
dfs2(v);
}
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &c[i]);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
tree[u].push_back(v);
tree[v].push_back(u);
}
dfs1(1, 0);
int cnt = 0;
ll ans = min(dp[1][0][0], dp[1][1][0]);
printf("%lld ", ans);
if (dp[1][0][0] == ans) dfs2(4*1+2*0+0);
if (dp[1][1][0] == ans) dfs2(4*1+2*1+0);
for (int i = 1; i <= n; i++) {
if (vis[4*i+2*1+0] || vis[4*i+2*1+1]) {
cnt ++;
}
}
printf("%d\n", cnt);
for (int i = 1; i <= n; i++) {
if (vis[4*i+2*1+0] || vis[4*i+2*1+1]) {
printf("%d ", i);
}
}
printf("\n");
return 0;
}