注意区分打印网络参数的个数和打印网络参数(权重和偏置)的个数
在TensorFlow 1.0 中,可以通过使用tf.trainable_variables()获取模型的所有可训练参数(即权重和偏置),并使用sess.run()在会话中运行这些变量来打印它们的值。
打印网络参数(权重和偏置)
import tensorflow as tf
# 构建模型
# 创建会话
with tf.Session() as sess:
# 初始化所有变量
sess.run(tf.global_variables_initializer())
# 获取所有可训练的变量
trainable_vars = tf.trainable_variables()
# 打印每个变量的名称和值
for var in trainable_vars:
print(var.name)
print(sess.run(var))
打印出网络参数的个数,需要获取每个可训练参数的形状,然后计算它们的乘积来得到每个参数的元素个数。最后,将所有参数的元素个数相加即可得到网络参数的总个数。
import tensorflow as tf
import numpy as np
# 构建模型
# 创建会话
with tf.Session() as sess:
# 初始化所有变量
sess.run(tf.global_variables_initializer())
# 获取所有可训练的变量
trainable_vars = tf.trainable_variables()
# 计算所有参数的总个数
total_parameters = 0
for variable in trainable_vars:
# 获取变量的形状,例如[5, 5, 1, 32]表示一个5x5的32通道卷积核
shape = variable.get_shape()
# 计算当前变量的参数个数,为形状的各维大小的乘积
variable_parametes = 1
for dim in shape:
variable_parametes *= dim.value
# 将当前变量的参数个数加到总个数上
total_parameters += variable_parametes
print("Total number of parameters in the network: {}".format(total_parameters))