微软GraphRAG +本地模型+Gradio 简单测试笔记

安装

pip install graphrag


mkdir -p ./ragtest/input

#将文档拷贝至  ./ragtest/input/  下

python -m graphrag.index --init --root ./ragtest


修改settings.yaml


encoding_model: cl100k_base
skip_workflows: []
llm:
  api_key: ${GRAPHRAG_API_KEY}
  type: openai_chat # or azure_openai_chat
  model: qwen2-instruct
  model_supports_json: true # recommended if this is available for your model.
  # max_tokens: 4000
  # request_timeout: 180.0
  api_base: http://192.168.2.2:9997/v1/
  # api_version: 2024-02-15-preview
  # organization: <organization_id>
  # deployment_name: <azure_model_deployment_name>
  # tokens_per_minute: 150_000 # set a leaky bucket throttle
  # requests_per_minute: 10_000 # set a leaky bucket throttle
  # max_retries: 10
  # max_retry_wait: 10.0
  # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
  concurrent_requests: 5 # the number of parallel inflight requests that may be made

parallelization:
  stagger: 0.3
  # num_threads: 50 # the number of threads to use for parallel processing

async_mode: threaded # or asyncio

embeddings:
  ## parallelization: override the global parallelization settings for embeddings
  async_mode: threaded # or asyncio
  llm:
    api_key: ${GRAPHRAG_API_KEY}
    type: openai_embedding # or azure_openai_embedding
    model: bge-large-zh-v1.5
    api_base: http://127.0.0.1:9997/v1/
    # api_version: 2024-02-15-preview
    # organization: <organization_id>
    # deployment_name: <azure_model_deployment_name>
    # tokens_per_minute: 150_000 # set a leaky bucket throttle
    # requests_per_minute: 10_000 # set a leaky bucket throttle
    # max_retries: 10
    # max_retry_wait: 10.0
    # sleep_on_rate_limit_recommendation: true # whether to sleep when azure suggests wait-times
    # concurrent_requests: 25 # the number of parallel inflight requests that may be made
    # batch_size: 16 # the number of documents to send in a single request
    # batch_max_tokens: 8191 # the maximum number of tokens to send in a single request
    # target: required # or optional
  


chunks:
  size: 300
  overlap: 100
  group_by_columns: [id] # by default, we don't allow chunks to cross documents
    
input:
  type: file # or blob
  file_type: text # or csv
  base_dir: "input"
  file_encoding: utf-8
  file_pattern: ".*\\.txt$"

cache:
  type: file # or blob
  base_dir: "cache"
  # connection_string: <azure_blob_storage_connection_string>
  # container_name: <azure_blob_storage_container_name>

storage:
  type: file # or blob
  base_dir: "output/${timestamp}/artifacts"
  # connection_string: <azure_blob_storage_connection_string>
  # container_name: <azure_blob_storage_container_name>

reporting:
  type: file # or console, blob
  base_dir: "output/${timestamp}/reports"
  # connection_string: <azure_blob_storage_connection_string>
  # container_name: <azure_blob_storage_container_name>

entity_extraction:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  prompt: "prompts/entity_extraction.txt"
  entity_types: [organization,person,geo,event]
  max_gleanings: 0

summarize_descriptions:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  prompt: "prompts/summarize_descriptions.txt"
  max_length: 500

claim_extraction:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  # enabled: true
  prompt: "prompts/claim_extraction.txt"
  description: "Any claims or facts that could be relevant to information discovery."
  max_gleanings: 0

community_report:
  ## llm: override the global llm settings for this task
  ## parallelization: override the global parallelization settings for this task
  ## async_mode: override the global async_mode settings for this task
  prompt: "prompts/community_report.txt"
  max_length: 2000
  max_input_length: 8000

cluster_graph:
  max_cluster_size: 10

embed_graph:
  enabled: false # if true, will generate node2vec embeddings for nodes
  # num_walks: 10
  # walk_length: 40
  # window_size: 2
  # iterations: 3
  # random_seed: 597832

umap:
  enabled: false # if true, will generate UMAP embeddings for nodes

snapshots:
  graphml: false
  raw_entities: false
  top_level_nodes: false

local_search:
  # text_unit_prop: 0.5
  # community_prop: 0.1
  # conversation_history_max_turns: 5
  # top_k_mapped_entities: 10
  # top_k_relationships: 10
  # max_tokens: 12000

global_search:
  # max_tokens: 12000
  # data_max_tokens: 12000
  # map_max_tokens: 1000
  # reduce_max_tokens: 2000
  # concurrency: 32

LLM模型 :Qwen2-72B-Instruct
EMBEDDING模型:  bge-large-zh-v1.5

本地部署模型使用的Xinference

生成索引 图谱

python -m graphrag.index --root ./ragtest

成功界面

全局查询和本地查询

python -m graphrag.query \
--root ./ragtest \
--method global \
"你的问题"


python -m graphrag.query \
--root ./ragtest \
--method local \
"你的问题"

gradio 代码

import sys
import shlex

import gradio as gr
import subprocess


def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def predict(history):
    messages = []
    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not model_msg:
            messages.append({"role": "user", "content": user_msg})
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    messages = messages[len(messages) - 1]["content"]
    print("\n\n====conversation====\n", messages)

    python_path = sys.executable
    # 构建命令
    cmd = [
        python_path, "-m", "graphrag.query",
        "--root", "./ragtest",
        "--method", "local",
    ]

    # 安全地添加查询到命令中
    cmd.append(shlex.quote(messages))
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True, encoding='utf-8')
        output = result.stdout
        if output:
            # 提取 "SUCCESS: Local Search Response:" 之后的内容
            response = output.split("SUCCESS: Local Search Response:", 1)[-1]
            history[-1][1] += response.strip()
            yield history
        else:
            history[-1][1] += "None"
            yield history
    except subprocess.CalledProcessError as e:
        print(e)


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">GraphRAG 测试</h1>""")
    chatbot = gr.Chatbot(height=600)

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")


    def user(query, history):
        return "", history + [[parse_text(query), ""]]


    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot], chatbot
    )

demo.queue()
demo.launch(server_name="0.0.0.0", server_port=9901, inbrowser=True, share=False)

不知道是不是受限于模型能力 还是自己操作问题,个人感觉效果一般 

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-18 08:40:03       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-18 08:40:03       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-18 08:40:03       58 阅读
  4. Python语言-面向对象

    2024-07-18 08:40:03       69 阅读

热门阅读

  1. vue3中的watch函数

    2024-07-18 08:40:03       22 阅读
  2. 力扣题解(目标和)

    2024-07-18 08:40:03       22 阅读
  3. oracle数据字典详解

    2024-07-18 08:40:03       17 阅读
  4. 自定义异常

    2024-07-18 08:40:03       20 阅读
  5. leetcode-46. 全排列

    2024-07-18 08:40:03       23 阅读
  6. 观察者模式-C#

    2024-07-18 08:40:03       26 阅读
  7. 掌握JVM调优:如何在Gradle中配置JVM参数?

    2024-07-18 08:40:03       20 阅读
  8. vue2.0中如何实现数据监听

    2024-07-18 08:40:03       21 阅读