题意:
萨博有个方程式:
求有多少组x在满足条件的情况下使得等式成立,答案对
取模。
题解:
以下为邦邦老师的ppt,这一部分讲的挺清晰的:
dp状态的解释也很清晰:
这样的做法复杂度是
的。
我自己做的时候一开始不知道得到dp这个dp状态之后如何得到这一位不全选1的解对答案的贡献,后来想明白了:假设这一位有
个
,那么
对答案的贡献是
,pos指的是最高位的位数。能做贡献的前提是j的奇偶性和k的这一位是一样的。
为什么呢。因为
相当于是总的方案数,然后我们因为有一个k的限制,所以可以让其中一个高位本可选1但是选了0的数字去和其他的凑,也就是说其他的定了它也就定了,不能乱动。这和它原本可以选择
种方案相比,变成了只有一种选择。
想清楚这一点之后我们发现,实际上不用记录选择了几个1,而只要记录选择了奇数个1还是偶数个1,是否为全选即可得到贡献,这样优化了一层记录选择1的个数的循环。
表示考虑前i个1,用了偶/奇数个1,未全选/全选了1的方案数。
这样复杂度可以优化为
。这样数据范围的
就可以出到
啦~
未优化的代码:
#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][55];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
ll res = 1;
while(b){
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
}return res;
}
ll sol(int pos){
//cout<<"pos:"<<pos<<endl;
if(pos < 0) return 1;
ll res = 0;
memset(dp, 0, sizeof dp);
dp[0][0] = 1;
int cur = 0;
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1){
cur++;
dp[cur][0] = dp[cur-1][0]*(1LL<<pos)%mod;
for(int j = 1; j <= cur; ++j){
dp[cur][j] = (dp[cur-1][j]*(1LL<<pos)%mod + dp[cur-1][j-1]*(m[i]-(1LL<<pos)+1)%mod)%mod;
}
}else{
for(int j = 0; j <= cur; ++j) dp[cur][j] = dp[cur][j]*(m[i]+1)%mod;
}
}
ll inv = qm(1LL<<pos, mod-2);
for(int i = (k>>pos&1); i < cur; i+=2){
res = (res + dp[cur][i]*inv)%mod;
}
//cout<<"res:"<<res<<endl;
if((cur&1) == (k>>pos&1)){
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
}
return (res + sol(pos-1))%mod;
}else return res;
}
int main()
{
while(scanf("%d%lld", &n, &k)!=EOF){
for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
ll ans = sol(31);
ans = (ans + mod)%mod;
cout<<ans<<endl;
}
}
优化后的代码
#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][2][2];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
ll res = 1;
while(b){
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
}return res;
}
ll sol(int pos){
//cout<<"pos:"<<pos<<endl;
if(pos < 0) return 1;
ll res = 0;
memset(dp, 0, sizeof dp);
dp[0][0][1] = 1;
int cur = 0;
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1){
cur++;
dp[cur][0][0] =( (m[i]-(1LL<<pos)+1)*dp[cur-1][1][0]%mod + (1LL<<pos)*(dp[cur-1][0][0]+dp[cur-1][0][1])%mod )%mod;
dp[cur][0][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][1][1]%mod;
dp[cur][1][0] = ( (m[i]-(1LL<<pos)+1)*dp[cur-1][0][0]%mod + (1LL<<pos)*(dp[cur-1][1][0] + dp[cur-1][1][1])%mod )%mod;
dp[cur][1][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][0][1]%mod;
}else{
dp[cur][0][0] = (dp[cur][0][0]*(m[i]+1))%mod;
dp[cur][1][0] = (dp[cur][1][0]*(m[i]+1))%mod;
dp[cur][0][1] = (dp[cur][0][1]*(m[i]+1))%mod;
dp[cur][1][1] = (dp[cur][1][1]*(m[i]+1))%mod;
}
}
ll inv = qm(1LL<<pos, mod-2);
res = dp[cur][k>>pos&1][0]*inv%mod;
if((cur&1) == (k>>pos&1)){
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
}
return (res + sol(pos-1))%mod;
}else return res;
}
int main()
{
while(scanf("%d%lld", &n, &k)!=EOF){
for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
ll ans = sol(31);
ans = (ans + mod)%mod;
cout<<ans<<endl;
}
}