【upc】树统计 | 虚树、边权统计

问题 E: 树统计

时间限制: 1 Sec  内存限制: 128 MB
提交 状态

题目描述

骗分过样例,暴力出奇迹。
关于树的算法有一大堆,样样都是毒瘤。
比如说 NOIP2018 提高组的 D2T3,如果会动态 DP 的做法那么就马上想到正解,但是 Tweetuzki 不会动态 DP,就只好骗分了。
可惜树题的码量也是超级大的。听说好多学长都会动态 DP,但是考场上调不出来,只好暴力分收场了。疯狂暗示
Tweetuzki 当时暴力写挂了,有 4 个点写成了死循环……于是分数白白少了 16 分。Tweetuzki 一想起这事,不禁夙夜忧叹,辗转反侧。
现在他又遇到一道毒瘤树上问题了,他下定决心:这次一定要把暴力分写满!
题目是这样的:
有一棵 n 个点的树,边有边权,每个点有颜色 ci。求所有颜色不同的点对的距离之和。由于答案可能很大,你只需要输出其对 998,244,353 取模的结果即可。
形式化地讲,记 u 号点和 v 号点在树上的距离为 dist(u,v),求:

输入

输入文件将会遵循以下格式:
n type
c1 c2 ⋯ cn
u1 v1 w1
u2 v2 w2

un−1 vn−1 wn−1
第一行两个正整数 n,type(2≤n≤2×105,1≤type≤6),其中 n 表示点数,type为部分分类型,详见数据范围,type=0 表示样例数据。
第二行输入 n 个正整数 ci(1≤ci≤109),表示每个点的颜色。
接下来n−1 行,每行输入三个正整数 ui,vi,wi(1≤ui<vi≤n,1≤wi≤109),描述这棵树。

输出

输出一行一个非负整数,表示答案对 998,244,353 取模的结果。

样例输入 Copy

4 0
1 2 3 3
1 2 5
2 3 4
3 4 7

样例输出 Copy

90

提示

满足条件的点对有 (1,2),(1,3),(1,4),(2,1),(2,3),(2,4),(3,1),(3,2),(4,1),(4,2),故答案为 5+9+16+5+4+11+9+4+16+11=90。

Subtask #1:n≤300, type=1。
Subtask #2:n≤2 000, type≤2。
Subtask #3:n≤10 000, type≤3。
Subtask #4:对于第 i (1≤i≤n) 号点,ci=i。type=4。
Subtask #5 :对于第 i(1≤i<n)条边,ui+1=vi。type=5。
Subtask #6:无特殊性质,type≤6。

题目大意:

中文题目

题目思路:

考虑将答案转换一下:

设全集为所有点间的距离,所求的点距离之和即为:

所有点间的距离 - 相同颜色点之间的距离

所以题目转换为如何求所有点间的距离与颜色相同点间的距离

其实颜色相同点间的距离与所有点间的距离求法一致

求所有点间的距离:

对整棵树进行dfs,考虑每条边的贡献,每条边贡献的路径数即为:该边左边的点*该边右边的点(方案数)

所以答案就很显然了。

那么对于颜色相同点间的距离呢?

考虑和所有点间距离相同的解法,每次对一种颜色进行dfs

但是这样复杂度可能会变成O(n*n)

所以此时采用虚树的方法,将相同颜色的点构造一棵虚树,然后在虚树上跑一下dp就可以了

注意此时算时,虚树新增的lca 可能颜色不相同,所以不可计算入内

Code:

/*** keep hungry and calm CoolGuang!***/
#pragma GCC optimize(2)
//#include <bits/stdc++.h>
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pp;
const ll INF=1e17;
const int maxn=3e5+6;
const int mod=998244353;
const double eps=1e-3;
inline bool read(ll &num)
{char in;bool IsN=false;
    in=getchar();if(in==EOF) return false;while(in!='-'&&(in<'0'||in>'9')) in=getchar();if(in=='-'){ IsN=true;num=0;}else num=in-'0';while(in=getchar(),in>='0'&&in<='9'){num*=10,num+=in-'0';}if(IsN) num=-num;return true;}

ll n,m,p;
int dfn[maxn],deep[maxn],f[maxn][21];///时间戳 深度 fa数组
vector< pair<int,int> >v[maxn];///存图
vector<int>g[maxn];
int st[maxn],s = 0;///模拟栈的实现
int ldfn = 0;
int vis[maxn];
ll c[maxn];
ll ans = 0;
ll dfs(int u,int fa,ll w){
    deep[u] = deep[fa] + 1;
    dfn[u] = ++ldfn;f[u][0] = fa;
    c[u] =c[fa]+w;
    ll temp = 1;
    for(int k=1;k<=20;k++) f[u][k] = f[f[u][k-1]][k-1];
    for(auto x:v[u]){
        if(x.first == fa) continue;
        temp += dfs(x.first,u,x.second);
    }
    ans = (ans + temp*(n-temp)%mod*w%mod*2)%mod;
    return temp;
}
int LCA(int u,int v){///求lca
    if(deep[u] < deep[v]) swap(u,v);
    for(int k=20;k>=0;k--) if(deep[v]<=deep[f[u][k]]) u = f[u][k];
    if(u == v) return u;
    for(int k=20;k>=0;k--){
        if(f[u][k]!=f[v][k]){
            u = f[u][k];
            v = f[v][k];
        }
    }
    return f[u][0];
}
void Insert(int x){///虚树中插入节点x
    if(!s) {st[++s] = x;return ;}
    int lca = LCA(st[s],x);
    if(lca == st[s]) {st[++s] = x;return ;}
    while(s>1&&dfn[lca]<=dfn[st[s-1]]){
        g[st[s-1]].push_back(st[s]);
        g[st[s]].push_back(st[s-1]);
        s--;
    }
    if(lca != st[s]){
        g[lca].push_back(st[s]);
        g[st[s]].push_back(lca);
        st[s] = lca;
    }
    st[++s] = x;
}
struct  Query{
    int c,id;
    bool friend operator<(Query a,Query b){
        if(a.c == b.c) return dfn[a.id]<dfn[b.id];
        return a.c < b.c;
    }
}q[maxn];
ll tempans = 0;
ll tempcot = 0;
ll dfs(int u,int fa){
    ll temp = vis[u];///节点数
    ll dis = c[u]-c[fa];
    for(int e:g[u]){
        if(e == fa) continue;
        temp += dfs(e,u);
    }
    g[u].clear();
    tempans = (tempans + dis%mod*(tempcot - temp)%mod*temp*2)%mod;
    return temp;
}
int main(){
    read(n);read(m);
    for(int i=1;i<=n;i++){
        scanf("%d",&q[i].c);
        q[i].id = i;
    }
    for(int i=1;i<=n-1;i++){
        int x,y,w;scanf("%d%d%d",&x,&y,&w);
        v[x].push_back({y,w});
        v[y].push_back({x,w});
    }
    dfs(1,1,0);
  ///  debug(ans);
    sort(q+1,q+1+n);
    for(int i=1;i<=n;i++){
        int k = i;
        s = tempans = tempcot = 0;
        while(q[k].c == q[i].c&&k<=n){
            Insert(q[k].id);
            vis[q[k].id] = 1;
            tempcot++;
            k++;
        }
        while(s>1){
            g[st[s-1]].push_back(st[s]);
            g[st[s]].push_back(st[s-1]);
            s--;
        }
        dfs(q[i].id,q[i].id);
        
        ans = (ans-tempans+mod)%mod;
       
        k = i;
        while(q[k].c == q[i].c&&k<=n){
            vis[q[k].id] = 0;
            k++;
        }
        i = k-1;
    }
    printf("%lld\n",ans);
    return 0;
}
/**
10
1 5 13
1 9 6
2 1 19
2 4 8
2 3 91
5 6 8
7 5 4
7 8 31
10 7 9
3
4 5 7 8 3
**/

相关补充:

此题的做法应该还有启发式合并 与 点分治

初见这题,我无奈地写了一下点分治,结果只过了一半的样例

感觉是因为写了一个map记录,使得复杂度变成了nlg^2

这里附带一下代码(正确性未知,1s卡T,3s估计可以检验):

/*** keep hungry and calm CoolGuang!***/
#include <bits/stdc++.h>
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline")
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const ll INF=2e18;
const int maxn=2e5+6;
const int mod=1e9+7;
const double eps=1e-15;
inline bool read(ll &num)
{char in;bool IsN=false;
    in=getchar();if(in==EOF) return false;while(in!='-'&&(in<'0'||in>'9')) in=getchar();if(in=='-'){ IsN=true;num=0;}else num=in-'0';while(in=getchar(),in>='0'&&in<='9'){num*=10,num+=in-'0';}if(IsN) num=-num;return true;}
ll n,m,p;
struct node{
    int e,next;
    ll w;
}edge[maxn*2];
int head[maxn*2];
ll cnt = 0;
unordered_map<int,ll>b,c;///颜色,数量
int a[maxn];
void addedge(int u,int v,ll w){
    edge[cnt] = node{v,head[u],w};
    head[u] = cnt++;
}
ll ans = 0;
int sz[maxn],mx[maxn],sum = n,rt = 0;
int vis[maxn];
void getrt(int u,int fa){
    mx[u] = 0;sz[u] = 1;
    for(int i=head[u];~i;i=edge[i].next){
        int e = edge[i].e;
        if(e==fa||vis[e]) continue;
        getrt(e,u);
        sz[u] += sz[e];
        mx[u] = max(mx[u],sz[e]);
    }
    mx[u] = max(mx[u],sum-sz[u]);
    if(mx[rt]>mx[u]) rt = u;
}
ll tempsum = 0,tempcot = 0;
void getdis(int u,int fa,ll w,ll f){

    b[a[u]] = (b[a[u]] + w*f)%mod;
    tempsum = (tempsum + w*f)%mod;///该节点到所有颜色的距离之和

    tempcot = (tempcot + 1*f)%mod;
    c[a[u]] = (c[a[u]] + 1*f)%mod;

    for(int i=head[u];~i;i=edge[i].next){
        int e = edge[i].e;
        if(e == fa||vis[e]) continue;
        getdis(e,u,w+edge[i].w,f);
    }
}
void go(int u,int fa,ll w){
   /// printf("%d %lld %lld %d %lld %lld\n",u,tempcot,tempsum,a[u],c[a[u]],b[a[u]]);
    ans = ans +(((tempcot-c[a[u]])*w)%mod + tempsum-b[a[u]])%mod;///加上所有贡献
    for(int i=head[u];~i;i=edge[i].next){
        int e = edge[i].e;
        if(e == fa||vis[e]) continue;
        go(e,u,w+edge[i].w);
    }
}
void again(int u,int fa,ll w,int ct){
    if(a[u] != a[ct]) ans = (ans+w)%mod;
    for(int i=head[u];~i;i=edge[i].next){
        int e = edge[i].e;
        if(e==fa||vis[e]) continue;
        again(e,u,w+edge[i].w,ct);
    }
}
void calc(int u){///calculate the subtree of u
    b.clear();c.clear();
    tempsum = 0;tempcot = 0;
    getdis(u,u,0,1ll);
  ///  printf("%lld %lld\n",tempsum,tempcot);
    for(int i=head[u];~i;i=edge[i].next){
        int e = edge[i].e;
        if(vis[e]) continue;
        getdis(e,u,edge[i].w,-1ll);
        go(e,u,edge[i].w);
        getdis(e,u,edge[i].w,1ll);
        again(e,u,edge[i].w,u);
    }
}
void solve(int u){
   /// printf("%d----------\n",rt);
    vis[u] = 1;calc(u);
    for(int i=head[u];~i;i=edge[i].next){
        int e = edge[i].e;
        if(vis[e]) continue;
        sum = sz[e];mx[rt = 0] = n;
        getrt(e,u);solve(rt);
    }
}
int main()
{
    memset(head,-1,sizeof(head));
    read(n);read(p);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    for(int i=1;i<=n-1;i++){
        int x,y;ll w;
        scanf("%d%d%lld",&x,&y,&w);
        addedge(x,y,w);
        addedge(y,x,w);
    }
    sum = n;
    mx[rt = 0] = n;
    getrt(1,1);
    ans = 0;
    solve(rt);
    printf("%lld\n",ans);
    return 0;
}

全当复习一下点分治

猜你喜欢

转载自blog.csdn.net/qq_43857314/article/details/107186161