Link
Difficulty
算法难度7,思维难度7,代码难度5
Description
给定一棵 个点的树,边带权值,要求你选出 条链,使得权值和最大。
Solution
前面的小部分分我就不说了,说一下和正解有极大联系的60分的树形dp吧。
首先我们考虑设计dp状态。
第一想法是 代表在 的子树中选了 条链的最大价值,看起来非常美好。
但是仔细想想发现没法写状态转移方程,因为不知道到底能不能和儿子连边,也不知道连边会发生什么事。
这样我们就发现我们还要记录一下每个点的连边状态。
代表在 的子树中完整地选了 条链的最大价值, 代表点 的度数。
首先初始状态:
- ,代表这个点可以不选。
- ,代表这个点可以作为链最下面的点向上连。
- ,代表这个点可以单独作为一条链,至于为什么要有这个状态,只需要想一下极端情况 时,合法答案是什么样子的就可以了。
- 其他都为负无穷,也就是不合法
考虑转移状态,将儿子 的状态合并到点 :( 代表 从 枚举到 ,下面不再描述)
-
,其中
代表不选或者选 条链。
-
,其中
代表不选,选 条链,或者选 条链并且选这条边。
-
,其中
代表不选,选 条链,或者选 条链并且选这条边增加一条链。
-
,其中
这个看起来跟上面的第二个转移的重复了,事实上并没有,因为这个转移既合法,第二个转移又转移不到。
-
代表从下面连上来,同样是第二个转移没有转移到的。
-
,其中
代表不选,在这里停止这条链并计入总数,或者把那两个度数去掉。
转移方程大概就是这些了,dp的顺序呀,细节呀,就看我的代码吧。
这样的话复杂度有些玄学(调循环边界的话),我不太会算,反正只能有 分,会TLE。
本来想把这个dp放到dfs序上说不定就可以到 了,后来发现我不会QAQ
这个dp必须先写一下,因为凸优化的代码就是在dp的基础上改的。
拿到45分之后,我们来看这题正解吧。
凸优化
凸优化就是针对凸函数求极值的优化。
我们这里不直接探究它的定义及一般情况,我们直接来看这个题,通过这个题来理解凸优化。
首先,通过打表可以发现,答案的函数是上凸的,对于样例来说画出来是这样的:
虽然图像有点儿尖,但是它确实是上凸的。
怎么直接判断一个题的答案是否上凸呢?
我们可以感性判断,比如对于这个题,假如只能选一条链的话,一定是选最长的,选两条的话,增长的就没有第一条那么多了,因为最长的已经选过了,这样来看,增长只会越来越慢,所以它是凸函数。
现在我们知道它是凸函数了,应该怎么做呢?
我们二分一个权值 ,代表选一条链需要付出的代价,然后我们去掉选多少条链那一维,还按照原来的dp做。
这样子相当于我们拿 的直线去切答案函数,在这个基础上求极值。
但是我们发现这样求得极值之后,无法判断下一次 变小还是变大。
我们同样可以发现,切了之后的可以取得极值的点是一段连续的区间。
因此,在此基础上我们再记录取得极值的最小的 是多少,也就是区间的左端点是多少。
假如题目中的 等于左端点的话,直接输出答案。
假如题目中的 一定不在这个区间内(左端点大于 ),则令 ,让选的代价变大,左端点减小。
假如题目中的 有可能在这个区间内(左端点小于 ),则令 ,让选的代价变小,左端点增大。
最后令 ,再做一次得到最终答案,并且把那个选的代价加回来,就好了。
感性理解一下这个过程,感觉挺对的QAQ
然后这个做法就叫凸优化啦,是不是感觉也没什么难的?
时间复杂度 ,还有树形dp常数挺大,所以跑得比较慢。
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
inline int read(){
int x=0,f=1;char ch=' ';
while(ch<'0' || ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0' && ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return f==1?x:-x;
}
const int N=3e5+5,K=105;
const LL inf=1e18;
int n,k,tot;
int head[N],to[N<<1],Next[N<<1],val[N<<1];
struct data{
LL x,y;
data(){}
data(LL _x,LL _y):x(_x),y(_y){}
inline bool operator < (const data& b) const {
if(x==b.x)return y>b.y;
return x<b.x;
}
inline data operator + (const data& b) const {return data(x+b.x,y+b.y);}
inline data operator + (LL b) const {return data(x+b,y);}
}dp[N][3];
inline void addedge(int x,int y,int l){
to[++tot]=y;
Next[tot]=head[x];
head[x]=tot;
val[tot]=l;
}
LL mid;
inline void dfs(int x,int fa){
dp[x][0]=data(0,0);
dp[x][1]=data(0,0);
dp[x][2]=max(data(0,0),data(-mid,1));
for(int i=head[x];i;i=Next[i]){
int u=to[i];
if(u==fa)continue;
dfs(u,x);
dp[x][2]=max(dp[x][2],max(dp[x][2]+dp[u][0],dp[x][1]+dp[u][1]+val[i]+data(-mid,1)));
dp[x][1]=max(dp[x][1],max(dp[x][1]+dp[u][0],dp[x][0]+dp[u][1]+val[i]));
dp[x][0]=max(dp[x][0],dp[x][0]+dp[u][0]);
}
dp[x][0]=max(dp[x][0],data(0,0));
dp[x][0]=max(dp[x][0],max(dp[x][1]+data(-mid,1),dp[x][2]));
}
int main(){
n=read();k=read()+1;
for(int i=1;i<n;++i){
int x=read(),y=read(),l=read();
addedge(x,y,l);addedge(y,x,l);
}
LL l=-1e12,r=1e12;
while(l<r){
mid=(double)(l+r)/2-0.5;
dfs(1,0);
if(dp[1][0].y==k){
printf("%lld\n",dp[1][0].x+k*mid);
return 0;
}
else if(dp[1][0].y>k)l=mid+1;
else r=mid;
}
mid=l;
dfs(1,0);
printf("%lld\n",dp[1][0].x+k*mid);
return 0;
}