MLP手写数字识别(3)-使用tf.data.Dataset模块制作模型输入(tensorflow)

1、tensorflow版本查看

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、MNIST数据集下载与预处理

(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()
train_images,test_images = tf.cast(train_images/255.0,tf.float32),tf.cast(test_images/255.0,tf.float32) # 归一化

tf.data.Dataset制作训练数据集

ds_train_image = tf.data.Dataset.from_tensor_slices(train_images)
ds_train_label = tf.data.Dataset.from_tensor_slices(train_labels)
ds_train = tf.data.Dataset.zip((ds_train_image,ds_train_label))
ds_train = ds_train.shuffle(10000).repeat().batch(64) # 乱序,无限次重复,每次取64张图片

print(ds_train_image)
print(ds_train_label)
print(ds_train)

在这里插入图片描述

tf.data.Dataset制作测试数据集

ds_test = tf.data.Dataset.from_tensor_slices((test_images,test_labels))
ds_test = ds_test.repeat().batch(64)

print(ds_test)

在这里插入图片描述

3、搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
# model.summary()

4、模型训练

调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

调用model.fit()配置训练参数,开始训练,并保存训练结果。

steps_per_epochs = train_images.shape[0]//64  # 937

H = model.fit(ds_train,  # 训练数据集
              steps_per_epoch=steps_per_epochs,  # 每个epoch训练步数
              validation_data=ds_test,  #验证数据集
              validation_steps=10000//64,
              epochs=10,
              verbose=1)

在这里插入图片描述

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-11 18:08:06       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-11 18:08:06       100 阅读
  3. 在Django里面运行非项目文件

    2024-05-11 18:08:06       82 阅读
  4. Python语言-面向对象

    2024-05-11 18:08:06       91 阅读

热门阅读

  1. 设计模式:命令模式

    2024-05-11 18:08:06       31 阅读
  2. 利用干扰源模型确定多通道音频信号盲源分离

    2024-05-11 18:08:06       34 阅读
  3. OceanBase OAT安装

    2024-05-11 18:08:06       30 阅读
  4. 单播、组播、广播

    2024-05-11 18:08:06       37 阅读
  5. PYTHON利用实时交易量智能股票交易系统

    2024-05-11 18:08:06       36 阅读
  6. MYSQL SQL优化思路和方法

    2024-05-11 18:08:06       77 阅读
  7. fastapi数据库连接池的模版

    2024-05-11 18:08:06       35 阅读
  8. D3.js实战:数据可视化高级技巧实例应用

    2024-05-11 18:08:06       35 阅读
  9. idea

    idea

    2024-05-11 18:08:06      33 阅读
  10. postman---认证(Certificates)是什么作用?

    2024-05-11 18:08:06       33 阅读
  11. git命令详解+使用样例

    2024-05-11 18:08:06       37 阅读