题意
求所有长度为n字符集大小为k的字符串中有多少个字符串可以由某个回文串移位若干次后得到。
分析
不难发现一个回文串移位若干位后一定可以得到一个合法的串,关键在于如何不重不漏的计数。
若一个回文串移位x次后得到的字符串也是一个回文串,则继续往后移位则必然会算重。
那移位多少次后会得到一个新的回文串呢?这显然跟原串的周期有关。
若周期T是偶数,则移位
次后就可以得到一个回文串,否则移位T次后可以得到回文串。
也就是说,一个周期为偶数的回文串的贡献为
,周期为奇数则贡献为T。
设
表示有多少个字符串的周期为
,
,
显然有
反演一下可以得到
那么
注意到当 是奇数且 是偶数时 ,因为 会和 相互抵消,其中 是奇数。对于其他情况不难得到
所以
那么直接用pollard_rho分解质因数就好了。
写完交上去莫名超时,调了半天发现是对long long进行abs的时候会炸掉。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
LL n,k,MOD,prime[9]={2,3,5,7,11,13,17,19,23},tot,ans,a[105],b[105];
/*LL mul(LL x,LL y,LL mo)
{
LL tmp=(x*y-(LL)((double)x*y/mo+0.1)*mo)%mo;
tmp+=tmp<0?mo:0;
return tmp;
}*/
LL mul(LL x,LL y,LL mo)
{
LL ans=0;
while (y)
{
if (y&1) (ans+=x)%=mo;
x=x*2%mo;y>>=1;
}
return ans;
}
LL ksm(LL x,LL y,LL mo)
{
LL ans=1;
while (y)
{
if (y&1) ans=mul(ans,x,mo);
x=mul(x,x,mo);y>>=1;
}
return ans;
}
LL gcd(LL x,LL y)
{
if (!y) return x;
else return gcd(y,x%y);
}
bool MR(LL n)
{
if (n==2) return 1;
if (n%2==0) return 0;
LL w=n-1;int lg=0;
while (w%2==0) w/=2,lg++;
for (int i=0;i<9;i++)
{
if (n==prime[i]) return 1;
LL x=ksm(prime[i],w,n);
for (int j=0;j<lg;j++)
{
LL y=mul(x,x,n);
if (y==1&&x!=1&&x!=n-1) return 0;
x=y;
}
if (x!=1) return 0;
}
return 1;
}
LL rho(LL n)
{
LL c=rand()%(n-1)+1,x1=rand()%n,x2=x1,p=1;int k=1;
for (int i=1;p==1;i++)
{
x1=(mul(x1,x1,n)+c)%n;
if (x1==x2) return p;
p=gcd(n,x1>x2?x1-x2:x2-x1);
if (i==k) x2=x1,k<<=1;
}
return p;
}
void divi(LL n)
{
if (n==1) return;
if (MR(n)) {a[++tot]=n;return;}
LL p=1;
while (p==1) p=rho(n);
divi(p);divi(n/p);
}
LL H(LL n)
{
return (n&1)?n:n/2;
}
LL ksm(LL x,LL y)
{
LL ans=1;
while (y)
{
if (y&1) ans=ans*x%MOD;
x=x*x%MOD;y>>=1;
}
return ans;
}
void dfs(int x,LL d,LL y)
{
if (x>tot)
{
if (!(d&1)&&((n/d)&1)) return;
(ans+=ksm(k,(n/d+1)/2)*(H(n/d)%MOD)%MOD*(y%MOD)%MOD)%=MOD;
return;
}
dfs(x+1,d,y);
y*=(1-a[x]);
for (int i=1;i<=b[x];i++)
{
d*=a[x];
dfs(x+1,d,y);
}
}
int main()
{
int T;scanf("%d",&T);
while (T--)
{
scanf("%lld%lld%lld",&n,&k,&MOD);
k%=MOD;
tot=0;
divi(n);
std::sort(a+1,a+tot+1);
int tmp=0;
for (int i=1;i<=tot;i++)
if (!tmp||a[i]!=a[tmp]) a[++tmp]=a[i],b[tmp]=1;
else b[tmp]++;
tot=tmp;
ans=0;
dfs(1,1,1);
printf("%lld\n",(ans+MOD)%MOD);
}
return 0;
}