版权声明:欢迎大家转载,转载请注明出处 https://blog.csdn.net/hao_zong_yin/article/details/82020688
预处理出最长链,若断边不在最长链上,那么答案依旧是最长链,否则计算断开后的最长链,这个做一次树形dp就可以了
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 10;
char c;
inline void input(int &x) {
x = 0;
while ((c = getchar()) < '0' || c > '9');
while ('0' <= c && c <= '9') x = x*10+(c-'0'), c = getchar();
}
int T, n;
int mem, head[maxn];
struct Edge { int to, val, next; }edges[maxn<<1];
void init_edges() {
mem = 0;
for (int i = 1; i <= n; i++) head[i] = -1;
}
void addedge(int u, int v, int val) {
edges[mem].to = v; edges[mem].val = val; edges[mem].next = head[u]; head[u] = mem++;
}
int U, MAX;
void dfs1(int fa, int u, int maxv) {
bool hasson = false;
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to, val = edges[i].val;
if (v != fa) { hasson = true; dfs1(u, v, maxv+val); }
}
if (!hasson && maxv > MAX) {
MAX = maxv, U = u;
}
}
ll dis[maxn];
int path[maxn], cnt, node[maxn];
bool vis[maxn];
void dfs2(int fa, int u) {
dis[u] = 0;
path[u] = -1;
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to, val = edges[i].val;
if (v == fa) continue;
dfs2(u, v);
if (dis[v] + val > dis[u]) {
dis[u] = dis[v] + val;
path[u] = v;
}
}
}
ll d[maxn];
void dfs3(int fa, int u) {
d[u] = 0;
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to, val = edges[i].val;
if (v == fa || vis[v]) continue;
dfs3(u, v);
d[u] = max(d[u], d[v] + val);
}
}
ll L[maxn], R[maxn];
void solve() {
input(n);
init_edges();
int u, v, w;
for (int i = 1; i <= n-1; i++) {
input(u); input(v); input(w);
addedge(u, v, w); addedge(v, u, w);
}
cnt = 0;
U = 0, MAX = 0;
dfs1(-1, 1, 0);
dfs2(-1, U);
for (int i = 1; i <= n; i++) vis[i] = 0;
for (int i = U; ~i; i = path[i]) {
node[++cnt] = i;
vis[i] = 1;
}
for (int i = 1; i <= cnt; i++) dfs3(-1, node[i]);
L[0] = 0;
for (int i = 1; i < cnt; i++) {
L[i] = max(L[i-1], dis[U] - dis[node[i]] + d[node[i]]);
}
R[cnt+1] = 0;
for (int i = cnt; i > 1; i--) {
R[i] = max(R[i+1], dis[node[i]] + d[node[i]]);
}
ll ans = 0;
for (int i = 1; i < cnt; i++) ans += max(L[i], R[i+1]);
ans += dis[U] * (n-cnt);
printf("%lld\n", ans);
}
int main() {
input(T);
while (T--) solve();
return 0;
};