题目地址:https://www.nowcoder.com/questionTerminal/128d8d7d1898406b817fc69baa20602f
链接:https://www.nowcoder.com/questionTerminal/128d8d7d1898406b817fc69baa20602f
来源:牛客网
牛牛和羊羊非常无聊.他们有n + m个共同朋友,他们中有n个是无聊的,m个是不无聊的。每个小时牛牛和羊羊随机选择两个不同的朋友A和B.(如果存在多种可能的pair(A, B),任意一个被选到的概率相同。),然后牛牛会和朋友A进行交谈,羊羊会和朋友B进行交谈。在交谈之后,如果被选择的朋友之前不是无聊会变得无聊。现在你需要计算让所有朋友变得无聊所需要的时间的期望值。
输入描述:
输入包括两个整数n 和 m(1 ≤ n, m ≤ 50)
输出描述:
输出一个实数,表示需要时间的期望值,四舍五入保留一位小数。
示例1
输入
2 1
输出
1.5
【解决】考虑数学期望的求解
m=0,数学期望是0,不用交谈。
m=1,用数学期望的累加,会一直加到无穷,如果m=1还好求解。
因为数学期望推导有无穷个项的数列,对于m>1的情况,无法推导出递推公式。
考虑用动态规划的方法求解。思路如下图。
记dp(n,m)为(n,m)的数学期望,表示通过谈话将(n,m)转化为(n+m,0)的时间期望。
对于(n,m)这个状态,如果再取一次对话,会有三种情况,(n,m),(n+1,m-1),(n+2,m-2)。
这里有个理解的困难点,这里求的是数学期望,很明显能得出dp(n+1,m-1),dp(n+2, m-2)的数学期望是较小的,所以这里的公式是要在dp(n+1,m-1)+1和dp(n+2,m-2)+1。
也可以这么理解,dp(n,m)可以分成这三块
一、dp(n,m) 概率是c1/s,
二、dp(n+1, m-1) 概率是c2/s
三、dp(n+2,m-2) 概率是c3/s
数学期望的计算公式是 各取值的可能性 * 概率 累加求和。这里因为是多交谈了一次,转换而来的,说明dp(n,m)比累加求和要大1。
公式推导如下图所示。
注意,对于m>2的情况,数组dp要初始化,dp[0]相当于是dp[n+m, 0],为0
dp[1]相当于是dp(n+m-1, 1),它的期望是m+n/2.0f
import java.util.Scanner;
/**
* Created by Administrator on 2018/6/23 0023.
*/
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
while (scanner.hasNext()) {
int n = scanner.nextInt();
int m = scanner.nextInt();
float[] dp = new float[m + 1];
if (m == 0) {
System.out.println(0);
return;
}
if (m == 1) {
float res = (n + 1) / 2.0f; //和下面的不同,此处已知m=1
String r = String.format("%.1f", res);
System.out.println(r);
return;
}
dp[1] = (n+ m) / 2.0f; //注意此处
for (int i = 2; i <= m; i++) {
dp[i] = p1(m + n-i, i) + p2(m + n -i, i) * (dp[i-1] + 1) + p3(m + n -i, i) * (dp[i-2] + 1);
}
String r = String.format("%.1f", dp[m]);
System.out.println(r);
}
}
private static float p1(int n, int m) {
int c1 = n * (n-1) / 2;
int s = (n + m) * (n + m -1) / 2;
float res = c1 * 1.0f / (s - c1);
return res;
}
private static float p2(int n, int m) {
int c1 = n * (n-1) / 2;
int c2 = n * m;
int s = (n + m) * (n + m -1) / 2;
float res = c2 * 1.0f / (s - c1);
return res;
}
private static float p3(int n, int m) {
int c1 = n * (n-1) / 2;
int c3 = m * (m -1) / 2;
int s = (n + m) * (n + m -1) / 2;
float res = c3 * 1.0f / (s - c1);
return res;
}
}