前言
官方Github
论文:STAEformer: Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting
STAEformer复现结果能够对的上论文实验结果,并且配置环境时没有遇到雷点。
一、创建虚拟环境
我创建了一个名称为STAEformer的虚拟环境(名称可以更改),并conda activate进入虚拟环境。
conda create -n STAEformer python==3.9.18
conda activate STAEformer
二、安装cuda、pytorch
官方代码仓要求pytorch>=1.11,这里我在虚拟环境安装cuda11.7和pytorch1.13.1及其相关依赖。
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
三、Required Packages
依赖包依次安装就完了。
pip install numpy
pip install pandas
pip install matplotlib
pip install pyyaml
pip install torchinfo
四、复现结果
cd model/
进入项目model文件夹下,运行
python train.py -d <dataset>
官方很贴心地把数据集都给整理好了,不用自己去找了。
<dataset>
:
- METRLA
- PEMSBAY
- PEMS03
- PEMS04
- PEMS07
- PEMS08
复现METRLA为例,测试结果为:
--------- Test ---------
All Steps RMSE = 5.93878, MAE = 2.92491, MAPE = 8.00552
Step 1 RMSE = 3.97519, MAE = 2.26739, MAPE = 5.49113
Step 2 RMSE = 4.65860, MAE = 2.49998, MAPE = 6.28363
Step 3 RMSE = 5.09899, MAE = 2.65260, MAPE = 6.84915
Step 4 RMSE = 5.45347, MAE = 2.76543, MAPE = 7.31178
Step 5 RMSE = 5.72631, MAE = 2.86540, MAPE = 7.71518
Step 6 RMSE = 5.96758, MAE = 2.95223, MAPE = 8.09941
Step 7 RMSE = 6.20411, MAE = 3.02974, MAPE = 8.39245
Step 8 RMSE = 6.37668, MAE = 3.09545, MAPE = 8.68591
Step 9 RMSE = 6.52967, MAE = 3.15964, MAPE = 8.96343
Step 10 RMSE = 6.69013, MAE = 3.21868, MAPE = 9.21720
Step 11 RMSE = 6.82990, MAE = 3.27001, MAPE = 9.42466
Step 12 RMSE = 6.95626, MAE = 3.32242, MAPE = 9.63251
Inference time: 4.86 s