min(a[u],a[v])*dis(u,v)
这个式子带min函数,dis函数,都比较麻烦,肯定需要化简的。
trick:
- dis(u,v)可以引入LCA,转化成
dis(1,u) + dis(1,v) - 2*dis(1,LCA)
- 对于min函数的处理,我们可以分类讨论,在rt子树中,把点权大于等于a[rt]的节点分为一类,小于a[rt]的节点分一类
- 于是可以把式子写成:
对于
a x 1 , a x 2 , . . . a x c n t 1 > = a u a_{x1},a_{x2},...a_{xcnt1} >= a_{u} ax1,ax2,...axcnt1>=au
贡献为:
a u ∗ ( d e p [ a u ] + d e p [ a x 1 ] − 2 ∗ d e p [ L C A ] ) a_{u}*(dep[a_{u}]+dep[a_{x1}]-2*dep[LCA]) au∗(dep[au]+dep[ax1]−2∗dep[LCA])
+ a u ∗ ( d e p [ a u ] + d e p [ a x 2 ] − 2 ∗ d e p [ L C A ] ) +a_{u}*(dep[a_{u}]+dep[a_{x2}]-2*dep[LCA]) +au∗(dep[au]+dep[ax2]−2∗dep[LCA])
…
+ a u ∗ ( d e p [ a u ] + d e p [ a x c n t 1 ] − 2 ∗ d e p [ L C A ] ) +a_{u}*(dep[a_{u}]+dep[a_{xcnt1}]-2*dep[LCA]) +au∗(dep[au]+dep[axcnt1]−2∗dep[LCA])
整理得到:
对于:
a y 1 , a y 2 , . . . a y c n t 2 < a u a_{y1},a_{y2},...a_{ycnt2} < a_{u} ay1,ay2,...aycnt2<au
贡献为:
整理得到:
所以我们可以对点权建权值线段树,维护4个变量:
①Σa[i] (点权为a[i]的点的点权和)
②Σcnt (点权为a[i]的点的个数)
③Σdep[i] (点权为a[i]的点的深度和)
④Σa[i]*dep[i] (点权为a[i]的点的点权 * 深度之和)
然后枚举LCA作为子树根节点,计算每个LCA的贡献即可。
用树状数组1000ms就行。
卑微线段树常数大,离散化+动态开点跑了1600ms…
#include<bits/stdc++.h>
using namespace std;
//#pragma GCC optimize(2)
#define ull unsigned long long
#define ll long long
#define pii pair<int, int>
#define pdd pair<double, double>
#define re register
const int maxn = 2e5 + 10;
const ll mod = 998244353;
const ll inf = (ll)4e17+5;
const int INF = 1e9 + 7;
const double pi = acos(-1.0);
ll inv(ll b){
if(b==1)return 1;return(mod-mod/b)*inv(mod%b)%mod;}
inline ll read()
{
ll x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){
if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';ch=getchar();}
return x*f;
}
//给定带点权的有根树 求 Σmin(a[i],b[j])*dis(i,j)
//分类讨论 枚举根节点作为LCA 求贡献
vector<int> g[maxn];
int n,n0;
ll a[maxn],b[maxn],idx[maxn];
int in[maxn],pos[maxn],clk,son[maxn],siz[maxn],dep[maxn];
ll ret;
//权值线段树模板
struct node
{
int cnt;
ll sum_d,sum_a,sum_ad;
node operator +(const node &f)const
{
node t;
t.cnt=cnt+f.cnt;
t.sum_a=(sum_a+f.sum_a)%mod;
t.sum_d=(sum_d+f.sum_d)%mod;
t.sum_ad=(sum_ad+f.sum_ad)%mod;
return t;
}
}tree[maxn*40];
int rt_node=0,cnt=0,lc[maxn*40],rc[maxn*40];//动态开点 rt_node用于传引用
inline void pushup(int rt)
{
tree[rt]=tree[lc[rt]]+tree[rc[rt]];
}
inline void upd(int &rt,int l,int r,int pos,int v,int f) //加入顶点v f为1或-1 表示加入或删除
{
if(!rt) rt=++cnt;
if(l==r)
{
tree[rt].cnt+=f;
tree[rt].sum_a=((tree[rt].sum_a+f*a[v])%mod + mod)%mod;
tree[rt].sum_d=((tree[rt].sum_d+f*dep[v])%mod + mod)%mod;
tree[rt].sum_ad=((tree[rt].sum_ad+f*a[v]*dep[v]%mod)%mod + mod)%mod;
return ;
}
int mid=l+r>>1;
if(pos<=mid) upd(lc[rt],l,mid,pos,v,f);
else upd(rc[rt],mid+1,r,pos,v,f);
pushup(rt);
}
inline node qry(int rt,int l,int r,int vl,int vr)
{
if(!rt || l>r)
{
node t={
0,0,0,0};
return t;
}
if(vl<=l && r<=vr) return tree[rt];
int mid=l+r>>1;
if(vr<=mid) return qry(lc[rt],l,mid,vl,vr);
else if(vl>mid) return qry(rc[rt],mid+1,r,vl,vr);
return qry(lc[rt],l,mid,vl,vr)+qry(rc[rt],mid+1,r,vl,vr);
}
//求重儿子+dfs序
void dfs1(int rt,int fa)
{
dep[rt]=dep[fa]+1;
siz[rt]=1;
in[rt]=++clk;
pos[clk]=rt;
for(int i:g[rt])
{
if(i==fa) continue;
dfs1(i,rt);
siz[rt]+=siz[i];
if(siz[i] > siz[son[rt]]) son[rt]=i;
}
}
int LCA;
inline ll cal(int u) //分别计算4部分 注意取模可能出现负数
{
ll ret=0;
node t1=qry(1,1,n0,idx[u],n0),t2=qry(1,1,n0,1,idx[u]-1);
ret=(ret + (1ll*dep[u]-2*dep[LCA]+mod) % mod * t1.cnt % mod * a[u] % mod) % mod;
ret=(ret + a[u] * t1.sum_d % mod)%mod;
ret=(ret + (1ll*dep[u]-2*dep[LCA]+mod) % mod * t2.sum_a % mod)%mod;
ret=(ret+t2.sum_ad)%mod;
return ret;
}
inline void add(int rt)
{
for(int i=in[rt];i<in[rt]+siz[rt];i++)
{
int u=pos[i];
ret=(ret+cal(u))%mod;
}
}
inline void up(int rt,int v)//子树每个点都加入
{
for(int i=in[rt];i<in[rt]+siz[rt];i++)
{
int u=pos[i];
upd(rt_node,1,n0,idx[u],u,v);
}
}
void dfs2(int rt,int fa,bool ok)
{
for(int i:g[rt])
{
if(i==son[rt] || i==fa) continue;
dfs2(i,rt,0);
}
if(son[rt]) dfs2(son[rt],rt,1);
LCA=rt;
upd(rt_node,1,n0,idx[rt],rt,1);
ret=(ret+cal(rt))%mod;//根节点也会产生贡献 需要在upd之前加
for(int i:g[rt])
{
if(i==son[rt] || i==fa) continue;
add(i);
up(i,1);
}
if(!ok) up(rt,-1);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
a[i]=read();
b[i]=a[i];
}
sort(b+1,b+n+1);
n0=unique(b+1,b+n+1)-b-1;//离散化
for(int i=1;i<=n;i++)
{
idx[i]=lower_bound(b+1,b+n0+1,a[i])-b;
}
for(int i=1,u,v;i<n;i++)
{
u=read();
v=read();
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1,0);
dfs2(1,0,1);
cout<<ret*2%mod<<'\n';
return 0;
}