昇思25天学习打卡营第21天|RNN实现情感分类

课程打卡凭证

RNN模型

RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列(sequence)数据的神经网络模型。与传统的神经网络不同,RNN在处理序列中的每个元素时,都会考虑之前元素的信息,即RNN具有“记忆”功能。这使得RNN非常适合处理如文本、语音、时间序列等具有时序特性的数据。

RNN的基本结构包括输入层、隐藏层和输出层,但与传统的神经网络不同的是,RNN的隐藏层之间也存在连接,这些连接允许信息从序列的一个时间步传递到下一个时间步。具体来说,RNN在每个时间步都会接收一个输入,并基于当前的输入和隐藏层的上一个状态(即记忆)来计算新的隐藏状态,然后基于新的隐藏状态产生输出。其基本结构如下图所示。

尽管RNN在处理序列数据上表现出了强大的能力,但它也存在一些问题,如梯度消失或梯度爆炸,这限制了它学习长期依赖关系的能力。为了克服这些问题,研究者们提出了多种RNN的变体,如LSTM(详见昇思25天学习打卡营第20天|LSTM+CRF序列标注-CSDN博客)。它通过引入门控机制(遗忘门、输入门、输出门)来控制信息的传递,有效地解决了梯度消失或梯度爆炸的问题,从而能够学习长期依赖关系。具体结构如下图所示。

因此,这里选择LSTM变种来进行特征提取。

训练过程

数据准备

导入必要的库与模块。

下载IMDB影评数据集,这是一个经典的情感分类数据集。

数据加载

定义一个IMDBData类,用于加载和处理IMDB数据集。它将IMDB数据集中的影评文本加载并处理为一个Python的迭代对象。

加载训练集。

加载IMDB数据集,并将其转化为MindSpore的GeneratorDataset对象,以便于后续的训练和测试。

划分训练集和测试集。

加载预训练词向量

加载GloVe词向量,并将其转化为适用于MindSpore的数据结构。

下载并加载GloVe词向量,并输出词汇表的大小。

使用加载的GloVe词汇表和嵌入矩阵获取单词 "the" 的索引和对应的嵌入向量。

数据集预处理

文本数据的操作,包括词汇表查找、填充、类型转换操作。

对训练集和测试集进行文本转换和标签转换。

在原有的训练集中进一步手动划分训练集和验证集。

将训练集和测试集分别打包成批次大小为64的批次,丢弃最后一个批次中不足64个样本的部分。

模型构建

定义一个简单的双向LSTM模型,它包括Embedding层、LSTM层、全连接层,可以将输入的文本数据经过嵌入层和双向LSTM层的处理后,通过全连接层输出最终的分类预测结果。

损失函数与优化器

设置模型的结构参数(隐藏状态维度、输出维度、LSTM层数、是否双向等)和训练参数(学习率、填充标记索引等),初始化一个包含嵌入层、双向LSTM层和全连接层的RNN模型,并定义用于二分类任务的损失函数和Adam优化器。

训练逻辑

forward_fn执行前向传播计算损失;grad_fn计算损失函数对模型参数的梯度;train_step执行一步训练,包括计算梯度和更新模型参数;train_one_epoch执行整个训练过程的一个周期,迭代训练数据集并显示训练进度。

评估指标和逻辑

计算二分类任务中每个批次的准确率。

定义评估函数evaluate,用于评估模型在测试数据集上的性能。

模型训练与保存

模型加载与测试

查看模型在测试集上的效果。

自定义输入测试

由于只进行了两轮训练,可以发现模型效果并不好。

相关推荐

  1. 25学习24|RNN实现情感分类

    2024-07-17 15:52:04       20 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-17 15:52:04       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-17 15:52:04       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-17 15:52:04       58 阅读
  4. Python语言-面向对象

    2024-07-17 15:52:04       69 阅读

热门阅读

  1. ES6基本语法(一)

    2024-07-17 15:52:04       22 阅读
  2. 100道ajax面试题、练习题

    2024-07-17 15:52:04       21 阅读
  3. Flask与Django框架比较

    2024-07-17 15:52:04       20 阅读
  4. MPNN消息传递神经网络

    2024-07-17 15:52:04       25 阅读
  5. C# —— (左移 右移 异或 与 或 )运算规则

    2024-07-17 15:52:04       20 阅读
  6. 知识加油站

    2024-07-17 15:52:04       20 阅读
  7. 鹈鹕优化算法(POA)及其Python和MATLAB实现

    2024-07-17 15:52:04       22 阅读
  8. 互联网开发工作现状的深度剖析

    2024-07-17 15:52:04       20 阅读
  9. 用c语言写一个贪吃蛇游戏

    2024-07-17 15:52:04       25 阅读