问答(Question Answering)是一项至关重要的自然语言处理任务,它使机器能够通过从给定的上下文中提取相关信息来理解和响应人类的问题。DistilBERT,作为 BERT 的蒸馏版本,在构建问答系统时提供了性能和计算效率之间的绝佳平衡。
在本教程中,您将学习如何使用 DistilBERT 和 transformers
库构建一个强大的问答(Q&A)系统。您将学习从基本实现到高级功能的方方面面。特别是,您将学习:
- 如何使用 DistilBERT 实现基本的问答系统
- 用于提高答案质量的高级技术
通过我的书籍《Hugging Face Transformers中的NLP》,快速启动您的项目。它提供了带有工作代码的自学教程。
让我们开始吧。

使用 DistilBERT 和 Transformers 构建问答系统
照片来源:Ana Municio。部分权利保留。
概述
这篇文章分为三个部分;它们是
- 构建一个简单的问答系统
- 处理长上下文
- 构建专家系统
构建一个简单的问答系统
问答系统不仅仅是将问题抛给模型并获得答案。您希望答案准确且有充分的依据。做到这一点的方法是提供一个“上下文”来查找答案。虽然这可以防止模型回答开放式问题,但也能防止它**凭空捏造**答案。能够完成此任务的模型将能够理解问题和上下文,这比仅仅一个语言模型要强大得多。
能够执行此任务的模型是 BERT。下面,您将使用 DistilBERT 模型构建一个简单的问答系统。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, pipeline device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "distilbert-base-uncased-distilled-squad" tokenizer = DistilBertTokenizer.from_pretrained(model_name) model = DistilBertForQuestionAnswering.from_pretrained(model_name) qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device) max_answer_length = 50 top_k = 3 question = "What is the capital of France?" context = "France is a country in Western Europe. Its capital is Paris, which is known for its art, fashion, gastronomy and culture." result = qa_pipeline(question=question, context=context, max_answer_len=max_answer_length, top_k=top_k) print(f"Question: {question}") print(f"Context: {context}") print(result) |
您将看到的输出是:
1 2 3 4 5 6 |
已将设备设置为使用 CPU 问题:法国的首都是哪里? 上下文:法国是西欧的一个国家。它的首都是巴黎,以其艺术、时尚、美食和文化而闻名。 [{'score': 0.9776948690414429, 'start': 54, 'end': 59, 'answer': 'Paris'}, {'score': 0.017595181241631508, 'start': 54, 'end': 60, 'answer': 'Paris,'}, {'score': 0.0026904228143393993, 'start': 39, 'end': 59, 'answer': 'Its capital is Paris'}] |
我们使用的模型是 distilbert-base-uncased-distilled-squad
,这是一个使用 SQuAD 数据集微调的 DistilBERT 模型。它是一个“不区分大小写”的模型,这意味着它将输入视为不区分大小写。这是一个经过微调的模型,可以在知识蒸馏方面表现更好。因此,它特别适用于需要理解问题和上下文的问答任务。
要使用它,您创建了一个使用 transformers
库的 pipeline
。您请求它成为一个问答管道,但指定了要使用的模型和分词器,而不是让 pipeline()
函数为您选择一个。
当您调用管道时,您会提供问题和上下文。模型将在上下文中找到答案并返回答案。但是,它返回的不是一个简单的答案,而是答案在上下文中出现的位置以及答案的分数(介于 0 和 1 之间)。由于 top_k
设置为 3,因此返回了三个这样的答案。
从输出中,您可以发现得分最高的答案只是“Paris”(在上下文字符串中的字符位置 54 到 59),但其他答案也并非错误,只是表述方式不同。您可以修改上面的代码,根据分数选择最佳答案。
处理长上下文
这个简单的问答系统的问题在于它只能处理短上下文。模型对它能接受的最大序列长度有限制,对于这个特定模型是 512 个 token。
通常,这个限制的问题不在于问题,而在于上下文,因为您通常有一大段文本作为背景信息,而问题是您想从中找到答案的单个句子。要解决这个问题,您可以“分块”,即,将长上下文字符串拆分成更小的块,然后逐个馈送给问答模型。您应该重复使用问题,但遍历不同的块来查找答案。
使用 top_k=3
,您可以期望从每个块中获得 3 个答案。由于每个答案都有一个分数,您可以简单地选择得分最高的答案。您也可以在找到最佳答案之前丢弃得分低的答案。这样,您就可以判断上下文是否提供了足够的信息来回答问题。
让我们看看如何实现这一点。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import time from dataclasses import dataclass import torch from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, pipeline @dataclass class QAConfig: """QA 设置的配置""" max_sequence_length: int = 512 max_answer_length: int = 50 top_k: int = 3 threshold: float = 0.5 class QASystem: """带分块的问答系统""" def __init__(self, model_name="distilbert-base-uncased-distilled-squad", device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) self.model = DistilBertForQuestionAnswering.from_pretrained(model_name) # 初始化管道用于简单查询和答案缓存 self.qa_pipeline = pipeline("question-answering", model=self.model, tokenizer=self.tokenizer, device=self.device) self.answer_cache = {} def preprocess_context(self, context, max_length=512): """将长上下文分割成小于 max_length 的块""" chunks = [] current_chunk = [] current_length = 0 for word in context.split(): if current_length + 1 + len(word) > max_length: chunks.append(" ".join(current_chunk)) current_chunk = [word] current_length = len(word) else: current_chunk.append(word) current_length += 1 + len(word) # 长度为空格 + 单词 # 添加最后一个块(如果非空) if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def get_answer(self, question, context, config): """带有置信度的答案获取""" # 检查缓存 cache_key = (question, context) if cache_key in self.answer_cache: return self.answer_cache[cache_key] # 预处理上下文为块 context_chunks = self.preprocess_context(context, config.max_sequence_length) # 从所有块中获取答案 answers = [] for chunk in context_chunks: result = self.qa_pipeline(question=question, context=chunk, max_answer_len=config.max_answer_length, top_k=config.top_k) assert isinstance(result, list) for answer in result: if answer["score"] >= config.threshold: answers.append(answer) # 返回最佳答案或指示未找到答案 if answers: best_answer = max(answers, key=lambda x: x["score"]) result = { "answer": best_answer["answer"], "confidence": best_answer["score"], } else: result = { "answer": "未找到答案", "confidence": 0.0, } # 缓存结果 self.answer_cache[cache_key] = result return result config = QAConfig(max_sequence_length=512, max_answer_length=50, threshold=0.5) qa_system = QASystem() context = """ Python 编程语言由 Guido van Rossum 创建,并于 1991 年发布。 Python 以其简单的语法和可读性而闻名。它已成为最受欢迎的 编程语言之一,尤其是在数据科学和机器学习等领域。 该语言由 Python 监督委员会维护,并由大型 社区贡献者开发。 """ questions = [ "谁创建了 Python?", "Python 何时发布?", "为什么 Python 如此受欢迎?", "Python 以什么闻名?" ] for question in questions: start_time = time.time() answer = qa_system.get_answer(question, context, config) duration = time.time() - start_time print(f"问题: {question}") print(f"答案: {answer['answer']}") print(f"置信度: {answer['confidence']:.2f}") print(f"耗时: {duration:.2f}s") print("-" * 50) |
这会将工作流程封装到一个类中,使其更易于使用。您将问题和上下文传递给 get_answer()
方法,它将返回得分最高的答案。
在 get_answer()
方法中,如果答案已在缓存中,它将立即返回。否则,它将通过空格分割上下文,将每个块保持在长度限制以下。然后,将每个块与问题进行匹配,从问答模型中获取答案(及分数)。只有得分高于阈值的答案才被认为是有效的。然后选择最佳答案。可能找不到得分足够高的答案。在这种情况下,您将其标记为“未找到答案”。
为了方便语法,使用的参数存储在 dataclass
对象中。请注意,它将 max_sequence_length
设置为 512。这是一个保守的选择,因为模型可以处理多达 512 个 token,大约相当于 1500 个字符。但是,设置较低的序列长度可以帮助模型更有效地运行,因为 Transformer 模型的时空复杂性与序列长度呈二次方关系。
此代码的输出是
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
设备设置为使用 cuda 问题:谁创建了 Python? 答案:Guido van Rossum 置信度:1.00 持续时间:0.10 秒 -------------------------------------------------- 问题:Python 是何时发布的? 答案:1991 置信度:0.98 持续时间:0.00 秒 -------------------------------------------------- 问题:为什么 Python 如此受欢迎? 答案:未找到答案 置信度:0.00 持续时间:0.00 秒 -------------------------------------------------- 问题:Python 以什么而闻名? 答案:未找到答案 置信度:0.00 持续时间:0.00 秒 -------------------------------------------------- |
您可能会注意到,上述实现可能存在一个问题,即一个块被分割在一个句子中间,而最合适的答案可能就在其中。在这种情况下,您可能会发现问答模型找不到答案,或者返回了一个次优的答案。这是 preprocess_context()
方法算法中的一个问题。您可以考虑使用更长的块大小或创建带有重叠词的块。您可以尝试将其作为一项练习来实现。
构建专家系统
有了上述问答系统作为构建块,您可以自动化构建问题上下文的过程。通过一个可以用作问答上下文的文档数据库,您可以构建一个专家系统,该系统可以回答各种各样的问题。
构建一个好的专家系统是一项复杂的任务,涉及许多考虑因素。但是,高级框架并不难理解。这与 RAG(检索增强生成)的想法相似,其中上下文是从文档数据库中检索的,然后由模型生成答案。一个关键的组成部分是能够检索与问题最相关上下文的数据库。让我们看看如何构建一个。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import collections class ContextManager: def __init__(self, max_contexts=10): self.contexts = collections.OrderedDict() self.max_contexts = max_contexts def add_context(self, context_id, context): """Add context with automatic cleanup""" if len(self.contexts) >= self.max_contexts: self.contexts.popitem(last=False) self.contexts[context_id] = context def get_context(self, context_id): """Get context by ID""" return self.contexts.get(context_id) def search_relevant_context(self, question, top_k=3): """Search for relevant contexts based on relevance score""" relevant_contexts = [] for context_id, context in self.contexts.items(): relevance_score = self._calculate_relevance(question, context) relevant_contexts.append((relevance_score, context_id)) return sorted(relevant_contexts, reverse=True)[:top_k] def _calculate_relevance(self, question, context): """Calculate relevance score between question and context. This is a simple counting the number of overlap words """ question_words = set(question.lower().split()) context_words = set(context.lower().split()) return len(question_words.intersection(context_words)) / len(question_words) |
这个类名为 ContextManager
。您可以使用上下文 ID 向其中添加一段文本,上下文管理器将只保留有限数量的上下文。您可以使用上下文 ID 来检索文本。但最重要的函数是 search_relevant_context()
,它将根据提供的问句搜索最相关的上下文。您可以使用不同的算法来计算相关性得分。这里使用了一个简单的算法,即计算重叠词的数量,或者使用 Jaccard 相似度。
有了这个类,您可以构建一个专家系统,该系统可以回答各种各样的问题。下面是如何使用它的示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
... context_manager = ContextManager(max_contexts=10) context_manager.add_context("python", """ Python 是一种高级解释型编程语言,由 Guido van Rossum 创建,于 1991 年发布。 Python 的设计理念强调代码的可读性,并显著使用了空白字符。 Python 具有动态类型系统和自动内存管理,并支持多种编程 范式,包括结构化、面向对象和函数式编程。 """) context_manager.add_context("machine_learning", """ 机器学习是一个研究领域,它使计算机能够在没有 明确编程的情况下进行学习。它是人工智能的一个分支,其基本思想是系统 可以从数据中学习、识别模式并做出决策,而最少的人工干预。 """) config = QAConfig(max_sequence_length=512, max_answer_length=50, threshold=0.5) qa_system = QASystem() question = "Who created Python?" relevant_contexts = context_manager.search_relevant_context(question, top_k=1) if relevant_contexts: relevance, context_id = relevant_contexts[0] context = context_manager.get_context(context_id) print(f"Question: {question}") print(f"Most relevant context: {context_id} (relevance: {relevance:.2f})") print(context) answer = qa_system.get_answer(question, context, config) print(f"答案: {answer['answer']}") print(f"置信度: {answer['confidence']:.2f}") else: print("No relevant context found.") |
您首先将一些上下文添加到上下文管理器中。根据所需的上下文管理器最大大小,您可以向系统中添加大量文本。然后,您可以根据问题搜索最相关的上下文。然后,您可以像前一节一样,将问题和上下文馈送到问答系统中以获取答案,其中块状化和迭代查找最佳答案是在后台完成的。
您可以扩展此功能,尝试使用前几个上下文来查找更广泛的上下文中的答案。这是一种避免在上下文中找不到最佳答案的简单方法。但是,如果您有更好的方法来评估上下文的相关性,例如使用神经网络模型来计算相关性得分,您可能不需要尝试许多上下文。
上述输出将是
1 2 3 4 5 6 7 8 9 10 |
问题:谁创建了 Python? 最相关的上下文:python (相关性:0.33) Python 是一种高级解释型编程语言,由 Guido van Rossum 创建,于 1991 年发布。 Python 的设计理念强调代码的可读性,并显著使用了空白字符。 Python 具有动态类型系统和自动内存管理,并支持多种编程 范式,包括结构化、面向对象和函数式编程。 答案:Guido van Rossum 置信度:1.00 |
总而言之,以下是完整的代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import collections import time from dataclasses import dataclass import torch from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering, pipeline @dataclass class QAConfig: """QA 设置的配置""" max_sequence_length: int = 512 max_answer_length: int = 50 top_k: int = 3 threshold: float = 0.5 class QASystem: """带分块的问答系统""" def __init__(self, model_name="distilbert-base-uncased-distilled-squad", device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) self.model = DistilBertForQuestionAnswering.from_pretrained(model_name) # 初始化管道用于简单查询和答案缓存 self.qa_pipeline = pipeline("question-answering", model=self.model, tokenizer=self.tokenizer, device=self.device) self.answer_cache = {} def preprocess_context(self, context, max_length=512): """将长上下文分割成小于 max_length 的块""" chunks = [] current_chunk = [] current_length = 0 for word in context.split(): if current_length + 1 + len(word) > max_length: chunks.append(" ".join(current_chunk)) current_chunk = [word] current_length = len(word) else: current_chunk.append(word) current_length += 1 + len(word) # 长度为空格 + 单词 # 添加最后一个块(如果非空) if current_chunk: chunks.append(" ".join(current_chunk)) return chunks def get_answer(self, question, context, config): """带有置信度的答案获取""" # 检查缓存 cache_key = (question, context) if cache_key in self.answer_cache: return self.answer_cache[cache_key] # 预处理上下文为块 context_chunks = self.preprocess_context(context, config.max_sequence_length) # 从所有块中获取答案 answers = [] for chunk in context_chunks: result = self.qa_pipeline(question=question, context=chunk, max_answer_len=config.max_answer_length, top_k=config.top_k) assert isinstance(result, list) for answer in result: if answer["score"] >= config.threshold: answers.append(answer) # 返回最佳答案或指示未找到答案 if answers: best_answer = max(answers, key=lambda x: x["score"]) result = { "answer": best_answer["answer"], "confidence": best_answer["score"], } else: result = { "answer": "未找到答案", "confidence": 0.0, } # 缓存结果 self.answer_cache[cache_key] = result return result class ContextManager: def __init__(self, max_contexts=10): self.contexts = collections.OrderedDict() self.max_contexts = max_contexts def add_context(self, context_id, context): """Add context with automatic cleanup""" if len(self.contexts) >= self.max_contexts: self.contexts.popitem(last=False) self.contexts[context_id] = context def get_context(self, context_id): """Get context by ID""" return self.contexts.get(context_id) def search_relevant_context(self, question, top_k=3): """Search for relevant contexts based on relevance score""" relevant_contexts = [] for context_id, context in self.contexts.items(): relevance_score = self._calculate_relevance(question, context) relevant_contexts.append((relevance_score, context_id)) return sorted(relevant_contexts, reverse=True)[:top_k] def _calculate_relevance(self, question, context): """Calculate relevance score between question and context. This is a simple counting the number of overlap words """ question_words = set(question.lower().split()) context_words = set(context.lower().split()) return len(question_words.intersection(context_words)) / len(question_words) context_manager = ContextManager(max_contexts=10) context_manager.add_context("python", """ Python 是一种高级解释型编程语言,由 Guido van Rossum 创建,于 1991 年发布。 Python 的设计理念强调代码的可读性,并显著使用了空白字符。 Python 具有动态类型系统和自动内存管理,并支持多种编程 范式,包括结构化、面向对象和函数式编程。 """) context_manager.add_context("machine_learning", """ 机器学习是一个研究领域,它使计算机能够在没有 明确编程的情况下进行学习。它是人工智能的一个分支,其基本思想是系统 可以从数据中学习、识别模式并做出决策,而最少的人工干预。 """) config = QAConfig(max_sequence_length=512, max_answer_length=50, threshold=0.5) qa_system = QASystem() question = "Who created Python?" relevant_contexts = context_manager.search_relevant_context(question, top_k=1) if relevant_contexts: relevance, context_id = relevant_contexts[0] context = context_manager.get_context(context_id) print(f"Question: {question}") print(f"Most relevant context: {context_id} (relevance: {relevance:.2f})") print(context) answer = qa_system.get_answer(question, context, config) print(f"答案: {answer['answer']}") print(f"置信度: {answer['confidence']:.2f}") else: print("No relevant context found.") |
进一步阅读
以下是一些您可能会觉得有用的资源:
- DistilBERT 模型 本教程中使用
- 什么是 RAG(检索增强生成)?
- 问答 Pipeline 在 transformers 文档中
总结
在本教程中,您使用 DistilBERT 构建了一个全面的问答系统。特别是,您学习了如何
- 使用 transformers 中的 pipeline 函数构建问答系统
- 通过分块处理大型上下文
- 使用上下文管理器管理上下文并在其之上构建专家系统
暂无评论。