链接: https://www.nowcoder.com/acm/contest/180/E
思路: 对于前两个操作就是最基本操作,那么问题就在于第三个操作,可以发现第三个操作就是求a1*(a2+a3+...+ an-1+an)+ a2*(a3+ a4+...+ an) +a3*( a4+...+an) +...+an-1*an;
那么转化一下就是[(a1+a2+a3+...+an)*(a1+a2+a3+...+an)-(a1*a1+a2*a2+...+an*an) ]/2;
那么就维护 sum 和 二次方sum就可以了。
代码:
#include<bits/stdc++.h>
#define lson (i<<1)
#define rson (i<<1|1)
using namespace std;
typedef long long ll;
const int N =1e5+5;
const ll mod=1e9+7;
const ll inv2=500000004;
struct eee
{
int v;
int next;
}edge[N*2];
int tot,head[N];
struct node
{
int l,r;
ll sum1;
ll sum2;
ll lz;
}tr[N<<2];
int fat[N]; /// 当前节点的直接父亲
int dep[N]; /// 当前节点的在树上深度
int siz[N]; /// 当前节点的孩子个数
int son[N]; /// 当前节点的重孩子
int rak[N]; /// 线段树的第i个节点是?
int top[N]; /// 当前节点的链开始节点 top
int idd[N]; /// x在线段树中第几个节点
int cnt;
int n,m;
ll a[N];
void init()
{
tot=0;
cnt=0;
memset(head,-1,sizeof(head));
memset(son,0,sizeof(son));
memset(siz,0,sizeof(siz));
}
void add(int u,int v)
{
edge[++tot].v=v; edge[tot].next=head[u]; head[u]=tot;
}
void dfs1(int u,int fa,int deep)
{
fat[u]=fa;
dep[u]=deep;
siz[u]=1;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v==fa) continue;
dfs1(v,u,deep+1);
siz[u]+=siz[v];
if(son[u]==0||siz[v]>siz[son[u]]){
son[u]=v;
}
}
}
void dfs2(int u,int t)
{
top[u]=t;
idd[u]=++cnt;
rak[cnt]=u;
if(!son[u]) return ;
dfs2(son[u],t);
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
//if(v==fat[u]) continue;
if(v!=son[u]&&v!=fat[u]){
dfs2(v,v);
}
}
}
void push_up(int i)
{
tr[i].sum1=(tr[lson].sum1+tr[rson].sum1)%mod;
tr[i].sum2=(tr[lson].sum2+tr[rson].sum2)%mod;
}
void build(int i,int l,int r)
{
tr[i].l=l; tr[i].r=r; tr[i].sum1=tr[i].sum2=0;
if(l==r){
tr[i].sum1=a[rak[l]];
tr[i].sum2=(a[rak[l]]*a[rak[l]])%mod;
//cout<<"sum1 "<<tr[i].sum1<<" sum2 "<<tr[i]sum2<<endl;
return ;
}
int mid=(l+r)>>1;
build(lson,l,mid);
build(rson,mid+1,r);
push_up(i);
}
void solve(int i,ll val)
{
tr[i].lz=(tr[i].lz+val)%mod;
ll cnt=tr[i].r-tr[i].l+1;
ll tmp1=(tr[i].sum1*2%mod*val)%mod;
ll tmp2=(cnt*val%mod*val)%mod;
tr[i].sum2=(tr[i].sum2+tmp1+tmp2)%mod;
tr[i].sum1=(tr[i].sum1+cnt*val%mod)%mod;
}
void push_down(int i)
{
if(tr[i].lz){
ll &lz=tr[i].lz;
solve(lson,lz);
solve(rson,lz);
lz=0;
}
}
void update(int i,int l,int r,ll val)
{
if(tr[i].l==l&&tr[i].r==r){
solve(i,val);
return ;
}
push_down(i);
int mid=(tr[i].l+tr[i].r)>>1;
if(r<=mid) update(lson,l,r,val);
else if(l>mid) update(rson,l,r,val);
else{
update(lson,l,mid,val);
update(rson,mid+1,r,val);
}
push_up(i);
}
void query(int i,int l,int r,ll &sum1,ll &sum2)
{
if(tr[i].l==l&&tr[i].r==r){
sum1+=tr[i].sum1;
sum1%=mod;
sum2+=tr[i].sum2;
sum2%=mod;
return ;
}
push_down(i);
int mid=(tr[i].l+tr[i].r)>>1;
if(r<=mid) return query(lson,l,r,sum1,sum2);
else if(l>mid) return query(rson,l,r,sum1,sum2);
else{
query(lson,l,mid,sum1,sum2);
query(rson,mid+1,r,sum1,sum2);
}
}
ll querys(int x,int y)
{
ll sum1=0,sum2=0;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]>=dep[fy]){
query(1,idd[fx],idd[x],sum1,sum2);
x=fat[fx]; fx=top[x];
}
else{
query(1,idd[fy],idd[y],sum1,sum2);
y=fat[fy]; fy=top[y];
}
}
if(idd[x]<=idd[y]){
query(1,idd[x],idd[y],sum1,sum2);
}
else{
query(1,idd[y],idd[x],sum1,sum2);
}
ll ans=(sum1*sum1%mod-sum2+mod)*inv2%mod;
return ans;
}
void updates(int x,int y,ll c)
{
int fx=top[x]; int fy=top[y];
while(fx!=fy)
{
if(dep[fx]>=dep[fy]){
update(1,idd[fx],idd[x],c);
x=fat[fx];
}
else{
update(1,idd[fy],idd[y],c);
y=fat[fy];
}
fx=top[x];
fy=top[y];
}
if(idd[x]<=idd[y]) update(1,idd[x],idd[y],c);
else update(1,idd[y],idd[x],c);
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
init();
int u,v;
for(int i=1;i<n;i++){
scanf("%d %d",&u,&v);
add(u,v);
add(v,u);
}
dfs1(1,1,0);
dfs2(1,1);
build(1,1,n);
int op;
ll val;
while(m--)
{
scanf("%d",&op);
if(op==1){
scanf("%d %lld",&u,&val);
update(1,idd[u],idd[u]+siz[u]-1,val);
}
else if(op==2){
scanf("%d %d %lld",&u,&v,&val);
updates(u,v,val);
}
else{
scanf("%d %d",&u,&v);
ll Ans=querys(u,v);
printf("%lld\n",Ans);
}
}
return 0;
}