[SDOI2011]消耗战(虚树优化树形dp)

[SDOI2011]消耗战

Solution

可以看出用树形dp,但传统方法复杂度为O(n*m),考虑优化

发现每一次树形dp中只有部分点有用,即资源丰富岛屿与它们之间的LCA

虚树构建方法

Code

#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>
#define ll long long
using namespace std;
const int N=25e4+10;
int dfn[N],tot,si[N],top[N],son[N],dep[N],fa[N];
int n,m,st[N],u,v,ver[N*2],nxt[N*2],cnt,a[N],k,head[N];
ll w,edge[N*2],mi[N];
vector <int> link[N];
void add(int u,int v,ll w)
{
    ver[++cnt]=v,nxt[cnt]=head[u],edge[cnt]=w,head[u]=cnt;
}
void dfs1(int u,int fu)
{
    fa[u]=fu,si[u]=1;
    dep[u]=dep[fu]+1;
    for(int i=head[u],v;v=ver[i],i;i=nxt[i])
    {
        if(v==fu) continue;
        mi[v]=min(mi[u],edge[i]);
        dfs1(v,u);
        si[u]+=si[v];
        if(!son[u] || si[son[u]]<si[v]) son[u]=v;
    }
}
void dfs2(int u,int fu)
{
    dfn[u]=++tot;
    if(!son[u]) return;
    top[son[u]]=top[u],dfs2(son[u],u);
    for(int i=head[u],v;v=ver[i],i;i=nxt[i])
    {
        if(v==son[u] || v==fu) continue;
        top[v]=v,dfs2(v,u);
    }
}
int lca(int x,int y)
{
    int px=top[x],py=top[y];
    while(px!=py)
        if(dep[px]>=dep[py]) x=fa[px],px=top[x];
        else y=fa[py],py=top[y];
    return dep[x]<=dep[y]?x:y;
}
bool cmp(int a,int b)
{
    return dfn[a]<dfn[b];
}
void insert(int x)
{
    if(st[0]==1)
    {
        st[++st[0]]=x;
        return;
    }
    int lc=lca(x,st[st[0]]);
    if(lc==st[st[0]]) return;
    while(st[0]>1 && dfn[st[st[0]-1]]>=dfn[lc])
        link[st[st[0]-1]].push_back(st[st[0]]),st[0]--;
    if(lc!=st[st[0]]) link[lc].push_back(st[st[0]]),st[st[0]]=lc;
    st[++st[0]]=x;
}
ll dfs(int u)
{
    int size=link[u].size();
    if(!size) return mi[u];
    ll ans=0;
    for(int i=0;i<size;i++)
        ans+=dfs(link[u][i]);
    link[u].clear();
    return min(mi[u],ans);
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        scanf("%d%d%lld",&u,&v,&w);
        add(u,v,w),add(v,u,w);
    }
    mi[1]=1e16,dfs1(1,0);
    top[1]=1,dfs2(1,0);
    scanf("%d",&m);
    while(m--)
    {
        scanf("%d",&k);
        for(int i=1;i<=k;i++) scanf("%d",&a[i]);
        sort(a+1,a+1+k,cmp);
        st[++st[0]=1]=1;
        for(int i=1;i<=k;i++) insert(a[i]);
        while(st[0]>1) link[st[st[0]-1]].push_back(st[st[0]]),st[0]--;
        printf("%lld\n",dfs(1)); 
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hsez-cyx/p/12453612.html