分2部分第一部分总的介绍,第二部分自己对代码的理解
目录
扫描二维码关注公众号,回复:
12371008 查看本文章
序列标注任务代码总体流程
- 主要用于实体提取项目,亦可用于其他序列标注项目,基于BiLSTM+CRF实现
实现流程
(1)数据预处理: 将人工标注数据转化为bio标注格式数据
(2)模型构建
(3) 模型训练及参数调节
- batch_size用于设置批次大小,本地电脑一般设置在32已下
- hidden_dim 用于设置隐藏层维度
- lr用于设置学习率,一般开始选择0.001,当训练精度难提高时,加载最优模型,调节学习率为前者1/10,再次训练
- cudnnlstm用于选择是否选用cudnnlstm进行训练
- dropout参数用于加强模型泛化能力,一般设置为0.5,当模型训练完毕,由ckpt转成pb时,需要去掉dropout
(4) 保存模型为ckpt格式
(5)模型部署
采用java语言编写部署代码
以上参考gitlab上面的介绍
序列标注任务代码分块解释(以train实现为例)
代码的总入口在main.py
根据超参数选择程序mode运行,mode分别为test_ckpt,test_pb,train 超参数默认为train
数据预处理
dataPreprocess.py代码主要做了3件事
1.转bio
根据resultId 将mongo数据库的内容导出
转换数据成序列标注格式,保存的数据格式为
标签id,标注结果id,标注集合表
2,获取tag
根据训练数据,转成带有oldtag标签的数据
3.分开存储内容
cutFile 按照num的大小分成f1,f2,f3三个部分存储
模型构建
查看main.py中Apply类下的train
模型训练主函数
导入一份基本包含所以词的文件,并弄成char-id与 id-char的形式备用
接着load_data作用 返回
train_data,dev_data,tag2id,id2tag
train_data,dev_data将训练,测试集返回 单词索引-单词字符索引 -标记索引
tag2id,id2tag 将tag 如(E-proj_name:9)序列化
接着BatchManager作用,返回
向上取整,获得一共多少节
-----------个人认为从这里,终于结束了预训练过程
接着开启tensorflow的会话,
创建模型,重用参数
在model模块中能看到构建图层的全部
self.add_placeholders() 增添placeholders self.lookup_layer_op() 添加特征 “words” self.bilstm_op() 增添双向lstm self.loss_layer_op() loss层 crf的运用 self.optimizer_op() 定义优化器 self.add_summary() 添加摘要
接着batch运行model
接着evaluate运行 返回
f1值
接着model saved
整体完毕