[关闭]
@gzr666 2019-07-24T04:03:01.000000Z 字数 6596 阅读 759

【算法•日更•第二十二期】数据结构:线段树

▎前置知识

树&分治

线段树是一种高级数据结构,必须先学会树([戳这里了解][1])和分治([戳这里了解][2])。

▎什么是线段树?

☞『前言』

想必分治学的还不错的话就一定会知道二分查找这种东西。
那么二分是什么样的?在一个数列中不断的折半查找,显然,这是一维的。
如果改成了二维,又会是怎样的?这便是线段树。

☞『定义』

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,实际应用时一般还要开4N的数组以免越界,因此有时需要离散化让空间压缩。

其实线段树就是把二分查找的过程记录下来的树。
每一个节点存储的都是一段区间,和二叉搜索树一样,都是用来查找的,只不过能维护的东西更多,如图所示,线段树就是这样的:

☞『实现操作』

  • 单点加
  • 区间加
  • 区间乘
  • 区间修改
  • 区间询问
  • ……

▎如何实现简单的线段树?

1.建树

不论实现什么操作,我们肯定需要先建起来这棵树。
但是指针、邻接表看起来有点令人头疼,那么我们就采用数组模拟树。
还记得吗?这种树看起来与众不同,首先它是一棵二叉树,其次,它又是一棵完全二叉树。
那么完全二叉树有什么性质?假设当前节点编号为i,那么左孩子编号一定为2*i,右孩子编号一定为2*i+1。
那么我们就递归建树就好了。
代码如下:

  1. void build(int o,int l,int r)
  2. {
  3. if(l==r) {sumv[o]=a[l];return;}
  4. int mid=(l+r)>>1;
  5. build(o<<1,l,mid);
  6. build(o<<1|1,mid+1,r);
  7. pushup(o);
  8. }

那么pushup操作是什么呢?就是合并下面的两个区间,记录这两个区间的和。
pushup的代码很简单:void pushup(int o) {sumv[o]=sumv[o<<1]+sumv[o<<1|1];}

2.单点加

如果只是用于询问区间和的话,那么使用ST表足矣。不同的是,线段树支持修改数字。
那么我们如何修改一个数字的值呢?当然是先找到这个数字,加上相应的值,上面父节点就只要沿路pushup就可以修改了。
代码如下:

  1. void change(int o,int l,int r,int k,int v)
  2. {
  3. if(l==r) {sumv[o]+=v;return;}
  4. int mid=(l+r)>>1;
  5. if(k<=mid) change(o<<1,l,mid,k,v);
  6. else change(o<<1|1,mid+1,r,k,v);
  7. pushup(o);
  8. }

3.区间加

有时题目会相当变态,让你实现区间加,那么我们该怎样解决?难道是暴力单点加
显然不行,我们应该先找到这个区间(也可能是多个区间组成),然后修改了区间,但是它的子区间还没有修改啊,所以我们就需要下放,同时打好标记,这样就会准确的知道哪些子区间需要修改。
那么下放我们又定义了一个新的函数pushdown,代码是这样的:

  1. void pushdown(int o,int l,int r)
  2. {
  3. if(add[o]==0) return;
  4. add[o<<1]+=add[o];
  5. add[o<<1|1]+=add[o];
  6. int mid=(l+r)>>1;
  7. sumv[o<<1]+=add[o]*(mid-l+1);
  8. sumv[o<<1|1]=add[o]*(r-mid);
  9. add[o]=0;
  10. }

其中add是标记,sumv是树存的和,注意:一定要下方后清除该标记。
还有一个做标记的函数puttagvoid puttag(int o,int l,int r,int v) {add[o]+=v;sumv[o]+=(r-l+1)*v;}
整个操作的代码是这样的:

  1. void optadd(int o,int l,int r,int ql,int qr,int v)
  2. {
  3. if(ql<=l&&r<=qr) {puttag(o,l,r,v);return;}
  4. int mid=(l+r)>>1;
  5. pushdown(o,l,r);
  6. if(ql<=mid) optadd(o<<1,l,mid,ql,qr,v);
  7. if(qr>mid) optadd(o<<1|1,mid+1,r,ql,qr,v);
  8. pushup(o);
  9. }

4.询问区间和

询问呢区间就一定要先找到区间,当然询问区间也有可能是多个树的区间构成的,所以首先来思考询问区间和当前区间可能是什么样的。
显然,要么是整个当前区间在询问区间中,要么只有一部分是。图解一下:

当整个区间都被包含时,我们直接返回和就可以了。
只有一部分时就要继续分类讨论:当ql<=mid时,一部分肯定在左区间;当qr>mid时,那么右区间肯定有一部分。
所以整个代码就很清晰了:

  1. int main()
  2. {
  3. cin>>n>>m;
  4. for(int i=1;i<=n;i++)
  5. cin>>a[i];
  6. build(1,1,n);
  7. for(int i=1;i<=m;i++)
  8. {
  9. int pos,x,y,z;
  10. cin>>pos>>x>>y;
  11. if(pos==1) change(1,1,n,x,y);
  12. else if(pos==2)
  13. {
  14. cin>>z;
  15. optadd(1,1,n,x,y,z);
  16. }
  17. else cout<<query(1,1,n,x,y);
  18. }
  19. return 0;
  20. }

亲测能用:

代码整合+小白版代码

代码如下:

  1. #include<iostream>
  2. using namespace std;
  3. int sumv[10000],add[10000],n,m,a[10000];
  4. void pushup(int o) {sumv[o]=sumv[o<<1]+sumv[o<<1|1];}
  5. void puttag(int o,int l,int r,int v) {add[o]+=v;sumv[o]+=(r-l+1)*v;}
  6. void pushdown(int o,int l,int r)
  7. {
  8. if(add[o]==0) return;
  9. add[o<<1]+=add[o];
  10. add[o<<1|1]+=add[o];
  11. int mid=(l+r)>>1;
  12. sumv[o<<1]+=add[o]*(mid-l+1);
  13. sumv[o<<1|1]=add[o]*(r-mid);
  14. add[o]=0;
  15. }
  16. void build(int o,int l,int r)
  17. {
  18. if(l==r) {sumv[o]=a[l];return;}
  19. int mid=(l+r)>>1;
  20. build(o<<1,l,mid);
  21. build(o<<1|1,mid+1,r);
  22. pushup(o);
  23. }
  24. void change(int o,int l,int r,int k,int v)
  25. {
  26. if(l==r) {sumv[o]+=v;return;}
  27. int mid=(l+r)>>1;
  28. if(k<=mid) change(o<<1,l,mid,k,v);
  29. else change(o<<1|1,mid+1,r,k,v);
  30. pushup(o);
  31. }
  32. void optadd(int o,int l,int r,int ql,int qr,int v)
  33. {
  34. if(ql<=l&&r<=qr) {puttag(o,l,r,v);return;}
  35. int mid=(l+r)>>1;
  36. pushdown(o,l,r);
  37. if(ql<=mid) optadd(o<<1,l,mid,ql,qr,v);
  38. if(qr>mid) optadd(o<<1|1,mid+1,r,ql,qr,v);
  39. pushup(o);
  40. }
  41. int query(int o,int l,int r,int ql,int qr)
  42. {
  43. if(ql<=l&&r<=qr) return sumv[o];
  44. int ans=0,mid=(l+r)>>1;
  45. pushdown(o,l,r);
  46. if(ql<=mid) ans+=query(o<<1,l,mid,ql,qr);
  47. if(qr>mid) ans+=query(o<<1|1,mid+1,r,ql,qr);
  48. return ans;
  49. }
  50. int main()
  51. {
  52. cin>>n>>m;
  53. for(int i=1;i<=n;i++)
  54. cin>>a[i];
  55. build(1,1,n);
  56. for(int i=1;i<=m;i++)
  57. {
  58. int pos,x,y,z;
  59. cin>>pos>>x>>y;
  60. if(pos==1) change(1,1,n,x,y);
  61. else if(pos==2)
  62. {
  63. cin>>z;
  64. optadd(1,1,n,x,y,z);
  65. }
  66. else cout<<query(1,1,n,x,y);
  67. }
  68. return 0;
  69. }

这里面用到了很多位运算,如果你是小白的话一定要尽快学位运算(快的可不止一星半点),那么发一个不带位运算的代码:

  1. #include<iostream>
  2. using namespace std;
  3. int sumv[10000],add[10000],n,m,a[10000];
  4. void pushup(int o) {sumv[o]=sumv[o*2]+sumv[o*2+1];}
  5. void puttag(int o,int l,int r,int v) {add[o]+=v;sumv[o]+=(r-l+1)*v;}
  6. void pushdown(int o,int l,int r)
  7. {
  8. if(add[o]==0) return;
  9. add[o*2]+=add[o];
  10. add[o*2+1]+=add[o];
  11. int mid=(l+r)>>1;
  12. sumv[o*1]+=add[o]*(mid-l+1);
  13. sumv[o*1+1]=add[o]*(r-mid);
  14. add[o]=0;
  15. }
  16. void build(int o,int l,int r)
  17. {
  18. if(l==r) {sumv[o]=a[l];return;}
  19. int mid=(l+r)/2;
  20. build(o*2,l,mid);
  21. build(o*2+1,mid+1,r);
  22. pushup(o);
  23. }
  24. void change(int o,int l,int r,int k,int v)
  25. {
  26. if(l==r) {sumv[o]+=v;return;}
  27. int mid=(l+r)/2;
  28. if(k<=mid) change(o*2,l,mid,k,v);
  29. else change(o*2+1,mid+1,r,k,v);
  30. pushup(o);
  31. }
  32. void optadd(int o,int l,int r,int ql,int qr,int v)
  33. {
  34. if(ql<=l&&r<=qr) {puttag(o,l,r,v);return;}
  35. int mid=(l+r)/2;
  36. pushdown(o,l,r);
  37. if(ql<=mid) optadd(o*2,l,mid,ql,qr,v);
  38. if(qr>mid) optadd(o*2+1,mid+1,r,ql,qr,v);
  39. pushup(o);
  40. }
  41. int query(int o,int l,int r,int ql,int qr)
  42. {
  43. if(ql<=l&&r<=qr) return sumv[o];
  44. int ans=0,mid=(l+r)/2;
  45. pushdown(o,l,r);
  46. if(ql<=mid) ans+=query(o*2,l,mid,ql,qr);
  47. if(qr>mid) ans+=query(o*2+1,mid+1,r,ql,qr);
  48. return ans;
  49. }
  50. int main()
  51. {
  52. cin>>n>>m;
  53. for(int i=1;i<=n;i++)
  54. cin>>a[i];
  55. build(1,1,n);
  56. for(int i=1;i<=m;i++)
  57. {
  58. int pos,x,y,z;
  59. cin>>pos>>x>>y;
  60. if(pos==1) change(1,1,n,x,y);
  61. else if(pos==2)
  62. {
  63. cin>>z;
  64. optadd(1,1,n,x,y,z);
  65. }
  66. else cout<<query(1,1,n,x,y);
  67. }
  68. return 0;
  69. }

▎实战演练

废话不多说,直接上题:


P1047 校门外的树

题目描述

某校大门外长度为L的马路上有一排树,每两棵相邻的树之间的间隔都是1米。我们可以把马路看成一个数轴,马路的一端在数轴0的位置,另一端在L的位置;数轴上的每个整数点,即0,1,2,…,L,都种有一棵树。

由于马路上有一些区域要用来建地铁。这些区域用它们在数轴上的起始点和终止点表示。已知任一区域的起始点和终止点的坐标都是整数,区域之间可能有重合的部分。现在要把这些区域中的树(包括区域端点处的两棵树)移走。你的任务是计算将这些树都移走后,马路上还有多少棵树。

输入输出格式

输入格式:

第一行有2个整数L(1≤L≤10000)和 M(1≤M≤100),L代表马路的长度,M代表区域的数目,L和M之间用一个空格隔开。
接下来的M行每行包含2个不同的整数,用一个空格隔开,表示一个区域的起始点和终止点的坐标。

输出格式:

1个整数,表示马路上剩余的树的数目。

输入输出样例

输入样例#1:
500 3
150 300
100 200
470 471
输出样例#1:
298

说明

NOIP2005普及组第二题

对于的数据,区域之间没有重合的部分;

对于其它的数据,区域之间有重合的情况。

这道题简直就是模板题,只要把区间加换成区间修改就可以了。

代码如下:

  1. #include<iostream>
  2. using namespace std;
  3. int sumv[100000],add[100000],l,m;
  4. void pushup(int o) {sumv[o]=sumv[o<<1]+sumv[o<<1|1];}
  5. void puttag(int o,int l,int r) {add[o]=1;sumv[o]=0;}
  6. void pushdown(int o,int l,int r)
  7. {
  8. if(add[o]==0) return;
  9. add[o<<1]=1;
  10. add[o<<1|1]=1;
  11. sumv[o<<1]=0;
  12. sumv[o<<1|1]=0;
  13. add[o]=0;
  14. }
  15. void build(int o,int l,int r)
  16. {
  17. add[o]=0;
  18. if(l==r) {sumv[o]=1;return;}
  19. int mid=(l+r)>>1;
  20. build(o<<1,l,mid);
  21. build(o<<1|1,mid+1,r);
  22. pushup(o);
  23. }
  24. void optadd(int o,int l,int r,int ql,int qr)
  25. {
  26. if(ql<=l&&r<=qr) {puttag(o,l,r);return;}
  27. int mid=(l+r)>>1;
  28. pushdown(o,l,r);
  29. if(ql<=mid) optadd(o<<1,l,mid,ql,qr);
  30. if(qr>mid) optadd(o<<1|1,mid+1,r,ql,qr);
  31. pushup(o);
  32. }
  33. //int query(int o,int l,int r,int ql,int qr)
  34. //{
  35. // if(ql<=l&&r<=qr) return sumv[o];
  36. // int mid=(l+r)>>1;
  37. // pushdown(o,l,r);
  38. // if(ql<=mid) int t1=optadd(sumv[o<<1],l,mid,ql,qr);
  39. // if(r>qr) int t2=optadd(sumv[o<<1|1],mid+1,r,ql,qr);
  40. // return t1+t2;
  41. //}
  42. int main()
  43. {
  44. cin>>l>>m;
  45. build(1,1,l+1);
  46. for(int i=1;i<=m;i++)
  47. {
  48. int x,y;
  49. cin>>x>>y;
  50. optadd(1,1,l+1,x+1,y+1);
  51. }
  52. cout<<sumv[1];
  53. return 0;
  54. }

偷偷的告诉你,这道题可以暴力做的:

  1. #include<iostream>
  2. using namespace std;
  3. int main()
  4. {
  5. int l,m,vis[100000],ans=0;
  6. cin>>l>>m;
  7. for(int i=1;i<=m;i++)
  8. {
  9. int x,y;
  10. cin>>x>>y;
  11. for(int i=x;i<=y;i++)
  12. vis[i]=1;
  13. }
  14. for(int i=0;i<=l;i++)
  15. if(vis[i]==0) ans++;
  16. cout<<ans;
  17. return 0;
  18. }
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注