神经网络设计过程

1.可根据Iris特征直接判断

2.神经网络方法,采集大量的Iris特征,分类对应标签,构成数据集。

将数据集喂入搭好的神经网络结构,网络通过反向传播优化参数得到模型。

有新的网络送入到模型里,模型会给出识别结果。

3.具体实现过程。

Iris有4个特征,分类总共三种,构成的神经网络如下所示

MP模型

 

转化为Iris Y(1, 3) = X (1, 4) * W(4, 3) + b(3,) 三个偏置项

此网络为全连接网络

神经网络执行前向传播  y = x * w + b

 

因为W, B是随机初始化,所以答案不准确

运用损失函数定义预测值(y)与标准答案(y_)之间的差距。

损失函数可以判断当前W和b的优劣,当损失函数值最小时,W和b最优

损失函数的表达方法之一就是:均方误差:MSE(y, y_) = ( \sum_{k = 0}^{n} (y - y)^{2})/ n

表示网络前向传播推理与标准答案之间的差距。

目的:找到最优的W和b

梯度:函数对各个参数求偏导后的向量。函数梯度下降方向是函数减小的方向

梯度下降法:沿损失函数梯度下降的方向,寻找损失函数的最小值,得到最优的参数的方法。

学习率:当学习率设置的过小时,收敛过程变得缓慢,过大,会在错过最小值

梯度下降法更新参数的计算。 lr 学习率

 求wt+1 = wt - lr*对wt求偏导

 

 这里假设损失函数为(W+1)^2, 对w进行求偏导

import tensorflow as tf

w = tf.Variable(tf.constant(5, dtype=tf.float32))
lr = 0.2
epoch = 40

for epoch in range(epoch):  # for epoch 定义顶层循环,表示对数据集循环epoch次,此例数据集数据仅有1个w,初始化时候constant赋值为5,循环40次迭代。
    with tf.GradientTape() as tape:  # with结构到grads框起了梯度的计算过程。
        loss = tf.square(w + 1)
    grads = tape.gradient(loss, w)  # .gradient函数告知谁对谁求导

    w.assign_sub(lr * grads)  # .assign_sub 对变量做自减 即:w -= lr*grads 即 w = w - lr*grads
    print("After %s epoch,w is %f,loss is %f" % (epoch, w.numpy(), loss))

# lr初始值:0.2   请自改学习率  0.001  0.999 看收敛过程
# 最终目的:找到 loss 最小 即 w = -1 的最优参数w

相关推荐

  1. Pytorch:神经网络过程代码详解

    2024-07-10 20:46:03       18 阅读
  2. 解密神经网络:深入探究传播机制与学习过程

    2024-07-10 20:46:03       29 阅读
  3. 通过神经网络模拟人类大脑的学习过程

    2024-07-10 20:46:03       37 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-10 20:46:03       5 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-10 20:46:03       5 阅读
  3. 在Django里面运行非项目文件

    2024-07-10 20:46:03       4 阅读
  4. Python语言-面向对象

    2024-07-10 20:46:03       7 阅读

热门阅读

  1. redis实现延时队列

    2024-07-10 20:46:03       10 阅读
  2. Shell选择结构

    2024-07-10 20:46:03       14 阅读
  3. Poincaré图和SD2计算参考

    2024-07-10 20:46:03       10 阅读
  4. C#控件总结

    2024-07-10 20:46:03       9 阅读
  5. STM32(一):安装环境

    2024-07-10 20:46:03       10 阅读
  6. 数据中台真的适合你的企业吗?

    2024-07-10 20:46:03       9 阅读
  7. [AIGC] ClickHouse的表引擎介绍

    2024-07-10 20:46:03       13 阅读
  8. go 函数

    2024-07-10 20:46:03       11 阅读
  9. 玩转springboot之springboot项目监测

    2024-07-10 20:46:03       10 阅读
  10. 【LeetCode】每日一题:跳跃游戏 II

    2024-07-10 20:46:03       10 阅读
  11. Python面试题: 如何在 Python 中实现一个线程池?

    2024-07-10 20:46:03       13 阅读