transformers库源码解析/src/transformers/tools/agents.py(二)from_pretrained()方法

def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
        self.model = model
        self.tokenizer = tokenizer
        super().__init__(
            chat_prompt_template=chat_prompt_template,
            run_prompt_template=run_prompt_template,
            additional_tools=additional_tools,
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """
        Convenience method to build a `LocalAgent` from a pretrained checkpoint.

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
            kwargs:
                Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].

        Example:

        ```py
        import torch
        from transformers import LocalAgent

        agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
        agent.run("Draw me a picture of rivers and lakes.")
        ```
        """
        model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
        return cls(model, tokenizer)

    @property
    def _model_device(self):
        if hasattr(self.model, "hf_device_map"):
            return list(self.model.hf_device_map.values())[0]
        for param in self.mode.parameters():
            return param.device

    def generate_one(self, prompt, stop):
        encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
        src_len = encoded_inputs["input_ids"].shape[1]
        stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
        outputs = self.model.generate(
            encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
        )

        result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
        # Inference API returns the stop sequence
        for stop_seq in stop:
            if result.endswith(stop_seq):
                result = result[: -len(stop_seq)]
        return result


class StopSequenceCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever a sequence of tokens is encountered.

    Args:
        stop_sequences (`str` or `List[str]`):
            The sequence (or list of sequences) on which to stop execution.
        tokenizer:
            The tokenizer used to decode the model outputs.
    """

    def __init__(self, stop_sequences, tokenizer):
        if isinstance(stop_sequences, str):
            stop_sequences = [stop_sequences]
        self.stop_sequences = stop_sequences
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
        return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)

这段代码主要定义了两个类:一个是LocalAgent,另一个是StopSequenceCriteria。下面我会逐行解释代码。

接下来是StopSequenceCriteria的定义。

  1. def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None): 这是LocalAgent类的构造函数,接受五个参数。model和tokenizer是用于生成文本的模型和分词器。chat_prompt_template、run_prompt_template和additional_tools这三个参数是可选的,分别代表聊天提示模板、运行提示模板和附加工具。

  2. self.model = modelself.tokenizer = tokenizer 这两行将输入的model和tokenizer保存为LocalAgent对象的属性。

  3. super().__init__(chat_prompt_template=chat_prompt_template, run_prompt_template=run_prompt_template, additional_tools=additional_tools) 调用父类的构造函数,传入chat_prompt_template、run_prompt_template和additional_tools。

  4. @classmethod 这是一个修饰器,表示下面的方法是类方法

  5. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 这是一个类方法,用于从预训练模型中创建LocalAgent对象。pretrained_model_name_or_path是预训练模型的名称或者路径,**kwargs是其他的关键字参数。

  6. model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) 从预训练模型中加载模型和分词器。

  7. return cls(model, tokenizer) 使用加载的模型和分词器创建LocalAgent对象,并返回。

  8. @property 是一个修饰器,表示下面的方法是一个属性

  9. def _model_device(self): 这个方法返回模型的设备

  10. if hasattr(self.model, "hf_device_map"): 判断模型是否有hf_device_map属性。

  11. return list(self.model.hf_device_map.values())[0] 如果有hf_device_map属性,返回第一个设备。

  12. for param in self.mode.parameters(): 如果没有hf_device_map属性,遍历模型的参数。

  13. return param.device 返回第一个参数的设备。

  14. def generate_one(self, prompt, stop): 这个方法用于根据给定的提示生成一段文本。prompt是提示,stop是停止标志。

  15. encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device) 对提示进行编码,并将编码结果移动到模型的设备上。

  16. src_len = encoded_inputs["input_ids"].shape[1] 获取输入的长度。

  17. stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)]) 创建停止条件,当生成的文本包含停止标志时停止生成。

  18. outputs = self.model.generate(encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria) 生成文本。

  19. result = self.tokenizer.decode(outputs[0].tolist()[src_len:]) 解码生成的文本。

  20. for stop_seq in20. for stop_seq in stop:` 遍历每一个停止序列。

  21. if result.endswith(stop_seq): 检查生成的结果是否以停止序列结尾。

  22. result = result[: -len(stop_seq)] 如果是,则将这个停止序列从结果中去掉。

  23. return result 返回生成的结果。

  24. class StopSequenceCriteria(StoppingCriteria): 定义一个名为StopSequenceCriteria的类,它继承自StoppingCriteria。这个类用于在生成过程中遇到特定序列时停止生成。

  25. def __init__(self, stop_sequences, tokenizer): 这是StopSequenceCriteria的构造函数,接受两个参数:停止序列和分词器。

  26. if isinstance(stop_sequences, str): 如果stop_sequences是字符串,那么将其转化为列表。

  27. stop_sequences = [stop_sequences]

  28. self.stop_sequences = stop_sequencesself.tokenizer = tokenizer 将输入的停止序列和分词器保存为StopSequenceCriteria对象的属性。

  29. def __call__(self, input_ids, scores, **kwargs) -> bool: 定义了该类的调用方法,输入参数为输入的id、得分以及其他关键字参数,返回值是布尔值。

  30. decoded_output = self.tokenizer.decode(input_ids.tolist()[0]) 将输入的id解码为文本。

  31. return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences) 如果解码出的文本以任何一个停止序列结尾,那么返回True,否则返回False。

对最后两个函数有更深入的理解:

  • generate_one LocalAgent 类的一个实例方法。这个方法通过使用实例的 tokenizer 对输入的提示 prompt 进行编码,然后调用模型的 generate 方法生成新的文本。生成的文本通过使用 StopSequenceCriteria 停止条件进行控制,如果生成的文本满足停止条件(即包含某个特定序列),则停止生成新的文本。生成的新文本通过 tokenizer 解码成字符串。如果解码出的结果以任何一个停止序列结尾,那么该停止序列将被去掉。最后,方法返回生成的结果。

  • StopSequenceCriteria 是一个继承自 StoppingCriteria 的子类。它重写了父类的 __call__ 方法。在这个方法中,输入的 id 通过使用实例的 tokenizer 解码为字符串,然后检查解码出的字符串是否以任何一个停止序列结尾,如果是,则返回 True,否则返回 False。这个返回值会被用来决定是否需要停止生成新的文本。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131578327