下载NLP_gluedata数据集的脚本

import os
import sys
import shutil
import argparse
import tempfile
import io
import urllib.request as urllib
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK_PATH = {"CoLA":'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',
             "SST":'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',
             "QQP":'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',
             "STS":'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',
             "MNLI":'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',
             "QNLI":'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',
             "RTE":'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',
             "WNLI":'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',
             "diagnostic":'https://dl.fbaipublicfiles.com/glue/data/AX.tsv'}
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
#data_dir:数据保存目录,path_to_data:指向已提取 MRPC 数据的目录的路径
def format_mrpc(data_dir, path_to_data):
    print("处理 MRPC中...")
    mrpc_dir = os.path.join(data_dir, "mrpc")
    if not os.path.exists(mrpc_dir):
        os.mkdir(mrpc_dir)
    # 如果 path_to_data有值(即不为空),则使用该路径下的文件
    if path_to_data:
        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
    # 否则从url下载,并保存到mrpc_dir里
    else:
        try:
            mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
            mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
            urllib.urlretrieve(MRPC_TRAIN, mrpc_train_file)
            urllib.urlretrieve(MRPC_TEST, mrpc_test_file)
        except urllib.error.HTTPError:
            print("下载MRPC错误!")
            return
    assert os.path.isfile(mrpc_train_file), "训练数据没有找到! %s" % mrpc_train_file
    assert os.path.isfile(mrpc_test_file), "测试数据没有找到! %s" % mrpc_test_file
    # 代码读取测试集文件,并将其转换为特定的 TSV格式
    with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \
        io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:
            header = data_fh.readline()
            test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
            for idx, row in enumerate(data_fh):
                label, id1, id2, s1, s2 = row.strip().split('\t')
                test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
    # try:
    #     urllib.urlretrieve(TASK_PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
    # except KeyError or urllib.error.HTTPError:
    #     print("\tError downloading standard development IDs for MRPC. You will need to manually split your data.")
    #     return
    # dev_ids = []
    # with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:
    #     for row in ids_fh:
    #         dev_ids.append(row.strip().split('\t'))
     # io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:
    with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \
         io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh:
            header = data_fh.readline()
            train_fh.write(header)
            # dev_fh.write(header)
            for row in data_fh:
                label, id1, id2, s1, s2 = row.strip().split('\t')
                # if [id1, id2] in dev_ids:
                #     dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
                # else:
                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
    print("\t MRPC 完工!")
# 下载和提取方法,参数:task:要下载的数据集的名称,data_dir:数据应该被提取到的目录
def download_and_extract(task, data_dir):
    print("下载提取 %s..." % task)
    data_file = "%s.zip" % task
    # 下载文件,数据会被保存到 data_file
    urllib.urlretrieve(TASK_PATH[task], data_file)
    # 使用 zipfile 模块来解压下载的.zip文件,with语句确保 
    # zipfile.ZipFile 对象在使用后会被正确关闭。extractall
    # 方法用于将压缩文件的内容解压到 data_dir 指定的目录
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    # 使用 os.remove方法删除下载的.zip 文件,以节省存储空间
    os.remove(data_file)
    print("\t%s 完工!" %task)
#获取任务列表方法
def get_tasks(task_names):
    #生成任务名称列表
    task_names = task_names.split(',')
    print(task_names)
    if "all" in task_names:# 如果任务名称列表内有all,代表下载所有
        tasks = TASKS
    else:# 如果不是下载所有
        tasks = []# 那就构建任务列表
        for task_name in task_names:
            #断言任务名称在我们定义的任务里,没有就提示,没找到
            assert task_name in TASKS, "Task %s not found!" % task_name
            # 能到这里来,证明它在我们定义的任务列表里
            tasks.append(task_name)
    return tasks
# 下载diagnostic的方法
def download_diagnostic(data_dir):
    print("下载解析diagnostic...")
    if not os.path.exists(os.path.join(data_dir,"diagnostic")):
        os.mkdir(os.path.join(data_dir,"diagnostic"))
    data_file = os.path.join(data_dir,"diagnostic", "diagnostic.tsv")
    urllib.urlretrieve(TASK_PATH["diagnostic"],data_file)
    print("\t diagnostic完工!" )
    return
# argparse 用于处理命令行参数,os 用于与操作系统交互,例如检查目录是否存在或创建新目录
def main(arguments):#传入参数
    print(arguments)
    parser = argparse.ArgumentParser()# 设置命令行参数解析器
    # data_dir 或 -d :用于指定保存数据的目录。如果没有提供,默认值是 'glue_data'。
    parser.add_argument('-d', '--data_dir', help='要保存的文件夹路径:', type=str, default='glue_data')
    # tasks 或-t:用于指定要下载数据的任务,多个任务之间用逗号分隔。默认值是 'all',表示下载所有任务的数据。  
    parser.add_argument('-t', '--tasks', help='要下载哪些文件',type=str, default='all')
    #path_to_mrpc:用于指定包含已提取的MRPC数据的目录。
    parser.add_argument('--path_to_mrpc', help='要提取的MRPC路径',
                        type=str, default='')
    # 这行代码将解析命令行参数,并将结果存储在 args 对象中。
    args = parser.parse_args(arguments)
    # 如果指定的数据目录不存在,则创建它。
    if not os.path.exists(args.data_dir):
        os.mkdir(args.data_dir)
    # 获取要处理的任务列表
    tasks = get_tasks(args.tasks)
    # 对于每个任务,根据任务的名称执行不同的操作
    for task in tasks:
        # 如果任务是 'MRPC',则调用 `format_mrpc` 函数来格式化 MRPC 数据
        if task == 'MRPC':
            format_mrpc(args.data_dir, args.path_to_mrpc)
        # 如果任务是 'diagnostic',则调用 `download_diagnostic` 函数来下载相关数据
        elif task == 'diagnostic':
            download_diagnostic(args.data_dir)
        # 对于其他任务,调用 `download_and_extract` 函数来下载并提取数据
        else:
            download_and_extract(task, args.data_dir)
# 它返回 sys.argv 列表中从第二个元素开始到最后一个元素的所有元素。
# 这样做是为了跳过脚本名称本身,只获取传递给脚本的参数
sys.exit(main(sys.argv[1:]))


 

相关推荐

  1. 下载NLP_gluedata数据脚本

    2024-03-22 15:58:02       21 阅读
  2. FF++数据下载脚本代码

    2024-03-22 15:58:02       29 阅读
  3. 在hf-mirror下载数据方式

    2024-03-22 15:58:02       15 阅读
  4. 开源数据下载地址

    2024-03-22 15:58:02       40 阅读

最近更新

  1. 专业课笔记——(第十二章:文件的读写)

    2024-03-22 15:58:02       0 阅读
  2. lvs集群

    2024-03-22 15:58:02       0 阅读
  3. Perl 语言入门学习

    2024-03-22 15:58:02       0 阅读
  4. 大模型/NLP/算法面试题总结3——BERT和T5的区别?

    2024-03-22 15:58:02       1 阅读
  5. 单元测试核心类备忘

    2024-03-22 15:58:02       1 阅读

热门阅读

  1. 类似于 FastAdmin的快速后台开发框架都有哪些

    2024-03-22 15:58:02       20 阅读
  2. k8s工作节点主要模块

    2024-03-22 15:58:02       22 阅读
  3. 大数据开发(HBase真题)

    2024-03-22 15:58:02       17 阅读
  4. Puppet 2024年度报告:平台工程发掘 DevOps 无限潜质

    2024-03-22 15:58:02       20 阅读
  5. 后台发送GET/POST方法

    2024-03-22 15:58:02       16 阅读
  6. Qt Excel文件读写

    2024-03-22 15:58:02       21 阅读
  7. 9. Linux 信号详解

    2024-03-22 15:58:02       21 阅读
  8. 在Linux/Ubuntu/Debian中创建自己的命令快捷方式

    2024-03-22 15:58:02       21 阅读
  9. 以太网网络变压器

    2024-03-22 15:58:02       23 阅读