Pytorch语义分割(2)--------模型搭建

经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。

unet.py:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Up(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Up, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.block(x)
        out = F.interpolate(x, scale_factor=2)
        return out


class Down(nn.Module):
    def __init__(self, in_channel, out_channel, stride=2):
        super(Down, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)


class UpConcat(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpConcat, self).__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
        )

    def forward(self, in_map1, in_map2):
        in_map2 = self.up(in_map2)
        out = torch.cat([in_map1, in_map2], dim=1)
        return self.conv2(out)


class MainNet(nn.Module):
    def __init__(self, num_classes):
        super(MainNet, self).__init__()
        self.down1 = Down(3, 64, stride=1)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.down5 = Down(512, 1024)

        # self.conv = nn.Conv2d(1024, 512, 3, 1, 1)

        self.up5concat = UpConcat(1024, 512)
        self.up4concat = UpConcat(512, 256)
        self.up3concat = UpConcat(256, 128)
        self.up2concat = UpConcat(128, 64)

        self.head = nn.Sequential(
            nn.Conv2d(64, num_classes, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat1 = self.down1(x)       # 3, 512, 512 ---->64, 512, 512
        feat2 = self.down2(feat1)   # 64, 512, 512 ---->128, 256, 256
        feat3 = self.down3(feat2)   # 128, 256, 256 ---->256,128,128
        feat4 = self.down4(feat3)   # 256,128,128 ---> 512,64,64
        feat5 = self.down5(feat4)   # 512,64,64 ----> 1024,32,32
        print("feat5:", feat5.shape)
        # feat5 = self.conv(feat5)

        feat4_up = self.up5concat(feat4, feat5)
        print("feat4_up:", feat4_up.shape)
        feat3_up = self.up4concat(feat3, feat4_up)
        feat2_up = self.up3concat(feat2, feat3_up)
        feat1_up = self.up2concat(feat1, feat2_up)
        print("feat1_up:", feat1_up.shape)

        print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)
        return self.head(feat1_up)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = torch.zeros((1, 3, 512, 512)).to(device)
    model = MainNet(num_classes=3).to(device)

    # print(model)
    # model.apply(inplace_relu)

    out = model(tensor)
    # print(out.shape)
    #
    from torchsummary import torchsummary
    torchsummary.summary(model, (3, 512, 512))
    # # from torchstat import stat
    # # stat(model, (3, 512, 512))
    # from thop import profile
    #
    # flops, params = profile(model, inputs=(tensor,))
    #
    # print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
    # print("params=", str(params / 1e6) + '{}'.format("M"))
    #
    # #FLOPs= 63.406604288G
    # # params= 14.127683M

相关推荐

  1. Pytorch语义分割2)--------模型

    2024-06-07 06:32:02       33 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-07 06:32:02       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-07 06:32:02       106 阅读
  3. 在Django里面运行非项目文件

    2024-06-07 06:32:02       87 阅读
  4. Python语言-面向对象

    2024-06-07 06:32:02       96 阅读

热门阅读

  1. ChatGPT-3

    2024-06-07 06:32:02       32 阅读
  2. QT之全局忽略编译警告QMAKE_CXXFLAGS

    2024-06-07 06:32:02       31 阅读
  3. Ubuntu禁止内核自动更新

    2024-06-07 06:32:02       24 阅读
  4. nginx如何编译安装?

    2024-06-07 06:32:02       29 阅读
  5. 【Android】点击图片获取点击位置在图片中的位置

    2024-06-07 06:32:02       31 阅读
  6. electron录制工具-准备录制mask

    2024-06-07 06:32:02       27 阅读
  7. 一些关于科技的想法

    2024-06-07 06:32:02       33 阅读
  8. 使用docker直接运行不同版本nodejs命令

    2024-06-07 06:32:02       36 阅读
  9. Centos7安装Docker和DockerCompose

    2024-06-07 06:32:02       30 阅读