(对二叉树有感觉的话,思路还是出得蛮快的)
题意:给定一个
思路:如果每次处理询问都要对树上节点遍历的话,肯定超时,于是想到预处理,由于这个树是很标准的二叉树,那么我们可以试试对于每个树节点,求出每个子节点距离它的值,并保存起来(就像线段树中的push_up操作一样)。然后对于每一个询问
时间复杂度:预处理+查询:
空间复杂度:预处理的结果:
(开始以为可能会MLE,然后发现题目中内存比较大。好久没写代码,写的时候调了半天(居然以为break语句能跳出它上面的if-else语句…)。)
#include <cstdio>
#include <algorithm>
#include <vector>
#define LL long long
using namespace std;
const int maxn = 1000050;
const int inf = 0x3f3f3f3f;
int len[maxn][2];
vector<int> vec[maxn];
vector<LL> sum[maxn];
void push_up(int p, int n) {
// merge sort
int l = p<<1, r = p<<1|1, i = 0, j = 0;
int l_len = len[p][0], r_len = r<=n? len[p][1]:0;
while(i < (int)vec[l].size()-1 && j < (int)vec[r].size()-1) {
if(vec[l][i]+l_len < vec[r][j]+r_len)
vec[p].push_back(vec[l][i++]+l_len);
else
vec[p].push_back(vec[r][j++]+r_len);
if(vec[p].back() >= inf) break ;
}
while(i < (int)vec[l].size()-1 && vec[p].back() < inf)
vec[p].push_back(vec[l][i++]+l_len);
while(j < (int)vec[r].size()-1 && vec[p].back() < inf)
vec[p].push_back(vec[r][j++]+r_len);
if(vec[p].back() >= inf) vec[p].pop_back();
// calculate prefix sum
sum[p].push_back(vec[p][0]);
for(int k=1; k<(int)vec[p].size(); k++) {
sum[p].push_back(vec[p][k]);
sum[p][k] += sum[p][k-1];
}
return ;
}
void build(int p, int n) {
int lson = p<<1, rson = p<<1|1;
if(lson <= n) build(lson, n);
if(rson <= n) build(rson, n);
vec[p].clear();
vec[p].push_back(0);
if(lson <= n || rson <= n)
push_up(p, n);
vec[p].push_back(inf);
return ;
}
LL get_sum(int n, int id, int h) {
LL ret = h;
int pos = upper_bound(vec[id].begin(), vec[id].end(), h) - vec[id].begin() - 1;
if(pos >= 1) ret += (LL)h*pos - sum[id][pos];
//printf("ret1 : %I64d\n",ret);
while(id != 1) {
h -= len[id/2][id&1];
if(h > 0) ret += h;
else break ;
//printf("ret2 : %I64d\n",ret);
int id_2 = id ^ 1, branch = len[id/2][id_2&1];
if(id_2 <= n && h-branch > 0) {
ret += h-branch;
int pos = upper_bound(vec[id_2].begin(), vec[id_2].end(), h-branch) - vec[id_2].begin() - 1;
if(pos >= 1) ret += (LL)(h-branch)*pos - sum[id_2][pos];
//printf("ret3 : %I64d\n",ret);
}
id /= 2;
}
return ret;
}
int main() {
int n, m;
scanf("%d%d",&n,&m);
for(int i=1; i<n; i++) {
int t, st = (i+1)/2;
scanf("%d",&t);
len[st][(i+1)&1] = t;
}
build(1, n);
while(m --) {
int id, h;
scanf("%d%d",&id,&h);
LL ans = get_sum(n, id, h);
printf("%I64d\n",ans);
}
return 0;
}