【手撕算法系列】BN

BN的计算公式

在这里插入图片描述

BN中均值与方差的计算

在这里插入图片描述

所以对于输入x: b,c,h,w
则 mean: 1,c,1,1
	var: 1,c,1,1

代码

class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        # num_features:完全连接层的输出数量或卷积层的输出通道数。
        # num_dims:2表示完全连接层,4表示卷积层    
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)
 
    def forward(self, x, momentum=0.9, eps=1e-5):
        if self.training:
            assert len(x.shape) in (2, 4)
            #判断是全连接层还是卷积层,2代表全连接层,样本数和特征数;4代表卷积层,批量数,通道数,高宽
            if len(x.shape) == 2:
                # 使用全连接层的情况,计算特征维上的均值和方差
                mean = x.mean(dim=0, keepdim=True)
                var = x.var(dim=0, keepdim=True)
            else:
                # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
                mean = x.mean(dim=(0, 2, 3), keepdim=True)  # 1, c, 1, 1
                var = x.var(dim=(0, 2, 3), keepdim=True)

            # 训练模式下,用当前的均值和方差做标准化
            x_hat = (x - mean) / torch.sqrt(var + eps)
            # 更新移动平均的均值和方差
            self.moving_mean = momentum * self.moving_mean + (1.0 - momentum) * mean
            self.moving_var = momentum * self.moving_var + (1.0 - momentum) * var
        
        else:
            x_hat = (x - self.moving_mean) / torch.sqrt(self.moving_var + eps)

        out = self.gamma * x_hat + self.beta
        return out

相关推荐

  1. 算法系列】k-means

    2023-12-17 17:10:01       63 阅读
  2. 算法系列----Dijkstra单源最短路径

    2023-12-17 17:10:01       44 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2023-12-17 17:10:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2023-12-17 17:10:01       101 阅读
  3. 在Django里面运行非项目文件

    2023-12-17 17:10:01       82 阅读
  4. Python语言-面向对象

    2023-12-17 17:10:01       91 阅读

热门阅读

  1. 使用Yellowbrick绘制获取最佳聚类K值的示例

    2023-12-17 17:10:01       56 阅读
  2. 【vue filters 过滤器】vue页面 全局使用

    2023-12-17 17:10:01       57 阅读
  3. RK3568-PWM

    2023-12-17 17:10:01       52 阅读
  4. Optee在嵌入式系统中是否支持多线程机制

    2023-12-17 17:10:01       53 阅读
  5. Word Excel模版引擎

    2023-12-17 17:10:01       69 阅读
  6. 设计模式——原型模式代码示例

    2023-12-17 17:10:01       52 阅读
  7. 通过接口引用对象

    2023-12-17 17:10:01       54 阅读
  8. 一句话分清C/C++声明和定义

    2023-12-17 17:10:01       57 阅读
  9. Vue3源码梳理:响应式系统的前世今生

    2023-12-17 17:10:01       51 阅读
  10. 数据库处理与分组存储

    2023-12-17 17:10:01       53 阅读