hdu 4705 Y

本文从WordPress迁移而来, 查看全部WordPress迁移文章

树dp

3个点,如果不能同时被一条简单路径通过,则计数一次

统计树中有多少组这样的点集合

定1为根

一开始的做法是,对于一个点,

统计cnt[u],表示以u为根的子树有多少个点

统计back[u],u点上方的点数

那么以u为根看这棵树,这棵树由几个分支组成

v1,v2,v3…….fa(v是u的儿子)

每个分支的点数为cnt[v1],cnt[v2],cnt[v3]……back[fa]

那么由这个点产生的个数

就是任选3个分支,只能在每个分支里面拿1个点,拿出3个点后,这3个点一定是合法的答案

这样做,一定是不重不漏的,但这样做,复杂度太高了

1
2
3
4
sample 
点u四周的分支及其点数
k1,k2,k3,k4,
k1*k2*k3 + k1*k2*k4 + k1*k3*k4 + k2*k3*k4

这个问题,从正面求是不行的,但是从反面求,可以变为O(n)

在树中任选3个点,只有2种情况

  1. 3个点同时在一条简单路径上,3个点不同时在一条简单路径上,这是对立事件,不会出现第3种情况
  2. 3个点不同时在一条简单路径的方案数 = 任选3个点的方案数 - 3个点同时在一条简单路径上的方案数

要算算3个点在一个简单路径上是比较好算的

同样以点u来看,必选点u,然后在一个分支选一个点,在剩下的分支中,任选2个点(2个点可以来自于同一个分支),这3点必定共线
当然这样选会重复,所以答案要/2

这样选出来的时候3点共线的,C(n,3)- ans 才是3点不共线的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#pragma comment(linker, "/STACK:16777216")
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define N 100010
#define LL long long
#define cl(xx,yy) memset((xx),(yy),sizeof((xx)))
int n,tot,head[N];
struct edge{
int u,v,next;
edge(){}
edge(int __u,int __v){
u = __u; v = __v;
}
}e[2*N];
int cnt[N],back[N],p[N];
LL ans;

inline void add(int u,int v){
e[tot] = edge(u,v); e[tot].next = head[u]; head[u] = tot++;
}

void dfs1(int u ,int fa){
cnt[u] = 1; p[u] = fa;
for(int k = head[u]; k != -1; k = e[k].next){
int v = e[k].v;
if(v == fa) continue;
dfs1(v,u);
cnt[u] += cnt[v];
}
}

void dfs2(int u,int fa){
back[u] = cnt[1] - cnt[u];
for(int k = head[u]; k != -1; k = e[k].next){
int v = e[k].v;
if(v == fa) continue;
dfs2(v,u);
}
}

inline LL C(LL m){
return m * (m - 1) * (m - 2) / 6;
}

int main(){
while(scanf("%d",&n)!=EOF){
cl(head,-1); tot = 0;
for(int i = 1; i < n; i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
cl(cnt,0); cl(back,0);
dfs1(1,-1);
dfs2(1,-1);
ans = 0;
for(int u = 1; u <= n; u++){
LL res,__res;
for(int k = head[u]; k != -1; k = e[k].next){
int v = e[k].v;
if(v == p[u]) res = back[u];
else res = cnt[v];
__res = n - res - 1;
ans += res * __res;
}
}
ans /= 2;
ans = C(n) - ans;
printf("%I64d\n",ans);
}
return 0;
}