一、代码
1. MNIST数据集的引变量c
dataset = MnistDataset()#数据集MNIST
latent_spec = [
(Uniform(62), False),#62类默认是false
(Categorical(10), True),#离散码c1
(Uniform(1, fix_std=True), True),#连续码c2
(Uniform(1, fix_std=True), True),#连续码c3
]
隐变量是由列表构成(列表是由一系列按特定顺序排列的元素组成,用[]来表示,并用逗号分隔其中的元素})
2. 互信息的参数定义
互信息计算起始于如下两个变量:
reg_z:表示了模型开始随机生成的隐变量
fake_ref_z_dist_info:表示了经过Encoder计算后的隐变量分布信息
根据连续型和离散型的分类,两个变量分成了以下四个变量:
cont_reg_z:reg_z的连续变量部分
cont_reg_dist_info:fake_ref_z_dist_info的连续变量部分
disc_reg_z:reg_z的离散变量部分