很明显的线段树+树链剖分,心态炸了调了1.5h,最后发现竟然是change(k<<1,tl,tr,x);写成change(k<<1,mid+1,tr,x);无语了。
#include<bits/stdc++.h> using namespace std; const int maxn=1e6+10; struct edge{ int to,next; }e[maxn*2]; struct list{ int l,r,sum,lsum,rsum; }tree[maxn*4]; int dep[maxn],top[maxn],size[maxn],pos[maxn],f[maxn],lazy[maxn*4],sz=0,cnt=0,n,m,a[maxn],root=1,head[maxn]; void add_edge(int s,int t){ e[++cnt].next=head[s];e[cnt].to=t;head[s]=cnt; e[++cnt].next=head[t];e[cnt].to=s;head[t]=cnt; } void pushup(int k){ tree[k].lsum=tree[k<<1].lsum;tree[k].rsum=tree[k<<1|1].rsum; int t=tree[k<<1].sum+tree[k<<1|1].sum; if(tree[k<<1].rsum==tree[k<<1|1].lsum)t--; tree[k].sum=t; } void pushlazy(int k){ if(lazy[k]){ tree[k<<1].sum=tree[k<<1|1].sum=1; tree[k<<1].rsum=tree[k<<1|1].lsum=tree[k<<1].lsum=tree[k<<1|1].rsum=lazy[k]; lazy[k<<1]=lazy[k<<1|1]=lazy[k]; lazy[k]=0; } } void build(int k,int l,int r){ tree[k].l=l;tree[k].r=r; if(l==r)return ; int mid=(l+r)>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); } void init(){ scanf("%d%d",&n,&m); build(1,1,n); for(int i=1;i<=n;++i)scanf("%d",&a[i]); for(int i=1;i<=n-1;++i){ int s,t; scanf("%d%d",&s,&t); add_edge(s,t); } } void dfs1(int u){ size[u]=1; for(int i=head[u];i;i=e[i].next){ if(f[u]!=e[i].to){ f[e[i].to]=u; dep[e[i].to]=dep[u]+1; dfs1(e[i].to); size[u]+=size[e[i].to]; } } } void change(int k,int tl,int tr,int x){ pushlazy(k); int l=tree[k].l,r=tree[k].r; if(tl<=l&&tr>=r){ tree[k].sum=1; tree[k].lsum=tree[k].rsum=x; lazy[k]=x; return ; } int mid=(l+r)>>1; if(mid>=tl)change(k<<1,tl,tr,x); if(mid<tr)change(k<<1|1,tl,tr,x); pushup(k); } int query(int k,int tl,int tr){ pushlazy(k); int l=tree[k].l,r=tree[k].r; if(tl<=l&&tr>=r){ return tree[k].sum; } int mid=(l+r)>>1,ret=0; if(mid>=tl)ret+=query(k<<1,tl,tr); if(mid<tr)ret+=query(k<<1|1,tl,tr); if(mid>=tl&&mid<tr&&tree[k<<1].rsum==tree[k<<1|1].lsum)ret--; return ret; } int get_c(int k,int p){ pushlazy(k); int l=tree[k].l,r=tree[k].r; if(l==r){ return tree[k].lsum; } int mid=(l+r)>>1; if(mid>=p)return get_c(k<<1,p); else return get_c(k<<1|1,p); } void dfs2(int u,int ance){ top[u]=ance;++sz; pos[u]=sz;int k=0; for(int i=head[u];i;i=e[i].next){ if(f[u]!=e[i].to&&size[e[i].to]>size[k]){ k=e[i].to; } } if(k==0)return ; dfs2(k,ance); for(int i=head[u];i;i=e[i].next){ if(f[u]!=e[i].to&&e[i].to!=k){ dfs2(e[i].to,e[i].to); } } } void solvesum(int a,int b,int x){ while(top[a]!=top[b]){ if(dep[top[a]]<dep[top[b]])swap(a,b); change(1,pos[top[a]],pos[a],x); a=f[top[a]]; } if(dep[a]>dep[b])swap(a,b); change(1,pos[a],pos[b],x); } int get_sum(int a,int b){ int ret=0; while(top[a]!=top[b]){ if(dep[top[a]]<dep[top[b]])swap(a,b); ret+=query(1,pos[top[a]],pos[a]); if(get_c(1,pos[top[a]])==get_c(1,pos[f[top[a]]]))--ret; a=f[top[a]]; } if(dep[a]>dep[b])swap(a,b); ret+=query(1,pos[a],pos[b]); return ret; } void solve(){ for(int i=1;i<=n;++i){ change(1,pos[i],pos[i],a[i]); } for(int i=1;i<=m;++i){ char opt[2]; int a,b; scanf("%s%d%d",opt,&a,&b); if(opt[0]=='C'){ int c; scanf("%d",&c); solvesum(a,b,c); } else{ printf("%d\n",get_sum(a,b)); } } } int main(){ init(); dfs1(root); dfs2(root,root); solve(); }