版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yzyyylx/article/details/87907279
题面
题意
给出一个长度为n的序列a,将它的所有子区间的gcd计算出来并存入数组b(长度为 ),并将b区间排序,然后将数组b的所有子区间的和计算出来存入数组c,问数组c的中位数是多少。
做法
首先因为数组a中的所有数都小于等于100000,因此gcd的数字种类也在这个范围内,对数组a可以用倍增求出每个数字作为gcd的次数。
然后考虑二分答案,可以只要统计出数组b中小于等于mid的区间和的个数即可,若数组b的长度很小,则可以直接尺取。可是现在数组b长度很大,但是一共只有至多100000种数字,因此仍然可以尺取。
首先考虑种数字内部的贡献,这部分比较容易计算。
然后考虑左端点为l,右端点为r的贡献,若
,则贡献显然为
,用前缀和维护一下即可,比较难处理的是
的情况。
这种情况的答案可以看作是
右边是一个常数,这里计作c,令
为
的非负整数解的数量。
则这部分贡献经过容斥之后可以看作是
现在考虑计算
可以发现
函数可以直接用类欧几里得算法求。
代码
#include<bits/stdc++.h>
#define ll long long
#define LG 16
#define N 100100
#define MN 100000
using namespace std;
ll n,m,num[N],cnt[N],qzc[N],qzs[N];
namespace Get
{
ll g[N][20];
inline ll gcd(ll u,ll v)
{
for(;u&&v&&u!=v;)
{
swap(u,v);
u%=v;
}
return max(u,v);
}
void work()
{
ll i,j,k,t,l;
for(i=1;i<=n;i++) g[i][0]=num[i];
for(i=1;i<=LG;i++)
{
for(j=1;j+(1 << (i-1))<=n;j++)
{
g[j][i]=gcd(g[j][i-1],g[j+(1 << (i-1))][i-1]);
}
}
for(i=1;i<=n;i++)
{
for(j=l=i,t=num[i];j<=n;)
{
t=gcd(t,num[j]);
for(k=LG;k>=0;k--)
{
if((j+(1 << k))>n+1) continue;
if(g[j][k]%t==0)
{
j+=(1 << k);
}
}
cnt[t]+=j-l;
l=j;
}
}
}
}
ll f(ll a,ll b,ll c,ll n)
{
if(n<0) return 0;
if(!a) return b/c*(n+1);
if(a>=c || b>=c) return f(a%c,b%c,c,n)+a/c*n*(n+1)/2+b/c*(n+1);
ll m=(a*n+b)/c;
return m*n-f(c,c-b-1,a,m-1);
}
inline ll ask(ll a,ll b,ll c)
{
if(c<0) return 0;
return c/a+1+f(a,c%a,b,c/a);
}
inline ll solve(ll a,ll ca,ll b,ll cb,ll c)
{
if(a*(ca-1)+b*(cb-1)<=c) return ca*cb;
if(c<0) return 0;
return ask(a,b,c)-ask(a,b,c-a*ca)-ask(a,b,c-b*cb)+ask(a,b,c-a*ca-b*cb);
}
inline ll calc(ll u)
{
ll i,j,l,r,t,res=0;
for(i=1;i<=MN;i++)
{
if(!cnt[i]) continue;
t=min(cnt[i],u/i);
res+=cnt[i]*t-t*(t-1)/2;
}
for(l=r=1;l<MN;l++)
{
if(!cnt[l]) continue;
if(l<r) res+=cnt[l]*(qzc[r-1]-qzc[l]);
for(;r<=MN&&qzs[r]-qzs[l]<=u;r++)
{
if(l!=r && cnt[r])
{
res+=solve(l,cnt[l],r,cnt[r],u-(qzs[r-1]-qzs[l])-l-r);
}
}
if(r<=MN)
{
if(l!=r && cnt[r])
{
res+=solve(l,cnt[l],r,cnt[r],u-(qzs[r-1]-qzs[l])-l-r);
}
}
}
return res;
}
int main()
{
ll i,j,l,r,mid;
cin>>n;
for(i=1;i<=n;i++)
{
scanf("%lld",&num[i]);
}
Get::work();
m=(n+1)*n/2;
m=(m+1)*m/2;
m=(m+1)/2;
for(i=1;i<=MN;i++)
{
qzc[i]=qzc[i-1]+cnt[i];
qzs[i]=qzs[i-1]+cnt[i]*i;
}
for(l=1,r=qzs[MN]+1;l<r;)
{
mid=((l+r)>>1);
if(calc(mid)<m) l=mid+1;
else r=mid;
}
cout<<l;
}