树链剖分入门学习笔记

发布于 2021-02-25  251 次阅读


这个算法还是比较容易理解的,只是代码调的我有点崩溃(还不是我太蒻了),因此学习笔记写的会比较简单。

本博客参考:PoPoQQQ 的课件;ChinHhh 的博客; attack 的代码启发。

前置

  • 知识点:DFS序、线段树。

然后是一些概念:

  • 重儿子:一个节点的子节点中,$size$(即子树大小)最大的那个。;
  • 轻儿子:非重儿子的子节点;
  • 重边:一个点到它的重儿子的边;
  • 轻边:一个点到它的轻儿子的边;
  • 重链:由重边连结形成的链;
  • 链顶:一条重链中深度最小的节点。

显然,一个非叶节点有且仅有一个重儿子。

树链剖分模板题:Luogu3384

预处理

预处理主要是两遍 dfs。

dfs1

第一次 dfs 处理出每个点和深度/父亲/子树大小/重儿子。

inline void dfs1(int u,int f,int depth){
    dep[u]=depth;//深度
    fa[u]=f;siz[u]=1;//父亲
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f) continue;
        dfs1(v,u,depth+1);
        siz[u]+=siz[v];//子树大小
        if(siz[v]>siz[son[u]]) son[u]=v;//重儿子
    }
}

dfs2

第二次 dfs 处理出 dfs 序(注意,这个 dfs 序是在先走重儿子前提下建立的 dfs 序)并把树上点的值赋到对应的新编号上,同时处理每个点所在链的链顶。

inline void dfs2(int u,int Top){
    id[u]=++tot;//记录在 dfs 序中新编号
    a[tot]=v[u];//赋值
    top[u]=Top;//处理链顶
    if(!son[u]) return;
    dfs2(son[u],Top);//先走重儿子,Top 不变
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);//轻儿子的链顶变成他自己
    }
}

容易发现两个性质:

  • 每一条重链的编号是连续的;
  • 每一棵子树的编号是连续的。

这两个性质是我们处理问题的关键。

建树

就是把新编号形成的新序列建成线段树,与正常线段树建树相同,不多阐述,可见文末完整代码。

简单问题的处理

树链剖分实际上就是把树上的问题转到新的依照 dfs 序形成的序列中进行处理,因而这里介绍模板题中四个问题的处理,即处理一段路径或是处理一棵子树。处理方法可以在不同题目间转化/引申。

求两点路径上的点权和

求两点路径上的点权和时,我们假定点 $x$ 为链顶更深的那个点,那我们每次都将 $ans$ 加上 $x$ 点至其链顶的点权和,然后再将他跳到其链顶上面一个点。我们反复执行这个步骤,直至两点在同一条链上,直接求出区间和即可。

线段树模板就略去了,有需要的可以在文末给出的代码中查看。

inline void queryTree(int x,int y){
    int ret=0;
    while(top[x]!=top[y]){//跳到处于同一条链上为止
        if(dep[top[x]]<dep[top[y]]) swap(x,y);//跳链顶深的那个点
        ret+=query(1,1,n,id[top[x]],id[x]);//转到线段树
        ret%=mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ret+=query(1,1,n,id[x],id[y]);//转到线段树
    ret%=mod;
    printf("%d\n",ret);
}

将两点路径上的点权都加上一个值

处理方法同上,直接给出代码。

inline void addTree(int x,int y,int k){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        add(1,1,n,id[top[x]],id[x],k);//转到线段树
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    add(1,1,n,id[x],id[y],k);//转到线段树
}

求子树的点权和

由于子树内编号是连续的,且我们记录了子树大小,这个问题就非常容易处理了。

scanf("%d",&x);
printf("%d\n",query(1,1,n,id[x],id[x]+siz[x]-1));//直接转为线段树上求区间和

子树的点权都加上一个值

处理方法同上,直接给出代码。

scanf("%d%d",&x,&z);z%=mod;
add(1,1,n,id[x],id[x]+siz[x]-1,z);//直接转为线段树上区间加

总结 & 一道练手题

简单总结一下树剖的步骤:

  1. 根据题目需要写出线段树处理问题的代码;
  2. 两边 dfs ,然后建树;
  3. 再把问题需要处理的树上路径转化成线段树上的区间。

Luogu1505 [国家集训队]旅游 主要是码量不小,练手挺好,有一个简单的知识点,就是将边权转化为点权,也比较容易。

复杂度证明

树链剖分的复杂度为 $O(n\log^2n)$。

简单说,就是每次询问,树上跳来跳去需要一个 $\log$,然后线段树查询需要一个 $\log$。

线段树查询的复杂度这个不谈,稍微证明一下为什么树上最多跳 $O(\log n)$ 次。

容易想到,我们只需要证明任意一点到根的路径上有不超过 $O(\log n)$ 条重链,由于轻重链交错,所以只需要证明有不超过 $O(\log n)$ 条轻边即可。

于是:任意一点 $x$ 沿着父亲往上走,如果走过一条轻边,因为这是轻边,所以父亲一定有个重儿子比当前儿子大,也就是说 size[fa[x]]>size[x]*2,因此,显然有走到根节点最多走的轻边不超过 $O(\log n)$ 条。

完整代码(Luogu3384 【模板】轻重链剖分)

#include<bits/stdc++.h>
#define ls u<<1
#define rs u<<1|1
using namespace std;
const int N=2e5+5;
struct SegmentTree{
    int len,tag,sum;
}tree[N<<3];
struct edge{
    int to,nxt;
}e[N<<1];
int cnt,head[N],n,m,mod,r;
int v[N],a[N];
inline void add(int u,int v){
    e[++cnt].to=v;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
int dep[N],fa[N],son[N],siz[N],top[N],id[N];
inline void dfs1(int u,int f,int depth){
    dep[u]=depth;
    fa[u]=f;siz[u]=1;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==f) continue;
        dfs1(v,u,depth+1);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]]) son[u]=v;
    }
}
int tot;
inline void dfs2(int u,int Top){
    id[u]=++tot;
    a[tot]=v[u];
    top[u]=Top;
    if(!son[u]) return;
    dfs2(son[u],Top);
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
inline void pushdown(int u){
    if(!tree[u].tag) return;
    tree[ls].tag=(tree[ls].tag+tree[u].tag)%mod;
    tree[rs].tag=(tree[rs].tag+tree[u].tag)%mod;
    tree[ls].sum=(tree[ls].sum+tree[ls].len*tree[u].tag)%mod;
    tree[rs].sum=(tree[rs].sum+tree[rs].len*tree[u].tag)%mod;
    tree[u].tag=0;
}
inline void pushup(int u){
    tree[u].sum=(tree[ls].sum+tree[rs].sum)%mod;
}
inline void build(int u,int l,int r){
    tree[u].len=r-l+1;
    if(l==r){
        tree[u].sum=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    pushup(u);
}
inline void add(int u,int l,int r,int L,int R,int k){
    if(l>R||r<L) return;
    if(l>=L&&r<=R){
        tree[u].sum+=tree[u].len*k;
        tree[u].tag+=k;
        return;
    }
    pushdown(u);
    int mid=(l+r)>>1;
    add(ls,l,mid,L,R,k);
    add(rs,mid+1,r,L,R,k);
    pushup(u);
}
inline int query(int u,int l,int r,int L,int R){
    if(l>R||r<L) return 0;
    if(l>=L&&r<=R) return tree[u].sum;
    pushdown(u);
    int ret=0;
    int mid=(l+r)>>1;
    ret+=query(ls,l,mid,L,R);
    ret+=query(rs,mid+1,r,L,R);
    ret%=mod;
    return ret;
}
inline void addTree(int x,int y,int k){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        add(1,1,n,id[top[x]],id[x],k);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    add(1,1,n,id[x],id[y],k);
}
inline void queryTree(int x,int y){
    int ret=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        ret+=query(1,1,n,id[top[x]],id[x]);
        ret%=mod;
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    ret+=query(1,1,n,id[x],id[y]);
    ret%=mod;
    printf("%d\n",ret);
}
int main(){
    scanf("%d%d%d%d",&n,&m,&r,&mod);
    for(int i=1;i<=n;i++) scanf("%d",&v[i]);
    for(int i=1,u,v;i<n;i++){
        scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }
    dfs1(r,0,1);
    dfs2(r,r);
    build(1,1,n);
    while(m--){
        int opt,x,y,z;
        scanf("%d",&opt);
        if(opt==1){
            scanf("%d%d%d",&x,&y,&z);
            z%=mod;
            addTree(x,y,z);
        }
        else if(opt==2){
            scanf("%d%d",&x,&y);
            queryTree(x,y);
        }
        else if(opt==3){
            scanf("%d%d",&x,&z);z%=mod;
            add(1,1,n,id[x],id[x]+siz[x]-1,z);
        }
        else{
            scanf("%d",&x);
            printf("%d\n",query(1,1,n,id[x],id[x]+siz[x]-1));
        }
    }
    return 0;
}

月流华 岁遗沙 万古吴钩出玉匣