题解
最近学后缀数组学得有点晕,还是要多练啊。
题目就一句话:求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。
根据容斥原理,只要分别求出两个子串合并后的答案和两个子串单独的答案,最后的答案就是它们相减。
问题的关键就是如何在可以接受的复杂度内求这个答案。(我一开始连这个答案是什么都不知道)
实际上我们要求的是所有子串的lcp的和,可以把它想象成一个枚举起点,长度就是贡献的过程。
至于怎么求,其实首先要知道任意两个子串的lcp为它们之间Height数组的最小值,这样用ST表 O(1) 查询可以使得整体复杂度达到 O(n^2)
但这还不够,我们发现,随着枚举的子串字典序增大,Height数组的最小值只会越来越小,于是我们可以用单调栈来维护。
具体来说,我们设 f[i] 为字典序为 i 的字符串和它之前的字符串产生的贡献,如果它的Height是当前最小的,则 f[i] 为 Height[i]\times (i-1) 而如果不是最小的,那就是上一个比它小的贡献+ Height[i]\times (i-sk[top])
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN=2000005; int Height[MAXN],tax[MAXN],tp[MAXN],sa[MAXN],rk[MAXN]; int sk[MAXN],f[MAXN]; int n,m=127,len1; string s1,s2,str; ll ans1,ans2,ans3; int a[MAXN]; void RSort(){ for(int i=1;i<=m;i++)tax[i]=0; for(int i=1;i<=n;i++)tax[rk[i]]++; for(int i=1;i<=m;i++)tax[i]+=tax[i-1]; for(int i=n;i>=1;i--)sa[tax[rk[tp[i]]]--]=tp[i]; } bool cmp(int *f,int x,int y,int w){ return f[x]==f[y]&&f[x+w]==f[y+w]; } ll Suffix(){ ll ans=0; for(int i=1;i<=n;i++)rk[i]=a[i],tp[i]=i,f[i]=0;//f数组初始化不要用memset,会T RSort();int p=0; for(int w=1;p<n;w+=w,m=p){ p=0;for(int i=n-w+1;i<=n;i++)tp[++p]=i; for(int i=1;i<=n;i++)if(sa[i]>w)tp[++p]=sa[i]-w; RSort();swap(tp,rk);rk[sa[1]]=p=1; for(int i=2;i<=n;i++)rk[sa[i]]=cmp(tp,sa[i],sa[i-1],w)?p:++p; } int k=0,j=0; for(int i=1;i<=n;Height[rk[i++]]=k){ for(k=k?k-1:k,j=sa[rk[i]-1];a[i+k]==a[j+k];k++); } int top=0; for(int i=2;i<=n;i++){ while(top&&Height[sk[top]]>=Height[i])top--; if(!top)f[i]=Height[i]*(i-1); else f[i]=f[sk[top]]+Height[i]*(i-sk[top]); sk[++top]=i; ans+=f[i]; } return ans; } void clear(){ memset(sa,0,sizeof(sa)); memset(Height,0,sizeof(Height)); m=127; } int main(){ cin>>s1>>s2; len1=s1.length(); str=s1+"#"+s2; n=s1.length(); for(int i=1;i<=n;i++)a[i]=s1[i-1]; ans1=Suffix(); n=s2.length(); for(int i=1;i<=n;i++)a[i]=s2[i-1]; clear(); ans2=Suffix(); n=str.length(); for(int i=1;i<=n;i++)a[i]=str[i-1]; clear(); ans3=Suffix(); cout<<ans3-ans2-ans1<<endl; return 0; } |