bestcoder

本文从WordPress迁移而来, 查看全部WordPress迁移文章

题目链接:DZY Loves Connecting

题意 & 题解 & 代码:

比赛的时候想到了,对于每个节点,算出这个节点及其子树下的集合和;然后把每个节点算出来的答案相加,就是结果。

后来跑偏了。

后来才意识到,没有跑偏,思路是对的,但没时间了。

回到最初的问题,对于每个节点,算出这个节点及其子树下的集合和

需要两个信息

以v为根的子树,集合的种类个数,dpc[v],集合和dps[v]

u是v的父亲,u利用所有的儿子节点能推导出 dpc[u], dps[u]

dpc[u] = (dpc[v1]+1) (dpc[v2]+1) (dpc[v3]+1) ….. * (dpc[vn]+1); 不解释,你懂的

dps[u] = (dps[v1] (dpc[u]/(dpc[v1]+1))) + (dps[v2] (dpc[u]/(dpc[v2]+1))) + (dps[v3] (dpc[u]/(dpc[v3]+1))) …… (dps[vn] * (dpc[u]/(dps[vn]+1)))

还没完,还有u自己:dps[u] += dpc[u] * 1;不解释,你懂的

然后很快就会想到费马小定理,就去求逆元了,然后就可能出现这样的WA代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#pragma comment(linker, "/STACK:102400000,102400000")
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <utility>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <algorithm>

using namespace std;

#define pb push_back
#define mp make_pair

const int INF = 0x3f3f3f3f;
typedef long long LL;
typedef unsigned long long ULL;
typedef unsigned int uint;
typedef pair<int, int> pii;
typedef pair<LL, LL> pll;

const int MAXN = 200010;
const int MOD = 1000000007LL;

bool vis[MAXN];
int head[MAXN], tot, n;
struct edge {
int u, v, next;
edge() {}
edge(int uu, int vv, int nx): u(uu), v(vv), next(nx) {}
}e[MAXN<<1];
LL sum;

LL dps[MAXN], dpc[MAXN];

inline void add(int u, int v) {
e[tot] = edge(u, v, head[u]);
head[u] = tot++;
}

LL pow(LL a, LL m) {
if (m == 0LL) {
return 1LL;
} else if (m == 1LL) {
return a;
}
LL temp = pow(a, m >> 1LL);
temp = (temp * temp) % MOD;
if (m & 1LL) {
temp = (temp * a) % MOD;
}
return temp;
}

void dfs(int u, int p) {
vis[u] = true;

dpc[u] = 1LL;
for (int k = head[u]; k != -1; k = e[k].next) {
int v = e[k].v;
if (vis[v]) continue;
dfs(v, u);
dpc[u] = (dpc[u] * (dpc[v] + 1LL)) % MOD;
}

dps[u] = 0LL;
for (int k = head[u]; k != -1; k = e[k].next) {
int v = e[k].v;
if (v == p) continue;
LL re = (dpc[u] * pow(dpc[v]+1LL, MOD - 2LL)) % MOD;
dps[u] = (dps[u] + dps[v] * re) % MOD;
}
dps[u] = (dps[u] + dpc[u]) % MOD;
sum = (sum + dps[u]) % MOD;
//printf("u:%d, dps:%lld, dpc:%lld, sum:%lld\n", u, dps[u], dpc[u], sum);
}

int main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
memset(head, -1, sizeof head);
tot = 0;
for (int i = 2; i <= n; i++) {
int v;
scanf("%d", &v);
add(i, v);
add(v, i);
}
memset(vis, false, sizeof vis);
LL ans = 0LL;
for (int i = 1; i <= n; i++) {
if (vis[i]) continue;
sum = 0LL;
dfs(i, -1);
ans = (ans + sum) % MOD;
}
cout << ans << endl;
}
return 0;
}

这样会WA的原因是,中间有+1的运算,+1之后,和10^9+7不一定再互质了

所以要避免逆元的计算,全部用乘法,然后就学了点新姿势,非常有用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#pragma comment(linker, "/STACK:102400000,102400000")
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <utility>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <algorithm>

using namespace std;

#define pb push_back
#define mp make_pair

const int INF = 0x3f3f3f3f;
typedef long long LL;
typedef unsigned long long ULL;
typedef unsigned int uint;
typedef pair<int, int> pii;
typedef pair<LL, LL> pll;

const int MAXN = 200010;
const int MOD = 1000000007LL;

bool vis[MAXN];
int head[MAXN], tot, n;
struct edge {
int u, v, next;
edge() {}
edge(int uu, int vv, int nx): u(uu), v(vv), next(nx) {}
}e[MAXN<<1];
LL sum;

LL dps[MAXN], dpc[MAXN];

inline void add(int u, int v) {
e[tot] = edge(u, v, head[u]);
head[u] = tot++;
}

void dfs(int u) {
vis[u] = true;
dpc[u] = 1LL;
dps[u] = 1LL;
for (int k = head[u]; k != -1; k = e[k].next) {
int v = e[k].v;
if (vis[v]) continue;
dfs(v);
dps[u] = (dps[u] * (dpc[v] + 1LL) + dps[v] * dpc[u]) % MOD;
dpc[u] = (dpc[u] * (dpc[v] + 1LL)) % MOD;
}

sum = (sum + dps[u]) % MOD;
//printf("u:%d, dps:%lld, dpc:%lld, sum:%lld\n", u, dps[u], dpc[u], sum);
}

int main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
memset(head, -1, sizeof head);
tot = 0;
for (int i = 2; i <= n; i++) {
int v;
scanf("%d", &v);
add(i, v);
add(v, i);
}
memset(vis, false, sizeof vis);
LL ans = 0LL;
for (int i = 1; i <= n; i++) {
if (vis[i]) continue;
sum = 0LL;
dfs(i);
ans = (ans + sum) % MOD;
}
cout << ans << endl;
}
return 0;
}