树剖基础模板*真 P3384 【模板】树链剖分

#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cstdio>
#include<cmath>
#define fmid (tree[rt].l + tree[rt].r) >> 1
#define lson rt<<1
#define rson rt<<1|1

using namespace std;
typedef long long ll;
const int maxn = 3e5 + 10;

struct data
{
    int next, to;
};
struct node
{
    int l, r, sum;
    int lazy;
};
int mod;
int n, m, f[maxn], son[maxn], sz[maxn], deep[maxn], id[maxn], rk[maxn], head[maxn], top[maxn];
int a[maxn];
int tot, time;
data edge[maxn<<1];
node tree[maxn<<2];
void add(int u, int v)
{
    edge[tot] = (data){head[u], v};
    head[u] = tot++;
}
void dfs1(int x)
{
    sz[x] = 1;
    deep[x] = deep[f[x]] + 1;
    for(int i = head[x]; ~i; i = edge[i].next)
    {
        int to = edge[i].to;
        if(to != f[x])
        {
            f[to] = x;
            dfs1(to);
            sz[x] += sz[to];
            if(sz[to] > sz[son[x]])
                son[x] = to;
        }
    }
}
void dfs2(int x, int t)
{
    top[x] = t;
    id[x] = ++time;
    rk[time] = x;
    if(son[x])
        dfs2(son[x], t);
    for(int i = head[x]; ~i; i = edge[i].next)
    {
        int to = edge[i].to;
        if(to != f[x] && to != son[x])
            dfs2(to, to);
    }
}
void pushup(int rt)
{
    tree[rt].sum = (tree[lson].sum + tree[rson].sum)%mod;
}
void build(int l, int r, int rt)
{
    tree[rt].l = l;
    tree[rt].r = r;
    tree[rt].lazy = 0;
    if(l == r)
    {
        tree[rt].sum = a[rk[l]] % mod;
        return;
    }
    int mid = fmid;
    build(l, mid, lson);
    build(mid + 1, r, rson);
    pushup(rt);
}

void pushdown(int rt)
{
    if(tree[rt].lazy)
    {
        tree[lson].sum += tree[rt].lazy * (tree[lson].r - tree[lson].l + 1);
        tree[lson].sum %= mod;
        tree[lson].lazy += tree[rt].lazy;
        tree[rson].sum += tree[rt].lazy * (tree[rson].r - tree[rson].l + 1);
        tree[rson].sum %= mod;
        tree[rson].lazy += tree[rt].lazy;
        tree[rt].lazy = 0;
    }
}
void update(int l, int r, int rt, int val)
{
    if(tree[rt].l >= l && tree[rt].r <= r)
    {
        tree[rt].sum = (tree[rt].sum + (tree[rt].r - tree[rt].l + 1) * val) % mod;
        tree[rt].lazy += val;
        return;
    }
    pushdown(rt);
    int mid = fmid;
    if(mid >= l) update(l, r, lson, val);
    if(mid < r) update(l, r, rson, val);
    pushup(rt);
}
int query(int l, int r, int rt)
{
    if(tree[rt].l >= l && tree[rt].r <= r)
    {
        return tree[rt].sum % mod;
    }
    pushdown(rt);
    int ans = 0;
    int mid = fmid;
    if(mid >= l) ans += query(l, r, lson);
    ans %= mod;
    if(mid < r) ans += query(l, r, rson);
    ans %= mod;
    return ans;
}
void up(int x, int y, int z) // 链上更新
{
    z %= mod;
    while(top[x] != top[y])
    {
        if(deep[top[x]] < deep[top[y]])swap(x, y);
        update(id[top[x]], id[x], 1, z);
        x = f[top[x]];
    }
    if(id[x] > id[y]) swap(x, y);
    update(id[x], id[y], 1, z);
}

int ask(int x, int y) // 链上求和
{
    int ans = 0;
    while(top[x] != top[y])
    {
        if(deep[top[x]] < deep[top[y]]) swap(x, y);
        ans += query(id[top[x]], id[x], 1);
        ans %= mod;
        x = f[top[x]];
    }
    if(id[x] > id[y])
        swap(x, y);
    ans += query(id[x], id[y], 1);
    ans %= mod;
    return ans;
}
void upran(int x, int z) // 对x子树更新
{
    update(id[x], id[x] + sz[x] - 1, 1, z);
}
int askran(int x)// 对x子树求和
{
    int ans = query(id[x], id[x] + sz[x] - 1, 1);
    return ans % mod;
}

int main()
{
    int n, m, r;
    memset(head, -1, sizeof(head));
    time = 0, tot = 0;
    scanf("%d%d%d%d", &n, &m, &r, &mod);
    for(int i = 1; i <= n; i++)
        scanf("%d", &a[i]);
    for(int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v);
        add(v, u);
    }
    dfs1(r);
    dfs2(r, 0);
    build(1, n, 1);
    while(m--)
    {
        int op;
        scanf("%d", &op);
        if(op == 1)
        {
            int x, y, z;
            scanf("%d%d%d", &x, &y, &z);
            up(x, y, z);
        }
        else if(op == 2)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            printf("%d\n", ask(x, y));
        }
        else if(op == 3)
        {
            int x, z;
            scanf("%d%d", &x, &z);
            upran(x, z);
        }
        else
        {
            int x;
            scanf("%d", &x);
            printf("%d\n", askran(x));
        }
    }
    return 0;
}






发布了40 篇原创文章 · 获赞 13 · 访问量 847

猜你喜欢

转载自blog.csdn.net/weixin_43891021/article/details/102864140