模板:(树剖\(LCA\)+建虚树)
#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;
struct node{
int to,next;
}e[maxn<<1];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
inline void addedge(int x,int y){
to[++cnt]=y;
nxt[cnt]=fir[x];
fir[x]=cnt;
}
void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
dep[x]=dep[f]+1;
int maxson=-1;
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++tim;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
bool cmp(int a,int b){
return id[a]<id[b];
}
int main()
{
n=read();
int x,y,w,k,lca;
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
m=read();
for(int t=1;t<=m;t++){
k=read();
for(int i=1;i<=k;i++){
h[i]=read();
vis[h[i]]=1;
}
sort(h+1,h+k+1,cmp);
cnt=0;sta[Top=1]=1;
for(int i=1;i<=k;i++){
lca=LCA(sta[Top],h[i]);
while(dep[lca]<dep[sta[Top]]){
if(dep[sta[Top-1]]<=dep[lca]){
addedge(lca,sta[Top--]);
if(sta[Top]!=lca) sta[++Top]=lca;
break;
}
addedge(sta[Top-1],sta[Top]);
Top--;
}
if(sta[Top]!=h[i]) sta[++Top]=h[i];
}
while(--Top) addedge(sta[Top],sta[Top+1]);
for(int i=1;i<=k;i++) vis[h[i]]=0;
}
return 0;
}
具体建虚树怎么建可以看别人的博客……我讲的肯定没有它们好
1、[SDOI2011]消耗战
分析:人生第一道虚树题。
难的就是建一棵虚树,然后在虚树上树形 \(dp\)
首先,打出一个树上前缀最小值。因为无论怎样,选一条最小的边断掉一定是最优的。建一棵虚树,若遍历到选定的点 \(x\),那么 \(dp[x]=min(dis[x],\sum_{son\in x}val_{x->son})\),其中 \(val\) 为边权。
\(Code\ Below:\)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=250000+10;
const int inf=1e18;
int n,m,dp[maxn],dis[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;
struct node{
int to,next,val;
}e[maxn<<1];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y,int w){
e[++tot].to=y;
e[tot].val=w;
e[tot].next=head[x];
head[x]=tot;
}
inline void addedge(int x,int y){
to[++cnt]=y;
nxt[cnt]=fir[x];
fir[x]=cnt;
}
void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
dep[x]=dep[f]+1;
int maxson=-1;
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==f) continue;
dis[y]=min(dis[x],e[i].val);
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++tim;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
bool cmp(int a,int b){
return id[a]<id[b];
}
void dfs(int x,int flag){
dp[x]=dis[x];
if(flag){
for(int i=fir[x];i;i=nxt[i])
dfs(to[i],flag);
fir[x]=vis[x]=0;
return ;
}
int val=0;
for(int i=fir[x],y;i;i=nxt[i]){
y=to[i];
dfs(y,vis[y]);
val+=dp[y];
}
if(!fir[x]||vis[x]) val=inf;
dp[x]=min(dp[x],val);
fir[x]=vis[x]=0;
}
signed main()
{
n=read();
int x,y,w,k,lca;
for(int i=1;i<n;i++){
x=read(),y=read(),w=read();
add(x,y,w);add(y,x,w);
}
dis[1]=inf;
dfs1(1,0);dfs2(1,1);
m=read();
for(int t=1;t<=m;t++){
k=read();
for(int i=1;i<=k;i++){
h[i]=read();
vis[h[i]]=1;
}
sort(h+1,h+k+1,cmp);
cnt=0;
sta[Top=1]=1;
for(int i=1;i<=k;i++){
lca=LCA(sta[Top],h[i]);
while(dep[lca]<dep[sta[Top]]){
if(dep[sta[Top-1]]<=dep[lca]){
addedge(lca,sta[Top]);
if(lca!=sta[--Top]) sta[++Top]=lca;
break;
}
addedge(sta[Top-1],sta[Top]);
Top--;
}
if(sta[Top]!=h[i]) sta[++Top]=h[i];
}
while(--Top) addedge(sta[Top],sta[Top+1]);
dfs(1,0);
printf("%lld\n",dp[1]);
}
return 0;
}
2、[HEOI2014]大工程
分析:这道题自己推的,很有成就感哈哈哈
方法与上题一样,不过多一点细节
\(sub[x]\) 表示在虚树上 \(x\) 的子树内有多少个选定点
这些点对 \((x,y)\) 对答案的贡献要分两类讨论:
1、\(x=lca(x,y)\) ,那么直接在 \(vis[x]=1\) 的时候算掉
2、\(x,y\) 在两棵不同的子树内,那就一边更新 \(sub[x]\) 一边算
void dfs(int x){
int now=0;
for(int i=fir[x];i;i=nxt[i]){
dfs(to[i]);
now+=sub[x]*sub[to[i]];
sub[x]+=sub[to[i]];
sub[to[i]]=0;
}
ans-=2*now*dep[x];
if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
}
找最小边权就是记录一下最小值和次小值,然后更新 \(ans\)
找最大边权同个道理
int dfs_min(int x){
int Min=inf,sec=inf;
for(int i=fir[x];i;i=nxt[i]){
sec=min(sec,dfs_min(to[i]));
if(Min>sec) swap(Min,sec);
}
if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
if(vis[x]) Min=dep[x];
return Min;
}
int dfs_max(int x){
int Max=-inf,sec=-inf;
for(int i=fir[x];i;i=nxt[i]){
sec=max(sec,dfs_max(to[i]));
if(Max<sec) swap(Max,sec);
}
if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
if(vis[x]&&Max==-inf) Max=dep[x];
fir[x]=0;
return Max;
}
那个前式链向星数组 \(fir[x]\) 一定要在 \(dfsmax()\) 的时候清空!!!
\(Code\ Below:\)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1000000+10;
const int inf=1e18;
int n,m,dp[maxn],vis[maxn],h[maxn],sub[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim,ans;
struct node{
int to,next;
}e[maxn<<1];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
inline void addedge(int x,int y){
to[++cnt]=y;
nxt[cnt]=fir[x];
fir[x]=cnt;
}
void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
dep[x]=dep[f]+1;
int maxson=-1;
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++tim;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
bool cmp(int a,int b){
return id[a]<id[b];
}
void dfs(int x){
int now=0;
for(int i=fir[x];i;i=nxt[i]){
dfs(to[i]);
now+=sub[x]*sub[to[i]];
sub[x]+=sub[to[i]];
sub[to[i]]=0;
}
ans-=2*now*dep[x];
if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
//printf("x=%lld,now=%lld,ans=%lld,sub[x]=%lld,dep[x]=%lld\n",x,now,ans,sub[x],dep[x]);
}
int dfs_min(int x){
int Min=inf,sec=inf;
for(int i=fir[x];i;i=nxt[i]){
sec=min(sec,dfs_min(to[i]));
if(Min>sec) swap(Min,sec);
}
if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
if(vis[x]) Min=dep[x];
return Min;
}
int dfs_max(int x){
int Max=-inf,sec=-inf;
for(int i=fir[x];i;i=nxt[i]){
sec=max(sec,dfs_max(to[i]));
if(Max<sec) swap(Max,sec);
}
if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
if(vis[x]&&Max==-inf) Max=dep[x];
fir[x]=0;
return Max;
}
signed main()
{
n=read();
int x,y,w,k,lca;
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
m=read();
for(int t=1;t<=m;t++){
k=read();
for(int i=1;i<=k;i++){
h[i]=read();
vis[h[i]]=1;
}
if(k==1){
printf("0 0 0\n");
continue;
}
sort(h+1,h+k+1,cmp);
cnt=0;sta[Top=1]=1;
for(int i=1;i<=k;i++){
lca=LCA(sta[Top],h[i]);
while(dep[lca]<dep[sta[Top]]){
if(dep[sta[Top-1]]<=dep[lca]){
addedge(lca,sta[Top--]);
if(sta[Top]!=lca) sta[++Top]=lca;
break;
}
addedge(sta[Top-1],sta[Top]);
Top--;
}
if(sta[Top]!=h[i]) sta[++Top]=h[i];
}
while(--Top) addedge(sta[Top],sta[Top+1]);
ans=0;
for(int i=1;i<=k;i++) ans+=(k-1)*dep[h[i]];
dfs(1);sub[1]=0;
printf("%lld ",ans);
ans=inf;dfs_min(1);
printf("%lld ",ans);
ans=-inf;dfs_max(1);
printf("%lld\n",ans);
for(int i=1;i<=k;i++) vis[h[i]]=0;
}
return 0;
}
3、CF613D Kingdom and its Cities
分析:在虚树上树形 \(dp\) 的时候分三种情况:
\(P.S:sum\) 表示有多少个儿子已经选了
1、\(vis[x]=1\),那么不能选 \(x\) 来断掉儿子的退路,那么 \(dp[x]=\sum_{son\in x} dp[son]\)
2、\(vis[x]=0,sum>1\),那就直接选 \(x\),\(x\) 的子树已经被 \(x\) 封死了
3、\(vis[x]=0,sum\leq 1\),那就传到 \(x\) 的父亲上,让 \(x\) 的父亲解决好了
\(Code\ Below:\)
#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;
struct node{
int to,next;
}e[maxn<<1];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
inline void addedge(int x,int y){
to[++cnt]=y;
nxt[cnt]=fir[x];
fir[x]=cnt;
}
void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
dep[x]=dep[f]+1;
int maxson=-1;
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++tim;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
bool cmp(int a,int b){
return id[a]<id[b];
}
int dfs(int x){
int ans=0,sum=0;
for(int i=fir[x];i;i=nxt[i])
ans+=dfs(to[i]),sum+=dp[to[i]];
if(vis[x]) dp[x]=1,ans+=sum;
else if(sum>1) dp[x]=0,ans++;
else dp[x]=sum;
fir[x]=0;
return ans;
}
int main()
{
n=read();
int x,y,w,k,lca,flag;
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
m=read();
for(int t=1;t<=m;t++){
k=read();
for(int i=1;i<=k;i++){
h[i]=read();
vis[h[i]]=1;
}
flag=0;
for(int i=1;i<=k;i++)
flag|=vis[fa[h[i]]];
if(flag){
printf("-1\n");
for(int i=1;i<=k;i++) vis[h[i]]=0;
continue;
}
sort(h+1,h+k+1,cmp);
cnt=0;sta[Top=1]=1;
for(int i=1;i<=k;i++){
lca=LCA(sta[Top],h[i]);
while(dep[lca]<dep[sta[Top]]){
if(dep[sta[Top-1]]<=dep[lca]){
addedge(lca,sta[Top--]);
if(sta[Top]!=lca) sta[++Top]=lca;
break;
}
addedge(sta[Top-1],sta[Top]);
Top--;
}
if(sta[Top]!=h[i]) sta[++Top]=h[i];
}
while(--Top) addedge(sta[Top],sta[Top+1]);
printf("%d\n",dfs(1));
for(int i=1;i<=k;i++) vis[h[i]]=0;
}
return 0;
}