链接
题目描述
给定一个n 行m 列的矩阵。
求矩阵中上下对称且左右对称的正方形子矩阵的个数。
样例输入
5 5
4 2 4 4 4
3 1 4 4 3
3 5 3 3 3
3 1 5 3 3
4 2 1 2 4
样例输出
27
思路
将矩阵上下颠倒一次,左右颠倒一次,再求出每个矩阵的每个区间矩阵的hash值,二分正方形的个数,比较hash值就可以判断是否是上下左右对称。
一个合法正方形满足在该正方形下的正方形都是合法的
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define zhi1 1000000007ull
#define zhi2 1000000009ull
#define ull unsigned long long
using namespace std;
int n, m, tot, lx, ly, hash1, hash2, hash3, ans, l, r;
ull hasha[1005][1005], hashb[1005][1005], hashc[1005][1005];
ull a[1005][1005], b[1005][1005], c[1005][1005], t1[10005], t2[10005];
bool check(int rx, int ry, int dis) {
lx = rx - dis + 1;
ly = ry - dis + 1;
hash1 = hasha[rx][ry] - hasha[rx][ly - 1] * t1[dis] - hasha[lx - 1][ry] * t2[dis] + hasha[lx - 1][ly - 1] * t1[dis] * t2[dis];//求出原矩阵的hash值
int tmp = rx;
rx = n - (rx - dis);
lx = rx - dis + 1;
ly = ry - dis + 1;
hash2 = hashc[rx][ry] - hashc[rx][ly - 1] * t1[dis] - hashc[lx - 1][ry] * t2[dis] + hashc[lx - 1][ly - 1] * t1[dis] * t2[dis];//求出上下折叠后矩阵的hash值
rx = tmp;
ry = m - (ry - dis);
lx = rx - dis + 1;
ly = ry - dis + 1;
hash3 = hashb[rx][ry] - hashb[rx][ly - 1] * t1[dis] - hashb[lx - 1][ry] * t2[dis] + hashb[lx - 1][ly - 1] * t1[dis] * t2[dis];//求出左右折叠矩阵的hash值
if (hash1 == hash2 && hash1 == hash3) return 1;
return 0;
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++i)
{
for(int j = 1; j <= m; ++j)
{
scanf("%d", &a[i][j]);
b[i][m - j + 1] = a[i][j];
c[n - i + 1][j] = a[i][j];
}
}
t1[0] = 1ull;
for(int i = 1; i <= n; ++i)
t1[i] = t1[i - 1] * zhi1;
t2[0] = 1ull;
for(int i = 1; i <= m; ++i)
t2[i] = t2[i - 1] * zhi2;
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= m; ++j)
{
hasha[i][j] = hasha[i][j - 1] * zhi1 + a[i][j];
hashb[i][j] = hashb[i][j - 1] * zhi1 + b[i][j];
hashc[i][j] = hashc[i][j - 1] * zhi1 + c[i][j];
}
t1[i] = t1[i - 1] * zhi1;
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j)
{
hasha[i][j] += hasha[i - 1][j] * zhi2;
hashb[i][j] += hashb[i - 1][j] * zhi2;
hashc[i][j] += hashc[i - 1][j] * zhi2;
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j)
{
ans = 0;
l = 1;
r = min(min(i, n - i + 1), min(j, m - j + 1));
while(l <= r)
{
int mid = (l + r) >> 1;
if(i - mid + 1 < 1 || i + mid - 1 > n || j - mid + 1 < 1 || j + mid - 1 > m) {
r = mid - 1;
continue;
}
if(check(i + mid - 1, j + mid - 1, mid * 2 - 1)) {
ans = mid;
l = mid + 1;
}
else r = mid - 1;
}//正方形长度为奇数时
tot += ans;
ans = 0;
l = 1;
r = min(min(i, n - i), min(j, m - j));
while(l <= r)
{
int mid = (l + r) >> 1;
if(i - mid + 1 < 1 || i + mid > n || j - mid + 1 < 1 || j + mid > m) {
r = mid - 1;
continue;
}
if(check(i + mid, j + mid, mid * 2)) {
ans = mid;
l = mid + 1;
}
else r = mid - 1;
}//正方形长度为偶数时
tot += ans;
}
printf("%d", tot);
}