政安晨:【Keras机器学习示例演绎】(十九)—— 可视化网络学习内容

目录

简介

设置

建立特征提取模型

设置梯度上升过程

设置端到端滤波器可视化回路

可视化目标层中的前 64 个滤波器


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:显示 convnet 过滤器响应的视觉模式。

简介


在本示例中,我们将研究图像分类模型能学习到哪些视觉模式。我们将使用在 ImageNet 数据集上训练的 ResNet50V2 模型。

我们的过程很简单:我们将创建输入图像,最大限度地激活目标层

选在模型中间的某个位置:层 conv3_block4_out)中的特定滤波器。这些图像代表了过滤器响应模式的可视化。

设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import numpy as np
import tensorflow as tf

# The dimensions of our input image
img_width = 180
img_height = 180
# Our target layer: we will visualize the filters from this layer.
# See `model.summary()` for list of layer names, if you want to change this.
layer_name = "conv3_block4_out"

建立特征提取模型

# Build a ResNet50V2 model loaded with pre-trained ImageNet weights
model = keras.applications.ResNet50V2(weights="imagenet", include_top=False)

# Set up a model that returns the activation values for our target layer
layer = model.get_layer(name=layer_name)
feature_extractor = keras.Model(inputs=model.inputs, outputs=layer.output)

设置梯度上升过程


我们要最大化的 "损失 "只是目标层中特定滤波器激活的平均值。为避免边界效应,我们将边界像素排除在外。

def compute_loss(input_image, filter_index):
    activation = feature_extractor(input_image)
    # We avoid border artifacts by only involving non-border pixels in the loss.
    filter_activation = activation[:, 2:-2, 2:-2, filter_index]
    return tf.reduce_mean(filter_activation)

我们的梯度上升函数只需计算上述损失相对于输入图像的梯度,并更新更新图像,使其趋向于更强地激活目标滤波器的状态。

@tf.function
def gradient_ascent_step(img, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(img)
        loss = compute_loss(img, filter_index)
    # Compute gradients.
    grads = tape.gradient(loss, img)
    # Normalize gradients.
    grads = tf.math.l2_normalize(grads)
    img += learning_rate * grads
    return loss, img

设置端到端滤波器可视化回路

我们的流程如下

* 从接近 "全灰 "的随机图像(即视觉上的净图像)开始
* 重复应用上文定义的梯度上升阶跃函数
* 通过对输入图像进行归一化处理、居中裁剪并将其限制在 [0, 255] 范围内,将生成的输入图像转换为可显示的形式。

def initialize_image():
    # We start from a gray image with some random noise
    img = tf.random.uniform((1, img_width, img_height, 3))
    # ResNet50V2 expects inputs in the range [-1, +1].
    # Here we scale our random inputs to [-0.125, +0.125]
    return (img - 0.5) * 0.25


def visualize_filter(filter_index):
    # We run gradient ascent for 20 steps
    iterations = 30
    learning_rate = 10.0
    img = initialize_image()
    for iteration in range(iterations):
        loss, img = gradient_ascent_step(img, filter_index, learning_rate)

    # Decode the resulting input image
    img = deprocess_image(img[0].numpy())
    return loss, img


def deprocess_image(img):
    # Normalize array: center on 0., ensure variance is 0.15
    img -= img.mean()
    img /= img.std() + 1e-5
    img *= 0.15

    # Center crop
    img = img[25:-25, 25:-25, :]

    # Clip to [0, 1]
    img += 0.5
    img = np.clip(img, 0, 1)

    # Convert to RGB array
    img *= 255
    img = np.clip(img, 0, 255).astype("uint8")
    return img

让我们在目标图层中使用滤镜 0 试试看:

from IPython.display import Image, display

loss, img = visualize_filter(0)
keras.utils.save_img("0.png", img)

这就是目标层 0 号滤波器响应最大化的输入结果:

display(Image("0.png"))

可视化目标层中的前 64 个滤波器

现在,让我们将目标层中的前 64 个滤波器做成一个 8x8 的网格,以了解模型学习到的不同视觉模式的范围。

# Compute image inputs that maximize per-filter activations
# for the first 64 filters of our target layer
all_imgs = []
for filter_index in range(64):
    print("Processing filter %d" % (filter_index,))
    loss, img = visualize_filter(filter_index)
    all_imgs.append(img)

# Build a black picture with enough space for
# our 8 x 8 filters of size 128 x 128, with a 5px margin in between
margin = 5
n = 8
cropped_width = img_width - 25 * 2
cropped_height = img_height - 25 * 2
width = n * cropped_width + (n - 1) * margin
height = n * cropped_height + (n - 1) * margin
stitched_filters = np.zeros((width, height, 3))

# Fill the picture with our saved filters
for i in range(n):
    for j in range(n):
        img = all_imgs[i * n + j]
        stitched_filters[
            (cropped_width + margin) * i : (cropped_width + margin) * i + cropped_width,
            (cropped_height + margin) * j : (cropped_height + margin) * j
            + cropped_height,
            :,
        ] = img
keras.utils.save_img("stiched_filters.png", stitched_filters)

from IPython.display import Image, display

display(Image("stiched_filters.png"))

演绎展示:

Processing filter 0
Processing filter 1
Processing filter 2
Processing filter 3
Processing filter 4
Processing filter 5
Processing filter 6
Processing filter 7
Processing filter 8
Processing filter 9
Processing filter 10
Processing filter 11
Processing filter 12
Processing filter 13
Processing filter 14
Processing filter 15
Processing filter 16
Processing filter 17
Processing filter 18
Processing filter 19
Processing filter 20
Processing filter 21
Processing filter 22
Processing filter 23
Processing filter 24
Processing filter 25
Processing filter 26
Processing filter 27
Processing filter 28
Processing filter 29
Processing filter 30
Processing filter 31
Processing filter 32
Processing filter 33
Processing filter 34
Processing filter 35
Processing filter 36
Processing filter 37
Processing filter 38
Processing filter 39
Processing filter 40
Processing filter 41
Processing filter 42
Processing filter 43
Processing filter 44
Processing filter 45
Processing filter 46
Processing filter 47
Processing filter 48
Processing filter 49
Processing filter 50
Processing filter 51
Processing filter 52
Processing filter 53
Processing filter 54
Processing filter 55
Processing filter 56
Processing filter 57
Processing filter 58
Processing filter 59
Processing filter 60
Processing filter 61
Processing filter 62
Processing filter 63

图像分类模型通过将其输入分解为 "向量基 "纹理过滤器来观察世界。


最近更新

  1. TCP协议是安全的吗?

    2024-04-27 06:56:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-27 06:56:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-27 06:56:04       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-27 06:56:04       20 阅读

热门阅读

  1. 【开发记录】arm v7配置青龙面板

    2024-04-27 06:56:04       13 阅读
  2. Django用户注册并自动关联到某数据表条目

    2024-04-27 06:56:04       12 阅读
  3. mpv编译播放器无视频输出

    2024-04-27 06:56:04       10 阅读
  4. “npm error code ELSPROBLEMS“问题解决

    2024-04-27 06:56:04       16 阅读
  5. 二分搜索法

    2024-04-27 06:56:04       13 阅读
  6. 前端点击地图上的位置获取当前经纬度

    2024-04-27 06:56:04       11 阅读
  7. LeetCode 每日一题 ---- 【2739.总行驶距离】

    2024-04-27 06:56:04       14 阅读
  8. 数据整合与 IT 自动化:工业企业的转型之路

    2024-04-27 06:56:04       14 阅读