原题见洛谷。
分析
对线段树中每个节点,维护左儿子lc,右儿子rc,左端点值l_color,右端点值r_color,不同颜色段计数cnt。
线段树的基本操作需要注意的:
1,在上传操作时,如果左儿子的右端点颜色和右儿子的左端点颜色一致,cnt需要-1。
2,下传时,左右儿子的cnt改为1。
3,查询时,处理方法同1。
树剖操作需要注意的:
1,更新正常操作。
2,查询时,因为首先是一跳一条链,而两条链交接处可能颜色相同,所以要记录当前链的端点颜色和上一条链的端点颜色。
具体来说,由于是往上跳,当前链右端点与上一条链左端点相接,所以维护这两个即可。由于上跳时是交替进行,所以要维护两个上下链。
此外,当a,b跳到同一条链上,深度小的对应当前链的左端,深度大的对应右端,两端都要考虑是否颜色与上一条链端点颜色相同(它们分别从哪条链跳上来)。
一旦有一次端点颜色相同,答案减一。
#include<bits/stdc++.h>
using namespace std;
const int MAXN=100005;
int lc[MAXN*2],rc[MAXN*2],l_color[MAXN*2],r_color[MAXN*2],change[MAXN*2],cnt[MAXN*2];
int last[MAXN],fa[MAXN],dep[MAXN],top_node[MAXN],big_son[MAXN],size[MAXN],dfs_pos[MAXN],a[MAXN],b[MAXN];
int N,M,np=0,np2=0,rt=0,dfs_clock=0;
struct edge{int to,pre;}E[MAXN*2];
char c;
void scan(int &x)
{
for(c=getchar();c<'0'||c>'9';c=getchar());
for(x=0;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
}
char num[20];int ct;
void print(int x)
{
ct=0;
if(!x) num[ct++]='0';
while(x) num[ct++]=x%10+'0',x/=10;
while(ct--) putchar(num[ct]);
putchar('\n');
}
//------------
void addedge(int u,int v)
{
E[++np2]=(edge){v,last[u]};
last[u]=np2;
}
void DFS1(int i,int f,int d)
{
dep[i]=d; fa[i]=f; size[i]=1;
for(int p=last[i];p;p=E[p].pre)
{
int j=E[p].to;
if(j==f) continue;
DFS1(j,i,d+1); size[i]+=size[j];
if(size[j]>size[big_son[i]]) big_son[i]=j;
}
}
void DFS2(int i,int top)
{
dfs_pos[i]=++dfs_clock;
top_node[i]=top; b[dfs_clock]=a[i];
if(!big_son[i]) return; DFS2(big_son[i],top);
for(int p=last[i];p;p=E[p].pre)
{
int j=E[p].to;
if(j==fa[i]||j==big_son[i]) continue;
DFS2(j,j);
}
}
//-----------------
void pushup(int now)
{
int l=lc[now],r=rc[now];
l_color[now]=l_color[l]; r_color[now]=r_color[r];
cnt[now]=cnt[l]+cnt[r];
if(r_color[l]==l_color[r]) cnt[now]-=1;
}
void build(int &now,int L,int R)
{
if(!now) now=++np;
if(L==R)
{
l_color[now]=r_color[now]=b[L];
cnt[now]=1; return;
}
int mid=(L+R)/2;
build(lc[now],L,mid);
build(rc[now],mid+1,R);
pushup(now);
}
void pushdown(int now)
{
if(!change[now]) return;
int l=lc[now],r=rc[now];
cnt[l]=cnt[r]=1;
change[l]=change[r]=change[now];
l_color[l]=r_color[l]=l_color[r]=r_color[r]=change[now];
change[now]=0;
}
void update(int now,int L,int R,int i,int j,int c)
{
if(i<=L&&R<=j)
{
change[now]=l_color[now]=r_color[now]=c;
cnt[now]=1; return;
}
pushdown(now); int mid=(L+R)/2;
if(i<=mid) update(lc[now],L,mid,i,j,c);
if(mid<j) update(rc[now],mid+1,R,i,j,c);
pushup(now);
}
void uprange(int u,int v,int color)
{
while(top_node[u]!=top_node[v])
{
if(dep[top_node[u]]<dep[top_node[v]]) swap(u,v);
update(rt,1,N,dfs_pos[top_node[u]],dfs_pos[u],color);
u=fa[top_node[u]];
}
if(dep[u]>dep[v]) swap(u,v);
update(rt,1,N,dfs_pos[u],dfs_pos[v],color);
}
int now_lc,now_rc;
int ques(int now,int L,int R,int i,int j)
{
if(i==L) now_lc=l_color[now];
if(R==j) now_rc=r_color[now];
if(i<=L&&R<=j) return cnt[now];
pushdown(now); int mid=(L+R)/2;
if(i<=mid&&mid<j)
{
int ans=ques(lc[now],L,mid,i,j)+ques(rc[now],mid+1,R,i,j);
if(r_color[lc[now]]==l_color[rc[now]]) ans--;
return ans;
}
if(i<=mid) return ques(lc[now],L,mid,i,j);
if(mid<j) return ques(rc[now],mid+1,R,i,j);
}
int qrange(int u,int v)
{
int ans=0,pre_u=-1,pre_v=-1; //u点上一条链的左端点颜色。v同理
while(top_node[u]!=top_node[v])
{
if(dep[top_node[u]]<dep[top_node[v]]) {swap(u,v);swap(pre_u,pre_v);}
ans+=ques(rt,1,N,dfs_pos[top_node[u]],dfs_pos[u]);
if(pre_u==now_rc) ans--;
pre_u=now_lc; u=fa[top_node[u]];
}
if(dep[u]>dep[v]) {swap(u,v);swap(pre_u,pre_v);}
ans+=ques(rt,1,N,dfs_pos[u],dfs_pos[v]);
if(now_rc==pre_v) ans--;
if(now_lc==pre_u) ans--;
return ans;
}
int main()
{
int i,u,v,w; char op;
scan(N);scan(M);
for(i=1;i<=N;i++) scan(a[i]);
for(i=1;i<N;i++)
{
scan(u);scan(v);
addedge(u,v);
addedge(v,u);
}
DFS1(1,0,1); DFS2(1,1); build(rt,1,N);
while(M--)
{
scanf("%s",&op);
if(op=='C')
{
scan(u);scan(v);scan(w);
uprange(u,v,w);
}
else
{
scan(u);scan(v);
print(qrange(u,v));
}
}
return 0;
}