题目描述
分析
一道简单的虚树加dp题。
显然拉出虚树之后对每条边二分出最优点然后给答案取min即可。
dp的设法是,f[x][012]表示x子树所有点到x的距离的0,1,2次幂。up[f][012]表示x子树外所有点到x。
虚树怎么建呢?
很显然虚树的点就是点集里所有点以及他们按dfn排序后,相邻两个的lca。
为了建出虚树,我们要维护一个深度递增的单调栈。
给出的点按dfn排序后,我们逐个加入虚树。
每次把一个点x和栈顶元素y求lca,然后把单调栈里深度大于lca的全部弹掉,加入lca,加入x。注意lca如果原本就有了就不需要加了。可以看出这个lca实际上就是x和点集的上一个元素的lca。
每次弹栈的时候,我们连边,即栈顶元素和底下一个元素连边,而如果lca的深度在他们之间,则连到lca上。
最后把栈清空一下就行啦。
接下来就是dp随便搞搞。
代码
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a<b)?a:b)
typedef long long ll;
typedef double db;
const int N=2e5+5,mo=998244353;
int dfn[N],td,f[N][25],g[N][25],Log[N],dis[N],pd[N],q,n,m,z,y,x,i,j,K,a[N],sta[N],st,lca,d[N],lst,rt;
db med;
ll upf[N][3],F[N][3],go[N][3],C,val,ans,tmp0,tmp1,len,X;
bool cmp(int x,int y) {return dfn[x]<dfn[y];}
int read()
{
int x=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while ('0'<=ch&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
return x;
}
int buf[20];
void Print(ll x)
{
buf[0]=0;
while (x) buf[++buf[0]]=x%10,x/=10;
if (!buf[0]) putchar('0');
while (buf[0]) putchar('0'+buf[buf[0]--]);
putchar('\n');
}
int tt,b[N],c[N],nxt[N],fst[N];
void cr(int x,int y,int z)
{
tt++;
b[tt]=y;
c[tt]=z;
nxt[tt]=fst[x];
fst[x]=tt;
}
int t1,b1[N],c1[N],nxt1[N],fst1[N];
void cr1(int x,int y,int z)
{
t1++;
b1[t1]=y;
c1[t1]=z;
nxt1[t1]=fst1[x];
fst1[x]=t1;
}
void dfs(int x,int y)
{
dfn[x]=++td;
f[x][0]=y;
dis[x]=dis[y]+1;
int i;
fo(i,1,20) f[x][i]=f[f[x][i-1]][i-1],g[x][i]=g[x][i-1]+g[f[x][i-1]][i-1];
for(int p=fst[x];p;p=nxt[p])
if (b[p]!=y)
{
g[b[p]][0]=c[p];
dfs(b[p],x);
}
}
int Lca(int x,int y)
{
if (dis[x]<dis[y]) swap(x,y);
int i;
fd(i,20,0) if (dis[f[x][i]]>=dis[y]) x=f[x][i];
if (x==y) return x;
fd(i,20,0) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int Len(int x,int y)
{
if (dis[x]<dis[y]) swap(x,y);
int ret=0,i;
fd(i,20,0) if (dis[f[x][i]]>=dis[y]) ret+=g[x][i],x=f[x][i];
return ret;
}
void thr(int x)
{
d[++d[0]]=x;
F[x][0]=(pd[x]==q);
int y;
for(int p=fst1[x];p;p=nxt1[p])
{
y=b1[p];
thr(y);
go[y][0]=F[y][0];
F[x][0]+=F[y][0];
go[y][1]=F[y][1]+F[y][0]*c1[p];
F[x][1]+=go[y][1];
go[y][2]=F[y][2]+2*F[y][1]*c1[p]+F[y][0]*c1[p]*c1[p];
F[x][2]+=go[y][2];
}
}
void dp(int x)
{
for (int p=fst1[x];p;p=nxt1[p])
{
y=b1[p];
upf[y][0]=upf[x][0]+F[x][0]-go[y][0];
upf[y][1]=upf[x][1]+F[x][1]-go[y][1];
upf[y][2]=upf[x][2]+F[x][2]-go[y][2]+upf[y][1]*c1[p]*2+upf[y][0]*c1[p]*c1[p];
upf[y][1]+=upf[y][0]*c1[p];
dp(y);
}
}
void solve(int x)
{
for (int p=fst1[x];p;p=nxt1[p])
{
y=b1[p];
len=c1[p];
tmp0=upf[y][0];
tmp1=upf[x][1]+F[x][1]-go[y][1];
med=(tmp1+len*tmp0-F[y][1])/db(tmp0+F[y][0]);
z=y;
X=0;
fd(i,20,0) if (X+g[z][i]<=med&&dis[f[z][i]]>=dis[x]) X+=g[z][i],z=f[z][i];
if (dis[f[z][0]]>=dis[x]&&med-db(X)>db(X+g[z][0])-med) X+=g[z][0],z=f[z][0];
C=F[y][2]+upf[x][2]+F[x][2]-go[y][2];
val=C+X*X*F[y][0]+2*X*F[y][1]+(len-X)*(len-X)*tmp0+2*(len-X)*tmp1;
cmin(ans,val);
solve(b1[p]);
}
}
int main()
{
freopen("t8.in","r",stdin);
freopen("tree.out","w",stdout);
n=read();m=read();
fo(i,1,n-1)
{
x=read();y=read();z=read();
cr(x,y,z);
cr(y,x,z);
}
fo(i,1,n) Log[i]=trunc(log(i)/log(2));
dfs(1,0);
fo(q,1,m)
{
fo(i,1,d[0])
{
fst1[d[i]]=0;
fo(j,0,2) upf[d[i]][j]=F[d[i]][j]=0;
}
d[0]=0;
t1=0;
K=read();
fo(i,1,K) a[i]=read(),pd[a[i]]=q;
sort(a+1,a+1+K,cmp);
sta[st=1]=a[1];
fo(i,2,K)
{
lca=Lca(a[i],sta[st]);
lst=0;
while (st&&dis[lca]<dis[sta[st]])
{
if (dis[sta[st-1]]>=dis[lca]) cr1(sta[st-1],sta[st],Len(sta[st-1],sta[st]));
lst=sta[st--];
}
if (lca!=sta[st])
{
if (dis[lca]<dis[lst]) cr1(lca,lst,Len(lca,lst));
sta[++st]=lca;
}
sta[++st]=a[i];
}
while (st>1) cr1(sta[st-1],sta[st],Len(sta[st-1],sta[st])),st--;
rt=sta[1];
st--;
ans=1e18;
d[0]=0;
thr(rt);
dp(rt);
solve(rt);
if (K==1) ans=0;
Print(ans);
}
}