树状数组用法比较多,但是需要想清楚来龙去脉。
区间求和和逆序数是比较典型的2类应用!
下面的两个应用采用两种办法,应用一比较巧妙,应用二是常规模板
应用一:区间求和
题目链接:https://leetcode-cn.com/problems/range-sum-query-mutable/
动态更改数组内容
并在O(logn)时间内更改动态数组,返回区间[l, r]的数组的和。
初始化一颗类满二叉树
数组为[1, 7, 4, 2 , 3, 6].
树状数组,对于数组实现,如果索引 i 处的元素不是一个叶节点,那么其左子节点和右子节点分别存储在索引为 2 * i 和 2 * i + 1 的元素处。
所有值的点都是叶子节点,因此建二叉树,放在数组的最后n个点,然后倒序更新根节点。
2 * i 和 2 * i + 1 的和就是 i 点的值。
def __init__(self, nums: List[int]):
n = len(nums)
self.num = n
self.tree = [0 for i in range(2*n)]
for i in range(n):
self.tree[i+n] = nums[i]
for i in range(n-1,-1,-1):
self.tree[i] = self.tree[i * 2] + self.tree[i * 2 + 1]
更新某个节点
就要更新该节点所有的根节点,所有的根遍历一遍也就
的深度。
给一个pos,那么根为pos // 2,继续向上//2一步步更新。
def update(self, i: int, val: int) -> None:
i += self.num
self.tree[i] = val
while(i > 0):
left = i
right = i
if(i % 2 == 0):
right = i + 1
else:
left = i - 1
self.tree[i // 2] = self.tree[left] + self.tree[right]
i = i // 2
返回区间和
比如求[3, 4]区间和,先思考下下标应该如何移动。
左右界是奇是偶,,,,,,可以归纳一下。
左界是奇,右界是偶[7, 10]
左为奇数,根中包含着左-1点,因此不能加根,需要加7节点,然后往右界靠近,left ++
右为偶数,根中包含着右+1点,因此不能加根,需要加10节点,然后往左界靠近,right –
左界是奇,右界是奇[7, 9]
左为奇数,根中包含着左-1点,因此不能加根,需要加7节点,然后往右界靠近,left ++
8 + 9 = 右界根节点,所以可不加9节点,等着加根,right –
左界是偶,右界是偶[8, 10]
8 + 9 = 左界根节点,所以可不加8节点,等着加根,left++
右为偶数,根中包含着右+1点,因此不能加根,需要加10节点,然后往左界靠近,right –
左界是偶,右界是奇[8, 9]
8 + 9 = 左界根节点,所以可不加8节点,等着加根,left++
8 + 9 = 右界根节点,所以可不加9节点,等着加根,right –
总结后可发现左右对奇偶的情况讨论。
def sumRange(self, i: int, j: int) -> int:
i += self.num
j += self.num
ans = 0
while (i <= j):
if ((i % 2) == 1):
ans += self.tree[i]
i += 1
if ((j % 2) == 0):
ans += self.tree[j]
j -= 1
i = i // 2
j = j // 2
return ans
应用二:逆序数
题目链接:https://www.nowcoder.com/practice/96bd6684e04a44eb80e6a68efc0ec6c5?tpId=13&tqId=11188&tPage=1&rp=1&ru=/ta/coding-interviews&qru=/ta/coding-interviews/question-ranking
树状数组模板
和上面完全不是一个套路
def lowbit(self, x:int) -> int:
return x&(-x)
def update(self, i: int, val: int) -> None:
x = i
while(x <= self.n):
self.tree[x] += val
x += self.lowbit(x)
def query(self, x:int) -> int:
ans = 0
while(x > 0):
ans += self.tree[x]
x -= self.lowbit(x)
return ans
常用的三个函数:lowbit,query, update
结构图:
lowbit
如图可以知道
C[1]=A[1];
C[2]=A[1]+A[2];
C[3]=A[3];
C[4]=A[1]+A[2]+A[3]+A[4];
C[5]=A[5];
C[6]=A[5]+A[6];
C[7]=A[7];
C[8]=A[1]+A[2]+A[3]+A[4]+A[5]+A[6]+A[7]+A[8];
将C[]数组的结点序号转化为二进制
1=(001) | C[1]=A[1] |
2=(010) | C[2]=A[1]+A[2]; |
3=(011) | C[3]=A[3]; |
4=(100) | C[4]=A[1]+A[2]+A[3]+A[4]; |
5=(101) | C[5]=A[5]; |
6=(110) | C[6]=A[5]+A[6]; |
7=(111) | C[7]=A[7]; |
8=(1000) | C[8]=A[1]+A[2]+A[3]+A[4]+A[5]+A[6]+A[7]+A[8]; |
k为i的二进制中从最低位到高位连续零的长度,例如i=8时,k=3
query
i=7;
C[4]=A[1]+A[2]+A[3]+A[4];
C[6]=A[5]+A[6];
C[7]=A[7];
可以推出: sum[7]=C[4]+C[6]+C[7];
序号写为二进制: sum[(111)]=C[(100)]+C[(110)]+C[(111)];
i=5
C[4]=A[1]+A[2]+A[3]+A[4];
C[5]=A[5];
可以推出: sum[5]=C[4]+C[5];
序号写为二进制: sum[(101)]=C[(100)]+C[(101)];
update
当我们修改A[]数组中的某一个值时 应当如何更新C[]数组呢?更新过程是查询过程的逆过程
更新A[2]时 需要向上更新C[2], C[4], C[8]
C[2] | C[4] | C[8] |
C[(010)] | C[(100)] | C[(1000)] |
2 | C[2] += A[2] | |
lowbit(2)=010 | 2+lowbit(2)=4 | C[4] += A[2] |
lowbit(4)=100 | 4+lowbit(4)=8 | C[8] += A[2] |
数组中的逆序对:剑指offer 原题,牛客可交
class NumArray:
def __init__(self, num:int):
self.n = num
self.tree = [0 for i in range(self.n+1)]
def lowbit(self, x:int) -> int:
return x&(-x)
def update(self, i: int, val: int) -> None:
i += 1
while(i <= self.n):
self.tree[i] += 1
i += self.lowbit(i)
def query(self, x:int) -> int:
ans = 0
x += 1
if(x == 0):
return 0
while(x > 0):
ans += self.tree[x]
x -= self.lowbit(x)
return ans
def sumRange(self, i: int, j: int) -> int:
return self.query(j+1) - self.query(i)
class Solution:
def get_new(self, data):
pos, n = 0, len(data)
new_data, temp = [0], []
for i in range(n):
temp.append([data[i], i])
temp = sorted(temp,key=lambda x: (x[0], x[1]))
temp[0].append(0)
for i in range(1,n):
if(temp[i][0] == temp[i-1][0]):
temp[i].append(pos)
else:
pos += 1
temp[i].append(pos)
temp = sorted(temp,key=lambda x: (x[1], x[0], x[2]))
for i in range(n):
new_data.append(temp[i][2])
return new_data, n, temp
def InversePairs(self, data):
answ = 0
if(len(data) == 0):
return answ
new_data, n, cur = self.get_new(data)
narr = NumArray(n)
for i in range(1,1+n):
narr.update(new_data[i], 1)
answ += ((i - 1 - narr.query(new_data[i] - 1))%1000000007)
return answ
https://leetcode-cn.com/problems/count-of-smaller-numbers-after-self/
这也是一个逆序对应用的题目。
class NumArray:
def __init__(self, num:int):
self.n = num
self.tree = [0 for i in range(self.n+1)]
def lowbit(self, x:int) -> int:
return x&(-x)
def update(self, i: int, val: int) -> None:
i += 1
while(i <= self.n):
self.tree[i] += 1
i += self.lowbit(i)
def query(self, x:int) -> int:
ans = 0
x += 1
if(x == 0):
return 0
while(x > 0):
ans += self.tree[x]
x -= self.lowbit(x)
return ans
def sumRange(self, i: int, j: int) -> int:
return self.query(j+1) - self.query(i)
class Solution:
def get_new(self, data):
pos, n = 0, len(data)
new_data, temp = [0], []
for i in range(n):
temp.append([data[i], i])
temp = sorted(temp,key=lambda x: (x[0], x[1]))
temp[0].append(0)
for i in range(1,n):
if(temp[i][0] == temp[i-1][0]):
temp[i].append(pos)
else:
pos += 1
temp[i].append(pos)
temp = sorted(temp,key=lambda x: (x[1], x[0], x[2]))
for i in range(n):
new_data.append(temp[i][2])
return new_data, n, temp
def countSmaller(self, nums: List[int]) -> List[int]:
answ = []
if(len(nums) == 0):
return answ
nums = nums[::-1]
new_data, n, cur = self.get_new(nums)
narr = NumArray(n)
for i in range(1,1+n):
narr.update(new_data[i], 1)
answ.append(narr.query(new_data[i]-1))
answ = answ[::-1]
return answ
,
,
,
,
,
,
,
再用应用二的思路做一下第一题
,
,
,
,
,
,
,
,
class NumArray:
def __init__(self, nums: List[int]):
self.n = len(nums)
self.array = nums
self.tree = [0 for i in range(self.n+1)]
self.sum = [0 for i in range(self.n+1)]
for i in range(1, len(nums)+1):
self.sum[i] = self.sum[i-1] + nums[i-1]
del nums
def lowbit(self, x:int) -> int:
return x&(-x)
def update(self, i: int, val: int) -> None: # 每次更新的是与原来值的差值
x = i + 1
while(x <= self.n):
self.tree[x] += (val - self.array[i])
x += self.lowbit(x)
self.array[i] = val
def query(self, x:int) -> int:
ans = 0
while(x > 0):
ans += self.tree[x]
x -= self.lowbit(x)
return ans
def sumRange(self, i: int, j: int) -> int:
# 用原始区间和加上后来更改过的差值求和
return self.query(j+1) - self.query(i) + self.sum[j+1] - self.sum[i]