传知代码-事件因果提取论文复现(论文复现)

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

1. 论文概述

本文对论文进行复现:Event Causality Extraction with Event Argument Correlations
事件因果识别(ECI)旨在检测两个给定文本事件之间是否存在因果关系,这对于理解事件因果关系至关重要。然而,ECI任务忽略了关键的事件结构和因果关系组件信息,导致在下游应用中存在困难。因此,论文提出了一种新颖的任务,名为事件因果提取(ECE),旨在从纯文本中提取因果事件对及其结构化事件信息。ECE任务更具挑战性,因为每个事件可能包含多个事件参数,需要考虑事件之间的细粒度关联来确定因果事件对。因此,论文提出了一种采用双网格标注方案的方法,以捕捉ECE的事件内部和事件间参数之间的关联。此外,他们设计了一种事件类型增强的模型架构,以实现双网格标注方案。实验证明了该方法的有效性,并进行了广泛的分析,指出了ECE的若干未来研究方向。
本次代码复现在原有代码的基础上已经下载好bert模型参数,直接训练就好

2. 论文方法

2.1 任务形式化

事件因果提取(ECE)旨在从纯文本中推导出因果事件对。在这里,一个因果事件对包含一个因果组件和一个结果组件,每个组件表示具有特定事件类型及其事件参数和事件角色的事件。给定一段文本,事件因果提取系统需要预测出其中所有的因果事件对,如图1所示.
在这里插入图片描述

2.2 模型

在这里插入图片描述

2.2.1 编码层

该层派生了句子中单词的上下文表示和事件类型。为了方便后续的事件参数预测,我们打算进行事件类型感知的编码,将文本表示与事件类型信息相结合。具体来说,我们将事件类型连接到句子前面,并使用BERT进行编码,这是因为它具有深度自注意力架构。输入序列的组织形式如下所示
在这里插入图片描述

2.2.2 网络表示层

每个网格中的每个条目分别模拟了一个标记与一个事件类型之间的关系,用于事件参数推导。对于连接第j个事件类型ej和句子中第i个标记的条目,其表示gji可以通过融合函数得到,通过整合ti和ej的语义来获得。直观地说,可以通过各种语义融合方式实现,包括连接或加法。考虑到相同的事件参数跨度在不同的事件类型中可能扮演不同的角色,因此事件参数的决策应该取决于事件类型。因此,应该表明事件类型和标记之间的条件依赖关系。因此,我们采用了条件层归一化(CLN)来实现。CLN主要基于层归一化,但是它根据先前的条件动态计算增益和偏差,而不是直接将它们作为可学习参数部署在神经网络中。给定事件类型表示ej作为条件和标记表示hi,通过CLN实现融合函数如下:
在这里插入图片描述

采用两个语义融合函数,c 和r,分别为因果和效果网格表派生条目表示。每个语义融合函数都由一层CLN实现,因此条目表示为:
在这里插入图片描述

2.2.3 训练和推断

由于每个表中的可以同时分配多个标记,我们对条目表示进行多标签分类。具体来说,一个全连接网络预测了每个标记的概率:
在这里插入图片描述

3.实验部分

3.1 数据集

在中国知识图谱和语义计算会议2021(CCKS2021)发布的语料库上进行实验。该语料库来自公共新闻和报道,包含7000个句子,平均长度为104个标记。它标注了15,816个事件,其中包含7908个因果事件对,涵盖了39种事件类型和3种事件角色,即产品、地区和行业。为了适应这个语料库的ECE任务,根据因果事件类型将其分成训练/验证/测试集。具体来说,CCKS2021被划分为训练/验证/测试集,比例为8:1:1。将拆分的数据集命名为ECE-CCKS。
在这里插入图片描述

3.2 实验步骤

step1:安装环境依赖

  • torch 1.7.1+cu110
  • transformers 4.5.1

step2:创建名为"log"的目录,并切换到名为"src"的目录中

mkdir log
cd ./src/

step3:训练

python train.py --task_name ece_task --training 1 --debug 0 --hidden_size 768
在这里插入图片描述

step4:推理

python train.py --task_name ece_task --training 0 --debug 0 --hidden_size 768 --model_name model_name

3.3 实验结果

在这里插入图片描述
在这里插入图片描述

4. 核心代码

# start
class Model(nn.Module):
    def __init__(self, args):
        super(Model, self).__init__()

        self.hidden_size = args.hidden_size

        self.bert_embedding = BertModel.from_pretrained(args.bert_path)
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_path)
        
        # Grid Representation and Classification for the Cause Table
        self.gridmodel = GridModel(hidden_size=self.hidden_size, type_num=int(len(tt_map) // 2), dropout=args.dropout)
        # Grid Representation and Classification for the Effect Table
        self.gridmodel2 = GridModel(hidden_size=self.hidden_size, type_num=int(len(tt_map) // 2), dropout=args.dropout)
        self.dp = nn.Dropout( args.dropout  )
        
        self.thresh = args.thresh

    def encoding(self, input_ids, segment_ids, input_masks):

        out = self.bert_embedding(input_ids=input_ids, attention_mask=input_masks, token_type_ids=segment_ids)
        input_embs = out.last_hidden_state
        input_embs = self.dp( input_embs  )
        return input_embs

    def obtain_embs(self, input_embs_, label_indexs ):

        etype_embs = torch.index_select(input_embs_, dim = 1, index = label_indexs[0])
        input_embs = input_embs_[:, -max_seq_len: , :]
        return input_embs, etype_embs


    def run(self, input_ids, segment_ids, input_maks, label_indexs):
        batch_size = input_ids.shape[0]

        input_embs_  = self.encoding(input_ids, segment_ids, input_maks)
        input_embs, etype_embs = self.obtain_embs(input_embs_, label_indexs)
        # input_embs: [batch, seq, dim]
        # etype_embs: [batch, type_num, dim]

        tt_outputs_1 = self.gridmodel(input_embs, etype_embs)
        tt_outputs_2 = self.gridmodel2(input_embs, etype_embs)
        # Concat the output of two tables to derive the final loss
        tt_outputs = torch.sigmoid(torch.cat([tt_outputs_1, tt_outputs_2], dim = -1))
        return tt_outputs


    def forward(self, input_ids, segment_ids, input_maks, label_indexs):

        tt_outputs = self.run(input_ids, segment_ids, input_maks, label_indexs)
        return tt_outputs

    
    def inference(self, text_ids, input_ids, segment_ids, input_maks, label_indexs):
        
        result = {'text_id': text_ids[0], 'result': []}
        batch_size = input_ids.shape[0]

        tt_outputs_ = self.run(input_ids, segment_ids, input_maks, label_indexs)
        tt_outputs = tt_outputs_.squeeze(0).detach().cpu().numpy() # [type_num, seq, type_nums]
        
        # Decode ent
        input_ids = input_ids.squeeze(0)[-max_seq_len:]
        sent_len = torch.sum(input_maks.squeeze(0)[-max_seq_len:]).item()
        heads, tails, iids = np.where(tt_outputs > self.thresh)
        ent_dict, event_ent_dict = {}, {}
        reason_dict, result_dict = {}, {}
        

        for (etype_id, token_id, iid) in list(zip(heads, tails, iids)):
            # etype_id: index along the column, namely event types
            # token_id: index along the row, namely the position of token
            # iid: index of the predefined tags
            
            etype = etype_id2type[etype_id]
            tag_type = tt_id2type[iid]
            tag, ent_type, ent_pos = tag_type.split('-')

            
            ### Step1: Argument Span Decoding ###
            # In the Cause Table
            if (tt_map['Rea2Rea-product-H'] <= iid <= tt_map['Rea2Rea-industry-T']) or (tt_map['Rea2Res-product-H'] <= iid <= tt_map['Rea2Res-industry-T']):
                if etype not in reason_dict:
                    reason_dict[etype] = { 'reason':{'product': {'H': [], 'T': []},  'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}}, 
                                           'result':{'product': {'H': [], 'T': []},  'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}}}
                if tag == 'Rea2Rea':
                    reason_dict[etype]['reason'][ent_type][ent_pos].append(token_id) 
                elif tag == 'Rea2Res':
                    reason_dict[etype]['result'][ent_type][ent_pos].append(token_id) 
            # In the Effect Table
            elif (tt_map['Res2Res-product-H'] <= iid <= tt_map['Res2Res-industry-T']) or (tt_map['Res2Rea-product-H'] <= iid <= tt_map['Res2Rea-industry-T']):
                if etype not in result_dict:
                    result_dict[etype] = {'reason':{'product': {'H': [], 'T': []},  'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}}, 
                                          'result':{'product': {'H': [], 'T': []},  'region': {'H': [], 'T': []}, 'industry': {'H': [], 'T': []}}}
                if tag == 'Res2Rea':
                    result_dict[etype]['reason'][ent_type][ent_pos].append(token_id) 
                elif tag == 'Res2Res':
                    # pdb.set_trace()
                    result_dict[etype]['result'][ent_type][ent_pos].append(token_id) 
 
        
        reason_ent_dict, result_ent_dict = {}, {}
        # In the cause table
        for etype in reason_dict:
            reason_ent_dict[etype] = {'reason': [], 'result': []}
            for tag in reason_dict[etype]: # tag: reason(Intra) / result(Inter)
                for key in reason_dict[etype][tag]: # key: product / region / industry
                    for ent_hid in reason_dict[etype][tag][key]['H']:
                        ent_tid_list = [ii for ii in reason_dict[etype][tag][key]['T'] if ii >= ent_hid]
                        if len(ent_tid_list) > 0:
                            ent_tid = min(ent_tid_list)
                            if max(ent_hid, ent_tid) < sent_len:
                                ent_text = "".join( self.tokenizer.convert_ids_to_tokens( input_ids[ent_hid: ent_tid + 1] ) )
                                reason_ent_dict[etype][tag].append((ent_text, key))
        # In the effect table
        for etype in result_dict:
            result_ent_dict[etype] = {'reason': [], 'result': []}
            for tag in result_dict[etype]: # tag: reason(Inter) / result(Intra)
                for key in result_dict[etype][tag]: # key: product / region / industry
                    for ent_hid in result_dict[etype][tag][key]['H']:
                        ent_tid_list = [ii for ii in result_dict[etype][tag][key]['T'] if ii >= ent_hid]                        
                        if len(ent_tid_list) > 0:
                            ent_tid = min(ent_tid_list)
                            if max(ent_hid, ent_tid) < sent_len:
                                ent_text = "".join( self.tokenizer.convert_ids_to_tokens( input_ids[ent_hid: ent_tid + 1] ) )
                                result_ent_dict[etype][tag].append((ent_text, key))

        
        ### Step3: Decode event pair ###
        for reason_type in reason_ent_dict:
            for result_type in result_ent_dict:
                reason_args = [item for item in reason_ent_dict[reason_type]['reason'] if item in result_ent_dict[result_type]['reason']]
                result_args = [item for item in result_ent_dict[result_type]['result'] if item in reason_ent_dict[reason_type]['result']]
                if max( len(reason_args), len(result_args)) != 0:
                    rr_pair = {'reason_type': reason_type, 'result_type': result_type, 
                        'reason_product': set(), 'reason_region': set(), 'reason_industry': set(), 
                        'result_product': set(), 'result_region': set(), 'result_industry': set()}

                    for item in reason_args:
                        ent_text, ent_type = item[-2], 'reason_' + item[-1]
                        rr_pair[ent_type].add(ent_text)
                    for item in result_args:
                        ent_text, ent_type = item[-2], 'result_' + item[-1]
                        rr_pair[ent_type].add(ent_text)

                    for key in ['reason_product', 'reason_region', 'reason_industry', 'result_product', 'result_region', 'result_industry']:
                        rr_pair[key] = ",".join(list(rr_pair[key])) if len(rr_pair[key]) != 0 else ""
                    result['result'].append(rr_pair)         
        return result

源码下载

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-12 03:30:02       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-12 03:30:02       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-12 03:30:02       58 阅读
  4. Python语言-面向对象

    2024-07-12 03:30:02       69 阅读

热门阅读

  1. linux:vi命令

    2024-07-12 03:30:02       17 阅读
  2. vagrant远程连接不上问题

    2024-07-12 03:30:02       19 阅读
  3. Android Gradle开发与应用(一): Gradle基础

    2024-07-12 03:30:02       22 阅读
  4. Android Gradle 开发与应用 (八): Gradle 与持续集成(CI)

    2024-07-12 03:30:02       22 阅读
  5. 宪法学学习笔记(个人向) Part.3

    2024-07-12 03:30:02       18 阅读
  6. 【Unity】RPG2D龙城纷争(十)战斗系统之角色战斗

    2024-07-12 03:30:02       22 阅读
  7. DP学习——策略模式

    2024-07-12 03:30:02       18 阅读
  8. UNIAPP 使用地图 百度 高德 腾讯地图路线轨迹

    2024-07-12 03:30:02       21 阅读
  9. 理解李彦宏的“不卷模型,卷应用”理念

    2024-07-12 03:30:02       23 阅读