[洛谷P4719] 动态DP模板

题意简述

一棵 \(n\)个点的树,点带点权。
\(m\) 次操作,每次操作给定 \(x,y\) ,表示修改点 \(x\) 的权值为 \(y\)
每次操作后求出这棵树的最大权独立集的权值大小。

\(n,m \leq 10^5\)


题解

首先有一个 \(O(nm)\)\(DP\)
\(f[u][0/1]\) 分别表示以 \(u\) 为根的子树中,\(u\) 不选/选 的最大独立集权值大小。
\[ f[u][0]=\sum\limits_{fa[v]=u} max(f[v][0],f[v][1]) \\ f[u][1]=val[u]+\sum\limits_{fa[v]=u} f[v][0] \]

显然超时。考虑如何优化。
注意到每次只修改一个点,也就是说只有该点到根节点的路径上的点的 \(dp\) 值有变化。
这并没有什么卵用,如果是链就废了。但这提示我们考虑树剖(神逻辑…)

\(g[u][0/1]\) 表示只考虑所有轻儿子时的 \(dp\) 值。
\[ g[u][0]=\sum\limits_{v为轻子} max(f[v][0],f[v][1]) \\ g[u][1]=val[u]+\sum\limits_{v为轻子} f[v][0] \]
\(v\)\(u\) 的重子,那么
\[ f[u][0]=g[u][0]+max(f[v][0],f[v][1]) \\ f[u][1]=g[u][1]+f[v][0] \]
这可以写成广义矩阵乘法形式(+变成 \(max\),乘变为+):
\[ \begin{equation*} \left( \begin{array}{cc} g[u][0]& g[u][0] \\ g[u][1]& -\infty \end{array} \right ) \times \left( \begin{array}{cc} f[v][0]\\ f[v][1] \end{array} \right ) = \left( \begin{array}{cc} f[u][0]\\ f[u][1] \end{array} \right ) \end{equation*} \]
(而且可以发现,此式可用在“更新”中:\((原) \times (新加入)=(新)\)
根据此式递推下去,树根的 \(dp\) 值就是树根所在的重链的 \(\left( \begin{array}{cc} g[u][0]& g[u][0] \\ g[u][1]& -\infty \end{array} \right )\) 矩阵乘积 再乘上 \(\left( \begin{array}{cc} 0\\ 0 \end{array} \right )\)
用线段树维护每个点的 \(\left( \begin{array}{cc} g[u][0]& g[u][0] \\ g[u][1]& -\infty \end{array} \right )\) 和区间矩阵积即可。

总体思路有了后,还剩两个细节。
一是,如何预处理出 \(g[u][0/1]\)
还记得前面说的 \((原) \times (新加入)=(新)\) 嘛?
最原始 \(g[u][0]=0,g[u][1]=val[u]\)\(dfs\)一遍,每个点的轻子的 \(f\) 值更新此点的 \(g\) 值;最后别忘了加上重子,求出此点的 \(f\) 值并传到其父节点。

二是,具体如何修改。
首先修改 \(x\) 的矩阵(变了的是 \(g[x][1]\)),然后往上跳链,修改跳到的 重链顶端点的父节点 的矩阵。
在修改时还会发现,被修改的点可能不止一个轻子。但由于每个轻子对该点答案的贡献是独立的(详见转移方程),只需记下该轻子在修改前后对父节点的贡献,然后相减更新父节点的矩阵。

细节还是很多的。


代码

代码写吐+调吐……更想不到的是写个题解都如此煎熬……

#include<cstdio>
#include<iostream>
#include<algorithm>

#define INF 1000000000

using namespace std;

int read(){
    int x=0,f=1;
    char ch=getchar();
    while(!isdigit(ch) && ch!='-') ch=getchar();
    if(ch=='-') f=-1,ch=getchar();
    while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x*f;
}

const int N = 100005;

int n,m,val[N];
struct node{
    int v;
    node *nxt;
}pool[N*2],*h[N];
int cnt1;
void addedge(int u,int v){
    node *p=&pool[++cnt1],*q=&pool[++cnt1];
    p->v=v;p->nxt=h[u];h[u]=p;
    q->v=u;q->nxt=h[v];h[v]=q;
}

struct Mat{
    int a[2][2];
    Mat() { a[0][0]=a[0][1]=a[1][0]=a[1][1]=0; }
    Mat operator * (const Mat &b) const{
        Mat c;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++){
                c.a[i][j]=-INF;
                for(int k=0;k<2;k++) 
                    c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]);
            }
        return c;
    }
}m0[N],mm[N*2];

int dfn[N],top[N],sz[N],son[N],tot,bot[N],re[N],fa[N];
void dfs1(int u){
    int v,Bson=0;
    sz[u]=1;
    for(node *p=h[u];p;p=p->nxt)
        if(!sz[v=p->v]){
            fa[v]=u;
            dfs1(v);
            sz[u]+=sz[v];
            if(sz[v]>Bson) Bson=sz[v],son[u]=v;
        }
}
void dfs2(int u){
    int v=son[u];
    if(v){
        top[v]=top[u];
        dfn[v]=++tot;
        re[tot]=v;
        dfs2(v);
    }
    else bot[top[u]]=u;
    for(node *p=h[u];p;p=p->nxt)
        if(!dfn[v=p->v]){
            top[v]=v;
            dfn[v]=++tot;
            re[tot]=v;
            dfs2(v);
        }
}
int g0,g1;
void getm(int u){
    int v;
    Mat c;
    m0[u].a[0][0]=m0[u].a[0][1]=0; /**/
    m0[u].a[1][0]=val[u]; m0[u].a[1][1]=-INF;
    for(node *p=h[u];p;p=p->nxt)
        if(fa[v=p->v]==u && v!=son[u]){
            getm(v);
            c.a[0][0]=g0; c.a[1][0]=g1; c.a[0][1]=c.a[1][1]=0;
            c=m0[u]*c;
            m0[u].a[0][0]=m0[u].a[0][1]=c.a[0][0];
            m0[u].a[1][0]=c.a[1][0]; m0[u].a[1][1]=-INF;
        }
    if(v=son[u]){
        getm(v);
        c.a[0][0]=g0; c.a[1][0]=g1; c.a[0][1]=c.a[1][1]=0;
        c=m0[u]*c;
        g0=c.a[0][0]; g1=c.a[1][0];/**/
    }
    else{ g0=0; g1=val[u]; }
}

int cnt,root,ch[N*2][2];
void build(int x,int l,int r){
    if(l==r) { mm[x]=m0[re[l]]; return; }
    int mid=(l+r)>>1;
    build(ch[x][0]=++cnt,l,mid);
    build(ch[x][1]=++cnt,mid+1,r);
    mm[x]=mm[ch[x][0]]*mm[ch[x][1]];
}
void change(int x,int l,int r,int c,int y0,int y1){
    if(l==r) { 
        mm[x].a[0][0]+=y0; mm[x].a[0][1]+=y0;
        mm[x].a[1][0]+=y1; mm[x].a[1][1]=-INF;
        return; 
    }
    int mid=(l+r)>>1;
    if(c<=mid) change(ch[x][0],l,mid,c,y0,y1);
    else change(ch[x][1],mid+1,r,c,y0,y1);
    mm[x]=mm[ch[x][0]]*mm[ch[x][1]];
}
Mat sum(int x,int l,int r,int L,int R){
    if(L<=l && r<=R) return mm[x];
    int mid=(l+r)>>1;
    if(R<=mid) return sum(ch[x][0],l,mid,L,R);/**/
    else if(L>mid) return sum(ch[x][1],mid+1,r,L,R); /**/
    return sum(ch[x][0],l,mid,L,mid)*sum(ch[x][1],mid+1,r,mid+1,R);
}

void jump(int x,int y){
    Mat g;
    int p0,p1,gg0,gg1;
    g0=0; g1=y;
    while(x){ /**/
        g=sum(root,1,n,dfn[top[x]],dfn[bot[top[x]]]);
        p0=g.a[0][0]; p1=g.a[1][0];/**/
        change(root,1,n,dfn[x],g0,g1);
        g=sum(root,1,n,dfn[top[x]],dfn[bot[top[x]]]);
        gg0=g.a[0][0]; gg1=g.a[1][0];
        g1=gg0-p0; g0=max(gg0,gg1)-max(p0,p1);
        x=fa[top[x]];
    }
}

int main()
{
    n=read(); m=read();
    for(int i=1;i<=n;i++) val[i]=read();
    for(int i=1;i<n;i++) addedge(read(),read());
    
    dfs1(1);
    top[1]=1; dfn[1]=++tot; re[tot]=1; dfs2(1);
    getm(1);
    build(root=++cnt,1,n);
    
    Mat cur;
    int x,y;
    while(m--){
        x=read(); y=read();
        jump(x,y-val[x]); val[x]=y;
        cur=sum(root,1,n,1,dfn[bot[1]]);
        printf("%d\n",max(cur.a[0][0],cur.a[1][0])); /**/
    }
    
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/lindalee/p/12336828.html