一个可以对基于Pytorch搭建的模型的训练过程进行全程追踪的模块

本文所述的trace模块实现了对损失和准确率的全过程跟踪,并在生成损失与准确率统计图的同时实现损失、准确率和模型本身的同步存储,这使得即使训练间断,训练过程中的数据仍然可以被保留和呈现。模块源文件可在本项目Github仓库中获取。下文是对该模块的说明,源自本项目的README文件,由于笔者考试周临近,暂时没时间写个中文版的介绍,先开个贴占个坑吧。

A Pytorch Training Process Tracing Module

Xiangnan Zhang

School of Future Technology, Beijing Institute of Technology


This module called "trace" is used to trace whole traning process, as well as realize visualization of accuracy and loss.

All loss and accuracy can be traced via a Statistic object. During training and testing, these data will be appended to its list-format attributes. The model and traning data will
be saved and loaded at the same time, hence to guarantee the whole-traning-process tracing. While saving the model and traning data, line charts of loss and accuracy will be shown and saved as follows.

image

Importing

In terms of importing this module, you should code as follows:

import trace
from trace import Statistic

NOTICE: The class Statistic() should be imported separately, because it is the preriquisit of function load_statistic()and sys_load().

Details

Statistic(path)

Objects that belong to this class stores traning and testing loss and accuracy. So when you are initializing your model, you should create a Statistic like this:

statis=Statistic(statis_path)

its __init__()method will establish attributes self.train_loss, self.train_accuracy, self.test_lossand self.test_accuracy. Each of them are empty list, so you can
use .append()method to append values in your train_loop and test_loop functions, like this:

def test_loop(test_ds,model,loss_fn,statis):
    model.eval()
    ...
    statis.test_loss.append(test_loss)
    statis.test_accuracy.append(accuracy)
    ...

There are two methods for Statistic project called self.draw()and self.save, which can be used to draw statistical images and save Statistic object as pkl files. However, in most cases you should use Sys object’s .sys_conclude()method instead.

load_statistic(path)

This function is used to load a Statistic object. But in most cases, you should use sys_load()function instead.

sys_load(model_path,statis_path)

This function is used to load both model and Statistic object. It is highly recommended that you use this function to load these two items, because it can guarentee that model and Statistic object can be loaded at the same time.

Sys(model,model_path,statis)

This class aims to process model and Statistic object at the same time. You should create a Sys project after the model and Statistic object are loaded or iinitialized, like this:

syst=trace.Sys(model,model_path,statis)

Then you can use .sys_conclude()method to save both the model and Statistic object:

syst.sys_conclude("ConvM(4_categ)")

A string that represents the model should be given when using this method. When you need to save your model and Statistic object, this method is always highly recommended.


After saving, the model will be saved as a pth file, and the Statistic object will be saved as a pkl file. These suffixes should be included into file paths.

Deficiency

When using this module to trace data, train_loss will dramatically increase at the begining of a new training process, which can be seen in the front image. However, I’m not sure whether it is the module’s problem, or it is my model’s problem.

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2023-12-31 16:30:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2023-12-31 16:30:01       100 阅读
  3. 在Django里面运行非项目文件

    2023-12-31 16:30:01       82 阅读
  4. Python语言-面向对象

    2023-12-31 16:30:01       91 阅读

热门阅读

  1. Python面向对象编程

    2023-12-31 16:30:01       62 阅读
  2. Pytorch整体框架学习

    2023-12-31 16:30:01       71 阅读
  3. 判断素数的方法大全

    2023-12-31 16:30:01       63 阅读
  4. STL——常用算术生成算法

    2023-12-31 16:30:01       56 阅读
  5. STL——排序算法

    2023-12-31 16:30:01       57 阅读
  6. 【MySQL从入门到精通】常用SQL语句分享

    2023-12-31 16:30:01       72 阅读
  7. 蓝牙常见断开错误码解释

    2023-12-31 16:30:01       198 阅读