(树形dp)hdu 5834 Magic boy Bi Luo with his excited tree

题目
hdu5834

题意:
在一棵树上,每个结点都有不同的价值,每条边都有不同的代价,每个结点的价值只能获取一次,但是每经过一条边都要扣除相应的代价。输出从结点1~n 开始获取价值的最大值。

思路:
从结点开始获取价值的最大值 = max(子树回来 + parent不回来,parent回来 + 子树不回来)
结点p 的子树s 回来和子树s 不回来的最大值好求:

0表示不回来,1表示回来

子树s 回来:遍历每个子树s(结点p 到结点s 的代价w),
dp [ p ] [ 1 ] = max (dp [ p ] [ 1 ], dp [ p ] [ 1 ] + dp [ s ] [ 1 ] - w * 2)

子树s 不回来:遍历每个子树s(结点p 到结点s 的代价w),
dp [ p ] [ 0 ] = max (dp [ p ] [ 0 ], dp [ p ] [ 0 ] + dp [ s ] [ 1 ] - w * 2, dp [ p ] [ 1 ] + dp [ s ] [ 0 ] - w)

结点p 的parent回来和parent不回来的最大值难求:

结点f 是结点p 的parent,结点f 到结点p 的代价d

  1. 如果结点f 到子树又回来时不经过子树p 的情况:
    fdp [ p ] [ 0 ] = max (0, fdp [ f ] [ 0 ] + dp [ f ] [ 1 ] - d)
    在这里插入图片描述

fdp [ p ] [ 1 ] = max (0, fdp [ f ] [ 1 ] + dp [ f ] [ 1 ] - d * 2)
在这里插入图片描述
2. 如果结点f 到子树又回来时经过子树p 的情况:
fdp [ p ] [ 0 ] = max (0, fdp [ f ] [ 0 ] + dp [ f ] [ 1 ] - (dp [ p ] [ 1 ] - d)
在这里插入图片描述
fdp [ p ] [ 1 ] = max (0, fdp [ f ] [ 1 ] + dp [ f ] [ 1 ] - dp [ p ] [ 1 ])
在这里插入图片描述
3. 结点f 到子树不回来时的情况(在子树p 上)的情况:
在这里插入图片描述

  1. 如果结点f 到子树不回来时的情况(在子树p 上) 并且结点f 到子树又回来时不经过子树p 的情况:
    fdp [ p ] [ 0 ] = max ( fdp [ p ] [ 0 ], fdp [ f ] [ 1 ] + dp [ f ] [ 0 ] - ( dp [ p ] [ 1 ] - d) )

  2. 如果结点f 到子树不回来时的情况(在子树p 上) 并且结点f 到子树又回来时经过子树p 的情况:
    fdp [ p ] [ 0 ] = max ( fdp [ p ] [ 0 ], fdp [ f ] [ 1 ] + dp [ f ] [ 0 ] - d)

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector> 
#define DEBUG freopen("_in.txt", "r", stdin); freopen("_out1.txt", "w", stdout);
#define CASEE int t; cin >> t; for (int ti = 1; ti <= t; ti++)
using namespace std;
const int MAXN = 1e5 + 10;
int fa[MAXN], c[MAXN], idx[MAXN];  //fa[i]: i的parent结点,c[i]: i结点到parent结点的花费,idx[i]: i结点如果不回来的子树结点 
int dp[MAXN][2], fdp[MAXN][2];  //dp[i][0]: i到子树后不回来的最大值,dp[i][1]: i到子树后回来的最大值,fdp[i][0]:  i到parent后不回来的最大值,fdp[i][0]:  i到parent后回来的最大值
int v[MAXN];
bool vis[MAXN];
int tree[MAXN], ti;
struct node{
    
    
	int v, w, next;
}R[MAXN<<1];
void addroad(int u, int v, int w){
    
    
	R[ti].v = v;
	R[ti].w = w;
	R[ti].next = tree[u];
	tree[u] = ti++;
	R[ti].v = u;
	R[ti].w = w;
	R[ti].next = tree[v];
	tree[v] = ti++;
}
void dfs1(int p){
    
    
	vis[p] = true;
	dp[p][0] = dp[p][1] = v[p];
	for (int i = tree[p], s, w; i != -1; i = R[i].next){
    
    
		s = R[i].v;
		w = R[i].w;
		if (!vis[s]){
    
    
			fa[s] = p;
			c[s] = w;
			dfs1(s);
			dp[p][0] = max(dp[p][0], dp[p][0] + dp[s][1] - (w << 1));
			if (dp[p][0] < dp[p][1] + dp[s][0] - w){
    
    
				idx[p] = s;  //记录不回来的子树结点的坐标 
				dp[p][0] = dp[p][1] + dp[s][0] - w;
			}
			dp[p][1] = max(dp[p][1], dp[p][1] + dp[s][1] - (w << 1));
		}
	}
}
void dfs2(int p){
    
    
	int f = fa[p], d = c[p];
	fdp[p][0] = fdp[p][1] = 0;
	if (dp[p][1] - (d << 1) <= 0){
    
      //结点f 到子树又回来时不经过子树p 的情况 
		fdp[p][0] = max(0, fdp[f][0] + dp[f][1] - d);
		fdp[p][1] = max(0, fdp[f][1] + dp[f][1] - (d << 1));
	}
	else{
    
    
		fdp[p][0] = max(0, fdp[f][0] + dp[f][1] - (dp[p][1] - d));
		fdp[p][1] = max(0, fdp[f][1] + dp[f][1] - dp[p][1]);
	}
	if (idx[f] == p){
    
      //结点f 到子树不回来时的情况(在子树p 上) 
		int max1 = v[f], max2 = v[f];
		for (int i = tree[f], s, w; i != -1; i = R[i].next){
    
    
			s = R[i].v;
			w = R[i].w;
			if (s == fa[f] || s == p)
				continue;
			max1 = max(max1, max(max1 + dp[s][1] - (w << 1), max2 + dp[s][0] - w));
			max2 = max(max2, max2 + dp[s][1] - (w << 1));
		}
		fdp[p][0] = max(fdp[p][0], max1 + fdp[f][1] - d);
	}
	else{
    
    
		if (dp[p][1] - (d << 1) <= 0)  //结点f 到子树又回来时不经过子树p 的情况 
			fdp[p][0] = max(fdp[p][0], fdp[f][1] + dp[f][0] - (dp[p][1] - d));
		else
			fdp[p][0] = max(fdp[p][0], fdp[f][1] + dp[f][0] - d);
	}
	for (int i = tree[p], s; i != -1; i = R[i].next){
    
    
		s = R[i].v;
		if (s != f)
			dfs2(s);
	}
}
void solve(){
    
    
	int n;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++)
		scanf("%d", &v[i]);
	memset(tree, -1, sizeof(tree));
	ti = 0;
	for (int i = 1, p, s, w; i < n; i++){
    
    
		scanf("%d%d%d", &p, &s, &w);
		addroad(p, s, w);
	}
	memset(vis, false, sizeof(vis));
	dfs1(1);
	dfs2(1);
	for (int i = 1; i <= n; i++)
		printf("%d\n", max(dp[i][0] + fdp[i][1], dp[i][1] + fdp[i][0]));
}
int main(){
    
    
	int n;
	CASEE{
    
    
		printf("Case #%d:\n", ti);
		solve(); 
	} 
	return 0;
}

猜你喜欢

转载自blog.csdn.net/ymxyld/article/details/113850933