链接
题解
折腾了一下午,我人没了QwQ
题目说:在树有这样的三个点 ,满足 ,让你数这样的三元组 的个数
这是我的错误思路:
肯定存在一个中心
,使得
在以
为根时处在不同的子树中,我如果以
为根建立有根树,那么如果我枚举
的话,那么这个三个点的分布只有两种情况:
情况一:一个在
往上走的那个分量里,另外两个处于
的不同子树中
情况二:处在
的三个不同子树中
我发现第一种情况很难计算,最后也没想出来怎么做
这是正确的思路:
既然已经建立了以
为根节点的有根树,就不要再考虑把题目给的条件看作是“以
为根建树时xxxx”.。
这个时候,应该着眼于三元组
在以
为根的有根树中的
,在这个
处统计才是正确的思路
在
处统计三元组,我枚举
的深度
,那么另外两个点
必须是长这样:
把满足这样的分布的
的对数,记作
,注意
包含了
的所有可能取值对应的方案
再记录一个 表示以 为根的子树深度为 的点的数目
要利用长链剖分辅助,重复利用空间,并减少重复计算,最后的时间复杂度是 的
代码
#include <bits/stdc++.h>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 100010
#define maxe 200010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct Graph
{
int etot, head[maxn], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(_,__) for(int p=__.head[_];p;p=__.next[p])
}G;
struct Longest_Chain_Decomposition
{
ll tot, len[maxn], son[maxn], depth[maxn], istop[maxn];
void dfs(Graph& G, ll u, ll fa)
{
son[u]=0;
len[u]=1;
depth[u]=depth[fa]+1;
istop[u]=false;
forp(u,G)
{
ll v(G.to[p]); if(v==fa)continue;
dfs(G,v,u);
if(len[v]+1>len[u])len[u]=len[v]+1, son[u]=v;
}
forp(u,G)
{
ll v(G.to[p]); if(v==fa)continue;
if(v!=son[u])istop[v]=true;
}
}
void run(Graph& G, ll root)
{
tot=0;
depth[0]=0;
dfs(G,root,0);
istop[root]=true;
}
}lcd;
ll *f[maxn], *g[maxn], pool[maxn<<2], tot, n, ans;
void dfs(ll u, ll fa)
{
if(lcd.istop[u])
{
f[u] = pool + tot;
tot += lcd.len[u];
tot += lcd.len[u];
g[u] = pool + tot;
tot += lcd.len[u];
}
if(lcd.son[u])
{
f[lcd.son[u]] = f[u] + 1;
g[lcd.son[u]] = g[u] - 1;
dfs(lcd.son[u],u);
}
f[u][0]=1;
forp(u,G)
{
ll v(G.to[p]), i; if(v==fa or v==lcd.son[u])continue;
dfs(v,u);
rep(i,0,lcd.len[v]-1)
{
ans += f[v][i] * g[u][i+1];
if(i>1)ans += g[v][i] * f[u][i-1];
}
rep(i,0,lcd.len[v]-1)
{
if(i>0)g[u][i-1] += g[v][i];
g[u][i+1] += f[v][i]*f[u][i+1];
f[u][i+1] += f[v][i];
}
}
ans += g[u][0];
}
int main()
{
ll i, u, v;
n = read();
rep(i,1,n-1)
{
u = read(), v = read();
G.adde(u,v); G.adde(v,u);
}
lcd.run(G,1);
dfs(1,0);
printf("%lld\n",ans);
return 0;
}