题目大意:给一个长为 n 的序列,以及 m 条直线,从中任选三条线,定义ans 为 这三条线都覆盖的点的个数,求这个个数的期望。
首先选直线的方案总数是C[m][3],期望就是所有可行的方案的ans 之和 除以 C[m][3]。
一开始我是用差分做的,第一眼觉得完全没必要用到线段树,结果哇了,事实上确实可以用差分做,只不过第一次忽略了一些处理。
第二次是用线段树做的:考虑用 n 来建树,我们可以维护一个区间内被几条线覆盖这样的信息,然后查询的时候就是查询被多条线覆盖的区间,设有x条覆盖了这段区间,那么答案贡献就是C[x][3] * 区间长度。
但是我们只能维护哪一段区间被几条线覆盖,不能维护哪一段区间被几条线完全覆盖。
事实上,考虑三条线公共覆盖的一段区域 【l,r】,我们求解答案时能否将区间拆开来求?例如拆成 [l,mid],[mid + 1,r] ,求出分别的贡献然后合并在一起,简单的试一下发现是可行的。
这有什么用呢?因为叶子结点长度为1,一定是被完全覆盖的,所以我们只需要查找叶子结点就OK了。
然后有个小优化就是,如果当前查找的这段区间整段都没有被超过三条线覆盖,那么我们直接剪掉这后面的枝,也就是不在往下查找。
仔细一想发现我们通过查找叶子结点更新答案,那其实完全没有必要建线段树了,前面那些结点对答案完全没贡献,我们没必要维护那些结点。
考虑差分的做法,其实完全一样,当输入一条线的时候我们O(1)维护一下这条线覆盖的区域,最后O(n)求一下前缀和,然后再O(n)求解答案,复杂度为O(n + m)。比线段树更优,代码又短。
主要的思考是答案的区间可加性,用线段树一般会将查询区间拆分,合并答案的时候需要问题符合区间可加性原则。这题也只是用到了这个原则,而实际上可以不用线段树来求解(因为最后查的全是叶子结点),通过这题加深了对区间可加性的理解。
贴两份代码:
线段树:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e4+10;
int sum[maxn << 2],add[maxn << 2];
long long c[maxn][3];
long long gcd(long long a,long long b){
return !b ? a : gcd(b,a % b);
}
void build(int rt,int l,int r){
sum[rt] = 0;
add[rt] = 0;
if(l == r) return ;
int mid = l + r >> 1;
build(rt << 1,l,mid);
build(rt << 1 | 1,mid + 1,r);
}
void pushdown(int rt){
if(!add[rt]) return;
int lson = rt << 1,rson = rt << 1 | 1;
add[lson] += add[rt];
add[rson] += add[rt];
sum[lson] += add[rt];
sum[rson] += add[rt];
add[rt] = 0;
}
void update(int rt,int l,int r,int L,int R){
if(L <= l && r <= R){
sum[rt]++;
add[rt]++;
return ;
}
pushdown(rt);
int mid = l + r >> 1;
if(L <= mid) update(rt << 1,l,mid,L,R);
if(mid + 1 <= R) update(rt << 1 | 1,mid + 1,r,L,R);
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
long long query(int rt,int l,int r){
if(sum[rt] < 3) return 0;
if(l == r){
return c[sum[rt]][3];
}
pushdown(rt);
long long res = 0;
int mid = l + r >> 1;
res += query(rt << 1,l,mid);
res += query(rt << 1 | 1,mid + 1,r);
return res;
}
int n,m,x,y;
int main(){
c[1][3] = c[2][3] = c[0][3] = 0;
c[3][3] = 1;
for(long long i = 4; i <= maxn; i++)
c[i][3] = i * (i - 1) * (i - 2) / 6;
cin >> n >> m;
build(1,1,n);
for(int i = 1; i <= m; i++){
scanf("%d%d",&x,&y);
update(1,1,n,x,y);
}
long long res = query(1,1,n);
long long tmp = c[m][3];
long long g = gcd(res,tmp);
if(tmp == 0){
printf("%lld\n",0);
}
else{
if(tmp/g != 1 && res != 0)
printf("%lld/%lld\n",res/g,tmp/g);
else printf("%lld\n",res/g);
}
return 0;
}
差分:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+10;
int sum[maxn + 1];
int a[maxn + 1];
long long c[maxn + 1][3];
long long gcd(long long a,long long b){
return !b?a:gcd(b,a%b);
}
int main(){
int n,m,x,y;
scanf("%d%d",&n,&m);
c[1][3] = c[2][3] = c[0][3] = 0;
c[3][3] = 1;
for(long long i = 4; i <= maxn; i++)
c[i][3] = i * (i - 1) * (i - 2) / 6;
for(int i = 1; i <= m; i++){
scanf("%d%d",&x,&y);
a[x]++;a[y + 1]--;
}
for(int i = 1; i <= n; i++){
a[i] += a[i - 1];
}
long long res = 0;
for(int i = 1; i <= n; i++){
res += c[a[i]][3];
}
long long t = c[m][3];
long long g = gcd(res,t);
if(t == 0){
printf("%lld\n",0);
}
else{
if(t/g != 1 && res != 0)
printf("%lld/%lld\n",res/g,t/g);
else printf("%lld\n",res/g);
}
return 0;
}