题目大意:一棵树,3种操作:将结点a的值改为b,查询a到b路径间最大结点权值,查询a到b间结点权值总和
分析:树剖模板。。。将树链分段映射到线段树上统计维护最大值和总和。
代码
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<vector>
#include<math.h>
using namespace std;
vector<int>g[30010];
int A[30010],su[300010],ma[300010];
int fa[30010],siz[30010],h[30010],son[30010],cnt,top[30010],pos[30010];
char s[20];
void dfs1(int u){
int i,v;
siz[u]=1;
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(v!=fa[u]){
h[v]=h[u]+1;
fa[v]=u;
dfs1(v);
siz[u]+=siz[v];
if(son[u]==-1||siz[son[u]]<siz[v]) son[u]=v;
}
}
}
void dfs2(int u,int f){
top[u]=f;
pos[u]=cnt++;
if(son[u]==-1) return;
dfs2(son[u],f);
int i,v;
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(v!=fa[u]&&v!=son[u]){
dfs2(v,v);
}
}
}
void update(int root,int l,int r,int a,int b){
if(l==r){
ma[root]=b;
su[root]=b;
return;
}
int mid=(l+r)/2;
if(a<=mid) update(root*2,l,mid,a,b);
else update(root*2+1,mid+1,r,a,b);
su[root]=su[root*2]+su[root*2+1];
ma[root]=ma[root*2];
if(ma[root]<ma[root*2+1]) ma[root]=ma[root*2+1];
}
int qmax(int root,int l,int r,int a,int b){
if(l==a&&r==b) return ma[root];
int mid=(l+r)/2,m,x,y;
if(b<=mid) m=qmax(root*2,l,mid,a,b);
else if(a>mid) m=qmax(root*2+1,mid+1,r,a,b);
else{
x=qmax(root*2,l,mid,a,mid);
y=qmax(root*2+1,mid+1,r,mid+1,b);
if(x>y) m=x;
else m=y;
}
return m;
}
int qsum(int root,int l,int r,int a,int b){
if(l==a&&r==b) return su[root];
int mid=(l+r)/2,m;
if(b<=mid) m=qsum(root*2,l,mid,a,b);
else if(a>mid) m=qsum(root*2+1,mid+1,r,a,b);
else{
m=qsum(root*2,l,mid,a,mid)+qsum(root*2+1,mid+1,r,mid+1,b);
}
return m;
}
int fmax(int a,int b){
int m=-1000000,v;
while(top[a]!=top[b]){
if(h[top[a]]<h[top[b]]) swap(a,b);
v=qmax(1,0,30010,pos[top[a]],pos[a]);
if(v>m) m=v;
a=fa[top[a]];
}
if(h[a]>h[b]) swap(a,b);
v=qmax(1,0,30010,pos[a],pos[b]);
if(v>m) m=v;
return m;
}
int fsum(int a,int b){
int m=0;
while(top[a]!=top[b]){
if(h[top[a]]<h[top[b]]) swap(a,b);
m+=qsum(1,0,30010,pos[top[a]],pos[a]);
a=fa[top[a]];
}
if(h[a]>h[b]) swap(a,b);
m+=qsum(1,0,30010,pos[a],pos[b]);
return m;
}
int main(){
int i,n,a,b,q;
scanf("%d",&n);
memset(son,-1,sizeof(son));
cnt=h[1]=fa[1]=0;
for(i=1;i<n;i++){
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs1(1);
dfs2(1,1);
for(i=1;i<=n;i++){
scanf("%d",&A[i]);
update(1,0,30010,pos[i],A[i]);
}
//cout<<ma[1]<<endl;
scanf("%d",&q);
while(q--){
scanf("%s%d%d",&s,&a,&b);
if(s[1]=='M') printf("%d\n",fmax(a,b));
else if(s[1]=='S') printf("%d\n",fsum(a,b));
else update(1,0,30010,pos[a],b);
}
return 0;
}