python手写数字识别(PaddlePaddle框架、MNIST数据集)

python手写数字识别(PaddlePaddle框架、MNIST数据集)

import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import Compose, Normalize

transform = Compose([Normalize(mean=[127.5],
                               std=[127.5],
                               data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
# 使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')

# 用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成卷积神经网络的构建
class CNN(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)
        self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
        self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
        self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1,stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x


# 开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练
train_loader = paddle.io.DataLoader(train_dataset, batch_size=128, shuffle=True)
# 加载训练集 batch_size 设为 128
def train(model):
    model.train()
    epochs = 3
    optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
    # 用Adam作为优化函数
    print("Training:")
    for epoch in range(epochs):
        for batch_id, data in enumerate(train_loader()):
            x_data = data[0]
            y_data = data[1]
            predicts = model(x_data)
            loss = F.cross_entropy(predicts, y_data)
            # 计算损失
            acc = paddle.metric.accuracy(predicts, y_data)
            loss.backward()

            if batch_id % 300 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
            optim.step()
            optim.clear_grad()
model = CNN()
train(model)


# 训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=128)
# 加载测试数据集
def test(model):
    model.eval()
    print("Testing:")
    for batch_id, data in enumerate(test_loader()):
        x_data = data[0]
        y_data = data[1]
        predicts = model(x_data)
        # 获取预测结果
        loss = F.cross_entropy(predicts, y_data)
        acc = paddle.metric.accuracy(predicts, y_data)

        if batch_id % 50 == 0:
            print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)

在这里插入图片描述

相关推荐

  1. MINIST数据&数字识别

    2024-05-16 06:32:13       30 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-16 06:32:13       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-16 06:32:13       101 阅读
  3. 在Django里面运行非项目文件

    2024-05-16 06:32:13       82 阅读
  4. Python语言-面向对象

    2024-05-16 06:32:13       91 阅读

热门阅读

  1. 《图像处理的璀璨星空:技术演进与热点聚焦》

    2024-05-16 06:32:13       35 阅读
  2. Uniapp基础面试

    2024-05-16 06:32:13       40 阅读
  3. iOS 学习资料

    2024-05-16 06:32:13       36 阅读
  4. Rust语言实现图像编码转换

    2024-05-16 06:32:13       37 阅读
  5. DB类的学习

    2024-05-16 06:32:13       34 阅读
  6. 从HTTP迁移到HTTPS:一篇全面的测试方案设计指南

    2024-05-16 06:32:13       39 阅读
  7. MyBatis的一二级缓存区别

    2024-05-16 06:32:13       32 阅读
  8. http 和 https 的区别及原理解析

    2024-05-16 06:32:13       35 阅读
  9. 阅读笔记——《代码整洁之道》ch2

    2024-05-16 06:32:13       33 阅读
  10. ifconfig 无输出

    2024-05-16 06:32:13       31 阅读