轻量微调和推理stanford_alpca

当前的Alpaca模型是在Self-Instruct论文中使用的技术生成的52K条指令数据,从7B LLaMA模型微调而来,并进行了一些修改。

A10 gpu显存:22G,cu117,驱动470.103.01

absl-py                  1.4.0
accelerate               0.18.0
addict                   2.4.0
aenum                    3.1.12
aiofiles                 23.1.0
aiohttp                  3.8.4
aiosignal                1.3.1
albumentations           0.4.3
altair                   4.2.2
antlr4-python3-runtime   4.9.3
anyio                    3.6.2
appdirs                  1.4.4
asttokens                2.2.1
async-timeout            4.0.2
attrs                    22.2.0
backcall                 0.2.0
basicsr                  1.4.2
bcrypt                   4.0.1
beautifulsoup4           4.12.1
blendmodes               2022
blinker                  1.6
boltons                  23.0.0
braceexpand              0.1.7
cachetools               5.3.0
certifi                  2022.12.7
cffi                     1.15.1
chardet                  4.0.0
charset-normalizer       3.1.0
clean-fid                0.1.29
click                    8.1.3
clip-anytorch            2.5.2
cmake                    3.26.1
comm                     0.1.3
contourpy                1.0.7
cryptography             40.0.1
cssselect2               0.7.0
cycler                   0.11.0
datasets                 2.11.0
debugpy                  1.6.7
decorator                5.1.1
deprecation              2.1.0
diffusers                0.15.0.dev0
dill                     0.3.6
docker-pycreds           0.4.0
einops                   0.4.1
entrypoints              0.4
executing                1.2.0
facexlib                 0.2.5
fastapi                  0.94.0
ffmpy                    0.3.0
filelock                 3.10.7
filterpy                 1.4.5
fire                     0.5.0
font-roboto              0.0.1
fonts                    0.0.3
fonttools                4.39.3
frozenlist               1.3.3
fsspec                   2023.3.0
ftfy                     6.1.1
future                   0.18.3
gdown                    4.7.1
gfpgan                   1.3.8
gitdb                    4.0.10
GitPython                3.1.30
google-auth              2.17.2
google-auth-oauthlib     1.0.0
gradio                   3.16.2
grpcio                   1.53.0
h11                      0.12.0
httpcore                 0.15.0
httpx                    0.23.3
huggingface-hub          0.15.1
idna                     2.10
imageio                  2.9.0
imageio-ffmpeg           0.4.2
imgaug                   0.2.6
importlib-metadata       6.1.0
inflection               0.5.1
ipykernel                6.23.1
ipython                  8.13.2
jedi                     0.18.2
Jinja2                   3.1.2
joblib                   1.2.0
jsonmerge                1.8.0
jsonschema               4.17.3
jupyter_client           8.2.0
jupyter_core             5.3.0
kiwisolver               1.4.4
kornia                   0.6.7
lark                     1.1.2
lazy_loader              0.2
linkify-it-py            2.0.0
lit                      16.0.0
llvmlite                 0.39.1
lmdb                     1.4.0
lpips                    0.1.4
lxml                     4.9.2
Markdown                 3.4.3
markdown-it-py           2.2.0
MarkupSafe               2.1.2
matplotlib               3.7.1
matplotlib-inline        0.1.6
mdit-py-plugins          0.3.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.4
multiprocess             0.70.14
mypy-extensions          1.0.0
nest-asyncio             1.5.6
networkx                 3.1rc0
nltk                     3.8.1
numba                    0.56.4
numexpr                  2.8.4
numpy                    1.23.3
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
omegaconf                2.2.3
open-clip-torch          2.7.0
openai                   0.27.7
opencv-python            4.7.0.72
opencv-python-headless   4.7.0.72
orjson                   3.8.9
packaging                23.0
pandas                   1.5.3
paramiko                 3.1.0
parso                    0.8.3
pathtools                0.1.2
pexpect                  4.8.0
pickleshare              0.7.5
piexif                   1.1.3
Pillow                   9.4.0
pip                      23.0.1
platformdirs             3.5.1
prompt-toolkit           3.0.38
protobuf                 3.20.3
psutil                   5.9.4
ptyprocess               0.7.0
pudb                     2019.2
pure-eval                0.2.2
pyarrow                  11.0.0
pyasn1                   0.4.8
pyasn1-modules           0.2.8
pycparser                2.21
pycryptodome             3.17
pydantic                 1.10.7
pydeck                   0.8.0
pyDeprecate              0.3.1
pydub                    0.25.1
Pygments                 2.14.0
Pympler                  1.0.1
PyNaCl                   1.5.0
pyparsing                3.0.9
pyre-extensions          0.0.23
pyrsistent               0.19.3
PySocks                  1.7.1
python-dateutil          2.8.2
python-multipart         0.0.6
pytorch-lightning        1.7.6
pytz                     2023.3
pytz-deprecation-shim    0.1.0.post0
PyWavelets               1.4.1
PyYAML                   6.0
pyzmq                    25.1.0
realesrgan               0.3.0
regex                    2023.3.23
reportlab                3.6.12
requests                 2.25.1
requests-oauthlib        1.3.1
resize-right             0.0.2
responses                0.18.0
rfc3986                  1.5.0
rich                     13.3.3
rouge-score              0.1.2
rsa                      4.9
safetensors              0.2.7
scikit-image             0.19.2
scipy                    1.10.1
semver                   3.0.0
sentencepiece            0.1.99
sentry-sdk               1.19.0
setproctitle             1.3.2
setuptools               59.6.0
six                      1.16.0
smmap                    5.0.0
sniffio                  1.3.0
soupsieve                2.4
stack-data               0.6.2
starlette                0.26.1
streamlit                1.20.0
svglib                   1.5.1
sympy                    1.12rc1
tb-nightly               2.13.0a20230405
tensorboard              2.12.1
tensorboard-data-server  0.7.0
tensorboard-plugin-wit   1.8.1
termcolor                2.3.0
test-tube                0.7.5
tifffile                 2023.3.21
timm                     0.6.7
tinycss2                 1.2.1
tokenizers               0.12.1
toml                     0.10.2
toolz                    0.12.0
torch                    2.0.1
torchdiffeq              0.2.3
torchmetrics             0.11.4
torchsde                 0.2.5
tornado                  6.2
tqdm                     4.65.0
traitlets                5.9.0
trampoline               0.1.2
transformers             4.28.0.dev0     /mnt/workspace/demos/alpaca/transformers
triton                   2.0.0
typing_extensions        4.5.0
typing-inspect           0.8.0
tzdata                   2023.3
tzlocal                  4.3
uc-micro-py              1.0.1
urllib3                  1.26.15
urwid                    2.1.2
uvicorn                  0.21.1
validators               0.20.0
wandb                    0.14.0
watchdog                 3.0.0
wcwidth                  0.2.6
webdataset               0.2.5
webencodings             0.5.1
websockets               11.0
Werkzeug                 2.2.3
wheel                    0.37.1
xformers                 0.0.16rc425
xxhash                   3.2.0
yapf                     0.32.0
yarl                     1.8.2
zipp                     3.15.0

aplaca的显存要求是比较大的,目前来看基本要保证32G的显存,当然我们可以通过调整模型的结构大小来减小显存。

1.下载stanford_alpaca

!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/stanford_alpaca.tgz
!tar -xvf stanford_alpaca.tgz

2.安装依赖

!cd stanford_alpaca &&  echo y | pip uninstall torch &&  echo y | pip uninstall torchvision && pip install -r requirements.txt && pip install gradio

!git clone https://github.com/huggingface/transformers.git && \
cd transformers && \
git checkout 165dd6dc916a43ed9b6ce8c1ed62c3fe8c28b6ef && \
pip install -e .

3.数据准备 

数据格式如下,如需使用自己的数据进行微调可以转化成如下形式:
"instruction":用于描述模型应该执行的任务
"input" : 任务的可选上下文或输入。例如,当指令是“总结以下文章”时,输入就是文章。
"output" :需要模型输出的答案

格式如下
[
    {
        "instruction": "Give three tips for staying healthy.",
        "input": "",
        "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
    }
]

# 下载数据集,如有重名文件,先将文件夹中的重名文件重命名。
!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/alpaca_data.json

4.微调模型

4.1 准备权重

llama-7B的权重大概有12G

!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/llama-7b-hf.tar.gz && tar -xvf llama-7b-hf.tar.gz

4.2 参数调节

可以对参数进行微调以适应显存,可以修改部分参数来保证在较小显存和单卡上也可以测试,根据预训练路径找到对应的config.json文件,并按照下面的参数修改./llama-7b-hf路径下面的config.json文件修改max_sequence_length和num_hidden_layers等参数可以保证较小显存也可以训练。

{
    "architectures": ["LLaMAForCausalLM"], 
    "bos_token_id": 0, 
    "eos_token_id": 1, 
    "hidden_act": "silu", 
    "hidden_size": 4096, 
    "intermediate_size": 11008, 
    "initializer_range": 0.02, 
    "max_sequence_length": 4, 
    "model_type": "llama", 
    "num_attention_heads": 32, 
    "num_hidden_layers": 4, 
    "pad_token_id": -1, 
    "rms_norm_eps": 1e-06, 
    "torch_dtype": "float16", 
    "transformers_version": "4.27.0.dev0", 
    "use_cache": true, 
    "vocab_size": 32000
}

4.3 训练

在stanford_alpaca/train.py中加上

import os
os.environ["WANDB_DISABLED"] = "true"
# 执行训练指令
!torchrun --nproc_per_node=1 --master_port=29588 ./stanford_alpaca/train.py \
 --model_name_or_path "./llama-7b-hf" \
 --data_path ./alpaca_data.json \
 --bf16 False \
 --output_dir ./models/alpaca-2 \
 --num_train_epochs 1 \
 --per_device_train_batch_size 1 \
 --per_device_eval_batch_size 1 \
 --gradient_accumulation_steps 8 \
 --evaluation_strategy "no" \
 --save_strategy "steps" \
 --save_steps 20 \
 --save_total_limit 1 \
 --learning_rate 2e-5 \
 --model_max_length 4 \
 --weight_decay 0. \
 --warmup_ratio 0.03 \
 --lr_scheduler_type "cosine" \
 --logging_steps 1 \
 --fsdp "full_shard auto_wrap" \
 --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
 --tf32 False 

5.推理阶段

import transformers
tokenizers = transformers.LlamaTokenizer.from_pretrained("./models/alpaca-2")
model = transformers.LlamaForCausalLM.from_pretrained("./models/alpaca-2").cuda()
model.eval()
def gen(req):
    batch = tokenizers(req, return_tensors='pt', add_special_tokens=False)
    batch = {k: v.cuda() for k, v in batch.items()}
    full_completion = model.generate(inputs=batch["input_ids"],
                                    attention_mask=batch["attention_mask"],
                                    temperature=0.7,
                                    top_p=0.9,
                                    do_sample=True,
                                    num_beams=1,
                                    max_new_tokens=600,
                                    eos_token_id=tokenizers.eos_token_id,
                                    pad_token_id=tokenizers.pad_token_id)
    print(tokenizers.decode(full_completion[0]))

gen("List all Canadian provinces in alphabetical order.")

在这个路径中有完整的原始权重

!wget  https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/gen.py

6.demo

import gradio as gr
import requests
import json
import transformers

tokenizers = transformers.LlamaTokenizer.from_pretrained("./models/alpaca-2")
model = transformers.LlamaForCausalLM.from_pretrained("./models/alpaca-2").cuda()
model.eval()


def inference(text):
    batch  = tokenizers(text, return_tensors="pt", add_special_tokens=False)                                                                                                                                                      
    batch = {k: v.cuda() for k, v in batch.items()}                                                                                                                                                                              
    full_completion = model.generate(inputs=batch["input_ids"],                                                                                                                                                                  
                                     attention_mask=batch["attention_mask"],                                                                                                                                                      
                                     temperature=0.7,                                                                                                                                                                             
                                     top_p=0.9,                                                                                                                                                                                   
                                     do_sample=True,                                                                                                                                                                              
                                     num_beams=1,                                                                                                                                                                                 
                                     max_new_tokens=600,                                                                                                                                                                          
                                     eos_token_id=tokenizers.eos_token_id,                                                                                                                                                        
                                     pad_token_id=tokenizers.pad_token_id)                                                                                                                                                                                                                                                                                                                                                              
    print(tokenizers.decode(full_completion[0]))
    return tokenizers.decode(full_completion[0])

demo = gr.Blocks()
with demo:
    input_prompt = gr.Textbox(label="请输入需求", 
                                value="帮我写一篇安全检查的新闻稿件。",
                                lines=6)
    generated_txt = gr.Textbox(lines=6)

    b1 = gr.Button("发送")
    b1.click(inference, inputs=[input_prompt], outputs=generated_txt) 

demo.launch(enable_queue=True, share=True)

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/131005622