2020牛客暑期多校训练营(第二场)(A hash 计数 +kmp next数组 去重)

题目链接

All with Pairs

题意:n个字符串,对每个字符串i 求所有字符串j 的f(si,sj)*f(si,sj) 的和。 f(si,sj)代表 si的最长前缀与 sj的后缀相同。

比如 f( ab  aab ) =2 , f(ba,aab) =1

做法:很明显的对所有的后缀hash 保存,然后  枚举 对 每个字符串i的前缀hash  ans+=mp[hash]*len*len  len 是当前前缀hash 的长度。

当然这样计算是会有重复计算的。

例如  aba  和  aba    会有len=3  len =1 两种情况,而答案只需要保存 len=3  。那么怎么去掉重复的计算呢?

当前找到一个后缀 aba 与  map里面保存的 aba 计算答案的时候 ,就一定会出现  以当前位置 j  为最长后缀(map里的字符串),匹配到一个与 当前字符串i 的 最长前缀。这句话越看越像kmp里的next数组性质。

于是 当样例 “aba”   “aba“匹配到len=3 的时候 然后将 位置 next[3] 的字符 标志vis[next[3]]++;当枚举到 前缀 next[3] 位置的时候 将map值减去 vis[next[3]]即可。普通的map超时了,然后用了队友的秘制 hash_map 过了。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=(b);++i)
#define mem(a,x) memset(a,x,sizeof(a))
#define pb push_back
using namespace std;
typedef long long ll;
ll gcd(ll a,ll b) { return b?gcd(b,a%b):a;}
const int N=1e6+1,M=1e6+1;
 
ll base[2]={43,47};
ll f[2][N],mod[2]={1000000007,998244353},h[3][N];
void init()
{
    for(int j=0;j<=1;++j) f[j][0]=1;
    for(int i=1;i<N;++i)
    for(int j=0;j<=1;++j) f[j][i]=f[j][i-1]*base[j]%mod[j];
}
ll getv(int l,int r,int j)
{
    return (h[j][r]-h[j][l-1]*f[j][r-l+1]%mod[j]+mod[j])%mod[j];
}
const ll modd=998244353;
 
const ll mood=1e9+7;
const int maxsz=3e6+7;
 
template<typename key,typename val>
class hash_map{public:
  struct node{key u;val v;int next;};
  vector<node> e;
  int head[maxsz],nume,numk,id[maxsz];
  int geths(pair<ll,ll> &u){
    int x=(1ll*u.first*mood+u.second)%maxsz;
    if(x<0) return x+maxsz;
    return x;
  }
  val& operator[](key u){
    int hs=geths(u);
    for(int i=head[hs];i;i=e[i].next)if(e[i].u==u) return e[i].v;
    if(!head[hs])id[++numk]=hs;
    if(++nume>=e.size())e.resize(nume<<1);
    return e[nume]=(node){u,0,head[hs]},head[hs]=nume,e[nume].v;
  }
  void clear(){
    for(int i=0;i<=numk;i++) head[id[i]]=0;
    numk=nume=0;
  }
};
hash_map<pair<ll,ll>,ll> mp;
 
// unordered_map<pair<ll,ll>,ll>mp;
//map<pair<ll,ll>,ll>mp;
string s[N];
ll vis[N];
int ne[N];
void get(string b)  //常规处理方法
{
    int len=b.size();
    ne[0]=-1;
    for(int i=0,j=-1;i<len;)
    {
        if(j==-1||b[i]==b[j]) ne[++i]=++j;
        else j=ne[j];
    }
//  for(int i=0;i<=len;++i){
//        printf("%d ",ne[i]);
//  }
//  puts("");
}
 
int n;
int main()
{
    std::ios::sync_with_stdio(false);
    //get("aaa");
    init();
    cin>>n;
    rep(i,1,n) cin>>s[i];
 
 
    rep(i,1,n)
    {
        int len=s[i].size();
        for(int j=0;j<len;++j){
            int x=s[i][j]-'a'+1;
            for(int k=0;k<=1;++k){
                h[k][j+1]=(h[k][j]*base[k]%mod[k]+x)%mod[k];
            }
        }
 
        pair<ll,ll>tmp;
        for(int j=1;j<=len;++j){
            tmp.first=getv(j,len,0);
            tmp.second=getv(j,len,1);
            //printf("l:%d r:%d tmp:%lld %lld\n",j,len,tmp.first,tmp.second);
            mp[tmp]++;
        }
//        puts("");
//        puts("");
//        puts("");
 
 
    }
    //puts("");
    //puts("");
    //puts("");
 
 
    ll ans=0,pre;
    rep(i,1,n)
    {
        int len=s[i].size();
        for(int j=0;j<len;++j){
            int x=s[i][j]-'a'+1;
            for(int k=0;k<=1;++k){
                h[k][j+1]=(h[k][j]*base[k]%mod[k]+x)%mod[k];
            }
        }
 
        //puts("");
        get(s[i]);
        for(int i=0;i<=len;++i) vis[i]=0;
 
        pair<ll,ll>tmp;
        pre=0;
        for(int j=len;j>=1;--j){
            tmp.first=getv(1,j,0);
            tmp.second=getv(1,j,1);
 
//            printf("l:%d r:%d tmp:%lld %lld\n",1,j,tmp.first,tmp.second);
//            printf("mp:%lld pre:%lld\n\n",mp[tmp],pre);
 
 
 
            ans=(ans+(mp[tmp]-vis[j]+modd)%modd*j%modd*j%modd)%modd;
            int nx=ne[j];
            vis[nx]=(vis[nx]+mp[tmp])%modd;
            pre=mp[tmp];
        }
    }
 
 
    cout<<ans<<endl;
    //printf("%lld\n",ans);
 
}
/*
3
abc
abc
abc
*/

猜你喜欢

转载自blog.csdn.net/qq_41286356/article/details/107327676