正题
快来树上涂颜色吧!
其实就是普通的树链剖分+较复杂的线段树维护。
求一段颜色有多少个分块用线段树维护到底应该怎么求。
我们一棵线段子树应该怎么从左儿子和右儿子的关系推出来?
当然很简单,每一棵线段子树可以记录下最左节点的颜色和最右节点的颜色。
那么当且仅当左边的最右节点和右边的最左节点相同,一个区间才会被计算两次。
所以,我们应该这么更新我们的线段树。
sum[now]=sum[ls[now]]+sum[rs[now]];//sum为颜色总段数 if(right[ls[now]]==left[rs[now]]) sum[now]--;//如果重复计算,sum-- left[now]=left[ls[now]];//更新最左节点的颜色 right[now]=right[rs[now]];//更新最右节点的颜色
我们还要考虑一个问题,就是,我们只能让一条重链上的编号连续,但是不能让一条路径上的编号连续,所以我们向上跑lca计算答案的时候还要计算。
lx来记录x那边的当前最上面的颜色,rx来记录y那边的当前最上面的颜色。
ly来记录当前重链的顶端,ry记录当前重链的底端。
那么同时可以用一个bool来记录当前处理的是x边还是y边(false为x边,true为y边)
最后还要分情况讨论即可。
#include<cstdio> #include<cstring> #include<cstdlib> #define mid (x+y)/2 #define swap(x,y) {int t=x;x=y;y=t;} int n,m; int d[100010]; int top[100010],tot[100010],fa[100010],dep[100010],image[100010],fact[100010]; struct edge{ int y,next; }s[200010]; int len=0; int first[100010]; int left[200010],right[200010],ls[200010],rs[200010],sum[200010],son[200010]; int lazy[200010]; int root=0; int v,op; int ly,ry; void pushdown(int x){ if(lazy[x]==0) return ; lazy[ls[x]]=lazy[rs[x]]=lazy[x]; left[ls[x]]=right[ls[x]]=left[rs[x]]=right[rs[x]]=lazy[x]; sum[ls[x]]=sum[rs[x]]=1; lazy[x]=0; } void ins(int x,int y){ len++; s[len].y=y;s[len].next=first[x];first[x]=len; } void dfs_1(int x){ tot[x]=1; for(int i=first[x];i!=0;i=s[i].next){ int y=s[i].y; if(y!=fa[x]){ dep[y]=dep[x]+1; fa[y]=x; dfs_1(y); if(tot[y]>tot[son[x]]) son[x]=y; tot[x]+=tot[y]; } } } void dfs_2(int x,int tp){ top[x]=tp;image[x]=++len;fact[len]=x; if(son[x]!=0) dfs_2(son[x],tp); for(int i=first[x];i!=0;i=s[i].next){ int y=s[i].y; if(y!=fa[x] && y!=son[x]) dfs_2(y,y); } } void update(int &now,int x,int y){ if(now==0) now=++len; if(x==y){ left[now]=right[now]=op; sum[now]=1;lazy[now]=0; return ; } if(v<=mid) update(ls[now],x,mid); else update(rs[now],mid+1,y); sum[now]=sum[ls[now]]+sum[rs[now]]; if(right[ls[now]]==left[rs[now]]) sum[now]--; left[now]=left[ls[now]]; right[now]=right[rs[now]]; } int find(int now,int x,int y){ if(x==y) return right[now]; pushdown(now); if(v<=mid) return find(ls[now],x,mid); else return find(rs[now],mid+1,y); } int query_sum(int now,int l,int r,int x,int y){ if(l==x && r==y) return sum[now]; pushdown(now); if(r<=mid) return query_sum(ls[now],l,r,x,mid); else if(mid<l) return query_sum(rs[now],l,r,mid+1,y); else{ int temp=query_sum(ls[now],l,mid,x,mid)+query_sum(rs[now],mid+1,r,mid+1,y); if(right[ls[now]]==left[rs[now]]) temp--; return temp; } } int solve(){ int x,y; scanf("%d %d",&x,&y); int tx=top[x],ty=top[y]; int lx=0,rx=0; bool tf=true;//true为右,false为左 int ans=0; while(tx!=ty){ if(dep[tx]>dep[ty]){ swap(tx,ty); swap(x,y); if(tf==true) tf=false; else tf=true;//换x,y是要改一下tf } v=image[ty];ly=find(root,1,n);v=image[y];ry=find(root,1,n); ans+=query_sum(root,image[ty],image[y],1,n); if(tf==false){//处理x边 if(lx==ry) ans--;//相同说明计算重复 lx=ly;//更新lx } else{//处理y边 if(rx==ry) ans--;//相同说明计算重复 rx=ly;//更新rx } y=fa[ty];ty=top[y];//往上跳 } if(dep[x]>dep[y]) { swap(x,y); tf^=true; }//最后再做一遍 ans+=query_sum(root,image[x],image[y],1,n); v=image[x];ly=find(root,1,n);v=image[y];ry=find(root,1,n); if(tf){//这里最重要,最后的链在y边说明lx接ly,ry接rx if(lx==ly) ans--; if(rx==ry) ans--; } else{//否则就是反过来 if(lx==ry) ans--; if(rx==ly) ans--; } return ans; } void change_(int now,int l,int r,int x,int y,int c){ if(l==x && r==y){ left[now]=right[now]=c; sum[now]=1;lazy[now]=c; return ; } pushdown(now); if(r<=mid) change_(ls[now],l,r,x,mid,c); else if(mid<l) change_(rs[now],l,r,mid+1,y,c); else {change_(ls[now],l,mid,x,mid,c);change_(rs[now],mid+1,r,mid+1,y,c);} sum[now]=sum[ls[now]]+sum[rs[now]]; if(right[ls[now]]==left[rs[now]]) sum[now]--; left[now]=left[ls[now]]; right[now]=right[rs[now]]; } void change(){ int x,y,c; scanf("%d %d %d",&x,&y,&c); int tx=top[x],ty=top[y]; while(tx!=ty){ if(dep[tx]>dep[ty]){ swap(tx,ty); swap(x,y); } change_(root,image[ty],image[y],1,n,c); y=fa[ty];ty=top[y]; } if(dep[x]>dep[y]) swap(x,y); change_(root,image[x],image[y],1,n,c); } int main(){ scanf("%d %d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&d[i]); for(int i=1;i<=n-1;i++){ int x,y; scanf("%d %d",&x,&y); ins(x,y);ins(y,x); } dep[1]=1;dfs_1(1); len=0;dfs_2(1,1); len=0; for(int i=1;i<=n;i++){ v=image[i];op=d[i]; update(root,1,n); } char ch[2]; while(m--){ scanf("%s",ch); if(ch[0]=='Q') printf("%d\n",solve()); else change(); } }