【机器学习】基于CTC模型的语音转换可编辑文本研究

1.引言

1.1语音识别技术的研究背景

1.1.1.语音识别技术的需求

语音识别技术的研究和发展,对于提升人类与机器的交互方式具有深远的影响。首先,它极大地提高了工作效率和便利性。通过语音指令控制设备,用户可以更快捷地完成任务,无需手动输入或操作。例如,在办公环境中,语音识别可以快速完成文档编辑、邮件发送等任务;在家庭中,智能家居设备响应语音指令,实现灯光、温度等的调节。这种交互方式不仅节省时间,也使得那些由于身体条件限制而难以使用传统输入方式的人群能够更轻松地使用技术。

其次,语音识别技术对于实现无障碍访问至关重要。它为视觉障碍人士提供了一种全新的与技术互动的方式,使他们能够独立完成日常任务,如阅读信息、上网浏览等。此外,多语言支持的语音识别技术也促进了全球化背景下的交流与合作,帮助不同语言背景的人们跨越语言障碍,更有效地沟通和使用技术。

最后,语音识别技术在教育、安全、紧急响应以及数据收集和分析等领域的应用,进一步扩展了其研究的价值。在教育领域,它可以帮助学生练习发音和语言技能,提供个性化的学习体验。在安全领域,语音识别可以用于身份验证,提高交易和通信的安全性。在紧急情况下,如医疗急救或灾难响应,快速准确的语音识别可以挽救生命。同时,自动转录的语音数据为市场研究、客户服务和产品开发等领域提供了宝贵的信息资源。

综上所述,语音识别技术的研究不仅推动了人工智能和机器学习领域的进步,而且为社会带来了更广泛的应用和便利,其重要性和影响力正日益增强。

1.1.2. 语音识别的挑战

当我们着手进行语音识别研究时,一个核心的挑战是音频信号与文本转录之间的对齐问题。尽管我们拥有音频片段和相应的文本记录,但音频中的声音与文本中的字符之间的精确对应关系往往是未知的。这种对齐的不确定性大大增加了训练一个高效语音识别系统的任务难度。

在没有精确对齐信息的情况下,采用简单的对应规则——比如假设每个字符对应固定数量的音频样本——是不可行的。由于说话人的语速、口音、情感表达等因素的差异,这种假设很快就会被现实情况所打破。例如,一个快速说话的人可能在很短的时间内说出许多字符,而一个慢速说话的人则可能在较长的时间段内只说出几个字符。

解决这一问题的一种方法是手动对齐,即将每个字符与音频中的特定位置进行匹配。虽然这种方法在理论上可以提供准确的模型训练数据,但它在实际操作中存在巨大的工作量,特别是对于大规模数据集而言,这种方法几乎是不可行的。这不仅需要耗费大量的人力和时间,而且很难保证一致性和准确性。

这种对齐问题并非语音识别领域所独有。在其他领域,如手写文字识别、视频分析等,我们也面临着类似的挑战。在手写文字识别中,需要确定笔划的顺序和结构,以正确识别单词和句子。而在视频分析中,动作识别和标注需要将视频中的连续帧与特定的动作或事件相对应。这些任务都需要精确的时间或空间对齐,以确保识别的准确性。

为了克服这些挑战,研究人员开发了各种算法和模型来自动学习数据中的对齐关系。例如,在语音识别中,可以使用深度学习技术,如循环神经网络(RNN)和长短期记忆网络(LSTM),它们能够捕捉音频信号中的时间依赖性,并学习字符与音频之间的复杂映射关系。在手写文字识别中,卷积神经网络(CNN)可以用于识别笔划模式和结构。而在视频分析中,两维和三维卷积网络可以用于捕捉视频中的空间和时间特征。

1.2. CTC模型简介

为了更正式地描述,让我们考虑将输入序列 𝑋 = [𝑥₁, 𝑥₂, …, 𝑥𝑇](如音频)映射到相应的输出序列 𝑌 = [𝑦₁, 𝑦₂, …, 𝑦𝑈](如转录文本)。我们想要找到从 𝑋 到 𝑌 的准确映射。

然而,使用更简单的监督学习算法会面临一些挑战,特别是:

  1. 𝑋 和 𝑌 的长度都可能变化。
  2. 𝑋 和 𝑌 长度的比例也可能变化。
  3. 我们没有 𝑋 和 𝑌 元素之间的准确对齐(对应)。

CTC 算法克服了这些挑战。对于给定的 𝑋,它给出了所有可能 𝑌 的输出分布。我们可以使用这个分布来推断可能的输出或评估给定输出的概率。

并非所有计算损失函数和执行推理的方法都是可行的。我们需要 CTC 高效地执行这两项任务。

损失函数:对于给定的输入,我们希望训练模型以最大化其对正确答案的分配概率。为此,我们需要有效地计算条件概率 𝑝(𝑌|𝑋)。函数 𝑝(𝑌|𝑋) 也必须是可微的,以便我们可以使用梯度下降。

推理:自然地,在训练模型之后,我们想要使用它来在给定的 𝑋 下推断可能的 𝑌。这意味着解决 𝑌∗ = argmax𝑌 𝑝(𝑌|𝑋) 的问题。理想情况下,可以高效地找到 𝑌∗。使用 CTC,我们将寻求一个不太昂贵且易于找到的近似解。

1.3.研究内容

语音识别技术,作为人工智能领域内的一个重要分支,是计算机科学与计算语言学相互交融的产物。它的目标是使计算机能够理解并转换人类的语音为可编辑的文本。这项技术常被称为自动语音识别(ASR)、计算机语音识别或语音到文本(STT)。它综合了计算机科学、语言学和计算机工程等多个学科的研究成果,为人工智能的进步和人机交互的优化提供了坚实的基础。

本文将深入探讨如何利用二维卷积神经网络(2D CNN)、循环神经网络(RNN)以及连接时序分类(CTC)损失函数,构建一个高效的自动语音识别系统。CTC算法在解决序列预测问题,尤其是处理音频信号与字符之间的对齐问题时,显示出了其卓越的能力。

在构建该系统的过程中,我们将采用LJSpeech数据集,这是一个源自LibriVox项目的开源资源。该数据集包含了由单一说话者朗读的短音频片段,这些片段均选自7本非小说类书籍。使用此数据集,我们可以训练出一个能够理解不同语境下语音信号的模型。

为了评估模型的性能,我们将采用词错误率(WER)作为主要的评价指标。WER是衡量语音识别准确性的一种标准方法,它通过计算识别结果中替换、插入和删除错误的数量,并与原始文本中的单词总数进行比较来得出。这一指标能够全面反映模型在语音识别任务中的准确性和稳定性。我们将使用jiwer这个开源Python库来计算WER,它提供了一套便捷的工具,以便于评估ASR系统的性能。

接下来,本文将详细阐述构建基于CTC的ASR模型的各个环节,包括数据预处理、模型架构设计、训练策略以及性能评估等。通过这些内容,读者不仅能够深入理解ASR系统的构建原理和关键技术,而且能够获得宝贵的知识和经验,为未来的研究和实践打下坚实的基础。

2. 语音识别实现过程

2.1. 软件包安装和设置

安装jiwer软件包

pip install jiwer

设置

# 导入 pandas 库,用于数据处理  
import pandas as pd  
  
# 导入 numpy 库,用于数值计算  
import numpy as np  
  
# 导入 tensorflow 库,用于深度学习模型构建  
import tensorflow as tf  
  
# 从 tensorflow 中导入 keras 模块,用于构建和训练深度学习模型  
from tensorflow import keras  
  
# 从 keras 中导入 layers 模块,包含各种神经网络层  
from tensorflow.keras import layers  
  
# 导入 matplotlib.pyplot 库,用于绘图  
import matplotlib.pyplot as plt  
  
# 导入 IPython.display 模块,用于在 Jupyter 环境中显示图像等  
from IPython import display  
  
# 导入 jiwer 库,用于计算词错误率(WER)  
from jiwer import wer  

2.2.加载数据集

本文使用LJSpeech数据集作为研究的数据资源。

2.2.1.LJSpeech数据集简介

LJSpeech数据集是一个为文本到语音合成(TTS)任务设计的公共数据集,由一位女性演讲者朗读7本非小说类书籍的段落录制而成。该数据集包含13,100个简短的音频剪辑,每个音频剪辑都是单通道16位PCM WAV格式,采样率为22050 Hz。音频剪辑的长度从1秒到10秒不等,总长度约为24小时。

数据集附带有元数据文件(通常为CSV格式),其中包含每个音频文件的ID、转录(Transcription)和规范化转录(Normalized transcription)。ID对应.wav文件的名称,转录列显示了读者说出的原始单词,而规范化转录则将数字、序数和货币单位等展开为完整的单词。这一特性使得LJSpeech数据集在文本到语音合成的研究和开发中特别有用,因为它可以帮助模型更好地学习特定说话者的语音特征和表达方式。

2.2.2. 加载数据集步骤

要加载LJSpeech数据集,您可以按照以下步骤进行:

  1. 下载数据集

    • 您可以通过直接链接下载LJSpeech数据集:http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
    • 或者,您可以使用百度网盘链接下载,链接为:链接,提取码为:7o1a
  2. 解压数据集

    • 下载完成后,您会得到一个名为LJSpeech-1.1.tar.bz2的压缩包。
    • 在Linux系统中,您可以使用以下命令解压:tar -jxvf LJSpeech-1.1.tar.bz2
  3. 了解数据集结构

    • 解压后,您会发现数据集包含两个主要部分:音频文件和元数据。
    • 音频文件位于/wavs/文件夹中,每个文件都是一个单通道16位PCM WAV,采样率为22,050 Hz。
    • 元数据存储在metadata.csv文件中,该文件包含每个音频文件的标签(转录)信息。
  4. 处理元数据

    • metadata.csv文件包含以下字段:
      • ID:这是对应.wav文件的名称
      • Transcription:读者说出的单词(UTF-8)
      • Normalized transcription:使用数字、序数和货币单位进行转录并扩展为完整单词(UTF-8)
    • 在本演示中,我们将使用“Normalized transcription”字段作为标签。
  5. 加载数据

    • 根据您的应用程序或框架,您可能需要编写代码来遍历/wavs/文件夹中的音频文件,并使用metadata.csv中的相应“Normalized transcription”作为标签。
    • 如果您正在使用Python,并希望利用现有的数据处理库(如pandas),您可以先加载metadata.csv文件到一个DataFrame中,然后按照需要进行处理。
# 导入所需的库  
import pandas as pd  
import tensorflow as tf  
from tensorflow.keras import utils  
  
# 数据集的下载链接  
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"  
  
# 使用keras的utils.get_file函数下载并解压数据集  
# 如果数据集已经下载过,则不会重复下载  
data_path = utils.get_file("LJSpeech-1.1", data_url, untar=True)  
  
# 设置音频文件和元数据文件的路径  
wavs_path = data_path + "/wavs/"  
metadata_path = data_path + "/metadata.csv"  
  
# 读取元数据文件并解析  
# 注意:原始数据可能使用“|”作为分隔符,但具体应查看文件确认  
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3, encoding='utf-8')  
  
# 为DataFrame设置列名  
metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]  
  
# 保留需要的列:文件名和规范化转录  
metadata_df = metadata_df[["file_name", "normalized_transcription"]]  
  
# 如果需要对数据集进行随机排序,可以使用sample方法  
# 这里设置frac=1表示保留全部数据,但进行随机排序  
# 如果不需要随机排序,可以省略这行代码  
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)  
  
# 显示前3行数据  
print(metadata_df.head(3))

上述代码的主要功能如下:

  1. 下载和解压数据集
    - 使用 TensorFlow(通过其 Keras API)的 utils.get_file 函数从指定的 URL 下载 LJSpeech 数据集(一个 tar.bz2 压缩包)。
    - 如果数据集已经下载过,则不会重复下载。
    - 解压下载的压缩包到指定的路径。

  2. 设置文件路径
    - 定义音频文件(wav 文件)所在的路径 wavs_path
    - 定义元数据文件(CSV 文件)所在的路径 metadata_path

  3. 读取元数据文件
    - 使用 pandas 的 read_csv 函数读取元数据 CSV 文件。
    - 假设 CSV 文件的字段之间使用 “|” 作为分隔符(但实际中需要根据文件内容来确定分隔符)。
    - 读取文件时指定了 header=None,因为 CSV 文件可能没有明确的标题行。
    - 读取时使用了 quoting=3,它告诉 pandas 如何处理引号内的字段(具体行为取决于 pandas 的实现和 CSV 文件的具体内容)。
    - 设置了文件的编码为 ‘utf-8’,以确保能够正确读取 UTF-8 编码的文本。

  4. 处理元数据
    - 为读取的 DataFrame 设置列名:file_name(音频文件名)、transcription(原始转录)和 normalized_transcription(规范化转录)。
    - 保留需要的列:file_namenormalized_transcription,因为在这个示例中我们只关心文件名和规范化后的转录文本。
    - 如果需要,可以对 DataFrame 进行随机排序(通过 sample(frac=1)),然后重置索引(通过 reset_index(drop=True))。这在某些情况下可能是有用的,比如当你想要打乱数据集的顺序时。

  5. 显示部分数据
    - 使用 print(metadata_df.head(3)) 显示元数据 DataFrame 的前 3 行。这有助于验证数据是否已成功加载并处理。

2.2.3.划分数据集

我们现在将数据分割成训练集和验证集。

# 计算分割点,用于将数据集分为90%的训练集和10%的验证集
split = int(len(metadata_df) * 0.90)

# 根据计算出的分割点,将数据集分割为训练集和验证集
df_train = metadata_df[:split]  # 训练集为前90%的数据
df_val = metadata_df[split:]   # 验证集为剩余10%的数据

# 打印训练集的大小
print(f"训练集大小: {len(df_train)}")

# 打印验证集的大小
print(f"验证集大小: {len(df_val)}")

2.3.数据预处理

2.3.1.准备词汇表
# 定义接受的字符集合,用于转录文本
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]

# 将字符映射到整数,用于模型的输入
# 使用keras的StringLookup层,设置词汇表为characters,并且定义一个未登录词(OOV)的标记
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")

# 将整数映射回原始字符,用于模型的输出
# 使用keras的StringLookup层,词汇表为char_to_num的词汇表反转,同样定义OOV标记
num_to_char = keras.layers.StringLookup(
    vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
)

# 打印词汇表和它的大小
print(
    f"词汇表是: {char_to_num.get_vocabulary()} "
    f"(大小 = {char_to_num.vocabulary_size()})"
)

这段代码的主要功能是为自动语音识别(ASR)系统或任何需要文本转录的任务准备字符编码映射。具体来说,它包括以下几个步骤:

  1. 定义字符集:创建一个包含所有接受字符的列表,这些字符将用于文本转录。在这个例子中,字符集包括小写英文字母、一些标点符号以及空格。

  2. 字符到数字的映射:使用keras.layers.StringLookup层来将字符映射到整数。这个映射过程对于将文本数据转换为模型可以处理的数值型数据是必要的。vocabulary参数指定了所有可能的字符,oov_token参数定义了一个特殊的标记,用于表示词汇表之外的字符(在这个例子中没有使用OOV标记,所以设置为空字符串)。

  3. 数字到字符的映射:再次使用keras.layers.StringLookup层,但是这次设置invert=True,创建一个从整数映射回原始字符的反向映射。这样,模型的输出可以被转换回原始文本形式。

  4. 打印词汇表和大小:打印出创建的词汇表和它的大小。词汇表大小是指映射中包含的唯一字符数量。

2.3.2.定义转化函数

在文本到语音(TTS)或自然语言处理(NLP)等任务中,预处理是一个关键步骤,它涉及到将原始数据转换为模型可以处理的格式。一旦我们准备好了词汇表(例如,在TTS中,这可能是字符或音素列表),下一步通常是定义一个函数来执行数据集的预处理和转换。

例如,在TTS中,预处理可能包括以下几个步骤:

  1. 文本清洗:删除不必要的字符、标点符号、特殊符号等。
  2. 文本归一化:将所有文本转换为小写(如果需要的话),并执行其他任何必要的文本转换。
  3. 词汇映射:使用前面准备好的词汇表,将文本转换为模型可以理解的格式,例如将字符或单词映射到唯一的索引或编码。
  4. 数据增强(可选):根据任务需求,可能还需要对数据进行增强,例如添加噪声、改变语速等。

为了实现这些转换,我们可以编写一个预处理函数,该函数接受原始文本作为输入,并返回经过转换的数据。这个函数将使用我们前面准备的词汇表和其他任何必要的转换逻辑。

# 定义音频处理的参数
# 帧长度,即STFT窗口中的样本数
frame_length = 256
# 帧步长,连续STFT窗口之间的样本数
frame_step = 160
# FFT大小,如果不指定,则使用大于等于frame_length的最小2的幂
fft_length = 384

def encode_single_sample(wav_file, label):
    ###########################################
    ## 音频处理
    ###########################################
    # 1. 读取wav文件
    file = tf.io.read_file(wavs_path + wav_file + ".wav")
    # 2. 解码wav文件,获取音频数据和采样率
    audio, _ = tf.audio.decode_wav(file)
    # 3. 去除音频数据的维度,只保留一维数据
    audio = tf.squeeze(audio, axis=-1)
    # 4. 将音频数据类型转换为浮点数
    audio = tf.cast(audio, tf.float32)
    # 5. 计算频谱图,使用STFT
    spectrogram = tf.signal.stft(
        audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
    )
    # 6. 取频谱图的幅度,即实部和虚部的模
    spectrogram = tf.abs(spectrogram)
    # 7. 对频谱图进行开平方操作
    spectrogram = tf.math.pow(spectrogram, 0.5)
    # 8. 对频谱图进行归一化处理
    means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
    stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
    spectrogram = (spectrogram - means) / (stddevs + 1e-10)

    ###########################################
    ## 文本标签处理
    ###########################################
    # 9. 将标签转换为小写
    label = tf.strings.lower(label)
    # 10. 将标签字符串分割成字符序列
    label = tf.strings.unicode_split(label, input_encoding="UTF-8")
    # 11. 将字符序列映射到整数序列
    label = char_to_num(label)

    # 返回包含频谱图和标签映射的字典
    return {"spectrogram": spectrogram, "label": label}

这段代码定义了一个函数encode_single_sample,它用于处理单个音频样本并将其转换为模型训练所需的格式。以下是代码的主要功能步骤:

  1. 定义帧长度、帧步长和FFT大小
    - frame_length:窗口长度,用于STFT(短时傅里叶变换)的样本数。
    - frame_step:帧步长,连续STFT窗口之间的样本数。
    - fft_length:应用FFT的大小,如果没有提供,则使用比frame_length大的最小2的幂。

  2. 读取和解码音频文件
    - 读取指定路径下的WAV文件。
    - 使用tf.audio.decode_wav解码WAV文件,获取音频数据和采样率。

  3. 转换音频数据类型

    • 将音频数据从原始类型转换为tf.float32,以便进行数学运算。
  4. 计算频谱图
    - 使用tf.signal.stft计算音频的短时傅里叶变换,得到频谱图。

  5. 提取频谱图的幅度信息
    - 通过取绝对值和开平方来获取频谱图的幅度信息。

  6. 归一化处理
    - 对频谱图进行归一化,使用均值和标准差进行标准化。

  7. 处理标签
    - 将标签转换为小写,以保证一致性。
    - 使用tf.strings.unicode_split将标签字符串分割成字符序列。

  8. 字符到数字的映射
    - 使用之前定义的char_to_num将每个字符映射到相应的整数。

  9. 返回结果
    - 函数返回一个包含频谱图和标签映射的字典,这是模型所期望的输入格式。

整体来看,这个函数是将原始音频文件和对应的文本标签转换为模型训练所需的数值型数据。通过STFT转换音频信号为频谱图,并对其进行归一化处理,同时将文本标签转换为数字序列,为模型训练做好准备。
创建数据集对象

我们创建一个tf.data.Dataset对象,该对象按照输入中出现的顺序生成转换后的元素。

在TensorFlow中,tf.data.Dataset API 提供了一种高效、灵活的方式来构建输入管道,用于机器学习模型。使用tf.data.Dataset,你可以从各种数据源(如文件、内存中的数据结构等)读取数据,应用各种转换(如映射、批处理、打乱等),并将数据以合适的形式提供给模型进行训练、评估或预测。

当你有一个原始数据集(例如,一个包含文本文件或音频文件的列表),并且你已经定义了如何将每个元素转换为模型可以理解的格式(通过前面提到的预处理函数),你就可以使用tf.data.Dataset来创建一个可以迭代的数据集对象。

# 定义批量大小
batch_size = 32

# 定义训练数据集
# 从数据框df_train中提取文件名和标准化转录文本的列表
train_dataset = tf.data.Dataset.from_tensor_slices(
    (list(df_train["file_name"]), list(df_train["normalized_transcription"]))
)

# 使用map函数和encode_single_sample函数处理每个样本
# num_parallel_calls设置为AUTOTUNE以自动调整并行调用数
# 使用padded_batch函数将样本打包成批量,自动填充不足批量大小的样本
# prefetch函数用于优化性能,提前从数据集中提取数据
train_dataset = (
    train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
    .padded_batch(batch_size)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

# 定义验证数据集
# 从数据框df_val中提取文件名和标准化转录文本的列表
validation_dataset = tf.data.Dataset.from_tensor_slices(
    (list(df_val["file_name"]), list(df_val["normalized_transcription"]))
)

# 与训练数据集类似,使用map函数和encode_single_sample函数处理每个样本
# 然后使用padded_batch和prefetch函数进行批处理和预提取
validation_dataset = (
    validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
    .padded_batch(batch_size)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

这段代码主要完成了以下任务:

  1. 设置批量大小batch_size为32,这是模型训练时每个批次的样本数量。

  2. 定义训练数据集train_dataset
    - 使用tf.data.Dataset.from_tensor_slices从训练数据框df_train中提取文件名和转录文本的列表,创建数据集。
    - 使用map函数和encode_single_sample函数处理数据集中的每个样本,num_parallel_calls=tf.data.AUTOTUNE允许TensorFlow自动调整并行处理的数量以优化性能。
    - 使用padded_batch函数将处理后的样本打包成批量,如果最后一个批次的样本数不足batch_size,会自动进行填充。
    - 使用prefetch函数提前从数据集中提取数据,进一步提高训练时的数据加载效率。

  3. 定义验证数据集validation_dataset
    - 类似于训练数据集的创建过程,从验证数据框df_val中提取文件名和转录文本的列表,创建数据集。
    - 同样使用mappadded_batchprefetch函数处理和优化数据加载。

通过这种方式,训练和验证数据集被准备好,可以用于模型的训练和评估。

2.2.3 数据可视化

在数据集中可视化一个示例,包括音频片段、频谱图和对应的标签,通常涉及多个步骤。由于音频数据和标签(如转录文本)是不同的数据类型,因此需要使用不同的可视化技术。以下是一个概述,说明如何实现这种可视化:

  1. 加载音频数据:首先从数据集中加载一个音频文件。这通常可以通过Python的音频处理库(如librosasoundfile)来完成。
  2. 绘制音频波形图:使用matplotlib等库,你可以绘制音频的波形图,显示音频信号的振幅随时间的变化。
  3. 计算频谱图(Spectrogram):频谱图是一种显示音频信号频率内容的可视化方式。你可以使用librosa库来计算音频的短时傅里叶变换(STFT),然后绘制频谱图。
  4. 显示标签(文本转录):标签(通常是文本转录)可以简单地以文本形式显示,或者如果你想要更高级的可视化,可以考虑使用文本到语音(TTS)引擎将文本转录转换回语音,但这样做通常比较复杂且不是直接的可视化。

这段代码的主要功能是展示训练数据集中的第一个样本的频谱图和对应的音频波形。

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from IPython.display import display, Audio

# 创建一个大小为8x5英寸的图形
fig = plt.figure(figsize=(8, 5))

# 从训练数据集中取出一个批次的样本
for batch in train_dataset.take(1):
    # 获取频谱图数据,并转换为numpy数组
    spectrogram = batch[0][0].numpy()
    # 转置频谱图,去除全0行,并重新转换为numpy数组
    spectrogram = np.array([np.trim_zeros(x) for x in np.transpose(spectrogram)])
    # 获取标签数据,并转换为字符串
    label = batch[1][0]
    label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
    
    # 绘制频谱图
    ax = plt.subplot(2, 1, 1)  # 创建2行1列的子图布局,并获取第1个子图
    ax.imshow(spectrogram, vmax=1)  # 绘制频谱图,设置最大值为1
    ax.set_title(label)  # 设置子图标题为标签内容
    ax.axis("off")  # 关闭坐标轴显示

    # 读取音频文件并绘制音频波形
    file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav")
    audio, _ = tf.audio.decode_wav(file)  # 解码wav文件,获取音频数据
    audio = audio.numpy()  # 将音频数据转换为numpy数组
    
    ax = plt.subplot(2, 1, 2)  # 获取第2个子图
    plt.plot(audio)  # 绘制音频波形
    ax.set_title("Signal Wave")  # 设置子图标题
    ax.set_xlim(0, len(audio))  # 设置x轴显示范围为音频数据的长度
    
    # 播放音频
    display.display(Audio(np.transpose(audio), rate=16000))

# 显示图形
plt.show()

代码的主要步骤包括:

  1. 创建一个图形,并设置其大小。
  2. 从训练数据集中取出一个批次的样本。
  3. 获取并处理频谱图数据,包括转置、去除全0行等。
  4. 获取标签数据,并将其转换为可读的字符串。
  5. 绘制频谱图,并设置标题和关闭坐标轴。
  6. 读取音频文件,解码并获取音频数据。
  7. 绘制音频波形,并设置标题和x轴显示范围。
  8. 使用IPython.display.Audio播放音频。
  9. 显示图形。

2.4 建立模型

2.4.1.定义CTC模型损失函数

下面的代码定义了一个函数CTCLoss,用于计算连接时序分类(CTC)损失值。CTC损失是用于序列预测问题,特别是像语音识别这样的任务中,其中输入是一个连续的信号,而输出是一个离散的序列。

import tensorflow as tf
from keras import backend as K

def CTCLoss(y_true, y_pred):
    # 计算训练时的损失值
    # 获取批次中样本的数量
    batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
    # 获取预测序列的长度
    input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
    # 获取真实标签序列的长度
    label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

    # 将输入长度和标签长度扩展到批次中每个样本的长度
    input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
    label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

    # 计算CTC损失
    # y_true是真实标签的one-hot编码形式
    # y_pred是模型预测的logits
    # input_length和label_length分别是预测序列和真实标签序列的长度
    loss = K.ctc_batch_cost(y_true, y_pred, input_length, label_length)
    
    return loss

函数CTCLoss接收两个参数:

  • y_true:真实标签的one-hot编码形式,形状为[batch_size, label_length]
  • y_pred:模型预测的logits,形状为[batch_size, input_length, num_classes]

函数的主要步骤:

  1. 计算批次大小batch_len,即批次中样本的数量。
  2. 获取模型预测序列的长度input_length和真实标签序列的长度label_length
  3. input_lengthlabel_length扩展到批次中每个样本的长度,以便与y_truey_pred对齐。
  4. 使用keras.backend.ctc_batch_cost函数计算CTC损失。这个函数接受真实标签、预测logits、输入长度和标签长度作为参数,并返回损失值。

CTC损失是一个重要的指标,特别是在处理像语音识别这样的序列预测问题时,它允许模型学习如何将连续的输入信号映射到正确的序列输出,即使在训练数据中存在噪声或不准确的对齐。

2.4.2. 定义模型

下面的代码定义了一个构建深度学习模型的函数build_model,并展示了如何使用这个函数来创建一个类似于DeepSpeech2的模型。然后,它使用指定的参数调用这个函数,并打印出模型的概述。

import tensorflow as tf
from tensorflow.keras import layers, Model

def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
    """构建类似于DeepSpeech2的模型。"""
    # 模型输入
    input_spectrogram = layers.Input((None, input_dim), name="input")
    # 扩展维度以使用2D卷积
    x = layers.Reshape((-1, input_dim, 1), name="expand_dim")(input_spectrogram)
    # 卷积层1
    x = layers.Conv2D(
        filters=32,
        kernel_size=[11, 41],
        strides=[2, 2],
        padding="same",
        use_bias=False,
        name="conv_1",
    )(x)
    x = layers.BatchNormalization(name="conv_1_bn")(x)
    x = layers.ReLU(name="conv_1_relu")(x)
    # 卷积层2
    x = layers.Conv2D(
        filters=32,
        kernel_size=[11, 21],
        strides=[1, 2],
        padding="same",
        use_bias=False,
        name="conv_2",
    )(x)
    x = layers.BatchNormalization(name="conv_2_bn")(x)
    x = layers.ReLU(name="conv_2_relu")(x)
    # 重塑结果体积以馈送到RNN层
    x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x)
    # RNN层
    for i in range(1, rnn_layers + 1):
        recurrent = layers.GRU(
            units=rnn_units,
            activation="tanh",
            recurrent_activation="sigmoid",
            use_bias=True,
            return_sequences=True,
            reset_after=True,
            name=f"gru_{i}",
        )
        x = layers.Bidirectional(
            recurrent, name=f"bidirectional_{i}", merge_mode="concat"
        )(x)
        if i < rnn_layers:
            x = layers.Dropout(rate=0.5)(x)
    # 密集连接层
    x = layers.Dense(units=rnn_units * 2, name="dense_1")(x)
    x = layers.ReLU(name="dense_1_relu")(x)
    x = layers.Dropout(rate=0.5)(x)
    # 分类层
    output = layers.Dense(units=output_dim + 1, activation="softmax")(x)
    # 模型
    model = Model(input_spectrogram, output, name="DeepSpeech_2")
    # 优化器
    opt = keras.optimizers.Adam(learning_rate=1e-4)
    # 编译模型并返回
    model.compile(optimizer=opt, loss=CTCLoss)
    return model

# 获取模型
model = build_model(
    input_dim=fft_length // 2 + 1,  # 输入维度
    output_dim=char_to_num.vocabulary_size(),  # 输出维度,即词汇表大小
    rnn_units=512,  # RNN单元数
)
# 打印模型概述
model.summary(line_length=110)

函数build_model的主要步骤:

  1. 定义模型输入input_spectrogram,其形状为(None, input_dim),其中None表示可变的时间步长。
  2. 使用Reshape层扩展输入维度,以适应2D卷积。
  3. 添加两个卷积层(Conv2D),每个卷积层后面跟着批量归一化(BatchNormalization)和ReLU激活函数。
  4. 使用Reshape层将卷积层的输出重塑为适合RNN层的形状。
  5. 循环创建多个双向GRU层(Bidirectional包装GRU),并在RNN层之间添加dropout层以防止过拟合。
  6. 添加一个密集连接层(Dense),后面跟一个ReLU激活函数和dropout层。
  7. 添加一个分类层,使用softmax激活函数,输出维度为output_dim + 1(包括一个用于CTC损失的空白符号)。
  8. 创建并编译Keras模型,使用Adam优化器和之前定义的CTC损失函数。

最后,使用指定的参数调用build_model函数创建模型,并使用model.summary打印模型的层和参数的详细概述。

2.5. 训练和评估模型

2.5.1.定义解码算法和回调函数
import numpy as np

def decode_batch_predictions(pred):
    # 输入长度,假设每个样本的输入长度等于pred的第二维度
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    
    # 使用贪婪搜索解码CTC输出,对于复杂任务可以使用束搜索
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
    
    # 迭代结果,将解码后的序列转换回文本
    output_text = []
    for result in results:
        # 将解码结果转换为字符串
        result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
        output_text.append(result)
    
    return output_text
from keras import backend as K

class CallbackEval(keras.callbacks.Callback):
    """每个epoch结束后显示一批输出的回调类。"""

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def on_epoch_end(self, epoch: int, logs=None):
        predictions = []
        targets = []
        # 迭代数据集中的每个批次
        for batch in self.dataset:
            X, y = batch
            # 预测批次数据
            batch_predictions = model.predict(X)
            # 解码批次预测
            batch_predictions = decode_batch_predictions(batch_predictions)
            predictions.extend(batch_predictions)
            for label in y:
                # 将标签转换为字符串
                label = (
                    tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
                )
                targets.append(label)
        
        # 计算词错误率(WER)
        wer_score = wer(targets, predictions)
        print("-" * 100)
        print(f"词错误率(Word Error Rate): {wer_score:.4f}")
        print("-" * 100)
        # 随机选择两个样本展示目标和预测结果
        for i in np.random.randint(0, len(predictions), 2):
            print(f"目标文本    : {targets[i]}")
            print(f"预测文本    : {predictions[i]}")
            print("-" * 100)
  1. 函数 decode_batch_predictions
    这个函数接受模型预测的输出pred,使用CTC解码算法(这里使用了贪婪搜索)来将预测的字符索引解码回文本字符串,并返回一个包含所有解码文本的列表。
  2. 回调类 CallbackEval
    这个回调类在每个epoch结束后被调用,它使用decode_batch_predictions函数解码模型的预测输出,并将解码后的文本与实际的目标文本进行比较。然后,它计算并打印整个数据集的词错误率(WER),并随机展示几个样本的目标文本和预测文本进行对比。
2.5.2.训练模型

下面的代码展示了如何使用Keras模型进行训练,并在每个epoch结束后使用自定义的回调函数来评估验证数据集上的转录性能。

# 定义训练的轮数(epoch数)
epochs = 1

# 创建CallbackEval回调类的一个实例,传入验证数据集
validation_callback = CallbackEval(validation_dataset)

# 训练模型
# 使用fit方法训练模型,传入训练数据集和验证数据集
# epochs参数指定训练的轮数
# callbacks参数传入自定义的回调函数列表,这里只有一个回调函数
history = model.fit(
    train_dataset,               # 训练数据集
    validation_data=validation_dataset,  # 验证数据集
    epochs=epochs,                # 训练的轮数
    callbacks=[validation_callback],     # 回调函数列表
)

代码的主要步骤包括:

  1. 设置训练的epoch数,这里设置为1,表示模型将训练1轮。
  2. 创建CallbackEval类的实例validation_callback,并将验证数据集validation_dataset作为参数传入。这个回调类将在每个epoch结束后被调用,以输出验证数据集上的转录结果和词错误率(WER)。
  3. 使用模型的fit方法开始训练过程。传入训练数据集train_dataset和验证数据集validation_datasetepochs参数指定了模型训练的轮数。callbacks参数传入了一个包含validation_callback的列表,这样在每个epoch结束后,CallbackEval类的on_epoch_end方法会被调用。

通过这种方式,你可以在训练过程中监控模型在验证数据集上的性能,及时发现过拟合等问题,并根据需要调整模型结构或训练参数。

2.6. 推理预测

下面的代码是在模型训练结束后,用于评估模型在验证数据集上的性能,特别是通过计算词错误率(WER)来衡量。

import numpy as np
# 假设已经定义了wer函数,用于计算词错误率

# 初始化两个列表,用于存储模型的预测结果和实际的目标文本
predictions = []
targets = []

# 遍历验证数据集中的每个批次
for batch in validation_dataset:
    X, y = batch  # X是特征数据,y是标签数据
    # 使用模型对当前批次的数据进行预测
    batch_predictions = model.predict(X)
    # 解码CTC输出,将预测的字符索引转换为文本字符串
    batch_predictions = decode_batch_predictions(batch_predictions)
    # 将解码后的预测结果添加到predictions列表中
    predictions.extend(batch_predictions)
    # 遍历批次中的每个标签,解码并将其添加到targets列表中
    for label in y:
        label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
        targets.append(label)

# 计算所有预测结果和目标文本之间的词错误率
wer_score = wer(targets, predictions)

# 打印结果
print("-" * 100)
print(f"词错误率(Word Error Rate): {wer_score:.4f}")
print("-" * 100)

# 随机选择几个样本,打印目标文本和模型的预测文本
for i in np.random.randint(0, len(predictions), 5):
    print(f"目标文本    : {targets[i]}")
    print(f"预测文本    : {predictions[i]}")
    print("-" * 100)

代码的主要步骤包括:

  1. 初始化两个空列表predictionstargets,用于存储模型的预测文本和验证数据集的实际文本。
  2. 遍历验证数据集validation_dataset中的所有批次,使用模型对每个批次的数据进行预测,并将预测结果解码为文本字符串。
  3. 将解码后的预测文本添加到predictions列表中,同时将每个批次的目标文本解码并添加到targets列表中。
  4. 使用wer函数计算predictionstargets之间的词错误率wer_score
  5. 打印出词错误率的结果,并随机选择几个样本展示目标文本和模型的预测文本进行对比。

3.总结和展望

3.1.总结

本文详细介绍了语音识别技术的研究背景、挑战、CTC模型、研究内容以及实现过程。语音识别作为人机交互的重要方式,在提高工作效率、实现无障碍访问、推动多语言交流等方面具有重要意义。同时,它也面临音频信号与文本转录对齐等挑战。CTC算法通过提供所有可能输出序列的分布,有效解决了这些问题。

研究内容包括使用2D CNN、RNN和CTC损失函数构建高效的ASR系统,以及使用LJSpeech数据集训练模型。评估模型性能时,采用了WER作为主要评价指标,并通过jiwer库计算。

实现过程中,涉及软件包安装、数据集加载与预处理、模型构建、训练与评估以及推理预测。在数据预处理阶段,定义了词汇表并准备了字符到数字的映射。模型构建采用了类似于DeepSpeech2的架构,包括卷积层、Bidirectional GRU层和Dense层。训练过程中使用了自定义的回调函数来输出转录结果和WER。

3.2.展望

语音识别技术的未来研究可以从以下几个方面展开:

  1. 模型优化:进一步优化模型结构,提高识别准确率和鲁棒性。
  2. 数据增强:探索更多数据增强方法,提高模型对不同口音和噪声的适应能力。
  3. 多语言支持:扩展模型以支持更多语言,促进跨文化交流。
  4. 实时识别:优化模型以实现实时语音识别,满足更多实际应用需求。
  5. 端到端学习:研究端到端的语音识别模型,减少预处理步骤,提高效率。
  6. 低资源语言:针对资源匮乏的语言,研究如何利用少量数据训练有效模型。
  7. 可解释性:提高模型的可解释性,帮助理解模型的决策过程。

随着技术的不断发展,语音识别将在更多领域发挥作用,为人类社会带来更多便利。

参考文献

Mohamed Reda Bouadjenek and Ngoc Dung Huynh. (2024-6-13). CTC Automatic Speech Recognition (ASR) Example. 检索自:https://keras.io/examples/audio/ctc_asr/

附录1 示例代码

# 导入所需的库
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from IPython import display
from jiwer import wer

# 下载并加载LJSpeech数据集
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" 
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
wavs_path = data_path + "/wavs/"  # 音频文件路径
metadata_path = data_path + "/metadata.csv"  # 元数据文件路径

# 读取元数据文件并解析
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]
# 只保留文件名和规范化转录
metadata_df = metadata_df[["file_name", "normalized_transcription"]]
# 随机排序数据
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
# 显示前3行数据
print(metadata_df.head(3))

# 划分训练集和验证集
split = int(len(metadata_df) * 0.90)
df_train = metadata_df[:split]
df_val = metadata_df[split:]
print(f"训练集大小: {len(df_train)}")
print(f"验证集大小: {len(df_val)}")

# 准备词汇表
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
num_to_char = keras.layers.StringLookup(vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True)
print(f"词汇表是: {char_to_num.get_vocabulary()} (大小 = {char_to_num.vocabulary_size()})")

# 定义处理单个样本的函数
def encode_single_sample(wav_file, label):
    # 读取WAV文件并解码
    file = tf.io.read_file(wavs_path + wav_file + ".wav")
    audio, _ = tf.audio.decode_wav(file)
    audio = tf.squeeze(audio, axis=-1)
    audio = tf.cast(audio, tf.float32)
    
    # 计算频谱图
    spectrogram = tf.signal.stft(
        audio, frame_length=256, frame_step=160, fft_length=384
    )
    spectrogram = tf.abs(spectrogram)
    spectrogram = tf.math.pow(spectrogram, 0.5)
    
    # 归一化处理
    means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
    stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
    spectrogram = (spectrogram - means) / (stddevs + 1e-10)
    
    # 处理标签
    label = tf.strings.lower(label)
    label = tf.strings.unicode_split(label, input_encoding="UTF-8")
    label = char_to_num(label)
    
    return spectrogram, label

# 创建数据集对象
batch_size = 32
train_dataset = tf.data.Dataset.from_tensor_slices(
    (list(df_train["file_name"]), list(df_train["normalized_transcription"]))
).map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.padded_batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)

validation_dataset = tf.data.Dataset.from_tensor_slices(
    (list(df_val["file_name"]), list(df_val["normalized_transcription"]))
).map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = validation_dataset.padded_batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)

# 数据可视化
fig = plt.figure(figsize=(8, 5))
for batch in train_dataset.take(1):
    spectrogram = batch[0][0].numpy()
    label = tf.strings.reduce_join(num_to_char(batch[1][0])).numpy().decode("utf-8")
    plt.subplot(2, 1, 1)
    plt.imshow(spectrogram, vmax=1)
    plt.title(label)
    plt.axis("off")
    file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav")
    audio, _ = tf.audio.decode_wav(file)
    plt.subplot(2, 1, 2)
    plt.plot(audio.numpy())
    plt.title("Signal Wave")
    plt.xlim(0, len(audio))
    display.display(Audio(audio.numpy(), rate=16000))
plt.show()

# 定义CTC损失函数
def CTCLoss(y_true, y_pred):
    batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
    input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
    label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
    input_length = input_length * tf.ones((batch_len, 1), dtype="int64")
    label_length = label_length * tf.ones((batch_len, 1), dtype="int64")
    loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
    return loss

# 定义DeepSpeech2模型
def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
    input_spectrogram = layers.Input((None, input_dim))
    x = layers.Reshape((-1, input_dim, 1))(input_spectrogram)
    x = layers.Conv2D(filters=32, kernel_size=[11, 41], strides=[2, 2], padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters=32, kernel_size=[11, 21], strides=[1, 2], padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x)
    for i in range(rnn_layers):
        recurrent = layers.GRU(
            units=rnn_units,
            activation="tanh",
            recurrent_activation="sigmoid",
            use_bias=True,
            return_sequences=True,
            reset_after=True,
            name=f"gru_{i+1}"
        )
        x = layers.Bidirectional(recurrent, merge_mode="concat", name=f"bidirectional_{i+1}")(x)
        if i < rnn_layers - 1:
            x = layers.Dropout(rate=0.5)(x)
    x = layers.Dense(units=rnn_units * 2, name="dense_1")(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(rate=0.5)(x)
    output = layers.Dense(units=output_dim + 1, activation="softmax")(x)
    model = keras.Model(input_spectrogram, output, name="DeepSpeech_2")
    opt = keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(optimizer=opt, loss=CTCLoss)
    return model

# 获取模型
model = build_model(
    input_dim=384 // 2 + 1,
    output_dim=char_to_num.vocabulary_size(),
    rnn_units=512,
)
model.summary()

# 训练和评估模型
def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
    output_text = []
    for result in results:
        result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
        output_text.append(result)
    return output_text

class CallbackEval(keras.callbacks.Callback):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
    def on_epoch_end(self, epoch, logs=None):
        predictions = []
        targets = []
        for batch in self.dataset:
            X, y = batch
            batch_predictions = model.predict(X)
            batch_predictions = decode_batch_predictions(batch_predictions)
            predictions.extend(batch_predictions)
            for label in y:
                label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
                targets.append(label)
        wer_score = wer(targets, predictions)
        print("-" * 100)
        print(f"词错误率(Word Error Rate): {wer_score:.4f}")
        print("-" * 100)
        for i in np.random.randint(0, len(predictions), 2):
            print(f"目标文本    : {targets[i]}")
            print(f"预测文本    : {predictions[i]}")
            print("-" * 100)

# 开始训练过程
epochs = 1
validation_callback = CallbackEval(validation_dataset)
history = model.fit(
    train_dataset,
    validation_data=validation_dataset    # 训练的轮数
    epochs=epochs,
    # 训练过程中的回调函数
    callbacks=[validation_callback],
)

# 推理预测
# 评估模型在验证数据集上的性能
predictions = []
targets = []
for batch in validation_dataset:
    X, y = batch
    batch_predictions = model.predict(X)
    batch_predictions = decode_batch_predictions(batch_predictions)
    predictions.extend(batch_predictions)
    for label in y:
        label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
        targets.append(label)
wer_score = wer(targets, predictions)
print("-" * 100)
print(f"词错误率(Word Error Rate): {wer_score:.4f}")
print("-" * 100)
# 随机选择几个样本展示目标和预测结果
for i in np.random.randint(0, len(predictions), 5):
    print(f"目标文本    : {targets[i]}")
    print(f"预测文本    : {predictions[i]}")
    print("-" * 100)

"""
## 结论

在实践中,你应该训练大约50轮或更多。使用 `GeForce RTX 2080 Ti` GPU 进行训练,每个epoch大约需要5-6分钟。
我们在50轮训练后的模型 `词错误率 (WER)` 大约在 `16%` 到 `17%` 之间。

以下是大约第50轮训练后的一些转录样本:

**音频文件:LJ017-0009.wav**
- 目标文本    : sir thomas overbury was undoubtedly poisoned by lord rochester in the reign of james the first
- 预测文本    : cer thomas overbery was undoubtedly poisoned by lordrochester in the reign of james the first

**音频文件:LJ003-0340.wav**
- 目标文本    : the committee does not seem to have yet understood that newgate could be only and properly replaced
- 预测文本    : the committee does not seem to have yet understood that newgate could be only and proberly replace

**音频文件:LJ011-0136.wav**
- 目标文本    : still no sentence of death was carried out for the offense and in eighteen thirtytwo
- 预测文本    : still no sentence of death was carried out for the offense and in eighteen thirtytwo
"""

最近更新

  1. TCP协议是安全的吗?

    2024-06-14 07:52:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-14 07:52:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-14 07:52:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-14 07:52:02       18 阅读

热门阅读

  1. postman接口测试工具详解

    2024-06-14 07:52:02       6 阅读
  2. React@16.x(27)useCallBack

    2024-06-14 07:52:02       8 阅读
  3. 深入理解服务器路由功能:配置与应用详解

    2024-06-14 07:52:02       4 阅读
  4. TCP是什么、UDP是什么,它们有什么区别

    2024-06-14 07:52:02       6 阅读
  5. WHAT - React 学习系列(一)

    2024-06-14 07:52:02       9 阅读
  6. .NET C# 实现国密算法加解密

    2024-06-14 07:52:02       3 阅读
  7. VB.net与C# 调用InitializeComponent的区别

    2024-06-14 07:52:02       6 阅读
  8. HarmonyOS(35) @State使用注意事项

    2024-06-14 07:52:02       5 阅读