目 录CONTENT

文章目录

【深度学习】PEFT TaskType 任务类型详解

EulerBlind
2025-11-11 / 0 评论 / 0 点赞 / 0 阅读 / 0 字

PEFT(Parameter-Efficient Fine-Tuning)是 Hugging Face 提供的参数高效微调库,在使用 LoRA、Prefix Tuning 等微调方法时,需要指定 TaskType 来告诉 PEFT 库当前任务的类型。正确选择任务类型对于模型微调的成功至关重要。

TaskType 枚举值

PEFT 支持以下 6 种任务类型:

from peft import TaskType

TaskType.SEQ_CLS              # 序列分类
TaskType.SEQ_2_SEQ_LM         # 序列到序列语言建模
TaskType.CAUSAL_LM            # 因果语言建模
TaskType.TOKEN_CLS            # 词元分类
TaskType.QUESTION_ANS         # 问答任务
TaskType.FEATURE_EXTRACTION   # 特征提取

各任务类型详解

1. SEQ_CLS - 序列分类

用途: 对整个输入序列进行分类,输出单个类别标签。

适用场景:

  • 情感分析(正面/负面/中性)
  • 文本分类(新闻分类、垃圾邮件检测)
  • 意图识别(用户意图分类)
  • 文本相似度判断(相似/不相似)

示例:

from peft import LoraConfig, TaskType

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    task_type=TaskType.SEQ_CLS  # 用于情感分析、文本分类
)

典型模型: BERT、RoBERTa 用于分类任务


2. SEQ_2_SEQ_LM - 序列到序列语言建模

用途: 将一个序列转换为另一个序列,输入和输出都是序列。

适用场景:

  • 机器翻译(中英文互译)
  • 文本摘要(长文本压缩为摘要)
  • 对话生成(多轮对话)
  • 文本改写(同义改写、风格转换)
  • 代码生成(自然语言到代码)

示例:

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj"],
    task_type=TaskType.SEQ_2_SEQ_LM  # 用于翻译、摘要、对话
)

典型模型: T5、BART、mT5


3. CAUSAL_LM - 因果语言建模

用途: 自回归语言建模,根据前面的 token 预测下一个 token。

适用场景:

  • 文本生成(续写、创作)
  • 代码补全(IDE 代码提示)
  • 对话生成(单轮对话)
  • 故事创作
  • 指令跟随(Instruction Following)

示例:

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type=TaskType.CAUSAL_LM  # 用于文本生成、代码补全
)

典型模型: GPT、LLaMA、Qwen、ChatGLM


4. TOKEN_CLS - 词元分类

用途: 对序列中的每个 token 进行分类标注。

适用场景:

  • 命名实体识别(NER)- 识别人名、地名、机构名
  • 词性标注(POS Tagging)
  • 中文分词
  • 序列标注任务
  • 实体抽取

示例:

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    task_type=TaskType.TOKEN_CLS  # 用于 NER、词性标注
)

典型模型: BERT、RoBERTa 用于序列标注


5. QUESTION_ANS - 问答任务

用途: 给定问题和上下文,从上下文中提取答案。

适用场景:

  • 阅读理解(MRC - Machine Reading Comprehension)
  • 抽取式问答(从文档中提取答案片段)
  • 文档问答(基于文档的问答系统)
  • 知识库问答

示例:

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    task_type=TaskType.QUESTION_ANS  # 用于阅读理解、问答
)

典型模型: BERT、RoBERTa、ELECTRA 用于问答任务


6. FEATURE_EXTRACTION - 特征提取

用途: 提取文本的隐藏状态(hidden states)作为嵌入向量或特征,用于下游任务。

适用场景:

  • 文本嵌入(Text Embedding)
  • 语义相似度计算
  • 向量检索(检索增强生成 RAG)
  • 句子/文档表示学习
  • 下游任务的预训练特征提取

示例:

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModel

# 加载基础模型
base_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")

# 配置 LoRA,用于特征提取任务
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION  # 用于 embedding 提取
)

# 应用 LoRA
model = get_peft_model(base_model, lora_config)

# 使用模型提取特征
def get_embedding(text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        output = model(**inputs)
        # 使用 [CLS] token 的 hidden state 作为句子 embedding
        return output.last_hidden_state[:, 0, :]

典型模型: Sentence-BERT、Qwen3-Embedding、BGE、E5


选择指南

任务类型输入输出典型应用
SEQ_CLS序列单个标签情感分析、文本分类
SEQ_2_SEQ_LM序列序列翻译、摘要、对话
CAUSAL_LM序列前缀下一个 token文本生成、代码补全
TOKEN_CLS序列每个 token 的标签NER、词性标注
QUESTION_ANS问题+上下文答案片段阅读理解、问答
FEATURE_EXTRACTION序列隐藏状态向量Embedding、检索

注意事项

  1. 任务类型影响模型结构: 不同的 TaskType 会影响 PEFT 如何配置模型的输出层和损失函数。

  2. 与模型架构匹配: 确保选择的 TaskType 与你的基础模型架构兼容。例如,GPT 系列模型通常使用 CAUSAL_LM,BERT 系列可用于 SEQ_CLSTOKEN_CLSQUESTION_ANS

  3. 特征提取的特殊性: FEATURE_EXTRACTION 不涉及分类头,主要用于提取中间层的表示,适合 embedding 模型微调。

  4. 多任务场景: 如果需要在同一模型上执行多种任务,可能需要分别训练不同的适配器,或使用支持多任务的配置。

正确选择 TaskType 是 PEFT 微调成功的第一步,它确保了模型能够针对特定任务进行有效的参数高效微调。

0
博主关闭了所有页面的评论