-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathdyffusion.yaml
57 lines (46 loc) · 3.3 KB
/
dyffusion.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# @package _global_
defaults:
- _base.yaml
- _self_
diffusion:
_target_: src.diffusion.dyffusion.DYffusion
# IMPORTANT: One of the following must be specified:
# Option 1. Interpolator pytorch module
interpolator: null
# Option 2. Wandb run id of the interpolator (set the desired file name in interpolator_wandb_ckpt_filename, e.g. "last.ckpt")
interpolator_run_id: null
interpolator_wandb_ckpt_filename: "best-val_20ens_mems_ipol_avg_crps.ckpt" # only used if interpolator_run_id is not None
# Option 3. Local path to the interpolator checkpoint (what is saved by, e.g., torch.save(model.state_dict(), path))
interpolator_local_checkpoint_path: null
# Contribution of the loss term to the total loss.
lambda_reconstruction: 0.5 # corresponds to the main loss term (diffusion loss)
lambda_reconstruction2: 0.5 # corresponds to the auxiliary loss term that simulates one step of the diffusion process
# How to condition the forecaster network. Options: "none", "data", "data+noise"
# If "data", condition on the input data (i.e., the initial conditions at time t0)
# If "data+noise", condition on a linear interpolation between the input data and a noise sample
# If "none", do not condition the forecaster network. The only input will be the output of the interpolator
forward_conditioning: "data"
# Using auxiliary diffusion steps (k>0 in the paper)
# The following parameters are only used if additional_interpolation_steps or additional_interpolation_steps_factor > 0
schedule: 'before_t1_only' # If 'before_t1_only', all auxiliary diffusion steps are added before t1
additional_interpolation_steps: 0 # k, how many additional diffusion steps to add. Only used if schedule='before_t1_only'
additional_interpolation_steps_factor: 0 # only use if schedule='linear'
interpolate_before_t1: True # Whether to interpolate before t1 too. Must be true if schedule='before_t1_only'
# Time encoding refers to the way the time is encoded for the forecaster network for a given diffusion step.
time_encoding: "dynamics" # Options: "dynamics", "discrete". Recommended: "dynamics", i.e. use actual timestep
# Enabling stochastic dropout in the interpolator is strongly recommended for better performance
enable_interpolator_dropout: True # Keep True!
# ---- Sampling related parameters:
# Sampling algorithm. Options: 'cold', 'naive'. Strongly recommended: 'cold'
sampling_type: 'cold'
# Accelerate sampling when k > 0, by using fewer diffusion steps by skipping some auxiliary diffusion steps
sampling_schedule: null # E.g. set to "every2nd" to skip every second auxiliary diffusion step. Only used if k > 0
# Whether to refine the intermediate interpolor predictions by re-running the interpolator (line 6 in Algorithm 1)
# It is recommended to set this to False during training
# At validation time you may set it to True to see if it improves the results
# python run.py mode=test diffusion.refine_intermediate_predictions=True logger.wandb.id=???
refine_intermediate_predictions: False
# Set to True to use the direct forecaster's prediction of x_{t+h} rather than a cold-sampled one (when sampling_type='cold')
use_cold_sampling_for_last_step: False
timesteps: ${datamodule.horizon} # Do not change, it is automatically inferred by DYffusion
log_every_t: null