隐马尔科夫模型(HMM)的无监督学习算法java实现(baum-welch迭代求解),包括串行以及并行实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_37667364/article/details/83718931

HMM的原理就不说了,这里主要说算法的实现。

实际实现起来并不是很困难,前提是你仔细看过hmm的原理,然后很多实现就照着公式写出对应的代码,比如前向算法,后向算法,参数更新都是有明确的公式的,只需要对应写成代码,这里需要提到2点技巧。

1,所有概率需要取对数,这是因为有的概率实在是太小了,容易溢出,或者精度不够。

2.对一个求和的式子取对数概率时需要用到一个技巧。下面直接贴出我写的关于这个计算技巧的理解。

LogSum计算技巧

如果需要计算下面的式子:

其中α是一个对数概率,也就是说只知道 而不知道 ,此时如果直接计算会溢出,为了解决这个问题就可以用到这个logSum计算技巧。传入参数是一个数组,每个元素为 ,找到其中的最大值max=max{ln (αi)}:

,根据:

示例代码:

这是一个java版本实现的无监督HMM,包括学习算法和预测算法,算法没有错误,我已经做过多次测试,但是由于HMM的训练算法就是EM算法,而EM算法对初值十分敏感,所以训练时必须给定一些先验条件,即需要给定HMM中的参数pi,A,B至少其中一个,不然训练出来的参数将时一样的,毫无意义。当然无监督的HMM效果依然不如监督学习的HMM,我测试了一下分词,给定了一个监督学习HMM分词的参数来训练无监督的HMM,效果如下:

无监督HMM效果:给定参数pi和B
参数已收敛....
最终参数:
pi:[0.0, -2.1474836360090876E9, -2.147483633470365E9, -2.1474836334854264E9]
A:
[-2.1474836482889004E9, -2.2141536347013195, -0.115686913295864, -2.147483648337081E9]
[-2.1474836479972153E9, -1.0239098431251064, -0.4450188807874582, -2.1474836480612097E9]
[-0.7149451174254677, -2.1474836483350754E9, -2.147483648333808E9, -0.6718142750423626]
[-0.4322715044602754, -2.1474836481090927E9, -2.1474836481949987E9, -1.0470634684331648]
[原标题, :, 日, 媒拍, 到, 了, 现场, 罕见, 一幕, 据, 日本, 新闻, 网, (, N, NN, ), 9月, 8日, 报道, ,, 日前, ,, 日本, 海上, 自卫队, 现役, 最大, 战舰, 之一, 的, 直升, 机航, 母, “, 加贺, ”, 号在, 南, 海航, 行时, ,, 遭多, 艘, 中国, 海军, 战舰, 抵, 近跟, 踪, 监视, 。]


监督学习HMM:
[原, 标题, :, 日媒, 拍到, 了, 现场, 罕见, 一幕, 据, 日本, 新闻网, (, NN, N)9月8日, 报道, ,, 日前, ,, 日本, 海上, 自卫队, 现役, 最大, 战舰, 之, 一, 的, 直升, 机航母, “, 加贺, ”, 号, 在, 南海, 航行, 时, ,, 遭多, 艘, 中国, 海军, 战舰, 抵近, 跟踪, 监视, 。]

虽然没有指定参数A,但是可以看到学习出来的A还是有准确性,比如B转移到B的概率为0,B转移到S的概率为0,M转移到M的概率为0,M转移到S的概率为0.....这和监督学习的HMM一样的。

这样看来这个算法确实有效。

光从结果来看无监督的HMM指定了pi和B参数,整体效果还是差于监督学习的HMM。测试语料只有人民日报1998的分割语料。

由于原代码比较长,并且本来是写到我的开源项目中的,所以不简单是整合到一个类中就是所有代码,还包含了一些依赖。

这里我整理出了串行版本的只包好2个依赖的代码,供学习使用,由于训练中很多步骤都可以并行实现,所以我并行了一些消耗时间的步骤,要比串行的快得多。过几天我会更新到github上,完整的源码请参考我的开源项目:https://github.com/colin0000007/CONLP

代码中需要用到语料以及HMM的参数A和B:

https://pan.baidu.com/s/1CgEaHEUgn25FKfsGl3kwBg

下面是串行版本的代码,:

package com.outsider.test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
/**
 * 
 * 无监督学习的HMM实现
 * 少量数据建议串行
 * 大量数据,几十万,百万甚至更高的数据强烈建议并行训练,性能是串行的好4倍以上
 * @author outsider
 */
public class UnsupervisedFirstOrderGeneralHMM{
	private double precision = 1e-7;
	/**
	 * 训练数据长度
	 */
	private int sequenceLen;
	public Logger logger = Logger.getLogger(UnsupervisedFirstOrderGeneralHMM.class.getName());
	/**初始状态概率**/
	protected double[] pi;
	/**转移概率**/
	protected double[][] transferProbability1;
	/**发射概率**/
	protected double[][] emissionProbability;
	/**定义无穷大**/
	public static final double INFINITY = (double) -Math.pow(2, 31);
	/**状态值集合的大小**/
	protected int stateNum;
	/**观测值集合的大小**/
	protected int observationNum;
	public UnsupervisedFirstOrderGeneralHMM() {
		super();
	}
	public UnsupervisedFirstOrderGeneralHMM(int stateNum, int observationNum, double[] pi,
			double[][] transferProbability1, double[][] emissionProbability) {
		this.stateNum = stateNum;
		this.observationNum = observationNum;
		this.pi = pi;
		this.transferProbability1 = transferProbability1;
		this.emissionProbability = emissionProbability;
	}
	public UnsupervisedFirstOrderGeneralHMM(int stateNum, int observationNum) {
		this.stateNum = stateNum;
		this.observationNum = observationNum;
		initParameters();
	}
	/**
	 * λ是HMM参数的总称
	 */
	
	/**
	 * 训练方法
	 * @param x 训练序列数据
	 * @param maxIter 最大迭代次数
	 * @param precision 精度
	 */
	public void train(int[] x, int maxIter, double precision) {
		this.sequenceLen = x.length;
		baumWelch(x, maxIter, precision);
	}
	
	public void train(int[] x) {
		this.sequenceLen = x.length;
		//不做概率归一化
	}
	
	/**
	 * baumWelch算法迭代求解
	 * 迭代时存在这样的现象:新参数和上一次的参数差反而会变大,但是到后面这个误差值几乎会收敛
	 * 所以迭代终止的条件有2个:
	 * 1.达到最大迭代次数
	 * 2.参数A,B,pi中的值相比上一次的最大误差小于某个精度值则认为收敛
	 * 3.若1中给的精度值太大,则可能导致无法收敛,所以增加了一个条件,如果当前迭代的误差和上一次迭代的误差小于某个值(这里给定1e-7),
	 * 可以认为收敛了。
	 * @param x 观测序列
	 * @param maxIter 最大迭代次数,如果传入<=0的数则默认为Integer.MAX_VALUE,相当于不收敛就不跳出循环
	 * @param precision 参数误差的精度小于precision就认为收敛
	 */
	protected void baumWelch(int[] x, int maxIter, double precision) {
		int iter = 0;
		double oldMaxError = 0;
		if(maxIter <= 0) {
			maxIter = Integer.MAX_VALUE;
		}
		//初始化各种参数
		double[][] alpha = new double[sequenceLen][stateNum];
		double[][] beta = new double[sequenceLen][stateNum];
		double[][] gamma  = new double[sequenceLen][stateNum];
		double[][][] ksi = new double[sequenceLen][stateNum][stateNum];
		while(iter < maxIter) {
			logger.info("\niter"+iter+"...");
			long start = System.currentTimeMillis();
			//计算各种参数,为更新模型参数做准备,对应EM中的E步
			calcAlpha(x, alpha);
			calcBeta(x, beta);
			calcGamma(x, alpha, beta, gamma);
			calcKsi(x, alpha, beta, ksi);
			//更新参数,对应EM中的M步
			double[][] oldA = generateOldA();
			//double[][] oldB = generateOldB();
			//double[] oldPi = pi.clone();
			updateLambda(x, gamma, ksi);
			//double maxError = calcError(oldA, oldPi, oldB);
			double maxError = calcError(oldA, null, null);
			logger.info("max_error:"+maxError);
			if(maxError < precision || (Math.abs(maxError-oldMaxError)) < this.precision) {
				logger.info("参数已收敛....");
				break;
			}
			oldMaxError = maxError;
			iter++;
			long end = System.currentTimeMillis();
			logger.info("本次迭代结束,耗时:"+(end - start)+"毫秒");
		}
		logger.info("最终参数:");
		logger.info("pi:"+Arrays.toString(pi));
		logger.info("A:");
		for(int i = 0; i < transferProbability1.length; i++) {
			logger.info(Arrays.toString(transferProbability1[i]));
		}
	}

	/**
	 * 保存旧的参数A
	 * @return
	 */
	protected double[][] generateOldA() {
		double[][] oldA = new double[stateNum][stateNum];
		for(int i = 0; i < stateNum; i++) {
			for(int j = 0; j < stateNum; j++) {
				oldA[i][j] = transferProbability1[i][j];
			}
		}
		return oldA;
	}
	/**
	 * 保存旧的参数B
	 * @return
	 */
	protected double[][] generateOldB() {
		double[][] oldB = new double[stateNum][observationNum];
		for(int i = 0; i < stateNum; i++) {
			for(int j = 0; j < observationNum; j++) {
				oldB[i][j] = emissionProbability[i][j];
			}
		}
		return oldB;
	}
	/**
	 * 暂时只计算参数A的误差
	 * 发现计算B和pi会发现参数误差越来越大的现象,基本不能收敛
	 * @param old
	 * @return
	 */
	protected double calcError(double[][] oldA, double[] oldPi, double[][] oldB) {
		double maxError = 0;
		for(int i =0 ; i < stateNum; i++) {
			/*double tmp1 = Math.abs(pi[i] - oldPi[i]);
			maxError = tmp1 > maxError ? tmp1 : maxError;*/
			for(int j =0; j < stateNum; j++) {
				double tmp = Math.abs(oldA[i][j] - transferProbability1[i][j]);
				maxError = tmp > maxError ? tmp : maxError;
			}
			/*for(int k =0; k < observationNum; k++) {
				double tmp2 = Math.abs(emissionProbability[i][k] - oldB[i][k]);
				maxError = tmp2 > maxError ? tmp2 : maxError;
			}*/
		}
		return maxError;
	}
	/**
	 * 概率初始化为0
	 */
	public void initParameters() {
		//初始概率随机初始化
		pi = new double[stateNum];
		transferProbability1 = new double[stateNum][stateNum];
		emissionProbability = new double[stateNum][observationNum];
		//概率初始化为0
		for(int i = 0; i < stateNum; i++) {
			pi[i] = INFINITY;
			for(int j = 0; j < stateNum; j++) {
				transferProbability1[i][j] = INFINITY;
			}
			for(int k = 0; k < observationNum; k++) {
				emissionProbability[i][k] = INFINITY;
			}
		}
	}
	/**
	 * 数组求和
	 * @param arr
	 * @return
	 */
	public static double sum(double[] arr) {
		double sum = 0;
		for(int i = 0; i < arr.length;i++) {
			sum += arr[i];
		}
		return sum;
	}
	/**
	 * 随机初始化参数PI
	 */
	public void randomInitPi() {
		for(int i = 0; i < stateNum; i++) {
			pi[i] = Math.random() * 100;
		}
		//log归一化
		double sum = Math.log(sum(pi));
		for(int i =0; i < stateNum; i++) {
			if(pi[i] == 0) {
				pi[i] = INFINITY;
				continue;
			}
			pi[i] = Math.log(pi[i]) - sum;
		}
	}
	/**
	 * 随机初始化参数A
	 */
	public void randomInitA() {
		for(int i = 0; i < stateNum; i++) {
			for(int j = 0; j < stateNum; j++) {
				transferProbability1[i][j] = Math.random()*100;;
			}
			double sum = Math.log(sum(transferProbability1[i]));
			for(int k = 0; k < stateNum; k++) {
				if(transferProbability1[i][k] == 0) {
					transferProbability1[i][k] = INFINITY;
					continue;
				}
				transferProbability1[i][k]  = Math.log(transferProbability1[i][k]) - sum;
			}
		}
	}
	/**
	 * 随机初始化参数B
	 */
	public void randomInitB() {
		for(int i = 0; i < stateNum; i++) {
			for(int j = 0; j < observationNum; j++) {
				emissionProbability[i][j] = Math.random()*100;;
			}
			double sum = Math.log(sum(emissionProbability[i]));
			for(int k = 0; k < observationNum; k++) {
				if(emissionProbability[i][k] == 0) {
					emissionProbability[i][k] = INFINITY;
					continue;
				}
				emissionProbability[i][k]  = Math.log(emissionProbability[i][k]) - sum;
			}
		}
	}
	
	/**
	 * 随机初始化所有参数
	 */
	public void randomInitAllParameters() {
		randomInitA();
		randomInitB();
		randomInitPi();
	}
	
	/**
	 * 前向算法,根据当前参数λ计算α
	 * α是一个序列长度*状态长度的矩阵
	 * 已检测,应该没有问题
	 */
	protected void calcAlpha(int[] x, double[][] alpha) {
		logger.info("计算alpha...");
		long start = System.currentTimeMillis();
		//double[][] alpha = new double[sequenceLen][stateNum];
		//alpha t=0初始值
		for(int i = 0; i < stateNum; i++) {
			alpha[0][i] = pi[i] + emissionProbability[i][x[0]];
		}
		double[] logProbaArr = new double[stateNum];
		for(int t = 1; t < sequenceLen; t++) {
			for(int i = 0; i < stateNum; i++) {
				for(int j = 0; j < stateNum; j++) {
					logProbaArr[j]	= (alpha[t -1][j] + transferProbability1[j][i]);
				}
				alpha[t][i] = logSum(logProbaArr) + emissionProbability[i][x[t]];
			}
		}
		long end = System.currentTimeMillis();
		logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
		//return alpha;
	}
	/**
	 * 后向算法,根据当前参数λ计算β
	 * 
	 * @param x
	 */
	protected void calcBeta(int[] x, double[][] beta) {
		logger.info("计算beta...");
		long start = System.currentTimeMillis();
		//double[][] beta = new double[sequenceLen][stateNum];
		//初始概率beta[T][i] = 1
		for(int i = 0; i < stateNum; i++) {
			beta[sequenceLen-1][i] = 1;
		}
		double[] logProbaArr = new double[stateNum];
		for(int t = sequenceLen -2; t >= 0; t--) {
			for(int i = 0; i < stateNum; i++) {
				for(int j = 0; j < stateNum; j++) {
					logProbaArr[j] = transferProbability1[i][j] + 
							emissionProbability[j][x[t+1]] +
							beta[t + 1][j];
				}
				beta[t][i] = logSum(logProbaArr);
			}
		}
		long end = System.currentTimeMillis();
		logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
		//return beta;
	}
	
	/**
	 * 根据当前参数λ计算ξ
	 * @param x 观测结点
	 * @param alpha 前向概率
	 * @param beta 后向概率
	 */
	protected void calcKsi(int[] x, double[][] alpha, double[][] beta, double[][][] ksi) {
		logger.info("计算ksi...");
		long start = System.currentTimeMillis();
		//double[][][] ksi = new double[sequenceLen][stateNum][stateNum];
		double[] logProbaArr = new double[stateNum * stateNum];
		for(int t = 0; t < sequenceLen -1; t++) {
			int k = 0;
			for(int i = 0; i < stateNum; i++) {
				for(int j = 0; j < stateNum; j++) {
					ksi[t][i][j] = alpha[t][i] + transferProbability1[i][j] +
							emissionProbability[j][x[t+1]]+beta[t+1][j];
					logProbaArr[k++] = ksi[t][i][j];
				}
			}
			double logSum = logSum(logProbaArr);//分母
			for(int i = 0; i < stateNum; i++) {
				for(int j = 0; j < stateNum; j++) {
					ksi[t][i][j] -= logSum;//分子除分母
				}
			}
		}
		long end = System.currentTimeMillis();
		logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
		//return ksi;
	}
	
	/**
	 * 根据当前参数λ,计算γ
	 * @param x
	 */
	protected void calcGamma(int[] x, double[][] alpha, double[][] beta, double[][] gamma) {
		logger.info("计算gamma...");
		long start = System.currentTimeMillis();
		//double[][] gamma  = new double[sequenceLen][stateNum];
		for(int t = 0; t < sequenceLen; t++) {
			//分母需要求LogSum
			for(int i = 0; i < stateNum; i++) {
				gamma[t][i] = alpha[t][i] + beta[t][i];
			}
			double logSum = logSum(gamma[t]);//分母部分
			for(int j = 0; j < stateNum; j++) {
				gamma[t][j] = gamma[t][j] - logSum;
			}
		}
		long end = System.currentTimeMillis();
		logger.info("计算结束...耗时:"+ (end - start) +"毫秒");
		//return gamma;
	}
	
	/**
	 * 更新参数
	 */
	protected void updateLambda(int[] x ,double[][] gamma, double[][][] ksi) {
		//顺序可以颠倒
		updatePi(gamma);
		updateA(ksi, gamma);
		updateB(x, gamma);
	}
	
	/**
	 * 更新参数pi
	 * @param gamma
	 */
	public void updatePi(double[][] gamma) {
		//更新HMM中的参数pi
		for(int i = 0; i < stateNum; i++) {
			pi[i] = gamma[0][i];
		}
	}
	/**
	 * 更新参数A
	 * @param ksi
	 * @param gamma
	 */
	protected void updateA(double[][][] ksi, double[][] gamma) {
		logger.info("更新参数转移概率A...");
		////由于在更新A都要用到对不同状态的前T-1的gamma值求和,所以这里先算
		double[] gammaSum = new double[stateNum];
		double[] tmp = new double[sequenceLen -1];
		for(int i = 0; i < stateNum; i++) {
			for(int t = 0; t < sequenceLen -1; t++) {
				tmp[t] = gamma[t][i];
			}
			gammaSum[i]  = logSum(tmp);
		}
		long start1 = System.currentTimeMillis();
		//更新HMM中的参数A
		double[] ksiLogProbArr = new double[sequenceLen - 1];
		for(int i = 0; i < stateNum; i++) {
			for(int j = 0; j < stateNum; j++) {
				for(int t = 0; t < sequenceLen -1; t++) {
					ksiLogProbArr[t] = ksi[t][i][j];
				}
				transferProbability1[i][j] = logSum(ksiLogProbArr) - gammaSum[i];
			}
		}
		long end1 = System.currentTimeMillis();
		logger.info("更新完毕...耗时:"+(end1 - start1)+"毫秒");
	}
	/**
	 * 更新参数B
	 * @param x
	 * @param gamma
	 */
	protected void updateB(int[] x, double[][] gamma) {
		//下面需要用到gamma求和为了减少重复计算,这里直接先计算
		//由于在更新B时都要用到对不同状态的所有gamma值求和,所以这里先算
		double[] gammaSum2 = new double[stateNum];
		double[] tmp2 = new double[sequenceLen];
		for(int i = 0; i < stateNum; i++) {
			for(int t = 0; t < sequenceLen; t++) {
				tmp2[t] = gamma[t][i];
			}
			gammaSum2[i]  = logSum(tmp2);
		}
		logger.info("更新状态下分布概率B...");
		long start2 = System.currentTimeMillis();
		ArrayList<Double> valid = new ArrayList<Double>();
		for(int i = 0; i < stateNum; i++) {
			for(int k = 0; k < observationNum; k++) {
				valid.clear();//由于这里没有初始化造成了计算出错的问题
				for(int t = 0; t < sequenceLen; t++) {
					if(x[t] == k) {
						valid.add(gamma[t][i]);
					}
				}
				//B[i][k],i状态下k的分布为概率0,
				if(valid.size() == 0) {
					emissionProbability[i][k] = INFINITY;
					continue;
				}
				//对分子求logSum
				double[] validArr = new double[valid.size()];
				for(int q = 0; q < valid.size(); q++) {
					validArr[q] = valid.get(q);
				}
				double validSum = logSum(validArr);
				//分母的logSum已经在上面做了
				emissionProbability[i][k] = validSum - gammaSum2[i];
			}
		}
		long end2 = System.currentTimeMillis();
		logger.info("更新完毕...耗时:"+(end2 - start2)+"毫秒");
	}
	
	/**
	 * parallelUpdateB并行计算部分
	 * @param x
	 * @param gamma
	 * @param i
	 * @param gammaSum2
	 */
	protected void updateBinSpecificState(int[] x, double[][] gamma, int i, double[] gammaSum2) {
		List<Double> valid = new ArrayList<>();
		for(int k = 0; k < observationNum; k++) {
			valid.clear();//由于这里没有初始化造成了计算出错的问题
			for(int t = 0; t < sequenceLen; t++) {
				if(x[t] == k) {
					valid.add(gamma[t][i]);
				}
			}
			//B[i][k],i状态下k的分布为概率0,
			if(valid.size() == 0) {
				emissionProbability[i][k] = INFINITY;
				continue;
			}
			//对分子求logSum
			double[] validArr = new double[valid.size()];
			for(int q = 0; q < valid.size(); q++) {
				validArr[q] = valid.get(q);
			}
			double validSum = logSum(validArr);
			//分母的logSum已经在上面做了
			emissionProbability[i][k] = validSum - gammaSum2[i];
		}
	}
	
	
	/**
	 * logSum计算技巧
	 * @param tmp
	 * @return
	 */
	public double logSum(double[] logProbaArr) {
		if(logProbaArr.length == 0) {
			return INFINITY;
		}
		double max = max(logProbaArr);
		double result = 0;
		for(int i = 0; i < logProbaArr.length; i++) {
			result += Math.exp(logProbaArr[i] - max);
		}
		return max + Math.log(result);
	}
	/**
	 * 设置先验概率pi
	 * 必须传入取对数后的概率
	 * @param pi
	 */
	public void setPriorPi(double[] pi){
		this.pi = pi;
	}
	/**
	 * 设置先验转移概率A
	 * 必须传入取对数的概率
	 * @param trtransferProbability1
	 */
	public void setPriorTransferProbability1(double[][] trtransferProbability1){
		this.transferProbability1 = trtransferProbability1;
	}
	/**
	 * 设置先验状态下的观测分布概率,B
	 * 必须传入取对数的概率
	 * @param emissionProbability
	 */
	public void setPriorEmissionProbability(double[][] emissionProbability) {
		this.emissionProbability = emissionProbability;
	}
	
	public static double max(double[] arr) {
		double max = arr[0];
		for(int i = 1; i < arr.length;i++) {
			max = arr[i] > max ? arr[i] : max;
		}
		return max;
	}
	
	/**
	 * 维特比解码
	 * @param O 观测序列,输入的是经过编码处理的,而不是原始数据,
	 * 比如,如果序列是字符串,那么输入必须是一系列的字符的编码而不是字符本身
	 * @return 返回预测结果,
	 */
	public int[] verterbi(int[] O) {
		double[][] deltas = new double[O.length][this.stateNum];
		//保存deltas[t][i]的值是由上一个哪个状态产生的
		int[][] states = new int[O.length][this.stateNum];
		//初始化deltas[0][]
		for(int i = 0;i < this.stateNum; i++) {
			deltas[0][i] = pi[i] + emissionProbability[i][O[0]];
		}
		//计算deltas
		for(int t = 1; t < O.length; t++) {
			for(int i = 0; i < this.stateNum; i++) {
				deltas[t][i] = deltas[t-1][0]+transferProbability1[0][i];
				for(int j = 1; j < this.stateNum; j++) {
					double tmp = deltas[t-1][j]+transferProbability1[j][i];
					if (tmp > deltas[t][i]) {
						deltas[t][i] = tmp;
						states[t][i] = j;
					}
				}
				deltas[t][i] += emissionProbability[i][O[t]];
			}
		}
		//回溯找到最优路径
		int[] predict = new int[O.length];
		double max = deltas[O.length-1][0];
		for(int i = 1; i < this.stateNum; i++) {
			if(deltas[O.length-1][i] > max) {
				max = deltas[O.length-1][i];
				predict[O.length-1] = i;				
			}
		}
		for(int i = O.length-2;i >= 0;i-- ) {
			predict[i] = states[i+1][predict[i+1]];
		}
		return predict;
	}
	
	//测试
	public static void main(String[] args) {
		UnsupervisedFirstOrderGeneralHMM hmm = new UnsupervisedFirstOrderGeneralHMM(4, 65536);
		//关闭日志打印
		//CONLPLogger.closeLogger(hmm.logger);
		//由于是监督学习的语料所以这里需要去掉其中的分隔符
		String path = "src/pku_training.splitBy2space.utf8";
		String data = IOUtils.readText(path, "utf-8");
		String[] d2 = data.split("  ");
		StringBuilder sb = new StringBuilder();
		for(String word : d2) {
			sb.append(word);
		}
		data = sb.toString();
		//训练数据
		int[] x = SegmentationUtils.str2int(data);
		//由于串行很慢,可以只取训练数据的前10000个来训练
		int[] minX = new int[10000];
		 System.arraycopy(x, 0, minX, 0, 10000);
		//训练之前设置先验概率,必须设置,EM对初始值敏感,如果不设置默认为都为0,所有参数都将一样,没有意义
		//如果只给了其中一些参数的先验值,可以随机初始化其他参数,例如
		//hmm.randomInitA();
		//hmm.randomInitB();
		//hmm.randomInitPi();
		//hmm.randomInitAllParameters();
		//设置先验信息至少设置参数pi,A,B中的一个
		hmm.setPriorPi(new double[] {-1.138130826175848, -2.632826946498266, -1.138130826175848, -1.2472622308278396});
		hmm.setPriorTransferProbability1((double[][]) IOUtils.readObject("src/A"));
		hmm.setPriorEmissionProbability((double[][]) IOUtils.readObject("src/B"));
		//开始训练
		hmm.train(minX, -1, 0.5);
		String str = "原标题:日媒拍到了现场罕见一幕" + 
				"据日本新闻网(NNN)9月8日报道,日前,日本海上自卫队现役最大战舰之一的直升机航母“加贺”号在南海航行时,遭多艘中国海军战舰抵近跟踪监视。" ; 
		//将词转换为对应的Unicode码
		int[] O = SegmentationUtils.str2int(str);
		int[] predict = hmm.verterbi(O);
		System.out.println(Arrays.toString(predict));
		String[] res = SegmentationUtils.decode(predict, str);
		System.out.println(Arrays.toString(res));
	}
}		

依赖IoUtils:

package com.outsider.test;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;

public class IOUtils {
	public static String readTextWithLineCheckBreak(String path, String encoding) {
		return readText(path, encoding, "\n");
	}
	/**
	 * 读取文本文件,返回整个字符串,不包括换行符号
	 * @param path 文件路径
	 * @param encoding 编码,传入null或者空串使用默认编码
	 * @return
	 */
	public static String readText(String path, String encoding) {
		return readText(path, encoding, null);
	}
	/**
	 * 读取文本,指定每一行末尾符号
	 * @param path
	 * @param encoding
	 * @param lineEndStr
	 * @return
	 */
	public static String readText(String path, String encoding, String lineEndStr) {
		try {
			if(lineEndStr == null) {
				lineEndStr = "";
			}
			BufferedReader reader = null;
			if((!encoding.trim().equals(""))&&encoding!=null) {
				reader = new BufferedReader(new InputStreamReader(new FileInputStream(path),encoding));
			} else {
				reader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
			}
			String s="";
			StringBuilder sb  = new StringBuilder();
			while((s=reader.readLine())!=null) {
				sb.append(s+lineEndStr);
			}
			reader.close();
			return sb.toString();
		} catch (UnsupportedEncodingException e) {
			e.printStackTrace();
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return null;
	}
	/**
	 * 读取文本文件,返回整个字符串,不包括换行符号
	 * @param path 文件路径
	 * @param encoding 编码,传入null或者空串使用默认编码
	 * @param addNewLine 是否加换行符
	 * @return
	 */
	public static List<String> readTextAndReturnLinesCheckLineBreak(String path, String encoding, boolean addNewLine) {
		try {
			String lineBreak;
			if(addNewLine) {
				lineBreak = "\n";
			} else {
				lineBreak = "";
			}
			BufferedReader reader = null;
			if((!encoding.trim().equals(""))&&encoding!=null) {
				reader = new BufferedReader(new InputStreamReader(new FileInputStream(path),encoding));
			} else {
				reader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
			}
			String s="";
			List<String> list = new ArrayList<>();
			while((s=reader.readLine())!=null) {
				list.add(s+lineBreak);
			}
			reader.close();
			return list;
		} catch (UnsupportedEncodingException e) {
			e.printStackTrace();
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return null;
	}
	
	public static List<String> readTextAndReturnLines(String path, String encoding){
		return readTextAndReturnLinesCheckLineBreak(path, encoding, false);
	}
	/**
	 * 读取文本的每一行
	 * 并且返回数组形式
	 * @param path
	 * @param encoding
	 * @return
	 */
	public static String[] readTextAndReturnLinesOfArray(String path, String encoding){
		List<String> lines = readTextAndReturnLines(path, encoding);
		String[] arr = new String[lines.size()];
		lines.toArray(arr);
		return arr;
	}
	/**
	 * 写入文本文件
	 * @param data
	 * @param path
	 * @param encoding
	 */
	public static void writeTextData2File(String data,String path,String encoding) {
		try {
			BufferedWriter writer = null;
			if((!encoding.trim().equals(""))&&encoding!=null) {
				writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path),encoding));
			} else {
				writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path)));
			}
			writer.write(data);
			writer.close();
		} catch (UnsupportedEncodingException e) {
			e.printStackTrace();
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	
	/**
	 * 把对象写入文件
	 * @param path
	 * @param object
	 */
	public static void writeObject2File(String path, Object object) {
		try {
			ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(path));
			out.writeObject(object);
			out.close();
		} catch (Exception e) {
			e.printStackTrace();
		} 
	}
	/**
	 * 读取对象
	 * @param path
	 * @return
	 */
	public static Object readObject(String path) {
		try {
			ObjectInputStream in = new ObjectInputStream(new FileInputStream(path));
			return in.readObject();
		} catch (Exception e) {
			e.printStackTrace();
		}
		return null; 
	}
	
}

依赖的SegmentationUtils:

package com.outsider.test;

import java.util.ArrayList;
import java.util.List;

public class SegmentationUtils {
	/**
	 * 将字符串数组的每一个字符串中的字符直接转换为Unicode码
	 * @param strs 字符串数组
	 * @return Unicode值
	 */
	public static List<int[]> strs2int(String[] strs) {
		List<int[]> res = new ArrayList<>(strs.length);
		for(int i = 0; i < strs.length;i++) {
			int[] O = new int[strs[i].length()];
			for(int j = 0; j < strs[i].length();j++) {
				O[j] = strs[i].charAt(j);
			}
			res.add(O);
		}
		return res;
	}
	
	public static int[] str2int(String str) {
		return strs2int(new String[] {str}).get(0);
	}
	/**
	 * 根据预测结果解码
	 * BEMS 0123
	 * @param predict 预测结果
	 * @param sentence 句子
	 * @return
	 */
	public static String[] decode(int[] predict, String sentence) {
		List<String> res = new ArrayList<>();
		char[] chars = sentence.toCharArray();
		for(int i = 0; i < predict.length;i++) {
			if(predict[i] == 0 || predict[i] == 1) {
				int a = i;
				while(predict[i] != 2) {
					i++;
					if(i == predict.length) {
						break;
					}
				}
				int b = i;
				if(b == predict.length) {
					b--;
				}
				res.add(new String(chars,a,b-a+1));
			} else {
				res.add(new String(chars,i,1));
			}
		}
		String[] s = new String[res.size()];
		return res.toArray(s);
	}
}

猜你喜欢

转载自blog.csdn.net/qq_37667364/article/details/83718931