语义检索系统排序阶段使用PaddleInference进行模型推理

import os
import sys
from scipy.special import softmax
from scipy.special import expit
import numpy as np
import pandas as pd
import paddle
from paddle import inference
import paddle.nn.functional as F
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.utils.log import logger
from paddlenlp.transformers import AutoTokenizer, AutoModel
sys.path.append('.')

max_seq_length=128
batch_size=32
model_dir= './output'
model_file = model_dir + "/inference.predict.pdmodel"
params_file = model_dir + "/inference.predict.pdiparams"
use_tensorrt=True
model_name_or_path='ernie-3.0-medium-zh'
input_file ='./datasets/yysy/sort/test_pairwise.csv'

config = paddle.inference.Config(model_file, params_file)

paddle.set_device('gpu')

config.enable_use_gpu(100, 0)

precision_map = {
            "fp16": inference.PrecisionType.Half,
            "fp32": inference.PrecisionType.Float32,
            "int8": inference.PrecisionType.Int8
        }

precision='fp32'

precision_mode = precision_map[precision]

# if use_tensorrt:
#     config.enable_tensorrt_engine(max_batch_size=batch_size,
#                                   min_subgraph_size=30,
#                                   precision_mode=precision_mode)

config.switch_use_feed_fetch_ops(False)

predictor = paddle.inference.create_predictor(config)

predictor.get_input_names()

input_handles = [
        predictor.get_input_handle(name)
        for name in predictor.get_input_names()
    ]

output_handle = predictor.get_output_handle(
           predictor.get_output_names()[0])

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def read_text_pair(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = line.rstrip().split("\t")
            if len(data) != 3:
                continue
            yield {'query': data[0], 'title': data[1]}

test_ds = load_dataset(read_text_pair,
                           data_path=input_file,
                           lazy=False)

data = [{'query': d['query'], 'title': d['title']} for d in test_ds]

 batches = [
        data[idx:idx + batch_size]
        for idx in range(0, len(data), batch_size)
    ]

def convert_example(example, tokenizer, max_seq_length=512, is_test=False):
    query, title = example["query"], example["title"]
    encoded_inputs = tokenizer(text=query,
                               text_pair=title,
                               max_seq_len=max_seq_length)
    input_ids = encoded_inputs["input_ids"]
    token_type_ids = encoded_inputs["token_type_ids"]
    if not is_test:#非测试,评估返回
        label = np.array([example["label"]], dtype="int64")
        return input_ids, token_type_ids, label
    else:#测试返回
        return input_ids, token_type_ids

results = []
for batch_data in batches:
    examples = []#保持每个样本的token和token_type,数字形式
    for text in batch_data:
        input_ids, segment_ids = convert_example(
            text,
            tokenizer,
            max_seq_length=max_seq_length,
            is_test=True)
        examples.append((input_ids, segment_ids))
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"),  # input
        Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64"
            ),  # segment
    ): fn(samples)
    input_ids, segment_ids = batchify_fn(examples)
    input_handles[0].copy_from_cpu(input_ids)
    input_handles[1].copy_from_cpu(segment_ids)
    predictor.run()
    sim_score = output_handle.copy_to_cpu()#本来这个处理就是概率了
#     print('1',sim_score[0][0])
#     sim_score = expit(sim_score)#sigmoid
#     print('2',sim_score[0][0])
    results.extend(sim_score)

for idx, text in enumerate(data[:20]):
    print('Data: {} \t prob: {}'.format(text, results[idx]))

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-10 22:24:03       5 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-10 22:24:03       5 阅读
  3. 在Django里面运行非项目文件

    2024-07-10 22:24:03       4 阅读
  4. Python语言-面向对象

    2024-07-10 22:24:03       7 阅读

热门阅读

  1. 【C语言】通过fgets和fscanf了解读写文件流的概念

    2024-07-10 22:24:03       10 阅读
  2. mac上修改jupyterlab工作目录

    2024-07-10 22:24:03       9 阅读
  3. mongoexport导出聚合查询的mongo数据

    2024-07-10 22:24:03       8 阅读
  4. k8s离线安装单节点elasticsearch7.x

    2024-07-10 22:24:03       10 阅读