实际上给的单向边而不是双向边,先处理出每个点到 b + 1 b + 1 b+1 的最短距离 d a d_a da, b + 1 b + 1 b+1 到每个点的最短距离 d b d_b db。 预处理权值 s = d a + d b s = d_a + d_b s=da+db,集合划分后的每一个点的贡献为: s [ i ] ∗ ( s i z e − 1 ) s[i] * (size - 1) s[i]∗(size−1),其中 s i z e size size 为划分到的集合的大小。 显然最优情况是值连续的划分到一组,对 s s s 进行排序,就可以在序列上做线性 dp。
考虑 d p [ i ] [ j ] dp[i][j] dp[i][j] 表示前 j j j 个分成 i i i 组的最小总距离。转移方程为 dp[k][i] = dp[k - 1][j] + (i - j - 1) * (sum[i] - sum[j])
,总复杂度为 O ( n 3 ) O(n^3) O(n3),考虑优化:
1、不考虑选择分 k 组这个限制,肯定尽可能多分组最优,给每个分组加上一个代价 x x x,每分一次组都要加上 x x x 的代价,显然 x x x 越大,最优解的分组数越少, x x x 越小, 最优解的分组数越大,满足单调性,考虑二分这个代价 x x x,然后做没有限制的 dp 的复杂度是 O ( n 2 ) O(n^2) O(n2),二分的右边界要大一点,大到最优解可能只分一次组。(从凸包的角度考虑可能不容易看出)
2、由于权值比较小的点划分到的 s i z e size size 肯定更大,决策具有单调性,利用这个单调性,当计算 dp[k][i] 时,转移范围只要枚举 [ i − ⌊ i k ⌋ , i − 1 ] [i - \lfloor\frac{i}{k}\rfloor, i - 1] [i−⌊ki⌋,i−1],因为最后这一块的大小肯定小于等于平均值, k k k 次计算后,对于每一个 i i i,计算所有的 d p [ k ] [ i ] dp[k][i] dp[k][i] 的复杂度是 i log i i \log i ilogi,最后总复杂度为 n 2 log n n^2\log n n2logn,这个 n 2 log n n^2\log n n2logn 没有跑满,因此跑得比较快。
由于决策转移点具有单调性,还可以实现到 O ( n 2 ) O(n^2) O(n2),且已经有 n log 2 n n\log^2n nlog2n 的做法 (都不会)
wqs二分优化代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e3 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const ll inf = 1e15;
int n,b,s,r,a[maxn],vis[maxn];
ll sum[maxn],d[maxn],t[maxn];
vector<pii> g[maxn],h[maxn];
ll dp[maxn],tp[maxn],lst[maxn];
void spfa1(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
d[i] = inf;
memset(vis,0,sizeof vis);
d[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : g[top]) {
if (d[it.fir] > d[top] + it.sec) {
d[it.fir] = d[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
void spfa2(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
t[i] = inf;
memset(vis,0,sizeof vis);
t[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : h[top]) {
if (t[it.fir] > t[top] + it.sec) {
t[it.fir] = t[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
ll solve(ll x) {
for (int i = 0; i <= b; i++)
dp[i] = inf, lst[i] = tp[i] = 0;
dp[0] = 0;
for (int i = 1; i <= b; i++) {
for (int j = lst[i]; j < i; j++) {
if (dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x < dp[i]) {
dp[i] = dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x;
tp[i] = tp[j] + 1;
} else if (dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x == dp[i]) {
if (tp[i] < tp[j] + 1)
tp[i] = tp[j] + 1;
}
}
}
return tp[b];
}
int main() {
scanf("%d%d%d%d",&n,&b,&s,&r);
for (int i = 1; i <= r; i++) {
int u,v,w; scanf("%d%d%d",&u,&v,&w);
g[u].push_back(pii(v,w));
h[v].push_back(pii(u,w));
}
spfa1(b + 1); spfa2(b + 1);
for (int i = 1; i <= b; i++)
sum[i] = d[i] + t[i];
sort(sum + 1,sum + b + 1);
for (int i = 1; i <= b; i++)
sum[i] += sum[i - 1];
ll l = 0, r = 1ll << 48;
while (l < r) {
ll mid = l + r >> 1;
if (solve(mid) < s) r = mid;
else l = mid + 1;
}
solve(l - 1);
printf("%lld\n",dp[b] - s * (l - 1));
return 0;
}
决策单调性优化:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e3 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const ll inf = 1e15;
int n,b,s,r,a[maxn],vis[maxn];
ll sum[maxn],d[maxn],t[maxn];
vector<pii> g[maxn],h[maxn];
ll dp[maxn],tp[maxn];
void spfa1(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
d[i] = inf;
memset(vis,0,sizeof vis);
d[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : g[top]) {
if (d[it.fir] > d[top] + it.sec) {
d[it.fir] = d[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
void spfa2(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
t[i] = inf;
memset(vis,0,sizeof vis);
t[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : h[top]) {
if (t[it.fir] > t[top] + it.sec) {
t[it.fir] = t[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
ll solve() {
for (int i = 0; i <= b; i++)
tp[i] = dp[i] = inf;
tp[0] = 0;
for (int k = 1; k <= s; k++) {
for (int i = 1; i <= b; i++) {
for (int j = i - i / k; j <= i - 1; j++) // i / k 是平均每个块的大小
dp[i] = min(dp[i],tp[j] + (i - j - 1) * (sum[i] - sum[j]));
}
for (int i = 0; i <= b; i++)
tp[i] = dp[i], dp[i] = inf;
}
return tp[b];
}
int main() {
scanf("%d%d%d%d",&n,&b,&s,&r);
for (int i = 1; i <= r; i++) {
int u,v,w; scanf("%d%d%d",&u,&v,&w);
g[u].push_back(pii(v,w));
h[v].push_back(pii(u,w));
}
spfa1(b + 1); spfa2(b + 1);
for (int i = 1; i <= b; i++)
sum[i] = d[i] + t[i];
sort(sum + 1,sum + b + 1);
for (int i = 1; i <= b; i++)
sum[i] += sum[i - 1];
printf("%lld\n",solve());
return 0;
}