P3250 [HNOI2016]网络 (整体二分 + 树上差分)

题目链接: P3250 [HNOI2016]网络

大致题意

给定一棵有 n n n个节点的树, 有 m m m次如下操作:

0 a b c 表示在 ( a , b ) (a, b) (a,b)的最短路径上增加一条重要度 c c c的边.

1 t 表示删除第 t t t次操作所增加的边

2 x 表示节点 x x x出现故障. 此时需要回答, 所有不经过 x x x节点的边中最大的重要度.

解题思路

整体二分

我们考虑到二分答案, 假设当前二分值为 m i d mid mid, 我们把所有 a l l all all ≥ m i d \ge mid mid的边加入, 判断对于 x x x节点而言, 此时经过的边数 n u m num num a l l all all的关系.

如果 n u m = = a l l num == all num==all 此时答案出现在 [ l , m i d − 1 ] [l, mid - 1] [l,mid1], 否则答案在 [ m i d , r ] [mid, r] [mid,r].


我们考虑如何统计经过某个节点的所有边:
​ 比较直观的方法是, 我们可以通过树链剖分来对于 ( a , b ) (a, b) (a,b)路径进行修改. 此时修改复杂度为 O ( l o g 2 n ) O(log^2n) O(log2n).
​ 当然我们也可以通过树上差分的思路来维护, 查询时, 经过 x x x节点的总边数应当是 x x x所在子树的区间和. 此时修改复杂度为 O ( l o g n ) O(logn) O(logn).


于是我们可以通过整体二分 + 树上差分的思路, 做到 O ( m l o g 2 n ) O(mlog^2n) O(mlog2n)的复杂度.

AC代码

#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 1E5 + 10, M = 2E5 + 10, INF = 0x3f3f3f3f;
int n, m;
/* 离散化模版 */
vector<int> v(1, -1); 
int find(int x) {
    
     return lower_bound(v.begin(), v.end(), x) - v.begin(); }
void discrete() {
    
     sort(v.begin(), v.end()); v.erase(unique(v.begin(), v.end()), v.end()); }

vector<int> edge[N];
/* 树链剖分求lca模板 */
int p[N], dep[N], sz[N], son[N];
void dfs1(int x = 1, int fa = 0) {
    
    
	p[x] = fa, dep[x] = dep[fa] + 1, sz[x] = 1; // son[x] = 0;
	for (auto& to : edge[x]) {
    
    
		if (to == fa) continue;
		dfs1(to, x);
		sz[x] += sz[to];
		if (sz[to] > sz[son[x]]) son[x] = to;
	}
}
int id[N], top[N], ind;
void dfs2(int x = 1, int tp = 1) {
    
    
	id[x] = ++ind, top[x] = tp;

	if (!son[x]) return;
	dfs2(son[x], tp);

	for (auto& to : edge[x]) {
    
    
		if (to == p[x] or to == son[x]) continue;
		dfs2(to, to);
	}
}
int lca(int a, int b) {
    
    
	while (top[a] != top[b]) {
    
    
		if (dep[top[a]] < dep[top[b]]) swap(a, b);
		a = p[top[a]];
	}
	return id[a] < id[b] ? a : b;
}


/* 树状数组模版 */
struct BIT {
    
    
	int t[N];
	static int lowbit(int x) {
    
     return x & -x; }
	void add(int x, int c) {
    
     for (int i = x; i <= n; i += lowbit(i)) t[i] += c; }
	void modify(int a, int b, int c) {
    
    
		int lca = ::lca(a, b), plca = p[lca];
		add(id[a], c), add(id[b], c);
		add(id[lca], -c);
		if (id[plca]) add(id[plca], -c);
	}
	int ask(int x) {
    
    
		int res = 0;
		for (int i = x; i; i -= lowbit(i)) res += t[i];
		return res;
	}
	int ask(int l, int r) {
    
     return ask(r) - ask(l - 1); }
}bit;


struct operation {
    
    
	int tp, a, b, v, id;
	//  0   a  b  v  NULL
	//  1   a  b  v  NULL
	//  2   x        id
}; vector<operation> area;
int res[M];
void fact(int l, int r, vector<operation>& q) {
    
    
	if (q.empty()) return;
	if (l == r) {
    
    
		for (auto& op : q) if (op.id) res[op.id] = v[l];
		return;
	}

	int mid = l + r >> 1;
	vector<operation> ql, qr;
	int all = 0;
	for (auto& op : q) {
    
    
		if (op.tp == 2) {
    
    
			int cou = bit.ask(id[op.a], id[op.a] + sz[op.a] - 1);
			if (cou == all) ql.push_back(op);
			else qr.push_back(op);
		}
		else {
    
    
			if (op.v > mid) {
    
    
				int flag = !op.tp ? 1 : -1; all += flag;
				bit.modify(op.a, op.b, flag);
				qr.push_back(op);
			}
			else ql.push_back(op);
		}
	}

	for (auto& op : qr) {
    
    
		if (op.tp != 2) {
    
    
			int flag = !op.tp ? 1 : -1; all += flag;
			all -= flag;
			bit.modify(op.a, op.b, -flag);
		}
	}

	fact(l, mid, ql), fact(mid + 1, r, qr);
}
int main()
{
    
    
	cin >> n >> m;
	rep(i, n - 1) {
    
    
		int a, b; scanf("%d %d", &a, &b);
		edge[a].push_back(b), edge[b].push_back(a);
	}
	dfs1(), dfs2();

	rep(i, m) {
    
    
		int tp; scanf("%d", &tp);
		if (!tp) {
    
    
			int a, b, val; scanf("%d %d %d", &a, &b, &val);
			area.push_back({
    
     0, a, b, val, NULL });
			v.push_back(val);
			res[i] = INF;
		}
		else if (tp == 1) {
    
    
			int x; scanf("%d", &x);
			area.push_back(area[x - 1]);
			area.back().tp = 1;
			res[i] = INF;
		}
		else {
    
    
			int x; scanf("%d", &x);
			area.push_back({
    
     2, x, NULL, NULL, i });
		}
	}
	discrete();

	for (auto& op : area) if (op.tp != 2) op.v = find(op.v);

	fact(0, v.size() - 1, area);

	rep(i, m) if (res[i] != INF) printf("%d\n", res[i]);

	return 0;
}

END

猜你喜欢

转载自blog.csdn.net/weixin_45799835/article/details/121480102