[Codeforces960G][NTT][DP]Bandit Blues

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Rose_max/article/details/82917113

翻译

给你三个正整数 n,a,b,定义 A 为一个排列中是前缀最大值的数的个数,定义 B 为一个排列中是后缀最大值的数的个数,求长度为 nn 的排列中满足 A = a且 B = b 的排列个数。n≤10^5,答案对 998244353取模。

题解

很妙
我是膜beginend的!
开始想的是每次加入n+1
然后就凉了啊…
转换一下思路
每次加入最小的一个数
显然只有在加入到最前方的时候才会对前缀最大值产生贡献
f [ i ] [ j ] f[i][j] 表示加入1~i的数,有j个前缀最大值的方案
转移就是
f [ i ] [ j ] = f [ i 1 ] [ j 1 ] + ( i 1 ) f [ i 1 ] [ j ] f[i][j]=f[i-1][j-1]+(i-1)*f[i-1][j]
枚举最大的放哪里 可以得到
i = 1 n f [ i 1 ] [ a 1 ] f [ n i ] [ b 1 ] C n 1 i 1 \sum_{i=1}^nf[i-1][a-1]*f[n-i][b-1]*C_{n-1}^{i-1}
这是个优秀的 n 2 n^2 方程
设前缀最大值的位置为 p 1 , p 2 , p 3... p1,p2,p3...
可以把 [ p i , p i + 1 1 ] [p_i,p_{i+1}-1] 看成一组
总共有 a + b 2 a+b-2
选出 a 1 a-1 组放到n的前面
可以知道
f [ n 1 ] [ a + b 2 ] C a + b 2 a 1 f[n-1][a+b-2]*C_{a+b-2}^{a-1}
然后就不会做了…
其实f的转移是第一类斯特林数的递推式
第一类斯特林数 s ( n , m ) s(n,m) 就等于x的n次上升幂的第m项系数
x n = x ( x + 1 ) ( x + 2 ) ( x + n 1 ) = i = 0 n s ( n , k ) x k x^{n \uparrow} = x(x + 1)(x + 2) \cdots (x + n - 1) =\sum_{i = 0}^n s(n, k) x^k
这个可以分治fft求出
点了两个技能真舒服

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
#include<ctime>
#include<map>
#define mod 998244353
#define MAXN 100010
#define LL long long
#define mp(x,y) make_pair(x,y)
using namespace std;
inline int read()
{
	int f=1,x=0;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
inline void write(int x)
{
	if(x<0)putchar('-'),x=-x;
	if(x>9)write(x/10);
	putchar(x%10+'0');
}
inline void print(int x){write(x);printf(" ");}
LL pow_mod(LL a,LL b)
{
	LL ret=1;
	while(b)
	{
		if(b&1)ret=ret*a%mod;
		a=a*a%mod;b>>=1;
	}
	return ret;
}
LL A[MAXN*4],B[MAXN*4];
int R[MAXN*4],L;
void NTT(LL *y,int len,int on)
{
	for(int i=0;i<len;i++)R[i]=((R[i>>1]>>1)|((i&1)*(len>>1)));
	for(int i=0;i<len;i++)if(i<R[i])swap(y[i],y[R[i]]);
	for(int i=1;i<len;i<<=1)
	{
		LL wn=pow_mod(3,(mod-1)/(i*2));if(on==-1)wn=pow_mod(wn,mod-2);
		for(int j=0;j<len;j+=(i<<1))
		{
			LL w=1;
			for(int k=0;k<i;k++)
			{
				LL u=y[j+k];
				LL v=y[j+k+i]*w%mod;
				y[j+k]=(u+v)%mod;
				y[j+k+i]=(u-v+mod)%mod;
				w=w*wn%mod;
			}
		}
	}
	if(on==-1)
	{
		LL tmp=pow_mod(len,mod-2);
		for(int i=0;i<len;i++)y[i]=(y[i]*tmp)%mod;
	}
}
void sol(LL *a,int len,int l,int r)
{
	if(l==r){a[0]=l;a[1]=1;return ;}
	int mid=(l+r)/2;LL g1[len+5],g2[len+5];
	memset(g1,0,sizeof(g1));memset(g2,0,sizeof(g2));
	sol(g1,len>>1,l,mid);sol(g2,len>>1,mid+1,r);
	
	NTT(g1,len,1);
	NTT(g2,len,1);
	for(int i=0;i<len;i++)a[i]=g1[i]*g2[i]%mod;
	NTT(a,len,-1);
}
LL pre[110000],inv[110000];
LL C(int n,int m){return pre[n]*inv[n-m]%mod*inv[m]%mod;}
int n,a,b;
int main()
{
	pre[0]=1;for(int i=1;i<=100000;i++)pre[i]=pre[i-1]*i%mod;
	inv[100000]=pow_mod(pre[100000],mod-2);
	for(int i=99999;i>=0;i--)inv[i]=inv[i+1]*(i+1)%mod;
	n=read();a=read();b=read();
	if(a+b-2>n-1||!a||!b){puts("0");return 0;}
	if(n==1){puts("1");return 0;}
	int ln;L=0;
	for(ln=1;ln<=2*(n+1);ln<<=1)L++;
	sol(A,ln,0,n-2);
	printf("%lld\n",A[a+b-2]*C(a+b-2,a-1)%mod);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Rose_max/article/details/82917113