【AutoML】AutoKeras 训练数据收集并入库

今天开始将会不定期更新一些本人学习人工智能的笔记和感悟,由于本人在 python 和人工智能领域都属于新手,因此写得不好的地方请大家多多包涵。

上一篇文章中已简单讲解了如何在 VSCode 中搭建一个用于 AutoKeras 训练的环境,接下来的问题就是“训练的数据从哪里来?”经摸索总结出 3 种数据源获取方式:

  1. 通过爬虫爬取(需注意法律法规和目标网站是否存在反爬机制);
  2. 通过人工智能社区获取开放数据源,如:和鲸社区、飞桨 AI Studio、Hugging face、Kaggle等等。这些开放数据源一般都已整理过,只需稍加修改即可使用(要注意数据源的开放原则避免侵权);
  3. 到 Github 寻找(走投无路方案,一般技术相关的都可试试在上面寻找);

这次是工作需要开始接触人工智能,因此寻找的数据源是需要跟中药材相关的。市面上“医疗”、“医药”乃至“中医药”的数据源还是比较好找的,但“中药材”…就玩得有点儿偏了,几乎是没有现成的。找了一周“颗粒无收”,最终决定从“医疗”、“医药”和“中医药”数据源中尝试提取。

ok,现在方向确定了,接下来就是要决定要做什么样的人工智能了。想来想去,老板最能接受且快出效果的就是问答机器人了(类似 ChatGPT)。既然是问答机器人,那么训练的数据也应该优先考虑问答的数据源。上网搜索了一番决定使用以下 4 个数据源。

baike 数据集

获取地址:https://github.com/brightmart/nlp_chinese_corpus
baike 数据集提供了几乎所有领域的问答数据。本数据集中 train 数据总量就达到 1425170,valid 数据总量达到 44972。别看训练数据已达百万的量,里面中药、中医等领域相关数据很少。在数据入库前需要进行特殊处理,这部分我用 python 进行了一下处理,关键代码如下图:

...

##############
# description: 解析 baike 的 json 文件并进行批量保存
# param {*} file_path
# param {*} status
# return {*}
##############


def load_and_save_to_mysql(file_path, status):

    # 根据 status 字段判断使用那个表
    table_name = "baike_dataset_tmp"
    if status != "train":
        table_name = "baike_dataset_valid"

    # 获取 mysql 链接
    conn = mysql_util.get_connection()
    # 打开数据库连接游标
    cursor = conn.cursor()
    # 打开文件 io 读取
    with open(file_path, 'r') as f:
        # 初始化批次数组
        batch = []
        # 初始化计数器
        counter = 0
        # 遍历读取json文档并获取行
        for line in f:
            counter = counter+1
            # 记录一个批次的数据获取开始时间
            start_time = time.time()
            # 将行记录转换为 json 对象
            record = json.loads(line)

            for keyword in keywords:
                if keyword in record["category"]:
                    # 将数据加载到批次数组中
                    batch.append((record["qid"], record["category"],
                                 record["title"], record["desc"], record["answer"], 0))
                    break

            # 若已经到达批次条数上线
            if counter % batch_size == 0 and counter > 0:
                batch_insert_and_clean(
                    cursor, conn, status, table_name, start_time, counter, batch)
                # 重置批次数组
                batch = []

        # 若读取循环结束后发现批次数组内还存在数据,则进行最后一次批量插入
        if batch:
            batch_insert_and_clean(cursor, conn, status,
                                   table_name, start_time, counter, batch)
    cursor.close()
    mysql_util.close_connection(conn)
...

在执行之前还需要在 MySQL 数据库中建立对应的两个表 baike_dataset_train 和 baike_dataset_valid。这里需要注意的是,后续训练的数据量达到一定量级后就必须迁移到像 Elasticsearch、MongoDB 等其他 NoSQL 数据库中,一来是为了方便全文检索,二来“历史数据”并不需要频繁修改,对事务需求不高,因此数据存放 NoSQL 数据库中是较为合适的手段。

此外,本次训练将在本地先做验证,难免会有硬件不足的情况。所以在代码执行上将采用分批插入的方式降低内存和 CPU 的占用。

cMedQA2 数据集

获取地址:https://github.com/zhangsheng93/cMedQA2
cMedQA2 数据集来源于医渡云,是华人社区较大的医疗数据集。该数据集最大的特色是已经做了脱敏处理,且在 github 上这个数据集是支持 GPL-3.0 许可。本数据集提供两个 csv 文件,一个是 question,另一个是 answer。这两个 csv 文件要结合来使用,因此在数据入库后先创建两个表存放两类数据,之后再手动组装成一个表。其中数据入库的关键代码如下图:

...
def load_and_save_to_mysql(file_path, status):

    # 获取 mysql 链接
    conn = mysql_util.get_connection()
    # 打开数据库连接游标
    cursor = conn.cursor()
    # 打开文件 io 读取
    with open(file_path, 'r') as file:
        # 根据行来进行划分
        reader = csv.reader(file)
        # 初始化批次数组
        batch = []
        # 初始化循环计数器
        counter = 0
        # 剔除第一行
        next(reader)

        # 如果是问题类型的
        if status == "question":

            # 循环获取每行中的两个主要的字段
            for qid, content in reader:
                # 记录一个批次的数据获取开始时间
                start_time = time.time()
                # 将数据加载到批次数组中
                batch.append((qid, content, 0))
                # 循环计数器次数+1
                counter = counter+1

                # 若已经到达批次条数上线
                if counter % batch_size == 0 and counter > 0:
                    batch_insert_and_clean(cursor, conn, "INSERT INTO cmedqa_dataset_question_tmp VALUES (%s, %s, %s)", batch,
                                           start_time, counter, status)
                    # 重置批次数组
                    batch = []

            # 若读取循环结束后发现批次数组内还存在数据,则进行最后一次批量插入
            if batch:
                batch_insert_and_clean(cursor, conn, "INSERT INTO cmedqa_dataset_question_tmp VALUES (%s, %s, %s)", batch,
                                       start_time, counter, status)
        else:
            for aid, qid, content in reader:
                start_time = time.time()
                batch.append((qid, content, 0))
                counter = counter+1
                if counter % batch_size == 0 and counter > 0:
                    batch_insert_and_clean(cursor, conn, "INSERT INTO cmedqa_dataset_answer_tmp (QUESTION_ID,ANSWER,FLAG) VALUES ( %s, %s, %s)", batch,
                                           start_time, counter, status)
                    # 重置批次数组
                    batch = []

            if batch:
                batch_insert_and_clean(cursor, conn, "INSERT INTO cmedqa_dataset_answer_tmp (QUESTION_ID,ANSWER,FLAG) VALUES ( %s, %s, %s)", batch,
                                       start_time, counter, status)
    cursor.close()
    mysql_util.close_connection(conn)
...

总体来说跟 baike 的处理方式大同小异,都是分批插入数据的。但需要注意的是由于 csv 是有表头的,因此在遍历文件的时候要先用 next 函数跳过第一行。

像 baike 一样这里我们也需要先创建了两个表,一个是 cmedqa_dataset_question_train 表用于存放问题,数据量为 120000,另一个是 cmedqa_dataset_answer_train 表用于存放答案,数据量为 226266。最后以 cmedqa_dataset_answer_train 表作为基础表将两个表合并成一个新表 cmedqa_dataset。为此我们先创建 cmedqa_dataset 表如下图:

CREATE TABLE `cmedqa_dataset` (
  `ID` bigint NOT NULL AUTO_INCREMENT COMMENT '主键 id',
  `QUESTION` text CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '问题',
  `ANSWER` longtext CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '回答内容',
  `FLAG` int NOT NULL DEFAULT '0',
  PRIMARY KEY (`ID`)
) ENGINE=MyISAM DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Cmedqa2 数据集(用于数据训练)'

将 cmedqa_dataset_answer_train 表作为标准表 join cmedqa_dataset_question_train 表后提取 question 和 answer 字段然后批量插入到 cmedqa_dataset 表。如下图:

insert into cmedqa_dataset(QUESTION,ANSWER)
select b.QUESTION,a.answer
from cmedqa_dataset_answer_train a left join cmedqa_dataset_question_train b on a.question_id = b.id

paddle 数据集

获取地址:https://aistudio.baidu.com/datasetdetail/84360/0
这里的 paddle 指代的是飞桨 AI Studio平台,这次用到的是中文医学问答数据集。下载下来后是两个 txt 文件,一个是 train 数据集数量为 181012,另一个是 test 数据集数量为 45254。对于 txt 文件仍然采用 csv 工具进行解析,唯一不同的是本数据集中字段数据之间是通过 tab 制表符进行区分的,因此需要在读取的时候加上“\t” 进行处理。关键代码如下图:

...

# 获取 mysql 链接
conn = mysql_util.get_connection()
# 打开数据库连接游标
cursor = conn.cursor()
# 打开文件 io 读取
with open(file_path, 'r') as file:
    # 根据 tab 制表符来进行划分
    reader = csv.reader(file, delimiter='\t')
    # 初始化批次数组
    batch = []
    # 初始化循环计数器
    counter = 0
    
...

huatuo 数据集

获取地址:https://github.com/FreedomIntelligence/Huatuo-26M
最后一个是大名鼎鼎的华佗数据集,这次下载的是 26M 版本,最终入库时 train 数据量为 363420,test 数据量为 1000,valid 数据量为 1000。虽然数据集提供的是 jsonl 文件,但是可以当做是普通的 json 文件进行解析。如下图:

{
   "questions": [["肾皮质化脓性感染的诊断是什么?", "什么是肾皮质化脓性感染的诊断?"]], "answers": ["血液中白细胞总数和中性粒细胞升高,血培养可呈阳性。早期尿中无白细胞,当感染扩展到肾盂时,尿中可发现白细胞。尿培养的结果应与血培养相同。B超引导下穿刺抽脓培养可发现致病菌。影像学检查根据病变程度而有不同的表现。1.急性局灶性细菌性肾炎 腹平片常无明显异常。静脉尿路造影对诊断有一定帮助,少数病人可出现肾盂肾盏受压。B超检查示肾实质局灶性低回声区,边界不清。CT检查为低密度实质性肿块。增强后密度不均匀增强,仍低于正常肾组织,肿块边界不清,不同于肾皮质脓肿由新生血管形成的界限清楚的壁。有文献报道CT示肾实质局限性肿大并有多个层面肾筋膜增厚是该病定性诊断依据。2.肾皮质脓肿 腹部平片显示患侧肾脏增大,肾周围水肿使肾影模糊,腰大肌阴影不清楚或消失。当脓肿破裂到肾周围时,腰椎侧弯。静脉尿路造影可显示肾盂肾盏受压变形。B型超声:显示不规则的脓肿轮廓,脓肿为低回声区,或混合回声区,肾窦回声偏移,稍向肾边缘凸出。CT肾扫描显示肾皮质不规则低密度病灶,CT值介于囊肿和肿瘤之间,增强CT扫描边缘增强明显,中心部无增强。肾被膜、肾周筋膜增厚,与邻近组织界面消失。放射性核素肾扫描:显示肾占位病变,肾缺损区与肾囊肿相似,用67Ga可提示感染组织。"]}

每一行数据的格式都是这样,因此在获取数据时就要做适当的处理,如下图:

...
counter = counter+1
# 记录一个批次的数据获取开始时间
start_time = time.time()
# 将行记录转换为 json 对象
record = json.loads(line)
batch.append((record["questions"][0][0], record["answers"][0], 0))
...

由于 loads 函数已经将字符串加载成一个 json 对象了,因此在获取数据的时候需要 record[“questions”][0][0] 来获取嵌套的数据。

虽然说现在已经找到了大量的问答数据了,但是数据的真实性和可用性都是存疑的。像 baike 的数据都是来自于论坛的,你能够确定回答的那个人不是随便乱答的?像 cMedQA2 和 huatuo 数据集虽然医疗数据很多,但是中药材的数据究竟又有多少可用呢…这些都是未知之数。接下来就是要对这批数据进行数据清洗和数据整理了,这部分内容将会在另一篇文章中进行叙述。

最近更新

  1. TCP协议是安全的吗?

    2024-02-01 12:08:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-01 12:08:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-01 12:08:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-01 12:08:02       20 阅读

热门阅读

  1. 51单片机温湿度数据管理系统

    2024-02-01 12:08:02       33 阅读
  2. 【NGINX】NGINX如何阻止指定ip的请求

    2024-02-01 12:08:02       26 阅读
  3. 【issue-halcon例程学习】rim_simple.hdev

    2024-02-01 12:08:02       31 阅读
  4. 深度学习有何新进展?

    2024-02-01 12:08:02       32 阅读
  5. React和Vue实现路由懒加载

    2024-02-01 12:08:02       27 阅读
  6. SpringCloud

    2024-02-01 12:08:02       37 阅读
  7. 大数据之 Spark 比 MapReduce 快的原因

    2024-02-01 12:08:02       32 阅读