【问题描述】
给定一个整数数组 nums ,你可以对它进行一些操作。
每次操作中,选择任意一个 nums[i] ,删除它并获得 nums[i] 的点数。之后,你必须删除每个等于 nums[i] - 1 或 nums[i] + 1 的元素。
开始你拥有 0 个点数。返回你能通过这些操作获得的最大点数。
示例 1:
输入: nums = [3, 4, 2]
输出: 6
解释:
删除 4 来获得 4 个点数,因此 3 也被删除。
之后,删除 2 来获得 2 个点数。总共获得 6 个点数。
示例 2:
输入: nums = [2, 2, 3, 3, 3, 4]
输出: 9
解释:
删除 3 来获得 3 个点数,接着要删除两个 2 和 4 。
之后,再次删除 3 获得 3 个点数,再次删除 3 获得 3 个点数。
总共获得 9 个点数。
注意:
nums的长度最大为20000。
每个整数nums[i]的大小都在[1, 10000]范围内。
【思路】
这一题先用dfs做的
int dfs(vector<int>& nums, int k, int num1, int num2)
{
if(k == 0)
{
if(nums[k] != num1 && nums[k] != num2)
return nums[k];
else
return 0;
}
else
{
if(nums[k] != num1 && nums[k] != num2)
return max(dfs(nums, k - 1, nums[k] + 1, nums[k] - 1) + nums[k], dfs(nums, k - 1, num1, num2));
else
return dfs(nums, k - 1, num1, num2);
}
}
我的想法也很直观,觉得这一题就是01背包的变种,思想一样,就是对每一个数字都有删或不删的两种选择,然后取最大值就好了。我的参数k代表当前数在数组中的位数,num1, num2代表其不能删的数字。初始调用dfs(nums, length - 1, 0, 0);
这个思路比较好想,但程序肯定是超时的,接下来就是要用dp,记忆化搜索的方法提速。但是我发现,我的递归参数少说也得3个,因为要确定位置、不能删的两个数这种信息,如果用dp数组记忆化搜索需要dp[20000][10000][10000]这肯定不行,被我pass了。
那么我怎么办呢?没办法,只好上网查了查,发现他们的思路真是太妙了。
用dp[i]表示处理到数字i时能获得的最大点数
先用一个数组cnt把每个数字i出现的次数存到cnt[i]中去,然后对于dp[i] = max(dp[i - 1], dp[i - 2] + cnt[i] * i);
其实核心思想跟我是很像的,01背包嘛,最妙的一点就在于不考虑 i + 1的情况,因为是顺序遍历,我们只需要考虑,删除 i 之后,i - 1这个数字我就不能动了,所以我的dp[i] = dp[i - 2] + cnt[i] * i,这点思想要学会!
我之前想的是按数组的位数遍历,而他们想的是按照数组元素的值从小到大来遍历,确实很妙
AC代码:
class Solution {
public:
/*
int dfs(vector<int>& nums, int k, int num1, int num2)
{
if(k == 0)
{
if(nums[k] != num1 && nums[k] != num2)
return nums[k];
else
return 0;
}
else
{
if(nums[k] != num1 && nums[k] != num2)
return max(dfs(nums, k - 1, nums[k] + 1, nums[k] - 1) + nums[k], dfs(nums, k - 1, num1, num2));
else
return dfs(nums, k - 1, num1, num2);
}
}*/
int dp[10001];
int cnt[10001]; //数字i的个数
int DP(vector<int>& nums, int max_n)
{
dp[0] = 0;
dp[1] = cnt[1];
for(int i = 2;i <= max_n;i++)
{
dp[i] = max(dp[i - 1], cnt[i] * i + dp[i - 2]);
}
return dp[max_n];
}
int deleteAndEarn(vector<int>& nums) {
if(nums.size() == 0)
return 0;
sort(nums.begin(), nums.end());
int length = nums.size();
int max_n = nums[length - 1];
for(int i = 1;i <= max_n;i++)
{
for(int j = 0;j < length;j++)
{
if(nums[j] == i)
cnt[i]++;
}
}
return DP(nums, max_n);
}
};