问题描述
输入格式
第一行包含两个整数 N,Q,表示星球的数量和操作的数量。星球从 1 开始编号。
接下来的 Q 行,每行是如下两种格式之一:
A x y 表示在 x和 y之间连一条边。保证之前 x 和 y 是不联通的。
Q x y 表示询问 (x,y)这条边上的负载。保证 x和 y之间有一条边。
输出格式
对每个查询操作,输出被查询的边的负载。
样例输入1
8 6
A 2 3
A 3 4
A 3 8
A 8 7
A 6 5
Q 3 8
样例输出1
6
样例输入2
10 20
A 4 10
A 1 6
A 10 1
Q 4 10
A 9 8
A 7 9
Q 10 1
A 2 5
Q 9 8
Q 2 5
Q 4 10
Q 9 8
Q 2 5
A 8 3
Q 7 9
Q 2 5
A 5 7
Q 8 3
A 3 4
Q 7 9
样例输出2
3
4
2
1
3
2
1
3
1
5
21
数据范围
对于所有数据,1≤N,Q≤100000
题解
显然,根据乘法原理,一条边的“负载”等于边两端的连通点数的乘积。如样例一,【3,8】中,与3号点连通的点(在不走【3,8】能到达的点,包括它本身)有3个,与8号点连通的点有2个,所以“负载”是6 。
所以,求“负载”的时候,我们可以先把一端点x弄成根,设以x为根的子树节点数为Size[x],以y为根的子树节点数为Size[y],答案即为Size[x]*(Size[x]-Size[y])
维护动态树,LCT吧。。。
但是注意!当在维护Splay的时候,不要只记录Splay树中的节点个数。我是直接开了个v[]数组,记录以i 点为根且到i 点路径不经过i 点的preferred son的点数。但这个维护起来又有点恶心。。。
还有,1e5
显然是会爆int 的,注意使用long long。
代码
#include <cstdio>
#include <iostream>
#include <ctime>
#include <stack>
#include <cstdlib>
#include <algorithm>
#define ll long long
using namespace std;
const ll Q=100005;
ll ls[Q],rs[Q],si[Q],n,f[Q],an[Q],v[Q],lazy[Q];
void lx(ll x)
{
ll y=f[x],z=f[y];
if(z)if(ls[z]==y)ls[z]=x;
else rs[z]=x;
f[x]=z;
swap(an[x],an[y]);
rs[y]=ls[x];
f[rs[y]]=y;
f[y]=x;
ls[x]=y;
si[x]=si[y];
si[y]=si[ls[y]]+si[rs[y]]+v[y];
}
void rx(ll x)
{
ll y=f[x],z=f[y];
if(z)if(ls[z]==y)ls[z]=x;
else rs[z]=x;
f[x]=z;
swap(an[x],an[y]);
ls[y]=rs[x];
f[ls[y]]=y;
f[y]=x;
rs[x]=y;
si[x]=si[y];
si[y]=si[ls[y]]+si[rs[y]]+v[y];
}
void pd(ll x)
{
swap(ls[x],rs[x]);
if(ls[x])lazy[ls[x]]^=1;
if(rs[x])lazy[rs[x]]^=1;
lazy[x]=0;
}
int ding=0,st[Q];
void splay(ll x)
{
for(ll now=x;now;now=f[now])st[++ding]=now;
while(ding)
{
if(lazy[st[ding]])pd(st[ding]);
--ding;
}
while(f[x])
{
ll y=f[x],z=f[y];
if(z)if(ls[z]==y)
if(ls[y]==x)rx(y),rx(x);
else lx(x),rx(x);
else if(rs[y]==x)lx(y),lx(x);
else rx(x),lx(x);
else if(ls[y]==x)rx(x);
else lx(x);
}
}
void ac(ll x)
{
ll y=0;
while(x)
{
splay(x);
if(rs[x])
{
v[x]+=si[rs[x]];
f[rs[x]]=0;
an[rs[x]]=x;
}
rs[x]=y;
v[x]-=si[y];
if(y)f[y]=x;
y=x;
x=an[x];
}
}
void mr(ll x)
{
ac(x);
splay(x);
lazy[x]^=1;
}
int main()
{
char o[15];
ll i,x,y;
scanf("%lld%lld",&n,&i);
for(x=1;x<=n;x++)
si[x]=v[x]=1;
while(i--)
{
scanf("%s%lld%lld",o,&x,&y);
if(o[0]=='A'){
mr(x);
an[x]=y;
ll temp=si[x];
while(an[x])
{
splay(an[x]);
si[an[x]]+=temp;
v[an[x]]+=temp;
x=an[x];
}
ac(y);
}
else{
mr(x);
ac(y);
splay(y);
printf("%lld\n",si[x]*(si[y]-si[x]));
}
}
return 0;
}