权值线段树

yuanheci 2022年12月08日 458次浏览

  线段树的两个拓展:动态开点线段树权值线段树。

  线段树有两种写法:堆形式存储结点形式存储

  权值线段树可以用堆形式也可以用动态开点。
  一般用堆形式才需要离散化,因为这样就可以一开始建树建成4 * NN是数量,不会很大)的,否则就需要建成4 * M的(M是值域·,很大,会MLE),堆形式在建树的时候是需要实实在在建出来的,因此
MEL

  动态开点的话一般就不用离散化了(而且动态开点还有一个重要功能是能保持在线)。
因为每次是用到的时候才现场开辟结点。因此虽然定义的时候是tr[M * 2],但是程序运行过程
中真正开辟的空间其实是N * 2范围的,因此不会MLE


  通常来说,线段树占用空间是总区间长的常数倍,空间复杂度是 。然而,有时候很巨大,而我们又不需要使用所有的节点,这时便可以动态开点——不再一次性建好树,而是一边修改、查询一边建立。我们不再用堆形式中p*2p*2+1代表左右儿子,而是用lr记录左右儿子的编号。设总查询次数为 m,则这样的总空间复杂度为 mlogn

前置知识——桶
  桶是一种数据结构。数据结构的用途是以一种特殊方式统计数据,使得我们能够快速地修改、查询我们想要的那部分数据。但是一般我们在想要统计一组数据的时候,我们更关注的是这些数据都是什么。就比如我们现在要统计一个数列,我们更关心的是这个数列里到底有那些数,而不是特别关心这些数都出现了几次。
  桶就打破了这个现状,作为一种“特殊”的数据结构,它所统计的就是每个数据在数据集合中一共出现了多少次

  权值线段树维护的是桶,按值域开空间,维护的是个数。
  简单线段树维护的是信息,按个数(下标)开空间,维护的是特定信息。

权值线段树的用途
  权值线段树可以解决数列第k大/小的问题。

  这里要注意!我们只能对给定数列解决整个数列的第k大/小,并不能解决数列的子区间的第k大/小。如果要求给定区间[l, r]中的第k大/小,那就得用主席树了。

  当权值线段树给出的权值有很大的情况时,如1e9,直接存储肯定是会爆空间的,此时可以有两种处理方式:
(1)离线方式:离散化~ yyds
(2)在线方式:动态开点(这里学习的是蓝书上的动态开点写法,用build()返回,也有其他写法(传引用),选择自己喜欢的开点方式就可以啦~·);

练习题:《第k小正数》


堆形式权值线段树 + 离散化

#include <bits/stdc++.h>
using namespace std;

const int N = 10010;
struct Node{
	int l, r;
	int cnt;
}tr[N * 4];

vector<int> a;
int n, k;

int find(int x){
	int l = 0, r = a.size() - 1;
	while(l < r){
		int mid = l + r >> 1;
		if(a[mid] >= x) r = mid;
		else l = mid + 1; 
	}
	return l;
}

void pushup(int u){
	tr[u].cnt = tr[u << 1].cnt + tr[u << 1 | 1].cnt; 
}

void build(int u, int l, int r){
	tr[u] = {l, r};
	if(l != r){
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
	}
}

void modify(int u, int x){
	if(tr[u].l == x && tr[u].r == x) tr[u].cnt++;
	else{
		int mid = tr[u].l + tr[u].r >> 1;
		if(x <= mid) modify(u << 1, x);
		else modify(u << 1 | 1, x);
		pushup(u);
	}
}

int query(int u, int k){
	if(tr[u].l == tr[u].r) return tr[u].l;
	else{
		int cnt = tr[u << 1].cnt;
		if(k <= cnt) return query(u << 1, k);
		else return query(u << 1 | 1, k - cnt);
	}
}

int main(){
	scanf("%d%d", &n, &k);
	for(int i = 1; i <= n; i++){
		int x;
		scanf("%d", &x);
		a.push_back(x);
	}
	sort(a.begin(), a.end());
	a.erase(unique(a.begin(), a.end()), a.end());
	if(a.size() < k) puts("NO RESULT");
	else{
		build(1, 0, a.size() - 1);
		for(int i = 0; i < a.size(); i++){
			modify(1, find(a[i]));
		}
		printf("%d\n", a[query(1, k)]);
	}
	
	return 0;
} 

真正的动态开点写法(没用带离散化,一般用在题目要求在线的情况下)
注意: 动态开点的精髓在于需要用到了再去开结点从而防止MLE,而预先tr[N * 2]比较大是没关系的,因为c++c++ 只有真正用到了空间才会去开辟,因此一开始开大一点没关系,最终的空间复杂度只取决于程序运行过程中实际需要开辟的空间大小。

#include <bits/stdc++.h>
using namespace std;

const int N = 30010;
struct Node{
	int l, r;
	int cnt;
}tr[N * 2];

int n, k;
int root, idx;
bool st[N];
int cnt;

int build(){
	int p = ++idx;
	tr[p].l = tr[p].r = tr[p].cnt = 0;
	return p;
}

void pushup(int u){
	tr[u].cnt = tr[tr[u].l].cnt + tr[tr[u].r].cnt;
}

void insert(int u, int l, int r, int x){
	if(l == r) tr[u].cnt++;
	else{
		int mid = l + r >> 1;
		if(x <= mid) {
			if(!tr[u].l) tr[u].l = build();
			insert(tr[u].l, l, mid, x);
		} 
		else{
			if(!tr[u].r) tr[u].r = build();
			insert(tr[u].r, mid + 1, r, x);
		}
		pushup(u);
	}
}

int query(int u, int l, int r, int k){
	if(l == r) return l;
	else{
		int mid = l + r >> 1;
		int cc = tr[tr[u].l].cnt;
		if(k <= cc) return query(tr[u].l, l, mid, k);
		else return query(tr[u].r, mid + 1, r, k - cc);
	}
}

int main(){
	scanf("%d%d", &n, &k);
	root = build();
	for(int i = 0; i < n; i++){
		int x;
		scanf("%d", &x); 
		if(st[x]) continue;
		st[x] = true;
		cnt++;
		insert(root, 1, N - 1, x);
	}
	if(cnt < k) puts("NO RESULT");
	else printf("%d\n", query(root, 1, N - 1, k));
	
	return 0;
}

动态开点权值线段树 + 离散化

#include <bits/stdc++.h>
using namespace std;

const int N = 10010;

struct Node{
	int l, r;
	int cnt;
}tr[N * 2];

vector<int> nums;
int a[N];
int n, k, idx, root;

int find(int x){
	int l = 0, r = nums.size() - 1;
	while(l < r){
		int mid = l + r >> 1;
		if(nums[mid] >= x) r = mid;
		else l = mid + 1;
	}
	return l;
}

//动态开点的权值线段树 
int build(){
	int p = ++idx;
	tr[p].l = tr[p].r = tr[p].cnt = 0;
	return p;
}

void insert(int u, int l, int r, int x){
	if(l == r){
		tr[u].cnt++;
		return;
	}
	int mid = l + r >> 1;
	if(x <= mid) {
		if(!tr[u].l) tr[u].l = build();
		insert(tr[u].l, l, mid, x);
	}
	else if(x > mid) {
		if(!tr[u].r) tr[u].r = build();
		insert(tr[u].r, mid + 1, r, x);
	}
	tr[u].cnt = tr[tr[u].l].cnt + tr[tr[u].r].cnt;
}

int query(int u, int l, int r, int k){
	if(l == r) return l;
	int cnt = tr[tr[u].l].cnt;
	int mid = l + r >> 1;
	if(k <= cnt) return query(tr[u].l, l, mid, k);
	else return query(tr[u].r, mid + 1, r, k - cnt);
}

int main(){
	scanf("%d%d", &n, &k);
	for(int i = 1; i <= n; i++){
		scanf("%d", &a[i]);
		nums.push_back(a[i]);
	}
	
	sort(nums.begin(), nums.end());
	nums.erase(unique(nums.begin(), nums.end()), nums.end());
	
	if(nums.size() < k){
		puts("NO RESULT");
		return 0;
	}
	
	root = build();
	
	for(int i = 0; i < nums.size(); i++){    //一开始写的是 <=n ...逆天 
		insert(root, 0, nums.size() - 1, find(nums[i]));
	}
	
	printf("%d\n", nums[query(root, 0, nums.size() - 1, k)]);
	
	return 0;
}