【Python】卷积神经网络

一、前言

卷积神经网络(Convolutional Neural Networks,CNN)是一种包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度学习(deep learning)的代表算法之一。

卷积神经网络具有表征学习(representation learning)能力,能够按其阶层结构对输入信息进行平移不变分类(shift-invariant classification),因此也被称为“平移不变人工神经网络(Shift-Invariant Artificial Neural Networks, SIANN)”。

二、卷积神经网络

2.1 简介

卷积神经网络仿造生物的视知觉(visual perception)机制构建,可以进行监督学习和非监督学习,其隐含层内的卷积核参数共享和层间连接的稀疏性使得卷积神经网络能够以较小的计算量对格点化(grid-like topology)特征,例如像素和音频进行学习、有稳定的效果且对数据没有额外的特征工程(feature engineering)要求。

卷积神经网络最常用于分析视觉图像,在图像和视频识别、推荐系统、图像分类、图像分割、医学图像分析、自然语言处理、脑机接口和金融时间序列等领域都有应用。卷积神经网络是多层感知机的规范化版本,不同神经元的感受野部分重叠,使其覆盖整个视野。

卷积神经网络的概念最早可以追溯到上世纪60年代,Hubel等人通过对猫视觉皮层细胞的研究,提出了感受野这个概念。到80年代,Fukushima在感受野概念的基础之上提出了神经认知机的概念,可以看作是卷积神经网络的第一个实现网络,神经认知机将一个视觉模式分解成许多子模式(特征),然后进入分层递阶式相连的特征平面进行处理,它试图将视觉系统模型化,使其能够在即使物体有位移或轻微变形的时候,也能完成识别。

卷积神经网络由纽约大学的Yann Lecun于1998年提出(LeNet-5),其本质是一个多层感知机,成功的原因在于其所采用的局部连接和权值共享的方式:一方面减少了权值的数量使得网络易于优化;另一方面降低了模型的复杂度、减小了过拟合的风险。当网络的输入为图像时,这些优点将表现地更加明显。

卷积神经网络的研究始于二十世纪80至90年代,时间延迟网络和LeNet-5是最早出现的卷积神经网络;在二十一世纪后,随着深度学习理论的提出和数值计算设备的改进,卷积神经网络得到了快速发展,并被应用于计算机视觉、自然语言处理等领域。2012年AlexNet取得ImageNet比赛的分类任务的冠军,使得卷积神经网络真正爆发。

2.2 层级结构

卷积层(Convolutional Layer): CNN的核心部分是卷积层,通过卷积操作从输入数据中提取特征。卷积操作使用一个可学习的滤波器(或卷积核)对输入数据进行滑动操作,从而生成特征图。这有助于网络捕捉输入数据中的局部空间结构。

池化层(Pooling Layer): 池化操作用于降低特征图的空间维度,减少计算复杂性,并提高模型的鲁棒性。常见的池化操作有最大池化(选择局部区域中的最大值)和平均池化(计算局部区域中的平均值)。

激活函数(Activation Function): 激活函数引入非线性性,使得网络能够学习复杂的映射关系。常见的激活函数包括ReLU(Rectified Linear Unit)和其变种,用于在特征图中引入非线性。

全连接层(Fully Connected Layer): 在卷积层和输出层之间可能存在全连接层,用于将卷积层提取的特征映射转化为最终的输出。全连接层将每个神经元与前一层的所有神经元连接。

Dropout: Dropout是一种正则化技术,随机地在训练时丢弃一部分神经元,以减少过拟合的风险。

权重共享: 在卷积操作中,同一卷积核被用于整个输入图像,从而实现参数的共享,减少模型的参数数量。

卷积神经网络的训练过程通常包括前向传播(计算预测值)、反向传播(计算梯度)和权重更新。通过多次迭代这个过程,网络逐渐学到有效的特征表示,从而提高在特定任务上的性能。CNN在图像处理领域取得了显著的成就,也被用于其他领域,如自然语言处理。

2.3实现

此示例使用CIFAR-10数据集。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 16 * 16, 10)  # 假设输入图像大小为32x32,且有10个类别

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 16 * 16)  # 将特征图展平
        x = self.fc1(x)
        return x

# 定义数据预处理和加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 图像标准化
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

# 实例化模型、损失函数和优化器
net = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练网络
for epoch in range(5):  # 仅迭代5次作为示例
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:  # 每2000个小批量数据打印一次损失值
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

相关推荐

  1. Python神经网络

    2024-01-09 16:06:02       43 阅读
  2. 神经网络

    2024-01-09 16:06:02       24 阅读
  3. 神经网络

    2024-01-09 16:06:02       14 阅读
  4. 神经网络

    2024-01-09 16:06:02       9 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-01-09 16:06:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-09 16:06:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-09 16:06:02       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-09 16:06:02       20 阅读

热门阅读

  1. 【面试高频算法解析】算法练习7 贪心算法

    2024-01-09 16:06:02       38 阅读
  2. SpringBoot项目中开启MyBatis的SQL日志

    2024-01-09 16:06:02       37 阅读
  3. openc910源码LSU系列之<66>

    2024-01-09 16:06:02       31 阅读
  4. golang学习-流程控制

    2024-01-09 16:06:02       38 阅读
  5. pytest-mock 数据模拟

    2024-01-09 16:06:02       51 阅读
  6. 用 Socket.D 替代原生 WebSocket 做前端开发

    2024-01-09 16:06:02       39 阅读
  7. 常见连读技巧

    2024-01-09 16:06:02       37 阅读
  8. Linux CentOS官方文档之U盘安装

    2024-01-09 16:06:02       39 阅读
  9. ACP科普:为什么Scrum的冲刺周期不变?

    2024-01-09 16:06:02       39 阅读
  10. socket从客户端向主机传输一个文件

    2024-01-09 16:06:02       38 阅读
  11. Scrum产品负责人(CSPO)认证Scrum Product Owner

    2024-01-09 16:06:02       36 阅读