测试环境:
pymc3==3.10.0
代码部分:
import pymc3 as pm
from scipy import stats
class variational_method(object):
def __init__(self):
pass
def fit(self,data,sigma_true,mu_prior_mu,sigma_prior_mu,samples,iter_num):
with pm.Model() as variational_model:
mu = pm.Normal('mu', mu=mu_prior_mu, sigma=sigma_prior_mu)
pm.Normal('x', mu=mu, sigma=sigma_true, observed=data)
approx = pm.fit(n=iter_num,method=pm.ADVI(),obj_optimizer=pm.adam(learning_rate=0.01))
trace = approx.sample(draws=samples)
pm.plot_trace(trace)
return pm.summary(trace)
### 样本分布的假设 正态
mu_true = 1.5
sigma_true = 2
data = stats.norm.rvs(mu_true, sigma_true, size=200)
### mu的先验分布设置 正态
mu_prior_mu = 0.5
sigma_prior_mu = 1
### mu更新的初始值设置
mu_init = 0
### 变分法其它参数
samples = 120
iter_num = 50000
vm = variational_method()
vm.fit(data,sigma_true,mu_prior_mu,sigma_prior_mu,samples,iter_num)
输出结果: