树
解:
傻逼的我考试的时候去写这道题,没调出来,结果爆零了-_-
其实思想是很好懂的,一眼是一个树剖,然后陷入了无尽的推式子中。其实没写出来还有很大一部分原因是因为我数学太菜了,写到一半发现推错,又重新推,又推错……
不过写了好几道树剖,好歹有一道会的了。
首先我们可以很容易想到一种预处理的方法:记录每个点
为根到子树里所有路径的和
&平方和
。
我们用
表示
到
子树内所有点的平方和。
然后思考一个子树外点
到这棵子树内所有点的平方和:
如果
在
的子树内我们可以以一个较短的时间求答案,那是不是就可以做这道题了?
考虑容斥:我们把
一步一步往上跳求答案。(画个图)
我们现在已经求得
,要求
。
先假设
在
外算一遍,显然在
子树内点的贡献是错误的,我们需要把它更正一下,由于已经求了
,我们只需要在当前算的
中的
子树的贡献扣掉就好了。
具体式子就长成这样:
然后一步一步往上跳,递归求解就行。
然而这样做太慢了,考虑优化:跳链很容易想到树剖,更何况这里还需要维护链上信息。(似乎还可以做单点修改,不过可能太毒瘤了。)
现在需要考虑的是如何快速跳过一条重链?
假设我们跳过了一个点
,那它对答案的贡献是多少?
根据上面的式子,我们可以发现,点s的贡献在跳到s的时候加了一堆,跳到fa(s)的时候又减了一堆,我们把这两堆加起来(先画个图):
把所有含有
的式子提出来,然后我们预处理一下
和
。
把这两个东西在树上做一个前缀和,跳链就可以做到一次跳一条重链了。
如何统计答案?
我们维护一下跳链的数值,然后最后跳到顶由于没有了减掉的那一堆,我们把它加回来,最后就是
了。
回答询问时,讨论一下 是否在 中。
哇~5KB代码:
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
struct lxy{
int to,next;
long long len;
}b[200005];
long long const mod=1000000007;
int n,x,y,z,cnt,q,ui,vi;
int head[100005];
int wson[100005];
int size[100005];
int dep[100005];
long long fro[100005];
bool vis[100005];
int fa[100005];
int tp[100005];
long long qrt[100005];
long long unqrt[100005];
long long rc[100005];
long long xi[100005];
long long cooold[100005];
long long dis[100005];
void add(int op,int ed,int len)
{
b[++cnt].next=head[op];
b[cnt].len=len;
b[cnt].to=ed;
head[op]=cnt;
}
void dfs2(int u,int las)
{
tp[u]=las;
vis[u]=1;
if(wson[u]!=0) dfs2(wson[u],las);
for(int i=head[u];i!=-1;i=b[i].next)
if(vis[b[i].to]==0&&b[i].to!=wson[u])
dfs2(b[i].to,b[i].to);
vis[u]=0;
}
void dfs1(int u,int dp)
{
dep[u]=dp;
int weigh=0;
size[u]=1;vis[u]=1;
for(int i=head[u];i!=-1;i=b[i].next)
if(vis[b[i].to]==0)
{
fro[b[i].to]=fro[u]+b[i].len;
fa[b[i].to]=u;
xi[b[i].to]=b[i].len;
dfs1(b[i].to,dp+1);
size[u]+=size[b[i].to];
if(size[b[i].to]>weigh)
weigh=size[b[i].to],wson[u]=b[i].to;
}
vis[u]=0;
}
void dfs3(int u)
{
vis[u]=1;
for(int i=head[u];i!=-1;i=b[i].next)
if(vis[b[i].to]==0)
{
dfs3(b[i].to);
unqrt[u]=(unqrt[u]+unqrt[b[i].to]+(size[b[i].to]*b[i].len)%mod)%mod;
qrt[u]=(qrt[b[i].to]+qrt[u]+2*unqrt[b[i].to]*b[i].len%mod+size[b[i].to]*b[i].len%mod*b[i].len)%mod;
}
vis[u]=0;
}
long long dfs4(int u)
{
vis[u]=1;
long long ret=0;
for(int i=head[u];i!=-1;i=b[i].next)
if(vis[b[i].to]==0)
{
if(b[i].to==wson[u])
{
ret=dfs4(wson[u]);
ret=(ret+b[i].len)%mod;
rc[u]=(4*unqrt[u]*xi[u]%mod+size[u]*(4*ret*xi[u]%mod+4*xi[u]*xi[u]%mod)%mod+rc[wson[u]])%mod;
cooold[u]=(4*xi[u]*size[u]+cooold[wson[u]])%mod;
dis[u]=ret;
}
else dfs4(b[i].to);
}
if(wson[u]==0)
{
rc[u]=(4*xi[u]*xi[u])%mod;
cooold[u]=(4*xi[u]*size[u])%mod;
}
vis[u]=0;
return ret;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
dfs1(1,1);dfs2(1,1);dfs3(1);dfs4(1);
scanf("%d",&q);
for(int i=1;i<=q;i++)
{
scanf("%d%d",&x,&y);
ui=x,vi=y;
long long ans=0,len=0,road=0,ret=0;int lca;
while(tp[x]!=tp[y])
{
if(dep[tp[x]]>=dep[tp[y]])
{
ans=(ans+rc[tp[x]]-rc[wson[x]]+(len-dis[x])*(cooold[tp[x]]-cooold[wson[x]]))%mod;
len=(len+fro[x]-fro[fa[tp[x]]])%mod;
x=tp[x];
x=fa[x];
}
else if(dep[tp[x]]<dep[tp[y]])
{
road=(road+fro[y]-fro[fa[tp[y]]])%mod;
y=fa[tp[y]];
}
}
if(dep[y]>=dep[x])
{
road=(road+fro[y]-fro[x])%mod;
ans=(ans+rc[x]-rc[wson[x]]+(len-dis[x])*(cooold[x]-cooold[wson[x]]))%mod;
lca=x,y=x;
}
else{
ans=(ans+rc[y]-rc[wson[x]]+(len-dis[x])*(cooold[y]-cooold[wson[x]]))%mod;
len=(len+fro[x]-fro[y])%mod;
x=y;lca=y;
}
if(lca!=vi)
{
ret=(qrt[vi]+2*unqrt[vi]*(len+road)%mod+size[vi]*(len+road)%mod*(len+road)%mod)%mod;
len=(len+fro[x]-fro[fa[x]]);
x=fa[x];
while(tp[x]!=0)
{
ans=(ans+rc[tp[x]]-rc[wson[x]]+(len-dis[x])*(cooold[tp[x]]-cooold[wson[x]]))%mod;
len=(len+fro[x]-fro[fa[tp[x]]])%mod;
x=fa[tp[x]];
}
ans=(qrt[1]+2*unqrt[1]%mod*len+size[1]*len%mod*len-ans)%mod;
ret=2*ret-ans;
ret=ret%mod;
if(ret<0) ret=(ret+mod)%mod;
printf("%lld\n",ret);
continue;
}
if(lca==vi)
{
ret=(qrt[lca]+2*unqrt[lca]*(len+2*xi[lca])%mod+size[lca]*(len+2*xi[lca])%mod*(len+2*xi[lca])-ans)%mod;
len=(len+fro[x]-fro[fa[x]]);
x=fa[x];
while(tp[x]!=0)
{
ans=(ans+rc[tp[x]]-rc[wson[x]]+(len-dis[x])*(cooold[tp[x]]-cooold[wson[x]]))%mod;
len=(len+fro[x]-fro[fa[tp[x]]])%mod;
x=fa[tp[x]];
}
ans=(qrt[1]+2*unqrt[1]%mod*len+size[1]*len%mod*len-ans)%mod;
ret=2*ret-ans;
ret=ret%mod;
if(ret<0) ret=(ret+mod)%mod;
printf("%lld\n",ret);
continue;
}
}
}