bzoj 4309 Maishroom & Mushroom - bitset

首先说一句,bzoj的数据有问题,有一个l>r的问题,以及,我用了将进一个G的内存……一些优化比如没有用到的信息统统delete掉之类都……总之疯狂卡内存(也过不去)就是了。为了保证复杂度需要手动实现一个bitset,有些这个题没用的功能(例如位移和[])就没实现……

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<assert.h>
#define rep(i,a,b) for(int i=(a);i<(b);i++)
#define N 50010
#define LOG 18
#define ull unsigned long long
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
inline int inn()
{
	int x,ch;while((ch=gc)<'0'||ch>'9');
	x=ch^'0';while((ch=gc)>='0'&&ch<='9')
		x=(x<<1)+(x<<3)+(ch^'0');return x;
}
bool use_zs[N],use_frt[N];
int up[N][LOG],Log[N],d[N];
struct Query{int tp,x,y,z;}q[N<<1];
ull bitpre[70],bitsuf[70];
int bitcnt[(1<<16)+10],bitlowbit[(1<<16)+10];
struct edges{int to,pre;}e[N<<1];int h[N],etop;
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
class My_bitset{
	private:
		ull a[790];int n;
	public:
		My_bitset() { n=(N-1)/64+1,memset(a,0,sizeof(ull)*n); }
		My_bitset(const My_bitset &b) { n=b.n,memcpy(a,b.a,sizeof(ull)*n); }
		inline My_bitset& operator=(const My_bitset &b) { n=b.n;return memcpy(a,b.a,sizeof(ull)*n),*this; }
		inline My_bitset operator&(const My_bitset &b)const{ My_bitset c(*this);rep(i,0,n) c.a[i]&=b.a[i];return c; }
		inline My_bitset& operator&=(const My_bitset &b) { rep(i,0,n) a[i]&=b.a[i];return *this; }
		inline My_bitset operator|(const My_bitset &b)const{ My_bitset c(*this);rep(i,0,n) c.a[i]|=b.a[i];return c; }
		inline My_bitset& operator|=(const My_bitset &b) { rep(i,0,n) a[i]|=b.a[i];return *this; }
		inline My_bitset operator^(const My_bitset &b)const{ My_bitset c(*this);rep(i,0,n) c.a[i]^=b.a[i];return c; }
		inline My_bitset& operator^=(const My_bitset &b) { rep(i,0,n) a[i]^=b.a[i];return *this; }
		inline My_bitset operator~()const{ My_bitset c(*this);rep(i,0,n) c.a[i]=~c.a[i];return c; }
		inline bool operator==(const My_bitset &b)const{ rep(i,0,n) if(a[i]^b.a[i]) return false;return true; }
		inline bool operator!=(const My_bitset &b)const{ rep(i,0,n) if(a[i]^b.a[i]) return true;return false; }
		inline int test(int x)const{ return (a[x>>6]>>(x&63))&1ll; }
		inline void set() { rep(i,0,n) a[i]=~0ull; }
		inline void set(int x) { a[x>>6]|=1ull<<(x&63); }
		inline void set(int l,int r)
		{
			int x=l>>6,y=r>>6,s=l&63,t=r&63;
			if(x==y) { a[x]|=(s?bitpre[s-1]:0)^bitpre[t];return; }
			rep(i,x+1,y) a[i]=~0ull;a[x]|=bitsuf[s],a[y]|=bitpre[t];
		}
		inline void reset() { rep(i,0,n) a[i]=0; }
		inline void reset(int x) { a[x>>6]&=(~0ull)^(1ull<<(x&63)); }
		inline void reset(int l,int r)
		{
			int x=l>>6,y=r>>6,s=l&63,t=r&63;
			if(x==y) { a[x]&=(s?bitpre[s-1]:0)|(t<63?bitsuf[t+1]:0);return; }
			rep(i,x+1,y) a[i]=0ull;a[x]&=(s?bitpre[s-1]:0),a[y]&=(t<63?bitsuf[t+1]:0);
		}
		inline void flip(int x) { a[x>>6]^=1ull<<(x&63); }
		inline int count()const
		{
			int ans=0;rep(i,0,n) ans+=
				bitcnt[a[i]&bitpre[15]]+bitcnt[(a[i]>>16)&bitpre[15]]+
				bitcnt[(a[i]>>32)&bitpre[15]]+bitcnt[a[i]>>48];
			return ans;
		}
		inline bool any(int p=0)const{ rep(i,p>>6,n) if(a[i]) return true;return false; }
		inline bool all(int p=0)const{ rep(i,p>>6,n) if(a[i]<(~0ull)) return false;return true; }
		inline bool none(int p=0)const{rep(i,p>>6,n) if(a[i]) return false;return true; }
		inline int lowbit(int p=0)const
		{
			ull b=0;
			rep(i,p>>6,n) if(a[i])
			{
				if((b=(a[i]&bitpre[15]))) return (i<<6)+bitlowbit[b];
				if((b=((a[i]>>16)&bitpre[15]))) return (i<<6)+16+bitlowbit[b];
				if((b=((a[i]>>32)&bitpre[15]))) return (i<<6)+32+bitlowbit[b];
				if((b=(a[i]>>48))) return (i<<6)+48+bitlowbit[b];
			}
			return -1;
		}
		inline void show()const { rep(i,0,N) cerr<<test(i);cerr ln; }
}*zs[N],*frt[N],*val[N<<1],tmp;
inline int prelude_bitinf()
{
	//bitpre,bitsuf,bitcnt,bitlowbit
	for(int i=0;i<63;i++) bitpre[i]=(1ull<<(i+1))-1;bitpre[63]=~0ull;
	for(int i=1;i<64;i++) bitsuf[i]=(~0ull)^bitpre[i-1];bitsuf[0]=~0ull;
	rep(i,0,1<<16) rep(j,0,16) bitcnt[i]+=((i>>j)&1);
	rep(i,1,1<<16) rep(j,0,16) if((i>>j)&1) { bitlowbit[i]=j;break; }
	return 0;
}
inline int getLCA(int x,int y)
{
	if(d[x]<d[y]) swap(x,y);
	for(int i=Log[d[x]];i>=0;i--)
		if(d[up[x][i]]>=d[y]) x=up[x][i];
	if(x==y) return x;
	for(int i=Log[d[x]];i>=0;i--)
		if(up[x][i]^up[y][i]) x=up[x][i],y=up[y][i];
	return up[x][0];
}
int dfs(int x,int fa=0)
{
	up[x][0]=fa,d[x]=d[fa]+1;for(int i=1;i<=Log[d[x]];i++) up[x][i]=up[up[x][i-1]][i-1];
	zs[x]=new My_bitset,frt[x]=new My_bitset(*frt[fa]),frt[x]->set(x-1),zs[x]->set(x-1);
	for(int i=h[x],y;i;i=e[i].pre) if((y=e[i].to)^fa) dfs(y,x),(*zs[x])|=*zs[y];return 0;
}
int main()
{
	prelude_bitinf();int n=inn(),cnt=1,u,v;
	frt[0]=new My_bitset,val[1]=new My_bitset,val[1]->set(0,n-1);
	for(int i=2;i<=n;i++)
		Log[i]=Log[i>>1]+1,u=inn(),v=inn(),add_edge(u,v),add_edge(v,u);
	dfs(1);int qc=inn();
	for(int i=1;i<=qc;i++)
		if((q[i].tp=inn())==1) q[i].x=inn();
		else if(q[i].tp==2) q[i].x=inn(),q[i].y=inn();
		else if(q[i].tp==3) q[i].x=inn(),q[i].y=inn(),use_zs[q[i].y]=1;
		else if(q[i].tp==4) q[i].x=inn(),q[i].y=inn(),use_zs[q[i].y]=1;
		else if(q[i].tp==5) q[i].x=inn(),q[i].y=inn(),q[i].z=inn(),use_frt[q[i].y]=use_frt[q[i].z]=1;
		else if(q[i].tp==6) q[i].x=inn(),q[i].y=inn(),q[i].z=inn(),use_frt[q[i].y]=use_frt[q[i].z]=1;
		else if(q[i].tp==7) q[i].x=inn(),q[i].y=inn(),q[i].z=inn();
		else if(q[i].tp==8) q[i].x=inn(),q[i].y=inn(),q[i].z=inn();
		else q[i].x=inn(),q[i].y=inn();
	for(int i=1;i<=n;i++) if(!use_zs[i]) delete zs[i];
	for(int i=1;i<=n;i++) if(!use_frt[i]) delete frt[i];
	for(int i=1;i<=qc;i++)
	{
		int tp=q[i].tp,x=q[i].x,y=q[i].y,z=q[i].z;
		if(tp==1) val[++cnt]=new My_bitset(*val[x]);
		else if(tp==2) (*val[x])|=*val[y];
		else if(tp==3) (*val[x])&=~(*zs[y]);
		else if(tp==4) (*val[x])&=*zs[y];
		else if(tp==5) tmp.reset(),tmp.set(getLCA(y,z)-1),(*val[x])&=~((*frt[y])^(*frt[z])^tmp);
		else if(tp==6) tmp.reset(),tmp.set(getLCA(y,z)-1),(*val[x])&=(*frt[y])^(*frt[z])^tmp;
		else if(tp==7) (y<=z?val[x]->reset(y-1,z-1),0:val[x]->reset();
		else if(tp==8) (y>1?val[x]->reset(0,y-2),0:0),(z<n?val[x]->reset(z,n-1),0:0);
		else{
			tmp=*val[x];int ans=0,t=0;
			while((t=tmp.lowbit(min(t,n-1)))>=0) tmp.reset(t,min(n-1,t+y)),t+=y,ans++;
			printf("%d\n",ans);
		}
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/82807212