首先有朴素的
,令
为强制 i 选 0 的合法串的方案数,枚举最后一段 1 的起始位置
发现可以线性递推,于是有了一个复杂度为
考虑到当 变大的时候 会比较小
把上面的转移写成前缀和的形式
于是有
考虑它的组合意义
向
连一条权值为 2 的边,
向
连一条权值为
的边
问题转换为求 0 到 n 的每一条路径的乘积的和
枚举 -1路径 的条数
把
求出来就可以了
这种对数据范围分段处理的题挺巧妙的
#include<bits/stdc++.h>
#define cs const
using namespace std;
cs int N = 1 << 16 | 5;
cs int Mod = 65537;
typedef long long ll;
int add(int a, int b){ return a + b >= Mod ? a + b - Mod : a + b; }
int mul(int a, int b){ return 1ll * a * b % Mod; }
int ksm(int a, ll b){ int ans = 1; for(;b;b>>=1,a=mul(a,a)) if(b&1) ans = mul(ans, a); return ans; }
void Add(int &a, int b){ a = add(a, b); }
int dec(int a, int b){ return a - b < 0 ? a - b + Mod : a - b; }
void Dec(int &a, int b){ a = dec(a, b); }
ll n, m;
namespace Poly{
cs int C = 16;
#define poly vector<int>
poly w[C+1];
void prework(){
for(int i = 1; i <= C; i++) w[i].resize(1<<i-1);
int wn = ksm(3, (Mod-1)/(1<<C)); w[C][0] = 1;
for(int i = 1; i < (1<<(C-1)); i++) w[C][i] = mul(w[C][i-1], wn);
for(int i = C-1; ~i; i--) for(int j = 0; j < (1<<i-1); j++) w[i][j] = w[i+1][j<<1];
} int F[N], bit, up, rev[N];
void init(int len){
bit = 0, up = 1; while(up < len) up <<= 1, ++bit;
for(int i = 0; i < up; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<bit-1);
}
void NTT(poly &a, int typ){
for(int i = 0; i < up; i++) if(i<rev[i]) swap(a[i], a[rev[i]]);
for(int l = 1, i = 1; i < up; i <<= 1, ++l)
for(int j = 0; j < up; j += (i<<1))
for(int k = 0; k < i; k++){
int x = a[k+j], y = mul(w[l][k], a[k+j+i]);
a[k+j] = add(x, y); a[k+j+i] = dec(x, y);
}
if(typ == -1){
reverse(a.begin()+1, a.end()); int iv = ksm(up, Mod-2);
for(int i = 0; i < up; i++) a[i] = mul(a[i], iv);
}
}
poly operator * (poly a, poly b){
int deg = a.size() + b.size() - 1;
init(deg); a.resize(up), b.resize(up);
NTT(a, 1); NTT(b, 1);
for(int i = 0; i < up; i++) a[i] = mul(a[i], b[i]);
NTT(a, -1); a.resize(deg);
if(deg < m) return a;
a.resize(m + m);
int sum = 0;
for(int i = deg-1; i >= m; i--) Add(a[i], sum), Add(sum, a[i]);
for(int i = m-1; i >= 0; i--){ a[i] = add(a[i], sum); sum = dec(sum, a[i+m]); }
a.resize(min(deg, (int)m));
return a;
}
void Solve(){
prework();
poly A, B; A.push_back(0); A.push_back(1); B.push_back(1);
for(;n;n>>=1,A=A*A) if(n&1) B=B*A;
int ans = 0;
F[0] = 1; for(int i = 1; i <= m; i++) F[i] = add(F[i-1], F[i-1]);
for(int i = 0; i < B.size(); i++) Add(ans, mul(B[i], F[i]));
cout << ans;
}
}
namespace Binom{
int fac[Mod+1], ifac[Mod+1];
int C(int n, int m){ if(n<0||m<0||n<m) return 0; return mul(fac[n], mul(ifac[n-m], ifac[m])); }
int Lucas(ll n, ll m){
if(n<Mod && m<Mod) return C(n, m); return mul(Lucas(n/Mod,m/Mod), C(n%Mod, m%Mod));
}
int calc(ll n){
int ans = 0;
for(int i = 0, up = n/(m+1); i <= up; i++){
int ret = mul(ksm(2, n-(m+1)*i), Lucas(n-i*m, i));
(i&1) ? Dec(ans, ret) : Add(ans, ret);
} return ans;
}
void Solve(){
fac[0] = fac[1] = ifac[0] = ifac[1] = 1;
for(int i = 2; i < Mod; i++) fac[i] = mul(fac[i-1], i);
ifac[Mod-1] = ksm(fac[Mod-1], Mod-2);
for(int i = Mod-2; i >= 2; i--) ifac[i] = mul(ifac[i+1], i+1);
cout << dec(calc(n+1), calc(n));
}
}
int main(){
scanf("%lld%lld", &n, &m);
if(m == 1){ puts("1"); return 0; }
if(m < (1 << 15)) Poly::Solve();
else Binom::Solve(); return 0;
}