你是否正在寻找使用 CSV/Excel 文件构建高效的 RAG?那么这篇文章绝对适合你 :)
在本文中,我将谈论构建高效的 RAG 的最佳实践,尤其是当您尝试在 LangChain 中使用父文档检索器,并将 csv/excel 文件与多列作为数据集时。
顺便问一下,为什么特别是父文件检索器?
哈哈哈.... 根据我的经验,运行多次实验后,当您有冗长的文档时,最好使用父文档检索器。
让我们开始吧!
首先,我们有一个 csv/excel 数据,这意味着虽然一个列可以作为回答问题的来源,但可能还有许多其他的列,比如国家名称、日期和其他列,可以作为元数据来改进答案。
现在让我们首先看看改善答案的不同技巧,然后再谈谈如何使用元数据。
- 创建自定义的 CSV 加载器:创建一个类似于 Lang chain 的 CSVLoader 的自定义 CSV 加载器,进行少量定制。这将帮助我们定义哪些列需要被视为“页面内容”,哪些列需要被视为“元数据”。
您可以使用此自定义的 CSVLoader 来加载您的 csv 文件。让我举个例子:
import csv
from typing import Dict, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.docstore.document import Document
class CSVLoader(BaseLoader):
"""Loads a CSV file into a list of documents.
Each document represents one row of the CSV file. Every row is converted into a
key/value pair and outputted to a new line in the document's page_content.
The source for each document loaded from csv is set to the value of the
`file_path` argument for all doucments by default.
You can override this by setting the `source_column` argument to the
name of a column in the CSV file.
The source of each document will then be set to the value of the column
with the name specified in `source_column`.
Output Example:
.. code-block:: txt
column1: value1
column2: value2
column3: value3
"""
def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
metadata_columns: Optional[List[str]] = None, # < ADDED
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
):
self.file_path = file_path
self.source_column = source_column
self.encoding = encoding
self.csv_args = csv_args or {}
self.metadata_columns = metadata_columns # < ADDED
def load(self) -> List[Document]:
"""Load data into document objects."""
docs = []
with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items() if k == "Source column name")
try:
source = (
row[self.source_column]
if self.source_column is not None
else self.file_path
)
except KeyError:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
metadata = {"source": source, "row": i}
# ADDED TO SAVE METADATA
if self.metadata_columns:
for k, v in row.items():
if k in self.metadata_columns:
metadata[k] = v
# END OF ADDED CODE
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)
return docs
2. 两个分割器:通常我们只使用一个文本分割器将长文本分割成多个更小的块,但在父文档检索器的情况下,我们使用两个分割器。一个用于具有更多上下文的较大块(让我们称这些较大块为父级),另一个用于具有更好语义意义的较小块(让我们称这些较小块为子级)。专业提示✨:在创建子文档时,玩转块大小:块大小在确定 RAG 系统如何生成答案方面起着重要作用。
我建议最初尝试不同的块大小,直到您觉得创建的子文档之间有较少的重叠,并且生成的答案符合您的期望。
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
3. 存储父级和子级块:存储嵌入✨:初始化父文档检索器并将文档添加到检索器是在运行时发生的事情,不仅需要长时间运行,而且每次查询相同的内容都需要花费成本来创建相同的嵌入,您可以使用下面的代码行来存储已创建的嵌入。
vectorstore.persist()
使用元数据:✨
现在来到有趣的话题。默认情况下,尽管根据文档中的元数据参数应用了元数据,但我们无法使用 LangChain 的父检索器来使用元数据,应用它不会根据提供的过滤器筛选检索到的文档。
因此,我们需要编写一个自定义类,根据元数据过滤相关文档。这有两种可能的方法:
- 从向量存储中获取检索到的相关文件,然后在其上应用元数据过滤器。
- 在执行向量搜索时应用元数据过滤,并仅返回唯一文档。
第二个选项是解决这个问题的正确方法。因此,始终创建一个自定义类,并使用自定义函数根据元数据过滤器检索最相关的文档。
通过这种方式,您不仅可以在搜索时应用筛选器,还可以增加检索到的文档数量(默认情况下只能获取 4 个相关文档)。让我举个例子,说明如何做到这一点。
我正在考虑将国家名称和产品名称作为我的数据中的两列,这些列如上所述已加载为元数据(创建自定义 CSVLoader)。
class ParentDocumentRetriever(BaseRetriever):
vectorstore: VectorStore
docstore: BaseStore[str, Document]
id_key: str = "doc_id"
search_kwargs: dict = Field(default_factory=dict)
child_splitter: TextSplitter
parent_splitter: Optional[TextSplitter] = None
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
metadata_filter: Optional[Dict[str, Any]] = None
) -> List[Document]:
all_results = []
if metadata_filter:
# Iterate over each key-value pair in the metadata_filter
unique_ids = set()
# Iterate over each key-value pair in the metadata_filter
for key, value in metadata_filter.items():
# Perform the similarity search for the current key-value pair
sub_docs = self.vectorstore.similarity_search(query, k=10, filter={key: value}, **self.search_kwargs)
ids = [d.metadata[self.id_key] for d in sub_docs]
# Add unique document IDs to the set
unique_ids.update(ids)
# Retrieve documents from the docstore based on the unique IDs
all_results = self.docstore.mget(list(unique_ids))
print("Filtering documents with metadata:", metadata_filter)
filtered_documents = []
for document in all_results:
if document is not None:
match = all(
any(value in document.metadata.get(key, []) for value in values)
if isinstance(document.metadata.get(key), list)
else document.metadata.get(key) in values
for key, values in metadata_filter.items() if values
)
if match:
filtered_documents.append(document)
docs = filtered_documents
else:
sub_docs = self.vectorstore.similarity_search(query, k=10, **self.search_kwargs)
ids = []
for d in sub_docs:
if d.metadata[self.id_key] not in ids:
ids.append(d.metadata[self.id_key])
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
add_to_docstore: bool = True,
) -> None:
if self.parent_splitter is not None:
documents = self.parent_splitter.split_documents(documents)
if ids is None:
doc_ids = [str(uuid.uuid4()) for _ in documents]
if not add_to_docstore:
raise ValueError(
"If ids are not passed in, `add_to_docstore` MUST be True"
)
else:
if len(documents) != len(ids):
raise ValueError(
"Got uneven list of documents and ids. "
"If `ids` is provided, should be same length as `documents`."
)
doc_ids = ids
docs = []
full_docs = []
for i, doc in enumerate(documents):
_id = doc_ids[i]
sub_docs = self.child_splitter.split_documents([doc])
for _doc in sub_docs:
_doc.metadata[self.id_key] = _id
docs.extend(sub_docs)
full_docs.append((_id, doc))
self.vectorstore.add_documents(docs)
if add_to_docstore:
self.docstore.mset(full_docs)
让我解释一下当您调用检索器的 get_relavant_documents 函数时会发生什么:
对于您提供的每个元数据筛选器(国家和产品名称)
parent_retriever = ParentDocumentRetriever(vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
parent_retriever.get_relevant_documents(query, metadata_filter={"Country":"Canada","ProductName":"Sample"})
矢量存储将在单独应用过滤器后执行相似性搜索,将它们组合起来,然后过滤出唯一的文档,这些唯一的文档可用于查询,即矢量搜索将发生 n 次,其中 n 是 metadata_filter 字典中键值对的数量,然后过滤出唯一的文档。
一旦您从检索器中获得了最相关的文档,就该是将它们链在一起的时候了:
context = parent_retriever.get_relevant_documents(query, metadata_filter={"Country":"Canada","ProductName":"Sample"})
response = llm_chain({"context": context, "question": query})
生成的响应将根据要求应用正确的元数据过滤器,符合预期。
希望这对你有帮助! :)