主席树入门

可持久化线段树

Posted by mental2008 on September 30, 2018

主席树思想是每个位置都维护一个线段树,线段树的节点是值的范围,然后第 $i$ 个线段树中某个区间 $[x, y]$ 维护的是,$a_1$-$a_i$ 中数字在 $[x, y]$ 范围内的个数。这里利用到了前缀和的思想。

常见主席树能处理的操作有:查询静态区间第 $k$ 小。

查询区间第 $k$ 小模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include<bits/stdc++.h>

using namespace std;

#define mem(a,i) memset(a,i,sizeof(a))
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define per(i,a,b) for(int i=a;i>=b;--i)
typedef long long ll;
const int maxn=2e5+5;
int a[maxn];
int b[maxn];
int rt[maxn];
int sum[maxn*32];
int L[maxn*32];
int R[maxn*32];
int cnt;
int build(int l,int r) {
    int root=++cnt;
    sum[root]=0;
    if(l==r) return root;
    int mid=(l+r)>>1;
    L[root]=build(l,mid);
    R[root]=build(mid+1,r);
    return root;
}
int update(int pre,int l,int r,int x) {
    int root=++cnt;
    L[root]=L[pre];
    R[root]=R[pre];
    sum[root]=sum[pre]+1;
    if(l==r) return root;
    int mid=(l+r)>>1;
    if(x<=mid) L[root]=update(L[pre],l,mid,x);
    else R[root]=update(R[pre],mid+1,r,x);
    return root;
}
int query(int u,int v,int l,int r,int k) {
    if(l==r) return l;
    int x=sum[L[v]]-sum[L[u]];
    int mid=(l+r)>>1;
    if(x>=k) return query(L[u],L[v],l,mid,k);
    else return query(R[u],R[v],mid+1,r,k-x);
}

int main() {
    int n,q;
    scanf("%d%d",&n,&q);
    rep(i,1,n) {
        scanf("%d",&a[i]);
        b[i]=a[i];
    }
    sort(b+1,b+n+1);
    int m=unique(b+1,b+n+1)-b-1;
    cnt=0;
    rt[0]=build(1,m);
    rep(i,1,n) {
        int t=lower_bound(b+1,b+m+1,a[i])-b;
        rt[i]=update(rt[i-1],1,m,t);
    }
    while(q--) {
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        int ans=query(rt[x-1],rt[y],1,m,z);
        printf("%d\n",b[ans]);
    }
    return 0;
}

当然,主席树并不是只能查询区间第 $k$ 大的值。作为一颗可持久化的线段树,它其实记录了所有的历史信息,所以可以回退到任一的历史版本。

hdu 4348

  • 操作C,对 [$l$, $r$] 区间加上 $d$,同时时间戳后移
  • 操作Q,查询 $t$ 时刻,[$l$, $r$] 的区间和
  • 操作H,查询当前时刻,[$l$, $r$] 的区间和
  • 操作B,回退到时刻 $t$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include<bits/stdc++.h>

using namespace std;

#define mem(a,i) memset(a,i,sizeof(a))
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define per(i,a,b) for(int i=a;i>=b;--i)
typedef long long ll;
const int maxn=1e5+5;
int n,q;
int rt[maxn];
int L[maxn*40];
int R[maxn*40];
ll sum[maxn*40];
ll add[maxn*40];
ll a[maxn];
int cnt;
int build(int l,int r) {
    int root=++cnt;
    add[root]=0;
    if(l==r) {
        sum[root]=a[l];
        return root;
    }
    int mid=(l+r)>>1;
    L[root]=build(l,mid);
    R[root]=build(mid+1,r);
    sum[root]=sum[L[root]]+sum[R[root]];
    return root;
}
int update(int pre,int l,int r,int ql,int qr,int d) {
    int root=++cnt;
    L[root]=L[pre];
    R[root]=R[pre];
    add[root]=add[pre];
    sum[root]=sum[pre];
    if(l>=ql&&r<=qr) {
        add[root]+=d;
        sum[root]+=(ll)d*(r-l+1);
        return root;
    }
    int mid=(l+r)>>1;
    if(ql<=mid) L[root]=update(L[pre],l,mid,ql,qr,d);
    if(qr>mid) R[root]=update(R[pre],mid+1,r,ql,qr,d);
    sum[root]=sum[L[root]]+sum[R[root]]+add[root]*(r-l+1);
    return root;
}
ll query(int root,int l,int r,int ql,int qr) {
    if(l>=ql&&r<=qr) return sum[root];
    ll res=0;
    if(l<=ql) {
        if(r>=qr) res+=add[root]*(qr-ql+1);
        else res+=add[root]*(r-ql+1);
    }
    else {
        res+=add[root]*(qr-l+1);
    }
    int mid=(l+r)>>1;
    if(ql<=mid) res+=query(L[root],l,mid,ql,qr);
    if(qr>mid) res+=query(R[root],mid+1,r,ql,qr);
    return res;
}

int main() {
    while(~scanf("%d%d",&n,&q)) {
        cnt=0;
        rep(i,1,n) scanf("%lld",&a[i]);
        rt[0]=build(1,n);
        int time=0;
        while(q--) {
            char o[5];
            scanf("%s",o);
            if(o[0]=='C') {
                time++;
                int l,r,d;
                scanf("%d%d%d",&l,&r,&d);
                rt[time]=update(rt[time-1],1,n,l,r,d);
            }
            else if(o[0]=='Q') {
                int l,r;
                scanf("%d%d",&l,&r);
                ll ans=query(rt[time],1,n,l,r);
                printf("%lld\n",ans);
            }
            else if(o[0]=='H') {
                int l,r,t;
                scanf("%d%d%d",&l,&r,&t);
                ll ans=query(rt[t],1,n,l,r);
                printf("%lld\n",ans);
            }
            else {
                int t;
                scanf("%d",&t);
                time=t;
            }
        }
    }
    return 0;
}