透彻理解树状数组

木桩涂涂看

通过这个题彻底理解了树状数组,其实getsum(i)就是最后i这个位置的值是多少,树状数组只不过是对区间操作进行了一次加强,使得不使用o(n)的时间,最后时间节省为o(lgn),刚开始理解老是觉得getsum(i)就是将所有之前的全部加起来,其实人家有个lowbit操作就很神奇,并非所想的那样。使用其实很简单,板子放main上面,下面就套用就想成一般的,只不过最后人家对应的是getsum(i)是最后这个位置上的值。


//透彻理解树状数组中的getsum到底是啥
#include<iostream>
#include<cstdio>

using namespace std;

const int maxn=1e5;

int tree[maxn];

int lowbit(int t)
{
    return t&(-t);
}

void change(int t,int v)
{
    for(;t<=maxn;t+=lowbit(t)){
        tree[t]+=v;
    }
}
int getsum(int t)
{
    int res=0;
    for(;t;t-=lowbit(t)){
        res+=tree[t];
    }
    return res;
}

int main()
{
    int n;
    scanf("%d",&n);
    int a,b;
    for(int i=0;i<n;i++){
        scanf("%d%d",&a,&b);
        change(a,1);
        change(b+1,-1);
    }
    for(int i=1;i<=n;i++){
        if(i!=1){
            printf(" ");
        }
        printf("%d",getsum(i));
    }
    return 0;
}

还有一种使用线段树Lazy操作

//线段树Lazy操作
#include<iostream>
#include<cstdio>

using namespace std;

typedef long long LL;

const int maxn=1e5+100;

struct node{
    LL l,r,sum;
};
node tree[maxn<<2];
LL lazy[maxn<<2];
LL a[maxn];

lazy核心操作
void push_down(LL p)
{
    lazy[p<<1]+=lazy[p];
    lazy[p<<1|1]+=lazy[p];
    tree[p].sum+=(tree[p].r-tree[p].l+1)*lazy[p];
    lazy[p]=0;
}

void build(LL p,LL l,LL r)
{
    tree[p].l=l;
    tree[p].r=r;
    tree[p].sum=0;
    if(l==r){
        return ;
    }
    int mid=(l+r)>>1;
    build(p<<1,l,mid);
    build(p<<1|1,mid+1,r);
}

void update(LL p,LL l,LL r,LL v)
{
    if(tree[p].l==l&&tree[p].r==r){
        lazy[p]+=v;
        return ;
    }
    tree[p].sum+=(r-l+1)*v;
    if(lazy[p]!=0){
        push_down(p);
    }
    LL mid=(tree[p].l+tree[p].r)>>1;
    if(r<=mid){
        update(p<<1,l,r,v);
    }else if(l>mid){
        update(p<<1|1,l,r,v);
    }else{
        update(p<<1,l,mid,v);
        update(p<<1|1,mid+1,r,v);
    }
}

LL find(LL p,LL l,LL r)
{
    if(tree[p].l==l&&tree[p].r==r){
        return tree[p].sum+(r-l+1)*lazy[p];
    }
    if(lazy[p]!=0){
        push_down(p);
    }
    LL mid=(tree[p].l+tree[p].r)>>1;
    if(r<=mid){
        return find(p<<1,l,r);
    }else if(l>mid){
        return find(p<<1|1,l,r);
    }else{
        return find(p<<1,l,mid)+find(p<<1|1,mid+1,r);
    }
}

int main()
{
    int n;
    scanf("%d",&n);
    build(1,1,n);
    int x,y;
    for(int i=1;i<=n;i++){
        scanf("%d%d",&x,&y);
        update(1,x,y,1);
    }
    for(int i=1;i<=n;i++){
        if(i!=1){
            printf(" ");
        }
        printf("%lld",find(1,i,i));
    }
    printf("\n");
    return 0;
}

猜你喜欢

转载自blog.csdn.net/zhouzi2018/article/details/81096060