树状数组


树状数组详解

什么是树状数组:是一种维护前缀和并且可以支持查询,添加的操作的一种数据结构。

为什么使用树状数组?他的add操作与sum操作时间复杂度都是logn。相比于传统的前缀和数组,创建是一个O(n)的复杂度,但是查询却是O(1)的复杂度,树状数组取了一个中,使得两种操作的耗时比较平均。

lowbit的概念:

二进制理解:就是该数x,转化为二进制数之后,从右向左数,遇到第一个1,这一段数字的大小,就是最低位的1的大小,这么说有点抽象….hh

举个例子:比如说8,他的二进制是1000,所以返回的就是1000,即8.

在比如说12,他的二进制是1100,那么他返回的就是100,就是4.

直白理解:对于奇数来说,lowbit就是1,因为他的二进制数的第一位绝对是1,开始即结束,对于偶数来说,lowbit其实就是他的最大2的次方的因子。

计算方式:

int lowbit(int x)
{
    return x & -x ;
}

这是树状数组的精华所在,所谓树状数组,说白了,其实就是用数组来实现树的某些功能,只不过不用建树而已,这里的lowbit其实就是划分父子节点的关键所在,同时这也是树状数组可以将查询操作与建立操作控制在logn的关键。

小声bb,这其实让我想到了背包问题,好像是多重背包问题的二进制优化,也是通过这样的思想,让本来O(N)的复杂度,变成logn。

树状数组的查询与建立操作

知道了lowbit之后,其实树状数组就是很好理解了。

也就是说,其实树状数组的序号以及序号之间的关系就已经定下来了,需要根据不同的题目来确定value以及序号的实际含义而已。

那我们如何建立一个树状数组呢,哎,可以通过普通的数组来进行建立一个树状数组。

代码如下:

void add(int x , int k)
{
    for(int i = x ; i <= n ; i += lowbit(i))
    {
        tr[i] += k ; 
    }
}

因为类似于前缀和数组,他的每一个子节点的改变,都会使得他的父节点有改变,所以我们一直往上改变。

好了,那如何查询一个节点的前缀和呢?

int sum(int x)
{
    int res = 0 ;
    for(int i = x ; i ; i -= lowbit(i))
    {
        res += tr[i] ; 
    }
    return res ; 
}

好了,以上两项是树状数组的基本操作,知道了这两种操作之后,我们可以玩很多花样,比如后面的例题。

可以在value上做一些花样,比如说存的是次数(其实是因为题目有局限性,因为横坐标与纵坐标是范围一致的),所以我们可以用树状数组的序号来存普通数组的值。

也可以在存的数值上做一些改变,比如想要快速改变一段区间的值,我们甚至可以存目标数组的差分数组,这样同时也可以快速的查询单点值,因为差分数组的区间和就是普通数组的一个单点值吗。

ok,下面是一些例题。

楼兰图腾


在完成了分配任务之后,西部 314 来到了楼兰古城的西部。

相传很久以前这片土地上(比楼兰古城还早)生活着两个部落,一个部落崇拜尖刀(V),一个部落崇拜铁锹(),他们分别用 V 的形状来代表各自部落的图腾。

西部 314 在楼兰古城的下面发现了一幅巨大的壁画,壁画上被标记出了 n个点,经测量发现这 n个点的水平位置和竖直位置是两两不同的。

西部 314认为这幅壁画所包含的信息与这 n 个点的相对位置有关,因此不妨设坐标分别为 (1,y1),(2,y2),…,(n,yn),其中y1∼yn 是 1 到 n 的一个排列。

西部 314 打算研究这幅壁画中包含着多少个图腾。

如果三个点 (i,yi),(j,yj),(k,yk) 满足 1≤i<j<k≤n 且 yi>yj,yj<yk,则称这三个点构成 V 图腾;

如果三个点 (i,yi),(j,yj),(k,yk) 满足 1≤i<j<k≤n 且 yi<yj,yj>yk,则称这三个点构成 图腾;

西部 314 想知道,这 n 个点中两个部落图腾的数目。

因此,你需要编写一个程序来求出 V 的个数和 的个数。

输入格式

第一行一个数 n。

第二行是 n个数,分别代表y1,y2,…,yn。

输出格式

两个数,中间用空格隔开,依次为 V 的个数和 的个数。

数据范围

对于所有数据,n≤200000,且输出答案不会超过 int64。
y1∼yn 是 11 到 n 的一个排列。

输入样例:

5
1 5 3 2 4

输出样例:

3 4

思路解释:炸一看此题,感觉思路比较自然,就是对于每一个节点来说,先将其看成最低点,分别找两边大于他的点的个数,两边一乘,就是此点v的值,然后将每个点的v的值加起来,就是总共的v的值,倒v也是如此。

最白的思路就是直接暴力,我们暴力来看一下。

#include 

using namespace std ;

const int N = 2e5 + 10 ; 

int n ; 

typedef long long LL ; 

int a[N] ,  le[N] ,  ri[N] ; 

LL res = 0 , ans = 0 ; 

int main()
{
    cin >> n ; 

    for(int i = 1 ; i <= n ;  ++ i)
    {
        cin >> a[i] ; 
    }

    for(int i = 2 ; i <= n-1 ; ++ i)
    {
        int l = 0 , r = 0 ;
        int tar = a[i] ;  
        for(int j = 1 ; j < i ; ++ j)
        {
            if(a[j] > tar)
            {
                l ++ ; 
            }
        }


        for(int j = i + 1 ; j <= n ; ++ j)
        {
            if(a[j] > tar)
            {
                r++ ; 
            }
        }

        le[i] = l , ri[i] = r ;     
        res += (LL)l * r ;
    }
 
    memset(le , 0 , sizeof le) ; 
    memset(ri , 0 , sizeof ri) ; 
    for(int i = 2 ; i <= n-1 ; ++ i)
    {
        int l = 0 , r = 0 ;
        int tar = a[i] ;  
        for(int j = 1 ; j < i ; ++ j)
        {
            if(a[j] < tar)
            {
                l ++ ; 
            }
        }


        for(int j = i + 1 ; j <= n ; ++ j)
        {
            if(a[j] < tar)
            {
                r++ ; 
            }
        }

        le[i] = l , ri[i] = r ;     
        ans += (LL)l * r ;
    }    
    cout << res << " " << ans << endl ; 
    return  0 ;

}

不出所料,是n方的复杂度,n最大是20万,肯定会TLE。

我们于是想如何用树状数组来进行存储,我们于是有注意到一点,yn的排列也是1-n,也就是说,这在逻辑上可以看成是一个正方形,什么意思呢?也就是说,我们可以将其与sum函数连接起来,sum可以查询在sum之前所有value的值。

那我们的思路自然就出来了,用树状数组的下标,存的是yn,每一个的value存的是yn出现的次数。在申明一个low和great数组,每一个存放比当前节点高或者是低的节点的数量即可。

贴个代码:

#include 

using namespace std ;

const int N = 2e5 + 10 ; 

typedef long long LL ; 

int a[N] ,tr[N] , low[N] , high[N];

int n ; 

int lowbit(int x)
{
    return x & -x ;
}

void add(int x , int k)
{
    for(int i = x ; i <= n ; i += lowbit(i))
    {
        tr[i] += k ; 
    }
}

int sum(int x)
{
    int res = 0 ;
    
    for(int i = x ; i ; i -= lowbit(i))
    {
        res += tr[i] ; 
    }
    return res ; 
}

int main()
{
    cin >> n ;
    
    for(int i = 1;  i <= n ; ++ i)
    {
        cin >> a[i] ; 
    }
    
    for(int i = 1 ; i <= n ; ++ i)
    {
        int y = a[i] ; 
        low[i] = sum(y-1) ; 
        high[i] = sum(n) - sum(y) ; 
        add(y,1) ; 
    }
    
    memset(tr , 0 , sizeof tr) ; 
    
    LL ans = 0 , res = 0 ; 
    
    for(int i = n ; i >= 1 ; -- i)
    {
        int y  = a[i];
        ans += (LL) low[i] * (sum(y-1)) ; 
        res += (LL) high[i] * (sum(n) - sum(y)) ; 
        add(y,1) ; 
    }
    cout << res << " " << ans << endl ; 
    return  0 ; 
}

一个简单的整数问题


给定长度为 N的数列A,然后输入 M 行操作指令。

第一类指令形如 C l r d,表示把数列中第 l∼r 个数都加 dd。

第二类指令形如 Q x,表示询问数列中第 x 个数的值。

对于每个询问,输出一个整数表示答案。

输入格式

第一行包含两个整数 N 和 M。

第二行包含 N 个整数 A[i]。

接下来 M 行表示 M 条指令,每条指令的格式如题目描述所示。

输出格式

对于每个询问,输出一个整数表示答案。

每个答案占一行。

数据范围

1≤N,M≤105,
|d|≤10000,
|A[i]|≤109

输入样例:

10 5
1 2 3 4 5 6 7 8 9 10
Q 4
Q 1
Q 2
C 1 6 3
Q 2

输出样例:

4
1
2
5

这一题题意很简单,Q是查询下标对应值,C是对数组的一个区间进行修改操作。

我们的正常思路应该是,循环add,但是那样的话,这一题很显然就超时了,而且这根本就不会这样考的。

所以,我们立刻想到了一种O(n)的修改区间值的方法,没错,就是差分,他比较适合于用两个端点的值来进行修改一整段区间的值。

代码如下:

#include 

using namespace std ;

const int N = 1e5 + 10 ; 

int a[N] , tr[N] ; 

int n , m  ; 

int lowbit(int x)
{
    return x & -x  ;
}

void add(int x , int k)
{
    for(int i = x ; i <= n ; i += lowbit(i))
    {
        tr[i] += k ; 
    }
}


int sum(int x)
{
    int res = 0 ; 
    for(int i = x ; i ;  i -= lowbit(i))
    {
        res += tr[i] ; 
    }
    return res ; 
}


int main()
{
    cin >> n >> m ; 
    
    for(int i = 1 ; i <= n ; ++ i)
    {
        cin >> a[i] ;
        add(i , a[i] - a[i-1]) ; 
    }
    
    while(m--)
    {
        string op ; 
        cin >> op ; 
        if(op[0] == 'C')
        {
            int l , r , d ; 
            cin >> l >> r >> d ; 
            add(l , d) , add(r+1 , -d) ; 
        }
        else
        {
            int num ; 
            cin >> num ; 
            int res = sum(num) ; 
            cout << res << endl ; 
        }
    }
    return  0 ;
}

一个简单的整数问题II


给定一个长度为 N 的数列 A,以及 M 条指令,每条指令可能是以下两种之一:

  1. C l r d,表示把 A[l],A[l+1],…,A[r] 都加上 d。
  2. Q l r,表示询问数列中第 l∼r 个数的和。

对于每个询问,输出一个整数表示答案。

输入格式

第一行两个整数 N,M。

第二行 N 个整数A[i]。

接下来 M 行表示 M 条指令,每条指令的格式如题目描述所示。

输出格式

对于每个询问,输出一个整数表示答案。

每个答案占一行。

数据范围

1≤N,M≤105,
|d|≤10000,
|A[i]|≤109

输入样例:

10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4

输出样例:

4
55
9
15

基本思路:与上一题相比的话,这里是在Q的时候,是一个区间查询,而不是单点和,这就比较蛋疼。为什么呢?因为我们知道,我们在C的时候,采用了用树状数组来维护一个差分数组的方法,但是这个区间和的查询本来应该是前缀和的事情,但是这等于是组了三个等级,我们如果在申请一个前缀和数组的话,维护起来比较麻烦,于是我们通过数学推导。

如何可以从差分数组中得出一段区间的和,那这样的话,我们只需要再次申请一个差分数组的变种即可。

推导如下:

a1 = b1 
a2 = b2 + b1
a3 = b3 + b2 + b1
    ...
an = bn + bn-1 + ... + b1
    
那我们如果要求
a1 + a2 + a3 + ... + an = 
b1 +
b1 + b2
b1 + b2 + b3
  ...
b1 + ... +  ... + ....  bn
    
我们把矩阵补全,
b1 + ... + ... + ... + bn
b1 +  ........ . . . . . 
b1 + b2 .... . . . . . . 
b1 + b2 + b3 ......  . .
  ...   .. . . . . . .. 
b1 + ... +  ... + ....  bn
    经推导可以得到:
Sn = (n+1) * Sb(n) - S(i*b(i))
    也就是说我们只需要在申请差分数组之外,在申请一个差分数组*i的差分数组即可。

    
    
    对于这个tri数组来说,add的就是tr数组的值*i。其余与tr数组并无大异。

逻辑就是这样,下面是代码展示:

#include 

using namespace std ;

const int N = 1e5 + 10 ; 

typedef long long LL ; 

int a[N] ; 

LL tr[N] , tri[N] ; 

int n , m  ; 

int lowbit(int x)
{
    return x & -x ; 
}

void add(int x , LL k , LL num[])
{
    for(int i = x ; i <= n ; i += lowbit(i))
    {
        num[i] += k ; 
    }
}

LL sum(int x , LL num[])
{
    LL res = 0 ; 
    for(int i = x ; i ;  i -= lowbit(i))
    {
        res += num[i] ; 
    }
    return res ; 
}

LL getsum(int x)
{
    return sum(x , tr) * (x+1) - sum(x , tri) ; 
}


int main()
{
    cin >> n >> m ; 
    
    for(int i = 1 ; i <= n ; ++ i)
    {
        cin >> a[i] ;
        add(i , a[i] - a[i-1] ,tr) ; 
        add(i , (LL)(a[i] - a[i-1]) * i , tri) ; 
    }
    
    while(m--)
    {
        string op ; 
        cin >> op ; 
        if(op[0] == 'C')
        {
            int l , r , d ;  cin >> l >> r >> d ; 
            add(l , d , tr) , add(r+1 , -d , tr) ; 
            add(l , d *l , tri) , add(r+1 , (-d)*(r+1) , tri) ; 
        }
        else
        {
            int l , r ; 
            cin >> l >> r ; 
            LL res = getsum(r) - getsum(l-1) ; 
            printf("%lld\n" , res) ; 
        }
    }
    return  0 ;
}

今天就更到这,下次在更…….


文章作者: 罗林
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 罗林 !
  目录