转载博客:戳这里
树状数组的用途就是维护一个数组,重点不是这个数组,而是要维护的东西,最常用的求区间和问题,单点更新。但是某些大牛YY出很多神奇的东西,完成部分线段树能完成的功能,比如区间更新,区间求最值问题。
树状数组当然是跟树有关了,但是这个树是怎么构建的呐?这里就不得不感叹大牛们的脑洞之大了,竟然能想出来用二进制末尾零的个数多少来构建树以下图为例:
从上图能看出来每一个数的父节点就是右边比自己末尾零个数多的最近的一个,也就是x的父节点就是x+(x&(-x)),这里为什么可以参考计算机位运算,x&(-x)就能得出自己末尾0的个数例如10&(-10)=(0010)二进制。每一个节点保存的就是以他为根节点的数的和,这样就得出来了更新树状数组的函数:
int lowbit(int x) { return x&(-x); } void uodate(int x) { while(x<Max) { c[x]+=val; x+=lowbit; } }
树状数组虽然将数据用树形结构组织起来的但是还是很乱怎么办呐?实际上树状数组维护的是数组的前缀和,比如sum(x)就是a[x]的前缀和,想查询l~r区间的元素和只需要求出来sum(r)-sum(l-1),这里的sum函数十分的神奇:
{ int s=0; while(x>0) { s+=c[x]; x-=lowbit; } return s; }
x-(x&(-x))刚巧是前一个棵树的根节点,这样就能求出1到x和,以x=9为例,9&(-9)=1;这样x-(x&(-x))=8,刚巧是前一棵树的根节点。
例题:poj 2352 stars http://poj.org/problem?id=2352;
题意给你n个星星的坐标,每一个星星的等级为:在不在这个星星右边并且不比这个星星高的星星的个数 然后出处每个等级星星的个数 树状数组,一开始把上一个题的模板扒过来的......真是伤啊,这个题更新点的时候要把右边的点更新到MAXN要不然会漏掉条件的 */ #include<iostream> #include<stdio.h> #include<string.h> #include<string> #include<algorithm> #define N 32010 using namespace std; int n; int c[N]; int cur[N];//统计每个等级的星星 int lowbit(int x) { return x&(-x); } int getx(int x) { int ans=0; while(x>0) { ans+=c[x]; x-=lowbit(x); } return ans; } void update(int x) { while(x<=N) { c[x]++; x+=lowbit(x); } } int main() { //freopen("in.txt","r",stdin); int x,y; while(scanf("%d",&n)!=EOF&&n) { memset(cur,0,sizeof cur); memset(c,0,sizeof c); for(int i=1;i<=n;i++) { scanf("%d%d",&x,&y); update(x+1); //for(int j=1; j<=n; j++) // cout<<c[j]<<" "; //cout<<endl; //cout<<getx(x+1)-1<<endl; cur[getx(x+1)-1]++; } //cout<<endl; for(int i=0;i<n;i++) printf("%d\n",cur[i]); } return 0; }
然后就是更高层次的操作了,区间更新。区间更新这里引进了一个数组delta数组,delta[i]表示区间 [i, n] 的共同增量,每次你需要更新的时候只需要更新delta数组就行了,因为每段区间更新的数都记录在这个数组中,那怎么查询前缀和呐?
sum[i]=a[1]+a[2]+a[3]+......+a[i]+delta[1]*(i-0)+delta[2]*(i-1)+delta[3]*(i-2)+......+delta[i]*(i-i+1); = sigma( a[x] ) + sigma( delta[x] * (i + 1 - x) ) = sigma( a[x] ) + (i + 1) * sigma( delta[x] ) - sigma( delta[x] * x )
红字不难理解就是从当前位置到i区间共同的增量乘上当前位置到i有多少个数就是增加的总量。
例题 :poj 3468 A Simple Problem with Integers http://poj.org/problem?id=3468
#include<stdio.h> #include<string.h> #define N 100010 #define ll long long using namespace std; ll c1[N];//delta的前缀和 ll c2[N];//delta * i的前缀和 ll ans[N];//存放的前缀和 ll n,m; string op; ll lowbit(ll x) { return x&(-x); } void update(ll x,ll val,ll *c) { while(x<=n) { c[x]+=val; x+=lowbit(x); } } ll getsum(ll x,ll *c) { ll s=0; while(x>0) { s+=c[x]; x-=lowbit(x); } return s; } int main() { freopen("C:\\Users\\acer\\Desktop\\in.txt","r",stdin); while(scanf("%lld%lld",&n,&m)!=EOF) { memset(c1,0,sizeof c1); memset(c2,0,sizeof c2); memset(ans,0,sizeof ans); for(int i=1;i<=n;i++) { scanf("%lld",&ans[i]); ans[i]+=ans[i-1]; } getchar(); for(int i=1;i<=m;i++) { cin>>op; if(op=="C") { ll s1,s2,s3; scanf("%lld%lld%lld",&s1,&s2,&s3); update(s1,s3,c1);//delta的前缀和 更新 update(s2+1,-s3,c1); update(s1,s1*s3,c2);//delta * i的前缀和 更新 update(s2+1,-(s2+1)*s3,c2); } else if(op=="Q") { /* sigma( a[x] ) + (i + 1) * sigma( delta[x] ) - sigma( delta[x] * x ) */ ll s1,s2; scanf("%lld%lld",&s1,&s2); /*sigma( a[x] )*/ ll cur=ans[s2]-ans[s1-1];//首先等于s1~s2这个区间的基础值 /*(i + 1) * sigma( delta[x] )*/ cur+=getsum(s2,c1)*(s2+1)-getsum(s2,c2);//0~s2对前缀和的影响 /*sigma( delta[x] * x )*/ cur-=getsum(s1-1,c1)*(s1)-getsum(s1-1,c2);//0~s1对前缀和的影响 printf("%lld\n",cur); } } } }
接着就是RMQ算法,用来求区间最值,直接求当然是不现实的,因为数据很多的时候,复杂度太高,这样就要先进性预处理,dp[i][j]表示从i开始2^j范围内的最值,这样能推出状态转移方程 dp[i][j]=max(dp[i][j-1],dp[i+(1<<(j-1)][j-1])或者min(dp[i][j-1],dp[i+(1<<(j-1)][j-1])。怎么得出来这个方程的呐?就是以i为起点2^j的状态能由以i为起点到2^j这个范围的中点2^(j-1)左右两个部分的最值得到。
首先是预处理部分:
void RMQ_init(int n) { for(int j=1;j<20;j++) for(int i=1;(i+(1<<j)-1)<=n;i++) { dp1[i][j]=max(dp1[i][j-1],dp1[i+(1<<(j-1))][j-1]); dp2[i][j]=min(dp2[i][j-1],dp2[i+(1<<(j-1))][j-1]); } }
然后是查询
int RMQ(int L,int R) { int k=(int)(log(R-L+1.0)/log(2.0)); return max(dp1[L][k],dp1[R-(1<<k)+1][k]);或者return min(dp2[L][k],dp2[R-(1<<k)+1][k]); }
查询是什么原理呐?就是l到r的长度内取k的最大值使得2^k<(r-l+1);这样查询l到l+2^k内的最值和r-2^k到r内的最值,虽然中间有些元素有些重复但是不会影响正确结果,但是查询区间和的时候就不能这么重复了。
例题 士兵杀敌(三)http://acm.nyist.net/JudgeOnline/problem.php?pid=119
#include <bits/stdc++.h> #define N 100010 using namespace std; int dp1[N][20];//存放最大值 int dp2[N][20];//存放最小值 int n,m,a; int l,r; void RMQ_init(int n) { for(int j=1;j<20;j++) for(int i=1;(i+(1<<j)-1)<=n;i++) { dp1[i][j]=max(dp1[i][j-1],dp1[i+(1<<(j-1))][j-1]); dp2[i][j]=min(dp2[i][j-1],dp2[i+(1<<(j-1))][j-1]); //cout<<"dp1[i][j]="<<dp1[i][j]<<endl; //cout<<"dp2[i][j]="<<dp2[i][j]<<endl; } } int RMQ(int L,int R) { int k=(int)(log(R-L+1.0)/log(2.0)); //cout<<"max(dp1[L][k],dp1[R-(1<<k)+1][k])="<<max(dp1[L][k],dp1[R-(1<<k)+1][k])<<endl; //cout<<"min(dp2[L][k],dp2[R-(1<<k)+1][k])="<<min(dp2[L][k],dp2[R-(1<<k)+1][k])<<endl; return max(dp1[L][k],dp1[R-(1<<k)+1][k])-min(dp2[L][k],dp2[R-(1<<k)+1][k]); } int main() { //freopen("C:\\Users\\acer\\Desktop\\in.txt","r",stdin); //memset(dp1,0,sizeof dp1); //memset(dp2,0,sizeof dp2); scanf("%d %d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&dp1[i][0]); dp2[i][0]=dp1[i][0]; } RMQ_init(n); for(int i=1;i<=m;i++) { scanf("%d%d",&l,&r); printf("%d\n",RMQ(l,r)); } return 0; }