gradio构建webui

import gradio as gr

import time

from transformers import AutoTokenizer, AutoModelForCausalLM,TextIteratorStreamer

from threading import Thread

import torch,sys,os

import json

import pandas

import argparse

with gr.Blocks() as demo:

    gr.Markdown("""<h1><center>智能助手</center></h1>""")

    chatbot = gr.Chatbot()

    msg = gr.Textbox()

    state = gr.State()

    with gr.Row():

        clear = gr.Button("新话题")

        re_generate = gr.Button("重新回答")

        sent_bt = gr.Button("发送")

    with gr.Accordion("生成参数", open=False):

        slider_temp = gr.Slider(minimum=0, maximum=1, label="temperature", value=0.3)

        slider_top_p = gr.Slider(minimum=0.5, maximum=1, label="top_p", value=0.95)

        slider_context_times = gr.Slider(minimum=0, maximum=5, label="上文轮次", value=0,step=2.0)

    def user(user_message, history):

        return "", history + [[user_message, None]]

    def bot(history,temperature,top_p,slider_context_times):

        if pandas.isnull(history[-1][1])==False:

            history[-1][1] = None

            yield history

        slider_context_times = int(slider_context_times)

        history_true = history[1:-1]

        prompt = ''

        if slider_context_times>0:

            prompt += '\n'.join([("<s>Human: "+one_chat[0].replace('<br>','\n')+'\n</s>' if one_chat[0] else '')  +"<s>Assistant: "+one_chat[1].replace('<br>','\n')+'\n</s>'    for one_chat in history_true[-slider_context_times:] ])

        prompt +=  "<s>Human: "+history[-1][0].replace('<br>','\n')+"\n</s><s>Assistant: "

        input_ids = tokenizer([prompt], return_tensors="pt",add_special_tokens=False).input_ids[:,-512:].to('cuda')        

        generate_input = {

            "input_ids":input_ids,

            "max_new_tokens":512,

            "do_sample":True,

            "top_k":50,

            "top_p":top_p,

            "temperature":temperature,

            "repetition_penalty":1.3,

            "streamer":streamer,

            "eos_token_id":tokenizer.eos_token_id,

            "bos_token_id":tokenizer.bos_token_id,

            "pad_token_id":tokenizer.pad_token_id

        }

        thread = Thread(target=model.generate, kwargs=generate_input)

        thread.start()

        start_time = time.time()

        bot_message =''

        print('Human:',history[-1][0])

        print('Assistant: ',end='',flush=True)

        for new_text in streamer:

            print(new_text,end='',flush=True)

            if len(new_text)==0:

                continue

            if new_text!='</s>':

                bot_message+=new_text

            if 'Human:' in bot_message:

                bot_message = bot_message.split('Human:')[0]

            history[-1][1] = bot_message

            yield history

        end_time =time.time()

        print()

        print('生成耗时:',end_time-start_time,'文字长度:',len(bot_message),'字耗时:',(end_time-start_time)/len(bot_message))

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(

        bot, [chatbot,slider_temp,slider_top_p,slider_context_times], chatbot

    )

    sent_bt.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(

        bot, [chatbot,slider_temp,slider_top_p,slider_context_times], chatbot

    )

    re_generate.click( bot, [chatbot,slider_temp,slider_top_p,slider_context_times], chatbot )

    clear.click(lambda: [], None, chatbot, queue=False)

if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name_or_path", type=str, help='mode name or path')

    parser.add_argument("--is_4bit", action='store_true', help='use 4bit model')

    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,use_fast=False)

    

    if args.is_4bit==False:

        #model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,device_map='auto',torch_dtype=torch.float16,load_in_8bit=True)

        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,device_map='auto',torch_dtype=torch.float16)

        model.eval()

    else:

        from auto_gptq import AutoGPTQForCausalLM

        model = AutoGPTQForCausalLM.from_quantized(args.model_name_or_path,low_cpu_mem_usage=True, device="cuda:0", use_triton=False,inject_fused_attention=False,inject_fused_mlp=False)

    streamer = TextIteratorStreamer(tokenizer,skip_prompt=True)

    if torch.__version__ >= "2" and sys.platform != "win32":

        model = torch.compile(model)

    #demo.queue().launch(share=False,debug = True)

    demo.queue(concurrency_count=80, max_size=100).launch(max_threads=150,share=False,inbrowser=True,server_name="0.0.0.0",server_port=8000)

相关推荐

  1. gradio构建webui

    2024-07-16 14:16:05       26 阅读
  2. 昇腾npu上构建modelbox webUI开发容器教程

    2024-07-16 14:16:05       69 阅读
  3. <span style='color:red;'>WebGIS</span>

    WebGIS

    2024-07-16 14:16:05      28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-16 14:16:05       70 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-16 14:16:05       74 阅读
  3. 在Django里面运行非项目文件

    2024-07-16 14:16:05       62 阅读
  4. Python语言-面向对象

    2024-07-16 14:16:05       72 阅读

热门阅读

  1. C++中const关键字的深度探索与应用实践

    2024-07-16 14:16:05       22 阅读
  2. ChatGPT对话:如何把Html文件转换为Markdown文件

    2024-07-16 14:16:05       18 阅读
  3. 第2部分:物联网模式在行动

    2024-07-16 14:16:05       21 阅读
  4. c# 在线程中访问ui元素

    2024-07-16 14:16:05       23 阅读
  5. C语言入门-7.结构体与C++引用

    2024-07-16 14:16:05       24 阅读
  6. Python3 第二十二课 -- 装饰器

    2024-07-16 14:16:05       29 阅读