本文只记录AC自动机的入门练习题,不再详解算法原理。
在掌握字典树(Trie树)和KMP思想的基础上,学习AC自动机算法原理,推荐阅读以下文章:
- 洛谷日报 强势图解AC自动机 https://www.luogu.com.cn/blog/3383669u/qiang-shi-tu-xie-ac-zi-dong-ji
- AC自动机 - 多模式匹配算法 https://blog.csdn.net/xaiojiang/article/details/52299258
- AC 自动机 - OI Wiki https://oi-wiki.org/string/ac-automaton/
一、HDU 2222 Keywords Search
AC自动机模板题,具体细节详见代码注释。代码参考:洛谷日报 强势图解AC自动机。
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10,M=26;
struct node
{
int ch[N][M];
int cnt[N];
int fail[N];
int tot;
queue<int>q;
void init() // 初始化
{
memset(cnt,0,sizeof(cnt));
memset(ch,0,sizeof(ch));
memset(fail,0,sizeof(fail));
tot=0;
}
void ins(char s[]) // insert代码同Trie树
{
int u=0; // 根节点为0,从根节点开始往下走
for(int i=0;s[i];i++)
{
int x=s[i]-'a'; // a~z -> 0~25
if(!ch[u][x])ch[u][x]=++tot; // 没有节点就造节点
u=ch[u][x]; // 向下遍历
}
// 此模式串终点对应编号为u,更新以u结尾的模式串个数
cnt[u]++;
}
void build_fail()
{
fail[0]=0; // 根节点0的fail指针指向自身
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]); // 与根节点直接相连的一层节点入队
}
while(!q.empty()) // 树的层次遍历
{
int u=q.front();q.pop(); // 父节点u
for(int i=0;i<M;i++)
{
int &v=ch[u][i]; // 子节点v,因为之后可能要修改,所以用&取引用
if(v)
{
fail[v]=ch[fail[u]][i];
q.push(v);
}
// else这短短一行代码,是算法优化,体现出路径压缩的思想
// 从而使得在之后的遍历过程中能更快得到fail指针
// 子节点不存在,造一个子节点,修改Trie树,从而将Trie树改造成Trie图
else v=ch[fail[u]][i];
}
}
}
int query(char s[])
{
int u=0,ans=0;
for(int i=0;s[i];i++)
{
int x=s[i]-'a';
u=ch[u][x];
// if(~cnt[j]) 等价于 if(cnt[j]!=-1)
// 不断向上找fail指针,找到cnt值存在的就更新答案
// 遍历到根或之前遍历过的点时停止
for(int j=u;j&&~cnt[j];j=fail[j])
{
ans+=cnt[j];
cnt[j]=-1;
// cnt在这里可以起到标记的作用,这样每个点至多被遍历一次
}
}
return ans;
}
}ac;
char t[N];
int n,T;
int main()
{
ios::sync_with_stdio(false);
cin>>T;
while(T--)
{
ac.init();
cin>>n;
for(int i=1;i<=n;i++)
{
cin>>t;
ac.ins(t);
}
ac.build_fail();
cin>>t;
printf("%d\n",ac.query(t));
}
return 0;
}
/*
1
2
she
h
she
ans:2
*/
二、HDU 2896 病毒侵袭
基本上也是模板题吧。
坑的地方在于,交了几次,数组开小了会RE,数组开大了会MLE,后来发现错误原因是用的二维vector记录模式串终点,而实际上题目说了不会出现两个相同的模式串(那么插入模式串后,每个终点都只唯一对应一个模式串),开二维的MLE,那就改成一维数组记录模式串终点对应的编号,大小开500*200就行了;
还有就是字符串中可能会有空格(空格ASCII码为32),应该用gets读入,gets之前记得写getchar()吸收回车。
#include <bits/stdc++.h>
using namespace std;
const int N=1e5+10,M=130,K=1e4+10;
int n,m,sum;
char t[K];
struct node
{
int ch[N][M];
int tot=0;
int cnt[N];
int fail[N];
queue<int>q;
void ins(int num,char s[]) // 插入模式串s[],编号为num
{
int u=0;
for(int i=0;s[i];i++)
{
int x=s[i]; // 直接用ASCII码
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}// 遍历到的终点是u
cnt[u]=num; // 以u结尾的模式串编号是num(唯一对应)
}
void build_fail()
{
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<M;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
fail[v]=f;
q.push(v);
}
else v=f;
}
}
}
set<int>ans;
int tmp[N];
bool query(char s[])
{
int u=0;
ans.clear();
memcpy(tmp,cnt,sizeof(cnt)); // tmp是cnt的一个拷贝,之后要修改tmp
bool flag=0;
for(int i=0;s[i];i++)
{
int x=s[i];
u=ch[u][x];
for(int j=u;j&&tmp[j];j=fail[j])
{
flag=1;
ans.insert(tmp[j]);
tmp[j]=0;
}
}
return flag;
}
}ac;
int main()
{
cin>>n;
getchar();
for(int i=1;i<=n;i++)
{
gets(t); // 不是cin!
ac.ins(i,t);
}
ac.build_fail();
cin>>m;
getchar();
for(int i=1;i<=m;i++)
{
gets(t); // 不是cin!
if(ac.query(t))
{
sum++;
printf("web %d:",i);
for(auto j:ac.ans)
{
printf(" %d",j);
}
printf("\n");
}
}
printf("total: %d\n",sum);
return 0;
}
三、HDU 3065 病毒侵袭持续中
这题和其他题不同的地方在于,在query函数,fail指针不断向上跳的过程中,跳回根节点才停止,每个点是可以被多次遍历到的,所以不需要(不能)标记访问过的点,这样才满足题目要求。虽然不能保证每个点只被遍历一次,时间复杂度会大一些,但是必须这样做,计数才不会有遗漏。
看这个样例,就很好明白了:
Input
3
A
AA
AAA
AAAA
Output
A: 4
AA: 3
AAA: 2
为什么A被遍历了4次,原因就在于,每次从目标串的当前位置向上跳fail指针的时候,都会遍历到单个的A。目标串长度为4,单个的A总共就被计数了4次。
还有坑的地方就是题目不说清楚是多组输入(出题人出来挨打!),单组输入会给你评测WA;
如果把gets写成了cin,评测结果不是WA而是TLE,非常的误导人,我先还以为是fail指针向上跳的时候会重复遍历点,时间复杂度太大,后来才发现是cin错了…
#include <bits/stdc++.h>
using namespace std;
const int N=5e4+10,M=26,K=2e6+10;
char ans[1005][55];// 模式串
char t[K]; // 目标串
map<int,int>vis; // 模式串编号对应出现的次数
struct node
{
int ch[N][M];
int cnt[N];
int tot;
int fail[N];
queue<int>q;
void init()
{
memset(ch,0,sizeof(ch));
memset(fail,0,sizeof(fail));
memset(cnt,0,sizeof(cnt));
while(!q.empty())q.pop();
tot=0;
}
void ins(int num,char s[])
{
int u=0;
for(int i=0;s[i];i++)
{
int x=s[i]-'A'; // A~Z -> 0~25
if(!ch[u][x])ch[u][x]=++tot;
u=ch[u][x];
}
cnt[u]=num;
}
void build_fail()
{
for(int i=0;i<M;i++)
{
if(ch[0][i])
q.push(ch[0][i]);
}
while(!q.empty())
{
int u=q.front();q.pop();
for(int i=0;i<M;i++)
{
int &v=ch[u][i];
int f=ch[fail[u]][i];
if(v)
{
fail[v]=f;
q.push(v);
}
else v=f;
}
}
}
void query(char s[])
{
int u=0;
vis.clear();
for(int i=0;s[i];i++)
{
if(s[i]<'A'||s[i]>'Z') // 跳回根节点
{
u=0;
continue;
}
int x=s[i]-'A';
u=ch[u][x];
for(int j=u;j&&cnt[j];j=fail[j])
{
vis[cnt[j]]++; //j对应模式串编号为cnt[j],个数+1
// 此处不能修改cnt[j]=0
// 因为按照题目要求,之后被再次遍历到,需要再次计数
}
}
}
}ac;
int main()
{
int n;
while(cin>>n) // 多组输入(题目没说清楚是多组!)
{
ac.init();
for(int i=1;i<=n;i++)
{
cin>>ans[i];
ac.ins(i,ans[i]);
}
ac.build_fail();
getchar();
gets(t); // gets!含空格!
ac.query(t);
for(auto it:vis)
{
int a=it.first; // 模式串下标
int b=it.second; // 出现次数
printf("%s: %d\n",ans[a],b);
}
}
return 0;
}
/*
3
A
AA
AAA
AAAA
A: 4
AA: 3
AAA: 2
*/