《神经网络与深度学习:案例与实践》动手练习1.3

飞桨AI Studio星河社区-人工智能学习与实训社区

动手练习1.3

执行上述算子的反向过程,并验证梯度是否正确。

import math


class Op(object):
    def __init__(self):
        pass

    def __call__(self, inputs):
        return self.forward(inputs)

    # 前向函数
    # 输入:张量inputs
    # 输出:张量outputs
    def forward(self, inputs):
        # return outputs
        raise NotImplementedError

    # 反向函数
    # 输入:最终输出对outputs的梯度outputs_grads
    # 输出:最终输出对inputs的梯度inputs_grads
    def backward(self, outputs_grads):
        # return inputs_grads
        raise NotImplementedError


class add(Op):
    def __init__(self):
        super(add, self).__init__()

    def __call__(self, x, y):
        return self.forward(x, y)

    def forward(self, x, y):
        self.x = x
        self.y = y
        outputs = x + y
        return outputs

    def backward(self, grads):
        grads_x = grads * 1
        grads_y = grads * 1
        return grads_x, grads_y


class multiply(Op):
    def __init__(self):
        super(multiply, self).__init__()

    def __call__(self, x, y):
        return self.forward(x, y)

    def forward(self, x, y):
        self.x = x
        self.y = y
        outputs = x * y
        return outputs

    def backward(self, grads):
        grads_x = grads * self.y
        grads_y = grads * self.x
        return grads_x, grads_y


class exponential(Op):
    def __init__(self):
        super(exponential, self).__init__()

    def forward(self, x):
        self.x = x
        outputs = math.exp(x)
        return outputs

    def backward(self, grads):
        grads = grads * math.exp(self.x)
        return grads


a, b, c, d = 2, 3, 2, 2
multiply_op1 = multiply()
f1=multiply_op1(a,b)
multiply_op2 = multiply()
f2=multiply_op2(c,d)
add_op = add()
f3=add_op(f1,f2)
exp_op = exponential()
f4=exp_op(f3)

print(f4)

val1=exp_op.backward(grads=1)
val2=add_op.backward(val1)
val3=multiply_op1.backward(val2[0])
val4=multiply_op2.backward(val2[0])
print(val3)
print(val4)

相关推荐

  1. 【前言】神经网络深度学习简介

    2024-04-23 07:02:04       18 阅读
  2. 神经网络深度学习(三)

    2024-04-23 07:02:04       12 阅读
  3. 神经网络深度学习(四)

    2024-04-23 07:02:04       9 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-23 07:02:04       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-23 07:02:04       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-23 07:02:04       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-23 07:02:04       18 阅读

热门阅读

  1. 【Leetcode】并查集/DFS/BFS多解

    2024-04-23 07:02:04       12 阅读
  2. Hive进阶(5)----yarn的资源调度策略

    2024-04-23 07:02:04       12 阅读
  3. OSPF的防止环路的机制

    2024-04-23 07:02:04       12 阅读
  4. 从C到Py:Python的字符串及正则表达式

    2024-04-23 07:02:04       12 阅读
  5. Golang学习笔记--Gin框架

    2024-04-23 07:02:04       14 阅读
  6. F5应用及配置

    2024-04-23 07:02:04       12 阅读
  7. flutter 按钮动画 AnimatedPress

    2024-04-23 07:02:04       14 阅读
  8. Flutter 从源码扒一扒Stream机制

    2024-04-23 07:02:04       12 阅读
  9. 【26考研】考研备考计划4.22开始

    2024-04-23 07:02:04       13 阅读