Description
给定一个长度为n的颜色序列a
求四元组
的数量
Solution
以下为了方便说明,用大写字母表示一种颜色,
就是一个合法四元组
直接计算比较繁琐,考虑补集转化
如果我们能算出 和 的数量,那就可以求出 的数量了
考虑计算 ,这其实就相当于把 的限制去掉了。
枚举 ,计算两边有多少对 ,再减掉y,q相同颜色的,乘上左边x的数量就是答案
计算两边有多少对 ,可以维护增量,从左向右扫的时候,每向右一个位置,相当于将一个位置从右边移到了左边,加上多出来的减去删掉的贡献即可。
这样 就算完了
考虑如何计算
直接用数据结构好像没有什么好的办法,于是往平衡规划的方向来想。
我们设定一个阈值 ,出现次数小于等于K的颜色记作Q,大于K的颜色记作P
那么P的总数不超过
,Q的大小的平方和不超过
对于四元组分类讨论
-
:枚举颜色P,从左到右扫,设当前扫到位置 ,i之前P的出现次数为 ,它在颜色a[i]中是第t个,那么明显贡献就是 只需要用一个桶记一下i前面每个颜色的j的 和即可。时间复杂度
-
:同样枚举颜色P,从左到右扫,位置i的贡献就是 ,把式子拆开,发现只需要多维护s[]的平方和即可。时间复杂度
-
:此时所有二元组 的总数不超过nK,那么我们将所有二元组 看做二维平面上的点,现在就相当于求点 的对数,就变成二维数点问题,将所有点按照纵坐标从小到大扫一遍,树状数组查询即可。时间复杂度
总复杂度为 ,容易发现当 时总复杂度最小,为
Code
#pragma GCC optimize(2)
#include <cstdio>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 676676
#define mo 19260817
#define LL long long
using namespace std;
int n,a[N],d[N],n1,num;
vector<int> pt[N];
LL ans,c1[N],c2[N],cnt[N],le[N],ny2,c[N];
struct node
{
int x,y;
friend bool operator <(node x,node y)
{
return x.y<y.y;
}
}pr[N*20];
int lowbit(int k)
{
return k&(-k);
}
LL get(int k)
{
LL s=0;
while(k) s+=c[k],k-=lowbit(k);
return s%mo;
}
void ins(int k)
{
while(k<=n) c[k]++,k+=lowbit(k);
}
int main()
{
cin>>n;
fo(i,1,n)
{
scanf("%d",&a[i]),cnt[i]=++le[a[i]],pt[a[i]].push_back(i);
if(le[a[i]]==1) d[++d[0]]=a[i];
}
ans=0;
LL v=0;
fo(i,1,n)
{
v=(v-(cnt[i]-1)+mo)%mo;
ans=(ans+(v-(cnt[i]-1)*(le[a[i]]-cnt[i])%mo+mo)%mo*(cnt[i]-1)%mo+mo)%mo;
v=(v+(le[a[i]]-cnt[i])+mo)%mo;
}
n1=sqrt(n/log2(n)*3);
ny2=(mo+1)/2;
num=0;
fo(i,1,d[0])
{
int cl=d[i];
if(le[cl]>n1)
{
memset(c1,0,sizeof(c1));
memset(c2,0,sizeof(c2));
LL v=0;
fo(j,1,n)
{
if(a[j]==cl) v++;
else
{
ans=(ans-(le[cl]-v)*c1[a[j]]%mo+mo)%mo;
if(le[a[j]]<=n1)
{
LL vs=((c2[a[j]]+(cnt[j]-1)*v%mo*(v-1)%mo+c1[a[j]]+mo)%mo*ny2-c1[a[j]]*v%mo+mo)%mo;
ans=(ans-vs+mo)%mo;
}
c1[a[j]]=(c1[a[j]]+v)%mo;
c2[a[j]]=(c2[a[j]]+v*v)%mo;
}
}
}
}
fo(i,1,n)
{
if(le[a[i]]<=n1)
{
fo(u,0,le[a[i]]-1)
{
int p=pt[a[i]][u];
if(p==i) break;
ans=(ans+(cnt[p]-1)*(le[a[i]]-cnt[i])%mo);
pr[++num]=(node){p,i};
}
}
}
int j=1,sm=0;
fo(i,1,n)
{
int p=j;
while(j<=num&&pr[j].y<=i) ans=(ans-(sm-get(pr[j].x))%mo+mo)%mo,j++;
fo(k,p,j-1) ins(pr[k].x),sm=(sm+1)%mo;
}
printf("%lld\n",ans);
}