csuoj 2289 Colorful Tree(lca+dfs序)

csuoj 2289 Colorful Tree

题目大意

给出一个树,每个节点都有颜色.每次给出一个操作,一为求某种颜色的最小生成树的边的数量,一为将某个节点的颜色改变

解题思路

对每种颜色都先维护出其生成树的大小,及这个颜色的点的集合.每次更改一个点的颜色都是将其原属于的生成树减去这个节点,为新的颜色的生成树增加这个节点.点到生成树的最近距离即与其最dfs序相邻的节点的距离.

AC代码

#include<bits/stdc++.h>
using namespace std;
#define int long long 
const int maxn=2e5+5;
int tot=1;
int id[maxn]; 
vector<int> G[maxn];
int dis[maxn],dep[maxn],pa[maxn][20];
int col[maxn];
int fds[maxn];
set<int> ss[maxn];
int pas[maxn];
int ans[maxn];
void dfs(int u,int F,int d)
{
	dep[u]=d;
	if(u==1) 
	{
		for(int i=0;i<20;i++) pa[u][i]=1;
	}
	else
	{
		pa[u][0]=F;
		for(int i=1;i<20;i++)
		{
			pa[u][i]=pa[pa[u][i-1]][i-1];
		}
	}
	for(auto x:G[u])
	{
		if(x==F) continue;
		dfs(x,u,d+1);
	}
}
int Jump(int u,int d)
{
	for(int j=19;j>=0;j--)
	{
		if((1<<j)&d)
		{
			u=pa[u][j];
		}
	}
	return u;
}
int lca(int u,int v)
{
	if(dep[u]<dep[v]) swap(u,v);
	u=Jump(u,dep[u]-dep[v]);
	if(u==v) return u;
	for(int i=19;i>=0;i--)
	{
		if(pa[u][i]!=pa[v][i])
		{
			u=pa[u][i];
			v=pa[v][i];
		}
	}
	return pa[u][0];
}
int diss(int x,int y)
{
	return dep[x]+dep[y]-2*dep[lca(x,y)];
}
void dfs0(int v)
{
	fds[tot]=v;
	id[v]=tot++;
	for(auto x:G[v])
	{
		if(id[x]!=-1) continue;
		dfs0(x);
	}
}	

void add(int v,int c)
{
	
	if(ss[c].size()){
		auto l=ss[c].upper_bound(id[v]);
		auto r=l;
		if(r==ss[c].end()) l=ss[c].begin(),--r;
		else if(l!=ss[c].begin()) --l;
		else r=--ss[c].end();
		int u1=fds[*l],u2=fds[*r];
		ans[c]+=(diss(u1,v)+diss(u2,v)-diss(u1,u2))/2;
//		cout<<ans[col[v]]<<endl;
	}ss[c].insert(id[v]);
	
}
void del(int v,int c)
{	
	ss[c].erase(ss[c].lower_bound(id[v]));
	if(ss[col[v]].size()){
		auto l=ss[c].upper_bound(id[v]);
		auto r=l;
		if(r==ss[c].end()) l=ss[c].begin(),--r;
		else if(l!=ss[c].begin()) --l;
		else r=--ss[c].end();
		int u1=fds[*l],u2=fds[*r];
		ans[c]-=(diss(u1,v)+diss(u2,v)-diss(u1,u2))/2;
	}
	
}
void update(int v,int c)
{
	del(v,col[v]);
	col[v]=c;
	add(v,c);
	
}	
int32_t main()
{
	memset(id,-1,sizeof(id));
	memset(ans,0,sizeof(ans));
	int n;
	scanf("%lld",&n);
	int u,v;
	for(int i=1;i<n;i++)
	{
		scanf("%lld%lld",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs0(1);
	dfs(1,1,0);
	for(int i=1;i<=n;i++)
	{
		scanf("%lld",&col[i]);
		add(i,col[i]);
	}	
	
	int m;
	scanf("%lld",&m);
	char s[5];
	int x,c;
	while(m--)
	{
		scanf("%s",&s);
		if(s[0]=='U')
		{
			scanf("%lld%lld",&x,&c);
			update(x,c);
		}
		if(s[0]=='Q')
		{ 
			scanf("%lld",&c);
			if(ss[c].size()==0) printf("-1\n");
			else printf("%lld\n",ans[c]);
		}
	}
}
//

猜你喜欢

转载自blog.csdn.net/baiyifeifei/article/details/88921851