回溯

概念

回溯通常有一个增量构造答案的过程,这个过程一般由递归实现。比如:

  1. 原问题:构造一个长度为n的字符串。
  2. 子问题:在枚举一个字符串后,就变成构造一个长度为n-1的字符串了。

回溯三问

  1. 当前的操作?
  2. 子问题?
  3. 下一个子问题?
    dfs(i) -> dfs(i+1)

子集型例题

MAPPING = "", "", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz"
class Solution:
    def letterCombinations(self, digits: str) -> List[str]:
        n=len(digits)
        if n==0:
            return []
        ans=[]
        path=['']*n
        def dfs(i):
            if i==n:
                ans.append(''.join(path))
                return
            for c in MAPPING[int(digits[i])]:
                path[i]=c
                dfs(i+1)
        dfs(0)
        return ans

# 1.输入的角度,每个数选或不选
class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        ans = []
        path = []
        n = len(nums)
        def dfs(i: int) -> None:
            if i == n:
                ans.append(path.copy())  # 固定答案
                return
            # 不选 nums[i]
            dfs(i + 1)
            # 选 nums[i]
            path.append(nums[i])
            dfs(i + 1)
            path.pop()  # 恢复现场
        dfs(0)
        return ans

# 2.答案的角度,选的哪几个数
class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        ans = []
        path = []
        n = len(nums)
        def dfs(i: int) -> None:
            ans.append(path.copy())  # 固定答案
            if i == n:
                return
            for j in range(i, n):  # 枚举选择的数字
                path.append(nums[j])
                dfs(j + 1)
                path.pop()  # 恢复现场
        dfs(0)
        return ans

# 1.答案的角度 枚举字符串结束位置
class Solution:
    def partition(self, s: str) -> List[List[str]]:
        ans = []
        path = []
        n = len(s)
        def dfs(i: int) -> None:
            if i == n:
                ans.append(path.copy())  # 固定答案
                return
            for j in range(i, n):  # 枚举子串的结束位置
                t = s[i: j + 1]
                if t == t[::-1]:  # 判断是否回文
                    path.append(t)
                    dfs(j + 1)
                    path.pop()  # 恢复现场
        dfs(0)
        return ans

# 2.输入的角度,假设俩字符间有逗号,那么就看每个逗号选还是不选
class Solution:
    def partition(self, s: str) -> List[List[str]]:
        ans = []
        path = []
        n = len(s)

        # start 表示当前这段回文子串的开始位置
        def dfs(i: int, start: int) -> None:
            if i == n:
                ans.append(path.copy())  # 固定答案
                return

            # 不选 i 和 i+1 之间的逗号(i=n-1 时右边没有逗号)
            if i < n - 1:
                dfs(i + 1, start)

            # 选 i 和 i+1 之间的逗号
            t = s[start: i + 1]
            if t == t[::-1]:  # 判断是否回文
                path.append(t)
                dfs(i + 1, i + 1)
                path.pop()  # 恢复现场

        dfs(0, 0)
        return ans

组合型例题

class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        ans=[]
        path=[]
        def dfs(i,start):
            if i==k:
                ans.append(path.copy())
            for j in range(start,n+1):
                # 剪枝
                if j>n-k+len(path)+1:
                    return 
                path.append(j)
                dfs(i+1,j+1)
                path.pop()
        dfs(0,1)
        return ans

class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        ans=[]
        path=[]
        def dfs(i,start):
            if i==k and sum(path)==n:
                ans.append(path.copy())
            for j in range(start,10):
                if sum(path)>n:
                    return
                path.append(j)
                dfs(i+1,j+1)
                path.pop()
        dfs(0,1)
        return ans

class Solution:
    def generateParenthesis(self, n: int) -> List[str]:
        m = n * 2
        ans = []
        path = [''] * m
        def dfs(i: int, open: int) -> None:
            if i == m:
                ans.append(''.join(path))
                return
            if open < n:  # 可以填左括号
                path[i] = '('
                dfs(i + 1, open + 1)
            if i - open < open:  # 可以填右括号
                path[i] = ')'
                dfs(i + 1, open)
        dfs(0, 0)
        return ans

排列型例题

class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        ans=[]
        path=[]
        n=len(nums)
        def dfs(i):
            if i==n:
                ans.append(path.copy())
            for j in range(n):
                if nums[j] not in path:
                    path.append(nums[j])
                    dfs(i+1)
                    path.pop()
        dfs(0)
        return ans

# 自己写的版本
class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        ans=[]
        path=[]*n
        a=Counter()
        b=Counter()
        def dfs(i):
            if i==n:
                a.clear()
                b.clear()
                #判断是否存在斜对角,斜对角行列和或行列差为定值
                for x,value in enumerate(path):
                    if a[x-value]==0:
                        a[x-value]+=1
                    else:
                        return
                    if b[x+value]==0:
                        b[x+value]+=1
                    else:
                        return
                templ=[""]*n
                for k in range(n):
                    temp=path[k]
                    templ[k]="."*n
                    l=list(templ[k])
                    l[temp]="Q"
                    templ[k]="".join(l)
                ans.append(templ.copy())
            # 全排列保证不在同一行同一列
            for j in range(n):
                if j not in path:
                    path.append(j)
                    dfs(i+1)
                    path.pop()
        dfs(0)
        return ans
# 推荐写法
class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        ans=[]
        col=[0]*n
        def dfs(r,s):
            if r==n:
                ans.append(['.'*c+'Q'+'.'*(n-1-c) for c in col])
                return
            for c in s:
                if all(r+c!=R+col[R] and r-c!=R-col[R] for R in range(r)):
                    col[r]=c
                    dfs(r+1,s-{c})
        dfs(0,set(range(n)))
        return ans