版权声明:本文为博主原创文章,可以转载但是必须声明版权。 https://blog.csdn.net/forever_shi/article/details/84192924
题意:
给你一棵以1为根的树,边有边权,有m次询问,每次询问选出k个点,问这k个与1号点都不连通要割断的最小边权和。
量级,
是
量级的。
题解:
一道虚树模板题。
这种树上选若干个点的题基本就是往虚树方面想了。我们预处理出每个点割断它到1号点的最小代价,然后每次询问建出虚树。如果父节点要被割断,那么就没有必要建它的子树中被选中的节点了,于是我们在虚树里只需要考虑把每个叶子割断的最小代价就好了。然后进行树形dp,
表示割断
所在的所有叶子花费的最小代价,有
。这样我们就能做到
的复杂度了。
注意开long long。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,hed[1800010],cnt,f[800010][21],dep[1800010],z;
int sta[1800010],dfn[1800010],k,b[1800010],tp;
long long mn[800010];
struct node
{
int to,next;
long long dis;
}a[1800010],aa[1800010];
vector<int> v[1800010];
inline void add(int from,int to,long long dis)
{
a[++cnt].to=to;
a[cnt].dis=dis;
a[cnt].next=hed[from];
hed[from]=cnt;
}
inline void dfs(int x)
{
dfn[x]=++z;
for(int i=1;i<=20;++i)
f[x][i]=f[f[x][i-1]][i-1];
for(int i=hed[x];i;i=a[i].next)
{
int y=a[i].to;
if(y==f[x][0])
continue;
f[y][0]=x;
dep[y]=dep[x]+1;
mn[y]=min(mn[x],a[i].dis);
dfs(y);
}
}
inline int cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
inline int lca(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
for(int i=20;i>=0;--i)
{
if(dep[x]-dep[y]>=(1<<i))
x=f[x][i];
}
if(x==y)
return x;
for(int i=20;i>=0;--i)
{
if(f[x][i]!=f[y][i]&&dep[x]>=(1<<i))
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
inline void add2(int from,int to)
{
v[from].push_back(to);
}
inline void insert(int x)
{
if(tp==1)
{
sta[++tp]=x;
return;
}
int z=lca(sta[tp],x);
if(z==sta[tp])
return;
while(dfn[sta[tp-1]]>=dfn[z]&&tp>1)
{
add2(sta[tp-1],sta[tp]);
--tp;
}
if(z!=sta[tp])
{
add2(z,sta[tp]);
sta[tp]=z;
}
sta[++tp]=x;
}
inline long long dfs2(int x)
{
int sz=v[x].size();
if(sz==0)
return mn[x];
long long res=0;
for(int i=0;i<sz;++i)
{
int y=v[x][i];
res+=dfs2(y);
}
v[x].clear();
return min(res,mn[x]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n-1;++i)
{
int x,y;
long long z;
scanf("%d%d%lld",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
mn[1]=2e16;
dep[1]=1;
dfs(1);
scanf("%d",&m);
for(int i=1;i<=m;++i)
{
scanf("%d",&k);
for(int j=1;j<=k;++j)
scanf("%d",&b[j]);
sort(b+1,b+k+1,cmp);
sta[++tp]=1;
for(int j=1;j<=k;++j)
insert(b[j]);
while(tp>0)
{
add2(sta[tp-1],sta[tp]);
--tp;
}
printf("%lld\n",dfs2(1));
}
return 0;
}