题意
维护一棵树,支持区间求和,区间最大值,区间最小值,区间取反,单点修改。
解析
这是一道很好的树剖练手题,思维难度不高,代码存在一些细节,本题解主要来说明这一部分,也用来警醒自己做题目时需要注意细节,尽可能避免调代码调半天的情况。
边权转点权
在树剖的两边 dfs 过程中,第一遍可以将边权赋给边端点中深度较大的那一个,即儿子节点。我们用儿子上的点权来表示了边权。
两遍 dfs 代码如下:
inline void dfs1(int u,int f){
dep[u]=dep[f]+1;
fa[u]=f;siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f) continue;
val[v]=e[i].val;//将边权赋给儿子
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
inline void dfs2(int u,int Top){
id[u]=++tot;
a[tot]=val[u];//获得边权后,儿子有了点权,将其赋到数列中
top[u]=Top;
if(!son[u]) return;
dfs2(son[u],Top);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
如何区间取反
在区间取反时,我们面对 $tag$ 需要 pushdown 的时候需要保证所记录的和/最大值/最小值都正确变化。
显然的,区间和直接取反即可,而最大值/最小值,容易发现他们取反之后就分别变成了新的最小值/最大值,因此取反后交换即可。
pushdown 部分代码如下:
inline void pushdown(int u){
if(!tree[u].tag) return;
tree[ls].tag^=1;tree[rs].tag^=1;
tree[ls].sum=-tree[ls].sum;tree[rs].sum=-tree[rs].sum;
tree[ls].mx=-tree[ls].mx;tree[rs].mx=-tree[rs].mx;
tree[ls].mn=-tree[ls].mn;tree[rs].mn=-tree[rs].mn;
swap(tree[ls].mx,tree[ls].mn);swap(tree[rs].mx,tree[rs].mn);
tree[u].tag=0;
}
切记不能直接互相赋值,除非你乐意维护一个中间量。
如何树上修改/查询
这里以区间求和为例。
inline void SumTree(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=querysum(1,1,n,id[top[x]],id[x]);//tag 1
x=fa[top[x]];
}
if(x!=y){
if(dep[x]>dep[y]) swap(x,y);
ans+=querysum(1,1,n,id[x]+1,id[y]); //tag 2
}
printf("%lld\n",ans);
}
注意用注释打上标记的这两行,$id_{top_x}$ 没有加 $1$,而 $id_x$ 加 $1$,原因是:当你在树上向上跳时,由于你要跳到链顶的上一个点,因此链顶所表示的那个边权是需要计算的,而第二个 $tag$ 那一行,需要求的边是从 $x$ 至 $y$ ($y$ 较深),因此 $x$ 上方那条边,即 $x$ 代表的边权是不计算的,需要加 $1$。
修改输入的第 $i$ 条边
我们需要在输入时记录每一条边的的 $u,v$ 两点(如果你写链式前向星之类的话,实际上已经记录好了),然后由于我们边权赋给了边的儿子,因此我们修改时应该是找出 $u,v$ 中深度较深的那一点,修改边权转点权之后这一点的点权。
修改单条边的代码如下:
if(ch[0]=='C'){
scanf("%lld%lld",&x,&w);
int tmp;
u=e[x<<1].from,v=e[x<<1].to;//x<<1 是因为使用链式前向星,正反都建了边
if(dep[u]>dep[v]) tmp=u;//修改深的点
else tmp=v;
updateTree(tmp,w);
}
处理好这几个问题之后,剩下的就是树剖的板子了,不再赘述。
代码
#include<bits/stdc++.h>
#define int long long
#define ls u<<1
#define rs u<<1|1
using namespace std;
const int N=2e5+5;
const int inf=2e9;
struct SegmentTree{
int len,tag,sum,mx,mn;
}tree[N<<4];
struct edge{
int from,to,nxt,val;
}e[N<<1];
int cnt,head[N],n,m,val[N];
char ch[20];
inline void add(int u,int v,int w){
e[++cnt].to=v;
e[cnt].from=u;
e[cnt].nxt=head[u];
e[cnt].val=w;
head[u]=cnt;
}
int dep[N],fa[N],son[N],siz[N],top[N],id[N],a[N];
inline void dfs1(int u,int f){
dep[u]=dep[f]+1;
fa[u]=f;siz[u]=1;
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==f) continue;
val[v]=e[i].val;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
int tot;
inline void dfs2(int u,int Top){
id[u]=++tot;
a[tot]=val[u];
top[u]=Top;
if(!son[u]) return;
dfs2(son[u],Top);
for(int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
//以下为线段树
inline void pushup(int u){
tree[u].sum=tree[ls].sum+tree[rs].sum;
tree[u].mx=max(tree[ls].mx,tree[rs].mx);
tree[u].mn=min(tree[ls].mn,tree[rs].mn);
}
inline void pushdown(int u){
if(!tree[u].tag) return;
tree[ls].tag^=1;tree[rs].tag^=1;
tree[ls].sum=-tree[ls].sum;tree[rs].sum=-tree[rs].sum;
tree[ls].mx=-tree[ls].mx;tree[rs].mx=-tree[rs].mx;
tree[ls].mn=-tree[ls].mn;tree[rs].mn=-tree[rs].mn;
swap(tree[ls].mx,tree[ls].mn);swap(tree[rs].mx,tree[rs].mn);
tree[u].tag=0;
}
inline void build(int u,int l,int r){
if(l==r){
tree[u].sum=tree[u].mx=tree[u].mn=a[l];
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
pushup(u);
}
inline void update(int u,int l,int r,int k,int w){
if(k<l||k>r) return;
if(l==r){
tree[u].sum=tree[u].mx=tree[u].mn=w;
return;
}
pushdown(u);
int mid=(l+r)>>1;
update(ls,l,mid,k,w);update(rs,mid+1,r,k,w);
pushup(u);
}
inline void makeop(int u,int l,int r,int L,int R){
if(r<L||l>R) return;
if(l>=L&&r<=R){
tree[u].tag^=1;
tree[u].sum=-tree[u].sum;
tree[u].mx=-tree[u].mx;
tree[u].mn=-tree[u].mn;
swap(tree[u].mx,tree[u].mn);
return;
}
pushdown(u);
int mid=(l+r)>>1;
makeop(ls,l,mid,L,R);makeop(rs,mid+1,r,L,R);
pushup(u);
}
inline int querysum(int u,int l,int r,int L,int R){
if(r<L||l>R) return 0;
if(l>=L&&r<=R) return tree[u].sum;
pushdown(u);
int mid=(l+r)>>1;
return querysum(ls,l,mid,L,R)+querysum(rs,mid+1,r,L,R);
}
inline int querymx(int u,int l,int r,int L,int R){
if(r<L||l>R) return -inf;
if(l>=L&&r<=R) return tree[u].mx;
pushdown(u);
int mid=(l+r)>>1;
return max(querymx(ls,l,mid,L,R),querymx(rs,mid+1,r,L,R));
}
inline int querymn(int u,int l,int r,int L,int R){
if(r<L||l>R) return inf;
if(l>=L&&r<=R) return tree[u].mn;
pushdown(u);
int mid=(l+r)>>1;
return min(querymn(ls,l,mid,L,R),querymn(rs,mid+1,r,L,R));
}
//以上为线段树
inline void updateTree(int x,int w){
update(1,1,n,id[x],w);
}
inline void NTree(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
makeop(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(x!=y){
if(dep[x]>dep[y]) swap(x,y);
makeop(1,1,n,id[x]+1,id[y]);
}
}
inline void SumTree(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=querysum(1,1,n,id[top[x]],id[x]);
x=fa[top[x]];
}
if(x!=y){
if(dep[x]>dep[y]) swap(x,y);
ans+=querysum(1,1,n,id[x]+1,id[y]);
}
printf("%lld\n",ans);
}
inline void MaxTree(int x,int y){
int ans=-inf;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=max(querymx(1,1,n,id[top[x]],id[x]),ans);
x=fa[top[x]];
}
if(x!=y){
if(dep[x]>dep[y]) swap(x,y);
ans=max(querymx(1,1,n,id[x]+1,id[y]),ans);
}
printf("%lld\n",ans);
}
inline void MinTree(int x,int y){
int ans=inf;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=min(querymn(1,1,n,id[top[x]],id[x]),ans);
x=fa[top[x]];
}
if(x!=y){
if(dep[x]>dep[y]) swap(x,y);
ans=min(querymn(1,1,n,id[x]+1,id[y]),ans);
}
printf("%lld\n",ans);
}
signed main(){
scanf("%lld",&n);
for(int i=1,u,v,w;i<n;i++){
scanf("%lld%lld%lld",&u,&v,&w);
add(u+1,v+1,w);add(v+1,u+1,w);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
scanf("%lld",&m);
for(int i=1,w,x,u,v;i<=m;i++){
scanf("%s",ch);
if(ch[0]=='C'){
scanf("%lld%lld",&x,&w);
int tmp;
u=e[x<<1].from,v=e[x<<1].to;
if(dep[u]>dep[v]) tmp=u;
else tmp=v;
updateTree(tmp,w);
}
else if(ch[0]=='N'){
scanf("%lld%lld",&u,&v);
NTree(u+1,v+1);
}
else if(ch[0]=='S'){
scanf("%lld%lld",&u,&v);
SumTree(u+1,v+1);
}
else if(ch[1]=='A'){
scanf("%lld%lld",&u,&v);
MaxTree(u+1,v+1);
}
else{
scanf("%lld%lld",&u,&v);
MinTree(u+1,v+1);
}
}
return 0;
}