Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input & Output
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面n - 1行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面m行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample
Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Output
3
1
2
Solution
维护查询树上路径,可以想到树链剖分,不过在维护答案的时候需要费一点是,如何使线段树满足区间加和性质呢?我们可以记录每个区间最左的颜色和最右的颜色,合并时,如果左儿子的右端点和右儿子的左端点相同,那么ans[x] = ans[lc] + ans[rc] - 1,否则就不减1,因为如果相同则一定会少一个颜色段。这一点在pushup和query的时候都要得到体现。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using std :: min;
using std :: max;
using std :: swap;
using std :: cin;
using std :: cout;
using std :: endl;
using std :: ios;
using std :: memset;
const int maxn = 100005;
struct node
{
int tot,lst,rst,tag;
}t[maxn << 2];
struct edge
{
int to,nxt;
}e[maxn << 1];
int n,m,c[maxn],f[maxn],dfn[maxn],son[maxn],top[maxn],size[maxn],dep[maxn];
int lnk[maxn],edgenum,u,v,cnt;
int w[maxn];
void add(int bgn,int end)
{
e[++edgenum].to = end;
e[edgenum].nxt = lnk[bgn];
lnk[bgn] = edgenum;
}
void dfs(int x,int fa,int d)
{
size[x] = 1;
f[x] = fa;
dep[x] = d;
for(int p = lnk[x]; p; p = e[p].nxt)
{
int y = e[p].to;
if(y == fa)continue;
dfs(y, x, d + 1);
size[x] += size[y];
if(size[y] > size[son[x]]) son[x] = y;
}
}
void dfs2(int x,int init)
{
dfn[x] = ++cnt;
w[cnt] = c[x];
top[x] = init;
if(!son[x])return;
dfs2(son[x],init);
for(int p = lnk[x]; p; p = e[p].nxt)
{
int y = e[p].to;
if(y == f[x]||y == son[x])continue;
dfs2(y,y);
}
}
void pushdown(int cur)
{
t[cur << 1].tot = t[cur << 1|1].tot = 1;
t[cur << 1].tag = t[cur << 1|1].tag = t[cur].tag;
t[cur << 1].lst = t[cur << 1].rst = t[cur].tag;
t[cur << 1|1].lst = t[cur << 1|1].rst = t[cur].tag;
t[cur].tag = 0;
}
void pushup(int cur)
{
t[cur].tot = t[cur << 1].tot + t[cur << 1|1].tot;
t[cur].lst = t[cur << 1].lst;
t[cur].rst = t[cur << 1|1].rst;
if(t[cur << 1].rst == t[cur << 1|1].lst) t[cur].tot--; //这里
}
void build(int cur,int l,int r)
{
if(l == r)
{
t[cur].tot = 1;
t[cur].lst = t[cur].rst = w[l];
t[cur].tag = 0;
return;
}
int mid = (l + r) >> 1;
build(cur << 1, l, mid);
build(cur << 1|1, mid+1, r);
pushup(cur);
}
int query(int cur,int l,int r,int L,int R)
{
int res = 0;
if(L <= l && r <= R) return t[cur].tot;
if(t[cur].tag) pushdown(cur);
int mid = (l + r) >> 1;
if(L <= mid) res += query(cur << 1, l, mid, L, R);
if(R > mid) res += query(cur << 1|1, mid+1, r, L, R);
if(mid >= L && mid < R && t[cur << 1].rst == t[cur << 1|1].lst) res--; //这里
return res;
}
void update(int cur,int l,int r,int L,int R,int x)
{
if(L <= l && r <= R)
{
t[cur].tot = 1;
t[cur].lst = t[cur].rst = x;
t[cur].tag = x;
return;
}
if(t[cur].tag) pushdown(cur);
int mid = (l + r) >> 1;
if(L <= mid)update(cur<<1,l,mid,L,R,x);
if(R > mid)update(cur<<1|1,mid+1,r,L,R,x);
pushup(cur);
}
int check(int cur,int l,int r,int pos)
{
if(l == r)return t[cur].lst;
if(t[cur].tag) pushdown(cur);
int mid = (l + r) >> 1;
if(pos > mid) return check(cur << 1|1, mid+1, r, pos);
else return check(cur << 1,l,mid,pos);
}
int queryt(int x,int y)
{
int ans = 0, b1,b2;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])swap(x,y);
ans += query(1,1,n,dfn[top[x]],dfn[x]);
b1 = check(1,1,n,dfn[top[x]]);
b2 = check(1,1,n,dfn[f[top[x]]]); //和这里
if(b1 == b2) ans--;
x = f[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans += query(1,1,n,dfn[x],dfn[y]);
return ans;
}
void updt(int x,int y,int c)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],c);
x = f[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
update(1,1,n,dfn[x],dfn[y],c);
}
int main()
{
ios :: sync_with_stdio(false);
char opt;
int x,y,z;
cin >> n >> m;
for(int i = 1; i <= n; ++i)
cin >> c[i];
for(int i = 1; i < n; ++i)
{
cin >> u >> v;
add(u,v);
add(v,u);
}
dfs(1,0,1);
dfs2(1,1);
build(1,1,n);
for(int i = 1; i <= m; ++i)
{
cin >> opt;
if(opt == 'Q')
{
cin >> x >> y;
cout << queryt(x,y) << endl;
}
else
{
cin >> x >> y >> z;
updt(x, y, z);
}
}
return 0;
}