一维卷积神经网络的特征可视化

随着以深度学习为代表的人工智能技术的不断发展,许多具有重要意义的深度学习模型和算法被开发出来,应用于计算机视觉、自然语言处理、语音处理、生物医疗、金融应用等众多行业领域。深度学习先进的数据挖掘、训练和分析能力来源于深度神经网络的海量模型参数以及高度非线性。也正因为深度学习算法的高度复杂性,许多模型往往难以解释其内部工作原理,这导致这些模型被称为缺乏可解释性的“黑箱模型”。

随着AI应用渗透到各行各业,AI的科技伦理受到广泛的关注。而科技伦理的一个核心议题就是可解释人工智能XAI。从社会科学角度,可解释性是指人对决策原因的理解程度,可解释性越高,人就越能理解为什么做出这样的决策。对应于AI领域,可解释性是指能够在一定程度上揭示AI模型内部工作机制和对模型结果的进行解释,帮助用户理解模型是如何做出预测或决策的。

因此,本文简单地对一维卷积神经网络的特征进行可视化,运行环境为Python,研究对象为心电信号。

首先导入相关库

import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import signal_screen
import signal_screen_tools

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv1D, MaxPool1D, Flatten, BatchNormalization, Input
from tensorflow.keras.callbacks import ModelCheckpoint

数据导入及处理

# load data
data_train = pd.read_csv("mitbih_train.csv", sep=",", header=None).to_numpy()
data_test = pd.read_csv("mitbih_test.csv", sep=",", header=None).to_numpy()

# get X and y
X_train, y_train = data_train[:, :data_train.shape[1]-2], data_train[:, -1]
X_test, y_test = data_test[:, :data_test.shape[1]-2], data_test[:, -1]

# number of categories
num_of_categories = np.unique(y_train).shape[0]

del data_train, data_test

#indexing examples to show visualisations
examples_to_visualise = [np.where(y_test == i)[0][0] for i in range(5)]
titles = [ "nonectopic", "supraventricular ectopic beat", "ventricular ectopic beat", "fusion beat", "unknown"]

# creation of tensors
X_train = np.expand_dims(tf.convert_to_tensor(X_train), axis=2)
X_test = np.expand_dims(tf.convert_to_tensor(X_test), axis=2)

# one-hot encoding for 5 categories
y_train = tf.one_hot(y_train, num_of_categories)
y_test = tf.one_hot(y_test, num_of_categories)

建立模型并进行训练

# basic model
model = Sequential([
    Input(shape=[X_train.shape[1], 1]),
    Conv1D(filters=16, kernel_size=3, activation="relu"),
    BatchNormalization(),
    MaxPool1D(),
    Conv1D(filters=32, kernel_size=3, activation="relu"),
    BatchNormalization(),
    Conv1D(filters=64, kernel_size=3, activation="relu"),
    BatchNormalization(),
    Flatten(),
    Dense(20, activation="relu"),
    Dense(num_of_categories, activation="softmax")
]
)

# train process

model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

checkPoint = ModelCheckpoint(filepath="model.h5", save_weights_only=False, monitor='val_accuracy',
                            mode='max', save_best_only=True)

model.fit(x=np.expand_dims(X_train, axis=2), y=y_train,
          batch_size=128, epochs=10, validation_data=(np.expand_dims(X_test, axis=2), y_test),
          callbacks=[checkPoint])

model = tf.keras.models.load_model("model.h5")
loss, acc = model.evaluate(np.expand_dims(X_test, axis=2), y_test)

采用Occlusion Sensitivity方法进行可视化,相关的参考文献较多。

fig, axs = plt.subplots(nrows=5, ncols=1)
fig.suptitle("Occlusion sensitivity")
fig.tight_layout()
fig.set_size_inches(10, 10)
axs = axs.ravel()

for c, row, ax, title in zip(range(5), examples_to_visualise, axs, titles):
    sensitivity, _ = signal_screen.calculate_occlusion_sensitivity(model=model,
                                                         data=np.expand_dims(X_test[row, :], axis=(0, 2)),
                                                         c=c,
                                                         number_of_zeros=[15])

    # create gradient plot
    signal_screen_tools.plot_with_gradient(ax=ax, y=X_test[row, :].ravel(), gradient=sensitivity[0], title=title)
    ax.set_xlabel("Samples[-]")
    ax.set_ylabel("ECG [-]")

plt.show()

采用Saliency map方法进行可视化。

采用Grad-CAM方法进行可视化。

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

最近更新

  1. TCP协议是安全的吗?

    2024-04-04 00:16:01       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-04 00:16:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-04 00:16:01       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-04 00:16:01       20 阅读

热门阅读

  1. 网络学习笔记 01 计算机硬件基础 - 数据的表示

    2024-04-04 00:16:01       13 阅读
  2. 从零开始学RSA:N不互素

    2024-04-04 00:16:01       12 阅读
  3. The Morning Star

    2024-04-04 00:16:01       15 阅读
  4. Windows——什么是进程?

    2024-04-04 00:16:01       11 阅读
  5. (译) 理解 Elixir 中的宏 Macro, 第四部分:深入化

    2024-04-04 00:16:01       13 阅读
  6. ViT模型实现-数据处理

    2024-04-04 00:16:01       19 阅读
  7. android 内存优化

    2024-04-04 00:16:01       16 阅读
  8. 财务管理 基础1:除了利润,一切都是扯淡

    2024-04-04 00:16:01       17 阅读