因为想要学可持久化平衡树,但是之前用平衡树基本都是splay这种需要旋转的,不利于可持久化,所以今天来学一下fhq-treap这种不需要旋转的平衡树
fhq-treap是一种基于分裂(split)和合并(merge)的一种treap,下文将会对其各个操作是如何利用分裂与合并进行详细说明
定义
下方代码全部基于以下定义:
int root, idx; // 分别表示根结点编号和当前用到哪个结点
int val[N]; // 结点权值
int pri[N]; // 结点优先级
int sz[N]; // 结点子树大小
int ch[N][2]; // 结点左右儿子
fhq-treap依然是一棵中序遍历按照val排序的平衡树,但是它的结构依赖于pri优先级,一个结点的两个儿子的pri一定比这个结点要小,然后我们随机取pri,就可以保证treap基本平衡
分裂 split
void split(int u, int x, int& lt, int& rt) // 把treap分成小于等于x和大于x的两个部分
{
if (u == 0)
{
lt = rt = 0;
return;
}
if (val[u] <= x) // 当前点小于x,说明分界点在右儿子
{
lt = u;
split(ch[u][1], x, ch[u][1], rt);
}
else // 当前点大于x,说明分界点在左儿子
{
rt = u;
split(ch[u][0], x, lt, ch[u][0]);
}
pushup(u); // 更新结点u信息
}
合并 merge
把要合并的两棵子树的根结点记为 lt
和 rt
,如果 lt
的优先级比 rt
大 就把 rt
合并到 lt
的右子树,否则把 lt
合并到 rt
的左子树
int merge(int lt, int rt) // 合并根结点编号为lt和rt的两棵treap
{
if (lt == 0 || rt == 0) return lt + rt;
if (pri[lt] > pri[rt]) // lt的优先级比rt大 就把rt合并到lt的右子树
{
ch[lt][1] = merge(ch[lt][1], rt);
pushup(lt);
return lt;
}
else // rt的优先级比lt大 就把lt合并到rt的左子树
{
ch[rt][0] = merge(lt, ch[rt][0]);
pushup(rt);
return rt;
}
}
插入 insert
插入就是创建单个结点的treap然后进行合并
void getnode(int x) // 创建权值为x的新结点
{
sz[++ idx] = 1;
ch[idx][0] = ch[idx][1] = 0;
val[idx] = x;
pri[idx] = rand(); // 优先值取随机数保证树形状随机
}
void insert(int x) // 将权值为x的新结点插入treap
{
int lt, rt;
split(root, x, lt, rt); // 先把原treap分成小于等于x和大于x两个部分
getnode(x); // 创建值为x的结点
root = merge(merge(lt, idx), rt); // 先合并小于等于x的子树和新结点 再将其与大于x的子树合并
return; // 返回新结点编号
}
删除 del
删除单个结点
void del(int x) // 删去权值为x的【一个】结点
{
int lt, md, rt;
split(root, x, lt, rt); // 先把原treap分成小于等于x和大于x两个部分
split(lt, x - 1, lt, md); // 再把小于等于x的部分分成小于等于x-1和大于x-1两个部分
// 此时以md为根结点的子树内所有结点的权值都是x
md = merge(ch[md][0], ch[md][1]); // 合并md的左右儿子 即删去md这个结点
root = merge(merge(lt, md), rt); // 先合并小于等于x-1的部分和等于x的部分 再将其与大于x的子树合并
return; // 返回根结点编号
}
删除所有值为 x 的结点
void del(int x) // 删去权值为x的【所有】结点
{
int lt, md, rt;
split(root, x, lt, rt); // 先把原treap分成小于等于x和大于x两个部分
split(lt, x - 1, lt, md); // 再把小于等于x的部分分成小于等于x-1和大于x-1两个部分
// 此时以md为根结点的子树内所有结点的权值都是x 直接删掉
root = merge(lt, rt); // 合并小于等于x-1和大于x的部分
return; // 返回根结点编号
}
根据值查询排名 rk
int rk(int x) // 查询权值为x的结点排名
{
int lt, rt, res;
split(root, x - 1, lt, rt); // 把原treap分成小于等于x-1和大于x-1两个部分
res = sz[lt] + 1; // 权值为x的结点编号就是小于等于x-1的结点个数+1
root = merge(lt, rt); // 恢复treap
return res;
}
查询第k大的结点编号
int kth(int u, int k) // 在以u为根的子树中查找排名为k的结点【编号】
{
int cur = u;
while (1)
{
if (k <= sz[ch[cur][0]]) cur = ch[cur][0];
else
{
k -= sz[ch[cur][0]] + 1;
if (k <= 0) return cur;
else cur = ch[cur][1];
}
}
}
查询前驱结点编号
int pre(int x) // 找权值为x的结点前驱【编号】
{
int lt, rt, res;
split(root, x - 1, lt, rt); // 把原treap分成小于等于x-1和大于x-1两个部分
res = kth(lt, sz[lt]); // 在小于等于x-1的部分找权值最大的结点编号
root = merge(lt, rt); // 恢复treap
return res;
}
查询后继结点编号
int suf(int x)
{
int lt, rt, res;
split(root, x, lt, rt); // 把原treap分成小于等于x和大于x两个部分
res = val[kth(rt, 1)]; // 在大于x的部分找权值最小的结点编号
root = merge(lt, rt); // 恢复treap
return res;
}
例题
P3369 【模板】普通平衡树
#include <bits/stdc++.h>
using namespace std;
#define int long long
using i64 = long long;
typedef pair<int, int> PII;
typedef pair<int, char> PIC;
typedef pair<double, double> PDD;
typedef pair<int, PII> PIII;
typedef pair<int, pair<int, bool>> PIIB;
const int N = 1e5 + 10;
const int mod = 1e9 + 7;
const int maxn = 1e6 + 10;
const int mod1 = 954169327;
const int mod2 = 906097321;
const int INF = 0x3f3f3f3f3f3f3f3f;
int root, idx; // 分别表示根结点编号和当前用到哪个结点
int val[N]; // 结点权值
int pri[N]; // 结点优先级
int sz[N]; // 结点子树大小
int ch[N][2]; // 结点左右儿子
void pushup(int u)
{
sz[u] = sz[ch[u][0]] + sz[ch[u][1]] + 1;
}
void split(int u, int x, int& lt, int& rt)
{
if (u == 0)
{
lt = rt = 0;
return;
}
if (val[u] <= x)
{
lt = u;
split(ch[u][1], x, ch[u][1], rt);
}
else
{
rt = u;
split(ch[u][0], x, lt, ch[u][0]);
}
pushup(u);
}
int merge(int lt, int rt)
{
if (lt == 0 || rt == 0) return lt + rt;
if (pri[lt] > pri[rt])
{
ch[lt][1] = merge(ch[lt][1], rt);
pushup(lt);
return lt;
}
else
{
ch[rt][0] = merge(lt, ch[rt][0]);
pushup(rt);
return rt;
}
}
void getnode(int x)
{
sz[++ idx] = 1;
ch[idx][0] = ch[idx][1] = 0;
val[idx] = x;
pri[idx] = rand();
}
int insert(int x)
{
int lt, rt;
split(root, x, lt, rt);
getnode(x);
root = merge(merge(lt, idx), rt);
return idx;
}
int del(int x)
{
int lt, md, rt;
split(root, x, lt, rt);
split(lt, x - 1, lt, md);
md = merge(ch[md][0], ch[md][1]);
root = merge(merge(lt, md), rt);
return root;
}
int rk(int x)
{
int lt, rt, res;
split(root, x - 1, lt, rt);
res = sz[lt] + 1;
root = merge(lt, rt);
return res;
}
int kth(int u, int k)
{
int cur = u;
while (1)
{
if (k <= sz[ch[cur][0]]) cur = ch[cur][0];
else
{
k -= sz[ch[cur][0]] + 1;
if (k <= 0) return cur;
else cur = ch[cur][1];
}
}
}
int pre(int x)
{
int lt, rt, res;
split(root, x - 1, lt, rt);
res = val[kth(lt, sz[lt])];
root = merge(lt, rt);
return res;
}
int suf(int x)
{
int lt, rt, res;
split(root, x, lt, rt);
res = val[kth(rt, 1)];
root = merge(lt, rt);
return res;
}
void solve()
{
int n;
cin >> n;
while (n -- )
{
int op, x;
cin >> op >> x;
if (op == 1) insert(x);
else if (op == 2) del(x);
else if (op == 3)
{
insert(x);
cout << rk(x) << '\n';
del(x);
}
else if (op == 4) cout << val[kth(root, x)] << '\n';
else if (op == 5)
{
insert(x);
cout << pre(x) << '\n';
del(x);
}
else if (op == 6)
{
insert(x);
cout << suf(x) << '\n';
del(x);
}
}
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int t = 1;
// cin >> t;
while (t--)
{
solve();
}
}