def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
self.model = model
self.tokenizer = tokenizer
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
kwargs:
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
Example:
```py
import torch
from transformers import LocalAgent
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, tokenizer)
@property
def _model_device(self):
if hasattr(self.model, "hf_device_map"):
return list(self.model.hf_device_map.values())[0]
for param in self.mode.parameters():
return param.device
def generate_one(self, prompt, stop):
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
src_len = encoded_inputs["input_ids"].shape[1]
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
outputs = self.model.generate(
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
)
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result
class StopSequenceCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever a sequence of tokens is encountered.
Args:
stop_sequences (`str` or `List[str]`):
The sequence (or list of sequences) on which to stop execution.
tokenizer:
The tokenizer used to decode the model outputs.
"""
def __init__(self, stop_sequences, tokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
这段代码主要定义了两个类:一个是LocalAgent,另一个是StopSequenceCriteria。下面我会逐行解释代码。
接下来是StopSequenceCriteria
类的定义。
-
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
这是LocalAgent类的构造函数,接受五个参数。model和tokenizer是用于生成文本的模型和分词器。chat_prompt_template、run_prompt_template和additional_tools这三个参数是可选的,分别代表聊天提示模板、运行提示模板和附加工具。 -
self.model = model
和self.tokenizer = tokenizer
这两行将输入的model和tokenizer保存为LocalAgent对象的属性。 -
super().__init__(chat_prompt_template=chat_prompt_template, run_prompt_template=run_prompt_template, additional_tools=additional_tools)
调用父类的构造函数,传入chat_prompt_template、run_prompt_template和additional_tools。 -
@classmethod
这是一个修饰器,表示下面的方法是类方法。 -
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
这是一个类方法,用于从预训练模型中创建LocalAgent对象。pretrained_model_name_or_path是预训练模型的名称或者路径,**kwargs是其他的关键字参数。 -
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
和tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
从预训练模型中加载模型和分词器。 -
return cls(model, tokenizer)
使用加载的模型和分词器创建LocalAgent对象,并返回。 -
@property
是一个修饰器,表示下面的方法是一个属性。 -
def _model_device(self):
这个方法返回模型的设备。 -
if hasattr(self.model, "hf_device_map"):
判断模型是否有hf_device_map属性。 -
return list(self.model.hf_device_map.values())[0]
如果有hf_device_map属性,返回第一个设备。 -
for param in self.mode.parameters():
如果没有hf_device_map属性,遍历模型的参数。 -
return param.device
返回第一个参数的设备。 -
def generate_one(self, prompt, stop):
这个方法用于根据给定的提示生成一段文本。prompt是提示,stop是停止标志。 -
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
对提示进行编码,并将编码结果移动到模型的设备上。 -
src_len = encoded_inputs["input_ids"].shape[1]
获取输入的长度。 -
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
创建停止条件,当生成的文本包含停止标志时停止生成。 -
outputs = self.model.generate(encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria)
生成文本。 -
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
解码生成的文本。 -
for stop_seq in20.
for stop_seq in stop:` 遍历每一个停止序列。 -
if result.endswith(stop_seq):
检查生成的结果是否以停止序列结尾。 -
result = result[: -len(stop_seq)]
如果是,则将这个停止序列从结果中去掉。 -
return result
返回生成的结果。 -
class StopSequenceCriteria(StoppingCriteria):
定义一个名为StopSequenceCriteria的类,它继承自StoppingCriteria。这个类用于在生成过程中遇到特定序列时停止生成。 -
def __init__(self, stop_sequences, tokenizer):
这是StopSequenceCriteria的构造函数,接受两个参数:停止序列和分词器。 -
if isinstance(stop_sequences, str):
如果stop_sequences是字符串,那么将其转化为列表。 -
stop_sequences = [stop_sequences]
-
self.stop_sequences = stop_sequences
和self.tokenizer = tokenizer
将输入的停止序列和分词器保存为StopSequenceCriteria对象的属性。 -
def __call__(self, input_ids, scores, **kwargs) -> bool:
定义了该类的调用方法,输入参数为输入的id、得分以及其他关键字参数,返回值是布尔值。 -
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
将输入的id解码为文本。 -
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
如果解码出的文本以任何一个停止序列结尾,那么返回True,否则返回False。
对最后两个函数有更深入的理解:
-
generate_one
是LocalAgent
类的一个实例方法。这个方法通过使用实例的tokenizer
对输入的提示prompt
进行编码,然后调用模型的generate
方法生成新的文本。生成的文本通过使用StopSequenceCriteria
停止条件进行控制,如果生成的文本满足停止条件(即包含某个特定序列),则停止生成新的文本。生成的新文本通过tokenizer
解码成字符串。如果解码出的结果以任何一个停止序列结尾,那么该停止序列将被去掉。最后,方法返回生成的结果。 -
StopSequenceCriteria
是一个继承自StoppingCriteria
的子类。它重写了父类的__call__
方法。在这个方法中,输入的 id 通过使用实例的tokenizer
解码为字符串,然后检查解码出的字符串是否以任何一个停止序列结尾,如果是,则返回True
,否则返回False
。这个返回值会被用来决定是否需要停止生成新的文本。