回溯法理论、应用及模板

1. 入门

有时会遇到这样一类题目,它的问题可以分解,但是又不能得出明确的动态规划或是递归解法,此时可以考虑用回溯法解决此类问题。回溯法的优点 在于其程序结构明确,可读性强,易于理解,而且通过对问题的分析可以大大提高运行效率。但是,对于可以得出明显的递推公式迭代求解的问题,还是不要用回溯法,因为它花费的时间比较长。

回溯算法的基本思想是:从一条路往前走,能进则进,不能进则退回来,换一条路再试。可以认为回溯算法一个”通用解题法“,这是由他试探性的行为决定的,就好比求一个最优解,我可能没有很好的概念知道怎么做会更快的求出这个最优解,但是我可以尝试所有的方法,先试探性的尝试每一个组合,看看到底通不通,如果不通,则折回去,由最近的一个节点继续向前尝试其他的组合,如此反复。这样所有解都出来了,在做一下比较,能求不出最优解吗?

2. 基本定义和概念

回溯法中,首先需要明确下面三个概念:

  • 约束函数:约束函数是根据题意定出的。通过描述合法解的一般特征用于去除不合法的解,从而避免继续搜索出这个不合法解的剩余部分。因此,约束函数是对于任何状态空间树上的节点都有效、等价的。
  • 状态空间树:刚刚已经提到,状态空间树是一个对所有解的图形描述。树上的每个子节点的解都只有一个部分与父节点不同。
  • 扩展节点、活结点、死结点:所谓扩展节点,就是当前正在求出它的子节点的节点,在DFS中,只允许有一个扩展节点。活结点就是通过与约束函数的对照,节点本身和其父节点均满足约束函数要求的节点;死结点反之。由此很容易知道死结点是不必求出其子节点的(没有意义)。

3. 为什么用DFS

深度优先搜索(DFS)和广度优先搜索(FIFO)在分支界限法中,一般用的是FIFO或最小耗费搜索;其思想是一次性将一个节点的所有子节点求出并将其放入一个待求子节点的队列。通过遍历这个队列(队列在 遍历过程中不断增长)完成搜索。而DFS的作法则是将每一条合法路径求出后再转而向上求第二条合法路径。而在回溯法中,一般都用DFS。为什么呢?这是因 为可以通过约束函数杀死一些节点从而节省时间,由于DFS是将路径逐一求出的,通过在求路径的过程中杀死节点即可省去求所有子节点所花费的时间。FIFO 理论上也是可以做到这样的,但是通过对比不难发现,DFS在以这种方法解决问题时思路要清晰非常多。

回溯法可以被认为是一个有过剪枝的DFS过程,利用回溯法解题的具体步骤如下:

首先,要通过读题完成下面三个步骤:

  1. 描述解的形式,定义一个解空间,它包含问题的所有解;
  2. 构造状态空间树;
  3. 构造约束函数(用于杀死节点);

然后就要通过DFS思想完成回溯,完整过程如下:

  1. 设置初始化的方案(给变量赋初值,读入已知数据等)。
  2. 变换方式去试探,若全部试完则转(7)。
  3. 判断此法是否成功(通过约束函数),不成功则转(2)。
  4. 试探成功则前进一步再试探。
  5. 正确方案还未找到则转(2)。
  6. 已找到一种方案则记录并打印。
  7. 退回一步(回溯),若未退到头则转(2)。
  8. 已退到头则结束或打印无解。

4. 回溯方法的步骤

回溯方法的步骤如下:

  1. 定义一个解空间,它包含问题的解。
  2. 用适于搜索的方式组织该空间。
  3. 用深度优先法搜索该空间,利用限界函数避免移动到不可能产生解的子空间。

5. 回溯模板

首先我们来看一道题目:

Combinations:Given two integers n and k,return all possible combinations of k numbers out of 1 … n. For example, If n = 4 and k =2, a solution is:
[
[2,4],
[3,4],
[2,3],
[1,2],
[1,3],
[1,4],
]
即:给你两个整数 n和k,从1-n中选择k个数字的组合。比如n=4,那么从1,2,3,4中选取两个数字的组合,包括图上所述的四种。

然后我们看看题目给出的框架:

public class Solution {
    public List<List<Integer>> combine(int n, int k) {
       
    }
}

要求返回的类型是List<List> 也就是说将所有可能的组合list(由整数构成)放入另一个list(由list构成)中。
现在进行套路教学:要求返回List<List>,那我就给你一个List<List>,因此

  1. 定义一个全局List<List> result=new ArrayList<List>();
  2. 定义一个辅助的方法(函数)public void backtracking(int n,int k, Listlist){}
    n, k 总是要有的吧,加上这两个参数,前面提到List<Integer> 是数字的组合,也是需要的吧,这三个是必须的,没问题吧。(可以尝试性地写参数,最后不需要的删除)
  3. 接着就是我们的重头戏了,如何实现这个算法?对于n=4,k=2,1,2,3,4中选2个数字,我们可以做如下尝试,加入先选择1,那我们只需要再选择一个数字,注意这时候k=1了(此时只需要选择1个数字啦)。当然,我们也可以先选择2,3 或者4,通俗化一点,我们可以选择 1~n 的所有数字,这个是可以用一个循环来描述?每次选择一个加入我们的链表list中,下一次只要再选择k-1个数字。那什么时候结束呢?当然是 k<0 的时候啦,这时候都选完了。
    有了上面的分析,我们可以开始填写
    public void backtracking(int n,int k, List<Integer> list){}
    中的内容。
publicvoid backtracking(int n,int k,int start,List<Integer> list){
        if(k<0)        return;
        else if(k==0){
            //k==0表示已经找到了k个数字的组合,这时候加入全局result中
            result.add(new ArrayList(list));
 
        }else{
            for(int i=start;i<=n;i++){
                list.add(i);    //尝试性的加入i
                //开始回溯啦,下一次要找的数字减少一个所以用k-1,i+1见后面分析
                backtracking(n,k-1,i+1,list);
                //(留白,有用=。=)
            }
        }
    }

观察一下上述代码,我们加入了一个start变量,它是i的起点。为什么要加入它呢?比如我们第一次加入了1,下一次搜索的时候还能再搜索1了么?肯定不可以啊!我们必须从他的下一个数字开始,也就是2 、3或者4啦。所以start就是一个开始标记这个很重要啦!
这时候我们在主方法中加入backtracking(n,k,1,list);调试后发现答案不对啊!为什么我的答案比他长那么多?
在这里插入图片描述

回溯回溯当然要退回再走啦,你不退回,当然又臭又长了!所以我们要在刚才代码注释留白处加上退回语句。仔细分析刚才的过程,我们每次找到了1,2这一对答案以后,下一次希望2退出然后让3进来,1 3就是我们要找的下一个组合。如果不回退,找到了2 ,3又进来,找到了3,4又进来,所以就出现了我们的错误答案。正确的做法就是加上:list.remove(list.size()-1);他的作用就是每次清除一个空位 让后续元素加入。寻找成功,最后一个元素要退位,寻找不到,方法不可行,那么我们回退,也要移除最后一个元素。
所以完整的程序如下:

public class Solution {
   List<List<Integer>> result=new ArrayList<List<Integer>>();
   public List<List<Integer>> combine(int n, int k) {
       List<Integer> list=new ArrayList<Integer>();
       backtracking(n,k,1,list);
       return result;
    }
   public void backtracking(int n,int k,int start,List<Integer>list){
       if(k<0) return ;
       else if(k==0){
           result.add(new ArrayList(list));
       }else{
           for(int i=start;i<=n;i++){
                list.add(i);
                backtracking(n,k-1,i+1,list);
                list.remove(list.size()-1);
            }
       }
    }
}

是不是有点想法了?那么我们操刀一下。

6. 应用实例

6.1 LeetCode 39. Combination Sum(组合总数)

Given a set of candidate numbers (candidates) (without duplicates) and a target number (target), find all unique combinations in candidates where the candidate numbers sums to target.

The same repeated number may be chosen from candidates unlimited number of times.

Note:

  • All numbers (including target) will be positive integers.
  • The solution set must not contain duplicate combinations.

题目:
给定一组候选数字 (候选) (不带重复项) 和目标数字 (目标), 可以在候选数字总和为目标的候选项中查找所有唯一的组合。

同样重复的数字可以从候选者无限次数中选择。

Example 1:

Input: candidates = [2,3,6,7], target = 7,
A solution set is:
[
  [7],
  [2,2,3]
]

Example 2:

Input: candidates = [2,3,5], target = 8,
A solution set is:
[
  [2,2,2,2],
  [2,3,3],
  [3,5]
]

按照前述的套路走一遍:

public class Solution {
   List<List<Integer>> result=new ArrayList<List<Integer>>();
   public List<List<Integer>> combinationSum(int[] candidates,int target) {
       Arrays.sort(candidates);
       List<Integer> list=new ArrayList<Integer>();
       return result;
    }
   public void backtracking(int[] candidates,int target,int start,){        
    }
}
  1. 全局List<List> result先定义
  2. 回溯backtracking方法要定义,数组candidates 目标target 开头start 辅助链表List list都加上。
  3. 分析算法:以[2,3,6,7] 每次尝试加入数组任何一个值,用循环来描述,表示依次选定一个值
for(inti=start; i<candidates.length; i++){

           list.add(candidates[i]);

       }

接下来回溯方法再调用。比如第一次选了2,下次还能再选2是吧,所以每次start都可以从当前i开始(ps:如果不允许重复,从i+1开始)。第一次选择2,下一次要凑的数就不是7了,而是7-2,也就是5,一般化就是remain = target - candidates[i],所以回溯方法为:

backtracking(candidates, target-candidates[i], i, list);

然后加上退回语句:list.remove(list.size()-1);
那么什么时候找到的解符合要求呢?自然是remain(注意区分初始的target)= 0 了,表示之前的组合恰好能凑出target。如果 remain < 0 表示凑的数太大了,组合不可行,要回退。当remain>0 说明凑的还不够,继续凑。
所以完整方法如下:

publicclass Solution {
    List<List<Integer>> result=newArrayList<List<Integer>>();
    public List<List<Integer>>combinationSum(int[] candidates, int target) {
        Arrays.sort(candidates);//所给数组可能无序,排序保证解按照非递减组合
        List<Integer> list=newArrayList<Integer>();
        backtracking(candidates,target,0,list);//给定target,start=0表示从数组第一个开始
        return result;//返回解的组合链表
    }
    public void backtracking(int[]candidates,int target,int start,List<Integer> list){
       
            if(target<0)    return;//凑过头了
            else if(target==0){
               
                result.add(newArrayList<>(list));//正好凑出答案,开心地加入解的链表
               
            }else{
                for(inti=start;i<candidates.length;i++){//循环试探每个数
                    list.add(candidates[i]);//尝试加入
		   //下一次凑target-candidates[i],允许重复,还是从i开始
                   backtracking(candidates,target-candidates[i],i,list);                   
		   list.remove(list.size()-1);//回退
                }
            }
       
    }
}

其对应的python版本如下:

class Solution:
    def combinationSum(self, candidates, target):
        """
        :type candidates: List[int]
        :type target: int
        :rtype: List[List[int]]
        """
        candidates.sort()
        Solution.res = []
        self.DFS(candidates, target, 0, [])
        
        return Solution.res
    
    def DFS(self, candidates, target, start, temp_res):
        if target == 0:
            Solution.res.append(temp_res[:])
            return
        for i in range(start, len(candidates)):
            if candidates[i] > target:
                return
            self.DFS(candidates, target - candidates[i], i, temp_res + [candidates[i]])
            

这里一定很迷惑,为什么转到了python版本之后就不用后退了呢?(对应list.remove(list.size()-1);
事实上,问题出在了
self.DFS(candidates, target - candidates[i], i, temp_res + [candidates[i]])

由于python解决方案中直接将temp_res + [candidates[i]]放在了递归语句中,则在递归遇到不满足条件跳出时,对应的temp_res也会将之前输入时加上的[candidates[i]]去掉;而上面java程序采用的是先将[candidates[i]]加到了temp_res,再传入递归程序DFS中,故而即便是跳出递归,temp_res并不会去除之前加上的temp_res,需要手动再加上后退程序:temp_res.pop().

故而也可看与java对应的python版本:

class Solution:
    def combinationSum(self, candidates, target):
        """
        :type candidates: List[int]
        :type target: int
        :rtype: List[List[int]]
        """
        candidates.sort()
        Solution.res = []
        self.DFS(candidates, target, 0, [])

        return Solution.res

    def DFS(self, candidates, target, start, temp_res):
        if target == 0:
            Solution.res.append(temp_res[:])
            return
        for i in range(start, len(candidates)):
            if candidates[i] > target:
                return
            temp_res.append(candidates[i])
            self.DFS(candidates, target - candidates[i], i, temp_res)
            temp_res.pop()
            

注:还是推荐采用下面这种方式,因为直接将对temp_res的操作放在递归程序的输入函数中,容易出现一些问题;本题之所以成果是因为temp_res + [candidates[i]]的妥当使用,事实上,如果将其换为temp_res.append(candidates[i]),程序就会出现错误:
在这里插入图片描述

从debug的过程发现,出错原因在使用temp_res + [candidates[i]]会立刻对temp_res进行转换;而采用temp_res.append(candidates[i])temp_res并不会立刻发生变化,而是直到下次达到此语句时候才进行了变化,与希望过程不符。

6.2 LeetCode 22. Generate Parentheses(括号生成)

原题

Given n pairs of parentheses, write a function to generate all combinations of well-formed parentheses.

For example, given n = 3, a solution set is:

题目:
给出 n 代表生成括号的对数,请你写出一个函数,使其能够生成所有可能的并且有效的括号组合。

Example:

[
  "((()))",
  "(()())",
  "(())()",
  "()(())",
  "()()()"
]

Solution:

class Solution:
    def generateParenthesis(self, n):
        """
        :type n: int
        :rtype: List[str]
        """
        Solution.res = []
        n, start, temp_res, len_left, len_right, candidate_list = n, 0, "", 0, 0, ["(", ")"]
        self.backtracking(n, len_left, len_right, temp_res, candidate_list)
        return Solution.res

    def backtracking(self, n, len_left, len_right, temp_res, candidate_list):

        # len_left = len_right = 0

        if len_left == n and len_right == n:
            Solution.res.append(temp_res[:])

        if len_left > n or len_right > n or len_right > len_left:
            return

        if len(temp_res) < 2 * n:
            for i in range(len(candidate_list)):
                temp_res += candidate_list[i]
                if candidate_list[i] == "(":
                    self.backtracking(n, len_left + 1, len_right, temp_res, candidate_list)
                else:
                    self.backtracking(n, len_left, len_right + 1, temp_res, candidate_list)
                temp_res = temp_res[: len(temp_res) - 1]

依旧采用的回溯法,只是针对字符串进行操作需要注意:

  1. 对于字符串进行回溯,一个很好的方法是直接写入回溯表达式(对于列表格式慎用,append不支持这种事实上赋值),如上述for循环表达式,可以替换为:
        if len(temp_res) < 2 * n:
            for i in range(len(candidate_list)):
                # temp_res += candidate_list[i]
                if candidate_list[i] == "(":
                    self.backtracking(n, len_left + 1, len_right, temp_res + candidate_list[i], candidate_list)
                else:
                    self.backtracking(n, len_left, len_right + 1, temp_res + candidate_list[i], candidate_list)
                # temp_res = temp_res[: len(temp_res) - 1]
  1. 对于回溯法字符串删除,可采用 temp_res = temp_res[: len(temp_res) - 1] 曲线救国。

参考文献:

回溯详解及其应用:Leetcode 39 combination sum
[python]回溯法模板
手把手教你中的回溯算法——多一点套路

猜你喜欢

转载自blog.csdn.net/Dby_freedom/article/details/82933845