-
借助工具下载 GraFITi 官方代码至
./GraFITi/
-
安装 tsdm
- 下载 tsdm 官方代码 至
./tsdm/
- 创建 conda 虚拟环境,注意
python=3.11
- 用
./GraFITi/tsdm
替换./tsdm/src/tsdm
- 将
./tsdm/src/tsdm/viz/_config.py
中的USE_TEX: Final[bool] = matplotlib.checkdep_usetex(True)
改为USE_TEX: Final[bool] = False
- 进入
./tsdm/
目录,执行pip install -e .
- 下载 tsdm 官方代码 至
-
修改
./GraFITi/train_grafiti.py
-
创建模型存储目录
if not os.path.exists('saved_models/'): os.makedirs('saved_models/')
-
修改优化器配置
OPTIMIZER_CONFIG = { "lr": ARGS.learn_rate, "betas": ARGS.betas, "weight_decay": ARGS.weight_decay, }
-
如果需要,添加
tqdm
打印进度条
-
-
进入
./GraFITi/
目录,运行如下命令运行官方示例,如果提示缺包自行安装即可
python train_grafiti.py --epochs 200 --learn-rate 0.001 --batch-size 128 --attn-head 1 --latent-dim 128 --nlayers 4 --dataset physionet2012 --fold 0 -ct 36 -ft 12
-
下载本项目
-
创建 conda 虚拟环境,注意
python=3.11
-
进入
tsdm-main
目录,执行pip install -e .
-
进入
./GraFITi/
目录,运行如下命令运行官方示例,如果提示缺包自行安装即可python train_grafiti.py --epochs 200 --learn-rate 0.001 --batch-size 128 --attn-head 1 --latent-dim 128 --nlayers 4 --dataset physionet2012 --fold 0 -ct 36 -ft 12