大语言模型微调指南

🧠

引言

大语言模型(LLM)的微调是让通用模型适应特定任务或领域的关键技术。本文将介绍主流的微调方法、工具链和最佳实践,帮助你打造专属的AI助手。

💡 何时微调:当需要模型学习特定领域知识、独特风格或复杂指令时,微调比提示工程更有效。但成本也更高,请权衡后再决定。

微调方法概览

1. 全参数微调(Full Fine-tuning)

更新模型所有参数,效果最好但资源消耗巨大。

2. LoRA / QLoRA

只训练低秩适配器,显存需求大幅降低,是目前最流行的方法。

3. Prefix Tuning

在输入前添加可训练的虚拟token前缀。

4. Prompt Tuning

仅优化输入prompt的嵌入向量,成本最低。

环境准备

# 创建虚拟环境
conda create -n finetune python=3.10
conda activate finetune

# 安装依赖
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install transformers==4.36.0
pip install datasets accelerate peft trl wandb
pip install bitsandbytes>=0.40.0

使用TRL进行SFT微调

from datasets import load_dataset
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

# 加载模型和分词器
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# 准备数据
dataset = load_dataset("json", data_files="training_data.jsonl")

# 配置训练器
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    dataset_text_field="text",
    max_seq_length=2048,
    args=TrainingArguments(
        output_dir="./output",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        save_strategy="epoch",
        logging_steps=10,
        fp16=True,
        optim="paged_adamw_8bit",
    ),
)

# 开始训练
trainer.train()

使用LoRA进行高效微调

from peft import LoraConfig, get_peft_model, TaskType

# 配置LoRA
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
)

# 将LoRA适配器添加到模型
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 输出: trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06%

数据准备

数据格式

// JSONL格式(每行一个JSON对象)
{"text": "用户: 帮我写一个Python函数。\n助手: 当然可以,请问需要什么功能的函数?"}
{"text": "用户: 解释一下什么是闭包。\n助手: 闭包是指一个函数可以访问其词法作用域外的变量..."}

// 或使用对话格式
{
    "conversations": [
        {"role": "user", "content": "什么是机器学习?"},
        {"role": "assistant", "content": "机器学习是人工智能的一个分支..."}
    ]
}

数据质量建议

超参数调优

# 推荐超参数配置
learning_rate: 1e-4 ~ 5e-4  # LoRA可用稍大的学习率
batch_size: 4~16           # 根据显存调整
epochs: 1~5                # 过多会过拟合
max_seq_length: 2048/4096  # 根据数据长度
warmup_ratio: 0.05~0.1
weight_decay: 0.01

# 学习率调度
lr_scheduler_type: "cosine" 或 "linear"

模型推理和部署

from peft import PeftModel
from transformers import AutoModelForCausalLM

# 加载训练好的LoRA适配器
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = PeftModel.from_pretrained(base_model, "./output/checkpoint-1000")

# 合并模型(可选,用于更快推理)
merged_model = model.merge_and_unload()
merged_model.save_pretrained("./final-model")
tokenizer.save_pretrained("./final-model")

使用vLLM加速推理

from vllm import LLM, SamplingParams

llm = LLM(model="./final-model")
sampling_params = SamplingParams(temperature=0.7, max_tokens=512)

outputs = llm.generate(["解释一下什么是云计算?"], sampling_params)
print(outputs[0].outputs[0].text)

评估方法

自动评估

人工评估

常见问题

灾难性遗忘

模型在学习新知识时忘记旧知识。解决方案:混合通用数据、使用较小学习率。

过拟合

训练损失下降但验证损失上升。解决方案:增加数据、使用正则化、早停。

训练不稳定

梯度爆炸或Loss飙升。解决方案:梯度裁剪、学习率热身、使用DeepSpeed。

🔧 工具推荐:Axolotl提供了完整的微调工具链,Unsloth可加速LoRA训练2-5倍,推荐尝试。