6438. 【GDOI2020模拟01.16】树上的鼠(长链剖分)

题目描述

Description

Input

Output

Sample Input
3
1 2
1 3

Sample Output
2

Explanation
只有连通块为整棵树时或只有一个点时小筄会输,其余情况小筄会赢。

Data Constraint

题解

一个连通块先手必败,当且仅当1在直径的中点且直径长度为奇数

证明:

若长度为奇数且不在中点,则可以先手移到中点,对方无论怎么移都可以移到直径上的对称点

若长度为偶数,则可以先手移到较远的中点,因为对手下一步的移动距离>1,所以不能移到另一个中点上,所以类似奇数的情况

\(f[i][j]\)表示以点i为根,最大深度为j的子树个数,最后在1处合并

与深度有关的dp显然是长链剖分,把\(f[u][1\sim x]\)\(f[v][1\sim y]\)\(x>y\))合并时,对于\(1\sim y\)的部分暴力合并,然后在\(y+1\)处对后面的数打赏后缀乘标记

转移时可以直接维护\(f\)值的后缀,最后计算答案时用总数-不合法,再用一个类似的dp计算不合法数即可

至于状态的保存,因为一条链的状态数=链长,所以可以按深度从上到下存到树上,表示深度1,2,3...时的\(f\),在上传重儿子时也只需考虑新加的那一个点

code

#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define mod 998244353
#define file
using namespace std;

int a[2000001][2];
int b[1000002];
int c[1000002];
int ls[1000001];
int dp[1000001];
int fa[1000001];
bool bz[1000001];
int d[1000001];
int I[1000001];
int nx[1000001];
long long g[1000001];
long long f[1000001];
long long F[1000001];
int st[1000001];
long long Ans[1000001][3];
long long Ans2[1000001];
int n,i,j,k,l,len,h,t,tot;
long long ans;

void New(int x,int y)
{
    ++len;
    a[len][0]=y;
    a[len][1]=ls[x];
    ls[x]=len;
}

void bfs()
{
    int i,mx,mx2;
    
    h=0;t=1;
    d[1]=1;
    bz[1]=1;
    
    while (h<t)
    {
        for (i=ls[d[++h]]; i; i=a[i][1])
        if (!bz[a[i][0]])
        {
            fa[a[i][0]]=d[h];
            
            bz[a[i][0]]=1;
            d[++t]=a[i][0];
        }
    }
    
    fd(l,t,1)
    {
        mx=0;mx2=-1;
        g[d[l]]=1;
        
        for (i=ls[d[l]]; i; i=a[i][1])
        if (a[i][0]!=fa[d[l]])
        {
            g[d[l]]=g[d[l]]*(g[a[i][0]]+1)%mod;
            
            if (dp[a[i][0]]>mx)
            mx=dp[a[i][0]],mx2=a[i][0];
        }
        
        nx[d[l]]=mx2;
        dp[d[l]]=mx+1;
    }
    ans=g[1];
}

void down(int t)
{
    f[t]=f[t]*F[t]%mod;
    if (nx[t]!=-1)
    F[nx[t]]=F[nx[t]]*F[t]%mod;
    F[t]=1;
}

void dfs(int st)
{
    int T,i,j,k,l;
    long long sum,s1,s2;
    
    t=1;
    d[1]=st;
    
    while (t)
    {
        T=d[t];
        
        if (nx[T]!=-1 && bz[T])
        {
            bz[T]=0;
            d[++t]=nx[T];
        }
        else
        if (I[T])
        {
            if (a[I[T]][0]==fa[T] || a[I[T]][0]==nx[T])
            I[T]=a[I[T]][1];
            else
            {
                d[++t]=a[I[T]][0];
                I[T]=a[I[T]][1];
            }
        }
        else
        {
            if (fa[T]!=1)
            {
                if (nx[fa[T]]==T)
                f[fa[T]]=(f[T]*F[T]+1)%mod;
                else
                {
                    tot=0;
                    i=fa[T];j=T;
                    
                    while (j!=-1)
                    {
                        ++tot;
                        b[tot]=i;c[tot+1]=j;
                        
                        down(i);down(j);
                        i=nx[i];j=nx[j];
                    }
                    b[++tot]=i;
                    down(i);
                    
                    sum=s1=s2=0;
                    if (nx[i]!=-1)
                    {
                        s1=f[nx[i]]*F[nx[i]]%mod;
                        
                        sum=f[nx[i]]*F[nx[i]]%mod*f[T]%mod;
                        F[nx[i]]=F[nx[i]]*(f[T]+1)%mod;
                    }
                    
                    fd(i,tot,2)
                    {
                        sum=(sum+(f[b[1]]-s1)*(f[c[2]]+1-s2)-(f[b[1]]-f[b[i]])*(f[c[2]]+1-f[c[i]])-(f[b[i]]-s1))%mod;
                        
                        s1=f[b[i]];s2=f[c[i]];
                        f[b[i]]=(f[b[i]]+sum)%mod;
                    }
                    f[b[1]]=(f[b[1]]+sum)%mod;
                }
            }
            --t;
        }
    }
}

int main()
{
    freopen("tree.in","r",stdin);
    #ifdef file
    freopen("tree.out","w",stdout);
    #endif
    
    scanf("%d",&n);
    Ans2[1]=Ans[1][0]=f[1]=F[1]=1;
    fo(i,2,n)
    {
        scanf("%d%d",&j,&k);
        Ans2[i]=Ans[i][0]=f[i]=F[i]=1;
        
        New(j,k);
        New(k,j);
    }
    
    fo(i,1,n) I[i]=ls[i];
    
    bfs();
    for (i=ls[1]; i; i=a[i][1])
    dfs(a[i][0]);
    
    for (i=ls[1]; i; i=a[i][1])
    {
        j=a[i][0];
        tot=1;
        
        while (j!=-1)
        {
            down(j);
            
            b[++tot]=j;
            j=nx[j];
        }
        b[tot+1]=0;
        
        Ans2[tot+1]=Ans2[tot+1]*(f[b[2]]+1)%mod;
        
        fo(j,2,tot)
        {
            Ans[j][2]=(Ans[j][2]*(f[b[2]]-f[b[j+1]]+1))%mod;
            
            fd(k,1,0)
            {
                Ans[j][k+1]=(Ans[j][k+1]+Ans[j][k]*(f[b[j]]-f[b[j+1]]))%mod;
                Ans[j][k]=Ans[j][k]*(f[b[2]]-f[b[j]]+1)%mod;
            }
        }
    }
    
    --ans;
    fo(i,2,n)
    {
        ans=(ans-Ans[i][2]*Ans2[i])%mod;
        Ans2[i+1]=Ans2[i+1]*Ans2[i]%mod;
    }
    
    printf("%lld\n",(ans+mod)%mod);
}

猜你喜欢

转载自www.cnblogs.com/gmh77/p/12204071.html