【图论】B040_NK_“好序列”的个数(快速幂 + 求差)

一、描述

现在你面前有一棵n个节点的树(全连通无环图)。树上的边只有2种颜色,红色或者黑色。现在还给你一个整数k,考虑下面这个k个节点的序列[a1, a2, …, ak]。

[a1, a2, …, ak]如果是”好序列“当且仅当满足下面的条件:

  1. 我们要走一条从a1开始到ak结束的路径。
  2. 从a1开始,到a2走一条a1到a2的最短路。然后从a2开始,继续走一条到a3的最短路,以此类推,最终到a(k-1)和ak。
  3. 走的路径中至少包含一条黑色的边。
    在这里插入图片描述

我们看一下上面的图片中的树,如果k=3,那么下面的序列是“好序列”:[1,4,7], [5,5,3]。下面的序列不是好序列: [1,4,6], [5,5,5], [3,7,3]。

总共有 n k n^k (n的k次方种路径方案),那么有多少路径是“好序列”呢?这个值可能非常大,输出的结果对(10^9+7)取模就可以。

输入描述:

第一行是2个整数n和k,其中(2 <= n <= 10^5, 2 <= k <= 100),n表示树的节点个数,k表示序列的长度。

下面n-1行,每行包含3个整数,u[i], v[i], w[i],其中1 <= u[i], v[i] <= n, w[i] = 0或1。u[i], v[i]表示这两个节点之间有一条边,w[i]表示这条边的颜色,其中0表示红色,1表示黑色。

输出描述:

输出所有“好序列”的个数模(10^9+7)

输入
4 4
1 2 1
2 3 1
3 4 1

输出
252

说明
这个例子中,所有序列一共有4^4 = 256个,其中不是好序列的只有4个:
[1, 1, 1, 1]

[2, 2, 2, 2]

[3, 3, 3, 3]

[4, 4, 4, 4]

二、Solution

方法一:求差

  • 直接求好序列比较难,因为遍历的时候还要统计 黑边 的数量
  • 而坏序列只包红边,又由题意得一个坏子图的结点数为 sz 时,那么该子图就有 s z k sz^k 个坏序列,所以我们只需求出每个坏子图 i i 的结点数 s z i sz_i ,最后用 t o t s z i tot - sz_i t o t = n k (tot = n^k) 即为所求答案。

细节:一般涉及到取模的题都需要仔细观察能取模的地方,比如这里,我提交时没有写上 + mod,直接 WA 了。

System.out.println((tot - bad + mod) % mod);

原因是为了防止负数取模仍是负数的问题,例如 -3 % 4 在 Java 中会得到 -3,而某些题目结果不小于 0,所以要加上 mod 确保结果非负(由取模定理得结果是不会改变的)

import java.util.*;
import java.math.*;
import java.io.*;
public class Main{
    static class Solution {
        Set<Integer> vis, g[];
        int mod = (int) 1e9+7;
        
        long qPow(long b, long p) {
            long ans = 1;
            while (p > 0) {
                if ((p & 1) == 1)
                    ans = (ans * b) % mod;
                b = (b * b) % mod;
                p >>= 1;
            }
            return ans;
        }
        long dfs(int u) {
            long sz = 1;
            vis.add(u);
            for (int v : g[u]) if (!vis.contains(v)) {
                sz += dfs(v);
            }
            return sz;
        }
        void init() {
            Scanner sc = new Scanner(new BufferedInputStream(System.in));
            int n = sc.nextInt(), k = sc.nextInt();
            long tot = qPow(n, k);

            vis = new HashSet<>();
            g = new HashSet[n+1];
            for (int i = 1; i <= n; i++) g[i] = new HashSet<>();
            for (int i = 1; i < n; i++) {
                int a = sc.nextInt(), b = sc.nextInt(), w = sc.nextInt();
                if (w == 0) {
                    g[a].add(b);
                    g[b].add(a);
                }
            }
            long bad = 0;
            for (int i = 1; i <= n; i++) if (!vis.contains(i)) {
                long sz = dfs(i);
                bad = (bad + qPow(sz, k))  % mod;
            }
            System.out.println((tot - bad + mod) % mod);
        }
    }
    public static void main(String[] args) throws IOException {  
        Solution s = new Solution();
        s.init();
    }
}

复杂度分析

  • 时间复杂度: O ( n ) O(n)
  • 空间复杂度: O ( n ) O(n)

猜你喜欢

转载自blog.csdn.net/qq_43539599/article/details/106825265