python实现回溯算法

什么是回溯算法?

回溯算法是一种经典的解决组合优化问题、搜索问题以及求解决策问题的算法。它通过不断地尝试各种可能的候选解,并在尝试过程中搜索问题的解空间,直到找到问题的解或者确定问题无解为止。回溯算法常用于解决诸如排列、组合、子集、棋盘类等问题。

下面是用Python实现回溯算法的示例代码,并对其进行详细解释:

def backtrack(candidate, path, result):
    if 满足结束条件:  # 如果已经满足结束条件
        result.append(path[:])  # 将当前路径添加到结果中
        return

    for 选择 in 候选集:  # 遍历所有候选选择
        if 当前选择合法:  # 如果当前选择是合法的
            做出选择  # 将当前选择添加到路径中
            backtrack(新的候选集, 新的路径, 结果)  # 递归调用,继续探索下一层决策树
            撤销选择  # 回溯,撤销当前选择,尝试其他选择

上述代码是回溯算法的一般模板。详细解释每个部分的含义:

1.backtrack 函数:这是回溯算法的核心函数。它接受三个参数:candidate 是当前的候选集,path 是当前的路径,result 是保存最终结果的列表。
2.结束条件:在 backtrack 函数的开头,我们先判断是否满足结束条件。如果满足结束条件,即已经找到了一个解,就将当前路径 path 添加到结果 result 中,并立即返回。
3.遍历候选集:在回溯算法中,我们会对当前的候选集进行遍历,尝试每一种可能的选择。
4.合法性检查:在遍历候选集的过程中,我们会对每一个候选选择进行合法性检查,判断当前选择是否符合问题的要求。
5.做出选择:如果当前选择是合法的,我们会将其添加到路径中,表示我们已经做出了一次选择。
6.递归调用:接着,我们会递归调用 backtrack 函数,继续探索下一层决策树。在递归调用中,我们会更新候选集、路径等参数,以便于进行下一层的选择。
7.撤销选择:在递归调用返回后,表示我们已经探索完了当前选择所导致的所有可能性,需要进行回溯。在回溯的过程中,我们会撤销当前选择,尝试其他选择,以便进一步探索其他可能的解。

这就是回溯算法的基本实现思路。通过不断地尝试各种可能的选择,并在尝试过程中搜索解空间,回溯算法能够有效地解决各种组合优化问题和搜索问题。

经典例题

1.N 皇后问题

给定一个 N × N 的棋盘,在棋盘上放置 N 个皇后,使得它们互相不能攻击,即任意两个皇后不能处于同一行、同一列或同一斜线上。

递归回溯:在每一行放置一个皇后,递归地处理下一行的放置。
合法性检查:在放置皇后时,检查当前位置是否与已放置皇后冲突。
回溯撤销:如果当前位置不能放置皇后,则撤销当前选择,尝试其他选择。

def solve_n_queens(n):
    def is_valid(row, col, queens):
        for r, c in queens:
            if r == row or c == col or abs(row - r) == abs(col - c):
                return False
        return True

    def backtrack(row, queens):
        if row == n:
            result.append(queens[:])
            return
        for col in range(n):
            if is_valid(row, col, queens):
                queens.append((row, col))
                backtrack(row + 1, queens)
                queens.pop()

    result = []
    backtrack(0, [])
    return [['.' * col + 'Q' + '.' * (n - col - 1) for row, col in solution] for solution in result]

# 测试
n = 4
print(solve_n_queens(n))

1.solve_n_queens 函数接受一个参数 n,表示棋盘的大小。
2.is_valid 函数用于检查当前位置 (row, col) 是否与已放置的皇后冲突,如果冲突则返回 False,否则返回 True。
3.backtrack 函数是回溯的核心函数,它递归地在每一行放置皇后,并进行合法性检查,如果合法则继续放置下一行的皇后,如果不合法则进行回溯撤销。
最后,返回所有合法的解。每个解使用二维列表表示,其中每个列表元素表示棋盘中一行的布局,‘Q’ 表示放置了皇后的位置,‘.’ 表示空白位置。

2.组合总和

给定一个候选数组 candidates 和一个目标数 target,找出候选数组中所有可以使数字和为目标数的组合。同一个数字可以被选取多次。
递归回溯:在每一层递归中,尝试使用当前候选数组中的数字来组合成目标数。
剪枝优化:在递归过程中,如果当前数字大于目标数,则可以提前结束递归。
去重处理:为了避免重复的组合,我们可以规定每次选择的数字必须不小于上一个选择的数字。

def combination_sum(candidates, target):
    def backtrack(start, path, target):
        if target == 0:
            result.append(path[:])
            return
        for i in range(start, len(candidates)):
            if candidates[i] > target:
                continue
            path.append(candidates[i])
            backtrack(i, path, target - candidates[i])
            path.pop()

    result = []
    candidates.sort()
    backtrack(0, [], target)
    return result

# 测试
candidates = [2, 3, 6, 7]
target = 7
print(combination_sum(candidates, target))

1.combination_sum 函数接受两个参数:candidates 是候选数组,target 是目标数。
2.backtrack 函数是回溯的核心函数,它接受三个参数:start 表示当前可选的起始索引,path 是当前的组合,target 是当前的目标数。
3.在 backtrack 函数中,如果 target 等于 0,则表示已经找到了一个组合,将当前组合 path 添加到结果列表 result 中。
4.然后,我们遍历候选数组中的数字,并递归调用 backtrack 函数进行下一层的组合。在递归调用中,我们更新 start 参数,以避免重复的组合。
5.如果当前数字大于目标数,则直接跳过,进行剪枝优化。
最后,返回所有找到的组合。

3.全排列

给定一个没有重复数字的序列,返回其所有可能的全排列

递归回溯:在每一层递归中,尝试使用当前可选的数字进行排列。
标记已使用:在递归过程中,需要标记已经使用过的数字,避免重复使用。
处理结果:当达到排列长度时,将当前排列添加到结果列表中。

def permute(nums):
    def backtrack(path):
        if len(path) == len(nums):
            result.append(path[:])
            return
        for num in nums:
            if num in path:
                continue
            path.append(num)
            backtrack(path)
            path.pop()

    result = []
    backtrack([])
    return result

# 测试
nums = [1, 2, 3]
print(permute(nums))

1.permute 函数接受一个参数 nums,表示输入的序列。
2.backtrack 函数是回溯的核心函数,它接受一个参数 path,表示当前的排列。
3.在 backtrack 函数中,如果当前排列长度等于输入序列长度,则表示已经找到一个全排列,将其添加到结果列表 result 中。
4.然后,我们遍历输入序列中的数字,并递归调用 backtrack 函数进行下一层的排列。在递归调用中,我们使用 path 参数来标记已经使用过的数字,避免重复使用。
5.最后,返回所有找到的全排列。

4.子集

给定一个数组,返回其所有可能的子集

递归回溯:在每一层递归中,尝试加入当前元素或不加入当前元素。
处理结果:当递归到底层时,将当前子集添加到结果列表中。

def subsets(nums):
    def backtrack(start, path):
        result.append(path[:])
        for i in range(start, len(nums)):
            path.append(nums[i])
            backtrack(i + 1, path)
            path.pop()

    result = []
    backtrack(0, [])
    return result

# 测试
nums = [1, 2, 3]
print(subsets(nums))

5.电话号码的字母组合

给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合

递归回溯:在每一层递归中,尝试加入当前数字对应的字母。
处理结果:当递归到底层时,将当前字母组合添加到结果列表中。

def letter_combinations(digits):
    if not digits:
        return []
    phone_map = {'2': 'abc', '3': 'def', '4': 'ghi', '5': 'jkl',
                 '6': 'mno', '7': 'pqrs', '8': 'tuv', '9': 'wxyz'}

    def backtrack(index, path):
        if index == len(digits):
            result.append(''.join(path))
            return
        for char in phone_map[digits[index]]:
            path.append(char)
            backtrack(index + 1, path)
            path.pop()

    result = []
    backtrack(0, [])
    return result

# 测试
digits = "23"
print(letter_combinations(digits))

6.岛屿数量

给定一个由 ‘0’ 和 ‘1’ 组成的二维网格地图,其中 ‘1’ 表示陆地,‘0’ 表示水域,计算岛屿的数量。岛屿被水域包围,并且水平或垂直相邻(不包含对角线)的陆地被认为是同一个岛屿。
实现思路:
DFS:遍历整个网格,当遇到陆地时,进行深度优先搜索,将与当前陆地相连的所有陆地标记为已访问,直到所有相连的陆地被访问完毕为止。
计数:每次找到一个新的岛屿时,增加岛屿数量。

def num_islands(grid):
    def dfs(row, col):
        if row < 0 or row >= len(grid) or col < 0 or col >= len(grid[0]) or grid[row][col] == '0':
            return
        grid[row][col] = '0'  # 标记为已访问
        for dr, dc in directions:
            dfs(row + dr, col + dc)

    if not grid:
        return 0
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    num_rows, num_cols = len(grid), len(grid[0])
    num_islands = 0
    for row in range(num_rows):
        for col in range(num_cols):
            if grid[row][col] == '1':
                num_islands += 1
                dfs(row, col)
    return num_islands

# 测试
grid = [
    ['1', '1', '0', '0', '0'],
    ['1', '1', '0', '0', '0'],
    ['0', '0', '1', '0', '0'],
    ['0', '0', '0', '1', '1']
]
print(num_islands(grid))

7.单词搜索

给定一个二维网格和一个单词,判断单词是否存在于网格中。字母相邻(包括对角线)的格子组成了单词
实现思路:
对于网格中的每个格子,都作为起点尝试进行深度优先搜索。
在深度优先搜索的过程中,递归地探索当前格子的上、下、左、右四个相邻格子,判断是否能够匹配单词中的下一个字母。
如果能够匹配,继续向下递归搜索;如果不能匹配或者超出了边界,则回溯到上一个格子,尝试其他方向的搜索。
在搜索的过程中,需要使用一个额外的标记数组来记录已经访问过的格子,避免重复访问。

def exist(board, word):
    def dfs(row, col, index):
        # 终止条件:单词已全部匹配
        if index == len(word):
            return True
        # 边界条件:越界或当前字母不匹配
        if row < 0 or row >= len(board) or col < 0 or col >= len(board[0]) or board[row][col] != word[index]:
            return False
        # 临时标记当前位置已访问
        temp = board[row][col]
        board[row][col] = '#'  
        # 递归搜索当前字母的上、下、左、右四个相邻格子
        for dr, dc in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            if dfs(row + dr, col + dc, index + 1):
                return True
        # 回溯:恢复原始状态
        board[row][col] = temp
        return False

    if not board or not word:
        return False
    # 逐个尝试每个格子作为起点
    for row in range(len(board)):
        for col in range(len(board[0])):
            if dfs(row, col, 0):
                return True
    return False

# 测试
board = [
    ['A', 'B', 'C', 'E'],
    ['S', 'F', 'C', 'S'],
    ['A', 'D', 'E', 'E']
]
word = "ABCCED"
print(exist(board, word))  # 输出: True

最近更新

  1. TCP协议是安全的吗?

    2024-03-10 16:24:04       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-10 16:24:04       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-10 16:24:04       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-10 16:24:04       18 阅读

热门阅读

  1. Svelte之基础知识一

    2024-03-10 16:24:04       24 阅读
  2. 读书·基于RISC-V和FPGA的嵌入式系统设计·第3章

    2024-03-10 16:24:04       21 阅读
  3. pytorch升级打怪(一)

    2024-03-10 16:24:04       22 阅读
  4. 力扣77-组合

    2024-03-10 16:24:04       23 阅读
  5. 设计模式之单例模式

    2024-03-10 16:24:04       21 阅读
  6. IntelliJ IDEA分支svn

    2024-03-10 16:24:04       27 阅读