导言
淀粉质点分治是一种统计方法,具体来说,是对树上的点的一些值来进行统计,标准的点分治统计复杂度为 O(Nlog2N)O(Nlog2N) 是解决树上疑难杂症的不三之选(还有一个是树形dp)
算法思路&&具体问题
给你一棵树,有m个询问,每个询问包含一个k,问树中是否存在距离为k的点对
这是一道很好的模板题,在这道题里面,我们具体关注点分治的想法与实现,其中有些地方复杂度会爆炸(还是能过),这里我们不做讨论
很容易得到一个 O(N2)O(N2) 的算法,我们首先选一个点,然后再枚举第二个点。
但是别忘了,这是一棵树。
假设最上面的那个节点是根,那么所有经过它的所有路径有什么特点?
- 以它为端点。
- 经过它。
把这些路径统计起来后,对于每个子树也统计这样的路径(注意要把根删掉),不断进行这样的分治,路径就被不重不漏地统计出来了
基于分治的思想,我们需要用树的重心来作为根。
于是这一题就轻松地写出来啦
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
#include<bits/stdc++.h> using namespace std; const int MAXN=10003; int Head[MAXN],Nt[MAXN<<1],to[MAXN<<1],val[MAXN<<1],tot; int size[MAXN],d[MAXN],a[MAXN],b[MAXN]; bool w[MAXN],v[MAXN],ap[10000003];//ap用来保存出现过的路径长度,w数组用来表示删去的根节点 int n,m,k[MAXN],now_part,root; int all; void add(int x,int y, int z){ Nt[++tot]=Head[x]; to[tot]=y; val[tot]=z; Head[x]=tot; } void find_root(int x){//寻找树的重心 v[x]=1; size[x]=1; int max_part=0; for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(v[y]||w[y])continue; find_root(y); size[x]+=size[y]; max_part=max(max_part,size[y]); } max_part=max(max_part,all-size[x]); if(max_part<now_part){ now_part=max_part; root=x; } } void dfs(int x,int deep){ v[x]=1; if(b[x]==0)b[x]=x; for(int i=Head[x];i;i=Nt[i]){ int y=to[i],z=val[i]; if(w[y]||v[y])continue; a[++a[0]]=y; if(deep>0)b[y]=b[x]; else b[y]=y; d[y]=d[x]+z; ap[d[y]]=1; dfs(y,deep+1); } } void work(int x){ memset(v,0,sizeof(v)); memset(d,0,sizeof(d)); memset(b,0,sizeof(b)); w[root]=1; a[0]=0; dfs(x,0);//统计路径(以根为端点) for(int i=1;i<=a[0];i++){//这里复杂度会爆炸 for(int j=i;j<=a[0];j++){ ap[d[a[i]]]=ap[d[a[j]]]=1; if(j!=i&&b[a[i]]!=b[a[j]])ap[d[a[i]]+d[a[j]]]=1;//统计路径(经过根),b相同说明在同一个子树,不统计 } } for(int i=Head[root];i;i=Nt[i]){//分治 int y=to[i]; if(w[y])continue; all=size[y]; root=0,now_part=1<<30; find_root(y); work(root); } } int main(){ scanf("%d%d",&n,&m);all=n; for(int i=1;i<n;i++){ int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } for(int i=1;i<=m;i++){ scanf("%d",&k[i]); } now_part=1<<30; find_root(1); work(root); for(int i=1;i<=m;i++){ if(ap[k[i]])printf("AYE\n"); else printf("NAY\n"); } return 0; } |
轻松的A了上面一题之后,突然觉得点分治很简单(然而这只是我的错觉),这是因为上面的点分治是“假的”,下面这一题可以说明问题出在哪
给你一棵树,问你有多少个点对之间的距离能被3整除
有了上一题的经验,于是套上了点分治的模板,然后就TLE了QAQ,60Pt
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
#include<bits/stdc++.h> using namespace std; const int MAXN=20005; int Head[MAXN],Nt[MAXN<<1],to[MAXN<<1],val[MAXN<<1],tot; bool v[MAXN],w[MAXN]; int size[MAXN],d[MAXN],b[MAXN],a[MAXN]; int n,root,ans; int now_part=1<<30; void add(int x,int y,int z){ Nt[++tot]=Head[x]; to[tot]=y; val[tot]=z; Head[x]=tot; } void find_root(int S,int x){ size[x]=1; v[x]=1; int max_part=0; for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(v[y]||w[y])continue; find_root(S,y); size[x]+=size[y]; max_part=max(max_part,size[y]); } max_part=max(max_part,S-size[x]); if(max_part<now_part){ now_part=max_part; root=x; } } void dfs(int x,int step){ v[x]=1; for(int i=Head[x];i;i=Nt[i]){ int y=to[i],z=val[i]; if(w[y]||v[y])continue; d[y]=d[x]+z; if(d[y]%3==0)ans++; a[++a[0]]=y; if(step==0)b[y]=y; else b[y]=b[x]; dfs(y,step+1); } } void work(int x){ w[x]=1; memset(d,0,sizeof(d)); memset(v,0,sizeof(v)); a[0]=0; dfs(x,0); for(int i=1;i<=a[0];i++){ for(int j=i+1;j<=a[0];j++){ if(b[a[i]]!=b[a[j]]&&(d[a[i]]+d[a[j]])%3==0)ans++;//问题出在这 } } for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(w[y])continue; memset(v,0,sizeof(v)); find_root(size[y],y); work(y); } } int gcd(int a,int b){ return b?gcd(b,a%b):a; } int main(){ scanf("%d",&n); for(int i=1;i<n;i++){ int x,y,z;scanf("%d%d%d",&x,&y,&z); add(x,y,z);add(y,x,z); } find_root(n,1); work(root); ans*=2; ans+=n; int g=gcd(ans,n*n); printf("%d/%d",ans/g,n*n/g); return 0; } |
问题就出在我们统计路径的时候没有做到O(N)O(N),一般来说,题目中总会有些特殊的性质来帮助你在这一步可以做到O(N)O(N)
比如这一题,我们可以用以下策略来统计:
- 保存d数组时保存的是 mod 3 后的余数,并统计每种值的数量储存在yu数组里面
- 根节点的 yu[0]2+2yu[1]yu[2]yu[0]2+2yu[1]yu[2] 加入答案(0+0的路径还是3的倍数,1+2的路径也是3的倍数,后面的*2是排列顺序不同,前面不*2是因为两个距离根为0的点已经被yu[0]统计了,它们都在里面)
- 上面一个肯定有重复(在同一子树内的点对也被统计到了),于是我们对于每一棵子树,再减去 yu[0]2+2yu[1]yu[2]yu[0]2+2yu[1]yu[2] (这里不是分治)
- 再进行分治
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
#include<bits/stdc++.h> using namespace std; const int MAXN=20003; int Head[MAXN],to[MAXN<<1],val[MAXN<<1],Nt[MAXN<<1],tot; int d[MAXN],size[MAXN]; int now_part; bool w[MAXN],v[MAXN]; int n,root,ans; int yu[4]; int gcd(int a,int b){ return b?gcd(b,a%b):a; } void add(int x,int y,int z){ Nt[++tot]=Head[x]; to[tot]=y; val[tot]=z; Head[x]=tot; } void find_root(int S,int x){ size[x]=1; v[x]=1; int max_part=0; for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(v[y]||w[y])continue; find_root(S,y); size[x]+=size[y]; max_part=max(max_part,size[y]); } max_part=max(max_part,S-size[x]); if(max_part<now_part){ now_part=max_part; root=x; } } void dfs(int x,int fa){ yu[d[x]]++; for(int i=Head[x];i;i=Nt[i]){ int y=to[i],z=val[i]; if(w[y]||fa==y)continue; d[y]=(d[x]+z)%3; dfs(y,x); } } int calc(int x,int now){ yu[0]=yu[1]=yu[2]=0; d[x]=now%3; dfs(x,0); return yu[0]*yu[0]+2*yu[1]*yu[2]; } void work(int x){ w[x]=1; ans+=calc(x,0); for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(w[y])continue; ans-=calc(y,val[i]); root=0,now_part=1<<30; memset(v,0,sizeof(v)); find_root(size[y],y); work(root); } } int main(){ scanf("%d",&n); for(int i=1;i<n;i++){ int x,y,z;scanf("%d%d%d",&x,&y,&z); add(x,y,z);add(y,x,z); } now_part=1<<30; find_root(n,1); work(root); int g=gcd(n*n,ans); printf("%d/%d",ans/g,n*n/g); return 0; } |
给一棵树,每条边有权。求一条简单路径,权值和等于K,且边的数量最小。
这一题其实和一开始的模板题没两样,只不过数据范围很大,统计路径的时候不能再O(N2)O(N2)了
用一棵平衡树来维护就行(set足够)
由于这是存在性问题,所以不用在意统计子树时的重复统计
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
#include<bits/stdc++.h> using namespace std; const int MAXN=200003; int n,k; int Head[MAXN],to[MAXN<<1],val[MAXN<<1],Nt[MAXN<<1],tot; int root,now_part,cnt; int size[MAXN]; pair<int,int>a[MAXN]; bool v[MAXN],w[MAXN]; int minn=1<<30; int read(){ int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } void add(int x,int y,int z){ Nt[++tot]=Head[x]; to[tot]=y; val[tot]=z; Head[x]=tot; } void find_root(int S,int x){ v[x]=1; size[x]=1; int max_part=0; for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(v[y]||w[y])continue; find_root(S,y); size[x]+=size[y]; max_part=max(max_part,size[y]); } max_part=max(max_part,S-size[x]); if(max_part<now_part){ now_part=max_part; root=x; } } void dfs(int x,int vl,int step){ v[x]=1; if(step>minn||vl>k)return; a[++cnt]=make_pair(vl,step); for(int i=Head[x];i;i=Nt[i]){ if(v[to[i]]||w[to[i]])continue; dfs(to[i],vl+val[i],step+1); } } void work(int x){ w[x]=1; set<pair<int,int> >st; st.insert(make_pair(0,0)); for(int i=Head[x];i;i=Nt[i]){ int y=to[i],z=val[i]; if(w[y])continue; memset(v,0,sizeof(v)); cnt=0; dfs(y,z,1); set<pair<int,int> >::iterator it; for(int i=1;i<=cnt;i++){ it=st.lower_bound(make_pair(k-a[i].first,0)); if(it!=st.end()&&it->first+a[i].first==k)minn=min(minn,it->second+a[i].second); } for(int i=1;i<=cnt;i++)st.insert(a[i]); } st.clear(); for(int i=Head[x];i;i=Nt[i]){ int y=to[i]; if(w[y])continue; memset(v,0,sizeof(v)); root=0,now_part=1<<30; find_root(size[y],y); work(root); } } int main(){ n=read(),k=read(); for(int i=1;i<n;i++){ int x=read(),y=read(),z=read(); add(x+1,y+1,z);add(y+1,x+1,z); } now_part=1<<30; find_root(n,1); work(root); if(minn==1<<30)printf("-1\n"); else printf("%d\n",minn); return 0; } |