题目:
In a given integer array A, we must move every element of A to either list B or list C. (B and C initially start empty.)
Return true if and only if after such a move, it is possible that the average value of B is equal to the average value of C, and B and C are both non-empty.
Example : Input: [1,2,3,4,5,6,7,8] Output: true Explanation: We can split the array into [1,4,5,8] and [2,3,6,7], and both of them have the average of 4.5.
Note:
- The length of
A
will be in the range [1, 30]. A[i]
will be in the range of[0, 10000]
.
思路:
第一次在Leetcode上面见到NP的题目(即目前还没有找到多项式解法的题目)。所以理论上这个算法的时间复杂度是指数级的。不过好消息是,通过分析,我们可以发现我们只需要列举有限多个情况就可以得到答案。整个过程分为三步:
1)如果一个长度为n的数组可以被划分为A和B两个数组,我们假设A的长度小于B并且A的大小是k,那么:total_sum / n == A_sum / k == B_sum / (n - k),其中1 <= k <= n / 2。那么可以知道:A_sum = total_sum * k / n。由于A_sum一定是个整数,所以我们可以推导出total_sum * k % n == 0,那就是说,对于特定的total_sum和n而言,符合条件的k不会太多。这样我们在第一步中就首先验证是否存在符合条件的k,如果不存在就可以提前返回false。
2)如果经过第一步的验证,发现确实有符合条件的k,那么我们在第二步中,就试图产生k个子元素的所有组合,并且计算他们的和。这里的思路就有点类似于背包问题了,我们的做法是:定义vector<vector<unordered_set<int>>> sums,其中sums[i][j]表示A[0, i]这个子数组中的任意j个元素的所有可能和。可以得到递推公式是:sums[i][j] = sums[i - 1][j] "join" (sums[i][j - 1] + A[i]),其中等式右边的第一项表示这j个元素中不包含A[i],而第二项表示这j个元素包含A[i]。这样就可以采用动态规划的思路得到sums[n - 1][k]了(1 <= k <= n / 2)。
3)有了sums[n - 1][k],我们就检查sums[n - 1][k]中是否包含(total_sum * k / n)。一旦发现符合条件的k,就返回true,否则就返回false。
在递推公式中我们发现,sums[i][j]仅仅和sums[i - 1][j],sums[i][j - 1]有关,所以可以进一步将空间复杂度从O(n^2*M)降低到O(n*M),其中M是n中的所有元素的组合数(可能高达O(2^n))。时间复杂度为O(n^3*M)。
代码:
class Solution { public: bool splitArraySameAverage(vector<int>& A) { int n = A.size(), m = n / 2; int totalSum = accumulate(A.begin(), A.end(), 0); // early pruning bool isPossible = false; for (int i = 1; i <= m; ++i) { if (totalSum * i % n == 0) { isPossible = true; break; } } if (!isPossible) { return false; } // DP like knapsack vector<unordered_set<int>> sums(m + 1); sums[0].insert(0); for (int num: A) { // for each element in A, we try to add it to sums[i] by joining sums[i - 1] for (int i = m; i >= 1; --i) { for (const int t: sums[i - 1]) { sums[i].insert(t + num); } } } for (int i = 1; i <= m; ++i) { if (totalSum * i % n == 0 && sums[i].find(totalSum * i / n) != sums[i].end()) { return true; } } return false; } };