zoj 3188 Treeland Exhibition

本文从WordPress迁移而来, 查看全部WordPress迁移文章

树dp

题意:给一个n个点的无根树,给一个长度L,要找出一条长度不超过L的链(即点数不超过L+1),然后其他点要汇聚到这条链上(每个点选链上离它最近的点),求其余点走到链上的距离和。问要怎么选这条链,然后使得这个距离和最小。

这个题目挺难的,难的部分个人觉得是在实现部分,想到了也不好实现,另外,怎么设置状态也是关键的,状态设置好了,复杂度能降维

分析:定0为根,然后,既然是要我们找一条长度不超过L的链,那怎么找?可以枚举树中的每个点u,假设这个点在链上,并且这条链是向下垂的,也就是这条链在u子树下,那么这条链的形状可以是”|”或者”^”,即可以看成又1条组成的或者2条组成。这样枚举,可以做到不遗漏所有情况,也不做多余的枚举。

假设我们已经得到了一条链,要怎么计算出其余点到这条链上的距离和呢?这需要dp记录一些信息

downcost[u],u下面的子树全部回到u的花费

upcost[u],u上方的所有点全部回到u的花费

cnt[u],u下方有多少个点

up[u],u上方有多少个点

downcost[u] = sigma downcost[v] + (cnt[u]-1);

upcost[u] = upcost[fa] + up[fa] + downcost[fa] - (cnt[fa]-1) - downcost[u] + 2*(cnt[fa]-1-cnt[u]) + 1;

cnt[u] = sigma cnt[v] + 1;

up[u] = n - cnt[u];

得到这些信息是为了dp转移的时候可以用上

定义dp[u][m]:选了点u,且在u子树下找到一条长度不大于m的链,其他点汇聚过来的最小花费,注意这里的链一定是直线”|”而不是叉链”^”

dp的定义可以看出,它是可以继承上一次的答案的,因为是“不大于m的链”

dp[u][m] = min(dp[u][m-1] , downcost[u] - downcost[v] - cnt[v] + dp[v][i-1]);

然后就是怎么更新答案了

最后要找的链如果是”|”的,可以直接由dp[u][m]的信息得到

如果是”^”的话,需要两条儿子链组成

假设链上的点数不超过m,那么u用掉一个点,下面dp[v1][k],dp[v2][m-1-k],这样就形成了”^”

这里就需要枚举两个儿子v1和v2?不用了,我们限定了顺序的话,一直枚举一直更新,只需要枚举一个儿子即可

选了v1和v2的话,花费就是

upcost[u] + downcost[u] + dp[v1][i-1] - downcost[v1] - cnt[v1] + dp[v2][i-1] - downcost[v2] - cnt[v2]

这里的实现有些讲究,可以降低复杂度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define N 10010
#define INF 100000000
#define M 110
#define cl(xx,yy) memset((xx),(yy),sizeof((xx)))

int n,m,tot,head[N];
struct edge{
int u,v,next;
}e[2*N];
int cnt[N],up[N],downcost[N],upcost[N];
int dp[N][M],id[N][M];
int ANS;

inline void add(int u,int v){
e[tot].u = u; e[tot].v = v;
e[tot].next = head[u]; head[u] = tot++;
}

void updata(int u,int fa){
dp[u][1] = downcost[u];
for(int i = 2; i <= m; i++){
dp[u][i] = dp[u][i-1];
for(int k = head[u]; k != -1; k = e[k].next){
int v = e[k].v;
if(v == fa) continue;
int res = downcost[u] - downcost[v] - cnt[v] + dp[v][i-1];
dp[u][i] = min(dp[u][i] , res);
}
}
}

void dfs1(int u,int fa){
cnt[u] = 1; downcost[u] = upcost[u] = 0;
for(int k = head[u]; k != -1; k = e[k].next){
int v = e[k].v;
if(v == fa) continue;
dfs1(v,u);
cnt[u] += cnt[v];
downcost[u] += downcost[v];
}
downcost[u] += (cnt[u]-1);
up[u] = n - cnt[u];
updata(u,fa);
}

void dfs2(int u,int fa){
if(fa == -1) upcost[u] = 0;
else{
upcost[u] = upcost[fa] + up[fa] +
downcost[fa] - (cnt[fa]-1) - downcost[u] +
2*(cnt[fa]-1-cnt[u]) + 1;
}
int TEMP,Min[M]; //dp[v][i] - downcost[v] - cnt[v];
fill(Min,Min+m+5,INF);
TEMP = INF;
for(int k = head[u]; k != -1; k = e[k].next){
int v = e[k].v;
if(v == fa) continue;
dfs2(v,u);
for(int i = 2; i <= m; i++){
int j = m - i;
TEMP = min(TEMP , downcost[u] + Min[j] + dp[v][i-1] - downcost[v] - cnt[v]);
}

for(int i = 1; i <= m; i++){
Min[i] = min(Min[i] , dp[v][i] - downcost[v] - cnt[v]);
if(i > 1) Min[i] = min(Min[i] , Min[i-1]);
}
}
ANS = min(ANS , TEMP + upcost[u]);
ANS = min(ANS , dp[u][m] + upcost[u]);
}

int main(){
while(scanf("%d%d",&n,&m)!=EOF){
if(!n && !m) break;
tot = 0; cl(head,-1); m++;
for(int i = 1; i < n; i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
for(int i = 0; i < n; i++)
for(int j = 0; j <= m; j++){
dp[i][j] = dp[i][j] = INF;
id[i][j] = id[i][j] = -1;
}
ANS = INF;
dfs1(0,-1);
dfs2(0,-1);
printf("%d\n",ANS);

}
return 0;
}