Sparse MLP

上图展示了本文网络的整体架构。与ViT、MLP-Mixer和Swin Transformer类似,空间分辨率为H×W的输入图像被分割为不重叠的patch。作者在网络中采用了4×4的patch大小,每个patch被reshape成一个48维的向量,然后由一个线性层映射到一个c维embedding

import torch, os, datetime
from torch import nn


class sMLPBlock(nn.Module):
    def __init__(self, h=224, w=224, c=3):
        super().__init__()
        self.proj_h = nn.Linear(h, h)
        self.proj_w = nn.Linear(w, w)
        self.fuse = nn.Linear(3 * c, c)

    def forward(self, x):
        x_h = self.proj_h(x.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
        x_w = self.proj_w(x)
        x_id = x
        x_fuse = torch.cat([x_h, x_w, x_id], dim=1)
        out = self.fuse(x_fuse.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        return out


if __name__ == '__main__':
    input = torch.randn(2, 3, 224, 224)
    smlp = sMLPBlock(h=224, w=224)
    out = smlp(input)
    print(out.shape)

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-02-22 16:50:05       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-22 16:50:05       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-22 16:50:05       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-22 16:50:05       20 阅读

热门阅读

  1. 文生视频Sora

    2024-02-22 16:50:05       30 阅读
  2. YOLOv8模型部署

    2024-02-22 16:50:05       33 阅读
  3. 存储过程与高级编程语言:解析其差异与融合

    2024-02-22 16:50:05       26 阅读
  4. IDEA打开已有vue项目

    2024-02-22 16:50:05       28 阅读
  5. 设计模式--组合模式(Composite Pattern)

    2024-02-22 16:50:05       29 阅读
  6. 设计模式详解(十一)——组合模式

    2024-02-22 16:50:05       34 阅读
  7. C#中的`out`关键字

    2024-02-22 16:50:05       25 阅读