【THUSC2017】【LOJ2978】杜老师 高斯消元

题目大意

  给你 \(l,r\),求从 \(l\)\(r\)\(r-l+1\) 个数中能选出多少个不同的子集,满足子集中所有的数的乘积是一个完全平方数。

  对 \(998244353\) 取模。

  \(1\leq l,r\leq {10}^7\)

  有 \(100\) 组数据,\(\sum r-l+1\leq 6\times {10}^7\)

题解

  对于每个数,求出这个数中包含了哪些出现次数为奇数的质数。

  那么就可以直接高斯消元,记矩阵的秩为 \(r\),答案就是 \(2^{r-l+1-r}\)。可以用 bitset 优化。

  时间复杂度为 \(O(\frac{n\pi(n)^2}{w})\)

  可以发现,一个数最多有一个 \(>\sqrt r\) 的质因子。那么对于两个最大值因子相同的数,可以让第二个数的状态异或上第一个数的状态,这样第二个数的状态就只有 \(\leq \sqrt r\) 的质因子了。

  这样就可以让矩阵的列的数量降低到 \(\pi(\sqrt n)\)

  但是还是过不了这题。

  可以发现,当 \(r-l+1\) 足够大的时候就可以认为这个矩阵满秩了。在本题中,当 \(r-l+1>6000\) 的时候就可以不用高斯消元直接求出答案了。

  时间复杂度:\(O(\frac{T\times 6000\times \pi(\sqrt n)^2}{w})\)

  这个复杂度很松,实际跑起来非常快。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<functional>
#include<cmath>
#include<vector>
#include<assert.h>
#include<bitset>
using namespace std;
using std::min;
using std::max;
using std::swap;
using std::sort;
using std::reverse;
using std::random_shuffle;
using std::lower_bound;
using std::upper_bound;
using std::unique;
using std::vector;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef std::pair<int,int> pii;
typedef std::pair<ll,ll> pll;
void open(const char *s){
#ifndef ONLINE_JUDGE
    char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
void open2(const char *s){
#ifdef DEBUG
    char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
#endif
}
int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
const ll p=998244353;
const int N=10000010;
const int n=10000000;
const int sqrtn=3162;
const int size=446;
typedef bitset<500> arr;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
int c[N],d[N],b[N],pri[N],cnt;
void sieve()
{
    c[1]=1;
    for(int i=2;i<=n;i++)
    {
        if(!b[i])
        {
            pri[++cnt]=i;
            d[i]=cnt;
            c[i]=i;
        }
        for(int j=1;j<=cnt&&i*pri[j]<=n;j++)
        {
            int v=i*pri[j];
            b[v]=1;
            c[v]=c[i];
            if(i%pri[j]==0)
                break;
        }
    }
}
void init()
{
    sieve();
}
arr get(int x)
{
    while(c[x]>sqrtn)
        x/=c[x];
    arr res;
    while(x>1)
    {
        res.flip(d[c[x]]-1);
        x/=c[x];
    }
    return res;
}
int len,tot,tot2;
void solve2(int l,int r)
{
    tot=0;
    len=r-l+1;
    for(int i=1;i<=cnt;i++)
        if(r/pri[i]!=(l-1)/pri[i])
            tot++;
    ll ans=fp(2,len-tot);
    printf("%lld\n",ans);
}
arr e[size];
int insert(arr v)
{
    for(int i=0;i<size;i++)
        if(v[i])
        {
            if(e[i][i])
                v^=e[i];
            else
            {
                e[i]=v;
                return 1;
            }
        }
    return 0;
}
pii a[10000];
arr pre;
int cmp(pii a,pii b)
{
    return a.second<b.second;
}
void solve()
{
    int l,r;
    scanf("%d%d",&l,&r);
    if(r-l>6000)
    {
        solve2(l,r);
        return;
    }
    tot=0;
    tot2=0;
    len=r-l+1;
    if(l==1)
        l++;
    int m=0;
    for(int i=0;i<size;i++)
        e[i].reset();
    for(int i=l;i<=r;i++)
        a[++m]=pii(i,c[i]);
    sort(a+1,a+m+1,cmp);
    for(int i=1;i<=m;i++)
        if(a[i].second<=sqrtn)
        {
            if(tot<size)
                if(insert(get(a[i].first)))
                    tot++;
        }
        else if(i==1||a[i].second!=a[i-1].second)
        {
            tot2++;
            if(tot<size)
                pre=get(a[i].first);
        }
        else
        {
            if(tot<size)
                if(insert(get(a[i].first)^pre))
                    tot++;
        }
    ll ans=fp(2,len-tot-tot2);
    printf("%lld\n",ans);
}
int main()
{
    open("dls");
    init();
    int t;
    scanf("%d",&t);
    while(t--)
        solve();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ywwyww/p/10294451.html