树链剖分——重链剖分,原理剖析,代码详解


零、前言

树链剖分是一种进行树上操作的在线算法,剖分的方式分为重链剖分和长链剖分。本文介绍重链剖分,其通过利用树上dfs时间戳的连续性,以及轻重链概念的引入,可以将一棵树转化为一段连续序列,一条路径转化为不超过 logn 段连续区间,结合线段树区间维护,可以较为快速的树上查询/修改。
关于线段树:高级搜索-线段树[C/C++]

一、重链剖分

1.1 概念引入

1.1.1 重儿子

重儿子:父节点所有儿子中,所在子树结点数目最多的结点。

1.1.2 重边

重边:父节点和重儿子连成的边。

1.1.3 重链

重链:由多条重边连接而成的路径。

如下图,红色边代表重边,黄色结点代表重儿子,红色边连接而成的路径就是重链

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

1.2 重要结论

  1. 整棵树会被剖分成若干条重链
  2. 轻儿子一定是每条重链的顶点
  3. 任意一条路径会被切分为O(logn)条重链

结论1、2不再证明,只证明结论3:

证明比较巧妙(似乎所有算法都是这样,雾)

我们只需证明根节点到任意叶子结点的路径上最多不超过logn条重链

我们考虑当前结点 cur ,父节点fa,son[fa]为fa的重儿子,size[u]代表结点u为根子树的大小

那么我们从叶子结点往上走,会经历若干重链,若干轻边,我们下面证明经过一条轻边后,size[cur]至少变为原来2倍

  • 假如size[son[fa]] <= size[fa] / 2,那么自然有son[cur] <= size[fa] / 2,那么向上走一条轻边自然size至少扩大二倍
  • 假如size[son[fa]] > size[fa] / 2,size[cur] < size[fa] - size[son[fa]] < size[fa] / 2,么向上走一条轻边size仍至少扩大二倍

得证,故从叶子结点向上最多走logn条轻边,而因为经过一条轻边一定会先经过一条重链,故重链也最多logn条,故得证,故结论3得证

1.3 重链剖分算法实现

1.3.1 所需数组

f[u]: u的父节点
dep[u]: u的深度
son[u]: u的重儿子
sz[u]: u所在子树大小
top[u]: u所在重链的顶点

1.3.2 算法流程

  • dfs1:处理fa、dep、sz、son
  • dfs2:处理top

1.3.3 算法实现

dfs预处理:O(n+m)

void dfs1(int x, int father){
    //深度、父节点、初始化sz
	dep[x] = dep[father] + 1, fa[x] = father, sz[x] = 1;
	for(int i = head[x]; ~i; i = edges[i].nxt){
		int y = edges[i].v;
		if(y == father) continue;
		dfs1(y, x);
		sz[x] += sz[y];	//累加sz
		if(sz[son[x]] < sz[y]) son[x] = y;	//维护重儿子
	}
}

void dfs2(int x, int t){
	top[x] = t;
	if(!son[x]) return;
	dfs2(son[x], t);	//重儿子和父亲的top相同
	for(int i = head[x]; ~i; i = edges[i].nxt){
		int y = edges[i].v;
		if(y == fa[x] || y == son[x]) continue;
		dfs2(y, y);	//轻儿子自己就是所在重链的顶点
	}
}

1.4 应用之LCA

重链剖分一个比较经典的应用就是求lca,效率比倍增要快些,不过二者代码都挺好写,当然tarjan也不错,不过个人感觉tarjan那个思想容易忘。

1.4.1 OJ链接

P3379 【模板】最近公共祖先(LCA) - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

1.4.2 算法流程

对于查询(x, y),二者到lca的路径上都满足不超过logn条重链,我们由于预处理了top和fa,我们可以通过x = fa[top[x]]跳到上一条重链

那么x和y各自向上跳不超过logn次就能到达同一条重链了,而且必然满足最终结果为x或y处于lca的位置,这个比较简单,可以自己想一下

1.4.3 AC代码

#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 5e5 + 10, M = 1e6 + 10;

int n, m, root, head[N], idx;
int sz[N], fa[N], son[N], top[N], dep[N];

struct edge{
	int v, nxt;
}edges[M];

void addedge(int u, int v){
	edges[idx] = { v, head[u] }, head[u] = idx++;
}

void add(int u, int v){
	addedge(u, v), addedge(v, u);
}

void dfs1(int x, int father){
	dep[x] = dep[father] + 1, fa[x] = father, sz[x] = 1;
	for(int i = head[x]; ~i; i = edges[i].nxt){
		int y = edges[i].v;
		if(y == father) continue;
		dfs1(y, x);
		sz[x] += sz[y];
		if(sz[son[x]] < sz[y]) son[x] = y;
	}
}

void dfs2(int x, int t){
	top[x] = t;
	if(!son[x]) return;
	dfs2(son[x], t);
	for(int i = head[x]; ~i; i = edges[i].nxt){
		int y = edges[i].v;
		if(y == fa[x] || y == son[x]) continue;
		dfs2(y, y);
	}
}

int lca(int x, int y){
	while(top[x] != top[y]){
		if(dep[top[x]] < dep[top[y]]) swap(x, y);
		x = fa[top[x]];
	}
	//此时已经在一条重链上
	return dep[x] < dep[y] ? x : y; 
}

int main(){
	//freopen("in.txt", "r", stdin);
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	memset(head, -1, sizeof head);
	cin >> n >> m >> root;
	for(int i = 0, a, b; i < n - 1; i++)
		cin >> a >> b, add(a, b);
	dfs1(root, 0), dfs2(root, root);
	for(int i = 0, a, b; i < m; i++)
		cin >> a >> b, cout << lca(a, b) << '\n';
	return 0;
}

1.5 树上查询/修改

如果只是用来求lca,那么就杀鸡用牛刀了,重链剖分的精髓其实在于树上查询/修改。

1.5.1 算法原理

我们尝试在dfs2的同时为每个结点打上时间戳,dfs中时间戳是一个非常重要的概念,我们有一个经典的结论就是:**同一个连通块内结点的时间戳集合是一个连续区间。**这也是tarjan算法的核心,因为我们深搜一定是搜完了当前连通块才会去搜其他连通块,所以同一连通块内的时间戳集合自然是连续的。

那么当我们在dfs2中为结点打上时间戳后,我们多了哪些可用信息?我们不妨记时间戳数组为id[]

  • 以u为根的子树内的结点的时间戳是一个连续区间——[id[u], id[u] + sz[u] - 1]
  • 从结点u所处重链顶点top[u]到u的路径可以由区间——[id[top[u]], id[u]]表示
  • 由于任意路径有不超过logn条重链,那么任意路径可以划分为不超过logn个连续区间

基于上面三条,如果树中结点带权的话,我们可以利用线段树和时间戳进行任意子树,任意路径的权值和的查询/修改

1.5.2 算法流程

初始化:

  • 除了前面的各个数组,还多了w[u]代表结点u的权值,id[u]代表结点u的时间戳,nw[x]代表时间戳为x的结点的权值
  • dfs1:处理fa、dep、sz、son
  • dfs2:处理top、id、nw
  • 线段树对nw[]递归建树

任意子树修改

子树u结点区间为 [id[u], id[u] + sz[u] - 1]

直接调用线段树区间修改接口即可

时间复杂度:O(log n)

任意子树查询

子树u结点区间为 [id[u], id[u] + sz[u] - 1]

直接调用线段树区间查询接口即可

时间复杂度:O(log n)

任意路径修改

  • 和lca的代码相似,每次可以从一条重链跳到上一条重链
  • 对于x,y间路径修改,我们不妨假设dep[top[x]] > dep[top[y]]
  • 那么我们对[id[top[x]], id[x]]进行区间修改
  • 然后跳到下一条重链,x = fa[top[x]]
  • 直到x,y处于同一条重链
  • 此时进行最后一次区间修改[id[y], id[x]],(不妨假设dep[x] >dep[y])

这样就完成了路径上所有重链上的结点的修改,自然完成了路径修改

时间复杂度:O(log^2 n)

任意路径查询

和路径修改相同,只不过不是进行区间修改而是区间查询

时间复杂度:O(log^2 n)

1.5.3 代码实现

void dfs1(int u, int father, int dep)
{ // 父子关系以及sz处理
    d[u] = dep, fa[u] = father, sz[u] = 1;
    for (int i = head[u]; ~i; i = edges[i].nxt)
    {
        int v = edges[i].v;
        if (v == father)
            continue;
        dfs1(v, u, dep + 1);
        sz[u] += sz[v];
        if (sz[son[u]] < sz[v])
            son[u] = v;
    }
}
void dfs2(int u, int t)
{
    nw[id[u] = ++tot] = w[u], top[u] = t;	//tot用来记录当前时间戳
    if (!son[u])
        return;
    dfs2(son[u], t);
    for (int i = head[u]; ~i; i = edges[i].nxt)
    {
        int v = edges[i].v;
        if (v == son[u] || v == fa[u])
            continue;
        dfs2(v, v);
    }
}

//void pushup(int p)	
//void pushdown(int p)	标记下传
//void update(int p, int l, int r, int k)	区间修改
//LL query(int p, int l, int r) 	区间查询
void build(int p, int l, int r)	//递归建树
{
    tr[p] = {l, r, nw[l]};
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    build(lc, l, mid), build(rc, mid + 1, r);
    pushup(p);
}

void update_path(int x, int y, int k)	//路径修改
{
    while (top[x] != top[y])
    {
        if (d[top[x]] < d[top[y]])
            swap(x, y);
        update(1, id[top[x]], id[x], k);
        x = fa[top[x]];
    }
    if (d[x] < d[y])
        swap(x, y);
    update(1, id[y], id[x], k);
}
LL query_path(int x, int y)	//路径查询
{
    LL res = 0;
    while (top[x] != top[y])
    {
        if (d[top[x]] < d[top[y]])
            swap(x, y);
        res = (res + query(1, id[top[x]], id[x])) % mod;
        x = fa[top[x]];
    }
    if (d[x] < d[y])
        swap(x, y);
    res = (res + query(1, id[y], id[x])) % mod;
    return res;
}
void update_tr(int x, int k)	//子树修改
{
    update(1, id[x], id[x] + sz[x] - 1, k);
}
LL query_tr(int x)	//子树查询
{
    return query(1, id[x], id[x] + sz[x] - 1);
}
//main
    dfs1(root, -1, 1);
    dfs2(root, root);
    build(1, 1, n);

1.6 OJ练习

1.6.1 重链剖分模板

1.6.1.1 原题链接

P3384 【模板】重链剖分/树链剖分 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

1.6.1.2 思路分析

板子题,复现上面的算法流程即可

1.6.1.3 AC代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
#define lc p << 1
#define rc p << 1 | 1
const int N = 1e5 + 10, M = N << 1;
int n, m, root, mod;
struct edge
{
    int v, nxt;
} edges[M];
int head[N], idx;
int w[N], id[N], nw[N], tot;
int d[N], sz[N], top[N], fa[N], son[N];
struct node
{
    int l, r;
    LL sum, tag;
} tr[N << 2];
void addedge(int u, int v)
{
    edges[idx] = {v, head[u]}, head[u] = idx++;
}
void add(int u, int v)
{
    addedge(u, v), addedge(v, u);
}
void dfs1(int u, int father, int dep)
{ // 父子关系以及sz处理
    d[u] = dep, fa[u] = father, sz[u] = 1;
    for (int i = head[u]; ~i; i = edges[i].nxt)
    {
        int v = edges[i].v;
        if (v == father)
            continue;
        dfs1(v, u, dep + 1);
        sz[u] += sz[v];
        if (sz[son[u]] < sz[v])
            son[u] = v;
    }
}
void dfs2(int u, int t)
{
    nw[id[u] = ++tot] = w[u], top[u] = t;
    if (!son[u])
        return;
    dfs2(son[u], t);
    for (int i = head[u]; ~i; i = edges[i].nxt)
    {
        int v = edges[i].v;
        if (v == son[u] || v == fa[u])
            continue;
        dfs2(v, v);
    }
}

void pushup(int p)
{
    tr[p].sum = (tr[lc].sum + tr[rc].sum) % mod;
}
void build(int p, int l, int r)
{
    tr[p] = {l, r, nw[l]};
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    build(lc, l, mid), build(rc, mid + 1, r);
    pushup(p);
}
void pushdown(int p)
{
    if (tr[p].tag)
    {
        tr[lc].sum = (tr[lc].sum + (tr[lc].r - tr[lc].l + 1) * tr[p].tag + mod) % mod;
        tr[rc].sum = (tr[rc].sum + (tr[rc].r - tr[rc].l + 1) * tr[p].tag + mod) % mod;
        tr[lc].tag += tr[p].tag, tr[rc].tag += tr[p].tag;
        tr[p].tag = 0;
    }
}
void update(int p, int l, int r, int k)
{
    if (l <= tr[p].l && tr[p].r <= r)
    {
        tr[p].tag += k, tr[p].sum = (tr[p].sum + (tr[p].r - tr[p].l + 1) * k + mod) % mod;
        return;
    }
    pushdown(p);
    int mid = (tr[p].l + tr[p].r) >> 1;
    if (l <= mid)
        update(lc, l, r, k);
    if (r > mid)
        update(rc, l, r, k);
    pushup(p);
}
LL query(int p, int l, int r)
{
    if (l <= tr[p].l && tr[p].r <= r)
    {
        return tr[p].sum;
    }
    pushdown(p);
    int mid = (tr[p].l + tr[p].r) >> 1;
    LL res = 0;
    if (l <= mid)
        res = (res + query(lc, l, r)) % mod;
    if (r > mid)
        res = (res + query(rc, l, r)) % mod;
    return res;
}
void update_path(int x, int y, int k)
{
    while (top[x] != top[y])
    {
        if (d[top[x]] < d[top[y]])
            swap(x, y);
        update(1, id[top[x]], id[x], k);
        x = fa[top[x]];
    }
    if (d[x] < d[y])
        swap(x, y);
    update(1, id[y], id[x], k);
}
LL query_path(int x, int y)
{
    LL res = 0;
    while (top[x] != top[y])
    {
        if (d[top[x]] < d[top[y]])
            swap(x, y);
        res = (res + query(1, id[top[x]], id[x])) % mod;
        x = fa[top[x]];
    }
    if (d[x] < d[y])
        swap(x, y);
    res = (res + query(1, id[y], id[x])) % mod;
    return res;
}
void update_tr(int x, int k)
{
    update(1, id[x], id[x] + sz[x] - 1, k);
}
LL query_tr(int x)
{
    return query(1, id[x], id[x] + sz[x] - 1);
}
int main()
{
    //freopen("in.txt", "r", stdin);
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    memset(head, -1, sizeof head);
    cin >> n >> m >> root >> mod;
    for (int i = 1; i <= n; i++)
        cin >> w[i], w[i] %= mod;
    for (int i = 1, a, b; i < n; i++)
        cin >> a >> b, add(a, b);

    dfs1(root, -1, 1);
    dfs2(root, root);
    build(1, 1, n);
    for (int i = 0, a, b, c; i < m; i++)
    {
        cin >> a;
        if (a == 1)
        {
            cin >> a >> b >> c;
            update_path(a, b, c);
        }
        else if (a == 2)
        {
            cin >> a >> b;
            cout << query_path(a, b) << '\n';
        }
        else if (a == 3)
        {
            cin >> a >> b;
            update_tr(a, b);
        }
        else
        {
            cin >> a;
            cout << query_tr(a) << '\n';
        }
    }
    /*
    1 x y z,表示将树从 x 到 y 结点最短路径上所有节点的值都加上 z

    2 x y,表示求树从 x 到 y 结点最短路径上所有节点的值之和。

    3 x z,表示将以 x 为根节点的子树内所有节点值都加上 z。

    4 x 表示求以 x 为根节点的子树内所有节点值之和
    */
    return 0;
}

1.6.2 [NOI2015] 软件包管理器

1.6.2.1 原题链接

[P2146 NOI2015] 软件包管理器 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

1.6.2.1 思路分析

这个题就很板子,而且很舒服,因为fa数组直接给你了

那么题目两个操作的指向性很强:

install x就是从根到x路径全变1

uninstall x就是从根到x路径全变0

这个其实还是重剖板子题,甚至更简单,我们只需要想一下怎么处理线段树的标记即可。

对于线段树的标记,如果为-1,则表示无标记,不需下传,如果为1,则代表结点左右子区间全变1,如果为0,则代表结点左右子区间全变0,这个标记下传自然是可以实现的

对于每次要输出改变多少结点状态,我们先存一下根节点的sum值也就是整个区间的和,然后跟修改后的根结点sum值做差即可

1.6.2.3 AC代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
#define lc p << 1
#define rc p << 1 | 1
int n, q, head[N], idx, tot;
int id[N], fa[N], d[N], w[N], sz[N], top[N], son[N];
struct edge{
	int v, nxt;
}edges[N];
void addedge(int u, int v){
	edges[idx] = { v, head[u] }, head[u] = idx++;
}
struct node{
	int l, r, sum, tag;
}tr[N << 2];
void build(int p, int l, int r){
	tr[p] = { l, r, 0, -1 };
	if(l == r) return;
	int mid = (tr[p].l + tr[p].r) >> 1;
	build(lc, l, mid), build(rc, mid + 1, r);
}
void pushup(int p){
	tr[p].sum = tr[lc].sum + tr[rc].sum;
}
void pushdown(int p){
	if(~tr[p].tag){
		tr[lc].sum = (tr[lc].r - tr[lc].l  + 1) * tr[p].tag;
		tr[rc].sum = (tr[rc].r - tr[rc].l  + 1) * tr[p].tag;
		tr[lc].tag = tr[rc].tag = tr[p].tag;
		tr[p].tag = -1;
	}
}
void update(int p, int l, int r, int k){
	if(l <= tr[p].l && tr[p].r <= r){
		tr[p].tag = k, tr[p].sum = (tr[p].r - tr[p].l + 1) * k;
		return;
	}
	pushdown(p);
	int mid = (tr[p].l + tr[p].r) >> 1;
	if(l <= mid) update(lc, l, r, k);
	if(r > mid) update(rc, l, r, k);
	pushup(p);
}
int query(int p, int l, int r){
	if(l <= tr[p].l && tr[p].r <= r){
		return tr[p].sum;
	}
	pushdown(p);
	int mid = (tr[p].l + tr[p].r) >> 1, ret = 0;
	if(l <= mid) ret += query(lc, l, r);
	if(r > mid) ret += query(rc, l, r);
	return ret;
}
void update_path(int x, int y, int k){
	
	while(top[x] != top[y]){
		if(d[top[x]] < d[top[y]]) swap(x, y);
		update(1, id[top[x]], id[x], k);
		x = fa[top[x]];
	}
	if(d[x] < d[y]) swap(x, y);
	update(1, id[y], id[x], k);
}
void update_tr(int x, int k){
	update(1, id[x], id[x] + sz[x] - 1, k);
}

void dfs1(int u, int dep){
	d[u] = dep, sz[u] = 1;
	for(int i = head[u]; ~i; i = edges[i].nxt){
		int v = edges[i].v;
		dfs1(v, dep + 1);
		sz[u] += sz[v];
		if(sz[son[u]] < sz[v]) son[u] = v;
	}
}

void dfs2(int u, int t){
	id[u] = ++tot, top[u] = t;
	if(!son[u]) return;
	dfs2(son[u], t);
	for(int i = head[u]; ~i; i = edges[i].nxt){
		int v = edges[i].v;
		if(v != son[u])
			dfs2(v, v);
	}
}
int main(){
	//freopen("in.txt", "r", stdin);
	ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
	memset(head, -1, sizeof head);
	cin >> n;
	for(int i = 2; i <= n; i++) cin >> fa[i], addedge(++fa[i], i);
	dfs1(1, 1), dfs2(1, 1), build(1, 1, n);
	cin >> q;
	string opt;
	for(int i = 0, x; i < q; i++){
		cin >> opt >> x, ++x;
		if(opt[0] == 'i'){
			int t = tr[1].sum;
			update_path(1, x, 1);
			cout << tr[1].sum - t << '\n';
		}
		else{
			int t = tr[1].sum;
			update_tr(x, 0);
			cout << t - tr[1].sum << '\n';		
		}
	}
	return 0;
}

相关推荐

  1. opengl polygon 三角

    2024-04-09 05:06:02       17 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-09 05:06:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-09 05:06:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-09 05:06:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-09 05:06:02       18 阅读

热门阅读

  1. hash模式和history模式的区别

    2024-04-09 05:06:02       18 阅读
  2. 关于SpringBoot的配置文件

    2024-04-09 05:06:02       13 阅读
  3. MySQL-commit,rollback

    2024-04-09 05:06:02       13 阅读
  4. 探索 C++ 中的 string 类

    2024-04-09 05:06:02       10 阅读
  5. Inotify

    Inotify

    2024-04-09 05:06:02      11 阅读
  6. PCL 三角形到三角形的距离

    2024-04-09 05:06:02       11 阅读
  7. 计算机病毒传播原理

    2024-04-09 05:06:02       15 阅读
  8. VPS入门指南:理解并有效利用虚拟专用服务器

    2024-04-09 05:06:02       12 阅读
  9. 力扣由浅至深 每日一题.23 Nim 游戏

    2024-04-09 05:06:02       11 阅读
  10. 测试细节的测试工程师

    2024-04-09 05:06:02       11 阅读