题目:
给定两个整数 n 和 k,返回 1 ... n 中所有可能的 k 个数的组合。
例:
输入: n = 4, k = 2 输出: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4], ]
来源:
解题思路:回溯
想用暴力搜索解决,但还是无从下手,这时我们采用回溯的办法解决。
回溯配合递归使用,递归前访问,递归后回溯。访问与回溯是一对相反操作。
代码一:
class Solution {
public:
vector< vector<int> > result;
vector<int> path;
vector< vector<int> > combine(int n, int k) {
go(n, k, 1);
return result;
}
void go(int n, int k, int start) {
if (path.size() == k) {
// 保存结果
result.push_back(path);
return;
}
for (int i = start; i <= n; i++) {
path.push_back(i); // 访问
go(n, k, i+1); // 递归
path.pop_back(); // 回溯
}
}
};
代码一暴力搜索全部可能的组合,包括path.size()<k的情况。对于这种情况,需要剪枝以提升代码效率。
例如,若集合[1,2,3,4],k=3,当start指向3时,后面只剩一个数字4了,已经不足了,此时就没必要对这种情况递归调用了。
剩余的个数:n - start
path中已有的个数:path.size()
当前start:1,(start指向的数字还没有进入path,也没有计算在剩余中)
所以当 剩余的个数 + path中已有的个数 + 当前start < k时,就退出循环,这就是剪枝条件:n - i + path.size() + 1 < k,
优化后的代码二:
class Solution {
public:
vector< vector<int> > result;
vector<int> path;
vector< vector<int> > combine(int n, int k) {
go(n, k, 1);
return result;
}
void go(int n, int k, int start) {
if (path.size() == k) {
result.push_back(path);
return;
}
for (int i = start; i <= n; i++) {
if (n - i + path.size() + 1 < k) break;
path.push_back(i);
go(n, k, i+1);
path.pop_back();
}
}
};
回溯+递归,写出来的代码简单易懂。这之前使用暴力写过一次代码,现在读起来自己都读不懂了,贴出来对比一下,见下面代码。
意思是先指定第一个结果,后面每个结果在前一个结果上+1,就像存在这么一个+1运算符:current = prev + 1。注意数字超了要进位。
/**
* Return an array of arrays of size *returnSize.
* The sizes of the arrays are returned as *returnColumnSizes array.
* Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().
*/
int combine_size(int n, int k) {
if (k > n / 2) k = n - k;
int s = 1;
for (int i = 1; i <= k; i++) {
s *= n--;
s /= i;
}
return s;
}
int** combine(int n, int k, int* returnSize, int** returnColumnSizes){
int sz = combine_size(n, k); // 先计算空间大小
*returnSize = sz;
// 申请空间
int **ret = (int**)malloc(sizeof(int*) * sz);
int *sizes = (int*)malloc(sizeof(int) * sz);
for (int i = 0; i < sz; i++) {
int *t = (int*)malloc(sizeof(int) * (k+1));
t[k] = n + 1;
ret[i] = t;
sizes[i] = k;
}
*returnColumnSizes = sizes;
// 初始第一个结果
for (int i = 0; i < k; i++) {
ret[0][i] = i + 1;
}
int p = 1;
while (p < sz) {
// ret[p] = ret[p-1] + 1
int *pre = ret[p-1];
int *cur = ret[p];
// p指向行,而pos指向列,从最后一列算起
// n=6 k=3, 1,2,3,6 -> 1,2,4,6
int pos = k - 1;
while (pos >= 0) {
if (pre[pos] < pre[pos+1] - 1) {
break;
}
pos--;
}
for (int i = 0; i < pos; i++) {
cur[i] = pre[i];
}
cur[pos] = pre[pos] + 1;
for (int i = pos + 1; i < k; i++) {
cur[i] = cur[i-1] + 1;
}
p++;
}
return ret;
}