题意:
给你一个带?的字符串S,和一个字符串T,问把?替换后最多能匹配多少次T?可以重叠匹配。
题解:
这种肯定是要DP的。
怎么DP呢?
AC自动机上的DP问题很多,这个也可以用AC自动机。
dp[i][j]表示当前在S串的i位置,在AC自动机的j状态时能完整匹配T的次数。
当i是问号时枚举26个字母转移,不是问号直接转移到这个字符。
滚动数组一下即可。
当然KMP也是可以的,因为只有一个串。不过需要手动建一下next数组,跟AC自动机一样。
代码:
AC自动机版本:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <bitset>
#include <map>
#include <vector>
#include <stack>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <cmath>
#include <ctime>
#ifdef LOCAL
#define debug(x) cout<<#x<<" = "<<(x)<<endl;
#else
#define debug(x) 1;
#endif
#define chmax(x,y) x=max(x,y)
#define chmin(x,y) x=min(x,y)
#define lson id<<1,l,mid
#define rson id<<1|1,mid+1,r
#define lowbit(x) x&-x
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, int> pii;
const int MOD = 1e9 + 7;
const double PI = acos (-1.);
const double eps = 1e-10;
const int INF = 0x3f3f3f3f;
const ll INFLL = 0x3f3f3f3f3f3f3f3f;
const int MAXN = 2e5 + 5;
char s[MAXN], t[MAXN];
const int SIGMA_SIZE = 130;
const int MAXNODE = 2e5 + 100;
const int MAXS = 150 + 10;
struct ACautomata {
int ch[MAXNODE][SIGMA_SIZE];
int f[MAXNODE]; // fail函数
int val[MAXNODE]; // 每个字符串的结尾结点都有一个非0的val
int last[MAXNODE]; // 输出链表的下一个结点
//int cnt[MAXS];
int sz;
void init() {
sz = 1;
memset(ch[0], 0, sizeof(ch[0]));
// memset(cnt, 0, sizeof(cnt));
}
// 字符c的编号
inline int idx(char c) {
return c - 'a';
}
// 插入字符串。v必须非0
void insert(char *s, int v) {
int u = 0, n = strlen(s);
for(int i = 0; i < n; i++) {
int c = idx(s[i]);
if(!ch[u][c]) {
memset(ch[sz], 0, sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u] = v;
}
// 计算fail函数
void getFail() {
queue<int> q;
f[0] = 0;
// 初始化队列
for(int c = 0; c < SIGMA_SIZE; c++) {
int u = ch[0][c];
if(u) {
f[u] = 0;
q.push(u);
last[u] = 0;
}
}
// 按BFS顺序计算fail
while(!q.empty()) {
int r = q.front();
q.pop();
for(int c = 0; c < SIGMA_SIZE; c++) {
int u = ch[r][c];
if(!u) {
ch[r][c]=ch[f[r]][c];
continue;
}
q.push(u);
int v = f[r];
while(v && !ch[v][c]) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}
} ac;
int d[2][MAXN];
int main() {
#ifdef LOCAL
freopen ("input.txt", "r", stdin);
#endif
scanf("%s %s", s + 1, t + 1);
ac.init();
ac.insert(t + 1, 1);
ac.getFail();
int n = strlen(s + 1), m = strlen(t + 1);
memset(d, -1, sizeof(d));
d[0][0] = 0;
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= m; j++) d[i & 1][j] = -1;
for (int j = 0; j <= m; j++) {
if(d[i & 1 ^ 1][j] == -1) continue;
if (s[i] == '?') {
for (int k = 0; k < 26; k++)
d[i & 1][ac.ch[j][k]] = max(d[i & 1 ^ 1][j] + (ac.ch[j][k] == m) , d[i & 1][ac.ch[j][k]]);
} else
d[i & 1][ac.ch[j][s[i] - 'a']] = max(d[i & 1 ^ 1][j] + (ac.ch[j][s[i] - 'a'] == m) , d[i & 1][ac.ch[j][s[i] - 'a']]);
}
}
int ans = 0;
for (int i = 0; i <= m; i++) ans = max(ans, d[n & 1][i]);
printf("%d\n", ans);
return 0;
}
KMP版本:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <bitset>
#include <map>
#include <vector>
#include <stack>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <cmath>
#include <ctime>
#ifdef LOCAL
#define debug(x) cout<<#x<<" = "<<(x)<<endl;
#else
#define debug(x) 1;
#endif
#define chmax(x,y) x=max(x,y)
#define chmin(x,y) x=min(x,y)
#define lson id<<1,l,mid
#define rson id<<1|1,mid+1,r
#define lowbit(x) x&-x
#define mp make_pair
#define pb push_back
#define fir first
#define sec second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, int> pii;
const int MOD = 1e9 + 7;
const double PI = acos (-1.);
const double eps = 1e-10;
const int INF = 0x3f3f3f3f;
const ll INFLL = 0x3f3f3f3f3f3f3f3f;
const int MAXN = 2e5 + 5;
char s[MAXN], t[MAXN];
int f[MAXN];
void getFail(char *P, int m) {
f[0] = f[1] = 0;
for (int i = 2; i <= m; i++) {
int j = f[i - 1];
while (j && P[i] != P[j + 1])
j = f[j];
f[i] = (P[i] == P[j + 1]) ? j + 1: 0;
}
}
int ch[MAXN][26];
int d[2][MAXN];
int main() {
#ifdef LOCAL
freopen ("input.txt", "r", stdin);
#endif
scanf("%s %s", s + 1, t + 1);
int n = strlen(s + 1), m = strlen(t + 1);
getFail(t, m);
memset(d, -1, sizeof(d));
d[0][0] = 0;
for (int i = 0; i <= m; i++) {
for (int j = 0; j < 26; j++) {
if(t[i + 1] == j + 'a') ch[i][j] = i + 1;
else ch[i][j] = ch[f[i]][j];
}
}
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= m; j++) d[i & 1][j] = -1;
for (int j = 0; j <= m; j++) {
if (d[i & 1 ^ 1][j] == -1) continue;
if (s[i] == '?') {
for (int k = 0; k < 26; k++)
d[i & 1][ch[j][k]] = max(d[i & 1 ^ 1][j] + (ch[j][k] == m) , d[i & 1][ch[j][k]]);
} else
d[i & 1][ch[j][s[i] - 'a']] = max(d[i & 1 ^ 1][j] + (ch[j][s[i] - 'a'] == m) , d[i & 1][ch[j][s[i] - 'a']]);
}
}
int ans = 0;
for (int i = 0; i <= m; i++) ans = max(ans, d[n & 1][i]);
printf("%d\n", ans);
return 0;
}