这道题,我们用一个count数组来记录当前根节点的子树上,每种距离的个数有多少,然后所有点对和模3等于零的情况,即可变为零节点的个数的平方加上1节点个数乘以2节点的再乘个2就统计完成了
所以solve函数就变成了:
int solve(int rt,int len)//每做一次solve 即要重新得到一次dis数组
{
Count[0] = Count[1] = Count[2] = 0;
dis[rt] = len%3;
get_dis(rt,0,len);
return Count[0]*Count[0]+Count[1]*Count[2]*2;//得到模3等于0的点
}
其他部分就都是模板的套路了
AC代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 2e5+7;
#define inf 0x3f3f3f3f
int head[MAXN],dis[MAXN],maxson[MAXN],siz[MAXN],vis[MAXN];
int SIZE,cnt,num,maxx,root,ans,Count[3];
struct Edge
{
int to,w,next;
}edge[MAXN<<1];
void addedge(int u,int v,int w)
{
edge[++cnt].to = v;
edge[cnt].w = w;
edge[cnt].next = head[u];
head[u] = cnt;
}
void get_root(int u,int fa)
{
siz[u] = 1;
maxson[u] = 0;
for(int i = head[u];i;i = edge[i].next){
int v = edge[i].to;
if(vis[v] || fa == v) continue;
get_root(v,u);
siz[u] += siz[v];
maxson[u] = max(maxson[u],siz[v]);
}
maxson[u] = max(maxson[u],SIZE-siz[u]);
if(maxx > maxson[u]) root = u,maxx = maxson[u];
}
void get_dis(int u,int fa,int d)
{
dis[++num] = d%3;//直接模3
Count[dis[num]]++;
for(int i = head[u];i;i = edge[i].next){
int v = edge[i].to;
if(vis[v] || fa == v) continue;
get_dis(v,u,d+edge[i].w);
}
return ;
}
int solve(int rt,int len)//每做一次solve 即要重新得到一次dis数组
{
Count[0] = Count[1] = Count[2] = 0;
dis[rt] = len%3;
get_dis(rt,0,len);
return Count[0]*Count[0]+Count[1]*Count[2]*2;//得到模3等于0的点
}
void Divide(int rt)
{
ans = ans + solve(rt,0);//先加一遍根节点的答案
vis[rt] = 1;
for(int i = head[rt];i;i = edge[i].next){
int v = edge[i].to;
if(vis[v]) continue;
ans = ans - solve(v,edge[i].w);
SIZE = siz[v];
root = 0;
maxx = inf;
get_root(v,rt);
Divide(root);
}
}
int gcd(int a,int b)
{
if(a < b) swap(a,b);
int r;
while(a%b){
r = a%b;
a = b;
b = r;
}
return b;
}
int main()
{
int n;
scanf("%d",&n);
for(int i = 1;i < n;i ++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
addedge(u,v,w);
addedge(v,u,w);
}
memset(vis,0,sizeof(vis));
maxx = inf;
SIZE = n;
root = 0;
//printf("1.---\n");
get_root(1,0);
//printf("2.---\n");
Divide(root);
//printf("---ans = %d\n",ans);
int sum = n*n;
printf("%d/%d\n",ans/gcd(ans,sum),sum/gcd(ans,sum));
return 0;
}