问题 E: 树统计
题目描述
骗分过样例,暴力出奇迹。
关于树的算法有一大堆,样样都是毒瘤。
比如说 NOIP2018 提高组的 D2T3,如果会动态 DP 的做法那么就马上想到正解,但是 Tweetuzki 不会动态 DP,就只好骗分了。
可惜树题的码量也是超级大的。听说好多学长都会动态 DP,但是考场上调不出来,只好暴力分收场了。疯狂暗示
Tweetuzki 当时暴力写挂了,有 4 个点写成了死循环……于是分数白白少了 16 分。Tweetuzki 一想起这事,不禁夙夜忧叹,辗转反侧。
现在他又遇到一道毒瘤树上问题了,他下定决心:这次一定要把暴力分写满!
题目是这样的:
有一棵 n 个点的树,边有边权,每个点有颜色 ci。求所有颜色不同的点对的距离之和。由于答案可能很大,你只需要输出其对 998,244,353 取模的结果即可。
形式化地讲,记 u 号点和 v 号点在树上的距离为 dist(u,v),求:
输入
输入文件将会遵循以下格式:
n type
c1 c2 ⋯ cn
u1 v1 w1
u2 v2 w2
⋮
un−1 vn−1 wn−1
第一行两个正整数 n,type(2≤n≤2×105,1≤type≤6),其中 n 表示点数,type为部分分类型,详见数据范围,type=0 表示样例数据。
第二行输入 n 个正整数 ci(1≤ci≤109),表示每个点的颜色。
接下来n−1 行,每行输入三个正整数 ui,vi,wi(1≤ui<vi≤n,1≤wi≤109),描述这棵树。
输出
输出一行一个非负整数,表示答案对 998,244,353 取模的结果。
样例输入 Copy
4 0
1 2 3 3
1 2 5
2 3 4
3 4 7
样例输出 Copy
90
提示
满足条件的点对有 (1,2),(1,3),(1,4),(2,1),(2,3),(2,4),(3,1),(3,2),(4,1),(4,2),故答案为 5+9+16+5+4+11+9+4+16+11=90。
Subtask #1:n≤300, type=1。
Subtask #2:n≤2 000, type≤2。
Subtask #3:n≤10 000, type≤3。
Subtask #4:对于第 i (1≤i≤n) 号点,ci=i。type=4。
Subtask #5 :对于第 i(1≤i<n)条边,ui+1=vi。type=5。
Subtask #6:无特殊性质,type≤6。
题目大意:
中文题目
题目思路:
考虑将答案转换一下:
设全集为所有点间的距离,所求的点距离之和即为:
所有点间的距离 - 相同颜色点之间的距离
所以题目转换为如何求所有点间的距离与颜色相同点间的距离
其实颜色相同点间的距离与所有点间的距离求法一致
求所有点间的距离:
对整棵树进行dfs,考虑每条边的贡献,每条边贡献的路径数即为:该边左边的点*该边右边的点(方案数)
所以答案就很显然了。
那么对于颜色相同点间的距离呢?
考虑和所有点间距离相同的解法,每次对一种颜色进行dfs
但是这样复杂度可能会变成O(n*n)
所以此时采用虚树的方法,将相同颜色的点构造一棵虚树,然后在虚树上跑一下dp就可以了
注意此时算时,虚树新增的lca 可能颜色不相同,所以不可计算入内
Code:
/*** keep hungry and calm CoolGuang!***/
#pragma GCC optimize(2)
//#include <bits/stdc++.h>
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pp;
const ll INF=1e17;
const int maxn=3e5+6;
const int mod=998244353;
const double eps=1e-3;
inline bool read(ll &num)
{char in;bool IsN=false;
in=getchar();if(in==EOF) return false;while(in!='-'&&(in<'0'||in>'9')) in=getchar();if(in=='-'){ IsN=true;num=0;}else num=in-'0';while(in=getchar(),in>='0'&&in<='9'){num*=10,num+=in-'0';}if(IsN) num=-num;return true;}
ll n,m,p;
int dfn[maxn],deep[maxn],f[maxn][21];///时间戳 深度 fa数组
vector< pair<int,int> >v[maxn];///存图
vector<int>g[maxn];
int st[maxn],s = 0;///模拟栈的实现
int ldfn = 0;
int vis[maxn];
ll c[maxn];
ll ans = 0;
ll dfs(int u,int fa,ll w){
deep[u] = deep[fa] + 1;
dfn[u] = ++ldfn;f[u][0] = fa;
c[u] =c[fa]+w;
ll temp = 1;
for(int k=1;k<=20;k++) f[u][k] = f[f[u][k-1]][k-1];
for(auto x:v[u]){
if(x.first == fa) continue;
temp += dfs(x.first,u,x.second);
}
ans = (ans + temp*(n-temp)%mod*w%mod*2)%mod;
return temp;
}
int LCA(int u,int v){///求lca
if(deep[u] < deep[v]) swap(u,v);
for(int k=20;k>=0;k--) if(deep[v]<=deep[f[u][k]]) u = f[u][k];
if(u == v) return u;
for(int k=20;k>=0;k--){
if(f[u][k]!=f[v][k]){
u = f[u][k];
v = f[v][k];
}
}
return f[u][0];
}
void Insert(int x){///虚树中插入节点x
if(!s) {st[++s] = x;return ;}
int lca = LCA(st[s],x);
if(lca == st[s]) {st[++s] = x;return ;}
while(s>1&&dfn[lca]<=dfn[st[s-1]]){
g[st[s-1]].push_back(st[s]);
g[st[s]].push_back(st[s-1]);
s--;
}
if(lca != st[s]){
g[lca].push_back(st[s]);
g[st[s]].push_back(lca);
st[s] = lca;
}
st[++s] = x;
}
struct Query{
int c,id;
bool friend operator<(Query a,Query b){
if(a.c == b.c) return dfn[a.id]<dfn[b.id];
return a.c < b.c;
}
}q[maxn];
ll tempans = 0;
ll tempcot = 0;
ll dfs(int u,int fa){
ll temp = vis[u];///节点数
ll dis = c[u]-c[fa];
for(int e:g[u]){
if(e == fa) continue;
temp += dfs(e,u);
}
g[u].clear();
tempans = (tempans + dis%mod*(tempcot - temp)%mod*temp*2)%mod;
return temp;
}
int main(){
read(n);read(m);
for(int i=1;i<=n;i++){
scanf("%d",&q[i].c);
q[i].id = i;
}
for(int i=1;i<=n-1;i++){
int x,y,w;scanf("%d%d%d",&x,&y,&w);
v[x].push_back({y,w});
v[y].push_back({x,w});
}
dfs(1,1,0);
/// debug(ans);
sort(q+1,q+1+n);
for(int i=1;i<=n;i++){
int k = i;
s = tempans = tempcot = 0;
while(q[k].c == q[i].c&&k<=n){
Insert(q[k].id);
vis[q[k].id] = 1;
tempcot++;
k++;
}
while(s>1){
g[st[s-1]].push_back(st[s]);
g[st[s]].push_back(st[s-1]);
s--;
}
dfs(q[i].id,q[i].id);
ans = (ans-tempans+mod)%mod;
k = i;
while(q[k].c == q[i].c&&k<=n){
vis[q[k].id] = 0;
k++;
}
i = k-1;
}
printf("%lld\n",ans);
return 0;
}
/**
10
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
3
4 5 7 8 3
**/
相关补充:
此题的做法应该还有启发式合并 与 点分治
初见这题,我无奈地写了一下点分治,结果只过了一半的样例
感觉是因为写了一个map记录,使得复杂度变成了nlg^2
这里附带一下代码(正确性未知,1s卡T,3s估计可以检验):
/*** keep hungry and calm CoolGuang!***/
#include <bits/stdc++.h>
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline")
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const ll INF=2e18;
const int maxn=2e5+6;
const int mod=1e9+7;
const double eps=1e-15;
inline bool read(ll &num)
{char in;bool IsN=false;
in=getchar();if(in==EOF) return false;while(in!='-'&&(in<'0'||in>'9')) in=getchar();if(in=='-'){ IsN=true;num=0;}else num=in-'0';while(in=getchar(),in>='0'&&in<='9'){num*=10,num+=in-'0';}if(IsN) num=-num;return true;}
ll n,m,p;
struct node{
int e,next;
ll w;
}edge[maxn*2];
int head[maxn*2];
ll cnt = 0;
unordered_map<int,ll>b,c;///颜色,数量
int a[maxn];
void addedge(int u,int v,ll w){
edge[cnt] = node{v,head[u],w};
head[u] = cnt++;
}
ll ans = 0;
int sz[maxn],mx[maxn],sum = n,rt = 0;
int vis[maxn];
void getrt(int u,int fa){
mx[u] = 0;sz[u] = 1;
for(int i=head[u];~i;i=edge[i].next){
int e = edge[i].e;
if(e==fa||vis[e]) continue;
getrt(e,u);
sz[u] += sz[e];
mx[u] = max(mx[u],sz[e]);
}
mx[u] = max(mx[u],sum-sz[u]);
if(mx[rt]>mx[u]) rt = u;
}
ll tempsum = 0,tempcot = 0;
void getdis(int u,int fa,ll w,ll f){
b[a[u]] = (b[a[u]] + w*f)%mod;
tempsum = (tempsum + w*f)%mod;///该节点到所有颜色的距离之和
tempcot = (tempcot + 1*f)%mod;
c[a[u]] = (c[a[u]] + 1*f)%mod;
for(int i=head[u];~i;i=edge[i].next){
int e = edge[i].e;
if(e == fa||vis[e]) continue;
getdis(e,u,w+edge[i].w,f);
}
}
void go(int u,int fa,ll w){
/// printf("%d %lld %lld %d %lld %lld\n",u,tempcot,tempsum,a[u],c[a[u]],b[a[u]]);
ans = ans +(((tempcot-c[a[u]])*w)%mod + tempsum-b[a[u]])%mod;///加上所有贡献
for(int i=head[u];~i;i=edge[i].next){
int e = edge[i].e;
if(e == fa||vis[e]) continue;
go(e,u,w+edge[i].w);
}
}
void again(int u,int fa,ll w,int ct){
if(a[u] != a[ct]) ans = (ans+w)%mod;
for(int i=head[u];~i;i=edge[i].next){
int e = edge[i].e;
if(e==fa||vis[e]) continue;
again(e,u,w+edge[i].w,ct);
}
}
void calc(int u){///calculate the subtree of u
b.clear();c.clear();
tempsum = 0;tempcot = 0;
getdis(u,u,0,1ll);
/// printf("%lld %lld\n",tempsum,tempcot);
for(int i=head[u];~i;i=edge[i].next){
int e = edge[i].e;
if(vis[e]) continue;
getdis(e,u,edge[i].w,-1ll);
go(e,u,edge[i].w);
getdis(e,u,edge[i].w,1ll);
again(e,u,edge[i].w,u);
}
}
void solve(int u){
/// printf("%d----------\n",rt);
vis[u] = 1;calc(u);
for(int i=head[u];~i;i=edge[i].next){
int e = edge[i].e;
if(vis[e]) continue;
sum = sz[e];mx[rt = 0] = n;
getrt(e,u);solve(rt);
}
}
int main()
{
memset(head,-1,sizeof(head));
read(n);read(p);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n-1;i++){
int x,y;ll w;
scanf("%d%d%lld",&x,&y,&w);
addedge(x,y,w);
addedge(y,x,w);
}
sum = n;
mx[rt = 0] = n;
getrt(1,1);
ans = 0;
solve(rt);
printf("%lld\n",ans);
return 0;
}
全当复习一下点分治