版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_37025443/article/details/83011285
题意:
给你两个矩阵A,B,
A是n*p,B是p*m,B是一个只有0,1组成的矩阵,Aij<65536
C=A*B,让你求出C的里面所有元素的异或和
解析:
官方的标解是分块,每8个分一组。
例如对于A,每行行每8个分成一组,对于B,每一列每8个分成一组,
定义组数为x=p/8+(p%8)1:0
那么现在A就变成了n*x,B变成x*m
现在我们需要解决的就是当分完块的A,B相乘时,对应组的乘积
那么对于这一个我们就可以预处理,因为B的每一列中,每8个一组,那么每一组的情况只有256种,
我们就可以把A的每一组都对应求在256种情况下,每一种情况的值。这个可以提前打表出来。
那么预处理的复杂度就是O(n*p*256)
然后最后相乘的复杂度就变成了O(n*m*p/8) ,会达到1e8
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAXN = 4096+100;
int n,p,m;
int a[MAXN][80];
int b[80][MAXN];
int d[MAXN][20][257];
int e[20][MAXN];
int main()
{
scanf("%d%d%d",&n,&p,&m);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=p;j++)
{
scanf("%x",&a[i][j]);
}
}
for(int i=1;i<=m;i++)
{
for(int j=1;j<=p;j++)
{
scanf("%01d",&b[j][i]);
}
}
int btr=p/8+(p%8?1:0);
for(int i=1;i<=n;i++)
{
for(int w=1;w<=btr;w++)
{
for(int j=0;j<256;j++)
{
int tmp=j;
for(int k=1;k<=8;k++)
{
if(tmp&1) d[i][w][j]+=a[i][(w-1)*8+k];
tmp>>=1;
}
}
}
}
for(int i=1;i<=m;i++)
{
for(int j=1;j<=btr;j++)
{
for(int k=1;k<=8;k++)
{
e[j][i]|=(b[(j-1)*8+k][i]<<(k-1));
}
}
}
int ans=0;
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
int res=0;
for(int k=1;k<=btr;k++)
{
res+=d[i][k][e[k][j]];
}
ans^=res;
}
}
printf("%d\n",ans);
return 0;
}
我自己那时候想的方法也过了。。复杂度是O(n*m*p/4)达到2e8,因为最暴力的O(n*m*p)都能过,1e9...
我就是把每一个数按照二进制拆出来。因为Aij最大只有2^16,那么一个A最多只能被分成16个矩阵
第一矩阵A1ij就表示,Aij的二进制第1位;第二个矩阵A2ij,表示Aij的二进制第二位..........
那么我们就把A1ij*B,然后把各个位置的进位保存在inc[][]里面,因为两个二进制矩阵相乘复杂度就是O(n*m)
对于进位,我们只需要他们相乘统计1的时候,加上去就可以了,然后在更新当前产生的新的进位
最后我们就需要把超过16位的进位进上去,只需要遍历一遍inc这个数组,复杂度O(n*m)
写的时候,找了一个BUG半天,发现存拆出来的数组里面的元素最大是有64位的,因为p<=64
所以需要用ull来存。。。。。。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int MAXN = 4096+10;
int n,p,m;
ull a[20][MAXN];
int tmp[80];
ull b[MAXN];
char str[80];
int inc[MAXN][MAXN];
int main()
{
scanf("%d%d%d",&n,&p,&m);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=p;j++)
{
scanf("%x",&tmp[j]);
}
for(int j=1;j<=16;j++)
{
for(int k=1;k<=p;k++)
{
a[j][i]=a[j][i]<<1;
if(tmp[k]&1) a[j][i]|=1;
tmp[k]=tmp[k]>>1;
}
}
}
for(int i=1;i<=m;i++)
{
getchar();
scanf("%s",str);
for(int j=0;j<p;j++)
{
b[i]=b[i]<<1;
if(str[j]=='1')
{
b[i]|=1;
}
}
}
ull w;
int num;
int coun=1;
int flag;
int ans=0;
for(int i=1;i<=16;i++)
{
flag=0;
for(int j=1;j<=n;j++)
{
for(int k=1;k<=m;k++)
{
w=a[i][j]&b[k];
num=__builtin_popcountll(w);
num+=inc[j][k];
inc[j][k]=num>>1;
flag^=(num&1);
}
}
ans^=(flag)?coun:0;
coun=coun<<1;
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
/*int ano=coun;
while(inc[i][j])
{
ans^=((inc[i][j]&1)?ano:0);
inc[i][j]=inc[i][j]>>1;
ano<<=1;
}*/
ans^=(inc[i][j]<<16);
}
}
printf("%d\n",ans);
return 0;
}