生日礼物(京东2016实习生真题)
思路
代码
#include <cstdlib>
#include <string>
#include <iostream>
#include <fstream>
#include <vector>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <map>
#include <set>
#include <stdio.h>
#include <numeric>
#include <algorithm>
#include <functional>
#include <stack>
#include <queue>
using namespace std;
typedef long long ll;
const int MOD = 1e9 + 7;
typedef unsigned char uchar;
int dirs[8][2] = { -1, -1, -1, 0, -1, 1, 0, -1, 0, 1, 1, -1, 1, 0, 1, 1 };
#define MAX_DIST 1e9
struct stu
{
stu(){}
int w_ = 0;
int h_ = 0;
int idx_ = 0;
};
// w first
bool cmp_w(const stu& s1, const stu& s2)
{
return s1.w_ < s2.w_;
}
// result中存储着最长子串对应的index(最初输入的index)
int LISS(vector<stu> &vec, vector<int>& result){
int k = 1;//记录最后一个尾元素的位置
int max_len = 0;
// prev:所有节点的前驱节点
vector<int> prev( vec.size(), 0 );
//求最长递增序列
vector<int> dp(vec.size(), 1); // dp矩阵
for (int i = 0; i < vec.size(); ++i)
{
prev[i] = i;
for (int j = 0; j < i; ++j)
{
// 在特定子串中的序列的高度都是严格递增的,同时如果子串相等,
// 则不修改之前的值,因为优先选择之前的值
if (vec[i].h_ > vec[j].h_ && dp[j] + 1 > dp[i])
{
prev[i] = j;
dp[i] = dp[j] + 1;
}
}
if (dp[i] > max_len)
{
k = i;
max_len = dp[i];
}
}
int curr_max = max_len - 1;
while (k != prev[k]) // 当前驱节点不是自身的时候,说明可以继续向前遍历
{
result[curr_max--] = vec[k].idx_;
k = prev[k];
}
result[curr_max] = vec[k].idx_;
return max_len;
}
void func(int n, int w, int h)
{
vector<stu> vec(n);
int cnt = 0;
for (int i = 0; i < n; ++i)
{
int wi, hi;
scanf("%d %d", &wi, &hi);
if (wi <= w || hi <= h)
continue;
vec[cnt].w_ = wi;
vec[cnt].h_ = hi;
vec[cnt++].idx_ = i;
}
if (cnt == 0)
{
printf("0\n");
return;
}
vec.resize(cnt);
// 只保留满足条件的w和h,进行后面的子串识别
std::sort(vec.begin(), vec.end(), cmp_w);
vector<int> result(cnt, 0);
int max_len = LISS(vec, result);
vector<int> output_idx(cnt, 0);
int curr_idx = 0, last_w = 0;;
for (int i = 0; i < max_len; ++i)
{
int j = 0;
for (; j < cnt; ++j)
{
if (vec[j].idx_ == result[i])
break;
}
// 如果和之前的宽度不同,则说明是w和h都是严格递增的
if (vec[j].w_ != last_w)
{
output_idx[curr_idx++] = result[i];
last_w = vec[j].w_;
}
}
printf("%d\n", curr_idx);
for (int i = 0; i < curr_idx; ++i)
{
if (i == 0)
printf("%d", output_idx[i] + 1);
else
printf(" %d", output_idx[i] + 1);
}
printf("\n");
}
int main(int /*argc*/, char** /*argv*/)
{
int n, w, h;
while (scanf("%d %d %d", &n, &w, &h) != EOF)
{
func(n, w, h);
}
system("pause");
return 0;
}