tensorrt-llm知识

tensorrt-llm如何打印logits和probs

#首先在engine编译时加入参数--gather_all_token_logits
trtllm-build --checkpoint_dir ./tmp \
        --output_dir $2/ \
        --gather_all_token_logits

#其次执行tensorrt_llm/examples/run.py,比如是两卡执行,下面放在shell脚本中执行
type=fp8
output_path=output_$type
mpirun -n 2 --allow-run-as-root python3 run.py --input_text="test.txt" --max_output_len 10 \
  --engine_dir /engine/$type --max_input_length 4000 --no_prompt_template \
  --temperature 0.1 --tokenizer_dir /engine/$type/tokenizer_path \
  --output_logits_npy ./$output_path/logits --output_log_probs_npy ./$output_path/log_probs --output_cum_log_probs_npy ./$output_path/cum_log_probs

#最后分析生成的logits_generation.npy文件,python代码如下
import numpy as np
import torch

prefix = 'output_int8'
array = np.load(f'{prefix}/logits_generation.npy')
print(array.shape)

for round in range(1):
    k = 10 #取得分最高的前10个token_id
    arr = array[0][0][round]
    values, indices = torch.topk(torch.from_numpy(arr), k)
    print("Top-k values:", values)

    from transformers import AutoTokenizer
    path = '/engine/tokenizer_path'
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    if len(indices) == 1:
        tokens = tokenizer.convert_ids_to_tokens([indices])
        print(tokens)
    else:
        for id in indices:
            tokens = tokenizer.decode([id])
            # tokens = tokenizer.convert_ids_to_tokens([id])
            print(tokens, end=' ')
            # print(f'{id}:{tokens}', end=' ')
        print()
print(tokenizer.encode("None"), tokenizer.encode("80"), tokenizer.encode("8"))
print(array[0][0][0][tokenizer.encode("None")[0]])
print(array[0][0][0][tokenizer.encode("80")[1]])
print(array[0][0][0][tokenizer.encode("8")[1]])

相关推荐

  1. tensorrt-llm知识

    2024-07-19 13:52:03       19 阅读
  2. TensorRT-llm入门

    2024-07-19 13:52:03       33 阅读
  3. TensorRT-LLM保姆级教程(一)-快速入门

    2024-07-19 13:52:03       59 阅读
  4. TensorRT 自学笔记001 基础知识点和学习资源

    2024-07-19 13:52:03       64 阅读
  5. LLM推理及加速知识

    2024-07-19 13:52:03       34 阅读
  6. ChatGPT高效提问—基础知识LM、PLM以及LLM

    2024-07-19 13:52:03       86 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-19 13:52:03       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-19 13:52:03       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-19 13:52:03       58 阅读
  4. Python语言-面向对象

    2024-07-19 13:52:03       69 阅读

热门阅读

  1. 芯片基础 | `wire`类型引发的学习

    2024-07-19 13:52:03       19 阅读
  2. oracle extract的使用

    2024-07-19 13:52:03       23 阅读
  3. mysql、oracle、db2数据库连接参数

    2024-07-19 13:52:03       19 阅读
  4. 什么是TCP/IP协议

    2024-07-19 13:52:03       23 阅读
  5. 初识synchronized

    2024-07-19 13:52:03       23 阅读
  6. 【QT】001第一个程序

    2024-07-19 13:52:03       19 阅读
  7. 【深度学习】CycleGAN

    2024-07-19 13:52:03       22 阅读
  8. 一篇就够mysql高阶知识总结

    2024-07-19 13:52:03       19 阅读
  9. oracle创建服务

    2024-07-19 13:52:03       22 阅读