transformers 库为许多流行的 Transformer 模型提供了一个干净且文档齐全的接口。它不仅使源代码更易于阅读和理解,还提供了一种与模型交互的标准方法。在前一篇文章中,您已经看到了如何使用 DistilBERT 等模型进行自然语言处理任务。在这篇文章中,您将学习如何为您自己的目的微调该模型。这会将模型的使用从推理扩展到训练。具体来说,您将学习:
- 如何准备用于训练的数据集
- 如何使用辅助库训练模型
通过我的书籍《Hugging Face Transformers中的NLP》,快速启动您的项目。它提供了带有工作代码的自学教程。
让我们开始吧。

为问答任务微调 DistilBERT
照片由 Lea Fabienne 拍摄。保留部分权利。
概述
这篇博文分为三部分;它们是:
- 为自定义问答微调 DistilBERT
- 数据集和预处理
- 运行训练
为自定义问答微调 DistilBERT
在 transformers 库中使用模型的最简单方法是创建 pipeline,它隐藏了许多与模型交互的细节。
您可能不想创建 pipeline 而要单独设置模型的一个原因是,您希望在自己的数据集上微调模型。这在使用 pipeline 时是不可能的,因为您需要检查模型在损失函数下的原始输出,而这通常是隐藏在 pipeline 中的。
通常,预训练模型是使用通用数据集创建的。但是,它可能无法在特定领域上很好地工作,尤其是当领域中的语言与通用用法显著不同时。这就是可以尝试微调的地方。
微调的难点可能在于一个好的数据集的可用性。这通常非常昂贵且耗时。为说明起见,下面我们将使用一个通用的、公开可用的数据集,称为 SQuAD(斯坦福问答数据集)。
得益于 transformers 库高度通用和干净的设计,微调模型非常简单。下面是如何在 SQuAD 数据集上微调模型的示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, \ Trainer, TrainingArguments from datasets import load_dataset # 加载 SQuAD 数据集 dataset = load_dataset("squad") # 加载分词器和模型 model_name = "distilbert-base-uncased" tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) model = DistilBertForQuestionAnswering.from_pretrained(model_name) # 对数据集进行分词 def preprocess_function(examples): questions = [q.strip() for q in examples["question"]] inputs = tokenizer( questions, examples["context"], max_length=384, truncation="only_second", return_offsets_mapping=True, padding="max_length", ) offset_mapping = inputs.pop("offset_mapping") answers = examples["answers"] start_positions = [] end_positions = [] for i, offsets in enumerate(offset_mapping): answer = answers[i] start_char = answer["answer_start"][0] end_char = start_char + len(answer["text"][0]) sequence_ids = inputs.sequence_ids(i) # 查找上下文的开始和结束 context_start = sequence_ids.index(1) context_end = len(sequence_ids) - 1 - sequence_ids[::-1].index(1) # 如果答案未完全包含在上下文中,则标记为 (0, 0) if offsets[context_start][0] > end_char or offsets[context_end][1] < start_char: start_positions.append(0) end_positions.append(0) else: # 否则,查找开始和结束的 token 位置 idx = context_start while idx <= context_end and offsets[idx][0] <= start_char: idx += 1 start_positions.append(idx - 1) idx = context_end while idx >= context_start and offsets[idx][1] >= end_char: idx -= 1 end_positions.append(idx + 1) inputs["start_positions"] = start_positions inputs["end_positions"] = end_positions return inputs # 对数据集应用预处理 tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) # 定义训练参数 training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=3, weight_decay=0.01, ) # 初始化 Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], tokenizer=tokenizer, ) # 训练模型并保存结果 trainer.train() model.save_pretrained("./fine-tuned-distilbert-squad") tokenizer.save_pretrained("./fine-tuned-distilbert-squad") |
这段代码有点复杂。让我们一步一步地分解。
数据集和预处理
SQuAD 数据集是用于问答的流行数据集,可以在 Hugging Face hub 上找到。您可以使用 Hugging Face 的 datasets 库中的 load_dataset()
函数来加载它。
1 2 3 |
from datasets import load_dataset dataset = load_dataset("squad") |
每个数据集都不同。但这个特定的数据集是类似字典的,包含“title”、“context”、“question”和“answers”键。“context”是一段中等长度的文本。“question”是一个问题句子。“answers”是一个字典,其中包含“text
”和“answer_start
”键。“text”映射到问题的简短答案字符串。“answer_start”映射到答案在上下文中的起始位置。“title”可以忽略,因为它提供了上下文摘录的文章标题。
要将数据集用于训练,您需要了解模型期望的输入以及它产生的输出类型。对于 DistilBERT 进行问答,模型通过 DistilBertForQuestionAnswering
类的实现进行固定,除非您决定编写自己的模型实现。在此类中,模型将序列整数 token ID 作为输入,输出是两个 logits 向量,一个用于答案的开始位置,一个用于结束位置。
您可以在上一篇文章中找到模型输入和输出格式的详细信息。或者,您可以在 DistilBertForQuestionAnswering 类文档中找到详细信息。
为了将数据集用于训练,您需要进行一些预处理。这是为了将数据集转换为与模型的输入和输出格式匹配的格式。从 Hugging Face hub 加载的数据集对象允许您使用 map()
方法进行此操作,其中转换实现为自定义函数 preprocess_function()
。
1 2 3 4 |
... tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) |
请注意,preprocess_function()
接受来自数据集的批次,因为您在 map()
方法中使用了 batched=True
。
在 preprocess_function()
中,分词器使用来自 examples["question"]
的问题和来自 examples["context"]
的上下文来调用。问题会去除额外的空格,上下文会截断以适应 384 个 token 的最大长度。此函数中使用分词器的方式与您在前一篇文章中看到的不同。
1 2 3 4 5 6 7 8 9 |
... inputs = tokenizer( questions, examples["context"], max_length=384, truncation="only_second", return_offsets_mapping=True, padding="max_length", ) |
首先,分词器使用问题和上下文的批次进行调用。对于可能不规则的输入,分词器会将输入填充到批次的最大长度。其次,使用 return_offsets_mapping=True
,分词器返回一个包含“input_ids
”、“attention_mask
”和“offset_mapping
”键的字典。“input_ids
”是整数 token ID 的序列。“attention_mask
”是一个二进制掩码,指示哪些 token 是真实的(1),哪些是填充的(0)。“offset_mapping
”是通过设置 return_offsets_mapping=True
添加的,它是一个元组列表,指示每个 token 在原始文本中的字符位置(开始和结束偏移量)。
分词器输出的 input_ids
以以下格式连接问题和上下文:
1 |
[CLS] 问题 [SEP] 上下文 [SEP] |
这就是模型所期望的。来自数据集的答案是一个字符串,以及答案在原始上下文中可以找到的字符偏移量。这与模型生成的 token 位置的 logits 不同。因此,您在 preprocess_function()
中使用了一个 for 循环来重新创建答案的开始和结束 token 位置。
在此代码中,分词器使用其他参数进行调用。设置 return_offsets_mapping=True
将使返回的对象包含 offset_mapping
,这是一个元组列表,用于识别每个输入文本中每个 token 的开始和结束位置。
首先,offset_mapping
从分词器返回的对象中弹出,因为它对于训练不是必需的。然后,对于每个答案,您需要从上下文中识别字符的开始和结束偏移量。您可以使用类似以下的代码来验证这一点:
1 2 3 4 |
... start_char = answer["answer_start"][0] end_char = start_char + len(answer["text"][0]) assert answer["text"] == context[start_char:end_char] |
即使您知道字符偏移量,模型也是根据 token 位置进行操作的。
请记住,分词器连接了问题和上下文。幸运的是,分词器提供了线索来识别输出中的上下文的开始和结束。在 inputs.sequence_ids(i)
中,它是一个整数或 None 的 Python 列表,对应于批次中的元素 i
。列表包含 None 表示特殊 token 的位置,包含一个整数表示实际输入的 token 位置。在您的用例中,您先调用了问题分词器,然后是上下文分词器,因此整数 0 对应于问题,1 对应于上下文。
因此,您可以通过检查 sequence_ids
列表中整数 1 第一次和最后一次出现的位置来识别上下文的 token 开始和结束偏移量。
1 2 3 4 |
... sequence_ids = inputs.sequence_ids(i) context_start = sequence_ids.index(1) context_end = len(sequence_ids) - 1 - sequence_ids[::-1].index(1) |
知道了上下文的开始和结束 token 位置后,您仍然需要检查答案是否被任何 token 覆盖。这通过逐个检查 token 来完成。您使用 for 循环遍历每一对偏移量,并检查答案的开始和结束字符位置是否在任何 token 内。如果是,则 token 的位置将被记为 start_positions
和 end_positions
。对于未找到的答案(例如,由于上下文被裁剪),则将其设置为 0。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
... # 如果答案未完全包含在上下文中,则标记为 (0, 0) if offsets[context_start][0] > end_char or offsets[context_end][1] < start_char: start_positions.append(0) end_positions.append(0) else: # 否则,查找开始和结束的 token 位置 idx = context_start while idx <= context_end and offsets[idx][0] <= start_char: idx += 1 start_positions.append(idx - 1) idx = context_end while idx >= context_start and offsets[idx][1] >= end_char: idx -= 1 end_positions.append(idx + 1) |
在 preprocess_function()
的末尾,将返回 inputs
对象。它是一个类似字典的对象,键为 input_ids
、attention_masks
、start_positions
和 end_positions
。您必须不要更改这些键的名称,因为 DistilBertForQuestionAnswering
类期望在 forward()
方法中使用这些参数。
DistilBERT 模型期望您使用 input_ids
参数来调用它。如果您使用大批次调用,还需要 attention_masks
来告知输入中的哪些 token 是填充的。如果您使用可选的开始和结束位置调用,还将计算交叉熵损失。这就是 transformers 库的设计方式,它可以帮助您使用相同的接口在推理和训练中调用模型。
运行训练
要运行此代码,您需要安装以下软件包:
1 |
pip install torch datasets transformers accelerate |
虽然您可以预期 torch
、transformers
和 datasets
是必需的。accelerate
包是使用 transformers
库中的 Trainer
类时的依赖项。
您可能期望训练像 DistilBERT 这样的复杂模型需要大量代码。确实,这并不容易,因为您需要决定使用什么优化器、训练的 epoch 数以及批次大小、学习率、权重衰减等超参数。您甚至需要处理检查点,以便在中断时可以恢复训练。
这就是 Trainer
类被引入的原因。您只需设置训练参数,然后使用数据集设置 Trainer
,然后运行训练。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
... training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, num_train_epochs=3, weight_decay=0.01, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], processing_class=tokenizer, ) trainer.train() |
Trainer
将在一次函数调用中处理检查点、日志记录和评估。训练完成后,您只需将微调后的模型(以及分词器,因为它们是一起加载的)以 Hugging Face 格式保存即可。
1 2 3 |
... model.save_pretrained("./fine-tuned-distilbert-squad") tokenizer.save_pretrained("./fine-tuned-distilbert-squad") |
您只需要做这些。即使您没有指定使用 GPU 进行训练,Trainer
也会自动检测您系统上的 GPU 并使用它来加速进程。上面的代码虽然不长,但它是 DistilBERT 在 SQuAD 数据集上进行微调的完整代码。
如果运行此代码,您将期望看到以下输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
DistilBertForQuestionAnswering 的一些权重未从 checkpoint distilbert-base-uncased 初始化,而是新初始化的:['qa_outputs.bias', 'qa_outputs.weight'] 您可能应该在此模型上进行下游任务训练,以便能够将其用于预测和推理。 Map: 100%|████████████████████████████████| 10570/10570 [00:01<00:00, 5387.60 examples/s] {'loss': 2.9462, 'grad_norm': 13.834440231323242, 'learning_rate': 1.9391171993911722e-05, 'epoch': 0.09} {'loss': 1.7333, 'grad_norm': 14.540811538696289, 'learning_rate': 1.8782343987823442e-05, 'epoch': 0.18} {'loss': 1.5268, 'grad_norm': 15.629022598266602, 'learning_rate': 1.8173515981735163e-05, 'epoch': 0.27} {'loss': 1.4487, 'grad_norm': 20.17080307006836, 'learning_rate': 1.756468797564688e-05, 'epoch': 0.37} {'loss': 1.3957, 'grad_norm': 21.543432235717773, 'learning_rate': 1.69558599695586e-05, 'epoch': 0.46} {'loss': 1.3816, 'grad_norm': 15.349509239196777, 'learning_rate': 1.634703196347032e-05, 'epoch': 0.55} {'loss': 1.314, 'grad_norm': 14.986817359924316, 'learning_rate': 1.573820395738204e-05, 'epoch': 0.64} {'loss': 1.2313, 'grad_norm': 15.443862915039062, 'learning_rate': 1.5129375951293761e-05, 'epoch': 0.73} {'loss': 1.2613, 'grad_norm': 10.729198455810547, 'learning_rate': 1.4520547945205482e-05, 'epoch': 0.82} {'loss': 1.1976, 'grad_norm': 18.681406021118164, 'learning_rate': 1.39117199391172e-05, 'epoch': 0.91} {'eval_loss': 1.142066240310669, 'eval_runtime': 14.8679, 'eval_samples_per_second': 710.926, 'eval_steps_per_second': 44.458, 'epoch': 1.0} {'loss': 1.1858, 'grad_norm': 15.170207023620605, 'learning_rate': 1.330289193302892e-05, 'epoch': 1.0} {'loss': 0.962, 'grad_norm': 14.375147819519043, 'learning_rate': 1.2694063926940641e-05, 'epoch': 1.1} {'loss': 0.9994, 'grad_norm': 13.867342948913574, 'learning_rate': 1.2085235920852361e-05, 'epoch': 1.19} {'loss': 0.9912, 'grad_norm': 13.35099983215332, 'learning_rate': 1.147640791476408e-05, 'epoch': 1.28} {'loss': 0.976, 'grad_norm': 18.943002700805664, 'learning_rate': 1.08675799086758e-05, 'epoch': 1.37} {'loss': 0.9687, 'grad_norm': 12.70341968536377, 'learning_rate': 1.025875190258752e-05, 'epoch': 1.46} {'loss': 0.949, 'grad_norm': 10.327693939208984, 'learning_rate': 9.64992389649924e-06, 'epoch': 1.55} {'loss': 0.9482, 'grad_norm': 17.166929244995117, 'learning_rate': 9.04109589041096e-06, 'epoch': 1.64} {'loss': 0.9248, 'grad_norm': 23.135452270507812, 'learning_rate': 8.432267884322679e-06, 'epoch': 1.74} {'loss': 0.9289, 'grad_norm': 15.964847564697266, 'learning_rate': 7.823439878234399e-06, 'epoch': 1.83} {'loss': 0.9605, 'grad_norm': 10.738043785095215, 'learning_rate': 7.214611872146119e-06, 'epoch': 1.92} {'eval_loss': 1.0946319103240967, 'eval_runtime': 14.7779, 'eval_samples_per_second': 715.256, 'eval_steps_per_second': 44.729, 'epoch': 2.0} {'loss': 0.9376, 'grad_norm': 22.791458129882812, 'learning_rate': 6.605783866057839e-06, 'epoch': 2.01} {'loss': 0.7745, 'grad_norm': 15.398698806762695, 'learning_rate': 5.996955859969558e-06, 'epoch': 2.1} {'loss': 0.7458, 'grad_norm': 17.4672908782959, 'learning_rate': 5.388127853881279e-06, 'epoch': 2.19} {'loss': 0.7636, 'grad_norm': 13.833612442016602, 'learning_rate': 4.779299847792998e-06, 'epoch': 2.28} {'loss': 0.7803, 'grad_norm': 11.179983139038086, 'learning_rate': 4.170471841704719e-06, 'epoch': 2.37} {'loss': 0.7666, 'grad_norm': 9.601215362548828, 'learning_rate': 3.5616438356164386e-06, 'epoch': 2.47} {'loss': 0.7784, 'grad_norm': 24.625328063964844, 'learning_rate': 2.9528158295281586e-06, 'epoch': 2.56} {'loss': 0.7389, 'grad_norm': 13.041014671325684, 'learning_rate': 2.343987823439878e-06, 'epoch': 2.65} {'loss': 0.7636, 'grad_norm': 12.822973251342773, 'learning_rate': 1.7351598173515982e-06, 'epoch': 2.74} {'loss': 0.7625, 'grad_norm': 12.254212379455566, 'learning_rate': 1.1263318112633182e-06, 'epoch': 2.83} {'loss': 0.727, 'grad_norm': 8.469372749328613, 'learning_rate': 5.17503805175038e-07, 'epoch': 2.92} {'eval_loss': 1.1390912532806396, 'eval_runtime': 14.8303, 'eval_samples_per_second': 712.731, 'eval_steps_per_second': 44.571, 'epoch': 3.0} {'train_runtime': 1106.5639, 'train_samples_per_second': 237.489, 'train_steps_per_second': 14.843, 'train_loss': 1.0775353780946775, 'epoch': 3.0} 100%|██████████████████████████████████████████████| 16425/16425 [18:26<00:00, 14.84it/s] |
即使使用性能不错的 GPU,这仍然需要一些时间来运行。然而,您正在对新数据集进行预训练模型的微调。这比从头开始训练要快得多,也容易得多。
训练完成后,您可以在其他项目中通过以下路径加载模型:
1 2 3 4 5 6 |
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering model_path = "./fine-tuned-distilbert-squad" tokenizer = DistilBertTokenizerFast.from_pretrained(model_path) model = DistilBertForQuestionAnswering.from_pretrained(model_path) ... |
请确保 model_path
是您项目中用于查找已保存模型文件的正确路径。
进一步阅读
以下是本帖中使用到的类和方法的文档链接:
- DistilBERT on Hugging Face transformer documentation
- DistilBertForQuestionAnswering on Hugging Face transformer documentation
- Trainer on Hugging Face transformer documentation
总结
在本帖中,您学习了如何为自定义问答任务微调 DistilBERT。即使 DistilBERT 和问答被用作示例,您也可以将相同的流程应用于其他模型和任务。特别是,您学习了:
- 如何准备用于训练的数据集
- 如何使用 transformers 库中的
Trainer
接口来训练或微调模型
暂无评论。