完整项目
run.py是整个项目的入口,它包含两部分,一是使用argparse工具,配置相关参数;二是整个项目的流程框架,各个模块/函数的调用。
目录
1. 参数配置
#声明argparse对象 可附加说明
parser = argparse.ArgumentParser(description='Chinese Text Classification')
#模型是必须设置的参数(required=True) 类型是字符串
parser.add_argument('--model', type=str, required=True, help='choose a model: bert, bert_CNN,bert_DPCNN,bert_RNN,bert_RCNN,ERNIE')
#解析参数
args = parser.parse_args()
2. 项目流程
if __name__ == '__main__':
dataset = 'THUCNews' # 使用的数据集
model_name = args.model #获取选择的模型名字
x = import_module('models.' + model_name)#根据所选模型名字在models包下 获取相应模块(.py)
config = x.Config(dataset)# 每一个模块(.py)中都有一个模型定义类 和与该模型相关的配置类(定义该模型的超参数) 初始化配置类的对象
# 设置随机种子 确保每次运行的条件(模型参数初始化、数据集的切分或打乱等)是一样的
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True
start_time = time.time()
print("Loading data...")
#数据预处理
train_data, dev_data, test_data = build_dataset(config)#构建训练集、验证集、测试集
# 构建训练集、验证集、测试集迭代器
train_iter = build_iterator(train_data, config)
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
# #构建模型对象 并to_device
model = x.Model(config).to(config.device) #传入模型相应配置类的对象 包含该模型的配置信息
train(config, model, train_iter, dev_iter, test_iter)#训练