用全连接对手写数字识别案例(附解决TensorFlow2.x没有examples问题)

数据集介绍

数据集直接调用可能出现问题,建议从官网直接下载下来,下载存在这四个文件

手写数字识别数据集下载:

链接:https://pan.baidu.com/s/1nqhP4yPNcqefKYs91jp9ng?pwd=xe1h 
提取码:xe1h

55000行训练数据集(minst.train)和10000行测试数据集(mnist.test)。

每一个数据单元分为两部分:一张含手写数字的图片和一个对应的标签 训练数据集图片mnist.train.images,训练数据集的标签是mnist.train.labels

特征值:黑白图片,每张包含28像素*28像素 

目标值:分类 --one-hot编码  那一列标号为1属于哪个类别(0-9数字表示10列,对应看标号1属于                 哪列表示哪个值)             

Mnist数据获取API

TensorFlow框架自带了获取这个数据集的接口,我们直接调用即可

》from trnsorflow.exampls.tutorials.mnist import_data

       mnist = input_data.read_data_sets(path,one_hot = True)

             mnist.train.next_batch(100) 提供批量获取

             mnist.train.image、 labels

             mnist.test.image、 labels

实战

1、网络设计 采用一层,即最后一个输出层的神经网络 ----全连接神经网络

2、全连接层计算

y = w1x1 +w2x2+....+b

x[None,784] * weights[784,10]  + bias[10]= y_predict[None,10]

平均损失:error=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,

                   logits=y_predict, name=None))

优化损失:梯度下降

计算准确率:比较输出结果最大值所在位置和真实值的最大值所在位置   一致返回1不一致返回0

                      计算平均

3、代码:

import tensorflow as tf
from tensorflow.examples.speech_commands import input_data

tf.compat.v1.disable_eager_execution()
from tensorflow.examples.tutorials.mnist import input_data

def full_connection():
    """
    用全连接对手写数字进行识别
    :return:
    """
    # 1)准备数据
    mnist = input_data.read_data_sets("D:\heima\Python深度之神经网络资料\02-代码\mnist_data", one_hot=True)
    # 用占位符定义真实数据
    X = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 784])
    y_true = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 10])

    # 2)构造模型 - 全连接
    # [None, 784] * W[784, 10] + Bias = [None, 10]
    weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[784, 10], stddev=0.01))
    bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[10], stddev=0.1))
    y_predict = tf.matmul(X, weights) + bias

    # 3)构造损失函数
    loss_list = tf.nn.softmax_cross_entropy_with_logits(logits=y_predict, labels=y_true)
    loss = tf.reduce_mean(loss_list)

    # 4)优化损失
    # optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss)
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.01).minimize(loss)

    # 5)增加准确率计算
    bool_list = tf.equal(tf.argmax(y_true, axis=1), tf.argmax(y_predict, axis=1))
    accuracy = tf.reduce_mean(tf.cast(bool_list, tf.float32))

    # 初始化变量
    init = tf.compat.v1.global_variables_initializer()

    # 开启会话
    with tf.compat.v1.Session() as sess:

        # 初始化变量
        sess.run(init)

        # 开始训练
        for i in range(5000):
            # 获取真实值
            image, label = mnist.train.next_batch(500)

            _, loss_value, accuracy_value = sess.run([optimizer, loss, accuracy], feed_dict={X: image, y_true: label})

            print("第%d次的损失为%f,准确率为%f" % (i+1, loss_value, accuracy_value))


    return None

if __name__ == "__main__":
    full_connection()

解决无examples问题

TensorFlow2.X版本没有examples 下载后又发现里面缺少tutoritus,我是下载的直接放在example里面

将下载的exaples放入下面这个路径:

然后将下载的tutorials放到examples里面 问题就解决了

examples下载:

链接:https://pan.baidu.com/s/1fGan_JGGARIUPror3x6mvw?pwd=3kaz 
提取码:3kaz

tutorials包下载:

链接:https://pan.baidu.com/s/1adA6hUWbxfXXFNRpkYoJIw 
提取码:ffg6

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-22 16:16:01       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-22 16:16:01       106 阅读
  3. 在Django里面运行非项目文件

    2024-04-22 16:16:01       87 阅读
  4. Python语言-面向对象

    2024-04-22 16:16:01       96 阅读

热门阅读

  1. 基于Python调用Gurobi求解器的入门文档

    2024-04-22 16:16:01       35 阅读
  2. 艾体宝观察 | 2024,如何开展网络安全风险分析

    2024-04-22 16:16:01       35 阅读
  3. 【无标题】

    2024-04-22 16:16:01       29 阅读
  4. PHP按自然月计算未来日期

    2024-04-22 16:16:01       30 阅读
  5. 使用Django Rest Framework设计与实现用户注册API

    2024-04-22 16:16:01       29 阅读
  6. SQL开窗函数

    2024-04-22 16:16:01       34 阅读
  7. 机器学习实战-k近邻分类

    2024-04-22 16:16:01       30 阅读