题目中要求的串,可以概括为两个回文串的组合:_____ i ____ j _____
以 i 和以 j 为中心的两个回文串拼接起来,注意 i 的回文串范围要覆盖 j,j 的回文串范围要覆盖 i
设p[i]是以 i 为中心的回文串覆盖半径(不包括 i )
那么要符合三个条件:j>i ; j<=i+p[i] ; i>=j-p[j] ;
先用Manacher算法算出以每个点为中心的最长回文串长度
再按照 i 从1~n维护满足( j-p[j]>=i )的 j 在树状数组里,然后查询[ i+1,i+p[i] ]范围内的 j 有多少个,加起来就是答案
水平太菜,比赛的时候只想到了马拉车,没想到怎么去优化。树状数组板子记一下吧。
看到题解都是用树状数组写的,我就试着用线段树写了一下,一直re,我觉得是线段树爆了,开不了那么的结构体数组,用指针写应该可以吧,太麻烦了,我就不试了。只能了解一下树状数组了。
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<iomanip>
#include<stdio.h>
#include<algorithm>
#include<map>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
# define inf 0x3f3f3f3f
# define maxn 500000+10
# define ll long long
ll c[maxn];
int n;
inline int lowbit(int x)
{
return x & -x;
}
void update(int idx, int v)
{
while (idx <= n)
{
c[idx] += v;
idx += lowbit(idx);
}
}
ll query(int x) // The sum of 1 to x
{
ll ans = 0;
while(x > 0)
{
ans += c[x];
x -= lowbit(x);
}
return ans;
}
struct code
{
int l,r;
ll sum;
} tree[maxn*6];
void buid(int t,int l,int r)
{
tree[t].l=l;
tree[t].r=r;
tree[t].sum=0;
if(l==r)
return ;
int mid=(l+r)/2;
buid(2*t,l,mid);
buid(2*t+1,mid+1,r);
}
void add(int t,int posi)
{
tree[t].sum++;
if(tree[t].l==tree[t].r)
return ;
int mid =(tree[t].l+tree[t].r)/2;
if(posi<=mid)
add(2*t,posi);
else
add(2*t+1,posi);
}
ll queuy(int t,int l,int r)
{
if(tree[t].l==l&&tree[t].r==r)
return tree[t].sum;
ll ans=0,mid=(tree[t].l+tree[t].r)/2;
if(r<=mid)
ans=queuy(2*t,l,r);
else if(l>mid)
ans=queuy(2*t+1,l,r);
else
ans=queuy(2*t,l,mid)+queuy(2*t+1,mid+1,r);
return ans;
}
char s[2*maxn],A[2*maxn];
int B[2*maxn];
int ans[2*maxn];
vector<int>dap[maxn];
void manacher(char s[],int len)
{
int l=0;
A[l++]='$';
A[l++]='#';
for(int i=0; i<len; i++)
{
A[l++]=s[i];
A[l++]='#';
}
A[l]=0;
int mx=0;
int id=0;
for(int i=0; i<l; i++)
{
B[i]=mx>i?min(B[2*id-i],mx-i):1;
while(A[i+B[i]]==A[i-B[i]])
{
B[i]++;
}
if(i+B[i]>mx)
{
mx=i+B[i];
id=i;
}
}
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
int num=0;
scanf("%s",s);
int len=strlen(s);
n=len;
manacher(s,len);
for(int i=0; i<2*len+2; i++)
{
if(A[i]>='a'&&A[i]<='z')
{
ans[++num]=B[i]/2-1;
}
}
// buid(1,0,len-1);
memset(c,0,sizeof(c));
for(int i=1; i<=len; i++)
dap[i].clear();
for(int i=1; i<=len; i++)
dap[i-ans[i]].push_back(i);
ll t=0;
for(int i=1; i<=len; i++)
{
int lax=dap[i].size();
for(int j=0; j<lax; j++)
update(dap[i][j],1);
t+=query(i+ans[i])-query(i);
}
printf("%lld\n",t);
}
return 0;
}