├── .gitignore
├── Config
├── config.yaml
├── eeg.yaml
├── energy.yaml
├── etth.yaml
├── fmri.yaml
├── mujoco.yaml
├── mujoco_sssd.yaml
├── sines.yaml
├── solar.yaml
├── solar_update.yaml
└── stocks.yaml
├── Data
├── Place the dataset here
├── build_dataloader.py
└── readme.md
├── Experiments
├── eeg_multiple_classes.ipynb
├── metric_pytorch.ipynb
├── metric_tensorflow.ipynb
├── mujoco_imputation.ipynb
├── solar_nips_forecasting.ipynb
└── view_interpretability.ipynb
├── LICENSE
├── Models
├── interpretable_diffusion
│ ├── classifier.py
│ ├── gaussian_diffusion.py
│ ├── model_utils.py
│ └── transformer.py
└── ts2vec
│ ├── models
│ ├── dilated_conv.py
│ ├── encoder.py
│ └── losses.py
│ ├── ts2vec.py
│ └── utils.py
├── README.md
├── Tutorial_0.ipynb
├── Tutorial_1.ipynb
├── Tutorial_2.ipynb
├── Utils
├── Data_utils
│ ├── eeg_dataset.py
│ ├── mujoco_dataset.py
│ ├── real_datasets.py
│ └── sine_dataset.py
├── context_fid.py
├── cross_correlation.py
├── discriminative_metric.py
├── imputation_utils.py
├── io_utils.py
├── masking_utils.py
├── metric_utils.py
└── predictive_metric.py
├── engine
├── logger.py
├── lr_sch.py
└── solver.py
├── figures
├── Flowchart.svg
├── fig1.jpg
├── fig2.jpg
├── fig3.jpg
└── fig4.jpg
├── main.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/Config/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 160
5 | feature_size: 5
6 | n_layer_enc: 1
7 | n_layer_dec: 2
8 | d_model: 64 # 4 X 16
9 | timesteps: 200
10 | sampling_timesteps: 200
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.1
16 | resid_pd: 0.1
17 | kernel_size: 5
18 | padding_size: 2
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 1000
23 | results_folder: ./Checkpoints_syn
24 | gradient_accumulate_every: 2
25 | save_cycle: 100 # max_epochs // 10
26 | ema:
27 | decay: 0.99
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 200
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 100
40 | verbose: False
41 |
--------------------------------------------------------------------------------
/Config/eeg.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 14
6 | n_layer_enc: 3
7 | n_layer_dec: 2
8 | d_model: 64 # 4 X 16
9 | timesteps: 500
10 | sampling_timesteps: 100
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | classifier:
21 | target: Models.interpretable_diffusion.classifier.Classifier
22 | params:
23 | seq_length: 24
24 | feature_size: 14
25 | num_classes: 2
26 | n_layer_enc: 3
27 | n_embd: 64 # 4 X 16
28 | n_heads: 4
29 | mlp_hidden_times: 4
30 | attn_pd: 0.0
31 | resid_pd: 0.0
32 | max_len: 24 # seq_length
33 | num_head_channels: 8
34 |
35 | solver:
36 | base_lr: 1.0e-5
37 | max_epochs: 12000
38 | results_folder: ./Checkpoints_eeg
39 | gradient_accumulate_every: 2
40 | save_cycle: 1200 # max_epochs // 10
41 | ema:
42 | decay: 0.995
43 | update_interval: 10
44 |
45 | scheduler:
46 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
47 | params:
48 | factor: 0.5
49 | patience: 3000
50 | min_lr: 1.0e-5
51 | threshold: 1.0e-1
52 | threshold_mode: rel
53 | warmup_lr: 8.0e-4
54 | warmup: 500
55 | verbose: False
56 |
57 | dataloader:
58 | train_dataset:
59 | target: Utils.Data_utils.eeg_dataset.EEGDataset
60 | params:
61 | data_root: ../Data/datasets/EEG_Eye_State.arff
62 | window: 24 # seq_length
63 | save2npy: True
64 | neg_one_to_one: True
65 | period: train
66 |
67 | batch_size: 128
68 | sample_size: 256
69 | shuffle: True
--------------------------------------------------------------------------------
/Config/energy.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 28
6 | n_layer_enc: 4
7 | n_layer_dec: 3
8 | d_model: 96 # 4 X 24
9 | timesteps: 1000
10 | sampling_timesteps: 1000
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 25000
23 | results_folder: ./Checkpoints_energy
24 | gradient_accumulate_every: 2
25 | save_cycle: 2500 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 5000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
42 | dataloader:
43 | train_dataset:
44 | target: Utils.Data_utils.real_datasets.CustomDataset
45 | params:
46 | name: energy
47 | proportion: 1.0 # Set to rate < 1 if training conditional generation
48 | data_root: ./Data/datasets/energy_data.csv
49 | window: 24 # seq_length
50 | save2npy: True
51 | neg_one_to_one: True
52 | seed: 123
53 | period: train
54 |
55 | test_dataset:
56 | target: Utils.Data_utils.real_datasets.CustomDataset
57 | params:
58 | name: energy
59 | proportion: 0.9 # rate
60 | data_root: ./Data/datasets/energy_data.csv
61 | window: 24 # seq_length
62 | save2npy: True
63 | neg_one_to_one: True
64 | seed: 123
65 | period: test
66 | style: separate
67 | distribution: geometric
68 | coefficient: 1.0e-2
69 | step_size: 5.0e-2
70 | sampling_steps: 250
71 |
72 | batch_size: 64
73 | sample_size: 256
74 | shuffle: True
--------------------------------------------------------------------------------
/Config/etth.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 7
6 | n_layer_enc: 3
7 | n_layer_dec: 2
8 | d_model: 64 # 4 X 16
9 | timesteps: 500
10 | sampling_timesteps: 500
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 18000
23 | results_folder: ./Checkpoints_etth
24 | gradient_accumulate_every: 2
25 | save_cycle: 1800 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 4000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
42 | dataloader:
43 | train_dataset:
44 | target: Utils.Data_utils.real_datasets.CustomDataset
45 | params:
46 | name: etth
47 | proportion: 1.0 # Set to rate < 1 if training conditional generation
48 | data_root: ./Data/datasets/ETTh.csv
49 | window: 24 # seq_length
50 | save2npy: True
51 | neg_one_to_one: True
52 | seed: 123
53 | period: train
54 |
55 | test_dataset:
56 | target: Utils.Data_utils.real_datasets.CustomDataset
57 | params:
58 | name: etth
59 | proportion: 0.9 # rate
60 | data_root: ./Data/datasets/ETTh.csv
61 | window: 24 # seq_length
62 | save2npy: True
63 | neg_one_to_one: True
64 | seed: 123
65 | period: test
66 | style: separate
67 | distribution: geometric
68 | coefficient: 1.0e-2
69 | step_size: 5.0e-2
70 | sampling_steps: 200
71 |
72 | batch_size: 128
73 | sample_size: 256
74 | shuffle: True
--------------------------------------------------------------------------------
/Config/fmri.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 50
6 | n_layer_enc: 4
7 | n_layer_dec: 4
8 | d_model: 96 # 4 X 24
9 | timesteps: 1000
10 | sampling_timesteps: 1000
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 5
18 | padding_size: 2
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 15000
23 | results_folder: ./Checkpoints_fmri
24 | gradient_accumulate_every: 2
25 | save_cycle: 1500 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 3000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
42 | dataloader:
43 | train_dataset:
44 | target: Utils.Data_utils.real_datasets.fMRIDataset
45 | params:
46 | name: fMRI
47 | proportion: 1.0 # Set to rate < 1 if training conditional generation
48 | data_root: ./Data/datasets/fMRI
49 | window: 24 # seq_length
50 | save2npy: True
51 | neg_one_to_one: True
52 | seed: 123
53 | period: train
54 |
55 | test_dataset:
56 | target: Utils.Data_utils.real_datasets.fMRIDataset
57 | params:
58 | name: fMRI
59 | proportion: 0.9 # rate
60 | data_root: ./Data/datasets/fMRI
61 | window: 24 # seq_length
62 | save2npy: True
63 | neg_one_to_one: True
64 | seed: 123
65 | period: test
66 | style: separate
67 | distribution: geometric
68 | coefficient: 1.0e-2
69 | step_size: 5.0e-2
70 | sampling_steps: 250
71 |
72 | batch_size: 64
73 | sample_size: 256
74 | shuffle: True
--------------------------------------------------------------------------------
/Config/mujoco.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 14
6 | n_layer_enc: 3
7 | n_layer_dec: 2
8 | d_model: 64 # 4 X 16
9 | timesteps: 1000
10 | sampling_timesteps: 1000
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 14000
23 | results_folder: ./Checkpoints_mujoco
24 | gradient_accumulate_every: 2
25 | save_cycle: 1400 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 3000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
42 | dataloader:
43 | train_dataset:
44 | target: Utils.Data_utils.mujoco_dataset.MuJoCoDataset
45 | params:
46 | num: 10000
47 | dim: 14
48 | window: 24 # seq_length
49 | save2npy: True
50 | neg_one_to_one: True
51 | seed: 123
52 | period: train
53 |
54 | test_dataset:
55 | target: Utils.Data_utils.mujoco_dataset.MuJoCoDataset
56 | params:
57 | num: 1000
58 | dim: 14
59 | window: 24 # seq_length
60 | save2npy: True
61 | neg_one_to_one: True
62 | seed: 123
63 | style: separate
64 | period: test
65 | distribution: geometric
66 | coefficient: 1.0e-2
67 | step_size: 5.0e-2
68 | sampling_steps: 250
69 |
70 | batch_size: 128
71 | sample_size: 256
72 | shuffle: True
--------------------------------------------------------------------------------
/Config/mujoco_sssd.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 100
5 | feature_size: 14
6 | n_layer_enc: 3
7 | n_layer_dec: 3
8 | d_model: 64 # 4 X 16
9 | timesteps: 500
10 | sampling_timesteps: 500
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0
16 | resid_pd: 0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 12000
23 | results_folder: ./Checkpoints_mujoco_sssd
24 | gradient_accumulate_every: 2
25 | save_cycle: 1200 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 3000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
--------------------------------------------------------------------------------
/Config/sines.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 5
6 | n_layer_enc: 1
7 | n_layer_dec: 2
8 | d_model: 64 # 4 X 16
9 | timesteps: 500
10 | sampling_timesteps: 500
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 12000
23 | results_folder: ./Checkpoints_sine
24 | gradient_accumulate_every: 2
25 | save_cycle: 1200 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 3000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
42 | dataloader:
43 | train_dataset:
44 | target: Utils.Data_utils.sine_dataset.SineDataset
45 | params:
46 | num: 10000
47 | dim: 5
48 | window: 24 # seq_length
49 | save2npy: True
50 | neg_one_to_one: True
51 | seed: 123
52 | period: train
53 |
54 | test_dataset:
55 | target: Utils.Data_utils.sine_dataset.SineDataset
56 | params:
57 | num: 1000
58 | dim: 5
59 | window: 24 # seq_length
60 | save2npy: True
61 | neg_one_to_one: True
62 | seed: 123
63 | style: separate
64 | period: test
65 | distribution: geometric
66 | coefficient: 1.0e-2
67 | step_size: 5.0e-2
68 | sampling_steps: 200
69 |
70 | batch_size: 128
71 | sample_size: 256
72 | shuffle: True
--------------------------------------------------------------------------------
/Config/solar.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 192
5 | feature_size: 128
6 | n_layer_enc: 4
7 | n_layer_dec: 4
8 | d_model: 96 # 4 X 24
9 | timesteps: 500
10 | sampling_timesteps: 500
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 1500
23 | results_folder: ./Checkpoints_solar
24 | gradient_accumulate_every: 2
25 | save_cycle: 150 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 300
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 100
40 | verbose: False
41 |
--------------------------------------------------------------------------------
/Config/solar_update.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 192
5 | feature_size: 137
6 | n_layer_enc: 4
7 | n_layer_dec: 4
8 | d_model: 96 # 4 X 24
9 | timesteps: 500
10 | sampling_timesteps: 500
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.5
16 | resid_pd: 0.5
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 1000
23 | results_folder: ./Checkpoints_solar_nips
24 | gradient_accumulate_every: 2
25 | save_cycle: 100 # max_epochs // 10
26 | ema:
27 | decay: 0.9
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 300
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 100
40 | verbose: False
41 |
--------------------------------------------------------------------------------
/Config/stocks.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS
3 | params:
4 | seq_length: 24
5 | feature_size: 6
6 | n_layer_enc: 2
7 | n_layer_dec: 2
8 | d_model: 64 # 4 X 16
9 | timesteps: 500
10 | sampling_timesteps: 500
11 | loss_type: 'l1'
12 | beta_schedule: 'cosine'
13 | n_heads: 4
14 | mlp_hidden_times: 4
15 | attn_pd: 0.0
16 | resid_pd: 0.0
17 | kernel_size: 1
18 | padding_size: 0
19 |
20 | solver:
21 | base_lr: 1.0e-5
22 | max_epochs: 10000
23 | results_folder: ./Checkpoints_stock
24 | gradient_accumulate_every: 2
25 | save_cycle: 1000 # max_epochs // 10
26 | ema:
27 | decay: 0.995
28 | update_interval: 10
29 |
30 | scheduler:
31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup
32 | params:
33 | factor: 0.5
34 | patience: 2000
35 | min_lr: 1.0e-5
36 | threshold: 1.0e-1
37 | threshold_mode: rel
38 | warmup_lr: 8.0e-4
39 | warmup: 500
40 | verbose: False
41 |
42 | dataloader:
43 | train_dataset:
44 | target: Utils.Data_utils.real_datasets.CustomDataset
45 | params:
46 | name: stock
47 | proportion: 1.0 # Set to rate < 1 if training conditional generation
48 | data_root: ./Data/datasets/stock_data.csv
49 | window: 24 # seq_length
50 | save2npy: True
51 | neg_one_to_one: True
52 | seed: 123
53 | period: train
54 |
55 | test_dataset:
56 | target: Utils.Data_utils.real_datasets.CustomDataset
57 | params:
58 | name: stock
59 | proportion: 0.9 # rate
60 | data_root: ./Data/datasets/stock_data.csv
61 | window: 24 # seq_length
62 | save2npy: True
63 | neg_one_to_one: True
64 | seed: 123
65 | period: test
66 | style: separate
67 | distribution: geometric
68 | coefficient: 1.0e-2
69 | step_size: 5.0e-2
70 | sampling_steps: 200
71 |
72 | batch_size: 64
73 | sample_size: 256
74 | shuffle: True
--------------------------------------------------------------------------------
/Data/Place the dataset here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/Data/Place the dataset here
--------------------------------------------------------------------------------
/Data/build_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from Utils.io_utils import instantiate_from_config
3 |
4 |
5 | def build_dataloader(config, args=None):
6 | batch_size = config['dataloader']['batch_size']
7 | jud = config['dataloader']['shuffle']
8 | config['dataloader']['train_dataset']['params']['output_dir'] = args.save_dir
9 | dataset = instantiate_from_config(config['dataloader']['train_dataset'])
10 |
11 | dataloader = torch.utils.data.DataLoader(dataset,
12 | batch_size=batch_size,
13 | shuffle=jud,
14 | num_workers=0,
15 | pin_memory=True,
16 | sampler=None,
17 | drop_last=jud)
18 |
19 | dataload_info = {
20 | 'dataloader': dataloader,
21 | 'dataset': dataset
22 | }
23 |
24 | return dataload_info
25 |
26 | def build_dataloader_cond(config, args=None):
27 | batch_size = config['dataloader']['sample_size']
28 | config['dataloader']['test_dataset']['params']['output_dir'] = args.save_dir
29 | if args.mode == 'infill':
30 | config['dataloader']['test_dataset']['params']['missing_ratio'] = args.missing_ratio
31 | elif args.mode == 'predict':
32 | config['dataloader']['test_dataset']['params']['predict_length'] = args.pred_len
33 | test_dataset = instantiate_from_config(config['dataloader']['test_dataset'])
34 |
35 | dataloader = torch.utils.data.DataLoader(test_dataset,
36 | batch_size=batch_size,
37 | shuffle=False,
38 | num_workers=0,
39 | pin_memory=True,
40 | sampler=None,
41 | drop_last=False)
42 |
43 | dataload_info = {
44 | 'dataloader': dataloader,
45 | 'dataset': test_dataset
46 | }
47 |
48 | return dataload_info
49 |
50 |
51 | if __name__ == '__main__':
52 | pass
53 |
54 |
--------------------------------------------------------------------------------
/Data/readme.md:
--------------------------------------------------------------------------------
1 | ## 🚄 Get Started
2 |
3 | Please download **dataset.zip**, then unzip and copy it to the location indicated by '`Place the dataset here`'.
4 |
5 | > 🔔 **dataset.zip** can be downloaded [here](https://drive.google.com/file/d/11DI22zKWtHjXMnNGPWNUbyGz-JiEtZy6/view?usp=sharing).
6 |
7 | ## 🔨 Build Dataloader
8 |
9 | ### 🕶️ Pipeline Overview
10 |
11 | The figure below shows everything that happens after calling *build_dataloader*
12 |
13 |
14 |

15 |
16 |
17 | Below we show the specific meaning of the saved file:
18 |
19 | - `{name}_norm_truth_{length}_train.npy` - The saved data is used for training model, which has been normalized into [0, 1].
20 | - `{name}_norm_truth_{length}_test.npy` - The saved data is used for inference, which has been normalized into [0, 1].
21 | - `{name}_ground_truth_{length}_train.npy` - The saved data is used for training model, however, it is raw data that has not been normalized.
22 | - `{name}_ground_truth_{length}_test.npy` - The saved data is used for inference, however, it is raw data that has not been normalized.
23 | - `{name}_masking_{length}.npy` - The saved mask-seqences indicate the generation target of imputation or forecasting.
24 |
25 | > 🔔 Note that the generated time series `ddpm_fake_{name}.npy` or `ddpm_{mode}_{name}.npy` are also normalized into [0, 1]. You can restore them by adding following codes in **main.py**:
26 | ```
27 | line 86: samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape)
28 | line 93: samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape)
29 | ```
30 |
31 | ### 📝 Custom Dataset
32 |
33 | Real-world sequences (or any self-prepared sequences) need to be configured via the following:
34 |
35 | * Create and check the settings in your **.yaml file** and modify it such as *seq_length* and *feature_size* etc. if necessary.
36 | * Convert the real-world time series into **.csv file**, then put it to the repo like our template datasets.
37 | * Make sure that non-numeric rows and columns are not included in the training data, or you may need to modify codes in **./Utils/Data_utils/real_datasets.py**.
38 | - Remove the header if it exists:
39 | ```
40 | line 132: df = pd.read_csv(filepath, header=0)
41 | # set `header=None` if it does not exsit.
42 | ```
43 | - Delete rows and columns, here using the first column as an example:
44 | ```
45 | line 133: df.drop(df.columns[0], axis=1, inplace=True)
46 | ```
47 |
48 | > 🔔 Please set `use_ff=False` at line 54 in **./Models/interpretable_diffusion/gaussian_diffusion.py** if your temporal data is highly irregular.
--------------------------------------------------------------------------------
/Experiments/metric_pytorch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Context-FID Score Presentation\n",
8 | "## Necessary packages and functions call\n",
9 | "\n",
10 | "- Context-FID score: A useful metric measures how well the the synthetic time series windows ”fit” into the local context of the time series"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import os\n",
20 | "import torch\n",
21 | "import numpy as np\n",
22 | "import sys\n",
23 | "sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))\n",
24 | "from Utils.context_fid import Context_FID\n",
25 | "from Utils.metric_utils import display_scores\n",
26 | "from Utils.cross_correlation import CrossCorrelLoss"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "## Data Loading\n",
34 | "\n",
35 | "Load original dataset and preprocess the loaded data."
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 2,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "iterations = 5\n",
45 | "ori_data = np.load('../toy_exp/samples/sine_ground_truth_24_train.npy')\n",
46 | "# ori_data = np.load('../OUTPUT/{dataset_name}/samples/{dataset_name}_norm_truth_{seq_length}_train.npy') # Uncomment the line if dataset other than Sine is used.\n",
47 | "fake_data = np.load('../toy_exp/ddpm_fake_sines.npy')"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "## Context-FID Score\n",
55 | "\n",
56 | "- The Frechet Inception distance-like score is based on unsupervised time series embeddings. It is able to score the fit of the fixed length synthetic samples into their context of (often much longer) true time series.\n",
57 | "\n",
58 | "- The lowest scoring models correspond to the best performing models in downstream tasks"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 3,
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "name": "stdout",
68 | "output_type": "stream",
69 | "text": [
70 | "Iter 0: context-fid = 0.007570160590888924 \n",
71 | "\n",
72 | "Iter 1: context-fid = 0.008166182104461794 \n",
73 | "\n",
74 | "Iter 2: context-fid = 0.007970870497297756 \n",
75 | "\n",
76 | "Iter 3: context-fid = 0.007697393798102688 \n",
77 | "\n",
78 | "Iter 4: context-fid = 0.008525671090148759 \n",
79 | "\n",
80 | "Final Score: 0.007986055616179984 ± 0.00047287486643304236\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "context_fid_score = []\n",
86 | "\n",
87 | "for i in range(iterations):\n",
88 | " context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]])\n",
89 | " context_fid_score.append(context_fid)\n",
90 | " print(f'Iter {i}: ', 'context-fid =', context_fid, '\\n')\n",
91 | " \n",
92 | "display_scores(context_fid_score)"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "## Correlational Score\n",
100 | "\n",
101 | "- The metric uses the absolute error of the auto-correlation estimator by real data and synthetic data as the metric to assess the temporal dependency.\n",
102 | "\n",
103 | "- For d > 1, it uses the l1-norm of the difference between cross correlation matrices."
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 4,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "def random_choice(size, num_select=100):\n",
113 | " select_idx = np.random.randint(low=0, high=size, size=(num_select,))\n",
114 | " return select_idx"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 5,
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "Iter 0: cross-correlation = 0.019852216581699826 \n",
127 | "\n",
128 | "Iter 1: cross-correlation = 0.01816951370252664 \n",
129 | "\n",
130 | "Iter 2: cross-correlation = 0.022373661672091448 \n",
131 | "\n",
132 | "Iter 3: cross-correlation = 0.012407943886992933 \n",
133 | "\n",
134 | "Iter 4: cross-correlation = 0.010309792931556355 \n",
135 | "\n",
136 | "Final Score: 0.01662262575497344 ± 0.006316425881906014\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "x_real = torch.from_numpy(ori_data)\n",
142 | "x_fake = torch.from_numpy(fake_data)\n",
143 | "\n",
144 | "correlational_score = []\n",
145 | "size = int(x_real.shape[0] / iterations)\n",
146 | "\n",
147 | "for i in range(iterations):\n",
148 | " real_idx = random_choice(x_real.shape[0], size)\n",
149 | " fake_idx = random_choice(x_fake.shape[0], size)\n",
150 | " corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')\n",
151 | " loss = corr.compute(x_fake[fake_idx, :, :])\n",
152 | " correlational_score.append(loss.item())\n",
153 | " print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\\n')\n",
154 | "\n",
155 | "display_scores(correlational_score)"
156 | ]
157 | }
158 | ],
159 | "metadata": {
160 | "kernelspec": {
161 | "display_name": "PT3.8",
162 | "language": "python",
163 | "name": "pt3.8"
164 | },
165 | "language_info": {
166 | "codemirror_mode": {
167 | "name": "ipython",
168 | "version": 3
169 | },
170 | "file_extension": ".py",
171 | "mimetype": "text/x-python",
172 | "name": "python",
173 | "nbconvert_exporter": "python",
174 | "pygments_lexer": "ipython3",
175 | "version": "3.8.13"
176 | }
177 | },
178 | "nbformat": 4,
179 | "nbformat_minor": 2
180 | }
181 |
--------------------------------------------------------------------------------
/Experiments/metric_tensorflow.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Metric Presentation and Visualization\n",
8 | "## Necessary packages and functions call\n",
9 | "\n",
10 | "- DDPM-TS: Interpretable Diffusion for Time Series Generation\n",
11 | "- Metrics: \n",
12 | " - discriminative_metrics\n",
13 | " - predictive_metrics\n",
14 | " - visualization"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 4,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import os\n",
24 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
25 | "\n",
26 | "import sys\n",
27 | "sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))\n",
28 | "\n",
29 | "import warnings\n",
30 | "warnings.filterwarnings(\"ignore\")\n",
31 | "\n",
32 | "import numpy as np\n",
33 | "import tensorflow as tf\n",
34 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
35 | "if gpus:\n",
36 | " try:\n",
37 | " for gpu in gpus:\n",
38 | " tf.config.experimental.set_memory_growth(gpu, True)\n",
39 | " except RuntimeError as e:\n",
40 | " print(e)\n",
41 | "\n",
42 | "from Utils.metric_utils import display_scores\n",
43 | "from Utils.discriminative_metric import discriminative_score_metrics\n",
44 | "from Utils.predictive_metric import predictive_score_metrics"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "## Data Loading\n",
52 | "\n",
53 | "Load original dataset and preprocess the loaded data."
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 5,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "iterations = 5\n",
63 | "ori_data = np.load('../toy_exp/samples/sine_ground_truth_24_train.npy')\n",
64 | "# ori_data = np.load('../OUTPUT/{dataset_name}/samples/{dataset_name}_norm_truth_{seq_length}_train.npy') # Uncomment the line if dataset other than Sine is used.\n",
65 | "fake_data = np.load('../toy_exp/ddpm_fake_sines.npy')"
66 | ]
67 | },
68 | {
69 | "cell_type": "markdown",
70 | "metadata": {},
71 | "source": [
72 | "## Evaluate the generated data\n",
73 | "\n",
74 | "### 1. Discriminative score\n",
75 | "\n",
76 | "To evaluate the classification accuracy between original and synthetic data using post-hoc RNN network. The output is | classification accuracy - 0.5 |.\n",
77 | "\n",
78 | "- metric_iteration: the number of iterations for metric computation."
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 6,
84 | "metadata": {
85 | "scrolled": false
86 | },
87 | "outputs": [
88 | {
89 | "name": "stderr",
90 | "output_type": "stream",
91 | "text": [
92 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00, 8.77it/s]\n"
93 | ]
94 | },
95 | {
96 | "name": "stdout",
97 | "output_type": "stream",
98 | "text": [
99 | "Iter 0: 0.00649999999999995 , 0.5825 , 0.4305 \n",
100 | "\n"
101 | ]
102 | },
103 | {
104 | "name": "stderr",
105 | "output_type": "stream",
106 | "text": [
107 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00, 8.75it/s]\n"
108 | ]
109 | },
110 | {
111 | "name": "stdout",
112 | "output_type": "stream",
113 | "text": [
114 | "Iter 1: 0.0034999999999999476 , 0.4425 , 0.5645 \n",
115 | "\n"
116 | ]
117 | },
118 | {
119 | "name": "stderr",
120 | "output_type": "stream",
121 | "text": [
122 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:32<00:00, 9.39it/s]\n"
123 | ]
124 | },
125 | {
126 | "name": "stdout",
127 | "output_type": "stream",
128 | "text": [
129 | "Iter 2: 0.0007500000000000284 , 0.46 , 0.5415 \n",
130 | "\n"
131 | ]
132 | },
133 | {
134 | "name": "stderr",
135 | "output_type": "stream",
136 | "text": [
137 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:46<00:00, 8.82it/s]\n"
138 | ]
139 | },
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "Iter 3: 0.02200000000000002 , 0.535 , 0.509 \n",
145 | "\n"
146 | ]
147 | },
148 | {
149 | "name": "stderr",
150 | "output_type": "stream",
151 | "text": [
152 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00, 8.77it/s]\n"
153 | ]
154 | },
155 | {
156 | "name": "stdout",
157 | "output_type": "stream",
158 | "text": [
159 | "Iter 4: 0.007249999999999979 , 0.5365 , 0.478 \n",
160 | "\n",
161 | "sine:\n",
162 | "Final Score: 0.007999999999999985 ± 0.010231963047355045\n",
163 | "\n"
164 | ]
165 | }
166 | ],
167 | "source": [
168 | "discriminative_score = []\n",
169 | "\n",
170 | "for i in range(iterations):\n",
171 | " temp_disc, fake_acc, real_acc = discriminative_score_metrics(ori_data[:], fake_data[:ori_data.shape[0]])\n",
172 | " discriminative_score.append(temp_disc)\n",
173 | " print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\\n')\n",
174 | " \n",
175 | "print('sine:')\n",
176 | "display_scores(discriminative_score)\n",
177 | "print()"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "## Evaluate the generated data\n",
185 | "\n",
186 | "### 2. Predictive score\n",
187 | "\n",
188 | "To evaluate the prediction performance on train on synthetic, test on real setting. More specifically, we use Post-hoc RNN architecture to predict one-step ahead and report the performance in terms of MAE. \n",
189 | "\n",
190 | "The model learns to predict the last dimension with one more step."
191 | ]
192 | },
193 | {
194 | "cell_type": "code",
195 | "execution_count": 7,
196 | "metadata": {
197 | "scrolled": false
198 | },
199 | "outputs": [
200 | {
201 | "name": "stderr",
202 | "output_type": "stream",
203 | "text": [
204 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [04:59<00:00, 16.67it/s]\n"
205 | ]
206 | },
207 | {
208 | "name": "stdout",
209 | "output_type": "stream",
210 | "text": [
211 | "0 epoch: 0.09314072894950631 \n",
212 | "\n"
213 | ]
214 | },
215 | {
216 | "name": "stderr",
217 | "output_type": "stream",
218 | "text": [
219 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [04:52<00:00, 17.08it/s]\n"
220 | ]
221 | },
222 | {
223 | "name": "stdout",
224 | "output_type": "stream",
225 | "text": [
226 | "1 epoch: 0.09352861018187178 \n",
227 | "\n"
228 | ]
229 | },
230 | {
231 | "name": "stderr",
232 | "output_type": "stream",
233 | "text": [
234 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [05:04<00:00, 16.41it/s]\n"
235 | ]
236 | },
237 | {
238 | "name": "stdout",
239 | "output_type": "stream",
240 | "text": [
241 | "2 epoch: 0.09322281220011404 \n",
242 | "\n"
243 | ]
244 | },
245 | {
246 | "name": "stderr",
247 | "output_type": "stream",
248 | "text": [
249 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [04:56<00:00, 16.85it/s]\n"
250 | ]
251 | },
252 | {
253 | "name": "stdout",
254 | "output_type": "stream",
255 | "text": [
256 | "3 epoch: 0.09313142589665173 \n",
257 | "\n"
258 | ]
259 | },
260 | {
261 | "name": "stderr",
262 | "output_type": "stream",
263 | "text": [
264 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [05:01<00:00, 16.57it/s]\n"
265 | ]
266 | },
267 | {
268 | "name": "stdout",
269 | "output_type": "stream",
270 | "text": [
271 | "4 epoch: 0.09342775384532898 \n",
272 | "\n",
273 | "sine:\n",
274 | "Final Score: 0.09329026621469458 ± 0.00022198749029314803\n",
275 | "\n"
276 | ]
277 | }
278 | ],
279 | "source": [
280 | "predictive_score = []\n",
281 | "for i in range(iterations):\n",
282 | " temp_pred = predictive_score_metrics(ori_data, fake_data[:ori_data.shape[0]])\n",
283 | " predictive_score.append(temp_pred)\n",
284 | " print(i, ' epoch: ', temp_pred, '\\n')\n",
285 | " \n",
286 | "print('sine:')\n",
287 | "display_scores(predictive_score)\n",
288 | "print()"
289 | ]
290 | }
291 | ],
292 | "metadata": {
293 | "kernelspec": {
294 | "display_name": "Tensorflow_3.8",
295 | "language": "python",
296 | "name": "python3"
297 | },
298 | "language_info": {
299 | "codemirror_mode": {
300 | "name": "ipython",
301 | "version": 3
302 | },
303 | "file_extension": ".py",
304 | "mimetype": "text/x-python",
305 | "name": "python",
306 | "nbconvert_exporter": "python",
307 | "pygments_lexer": "ipython3",
308 | "version": "3.8.13"
309 | }
310 | },
311 | "nbformat": 4,
312 | "nbformat_minor": 2
313 | }
314 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 XXX
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Models/interpretable_diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from torch import nn
6 | from Models.interpretable_diffusion.model_utils import LearnablePositionalEncoding, Conv_MLP,\
7 | AdaLayerNorm, GELU2
8 |
9 |
10 | class GroupNorm32(nn.GroupNorm):
11 | def forward(self, x):
12 | return super().forward(x.float()).type(x.dtype)
13 |
14 |
15 | def normalization(channels):
16 | """
17 | Make a standard normalization layer.
18 |
19 | :param channels: number of input channels.
20 | :return: an nn.Module for normalization.
21 | """
22 | return GroupNorm32(8, channels)
23 |
24 |
25 | def conv_nd(dims, *args, **kwargs):
26 | """
27 | Create a 1D, 2D, or 3D convolution module.
28 | """
29 | if dims == 1:
30 | return nn.Conv1d(*args, **kwargs)
31 | elif dims == 2:
32 | return nn.Conv2d(*args, **kwargs)
33 | elif dims == 3:
34 | return nn.Conv3d(*args, **kwargs)
35 | raise ValueError(f"unsupported dimensions: {dims}")
36 |
37 |
38 | class QKVAttention(nn.Module):
39 | """
40 | A module which performs QKV attention and splits in a different order.
41 | """
42 |
43 | def __init__(self, n_heads):
44 | super().__init__()
45 | self.n_heads = n_heads
46 |
47 | def forward(self, qkv):
48 | """
49 | Apply QKV attention.
50 |
51 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
52 | :return: an [N x (H * C) x T] tensor after attention.
53 | """
54 | bs, width, length = qkv.shape
55 | assert width % (3 * self.n_heads) == 0
56 | ch = width // (3 * self.n_heads)
57 | q, k, v = qkv.chunk(3, dim=1)
58 | scale = 1 / math.sqrt(math.sqrt(ch))
59 | weight = torch.einsum(
60 | "bct,bcs->bts",
61 | (q * scale).view(bs * self.n_heads, ch, length),
62 | (k * scale).view(bs * self.n_heads, ch, length),
63 | ) # More stable with f16 than dividing afterwards
64 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
65 | a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
66 | return a.reshape(bs, -1, length)
67 |
68 |
69 |
70 | class AttentionPool2d(nn.Module):
71 | """
72 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
73 | """
74 |
75 | def __init__(
76 | self,
77 | embed_dim: int,
78 | num_heads_channels: int,
79 | output_dim: int = None,
80 | ):
81 | super().__init__()
82 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
83 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
84 | self.num_heads = embed_dim // num_heads_channels
85 | self.attention = QKVAttention(self.num_heads)
86 |
87 | def forward(self, x):
88 | b, c, *_spatial = x.shape
89 | x = x.reshape(b, c, -1) # NC(HW)
90 | x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
91 | x = self.qkv_proj(x)
92 | x = self.attention(x)
93 | x = self.c_proj(x)
94 | return x[:, :, 0]
95 |
96 |
97 | class FullAttention(nn.Module):
98 | def __init__(self,
99 | n_embd, # the embed dim
100 | n_head, # the number of heads
101 | attn_pdrop=0.1, # attention dropout prob
102 | resid_pdrop=0.1, # residual attention dropout prob
103 | ):
104 | super().__init__()
105 | assert n_embd % n_head == 0
106 | # key, query, value projections for all heads
107 | self.key = nn.Linear(n_embd, n_embd)
108 | self.query = nn.Linear(n_embd, n_embd)
109 | self.value = nn.Linear(n_embd, n_embd)
110 |
111 | # regularization
112 | self.attn_drop = nn.Dropout(attn_pdrop)
113 | self.resid_drop = nn.Dropout(resid_pdrop)
114 | # output projection
115 | self.proj = nn.Linear(n_embd, n_embd)
116 | self.n_head = n_head
117 |
118 | def forward(self, x, mask=None):
119 | B, T, C = x.size()
120 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
121 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
122 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
123 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
124 |
125 | att = F.softmax(att, dim=-1) # (B, nh, T, T)
126 | att = self.attn_drop(att)
127 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
128 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C)
129 | att = att.mean(dim=1, keepdim=False) # (B, T, T)
130 |
131 | # output projection
132 | y = self.resid_drop(self.proj(y))
133 | return y, att
134 |
135 |
136 | class EncoderBlock(nn.Module):
137 | """ an unassuming Transformer block """
138 | def __init__(self,
139 | n_embd=1024,
140 | n_head=16,
141 | attn_pdrop=0.1,
142 | resid_pdrop=0.1,
143 | mlp_hidden_times=4,
144 | activate='GELU'
145 | ):
146 | super().__init__()
147 |
148 | self.ln1 = AdaLayerNorm(n_embd)
149 | self.ln2 = nn.LayerNorm(n_embd)
150 | self.attn = FullAttention(
151 | n_embd=n_embd,
152 | n_head=n_head,
153 | attn_pdrop=attn_pdrop,
154 | resid_pdrop=resid_pdrop,
155 | )
156 |
157 | assert activate in ['GELU', 'GELU2']
158 | act = nn.GELU() if activate == 'GELU' else GELU2()
159 |
160 | self.mlp = nn.Sequential(
161 | nn.Linear(n_embd, mlp_hidden_times * n_embd),
162 | act,
163 | nn.Linear(mlp_hidden_times * n_embd, n_embd),
164 | nn.Dropout(resid_pdrop),
165 | )
166 |
167 | def forward(self, x, timestep, mask=None, label_emb=None):
168 | a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask)
169 | x = x + a
170 | x = x + self.mlp(self.ln2(x)) # only one really use encoder_output
171 | return x, att
172 |
173 |
174 | class Encoder(nn.Module):
175 | def __init__(
176 | self,
177 | n_layer=14,
178 | n_embd=1024,
179 | n_head=16,
180 | attn_pdrop=0.,
181 | resid_pdrop=0.,
182 | mlp_hidden_times=4,
183 | block_activate='GELU',
184 | ):
185 | super().__init__()
186 |
187 | self.blocks = nn.Sequential(*[EncoderBlock(
188 | n_embd=n_embd,
189 | n_head=n_head,
190 | attn_pdrop=attn_pdrop,
191 | resid_pdrop=resid_pdrop,
192 | mlp_hidden_times=mlp_hidden_times,
193 | activate=block_activate,
194 | ) for _ in range(n_layer)])
195 |
196 | def forward(self, input, t, padding_masks=None, label_emb=None):
197 | x = input
198 | for block_idx in range(len(self.blocks)):
199 | x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb)
200 | return x
201 |
202 |
203 | class Classifier(nn.Module):
204 | def __init__(
205 | self,
206 | feature_size,
207 | seq_length,
208 | num_classes=2,
209 | n_layer_enc=5,
210 | n_embd=1024,
211 | n_heads=16,
212 | attn_pdrop=0.1,
213 | resid_pdrop=0.1,
214 | mlp_hidden_times=4,
215 | block_activate='GELU',
216 | max_len=2048,
217 | num_head_channels=8,
218 | **kwargs
219 | ):
220 | super().__init__()
221 | self.emb = Conv_MLP(feature_size, n_embd, resid_pdrop=resid_pdrop)
222 | self.encoder = Encoder(n_layer_enc, n_embd, n_heads, attn_pdrop, resid_pdrop, mlp_hidden_times, block_activate)
223 | self.pos_enc = LearnablePositionalEncoding(n_embd, dropout=resid_pdrop, max_len=max_len)
224 |
225 | assert num_head_channels != -1
226 | self.out = nn.Sequential(
227 | normalization(seq_length),
228 | nn.SiLU(),
229 | AttentionPool2d(
230 | seq_length, num_head_channels, num_classes
231 | ),
232 | )
233 |
234 | def forward(self, input, t, padding_masks=None):
235 | emb = self.emb(input)
236 | inp_enc = self.pos_enc(emb)
237 | output = self.encoder(inp_enc, t, padding_masks=padding_masks)
238 |
239 | return self.out(output)
240 |
241 |
242 | if __name__ == '__main__':
243 | device = torch.device('cuda:0')
244 | input = torch.randn(128, 64, 14).to(device)
245 | t = torch.randint(0, 1000, (128, ), device=device)
246 |
247 | def count_model_parameters(model):
248 | total_params = sum(p.numel() for p in model.parameters())
249 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
250 | total_size_mb = total_params * 4 / (1024 * 1024) # assuming float32
251 | trainable_size_mb = trainable_params * 4 / (1024 * 1024)
252 | return {'Total Parameters': total_params, 'Trainable Parameters': trainable_params,
253 | 'Total Size (MB)': total_size_mb, 'Trainable Size (MB)': trainable_size_mb}
254 |
255 | model = Classifier(
256 | feature_size=14,
257 | n_layer_enc=3,
258 | seq_length=64,
259 | num_classes=2,
260 | n_embd=64,
261 | n_heads=4,
262 | attn_pdrop=0.,
263 | resid_pdrop=0.,
264 | mlp_hidden_times=4,
265 | block_activate='GELU',
266 | max_len=64,
267 | ).to(device)
268 | print(count_model_parameters(model))
269 | output = model(input, t)
270 | print(output.shape)
--------------------------------------------------------------------------------
/Models/interpretable_diffusion/model_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import scipy
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from torch import nn, einsum
7 | from functools import partial
8 | from einops import rearrange, reduce
9 | from scipy.fftpack import next_fast_len
10 |
11 |
12 | def exists(x):
13 | """
14 | Check if the input is not None.
15 |
16 | Args:
17 | x: The input to check.
18 |
19 | Returns:
20 | bool: True if the input is not None, False otherwise.
21 | """
22 | return x is not None
23 |
24 | def default(val, d):
25 | """
26 | Return the value if it exists, otherwise return the default value.
27 |
28 | Args:
29 | val: The value to check.
30 | d: The default value or a callable that returns the default value.
31 |
32 | Returns:
33 | The value if it exists, otherwise the default value.
34 | """
35 | if exists(val):
36 | return val
37 | return d() if callable(d) else d
38 |
39 | def identity(t, *args, **kwargs):
40 | """
41 | Return the input tensor unchanged.
42 |
43 | Args:
44 | t: The input tensor.
45 | *args: Additional arguments (unused).
46 | **kwargs: Additional keyword arguments (unused).
47 |
48 | Returns:
49 | The input tensor unchanged.
50 | """
51 | return t
52 |
53 | def extract(a, t, x_shape):
54 | """
55 | Extracts values from tensor `a` at indices specified by tensor `t` and reshapes the result.
56 | Args:
57 | a (torch.Tensor): The input tensor from which values are extracted.
58 | t (torch.Tensor): The tensor containing indices to extract from `a`.
59 | x_shape (tuple): The shape of the tensor `x` which determines the final shape of the output.
60 | Returns:
61 | torch.Tensor: A tensor containing the extracted values, reshaped to match the shape of `x` except for the first dimension.
62 | """
63 |
64 | b, *_ = t.shape
65 | out = a.gather(-1, t)
66 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
67 |
68 | def cond_fn(x, t, classifier=None, y=None, classifier_scale=1.):
69 | """
70 | Compute the gradient of the classifier's log probabilities with respect to the input.
71 |
72 | Args:
73 | classifier (nn.Module): The classifier model used to compute logits.
74 | x (torch.Tensor): The input tensor for which gradients are computed.
75 | t (torch.Tensor): The time step tensor.
76 | y (torch.Tensor, optional): The target labels tensor. Must not be None.
77 | classifier_scale (float, optional): Scaling factor for the gradients. Default is 1.
78 |
79 | Returns:
80 | torch.Tensor: The gradient of the selected log probabilities with respect to the input tensor, scaled by classifier_scale.
81 | """
82 | assert y is not None
83 | with torch.enable_grad():
84 | x_in = x.detach().requires_grad_(True)
85 | logits = classifier(x_in, t)
86 | log_probs = F.log_softmax(logits, dim=-1)
87 | selected = log_probs[range(len(logits)), y.view(-1)]
88 | return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale
89 |
90 | # normalization functions
91 |
92 | def normalize_to_neg_one_to_one(x):
93 | return x * 2 - 1
94 |
95 | def unnormalize_to_zero_to_one(x):
96 | return (x + 1) * 0.5
97 |
98 |
99 | # sinusoidal positional embeds
100 |
101 | class SinusoidalPosEmb(nn.Module):
102 | """
103 | Sinusoidal positional embedding module.
104 |
105 | This module generates sinusoidal positional embeddings for input tensors.
106 | The embeddings are computed using sine and cosine functions with different frequencies.
107 |
108 | Attributes:
109 | dim (int): The dimension of the positional embeddings.
110 | """
111 | def __init__(self, dim):
112 | super().__init__()
113 | self.dim = dim
114 |
115 | def forward(self, x):
116 | device = x.device
117 | half_dim = self.dim // 2
118 | emb = math.log(10000) / (half_dim - 1)
119 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
120 | emb = x[:, None] * emb[None, :]
121 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
122 | return emb
123 |
124 |
125 | # learnable positional embeds
126 |
127 | class LearnablePositionalEncoding(nn.Module):
128 | """
129 | Learnable positional encoding module.
130 |
131 | This module generates learnable positional embeddings for input tensors.
132 | The embeddings are learned during training and can adapt to the specific task.
133 |
134 | Attributes:
135 | d_model (int): The dimension of the positional embeddings.
136 | dropout (float): The dropout rate applied to the embeddings.
137 | max_len (int): The maximum length of the input sequences.
138 | """
139 | def __init__(self, d_model, dropout=0.1, max_len=1024):
140 | super(LearnablePositionalEncoding, self).__init__()
141 | self.dropout = nn.Dropout(p=dropout)
142 | # Each position gets its own embedding
143 | # Since indices are always 0 ... max_len, we don't have to do a look-up
144 | self.pe = nn.Parameter(torch.empty(1, max_len, d_model)) # requires_grad automatically set to True
145 | nn.init.uniform_(self.pe, -0.02, 0.02)
146 |
147 | def forward(self, x):
148 | r"""Inputs of forward function
149 | Args:
150 | x: the sequence fed to the positional encoder model (required).
151 | Shape:
152 | x: [batch size, sequence length, embed dim]
153 | output: [batch size, sequence length, embed dim]
154 | """
155 | # print(x.shape)
156 | x = x + self.pe
157 | return self.dropout(x)
158 |
159 |
160 | class moving_avg(nn.Module):
161 | """
162 | Moving average block to highlight the trend of time series
163 | """
164 | def __init__(self, kernel_size, stride):
165 | super(moving_avg, self).__init__()
166 | self.kernel_size = kernel_size
167 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
168 |
169 | def forward(self, x):
170 | # padding on the both ends of time series
171 | front = x[:, 0:1, :].repeat(1, self.kernel_size - 1-math.floor((self.kernel_size - 1) // 2), 1)
172 | end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1)
173 | x = torch.cat([front, x, end], dim=1)
174 | x = self.avg(x.permute(0, 2, 1))
175 | x = x.permute(0, 2, 1)
176 | return x
177 |
178 |
179 | class series_decomp(nn.Module):
180 | """
181 | Series decomposition block
182 | """
183 | def __init__(self, kernel_size):
184 | super(series_decomp, self).__init__()
185 | self.moving_avg = moving_avg(kernel_size, stride=1)
186 |
187 | def forward(self, x):
188 | moving_mean = self.moving_avg(x)
189 | res = x - moving_mean
190 | return res, moving_mean
191 |
192 |
193 | class series_decomp_multi(nn.Module):
194 | """
195 | Series decomposition block
196 | """
197 | def __init__(self, kernel_size):
198 | super(series_decomp_multi, self).__init__()
199 | self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size]
200 | self.layer = torch.nn.Linear(1, len(kernel_size))
201 |
202 | def forward(self, x):
203 | moving_mean=[]
204 | for func in self.moving_avg:
205 | moving_avg = func(x)
206 | moving_mean.append(moving_avg.unsqueeze(-1))
207 | moving_mean=torch.cat(moving_mean,dim=-1)
208 | moving_mean = torch.sum(moving_mean*nn.Softmax(-1)(self.layer(x.unsqueeze(-1))),dim=-1)
209 | res = x - moving_mean
210 | return res, moving_mean
211 |
212 |
213 | class Transpose(nn.Module):
214 | """ Wrapper class of torch.transpose() for Sequential module. """
215 | def __init__(self, shape: tuple):
216 | super(Transpose, self).__init__()
217 | self.shape = shape
218 |
219 | def forward(self, x):
220 | return x.transpose(*self.shape)
221 |
222 |
223 | class Conv_MLP(nn.Module):
224 | def __init__(self, in_dim, out_dim, resid_pdrop=0.):
225 | super().__init__()
226 | self.sequential = nn.Sequential(
227 | Transpose(shape=(1, 2)),
228 | nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1),
229 | nn.Dropout(p=resid_pdrop),
230 | )
231 |
232 | def forward(self, x):
233 | return self.sequential(x).transpose(1, 2)
234 |
235 |
236 | class Transformer_MLP(nn.Module):
237 | def __init__(self, n_embd, mlp_hidden_times, act, resid_pdrop):
238 | super().__init__()
239 | self.sequential = nn.Sequential(
240 | nn.Conv1d(in_channels=n_embd, out_channels=int(mlp_hidden_times * n_embd), kernel_size=1, padding=0),
241 | act,
242 | nn.Conv1d(in_channels=int(mlp_hidden_times * n_embd), out_channels=int(mlp_hidden_times * n_embd), kernel_size=3, padding=1),
243 | act,
244 | nn.Conv1d(in_channels=int(mlp_hidden_times * n_embd), out_channels=n_embd, kernel_size=3, padding=1),
245 | nn.Dropout(p=resid_pdrop),
246 | )
247 |
248 | def forward(self, x):
249 | return self.sequential(x)
250 |
251 |
252 | class GELU2(nn.Module):
253 | def __init__(self):
254 | super().__init__()
255 | def forward(self, x):
256 | return x * F.sigmoid(1.702 * x)
257 |
258 |
259 | class AdaLayerNorm(nn.Module):
260 | def __init__(self, n_embd):
261 | super().__init__()
262 | self.emb = SinusoidalPosEmb(n_embd)
263 | self.silu = nn.SiLU()
264 | self.linear = nn.Linear(n_embd, n_embd*2)
265 | self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
266 |
267 | def forward(self, x, timestep, label_emb=None):
268 | emb = self.emb(timestep)
269 | if label_emb is not None:
270 | emb = emb + label_emb
271 | emb = self.linear(self.silu(emb)).unsqueeze(1)
272 | scale, shift = torch.chunk(emb, 2, dim=2)
273 | x = self.layernorm(x) * (1 + scale) + shift
274 | return x
275 |
276 |
277 | class AdaInsNorm(nn.Module):
278 | def __init__(self, n_embd):
279 | super().__init__()
280 | self.emb = SinusoidalPosEmb(n_embd)
281 | self.silu = nn.SiLU()
282 | self.linear = nn.Linear(n_embd, n_embd*2)
283 | self.instancenorm = nn.InstanceNorm1d(n_embd)
284 |
285 | def forward(self, x, timestep, label_emb=None):
286 | emb = self.emb(timestep)
287 | if label_emb is not None:
288 | emb = emb + label_emb
289 | emb = self.linear(self.silu(emb)).unsqueeze(1)
290 | scale, shift = torch.chunk(emb, 2, dim=2)
291 | x = self.instancenorm(x.transpose(-1, -2)).transpose(-1,-2) * (1 + scale) + shift
292 | return x
--------------------------------------------------------------------------------
/Models/interpretable_diffusion/transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | from torch import nn
7 | from einops import rearrange, reduce, repeat
8 | from Models.interpretable_diffusion.model_utils import LearnablePositionalEncoding, Conv_MLP,\
9 | AdaLayerNorm, Transpose, GELU2, series_decomp
10 |
11 |
12 | class TrendBlock(nn.Module):
13 | """
14 | Model trend of time series using the polynomial regressor.
15 | """
16 | def __init__(self, in_dim, out_dim, in_feat, out_feat, act):
17 | super(TrendBlock, self).__init__()
18 | trend_poly = 3
19 | self.trend = nn.Sequential(
20 | nn.Conv1d(in_channels=in_dim, out_channels=trend_poly, kernel_size=3, padding=1),
21 | act,
22 | Transpose(shape=(1, 2)),
23 | nn.Conv1d(in_feat, out_feat, 3, stride=1, padding=1)
24 | )
25 |
26 | lin_space = torch.arange(1, out_dim + 1, 1) / (out_dim + 1)
27 | self.poly_space = torch.stack([lin_space ** float(p + 1) for p in range(trend_poly)], dim=0)
28 |
29 | def forward(self, input):
30 | b, c, h = input.shape
31 | x = self.trend(input).transpose(1, 2)
32 | trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device))
33 | trend_vals = trend_vals.transpose(1, 2)
34 | return trend_vals
35 |
36 |
37 | class MovingBlock(nn.Module):
38 | """
39 | Model trend of time series using the moving average.
40 | """
41 | def __init__(self, out_dim):
42 | super(MovingBlock, self).__init__()
43 | size = max(min(int(out_dim / 4), 24), 4)
44 | self.decomp = series_decomp(size)
45 |
46 | def forward(self, input):
47 | b, c, h = input.shape
48 | x, trend_vals = self.decomp(input)
49 | return x, trend_vals
50 |
51 |
52 | class FourierLayer(nn.Module):
53 | """
54 | Model seasonality of time series using the inverse DFT.
55 | """
56 | def __init__(self, d_model, low_freq=1, factor=1):
57 | super().__init__()
58 | self.d_model = d_model
59 | self.factor = factor
60 | self.low_freq = low_freq
61 |
62 | def forward(self, x):
63 | """x: (b, t, d)"""
64 | b, t, d = x.shape
65 | x_freq = torch.fft.rfft(x, dim=1)
66 |
67 | if t % 2 == 0:
68 | x_freq = x_freq[:, self.low_freq:-1]
69 | f = torch.fft.rfftfreq(t)[self.low_freq:-1]
70 | else:
71 | x_freq = x_freq[:, self.low_freq:]
72 | f = torch.fft.rfftfreq(t)[self.low_freq:]
73 |
74 | x_freq, index_tuple = self.topk_freq(x_freq)
75 | f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device)
76 | f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device)
77 | return self.extrapolate(x_freq, f, t)
78 |
79 | def extrapolate(self, x_freq, f, t):
80 | x_freq = torch.cat([x_freq, x_freq.conj()], dim=1)
81 | f = torch.cat([f, -f], dim=1)
82 | t = rearrange(torch.arange(t, dtype=torch.float),
83 | 't -> () () t ()').to(x_freq.device)
84 |
85 | amp = rearrange(x_freq.abs(), 'b f d -> b f () d')
86 | phase = rearrange(x_freq.angle(), 'b f d -> b f () d')
87 | x_time = amp * torch.cos(2 * math.pi * f * t + phase)
88 | return reduce(x_time, 'b f t d -> b t d', 'sum')
89 |
90 | def topk_freq(self, x_freq):
91 | length = x_freq.shape[1]
92 | top_k = int(self.factor * math.log(length))
93 | values, indices = torch.topk(x_freq.abs(), top_k, dim=1, largest=True, sorted=True)
94 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)), indexing='ij')
95 | index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1))
96 | x_freq = x_freq[index_tuple]
97 | return x_freq, index_tuple
98 |
99 |
100 | class SeasonBlock(nn.Module):
101 | """
102 | Model seasonality of time series using the Fourier series.
103 | """
104 | def __init__(self, in_dim, out_dim, factor=1):
105 | super(SeasonBlock, self).__init__()
106 | season_poly = factor * min(32, int(out_dim // 2))
107 | self.season = nn.Conv1d(in_channels=in_dim, out_channels=season_poly, kernel_size=1, padding=0)
108 | fourier_space = torch.arange(0, out_dim, 1) / out_dim
109 | p1, p2 = (season_poly // 2, season_poly // 2) if season_poly % 2 == 0 \
110 | else (season_poly // 2, season_poly // 2 + 1)
111 | s1 = torch.stack([torch.cos(2 * np.pi * p * fourier_space) for p in range(1, p1 + 1)], dim=0)
112 | s2 = torch.stack([torch.sin(2 * np.pi * p * fourier_space) for p in range(1, p2 + 1)], dim=0)
113 | self.poly_space = torch.cat([s1, s2])
114 |
115 | def forward(self, input):
116 | b, c, h = input.shape
117 | x = self.season(input)
118 | season_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device))
119 | season_vals = season_vals.transpose(1, 2)
120 | return season_vals
121 |
122 |
123 | class FullAttention(nn.Module):
124 | def __init__(self,
125 | n_embd, # the embed dim
126 | n_head, # the number of heads
127 | attn_pdrop=0.1, # attention dropout prob
128 | resid_pdrop=0.1, # residual attention dropout prob
129 | ):
130 | super().__init__()
131 | assert n_embd % n_head == 0
132 | # key, query, value projections for all heads
133 | self.key = nn.Linear(n_embd, n_embd)
134 | self.query = nn.Linear(n_embd, n_embd)
135 | self.value = nn.Linear(n_embd, n_embd)
136 |
137 | # regularization
138 | self.attn_drop = nn.Dropout(attn_pdrop)
139 | self.resid_drop = nn.Dropout(resid_pdrop)
140 | # output projection
141 | self.proj = nn.Linear(n_embd, n_embd)
142 | self.n_head = n_head
143 |
144 | def forward(self, x, mask=None):
145 | B, T, C = x.size()
146 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
147 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
148 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
149 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
150 |
151 | att = F.softmax(att, dim=-1) # (B, nh, T, T)
152 | att = self.attn_drop(att)
153 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
154 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C)
155 | att = att.mean(dim=1, keepdim=False) # (B, T, T)
156 |
157 | # output projection
158 | y = self.resid_drop(self.proj(y))
159 | return y, att
160 |
161 |
162 | class CrossAttention(nn.Module):
163 | def __init__(self,
164 | n_embd, # the embed dim
165 | condition_embd, # condition dim
166 | n_head, # the number of heads
167 | attn_pdrop=0.1, # attention dropout prob
168 | resid_pdrop=0.1, # residual attention dropout prob
169 | ):
170 | super().__init__()
171 | assert n_embd % n_head == 0
172 | # key, query, value projections for all heads
173 | self.key = nn.Linear(condition_embd, n_embd)
174 | self.query = nn.Linear(n_embd, n_embd)
175 | self.value = nn.Linear(condition_embd, n_embd)
176 |
177 | # regularization
178 | self.attn_drop = nn.Dropout(attn_pdrop)
179 | self.resid_drop = nn.Dropout(resid_pdrop)
180 | # output projection
181 | self.proj = nn.Linear(n_embd, n_embd)
182 | self.n_head = n_head
183 |
184 | def forward(self, x, encoder_output, mask=None):
185 | B, T, C = x.size()
186 | B, T_E, _ = encoder_output.size()
187 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
188 | k = self.key(encoder_output).view(B, T_E, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
189 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
190 | v = self.value(encoder_output).view(B, T_E, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
191 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
192 |
193 | att = F.softmax(att, dim=-1) # (B, nh, T, T)
194 | att = self.attn_drop(att)
195 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
196 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C)
197 | att = att.mean(dim=1, keepdim=False) # (B, T, T)
198 |
199 | # output projection
200 | y = self.resid_drop(self.proj(y))
201 | return y, att
202 |
203 |
204 | class EncoderBlock(nn.Module):
205 | """ an unassuming Transformer block """
206 | def __init__(self,
207 | n_embd=1024,
208 | n_head=16,
209 | attn_pdrop=0.1,
210 | resid_pdrop=0.1,
211 | mlp_hidden_times=4,
212 | activate='GELU'
213 | ):
214 | super().__init__()
215 |
216 | self.ln1 = AdaLayerNorm(n_embd)
217 | self.ln2 = nn.LayerNorm(n_embd)
218 | self.attn = FullAttention(
219 | n_embd=n_embd,
220 | n_head=n_head,
221 | attn_pdrop=attn_pdrop,
222 | resid_pdrop=resid_pdrop,
223 | )
224 |
225 | assert activate in ['GELU', 'GELU2']
226 | act = nn.GELU() if activate == 'GELU' else GELU2()
227 |
228 | self.mlp = nn.Sequential(
229 | nn.Linear(n_embd, mlp_hidden_times * n_embd),
230 | act,
231 | nn.Linear(mlp_hidden_times * n_embd, n_embd),
232 | nn.Dropout(resid_pdrop),
233 | )
234 |
235 | def forward(self, x, timestep, mask=None, label_emb=None):
236 | a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask)
237 | x = x + a
238 | x = x + self.mlp(self.ln2(x)) # only one really use encoder_output
239 | return x, att
240 |
241 |
242 | class Encoder(nn.Module):
243 | def __init__(
244 | self,
245 | n_layer=14,
246 | n_embd=1024,
247 | n_head=16,
248 | attn_pdrop=0.,
249 | resid_pdrop=0.,
250 | mlp_hidden_times=4,
251 | block_activate='GELU',
252 | ):
253 | super().__init__()
254 |
255 | self.blocks = nn.Sequential(*[EncoderBlock(
256 | n_embd=n_embd,
257 | n_head=n_head,
258 | attn_pdrop=attn_pdrop,
259 | resid_pdrop=resid_pdrop,
260 | mlp_hidden_times=mlp_hidden_times,
261 | activate=block_activate,
262 | ) for _ in range(n_layer)])
263 |
264 | def forward(self, input, t, padding_masks=None, label_emb=None):
265 | x = input
266 | for block_idx in range(len(self.blocks)):
267 | x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb)
268 | return x
269 |
270 |
271 | class DecoderBlock(nn.Module):
272 | """ an unassuming Transformer block """
273 | def __init__(self,
274 | n_channel,
275 | n_feat,
276 | n_embd=1024,
277 | n_head=16,
278 | attn_pdrop=0.1,
279 | resid_pdrop=0.1,
280 | mlp_hidden_times=4,
281 | activate='GELU',
282 | condition_dim=1024,
283 | ):
284 | super().__init__()
285 |
286 | self.ln1 = AdaLayerNorm(n_embd)
287 | self.ln2 = nn.LayerNorm(n_embd)
288 |
289 | self.attn1 = FullAttention(
290 | n_embd=n_embd,
291 | n_head=n_head,
292 | attn_pdrop=attn_pdrop,
293 | resid_pdrop=resid_pdrop,
294 | )
295 | self.attn2 = CrossAttention(
296 | n_embd=n_embd,
297 | condition_embd=condition_dim,
298 | n_head=n_head,
299 | attn_pdrop=attn_pdrop,
300 | resid_pdrop=resid_pdrop,
301 | )
302 |
303 | self.ln1_1 = AdaLayerNorm(n_embd)
304 |
305 | assert activate in ['GELU', 'GELU2']
306 | act = nn.GELU() if activate == 'GELU' else GELU2()
307 |
308 | self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act)
309 | # self.decomp = MovingBlock(n_channel)
310 | self.seasonal = FourierLayer(d_model=n_embd)
311 | # self.seasonal = SeasonBlock(n_channel, n_channel)
312 |
313 | self.mlp = nn.Sequential(
314 | nn.Linear(n_embd, mlp_hidden_times * n_embd),
315 | act,
316 | nn.Linear(mlp_hidden_times * n_embd, n_embd),
317 | nn.Dropout(resid_pdrop),
318 | )
319 |
320 | self.proj = nn.Conv1d(n_channel, n_channel * 2, 1)
321 | self.linear = nn.Linear(n_embd, n_feat)
322 |
323 | def forward(self, x, encoder_output, timestep, mask=None, label_emb=None):
324 | a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask)
325 | x = x + a
326 | a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask)
327 | x = x + a
328 | x1, x2 = self.proj(x).chunk(2, dim=1)
329 | trend, season = self.trend(x1), self.seasonal(x2)
330 | x = x + self.mlp(self.ln2(x))
331 | m = torch.mean(x, dim=1, keepdim=True)
332 | return x - m, self.linear(m), trend, season
333 |
334 |
335 | class Decoder(nn.Module):
336 | def __init__(
337 | self,
338 | n_channel,
339 | n_feat,
340 | n_embd=1024,
341 | n_head=16,
342 | n_layer=10,
343 | attn_pdrop=0.1,
344 | resid_pdrop=0.1,
345 | mlp_hidden_times=4,
346 | block_activate='GELU',
347 | condition_dim=512
348 | ):
349 | super().__init__()
350 | self.d_model = n_embd
351 | self.n_feat = n_feat
352 | self.blocks = nn.Sequential(*[DecoderBlock(
353 | n_feat=n_feat,
354 | n_channel=n_channel,
355 | n_embd=n_embd,
356 | n_head=n_head,
357 | attn_pdrop=attn_pdrop,
358 | resid_pdrop=resid_pdrop,
359 | mlp_hidden_times=mlp_hidden_times,
360 | activate=block_activate,
361 | condition_dim=condition_dim,
362 | ) for _ in range(n_layer)])
363 |
364 | def forward(self, x, t, enc, padding_masks=None, label_emb=None):
365 | b, c, _ = x.shape
366 | # att_weights = []
367 | mean = []
368 | season = torch.zeros((b, c, self.d_model), device=x.device)
369 | trend = torch.zeros((b, c, self.n_feat), device=x.device)
370 | for block_idx in range(len(self.blocks)):
371 | x, residual_mean, residual_trend, residual_season = \
372 | self.blocks[block_idx](x, enc, t, mask=padding_masks, label_emb=label_emb)
373 | season += residual_season
374 | trend += residual_trend
375 | mean.append(residual_mean)
376 |
377 | mean = torch.cat(mean, dim=1)
378 | return x, mean, trend, season
379 |
380 |
381 | class Transformer(nn.Module):
382 | def __init__(
383 | self,
384 | n_feat,
385 | n_channel,
386 | n_layer_enc=5,
387 | n_layer_dec=14,
388 | n_embd=1024,
389 | n_heads=16,
390 | attn_pdrop=0.1,
391 | resid_pdrop=0.1,
392 | mlp_hidden_times=4,
393 | block_activate='GELU',
394 | max_len=2048,
395 | conv_params=None,
396 | **kwargs
397 | ):
398 | super().__init__()
399 | self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop)
400 | self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop)
401 |
402 | if conv_params is None or conv_params[0] is None:
403 | if n_feat < 32 and n_channel < 64:
404 | kernel_size, padding = 1, 0
405 | else:
406 | kernel_size, padding = 5, 2
407 | else:
408 | kernel_size, padding = conv_params
409 |
410 | self.combine_s = nn.Conv1d(n_embd, n_feat, kernel_size=kernel_size, stride=1, padding=padding,
411 | padding_mode='circular', bias=False)
412 | self.combine_m = nn.Conv1d(n_layer_dec, 1, kernel_size=1, stride=1, padding=0,
413 | padding_mode='circular', bias=False)
414 |
415 | self.encoder = Encoder(n_layer_enc, n_embd, n_heads, attn_pdrop, resid_pdrop, mlp_hidden_times, block_activate)
416 | self.pos_enc = LearnablePositionalEncoding(n_embd, dropout=resid_pdrop, max_len=max_len)
417 |
418 | self.decoder = Decoder(n_channel, n_feat, n_embd, n_heads, n_layer_dec, attn_pdrop, resid_pdrop, mlp_hidden_times,
419 | block_activate, condition_dim=n_embd)
420 | self.pos_dec = LearnablePositionalEncoding(n_embd, dropout=resid_pdrop, max_len=max_len)
421 |
422 | def forward(self, input, t, padding_masks=None, return_res=False):
423 | emb = self.emb(input)
424 | inp_enc = self.pos_enc(emb)
425 | enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks)
426 |
427 | inp_dec = self.pos_dec(emb)
428 | output, mean, trend, season = self.decoder(inp_dec, t, enc_cond, padding_masks=padding_masks)
429 |
430 | res = self.inverse(output)
431 | res_m = torch.mean(res, dim=1, keepdim=True)
432 | season_error = self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m
433 | trend = self.combine_m(mean) + res_m + trend
434 |
435 | if return_res:
436 | return trend, self.combine_s(season.transpose(1, 2)).transpose(1, 2), res - res_m
437 |
438 | return trend, season_error
439 |
440 |
441 | if __name__ == '__main__':
442 | pass
--------------------------------------------------------------------------------
/Models/ts2vec/models/dilated_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class SamePadConv(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
8 | super().__init__()
9 | self.receptive_field = (kernel_size - 1) * dilation + 1
10 | padding = self.receptive_field // 2
11 | self.conv = nn.Conv1d(
12 | in_channels, out_channels, kernel_size,
13 | padding=padding,
14 | dilation=dilation,
15 | groups=groups
16 | )
17 | self.remove = 1 if self.receptive_field % 2 == 0 else 0
18 |
19 | def forward(self, x):
20 | out = self.conv(x)
21 | if self.remove > 0:
22 | out = out[:, :, : -self.remove]
23 | return out
24 |
25 | class ConvBlock(nn.Module):
26 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False):
27 | super().__init__()
28 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation)
29 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation)
30 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None
31 |
32 | def forward(self, x):
33 | residual = x if self.projector is None else self.projector(x)
34 | x = F.gelu(x)
35 | x = self.conv1(x)
36 | x = F.gelu(x)
37 | x = self.conv2(x)
38 | return x + residual
39 |
40 | class DilatedConvEncoder(nn.Module):
41 | def __init__(self, in_channels, channels, kernel_size):
42 | super().__init__()
43 | self.net = nn.Sequential(*[
44 | ConvBlock(
45 | channels[i-1] if i > 0 else in_channels,
46 | channels[i],
47 | kernel_size=kernel_size,
48 | dilation=2**i,
49 | final=(i == len(channels)-1)
50 | )
51 | for i in range(len(channels))
52 | ])
53 |
54 | def forward(self, x):
55 | return self.net(x)
56 |
--------------------------------------------------------------------------------
/Models/ts2vec/models/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .dilated_conv import DilatedConvEncoder
6 |
7 | def generate_continuous_mask(B, T, n=5, l=0.1):
8 | res = torch.full((B, T), True, dtype=torch.bool)
9 | if isinstance(n, float):
10 | n = int(n * T)
11 | n = max(min(n, T // 2), 1)
12 |
13 | if isinstance(l, float):
14 | l = int(l * T)
15 | l = max(l, 1)
16 |
17 | for i in range(B):
18 | for _ in range(n):
19 | t = np.random.randint(T-l+1)
20 | res[i, t:t+l] = False
21 | return res
22 |
23 | def generate_binomial_mask(B, T, p=0.5):
24 | return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool)
25 |
26 | class TSEncoder(nn.Module):
27 | def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'):
28 | super().__init__()
29 | self.input_dims = input_dims
30 | self.output_dims = output_dims
31 | self.hidden_dims = hidden_dims
32 | self.mask_mode = mask_mode
33 | self.input_fc = nn.Linear(input_dims, hidden_dims)
34 | self.feature_extractor = DilatedConvEncoder(
35 | hidden_dims,
36 | [hidden_dims] * depth + [output_dims],
37 | kernel_size=3
38 | )
39 | self.repr_dropout = nn.Dropout(p=0.1)
40 |
41 | def forward(self, x, mask=None): # x: B x T x input_dims
42 | nan_mask = ~x.isnan().any(axis=-1)
43 | x[~nan_mask] = 0
44 | x = self.input_fc(x) # B x T x Ch
45 |
46 | # generate & apply mask
47 | if mask is None:
48 | if self.training:
49 | mask = self.mask_mode
50 | else:
51 | mask = 'all_true'
52 |
53 | if mask == 'binomial':
54 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device)
55 | elif mask == 'continuous':
56 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device)
57 | elif mask == 'all_true':
58 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
59 | elif mask == 'all_false':
60 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
61 | elif mask == 'mask_last':
62 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
63 | mask[:, -1] = False
64 |
65 | mask &= nan_mask
66 | x[~mask] = 0
67 |
68 | # conv encoder
69 | x = x.transpose(1, 2) # B x Ch x T
70 | x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T
71 | x = x.transpose(1, 2) # B x T x Co
72 |
73 | return x
74 |
--------------------------------------------------------------------------------
/Models/ts2vec/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0):
7 | loss = torch.tensor(0., device=z1.device)
8 | d = 0
9 | while z1.size(1) > 1:
10 | if alpha != 0:
11 | loss += alpha * instance_contrastive_loss(z1, z2)
12 | if d >= temporal_unit:
13 | if 1 - alpha != 0:
14 | loss += (1 - alpha) * temporal_contrastive_loss(z1, z2)
15 | d += 1
16 | z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
17 | z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)
18 | if z1.size(1) == 1:
19 | if alpha != 0:
20 | loss += alpha * instance_contrastive_loss(z1, z2)
21 | d += 1
22 | return loss / d
23 |
24 | def instance_contrastive_loss(z1, z2):
25 | B, T = z1.size(0), z1.size(1)
26 | if B == 1:
27 | return z1.new_tensor(0.)
28 | z = torch.cat([z1, z2], dim=0) # 2B x T x C
29 | z = z.transpose(0, 1) # T x 2B x C
30 | sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B
31 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1)
32 | logits += torch.triu(sim, diagonal=1)[:, :, 1:]
33 | logits = -F.log_softmax(logits, dim=-1)
34 |
35 | i = torch.arange(B, device=z1.device)
36 | loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
37 | return loss
38 |
39 | def temporal_contrastive_loss(z1, z2):
40 | B, T = z1.size(0), z1.size(1)
41 | if T == 1:
42 | return z1.new_tensor(0.)
43 | z = torch.cat([z1, z2], dim=1) # B x 2T x C
44 | sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T
45 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1)
46 | logits += torch.triu(sim, diagonal=1)[:, :, 1:]
47 | logits = -F.log_softmax(logits, dim=-1)
48 |
49 | t = torch.arange(T, device=z1.device)
50 | loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
51 | return loss
52 |
--------------------------------------------------------------------------------
/Models/ts2vec/ts2vec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.utils.data import TensorDataset, DataLoader
4 | import numpy as np
5 | from Models.ts2vec.models.encoder import TSEncoder
6 | from Models.ts2vec.models.losses import hierarchical_contrastive_loss
7 | from Models.ts2vec.utils import take_per_row, split_with_nan, centerize_vary_length_series, torch_pad_nan
8 |
9 |
10 | class TS2Vec:
11 | '''The TS2Vec model'''
12 |
13 | def __init__(
14 | self,
15 | input_dims,
16 | output_dims=320,
17 | hidden_dims=64,
18 | depth=10,
19 | device='cuda',
20 | lr=0.001,
21 | batch_size=16,
22 | max_train_length=None,
23 | temporal_unit=0,
24 | after_iter_callback=None,
25 | after_epoch_callback=None
26 | ):
27 | ''' Initialize a TS2Vec model.
28 |
29 | Args:
30 | input_dims (int): The input dimension. For a univariate time series, this should be set to 1.
31 | output_dims (int): The representation dimension.
32 | hidden_dims (int): The hidden dimension of the encoder.
33 | depth (int): The number of hidden residual blocks in the encoder.
34 | device (int): The gpu used for training and inference.
35 | lr (int): The learning rate.
36 | batch_size (int): The batch size.
37 | max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than , it would be cropped into some sequences, each of which has a length less than .
38 | temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory.
39 | after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration.
40 | after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch.
41 | '''
42 |
43 | super().__init__()
44 | self.device = device
45 | self.lr = lr
46 | self.batch_size = batch_size
47 | self.max_train_length = max_train_length
48 | self.temporal_unit = temporal_unit
49 |
50 | self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth).to(self.device)
51 | self.net = torch.optim.swa_utils.AveragedModel(self._net)
52 | self.net.update_parameters(self._net)
53 |
54 | self.after_iter_callback = after_iter_callback
55 | self.after_epoch_callback = after_epoch_callback
56 |
57 | self.n_epochs = 0
58 | self.n_iters = 0
59 |
60 | def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False):
61 | ''' Training the TS2Vec model.
62 |
63 | Args:
64 | train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
65 | n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops.
66 | n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise.
67 | verbose (bool): Whether to print the training loss after each epoch.
68 |
69 | Returns:
70 | loss_log: a list containing the training losses on each epoch.
71 | '''
72 | assert train_data.ndim == 3
73 |
74 | if n_iters is None and n_epochs is None:
75 | n_iters = 200 if train_data.size <= 100000 else 600 # default param for n_iters
76 |
77 | if self.max_train_length is not None:
78 | sections = train_data.shape[1] // self.max_train_length
79 | if sections >= 2:
80 | train_data = np.concatenate(split_with_nan(train_data, sections, axis=1), axis=0)
81 |
82 | temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0)
83 | if temporal_missing[0] or temporal_missing[-1]:
84 | train_data = centerize_vary_length_series(train_data)
85 |
86 | train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)]
87 |
88 | train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float))
89 | train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True, drop_last=True)
90 |
91 | optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr)
92 |
93 | loss_log = []
94 |
95 | while True:
96 | if n_epochs is not None and self.n_epochs >= n_epochs:
97 | break
98 |
99 | cum_loss = 0
100 | n_epoch_iters = 0
101 |
102 | interrupted = False
103 | for batch in train_loader:
104 | if n_iters is not None and self.n_iters >= n_iters:
105 | interrupted = True
106 | break
107 |
108 | x = batch[0]
109 | if self.max_train_length is not None and x.size(1) > self.max_train_length:
110 | window_offset = np.random.randint(x.size(1) - self.max_train_length + 1)
111 | x = x[:, window_offset : window_offset + self.max_train_length]
112 | x = x.to(self.device)
113 |
114 | ts_l = x.size(1)
115 | crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l+1)
116 | crop_left = np.random.randint(ts_l - crop_l + 1)
117 | crop_right = crop_left + crop_l
118 | crop_eleft = np.random.randint(crop_left + 1)
119 | crop_eright = np.random.randint(low=crop_right, high=ts_l + 1)
120 | crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0))
121 |
122 | optimizer.zero_grad()
123 |
124 | out1 = self._net(take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft))
125 | out1 = out1[:, -crop_l:]
126 |
127 | out2 = self._net(take_per_row(x, crop_offset + crop_left, crop_eright - crop_left))
128 | out2 = out2[:, :crop_l]
129 |
130 | loss = hierarchical_contrastive_loss(
131 | out1,
132 | out2,
133 | temporal_unit=self.temporal_unit
134 | )
135 |
136 | loss.backward()
137 | optimizer.step()
138 | self.net.update_parameters(self._net)
139 |
140 | cum_loss += loss.item()
141 | n_epoch_iters += 1
142 |
143 | self.n_iters += 1
144 |
145 | if self.after_iter_callback is not None:
146 | self.after_iter_callback(self, loss.item())
147 |
148 | if interrupted:
149 | break
150 |
151 | cum_loss /= n_epoch_iters
152 | loss_log.append(cum_loss)
153 | if verbose:
154 | print(f"Epoch #{self.n_epochs}: loss={cum_loss}")
155 | self.n_epochs += 1
156 |
157 | if self.after_epoch_callback is not None:
158 | self.after_epoch_callback(self, cum_loss)
159 |
160 | return loss_log
161 |
162 | def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
163 | out = self.net(x.to(self.device, non_blocking=True), mask)
164 | if encoding_window == 'full_series':
165 | if slicing is not None:
166 | out = out[:, slicing]
167 | out = F.max_pool1d(
168 | out.transpose(1, 2),
169 | kernel_size = out.size(1),
170 | ).transpose(1, 2)
171 |
172 | elif isinstance(encoding_window, int):
173 | out = F.max_pool1d(
174 | out.transpose(1, 2),
175 | kernel_size = encoding_window,
176 | stride = 1,
177 | padding = encoding_window // 2
178 | ).transpose(1, 2)
179 | if encoding_window % 2 == 0:
180 | out = out[:, :-1]
181 | if slicing is not None:
182 | out = out[:, slicing]
183 |
184 | elif encoding_window == 'multiscale':
185 | p = 0
186 | reprs = []
187 | while (1 << p) + 1 < out.size(1):
188 | t_out = F.max_pool1d(
189 | out.transpose(1, 2),
190 | kernel_size = (1 << (p + 1)) + 1,
191 | stride = 1,
192 | padding = 1 << p
193 | ).transpose(1, 2)
194 | if slicing is not None:
195 | t_out = t_out[:, slicing]
196 | reprs.append(t_out)
197 | p += 1
198 | out = torch.cat(reprs, dim=-1)
199 |
200 | else:
201 | if slicing is not None:
202 | out = out[:, slicing]
203 |
204 | return out.cpu()
205 |
206 | def encode(self, data, mask=None, encoding_window=None, casual=False, sliding_length=None, sliding_padding=0, batch_size=None):
207 | ''' Compute representations using the model.
208 |
209 | Args:
210 | data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
211 | mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'.
212 | encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size.
213 | casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp.
214 | sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series.
215 | sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows.
216 | batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training.
217 |
218 | Returns:
219 | repr: The representations for data.
220 | '''
221 | assert self.net is not None, 'please train or load a net first'
222 | assert data.ndim == 3
223 | if batch_size is None:
224 | batch_size = self.batch_size
225 | n_samples, ts_l, _ = data.shape
226 |
227 | org_training = self.net.training
228 | self.net.eval()
229 |
230 | dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
231 | loader = DataLoader(dataset, batch_size=batch_size)
232 |
233 | with torch.no_grad():
234 | output = []
235 | for batch in loader:
236 | x = batch[0]
237 | if sliding_length is not None:
238 | reprs = []
239 | if n_samples < batch_size:
240 | calc_buffer = []
241 | calc_buffer_l = 0
242 | for i in range(0, ts_l, sliding_length):
243 | l = i - sliding_padding
244 | r = i + sliding_length + (sliding_padding if not casual else 0)
245 | x_sliding = torch_pad_nan(
246 | x[:, max(l, 0) : min(r, ts_l)],
247 | left=-l if l<0 else 0,
248 | right=r-ts_l if r>ts_l else 0,
249 | dim=1
250 | )
251 | if n_samples < batch_size:
252 | if calc_buffer_l + n_samples > batch_size:
253 | out = self._eval_with_pooling(
254 | torch.cat(calc_buffer, dim=0),
255 | mask,
256 | slicing=slice(sliding_padding, sliding_padding+sliding_length),
257 | encoding_window=encoding_window
258 | )
259 | reprs += torch.split(out, n_samples)
260 | calc_buffer = []
261 | calc_buffer_l = 0
262 | calc_buffer.append(x_sliding)
263 | calc_buffer_l += n_samples
264 | else:
265 | out = self._eval_with_pooling(
266 | x_sliding,
267 | mask,
268 | slicing=slice(sliding_padding, sliding_padding+sliding_length),
269 | encoding_window=encoding_window
270 | )
271 | reprs.append(out)
272 |
273 | if n_samples < batch_size:
274 | if calc_buffer_l > 0:
275 | out = self._eval_with_pooling(
276 | torch.cat(calc_buffer, dim=0),
277 | mask,
278 | slicing=slice(sliding_padding, sliding_padding+sliding_length),
279 | encoding_window=encoding_window
280 | )
281 | reprs += torch.split(out, n_samples)
282 | calc_buffer = []
283 | calc_buffer_l = 0
284 |
285 | out = torch.cat(reprs, dim=1)
286 | if encoding_window == 'full_series':
287 | out = F.max_pool1d(
288 | out.transpose(1, 2).contiguous(),
289 | kernel_size = out.size(1),
290 | ).squeeze(1)
291 | else:
292 | out = self._eval_with_pooling(x, mask, encoding_window=encoding_window)
293 | if encoding_window == 'full_series':
294 | out = out.squeeze(1)
295 |
296 | output.append(out)
297 |
298 | output = torch.cat(output, dim=0)
299 |
300 | self.net.train(org_training)
301 | return output.numpy()
302 |
303 | def save(self, fn):
304 | ''' Save the model to a file.
305 |
306 | Args:
307 | fn (str): filename.
308 | '''
309 | torch.save(self.net.state_dict(), fn)
310 |
311 | def load(self, fn):
312 | ''' Load the model from a file.
313 |
314 | Args:
315 | fn (str): filename.
316 | '''
317 | state_dict = torch.load(fn, map_location=self.device)
318 | self.net.load_state_dict(state_dict)
319 |
320 |
--------------------------------------------------------------------------------
/Models/ts2vec/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import torch
5 | import random
6 | from datetime import datetime
7 |
8 | def pkl_save(name, var):
9 | with open(name, 'wb') as f:
10 | pickle.dump(var, f)
11 |
12 | def pkl_load(name):
13 | with open(name, 'rb') as f:
14 | return pickle.load(f)
15 |
16 | def torch_pad_nan(arr, left=0, right=0, dim=0):
17 | if left > 0:
18 | padshape = list(arr.shape)
19 | padshape[dim] = left
20 | arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim)
21 | if right > 0:
22 | padshape = list(arr.shape)
23 | padshape[dim] = right
24 | arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim)
25 | return arr
26 |
27 | def pad_nan_to_target(array, target_length, axis=0, both_side=False):
28 | assert array.dtype in [np.float16, np.float32, np.float64]
29 | pad_size = target_length - array.shape[axis]
30 | if pad_size <= 0:
31 | return array
32 | npad = [(0, 0)] * array.ndim
33 | if both_side:
34 | npad[axis] = (pad_size // 2, pad_size - pad_size//2)
35 | else:
36 | npad[axis] = (0, pad_size)
37 | return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan)
38 |
39 | def split_with_nan(x, sections, axis=0):
40 | assert x.dtype in [np.float16, np.float32, np.float64]
41 | arrs = np.array_split(x, sections, axis=axis)
42 | target_length = arrs[0].shape[axis]
43 | for i in range(len(arrs)):
44 | arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis)
45 | return arrs
46 |
47 | def take_per_row(A, indx, num_elem):
48 | all_indx = indx[:,None] + np.arange(num_elem)
49 | return A[torch.arange(all_indx.shape[0])[:,None], all_indx]
50 |
51 | def centerize_vary_length_series(x):
52 | prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1)
53 | suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1)
54 | offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros
55 | rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]]
56 | offset[offset < 0] += x.shape[1]
57 | column_indices = column_indices - offset[:, np.newaxis]
58 | return x[rows, column_indices]
59 |
60 | def data_dropout(arr, p):
61 | B, T = arr.shape[0], arr.shape[1]
62 | mask = np.full(B*T, False, dtype=np.bool)
63 | ele_sel = np.random.choice(
64 | B*T,
65 | size=int(B*T*p),
66 | replace=False
67 | )
68 | mask[ele_sel] = True
69 | res = arr.copy()
70 | res[mask.reshape(B, T)] = np.nan
71 | return res
72 |
73 | def name_with_datetime(prefix='default'):
74 | now = datetime.now()
75 | return prefix + '_' + now.strftime("%Y%m%d_%H%M%S")
76 |
77 | def init_dl_program(
78 | device_name,
79 | seed=None,
80 | use_cudnn=True,
81 | deterministic=False,
82 | benchmark=False,
83 | use_tf32=False,
84 | max_threads=None
85 | ):
86 | import torch
87 | if max_threads is not None:
88 | torch.set_num_threads(max_threads) # intraop
89 | if torch.get_num_interop_threads() != max_threads:
90 | torch.set_num_interop_threads(max_threads) # interop
91 | try:
92 | import mkl
93 | except:
94 | pass
95 | else:
96 | mkl.set_num_threads(max_threads)
97 |
98 | if seed is not None:
99 | random.seed(seed)
100 | seed += 1
101 | np.random.seed(seed)
102 | seed += 1
103 | torch.manual_seed(seed)
104 |
105 | if isinstance(device_name, (str, int)):
106 | device_name = [device_name]
107 |
108 | devices = []
109 | for t in reversed(device_name):
110 | t_device = torch.device(t)
111 | devices.append(t_device)
112 | if t_device.type == 'cuda':
113 | assert torch.cuda.is_available()
114 | torch.cuda.set_device(t_device)
115 | if seed is not None:
116 | seed += 1
117 | torch.cuda.manual_seed(seed)
118 | devices.reverse()
119 | torch.backends.cudnn.enabled = use_cudnn
120 | torch.backends.cudnn.deterministic = deterministic
121 | torch.backends.cudnn.benchmark = benchmark
122 |
123 | if hasattr(torch.backends.cudnn, 'allow_tf32'):
124 | torch.backends.cudnn.allow_tf32 = use_tf32
125 | torch.backends.cuda.matmul.allow_tf32 = use_tf32
126 |
127 | return devices if len(devices) > 1 else devices[0]
128 |
129 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion-TS: Interpretable Diffusion for General Time Series Generation
2 |
3 | [](https://github.com/Y-debug-sys/Diffusion-TS/stargazers)
4 | [](https://github.com/Y-debug-sys/Diffusion-TS/network)
5 | [](https://github.com/Y-debug-sys/Diffusion-TS/blob/main/LICENSE)
6 |
7 |
8 |
9 | > **Abstract:** Denoising diffusion probabilistic models (DDPMs) are becoming the leading paradigm for generative models. It has recently shown breakthroughs in audio synthesis, time series imputation and forecasting. In this paper, we propose Diffusion-TS, a novel diffusion-based framework that generates multivariate time series samples of high quality by using an encoder-decoder transformer with disentangled temporal representations, in which the decomposition technique guides Diffusion-TS to capture the semantic meaning of time series while transformers mine detailed sequential information from the noisy model input. Different from existing diffusion-based approaches, we train the model to directly reconstruct the sample instead of the noise in each diffusion step, combining a Fourier-based loss term. Diffusion-TS is expected to generate time series satisfying both interpretablity and realness. In addition, it is shown that the proposed Diffusion-TS can be easily extended to conditional generation tasks, such as forecasting and imputation, without any model changes. This also motivates us to further explore the performance of Diffusion-TS under irregular settings. Finally, through qualitative and quantitative experiments, results show that Diffusion-TS achieves the state-of-the-art results on various realistic analyses of time series.
10 |
11 | Diffusion-TS is a diffusion-based framework that generates general time series samples both conditionally and unconditionally. As shown in Figure 1, the framework contains two parts: a sequence encoder and an interpretable decoder which decomposes the time series into seasonal part and trend part. The trend part contains the polynomial regressor and extracted mean of each block output. For seasonal part, we reuse trigonometric representations based on Fourier series. Regarding training, sampling and more details, please refer to [our paper](https://openreview.net/pdf?id=4h1apFjO99) in ICLR 2024.
12 |
13 |
14 |
15 |
16 | Figure 1: Overall Architecture of Diffusion-TS.
17 |
18 |
19 | 🎤 **Update (2025/2/28)**: We have added an additional [experiment](https://github.com/Y-debug-sys/Diffusion-TS/blob/main/Experiments/eeg_multiple_classes.ipynb) for the EEG dataset and incorporated the Classifier Guidance mechanism. Now, Diffusion-TS supports multi-class generation.
20 |
21 | ## Dataset Preparation
22 |
23 | All the four real-world datasets (Stocks, ETTh1, Energy and fMRI) can be obtained from [Google Drive](https://drive.google.com/file/d/11DI22zKWtHjXMnNGPWNUbyGz-JiEtZy6/view?usp=sharing). Please download **dataset.zip**, then unzip and copy it to the folder `./Data` in our repository. EEG dataset can be downloaded from [here](https://drive.google.com/file/d/1IqwE0wbCT1orVdZpul2xFiNkGnYs4t89/view?usp=sharing) and should also be placed in the aforementioned `./Data/dataset` folder.
24 |
25 | ## Running the Code
26 |
27 | The code requires conda3 (or miniconda3), and one CUDA capable GPU. The instructions below guide you regarding running the codes in this repository.
28 |
29 | ### Environment & Libraries
30 |
31 | The full libraries list is provided as a `requirements.txt` in this repo. Please create a virtual environment with `conda` or `venv` and run
32 |
33 | ~~~bash
34 | (myenv) $ pip install -r requirements.txt
35 | ~~~
36 |
37 | ### Training & Sampling
38 |
39 | For training, you can reproduce the experimental results of all benchmarks by runing
40 |
41 | ~~~bash
42 | (myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --train
43 | ~~~
44 |
45 | **Note:** We also provided the corresponding `.yml` files (only stocks, sines, mujoco, etth, energy and fmri) under the folder `./Config` where all possible option can be altered. You may need to change some parameters in the model for different scenarios. For example, we use the whole data to train model for unconditional evaluation, then *training_ratio* is set to 1 by default. As for conditional generation, we need to divide data set thus it should be changed to a value < 1.
46 |
47 | While training, the script will save check points to the *results* folder after a fixed number of epochs. Once trained, please use the saved model for sampling by running
48 |
49 | #### Unconstrained
50 | ```bash
51 | (myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --sample 0 --milestone {checkpoint_number}
52 | ```
53 |
54 | #### Imputation
55 | ```bash
56 | (myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --sample 1 --milestone {checkpoint_number} --mode infill --missing_ratio {missing_ratio}
57 | ```
58 |
59 | #### Forecasting
60 | ```bash
61 | (myenv) $ python main.py --name {dataset_name} --config_file {config.yaml} --gpu 0 --sample 1 --milestone {checkpoint_number} --mode predict --pred_len {pred_len}
62 | ```
63 |
64 |
65 | ## Visualization and Evaluation
66 |
67 | After sampling, synthetic data and orginal data are stored in `.npy` file format under the *output* folder, which can be directly read to calculate quantitative metrics such as discriminative, predictive, correlational and context-FID score. You can also reproduce the visualization results using t-SNE or kernel plotting, and all of these evaluational codes can be found in the folder `./Utils`. Please refer to `.ipynb` tutorial files in this repo for more detailed implementations.
68 |
69 | **Note:** All the metrics can be found in the `./Experiments` folder. Additionally, by default, for datasets other than the Sine dataset (because it do not need normalization), their normalized forms are saved in `{...}_norm_truth.npy`. Therefore, when you run the Jupternotebook for dataset other than Sine, just uncomment and rewrite the corresponding code written at the beginning.
70 |
71 | ### Main Results
72 |
73 | #### Standard TS Generation
74 |
75 | Table 1: Results of 24-length Time-series Generation.
76 |
77 |
78 |
79 |
80 | #### Long-term TS Generation
81 |
82 | Table 2: Results of Long-term Time-series Generation.
83 |
84 |
85 |
86 |
87 | #### Conditional TS Generation
88 |
89 |
90 |
91 | Figure 2: Visualizations of Time-series Imputation and Forecasting.
92 |
93 |
94 |
95 | ## Authors
96 |
97 | * Paper Authors : Xinyu Yuan, Yan Qiao
98 |
99 | * Code Author : Xinyu Yuan
100 |
101 | * Contact : yxy5315@gmail.com
102 |
103 |
104 | ## Citation
105 | If you find this repo useful, please cite our paper via
106 | ```bibtex
107 | @inproceedings{yuan2024diffusionts,
108 | title={Diffusion-{TS}: Interpretable Diffusion for General Time Series Generation},
109 | author={Xinyu Yuan and Yan Qiao},
110 | booktitle={The Twelfth International Conference on Learning Representations},
111 | year={2024},
112 | url={https://openreview.net/forum?id=4h1apFjO99}
113 | }
114 | ```
115 |
116 |
117 | ## Acknowledgement
118 |
119 | We appreciate the following github repos a lot for their valuable code base:
120 |
121 | https://github.com/lucidrains/denoising-diffusion-pytorch
122 |
123 | https://github.com/cientgu/VQ-Diffusion
124 |
125 | https://github.com/XiangLi1999/Diffusion-LM
126 |
127 | https://github.com/philipperemy/n-beats
128 |
129 | https://github.com/salesforce/ETSformer
130 |
131 | https://github.com/ermongroup/CSDI
132 |
133 | https://github.com/jsyoon0823/TimeGAN
134 |
--------------------------------------------------------------------------------
/Utils/Data_utils/eeg_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from scipy.io import arff
7 | from scipy import stats
8 | from copy import deepcopy
9 | from torch.utils.data import Dataset
10 | from Utils.masking_utils import noise_mask
11 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one
12 | from sklearn.preprocessing import MinMaxScaler
13 |
14 |
15 | class EEGDataset(Dataset):
16 | def __init__(
17 | self,
18 | data_root,
19 | window=64,
20 | save2npy=True,
21 | neg_one_to_one=True,
22 | period='train',
23 | output_dir='./OUTPUT'
24 | ):
25 | super(EEGDataset, self).__init__()
26 | assert period in ['train', 'test'], 'period must be train or test.'
27 |
28 | self.auto_norm, self.save2npy = neg_one_to_one, save2npy
29 | self.data_0, self.data_1, self.scaler = self.read_data(data_root, window)
30 | self.labels = np.zeros(self.data_0.shape[0] + self.data_1.shape[0]).astype(np.int64)
31 | self.labels[self.data_0.shape[0]:] = 1
32 | self.rawdata = np.vstack([self.data_0, self.data_1])
33 | self.dir = os.path.join(output_dir, 'samples')
34 | os.makedirs(self.dir, exist_ok=True)
35 |
36 | self.window, self.period = window, period
37 | self.len, self.var_num = self.rawdata.shape[0], self.rawdata.shape[-1]
38 |
39 | self.samples = self.normalize(self.rawdata)
40 |
41 | # np.save(os.path.join(self.dir, 'eeg_ground_0_truth.npy'), self.data_0)
42 | # np.save(os.path.join(self.dir, 'eeg_ground_1_truth.npy'), self.data_1)
43 |
44 | self.sample_num = self.samples.shape[0]
45 |
46 | def read_data(self, filepath, length):
47 | """
48 | Reads the data from the given filepath, removes outliers, classifies the data into two classes,
49 | and scales the data using MinMaxScaler.
50 |
51 | Args:
52 | filepath (str): Path to the .arff file containing the EEG data.
53 | length (int): Length of the window for classification.
54 | """
55 | data = arff.loadarff(filepath)
56 | df = pd.DataFrame(data[0])
57 | df['eyeDetection'] = df['eyeDetection'].astype('int')
58 |
59 | df = self.__OutlierRemoval__(df)
60 | df_0, df_1 = self.__Classify__(df, length=length)
61 | # df_0.to_csv('./EEG_Eye_State_0.csv', index=False)
62 | # df_1.to_csv('./EEG_Eye_State_1.csv', index=False)
63 |
64 | data_0 = df_0.values.reshape(df_0.shape[0], length, -1)
65 | data_1 = df_1.values.reshape(df_1.shape[0], length, -1)
66 |
67 | # print(f"Class 0: {data_0.shape}, Class 1: {data_1.shape}")
68 |
69 | data = np.vstack([data_0.reshape(-1, data_0.shape[-1]), data_1.reshape(-1, data_1.shape[-1])])
70 |
71 | scaler = MinMaxScaler()
72 | scaler = scaler.fit(data)
73 |
74 | return data_0, data_1, scaler
75 |
76 | @staticmethod
77 | def __OutlierRemoval__(df):
78 | """
79 | Removes outliers from the dataframe using z-score method and interpolates the missing values.
80 |
81 | Args:
82 | df (pd.DataFrame): Dataframe containing the EEG data.
83 |
84 | Returns:
85 | pd.DataFrame: Cleaned dataframe with outliers removed and missing values interpolated.
86 | """
87 | temp_data_frame = deepcopy(df)
88 | clean_data_frame = deepcopy(df)
89 | for column in temp_data_frame.columns[:-1]:
90 | temp_data_frame[str(column)+'z_score'] = stats.zscore(temp_data_frame[column])
91 | clean_data_frame[column] = temp_data_frame.loc[temp_data_frame[str(column)+'z_score'].abs()<=3][column]
92 |
93 | clean_data_frame.interpolate(method='linear', inplace=True)
94 |
95 | temp_data_frame = deepcopy(clean_data_frame)
96 | clean_data_frame_second = deepcopy(clean_data_frame)
97 |
98 | for column in temp_data_frame.columns[:-1]:
99 | temp_data_frame[str(column)+'z_score'] = stats.zscore(temp_data_frame[column])
100 | clean_data_frame_second[column] = temp_data_frame.loc[temp_data_frame[str(column)+'z_score'].abs()<=3][column]
101 |
102 | clean_data_frame_second.interpolate(method='linear', inplace=True)
103 | return clean_data_frame
104 |
105 | @staticmethod
106 | def __Classify__(df, length=100):
107 | """
108 | Classifies the data into two classes based on the eyeDetection column and creates signals for the two classes.
109 |
110 | Args:
111 | df (pd.DataFrame): Dataframe containing the EEG data.
112 | length (int): Length of the window for classification.
113 |
114 | Returns:
115 | pd.DataFrame: Dataframe containing the signals for class 0.
116 | pd.DataFrame: Dataframe containing the signals for class 1.
117 | """
118 | # normalize the columns between -1 and 1
119 | mean, max, min = df.mean(), df.max(), df.min()
120 | df = 2*(df - mean) / (max - min)
121 |
122 | df['edge'] = df['eyeDetection'].diff()
123 | df['edge'][0] = 0.0
124 |
125 | starting = df['edge'][df['edge']==2]
126 | starting = starting.iloc[:-1]
127 | starting_time = starting.index.values
128 | ending = df['edge'][df['edge']==-2]
129 | ending_time = ending.index.values
130 | end_time_with_0 = np.insert(ending_time, 0, 0)
131 |
132 | signal_0 = []
133 | singal_1 = []
134 |
135 | # create signal 0
136 | for start, end in zip(end_time_with_0, starting_time):
137 | for i in range(start+50, end-length-50, 1):
138 | temp = []
139 | for channel in df.columns[:-2]:
140 | temp.append(df[channel][i:i+length])
141 |
142 | signal_0.append(np.hstack(temp))
143 |
144 | # create signal 1
145 | for start, end in zip(starting_time, ending_time):
146 | for i in range(start+50, end-length-50, 1):
147 | temp = []
148 | for channel in df.columns[:-2]:
149 | temp.append(df[channel][i:i+length])
150 |
151 | singal_1.append(np.hstack(temp))
152 |
153 | df_0 = pd.DataFrame(signal_0)
154 | df_1 = pd.DataFrame(singal_1)
155 |
156 | # min_samples = min(df_0.shape[0], df_1.shape[0]) - 1
157 | # # chop the data to the same length
158 | # df_0 = df_0.iloc[:min_samples, :]
159 | # df_1 = df_1.iloc[:min_samples, :]
160 |
161 | return df_0, df_1
162 |
163 | def __getitem__(self, ind):
164 | if self.period == 'test':
165 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
166 | y = self.labels[ind] # (1,) int
167 | return torch.from_numpy(x).float(), torch.tensor(y)
168 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
169 | return torch.from_numpy(x).float()
170 |
171 | def __len__(self):
172 | return self.sample_num
173 |
174 | def normalize(self, sq):
175 | d = self.__normalize(sq.reshape(-1, self.var_num))
176 | data = d.reshape(-1, self.window, self.var_num)
177 | return data
178 |
179 | def __normalize(self, rawdata):
180 | data = self.scaler.transform(rawdata)
181 | if self.auto_norm:
182 | data = normalize_to_neg_one_to_one(data)
183 | return data
184 |
185 | def unnormalize(self, sq):
186 | d = self.__unnormalize(sq.reshape(-1, self.var_num))
187 | return d.reshape(-1, self.window, self.var_num)
188 |
189 | def __unnormalize(self, data):
190 | if self.auto_norm:
191 | data = unnormalize_to_zero_to_one(data)
192 | x = data
193 | return self.scaler.inverse_transform(x)
194 |
195 | def shift_period(self, period):
196 | assert period in ['train', 'test'], 'period must be train or test.'
197 | self.period = period
--------------------------------------------------------------------------------
/Utils/Data_utils/mujoco_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from torch.utils.data import Dataset
6 | from sklearn.preprocessing import MinMaxScaler
7 |
8 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one
9 | from Utils.masking_utils import noise_mask
10 |
11 |
12 | class MuJoCoDataset(Dataset):
13 | def __init__(
14 | self,
15 | window=128,
16 | num=30000,
17 | dim=12,
18 | save2npy=True,
19 | neg_one_to_one=True,
20 | seed=123,
21 | scalar=None,
22 | period='train',
23 | output_dir='./OUTPUT',
24 | predict_length=None,
25 | missing_ratio=None,
26 | style='separate',
27 | distribution='geometric',
28 | mean_mask_length=3
29 | ):
30 | super(MuJoCoDataset, self).__init__()
31 | assert period in ['train', 'test'], 'period must be train or test.'
32 | if period == 'train':
33 | assert ~(predict_length is not None or missing_ratio is not None), ''
34 |
35 | self.window, self.var_num = window, dim
36 | self.auto_norm = neg_one_to_one
37 | self.dir = os.path.join(output_dir, 'samples')
38 | os.makedirs(self.dir, exist_ok=True)
39 | self.pred_len, self.missing_ratio = predict_length, missing_ratio
40 | self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length
41 |
42 | self.rawdata, self.scaler = self._generate_random_trajectories(n_samples=num, seed=seed)
43 | if scalar is not None:
44 | self.scaler = scalar
45 |
46 | self.period, self.save2npy = period, save2npy
47 | self.samples = self.normalize(self.rawdata)
48 | self.sample_num = self.samples.shape[0]
49 |
50 | if period == 'test':
51 | if missing_ratio is not None:
52 | self.masking = self.mask_data(seed)
53 | elif predict_length is not None:
54 | masks = np.ones(self.samples.shape)
55 | masks[:, -predict_length:, :] = 0
56 | self.masking = masks.astype(bool)
57 | else:
58 | raise NotImplementedError()
59 |
60 | def _generate_random_trajectories(self, n_samples, seed=123):
61 | try:
62 | from dm_control import suite # noqa: F401
63 | except ImportError as e:
64 | raise Exception('Deepmind Control Suite is required to generate the dataset.') from e
65 |
66 | env = suite.load('hopper', 'stand')
67 | physics = env.physics
68 |
69 | # Store the state of the RNG to restore later.
70 | st0 = np.random.get_state()
71 | np.random.seed(seed)
72 |
73 | data = np.zeros((n_samples, self.window, self.var_num))
74 | for i in range(n_samples):
75 | with physics.reset_context():
76 | # x and z positions of the hopper. We want z > 0 for the hopper to stay above ground.
77 | physics.data.qpos[:2] = np.random.uniform(0, 0.5, size=2)
78 | physics.data.qpos[2:] = np.random.uniform(-2, 2, size=physics.data.qpos[2:].shape)
79 | physics.data.qvel[:] = np.random.uniform(-5, 5, size=physics.data.qvel.shape)
80 |
81 | for t in range(self.window):
82 | data[i, t, :self.var_num // 2] = physics.data.qpos
83 | data[i, t, self.var_num // 2:] = physics.data.qvel
84 | physics.step()
85 |
86 | # Restore RNG.
87 | np.random.set_state(st0)
88 |
89 | scaler = MinMaxScaler()
90 | scaler = scaler.fit(data.reshape(-1, self.var_num))
91 | return data, scaler
92 |
93 | def normalize(self, sq):
94 | d = self.__normalize(sq.reshape(-1, self.var_num))
95 | data = d.reshape(-1, self.window, self.var_num)
96 | if self.save2npy:
97 | np.save(os.path.join(self.dir, f"mujoco_ground_truth_{self.window}_{self.period}.npy"), sq)
98 |
99 | if self.auto_norm:
100 | np.save(os.path.join(self.dir, f"mujoco_norm_truth_{self.window}_{self.period}.npy"), unnormalize_to_zero_to_one(data))
101 | else:
102 | np.save(os.path.join(self.dir, f"mujoco_norm_truth_{self.window}_{self.period}.npy"), data)
103 |
104 | return data
105 |
106 | def __normalize(self, rawdata):
107 | data = self.scaler.transform(rawdata)
108 | if self.auto_norm:
109 | data = normalize_to_neg_one_to_one(data)
110 | return data
111 |
112 | def unnormalize(self, sq):
113 | d = self.__unnormalize(sq.reshape(-1, self.var_num))
114 | return d.reshape(-1, self.window, self.var_num)
115 |
116 | def __unnormalize(self, data):
117 | if self.auto_norm:
118 | data = unnormalize_to_zero_to_one(data)
119 | x = data
120 | return self.scaler.inverse_transform(x)
121 |
122 | def mask_data(self, seed=2023):
123 | masks = np.ones_like(self.samples)
124 | # Store the state of the RNG to restore later.
125 | st0 = np.random.get_state()
126 | np.random.seed(seed)
127 |
128 | for idx in range(self.samples.shape[0]):
129 | x = self.samples[idx, :, :] # (seq_length, feat_dim) array
130 | mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style,
131 | self.distribution) # (seq_length, feat_dim) boolean array
132 | masks[idx, :, :] = mask
133 |
134 | if self.save2npy:
135 | np.save(os.path.join(self.dir, f"mujoco_masking_{self.window}.npy"), masks)
136 |
137 | # Restore RNG.
138 | np.random.set_state(st0)
139 | return masks.astype(bool)
140 |
141 | def __getitem__(self, ind):
142 | if self.period == 'test':
143 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
144 | m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array
145 | return torch.from_numpy(x).float(), torch.from_numpy(m)
146 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
147 | return torch.from_numpy(x).float()
148 |
149 | def __len__(self):
150 | return self.sample_num
151 |
--------------------------------------------------------------------------------
/Utils/Data_utils/real_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from scipy import io
7 | from sklearn.preprocessing import MinMaxScaler
8 | from torch.utils.data import Dataset
9 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one
10 | from Utils.masking_utils import noise_mask
11 |
12 |
13 | class CustomDataset(Dataset):
14 | def __init__(
15 | self,
16 | name,
17 | data_root,
18 | window=64,
19 | proportion=0.8,
20 | save2npy=True,
21 | neg_one_to_one=True,
22 | seed=123,
23 | period='train',
24 | output_dir='./OUTPUT',
25 | predict_length=None,
26 | missing_ratio=None,
27 | style='separate',
28 | distribution='geometric',
29 | mean_mask_length=3
30 | ):
31 | super(CustomDataset, self).__init__()
32 | assert period in ['train', 'test'], 'period must be train or test.'
33 | if period == 'train':
34 | assert ~(predict_length is not None or missing_ratio is not None), ''
35 | self.name, self.pred_len, self.missing_ratio = name, predict_length, missing_ratio
36 | self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length
37 | self.rawdata, self.scaler = self.read_data(data_root, self.name)
38 | self.dir = os.path.join(output_dir, 'samples')
39 | os.makedirs(self.dir, exist_ok=True)
40 |
41 | self.window, self.period = window, period
42 | self.len, self.var_num = self.rawdata.shape[0], self.rawdata.shape[-1]
43 | self.sample_num_total = max(self.len - self.window + 1, 0)
44 | self.save2npy = save2npy
45 | self.auto_norm = neg_one_to_one
46 |
47 | self.data = self.__normalize(self.rawdata)
48 | train, inference = self.__getsamples(self.data, proportion, seed)
49 |
50 | self.samples = train if period == 'train' else inference
51 | if period == 'test':
52 | if missing_ratio is not None:
53 | self.masking = self.mask_data(seed)
54 | elif predict_length is not None:
55 | masks = np.ones(self.samples.shape)
56 | masks[:, -predict_length:, :] = 0
57 | self.masking = masks.astype(bool)
58 | else:
59 | raise NotImplementedError()
60 | self.sample_num = self.samples.shape[0]
61 |
62 | def __getsamples(self, data, proportion, seed):
63 | x = np.zeros((self.sample_num_total, self.window, self.var_num))
64 | for i in range(self.sample_num_total):
65 | start = i
66 | end = i + self.window
67 | x[i, :, :] = data[start:end, :]
68 |
69 | train_data, test_data = self.divide(x, proportion, seed)
70 |
71 | if self.save2npy:
72 | if 1 - proportion > 0:
73 | np.save(os.path.join(self.dir, f"{self.name}_ground_truth_{self.window}_test.npy"), self.unnormalize(test_data))
74 | np.save(os.path.join(self.dir, f"{self.name}_ground_truth_{self.window}_train.npy"), self.unnormalize(train_data))
75 | if self.auto_norm:
76 | if 1 - proportion > 0:
77 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_test.npy"), unnormalize_to_zero_to_one(test_data))
78 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_train.npy"), unnormalize_to_zero_to_one(train_data))
79 | else:
80 | if 1 - proportion > 0:
81 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_test.npy"), test_data)
82 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_train.npy"), train_data)
83 |
84 | return train_data, test_data
85 |
86 | def normalize(self, sq):
87 | d = sq.reshape(-1, self.var_num)
88 | d = self.scaler.transform(d)
89 | if self.auto_norm:
90 | d = normalize_to_neg_one_to_one(d)
91 | return d.reshape(-1, self.window, self.var_num)
92 |
93 | def unnormalize(self, sq):
94 | d = self.__unnormalize(sq.reshape(-1, self.var_num))
95 | return d.reshape(-1, self.window, self.var_num)
96 |
97 | def __normalize(self, rawdata):
98 | data = self.scaler.transform(rawdata)
99 | if self.auto_norm:
100 | data = normalize_to_neg_one_to_one(data)
101 | return data
102 |
103 | def __unnormalize(self, data):
104 | if self.auto_norm:
105 | data = unnormalize_to_zero_to_one(data)
106 | x = data
107 | return self.scaler.inverse_transform(x)
108 |
109 | @staticmethod
110 | def divide(data, ratio, seed=2023):
111 | size = data.shape[0]
112 | # Store the state of the RNG to restore later.
113 | st0 = np.random.get_state()
114 | np.random.seed(seed)
115 |
116 | regular_train_num = int(np.ceil(size * ratio))
117 | # id_rdm = np.random.permutation(size)
118 | id_rdm = np.arange(size)
119 | regular_train_id = id_rdm[:regular_train_num]
120 | irregular_train_id = id_rdm[regular_train_num:]
121 |
122 | regular_data = data[regular_train_id, :]
123 | irregular_data = data[irregular_train_id, :]
124 |
125 | # Restore RNG.
126 | np.random.set_state(st0)
127 | return regular_data, irregular_data
128 |
129 | @staticmethod
130 | def read_data(filepath, name=''):
131 | """Reads a single .csv
132 | """
133 | df = pd.read_csv(filepath, header=0)
134 | if name == 'etth':
135 | df.drop(df.columns[0], axis=1, inplace=True)
136 | data = df.values
137 | scaler = MinMaxScaler()
138 | scaler = scaler.fit(data)
139 | return data, scaler
140 |
141 | def mask_data(self, seed=2023):
142 | masks = np.ones_like(self.samples)
143 | # Store the state of the RNG to restore later.
144 | st0 = np.random.get_state()
145 | np.random.seed(seed)
146 |
147 | for idx in range(self.samples.shape[0]):
148 | x = self.samples[idx, :, :] # (seq_length, feat_dim) array
149 | mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style,
150 | self.distribution) # (seq_length, feat_dim) boolean array
151 | masks[idx, :, :] = mask
152 |
153 | if self.save2npy:
154 | np.save(os.path.join(self.dir, f"{self.name}_masking_{self.window}.npy"), masks)
155 |
156 | # Restore RNG.
157 | np.random.set_state(st0)
158 | return masks.astype(bool)
159 |
160 | def __getitem__(self, ind):
161 | if self.period == 'test':
162 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
163 | m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array
164 | return torch.from_numpy(x).float(), torch.from_numpy(m)
165 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
166 | return torch.from_numpy(x).float()
167 |
168 | def __len__(self):
169 | return self.sample_num
170 |
171 |
172 | class fMRIDataset(CustomDataset):
173 | def __init__(
174 | self,
175 | proportion=1.,
176 | **kwargs
177 | ):
178 | super().__init__(proportion=proportion, **kwargs)
179 |
180 | @staticmethod
181 | def read_data(filepath, name=''):
182 | """Reads a single .csv
183 | """
184 | data = io.loadmat(filepath + '/sim4.mat')['ts']
185 | scaler = MinMaxScaler()
186 | scaler = scaler.fit(data)
187 | return data, scaler
188 |
--------------------------------------------------------------------------------
/Utils/Data_utils/sine_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from tqdm.auto import tqdm
6 | from torch.utils.data import Dataset
7 |
8 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one
9 | from Utils.masking_utils import noise_mask
10 |
11 |
12 | class SineDataset(Dataset):
13 | def __init__(
14 | self,
15 | window=128,
16 | num=30000,
17 | dim=12,
18 | save2npy=True,
19 | neg_one_to_one=True,
20 | seed=123,
21 | period='train',
22 | output_dir='./OUTPUT',
23 | predict_length=None,
24 | missing_ratio=None,
25 | style='separate',
26 | distribution='geometric',
27 | mean_mask_length=3
28 | ):
29 | super(SineDataset, self).__init__()
30 | assert period in ['train', 'test'], 'period must be train or test.'
31 | if period == 'train':
32 | assert ~(predict_length is not None or missing_ratio is not None), ''
33 |
34 | self.pred_len, self.missing_ratio = predict_length, missing_ratio
35 | self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length
36 |
37 | self.dir = os.path.join(output_dir, 'samples')
38 | os.makedirs(self.dir, exist_ok=True)
39 |
40 | self.rawdata = self.sine_data_generation(no=num, seq_len=window, dim=dim, save2npy=save2npy,
41 | seed=seed, dir=self.dir, period=period)
42 | self.auto_norm = neg_one_to_one
43 | self.samples = self.normalize(self.rawdata)
44 | self.var_num = dim
45 | self.sample_num = self.samples.shape[0]
46 | self.window = window
47 |
48 | self.period, self.save2npy = period, save2npy
49 | if period == 'test':
50 | if missing_ratio is not None:
51 | self.masking = self.mask_data(seed)
52 | elif predict_length is not None:
53 | masks = np.ones(self.samples.shape)
54 | masks[:, -predict_length:, :] = 0
55 | self.masking = masks.astype(bool)
56 | else:
57 | raise NotImplementedError()
58 |
59 | def normalize(self, rawdata):
60 | if self.auto_norm:
61 | data = normalize_to_neg_one_to_one(rawdata)
62 | return data
63 |
64 | def unnormalize(self, data):
65 | if self.auto_norm:
66 | data = unnormalize_to_zero_to_one(data)
67 | return data
68 |
69 | @staticmethod
70 | def sine_data_generation(no, seq_len, dim, save2npy=True, seed=123, dir="./", period='train'):
71 | """Sine data generation.
72 |
73 | Args:
74 | - no: the number of samples
75 | - seq_len: sequence length of the time-series
76 | - dim: feature dimensions
77 |
78 | Returns:
79 | - data: generated data
80 | """
81 | # Store the state of the RNG to restore later.
82 | st0 = np.random.get_state()
83 | np.random.seed(seed)
84 |
85 | # Initialize the output
86 | data = list()
87 | # Generate sine data
88 | for i in tqdm(range(0, no), total=no, desc="Sampling sine-dataset"):
89 | # Initialize each time-series
90 | temp = list()
91 | # For each feature
92 | for k in range(dim):
93 | # Randomly drawn frequency and phase
94 | freq = np.random.uniform(0, 0.1)
95 | phase = np.random.uniform(0, 0.1)
96 |
97 | # Generate sine signal based on the drawn frequency and phase
98 | temp_data = [np.sin(freq * j + phase) for j in range(seq_len)]
99 | temp.append(temp_data)
100 |
101 | # Align row/column
102 | temp = np.transpose(np.asarray(temp))
103 | # Normalize to [0,1]
104 | temp = (temp + 1)*0.5
105 | # Stack the generated data
106 | data.append(temp)
107 |
108 | # Restore RNG.
109 | np.random.set_state(st0)
110 | data = np.array(data)
111 | if save2npy:
112 | np.save(os.path.join(dir, f"sine_ground_truth_{seq_len}_{period}.npy"), data)
113 |
114 | return data
115 |
116 | def mask_data(self, seed=2023):
117 | masks = np.ones_like(self.samples)
118 | # Store the state of the RNG to restore later.
119 | st0 = np.random.get_state()
120 | np.random.seed(seed)
121 |
122 | for idx in range(self.samples.shape[0]):
123 | x = self.samples[idx, :, :] # (seq_length, feat_dim) array
124 | mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style,
125 | self.distribution) # (seq_length, feat_dim) boolean array
126 | masks[idx, :, :] = mask
127 |
128 | if self.save2npy:
129 | np.save(os.path.join(self.dir, f"sine_masking_{self.window}.npy"), masks)
130 |
131 | # Restore RNG.
132 | np.random.set_state(st0)
133 | return masks.astype(bool)
134 |
135 | def __getitem__(self, ind):
136 | if self.period == 'test':
137 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
138 | m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array
139 | return torch.from_numpy(x).float(), torch.from_numpy(m)
140 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array
141 | return torch.from_numpy(x).float()
142 |
143 | def __len__(self):
144 | return self.sample_num
145 |
--------------------------------------------------------------------------------
/Utils/context_fid.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import numpy as np
3 |
4 | from Models.ts2vec.ts2vec import TS2Vec
5 |
6 |
7 | def calculate_fid(act1, act2):
8 | # calculate mean and covariance statistics
9 | mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
10 | mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
11 | # calculate sum squared difference between means
12 | ssdiff = np.sum((mu1 - mu2)**2.0)
13 | # calculate sqrt of product between cov
14 | covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))
15 | # check and correct imaginary numbers from sqrt
16 | if np.iscomplexobj(covmean):
17 | covmean = covmean.real
18 | # calculate score
19 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
20 | return fid
21 |
22 | def Context_FID(ori_data, generated_data):
23 | model = TS2Vec(input_dims=ori_data.shape[-1], device=0, batch_size=8, lr=0.001, output_dims=320,
24 | max_train_length=3000)
25 | model.fit(ori_data, verbose=False)
26 | ori_represenation = model.encode(ori_data, encoding_window='full_series')
27 | gen_represenation = model.encode(generated_data, encoding_window='full_series')
28 | idx = np.random.permutation(ori_data.shape[0])
29 | ori_represenation = ori_represenation[idx]
30 | gen_represenation = gen_represenation[idx]
31 | results = calculate_fid(ori_represenation, gen_represenation)
32 | return results
--------------------------------------------------------------------------------
/Utils/cross_correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def cacf_torch(x, max_lag, dim=(0, 1)):
6 | def get_lower_triangular_indices(n):
7 | return [list(x) for x in torch.tril_indices(n, n)]
8 |
9 | ind = get_lower_triangular_indices(x.shape[2])
10 | x = (x - x.mean(dim, keepdims=True)) / x.std(dim, keepdims=True)
11 | x_l = x[..., ind[0]]
12 | x_r = x[..., ind[1]]
13 | cacf_list = list()
14 | for i in range(max_lag):
15 | y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r
16 | cacf_i = torch.mean(y, (1))
17 | cacf_list.append(cacf_i)
18 | cacf = torch.cat(cacf_list, 1)
19 | return cacf.reshape(cacf.shape[0], -1, len(ind[0]))
20 |
21 |
22 | class Loss(nn.Module):
23 | def __init__(self, name, reg=1.0, transform=lambda x: x, threshold=10., backward=False, norm_foo=lambda x: x):
24 | super(Loss, self).__init__()
25 | self.name = name
26 | self.reg = reg
27 | self.transform = transform
28 | self.threshold = threshold
29 | self.backward = backward
30 | self.norm_foo = norm_foo
31 |
32 | def forward(self, x_fake):
33 | self.loss_componentwise = self.compute(x_fake)
34 | return self.reg * self.loss_componentwise.mean()
35 |
36 | def compute(self, x_fake):
37 | raise NotImplementedError()
38 |
39 | @property
40 | def success(self):
41 | return torch.all(self.loss_componentwise <= self.threshold)
42 |
43 |
44 | class CrossCorrelLoss(Loss):
45 | def __init__(self, x_real, **kwargs):
46 | super(CrossCorrelLoss, self).__init__(norm_foo=lambda x: torch.abs(x).sum(0), **kwargs)
47 | self.cross_correl_real = cacf_torch(self.transform(x_real), 1).mean(0)[0]
48 |
49 | def compute(self, x_fake):
50 | cross_correl_fake = cacf_torch(self.transform(x_fake), 1).mean(0)[0]
51 | loss = self.norm_foo(cross_correl_fake - self.cross_correl_real.to(x_fake.device))
52 | return loss / 10.
--------------------------------------------------------------------------------
/Utils/discriminative_metric.py:
--------------------------------------------------------------------------------
1 | """Reimplement TimeGAN-pytorch Codebase.
2 |
3 | Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
4 | "Time-series Generative Adversarial Networks,"
5 | Neural Information Processing Systems (NeurIPS), 2019.
6 |
7 | Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks
8 |
9 | Last updated Date: October 18th 2021
10 | Code author: Zhiwei Zhang (bitzzw@gmail.com)
11 |
12 | -----------------------------
13 |
14 | predictive_metrics.py
15 |
16 | Note: Use post-hoc RNN to classify original data and synthetic data
17 |
18 | Output: discriminative score (np.abs(classification accuracy - 0.5))
19 | """
20 |
21 | # Necessary Packages
22 | import tensorflow as tf
23 | import tensorflow._api.v2.compat.v1 as tf1
24 | import numpy as np
25 | from sklearn.metrics import accuracy_score
26 | from Utils.metric_utils import train_test_divide, extract_time
27 |
28 |
29 | def batch_generator(data, time, batch_size):
30 | """Mini-batch generator.
31 |
32 | Args:
33 | - data: time-series data
34 | - time: time information
35 | - batch_size: the number of samples in each batch
36 |
37 | Returns:
38 | - X_mb: time-series data in each batch
39 | - T_mb: time information in each batch
40 | """
41 | no = len(data)
42 | idx = np.random.permutation(no)
43 | train_idx = idx[:batch_size]
44 |
45 | X_mb = list(data[i] for i in train_idx)
46 | T_mb = list(time[i] for i in train_idx)
47 |
48 | return X_mb, T_mb
49 |
50 |
51 | def discriminative_score_metrics (ori_data, generated_data):
52 | """Use post-hoc RNN to classify original data and synthetic data
53 |
54 | Args:
55 | - ori_data: original data
56 | - generated_data: generated synthetic data
57 |
58 | Returns:
59 | - discriminative_score: np.abs(classification accuracy - 0.5)
60 | """
61 | # Initialization on the Graph
62 | tf1.reset_default_graph()
63 |
64 | # Basic Parameters
65 | no, seq_len, dim = np.asarray(ori_data).shape
66 |
67 | # Set maximum sequence length and each sequence length
68 | ori_time, ori_max_seq_len = extract_time(ori_data)
69 | generated_time, generated_max_seq_len = extract_time(ori_data)
70 | max_seq_len = max([ori_max_seq_len, generated_max_seq_len])
71 |
72 | ## Builde a post-hoc RNN discriminator network
73 | # Network parameters
74 | hidden_dim = int(dim/2)
75 | iterations = 2000
76 | batch_size = 128
77 |
78 | # Input place holders
79 | # Feature
80 | X = tf1.placeholder(tf.float32, [None, max_seq_len, dim], name = "myinput_x")
81 | X_hat = tf1.placeholder(tf.float32, [None, max_seq_len, dim], name = "myinput_x_hat")
82 |
83 | T = tf1.placeholder(tf.int32, [None], name = "myinput_t")
84 | T_hat = tf1.placeholder(tf.int32, [None], name = "myinput_t_hat")
85 |
86 | # discriminator function
87 | def discriminator (x, t):
88 | """Simple discriminator function.
89 |
90 | Args:
91 | - x: time-series data
92 | - t: time information
93 |
94 | Returns:
95 | - y_hat_logit: logits of the discriminator output
96 | - y_hat: discriminator output
97 | - d_vars: discriminator variables
98 | """
99 | with tf1.variable_scope("discriminator", reuse = tf1.AUTO_REUSE) as vs:
100 | d_cell = tf1.nn.rnn_cell.GRUCell(num_units=hidden_dim, activation=tf.nn.tanh, name = 'd_cell')
101 | d_outputs, d_last_states = tf1.nn.dynamic_rnn(d_cell, x, dtype=tf.float32, sequence_length = t)
102 | # y_hat_logit = tf1.contrib.layers.fully_connected(d_last_states, 1, activation_fn=None)
103 | y_hat_logit = tf1.layers.dense(d_last_states, 1, activation=None)
104 | y_hat = tf.nn.sigmoid(y_hat_logit)
105 | d_vars = [v for v in tf1.all_variables() if v.name.startswith(vs.name)]
106 |
107 | return y_hat_logit, y_hat, d_vars
108 |
109 | y_logit_real, y_pred_real, d_vars = discriminator(X, T)
110 | y_logit_fake, y_pred_fake, _ = discriminator(X_hat, T_hat)
111 |
112 | # Loss for the discriminator
113 | d_loss_real = tf1.reduce_mean(tf1.nn.sigmoid_cross_entropy_with_logits(logits = y_logit_real,
114 | labels = tf1.ones_like(y_logit_real)))
115 | d_loss_fake = tf1.reduce_mean(tf1.nn.sigmoid_cross_entropy_with_logits(logits = y_logit_fake,
116 | labels = tf1.zeros_like(y_logit_fake)))
117 | d_loss = d_loss_real + d_loss_fake
118 |
119 | # optimizer
120 | d_solver = tf1.train.AdamOptimizer().minimize(d_loss, var_list = d_vars)
121 |
122 | ## Train the discriminator
123 | # Start session and initialize
124 | sess = tf1.Session()
125 | sess.run(tf1.global_variables_initializer())
126 |
127 | # Train/test division for both original and generated data
128 | train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat = \
129 | train_test_divide(ori_data, generated_data, ori_time, generated_time)
130 |
131 | from tqdm.auto import tqdm
132 |
133 | # Training step
134 | for itt in tqdm(range(iterations), desc='training', total=iterations):
135 |
136 | # Batch setting
137 | X_mb, T_mb = batch_generator(train_x, train_t, batch_size)
138 | X_hat_mb, T_hat_mb = batch_generator(train_x_hat, train_t_hat, batch_size)
139 |
140 | # Train discriminator
141 | _, step_d_loss = sess.run([d_solver, d_loss],
142 | feed_dict={X: X_mb, T: T_mb, X_hat: X_hat_mb, T_hat: T_hat_mb})
143 |
144 | ## Test the performance on the testing set
145 | y_pred_real_curr, y_pred_fake_curr = sess.run([y_pred_real, y_pred_fake],
146 | feed_dict={X: test_x, T: test_t, X_hat: test_x_hat, T_hat: test_t_hat})
147 |
148 | y_pred_final = np.squeeze(np.concatenate((y_pred_real_curr, y_pred_fake_curr), axis = 0))
149 | y_label_final = np.concatenate((np.ones([len(y_pred_real_curr),]), np.zeros([len(y_pred_fake_curr),])), axis = 0)
150 |
151 | # Compute the accuracy
152 | acc = accuracy_score(y_label_final, (y_pred_final>0.5))
153 |
154 | fake_acc = accuracy_score(np.zeros([len(y_pred_fake_curr),]), (y_pred_fake_curr>0.5))
155 | real_acc = accuracy_score(np.ones([len(y_pred_fake_curr),]), (y_pred_real_curr>0.5))
156 | # print("Fake Accuracy: ", fake_acc)
157 | # print("Real Accuracy: ", real_acc)
158 |
159 | discriminative_score = np.abs(0.5-acc)
160 | return discriminative_score, fake_acc, real_acc
161 |
--------------------------------------------------------------------------------
/Utils/imputation_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | import matplotlib.pyplot as plt
5 |
6 | from torch import nn
7 |
8 |
9 | def get_quantile(samples,q,dim=1):
10 | return torch.quantile(samples,q,dim=dim).cpu().numpy()
11 |
12 | def plot_sample(ori_data, gen_data, masks, sample_idx=0):
13 | plt.rcParams["font.size"] = 12
14 | fig, axes = plt.subplots(nrows=7, ncols=4, figsize=(12, 15))
15 | sample_num, seq_len, feat_dim = ori_data.shape
16 | observed = ori_data * masks
17 |
18 | quantiles = []
19 | quantiles.append(get_quantile(torch.from_numpy(gen_data), 0.5, dim=0) * (1 - masks) + observed)
20 | quantiles.append(get_quantile(torch.from_numpy(gen_data), 0.05, dim=0) * (1 - masks) + observed)
21 | quantiles.append(get_quantile(torch.from_numpy(gen_data), 0.95, dim=0) * (1 - masks) + observed)
22 |
23 | for feat_idx in range(feat_dim):
24 | row = feat_idx // 4
25 | col = feat_idx % 4
26 |
27 | df_x = pd.DataFrame({"x": np.arange(0, seq_len), "val": ori_data[sample_idx, :, feat_idx],
28 | "y": masks[sample_idx, :, feat_idx]})
29 | df_x = df_x[df_x.y!=0]
30 |
31 | df_o = pd.DataFrame({"x": np.arange(0, seq_len), "val": ori_data[sample_idx, :, feat_idx],
32 | "y": (1 - masks)[sample_idx, :, feat_idx]})
33 | df_o = df_o[df_o.y!=0]
34 |
35 | axes[row][col].plot(range(0, seq_len), quantiles[0][sample_idx, :, feat_idx], color='g', linestyle='solid', label='Diffusion-TS')
36 | axes[row][col].fill_between(range(0, seq_len), quantiles[1][sample_idx, :, feat_idx],
37 | quantiles[2][sample_idx, :, feat_idx], color='g', alpha=0.3)
38 |
39 | axes[row][col].plot(df_o.x, df_o.val, color='b', marker='o', linestyle='None')
40 | axes[row][col].plot(df_x.x, df_x.val, color='r', marker='x', linestyle='None')
41 |
42 | if col == 0:
43 | plt.setp(axes[row, 0], ylabel='value')
44 | if row == -1:
45 | plt.setp(axes[-1, col], xlabel='time')
46 | plt.tight_layout()
47 | plt.show()
48 |
49 |
50 | class MaskedLoss(nn.Module):
51 | """ Masked MSE Loss
52 | """
53 |
54 | def __init__(self, reduction: str = 'mean', mode='mse'):
55 |
56 | super().__init__()
57 |
58 | self.reduction = reduction
59 | if mode == 'mse':
60 | self.loss = nn.MSELoss(reduction=self.reduction)
61 | else:
62 | self.loss = nn.L1Loss(reduction=self.reduction)
63 |
64 | def forward(self,
65 | y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
66 | """Compute the loss between a target value and a prediction.
67 |
68 | Args:
69 | y_pred: Estimated values
70 | y_true: Target values
71 | mask: boolean tensor with 0s at places where values should be ignored and 1s where they should be considered
72 |
73 | Returns
74 | -------
75 | if reduction == 'none':
76 | (num_active,) Loss for each active batch element as a tensor with gradient attached.
77 | if reduction == 'mean':
78 | scalar mean loss over batch as a tensor with gradient attached.
79 | """
80 |
81 | # for this particular loss, one may also elementwise multiply y_pred and y_true with the inverted mask
82 | masked_pred = torch.masked_select(y_pred, mask)
83 | masked_true = torch.masked_select(y_true, mask)
84 |
85 | return self.loss(masked_pred, masked_true)
86 |
87 | def random_mask(observed_values, missing_ratio=0.1, seed=1984):
88 | observed_masks = ~np.isnan(observed_values)
89 |
90 | # randomly set some percentage as ground-truth
91 | masks = observed_masks.reshape(-1).copy()
92 | obs_indices = np.where(masks)[0].tolist()
93 |
94 | # Store the state of the RNG to restore later.
95 | st0 = np.random.get_state()
96 | np.random.seed(seed)
97 |
98 | miss_indices = np.random.choice(
99 | obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False
100 | )
101 |
102 | # Restore RNG.
103 | np.random.set_state(st0)
104 |
105 | masks[miss_indices] = False
106 | gt_masks = masks.reshape(observed_masks.shape)
107 |
108 | observed_values = np.nan_to_num(observed_values)
109 | return torch.from_numpy(observed_values).float(), torch.from_numpy(observed_masks).float(),\
110 | torch.from_numpy(gt_masks).float()
--------------------------------------------------------------------------------
/Utils/io_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import yaml
4 | import json
5 | import torch
6 | import random
7 | import warnings
8 | import importlib
9 | import numpy as np
10 |
11 |
12 | def load_yaml_config(path):
13 | with open(path) as f:
14 | config = yaml.full_load(f)
15 | return config
16 |
17 | def save_config_to_yaml(config, path):
18 | assert path.endswith('.yaml')
19 | with open(path, 'w') as f:
20 | f.write(yaml.dump(config))
21 | f.close()
22 |
23 | def save_dict_to_json(d, path, indent=None):
24 | json.dump(d, open(path, 'w'), indent=indent)
25 |
26 | def load_dict_from_json(path):
27 | return json.load(open(path, 'r'))
28 |
29 | def write_args(args, path):
30 | args_dict = dict((name, getattr(args, name)) for name in dir(args)if not name.startswith('_'))
31 | with open(path, 'a') as args_file:
32 | args_file.write('==> torch version: {}\n'.format(torch.__version__))
33 | args_file.write('==> cudnn version: {}\n'.format(torch.backends.cudnn.version()))
34 | args_file.write('==> Cmd:\n')
35 | args_file.write(str(sys.argv))
36 | args_file.write('\n==> args:\n')
37 | for k, v in sorted(args_dict.items()):
38 | args_file.write(' %s: %s\n' % (str(k), str(v)))
39 | args_file.close()
40 |
41 | def seed_everything(seed, cudnn_deterministic=False):
42 | """
43 | Function that sets seed for pseudo-random number generators in:
44 | pytorch, numpy, python.random
45 |
46 | Args:
47 | seed: the integer value seed for global random state
48 | """
49 | if seed is not None:
50 | print(f"Global seed set to {seed}")
51 | random.seed(seed)
52 | np.random.seed(seed)
53 | torch.manual_seed(seed)
54 | torch.cuda.manual_seed_all(seed)
55 | torch.backends.cudnn.deterministic = False
56 |
57 | if cudnn_deterministic:
58 | torch.backends.cudnn.deterministic = True
59 | warnings.warn('You have chosen to seed training. '
60 | 'This will turn on the CUDNN deterministic setting, '
61 | 'which can slow down your training considerably! '
62 | 'You may see unexpected behavior when restarting '
63 | 'from checkpoints.')
64 |
65 | def merge_opts_to_config(config, opts):
66 | def modify_dict(c, nl, v):
67 | if len(nl) == 1:
68 | c[nl[0]] = type(c[nl[0]])(v)
69 | else:
70 | # print(nl)
71 | c[nl[0]] = modify_dict(c[nl[0]], nl[1:], v)
72 | return c
73 |
74 | if opts is not None and len(opts) > 0:
75 | assert len(opts) % 2 == 0, "each opts should be given by the name and values! The length shall be even number!"
76 | for i in range(len(opts) // 2):
77 | name = opts[2*i]
78 | value = opts[2*i+1]
79 | config = modify_dict(config, name.split('.'), value)
80 | return config
81 |
82 | def modify_config_for_debug(config):
83 | config['dataloader']['num_workers'] = 0
84 | config['dataloader']['batch_size'] = 1
85 | return config
86 |
87 | def get_model_parameters_info(model):
88 | # for mn, m in model.named_modules():
89 | parameters = {'overall': {'trainable': 0, 'non_trainable': 0, 'total': 0}}
90 | for child_name, child_module in model.named_children():
91 | parameters[child_name] = {'trainable': 0, 'non_trainable': 0}
92 | for pn, p in child_module.named_parameters():
93 | if p.requires_grad:
94 | parameters[child_name]['trainable'] += p.numel()
95 | else:
96 | parameters[child_name]['non_trainable'] += p.numel()
97 | parameters[child_name]['total'] = parameters[child_name]['trainable'] + parameters[child_name]['non_trainable']
98 |
99 | parameters['overall']['trainable'] += parameters[child_name]['trainable']
100 | parameters['overall']['non_trainable'] += parameters[child_name]['non_trainable']
101 | parameters['overall']['total'] += parameters[child_name]['total']
102 |
103 | # format the numbers
104 | def format_number(num):
105 | K = 2**10
106 | M = 2**20
107 | G = 2**30
108 | if num > G: # K
109 | uint = 'G'
110 | num = round(float(num)/G, 2)
111 | elif num > M:
112 | uint = 'M'
113 | num = round(float(num)/M, 2)
114 | elif num > K:
115 | uint = 'K'
116 | num = round(float(num)/K, 2)
117 | else:
118 | uint = ''
119 |
120 | return '{}{}'.format(num, uint)
121 |
122 | def format_dict(d):
123 | for k, v in d.items():
124 | if isinstance(v, dict):
125 | format_dict(v)
126 | else:
127 | d[k] = format_number(v)
128 |
129 | format_dict(parameters)
130 | return parameters
131 |
132 | def format_seconds(seconds):
133 | h = int(seconds // 3600)
134 | m = int(seconds // 60 - h * 60)
135 | s = int(seconds % 60)
136 |
137 | d = int(h // 24)
138 | h = h - d * 24
139 |
140 | if d == 0:
141 | if h == 0:
142 | if m == 0:
143 | ft = '{:02d}s'.format(s)
144 | else:
145 | ft = '{:02d}m:{:02d}s'.format(m, s)
146 | else:
147 | ft = '{:02d}h:{:02d}m:{:02d}s'.format(h, m, s)
148 |
149 | else:
150 | ft = '{:d}d:{:02d}h:{:02d}m:{:02d}s'.format(d, h, m, s)
151 |
152 | return ft
153 |
154 | def instantiate_from_config(config):
155 | if config is None:
156 | return None
157 | if not "target" in config:
158 | raise KeyError("Expected key `target` to instantiate.")
159 | module, cls = config["target"].rsplit(".", 1)
160 | cls = getattr(importlib.import_module(module, package=None), cls)
161 | return cls(**config.get("params", dict()))
162 |
163 | def class_from_string(class_name):
164 | module, cls = class_name.rsplit(".", 1)
165 | cls = getattr(importlib.import_module(module, package=None), cls)
166 | return cls
167 |
168 | def get_all_file(dir, end_with='.h5'):
169 | if isinstance(end_with, str):
170 | end_with = [end_with]
171 | filenames = []
172 | for root, dirs, files in os.walk(dir):
173 | for f in files:
174 | for ew in end_with:
175 | if f.endswith(ew):
176 | filenames.append(os.path.join(root, f))
177 | break
178 | return filenames
179 |
180 | def get_sub_dirs(dir, abs=True):
181 | sub_dirs = os.listdir(dir)
182 | if abs:
183 | sub_dirs = [os.path.join(dir, s) for s in sub_dirs]
184 | return sub_dirs
185 |
186 | def get_model_buffer(model):
187 | state_dict = model.state_dict()
188 | buffers_ = {}
189 | params_ = {n: p for n, p in model.named_parameters()}
190 |
191 | for k in state_dict:
192 | if k not in params_:
193 | buffers_[k] = state_dict[k]
194 | return buffers_
195 |
--------------------------------------------------------------------------------
/Utils/masking_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 |
5 |
6 | def costume_collate(data, max_len=None, mask_compensation=False):
7 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
8 | Args:
9 | data: len(batch_size) list of tuples (X, mask).
10 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
11 | - mask: boolean torch tensor of shape (seq_length, feat_dim); variable seq_length.
12 | max_len: global fixed sequence length. Used for architectures requiring fixed length input,
13 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
14 | Returns:
15 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
16 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
17 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
18 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
19 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 ignore (padding)
20 | """
21 |
22 | batch_size = len(data)
23 | features, masks = zip(*data)
24 |
25 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
26 | lengths = [X.shape[0] for X in features] # original sequence length for each time series
27 | if max_len is None:
28 | max_len = max(lengths)
29 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim)
30 | target_masks = torch.zeros_like(
31 | X, dtype=torch.bool) # (batch_size, padded_length, feat_dim) masks related to objective
32 | for i in range(batch_size):
33 | end = min(lengths[i], max_len)
34 | X[i, :end, :] = features[i][:end, :]
35 | target_masks[i, :end, :] = masks[i][:end, :]
36 |
37 | targets = X.clone()
38 | X = X * target_masks # mask input
39 | if mask_compensation:
40 | X = compensate_masking(X, target_masks)
41 |
42 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
43 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep
44 | # target_masks = ~target_masks # inverse logic: 0 now means ignore, 1 means predict
45 | return X, targets, target_masks, padding_masks
46 |
47 |
48 | def compensate_masking(X, mask):
49 | """
50 | Compensate feature vectors after masking values, in a way that the matrix product W @ X would not be affected on average.
51 | If p is the proportion of unmasked (active) elements, X' = X / p = X * feat_dim/num_active
52 | Args:
53 | X: (batch_size, seq_length, feat_dim) torch tensor
54 | mask: (batch_size, seq_length, feat_dim) torch tensor: 0s means mask and predict, 1s: unaffected (active) input
55 | Returns:
56 | (batch_size, seq_length, feat_dim) compensated features
57 | """
58 |
59 | # number of unmasked elements of feature vector for each time step
60 | num_active = torch.sum(mask, dim=-1).unsqueeze(-1) # (batch_size, seq_length, 1)
61 | # to avoid division by 0, set the minimum to 1
62 | num_active = torch.max(num_active, torch.ones(num_active.shape, dtype=torch.int16)) # (batch_size, seq_length, 1)
63 | return X.shape[-1] * X / num_active
64 |
65 |
66 | def padding_mask(lengths, max_len=None):
67 | """
68 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
69 | where 1 means keep element at this position (time step)
70 | """
71 | batch_size = lengths.numel()
72 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types
73 | return (torch.arange(0, max_len, device=lengths.device)
74 | .type_as(lengths)
75 | .repeat(batch_size, 1)
76 | .lt(lengths.unsqueeze(1)))
77 |
78 |
79 | def noise_mask(X, masking_ratio, lm=3, mode='separate', distribution='geometric', exclude_feats=None):
80 | """
81 | Creates a random boolean mask of the same shape as X, with 0s at places where a feature should be masked.
82 | Args:
83 | X: (seq_length, feat_dim) numpy array of features corresponding to a single sample
84 | masking_ratio: proportion of seq_length to be masked. At each time step, will also be the proportion of
85 | feat_dim that will be masked on average
86 | lm: average length of masking subsequences (streaks of 0s). Used only when `distribution` is 'geometric'.
87 | mode: whether each variable should be masked separately ('separate'), or all variables at a certain positions
88 | should be masked concurrently ('concurrent')
89 | distribution: whether each mask sequence element is sampled independently at random, or whether
90 | sampling follows a markov chain (and thus is stateful), resulting in geometric distributions of
91 | masked squences of a desired mean length `lm`
92 | exclude_feats: iterable of indices corresponding to features to be excluded from masking (i.e. to remain all 1s)
93 |
94 | Returns:
95 | boolean numpy array with the same shape as X, with 0s at places where a feature should be masked
96 | """
97 | if exclude_feats is not None:
98 | exclude_feats = set(exclude_feats)
99 |
100 | if distribution == 'geometric': # stateful (Markov chain)
101 | if mode == 'separate': # each variable (feature) is independent
102 | mask = np.ones(X.shape, dtype=bool)
103 | for m in range(X.shape[1]): # feature dimension
104 | if exclude_feats is None or m not in exclude_feats:
105 | mask[:, m] = geom_noise_mask_single(X.shape[0], lm, masking_ratio) # time dimension
106 | else: # replicate across feature dimension (mask all variables at the same positions concurrently)
107 | mask = np.tile(np.expand_dims(geom_noise_mask_single(X.shape[0], lm, masking_ratio), 1), X.shape[1])
108 | else: # each position is independent Bernoulli with p = 1 - masking_ratio
109 | if mode == 'separate':
110 | mask = np.random.choice(np.array([True, False]), size=X.shape, replace=True,
111 | p=(1 - masking_ratio, masking_ratio))
112 | else:
113 | mask = np.tile(np.random.choice(np.array([True, False]), size=(X.shape[0], 1), replace=True,
114 | p=(1 - masking_ratio, masking_ratio)), X.shape[1])
115 |
116 | return mask
117 |
118 |
119 | def geom_noise_mask_single(L, lm, masking_ratio):
120 | """
121 | Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio`
122 | proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution.
123 | Args:
124 | L: length of mask and sequence to be masked
125 | lm: average length of masking subsequences (streaks of 0s)
126 | masking_ratio: proportion of L to be masked
127 |
128 | Returns:
129 | (L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L
130 | """
131 | keep_mask = np.ones(L, dtype=bool)
132 | p_m = 1 / lm # probability of each masking sequence stopping. parameter of geometric distribution.
133 | p_u = p_m * masking_ratio / (1 - masking_ratio)
134 | # probability of each unmasked sequence stopping. parameter of geometric distribution.
135 | p = [p_m, p_u]
136 |
137 | # Start in state 0 with masking_ratio probability
138 | state = int(np.random.rand() > masking_ratio) # state 0 means masking, 1 means not masking
139 | for i in range(L):
140 | keep_mask[i] = state # here it happens that state and masking value corresponding to state are identical
141 | if np.random.rand() < p[state]:
142 | state = 1 - state
143 |
144 | return keep_mask
--------------------------------------------------------------------------------
/Utils/metric_utils.py:
--------------------------------------------------------------------------------
1 | ## Necessary Packages
2 | import scipy.stats
3 | import numpy as np
4 | import seaborn as sns
5 | import matplotlib.pyplot as plt
6 |
7 | from sklearn.manifold import TSNE
8 | from sklearn.decomposition import PCA
9 |
10 |
11 | def display_scores(results):
12 | mean = np.mean(results)
13 | sigma = scipy.stats.sem(results)
14 | sigma = sigma * scipy.stats.t.ppf((1 + 0.95) / 2., 5-1)
15 | # sigma = 1.96*(np.std(results)/np.sqrt(len(results)))
16 | print('Final Score: ', f'{mean} \xB1 {sigma}')
17 |
18 |
19 | def train_test_divide (data_x, data_x_hat, data_t, data_t_hat, train_rate=0.8):
20 | """Divide train and test data for both original and synthetic data.
21 |
22 | Args:
23 | - data_x: original data
24 | - data_x_hat: generated data
25 | - data_t: original time
26 | - data_t_hat: generated time
27 | - train_rate: ratio of training data from the original data
28 | """
29 | # Divide train/test index (original data)
30 | no = len(data_x)
31 | idx = np.random.permutation(no)
32 | train_idx = idx[:int(no*train_rate)]
33 | test_idx = idx[int(no*train_rate):]
34 |
35 | train_x = [data_x[i] for i in train_idx]
36 | test_x = [data_x[i] for i in test_idx]
37 | train_t = [data_t[i] for i in train_idx]
38 | test_t = [data_t[i] for i in test_idx]
39 |
40 | # Divide train/test index (synthetic data)
41 | no = len(data_x_hat)
42 | idx = np.random.permutation(no)
43 | train_idx = idx[:int(no*train_rate)]
44 | test_idx = idx[int(no*train_rate):]
45 |
46 | train_x_hat = [data_x_hat[i] for i in train_idx]
47 | test_x_hat = [data_x_hat[i] for i in test_idx]
48 | train_t_hat = [data_t_hat[i] for i in train_idx]
49 | test_t_hat = [data_t_hat[i] for i in test_idx]
50 |
51 | return train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat
52 |
53 |
54 | def extract_time (data):
55 | """Returns Maximum sequence length and each sequence length.
56 |
57 | Args:
58 | - data: original data
59 |
60 | Returns:
61 | - time: extracted time information
62 | - max_seq_len: maximum sequence length
63 | """
64 | time = list()
65 | max_seq_len = 0
66 | for i in range(len(data)):
67 | max_seq_len = max(max_seq_len, len(data[i][:,0]))
68 | time.append(len(data[i][:,0]))
69 |
70 | return time, max_seq_len
71 |
72 |
73 | def visualization(ori_data, generated_data, analysis, compare=3000):
74 | """Using PCA or tSNE for generated and original data visualization.
75 |
76 | Args:
77 | - ori_data: original data
78 | - generated_data: generated synthetic data
79 | - analysis: tsne or pca or kernel
80 | """
81 | # Analysis sample size (for faster computation)
82 | anal_sample_no = min([compare, ori_data.shape[0]])
83 | idx = np.random.permutation(ori_data.shape[0])[:anal_sample_no]
84 |
85 | # Data preprocessing
86 | # ori_data = np.asarray(ori_data)
87 | # generated_data = np.asarray(generated_data)
88 |
89 | ori_data = ori_data[idx]
90 | generated_data = generated_data[idx]
91 |
92 | no, seq_len, dim = ori_data.shape
93 |
94 | for i in range(anal_sample_no):
95 | if (i == 0):
96 | prep_data = np.reshape(np.mean(ori_data[0, :, :], 1), [1, seq_len])
97 | prep_data_hat = np.reshape(np.mean(generated_data[0, :, :], 1), [1, seq_len])
98 | else:
99 | prep_data = np.concatenate((prep_data,
100 | np.reshape(np.mean(ori_data[i, :, :], 1), [1, seq_len])))
101 | prep_data_hat = np.concatenate((prep_data_hat,
102 | np.reshape(np.mean(generated_data[i, :, :], 1), [1, seq_len])))
103 |
104 | # Visualization parameter
105 | colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)]
106 |
107 | if analysis == 'pca':
108 | # PCA Analysis
109 | pca = PCA(n_components=2)
110 | pca.fit(prep_data)
111 | pca_results = pca.transform(prep_data)
112 | pca_hat_results = pca.transform(prep_data_hat)
113 |
114 | # Plotting
115 | f, ax = plt.subplots(1)
116 | plt.scatter(pca_results[:, 0], pca_results[:, 1],
117 | c=colors[:anal_sample_no], alpha=0.2, label="Original")
118 | plt.scatter(pca_hat_results[:, 0], pca_hat_results[:, 1],
119 | c=colors[anal_sample_no:], alpha=0.2, label="Synthetic")
120 |
121 | ax.legend()
122 | plt.title('PCA plot')
123 | plt.xlabel('x-pca')
124 | plt.ylabel('y_pca')
125 | plt.show()
126 |
127 | elif analysis == 'tsne':
128 |
129 | # Do t-SNE Analysis together
130 | prep_data_final = np.concatenate((prep_data, prep_data_hat), axis=0)
131 |
132 | # TSNE anlaysis
133 | tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
134 | tsne_results = tsne.fit_transform(prep_data_final)
135 |
136 | # Plotting
137 | f, ax = plt.subplots(1)
138 |
139 | plt.scatter(tsne_results[:anal_sample_no, 0], tsne_results[:anal_sample_no, 1],
140 | c=colors[:anal_sample_no], alpha=0.2, label="Original")
141 | plt.scatter(tsne_results[anal_sample_no:, 0], tsne_results[anal_sample_no:, 1],
142 | c=colors[anal_sample_no:], alpha=0.2, label="Synthetic")
143 |
144 | ax.legend()
145 |
146 | plt.title('t-SNE plot')
147 | plt.xlabel('x-tsne')
148 | plt.ylabel('y_tsne')
149 | plt.show()
150 |
151 | elif analysis == 'kernel':
152 |
153 | # Visualization parameter
154 | # colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)]
155 |
156 | f, ax = plt.subplots(1)
157 | sns.distplot(prep_data, hist=False, kde=True, kde_kws={'linewidth': 5}, label='Original', color="red")
158 | sns.distplot(prep_data_hat, hist=False, kde=True, kde_kws={'linewidth': 5, 'linestyle':'--'}, label='Synthetic', color="blue")
159 | # Plot formatting
160 |
161 | # plt.legend(prop={'size': 22})
162 | plt.legend()
163 | plt.xlabel('Data Value')
164 | plt.ylabel('Data Density Estimate')
165 | # plt.rcParams['pdf.fonttype'] = 42
166 |
167 | # plt.savefig(str(args.save_dir)+"/"+args.model1+"_histo.png", dpi=100,bbox_inches='tight')
168 | # plt.ylim((0, 12))
169 | plt.show()
170 | plt.close()
171 |
172 |
173 | if __name__ == '__main__':
174 | pass
--------------------------------------------------------------------------------
/Utils/predictive_metric.py:
--------------------------------------------------------------------------------
1 | """Reimplement TimeGAN-pytorch Codebase.
2 |
3 | Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar,
4 | "Time-series Generative Adversarial Networks,"
5 | Neural Information Processing Systems (NeurIPS), 2019.
6 |
7 | Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks
8 |
9 | Last updated Date: October 18th 2021
10 | Code author: Zhiwei Zhang (bitzzw@gmail.com)
11 |
12 | -----------------------------
13 |
14 | predictive_metrics.py
15 |
16 | Note: Use Post-hoc RNN to predict one-step ahead (last feature)
17 | """
18 |
19 | # Necessary Packages
20 | import tensorflow as tf
21 | import tensorflow._api.v2.compat.v1 as tf1
22 | tf.compat.v1.disable_eager_execution()
23 | import numpy as np
24 | from sklearn.metrics import mean_absolute_error
25 | from Utils.metric_utils import extract_time
26 |
27 |
28 | def predictive_score_metrics(ori_data, generated_data):
29 | """Report the performance of Post-hoc RNN one-step ahead prediction.
30 |
31 | Args:
32 | - ori_data: original data
33 | - generated_data: generated synthetic data
34 |
35 | Returns:
36 | - predictive_score: MAE of the predictions on the original data
37 | """
38 | # Initialization on the Graph
39 | tf1.reset_default_graph()
40 |
41 | # Basic Parameters
42 | no, seq_len, dim = ori_data.shape
43 |
44 | # Set maximum sequence length and each sequence length
45 | ori_time, ori_max_seq_len = extract_time(ori_data)
46 | generated_time, generated_max_seq_len = extract_time(ori_data)
47 | max_seq_len = max([ori_max_seq_len, generated_max_seq_len])
48 | # max_seq_len = 36
49 |
50 | ## Builde a post-hoc RNN predictive network
51 | # Network parameters
52 | hidden_dim = int(dim/2)
53 | iterations = 5000
54 | batch_size = 128
55 |
56 | # Input place holders
57 | X = tf1.placeholder(tf.float32, [None, max_seq_len-1, dim-1], name = "myinput_x")
58 | T = tf1.placeholder(tf.int32, [None], name = "myinput_t")
59 | Y = tf1.placeholder(tf.float32, [None, max_seq_len-1, 1], name = "myinput_y")
60 |
61 | # Predictor function
62 | def predictor (x, t):
63 | """Simple predictor function.
64 |
65 | Args:
66 | - x: time-series data
67 | - t: time information
68 |
69 | Returns:
70 | - y_hat: prediction
71 | - p_vars: predictor variables
72 | """
73 | with tf1.variable_scope("predictor", reuse = tf1.AUTO_REUSE) as vs:
74 | p_cell = tf1.nn.rnn_cell.GRUCell(num_units=hidden_dim, activation=tf.nn.tanh, name = 'p_cell')
75 | p_outputs, p_last_states = tf1.nn.dynamic_rnn(p_cell, x, dtype=tf.float32, sequence_length = t)
76 | # y_hat_logit = tf.contrib.layers.fully_connected(p_outputs, 1, activation_fn=None)
77 | y_hat_logit = tf1.layers.dense(p_outputs, 1, activation=None)
78 | y_hat = tf.nn.sigmoid(y_hat_logit)
79 | p_vars = [v for v in tf1.all_variables() if v.name.startswith(vs.name)]
80 |
81 | return y_hat, p_vars
82 |
83 | y_pred, p_vars = predictor(X, T)
84 | # Loss for the predictor
85 | p_loss = tf1.losses.absolute_difference(Y, y_pred)
86 | # optimizer
87 | p_solver = tf1.train.AdamOptimizer().minimize(p_loss, var_list = p_vars)
88 |
89 | ## Training
90 | # Session start
91 | sess = tf1.Session()
92 | sess.run(tf1.global_variables_initializer())
93 |
94 | from tqdm.auto import tqdm
95 |
96 | # Training using Synthetic dataset
97 | for itt in tqdm(range(iterations), desc='training', total=iterations):
98 |
99 | # Set mini-batch
100 | idx = np.random.permutation(len(generated_data))
101 | train_idx = idx[:batch_size]
102 |
103 | X_mb = list(generated_data[i][:-1,:(dim-1)] for i in train_idx)
104 | T_mb = list(generated_time[i]-1 for i in train_idx)
105 | Y_mb = list(np.reshape(generated_data[i][1:,(dim-1)],[len(generated_data[i][1:,(dim-1)]),1]) for i in train_idx)
106 |
107 | # Train predictor
108 | _, step_p_loss = sess.run([p_solver, p_loss], feed_dict={X: X_mb, T: T_mb, Y: Y_mb})
109 |
110 | ## Test the trained model on the original data
111 | idx = np.random.permutation(len(ori_data))
112 | train_idx = idx[:no]
113 |
114 | # idx = np.random.permutation(len(generated_data))
115 | # train_idx = idx[:batch_size]
116 | # X_mb = list(generated_data[i][:-1,:(dim-1)] for i in train_idx)
117 | # T_mb = list(generated_time[i]-1 for i in train_idx)
118 | # Y_mb = list(np.reshape(generated_data[i][1:,(dim-1)],[len(generated_data[i][1:,(dim-1)]),1]) for i in train_idx)
119 |
120 | X_mb = list(ori_data[i][:-1,:(dim-1)] for i in train_idx)
121 | T_mb = list(ori_time[i]-1 for i in train_idx)
122 | Y_mb = list(np.reshape(ori_data[i][1:,(dim-1)], [len(ori_data[i][1:,(dim-1)]),1]) for i in train_idx)
123 |
124 | # Prediction
125 | pred_Y_curr = sess.run(y_pred, feed_dict={X: X_mb, T: T_mb})
126 |
127 | # Compute the performance in terms of MAE
128 | MAE_temp = 0
129 | for i in range(no):
130 | MAE_temp = MAE_temp + mean_absolute_error(Y_mb[i], pred_Y_curr[i,:,:])
131 |
132 | predictive_score = MAE_temp / no
133 |
134 | return predictive_score
135 |
--------------------------------------------------------------------------------
/engine/logger.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import time
7 | import torch
8 | from Utils.io_utils import write_args, save_config_to_yaml
9 |
10 |
11 | class Logger(object):
12 | def __init__(self, args):
13 | self.args = args
14 | self.save_dir = args.save_dir
15 |
16 | os.makedirs(self.save_dir, exist_ok=True)
17 |
18 | # save the args and config
19 | self.config_dir = os.path.join(self.save_dir, 'configs')
20 | os.makedirs(self.config_dir, exist_ok=True)
21 | file_name = os.path.join(self.config_dir, 'args.txt')
22 | write_args(args, file_name)
23 |
24 | log_dir = os.path.join(self.save_dir, 'logs')
25 | if not os.path.exists(log_dir):
26 | os.makedirs(log_dir, exist_ok=True)
27 | self.text_writer = open(os.path.join(log_dir, 'log.txt'), 'a') # 'w')
28 | if args.tensorboard:
29 | self.log_info('using tensorboard')
30 | self.tb_writer = torch.utils.tensorboard.SummaryWriter(log_dir=log_dir) # tensorboard.SummaryWriter(log_dir=log_dir)
31 | else:
32 | self.tb_writer = None
33 |
34 | def save_config(self, config):
35 | save_config_to_yaml(config, os.path.join(self.config_dir, 'config.yaml'))
36 |
37 | def log_info(self, info, check_primary=True):
38 | print(info)
39 | info = str(info)
40 | time_str = time.strftime('%Y-%m-%d-%H-%M')
41 | info = '{}: {}'.format(time_str, info)
42 | if not info.endswith('\n'):
43 | info += '\n'
44 | self.text_writer.write(info)
45 | self.text_writer.flush()
46 |
47 | def add_scalar(self, **kargs):
48 | """Log a scalar variable."""
49 | if self.tb_writer is not None:
50 | self.tb_writer.add_scalar(**kargs)
51 |
52 | def add_scalars(self, **kargs):
53 | """Log a scalar variable."""
54 | if self.tb_writer is not None:
55 | self.tb_writer.add_scalars(**kargs)
56 |
57 | def add_image(self, **kargs):
58 | """Log a scalar variable."""
59 | if self.tb_writer is not None:
60 | self.tb_writer.add_image(**kargs)
61 |
62 | def add_images(self, **kargs):
63 | """Log a scalar variable."""
64 | if self.tb_writer is not None:
65 | self.tb_writer.add_images(**kargs)
66 |
67 | def close(self):
68 | self.text_writer.close()
69 | self.tb_writer.close()
70 |
71 |
--------------------------------------------------------------------------------
/engine/lr_sch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch import inf
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class ReduceLROnPlateauWithWarmup(object):
7 | """Reduce learning rate when a metric has stopped improving.
8 | Models often benefit from reducing the learning rate by a factor
9 | of 2-10 once learning stagnates. This scheduler reads a metrics
10 | quantity and if no improvement is seen for a 'patience' number
11 | of epochs, the learning rate is reduced.
12 |
13 | Args:
14 | optimizer (Optimizer): Wrapped optimizer.
15 | mode (str): One of `min`, `max`. In `min` mode, lr will
16 | be reduced when the quantity monitored has stopped
17 | decreasing; in `max` mode it will be reduced when the
18 | quantity monitored has stopped increasing. Default: 'min'.
19 | factor (float): Factor by which the learning rate will be
20 | reduced. new_lr = lr * factor. Default: 0.1.
21 | patience (int): Number of epochs with no improvement after
22 | which learning rate will be reduced. For example, if
23 | `patience = 2`, then we will ignore the first 2 epochs
24 | with no improvement, and will only decrease the LR after the
25 | 3rd epoch if the loss still hasn't improved then.
26 | Default: 10.
27 | threshold (float): Threshold for measuring the new optimum,
28 | to only focus on significant changes. Default: 1e-4.
29 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
30 | dynamic_threshold = best * ( 1 + threshold ) in 'max'
31 | mode or best * ( 1 - threshold ) in `min` mode.
32 | In `abs` mode, dynamic_threshold = best + threshold in
33 | `max` mode or best - threshold in `min` mode. Default: 'rel'.
34 | cooldown (int): Number of epochs to wait before resuming
35 | normal operation after lr has been reduced. Default: 0.
36 | min_lr (float or list): A scalar or a list of scalars. A
37 | lower bound on the learning rate of all param groups
38 | or each group respectively. Default: 0.
39 | eps (float): Minimal decay applied to lr. If the difference
40 | between new and old lr is smaller than eps, the update is
41 | ignored. Default: 1e-8.
42 | verbose (bool): If ``True``, prints a message to stdout for
43 | each update. Default: ``False``.
44 | warmup_lr: float or None, the learning rate to be touched after warmup
45 | warmup: int, the number of steps to warmup
46 | """
47 |
48 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
49 | threshold=1e-4, threshold_mode='rel', cooldown=0,
50 | min_lr=0, eps=1e-8, verbose=False, warmup_lr=None,
51 | warmup=0):
52 |
53 | if factor >= 1.0:
54 | raise ValueError('Factor should be < 1.0.')
55 | self.factor = factor
56 |
57 | # Attach optimizer
58 | if not isinstance(optimizer, Optimizer):
59 | raise TypeError('{} is not an Optimizer'.format(
60 | type(optimizer).__name__))
61 | self.optimizer = optimizer
62 |
63 | if isinstance(min_lr, list) or isinstance(min_lr, tuple):
64 | if len(min_lr) != len(optimizer.param_groups):
65 | raise ValueError("expected {} min_lrs, got {}".format(
66 | len(optimizer.param_groups), len(min_lr)))
67 | self.min_lrs = list(min_lr)
68 | else:
69 | self.min_lrs = [min_lr] * len(optimizer.param_groups)
70 |
71 | self.patience = patience
72 | self.verbose = verbose
73 | self.cooldown = cooldown
74 | self.cooldown_counter = 0
75 | self.mode = mode
76 | self.threshold = threshold
77 | self.threshold_mode = threshold_mode
78 |
79 | self.warmup_lr = warmup_lr
80 | self.warmup = warmup
81 |
82 | self.best = None
83 | self.num_bad_epochs = None
84 | self.mode_worse = None # the worse value for the chosen mode
85 | self.eps = eps
86 | self.last_epoch = 0
87 | self._init_is_better(mode=mode, threshold=threshold,
88 | threshold_mode=threshold_mode)
89 | self._reset()
90 |
91 | def _prepare_for_warmup(self):
92 | if self.warmup_lr is not None:
93 | if isinstance(self.warmup_lr, (list, tuple)):
94 | if len(self.warmup_lr) != len(self.optimizer.param_groups):
95 | raise ValueError("expected {} warmup_lrs, got {}".format(
96 | len(self.optimizer.param_groups), len(self.warmup_lr)))
97 | self.warmup_lrs = list(self.warmup_lr)
98 | else:
99 | self.warmup_lrs = [self.warmup_lr] * len(self.optimizer.param_groups)
100 | else:
101 | self.warmup_lrs = None
102 | if self.warmup > self.last_epoch:
103 | curr_lrs = [group['lr'] for group in self.optimizer.param_groups]
104 | self.warmup_lr_steps = [max(0, (self.warmup_lrs[i] - curr_lrs[i])/float(self.warmup)) for i in range(len(curr_lrs))]
105 | else:
106 | self.warmup_lr_steps = None
107 |
108 | def _reset(self):
109 | """Resets num_bad_epochs counter and cooldown counter."""
110 | self.best = self.mode_worse
111 | self.cooldown_counter = 0
112 | self.num_bad_epochs = 0
113 |
114 | def step(self, metrics):
115 | # convert `metrics` to float, in case it's a zero-dim Tensor
116 | current = float(metrics)
117 | epoch = self.last_epoch + 1
118 | self.last_epoch = epoch
119 |
120 | if epoch <= self.warmup:
121 | self._increase_lr(epoch)
122 | else:
123 | if self.is_better(current, self.best):
124 | self.best = current
125 | self.num_bad_epochs = 0
126 | else:
127 | self.num_bad_epochs += 1
128 |
129 | if self.in_cooldown:
130 | self.cooldown_counter -= 1
131 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
132 |
133 | if self.num_bad_epochs > self.patience:
134 | self._reduce_lr(epoch)
135 | self.cooldown_counter = self.cooldown
136 | self.num_bad_epochs = 0
137 |
138 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
139 |
140 | def _reduce_lr(self, epoch):
141 | for i, param_group in enumerate(self.optimizer.param_groups):
142 | old_lr = float(param_group['lr'])
143 | new_lr = max(old_lr * self.factor, self.min_lrs[i])
144 | if old_lr - new_lr > self.eps:
145 | param_group['lr'] = new_lr
146 | if self.verbose:
147 | print('Epoch {:5d}: reducing learning rate'
148 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
149 |
150 | def _increase_lr(self, epoch):
151 | # used for warmup
152 | for i, param_group in enumerate(self.optimizer.param_groups):
153 | old_lr = float(param_group['lr'])
154 | new_lr = max(old_lr + self.warmup_lr_steps[i], self.min_lrs[i])
155 | param_group['lr'] = new_lr
156 | if self.verbose:
157 | print('Epoch {:5d}: increasing learning rate'
158 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
159 |
160 | @property
161 | def in_cooldown(self):
162 | return self.cooldown_counter > 0
163 |
164 | def is_better(self, a, best):
165 | if self.mode == 'min' and self.threshold_mode == 'rel':
166 | rel_epsilon = 1. - self.threshold
167 | return a < best * rel_epsilon
168 |
169 | elif self.mode == 'min' and self.threshold_mode == 'abs':
170 | return a < best - self.threshold
171 |
172 | elif self.mode == 'max' and self.threshold_mode == 'rel':
173 | rel_epsilon = self.threshold + 1.
174 | return a > best * rel_epsilon
175 |
176 | else: # mode == 'max' and epsilon_mode == 'abs':
177 | return a > best + self.threshold
178 |
179 | def _init_is_better(self, mode, threshold, threshold_mode):
180 | if mode not in {'min', 'max'}:
181 | raise ValueError('mode ' + mode + ' is unknown!')
182 | if threshold_mode not in {'rel', 'abs'}:
183 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
184 |
185 | if mode == 'min':
186 | self.mode_worse = inf
187 | else: # mode == 'max':
188 | self.mode_worse = -inf
189 |
190 | self.mode = mode
191 | self.threshold = threshold
192 | self.threshold_mode = threshold_mode
193 |
194 | self._prepare_for_warmup()
195 |
196 | def state_dict(self):
197 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
198 |
199 | def load_state_dict(self, state_dict):
200 | self.__dict__.update(state_dict)
201 | self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
202 |
203 |
204 | class CosineAnnealingLRWithWarmup(object):
205 | """
206 | adjust lr:
207 |
208 | args:
209 | warmup_lr: float or None, the learning rate to be touched after warmup
210 | warmup: int, the number of steps to warmup
211 | """
212 |
213 | def __init__(self, optimizer, T_max, last_epoch=-1, verbose=False,
214 | min_lr=0, warmup_lr=None, warmup=0):
215 | self.optimizer = optimizer
216 | self.T_max = T_max
217 | self.last_epoch = last_epoch
218 | self.verbose = verbose
219 | self.warmup_lr = warmup_lr
220 | self.warmup = warmup
221 |
222 | if isinstance(min_lr, list) or isinstance(min_lr, tuple):
223 | if len(min_lr) != len(optimizer.param_groups):
224 | raise ValueError("expected {} min_lrs, got {}".format(
225 | len(optimizer.param_groups), len(min_lr)))
226 | self.min_lrs = list(min_lr)
227 | else:
228 | self.min_lrs = [min_lr] * len(optimizer.param_groups)
229 | self.max_lrs = [lr for lr in self.min_lrs]
230 |
231 | self._prepare_for_warmup()
232 |
233 | def step(self):
234 | epoch = self.last_epoch + 1
235 | self.last_epoch = epoch
236 |
237 | if epoch <= self.warmup:
238 | self._increase_lr(epoch)
239 | else:
240 | self._reduce_lr(epoch)
241 |
242 | def _reduce_lr(self, epoch):
243 | for i, param_group in enumerate(self.optimizer.param_groups):
244 | progress = float(epoch - self.warmup) / float(max(1, self.T_max - self.warmup))
245 | factor = max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
246 | old_lr = float(param_group['lr'])
247 | new_lr = max(self.max_lrs[i] * factor, self.min_lrs[i])
248 | param_group['lr'] = new_lr
249 | if self.verbose:
250 | print('Epoch {:5d}: reducing learning rate'
251 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
252 |
253 | def _increase_lr(self, epoch):
254 | # used for warmup
255 | for i, param_group in enumerate(self.optimizer.param_groups):
256 | old_lr = float(param_group['lr'])
257 | new_lr = old_lr + self.warmup_lr_steps[i]
258 | param_group['lr'] = new_lr
259 | self.max_lrs[i] = max(self.max_lrs[i], new_lr)
260 | if self.verbose:
261 | print('Epoch {:5d}: increasing learning rate'
262 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
263 |
264 | def _prepare_for_warmup(self):
265 | if self.warmup_lr is not None:
266 | if isinstance(self.warmup_lr, (list, tuple)):
267 | if len(self.warmup_lr) != len(self.optimizer.param_groups):
268 | raise ValueError("expected {} warmup_lrs, got {}".format(
269 | len(self.optimizer.param_groups), len(self.warmup_lr)))
270 | self.warmup_lrs = list(self.warmup_lr)
271 | else:
272 | self.warmup_lrs = [self.warmup_lr] * len(self.optimizer.param_groups)
273 | else:
274 | self.warmup_lrs = None
275 | if self.warmup > self.last_epoch:
276 | curr_lrs = [group['lr'] for group in self.optimizer.param_groups]
277 | self.warmup_lr_steps = [max(0, (self.warmup_lrs[i] - curr_lrs[i])/float(self.warmup)) for i in range(len(curr_lrs))]
278 | else:
279 | self.warmup_lr_steps = None
280 |
281 |
282 | def state_dict(self):
283 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
284 |
285 | def load_state_dict(self, state_dict):
286 | self.__dict__.update(state_dict)
287 | self._prepare_for_warmup()
--------------------------------------------------------------------------------
/engine/solver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import torch
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 | from pathlib import Path
9 | from tqdm.auto import tqdm
10 | from ema_pytorch import EMA
11 | from torch.optim import Adam
12 | from torch.nn.utils import clip_grad_norm_
13 | from Utils.io_utils import instantiate_from_config, get_model_parameters_info
14 |
15 |
16 | sys.path.append(os.path.join(os.path.dirname(__file__), '../'))
17 |
18 | def cycle(dl):
19 | while True:
20 | for data in dl:
21 | yield data
22 |
23 |
24 | class Trainer(object):
25 | def __init__(self, config, args, model, dataloader, logger=None):
26 | super().__init__()
27 | self.model = model
28 | self.device = self.model.betas.device
29 | self.train_num_steps = config['solver']['max_epochs']
30 | self.gradient_accumulate_every = config['solver']['gradient_accumulate_every']
31 | self.save_cycle = config['solver']['save_cycle']
32 | self.dl = cycle(dataloader['dataloader'])
33 | self.dataloader = dataloader['dataloader']
34 | self.step = 0
35 | self.milestone = 0
36 | self.args, self.config = args, config
37 | self.logger = logger
38 |
39 | self.results_folder = Path(config['solver']['results_folder'] + f'_{model.seq_length}')
40 | os.makedirs(self.results_folder, exist_ok=True)
41 |
42 | start_lr = config['solver'].get('base_lr', 1.0e-4)
43 | ema_decay = config['solver']['ema']['decay']
44 | ema_update_every = config['solver']['ema']['update_interval']
45 |
46 | self.opt = Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=start_lr, betas=[0.9, 0.96])
47 | self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to(self.device)
48 |
49 | sc_cfg = config['solver']['scheduler']
50 | sc_cfg['params']['optimizer'] = self.opt
51 | self.sch = instantiate_from_config(sc_cfg)
52 |
53 | if self.logger is not None:
54 | self.logger.log_info(str(get_model_parameters_info(self.model)))
55 | self.log_frequency = 100
56 |
57 | def save(self, milestone, verbose=False):
58 | if self.logger is not None and verbose:
59 | self.logger.log_info('Save current model to {}'.format(str(self.results_folder / f'checkpoint-{milestone}.pt')))
60 | data = {
61 | 'step': self.step,
62 | 'model': self.model.state_dict(),
63 | 'ema': self.ema.state_dict(),
64 | 'opt': self.opt.state_dict(),
65 | }
66 | torch.save(data, str(self.results_folder / f'checkpoint-{milestone}.pt'))
67 |
68 | def save_classifier(self, milestone, verbose=False):
69 | if self.logger is not None and verbose:
70 | self.logger.log_info('Save current classifer to {}'.format(str(self.results_folder / f'ckpt_classfier-{milestone}.pt')))
71 | data = {
72 | 'step': self.step_classifier,
73 | 'classifier': self.classifier.state_dict()
74 | }
75 | torch.save(data, str(self.results_folder / f'ckpt_classfier-{milestone}.pt'))
76 |
77 | def load(self, milestone, verbose=False):
78 | if self.logger is not None and verbose:
79 | self.logger.log_info('Resume from {}'.format(str(self.results_folder / f'checkpoint-{milestone}.pt')))
80 | device = self.device
81 | data = torch.load(str(self.results_folder / f'checkpoint-{milestone}.pt'), map_location=device)
82 | self.model.load_state_dict(data['model'])
83 | self.step = data['step']
84 | self.opt.load_state_dict(data['opt'])
85 | self.ema.load_state_dict(data['ema'])
86 | self.milestone = milestone
87 |
88 | def load_classifier(self, milestone, verbose=False):
89 | if self.logger is not None and verbose:
90 | self.logger.log_info('Resume from {}'.format(str(self.results_folder / f'ckpt_classfier-{milestone}.pt')))
91 | device = self.device
92 | data = torch.load(str(self.results_folder / f'ckpt_classfier-{milestone}.pt'), map_location=device)
93 | self.classifier.load_state_dict(data['classifier'])
94 | self.step_classifier = data['step']
95 | self.milestone_classifier = milestone
96 |
97 | def train(self):
98 | device = self.device
99 | step = 0
100 | if self.logger is not None:
101 | tic = time.time()
102 | self.logger.log_info('{}: start training...'.format(self.args.name), check_primary=False)
103 |
104 | with tqdm(initial=step, total=self.train_num_steps) as pbar:
105 | while step < self.train_num_steps:
106 | total_loss = 0.
107 | for _ in range(self.gradient_accumulate_every):
108 | data = next(self.dl).to(device)
109 | loss = self.model(data, target=data)
110 | loss = loss / self.gradient_accumulate_every
111 | loss.backward()
112 | total_loss += loss.item()
113 |
114 | pbar.set_description(f'loss: {total_loss:.6f}')
115 |
116 | clip_grad_norm_(self.model.parameters(), 1.0)
117 | self.opt.step()
118 | self.sch.step(total_loss)
119 | self.opt.zero_grad()
120 | self.step += 1
121 | step += 1
122 | self.ema.update()
123 |
124 | with torch.no_grad():
125 | if self.step != 0 and self.step % self.save_cycle == 0:
126 | self.milestone += 1
127 | self.save(self.milestone)
128 | # self.logger.log_info('saved in {}'.format(str(self.results_folder / f'checkpoint-{self.milestone}.pt')))
129 |
130 | if self.logger is not None and self.step % self.log_frequency == 0:
131 | # info = '{}: train'.format(self.args.name)
132 | # info = info + ': Epoch {}/{}'.format(self.step, self.train_num_steps)
133 | # info += ' ||'
134 | # info += '' if loss_f == 'none' else ' Fourier Loss: {:.4f}'.format(loss_f.item())
135 | # info += '' if loss_r == 'none' else ' Reglarization: {:.4f}'.format(loss_r.item())
136 | # info += ' | Total Loss: {:.6f}'.format(total_loss)
137 | # self.logger.log_info(info)
138 | self.logger.add_scalar(tag='train/loss', scalar_value=total_loss, global_step=self.step)
139 |
140 | pbar.update(1)
141 |
142 | print('training complete')
143 | if self.logger is not None:
144 | self.logger.log_info('Training done, time: {:.2f}'.format(time.time() - tic))
145 |
146 | def sample(self, num, size_every, shape=None, model_kwargs=None, cond_fn=None):
147 | if self.logger is not None:
148 | tic = time.time()
149 | self.logger.log_info('Begin to sample...')
150 | samples = np.empty([0, shape[0], shape[1]])
151 | num_cycle = int(num // size_every) + 1
152 |
153 | for _ in range(num_cycle):
154 | sample = self.ema.ema_model.generate_mts(batch_size=size_every, model_kwargs=model_kwargs, cond_fn=cond_fn)
155 | samples = np.row_stack([samples, sample.detach().cpu().numpy()])
156 | torch.cuda.empty_cache()
157 |
158 | if self.logger is not None:
159 | self.logger.log_info('Sampling done, time: {:.2f}'.format(time.time() - tic))
160 | return samples
161 |
162 | def restore(self, raw_dataloader, shape=None, coef=1e-1, stepsize=1e-1, sampling_steps=50):
163 | if self.logger is not None:
164 | tic = time.time()
165 | self.logger.log_info('Begin to restore...')
166 | model_kwargs = {}
167 | model_kwargs['coef'] = coef
168 | model_kwargs['learning_rate'] = stepsize
169 | samples = np.empty([0, shape[0], shape[1]])
170 | reals = np.empty([0, shape[0], shape[1]])
171 | masks = np.empty([0, shape[0], shape[1]])
172 |
173 | for idx, (x, t_m) in enumerate(raw_dataloader):
174 | x, t_m = x.to(self.device), t_m.to(self.device)
175 | if sampling_steps == self.model.num_timesteps:
176 | sample = self.ema.ema_model.sample_infill(shape=x.shape, target=x*t_m, partial_mask=t_m,
177 | model_kwargs=model_kwargs)
178 | else:
179 | sample = self.ema.ema_model.fast_sample_infill(shape=x.shape, target=x*t_m, partial_mask=t_m, model_kwargs=model_kwargs,
180 | sampling_timesteps=sampling_steps)
181 |
182 | samples = np.row_stack([samples, sample.detach().cpu().numpy()])
183 | reals = np.row_stack([reals, x.detach().cpu().numpy()])
184 | masks = np.row_stack([masks, t_m.detach().cpu().numpy()])
185 |
186 | if self.logger is not None:
187 | self.logger.log_info('Imputation done, time: {:.2f}'.format(time.time() - tic))
188 | return samples, reals, masks
189 | # return samples
190 |
191 | def forward_sample(self, x_start):
192 | b, c, h = x_start.shape
193 | noise = torch.randn_like(x_start, device=self.device)
194 | t = torch.randint(0, self.model.num_timesteps, (b,), device=self.device).long()
195 | x_t = self.model.q_sample(x_start=x_start, t=t, noise=noise).detach()
196 | return x_t, t
197 |
198 | def train_classfier(self, classifier):
199 | device = self.device
200 | step = 0
201 | self.milestone_classifier = 0
202 | self.step_classifier = 0
203 | dataloader = self.dataloader
204 | dataloader.dataset.shift_period('test')
205 | dataloader = cycle(dataloader)
206 |
207 | self.classifier = classifier
208 | self.opt_classifier = Adam(filter(lambda p: p.requires_grad, self.classifier.parameters()), lr=5.0e-4)
209 |
210 | if self.logger is not None:
211 | tic = time.time()
212 | self.logger.log_info('{}: start training classifier...'.format(self.args.name), check_primary=False)
213 |
214 | with tqdm(initial=step, total=self.train_num_steps) as pbar:
215 | while step < self.train_num_steps:
216 | total_loss = 0.
217 | for _ in range(self.gradient_accumulate_every):
218 | x, y = next(dataloader)
219 | x, y = x.to(device), y.to(device)
220 | x_t, t = self.forward_sample(x)
221 | logits = classifier(x_t, t)
222 | loss = F.cross_entropy(logits, y)
223 | loss = loss / self.gradient_accumulate_every
224 | loss.backward()
225 | total_loss += loss.item()
226 |
227 | pbar.set_description(f'loss: {total_loss:.6f}')
228 |
229 | self.opt_classifier.step()
230 | self.opt_classifier.zero_grad()
231 | self.step_classifier += 1
232 | step += 1
233 |
234 | with torch.no_grad():
235 | if self.step_classifier != 0 and self.step_classifier % self.save_cycle == 0:
236 | self.milestone_classifier += 1
237 | self.save(self.milestone_classifier)
238 |
239 | if self.logger is not None and self.step_classifier % self.log_frequency == 0:
240 | self.logger.add_scalar(tag='train/loss', scalar_value=total_loss, global_step=self.step)
241 |
242 | pbar.update(1)
243 |
244 | print('training complete')
245 | if self.logger is not None:
246 | self.logger.log_info('Training done, time: {:.2f}'.format(time.time() - tic))
247 |
248 | # return classifier
249 |
250 |
--------------------------------------------------------------------------------
/figures/fig1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig1.jpg
--------------------------------------------------------------------------------
/figures/fig2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig2.jpg
--------------------------------------------------------------------------------
/figures/fig3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig3.jpg
--------------------------------------------------------------------------------
/figures/fig4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig4.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 |
6 | from engine.logger import Logger
7 | from engine.solver import Trainer
8 | from Data.build_dataloader import build_dataloader, build_dataloader_cond
9 | from Models.interpretable_diffusion.model_utils import unnormalize_to_zero_to_one
10 | from Utils.io_utils import load_yaml_config, seed_everything, merge_opts_to_config, instantiate_from_config
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description='PyTorch Training Script')
15 | parser.add_argument('--name', type=str, default=None)
16 |
17 | parser.add_argument('--config_file', type=str, default=None,
18 | help='path of config file')
19 | parser.add_argument('--output', type=str, default='OUTPUT',
20 | help='directory to save the results')
21 | parser.add_argument('--tensorboard', action='store_true',
22 | help='use tensorboard for logging')
23 |
24 | # args for random
25 |
26 | parser.add_argument('--cudnn_deterministic', action='store_true', default=False,
27 | help='set cudnn.deterministic True')
28 | parser.add_argument('--seed', type=int, default=12345,
29 | help='seed for initializing training.')
30 | parser.add_argument('--gpu', type=int, default=None,
31 | help='GPU id to use. If given, only the specific gpu will be'
32 | ' used, and ddp will be disabled')
33 |
34 | # args for training
35 | parser.add_argument('--train', action='store_true', default=False, help='Train or Test.')
36 | parser.add_argument('--sample', type=int, default=0,
37 | choices=[0, 1], help='Condition or Uncondition.')
38 | parser.add_argument('--mode', type=str, default='infill',
39 | help='Infilling or Forecasting.')
40 | parser.add_argument('--milestone', type=int, default=10)
41 |
42 | parser.add_argument('--missing_ratio', type=float, default=0., help='Ratio of Missing Values.')
43 | parser.add_argument('--pred_len', type=int, default=0, help='Length of Predictions.')
44 |
45 | # args for modify config
46 | parser.add_argument('opts', help='Modify config options using the command-line',
47 | default=None, nargs=argparse.REMAINDER)
48 |
49 | args = parser.parse_args()
50 | args.save_dir = os.path.join(args.output, f'{args.name}')
51 |
52 | return args
53 |
54 | def main():
55 | args = parse_args()
56 |
57 | if args.seed is not None:
58 | seed_everything(args.seed)
59 |
60 | if args.gpu is not None:
61 | torch.cuda.set_device(args.gpu)
62 |
63 | config = load_yaml_config(args.config_file)
64 | config = merge_opts_to_config(config, args.opts)
65 |
66 | logger = Logger(args)
67 | logger.save_config(config)
68 |
69 | model = instantiate_from_config(config['model']).cuda()
70 | if args.sample == 1 and args.mode in ['infill', 'predict']:
71 | test_dataloader_info = build_dataloader_cond(config, args)
72 | dataloader_info = build_dataloader(config, args)
73 | trainer = Trainer(config=config, args=args, model=model, dataloader=dataloader_info, logger=logger)
74 |
75 | if args.train:
76 | trainer.train()
77 | elif args.sample == 1 and args.mode in ['infill', 'predict']:
78 | trainer.load(args.milestone)
79 | dataloader, dataset = test_dataloader_info['dataloader'], test_dataloader_info['dataset']
80 | coef = config['dataloader']['test_dataset']['coefficient']
81 | stepsize = config['dataloader']['test_dataset']['step_size']
82 | sampling_steps = config['dataloader']['test_dataset']['sampling_steps']
83 | samples, *_ = trainer.restore(dataloader, [dataset.window, dataset.var_num], coef, stepsize, sampling_steps)
84 | if dataset.auto_norm:
85 | samples = unnormalize_to_zero_to_one(samples)
86 | # samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape)
87 | np.save(os.path.join(args.save_dir, f'ddpm_{args.mode}_{args.name}.npy'), samples)
88 | else:
89 | trainer.load(args.milestone)
90 | dataset = dataloader_info['dataset']
91 | samples = trainer.sample(num=len(dataset), size_every=2001, shape=[dataset.window, dataset.var_num])
92 | if dataset.auto_norm:
93 | samples = unnormalize_to_zero_to_one(samples)
94 | # samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape)
95 | np.save(os.path.join(args.save_dir, f'ddpm_fake_{args.name}.npy'), samples)
96 |
97 | if __name__ == '__main__':
98 | main()
99 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | einops==0.6.0
3 | ema-pytorch==0.2.1
4 | matplotlib==3.6.0
5 | pandas==1.5.0
6 | scikit-learn==1.1.2
7 | scipy==1.8.1
8 | seaborn==0.12.2
9 | tqdm==4.64.1
10 | dm-control==1.0.12
11 | dm-env==1.6
12 | dm-tree==0.1.8
13 | mujoco==2.3.4
14 | gluonts==0.12.6
15 | pyyaml==6.0
16 |
--------------------------------------------------------------------------------