Problem
Solution
虚树+dp,但是dp挺难写的
写什么dp咯,分情况讨论+模拟可是O(n)的
详细说一下dp吧。我们设bel表示这个点最近的关键点,直接dp儿子,然后比较子树中的答案是否更优。但是最近的关键点可能并不在子树中,那么我们还需要重新再判断一下,是否父亲的最优答案可以更新儿子的最优答案。可以用两边dp解决。
怎么统计答案?不妨直接对虚树上的边进行考虑。注意因为要统计到一些不在虚树上的空子树的答案,我们直接用sz来统计答案。那么对于边有两种情况,要么整条边都是去同一个关键点更优,要么就是上面的一部分去往深度小的一个关键点,下面的一部分去另一个。我们对第二种情况进行考虑,怎么找分割点呢?其实是可以用二分的,但由于这个在树上,把这条边剖成一条序列会很麻烦,时间达到了
,我们可以采用另一种类似的思想倍增做。
一次询问的时间复杂度
因为在建虚树需要求很多次lca,我们推荐用树剖,最好用rmq。我用的倍增导致常数有点大,有四个点都是900ms。。在TLE的边缘试探.jpg
Code
#include <algorithm>
#include <cstring>
#include <cstdio>
#define rg register
using namespace std;
const int maxn=300010,INF=0x3f3f3f3f;
struct data{int v,w,nxt;}edge[maxn<<1];
int n,q,m,p,top,cnt,dfc,a[maxn],head[maxn],pos[maxn],bel[maxn],rem[maxn];
int b[maxn],ans[maxn],stk[maxn],sz[maxn],dfn[maxn],deep[maxn],f[maxn][20];
template <typename Tp> inline void read(Tp &x)
{
x=0;char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
}
inline void insert(int u,int v,int w){edge[++p]=(data){v,w,head[u]};head[u]=p;}
inline int cmp(const int &x,const int &y){return dfn[x]<dfn[y];}
void dfs(int x)
{
dfn[x]=++dfc;sz[x]=1;deep[x]=deep[f[x][0]]+1;
for(int i=1;i<=19;i++)
f[x][i]=f[f[x][i-1]][i-1];
for(int i=head[x];i;i=edge[i].nxt)
if(edge[i].v!=f[x][0])
f[edge[i].v][0]=x,dfs(edge[i].v),sz[x]+=sz[edge[i].v];
}
void input()
{
int u,v;
read(n);
for(int i=1;i<n;i++)
{
read(u);read(v);
insert(u,v,1);insert(v,u,1);
}
read(q);dfs(1);memset(head,0,sizeof(head));
}
int lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
for(int i=19;~i;i--)
if(deep[f[x][i]]>=deep[y]) x=f[x][i];
if(x==y) return x;
for(int i=19;~i;i--)
if(f[x][i]^f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int dis(int x,int y){return deep[x]+deep[y]-(deep[lca(x,y)]<<1);}
void build()
{
int t;
read(m);cnt=top=p=0;
for(int i=1;i<=m;i++) read(a[i]),bel[a[i]]=b[i]=a[i];
sort(a+1,a+m+1,cmp);
if(bel[1]^1) stk[++top]=1;
for(int i=1;i<=m;i++)
{
t=0;
while(top)
{
t=lca(a[i],stk[top]);
if(top>1&&deep[t]<deep[stk[top-1]])
insert(stk[top-1],stk[top],dis(stk[top-1],stk[top])),top--;
else if(deep[t]<deep[stk[top]])
{insert(t,stk[top],dis(t,stk[top]));top--;break;}
else break;
}
if(stk[top]^t) stk[++top]=t;
stk[++top]=a[i];
}
while(top>1) insert(stk[top-1],stk[top],dis(stk[top-1],stk[top])),top--;
}
void dfs1(int x)
{
pos[++cnt]=x;rem[x]=sz[x];
for(int i=head[x],d1,d2;i;i=edge[i].nxt)
{
dfs1(edge[i].v);
if(!bel[edge[i].v]) continue;
d1=dis(x,bel[x]),d2=dis(x,bel[edge[i].v]);
if(!bel[x]||d1>d2||(d1==d2&&bel[edge[i].v]<bel[x]))
bel[x]=bel[edge[i].v];
}
}
void dfs2(int x)
{
for(int i=head[x],d1,d2;i;i=edge[i].nxt)
{
d1=dis(edge[i].v,bel[x]);d2=dis(edge[i].v,bel[edge[i].v]);
if(!bel[edge[i].v]||d1<d2||(d1==d2&&bel[x]<bel[edge[i].v]))
bel[edge[i].v]=bel[x];
dfs2(edge[i].v);
}
}
void calc(int x,int y)
{
int tmp=y,mid=y;
for(int i=19;~i;i--)
if(deep[f[tmp][i]]>deep[x]) tmp=f[tmp][i];
rem[x]-=sz[tmp];//rem表示的是这个节点构建虚树时被忽略的其他子树的总大小
if(bel[x]==bel[y]){ans[bel[x]]+=sz[tmp]-sz[y];return ;}
for(int i=19,d1,d2;~i;i--)
if(deep[f[mid][i]]>deep[x])
{
d1=dis(bel[x],f[mid][i]);d2=dis(bel[y],f[mid][i]);
if(d2<d1||(d2==d1&&bel[y]<bel[x])) mid=f[mid][i];
}
ans[bel[x]]+=sz[tmp]-sz[mid];ans[bel[y]]+=sz[mid]-sz[y];
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
#endif
input();
while(q--)
{
build();dfs1(1);dfs2(1);
for(int i=1;i<=cnt;i++)
for(int j=head[pos[i]];j;j=edge[j].nxt)
calc(pos[i],edge[j].v);
for(int i=1;i<=cnt;i++) ans[bel[pos[i]]]+=rem[pos[i]];
for(int i=1;i<=m;i++) printf("%d ",ans[b[i]]);putchar('\n');
for(int i=1;i<=cnt;i++) head[pos[i]]=bel[pos[i]]=ans[pos[i]]=0;
}
return 0;
}
RMQ
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline int getmin(int x,int y){return deep[x]<deep[y]?x:y;}
void dfs(int x,int pre)
{
deep[x]=deep[pre]+1;dfn[0][x]=cnt+1;
for(int i=head[x];i;i=edge[i].nxt)
if(edge[i].v!=pre)
{
a[++cnt]=x;mn[0][cnt]=a[cnt];dfs(edge[i].v,x);
}
a[++cnt]=x;mn[0][cnt]=a[cnt];dfn[1][x]=cnt;
}
int query(int l,int r)
{
int p=lg[r-l+1],k=r-(1<<p)+1;
return getmin(mn[p][l],mn[p][k]);
}
void init()
{
lg[0]=-1;
for(rg int i=1;i<=1000000;i++) lg[i]=lg[i>>1]+1;
for(int i=1;i<=19;i++)
for(rg int l=1,r=l+(1<<i-1);r<=cnt;l++,r++)
mn[i][l]=getmin(mn[i-1][l],mn[i-1][r]);
}
int lca(int x,int y){return query(min(dfn[0][x],dfn[0][y]),max(dfn[1][x],dfn[1][y]));}