树上背包问题
树上背包问题
一些题目给定了树形结构,在这个树形结构中选取一定数量的点或边(也可能是其他属性),使得某种与点权或者边权相关的花费最大或者最小。解决这类问题,一般要考虑使用树上背包。
算法原理
树上背包,顾名思义,就是在树上做背包问题。一个节点的若干子树可以看作是若干组背包,也就是用树形dp的方式做分组背包问题。一般来说, f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,在 j j j的容量范围内,最大或者最小可以获得多少收益。根据分组背包的思想,第一维枚举物品(在树上指的是子树),第二维枚举容量,第三维枚举决策(这里指的是给子树分配多少容量)。基本的代码框架如下:
void dfs(int u, int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs(son, u);
for(int j = m; j >= 0; j --)
for(int k = 0; k <= j; k ++)
f[u][j] = max(f[u][j], f[u][j-k] + f[son][k] + val);
}
}
例题一:有依赖的背包问题
题意
有 n n n个物品和一个容量是 m m m的背包。物品之间具有依赖关系,且依赖关系组成一棵树的形状。如果选择一个物品,则必须选择它的父节点。
求解将哪些物品装入背包,可使物品总体积不超过背包容量,且总价值最大。输出最大价值。
每件物品的编号是 i i i,体积是 v i v_i vi,价值是 w i w_i wi,依赖的父节点编号是 p i p_i pi。物品的下标范围是 1 … N 1 \dots N 1…N。
数据范围
1 ≤ n , m ≤ 100 1 \leq n,m \leq 100 1≤n,m≤100
1 ≤ v i , w i ≤ 100 1 \leq v_i,w_i \leq 100 1≤vi,wi≤100
思路
f ( i , j ) f(i,j) f(i,j)表示选择以 i i i为子树的物品,在容量不超过 j j j时所获得的最大价值。
由于只有选择了根节点,才会继续往下遍历,所以在遍历到 i i i节点时,先考虑一定选上它。
在分组背包部分, j j j的范围为 [ m , v [ i ] ] [m,v[i]] [m,v[i]],否则没有意义,因为连根节点也放不下; k k k的范围 [ 0 , j − v [ i ] ] [0,j-v[i]] [0,j−v[i]],当大于 j − v [ i ] j-v[i] j−v[i]时分给该子树的容量过多,剩余的容量连根节点的物品都放不下了。
递推式为: f ( i , j ) = m a x ( f ( i , j ) , f ( i , j − k ) + f ( s o n , k ) ) f(i,j) = max(f(i,j), f(i,j - k) + f(son,k)) f(i,j)=max(f(i,j),f(i,j−k)+f(son,k))。
代码
void dfs(int u)
{
for(int i = v[u]; i <= m; i ++) f[u][i] = w[u];
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
dfs(son);
for(int j = m; j >= v[u]; j --)
for(int k = 0; k <= j - v[u]; k ++)
f[u][j] = max(f[u][j], f[u][j - k] + f[son][k]);
}
}
例题二:二叉苹果树
题意
给定一棵二叉树,每条边有边权,保留一定数量的边(其他边删除),使得保留下来的边的边权和最大。
数据范围
1 ≤ n < m ≤ 100 1 \leq n < m \leq 100 1≤n<m≤100
w i ≤ 30000 w_i \leq 30000 wi≤30000
思路
f ( i , j ) f(i,j) f(i,j)表示以 i i i为根的子树中,恰好保留 j j j条边的最大边权和。
若需要选择该子树中的边,则根结点到子树的边一定要选,因此能用上的总边数一定减 1 1 1,总共可以选择 j j j条边时,当前子树son分配的最大边数是 j − 1 j - 1 j−1。
递推式为, f ( i , j ) = m a x ( f ( i , j ) , f ( i , j − k − 1 ) + f ( s o n , k ) + w [ i ] ) f(i,j) = max(f(i,j), f(i,j-k-1) + f(son, k) + w[i]) f(i,j)=max(f(i,j),f(i,j−k−1)+f(son,k)+w[i])。
代码
void dfs(int u, int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs(son, u);
for(int j = m; j >= 1; j -- )
for(int k = 0; k <= j - 1; k ++ )
f[u][j] = max(f[u][j], f[u][j - k - 1] + f[son][k] + w[i]);
}
}
例题三:Factories(2018icpc银川网络赛)
题意
给定一棵树,边有边权。每个叶子节点上最多可以布置一个工厂,总共要布置 k k k个工厂。问怎样布置工厂,使得工厂之间的距离和最小。
数据范围
10 s 10s 10s
2 ≤ n ≤ 1 0 5 2 \leq n \leq 10^5 2≤n≤105, 1 ≤ m ≤ 100 1 \leq m \leq 100 1≤m≤100
1 ≤ w i ≤ 1 0 5 1 \leq w_i \leq 10^5 1≤wi≤105
多组测试数据, n n n总数不超过 1 0 6 10^6 106
思路
直接考虑距离之和非常困难,所以可以考虑每条边被计算了几次(距离和等类似问题很多都是这么考虑的)。不妨设一条边为 i i i,与 i i i相连的子树中有 j j j个工厂,则这条边被计算的次数为 j ∗ ( m − j ) j*(m - j) j∗(m−j)。
f ( i , j ) f(i,j) f(i,j)表示以 i i i为根节点的子树中,选择恰好 j j j个叶子节点的距离总和。
递推式为, f ( i , j ) = m i n ( f ( i , j ) , f ( i , j − k ) + f ( s o n , k ) + w [ i ] ∗ j ∗ ( m − j ) ) f(i,j) = min(f(i,j), f(i,j - k) + f(son, k) + w[i] * j * (m - j)) f(i,j)=min(f(i,j),f(i,j−k)+f(son,k)+w[i]∗j∗(m−j))。
因为只能分布在叶子节点,因此初始化的时候要注意,如果点 i i i为叶子节点,那么 f ( i , 1 ) = 0 f(i,1) = 0 f(i,1)=0。
同时这道题要卡常数,所以要对状态做一个优化,即把无效状态去掉。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 100003, M = 103;
const ll inf = 1e18;
int n, m;
int h[N], e[2*N], ne[2*N], w[2*N], idx;
int s[N], deg[N];
ll f[N][M];
void add(int a,int b,int c)
{
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx ++;
}
void dfs(int u,int fa)
{
for(int i = h[u]; ~i; i = ne[i])
{
int son = e[i];
if(son == fa) continue;
dfs(son, u);
s[u] += s[son];
for(int j = min(m, s[u]); j >= 1; j --)
for(int k = 1; k <= min(j, s[son]); k ++)
f[u][j] = min(f[u][j], f[u][j-k] + f[son][k] + (ll)w[i] * k * (m - k));
}
}
int main()
{
int T;
scanf("%d", &T);
int cas = 0;
while(T --)
{
scanf("%d%d", &n,&m);
for(int i = 1; i <= n; i ++) h[i] = -1, deg[i] = 0;
idx = 0;
for(int i = 0; i < n - 1; i ++)
{
int a,b,c;
scanf("%d%d%d", &a,&b,&c);
add(a,b,c), add(b,a,c);
deg[a] ++, deg[b] ++;
}
for(int i = 1; i <= n; i ++)
{
s[i] = 0;
for(int j = 1; j <= m; j ++) f[i][j] = inf;
if(deg[i]==1) f[i][1] = 0, s[i] = 1;
}
dfs(1, -1);
printf("Case #%d: %lld\n",++cas,f[1][m]);
}
return 0;
}