解题思路
这道题实际上与上题类似,不过是二维的。
我们肯定要先找正方形的对称轴,但这样不好处理,所以我们可以先把分别正方形左右对称和上下对称,这个要用二维Hash 来提效率。具体的处理就跟二维前缀和类似,行列分别做就行了,翻转正方形也是一样处理,判断对称的时候只要判断三个正方形是否相同。
注意在枚举中心点的时候要分类讨论奇偶情况:
-
对于长度为奇数的正方形,以格子(一个 1 ∗ 1 1∗1 1∗1的正方形)为中心二分最远符合条件的长度。。
-
对于长度为偶数的正方形,以格点(就是一个点)为中心二分最远符合条件的长度。。
PS:最后加上 n ∗ m n∗m n∗m 因为每个 1 ∗ 1 1 ∗ 1 1∗1 格子也算一个正方形。
代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <iomanip>
#include <cmath>
using namespace std;
typedef unsigned long long ull;
ull p1=131,p2=313;
ull g[1010][1010],xturn[1010][1010],yturn[1010][1010];
int ans,n,m,a[1010][1010];
struct c {
ull x,y;
} base[1010];
void yu()
{
for(int i=1; i<=n; i++) {
for(int j=1; j<=m; j++)
xturn[n-i+1][j]=yturn[i][m-j+1]=a[i][j];
}
}
void hash()//二维Hash
{
base[0].x=1,base[0].y=1;
for(int i=1; i<=max(n,m); i++) {
base[i].x=base[i-1].x*p1;
base[i].y=base[i-1].y*p2;
}
for(int i=1; i<=n; i++) {
for(int j=1; j<=m; j++) {
a[i][j]+=a[i-1][j]*p1;
xturn[i][j]+=xturn[i-1][j]*p1;
yturn[i][j]+=yturn[i-1][j]*p1;
}
}
for(int i=1; i<=n; i++) {
for(int j=1; j<=m; j++) {
a[i][j]+=a[i][j-1]*p2;
xturn[i][j]+=xturn[i][j-1]*p2;
yturn[i][j]+=yturn[i][j-1]*p2;
}
}
}
bool check(int x,int y,int len)
{
int v1,v2,v3,y1,x1;
if(x<len||x>n||y<len||y>m) return 0;
v1=a[x][y]-a[x-len][y]*base[len].x-a[x][y-len]*base[len].y+a[x-len][y-len]*base[len].y*base[len].x;
x1=n-(x-len);//上下翻转
v2=xturn[x1][y]-xturn[x1-len][y]*base[len].x-xturn[x1][y-len]*base[len].y+xturn[x1-len][y-len]*base[len].y*base[len].x;
y1=m-(y-len);//左右翻转
v3=yturn[x][y1]-yturn[x-len][y1]*base[len].x-yturn[x][y1-len]*base[len].y+yturn[x-len][y1-len]*base[len].y*base[len].x;
if(v1==v2&&v2==v3)
return 1;
else return 0;
}
void work()
{
int t=0,l=0,r=max(n,m)+1,mid=0,x,y;
//长度为奇数
for(int i=1; i<n; i++) {
//枚举中心点
for(int j=1; j<m; j++) {
t=0,l=0,r=max(n,m)+1,mid=0;
while(l<r) {
//二分边长
mid=(l+r+1)>>1;
x=mid+i,y=mid+j;//右下角
if(check(x,y,mid*2)) {
t=mid;
l=mid;
}
else r=mid-1;
}
ans+=t;
}
}
//长度为奇数
for(int i=1; i<=n; i++) {
//枚举中心点
for(int j=1; j<=m; j++) {
t=0,l=0,r=max(n,m)+1,mid=0;
while(l<r) {
mid=(l+r+1)>>1;//二分边长
x=mid+i,y=mid+j;//右下角
if(check(x,y,mid*2+1)) {
t=mid;
l=mid;
}
else r=mid-1;
}
ans+=t;
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1; i<=n; i++)
for(int j=1; j<=m; j++)
scanf("%d",&a[i][j]);
yu();
hash();
work();
ans+=m*n;
printf("%d",ans);
}
/*
5 5
4 2 4 4 4
3 1 4 4 3
3 5 3 3 3
3 1 5 3 3
4 2 1 2 4
*/