VGG介绍及Pytorch实现

VGG是一种经典的卷积神经网络架构,由牛津大学视觉几何组(Visual Geometry Group)提出。VGG以其简单而有效的设计而闻名,其核心思想是通过多层深度的卷积和池化层来逐渐提取图像特征,并且通过堆叠多个卷积层和池化层来增加网络的深度。VGG网络结构中主要采用3x3大小的卷积核和2x2大小的最大池化核,这种统一的设计使得网络架构非常规整,易于理解和实现。VGG网络共有多个版本,其中VGG16和VGG19是最为常用的两个版本,分别包含16和19个卷积层,以及若干个全连接层。虽然VGG相对于其他深度学习模型而言较为简单,但其在图像分类等任务上表现出色,成为了深度学习领域的经典模型之一。

这篇文章很简单,就是卷积块的堆叠, 注意力卷积的通道输入输出即可,没什么难度。

Pytorch代码

# Define VGG-16 and VGG-19.
import torch

cfg = {
    'VGG-16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG-19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


# VGG-16 and VGG-19
class VGGNet(torch.nn.Module):
    def __init__(self, VGG_type, num_classes):
        super(VGGNet, self).__init__()
        self.features = self._make_layers(cfg[VGG_type])
        self.classifier = torch.nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':  # MaxPool2d
                layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [torch.nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           torch.nn.BatchNorm2d(x),
                           torch.nn.ReLU(inplace=True)]
                in_channels = x
        layers += [torch.nn.AvgPool2d(kernel_size=1, stride=1)]
        return torch.nn.Sequential(*layers)  # The number of parameters is more than one.

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-24 00:32:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-24 00:32:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-24 00:32:04       82 阅读
  4. Python语言-面向对象

    2024-03-24 00:32:04       91 阅读

热门阅读

  1. 一种通用的ES索引创建设计方案

    2024-03-24 00:32:04       42 阅读
  2. Docker 中安装 Redis

    2024-03-24 00:32:04       45 阅读
  3. es6类,判断数据类型

    2024-03-24 00:32:04       38 阅读
  4. 设计模式,简单工厂模式

    2024-03-24 00:32:04       39 阅读
  5. js实现读取excel文件

    2024-03-24 00:32:04       37 阅读
  6. 模型参数加载

    2024-03-24 00:32:04       40 阅读
  7. oracle添加用户

    2024-03-24 00:32:04       46 阅读
  8. 第四章 可变参数模板

    2024-03-24 00:32:04       37 阅读
  9. SQL运维_Unix下MySQL-5.5.11配置文件示例

    2024-03-24 00:32:04       43 阅读
  10. TensorFlow的研究应用与开发~深度学习

    2024-03-24 00:32:04       47 阅读