题目传送门
题意:
一棵n个节点的树,第i个节点有一个权值ai。定义一个二元函数dist(x,y),dist(x,y)的值等于以x节点和y节点为端点的简单路径包括的节点个数,且这条路径上各个节点权值的gcd>1。询问最大dist(x,y)。
数据范围:n <= 2e5 ,ai <= 2e5。
题解:
这个题就是让你找出一条最长的简单路径,这条简单路径上的权值的gcd>1。最后让你输出这个长度。
朴素的做法显然是O(n²)的,我是不会优化,转而分析gcd。
你要发现一个现象就是节点的权值不大,可以进行质因数分解。并且权值的质因子只有十几个,是可以枚举的。
首先,打一个素数表,然后把每个节点的序号放进若干个set里面,这些set的下标是每个节点权值的质因子。
然后遍历每个质因数,在当前质因数set内的序号对应节点保留,其余视为不存在。此时得到若干个连通块。
通过这若干个连通块,你只需要找到最长的一条链即可,用这条链的长度去更新答案就好了。
对于一个连通块,找到最长的一条链的方法:取任意节点为根节点,然后dfs,定义Max[u]为以u节点为根节点的子树的最大深度。对于每个节点,若该节点的子树对应的两个最大Max分别是temp1和temp2,那么通过该节点的最长链为temp1+temp2+1。
感受:
这道题我想去优化O(n²)的朴素做法,用类似于树形dp的方法去优化,发现无法转移。(可能可以,但是我不会)
进而我发现了每个节点权值的质因子只有十几个,如果每个节点最多枚举十几次,那是可以接受的。
然后下面的思路就顺利了。
过题后看了网上各路大神解法,我以为我求最长链用的是点分治思想(通过当前节点的路径),但其实用的是树形dp的转移。
并且这道题真的可以用点分治做。。。。。。
代码:
#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
const int maxn = 2e5 + 5 ;
int n ;
int a[maxn] ;
int num , head[maxn] ;
int cnt = 0 ;
bool vis[maxn] ;
bool used[maxn] ;
int Max[maxn] ;
int prime[maxn] ;
int ans = 0 ;
set<int> st[maxn] , p[maxn] ;
struct Edge
{
int v , next ;
} edge[maxn << 1] ;
void add_edge(int u , int v)
{
edge[num].v = v ;
edge[num].next = head[u] ;
head[u] = num ++ ;
}
void get_prime()
{
memset(vis , 0 , sizeof(vis)) ;
vis[1] = 1 ;
for(int i = 2 ; i <= 2e5 ; i ++)
{
if(!vis[i])
prime[++ cnt] = i ;
for(int j = 1 ; j <= cnt && i * prime[j] <= 2e5 ; j ++)
{
vis[i * prime[j]] = 1 ;
if(i % prime[j] == 0) break ;
}
}
}
void init()
{
for(int i = 1 ; i <= cnt ; i ++)
for(int j = prime[i] ; j <= 2e5 ; j += prime[i])
st[j].insert(prime[i]) ;
for(int i = 1 ; i <= n ; i ++)
for(auto x : st[a[i]])
p[x].insert(i) ;
}
void dfs(int u)
{
int temp1 = 0 , temp2 = 0 ;
used[u] = 0 ;
for(int i = head[u] ; i != -1 ; i = edge[i].next)
{
int v = edge[i].v ;
if(!used[v]) continue ;
dfs(v) ;
int x = Max[v] ;
if(x > temp1) temp2 = temp1 , temp1 = x ;
else temp2 = max(temp2 , x) ;
}
Max[u] = temp1 + 1 ;
ans = max(ans , temp1 + temp2 + 1) ;
}
void cal(int x)
{
for(auto u : p[x]) used[u] = 1 , Max[u] = 0 ;
for(auto u : p[x])
{
if(!used[u]) continue ;
dfs(u) ;
}
}
void solve()
{
for(int i = 1 ; i <= cnt ; i ++) cal(prime[i]) ;
}
int main()
{
scanf("%d" , &n) ;
for(int i = 1 ; i <= n ; i ++) scanf("%d" , &a[i]) ;
num = 0 , memset(head , -1 , sizeof(head)) ;
for(int i = 1 ; i <= n - 1 ; i ++)
{
int u , v ;
scanf("%d%d" , &u , &v) ;
add_edge(u , v) , add_edge(v , u) ;
}
get_prime() ;
init() ;
solve() ;
printf("%d\n" , ans) ;
return 0 ;
}