MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

1.MNIST的数据集下载与预处理

import tensorflow as tf
from keras.datasets import mnist
from keras.utils import to_categorical

(train_x,train_y),(test_x,test_y) = mnist.load_data()
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) # 归一化
y_train,y_test = to_categorical(train_y),to_categorical(test_y) # onehot
print(X_train[:5])
print(y_train[:5])

2.搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

3.模型训练

3.1 调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',
			  loss='categorical_crossentropy',
              metrics=['accuracy'])
  • loss=‘categorical_crossentropy’: 损失函数设置为交叉熵损失函数,在深度学习中用交叉熵模式训练效果会比较好。
  • optimizer=‘adam’: 优化器设置为adam, 在深度学习中可以让训练更快收敛,并提高准确率。
  • metrics=[‘accuracy’]:评估模式设置为准确度评估模式。

loss参数常用的损失函数

  • binary_crossentropy: 亦称作对数损失,logloss
  • categorical_crossentropy: 交叉熵损失函数,亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为onehot形式
  • sparse_categorical_crossentropy:稀疏交叉熵损失函数。
  • kullback_leibler_divergence: 从预测值概率分布Q到真值概率分布P的信息增益,用以度量两个分布的差异
  • poisson: 即(pred-target*log(pred))的均值
  • cosine_proximity:预测值与真实标签的余弦距离平均值的相反数

优化器

  • SGD
  • RMSprop
  • Adagrad
  • Adadelta
  • Adam
  • Adamax
  • Nadam
  • TFOptimizer

评估模式

  • binary_accuracy: 对二分类问题,计算在所有预测值上的平均正确率
  • categorical_accuracy: 对多分类问题,计算在所有预测值上的平均正确率
  • sparse_categorical_accuracy:与categorical_accuracy相同,在对稀疏的目标值预测时有用
  • top_k_categorical_accuracy: 计算top-k正确率,当预测值的前K个值中存在目标类别即认为预测正确
  • sparse_top_k_categorical_accuracy: 与top_k_categorical_accuracy作用相同,但适用于稀疏情况

3.2 调用model.fit()配置训练参数,开始训练,并保存训练结果。

H = model.fit(x=X_train,
			  y=y_train,
			  validation_split=0.2,
			  epochs=20,
		      batch_size=128,
			  verbose=1)

在这里插入图片描述

4.显示模型准确率和误差

import matplotlib.pyplot as plt

def show_train(history,train,validation):
    plt.plot(history.epoch, history.history[train],label=train)
    plt.plot(history.epoch, history.history[validation],label=validation)
    plt.title(train)
    plt.legend()
    plt.show()
    
show_train(H,'loss','val_loss')
show_train(H,'accuracy','val_accuracy')

在这里插入图片描述

5.使用测试数据进行识别

import numpy as np
import matplotlib.pyplot as plt

def pred_plot_images_lables(images,labels,start_idx,num=5):
    # 预测
    res = model.predict(images[start_idx:start_idx+num])
    res = np.argmax(res,axis=1)

    # 画图
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

pred_plot_images_lables(X_test,test_y,0,5)

在这里插入图片描述

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-04 11:44:03       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-04 11:44:03       101 阅读
  3. 在Django里面运行非项目文件

    2024-05-04 11:44:03       82 阅读
  4. Python语言-面向对象

    2024-05-04 11:44:03       91 阅读

热门阅读

  1. C++泛型算法2——谓词,lambda表达式

    2024-05-04 11:44:03       24 阅读
  2. Web开发:使用url引用图片

    2024-05-04 11:44:03       34 阅读
  3. 等级保护科普小知识

    2024-05-04 11:44:03       29 阅读
  4. 设计模式(软件设计师第5版)

    2024-05-04 11:44:03       31 阅读
  5. 【C++并发编程】(二)线程的创建、分离和连接

    2024-05-04 11:44:03       37 阅读
  6. MySQL45讲(一)(42)

    2024-05-04 11:44:03       29 阅读
  7. pycharm批量注释或取消多行

    2024-05-04 11:44:03       35 阅读