有趣的线段树维护——吉老师线段树学习笔记

发布于 2021-03-17  259 次阅读


本来没打算写这玩意儿,结果学习了之后觉得吉老师线段树真的很有意思,所以就决定简单写一下学习笔记。

本博客参考:echo 的 博客 与 代码启示,jiry_2 的 课件

但是个人感觉该博客对于 pushdown 部分还是有点难以理解,所以决定用自己的语言写一篇博客阐述一下。

吉老师线段树

吉如一(jiry_2)的 PDF 见博客上方博客参考部分。

模板题:Luogu6242

题意:写一棵线段树,维护区间加、区间取最值、区间求和、区间最值与区间历史最值操作。

先不考虑历史最值问题,我们容易发现暴力维护区间取最值操作(本题中为将 $A_i$ 变成 $\min(A_i,v)$ 暴力方法即:一直往下搜直至搜到 $l=r$ 或 $v>max$ 为止,$max$ 为区间最大值)的复杂度是 $O(n^2)$ ,不可接受。我们考虑如何维护这个东西。

吉老师提出了一个很好的解决方案:对于每一个结点,我们维护区间最大值,最大值出现次数,区间严格次大值(用 $sec$ 表示)。我们维护好了这些信息之后,我们发现我们可以在区间取 $\min$ 操作时不断遍历线段树,直到 $sec<v$,由于我们维护好了区间最大值和最大值出现次数,我们可以很容易地维护其他需要维护的信息。吉老师也承认了自己在课件上的 $O(n\log n)$ 证明是伪证,但是可以确定的是,该算法的复杂度不多于 $O(n\log^2n)$。

如果对证明有兴趣,可以参考]吉如一的讨论,灵梦的洛谷日报文章, 以及吉如一的集训队论文。

那么,加入了历史最值操作之后要怎么维护呢?我将会下文的标记下传(pushdown)部分详细阐述。

接着将会着重写三个板块的内容:数值上传(pushup),标记下传(pushdown), 区间取最值(update_min)。

数值上传(pushup)

我们可以直接上传区间最大值、历史最大值和区间和。但是怎么维护区间严格次大值和最大值出现次数呢?我们先判断左右区间的区间最大值是否相同:若相同,将左右区间最大值出现次数相加,并取左右区间的严格次大值中较大的那个作为区间次大值;若不同,取最大值较大的那个区间的最大值区间次数,并将该区间的次大值与另一个区间的严格次大值比较以更新区间次大值。

该部分应该比较容易理解。

inline void pushup(int u){
    tree[u].maxa=max(tree[ls].maxa,tree[rs].maxa);
    tree[u].maxb=max(tree[ls].maxb,tree[rs].maxb);
    tree[u].sum=tree[ls].sum+tree[rs].sum;//直接上传
    if(tree[ls].maxa==tree[rs].maxa){//三种情况,不同取法
        tree[u].sec=max(tree[ls].sec,tree[rs].sec);
        tree[u].cnt=tree[ls].cnt+tree[rs].cnt;
    }
    else if(tree[ls].maxa>tree[rs].maxa){
        tree[u].sec=max(tree[ls].sec,tree[rs].maxa);
        tree[u].cnt=tree[ls].cnt;
    }
    else{
        tree[u].sec=max(tree[ls].maxa,tree[rs].sec);
        tree[u].cnt=tree[rs].cnt;
    }
}

标记下传(pushdown)

在这里着重讲一下加入了历史最大值操作要怎么维护。

首先我们要维护四个 tag,即 $add_a,add_a1,add_b,add_b1$。

struct Segment_Tree{
    int add_a,add_a1,add_b,add_b1;
    int maxa,sec,maxb,sum,cnt,len;
}tree[N<<3];

这四个 tag 分别是干什么的呢?由于在吉老师线段树中,我们区间取 $\min$ 时只会更新区间最大值的值,因此我们考虑将每个区间的数分为两类:最大数和非最大数,然后进行维护。这样我们在区间取 $\min$ 时就可以很好的只对最大值进行操作了。因此,我们用 $add_a,add_a1$ 分别表示当前区间最大值的加标记和当前区间非最大值的加标记(加标记是什么?参考 【模板】线段树 1),用 $add_b,add_b1$ 分别表示当前区间最大值历史上加的最多的那次的加标记以及对应的非最大值的历史上加的最多的那次的加标记。

什么是历史上加的最多的那次的加标记?我们举个例子,我们曾经给一个区间的最大值打上了 $+5$ 的加标记,然后下一个操作,这个区间的最大值要 $-2$,我们打这个标记后,$add_a$ 就变成了 $+3$,而 $add_b$ 仍为 $+5$。因为他曾经加的最多的是 $5$,所以他的子区间都在某一个时刻是 $+5$ 的,这个时刻才可能是历史最大的位置。

这四个标记理解了之后,后面就比较容易理解了。由于标记下传比较复杂,我们通常用 pushdown 和 update 两个函数来维护。对于 pushdown,我们先判断最大值在哪边,如果在左边,那么左边的最大值受到的影响是当前区间对最大值的影响,而右边的最大值受到的影响仅是当前区间对次大值的影响(当然可能左右两边都有最大值,那么两边的最大值受到的影响均是当前区间对最大值的影响)。如果最大值在右边,反过来就可以了。

inline void pushdown(int u){
    int maxx=max(tree[ls].maxa,tree[rs].maxa);
    if(tree[ls].maxa==maxx)//最大值在左边?
        update(ls,tree[u].add_a,tree[u].add_b,tree[u].add_a1,tree[u].add_b1);//是,给最大值的标记给左边。
    else update(ls,tree[u].add_a1,tree[u].add_b1,tree[u].add_a1,tree[u].add_b1);//不是,左边的最大值获得的标记也应该是该区间的非最大值获得的标记。
    if(tree[rs].maxa==maxx)//最大值在右边?
        update(rs,tree[u].add_a,tree[u].add_b,tree[u].add_a1,tree[u].add_b1);//类似上边
    else update(rs,tree[u].add_a1,tree[u].add_b1,tree[u].add_a1,tree[u].add_b1);
    tree[u].add_a=tree[u].add_b=tree[u].add_a1=tree[u].add_b1=0;
}

这时候可能有疑问,我们仅仅判断了当前的最大值在哪里,怎么 $add_b$ 也跟着走了,这个东西不是维护历史的吗?如果你理解了上面的,那么该标记表示的是当前最大值历史上加的最多的那次,所以这个标记实质上还是打给当前最大值的,当然是与 $add_a$ 是一起的。

那么 update 呢?我们先看代码。

inline void update(int u,int k1,int k2,int k3,int k4){
    tree[u].sum+=k1*tree[u].cnt+k3*(tree[u].len-tree[u].cnt);//更新区间和
    tree[u].maxb=max(tree[u].maxb,tree[u].maxa+k2);//更新历史最大值
    tree[u].add_b=max(tree[u].add_b,tree[u].add_a+k2);
    tree[u].add_b1=max(tree[u].add_b1,tree[u].add_a1+k4);
    tree[u].maxa+=k1;//更新当前最大值
    tree[u].add_a+=k1;
    tree[u].add_a1+=k3;
    if(tree[u].sec!=-1e18) tree[u].sec+=k3;//更新次大值
}

$k1,k2$ 分别表示:当前/历史最大值要加的数。

$k3,k4$ 分别表示:当前/历史非最大值要加的数。

然后我们考虑如果更新各个标记。区间和很好理解,就是最大值加的量加上非最大值加的量。对于 $add_b$,它是历史上加的最多的那次,所以我们判断一下当前的 $add_a$ 加上 $k1$ 是否超过了历史,若是,更新历史(如果不理解,请先理解好 $add_b$ 的含义)。$add_b1$ 类似。对于历史最大值,我们看看当前最大值加上历史上加的最多的时候的加标记的值(即 $k2$)是否超过了历史最大值,若超过了,更新历史最大值。其他信息的维护就很容易了,自行看上方代码即可。

区间取最值(update_min)

首先,如果该区间的最大值比要取的 $v$(下面代码中为 $k$)小,显然直接返回即可。若 $sec \le v \le maxa$,那么就更新最大值。怎么更新呢?我们将 $maxa$ 变成了 $k$,也就是对于最大值,减去了 $maxa-k$,换言之,加了 $k-maxa$,用 update 函数操作即可。若不属于这两种情况中的任意一种,那就需要继续遍历到子区间。

inline void update_min(int u,int l,int r,int L,int R,int k){
    if(l>R||r<L||k>=tree[u].maxa) return;//直接返回
    if(l>=L&&r<=R&&k>=tree[u].sec){//更新再返回
        update(u,k-tree[u].maxa,k-tree[u].maxa,0,0);
        return;
    }
    pushdown(u);//继续遍历
    int mid=(l+r)>>1;
    update_min(ls,l,mid,L,R,k);update_min(rs,mid+1,r,L,R,k);
    pushup(u);
}

这个函数不难理解,不多做赘述。其他函数均与普通的线段树类似,这里也不谈了。

完整代码

#include<bits/stdc++.h>
#define int long long
#define ls u<<1
#define rs u<<1|1
#define getchar()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
template <typename T>
inline void read(T& r) {
    r=0;bool w=0; char ch=getchar();
    while(ch<'0'||ch>'9') w=ch=='-'?1:0,ch=getchar();
    while(ch>='0'&&ch<='9') r=r*10+(ch^48), ch=getchar();
    r=w?-r:r;
}
const int N=5e5+5;
struct Segment_Tree{
    int add_a,add_a1,add_b,add_b1;
    int maxa,sec,maxb,sum,cnt,len;
}tree[N<<3];
int n,m,a[N];
inline void pushup(int u){
    tree[u].maxa=max(tree[ls].maxa,tree[rs].maxa);
    tree[u].maxb=max(tree[ls].maxb,tree[rs].maxb);
    tree[u].sum=tree[ls].sum+tree[rs].sum;
    if(tree[ls].maxa==tree[rs].maxa){
        tree[u].sec=max(tree[ls].sec,tree[rs].sec);
        tree[u].cnt=tree[ls].cnt+tree[rs].cnt;
    }
    else if(tree[ls].maxa>tree[rs].maxa){
        tree[u].sec=max(tree[ls].sec,tree[rs].maxa);
        tree[u].cnt=tree[ls].cnt;
    }
    else{
        tree[u].sec=max(tree[ls].maxa,tree[rs].sec);
        tree[u].cnt=tree[rs].cnt;
    }
}
inline void update(int u,int k1,int k2,int k3,int k4){
    tree[u].sum+=k1*tree[u].cnt+k3*(tree[u].len-tree[u].cnt);
    tree[u].maxb=max(tree[u].maxb,tree[u].maxa+k2);
    tree[u].add_b=max(tree[u].add_b,tree[u].add_a+k2);
    tree[u].add_b1=max(tree[u].add_b1,tree[u].add_a1+k4);
    tree[u].maxa+=k1;
    tree[u].add_a+=k1;
    tree[u].add_a1+=k3;
    if(tree[u].sec!=-1e18) tree[u].sec+=k3;
}
inline void pushdown(int u){
    int maxx=max(tree[ls].maxa,tree[rs].maxa);
    if(tree[ls].maxa==maxx)
        update(ls,tree[u].add_a,tree[u].add_b,tree[u].add_a1,tree[u].add_b1);
    else update(ls,tree[u].add_a1,tree[u].add_b1,tree[u].add_a1,tree[u].add_b1);
    if(tree[rs].maxa==maxx)
        update(rs,tree[u].add_a,tree[u].add_b,tree[u].add_a1,tree[u].add_b1);
    else update(rs,tree[u].add_a1,tree[u].add_b1,tree[u].add_a1,tree[u].add_b1);
    tree[u].add_a=tree[u].add_b=tree[u].add_a1=tree[u].add_b1=0;
}
inline void build(int u,int l,int r){
    tree[u].len=r-l+1;
    if(l==r){
        tree[u].maxa=tree[u].maxb=tree[u].sum=a[l];
        tree[u].sec=-1e18;tree[u].cnt=1;
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);
    pushup(u);
}
inline void update_add(int u,int l,int r,int L,int R,int k){
    if(l>R||r<L) return;
    if(l>=L&&r<=R){
        update(u,k,k,k,k);
        return;
    }
    pushdown(u);
    int mid=(l+r)>>1;
    update_add(ls,l,mid,L,R,k);update_add(rs,mid+1,r,L,R,k);
    pushup(u);
}
inline void update_min(int u,int l,int r,int L,int R,int k){
    if(l>R||r<L||k>=tree[u].maxa) return;
    if(l>=L&&r<=R&&k>=tree[u].sec){
        update(u,k-tree[u].maxa,k-tree[u].maxa,0,0);
        return;
    }
    pushdown(u);
    int mid=(l+r)>>1;
    update_min(ls,l,mid,L,R,k);update_min(rs,mid+1,r,L,R,k);
    pushup(u);
}
inline int query_sum(int u,int l,int r,int L,int R){
    if(l>R||r<L) return 0;
    if(l>=L&&r<=R) return tree[u].sum;
    int mid=(l+r)>>1;
    pushdown(u);
    return query_sum(ls,l,mid,L,R)+query_sum(rs,mid+1,r,L,R);
}
inline int query_maxa(int u,int l,int r,int L,int R){
    if(l>R||r<L) return -1e18;
    if(l>=L&&r<=R) return tree[u].maxa;
    int mid=(l+r)>>1;
    pushdown(u);
    return max(query_maxa(ls,l,mid,L,R),query_maxa(rs,mid+1,r,L,R));
}
inline int query_maxb(int u,int l,int r,int L,int R){
    if(l>R||r<L) return -1e18;
    if(l>=L&&r<=R) return tree[u].maxb;
    int mid=(l+r)>>1;
    pushdown(u);
    return max(query_maxb(ls,l,mid,L,R),query_maxb(rs,mid+1,r,L,R));
}
signed main(){
    read(n);read(m);
    for(int i=1;i<=n;i++) read(a[i]);
    build(1,1,n);
    for(int i=1,opt,l,r,k;i<=m;i++){
        read(opt);read(l);read(r);
        if(opt==1){
            read(k);
            update_add(1,1,n,l,r,k);
        }
        else if(opt==2){
            read(k);
            update_min(1,1,n,l,r,k);
        }
        else if(opt==3){
            printf("%lld\n",query_sum(1,1,n,l,r));
        }
        else if(opt==4){
            printf("%lld\n",query_maxa(1,1,n,l,r));
        }
        else{
            printf("%lld\n",query_maxb(1,1,n,l,r));
        }
    }
    return 0;
}

月流华 岁遗沙 万古吴钩出玉匣