使用父文档检索器(PDR)构建 RAG 的最佳实践,并使用 PDR 进行元数据过滤。

你是否正在寻找使用 CSV/Excel 文件构建高效的 RAG?那么这篇文章绝对适合你 :)

在本文中,我将谈论构建高效的 RAG 的最佳实践,尤其是当您尝试在 LangChain 中使用父文档检索器,并将 csv/excel 文件与多列作为数据集时。

顺便问一下,为什么特别是父文件检索器?

哈哈哈.... 根据我的经验,运行多次实验后,当您有冗长的文档时,最好使用父文档检索器。

 让我们开始吧!

首先,我们有一个 csv/excel 数据,这意味着虽然一个列可以作为回答问题的来源,但可能还有许多其他的列,比如国家名称、日期和其他列,可以作为元数据来改进答案。

现在让我们首先看看改善答案的不同技巧,然后再谈谈如何使用元数据。

  1. 创建自定义的 CSV 加载器:创建一个类似于 Lang chain 的 CSVLoader 的自定义 CSV 加载器,进行少量定制。这将帮助我们定义哪些列需要被视为“页面内容”,哪些列需要被视为“元数据”。

    您可以使用此自定义的 CSVLoader 来加载您的 csv 文件。让我举个例子:
import csv
from typing import Dict, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.docstore.document import Document


class CSVLoader(BaseLoader):
    """Loads a CSV file into a list of documents.

    Each document represents one row of the CSV file. Every row is converted into a
    key/value pair and outputted to a new line in the document's page_content.

    The source for each document loaded from csv is set to the value of the
    `file_path` argument for all doucments by default.
    You can override this by setting the `source_column` argument to the
    name of a column in the CSV file.
    The source of each document will then be set to the value of the column
    with the name specified in `source_column`.

    Output Example:
        .. code-block:: txt

            column1: value1
            column2: value2
            column3: value3
    """

    def __init__(
        self,
        file_path: str,
        source_column: Optional[str] = None,
        metadata_columns: Optional[List[str]] = None,   # < ADDED
        csv_args: Optional[Dict] = None,
        encoding: Optional[str] = None,
    ):
        self.file_path = file_path
        self.source_column = source_column
        self.encoding = encoding
        self.csv_args = csv_args or {}
        self.metadata_columns = metadata_columns        # < ADDED

    def load(self) -> List[Document]:
        """Load data into document objects."""

        docs = []
        with open(self.file_path, newline="", encoding=self.encoding) as csvfile:
            csv_reader = csv.DictReader(csvfile, **self.csv_args)  # type: ignore
            for i, row in enumerate(csv_reader):
                content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items() if k == "Source column name")
                try:
                    source = (
                        row[self.source_column]
                        if self.source_column is not None
                        else self.file_path
                    )
                except KeyError:
                    raise ValueError(
                        f"Source column '{self.source_column}' not found in CSV file."
                    )
                metadata = {"source": source, "row": i}
                # ADDED TO SAVE METADATA
                if self.metadata_columns:
                    for k, v in row.items():
                        if k in self.metadata_columns:
                            metadata[k] = v
                # END OF ADDED CODE
                doc = Document(page_content=content, metadata=metadata)
                docs.append(doc)

        return docs

2. 两个分割器:通常我们只使用一个文本分割器将长文本分割成多个更小的块,但在父文档检索器的情况下,我们使用两个分割器。一个用于具有更多上下文的较大块(让我们称这些较大块为父级),另一个用于具有更好语义意义的较小块(让我们称这些较小块为子级)。专业提示✨:在创建子文档时,玩转块大小:块大小在确定 RAG 系统如何生成答案方面起着重要作用。

我建议最初尝试不同的块大小,直到您觉得创建的子文档之间有较少的重叠,并且生成的答案符合您的期望。

parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)

child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)

3. 存储父级和子级块:存储嵌入✨:初始化父文档检索器并将文档添加到检索器是在运行时发生的事情,不仅需要长时间运行,而且每次查询相同的内容都需要花费成本来创建相同的嵌入,您可以使用下面的代码行来存储已创建的嵌入。

vectorstore.persist()

 使用元数据:✨

现在来到有趣的话题。默认情况下,尽管根据文档中的元数据参数应用了元数据,但我们无法使用 LangChain 的父检索器来使用元数据,应用它不会根据提供的过滤器筛选检索到的文档。

因此,我们需要编写一个自定义类,根据元数据过滤相关文档。这有两种可能的方法:

  1. 从向量存储中获取检索到的相关文件,然后在其上应用元数据过滤器。
  2. 在执行向量搜索时应用元数据过滤,并仅返回唯一文档。

第二个选项是解决这个问题的正确方法。因此,始终创建一个自定义类,并使用自定义函数根据元数据过滤器检索最相关的文档。

通过这种方式,您不仅可以在搜索时应用筛选器,还可以增加检索到的文档数量(默认情况下只能获取 4 个相关文档)。让我举个例子,说明如何做到这一点。

我正在考虑将国家名称和产品名称作为我的数据中的两列,这些列如上所述已加载为元数据(创建自定义 CSVLoader)。

class ParentDocumentRetriever(BaseRetriever):
    vectorstore: VectorStore
    docstore: BaseStore[str, Document]
    id_key: str = "doc_id"
    search_kwargs: dict = Field(default_factory=dict)
    child_splitter: TextSplitter
    parent_splitter: Optional[TextSplitter] = None

    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
        metadata_filter: Optional[Dict[str, Any]] = None
    ) -> List[Document]:
        all_results = []
        if metadata_filter:
            # Iterate over each key-value pair in the metadata_filter
            unique_ids = set()

            # Iterate over each key-value pair in the metadata_filter
            for key, value in metadata_filter.items():
                # Perform the similarity search for the current key-value pair
                sub_docs = self.vectorstore.similarity_search(query, k=10, filter={key: value}, **self.search_kwargs)
                ids = [d.metadata[self.id_key] for d in sub_docs]

                # Add unique document IDs to the set
                unique_ids.update(ids)

            # Retrieve documents from the docstore based on the unique IDs
            all_results = self.docstore.mget(list(unique_ids))
            print("Filtering documents with metadata:", metadata_filter)
            filtered_documents = []

            for document in all_results:
                if document is not None:
                    match = all(
                        any(value in document.metadata.get(key, []) for value in values)
                        if isinstance(document.metadata.get(key), list)
                        else document.metadata.get(key) in values
                        for key, values in metadata_filter.items() if values
                        )
                if match:
                    filtered_documents.append(document)

            docs = filtered_documents
        else:
            sub_docs = self.vectorstore.similarity_search(query, k=10, **self.search_kwargs)
            ids = []
            for d in sub_docs:
                if d.metadata[self.id_key] not in ids:
                    ids.append(d.metadata[self.id_key])
            docs = self.docstore.mget(ids)

        return [d for d in docs if d is not None]

    def add_documents(
        self,
        documents: List[Document],
        ids: Optional[List[str]] = None,
        add_to_docstore: bool = True,
    ) -> None:
        if self.parent_splitter is not None:
            documents = self.parent_splitter.split_documents(documents)
        if ids is None:
            doc_ids = [str(uuid.uuid4()) for _ in documents]
            if not add_to_docstore:
                raise ValueError(
                    "If ids are not passed in, `add_to_docstore` MUST be True"
                )
        else:
            if len(documents) != len(ids):
                raise ValueError(
                    "Got uneven list of documents and ids. "
                    "If `ids` is provided, should be same length as `documents`."
                )
            doc_ids = ids

        docs = []
        full_docs = []
        for i, doc in enumerate(documents):
            _id = doc_ids[i]
            sub_docs = self.child_splitter.split_documents([doc])
            for _doc in sub_docs:
                _doc.metadata[self.id_key] = _id
            docs.extend(sub_docs)
            full_docs.append((_id, doc))
        self.vectorstore.add_documents(docs)
        if add_to_docstore:
            self.docstore.mset(full_docs)

让我解释一下当您调用检索器的 get_relavant_documents 函数时会发生什么:

对于您提供的每个元数据筛选器(国家和产品名称)

parent_retriever = ParentDocumentRetriever(vectorstore=vectorstore,
        docstore=store,
        child_splitter=child_splitter,
        parent_splitter=parent_splitter,
    )
parent_retriever.get_relevant_documents(query, metadata_filter={"Country":"Canada","ProductName":"Sample"})

矢量存储将在单独应用过滤器后执行相似性搜索,将它们组合起来,然后过滤出唯一的文档,这些唯一的文档可用于查询,即矢量搜索将发生 n 次,其中 n 是 metadata_filter 字典中键值对的数量,然后过滤出唯一的文档。

一旦您从检索器中获得了最相关的文档,就该是将它们链在一起的时候了:

context = parent_retriever.get_relevant_documents(query, metadata_filter={"Country":"Canada","ProductName":"Sample"})
response = llm_chain({"context": context, "question": query})

生成的响应将根据要求应用正确的元数据过滤器,符合预期。

希望这对你有帮助! :)

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-06-12 21:06:06       17 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-12 21:06:06       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-12 21:06:06       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-12 21:06:06       18 阅读

热门阅读

  1. 第10天:数据库模型(基础)

    2024-06-12 21:06:06       7 阅读
  2. 短剧推荐2024-03

    2024-06-12 21:06:06       6 阅读
  3. 百度地图瓦片下载地址

    2024-06-12 21:06:06       7 阅读
  4. GPT-4o的综合评估与前景展望

    2024-06-12 21:06:06       6 阅读
  5. 全面解析C++对象的向上和向下类型转换”

    2024-06-12 21:06:06       7 阅读
  6. Web前端开发海报:揭示前端设计的魅力与技巧

    2024-06-12 21:06:06       10 阅读
  7. Anconda环境迁移

    2024-06-12 21:06:06       7 阅读