Step1 Problem:
给你 n 个字符串,n 个字符串两两连接(n^2种连接方法),组成的所有字符串,有多少个回文串。
Step2 Ideas:
核心:s1 串和 s2 串连接起来是回文,s1 和 s2 中较短的串,是长串反转后的前缀,长串前缀后面的字符构成回文.
字典树是放一堆字符串很好的结构,每个节点代表着其中一个字符串的前缀,我们需要O(1)处理,前缀后面剩下的字符构成的串有多少个回文。这里可以用manacher, 扩展KMP都可以,我采用的是manacher算法。
再把每个串反转后去和字典树匹配,如果反串当前位置后的字符构成的串是回文(O(1)判断),就加上该节点有几个字符串。如果反串全匹配成功,加上前面预处理出来的前缀后面剩下的字符构成的串有多少个回文。
Step3 Code:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
const int N = 2e6+5;
struct node
{
int data, data1;
node *next[26];
};
node a[N];
char s[N], fzs[N], ts[N*2];
int len, pos[N], top, lent, tp[N*2];
int get_s(char s1[], int len)
{
int j = 0;
ts[j++] = '$'; ts[j++] = '#';
for(int i = 0; i < len; i++)
{
ts[j++] = s1[i];
ts[j++] = '#';
}
ts[j] = '\0';
return j;
}
void get_p(int len)
{
int mx = 0, id;
for(int i = 1; i <= len; i++)
{
if(i < mx)
tp[i] = min(tp[2*id - i], mx - i);
else tp[i] = 1;
while(ts[i-tp[i]] == ts[i+tp[i]])
tp[i]++;
if(mx < i+tp[i])
{
id = i;
mx = i+tp[i];
}
}
}
node *creat_kong()
{
node *root = &a[top++];
root->data = root->data1 = 0;
for(int i = 0; i < 26; i++)
root->next[i] = NULL;
return root;
}
node *Insert(node *root, char s[], int len, int len1)
{
node *p = root;
for(int i = 0; i < len; i++)
{
int tmp = s[i] - 'a';
if(!p->next[tmp]) p->next[tmp] = creat_kong();
p = p->next[tmp];
int t = (len1+(i+1)*2)/2;
if(tp[t]-1 == len-i-1)//如果剩下的字符构成回文串
{
p->data1++;
}
}
p->data++;
return root;
}
int Find(node *root, char s[], int len, int len1)
{
node *p = root;
int ans = 0, i;
for(i = 0; i < len; i++)
{
int tmp = s[i] - 'a';
if(!p->next[tmp]) break;
p = p->next[tmp];
int t = (len1+(i+1)*2)/2;
if(tp[t]-1 == len-i-1 && i!=len-1)//如果反串剩下的字符构成回文串 i!=len-1是为了不重复计算
ans += p->data;
}
if(i == len) {//反串全匹配完,加上前面预处理出来的
ans += p->data1;
}
return ans;
}
void get_fzs(char s[], int len)
{
for(int i = 0; i < len; i++)
fzs[len-i-1] = s[i];
fzs[len] = '\0';
}
int main()
{
int n;
while(~scanf("%d", &n))
{
len = 0, top = 0;
node *root = creat_kong();
for(int i = 0; i < n; i++)
{
scanf("%d %s", &pos[len], s+len);
int len1 = get_s(s+len, pos[len]);
get_p(len1);
root = Insert(root, s+len, pos[len], len1);
len = len+pos[len];
}
int tmp = 0;
ll ans = 0;
for(int i = 0; i < n; i++)
{
get_fzs(s+tmp, pos[tmp]);
int len1 = get_s(fzs, pos[tmp]);
get_p(len1);
ans += Find(root, fzs, pos[tmp], len1);
tmp = tmp + pos[tmp];
}
printf("%lld\n", ans);
}
return 0;
}