题目大意
给出一个长度为 n ( 1 ≤ n ≤ 2 e 5 ) n ( 1 \leq n \leq 2e5) n(1≤n≤2e5)的序列 a a a,初始时每相邻两个数中间都含有一条权值为 p p p的边,对于一段区间 [ i , j ] [i,j] [i,j],若其满足 g c d ( { a k ∣ i ≤ k ≤ j } ) = m i n { a k ∣ i ≤ k ≤ j } gcd(\{a_k | i \leq k \leq j\}) = min\{a_k | i \leq k \leq j\} gcd({ ak∣i≤k≤j})=min{ ak∣i≤k≤j},那么 i , j i,j i,j之间有一条权值为 m i n { a k ∣ i ≤ k ≤ j } min\{a_k | i \leq k \leq j\} min{ ak∣i≤k≤j}的边。每个数看做一个节点,求出这棵树的最小生成树的权值。
解题思路
假设对于某个点 k k k来说,它能向左向右延伸的最大区间为 [ i , j ] [i,j] [i,j],那么大区间内含有 a k a_k ak的所有子区间的端点之间都可以连边,或者说,问题转化为 k k k能向区间内的所有点连边。
然后考虑一个点 j j j如果和它左边多个点连边,我们只需要计算最小边权的那个数,或者说假设左边的最小的数在位置 i i i,那么保证 a i a_i ai是这段区间内最小的数,且 ∀ k ∈ [ i , j ] , a [ i ] ∣ a [ k ] \forall k \in [i,j], a[i] ~| ~a[k] ∀k∈[i,j],a[i] ∣ a[k];同理右边的最小的那个数。
考虑上面的思路,我们可以从左向右扫,维护左端点的最小数的位置 L [ i ] L[i] L[i],具体过程为(从右向左扫时同理):
- 若 a [ L [ i − 1 ] ] ≤ a [ i ] a[L[i-1]] \leq a[i] a[L[i−1]]≤a[i]且 a [ i ] % a [ L [ i − 1 ] ] = = 0 a[i] \% a[L[i-1]]==0 a[i]%a[L[i−1]]==0,那么 L [ i ] = L [ i − 1 ] L[i] = L[i-1] L[i]=L[i−1];否则 L [ i ] = i L[i] = i L[i]=i。
使用克鲁斯卡尔求解最小生成树,初始的 n − 1 n - 1 n−1条边加入集合。根据已经维护好的两个数组 L , R L,R L,R,即代表每个数最左边的合法位置和最右边的合法位置,若某个方向的位置不是该点本身,将两点的位置作为边、端点的值作为权值加入集合,这样集合中最多只会有 3 ∗ n 3*n 3∗n条边,跑板子即可。
#include <bits/stdc++.h>
using namespace std;
#define ENDL "\n"
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int Mod = 1e9 + 7;
const int maxn = 2e5 + 10;
struct node {
int u, v, w;
bool operator<(const node &p) const {
return w < p.w; }
};
vector<node> edges;
int a[maxn], f[maxn], L[maxn], R[maxn];
int Find(int x) {
return f[x] == x ? x : f[x] = Find(f[x]); }
int main() {
// freopen(out.txt, stdout, ) srand(time(0));
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int T, n, p;
cin >> T;
while (T--) {
cin >> n >> p;
edges.clear();
for (int i = 1; i <= n; i++) {
cin >> a[i];
f[i] = i;
if (i > 1) edges.push_back({
i - 1, i, p});
}
L[1] = 1, R[n] = n;
for (int i = 2; i <= n; i++) {
if (a[L[i - 1]] <= a[i] && a[i] % a[L[i - 1]] == 0)
L[i] = L[i - 1];
else
L[i] = i;
}
for (int i = n - 1; i >= 1; i--) {
if (a[R[i + 1]] <= a[i] && a[i] % a[R[i + 1]] == 0)
R[i] = R[i + 1];
else
R[i] = i;
}
// for (int i = 1; i <= n; i++) cout << L[i] << " ";
// cout << endl;
// for (int i = 1; i <= n; i++) cout << R[i] << " ";
// cout << endl;
for (int i = 1; i <= n; i++) {
if (L[i] < i) edges.push_back({
L[i], i, a[L[i]]});
if (R[i] > i) edges.push_back({
i, R[i], a[R[i]]});
}
sort(edges.begin(), edges.end());
int cnt = 0;
ll ans = 0;
for (auto p : edges) {
int fu = Find(p.u), fv = Find(p.v);
if (fu != fv) {
f[fv] = fu;
ans += p.w;
cnt++;
}
if (cnt == n - 1) break;
}
cout << ans << ENDL;
}
return 0;
}