在 transformers 库中,Auto Classes 是一项关键设计,它允许您使用预训练模型,而无需担心底层模型架构。这使得您的代码更简洁、更易于维护。例如,您只需更改模型名称即可轻松切换不同的模型架构;即使是运行模型的代码也大不相同。在本博文中,您将了解 Auto Classes 的工作原理以及如何在代码中使用它们。
通过我的书籍《Hugging Face Transformers中的NLP》,快速启动您的项目。它提供了带有工作代码的自学教程。
让我们开始吧!

在 Transformers 库中使用 Auto Classes
图片来源:Erik Mclean。部分权利保留。
概述
这篇博文分为三部分;它们是:
- 什么是 Auto Classes
- 如何使用 Auto Classes
- Auto Classes 的局限性
什么是 Auto Classes
transformers 库中没有名为“AutoClass”的类。相反,有几个类以“Auto”前缀命名。
在自然语言处理的 Transformer 模型中,您将从一些文本开始。您需要将文本转换为 token,然后将 token 转换为 token ID。然后将 token ID 输入模型以获得输出。输出应转换回文本。
在此过程中,您将需要一个分词器和一个主模型。根据任务,例如文本分类或问答,您可以使用同一模型的不同变体。它们核心相同,但使用不同的“头部”来执行任务。
鉴于工作流程在高级别上是标准化的,唯一的区别在于模型应如何精确地操作。该库中有数十种模型架构。您不必了解所有这些细节。但如果您了解,您可以编写如下代码:
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification model_name = "KernAI/stock-news-distilbert" tokenizer = DistilBertTokenizer.from_pretrained(model_name) model = DistilBertForSequenceClassification.from_pretrained(model_name) text = "Machine Learning Mastery is a nice website." inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = logits.argmax().item() |
首先,这不是使用模型的最冗长的方式。在 transformers 库中,您可以定义一个裸露的 DistilBertTokenizer
对象,然后从文件加载词汇表,定义特殊 token,以及其他规则,例如是否强制所有字母小写。其次,创建 DistilBertForSequenceClassification
对象首先应创建一个 config 对象 DistilBertConfig
,该对象定义模型的超参数。然后,您可以从 checkpoint 加载权重。但您可以想象这需要大量工作。
在上文中,您通过使用 from_pretrained()
方法已经简化了工作流程。这是为了从互联网下载预训练模型,其中包含了 config 和相应的分词器参数。但是,上面的代码先设置了模型,然后加载了权重和参数。它假设下载的模型文件与架构兼容。例如,模型可能需要一个名为 hidden_size
的参数,而下载的文件必须不称为 hidden_dim
。
记住每个模型架构的类名并不容易。因此,Auto Classes 的设计就是为了隐藏这种复杂性。
如何使用 Auto Classes
以 DistilBERT 为例,有多种变体。首先,有完全相同的模型的 PyTorch、TensorFlow 和 Flax 实现。其次,DistilBERT 是基础模型的名称。在此基础上,您可以为各种任务添加不同的“头部”。您可以获得
- 基础模型(
DistilBertModel
),它输出原始隐藏状态, - 用于掩码语言模型(
DistilBertForMaskedLM
)的模型,它预测被掩码的 token 应该是什么, - 用于序列分类(
DistilBertForSequenceClassification
)的模型,它用于将整个输入标记为预定义的类别, - 用于问答(
DistilBertForQuestionAnswering
)的模型,它用于从提供的上下文中查找指定问题的答案, - 用于 token 分类(
DistilBertForTokenClassification
)的模型,它用于将每个 token 分类到某个类别, - 用于多项选择任务(
DistilBertForMultipleChoice
)的模型,它比较问题的多个答案并对每个答案的可能性进行评分。
这些都是相同的基本模型,但具有不同的头部。这不是不同变体的唯一列表,因为某些基本模型可能有 DistilBERT 中没有的头部,而某些基本模型可能没有 DistilBERT 拥有的头部。
只要您知道如何为特定任务使用模型,就可以轻松切换到另一个模型。例如,以下代码运行正常,没有任何错误。
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch from transformers import GPT2Tokenizer, OPTForSequenceClassification model_name = "ArthurZ/opt-350m-dummy-sc" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = OPTForSequenceClassification.from_pretrained(model_name) text = "Machine Learning Mastery is a nice website." inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = logits.argmax().item() |
忽略输出,这段代码只更改了分词器和模型的名称。这是 transformers 库标准化接口的结果。但看看上面的代码:您需要知道以“ArthurZ/opt-350m-dummy-sc”存储的模型使用的是 OPTForSequenceClassification
架构(可能可以从名称中猜出)。您还需要知道分词器是 GPT2Tokenizer
(可能无法从名称中猜出,但可以从文档中找到)。
如果只更改模型名称就可以让代码工作,那就更方便了。这时 Auto Classes 就派上用场了。代码将如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification model_name = "ArthurZ/opt-350m-dummy-sc" # 或 "KernAI/stock-news-distilbert" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) text = "Machine Learning Mastery is a nice website." inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = logits.argmax().item() |
您使用了 AutoTokenizer
和 AutoModelForSequenceClassification
。现在,当您更改模型名称时,代码将起作用。这是因为 Auto Classes 会自动下载模型并检查其 config 文件。然后,根据 config 文件中的指定,它将实例化正确的分词器和模型——所有这些都无需您的输入。
请注意,上面的示例使用的是 PyTorch。您要求分词器为您提供 PyTorch 张量,而模型本身也是 PyTorch 的。这是 transformers 库中的默认设置。但如果模型支持,您可以创建一个 TensorFlow/Keras 等效项,只需对代码稍作修改。
1 2 3 4 5 6 7 8 9 10 11 |
import tensorflow as tf from transformers import AutoTokenizer, TFAutoModelForSequenceClassification model_name = "KernAI/stock-news-distilbert" tokenizer = AutoTokenizer.from_pretrained(model_name) model = TFAutoModelForSequenceClassification.from_pretrained(model_name, from_pt=True) text = "Machine Learning Mastery is a nice website." inputs = tokenizer(text, return_tensors="tf") logits = model(**inputs).logits predicted_class_id = tf.math.argmax(logits).numpy() |
您可以尝试使用另一个模型“ArthurZ/opt-350m-dummy-sc”,您应该会看到一个错误。这是因为 OPTForSequenceClassification
类没有对应的 TFOPTForSequenceClassification
。
Auto Classes 的局限性
transformers 库中有许多 Auto Classes。对于 NLP 任务,一些例子是 AutoModel
、AutoModelForCausalLM
、AutoModelForMaskedLM
、AutoModelForSequenceClassification
、AutoModelForQuestionAnswering
、AutoModelForTokenClassification
、AutoModelForMultipleChoice
、AutoModelForTextEncoding
和 AutoModelForNextSentencePrediction
。请注意,其中每一个都针对不同的任务(即基本模型之上的不同头部),并且并非所有模型都支持所有任务。例如,在上一节中,您了解到有 DistilBertForMaskedLM
,因此您可以使用 AutoModelForMaskedLM
和 DistilBERT 模型名称创建一个,但您不能使用 AutoModelForCausalLM
创建 DistilBERT 模型,因为没有 DistilBertForCausalLM
类。
另外,请注意,您会看到以下代码的警告:
1 2 3 4 |
from transformers import AutoModelForSequenceClassification model_name = "distilbert-base-uncased" model = AutoModelForSequenceClassification.from_pretrained(model_name) |
您会看到以下警告:
1 2 |
DistilBertForSequenceClassification 的一些权重未从 distilbert-base-uncased 的模型 checkpoint 初始化,并且是新初始化的:['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight'] 您应该在该下游任务上训练此模型,以便能够将其用于预测和推理。 |
这是因为模型名称“distilbert-base-uncased”仅包含基础模型。它的 config 足以创建 DistilBERT 系列下的所有模型,因为它们的区别仅在于头部。但是,基础模型没有特定头部的权重。当您实例化一个模型并尝试加载权重时,库会发现某些层未初始化,这些层只能使用随机权重作为占位符。这也意味着模型尚未按预期工作。您需要使用自己的数据集训练模型,或者从另一个模型加载权重,例如上一示例中的“KernAI/stock-news-distilbert”。
Auto Classes 的第二个限制是它是一个深度学习模型的包装器。也就是说,它期望一个数值张量并输出一个数值张量。这就是为什么您在上面的示例中需要使用分词器的原因。如果您不需要操作这些张量,而只是将模型用于某个任务,您可以通过使用 pipeline()
函数进一步简化代码。
1 2 3 4 5 6 7 8 9 |
import torch from transformers import pipeline model_name = "KernAI/stock-news-distilbert" classifier = pipeline(model=model_name) text = "Machine Learning Mastery is a nice website." prediction = classifier(text) print(prediction) |
这个示例实际上比上面的任何示例都做得更多。它解释了模型的输出,并为您提供了人类可读的输出。您可以看到它的输出是:
1 |
[{'label': 'positive', 'score': 0.9953118562698364}] |
进一步阅读
以下是一些您可能会发现有用的进一步阅读材料:
总结
在本博文中,您了解了如何在 transformers 库中使用 Auto Classes。它是特定模型类的替代品,因此您可以让库根据模型配置确定要使用的正确类。这使得您可以通过更改名称或路径轻松地在不同模型或 checkpoint 之间切换,而无需进行任何代码更改。使用 Auto Classes 比使用 pipeline API 稍微复杂一些,但它可以避免您需要弄清楚要使用哪个正确类的麻烦。
好文!