试题 算法提高 着急的WYF(不同子串个数)
资源限制
时间限制:476ms 内存限制:256.0MB
问题描述
由于战网的密码是一串乱码,WYF巧妙地忘记了他的密码。(他就是作死,如同自掘坟墓。说到掘坟墓,问题就来了——挖掘机技术究竟哪家强?)他现在非常着急,走投无路,都快飞起来了。他只记得他的密码是某个字符串S的子串。现在问题来了,你要告诉他有多少种可能的密码,以帮助他确定他能在多少时间内完成枚举并尝试密码的工作。
输入格式
输入仅包含一行,为字符串S,不含空格。
输出格式
输出一个整数,表示可能的密码数量。
样例输入
ToTal
样例输出
14
数据规模和约定
对于70%的数据,S的长度不超过1000;(暴力)
对于100%的数据,S的长度不超过15000。(Suffix Array)
思路:
思路来自于博客
后缀数组,我们先用后缀排序求出
然后算出
子串就是后缀的前缀,所以可以枚举每个后缀,计算前缀总数,再减掉重复。
- 前缀总数其实就是子串个数 。
- 如果按后缀排序的顺序枚举后缀,每次新增的子串就是除了与上一个后缀的 LCP 剩下的前缀。这些前缀一定是新增的,否则会破坏 的性质。只有这些前缀是新增的,因为 LCP 部分在枚举上一个前缀时计算过了。所以就是
最终的式子就是 。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <vector>
#include <string>
#include <cmath>
#include <set>
#include <map>
#include <deque>
#include <stack>
using namespace std;
typedef long long ll;
typedef vector<int> veci;
typedef vector<ll> vecl;
typedef pair<int, int> pii;
template <class T>
inline void read(T &ret) {
char c;
int sgn;
if (c = getchar(), c == EOF) return ;
while (c != '-' && (c < '0' || c > '9')) c = getchar();
sgn = (c == '-') ? -1:1;
ret = (c == '-') ? 0:(c - '0');
while (c = getchar(), c >= '0' && c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return ;
}
inline void out(int x) {
if (x > 9) out(x / 10);
putchar(x % 10 + '0');
}
const int maxn = 2e4 + 10;
int rk[maxn << 1], oldrk[maxn << 1], px[maxn], sa[maxn], ht[maxn], id[maxn], cnt[maxn];
char s[maxn];
bool cmp(int x, int y, int w) {
return (oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w]);
}
void da() {
int n = strlen(s + 1), m = 300;
for (int i = 1; i <= n; i++) cnt[rk[i] = s[i]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[rk[i]]--] = i;
int p = 0;
for (int w = 1; w < n; w <<= 1, m = p) {
p = 0;
for (int i = n; i > n - w; i--) id[++p] = i;
for (int i = 1; i <= n; i++) {
if (sa[i] > w) id[++p] = sa[i] - w;
}
memset(cnt, 0, sizeof(cnt));
for (int i = 1; i <= n; i++) cnt[px[i] = rk[id[i]]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[px[i]]--] = id[i];
memcpy(oldrk, rk, sizeof(rk));
p = 0;
for (int i = 1; i <= n; i++) {
rk[sa[i]] = cmp(sa[i], sa[i - 1], w) ? p : ++p;
}
}
//for (int i = 1; i <= n; i++) printf("%d ", sa[i]);
int k = 0;
for (int i = 1; i <= n; i++) {
if (k) k--;
int j = sa[rk[i] - 1];
while (s[i + k] == s[j + k]) k++;
ht[rk[i]] = k;
}
}
int main() {
scanf("%s", s + 1);
ll n = strlen(s + 1);
da();
ll sum = 0;
for (int i = 2; i <= n; i++) sum += ht[i];
sum = n * (n + 1) / 2 - sum;
printf("%lld", sum);
return 0;
}