链接
题解
不带 的很好做
带 的,可以通过枚举 前面的一段,并且拿数据结构维护后面的一段即可
在后缀自动机上枚举 前面的一段,也就是在后缀自动机的每个节点上统计,我可以得到一系列的 ,接下来就是在这些 处统计本质不同前缀的个数。
这里可以拿后缀数组来做,一上来先把所有后缀的 求出来,每个后缀分别丢进一个 ,然后在 树上对 做启发式合并,不停的合并后缀集合,合并的同时维护当前集合有多少本质不同的子串。
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 100010
#define maxe 200010
#define maxk 17
#define cl(x) memset(x,0,sizeof(x))
#define rep(_,__) for(_=1;_<=(__);_++)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
int w[maxn<<1];
struct SuffixArray
{
int sa[maxn], rank[maxn], ws[maxn], wv[maxn], wa[maxn], wb[maxn], height[maxn], st[maxk+2][maxn], N;
bool cmp(int *r, int a, int b, int l){return r[a]==r[b] and r[a+l]==r[b+l];}
void clear()
{
cl(sa), cl(rank), cl(ws), cl(wv), cl(wa), cl(wb), cl(height);
}
void build(char *r, int n, int m)
{
N=n;
n++;
int i, j, k=0, p, *x=wa, *y=wb, *t;
for(i=0;i<m;i++)ws[i]=0;
for(i=0;i<n;i++)ws[x[i]=r[i]]++;
for(i=1;i<m;i++)ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--)sa[--ws[x[i]]]=i;
for(p=j=1;p<n;j<<=1,m=p)
{
for(p=0,i=n-j;i<n;i++)y[p++]=i;
for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;
for(i=0;i<n;i++)wv[i]=x[y[i]];
for(i=0;i<m;i++)ws[i]=0;
for(i=0;i<n;i++)ws[wv[i]]++;
for(i=1;i<m;i++)ws[i]+=ws[i-1];
for(i=n-1;i>=0;i--)sa[--ws[wv[i]]]=y[i];
for(t=x,x=y,y=t,p=1,i=1,x[sa[0]]=0;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
for(i=0;i<n;i++)rank[sa[i]]=i;
for(i=0;i<n-1;height[rank[i++]]=k)
for(k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
}
void build_st() //st表
{
int i, k;
for(i=1;i<=N;i++)st[0][i]=height[i];
for(k=1;k<=maxk;k++)
for(i=1;i+(1<<k)-1<=N;i++)
st[k][i]=min(st[k-1][i],st[k-1][i+(1<<k-1)]);
}
int lcp(int x, int y) //最长公共前缀
{
int l=rank[x], r=rank[y];
if(l>r)swap(l,r);
if(l==r)return N-sa[l];
int t=log2(r-l);
return min(st[t][l+1],st[t][r-(1<<t)+1]);
}
}SA;
struct SAM
{
int tot, las, ch[maxn<<1][26], fa[maxn<<1], len[maxn<<1], pref[maxn<<1];
int* operator[](ll u){return ch[u];}
void init()
{
int i;
rep(i,tot)cl(ch[i]),fa[i]=len[i]=pref[i]=0;
tot=las=1;
}
void append(int c, int tag=1)
{
int p(las);
len[las=++tot]=len[p]+1;
pref[las]=tag;
for(;p and !ch[p][c];p=fa[p])ch[p][c]=las;
if(!p)fa[las]=1;
else
{
int q=ch[p][c];
if(len[q]==len[p]+1)fa[las]=q;
else
{
int qq=++tot;
memcpy(ch[qq],ch[q],sizeof(ch[q]));
fa[qq]=fa[q];
len[qq]=len[p]+1;
fa[q]=fa[las]=qq;
for(;ch[p][c]==q;p=fa[p])ch[p][c]=qq;
}
}
}
}sam;
struct Graph
{
int etot, head[maxn<<1], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(pos,G) for(auto p=G.head[pos];p;p=G.next[p])
}G;
char s[maxn];
ll ans;
set<int> S[maxn];
ll cnt[maxn], n;
int deg[maxn<<1], Sid[maxn<<1];
int join(int x, int y)
{
if(!x or !y)return x|y;
if(S[x].size()<S[y].size())swap(x,y);
auto& s=S[x];
for(auto t:S[y])
{
if(abs(t)==iinf or !t)continue;
s.insert(t);
auto it1=s.find(t), it2=it1;
it1--, it2++;
if(abs(*it2)!=iinf)cnt[x] -= n-SA.sa[*it2]-SA.lcp(SA.sa[*it1],SA.sa[*it2]);
cnt[x] += n-SA.sa[t]-SA.lcp(SA.sa[*it1],SA.sa[t]);
if(abs(*it2)!=iinf)cnt[x] += n-SA.sa[*it2]-SA.lcp(SA.sa[*it2],SA.sa[t]);
}
return x;
}
void dp()
{
queue<int> q; int i;
rep(i,sam.tot)if(deg[i]==0)q.em(i);
while(!q.empty())
{
auto u=q.front(); q.pop();
auto f=sam.fa[u];
if(--deg[f]==0)q.em(f);
if(sam.pref[u])Sid[u] = join(Sid[u],sam.pref[u]+2);
ans += ( sam.len[u] - sam.len[f] ) * (cnt[Sid[u]]+1);
Sid[f] = join(Sid[f],Sid[u]);
}
}
int main()
{
int i;
scanf("%s",s+1); n=strlen(s+1);
if(n==1)
{
printf("3");
return 0;
}
sam.init();
rep(i,n-1)sam.append(s[i]-'a',i);
SA.build(s+1,n,300);
SA.build_st();
rep(i,n)
{
S[i].em(0);
S[i].em(iinf);
S[i].em(SA.rank[i-1]);
cnt[i]=n-i+1;
}
rep(i,sam.tot)if(i>1)deg[sam.fa[i]]++;
dp(); //(非空)*(可空)
auto t = join(Sid[1],2);
ans += cnt[t] + 1; //*开头的以及*本身
sam.append(s[n]-'a');
rep(i,sam.tot)ans+=sam.len[i]-sam.len[sam.fa[i]]; //不包含*的串
printf("%lld",ans+1); //最后加上空串
return 0;
}