感觉题目讲的很不清楚……
题目意思就是给出一个长度为\(n\)的字符串,求对于\(r=0,1,...,n-1\),求出\(LCP(suffix_p,suffix_q) \geq r\)的无序数对\((p,q)\)的数目,并令一对无序数对的价值为\(val_p \times val_q\),则还要求对于每一个\(r\),所有满足上述条件的无序数对中的最大价值
跟后缀\(LCP\)长度有关,直接上\(SA\)。求出\(sa\)数组和\(height\)数组,我们考虑如何实现对于每一个\(r\)的询问快速求出答案。不妨将\(r\)从大到小求解,那么对于某一个后缀\(sa_k\),满足\(LCP(suffix_{sa_p} , suffix_{sa_k}) \geq r\)的\(p\)一定是一段区间,而且这一段区间随着\(r\)的缩小不断增大。
然后我们考虑如何拓展区间。考虑对于\(height_k=q\),当\(r>q\)的时候\(k\)位置两端的区间不会越过\(k-1\)与\(k\),而当\(r \leq q\)时这两段区间就会合成一段区间。这个显然是可以使用并查集维护的,并且可以比较轻松地在并查集上维护最大价值。
#include<bits/stdc++.h>
#define mid ((l + r) >> 1)
#define lch Tree[x].l
#define rch Tree[x].r
//This code is written by Itst
using namespace std;
inline int read(){
int a = 0;
char c = getchar();
bool f = 0;
while(!isdigit(c) && c != EOF){
if(c == '-')
f = 1;
c = getchar();
}
if(c == EOF)
exit(0);
while(isdigit(c)){
a = (a << 3) + (a << 1) + (c ^ '0');
c = getchar();
}
return f ? -a : a;
}
const int MAXN = 3e5 + 10;
int fa[MAXN] , val[MAXN] , valMax[MAXN][2] , valMin[MAXN][2];
int sa[MAXN] , rk[MAXN] , pot[MAXN] , tp[MAXN << 1] , h[MAXN];
int ind[MAXN] , size[MAXN] , N , maxN = 26;
char s[MAXN];
long long Max , cnt , ans[MAXN][2];
int find(int x){
return fa[x] == x ? x : (fa[x] = find(fa[x]));
}
void Debug(){
for(int i = 1 ; i <= N ; ++i)
cout << sa[i] << ' ';
cout << endl;
for(int i = 1 ; i <= N ; ++i)
cout << ind[i] << ' ';
cout << endl << endl;
}
void input(){
N = read();
scanf("%s" , s + 1);
for(int i = 1 ; i <= N ; ++i){
val[i] = read();
if(val[i] < 0)
valMin[i][0] = val[i];
}
}
void sort(int p){
memset(pot , 0 , sizeof(pot));
for(int i = 1 ; i <= N ; ++i)
++pot[rk[i]];
for(int i = 1 ; i <= maxN ; ++i)
pot[i] += pot[i - 1];
for(int i = 1 ; i <= N ; ++i)
sa[++pot[rk[tp[i]] - 1]] = tp[i];
memcpy(tp , rk , sizeof(int) * (N + 1));
for(int i = 1 ; i <= N ; ++i)
rk[sa[i]] = rk[sa[i - 1]] + (tp[sa[i]] != tp[sa[i - 1]] || tp[sa[i] + p] != tp[sa[i - 1] + p]);
maxN = rk[sa[N]];
}
bool cmp(int a , int b){
return h[a] < h[b];
}
void init(){
memset(valMax , -0x3f , sizeof(valMax));
Max = -1ll * 0x3f3f3f3f * 0x3f3f3f3f;
for(int i = 1 ; i <= N ; ++i)
rk[tp[i] = i] = s[i] - 'a' + 1;
sort(0);
for(int i = 1 ; i <= N && maxN < N ; i <<= 1){
int cnt = 0;
for(int j = 1 ; j <= i ; ++j)
tp[++cnt] = N - i + j;
for(int j = 1 ; j <= N ; ++j)
if(sa[j] > i)
tp[++cnt] = sa[j] - i;
sort(i);
}
for(int i = 1 ; i <= N ; ++i){
if(rk[i] == 1)
continue;
int t = rk[i];
h[t] = max(0 , h[rk[i - 1]] - 1);
while(s[sa[t] + h[t]] == s[sa[t - 1] + h[t]])
++h[t];
ind[t] = t;
}
sort(ind + 2 , ind + N + 1 , cmp);
for(int i = 1 ; i <= N ; ++i){
fa[i] = i;
size[i] = 1;
valMax[i][0] = val[i];
}
}
inline void merge(int x , int y){
fa[x] = y;
int num[4] = {valMax[x][0] , valMax[x][1] , valMax[y][0] , valMax[y][1]};
sort(num , num + 4);
valMax[y][0] = num[3];
valMax[y][1] = num[2];
Max = max(Max , 1ll * valMax[y][0] * valMax[y][1]);
num[0] = valMin[x][0];
num[1] = valMin[x][1];
num[2] = valMin[y][0];
num[3] = valMin[y][1];
sort(num , num + 4);
valMin[y][0] = num[0];
valMin[y][1] = num[1];
if(1ll * valMin[y][0] * valMin[y][1])
Max = max(Max , 1ll * valMin[y][0] * valMin[y][1]);
cnt -= 1ll * size[x] * (size[x] - 1) / 2 + 1ll * size[y] * (size[y] - 1) / 2;
size[y] += size[x];
cnt += 1ll * size[y] * (size[y] - 1) / 2;
}
void work(){
int p = N;
for(int i = N - 1 ; i >= 0 ; --i){
while(p > 1 && h[ind[p]] == i){
merge(find(sa[ind[p]]) , find(sa[ind[p] - 1]));
--p;
}
if(cnt){
ans[i][0] = cnt;
ans[i][1] = Max;
}
}
}
void output(){
for(int i = 0 ; i <= N - 1 ; ++i)
cout << ans[i][0] << ' ' << ans[i][1] << '\n';
}
int main(){
input();
init();
work();
output();
return 0;
}