简介
参考自@WNJXYK写的CSP 201909-5 城市规划,感谢。
-
城市中有N个公交站,公交站之间通过N-1条道路连接,每条道路有相应的长度。保证所有公交站两两之间能够通过一条唯一的通路互相达到。
这实际上告诉我们,输入是一棵有N个节点的树。
-
输出路径长度总和即可
要求解的是两两之间路径长度之和的最小值,而不是类似最小生成树那样的结果。
-
对于某一条边,假设在这条边两端的子树里分别选了
x (0 <= x <= K)
和K - x
个重要节点,那么,在计算两两之间路径长度总和时,这条边要累加x * (K - x)
次。简单的组合问题。 -
下面的写法是按照点分治来的。定义
dp[i][j]
为 在以节点i
为根的子树中选j
个重要节点时,在这个子树里的最小路径和 。即,仅计算在这棵子树里的每条边。 -
如果以节点
i
为根的子树里,没有足够的j
个重要节点,那么dp[i][j] = -1
,表示此情况无意义。同时对任意i
有dp[i][0] = 0
。 -
当这棵子树仅有一个节点u时,如果u是重要节点的话,
dp[u][1] = 0
;否则,dp[u][1] = -1
。 -
当这棵子树不止一个节点u时,我们先算出u的下一级节点(也即u的子树)的解,然后再合并成节点u的解。合并的方法,类似于背包问题,从
n
个子树里总共选x (0 < x <= K)
个重要节点,使得dp[u][x]
最小。这里又用到递归。假设已知前n-1
棵子树的合并结果,也就是说,已知从前n-1
棵子树里选x
个重要节点的dp[u][x]
,那么,从所有n
棵子树里选x
个重要节点的dp[u][x]
等于,从前n-1
棵子树里选y (0 <= y <= x)
个重要节点并且从第n
棵子树里选x - y
个节点,这x + 1
种情况的最小值。不会打漂亮的数学公式,建议直接阅读代码相关部分。
时间复杂度
solve只执行了N次,每个节点一次。第16行起的二重循环,复杂度为 。
总的时间复杂度为
代码
#include <iostream>
#include <vector>
using namespace std;
int N, M, K;
vector<bool> imp;
vector<vector<pair<int, int>>> tree;
vector<vector<long long>> dp; // dp[i][j]: 在以节点i为根的子树中选j个重要节点时,在这个子树里的最小路径和
void solve(int u, int fa) {
vector<long long>& cur = dp[u]; // 当前要求解的数组
cur[0] = 0; // j = 0(一个重要节点都不选)时,路径和显然为0
for (const auto& pr: tree[u]) if (pr.first != fa) {
solve(pr.first, u); // 先解出子问题dp[pr.first],然后合并,分治思想
vector<long long>& sub = dp[pr.first];
for (int x = K; x > 0; --x) { // 合并时,重用了数组cur。应该将此时的cur视为前n-1个子树合并后的结果
for (int y = x; y >= 0; --y) {
if (cur[y] != -1 && sub[x - y] != -1) {
long long t = cur[y] + sub[x - y] + (long long)pr.second * (x - y) * (K - x + y);
if (cur[x] == -1 || cur[x] > t) cur[x] = t;
}
}
}
}
if (imp[u]) {
for (int i = K; i > 0; --i)
if (cur[i] == -1 || cur[i] > cur[i - 1]) // cur[i] != -1的话,cur[i-1]不可能等于-1
cur[i] = cur[i - 1];
}
}
int main() {
cin.tie(nullptr);
ios_base::sync_with_stdio(false);
cin >> N >> M >> K;
imp.resize(N + 1);
for (int i = 0; i < M; ++i) {
int x;
cin >> x;
imp[x] = true;
}
tree.resize(N + 1);
for (int i = 1; i < N; ++i) {
int a, b, c;
cin >> a >> b >> c;
tree[a].emplace_back(b, c);
tree[b].emplace_back(a, c);
}
dp.resize(N + 1, vector<long long>(K + 1, -1));
solve(1, 0);
cout << dp[1][K] << endl;
return 0;
}