昇思训练营打卡第二十四天(LSTM+CRF序列标注)

LSTM(Long Short-Term Memory,长短时记忆网络)是一种特殊的循环神经网络(RNN),由Hochreiter和Schmidhuber在1997年提出。它旨在解决传统RNN在处理长距离依赖问题时出现的梯度消失和梯度爆炸问题。以下是LSTM的一些主要特点:

  1. 细胞状态(Cell State):LSTM的核心是细胞状态,它贯穿整个LSTM网络,使得信息可以在网络中长时间传递。

  2. 门结构(Gates):LSTM通过门结构来控制信息的流入、流出和遗忘。主要包括以下三个门:

    • 遗忘门(Forget Gate):决定从细胞状态中丢弃什么信息。
    • 输入门(Input Gate):决定哪些新的信息被存储在细胞状态中。
    • 输出门(Output Gate):决定从细胞状态中输出什么信息。
  3. 梯度消失和梯度爆炸问题的缓解:由于LSTM的门结构,它能够在长序列中保持稳定的梯度,从而有效地缓解梯度消失和梯度爆炸问题。

CRF(Conditional Random Field,条件随机场)是一种概率图模型,常用于自然语言处理(NLP)中的序列标注任务,如词性标注、命名实体识别(NER)等。CRF模型能够考虑上下文信息,对序列中的每个元素(如词语)进行标注,使得整个序列的标注结果尽可能合理。

以下是CRF的一些关键特点:

  1. 无向图:CRF是一种无向图模型,它假设序列中的每个元素与其相邻元素之间存在依赖关系。

  2. 条件概率:CRF模型定义了一个条件概率分布,它表示在给定输入序列的情况下,输出标签序列的概率。

  3. 特征函数:CRF模型使用特征函数来描述输入与输出之间的关系。特征函数可以是基于输入序列和输出标签的任意函数,例如当前词、前后词、词性等。

  4. 全局最优标注:CRF在预测时会考虑整个序列的信息,寻找全局最优的标签序列,而不是单独对每个元素进行最优标注。

  5. 训练与解码:CRF的训练通常使用最大似然估计,而解码(即预测)时通常使用维特比算法(Viterbi algorithm)来找到最有可能的标签序列。

CRF模型的结构通常包括以下两个部分:

  • 发射特征(Emission Features):这些特征与输入序列中的每个元素相关,描述了输入元素与其对应标签的关系。
  • 转移特征(Transition Features):这些特征描述了标签序列中相邻标签之间的关系。

CRF模型的优势在于它能够有效地利用上下文信息,并且能够通过特征函数灵活地定义输入与输出之间的关系。这使得CRF在处理序列标注问题时通常比其他模型(如基于规则的模型或简单的概率模型)表现得更好。

Score计算

def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):
    # emissions: (seq_length, batch_size, num_tags)
    # tags: (seq_length, batch_size)
    # mask: (seq_length, batch_size)

    seq_length, batch_size = tags.shape
    mask = mask.astype(emissions.dtype)

    # 将score设置为初始转移概率
    # shape: (batch_size,)
    score = start_trans[tags[0]]
    # score += 第一次发射概率
    # shape: (batch_size,)
    score += emissions[0, mnp.arange(batch_size), tags[0]]

    for i in range(1, seq_length):
        # 标签由i-1转移至i的转移概率(当mask == 1时有效)
        # shape: (batch_size,)
        score += trans[tags[i - 1], tags[i]] * mask[i]

        # 预测tags[i]的发射概率(当mask == 1时有效)
        # shape: (batch_size,)
        score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]

    # 结束转移
    # shape: (batch_size,)
    last_tags = tags[seq_ends, mnp.arange(batch_size)]
    # score += 结束转移概率
    # shape: (batch_size,)
    score += end_trans[last_tags]

    return score

Normalizer计算

def compute_normalizer(emissions, mask, trans, start_trans, end_trans):
    # emissions: (seq_length, batch_size, num_tags)
    # mask: (seq_length, batch_size)

    seq_length = emissions.shape[0]

    # 将score设置为初始转移概率,并加上第一次发射概率
    # shape: (batch_size, num_tags)
    score = start_trans + emissions[0]

    for i in range(1, seq_length):
        # 扩展score的维度用于总score的计算
        # shape: (batch_size, num_tags, 1)
        broadcast_score = score.expand_dims(2)

        # 扩展emission的维度用于总score的计算
        # shape: (batch_size, 1, num_tags)
        broadcast_emissions = emissions[i].expand_dims(1)

        # 根据公式(7),计算score_i
        # 此时broadcast_score是由第0个到当前Token所有可能路径
        # 对应score的log_sum_exp
        # shape: (batch_size, num_tags, num_tags)
        next_score = broadcast_score + trans + broadcast_emissions

        # 对score_i做log_sum_exp运算,用于下一个Token的score计算
        # shape: (batch_size, num_tags)
        next_score = ops.logsumexp(next_score, axis=1)

        # 当mask == 1时,score才会变化
        # shape: (batch_size, num_tags)
        score = mnp.where(mask[i].expand_dims(1), next_score, score)

    # 最后加结束转移概率
    # shape: (batch_size, num_tags)
    score += end_trans
    # 对所有可能的路径得分求log_sum_exp
    # shape: (batch_size,)
    return ops.logsumexp(score, axis=1)

Viterbi算法

def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
    # emissions: (seq_length, batch_size, num_tags)
    # mask: (seq_length, batch_size)

    seq_length = mask.shape[0]

    score = start_trans + emissions[0]
    history = ()

    for i in range(1, seq_length):
        broadcast_score = score.expand_dims(2)
        broadcast_emission = emissions[i].expand_dims(1)
        next_score = broadcast_score + trans + broadcast_emission

        # 求当前Token对应score取值最大的标签,并保存
        indices = next_score.argmax(axis=1)
        history += (indices,)

        next_score = next_score.max(axis=1)
        score = mnp.where(mask[i].expand_dims(1), next_score, score)

    score += end_trans

    return score, history

def post_decode(score, history, seq_length):
    # 使用Score和History计算最佳预测序列
    batch_size = seq_length.shape[0]
    seq_ends = seq_length - 1
    # shape: (batch_size,)
    best_tags_list = []

    # 依次对一个Batch中每个样例进行解码
    for idx in range(batch_size):
        # 查找使最后一个Token对应的预测概率最大的标签,
        # 并将其添加至最佳预测序列存储的列表中
        best_last_tag = score[idx].argmax(axis=0)
        best_tags = [int(best_last_tag.asnumpy())]

        # 重复查找每个Token对应的预测概率最大的标签,加入列表
        for hist in reversed(history[:seq_ends[idx]]):
            best_last_tag = hist[idx][best_tags[-1]]
            best_tags.append(int(best_last_tag.asnumpy()))

        # 将逆序求解的序列标签重置为正序
        best_tags.reverse()
        best_tags_list.append(best_tags)

    return best_tags_list

CRF层

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform

def sequence_mask(seq_length, max_length, batch_first=False):
    """根据序列实际长度和最大长度生成mask矩阵"""
    range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
    result = range_vector < seq_length.view(seq_length.shape + (1,))
    if batch_first:
        return result.astype(ms.int64)
    return result.astype(ms.int64).swapaxes(0, 1)

class CRF(nn.Cell):
    def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.reduction = reduction
        self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
        self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
        self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')

    def construct(self, emissions, tags=None, seq_length=None):
        if tags is None:
            return self._decode(emissions, seq_length)
        return self._forward(emissions, tags, seq_length)

    def _forward(self, emissions, tags=None, seq_length=None):
        if self.batch_first:
            batch_size, max_length = tags.shape
            emissions = emissions.swapaxes(0, 1)
            tags = tags.swapaxes(0, 1)
        else:
            max_length, batch_size = tags.shape

        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)

        mask = sequence_mask(seq_length, max_length)

        # shape: (batch_size,)
        numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
        # shape: (batch_size,)
        denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
        # shape: (batch_size,)
        llh = denominator - numerator

        if self.reduction == 'none':
            return llh
        if self.reduction == 'sum':
            return llh.sum()
        if self.reduction == 'mean':
            return llh.mean()
        return llh.sum() / mask.astype(emissions.dtype).sum()

    def _decode(self, emissions, seq_length=None):
        if self.batch_first:
            batch_size, max_length = emissions.shape[:2]
            emissions = emissions.swapaxes(0, 1)
        else:
            batch_size, max_length = emissions.shape[:2]

        if seq_length is None:
            seq_length = mnp.full((batch_size,), max_length, ms.int64)

        mask = sequence_mask(seq_length, max_length)

        return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)

BiLSTM+CRF模型

class BiLSTM_CRF(nn.Cell):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
        self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')
        self.crf = CRF(num_tags, batch_first=True)

    def construct(self, inputs, seq_length, tags=None):
        embeds = self.embedding(inputs)
        outputs, _ = self.lstm(embeds, seq_length=seq_length)
        feats = self.hidden2tag(outputs)

        crf_outs = self.crf(feats, tags, seq_length)
        return crf_outs
embedding_dim = 16
hidden_dim = 32

training_data = [(
    "清 华 大 学 ".split(),
    "B I I I ".split()
), (
    "重 庆 ".split(),
    "B I ".split()
)]

word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)

tag_to_idx = {"B": 0, "I": 1, "O": 2}

len(word_to_idx)
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)

def train_step(data, seq_length, label):
    loss, grads = grad_fn(data, seq_length, label)
    optimizer(grads)
    return loss
def prepare_sequence(seqs, word_to_idx, tag_to_idx):
    seq_outputs, label_outputs, seq_length = [], [], []
    max_len = max([len(i[0]) for i in seqs])

    for seq, tag in seqs:
        seq_length.append(len(seq))
        idxs = [word_to_idx[w] for w in seq]
        labels = [tag_to_idx[t] for t in tag]
        idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
        labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
        seq_outputs.append(idxs)
        label_outputs.append(labels)

    return ms.Tensor(seq_outputs, ms.int64), \
            ms.Tensor(label_outputs, ms.int64), \
            ms.Tensor(seq_length, ms.int64)
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
data.shape, label.shape, seq_length.shape
from tqdm import tqdm

steps = 500
with tqdm(total=steps) as t:
    for i in range(steps):
        loss = train_step(data, seq_length, label)
        t.set_postfix(loss=loss)
        t.update(1)
score, history = model(data, seq_length)
score
predict = post_decode(score, history, seq_length)
predict
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}

def sequence_to_tag(sequences, idx_to_tag):
    outputs = []
    for seq in sequences:
        outputs.append([idx_to_tag[i] for i in seq])
    return outputs
sequence_to_tag(predict, idx_to_tag)

相关推荐

  1. 训练第二(LSTM+CRF序列标注

    2024-07-13 19:52:02       17 阅读
  2. 训练第二(RNN实现情感分类)

    2024-07-13 19:52:02       17 阅读
  3. 25学习第18 | LSTM+CRF序列标注

    2024-07-13 19:52:02       19 阅读
  4. 25学习第23|LSTM+CRF序列标注

    2024-07-13 19:52:02       22 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-13 19:52:02       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-13 19:52:02       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-13 19:52:02       58 阅读
  4. Python语言-面向对象

    2024-07-13 19:52:02       69 阅读

热门阅读

  1. Nginx 日志统计分析命令

    2024-07-13 19:52:02       21 阅读
  2. 天童美语:放假给孩子看什么地理纪录片

    2024-07-13 19:52:02       17 阅读
  3. Perl 语言开发(十三):网络编程

    2024-07-13 19:52:02       22 阅读
  4. 块设备驱动实现--模拟一个块设备

    2024-07-13 19:52:02       16 阅读
  5. Docker

    2024-07-13 19:52:02       15 阅读
  6. docker

    2024-07-13 19:52:02       20 阅读
  7. qint64 pendingDatagramSize() const;

    2024-07-13 19:52:02       20 阅读
  8. ThreadLocal有哪些应用场景?底层如何实现?

    2024-07-13 19:52:02       21 阅读
  9. IPython:提升Python编程效率的实用技巧与案例

    2024-07-13 19:52:02       19 阅读
  10. 赋值运算符.二

    2024-07-13 19:52:02       18 阅读
  11. 数据结构第25节 深度优先搜索

    2024-07-13 19:52:02       16 阅读
  12. Python面试题:如何在 Python 中发送 HTTP 请求?

    2024-07-13 19:52:02       18 阅读
  13. ThreadLocal使用的场景有哪些?

    2024-07-13 19:52:02       18 阅读