测试地址:Minimax
做法:本题需要用到树形DP+线段树合并。
很快想到一种定义状态的方式:
表示点
的权值为
(离散化后)的概率,然后转移时,因为这是棵二叉树,令
为当前枚举的儿子,
为另一个儿子中权值小于
的概率,我们有以下状态转移方程:
因为题目中说明了叶子节点权值各不相同,所以不用考虑重复的问题。那么我们就可以这样转移,时间复杂度为
。
然而显然过不了,这怎么办呢?我们发现,上面的转移过程就是把两个序列拼在一起,然后再进行某些区间乘就是新的概率。我们知道求区间和和维护区间乘可以用线段树维护,那么这个“序列合并”的操作显然就可以用线段树合并来维护了,因为区间乘操作只会出现在两棵树中仅在一棵中存在的子树中,所以直接在线段树合并时打标记即可。于是我们就完成了这一题,时间复杂度为
。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
int n,first[300010]={0},tot=0,totp=0,rt[300010]={0},ch[6000010][2]={0};
ll p[6000010],seg[6000010]={0},tag[6000010]={0};
ll pxL,pyL,ans;
struct edge
{
int v,next;
}e[300010];
struct forsort
{
int id;
ll val;
}f[300010];
bool cmp(forsort a,forsort b)
{
return a.val<b.val;
}
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
ll power(ll a,ll b)
{
ll s=1,ss=a;
while(b)
{
if (b&1) s=s*ss%mod;
ss=ss*ss%mod;b>>=1;
}
return s;
}
void update(int x,ll p)
{
seg[x]=seg[x]*p%mod;
tag[x]=tag[x]*p%mod;
}
void pushdown(int no)
{
if (tag[no]!=1)
{
if (ch[no][0]) update(ch[no][0],tag[no]);
if (ch[no][1]) update(ch[no][1],tag[no]);
tag[no]=1;
}
}
void pushup(int no)
{
seg[no]=(seg[ch[no][0]]+seg[ch[no][1]])%mod;
}
void seginsert(int &no,int l,int r,int x)
{
if (!no) no=++totp;
tag[no]=1;
if (l==r) {seg[no]=1;return;}
int mid=(l+r)>>1;
if (x<=mid) seginsert(ch[no][0],l,mid,x);
else seginsert(ch[no][1],mid+1,r,x);
pushup(no);
}
int merge(int x,int y,ll p)
{
if (!x)
{
if (!y) return y;
pyL=(pyL+seg[y])%mod;
update(y,(p*pxL%mod+(1-p+mod)*(1-pxL+mod)%mod)%mod);
return y;
}
if (!y)
{
pxL=(pxL+seg[x])%mod;
update(x,(p*pyL%mod+(1-p+mod)*(1-pyL+mod)%mod)%mod);
return x;
}
pushdown(x),pushdown(y);
ch[x][0]=merge(ch[x][0],ch[y][0],p);
ch[x][1]=merge(ch[x][1],ch[y][1],p);
pushup(x);
return x;
}
void solve(int v)
{
int lson=0,rson=0;
if (first[v])
{
lson=e[first[v]].v;
solve(lson);
}
if (e[first[v]].next)
{
rson=e[e[first[v]].next].v;
solve(rson);
}
if (!lson) return;
if (!rson) rt[v]=rt[lson];
else
{
pxL=pyL=0;
rt[v]=merge(rt[lson],rt[rson],p[v]);
}
}
void query(int no,int l,int r)
{
if (l==r)
{
ans=(ans+(ll)l*f[l].val%mod*seg[no]%mod*seg[no]%mod)%mod;
return;
}
int mid=(l+r)>>1;
pushdown(no);
query(ch[no][0],l,mid);
query(ch[no][1],mid+1,r);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
if (x) insert(x,i);
}
tot=0;
ll inv=power(10000,mod-2);
for(int i=1;i<=n;i++)
{
if (!first[i])
{
scanf("%lld",&f[++tot].val);
f[tot].id=i;
}
else
{
scanf("%lld",&p[i]);
p[i]=p[i]*inv%mod;
}
}
sort(f+1,f+tot+1,cmp);
for(int i=1;i<=tot;i++)
{
rt[f[i].id]=++totp;
seginsert(rt[f[i].id],1,tot,i);
}
solve(1);
ans=0;
query(rt[1],1,tot);
printf("%lld",ans);
return 0;
}