通过os.dup sys.stdout.fileno捕获标准输出,判断pytorch算子是否fallback到了cpu

通过os.dup sys.stdout.fileno捕获标准输出,判断pytorch算子是否fallback到了cpu

某种设备在运行pytorch算子时,如果不支持会自动fallback到cpu,输出的tensor.device却不是cpu,我希望能获取到这个状态。本文通过捕获标准输出,根据终端是否输出fallback字符串,判断是否触发了fallback

一.代码


import threading
import sys
import os

class CheckFallback:
    def __init__(self,enable=True):        
        self.is_fallback=False
        self.enable=enable
        if self.enable:
            self.stdout_fileno_origin = sys.stdout.fileno()
            self.stdout_fileno_dup = os.dup(self.stdout_fileno_origin)
            self.stdout_pipe = os.pipe()
            os.dup2(self.stdout_pipe[1], self.stdout_fileno_origin)
            os.close(self.stdout_pipe[1])
            self.stdout_messages = ''
            self.running=True
            self.task = threading.Thread(target=self.read_pipe)
            self.task.start()

    def read_pipe(self):
        while self.running:
            msg = os.read(self.stdout_pipe[0], 8192)
            if msg:
                self.stdout_messages+=msg.decode('utf-8')
    
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.enable:
            self.running=False
            os.close(self.stdout_fileno_origin)
            self.task.join()
            os.close(self.stdout_pipe[0])
            os.dup2(self.stdout_fileno_dup, self.stdout_fileno_origin)
            os.close(self.stdout_fileno_dup)
            #检查终端是否有fallback信息输出
            if self.stdout_messages.find("fallback")>=0:
                self.is_fallback=True

import torch
A=torch.ones((512,65024),dtype=torch.float16).to("your_device")
with CheckFallback() as f:
    C=torch.ops.aten.gelu.default(A)    
print(f.is_fallback)
print(C.shape,C.device)

with CheckFallback() as f:
    A=torch.ones((1,32),dtype=torch.float16).to("your_device")
    C=torch.ops.aten.pow(A,A)
print(f.is_fallback)
print(C.shape,C.device)

相关推荐

  1. 通过浏览器判断是否安装APP

    2024-05-12 11:04:10       29 阅读
  2. 如何判断服务器是否被入侵

    2024-05-12 11:04:10       37 阅读
  3. 拦截pytorch算子,dump输入输出

    2024-05-12 11:04:10       22 阅读
  4. 将Linux 标准输出,错误输出重定向文件

    2024-05-12 11:04:10       41 阅读
  5. php 如何判断是否上传文件、图片

    2024-05-12 11:04:10       29 阅读
  6. js判断是否T+N的时间限制

    2024-05-12 11:04:10       47 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-12 11:04:10       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-12 11:04:10       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-12 11:04:10       20 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-12 11:04:10       20 阅读

热门阅读

  1. TensorFlow和PyTorch的对比

    2024-05-12 11:04:10       9 阅读
  2. MongoDB聚合运算符:$toString

    2024-05-12 11:04:10       9 阅读
  3. Flutter备用依赖

    2024-05-12 11:04:10       10 阅读
  4. 什么是渐进式框架

    2024-05-12 11:04:10       8 阅读
  5. matlab人脸识别

    2024-05-12 11:04:10       8 阅读
  6. 基于STM32的衣柜防潮系统设计的毕业论文

    2024-05-12 11:04:10       8 阅读
  7. Android中C++如何读写json文件

    2024-05-12 11:04:10       12 阅读