1. 问题描述:
每一头牛的愿望就是变成一头最受欢迎的牛。现在有 N 头牛,编号从 1 到 N,给你 M 对整数 (A,B),表示牛 A 认为牛 B 受欢迎。这种关系是具有传递性的,如果 A 认为 B 受欢迎,B 认为 C 受欢迎,那么牛 A 也认为牛 C 受欢迎。你的任务是求出有多少头牛被除自己之外的所有牛认为是受欢迎的。
输入格式
第一行两个数 N,M;接下来 M 行,每行两个数 A,B,意思是 A 认为 B 是受欢迎的(给出的信息有可能重复,即有可能出现多个 A,B)。
输出格式
输出被除自己之外的所有牛认为是受欢迎的牛的数量。
数据范围
1 ≤ N ≤ 10 ^ 4,
1 ≤ M ≤ 5 × 10 ^ 4
输入样例:
3 3
1 2
2 1
2 3
输出样例:
1
样例解释
只有第三头牛被除自己之外的所有牛认为是受欢迎的。
来源:https://www.acwing.com/problem/content/description/1176/
2. 思路分析:
分析题目可以知道比较容易想到的是对于每一头牛,判断一下其余的牛是否可以到达,时间复杂度为O(n(n + m)),也即O(n ^ 2)的时间复杂度,肯定会超时的;若当前给出的图是有向无环图(DAG)也即拓扑图,那么问题就比较简单了,如果有两个终点那么至少存在一头牛不被另外一头牛欢迎,那么答案为0,这里终点的意思指的是出度为0的点,如果只存在一个出度为0的点说明答案是1,所以如果当前的图为拓扑图的时候只需要判断出度为0的点的个数即可。我们可以通过强连通分量算法将原图转换为一个拓扑图,这样时间复杂度是线性的,一般可以使用tarjan算法求解强连通分量,求解的过程基本上是固定的,在理解的基础上进行记忆,把模板记熟悉即可。具体的步骤如下:
- 求强连通分量,在求解的时候基于深度优先遍历(dfs)的顺序,首先需要两个数组dfn和low,其中dfn数组记录dfs遍历每一个节点的时间戳,我们可以使用一个全局timestamp来记录dfs遍历节点的先后顺序,在dfs的过程中给每一个遍历的节点一个编号,low[u]表示从节点u开始走能够遍历到的最小时间戳,在dfs的过程中更新dfn和low数组的值,dfn数组比较好处理,每一次递归调用之前赋值当前的时间戳即可,对于一开始的时候low[u]等于当前遍历节点的时间戳,表示当前的节点u能够到达的最早时间戳,也即dfn[u] = low[u] = ++timestamp,在回溯的时候(当前递归方法调用的后面位置)来更新当前的low[u],这里在更新的low[u]的时候需要借助于一个栈stk和记录当前的节点是否在栈中的数组in_stk,遍历到当前节点的时候那么stk就将当前的元素入栈,stk[++top] = u,并且标记当前的元素已经在栈中,in_stk[u] = 1;当前的节点还没有遍历到那么递归调用当前的元素,回溯的时候由子节点的low[next]来更新当前的low[u],如果当前的元素之前已经遍历过并且还在栈in_stk中那么:low[u] = min(low[u],dfn[next]),因为当前节点u的邻接点next可能是强连通分量的最高点(强连通分量最开始的那个点),所以需要使用dfn来更新low[u]取一个最小值,可以画一下图会更好理解,可以参照下面的图,可以发现当遍历完当前节点u的所有邻接点之后如果dfn[u] == low[u],说明当前的节点u是强连通分量的最高点,此时我们可以通过stk中栈中保存的元素找出当前强连通分量的所有点,可以使用do-while循环来找,在找当前强连通分量的过程中标记当前强连通分量中的点属于哪一个强连通分量,记录到Size中表示对应编号的强连通分量的点的个数加1,并且强连通分量的个数scc_cnt加1。
- 缩点,遍历所有点,然后遍历当前点的所有邻接点,如果idx[i] != idx[next]说明当前两个点不在同一个强连通分量中,那么需要加一条边,具体的体现是对应强连通分量编号的出度加1,所以在缩点的时候并不需要将图建出来,最终每一个强连通分量对应编号都对应着一个出度,我们只需要遍历所有强连通分量的出度,如果只有一个出度为0的点(将强连通分量看成是一个点),那么强连通分量中点的个数就是答案,如果大于1说明答案就是0
3. 代码如下:
c++:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 10010, M = 50010;
int n, m;
int h[N], e[M], ne[M], idx;
int dfn[N], low[N], timestamp;
int stk[N], top;
bool in_stk[N];
int id[N], scc_cnt, Size[N];
int dout[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void tarjan(int u)
{
dfn[u] = low[u] = ++ timestamp;
stk[ ++ top] = u, in_stk[u] = true;
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (!dfn[j])
{
tarjan(j);
low[u] = min(low[u], low[j]);
}
else if (in_stk[j]) low[u] = min(low[u], dfn[j]);
}
if (dfn[u] == low[u])
{
++ scc_cnt;
int y;
do {
y = stk[top -- ];
in_stk[y] = false;
id[y] = scc_cnt;
Size[scc_cnt] ++ ;
} while (y != u);
}
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
while (m -- )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
}
for (int i = 1; i <= n; i ++ )
if (!dfn[i])
tarjan(i);
for (int i = 1; i <= n; i ++ )
for (int j = h[i]; ~j; j = ne[j])
{
int k = e[j];
int a = id[i], b = id[k];
if (a != b) dout[a] ++ ;
}
int zeros = 0, sum = 0;
for (int i = 1; i <= scc_cnt; i ++ )
if (!dout[i])
{
zeros ++ ;
sum += Size[i];
if (zeros > 1)
{
sum = 0;
break;
}
}
printf("%d\n", sum);
return 0;
}
python:由于递归调用次数太大了导致堆栈溢出,只过了17个数据:
from typing import List
import sys
class Solution:
# timestamp用来记录节点访问的顺序, in_stack表示是否在栈中, stk保存正在遍历的元素, idx表示节点位于强连通分量的编号, Size记录每一个强联通分量中节点的数目, scc_cnt记录强连通分量的个数
timestamp, in_stack, stk, idx, Size, scc_cnt, top = None, None, None, None, None, None, None
def tarjan(self, u: int, dfn: List[int], low: List[int], g: List[List[int]]):
# 一开始的时候当前的dfn与low列表等于时间戳
dfn[u] = low[u] = self.timestamp + 1
self.timestamp += 1
self.in_stack[u] = 1
self.stk[self.top + 1] = u
self.top += 1
for next in g[u]:
if dfn[next] == 0:
self.tarjan(next, dfn, low, g)
# 由子节点的low[u]更新当前的low[u]
low[u] = min(low[u], low[next])
elif self.in_stack[next] == 1:
# 更新low[u], 当前的邻接点next可能是强连通分量的最高点
low[u] = min(low[u], dfn[next])
# 判断当前的节点u是否是强连通分量的最高点, 也即判断low[u] == dfn[u]
if dfn[u] == low[u]:
# 强连通分量的数量加1
self.scc_cnt += 1
# 找出强连通分量的各个点, 这里可以使用while-true循环, 一般c++语言可以使用do-while循环
while True:
# 栈顶元素
t = self.stk[self.top]
self.top -= 1
self.in_stack[t] = 0
# 标记当前点属于哪一个强连通分量
self.idx[t] = self.scc_cnt
# 对应强连通分量编号的点数加1
self.Size[self.scc_cnt] += 1
# 相等的时候说明当前属于同一个强连通分量的所有点已经找到了, 退出循环
if t == u: break
def process(self):
# m条边
n, m = map(int, input().split())
g = [list() for i in range(n + 10)]
for i in range(m):
a, b = map(int, input().split())
g[a].append(b)
self.timestamp = self.scc_cnt = self.top = 0
self.Size, self.stk, self.in_stack, self.idx = [0] * (n + 10), [0] * (n + 10), [0] * (n + 10), [0] * (n + 10)
# dfn列表记录每一个节点访问的时间戳, low记录每一个点能够遍历到的最小时间戳
dfn, low = [0] * (n + 10), [0] * (n + 10)
for i in range(1, n + 1):
if dfn[i] == 0:
self.tarjan(i, dfn, low, g)
# dout记录每一个强连通分量整体的出度
dout = [0] * (self.scc_cnt + 1)
# 缩点: 遍历所有节点的邻接点
for i in range(1, n + 1):
for next in g[i]:
if self.idx[i] != self.idx[next]:
# 计算每一个强连通分量的出度
dout[self.idx[i]] += 1
# count记录出度为0的强连通分量的个数
res, count = 0, 0
for i in range(1, self.scc_cnt + 1):
if dout[i] == 0:
count += 1
res += self.Size[i]
# 出度的个数大于0说明答案为0
if count > 1:
res = 0
break
return res
if __name__ == "__main__":
# 设置最大递归调用次数
sys.setrecursionlimit(50000)
print(Solution().process())