P3574 [POI2014]FAR-FarmCraft(树形dp)

题目传送门

题意: 有一颗树,你在根节点1,从1出发经过每个点恰好一次,经过每条边的时间为1,每个点有一个权值a[i],每个点在第一次经过的时候就开始计时,a[i]秒之后结束。特殊的:点1的计时是最后开始。问:你最短需要多长时间让所有点都计时结束。

思路: f [ x ] f[x] f[x]表示以x为根节点的子树需要的最短时间。假设y是x的一颗子树,那么就有:

  • f [ x ] = a [ x ] f[x]=a[x] f[x]=a[x]
  • f [ x ] = m a x ( f [ x ] , f [ y ] + s i z [ x ] y + 1 ) f[x]=max(f[x],f[y]+siz[x]_y+1) f[x]=max(f[x],f[y]+siz[x]y+1)(1表示从x到y)

s i z [ x ] y siz[x]_y siz[x]y表示走到y之前,在x的子树内,走路花费了的时间。

但是这样的话,就需要考虑一个子树顺序问题,因为顺序不同可能导致更新出来的 f [ x ] f[x] f[x]不同。

我们假设有 y y y, z z z两颗子树,先遍历 y y y子树时:

  • f [ x ] = m a x ( f [ x ] , s i z [ x ] y + m a x ( f [ y ] , f [ z ] + s i z [ y ] + 2 ) + 1 ) f[x]=max(f[x],siz[x]_y+ max(f[y],f[z]+siz[y]+2)+1) f[x]=max(f[x],siz[x]y+max(f[y],f[z]+siz[y]+2)+1)

如果先遍历 z z z子树:

  • f [ x ] = m a x ( f [ x ] , s i z [ x ] y + m a x ( f [ z ] , f [ y ] + s i z [ z ] + 2 ) + 1 ) f[x]=max(f[x],siz[x]_y+max(f[z],f[y]+siz[z]+2)+1) f[x]=max(f[x],siz[x]y+max(f[z],f[y]+siz[z]+2)+1)

那么如果交换比不交换更优,则有:

  • m a x ( f [ y ] , f [ z ] + s i z [ y ] ) > m a x ( f [ z ] , f [ y ] + s i z [ z ] ) max(f[y],f[z]+siz[y])>max(f[z],f[y]+siz[z]) max(f[y],f[z]+siz[y])>max(f[z],f[y]+siz[z])

显然, f [ y ] < f [ y ] + s i z [ z ] , f [ z ] < f [ z ] + s i z [ y ] f[y]<f[y]+siz[z],f[z]<f[z]+siz[y] f[y]<f[y]+siz[z],f[z]<f[z]+siz[y]
即有, f [ z ] + s i z [ y ] > f [ y ] + s i z [ z ] f[z]+siz[y]>f[y]+siz[z] f[z]+siz[y]>f[y]+siz[z]
即, s i z [ z ] − f [ z ] < s i z [ y ] − f [ y ] siz[z]-f[z]<siz[y]-f[y] siz[z]f[z]<siz[y]f[y]
将这个作为依据对子树排序,再对根节点进行更新即可。
(注意处理子树根节点是1的情况)

注意: 考虑状态转移方程中,为何是 f [ x ] = m a x ( f [ x ] , f [ y ] + s i z [ x ] y + 1 ) f[x]=max(f[x],f[y]+siz[x]_y+1) f[x]=max(f[x],f[y]+siz[x]y+1)而不是 f [ x ] = m a x ( f [ x ] , f [ y ] + s i z [ x ] y + 2 ) f[x]=max(f[x],f[y]+siz[x]_y+2) f[x]=max(f[x],f[y]+siz[x]y+2) ?我们走完之前的子树,现在要走y子树,按道理应该是 s i z [ x ] y + f [ y ] + 2 siz[x]_y+f[y]+2 siz[x]y+f[y]+2,(2表示从x到y再从y到x)。其实,我们知道,f[y]>siz[y](无特殊情况),即在y子树内的走路时间一定是小于在y子树计时总时间的,那么从y走到x的时间其实是已经被包含在 f [ y ] f[y] f[y]里面了。

代码:

#include<bits/stdc++.h>
#define endl '\n'
#define mp make_pair
#define pb push_back
#define ll long long
#define int long long
#define pii pair<int,int>
#define sz(x) (int)(x).size()
#define all(x) (x).begin(),(x).end()
#define mem(a,b) memset(a,b,sizeof(a))
char *fs,*ft,buf[1<<20];
#define gc() (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<20,stdin),fs==ft))?0:*fs++;
inline int read()
{
    
    
    int x=0,f=1;
    char ch=gc();
    while(ch<'0'||ch>'9')
    {
    
    
        if(ch=='-')
            f=-1;
        ch=gc();
    }
    while(ch>='0'&&ch<='9')
    {
    
    
        x=x*10+ch-'0';
        ch=gc();
    }
    return x*f;
}
using namespace std;
const int N=5e5+10;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const double eps=1e-7;

vector<int>e[N];
int a[N],f[N],siz[N],temp[N];

bool cmp(int x,int y)
{
    
    
    return siz[x]-f[x]<siz[y]-f[y];
}

void dfs(int fa,int x)
{
    
    
    if(x!=1)
        f[x] = a[x];
    for(auto i:e[x])
        if(i!=fa)
            dfs(x,i);
    int tot=0;
    for(auto i:e[x])
        if(i!=fa)
            temp[++tot] = i;
    sort(temp+1,temp+tot+1,cmp);
    for(int i=1;i<=tot;i++)
        f[x]=max(f[x],f[temp[i]]+siz[x]+1),siz[x]+=siz[temp[i]]+2;
    if(x==1)
        f[x] = max(f[x],siz[x]+a[x]);
}

void solve()
{
    
    
    int n;
    cin>>n;
    for(int i=1;i<=n;i++)
        cin>>a[i];
    for(int i=1;i<=n-1;i++)
    {
    
    
        int u,v;
        cin>>u>>v;
        e[u].pb(v);
        e[v].pb(u);
    }
    dfs(1,1);
    cout<<f[1]<<endl;
}

signed main()
{
    
    
    solve();

    return 0;
}

猜你喜欢

转载自blog.csdn.net/Joker_He/article/details/112974789