0%

点分治模板

模板题:https://www.luogu.org/problemnew/show/P3806

给定一棵有n个点的树,m次询问树上距离为k的点对是否存在。k很大,m较小。

1.暴力做法,可能变成n^2 (756ms)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include<iostream>
#include<cstdio>
using namespace std;
const int maxK=10000001;
int head[10005],vet[20005],nxt[20005],w[20005],num;
int sz[10005],mxs[10005],root;
int d[10005],top,found[10000005];
bool vis[10005];
int n,m;
void addedge(int x,int y,int c) {
num++;vet[num]=y;nxt[num]=head[x];w[num]=c;head[x]=num;
num++;vet[num]=x;nxt[num]=head[y];w[num]=c;head[y]=num;
}
void getroot(int x,int ff,int tt) {
sz[x]=1;mxs[x]=0;
for (int e=head[x];e;e=nxt[e]) {
int v=vet[e];
if (v==ff || vis[v]) continue;
getroot(v,x,tt);
sz[x]+=sz[v];
mxs[x]=max(mxs[x],sz[v]);
}
mxs[x]=max(mxs[x],tt-sz[x]);
if (mxs[x]<mxs[root]) root=x;
}
void getdepth(int x,int ff,int depth) {
d[++top]=depth;
for (int e=head[x];e;e=nxt[e]) {
if (vet[e]==ff || vis[vet[e]]) continue;
getdepth(vet[e],x,depth+w[e]);
}
}
void cal(int x,int typ,int val) {
top=0;
getdepth(x,0,0);
for (int i=1;i<=top;i++)
for (int j=1;j<=top;j++) {
if (typ && d[i]+d[j]<=maxK) found[d[i]+d[j]]++;
else if (d[i]+d[j]+val<=maxK) found[d[i]+d[j]+val]--;
}
}
void solve(int x,int tot) {
cal(x,1,0);
vis[x]=true;
for (int e=head[x];e;e=nxt[e]) {
if (!vis[vet[e]]) {
cal(vet[e],0,w[e]*2);//刷答案的地方视具体题目而定
root=0;
int tt=sz[vet[e]];
if (tt>sz[x]) tt=tot-sz[x];//主流写法没有这一句。这样写保证不退化,运行速度略微提升。
getroot(vet[e],x,tt);
solve(root,tt);
}
}
}
int main() {
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++) {
int x,y,c;
scanf("%d%d%d",&x,&y,&c);
addedge(x,y,c);
}
mxs[0]=(n<<1);root=0;
getroot(1,0,n);
solve(root,n);
while (m--) {
int x;
scanf("%d",&x);
if (found[x]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}

其实嘛
我觉得常规写法子树的总节点数tot有问题
如果子树是上一次root的父亲
size[to]是不能直接用的,需要-size[oldroot]
实验证明主流解法的确分得不太均匀
不过因为众多子树中只有一个会出错,问题也不大

另附主流写法核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void getroot(int x,int ff) {
sz[x]=1;mxs[x]=0;
for (int e=head[x];e;e=nxt[e]) {
int v=vet[e];
if (v==ff || vis[v]) continue;
getroot(v,x);
sz[x]+=sz[v];
mxs[x]=max(mxs[x],sz[v]);
}
mxs[x]=max(mxs[x],tt-sz[x]);
if (mxs[x]<mxs[root]) root=x;
}
void solve(int x) {
cal(x,1,0);
vis[x]=true;
for (int e=head[x];e;e=nxt[e]) {
if (!vis[vet[e]]) {
cal(vet[e],0,w[e]*2);
root=0;
tt=sz[vet[e]];//?
getroot(vet[e],x);
solve(root);
}
}
}

2.优化,枚举询问 (16ms)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<iostream>
#include<cstdio>
using namespace std;
const int maxK=10000001;
int head[10005],vet[20005],nxt[20005],w[20005],num;
int sz[10005],mxs[10005],root;
int d[10005],q[105],xg[10005];
bool vis[10005],found[10000005],ans[105];
int n,m;
inline int read() {
char ch=0;int sum=0;
while (ch>'9'||ch<'0') ch=getchar();
while (ch>='0'&&ch<='9') sum=sum*10+ch-'0',ch=getchar();
return sum;
}
inline void addedge(int x,int y,int c) {
num++;vet[num]=y;nxt[num]=head[x];w[num]=c;head[x]=num;
num++;vet[num]=x;nxt[num]=head[y];w[num]=c;head[y]=num;
}
void getroot(int x,int ff,int tt) {
sz[x]=1;mxs[x]=0;
for (int e=head[x];e;e=nxt[e]) {
int v=vet[e];
if (v==ff || vis[v]) continue;
getroot(v,x,tt);
sz[x]+=sz[v];
mxs[x]=max(mxs[x],sz[v]);
}
mxs[x]=max(mxs[x],tt-sz[x]);
if (mxs[x]<mxs[root]) root=x;
}
void getdepth(int x,int ff,int depth) {
d[++d[0]]=depth;
for (int e=head[x];e;e=nxt[e]) {
if (vet[e]==ff || vis[vet[e]]) continue;
getdepth(vet[e],x,depth+w[e]);
}
}
inline void calc(int x,int val) {
d[0]=0;
getdepth(x,0,val);
for (int i=1;i<=d[0];i++)
for (int j=1;j<=m;j++)
if (q[j]>=d[i]) ans[j]|=found[q[j]-d[i]];
for (int i=1;i<=d[0];i++) found[d[i]]=1,xg[++xg[0]]=d[i];
}
void solve(int x,int tot) {
vis[x]=true;xg[0]=0;found[0]=1;
for (int e=head[x];e;e=nxt[e]) {
if (!vis[vet[e]]) calc(vet[e],w[e]);
}
for (int i=1;i<=xg[0];i++) found[xg[i]]=0;
for (int e=head[x];e;e=nxt[e]) {
if (!vis[vet[e]]) {
root=0;
int tt=sz[vet[e]];
if (tt>sz[x]) tt=tot-sz[x];
getroot(vet[e],x,tt);
solve(root,tt);
}
}
}
int main() {
n=read();m=read();
for (int i=1;i<n;i++) {
int x=read(),y=read(),c=read();
addedge(x,y,c);
}
for (int i=1;i<=m;i++) q[i]=read();
mxs[0]=(n<<1);root=0;
getroot(1,0,n);
solve(root,n);
for (int i=1;i<=m;i++) {
if (ans[i]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}