1. 问题描述:
传说中的暗之连锁被人们称为 Dark。Dark 是人类内心的黑暗的产物,古今中外的勇者们都试图打倒它。经过研究,你发现 Dark 呈现无向图的结构,图中有 N 个节点和两类边,一类边被称为主要边,而另一类被称为附加边。Dark 有 N–1 条主要边,并且 Dark 的任意两个节点之间都存在一条只由主要边构成的路径。另外,Dark 还有 M 条附加边。你的任务是把 Dark 斩为不连通的两部分。一开始 Dark 的附加边都处于无敌状态,你只能选择一条主要边切断。一旦你切断了一条主要边,Dark 就会进入防御模式,主要边会变为无敌的而附加边可以被切断。但是你的能力只能再切断 Dark 的一条附加边。现在你想要知道,一共有多少种方案可以击败 Dark。注意,就算你第一步切断主要边之后就已经把 Dark 斩为两截,你也需要切断一条附加边才算击败了 Dark。
输入格式
第一行包含两个整数 N 和 M。之后 N–1 行,每行包括两个整数 A 和 B,表示 A 和 B 之间有一条主要边。之后 M 行以同样的格式给出附加边。
输出格式
输出一个整数表示答案。
数据范围
N ≤ 100000,M ≤ 200000,数据保证答案不超过2 ^ 31−1
输入样例:
4 1
1 2
2 3
1 4
3 4
输出样例:
3
来源:https://www.acwing.com/problem/content/description/354/
2. 思路分析:
分析题目可以知道我们需要先砍一条树边,然后再砍一条非树边,我们可以先画画图观察一下有什么规律,可以发现当我们加入一条非树边之后那么在树中就会形成一个环,对于环上的每一条树边来说,如果我们希望砍完当前的树边之后图是不连通的那么就需要砍掉当前对应的非树边的数量,我们可以给树中的每一条边都加上一个权重,表示砍完当前的树边之后还需要砍多少条非树边才可以使得图是不连通的,当我们再加入一条非树边之后那么每条边的权重也会相应增加,所以可以发现我们只需要统计一下每条树边,计算删除当前的树边之后还需要删除多少条非树边才可以使得图不联通,分为三种情况(c为边的权重,m为非树边的数量):
- c = 0,+m,说明删除掉任意一条非树边都可以
- c = 1,+1,说明需要删除掉对应的那条非树边
- c > 1,+0,说明需要删除至少两条非树边才可以使得图是不连通的,而题目中只能够删除一条所以不满足要求,+0
所以问题就转化为如何快速给树中的每一条边加上一个权重,其实这个过程类似于一维数组中给数组中的某一段加上一个数字的过程,这里其实是给树中的每一条边加上一个权重,也可以使用差分的思想(也即树上差分),如果在树上x-->y的路径中加上一个数字,那么具体的实现如下:
- d(x) += c
- d(y) += c
- d(p) -= 2c
其中p为x和y的最近公共祖先,c为边的权重,d(i)表示节点i的权重,我们相当于是把边的权重累加到树中的节点上,对数组d中的x,y,p位置操作只会影响x,y,p所在路径的节点,对于其余节点是没有什么影响的,有了d数组之后那么我们枚举树中的每一条边,计算删除当前的树边之后还需要删除多少条非树边,根据上面的三种情况分别累加方案数目即可。将边的权重累加到节点上的一个好处是方便后面dfs枚举每一条树边(dfs枚举邻接点的过程其实是枚举以当前根节点下所有连接的边的过程):
3. 代码如下:
import collections
from typing import List
class Solution:
# 记录删除每一条非树边的方案数目
ans = 0
def bfs(self, fa: List[List[int]], depth: List[int], g: List[List[int]]):
# 这里选择1号点作为根节点
depth[0], depth[1] = 0, 1
# 将1号节点加入到队列中
q = collections.deque([1])
while q:
p = q.popleft()
for next in g[p]:
if depth[next] > depth[p] + 1:
depth[next] = depth[p] + 1
q.append(next)
j = next
fa[j][0] = p
for k in range(1, 17):
fa[j][k] = fa[fa[j][k - 1]][k - 1]
# 求解a和b的最近公共祖先
def lca(self, a: int, b: int, fa: List[List[int]], depth: List[int]):
if depth[a] < depth[b]:
a, b = b, a
for k in range(16, -1, -1):
if depth[fa[a][k]] >= depth[b]:
a = fa[a][k]
if a == b: return a
for k in range(16, -1, -1):
if fa[a][k] != fa[b][k]:
a = fa[a][k]
b = fa[b][k]
return fa[a][0]
# dfs方法返回子树的节点之和, 可以发现在dfs过程中其实枚举删除的是节点往上的那条边, m为非树边的数目
def dfs(self, m: int, u: int, fa: int, d: List[int], g: List[List[int]]):
res = d[u]
for next in g[u]:
# 因为是无向边所以需要先判断一下是否是父节点, 如果是父节点那么跳过, 当不是父节点的时候才继续递归
if next != fa:
t = self.dfs(m, next, u, d, g)
# dfs返回的是当前节点下面子节点对应边的所有权重, 枚举的是删除连接当前节点的上一条边的情况
if t == 0:
self.ans += m
elif t == 1:
self.ans += 1
res += t
return res
def process(self):
# n个点, m条非树边
n, m = map(int, input().split())
g = [list() for i in range(n + 1)]
for i in range(n - 1):
a, b = map(int, input().split())
# 因为是无向边所以需要添加两个方向上的边
g[a].append(b)
g[b].append(a)
INF = 10 ** 10
fa, depth = [[0] * 17 for i in range(n + 1)], [INF] * (n + 1)
self.bfs(fa, depth, g)
# 记录每一个点上的权重
d = [0] * (n + 1)
for i in range(m):
a, b = map(int, input().split())
# 求解a, b的最近公共祖先
p = self.lca(a, b, fa, depth)
# 树上差分
d[a] += 1
d[b] += 1
d[p] -= 2
self.ans = 0
self.dfs(m, 1, -1, d, g)
return self.ans
if __name__ == "__main__":
print(Solution().process())