1.EM算法简介
EM算法的详解和样本集实例数学过程讲解,可以详见:https://blog.csdn.net/u012421852/article/details/79915908
2.EM算法的Python实现
# -*- coding: utf-8 -*- """ @author: 蔚蓝的天空Tom Aim:实现EM算法(Expectation Maximization Algorithm) """ import numpy as np class CEM(object): def __init__(self, samples, pa, pb, threshold): self.samples = samples self.pa = pa self.pb = pb self.eStepRet = None self.mStepRet = None self.threshold = threshold self.work() def likelihood_func(self, samples, p): '''似然函数''' ret = [] for e in samples: ret.append(np.power(p, list(e).count(1))*np.power(1-p, list(e).count(0))) return ret def e_step(self): ''' 计算在模型参数pa,pb下观察数据来自投掷硬币a/b的概率 ''' #计算每轮投掷coin a和coin b的似然函数值(即每个样本发生概率的似然值) likelihooda = self.likelihood_func(self.samples, self.pa) #[0.00079626239999999997, 0.0005308416000000002, 0.0005308416000000002, 0.0005308416000000002, 0.0011943936] likelihoodb = self.likelihood_func(self.samples, self.pb) #[0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625, 0.0009765625] #计算每轮投掷来自coin a和coin b的概率 self.eStepRet = np.array([e/sum(e) for e in zip(likelihooda, likelihoodb)]) print('eStepRet:\n', self.eStepRet) #[[ 0.44914893 0.55085107] # [ 0.35215613 0.64784387] # [ 0.35215613 0.64784387] # [ 0.35215613 0.64784387] # [ 0.55016939 0.44983061]] return def m_step(self): '''计算模型参数pa, pb的新估计值 ''' old_pa, old_pb = self.pa, self.pb print('old pa:', old_pa, 'old pb:', old_pb) h_a, t_a = 0, 0 h_b, t_b = 0, 0 for sample, e in zip(self.samples, self.eStepRet): h_a += list(sample).count(1) * e[0] t_a += list(sample).count(0) * e[0] h_b += list(sample).count(1) * e[1] t_b += list(sample).count(0) * e[1] self.pa = h_a / (h_a + t_a) self.pb = h_b / (h_b + t_b) print('new pa:', self.pa, 'new pb:', self.pb) gap_pa, gap_pb = self.pa - old_pa, self.pb - old_pb print('gap_pa:', gap_pa, 'gap_pb:', gap_pb) return gap_pa < self.threshold and gap_pb < self.threshold def work(self): self.e_step() stop = self.m_step() if (stop != True): return self.work() print('stop em\n') return def GetResult(self): return self.pa, self.pb pass def CEM_manual(): samples = np.array([[1,0,1,0,1,0,1,0,1,0], #coin a, 5+5- [1,0,1,0,1,0,1,0,1,1], #coin b, 6+4- [1,1,1,0,1,0,1,0,1,0], #coin a, 6+4- [1,0,1,1,1,0,1,0,1,0], #coin b, 6+4- [1,0,1,0,1,0,0,1,0,0]])#coin a, 4+6- samples = np.array([[1,0,1,1,1,0,1,0,1,0], #coin a, 5+5- [1,0,1,1,1,0,1,0,1,1], #coin b, 6+4- [1,1,1,1,1,0,1,0,1,0], #coin a, 6+4- [1,0,1,1,1,1,1,0,1,0], #coin b, 6+4- [1,0,1,1,1,0,0,1,0,0]])#coin a, 4+6- #可以知道 #p(1|a) = (5+6+4)/30 = 0.5 #p(1|b) = (6+6)/20 = 0.6 #设置初始值 pa, pb = 0.4, 0.5#p(1|a) = 0.4, p(1|b) = 0.5 threshold = 0.00001 em = CEM(samples, pa, pb, threshold) ret = em.GetResult() print(ret) return if __name__=='__main__': CEM_manual()
3.运行结果
runfile('C:/Users/l13277/EM.py', wdir='C:/Users/l13277') eStepRet: [[ 0.35215613 0.64784387] [ 0.26599464 0.73400536] [ 0.26599464 0.73400536] [ 0.26599464 0.73400536] [ 0.44914893 0.55085107]] old pa: 0.4 old pb: 0.5 new pa: 0.621811879589 new pb: 0.648553523107 gap_pa: 0.221811879589 gap_pb: 0.148553523107 eStepRet: [[ 0.51017251 0.48982749] [ 0.4813223 0.5186777 ] [ 0.4813223 0.5186777 ] [ 0.4813223 0.5186777 ] [ 0.53895512 0.46104488]] old pa: 0.621811879589 old pb: 0.648553523107 new pa: 0.63630074059 new pb: 0.643678879642 gap_pa: 0.0144888610009 gap_pb: -0.00487464346478 eStepRet: [[ 0.50320194 0.49679806] [ 0.49519623 0.50480377] [ 0.49519623 0.50480377] [ 0.49519623 0.50480377] [ 0.51120602 0.48879398]] old pa: 0.63630074059 old pb: 0.643678879642 new pa: 0.638975359185 new pb: 0.64102463807 gap_pa: 0.00267461859504 gap_pb: -0.00265424157212 eStepRet: [[ 0.50088945 0.49911055] [ 0.49866584 0.50133416] [ 0.49866584 0.50133416] [ 0.49866584 0.50133416] [ 0.50311303 0.49688697]] old pa: 0.638975359185 old pb: 0.64102463807 new pa: 0.639715379851 new pb: 0.640284620153 gap_pa: 0.000740020666399 gap_pb: -0.00074001791737 eStepRet: [[ 0.50024707 0.49975293] [ 0.4996294 0.5003706 ] [ 0.4996294 0.5003706 ] [ 0.4996294 0.5003706 ] [ 0.50086473 0.49913527]] old pa: 0.639715379851 old pb: 0.640284620153 new pa: 0.639920938888 new pb: 0.640079061112 gap_pa: 0.000205559037088 gap_pb: -0.000205559040588 eStepRet: [[ 0.50006863 0.49993137] [ 0.49989706 0.50010294] [ 0.49989706 0.50010294] [ 0.49989706 0.50010294] [ 0.5002402 0.4997598 ]] old pa: 0.639920938888 old pb: 0.640079061112 new pa: 0.639978038581 new pb: 0.640021961419 gap_pa: 5.7099692811e-05 gap_pb: -5.70996928321e-05 eStepRet: [[ 0.50001906 0.49998094] [ 0.4999714 0.5000286 ] [ 0.4999714 0.5000286 ] [ 0.4999714 0.5000286 ] [ 0.50006672 0.49993328]] old pa: 0.639978038581 old pb: 0.640021961419 new pa: 0.639993899606 new pb: 0.640006100394 gap_pa: 1.58610249222e-05 gap_pb: -1.58610249225e-05 eStepRet: [[ 0.5000053 0.4999947 ] [ 0.49999206 0.50000794] [ 0.49999206 0.50000794] [ 0.49999206 0.50000794] [ 0.50001853 0.49998147]] old pa: 0.639993899606 old pb: 0.640006100394 new pa: 0.639998305446 new pb: 0.640001694554 gap_pa: 4.40584023775e-06 gap_pb: -4.40584023775e-06 stop em (0.63999830544606295, 0.64000169455393696)
(end)