一、Lora微调mt5-xxl
GPU要求:至少 A100-SXM4-80GB * 2
batch_size:A100-SXM4-80GB * 2情况下最大 16
备注:mt5-xxl全参数微调,batch_size=2时,A100-SXM4-80GB至少需要5张
run_finetune_lora.py
import logging
import os
import sys
import numpy as np
from datasets import Dataset
from peft import PeftModel
from peft import LoraModel, LoraConfig, get_peft_model
import random
import torch
import json
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, HfArgumentParser, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from arguments import ModelArguments, DataTrainingArguments
from utils import get_extract_metrics_f1
os.environ["WANDB_DISABLED"] = "true"
logger = logging.getLogger("__main__&#