TensorFlow打印网络参数的个数

注意区分打印网络参数的个数和打印网络参数(权重和偏置)的个数

在TensorFlow 1.0 中,可以通过使用tf.trainable_variables()获取模型的所有可训练参数(即权重和偏置),并使用sess.run()在会话中运行这些变量来打印它们的值。

打印网络参数(权重和偏置)

import tensorflow as tf

# 构建模型

# 创建会话
with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())
    
    # 获取所有可训练的变量
    trainable_vars = tf.trainable_variables()
    
    # 打印每个变量的名称和值
    for var in trainable_vars:
        print(var.name)
        print(sess.run(var))

打印出网络参数的个数,需要获取每个可训练参数的形状,然后计算它们的乘积来得到每个参数的元素个数。最后,将所有参数的元素个数相加即可得到网络参数的总个数。

import tensorflow as tf
import numpy as np

# 构建模型

# 创建会话
with tf.Session() as sess:
    # 初始化所有变量
    sess.run(tf.global_variables_initializer())
    
    # 获取所有可训练的变量
    trainable_vars = tf.trainable_variables()
    
    # 计算所有参数的总个数
    total_parameters = 0
    for variable in trainable_vars:
        # 获取变量的形状,例如[5, 5, 1, 32]表示一个5x5的32通道卷积核
        shape = variable.get_shape()
        
        # 计算当前变量的参数个数,为形状的各维大小的乘积
        variable_parametes = 1
        for dim in shape:
            variable_parametes *= dim.value
        
        # 将当前变量的参数个数加到总个数上
        total_parameters += variable_parametes
    
    print("Total number of parameters in the network: {}".format(total_parameters))

相关推荐

  1. TensorFlow打印网络参数个数

    2024-03-24 07:12:01       39 阅读
  2. sqlalchemy打印querySQL和参数

    2024-03-24 07:12:01       23 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-24 07:12:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-24 07:12:01       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-24 07:12:01       82 阅读
  4. Python语言-面向对象

    2024-03-24 07:12:01       91 阅读

热门阅读

  1. ORACLE 知识整理

    2024-03-24 07:12:01       36 阅读
  2. TensorFlow 的基本概念和使用场景

    2024-03-24 07:12:01       37 阅读
  3. C语言中的static关键字

    2024-03-24 07:12:01       44 阅读
  4. 二进制源码部署mysql8.0.35

    2024-03-24 07:12:01       31 阅读
  5. node.js中常用的命令及示例

    2024-03-24 07:12:01       36 阅读