题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4547
思路:由于点很多所以需要邻接表存数据,先dfs记录每个点的深度,然后tarjan找到最近共公祖先,然后根据题意输出
这到题需要map来讲字符串转换成数字来记录;
父目录到子目录只需一步,子目录到父目录所需步数就是深度差,转成的目录不是父目录那么转成另一目录的步数就是到两者最近公共祖先的深度的差再加一,否则不用加。而如果两者是一样的,也就是它要转成它自己,那么所需的步数就是0.
#include<stdio.h>
#include<string.h>
#include<string>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#include<set>
#include<algorithm>
#define inf 0x3f3f3f3f
using namespace std;
const int N=4e5+10;
struct node
{
int v,ne;
}edge[N];
int head[N];
struct node1//存查询的路径
{
int v,ne,num;
}qre[N];
int f[N];
struct node2//存查询的点和他们的最先公共祖先
{
int v,u,lca;
}ans[N];
int vis[N];//标记是否访问
int deep[N];//记录深度
int pre[N];
int isroot[N];//找根
int n,e,m,top;
map<string,int>mp;//把字符串转换成数字
void add(int a,int b)
{
edge[e].v=b;
edge[e].ne=head[a];
head[a]=e++;
}
void addq(int a,int b,int c)
{
qre[e].v=b;
qre[e].ne=f[a];
qre[e].num=c;
f[a]=e++;
}
void init()
{
top=0;
memset(f,-1,sizeof(f));
memset(head,-1,sizeof(head));
memset(vis,0,sizeof(vis));
memset(deep,0,sizeof(deep));
for(int i=1;i<=n;i++)
pre[i]=i;
}
void dfs(int a)//dfs访问每个点并记录深度
{
for(int i=head[a];i!=-1;i=edge[i].ne)
{
int v=edge[i].v;
if(!deep[v])
{
deep[v]=deep[a]+1;
dfs(v);
}
}
}
int found(int r)//找祖先并且压缩路径
{
if(r==pre[r])
return r;
else
{
pre[r]=found(pre[r]);
return pre[r];
}
}
int change(char str[])
{
if(mp.find(str)==mp.end())
{
mp[str]=++top;
return top;
}
else
return mp[str];
}
void tarjan(int u)
{
for(int i=f[u];i!=-1;i=qre[i].ne)
{
int v=qre[i].v;
if(vis[v])
{
int temp=qre[i].num;
ans[temp].lca=found(v);
}
}
vis[u]=1;
for(int i=head[u];i!=-1;i=edge[i].ne)
{
int v=edge[i].v;
if(!vis[v])
{
tarjan(v);
pre[v]=u;
}
}
}
void solve()
{
for(int i=1;i<=n;i++)
{
if(!isroot[i])
{
deep[i]=1;
dfs(i);
tarjan(i);
}
}
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d %d",&n,&m);
int a,b;
init();
memset(isroot,0,sizeof(isroot));
mp.clear();
char str1[45],str2[45];
e=0;
for(int i=1;i<=n-1;i++)
{
scanf("%s",str1);
scanf("%s",str2);
a=change(str1);
b=change(str2);
add(a,b);
add(b,a);
isroot[a]=1;
}
e=0;
for(int i=1;i<=m;i++)
{
scanf("%s",str1);
scanf("%s",&str2);
a=mp[str1];
b=mp[str2];
addq(a,b,i);
addq(b,a,i);
ans[i].u=a;
ans[i].v=b;
}
solve();
for(int i=1;i<=m;i++)
{
int t=deep[ans[i].u]-deep[ans[i].lca];
if(ans[i].v!=ans[i].lca)
t++;
if(ans[i].v==ans[i].u)
t=0;
printf("%d\n",t);
}
}
return 0;
}