蓝桥杯国赛训练营:蒜头君的玩具(线段树 or 差分)

题目大意:给一个长为 n 的序列,以及 m 条直线,从中任选三条线,定义ans 为 这三条线都覆盖的点的个数,求这个个数的期望。

首先选直线的方案总数是C[m][3],期望就是所有可行的方案的ans 之和 除以 C[m][3]。
一开始我是用差分做的,第一眼觉得完全没必要用到线段树,结果哇了,事实上确实可以用差分做,只不过第一次忽略了一些处理。

第二次是用线段树做的:考虑用 n 来建树,我们可以维护一个区间内被几条线覆盖这样的信息,然后查询的时候就是查询被多条线覆盖的区间,设有x条覆盖了这段区间,那么答案贡献就是C[x][3] * 区间长度。
但是我们只能维护哪一段区间被几条线覆盖,不能维护哪一段区间被几条线完全覆盖。
事实上,考虑三条线公共覆盖的一段区域 【l,r】,我们求解答案时能否将区间拆开来求?例如拆成 [l,mid],[mid + 1,r] ,求出分别的贡献然后合并在一起,简单的试一下发现是可行的。
这有什么用呢?因为叶子结点长度为1,一定是被完全覆盖的,所以我们只需要查找叶子结点就OK了。
然后有个小优化就是,如果当前查找的这段区间整段都没有被超过三条线覆盖,那么我们直接剪掉这后面的枝,也就是不在往下查找。

仔细一想发现我们通过查找叶子结点更新答案,那其实完全没有必要建线段树了,前面那些结点对答案完全没贡献,我们没必要维护那些结点。

考虑差分的做法,其实完全一样,当输入一条线的时候我们O(1)维护一下这条线覆盖的区域,最后O(n)求一下前缀和,然后再O(n)求解答案,复杂度为O(n + m)。比线段树更优,代码又短。

主要的思考是答案的区间可加性,用线段树一般会将查询区间拆分,合并答案的时候需要问题符合区间可加性原则。这题也只是用到了这个原则,而实际上可以不用线段树来求解(因为最后查的全是叶子结点),通过这题加深了对区间可加性的理解。

贴两份代码:
线段树:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e4+10;
int sum[maxn << 2],add[maxn << 2];
long long c[maxn][3];
long long gcd(long long a,long long b){
 	return !b ? a : gcd(b,a % b);
}
void build(int rt,int l,int r){
 	sum[rt] = 0;
 	add[rt] = 0;
 	if(l == r) return ;
 	int mid = l + r >> 1;
 	build(rt << 1,l,mid);
 	build(rt << 1 | 1,mid + 1,r);
}
void pushdown(int rt){
 	if(!add[rt]) return;
 	int lson = rt << 1,rson = rt << 1 | 1;
 	add[lson] += add[rt];
 	add[rson] += add[rt];
 	sum[lson] += add[rt];
 	sum[rson] += add[rt];
 	add[rt] = 0; 
}
void update(int rt,int l,int r,int L,int R){
	if(L <= l && r <= R){
		 sum[rt]++;
		 add[rt]++;
		 return ;
	}
	pushdown(rt);
	int mid = l + r >> 1;
	if(L <= mid) update(rt << 1,l,mid,L,R);
	if(mid + 1 <= R) update(rt << 1 | 1,mid + 1,r,L,R);
 	sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
long long query(int rt,int l,int r){
	if(sum[rt] < 3) return 0;
	if(l == r){
	  	return c[sum[rt]][3];
	}
	pushdown(rt);
	long long res = 0;
	int mid = l + r >> 1;
	res += query(rt << 1,l,mid);
	res += query(rt << 1 | 1,mid + 1,r);
	return res;
}
int n,m,x,y;
int main(){
 	c[1][3] = c[2][3] = c[0][3] = 0;
 	c[3][3] = 1;
 	for(long long i = 4; i <= maxn; i++)
  		c[i][3] = i * (i - 1) * (i - 2) / 6;
 	cin >> n >> m;
 	build(1,1,n);
 	for(int i = 1; i <= m; i++){
  		scanf("%d%d",&x,&y);
  		update(1,1,n,x,y);
 	}
 	long long res = query(1,1,n);
 	long long tmp = c[m][3];
 	long long g = gcd(res,tmp);
 	if(tmp == 0){
 		printf("%lld\n",0);
 	}
 	else{
  		if(tmp/g != 1 && res != 0)
   			printf("%lld/%lld\n",res/g,tmp/g);
  		else printf("%lld\n",res/g);
 	}
 	return 0;
} 
 

差分:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+10;
int sum[maxn + 1];
int a[maxn + 1];
long long c[maxn + 1][3];
long long gcd(long long a,long long b){
 	return !b?a:gcd(b,a%b);
}
int main(){
 	int n,m,x,y;
 	scanf("%d%d",&n,&m);
 	c[1][3] = c[2][3] = c[0][3] = 0;
 	c[3][3] = 1;
 	for(long long i = 4; i <= maxn; i++)
  		c[i][3] = i * (i - 1) * (i - 2) / 6; 
 	for(int i = 1; i <= m; i++){
  		scanf("%d%d",&x,&y);
  		a[x]++;a[y + 1]--;
 	}
 	for(int i = 1; i <= n; i++){
  		a[i] += a[i - 1];
 	}
 	long long res = 0;
 	for(int i = 1; i <= n; i++){
  		res += c[a[i]][3];
 	}
 	long long t = c[m][3];
 	long long g = gcd(res,t);
 	if(t == 0){
  		printf("%lld\n",0);
 	}
 	else{
  		if(t/g != 1 && res != 0)
  			printf("%lld/%lld\n",res/g,t/g);
  		else printf("%lld\n",res/g);
 	}
 	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41997978/article/details/89374852