Annoying problem
Time Limit: 16000/8000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)
Total Submission(s): 1906 Accepted Submission(s): 632
Problem Description
Coco has a tree, whose nodes are conveniently labeled by 1,2,…,n, which has n-1 edge,each edge has a weight. An existing set S is initially empty.
Now there are two kinds of operation:
1 x: If the node x is not in the set S, add node x to the set S
2 x: If the node x is in the set S,delete node x from the set S
Now there is a annoying problem: In order to select a set of edges from tree after each operation which makes any two nodes in set S connected. What is the minimum of the sum of the selected edges’ weight ?
Input
one integer number T is described in the first line represents the group number of testcases.( T<=10 )
For each test:
The first line has 2 integer number n,q(0<n,q<=100000) describe the number of nodes and the number of operations.
The following n-1 lines each line has 3 integer number u,v,w describe that between node u and node v has an edge weight w.(1<=u,v<=n,1<=w<=100)
The following q lines each line has 2 integer number x,y describe one operation.(x=1 or 2,1<=y<=n)
Output
Each testcase outputs a line of "Case #x:" , x starts from 1.
The next q line represents the answer to each operation.
Sample Input
1 6 5 1 2 2 1 5 2 5 6 2 2 4 2 2 3 2 1 5 1 3 1 4 1 2 2 5
Sample Output
Case #1: 0 6 8 8 4
题意:给你一颗带权无向树,给你一个集合s,初始为空,有两种操作:
1. 1 x 表示把节点x加入到集合s中
2. 2 x 表示把节点x从集合s中删除
每一次操作后,你要求出集合s中构成树的最小边权和(边来自原树)。
思路:LCA+DFS序+思维。
考虑到s为空,然后:加一个点,结果为0,再加一个点,结果为两点间的距离,再加一个点,结果为上一个的结果+这个点到前两个点构成的链的最小距离。删除也是一样。
于是我们只需要求出当前要加入的点u到集合s中所有点构成的树的最短距离。
根据dfs序性质:1、比u编号大的节点不是在u的孩子中,就是在u的右侧。
2、比u编号小的节点不是在u的祖先中,就是在u的左侧。 即dfs序可以确定两点在书中的相对位置。
因此我们分多种情况讨论,最终可以得出u到集合s中所有点构成的树的最短距离,就等于u到集合中dfs序比他大的最小dfs序的点和dfs序比他小的最大dfs序的点。如果都比它小或者都比它大就取dfs序最小的和最大的点即可。考虑到插入删除以及排序,我们用set来存储集合s的所有节点的dfs序。
x到y的距离(即链xy的长度)=dis[x]+dis[y]-2*dis[lca(x,y)];
u到x的距离=dis[u]+dis[x]-2*dis[lca(u,x)];
u到y的距离=dis[u]+dis[y]-2*dis[lca(u,y)];
后两式相加减第一式即是u到链xy的最短距离的两倍。
因此u到链xy的最短距离=dis[u]-dis[lca(u,x)]-dis[lca(u,y)]+dis[lca(x,y)];
LCA模板不够熟练,有敲错过的地方标记出来,提醒自己。
代码:
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
using namespace std;
const int maxn=400010;
int n,m,q,tot,tot2,cnt,tmp,ans;
int f[maxn][20];
int head[maxn],dis[maxn];
int dfn[maxn],d[maxn];
int dep[maxn],e[maxn],pos[maxn];
bool vis[maxn];
struct node
{
int to,nex,w;
}a[maxn];
void add(int u,int v,int w)
{
a[cnt].to=v;
a[cnt].w=w;
a[cnt].nex=head[u];
head[u]=cnt++;
}
void init()
{
cnt=tot=tot2=0;
memset(head,-1,sizeof(head));
memset(pos,-1,sizeof(pos));
memset(vis,0,sizeof(vis));
dis[1]=0;
}
void dfs(int u,int deep)//1
{
if(pos[u]!=-1)return;
dfn[tot2]=u;d[u]=tot2++;
pos[u]=tot;e[tot]=u;dep[tot++]=deep;
for(int i=head[u];i!=-1;i=a[i].nex)
{
int v=a[i].to;
if(pos[v]==-1)
{
dis[v]=dis[u]+a[i].w;
dfs(v,deep+1);
e[tot]=u;dep[tot++]=deep;
}
}
}
void rmq(int n)//2
{
for(int i=1;i<=n;i++)f[i][0]=i;
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
{
if(dep[f[i][j-1]]<dep[f[i+(1<<(j-1))][j-1]]) f[i][j]=f[i][j-1];
else f[i][j]=f[i+(1<<(j-1))][j-1];
}
}
int RMQ(int l,int r)
{
int k=(int)(log((double)(r-l+1))/log(2.0));
if(dep[f[l][k]]<dep[f[r-(1<<k)+1][k]]) return f[l][k];
else return f[r-(1<<k)+1][k];
}
int lca(int x,int y)//3
{
if(pos[x]<pos[y]) return e[RMQ(pos[x],pos[y])];
else return e[RMQ(pos[y],pos[x])];
}
set<int>st;
set<int>::iterator it;
int cal(int u)
{
if(st.empty())return 0;
int x,y;
it=st.lower_bound(d[u]);
if(it==st.begin()||it==st.end()){
x=dfn[*st.begin()];
y=dfn[*st.rbegin()];
//cout<<x<<y<<endl;
}
else
{
x=dfn[*it];
it--;
y=dfn[*it];
}
return dis[u]-dis[lca(u,x)]-dis[lca(u,y)]+dis[lca(x,y)];
}
int main()
{
int T,cas=1;
scanf("%d",&T);
while(T--)
{
init();
st.clear();//......
scanf("%d %d",&n,&q);
for(int i=0;i<n-1;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
printf("Case #%d:\n",cas++);
dfs(1,0);
rmq(2*n-1);//4
ans=0;
while(q--)
{
int id,u;
scanf("%d %d",&id,&u);
if(id==1&&!vis[u])
{
vis[u]=1;
ans+=cal(u);
st.insert(d[u]);
}
else if(id==2&&vis[u])
{
vis[u]=0;
st.erase(d[u]);
ans-=cal(u);
}
printf("%d\n",ans);
}
}
return 0;
}