基于P-Tuningv2轻量微调和推理chatglm

类ChatGPT的部署与微调(下):从GLM、ChatGLM到MOSS、ChatDoctor、可商用_v_JULY_v的博客-CSDN博客随着『GPT4多模态/Microsoft 365 Copilot/Github Copilot X/ChatGPT插件』的推出,绝大部分公司的技术 产品 服务,以及绝大部分人的工作都将被革新一遍类似iPhone的诞生 大家面向iOS编程 有了App Store现在有了ChatGPT插件/GPT应用商店,以后很多公司 很多人面向GPT编程(很快技术人员分两种,一种懂GPT,一种不懂GPT)然ChatGPT/GPT4基本不可能开源了,而通过上文《https://blog.csdn.net/v_JULY_v/article/details/129880836ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。 ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T Token的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。不过,由于 ChatGLM-6B 的规模较小,目前已知其具有相当多的局限性,如事实性/数学逻辑错误,可能生成有害/有偏见内容,较弱的上下文能力,自我认知混乱,以及对英文指示生成与中文指示完全矛盾的内容。

在阿里云PAI平台上跑

absl-py                  1.4.0
accelerate               0.18.0
addict                   2.4.0
aenum                    3.1.12
aiofiles                 23.1.0
aiohttp                  3.8.4
aiosignal                1.3.1
albumentations           0.4.3
altair                   4.2.2
antlr4-python3-runtime   4.9.3
anyio                    3.6.2
appdirs                  1.4.4
asttokens                2.2.1
async-timeout            4.0.2
attrs                    22.2.0
backcall                 0.2.0
basicsr                  1.4.2
bcrypt                   4.0.1
beautifulsoup4           4.12.1
blendmodes               2022
blinker                  1.6
boltons                  23.0.0
braceexpand              0.1.7
cachetools               5.3.0
certifi                  2022.12.7
cffi                     1.15.1
chardet                  4.0.0
charset-normalizer       3.1.0
clean-fid                0.1.29
click                    8.1.3
clip-anytorch            2.5.2
cmake                    3.26.1
comm                     0.1.3
contourpy                1.0.7
cpm-kernels              1.0.11
cryptography             40.0.1
cssselect2               0.7.0
cycler                   0.11.0
datasets                 2.11.0
debugpy                  1.6.7
decorator                5.1.1
deprecation              2.1.0
diffusers                0.15.0
dill                     0.3.6
docker-pycreds           0.4.0
einops                   0.4.1
entrypoints              0.4
executing                1.2.0
facexlib                 0.2.5
fastapi                  0.94.0
ffmpy                    0.3.0
filelock                 3.10.7
filterpy                 1.4.5
font-roboto              0.0.1
fonts                    0.0.3
fonttools                4.39.3
frozenlist               1.3.3
fsspec                   2023.3.0
ftfy                     6.1.1
future                   0.18.3
gdown                    4.7.1
gfpgan                   1.3.8
gitdb                    4.0.10
GitPython                3.1.30
google-auth              2.17.2
google-auth-oauthlib     1.0.0
gradio                   3.16.2
grpcio                   1.53.0
h11                      0.12.0
httpcore                 0.15.0
httpx                    0.23.3
huggingface-hub          0.13.3
icetk                    0.0.4
idna                     2.10
imageio                  2.9.0
imageio-ffmpeg           0.4.2
imgaug                   0.2.6
importlib-metadata       6.1.0
inflection               0.5.1
ipykernel                6.23.1
ipython                  8.13.2
jedi                     0.18.2
jieba                    0.42.1
Jinja2                   3.1.2
joblib                   1.2.0
jsonmerge                1.8.0
jsonschema               4.17.3
jupyter_client           8.2.0
jupyter_core             5.3.0
kiwisolver               1.4.4
kornia                   0.6.7
lark                     1.1.2
lazy_loader              0.2
linkify-it-py            2.0.0
lit                      16.0.0
llvmlite                 0.39.1
lmdb                     1.4.0
lpips                    0.1.4
lxml                     4.9.2
Markdown                 3.4.3
markdown-it-py           2.2.0
MarkupSafe               2.1.2
matplotlib               3.7.1
matplotlib-inline        0.1.6
mdit-py-plugins          0.3.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.4
multiprocess             0.70.14
mypy-extensions          1.0.0
nest-asyncio             1.5.6
networkx                 3.1rc0
nltk                     3.8.1
numba                    0.56.4
numexpr                  2.8.4
numpy                    1.23.3
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
omegaconf                2.2.3
open-clip-torch          2.7.0
opencv-python            4.7.0.72
opencv-python-headless   4.7.0.72
orjson                   3.8.9
packaging                23.0
pandas                   1.5.3
paramiko                 3.1.0
parso                    0.8.3
pathtools                0.1.2
pexpect                  4.8.0
pickleshare              0.7.5
piexif                   1.1.3
Pillow                   9.4.0
pip                      23.0.1
platformdirs             3.5.1
prompt-toolkit           3.0.38
protobuf                 3.20.0
psutil                   5.9.4
ptyprocess               0.7.0
pudb                     2019.2
pure-eval                0.2.2
pyarrow                  11.0.0
pyasn1                   0.4.8
pyasn1-modules           0.2.8
pycparser                2.21
pycryptodome             3.17
pydantic                 1.10.7
pydeck                   0.8.0
pyDeprecate              0.3.1
pydub                    0.25.1
Pygments                 2.14.0
Pympler                  1.0.1
PyNaCl                   1.5.0
pyparsing                3.0.9
pyre-extensions          0.0.23
pyrsistent               0.19.3
PySocks                  1.7.1
python-dateutil          2.8.2
python-multipart         0.0.6
pytorch-lightning        1.7.6
pytz                     2023.3
pytz-deprecation-shim    0.1.0.post0
PyWavelets               1.4.1
PyYAML                   6.0
pyzmq                    25.1.0
realesrgan               0.3.0
regex                    2023.3.23
reportlab                3.6.12
requests                 2.25.1
requests-oauthlib        1.3.1
resize-right             0.0.2
responses                0.18.0
rfc3986                  1.5.0
rich                     13.3.3
rouge-chinese            1.0.3
rsa                      4.9
safetensors              0.2.7
scikit-image             0.19.2
scipy                    1.10.1
semver                   3.0.0
sentencepiece            0.1.99
sentry-sdk               1.19.0
setproctitle             1.3.2
setuptools               59.6.0
six                      1.16.0
smmap                    5.0.0
sniffio                  1.3.0
soupsieve                2.4
stack-data               0.6.2
starlette                0.26.1
streamlit                1.20.0
svglib                   1.5.1
sympy                    1.12rc1
tb-nightly               2.13.0a20230405
tensorboard              2.12.1
tensorboard-data-server  0.7.0
tensorboard-plugin-wit   1.8.1
test-tube                0.7.5
tifffile                 2023.3.21
timm                     0.6.7
tinycss2                 1.2.1
tokenizers               0.12.1
toml                     0.10.2
toolz                    0.12.0
torch                    1.13.1+cu117
torchdiffeq              0.2.3
torchmetrics             0.11.4
torchsde                 0.2.5
torchvision              0.14.1+cu117
tornado                  6.2
tqdm                     4.65.0
traitlets                5.9.0
trampoline               0.1.2
transformers             4.27.1
triton                   2.0.0
typing_extensions        4.5.0
typing-inspect           0.8.0
tzdata                   2023.3
tzlocal                  4.3
uc-micro-py              1.0.1
urllib3                  1.26.15
urwid                    2.1.2
uvicorn                  0.21.1
validators               0.20.0
wandb                    0.14.0
watchdog                 3.0.0
wcwidth                  0.2.6
webdataset               0.2.5
webencodings             0.5.1
websockets               11.0
Werkzeug                 2.2.3
wheel                    0.37.1
xformers                 0.0.16rc425
xxhash                   3.2.0
yapf                     0.32.0
yarl                     1.8.2
zipp                     3.15.0

python39+cu117+驱动:470.103.1

apt list | grep cuda

 1.下载chatglm-6b

阿里的镜像源和官方的git代码比,版本就有些落后了,建议使用官方的版本。

import os
dsw_region = os.environ.get("dsw_region")
url_link = {
    "cn-shanghai": "https://atp-modelzoo-sh.oss-cn-shanghai-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-hangzhou": "https://atp-modelzoo.oss-cn-hangzhou-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-shenzhen": "https://atp-modelzoo-sz.oss-cn-shenzhen-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz",
    "cn-beijing": "https://atp-modelzoo-bj.oss-cn-beijing-internal.aliyuncs.com/release/tutorials/chatGLM/ChatGLM-6B-main.tar.gz", 
}

# 获取阿里云实例所在区域进行下载
path = url_link[dsw_region]
os.environ['LINK_CHAT'] = path
!wget $LINK_CHAT
!tar -xvf ChatGLM-6B-main.tar.gz 

# 去掉link中-internal可以自己下载

2.安装依赖

!cd ChatGLM-6B-main && pip install -r requirements.txt && \
pip install rouge_chinese nltk jieba datasets 

3.数据准备

用户可以使用我们提供的数据(AdvertiseGen_Simple),也可以使用自定义数据。
数据文件为json文件。json文件中每条数据是一个字典,记录输入文本和输出文本。将 train.sh 文件以及 evaluate.sh 文件中的 train_file、validation_file 和 test_file 修改为对应的数据集路径,并将 prompt_column 和 response_column 改为 JSON 文件中输入文本和输出文本对应的 KEY。
以我们提供的AdvertiseGen_Simple数据集为例,prompt_column为content,response_column为summary 。

4.微调模型 

!cd ChatGLM-6B-main/ptuning && bash train.sh

PRE_SEQ_LEN=8
LR=1e-2

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_train \
    --train_file AdvertiseGen_Simple/train.json \
    --validation_file AdvertiseGen_Simple/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --logging_steps 10 \
    --save_steps 6 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --num_train_epochs 1
    #--quantization_bit 4 \    

把int4注掉了就可以跑,否则报错:RuntimeError: Library cublasLt is not initialized。

在demos/chatglm_6b/ChatGLM-6B-main/ptuning/main.py中设置os.environ["WANDB_DISABLED"] = "true",关掉wandb。

5.模型评测

评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在 ./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt。

!cd ChatGLM-6B-main/ptuning &&  bash evaluate.sh

PRE_SEQ_LEN=8
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
STEP=6

CUDA_VISIBLE_DEVICES=0 python main.py \
    --do_predict \
    --validation_file AdvertiseGen_Simple/dev.json \
    --test_file AdvertiseGen_Simple/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP  \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN 
    #--quantization_bit 4

6.模型推理

!cd ChatGLM-6B-main/ && python web_demo.py

猜你喜欢

转载自blog.csdn.net/u012193416/article/details/130999628