前后端同时插入,增加前缀fail指针,发现对于每个节点,前缀fail和后缀fail相同。那么分情况直接更新就好了。并且在前面插入字符时不会影响原树的fail。因为每个节点的fail长度比他本身小。
并且要注意,当插入一个节点时,如果它新生产的回文串的长度等于此时总串的长度,要同时更新L_last,R_last
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
///len表示当前节点对应回文子串长度
///num表示前pam节点(即读入s[i]之后,s[1,…,i]的最长回文后缀对应节点u)对应回文子串的所有不同回文后缀的个数(包括自身)
///sz表示当前节点对应回文子串在主串中出现次数(需要处理完整个主串后倒着对failDP)
///别忘初始化!!!!
///内存问题
const int SZ = 26;///字符集
const int maxn = 2e5 + 6; ///开两倍
struct PAM {
struct PamNode{
int fail,trans[SZ],sz,len,num;}pam[maxn];
int tot,n,l,r;char s[maxn]; ll ans;
void init() {
tot=1; ans=0; l=100003,r=100002;
memset(s,-1,sizeof s); ///特别注意
pam[0].fail=1; pam[0].len=0;
pam[1].fail=1; pam[1].len=-1;
memset(pam[0].trans,0,SZ*sizeof (int));
memset(pam[1].trans,0,SZ*sizeof (int));
}
inline int newnode(int len) {
tot++;
memset(pam[tot].trans,0,SZ*sizeof (int));
pam[tot].len=len; pam[tot].fail=0;
pam[tot].sz=0; pam[tot].num=0;
return tot;
}
inline int L_getfail(int i,int u) {
while(s[i+pam[u].len+1]^s[i]) u=pam[u].fail;
return u;
}
inline int R_getfail(int i,int u) {
while(s[i-pam[u].len-1]^s[i]) u=pam[u].fail;
return u;
}
inline int L_append(char c,int u) {
s[--l]=c; c=c-'a';
int fa=L_getfail(l,u);
u=pam[fa].trans[c];
if(!u) {
int z=newnode(pam[fa].len+2);
int w=L_getfail(l,pam[fa].fail);
pam[z].fail=pam[w].trans[c];
u=pam[fa].trans[c]=z;///注意这里要后更新,否则上面getfail时可能导致死循环
pam[z].num=pam[pam[z].fail].num+1;
}
ans=ans+pam[u].num;
pam[u].sz++;return u;
}
inline int R_append(char c,int u) {
s[++r]=c; c=c-'a';
int fa=R_getfail(r,u);
u=pam[fa].trans[c];
if(!u) {
int z=newnode(pam[fa].len+2);
int w=R_getfail(r,pam[fa].fail);
pam[z].fail=pam[w].trans[c];
u=pam[fa].trans[c]=z;///注意这里要后更新,否则上面getfail时可能导致死循环
pam[z].num=pam[pam[z].fail].num+1;
}
ans=ans+pam[u].num;
pam[u].sz++;return u;
}
void calu() {
for(int i=tot,fail;i>1;i--) {
fail=pam[i].fail;
pam[fail].sz+=pam[i].sz;
}
}
}pa;
int n,k;
int main() {
while(~scanf("%d",&n)) {
pa.init(); char op[3];
for(int i=1,L_last=0,R_last=0;i<=n;i++) {
scanf("%d",&k);
if(k==1) {
scanf("%s",op);
L_last=pa.L_append(op[0],L_last);
if(pa.pam[L_last].len==pa.r-pa.l+1) R_last=L_last;
}
if(k==2) {
scanf("%s",op);
R_last=pa.R_append(op[0],R_last);
if(pa.pam[R_last].len==pa.r-pa.l+1) L_last=R_last;
}
if(k==3) printf("%d\n",pa.tot-1);
if(k==4) printf("%lld\n",pa.ans);
}
}
}