如何用pytorch调用预训练Swin Transformer中的一个Swin block模块

1,首先,我们需要知道的是,想要调用预训练的Swin Transformer模型,必须要安装pytorch2,因为pytorch1对应的torchvision中不包含Swin Transformer。

2,pytorch2调用预训练模型时,不建议使用pretrained=True,这个用法即将淘汰,会报警告。最好用如下方式:

from torchvision.models.swin_transformer import swin_b, Swin_B_Weights  
  
model = swin_b(weights=Swin_B_Weights.DEFAULT)  

这里调用的就是swin_b在imagenet上的预训练模型

3,swin_b的模型结构如下(仅展示到第一个patch merging部分),在绝大部分情况下,我们可能需要的不是整个模型,而是其中的一个模块,比如SwinTransformerBlock。

SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, out_features=128, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, out_features=128, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (2): PatchMerging(
      (reduction): Linear(in_features=512, out_features=256, bias=False)
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )

那么如何调用其中的SwinTransformerBlock呢。

由于该模型是个嵌套结构,而不是类似vgg一样简单的结构,所以不能直接用layer0=model.SwinTransformerBlock调用。

因为SwinTransformerBlock是Sequential下的子模块,故正确的调用代码如下:

swinblock = model.features[1][0]

结果如下,调用成功:

相关推荐

  1. Pytorch 第一讲】 如何加载训练模型

    2024-03-24 09:46:01       61 阅读
  2. pytorch模型训练学习率动态调整

    2024-03-24 09:46:01       39 阅读
  3. 关于训练模型一点感悟

    2024-03-24 09:46:01       51 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-24 09:46:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-24 09:46:01       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-24 09:46:01       82 阅读
  4. Python语言-面向对象

    2024-03-24 09:46:01       91 阅读

热门阅读

  1. 如何在OpenCV中实现实时人脸识别?

    2024-03-24 09:46:01       38 阅读
  2. 24计算机考研调剂 | 江西理工大学

    2024-03-24 09:46:01       41 阅读
  3. 日志收集监控告警平台的选型思考

    2024-03-24 09:46:01       38 阅读
  4. Github 2024-03-24 开源项目日报Top10

    2024-03-24 09:46:01       38 阅读
  5. 数据库第一次作业

    2024-03-24 09:46:01       36 阅读
  6. Ubuntu添加硬盘

    2024-03-24 09:46:01       40 阅读