You are given a permutation p1,p2,…,pnp1,p2,…,pn. A permutation of length nn is a sequence such that each integer between 11 and nn occurs exactly once in the sequence.
Find the number of pairs of indices (l,r)(l,r) (1≤l≤r≤n1≤l≤r≤n) such that the value of the median of pl,pl+1,…,prpl,pl+1,…,pr is exactly the given number mm.
The median of a sequence is the value of the element which is in the middle of the sequence after sorting it in non-decreasing order. If the length of the sequence is even, the left of two middle elements is used.
For example, if a=[4,2,7,5]a=[4,2,7,5] then its median is 44 since after sorting the sequence, it will look like [2,4,5,7][2,4,5,7] and the left of two middle elements is equal to 44. The median of [7,1,2,9,6][7,1,2,9,6] equals 66 since after sorting, the value 66 will be in the middle of the sequence.
Write a program to find the number of pairs of indices (l,r)(l,r) (1≤l≤r≤n1≤l≤r≤n) such that the value of the median of pl,pl+1,…,prpl,pl+1,…,pr is exactly the given number mm.
The first line contains integers nn and mm (1≤n≤2⋅1051≤n≤2⋅105, 1≤m≤n1≤m≤n) — the length of the given sequence and the required value of the median.
The second line contains a permutation p1,p2,…,pnp1,p2,…,pn (1≤pi≤n1≤pi≤n). Each integer between 11 and nn occurs in pp exactly once.
Print the required number.
5 4
2 4 5 3 1
4
5 5
1 2 3 4 5
1
15 8
1 15 2 14 3 13 4 8 12 5 11 6 10 7 9
48
In the first example, the suitable pairs of indices are: (1,3)(1,3), (2,2)(2,2), (2,3)(2,3) and (2,4)(2,4).
暴力法 超时
#include<bits/stdc++.h>
using namespace std;
int a[200005],b[200005];
int main()
{
int n,m,sum=0;
cin>>n>>m;
int i;
for(i=0;i<n;i++)
{
cin>>a[i];
b[i]=a[i];
}
int j;
for(i=0;i<n;i++)
{
for(j=i;j<n;j++)
{
sort(b+i,b+j+1);
if(b[(i+j)/2]==m)
sum++;
for(int k=i;k<=j;k++)
b[k]=a[k];
}
}
cout<<sum;
}
标准解答玄学。。对于m后面的数 若大于m和小于m的数量差和前面相等则这一段区间就可以,也可以加上前面的小数比大数多1这样这段区间里大数会比小数多1
#include<bits/stdc++.h>
using namespace std;
long long ans=0;
map<int,int>mp;
int n,m,cnt=0;
int main()
{
scanf("%d%d",&n,&m);
int i;
mp[0]=1;
int flag=0;
for(i=0;i<n;i++)
{
int k;
scanf("%d",&k);
if(k==m)
flag=1;
else if(k>m)
cnt++;
else if(k<m)
cnt--;
if(!flag)
mp[cnt]++;
if(flag)
ans=ans+mp[cnt]+mp[cnt-1];
}
printf("%lld",ans);
}