题意
设L为字符串x和y的lcs长度,求x的不连续子串中能与y匹配的子串数。
思路
- lcs[i][j]: x[1 ~ i] 和 y[1 ~ j] 的最长公共子序列长度
- dp[i][j]: x[1 ~ i] 和 y[1 ~ j] 范围内,x[1 ~ i] 的长度为 lcs[i][j] 的不连续子串中能与 y[1 ~ j] 匹配的子串数。
- 当 x[i] != y[j] 时,对长度为 lcs[i][j] 的字串简单容斥:dp[i][j] = dp' [i][j - 1] + dp' [i - 1][j] - dp' [i - 1][j - 1],如果 lcs[a][b] < lcs[i][j],dp' [a][b] = 0。
- 当 x[i] = y[j] 时,dp[i][j] = dp[i - 1][j - 1] + dp' [i - 1][j]. 当 lcs[i - 1][j] = lcs[i][j] 时,dp' [i - 1][j] = dp[i - 1][j]. dp[i][j] 包括含有x[i] 的子串数(dp[i - 1][j - 1])和不含 x[i] 的子串数(dp' [i - 1][j])。 此处体现问题的不对称性。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #define INF 0x3f3f3f3f #define rep0(i, n) for (int i = 0; i < n; i++) #define rep1(i, n) for (int i = 1; i <= n; i++) #define rep_0(i, n) for (int i = n - 1; i >= 0; i--) #define rep_1(i, n) for (int i = n; i > 0; i--) #define MAX(x, y) (((x) > (y)) ? (x) : (y)) #define MIN(x, y) (((x) < (y)) ? (x) : (y)) #define mem(x, y) memset(x, y, sizeof(x)) #define MOD 1000000007 #define MAXN 1010 using namespace std; typedef long long ll; ll dp[MAXN][MAXN]; int len1, len2, lcs[MAXN][MAXN]; char str[2][MAXN]; void build() { mem(lcs, 0); for (int i = 1; i <= len1; i++) { for (int j = 1; j <= len2; j++) { if (str[0][i] == str[1][j]) lcs[i][j] = lcs[i - 1][j - 1] + 1; else lcs[i][j] = MAX(lcs[i - 1][j], lcs[i][j - 1]); } } } void solve() { mem(dp, 0); for (int i = 0; i <= len1; i++) { for (int j = 0; j <= len2; j++) { if (lcs[i][j] == 0) { dp[i][j] = 1; continue; } if (str[0][i] != str[1][j]) { if (lcs[i][j] == lcs[i][j - 1]) dp[i][j] = (dp[i][j] + dp[i][j - 1]) % MOD; if (lcs[i][j] == lcs[i - 1][j]) dp[i][j] = (dp[i][j] + dp[i - 1][j]) % MOD; if (lcs[i][j] == lcs[i - 1][j - 1]) dp[i][j] = (dp[i][j] - dp[i - 1][j - 1] + MOD) % MOD; } else { dp[i][j] = dp[i - 1][j - 1]; if (lcs[i][j] == lcs[i - 1][j]) dp[i][j] = (dp[i][j] + dp[i - 1][j]) % MOD; } } } } int main() { #ifndef ONLINE_JUDGE freopen("in.txt", "r", stdin); #endif // ONLINE_JUDGE int t; scanf("%d", &t); while (t--) { scanf("%s %s", str[0] + 1, str[1] + 1); len1 = strlen(str[0] + 1); len2 = strlen(str[1] + 1); build(); solve(); printf("%lld\n", dp[len1][len2]); } return 0; }