# Get the column names for input/target.
prompt_column = data_args.prompt_column
response_column = data_args.response_column
history_column = data_args.history_column
# Temporarily set max_target_length for training.
max_target_length = data_args.max_target_length
def preprocess_function_eval(examples):
inputs, targets = [], []
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
query = examples[prompt_column][i]
history = examples[history_column][i] if history_column is not None else None
prompt = tokenizer.build_prompt(query, history)
inputs.append(prompt)
targets.append(examples[response_column][i])
inputs = [prefix + inp for inp in inputs]
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
if data_args.ignore_pad_token_for_loss:
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
def preprocess_function_train(examples):
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
model_inputs = {
"input_ids": [],
"labels": [],
}
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
query, answer = examples[prompt_column][i], examples[response_column][i]
history = examples[history_column][i] if history_column is not None else None
prompt = tokenizer.build_prompt(query, history)
prompt = prefix + prompt
a_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,
max_length=data_args.max_source_length)
b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,
max_length=data_args.max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [tokenizer.eos_token_id]
labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]
pad_len = max_seq_length - len(input_ids)
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
labels = labels + [tokenizer.pad_token_id] * pad_len
if data_args.ignore_pad_token_for_loss:
labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def print_dataset_example(example):
print("input_ids", example["input_ids"])
print("inputs", tokenizer.decode(example["input_ids"]))
print("label_ids", example["labels"])
print("labels", tokenizer.decode(example["labels"]))
if training_args.do_train:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
with training_args.main_process_first(desc="train dataset map pre-processing"):
train_dataset = train_dataset.map(
preprocess_function_train,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on train dataset",
)
print_dataset_example(train_dataset[0])
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
with training_args.main_process_first(desc="validation dataset map pre-processing"):
eval_dataset = eval_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on validation dataset",
)
print_dataset_example(eval_dataset[0])
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None:
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
predict_dataset = predict_dataset.map(
preprocess_function_eval,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on prediction dataset",
)
print_dataset_example(predict_dataset[0])
# Data collator
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=None,
padding=False
)
# Metric
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
if data_args.ignore_pad_token_for_loss:
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
score_dict = {
"rouge-1": [],
"rouge-2": [],
"rouge-l": [],
"bleu-4": []
}
for pred, label in zip(decoded_preds, decoded_labels):
hypothesis = list(jieba.cut(pred))
reference = list(jieba.cut(label))
rouge = Rouge()
scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
result = scores[0]
for k, v in result.items():
score_dict[k].append(round(v["f"] * 100, 4))
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
for k, v in score_dict.items():
score_dict[k] = float(np.mean(v))
return score_dict
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = (
training_args.generation_max_length
if training_args.generation_max_length is not None
else data_args.val_max_target_length
)
training_args.generation_num_beams = (
data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
save_changed=model_args.pre_seq_len is not None
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
# elif last_checkpoint is not None:
# checkpoint = last_checkpoint
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluation
results = {}
max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Predict ***")
predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
metrics = predict_results.metrics
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
if trainer.is_world_process_zero():
if training_args.predict_with_generate:
predictions = tokenizer.batch_decode(
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
predictions = [pred.strip() for pred in predictions]
labels = tokenizer.batch_decode(
predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
labels = [label.strip() for label in labels]
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
with open(output_prediction_file, "w", encoding="utf-8") as writer:
for p, l in zip(predictions, labels):
res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
writer.write(f"{res}\n")
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
-
prompt_column, response_column, history_column: 这些变量被定义为用于读取训练数据的列名。prompt_column和response_column分别是提问和回答的列,history_column是聊天记录的列。
-
max_target_length: 这个变量是指预测的最大长度。
-
preprocess_function_eval: 这是一个预处理函数,用于在评估阶段对数据进行处理。它创建了输入和目标列表,然后迭代数据集中的每个示例。对于每个示例,它检查是否有prompt和response,然后使用tokenizer将prompt和history转换为模型可以理解的格式。然后,所有的输入都被添加到一个前缀,并用tokenizer进行编码。最后,对目标进行同样的处理,并将处理后的输入和目标加入到模型输入中。
-
preprocess_function_train: 这是一个预处理函数,用于在训练阶段对数据进行处理。它的处理方式与eval的预处理函数类似,但有一些不同之处,例如它还添加了一个eos(end of sentence) token到输入和标签的末尾,并确保输入和标签的长度都符合最大序列长度。
-
print_dataset_example: 这个函数用于打印数据集中的一个示例。
-
training_args.do_train: 这是一个条件语句,如果训练参数中的do_train设定为True,那么它会执行训练数据的预处理并打印一个训练数据的示例。
-
training_args.do_eval: 这也是一个条件语句,如果训练参数中的do_eval设定为True,那么它会执行验证数据的预处理并打印一个验证数据的示例。
-
training_args.do_predict: 同样是一个条件语句,如果训练参数中的do_predict设定为True,那么它会执行测试数据的预处理并打印一个测试数据的示例。
-
label_pad_token_id, data_collator: 这些变量被定义为处理序列到序列任务的工具。label_pad_token_id是用于填充标签的token的ID,data_collator用于处理批量数据。
-
compute_metrics: 这个函数用于计算评估指标。它首先解码预测和标签,然后计算ROUGE和BLEU评分。
-
trainer: 这个变量是一个Seq2SeqTrainer对象,它用于训练模型。
-
if training_args.do_train: 如果训练参数中的do_train设定为True,那么它会执行训练,并保存训练的指标和状态。
-
if training_args.do_eval: 如果训练参数中的do_eval设定为True,那么它会执行评估,并保存评估的指标。
-
if training_args.do_predict: 如果训练参数中的do_predict设定为True,那么它会执行预测,并保存预测的指标。最后,它还将预测的结果写入文件。
-
def _mp_fn(index): 这个函数是用于TPUs的函数,如果在TPU上运行,则会调用这个函数。
-
if name == "main": 这个判断语句用来判断是否为脚本的运行入口,如果是,则执行main()函数。
整个脚本是用于训练和评估一个聊天机器人模型的代码。脚本先进行数据的预处理,然后定义一个模型训练器并使用它来训练模型。在训练后,脚本使用模型进行评估和预测,并将预测结果保存到文件中。在整个过程中,脚本使用了许多与序列到序列任务相关的工具,如数据整理器和评估指标计算函数。
当然,让我们逐行来解析这些代码。
prompt_column = data_args.prompt_column
:从数据参数中获取提示列名称,也就是用于提问的列。response_column = data_args.response_column
:从数据参数中获取回答列的名称,也就是作为回答或目标的列。history_column = data_args.history_column
:从数据参数中获取历史对话列的名称,如果存在的话,这些历史对话将被用作提问的上下文。
以下是预处理函数,它们用于将输入和目标列进行格式化和分词。格式化的结果将被用于模型的训练和验证。
preprocess_function_eval
和preprocess_function_train
:这两个函数是为评估和训练准备数据的。它们从示例数据中提取问题和回答,并根据需要将其进行格式化和分词。然后它们会将输入和目标添加到model_inputs
列表中,然后返回这个列表。
在接下来的代码中,我们根据是否要进行训练、评估或预测,以及提供的数据集中是否包含所需的部分(训练、验证或测试),来分别处理数据集。
-
if training_args.do_train:
:如果设置了训练标志,那么就需要检查是否提供了训练数据集,然后根据需要进行预处理。然后打印出第一个训练样例。 -
if training_args.do_eval:
:类似地,如果设置了评估标志,那么就需要检查是否提供了评估数据集,并进行预处理。然后打印出第一个评估样例。 -
if training_args.do_predict:
:对于预测,我们需要检查是否提供了测试数据集,并进行预处理。然后打印出第一个测试样例。 -
data_collator = DataCollatorForSeq2Seq(...)
:创建一个数据整理器,用于将预处理后的输入数据组装成可以直接喂入模型的批次。
接下来是评估指标的计算函数。这个函数将模型的预测结果与实际标签进行比较,然后计算并返回指标分数。
compute_metrics
:这个函数接收预测和标签,首先进行解码,然后计算rouge和bleu分数。
接着,我们覆盖一些解码参数,然后初始化训练器,并进行训练、评估和预测。
-
trainer = Seq2SeqTrainer(...)
:初始化一个训练器,它将用于训练、评估和预测。 -
if training_args.do_train:
:如果设置了训练标志,就进行训练,并在训练结束后保存模型和指标。 -
if training_args.do_eval:
:如果设置了评估标志,就进行评估,并记录并保存评估指标。 -
if training_args.do_predict:
:如果设置了预测标志,就进行预测,并记录并保存预测指标。如果预测是使用生成方法完成的,就将预测和标签保存到文件中。
最后,如果此脚本是作为主脚本运行的,就调用main
函数。
if __name__ == "__main__":
:如果此脚本是作为主脚本运行的,就调用main
函数。这是Python的一种常见模式,用于检查脚本是直接运行还是作为模块导入。只有在直接运行脚本时,__name__
的值才会是"__main__"
,因此只有在这种情况下,才会调用main
函数。
for i in range(len(examples[prompt_column])):
if examples[prompt_column][i] and examples[response_column][i]:
query = examples[prompt_column][i]
history = examples[history_column][i] if history_column is not None else None
prompt = tokenizer.build_prompt(query, history)
inputs.append(prompt)
targets.append(examples[response_column][i])