转载自:http://blog.csdn.net/qq_18455665/article/details/50989113
前言
- 首先说说出处:
- 清华大学 张昆玮(zkw) - ppt 《统计的力量》
- 本文
(辣鸡)编辑:BeiYu - 写这篇博客的原因:
1.zkw线段树非递归,效率高,代码短
2.网上关于zkw线段树的讲解实在是太少了
3.个人感觉很实用
更新日志
- 20160327-Part 1(zkw线段树的建立)
- 20160329-Part 2(单点操作)
- 20160329-Part 3(区间操作)
Part 1
来说说它的构造
线段树的堆式储存
我们来转成二进制看看
小学生问题:找规律
规律是很显然的
- 一个节点的父节点是这个数左移1,这个位运算就是低位舍弃,所有数字左移一位
- 一个节点的子节点是这个数右移1,是左节点,右移1+1是右节点
- 同一层的节点是依次递增的,第n层有2^(n-1)个节点
- 最后一层有多少节点,值域就是多少(这个很重要)
有了这些规律就可以开始着手建树了
- 查询区间[1,n]
最后一层不是2的次幂怎么办?
开到2的次幂!后面的空间我不要了!就是这么任性!
Build函数就这么出来了!找到不小于n的2的次幂
直接输入叶节点的信息
-
int n,M,q;
int d[N<<
1];
-
inline void Build(int n){
-
for(M=
1;M<n;M<<=
1);
-
for(
int i=M+
1;i<=M+n;i++) d[i]=in();
-
}
建完了?当然没有!父节点还都是空的呢!
维护父节点信息?
倒叙访问,每个节点访问的时候它的子节点已经处理过辣!
- 维护区间和?
for(int i=M-1;i;--i) d[i]=d[i<<1]+d[i<<1|1];
- 维护最大值?
for(int i=M-1;i;--i) d[i]=max(d[i<<1],d[i<<1|1]);
- 维护最小值?
for(int i=M-1;i;--i) d[i]=min(d[i<<1],d[i<<1|1]);
这样就构造出了一颗二叉树,也就是zkw线段树了!
如果你是压行选手的话(比如我),建树的代码只需要两行。
是不是特别Easy!
新技能Get√
Part 2
单点操作
- 单点修改
-
void Change(int x,int v){
-
d[M+x]+=v;
-
}
只是这么简单?当然不是,跟线段树一样,我们要更新它的父节点!
-
void Change(int x,int v){
-
d[x=M+x]+=v;
-
while(x) d[x>>=
1]=d[x<<
1]+d[x<<
1|
1];
-
}
没了?没了。
- 单点查询(差分思想,后面会用到)
把d维护的值修改一下,变成维护它与父节点的差值(为后面的RMQ问题做准备)
建树的过程就要修改一下咯!
-
void Build(int n){
-
for(M=
1;M<=n+
1;M<<=
1);
for(
int i=M+
1;i<=M+n;i++) d[i]=in();
-
for(
int i=M
-1;i;--i) d[i]=min(d[i<<
1],d[i<<
1|
1]),d[i<<
1]-=d[i],d[i<<
1|
1]-=d[i];
-
}
在当前情况下的查询
-
void Sum(int x,int res=0){
-
while(x) res+=d[x],x>>=
1;
return res;
-
}
Part 3
区间操作
询问区间和,把[s,t]闭区间换成(s,t)开区间来计算
-
int Sum(int s,int t,int Ans=0){
-
for (s=s+M
-1,t=t+M+
1;s^t^
1;s>>=
1,t>>=
1){
-
if(~s&
1) Ans+=d[s^
1];
-
if( t&
1) Ans+=d[t^
1];
-
}
return Ans;
-
}
- 为什么
~s&1
? -
为什么
t&1
?
变成开区间了以后,如果s是左儿子,那么它的兄弟节点一定在区间内,同理,如果t是右儿子,那么它的兄弟节点也一定在区间内! -
这样计算不会重复吗?
答案是会的!所以注意迭代的出口s^t^1
如果s,t就是兄弟节点,那么也就迭代完成了。
代码简单,即使背过也不难QuQ
- 区间最小值
-
void Sum(int s,int t,int L=0,int R=0){
-
for(s=s+M
-1,t=t+M+
1;s^t^
1;s>>=
1,t>>=
1){
-
L+=d[s],R+=d[t];
-
if(~s&
1) L=min(L,d[s^
1]);
-
if(t&
1) R=min(R,d[t^
1]);
-
}
-
int res=min(L,R);
while(s) res+=d[s>>=
1];
-
}
差分!
不要忘记最后的统计!
还有就是建树的时候是用的最大值还是最小值,这个一定要注意,影响到差分。
- 区间最大值
-
void Sum(int s,int t,int L=0,int R=0){
-
for(s=s+M
-1,t=t+M+
1;s^t^
1;s>>=
1,t>>=
1){
-
L+=d[s],R+=d[t];
-
if(~s&
1) L=max(L,d[s^
1]);
-
if(t&
1) R=max(R,d[t^
1]);
-
}
-
int res=max(L,R);
while(s) res+=d[s>>=
1];
-
}
同理。
- 区间加法
-
void Add(int s,int t,int v,int A=0){
-
for(s=s+M
-1,t=t+M+
1;s^t^
1;s>>=
1,t>>=
1){
-
if(~s&
1) d[s^
1]+=v;
if(t&
1) d[t^
1]+=v;
-
A=min(d[s],d[s^
1]);d[s]-=A,d[s^
1]-=A,d[s>>
1]+=A;
-
A=min(d[t],d[t^
1]);d[t]-=A,d[t^
1]-=A,d[t>>
1]+=A;
-
}
-
while(s) A=min(d[s],d[s^
1]),d[s]-=A,d[s^
1]-=A,d[s>>=
1]+=A;
-
}
同样是差分!差分就是厉害QuQ
zkw线段树小试牛刀(code来自hzwer.com)
-
#include<cstdio>
-
#include<iostream>
-
#define M 261244
-
using
namespace
std;
-
int tr[
524289];
-
void query(int s,int t)
-
{
-
int ans=
0;
-
for(s=s+M
-1,t=t+M+
1;s^t^
1;s>>=
1,t>>=
1)
-
{
-
if(~s&
1)ans+=tr[s^
1];
-
if(t&
1)ans+=tr[t^
1];
-
}
-
printf(
"%d\n",ans);
-
}
-
void change(int x,int y)
-
{
-
for(tr[x+=M]+=y,x>>=
1;x;x>>=
1)
-
tr[x]=tr[x<<
1]+tr[x<<
1|
1];
-
}
-
int main()
-
{
-
int n,m,f,x,y;
-
scanf(
"%d",&n);
-
for(
int i=
1;i<=n;i++){
scanf(
"%d",&x);change(i,x);}
-
scanf(
"%d",&m);
-
for(
int i=
1;i<=m;i++)
-
{
-
scanf(
"%d%d%d",&f,&x,&y);
-
if(f==
1)change(x,y);
-
else query(x,y);
-
}
-
return
0;
-
}
poj3468(code来自网络)
-
#include <cstdio>
-
#include <cstring>
-
#include <cctype>
-
#define N ((131072 << 1) + 10) //表示节点个数->不小于区间长度+2的最小2的正整数次幂*2+10
-
typedef
long
long LL;
-
inline int getc() {
-
static
const
int L =
1 <<
15;
-
static
char buf[L] , *S = buf , *T = buf;
-
if (S == T) {
-
T = (S = buf) + fread(buf ,
1 , L ,
stdin);
-
if (S == T)
-
return EOF;
-
}
-
return *S++;
-
}
-
inline int getint() {
-
static
char c;
-
while(!
isdigit(c = getc()) && c !=
'-');
-
bool sign = (c ==
'-');
-
int tmp = sign ?
0 : c -
'0';
-
while(
isdigit(c = getc()))
-
tmp = (tmp <<
1) + (tmp <<
3) + c -
'0';
-
return sign ? -tmp : tmp;
-
}
-
inline char getch() {
-
char c;
-
while((c = getc()) !=
'Q' && c !=
'C');
-
return c;
-
}
-
int M;
//底层的节点数
-
int dl[N] , dr[N];
//节点的左右端点
-
LL sum[N];
//节点的区间和
-
LL add[N];
//节点的区间加上一个数的标记
-
#define l(x) (x<<1) //x的左儿子,利用堆的性质
-
#define r(x) ((x<<1)|1) //x的右儿子,利用堆的性质
-
void pushdown(int x) {
//下传标记
-
if (add[x]&&x<M) {
//如果是叶子节点,显然不用下传标记(别忘了)
-
add[l(x)] += add[x];
-
sum[l(x)] += add[x] * (dr[l(x)] - dl[l(x)] +
1);
-
add[r(x)] += add[x];
-
sum[r(x)] += add[x] * (dr[r(x)] - dl[r(x)] +
1);
-
add[x] =
0;
-
}
-
}
-
int
stack[
20] , top;
//栈
-
void upd(int x) {
//下传x至根节点路径上节点的标记(自上而下,用栈实现)
-
top =
0;
-
int tmp = x;
-
for(; tmp ; tmp >>=
1)
-
stack[++top] = tmp;
-
while(top--)
-
pushdown(
stack[top]);
-
}
-
LL query(int tl , int tr) {
//求和
-
LL res=
0;
-
int insl =
0, insr =
0;
//两侧第一个有用节点
-
for(tl=tl+M
-1,tr=tr+M+
1;tl^tr^
1;tl>>=
1,tr>>=
1) {
-
if (~tl&
1) {
-
if (!insl)
-
upd(insl=tl^
1);
-
res+=sum[tl^
1];
-
}
-
if (tr&
1) {
-
if(!insr)
-
upd(insr=tl^
1)
-
res+=sum[tr^
1];
-
}
-
}
-
return res;
-
}
-
void modify(int tl , int tr , int val) {
//修改
-
int insl =
0, insr =
0;
-
for(tl=tl+M
-1,tr=tr+M+
1;tl^tr^
1;tl>>=
1,tr>>=
1) {
-
if (~tl&
1) {
-
if (!insl)
-
upd(insl=tl^
1);
-
add[tl^
1]+=val;
-
sum[tl^
1]+=(LL)val*(dr[tl^
1]-dl[tl^
1]+
1);
-
}
-
if (tr&
1) {
-
if (!insr)
-
upd(insr=tr^
1);
-
add[tr^
1]+=val;
-
sum[tr^
1]+=(LL)val*(dr[tr^
1]-dl[tr^
1]+
1);
-
}
-
}
-
for(insl=insl>>
1;insl;insl>>=
1)
//一路update
-
sum[insl]=sum[l(insl)]+sum[r(insl)];
-
for(insr=insr>>
1;insr;insr>>=
1)
-
sum[insr]=sum[l(insr)]+sum[r(insr)];
-
-
-
}
-
inline void swap(int &a , int &b) {
-
int tmp = a;
-
a = b;
-
b = tmp;
-
}
-
int main() {
-
//freopen("tt.in" , "r" , stdin);
-
int n , ask;
-
n = getint();
-
ask = getint();
-
int i;
-
for(M =
1 ; M < (n +
2) ; M <<=
1);
-
for(i =
1 ; i <= n ; ++i)
-
sum[M + i] = getint() , dl[M + i] = dr[M + i] = i;
//建树
-
for(i = M -
1; i >=
1 ; --i) {
//预处理节点左右端点
-
sum[i] = sum[l(i)] + sum[r(i)];
-
dl[i] = dl[l(i)];
-
dr[i] = dr[r(i)];
-
}
-
char s;
-
int a , b , x;
-
while(ask--) {
-
s = getch();
-
if (s ==
'Q') {
-
a = getint();
-
b = getint();
-
if (a > b)
-
swap(a , b);
-
printf(
"%lld\n" , query(a , b));
-
}
-
else {
-
a = getint();
-
b = getint();
-
x = getint();
-
if (a > b)
-
swap(a , b);
-
modify(a , b , x);
-
}
-
}
-
return
0;
-
}
可持久化线段树版本?!(来自http://blog.csdn.net/forget311300/article/details/44306265)
-
#include <iostream>
-
#include <cstdio>
-
#include <cstring>
-
#include <cmath>
-
#include <algorithm>
-
#include <vector>
-
#define mp(x,y) make_pair(x,y)
-
-
using
namespace
std;
-
-
const
int N =
100000;
-
const
int inf =
0x3f3f3f3f;
-
-
int a[N +
10];
-
int b[N +
10];
-
int M;
-
int lq, rq;
-
vector<pair<
int,
int> > s[N *
22];
-
-
void add(int id, int cur)
-
{
-
cur += M;
-
int lat =
0;
-
if (s[cur].size())
-
lat = s[cur][s[cur].size() -
1].second;
-
s[cur].push_back(mp(id, ++lat));
-
for (cur >>=
1; cur; cur >>=
1)
-
{
-
int l =
0;
-
if (s[cur <<
1].size())
-
l = s[cur <<
1][s[cur <<
1].size() -
1].second;
-
int r =
0;
-
if (s[cur <<
1 |
1].size())
-
r = s[cur <<
1 |
1][s[cur <<
1 |
1].size() -
1].second;
-
s[cur].push_back(mp(id, l + r));
-
}
-
}
-
-
int Q(int id, int k)
-
{
-
if (id >= M)
return id - M;
-
int l = id <<
1, r = l ^
1;
-
int ll = lower_bound(s[l].begin(), s[l].end(), mp(lq, inf)) - s[l].begin() -
1;
-
int rr = lower_bound(s[l].begin(), s[l].end(), mp(rq, inf)) - s[l].begin() -
1;
-
int kk =
0;
-
if (rr >=
0)kk = s[l][rr].second;
-
if (ll >=
0)kk = s[l][rr].second - s[l][ll].second;
-
if (kk < k)
return Q(r, k - kk);
-
return Q(l, k);
-
}
-
-
int main()
-
{
-
int n, m;
-
while (~
scanf(
"%d%d", &n, &m))
-
{
-
for (
int i =
0; i < n; i++)
-
{
-
scanf(
"%d", a + i);
-
b[i] = a[i];
-
}
-
sort(b, b + n);
-
int nn = unique(b, b + n) - b;
-
for (M =
1; M < nn; M <<=
1);
-
for (
int i =
1; i < M + M; i++)
-
{
-
s[i].clear();
-
//s[i].push_back(mp(0, 0));
-
}
-
for (
int i =
0; i < n; i++)
-
{
-
int id = lower_bound(b, b + nn, a[i]) - b;
-
add(i +
1, id);
-
}
-
while (m--)
-
{
-
int k;
-
scanf(
"%d %d %d", &lq, &rq, &k);
-
lq--;
-
int x = Q(
1, k);
-
printf(
"%d\n", b[x]);
-
}
-
}
-
return
0;
-
}
完全模板?!(来自http://blog.csdn.net/forget311300/article/details/44306265)
-
const
int N =
1e5;
-
-
struct node
-
{
-
int sum, d, v;
-
int l, r;
-
void init()
-
{
-
d =
0;
-
v =
-1;
-
}
-
void cb(node ls, node rs)
-
{
-
sum = ls.sum + rs.sum;
-
l = ls.l, r = rs.r;
-
}
-
int len()
-
{
-
return r - l +
1;
-
}
-
void V(int x)
-
{
-
sum = len() * x;
-
d =
0;
-
v = x;
-
}
-
void D(int x)
-
{
-
sum += len() * x;
-
d += x;
-
}
-
};
-
-
struct tree
-
{
-
int m, h;
-
node g[N <<
2];
-
void init(int n)
-
{
-
for (m = h =
1; m < n +
2; m <<=
1, h++);
-
int i =
0;
-
for (; i <= m; i++)
-
{
-
g[i].init();
-
g[i].sum =
0;
-
}
-
for (; i <= m + n; i++)
-
{
-
g[i].init();
-
scanf(
"%d", &g[i].sum);
-
g[i].l = g[i].r = i - m;
-
}
-
for (; i < m + m; i++)
-
{
-
g[i].init();
-
g[i].sum =
0;
-
g[i].l = g[i].r = i - m;
-
}
-
for (i = m -
1; i >
0; i--)
-
g[i].cb(g[i <<
1], g[i <<
1 |
1]);
-
}
-
void dn(int x)
-
{
-
for (
int i = h -
1; i >
0; i--)
-
{
-
int f = x >> i;
-
if (g[f].v !=
-1)
-
{
-
g[f <<
1].V(g[f].v);
-
g[f <<
1 |
1].V(g[f].v);
-
}
-
if (g[f].d)
-
{
-
g[f <<
1].D(g[f].d);
-
g[f <<
1 |
1].D(g[f].d);
-
}
-
g[f].v =
-1;
-
g[f].d =
0;
-
}
-
}
-
void up(int x)
-
{
-
for (x >>=
1; x; x >>=
1)
-
{
-
if (g[x].v !=
-1)
continue;
-
int d = g[x].d;
-
g[x].d =
0;
-
g[x].cb(g[x <<
1], g[x <<
1 |
1]);
-
g[x].D(d);
-
}
-
}
-
void update(int l, int r, int x, int o)
-
{
-
l += m -
1, r += m +
1;
-
dn(l), dn(r);
-
for (
int s = l, t = r; s ^ t ^
1; s >>=
1, t >>=
1)
-
{
-
if (~s &
1)
-
{
-
if (o)
-
g[s ^
1].V(x);
-
else
-
g[s ^
1].D(x);
-
}
-
if (t &
1)
-
{
-
if (o)
-
g[t ^
1].V(x);
-
else
-
g[t ^
1].D(x);
-
}
-
}
-
up(l), up(r);
-
}
-
int Q(int l, int r)
-
{
-
int ans =
0;
-
l += m -
1, r += m +
1;
-
dn(l), dn(r);
-
for (
int s = l, t = r; s ^ t ^
1; s >>=
1, t >>=
1)
-
{
-
if (~s &
1)ans += g[s ^
1].sum;
-
if (t &
1)ans += g[t ^
1].sum;
-
}
-
return ans;
-
}
-
};
二维情况(来自http://blog.csdn.net/forget311300/article/details/44306265)
-
#include <cstdio>
-
#include <algorithm>
-
#include <cstring>
-
#include <cmath>
-
#include <vector>
-
#include <iostream>
-
-
using
namespace
std;
-
-
const
int W =
1000;
-
-
int m;
-
-
struct tree
-
{
-
int d[W <<
2];
-
void o()
-
{
-
for (
int i =
1; i < m + m; i++)d[i] =
0;
-
}
-
void Xor(int l, int r)
-
{
-
l += m -
1, r += m +
1;
-
for (
int s = l, t = r; s ^ t ^
1; s >>=
1, t >>=
1)
-
{
-
if (~s &
1)d[s ^
1] ^=
1;
-
if (t &
1)d[t ^
1] ^=
1;
-
}
-
}
-
-
} g[W <<
2];
-
-
void chu()
-
{
-
for (
int i =
1; i < m + m; i++)
-
g[i].o();
-
}
-
-
-
void Xor(int lx, int ly, int rx, int ry)
-
{
-
lx += m -
1, rx += m +
1;
-
for (
int s = lx, t = rx; s ^ t ^
1; s >>=
1, t >>=
1)
-
{
-
if (~s &
1)g[s ^
1].Xor(ly, ry);
-
if (t &
1)g[t ^
1].Xor(ly, ry);
-
}
-
}
-
-
int Q(int x, int y)
-
{
-
int ans =
0;
-
for (
int xx = x + m; xx; xx >>=
1)
-
{
-
for (
int yy = y + m; yy; yy >>=
1)
-
{
-
ans ^= g[xx].d[yy];
-
}
-
}
-
return ans;
-
}
-
-
int main()
-
{
-
int T;
-
cin >> T;
-
int fl =
0;
-
while (T--)
-
{
-
if (fl)
-
{
-
printf(
"\n");
-
}
-
fl =
1;
-
int N, M;
-
cin >> N >> M;
-
for (m =
1; m < N +
2; m <<=
1);
-
chu();
-
while (M--)
-
{
-
char o[
4];
-
scanf(
"%s", o);
-
if (*o ==
'Q')
-
{
-
int x, y;
-
scanf(
"%d%d", &x, &y);
-
printf(
"%d\n", Q(x, y));
-
}
-
else
-
{
-
int lx, ly, rx, ry;
-
scanf(
"%d%d%d%d", &lx, &ly, &rx, &ry);
-
Xor(lx, ly, rx, ry);
-
}
-
}
-
}
-
return
0;
-
}
非递归扫描线+离散化?!(来自http://blog.csdn.net/forget311300/article/details/44306265)
-
#include <algorithm>
-
#include <iostream>
-
#include <cstdio>
-
#include <cstring>
-
#include <vector>
-
#include <cmath>
-
-
using
namespace
std;
-
-
const
int N =
111;
-
-
int n;
-
vector<
double> y;
-
-
struct node
-
{
-
double s;
-
int c;
-
int l, r;
-
void chu(double ss, int cc, int ll, int rr)
-
{
-
s = ss;
-
c = cc;
-
l = ll, r = rr;
-
}
-
double len()
-
{
-
return y[r] - y[l -
1];
-
}
-
} g[N <<
4];
-
int M;
-
-
void init(int n)
-
{
-
for (M =
1; M < n +
2; M <<=
1);
-
g[M].chu(
0,
0,
1,
1);
-
for (
int i =
1; i <= n; i++)
-
g[i + M].chu(
0,
0, i, i);
-
for (
int i = n +
1; i < M; i++)
-
g[i + M].chu(
0,
0, n, n);
-
for (
int i = M -
1; i >
0; i--)
-
g[i].chu(
0,
0, g[i <<
1].l, g[i <<
1 |
1].r);
-
}
-
-
struct line
-
{
-
double x, yl, yr;
-
int d;
-
line() {}
-
line(
double x,
double yl,
double yr,
int dd): x(x), yl(yl), yr(yr), d(dd) {}
-
bool
operator < (
const line &cc)
const
-
{
-
return x < cc.x || (x == cc.x && d > cc.d);
-
}
-
};
-
-
vector<line>L;
-
-
void one(int x)
-
{
-
if (x >= M)
-
{
-
g[x].s = g[x].c ? g[x].len() :
0;
-
return;
-
}
-
g[x].s = g[x].c ? g[x].len() : g[x <<
1].s + g[x <<
1 |
1].s;
-
}
-
-
void up(int x)
-
{
-
for (; x; x >>=
1)
-
one(x);
-
}
-
-
void add(int l, int r, int d)
-
{
-
if (l > r)
return;
-
l += M -
1, r += M +
1;
-
for (
int s = l, t = r; s ^ t ^
1; s >>=
1, t >>=
1)
-
{
-
if (~s &
1)
-
{
-
g[s ^
1].c += d;
-
one(s ^
1);
-
}
-
if (t &
1)
-
{
-
g[t ^
1].c += d;
-
one(t ^
1);
-
}
-
}
-
up(l);
-
up(r);
-
}
-
-
double sol()
-
{
-
y.clear();
-
L.clear();
-
for (
int i =
0; i < n; i++)
-
{
-
double lx, ly, rx, ry;
-
scanf(
"%lf %lf %lf %lf", &lx, &ly, &rx, &ry);
-
L.push_back(line(lx, ly, ry,
1));
-
L.push_back(line(rx, ly, ry,
-1));
-
y.push_back(ly);
-
y.push_back(ry);
-
}
-
sort(y.begin(), y.end());
-
y.erase(unique(y.begin(), y.end()), y.end());
-
init(y.size());
-
sort(L.begin(), L.end());
-
n = L.size() -
1;
-
double ans =
0;
-
for (
int i =
0; i < n; i++)
-
{
-
int l = upper_bound(y.begin(), y.end(), L[i].yl +
1e-8) - y.begin();
-
int r = upper_bound(y.begin(), y.end(), L[i].yr +
1e-8) - y.begin() -
1;
-
add(l, r, L[i].d);
-
ans += g[
1].s * (L[i +
1].x - L[i].x);
-
}
-
return ans;
-
}
-
-
int main()
-
{
-
int ca =
1;
-
while (
cin >> n && n)
-
{
-
printf(
"Test case #%d\nTotal explored area: %.2f\n\n", ca++, sol());
-
}
-
return
0;
-
}