hdu 4734 F(x) (数位DP,memset优化,变加为减)

原题地址:http://acm.hdu.edu.cn/showproblem.php?pid=4734

先贴两个讲的比较好的博客:http://www.cnblogs.com/dilthey/p/8545485.html
https://blog.csdn.net/wust_zzwh/article/details/52100392#t2

说一下这题的坑点。
由于时限只给了500ms,然后这题一开始不经过转化的话是不能够使用memset优化的(即memset只初始化一次),因为你dp里存放的结果和a的取值是有关系的,你如果取的数据使得f(a)的值小,那么你dp里存的结果会大一些,反之就会小一些。

所以解这题的方法就需要将加法变为减法,使得转化后能够使用memset优化,不再在memset上浪费时间。

贴一个大佬对转化的解释
这里写图片描述

还有解释

题目给了个f(x)的定义:F(x) = An * 2n-1 + An-1 * 2n-2 + … + A2 * 2 + A1 * 1,Ai是十进制数位,然后给出a,b求区间[0,b]内满足f(i)<=f(a)的i的个数。
常规想:这个f(x)计算就和数位计算是一样的,就是加了权值,所以dp[pos][sum],这状态是基本的。a是题目给定的,f(a)是变化的不过f(a)最大好像是4600的样子。如果要memset优化就要加一维存f(a)的不同取值,那就是dp[10][4600][4600],这显然不合法。
这个时候就要用减法了,dp[pos][sum],sum不是存当前枚举的数的前缀和(加权的),而是枚举到当前pos位,后面还需要凑sum的权值和的个数,
也就是说初始的是时候sum是f(a),枚举一位就减去这一位在计算f(i)的权值,那么最后枚举完所有位 sum>=0时就是满足的,后面的位数凑足sum位就可以了。
仔细想想这个状态是与f(a)无关的(新手似乎很难理解),一个状态只有在sum>=0时才满足,如果我们按常规的思想求f(i)的话,那么最后sum>=f(a)才是满足的条件。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 4;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
int n, m, cnt, t;
int dp[10][4600];//dp[i][j],当前在第i位,还距离sum权值的数量
int a[100];
int p[20];
int dfs(int pos,  int sum, int limit) {
    if(pos == -1) {
        if(sum >= 0) return 1;
        else return 0;
    }
    if(sum < 0) return 0;
    if(!limit && dp[pos][sum] != -1) return dp[pos][sum]; //因为这是从高位一位一位记忆化搜索,所以会有重复的
    int up = limit ? a[pos] : 9;
    int ans = 0;
    for(int i = 0; i <= up; i++) {
        ans += dfs(pos - 1,  sum - p[pos] * i, limit && i == up);
    }
    if(!limit) dp[pos][sum] = ans;
    return ans;
}

int  init(int x) {
    cnt = 0;
    while(x) {
        a[cnt++] = x % 10;
        x /= 10;
    }
    int ans = 0;
    for(int i = 0; i < cnt; i++) {
        ans += p[i] * a[i];
    }
    return ans;
}
int solve(int x, int y) {
    p[0] = 1;
    for(int i = 1; i <= 10; i++) p[i] = p[i - 1] * 2;
    int fa = init(x);
    cnt = 0;
    while(y) {
        a[cnt++] = y % 10;
        y /= 10;
    }
    int ans = dfs(cnt - 1, fa, true);
    return ans;
}

int main() {
    scanf("%d", &t);
    memset(dp, -1, sizeof(dp));
    for(int i = 1; i <= t; i++) {
        scanf("%d%d", &n, &m);
        printf("Case #%d: ", i);
        printf("%d\n", solve(n, m));
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/yiqzq/article/details/81136427