pytorch中的while for 循环 导出onnx的问题

问题:

for执行次数不跟据输入而改变。

解决方案:

torch.jit.script

例如:

class LoopAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        h = x
        for i in range(x.size(0)):
            h = h + 1
        return h
input_1 = torch.ones(3, 16)
model = LoopAdd()
traced_model = torch.jit.trace(model, (input_1, ))
print(traced_model.graph)
graph(%self : __torch__.LoopAdd,
      %x : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu)):
  %7 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %8 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %h.1 : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%x, %7, %8) # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %10 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %11 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %h : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%h.1, %10, %11) # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %13 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %14 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %15 : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%h, %13, %14) # /home/mark.yj/GPT-SoVITS/b.py:8:0
  return (%15)

改造后:

class LoopAdd(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
    @torch.jit.script_method
    def forward(self, x):
        h = x
        for i in range(x.size(0)):
            h = h + 1
        return h
input_1 = torch.ones(3, 16)
model = LoopAdd()
traced_model = torch.jit.trace(model, (input_1, ))
print(traced_model.graph)
graph(%self : __torch__.LoopAdd,
      %x.1 : Tensor):
  %8 : bool = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:18:8
  %4 : int = prim::Constant[value=0]() # /home/mark.yj/GPT-SoVITS/b.py:18:30
  %11 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:19:20
  %5 : int = aten::size(%x.1, %4) # /home/mark.yj/GPT-SoVITS/b.py:18:23
  %h : Tensor = prim::Loop(%5, %8, %x.1) # /home/mark.yj/GPT-SoVITS/b.py:18:8
    block0(%i : int, %h.9 : Tensor):
      %h.3 : Tensor = aten::add(%h.9, %11, %11) # /home/mark.yj/GPT-SoVITS/b.py:19:16
      -> (%8, %h.3)
  return (%h)

可以看到 prim::Loop ,说明不再是固定参数的静态图了。

相关推荐

  1. pytorchwhile for 循环 导出onnx问题

    2024-04-03 13:16:01       13 阅读
  2. pytorch导出ONNX相关问题

    2024-04-03 13:16:01       43 阅读
  3. pytorch2ONNX时,AdaptiveAvgPool2d相关问题

    2024-04-03 13:16:01       7 阅读
  4. pytorch 支持更多 onnx 算子

    2024-04-03 13:16:01       47 阅读
  5. pytorch onnx ncnn间关系

    2024-04-03 13:16:01       9 阅读
  6. 解决PyTorch ONNX模型每次输出结果不稳定问题

    2024-04-03 13:16:01       41 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-03 13:16:01       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-03 13:16:01       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-03 13:16:01       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-03 13:16:01       18 阅读

热门阅读

  1. docker安装nodejs

    2024-04-03 13:16:01       11 阅读
  2. Vue父子组件通信代码示例

    2024-04-03 13:16:01       14 阅读
  3. CachedNetworkImage 在listview 返回页面闪烁问题

    2024-04-03 13:16:01       11 阅读
  4. @QtCore.pyqtSlot() 的用法

    2024-04-03 13:16:01       11 阅读
  5. 排队接水水水水水水

    2024-04-03 13:16:01       11 阅读
  6. kafka broker

    2024-04-03 13:16:01       9 阅读
  7. go root和go path

    2024-04-03 13:16:01       12 阅读
  8. 软件设计原则:组合/聚合复用原则

    2024-04-03 13:16:01       12 阅读
  9. 算法刷题记录 Day33

    2024-04-03 13:16:01       13 阅读
  10. 如何在Windows上安装SSH

    2024-04-03 13:16:01       16 阅读