树上问题
题目链接:ybt金牌导航5-2-1 / luogu P4178
题目大意
有一个树,然后边有距离。
要你统计有多少个点对,使得他们树上距离不超过一个值。
思路
这道题我们考虑点分治。
假设你任意确定了一个根节点,然后从根节点开始,现在到了一个点。
那你构成的点对有三种情况。
分别是两个点都在它的子树中,一个点是它,还有经过它一个在子树中三种。
那两个都在它的子树中就可以递归求解。
一个点是它,那你可以在外层里们找个距离比它小的个数。
那一个是子树中的我们可以把到子树中点的距离和其他点的距离求出来,然后用双指针来求。
双指针就是一个是在最小,一个是在最大,然后如果加起来不大于要求,就统计确定这个最小时能有方案,把最小的变大,否则就是要把最大的变小。
其实你会发现第二第三个可以合并在一起。
然后就变成直接找出现在树中所有点到它的距离,然后用双指针跑。
但是你会发现它也包含了第一种,而且这个好像还不能用来算第一种。
那怎么办呢?我们考虑容斥,在用同样的方法,算你现在找的子树,然后把能算到的方案给去掉。(你可以以现在找的子树的根为起点,然后弄个初始距离为根节点到起点的距离)
然后你想,你这么搞,就要尽可能得减少树的深度。
那这个时候我们就不能乱选根节点,对于一个子树,我们应该选重心为根节点。
重心就是使得它的最大的子树大小最小的点,这样就可以是深度减小。
至于求重心,就以任意一个点为根一个 dfs 过去,它的子树是它为根时的一个子树,还有一个它为根时的子树就是这个树除了它和它的子树的点构成的子树。
然后就是这样了。
(我也不知道我的代码为什么时间那么长,可能是实现方法比较垃圾,但是能过)
代码
#include<queue>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
struct node {
int x, to, nxt;
}e[100001];
struct Tree {
int s, f;
}t[50001];
int n, x, y, z, le[50001], KK;
bool in[50001], inn[50001];
int d[50001], dis[50001];
int k, fa[50001];
ll ans;
bool cmp(int x, int y) {
return x < y;
}
void add(int x, int y, int z) {
e[++KK] = (node){
z, y, le[x]}; le[x] = KK;
}
void dfs_root(int now, int &root, int &minn, int size) {
//找树的重心
int maxn = 0;
t[now].s = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (!inn[e[i].to] && !in[e[i].to]) {
inn[e[i].to] = 1;
t[e[i].to].f = now;
dfs_root(e[i].to, root, minn, size);
t[now].s += t[e[i].to].s;
maxn = max(maxn, t[e[i].to].s);
}
maxn = max(maxn, size - t[now].s);
if (maxn < minn) {
minn = maxn;
root = now;
}
}
int get_size(int now, int father) {
//得到当前树的大小
int re = 1;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to]) {
re += get_size(e[i].to, now);
}
return re;
}
int find_root(int now) {
//寻找重心
int re = 1e9, tmp = 0;
memset(inn, 0, sizeof(inn));
inn[now] = 1;
dfs_root(now, tmp, re, get_size(now, 0));
return tmp;
}
void get_d(int now, int father, int dis) {
//得出树上每个点到重心的距离
d[++d[0]] = dis;
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father && !in[e[i].to])
get_d(e[i].to, now, dis + e[i].x);
}
ll count(int rt, int cs) {
memset(d, 0, sizeof(d));
get_d(rt, 0, cs);
ll re = 0;
sort(d + 1, d + d[0] + 1, cmp);//排序每个点到它的距离后双指针求合法对数
int l = 1, r = d[0];
while (l < r) {
if (d[l] + d[r] <= k) {
re += 1ll * (r - l);
l++;
}
else r--;
}
return re;
}
void dfs(int now) {
int root = find_root(now);//找到重心
in[root] = 1;
ans += count(root, 0);//算第二第三种情况
for (int i = le[root]; i; i = e[i].nxt)
if (!in[e[i].to])
ans -= count(e[i].to, e[i].x);//容斥,去掉第一种情况
for (int i = le[root]; i; i = e[i].nxt)
if (!in[e[i].to]) dfs(e[i].to);//算第一种情况
}
int main() {
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%d %d %d", &x, &y, &z);
add(x, y, z);
add(y, x, z);
}
scanf("%d", &k);
dfs(1);
printf("%lld", ans);
return 0;
}