题意
求以 1 1 1为根的 n n n个点的有根树上,满足下面条件的有序点对 ( x , y ) (x,y) (x,y)数目(点 i i i的权值为 v i v_i vi):
- 一个点不为另一个点的祖先
- v x + v y = v l c a ( x , y ) v_x+v_y=v_{lca(x,y)} vx+vy=vlca(x,y)
- x x x到 y y y的路径长度小于等于给定的值 k k k
n , k ≤ 1 e 5 , 0 ≤ v i ≤ n n,k\le 1e5, 0\le v_i\le n n,k≤1e5,0≤vi≤n
解题思路:
树上启发式合并,先统计轻儿子子树中的答案,对轻儿子处理完之后清除轻儿子的痕迹,然后统计重儿子子树中的答案。统计完之后保留重儿子的痕迹,依次遍历轻儿子的结点去询问答案,处理完一颗轻子树之后把这颗子树合并进来。
具体的,对每一个权值 x x x开一颗线段树,存放权值为 x x x,深度在 [ l , r ] [l,r] [l,r]范围的点有多少个,并以此为根据来询问。
自己踩的坑:
be up to k理解为必须为k,人傻了
算出来的深度范围如果上限超过n,应该以n为上限,而不是直接不询问(上一个坑的后遗症)
所以英语很重要
#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
using namespace std;
const int maxn = 1e5 + 50;
int n, k;
vector<int> g[maxn];
int val[maxn], son[maxn], dep[maxn], sz[maxn];
int T[maxn], lc[maxn*200], rc[maxn*200], tot = 0, sum[maxn*200];
void update(int &rt, int l, int r, int pos, int x){
if(!rt) rt = ++tot;
sum[rt] += x;
if(l == r) return;
if(pos <= mid) update(lc[rt], l, mid, pos, x);
else update(rc[rt], mid+1, r, pos, x);
}
int qry(int rt, int l, int r, int L, int R){
if(!rt) return 0;
if(L <= l && r <= R) return sum[rt];
int res = 0;
if(L <= mid) res += qry(lc[rt], l, mid, L, R);
if(R > mid) res += qry(rc[rt], mid+1, r, L, R);
return res;
}
ll ans = 0;
void init(){
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; ++i) scanf("%d", &val[i]);
for(int i = 2; i <= n; ++i){
int u; scanf("%d", &u); g[u].push_back(i);}
}
void dfs1(int u){
sz[u] = 1;
for(int i = 0; i < g[u].size(); ++i){
int v = g[u][i];
dep[v] = dep[u]+1;
dfs1(v); sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}return;
}
void del(int u){
update(T[val[u]], 1, n, dep[u], -1);
for(int i = 0; i < g[u].size(); ++i) del(g[u][i]);
}
void qry(int u, int td, int tv){
int d = k+2*td-dep[u];
d = min(d, n);//Attention!!
int t = 2*tv-val[u];
if(d >= 1 && t >= 0 && t <= n) ans = ans + 2LL*qry(T[t], 1, n, 1, d);
for(int i = 0; i < g[u].size(); ++i) qry(g[u][i], td, tv);
}
void add(int u){
update(T[val[u]], 1, n, dep[u], 1);
for(int i = 0; i < g[u].size(); ++i) add(g[u][i]);
}
void dfs2(int u){
for(int i = 0; i < g[u].size(); ++i){
int v = g[u][i];
if(v == son[u]) continue;
dfs2(v);del(v);
}
if(son[u]) dfs2(son[u]);
for(int i = 0; i < g[u].size(); ++i){
int v = g[u][i]; if(v == son[u]) continue;
qry(v, dep[u], val[u]); add(v);
}
update(T[val[u]], 1, n, dep[u], 1);
}
int main()
{
init();
dep[1] = 1;
dfs1(1);dfs2(1);
cout<<ans<<endl;
}