今天花了一天的时间尝试学习SAM,于是有了以下的见解。
后缀自动机真的很强,也就是从根节点(起点),走的每一条路径都是一个子串,这样一来和AC自动机走前缀有点相似。其实能利用的性质还有很多:
- parent树:相当于是fail树,par指针指向的是当前状态对应的所有字符串的最长公共前缀对应的状态,fail树相当于字符串逆序之后形成的后缀树
- Max(s):大家喜欢写成len,是指状态S或者说是节点S对应的最长字串的长度
- Min(s): 一个状态对用的最小字符串长度(或者说是能走到的长度,上同),实际上等于par树上父节点的Max+1,相当于加了一个字符,怎么说都要多一个
- right:能够表示状态s的字符串一共出现了了right次,需要预 处理
- 最大能接受的字符数max - min +1:其实就是这个状态对应的字串集合的大小
有了这些工具就好办了
对于SAM初学,要深刻理解出现次数向父亲传递,接收串数从儿子获取这句话。
一定要时刻把握这几条性质:
1.走 子串
2.Parent Tree的祖先 Right集合变大,字符串变短(路径长度变短),并且是后代的后缀哦
3.出现次数向父亲(Parent边)传递,接收串数从儿子(仍然Parent边)获取、
4.拓扑排序/对val用基数排序 , 然后可以转移边/Parent边 DP ,可以倒着递推出|Right|
SPOJ LCS Longest Common Substring
这道题做法就是先给A建SAM,然后用B串在上面“walk“,相当于一个匹配过程,沿着next走,可以走的话相当于长度++,走不了我们就走fa指针回退,能走到我们就从这里走,走不到的话就重新从root走,每次走到一个地方更新下答案就好
#include <stdio.h>
#include <iostream>
#include <string.h>
#define ll long long
using namespace std;
const int CHAR = 26;
const int MAXN = 250010;
struct SAM_Node
{
SAM_Node *fa,*next[CHAR];
int len;
int id,pos;
SAM_Node() {}
SAM_Node(int _len)
{
fa = 0;
len = _len;
memset(next,0,sizeof(next));
}
};
SAM_Node SAM_node[MAXN*2], *SAM_root, *SAM_last;
int SAM_size;
SAM_Node *newSAM_Node(int len)
{
SAM_node[SAM_size] = SAM_Node(len);
SAM_node[SAM_size].id = SAM_size;
return &SAM_node[SAM_size++];
}
SAM_Node *newSAM_Node(SAM_Node *p)
{
SAM_node[SAM_size] = *p;
SAM_node[SAM_size].id = SAM_size;
return &SAM_node[SAM_size++];
}
void SAM_init()
{
SAM_size = 0;
SAM_root = SAM_last = newSAM_Node(0);
SAM_node[0].pos = 0;
}
void SAM_add(int x,int len)
{
SAM_Node *p = SAM_last, *np = newSAM_Node(p->len+1);
np->pos = len;
SAM_last = np;
for(; p && !p->next[x]; p = p->fa)
p->next[x] = np;
if(!p)
{
np->fa = SAM_root;
return;
}
SAM_Node *q = p->next[x];
if(q->len == p->len + 1)
{
np->fa = q;
return;
}
SAM_Node *nq = newSAM_Node(q);
nq->len = p->len + 1;
q->fa = nq;
np->fa = nq;
for(; p && p->next[x] == q; p = p->fa)
p->next[x] = nq;
}
void SAM_build(char *s)
{
SAM_init();
int len = strlen(s);
for(int i = 0; i < len; i++)
SAM_add(s[i] - 'a',i+1);
}
char a[MAXN],b[MAXN];
int main()
{
while(scanf("%s%s",a,b) != EOF)
{
SAM_build(a);
int len = strlen(b);
SAM_Node *p = SAM_root;
int ans = 0,t = 0;
for(int i = 0;i<len;i++)
{
int x = b[i] - 'a';
if(p->next[x])
{
p = p->next[x];
t++;
}
else
{
while(p && !p->next[x]) p = p->fa;
if(p == NULL) p = SAM_root,t = 0;
else
{
t = p->len+1;
p = p->next[x];
}
}
ans = max(ans,t);
}
printf("%d\n",ans);
}
return 0;
}
SPOJ LCS2 Longest Common Substring II
上一道题的增强版,多个串的话我们还是先对第一个串建立SAM,之后每个串在上面walk一下,反向更新出每个节点能承载的最大值。
解释一下:我们对每个串进行匹配(walk)之后,需要更新这个最大值,这个最大值相当于是每个字符串的最长公共子串子串,落到这个节点上的长度的最小值。更新这个就需要一并更新他的fa节点,因为fa连接的后缀同样会因为单的更新而更新。这就需要对节点在parent树上进行排序之后逆序更新了。
#include <stdio.h>
#include <iostream>
#include <string.h>
#define ll long long
using namespace std;
const int MAXN = 2e5 + 5;
struct node
{
int next[26],fa,len;
}t[MAXN<<1];
int sz,root,last;
inline int new_node(int len)
{
t[++sz].len = len;
return sz;
}
void init()
{
sz = 0;
root = last = new_node(0);
}
void extend(int c)
{
int p =last,np = new_node(t[p].len+1);
while(p && t[p].next[c] == 0) t[p].next[c] = np,p = t[p].fa;
if(p == 0) t[np].fa = root;
else
{
int q = t[p].next[c];
if(t[q].len == t[p].len+1) t[np].fa = q;
else
{
int nq = new_node(t[p].len+1);
memcpy(t[nq].next,t[q].next,sizeof(t[q].next));
t[nq].fa = t[q].fa;
t[q].fa = t[np].fa = nq;
while(p && t[p].next[c] == q) t[p].next[c] = nq,p = t[p].fa;
}
}
last = np;
}
int mx[MAXN<<1],mn[MAXN];
void walk(char s[])
{
int sum = 0,p = root,n = strlen(s+1);
for(int i = 1;i<=n;i++)
{
int c = s[i] - 'a';
if(t[p].next[c])
{
p = t[p].next[c];
mx[p] = max(mx[p],++sum);
}
else
{
while(p && !t[p].next[c]) p = t[p].fa;
if(!p) p = root,sum = 0;
else
{
sum = t[p].len+1;
p = t[p].next[c];
mx[p] = max(mx[p],sum);
}
}
}
}
int c[MAXN<<1],a[MAXN<<1];
int ans;
char buf[MAXN];
int main()
{
init();
scanf("%s",buf+1);
int len = strlen(buf+1);
for(int i = 1;i<=len;i++) extend(buf[i] - 'a');
for(int i = 1;i<=sz;i++)
{
c[t[i].len]++;
}
for(int i = 1;i<=len;i++)
{
c[i] += c[i-1];
}
for(int i = sz;i >= 1;i--)
{
a[c[t[i].len]--] = i;
}
for(int i = 1;i<=sz;i++)
{
mn[i] = t[i].len;
}
while(scanf("%s",buf+1) != EOF)
{
walk(buf);
for(int j = sz;j >=1;j--)
{
int u = a[j];
mn[u] = min(mn[u],mx[u]);
if(mx[u] && t[u].fa) mx[t[u].fa] = t[t[u].fa].len;
mx[u] = 0;
}
}
ans = 0;
for(int i =2;i<=sz;i++)
{
ans = max(ans,mn[i]);
}
printf("%d\n",ans);
return 0;
}
SPOJ NSUBSTR Substrings
求一个字符串长度为i的字串出现次数的最多是多少次。
从人一个节点走i步,就是一个字串。这个和我们上面提到的right数组有关系,我们先拓扑排序下,维护出每个状态的right,每个right都是1,我们从子节点向上更新,相当于接受串数从子节点获取。然后还要给每个后缀求一波最大,还是因为小的后缀出现的次数和比他长的后缀出现次数有关系。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#define ll long long
using namespace std;
const int MAXN = 260000 ;
const int S = 26;
int f[MAXN<<1];
namespace SAM
{
int next[MAXN<<1][S],len[MAXN<<1],fa[MAXN<<1];
int in[MAXN<<1],right[MAXN<<1];
int tot,last;
queue<int>q;
int newnode()
{
tot++;
memset(next[tot],0,sizeof(int)*S);
fa[tot]=len[tot]=0;
return tot;
}
void init()
{
tot=0;
last=newnode();
}
void Insert(int x)
{
int p,np=newnode();
right[np]=1;
len[np]=len[last]+1;
for(p=last;p&&!next[p][x];p=fa[p])
next[p][x]=np;
if(!p)
fa[np]=1;
else
{
int q=next[p][x];
if(len[q]==len[p]+1)
fa[np]=q;
else
{
int nq=newnode();
fa[nq]=fa[q];
len[nq]=len[p]+1;
memcpy(next[nq],next[q],sizeof(int)*S);
fa[np]=fa[q]=nq;
for(;next[p][x]==q;p=fa[p])
next[p][x]=nq;
}
}
last=np;
}
void Build()
{
int x,i;
for(i=1;i<=tot;i++)
{
in[fa[i]]++;
}
for(i=1;i<=tot;i++)
{
if(!in[i])
q.push(i);
}
while(!q.empty())
{
x=q.front(),q.pop();
in[fa[x]]--;
right[fa[x]]+=right[x];
if(!in[fa[x]])
q.push(fa[x]);
}
}
void calc(int n)
{
for(int i=1;i<=tot;i++)
f[len[i]]=max(f[len[i]],right[i]);
for(int i=n-1;i>=1;i--)
f[i]=max(f[i],f[i+1]);
}
}
char buf[MAXN];
int main()
{
SAM::init();
scanf("%s",buf + 1);
int n = strlen(buf+1);
for(int i = 1;i<=n;i++) SAM::Insert(buf[i]-'a');
SAM::Build();
SAM::calc(n);
for(int i = 1;i<=n;i++)
{
printf("%d\n",f[i]);
}
return 0;
}
SPOJ SUBLEX Lexicographical Substring Search
获取第k大的子串,首先我们更新出每个状态能接受多少个不同的字串sum。说白了字典序第k大还是判断小与它的不同字串个数,基于这一点,我们在next上dfs,只要这个点的sum大于K,说明可以能承载k个,然后走下一个点,由于dfs按照字典序,相当于是走了最小的。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#define ll long long
using namespace std;
const int MAXN = 360000;
const int S = 26;
int f[MAXN<<1];
namespace SAM
{
int next[MAXN<<1][S],len[MAXN<<1],fa[MAXN<<1];
int in[MAXN<<1],right[MAXN<<1],sum[MAXN<<1];
int tot,last;
queue<int>q;
int newnode()
{
tot++;
memset(next[tot],0,sizeof(int)*S);
fa[tot]=len[tot]=0;
return tot;
}
void init()
{
tot=0;
last=newnode();
}
void Insert(int x)
{
int p,np=newnode();
right[np]=1;
len[np]=len[last]+1;
for(p=last;p&&!next[p][x];p=fa[p])
next[p][x]=np;
if(!p)
fa[np]=1;
else
{
int q=next[p][x];
if(len[q]==len[p]+1)
fa[np]=q;
else
{
int nq=newnode();
fa[nq]=fa[q];
len[nq]=len[p]+1;
memcpy(next[nq],next[q],sizeof(int)*S);
fa[np]=fa[q]=nq;
for(;next[p][x]==q;p=fa[p])
next[p][x]=nq;
}
}
last=np;
}
int c[MAXN<<1],a[MAXN<<1];
void Topo(int n)
{
for(int i = 1;i<=tot;i++) c[len[i]]++;
for(int i = 1;i<=n;i++) c[i] += c[i-1];
for(int i = tot;i >= 1;i--) a[c[len[i]]--] = i;
for(int i = tot; i>=1; i--)
{
int p = a[i]; sum[p] = 1;
for(int j = 0;j<26;j++) sum[p] += sum[next[p][j]];
}
}
void Build()
{
int x,i;
for(i=1;i<=tot;i++)
{
in[fa[i]]++;
}
for(i=1;i<=tot;i++)
{
if(!in[i])
q.push(i);
}
while(!q.empty())
{
x=q.front(),q.pop();
in[fa[x]]--;
right[fa[x]]+=right[x];
if(!in[fa[x]])
q.push(fa[x]);
}
}
void solve(int k)
{
int p = 1;
while(k)
{
for(int i =0;i<26;i++)
{
if(sum[next[p][i]] >= k)
{
printf("%c",'a'+i);
k--;
p = next[p][i];
break;
}
else k -= sum[next[p][i]];
}
}
puts("");
}
}
char buf[MAXN];
int main()
{
SAM::init();
scanf("%s",buf + 1);
int n = strlen(buf+1);
for(int i = 1;i<=n;i++) SAM::Insert(buf[i]-'a');
SAM::Topo(n);
int q,k;
scanf("%d",&q);
while(q--)
{
scanf("%d",&k);
SAM::solve(k);
}
return 0;
}
HDU 4416 Good Article Good sentence
求没有在字符串集合B出现的A的字串的个数。我们先对A建SAM,设f[s]是对应节点与其他字符串匹配的最大长度,显然这个节点的len如果比最大匹配长度要大,说明剩下的一截字串是我们想要的答案,注意我们不仅要讲出现次数向父节点传递,还要注意父节点和自己的len差距,我们要更新这两个差距的最大值。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#define ll long long
using namespace std;
const int MAXN = 1e6 ;
const int S = 26;
int f[MAXN<<1];
namespace SAM
{
int next[MAXN<<1][S],len[MAXN<<1],fa[MAXN<<1];
int in[MAXN<<1],right[MAXN<<1],f[MAXN<<1];
int tot,last;
int newnode()
{
tot++;
memset(next[tot],0,sizeof(int)*S);
fa[tot]=len[tot]=0;
return tot;
}
void init()
{
tot=0;
last=newnode();
memset(f,0,sizeof(f));
}
void Insert(int x)
{
int p,np=newnode();
right[np]=1;
len[np]=len[last]+1;
for(p=last; p&&!next[p][x]; p=fa[p])
next[p][x]=np;
if(!p)
fa[np]=1;
else
{
int q=next[p][x];
if(len[q]==len[p]+1)
fa[np]=q;
else
{
int nq=newnode();
fa[nq]=fa[q];
len[nq]=len[p]+1;
memcpy(next[nq],next[q],sizeof(int)*S);
fa[np]=fa[q]=nq;
for(; next[p][x]==q; p=fa[p])
next[p][x]=nq;
}
}
last=np;
}
int c[MAXN<<1],a[MAXN<<1];
void Topo(int n)
{
memset(c,0,sizeof(c));
for(int i = 1; i<=tot; i++)
c[len[i]]++;
for(int i = 1; i<=n; i++)
c[i] += c[i-1];
for(int i = tot; i >= 1; i--)
a[c[len[i]]--] = i;
}
void walk(char s[])
{
int l = 0,p = 1;
int L = strlen(s);
for(int i = 0; i < L; i++)
{
int c = s[i] - 'a';
if(next[p][c])
{
l++;
p = next[p][c];
}
else
{
while(p && !next[p][c])
p = fa[p];
if(p)
{
l = len[p]+1;
p = next[p][c];
}
else p = 1,l = 0;
}
f[p] = max(f[p],l);
}
}
ll solve(int n)
{
memset(c,0,sizeof(c));
for(int i = 1; i<=tot; i++)
c[len[i]]++;
for(int i = 1; i<=n; i++)
c[i] += c[i-1];
for(int i = tot; i >= 1; i--)
a[c[len[i]]--] = i;
ll ans = 0;
for(int i = tot;i >=1;i--)
{
int u = a[i];
if(f[u] < len[u])
{
int tmp = max(f[u],len[fa[u]]);
ans += len[u] - tmp;
}
f[fa[u]] = max(f[fa[u]],f[u]);
}
return ans;
}
}
char buf[MAXN];
int main()
{
int ca,cat = 1;
scanf("%d",&ca);
while(ca--)
{
int m;
scanf("%d",&m);
SAM::init();
scanf("%s",buf);
int n = strlen(buf);
for(int i = 0; i<n; i++)
SAM::Insert(buf[i]-'a');
while(m--)
{
scanf("%s",buf);
SAM::walk(buf);
}
printf("Case %d: %lld\n",cat++,SAM::solve(n));
}
return 0;
}
bzoj 3998 [TJOI2015]弦论
还是求第K大的字串,但是和上面的不同,我们这回要判断是不是有重叠的,重叠的时候就是上面的做法,每个状态的增量为1,重叠时,我们就要让父节点接受字串出现次数了。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <queue>
#define ll long long
using namespace std;
const int MAXN = 5e5+10;
const int S = 26;
int f[MAXN<<1];
namespace SAM
{
int next[MAXN<<1][S],len[MAXN<<1],fa[MAXN<<1];
int in[MAXN<<1],right[MAXN<<1],sum[MAXN<<1],val[MAXN<<1];
int tot,last;
queue<int>q;
int newnode()
{
tot++;
memset(next[tot],0,sizeof(int)*S);
fa[tot]=len[tot]=0;
return tot;
}
void init()
{
tot=0;
last=newnode();
}
void Insert(int x)
{
int p,np=newnode();
right[np]=1;
len[np]=len[last]+1;
val[np] = 1;
for(p=last; p&&!next[p][x]; p=fa[p])
next[p][x]=np;
if(!p)
fa[np]=1;
else
{
int q=next[p][x];
if(len[q]==len[p]+1)
fa[np]=q;
else
{
int nq=newnode();
fa[nq]=fa[q];
len[nq]=len[p]+1;
memcpy(next[nq],next[q],sizeof(int)*S);
fa[np]=fa[q]=nq;
for(; next[p][x]==q; p=fa[p])
next[p][x]=nq;
}
}
last=np;
}
int c[MAXN<<1],a[MAXN<<1];
void Topo(int n,int T)
{
for(int i = 1; i<=tot; i++)
c[len[i]]++;
for(int i = 1; i<=n; i++)
c[i] += c[i-1];
for(int i = tot; i >= 1; i--)
a[c[len[i]]--] = i;
for(int i = tot; i>=1; i--)
{
int p = a[i];
if(T == 1)
val[fa[p]] += val[p];
else
val[p] = 1;
}
val[1] = 0;
for(int i = tot; i >=1; i--)
{
int p = a[i];
sum[p] = val[p];
for(int j = 0; j<26; j++)
{
sum[p] += sum[next[p][j]];
}
}
}
void dfs(int x,int K)
{
if(K<=val[x])
return;
K-=val[x];
for(int i=0; i<26; i++)
if(int t=next[x][i])
{
if(K<=sum[t])
{
putchar(i+'a');
dfs(t,K);
return;
}
K-=sum[t];
}
}
}
char buf[MAXN];
int main()
{
SAM::init();
scanf("%s",buf + 1);
int n = strlen(buf+1);
for(int i = 1; i<=n; i++)
SAM::Insert(buf[i]-'a');
int T,k;
scanf("%d%d",&T,&k);
SAM::Topo(n,T);
if(k > SAM::sum[1])
puts("-1");
else
SAM::dfs(1,k);
return 0;
}