我们观察 n n n, m m m大小,我们发现 n n n很小,考虑用 f l o r d flord flord去解决问题,我们可以发现并不可以直接解决问题,因为一条路径表示为前缀最小值的和,我们可以将题目转化为每一段只用第一条边的最长路径,即我们"加边",把新路径表示为第一条路径*两个点直接所经过最小边数,然后再用这写新边去跑一遍flord即可,时间复杂度 O ( n 3 + n m + n 3 ) O(n^3+nm+n^3) O(n3+nm+n3) 。
第一遍我们跑一遍 f l o r d flord flord表示两点间经过的最小边数,然后第二次我们跑用某条边创造的新路径的新边。第三次跑一遍朴素 f l o r d flord flord。
第二遍详细过程可解释为我们点1到点7经过了点1 、2、5、6、7,我们考虑直接用点1和点2之间的边去跑完所有路径,即dist[1][2] * 4是点1到点7的新路径,不难发现这是一个等价问题变换,因为我们正解后缀和一定要比我们当前路径的值小。
赛时没有想到优化 O ( m n 2 ) O(mn^2) O(mn2)的做法,补题时发现转化思考方式可以优化到 O ( m n ) O(mn) O(mn)。
代码如下
#include <iostream>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <bits/stdc++.h>
#define x first
#define y second
#define int long long
using namespace std ;
int read(){
int res = 0 , flag = 1 ;
char c = getchar() ;
while(!isdigit(c)){
if(c == '-') flag = -1 ;
c = getchar() ;
}
while(isdigit(c)){
res = (res << 1) + (res << 3) + (c ^ 48) ;
c = getchar() ;
}
return res * flag ;
}
void write(int x){
if(x < 0) {
putchar('-') ;
x = - x ;
}
if(x >= 10) write(x / 10) ;
putchar('0' + x % 10) ;
}
void write(int x , char c){
write(x) ;
putchar(c) ;
}
const int N = 3e2 + 10 ;
typedef pair<int , int> pii ;
typedef pair<double ,double> pdd ;
const int mod = 998244353 ;
const int inf = 1e9 + 10 ;
const int M = 1e5 + 10 ;
int dist[N][N] , dp[N][N] ;
struct T{
int a , b , c ;
} l[M] ;
void solve() {
int n = read() , m = read() ;
memset(dist , 0x3f , sizeof dist) ;
memset(dp , 0x3f , sizeof dp) ;
for(int i = 1 ; i <= n ; i ++) dist[i][i] = dp[i][i] = 0 ;
for(int i = 1 ; i <= m ; i ++){
int a = read() , b = read() , c = read() ;
l[i].a = a , l[i].b = b , l[i].c = c ;
dist[a][b] = dist[b][a] = 1 ;
dp[a][b] = dp[b][a] = min(dp[a][b] , c) ;
}
for(int k = 1 ; k <= n ; k ++){
for(int i = 1 ; i <= n ; i ++)
for(int j = 1 ; j <= n ; j ++)
dist[i][j] = min(dist[i][j] , dist[i][k] + dist[k][j]) ;
}
for(int i = 1 ; i <= m ; i ++){
int a = l[i].a , b = l[i].b , c = l[i].c ;
for(int j = 1 ; j <= n ; j ++)
dp[a][j] = min(dp[a][j] , (1 + dist[b][j]) * c) ;
for(int j = 1 ; j <= n ; j ++)
dp[b][j] = min(dp[b][j] , (1 + dist[a][j]) * c) ;
}
for(int k = 1 ; k <= n ; k ++){
for(int i = 1 ; i <= n ; i ++)
for(int j = 1 ; j <= n ; j ++)
dp[i][j] = min(dp[i][j] , dp[i][k] + dp[k][j]) ;
}
int ans = 0 ;
for(int i = 1 ; i <= n ; i ++)
for(int j = i + 1 ; j <= n ; j ++)
ans += dp[i][j] ;
write(ans) ;
}
signed main(void){
solve() ;
}