2019.4.10 一题——根号分治

  不收敛就是存在一个点,在一轮中有过贡献,但是没有被操作过,并且它的权值非 0 。枚举每条边判断是否一端被操作一端未被操作再看看权值即可。

  考虑计算一轮中每个点乘上 \( \frac{1}{2^i} \) 贡献了几次,那么无限轮就是一轮的总贡献乘上 \( \sum\limits_{i=0}^{\infty}\frac{1}{2^{i*c}} \) ,其中 c 是一轮中该点被操作的次数。

  然后就不会了。

  看到边数有限制,其实可以考虑根号分治!就是度数 <= 根号的点一种做法,度数 > 根号的点一种做法。

  把度数 <= 根号的点称为小点,度数 > 根号的点称为大点。

  1.小点要枚举出边,考虑它自己给它的 “出点” 的贡献;2.大点可以枚举所有 n 个点,考虑它们给它的贡献;3.同时还要考虑大点对小点的贡献。

  发现如果枚举 n 个点之类的话,每个点得维护它在操作序列上的位置之类的,很难做。那么考虑枚举操作序列上 k 个位置。

  1.

    小点给它的出点贡献的时候,得知道对方现在是 2 的几次方分之一的状态。所以记一个 ct 表示每个点当前是几次方状态。再记一个 sm 表示每个点在一轮中贡献了几倍权值。

    依次枚举 k 个位置,先给该位置的点的 ct[ ] ++ ,遇到大点就跳过;遇到小点就枚举它的出边,如果对方是小点就给对方的 sm 加上对方的 \( 2^{ct[ ]} \) 。

  2.

    枚举大点 cr ,再枚举 k 个位置。每次开始一个大点的时候,就把所有点的 ct 清空,然后枚举 k 个位置的时候维护起来,这样仍可以知道每个点当前的状态。

    k 个位置里遇到小点,如果是与 cr 有边的,就 sm[ cr ] += bin[ ct[ cr ] ] 即可。(bin[ i ] 表示 \( \frac{1}{2^i} \))

    可以预处理 \( n\sqrt{n} \) 的 bool 数组表示每个大点和各点之间有没有边。一开始枚举每条边就能做出。

  3.

    这个难以在枚举大点的时候做,所以考虑在枚举小点 cr 的出边的时候,“出点”是大点的话就 sm[ cr ] += ... 。

    考虑这里的 “ ... ” ,应该是 “在 cr 上次被操作到这次被操作之间该大点的出现次数 tp ” 乘上 bin[ ct[ cr ]-1 ] (减 1 因为 cr 上次操作到这次操作之间是 ct[ cr ]-1 的状态)。

    考虑 tp 怎么求。可以开一个 \( n\sqrt{n} \) 的数组表示每个大点在当前时刻(因为是在外围枚举着 k 个位置的,所以有当前时刻)出现了多少次。 位置每变一下,就枚举大点维护数组,也不过是 \( k\sqrt{n} \) 的复杂度。

  所以这道题就 \( O(k\sqrt{n}) \) 即可。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=1e5+5,M=2e5+5,B=340,mod=998244353;
int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,m,k,w[N],s[M],rd[N],hd[N],xnt,to[N<<1],nxt[N<<1];
int q[B],tot,dy[N],ct[N],sm[N],lj[B][M],lst[N],bin[M],iv2;
bool lx[N],vis[N],b[B][M];
struct Ed{
  int x,y;
  Ed(int x=0,int y=0):x(x),y(y) {}
}ed[N];
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;}
void init()
{
  memset(rd,0,sizeof rd);
  xnt=0;memset(hd,0,sizeof hd);
  tot=0;memset(lx,0,sizeof lx);
  memset(ct,0,sizeof ct); memset(sm,0,sizeof sm);
  memset(vis,0,sizeof vis); memset(b,0,sizeof b);
  memset(lj,0,sizeof lj); memset(lst,0,sizeof lst);
}
int main()
{
  while(scanf("%d",&n)==1)
    {
      m=rdn();k=rdn();
      init();
      for(int i=1;i<=n;i++)w[i]=rdn();
      for(int i=1;i<=k;i++)s[i]=rdn(),vis[s[i]]=1;
      iv2=pw(2,mod-2);
      bin[0]=1;for(int i=1;i<=k;i++)bin[i]=(ll)bin[i-1]*iv2%mod;
      bool fg=0;
      for(int i=1,u,v;i<=m;i++)
    {
      u=rdn();v=rdn(); ed[i]=Ed(u,v); add(u,v);add(v,u);
      rd[u]++; rd[v]++;
      if((vis[u]&&!vis[v]&&w[v])||(vis[v]&&!vis[u]&&w[u]))fg=1;
    }
      if(fg){puts("-1");continue;}
      int bs=sqrt(n);
      for(int i=1;i<=n;i++)
    if(rd[i]>bs&&vis[i]) lx[i]=1,q[++tot]=i,dy[i]=tot;
      for(int i=1;i<=m;i++)
    {
      int u=ed[i].x, v=ed[i].y;
      if(lx[u])b[dy[u]][v]=1; if(lx[v])b[dy[v]][u]=1;
    }
      for(int i=1;i<=k;i++)
    {
      int cr=s[i]; ct[cr]++;
      for(int j=1;j<=tot;j++)lj[j][i]=lj[j][i-1];
      if(lx[cr]){lj[dy[cr]][i]++;continue;}
      for(int j=hd[cr],v;j;j=nxt[j])
        {
          if(!lx[v=to[j]])
        {
          sm[v]=upt(sm[v]+bin[ct[v]]);
        }
          else
        {
          int tp=lj[dy[v]][i]-lj[dy[v]][lst[cr]];
          if(!tp)continue;
          sm[cr]=(sm[cr]+(ll)bin[ct[cr]-1]*tp)%mod;
        }
        }
      lst[cr]=i;
    }
      for(int i=1;i<=n;i++)
    {
      if(!vis[i]||lx[i])continue;
      for(int j=hd[i],v;j;j=nxt[j])
        {
          v=to[j]; if(!lx[v])continue;
          int tp=lj[dy[v]][k]-lj[dy[v]][lst[i]];
          if(!tp)continue;
          sm[i]=(sm[i]+(ll)bin[ct[i]]*tp)%mod;
        }
    }
      for(int i=1;i<=tot;i++)
    {
      int cr=q[i];
      for(int j=1;j<=k;j++)ct[s[j]]=0;//
      for(int j=1;j<=k;j++)
        {
          int v=s[j]; ct[v]++; if(v==cr)continue;
          if(!b[i][v])continue;
          sm[cr]=upt(sm[cr]+bin[ct[cr]]);
        }
    }
      int ans=0;
      for(int i=1;i<=n;i++)
    if(vis[i])
      {
        int tp=pw(upt(1-bin[ct[i]]),mod-2);
        tp=(ll)tp*sm[i]%mod*w[i]%mod;
        ans=upt(ans+tp);
      }
      printf("%d\n",ans);
    }
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/Narh/p/10682586.html