【代码复现】STAEformer


前言

官方Github
论文:STAEformer: Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting

STAEformer复现结果能够对的上论文实验结果,并且配置环境时没有遇到雷点。

一、创建虚拟环境

我创建了一个名称为STAEformer的虚拟环境(名称可以更改),并conda activate进入虚拟环境。

conda create -n STAEformer python==3.9.18
conda activate STAEformer

二、安装cuda、pytorch

官方代码仓要求pytorch>=1.11,这里我在虚拟环境安装cuda11.7和pytorch1.13.1及其相关依赖。

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia

三、Required Packages

依赖包依次安装就完了。

pip install numpy
pip install pandas
pip install matplotlib
pip install pyyaml
pip install torchinfo

四、复现结果

cd model/进入项目model文件夹下,运行

python train.py -d <dataset>

官方很贴心地把数据集都给整理好了,不用自己去找了。
<dataset>:

  • METRLA
  • PEMSBAY
  • PEMS03
  • PEMS04
  • PEMS07
  • PEMS08

复现METRLA为例,测试结果为:

--------- Test ---------
All Steps RMSE = 5.93878, MAE = 2.92491, MAPE = 8.00552
Step 1 RMSE = 3.97519, MAE = 2.26739, MAPE = 5.49113
Step 2 RMSE = 4.65860, MAE = 2.49998, MAPE = 6.28363
Step 3 RMSE = 5.09899, MAE = 2.65260, MAPE = 6.84915
Step 4 RMSE = 5.45347, MAE = 2.76543, MAPE = 7.31178
Step 5 RMSE = 5.72631, MAE = 2.86540, MAPE = 7.71518
Step 6 RMSE = 5.96758, MAE = 2.95223, MAPE = 8.09941
Step 7 RMSE = 6.20411, MAE = 3.02974, MAPE = 8.39245
Step 8 RMSE = 6.37668, MAE = 3.09545, MAPE = 8.68591
Step 9 RMSE = 6.52967, MAE = 3.15964, MAPE = 8.96343
Step 10 RMSE = 6.69013, MAE = 3.21868, MAPE = 9.21720
Step 11 RMSE = 6.82990, MAE = 3.27001, MAPE = 9.42466
Step 12 RMSE = 6.95626, MAE = 3.32242, MAPE = 9.63251
Inference time: 4.86 s

相关推荐

  1. 代码STAEformer

    2024-07-14 06:10:04       22 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-14 06:10:04       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-14 06:10:04       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-14 06:10:04       58 阅读
  4. Python语言-面向对象

    2024-07-14 06:10:04       69 阅读

热门阅读

  1. python中的pickle模块和json模块

    2024-07-14 06:10:04       23 阅读
  2. ClickHouse实战第二章-ClickHouse 的安装调试

    2024-07-14 06:10:04       25 阅读
  3. Spring事件监听机制详解

    2024-07-14 06:10:04       22 阅读
  4. 案例:分库分表与SELECT * 发生的线上问题

    2024-07-14 06:10:04       24 阅读
  5. TypeScript的类型谓词与控制流分析

    2024-07-14 06:10:04       26 阅读
  6. ThreadLocal详解

    2024-07-14 06:10:04       22 阅读
  7. 小程序如何刷新当前页面

    2024-07-14 06:10:04       25 阅读
  8. qt 根据名称获取按钮,并添加点击事件

    2024-07-14 06:10:04       19 阅读
  9. Linux开发讲课37--- ARM的22个常用概念

    2024-07-14 06:10:04       27 阅读