D-Double Strings
题意
给两个字符串 A A A 和 B B B ( 1 ≤ ∣ A ∣ , ∣ B ∣ ≤ 5000 ) (1\le |A|,|B| \le 5000) (1≤∣A∣,∣B∣≤5000) , a a a 是 A A A 的子序列, b b b 是 B B B 的子序列,求有多少个子序列组合满足两个子序列的长度相同并且 ∃ i ∈ { 1 , 2 , … , ∣ a ∣ } , A a i < B b i , ∀ j ∈ { 1 , 2 , … , i − 1 } , A a j = B b j \exists i \in \{1, 2, \dots, |a|\},A_{ai} < B_{bi},\forall j \in \{1, 2, \dots, i - 1\},A_{aj}=B_{bj} ∃i∈{ 1,2,…,∣a∣},Aai<Bbi,∀j∈{ 1,2,…,i−1},Aaj=Bbj 。
题解
- 可以把子序列分成三段 ,第一段两个子序列完全相同,第二段长度为 1 1 1 满足 A i < B i A_i<B_i Ai<Bi ,第三段只需要长度相同即可。
- 二重循环遍历两个字符串,如果满足 A i < B j A_i<B_j Ai<Bj ,那么 i i i 把字符串 A A A 分成了前后两部分, j j j 把字符串 B B B 分成了前后两部分,分别求出前面部分公共子序列的数量和后面部分相同的数量即可。
- 前面部分可以通过二维dp O ( 1 ) O(1) O(1) 转移得到;
- 后面部分可以dp, A A A 此时剩余长度为 x x x , B B B 此时剩余长度为 y y y,不妨 x ≤ y x\le y x≤y , ∑ i = 0 x C x i ⋅ C y i = ∑ i = 0 x C x x − i ⋅ C y i = C x + y x \sum_{i = 0}^{x} C_x^i \cdot C_y^i = \sum_{i = 0}^{x} C_x^{x - i} \cdot C_y^i = C_{x + y} ^ x ∑i=0xCxi⋅Cyi=∑i=0xCxx−i⋅Cyi=Cx+yx 。
代码
#include <bits/stdc++.h>
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#define per(i, a, n) for (int i = n; i >= a; --i)
#ifdef LOCAL
#include "Print.h"
#define de(...) W('[', #__VA_ARGS__,"] =", __VA_ARGS__)
#else
#define de(...)
#endif
using namespace std;
typedef long long ll;
const int maxn = 5e3 + 5;
const int mod = 1e9 + 7;
char s[maxn], t[maxn];
int n, m;
ll dp[maxn][maxn];
ll fac[maxn * 2], inv[maxn * 2];
void add(ll &x, ll y) {
if ((x += y) >= mod) x -= mod; }
void sub(ll &x, ll y) {
if ((x -= y) < 0) x += mod; }
ll powmod(ll a, ll b) {
ll ans = 1;
while (b) {
if (b & 1) ans = ans * a % mod;
a = a * a % mod;
b >>= 1;
}
return ans;
}
void init() {
fac[1] = fac[0] = 1;
for (int i = 2; i < maxn * 2; ++i) fac[i] = fac[i - 1] * i % mod;
inv[maxn * 2 - 1] = powmod(fac[maxn * 2 - 1], mod - 2);
for (int i = 2 * maxn - 2; i >= 0; --i) inv[i] = inv[i + 1] * (i + 1) % mod;
}
inline ll cal(ll a, ll b) {
return fac[a] * inv[b] % mod * inv[a - b] % mod;
}
void DP() {
rep(i, 0, maxn - 1) dp[i][0] = dp[0][i] = 1;
rep(i, 1, n) rep(j, 1, m) {
add(dp[i][j], dp[i - 1][j] + dp[i][j - 1]);
if (s[i] != t[j]) sub(dp[i][j], dp[i - 1][j - 1]);
}
}
int case_Test() {
scanf("%s%s", s + 1, t + 1);
n = strlen(s + 1), m = strlen(t + 1);
init(), DP();
ll ans = 0;
rep(i, 1, n) rep(j, 1, m) if (s[i] < t[j])
add(ans, dp[i - 1][j - 1] * cal(n + m - i - j, min(n - i, m - j)) % mod);
printf("%lld\n", ans);
return 0;
}
int main() {
#ifdef LOCAL
freopen("in.in", "r", stdin);
freopen("out.out", "w", stdout);
clock_t start = clock();
#endif
int _ = 1;
// scanf("%d", &_);
while (_--) case_Test();
#ifdef LOCAL
printf("Time used: %.3lfs\n", (double)(clock() - start) / CLOCKS_PER_SEC);
#endif
return 0;
}