├── .gitignore ├── Config ├── config.yaml ├── eeg.yaml ├── energy.yaml ├── etth.yaml ├── fmri.yaml ├── mujoco.yaml ├── mujoco_sssd.yaml ├── sines.yaml ├── solar.yaml ├── solar_update.yaml └── stocks.yaml ├── Data ├── Place the dataset here ├── build_dataloader.py └── readme.md ├── Experiments ├── eeg_multiple_classes.ipynb ├── metric_pytorch.ipynb ├── metric_tensorflow.ipynb ├── mujoco_imputation.ipynb ├── solar_nips_forecasting.ipynb └── view_interpretability.ipynb ├── LICENSE ├── Models ├── interpretable_diffusion │ ├── classifier.py │ ├── gaussian_diffusion.py │ ├── model_utils.py │ └── transformer.py └── ts2vec │ ├── models │ ├── dilated_conv.py │ ├── encoder.py │ └── losses.py │ ├── ts2vec.py │ └── utils.py ├── README.md ├── Tutorial_0.ipynb ├── Tutorial_1.ipynb ├── Tutorial_2.ipynb ├── Utils ├── Data_utils │ ├── eeg_dataset.py │ ├── mujoco_dataset.py │ ├── real_datasets.py │ └── sine_dataset.py ├── context_fid.py ├── cross_correlation.py ├── discriminative_metric.py ├── imputation_utils.py ├── io_utils.py ├── masking_utils.py ├── metric_utils.py └── predictive_metric.py ├── engine ├── logger.py ├── lr_sch.py └── solver.py ├── figures ├── Flowchart.svg ├── fig1.jpg ├── fig2.jpg ├── fig3.jpg └── fig4.jpg ├── main.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /Config/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 160 5 | feature_size: 5 6 | n_layer_enc: 1 7 | n_layer_dec: 2 8 | d_model: 64 # 4 X 16 9 | timesteps: 200 10 | sampling_timesteps: 200 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.1 16 | resid_pd: 0.1 17 | kernel_size: 5 18 | padding_size: 2 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 1000 23 | results_folder: ./Checkpoints_syn 24 | gradient_accumulate_every: 2 25 | save_cycle: 100 # max_epochs // 10 26 | ema: 27 | decay: 0.99 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 200 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 100 40 | verbose: False 41 | -------------------------------------------------------------------------------- /Config/eeg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 14 6 | n_layer_enc: 3 7 | n_layer_dec: 2 8 | d_model: 64 # 4 X 16 9 | timesteps: 500 10 | sampling_timesteps: 100 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | classifier: 21 | target: Models.interpretable_diffusion.classifier.Classifier 22 | params: 23 | seq_length: 24 24 | feature_size: 14 25 | num_classes: 2 26 | n_layer_enc: 3 27 | n_embd: 64 # 4 X 16 28 | n_heads: 4 29 | mlp_hidden_times: 4 30 | attn_pd: 0.0 31 | resid_pd: 0.0 32 | max_len: 24 # seq_length 33 | num_head_channels: 8 34 | 35 | solver: 36 | base_lr: 1.0e-5 37 | max_epochs: 12000 38 | results_folder: ./Checkpoints_eeg 39 | gradient_accumulate_every: 2 40 | save_cycle: 1200 # max_epochs // 10 41 | ema: 42 | decay: 0.995 43 | update_interval: 10 44 | 45 | scheduler: 46 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 47 | params: 48 | factor: 0.5 49 | patience: 3000 50 | min_lr: 1.0e-5 51 | threshold: 1.0e-1 52 | threshold_mode: rel 53 | warmup_lr: 8.0e-4 54 | warmup: 500 55 | verbose: False 56 | 57 | dataloader: 58 | train_dataset: 59 | target: Utils.Data_utils.eeg_dataset.EEGDataset 60 | params: 61 | data_root: ../Data/datasets/EEG_Eye_State.arff 62 | window: 24 # seq_length 63 | save2npy: True 64 | neg_one_to_one: True 65 | period: train 66 | 67 | batch_size: 128 68 | sample_size: 256 69 | shuffle: True -------------------------------------------------------------------------------- /Config/energy.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 28 6 | n_layer_enc: 4 7 | n_layer_dec: 3 8 | d_model: 96 # 4 X 24 9 | timesteps: 1000 10 | sampling_timesteps: 1000 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 25000 23 | results_folder: ./Checkpoints_energy 24 | gradient_accumulate_every: 2 25 | save_cycle: 2500 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 5000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | 42 | dataloader: 43 | train_dataset: 44 | target: Utils.Data_utils.real_datasets.CustomDataset 45 | params: 46 | name: energy 47 | proportion: 1.0 # Set to rate < 1 if training conditional generation 48 | data_root: ./Data/datasets/energy_data.csv 49 | window: 24 # seq_length 50 | save2npy: True 51 | neg_one_to_one: True 52 | seed: 123 53 | period: train 54 | 55 | test_dataset: 56 | target: Utils.Data_utils.real_datasets.CustomDataset 57 | params: 58 | name: energy 59 | proportion: 0.9 # rate 60 | data_root: ./Data/datasets/energy_data.csv 61 | window: 24 # seq_length 62 | save2npy: True 63 | neg_one_to_one: True 64 | seed: 123 65 | period: test 66 | style: separate 67 | distribution: geometric 68 | coefficient: 1.0e-2 69 | step_size: 5.0e-2 70 | sampling_steps: 250 71 | 72 | batch_size: 64 73 | sample_size: 256 74 | shuffle: True -------------------------------------------------------------------------------- /Config/etth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 7 6 | n_layer_enc: 3 7 | n_layer_dec: 2 8 | d_model: 64 # 4 X 16 9 | timesteps: 500 10 | sampling_timesteps: 500 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 18000 23 | results_folder: ./Checkpoints_etth 24 | gradient_accumulate_every: 2 25 | save_cycle: 1800 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 4000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | 42 | dataloader: 43 | train_dataset: 44 | target: Utils.Data_utils.real_datasets.CustomDataset 45 | params: 46 | name: etth 47 | proportion: 1.0 # Set to rate < 1 if training conditional generation 48 | data_root: ./Data/datasets/ETTh.csv 49 | window: 24 # seq_length 50 | save2npy: True 51 | neg_one_to_one: True 52 | seed: 123 53 | period: train 54 | 55 | test_dataset: 56 | target: Utils.Data_utils.real_datasets.CustomDataset 57 | params: 58 | name: etth 59 | proportion: 0.9 # rate 60 | data_root: ./Data/datasets/ETTh.csv 61 | window: 24 # seq_length 62 | save2npy: True 63 | neg_one_to_one: True 64 | seed: 123 65 | period: test 66 | style: separate 67 | distribution: geometric 68 | coefficient: 1.0e-2 69 | step_size: 5.0e-2 70 | sampling_steps: 200 71 | 72 | batch_size: 128 73 | sample_size: 256 74 | shuffle: True -------------------------------------------------------------------------------- /Config/fmri.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 50 6 | n_layer_enc: 4 7 | n_layer_dec: 4 8 | d_model: 96 # 4 X 24 9 | timesteps: 1000 10 | sampling_timesteps: 1000 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 5 18 | padding_size: 2 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 15000 23 | results_folder: ./Checkpoints_fmri 24 | gradient_accumulate_every: 2 25 | save_cycle: 1500 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 3000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | 42 | dataloader: 43 | train_dataset: 44 | target: Utils.Data_utils.real_datasets.fMRIDataset 45 | params: 46 | name: fMRI 47 | proportion: 1.0 # Set to rate < 1 if training conditional generation 48 | data_root: ./Data/datasets/fMRI 49 | window: 24 # seq_length 50 | save2npy: True 51 | neg_one_to_one: True 52 | seed: 123 53 | period: train 54 | 55 | test_dataset: 56 | target: Utils.Data_utils.real_datasets.fMRIDataset 57 | params: 58 | name: fMRI 59 | proportion: 0.9 # rate 60 | data_root: ./Data/datasets/fMRI 61 | window: 24 # seq_length 62 | save2npy: True 63 | neg_one_to_one: True 64 | seed: 123 65 | period: test 66 | style: separate 67 | distribution: geometric 68 | coefficient: 1.0e-2 69 | step_size: 5.0e-2 70 | sampling_steps: 250 71 | 72 | batch_size: 64 73 | sample_size: 256 74 | shuffle: True -------------------------------------------------------------------------------- /Config/mujoco.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 14 6 | n_layer_enc: 3 7 | n_layer_dec: 2 8 | d_model: 64 # 4 X 16 9 | timesteps: 1000 10 | sampling_timesteps: 1000 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 14000 23 | results_folder: ./Checkpoints_mujoco 24 | gradient_accumulate_every: 2 25 | save_cycle: 1400 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 3000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | 42 | dataloader: 43 | train_dataset: 44 | target: Utils.Data_utils.mujoco_dataset.MuJoCoDataset 45 | params: 46 | num: 10000 47 | dim: 14 48 | window: 24 # seq_length 49 | save2npy: True 50 | neg_one_to_one: True 51 | seed: 123 52 | period: train 53 | 54 | test_dataset: 55 | target: Utils.Data_utils.mujoco_dataset.MuJoCoDataset 56 | params: 57 | num: 1000 58 | dim: 14 59 | window: 24 # seq_length 60 | save2npy: True 61 | neg_one_to_one: True 62 | seed: 123 63 | style: separate 64 | period: test 65 | distribution: geometric 66 | coefficient: 1.0e-2 67 | step_size: 5.0e-2 68 | sampling_steps: 250 69 | 70 | batch_size: 128 71 | sample_size: 256 72 | shuffle: True -------------------------------------------------------------------------------- /Config/mujoco_sssd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 100 5 | feature_size: 14 6 | n_layer_enc: 3 7 | n_layer_dec: 3 8 | d_model: 64 # 4 X 16 9 | timesteps: 500 10 | sampling_timesteps: 500 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0 16 | resid_pd: 0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 12000 23 | results_folder: ./Checkpoints_mujoco_sssd 24 | gradient_accumulate_every: 2 25 | save_cycle: 1200 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 3000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | -------------------------------------------------------------------------------- /Config/sines.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 5 6 | n_layer_enc: 1 7 | n_layer_dec: 2 8 | d_model: 64 # 4 X 16 9 | timesteps: 500 10 | sampling_timesteps: 500 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 12000 23 | results_folder: ./Checkpoints_sine 24 | gradient_accumulate_every: 2 25 | save_cycle: 1200 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 3000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | 42 | dataloader: 43 | train_dataset: 44 | target: Utils.Data_utils.sine_dataset.SineDataset 45 | params: 46 | num: 10000 47 | dim: 5 48 | window: 24 # seq_length 49 | save2npy: True 50 | neg_one_to_one: True 51 | seed: 123 52 | period: train 53 | 54 | test_dataset: 55 | target: Utils.Data_utils.sine_dataset.SineDataset 56 | params: 57 | num: 1000 58 | dim: 5 59 | window: 24 # seq_length 60 | save2npy: True 61 | neg_one_to_one: True 62 | seed: 123 63 | style: separate 64 | period: test 65 | distribution: geometric 66 | coefficient: 1.0e-2 67 | step_size: 5.0e-2 68 | sampling_steps: 200 69 | 70 | batch_size: 128 71 | sample_size: 256 72 | shuffle: True -------------------------------------------------------------------------------- /Config/solar.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 192 5 | feature_size: 128 6 | n_layer_enc: 4 7 | n_layer_dec: 4 8 | d_model: 96 # 4 X 24 9 | timesteps: 500 10 | sampling_timesteps: 500 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 1500 23 | results_folder: ./Checkpoints_solar 24 | gradient_accumulate_every: 2 25 | save_cycle: 150 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 300 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 100 40 | verbose: False 41 | -------------------------------------------------------------------------------- /Config/solar_update.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 192 5 | feature_size: 137 6 | n_layer_enc: 4 7 | n_layer_dec: 4 8 | d_model: 96 # 4 X 24 9 | timesteps: 500 10 | sampling_timesteps: 500 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.5 16 | resid_pd: 0.5 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 1000 23 | results_folder: ./Checkpoints_solar_nips 24 | gradient_accumulate_every: 2 25 | save_cycle: 100 # max_epochs // 10 26 | ema: 27 | decay: 0.9 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 300 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 100 40 | verbose: False 41 | -------------------------------------------------------------------------------- /Config/stocks.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: Models.interpretable_diffusion.gaussian_diffusion.Diffusion_TS 3 | params: 4 | seq_length: 24 5 | feature_size: 6 6 | n_layer_enc: 2 7 | n_layer_dec: 2 8 | d_model: 64 # 4 X 16 9 | timesteps: 500 10 | sampling_timesteps: 500 11 | loss_type: 'l1' 12 | beta_schedule: 'cosine' 13 | n_heads: 4 14 | mlp_hidden_times: 4 15 | attn_pd: 0.0 16 | resid_pd: 0.0 17 | kernel_size: 1 18 | padding_size: 0 19 | 20 | solver: 21 | base_lr: 1.0e-5 22 | max_epochs: 10000 23 | results_folder: ./Checkpoints_stock 24 | gradient_accumulate_every: 2 25 | save_cycle: 1000 # max_epochs // 10 26 | ema: 27 | decay: 0.995 28 | update_interval: 10 29 | 30 | scheduler: 31 | target: engine.lr_sch.ReduceLROnPlateauWithWarmup 32 | params: 33 | factor: 0.5 34 | patience: 2000 35 | min_lr: 1.0e-5 36 | threshold: 1.0e-1 37 | threshold_mode: rel 38 | warmup_lr: 8.0e-4 39 | warmup: 500 40 | verbose: False 41 | 42 | dataloader: 43 | train_dataset: 44 | target: Utils.Data_utils.real_datasets.CustomDataset 45 | params: 46 | name: stock 47 | proportion: 1.0 # Set to rate < 1 if training conditional generation 48 | data_root: ./Data/datasets/stock_data.csv 49 | window: 24 # seq_length 50 | save2npy: True 51 | neg_one_to_one: True 52 | seed: 123 53 | period: train 54 | 55 | test_dataset: 56 | target: Utils.Data_utils.real_datasets.CustomDataset 57 | params: 58 | name: stock 59 | proportion: 0.9 # rate 60 | data_root: ./Data/datasets/stock_data.csv 61 | window: 24 # seq_length 62 | save2npy: True 63 | neg_one_to_one: True 64 | seed: 123 65 | period: test 66 | style: separate 67 | distribution: geometric 68 | coefficient: 1.0e-2 69 | step_size: 5.0e-2 70 | sampling_steps: 200 71 | 72 | batch_size: 64 73 | sample_size: 256 74 | shuffle: True -------------------------------------------------------------------------------- /Data/Place the dataset here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/Data/Place the dataset here -------------------------------------------------------------------------------- /Data/build_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from Utils.io_utils import instantiate_from_config 3 | 4 | 5 | def build_dataloader(config, args=None): 6 | batch_size = config['dataloader']['batch_size'] 7 | jud = config['dataloader']['shuffle'] 8 | config['dataloader']['train_dataset']['params']['output_dir'] = args.save_dir 9 | dataset = instantiate_from_config(config['dataloader']['train_dataset']) 10 | 11 | dataloader = torch.utils.data.DataLoader(dataset, 12 | batch_size=batch_size, 13 | shuffle=jud, 14 | num_workers=0, 15 | pin_memory=True, 16 | sampler=None, 17 | drop_last=jud) 18 | 19 | dataload_info = { 20 | 'dataloader': dataloader, 21 | 'dataset': dataset 22 | } 23 | 24 | return dataload_info 25 | 26 | def build_dataloader_cond(config, args=None): 27 | batch_size = config['dataloader']['sample_size'] 28 | config['dataloader']['test_dataset']['params']['output_dir'] = args.save_dir 29 | if args.mode == 'infill': 30 | config['dataloader']['test_dataset']['params']['missing_ratio'] = args.missing_ratio 31 | elif args.mode == 'predict': 32 | config['dataloader']['test_dataset']['params']['predict_length'] = args.pred_len 33 | test_dataset = instantiate_from_config(config['dataloader']['test_dataset']) 34 | 35 | dataloader = torch.utils.data.DataLoader(test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=0, 39 | pin_memory=True, 40 | sampler=None, 41 | drop_last=False) 42 | 43 | dataload_info = { 44 | 'dataloader': dataloader, 45 | 'dataset': test_dataset 46 | } 47 | 48 | return dataload_info 49 | 50 | 51 | if __name__ == '__main__': 52 | pass 53 | 54 | -------------------------------------------------------------------------------- /Data/readme.md: -------------------------------------------------------------------------------- 1 | ## 🚄 Get Started 2 | 3 | Please download **dataset.zip**, then unzip and copy it to the location indicated by '`Place the dataset here`'. 4 | 5 | > 🔔 **dataset.zip** can be downloaded [here](https://drive.google.com/file/d/11DI22zKWtHjXMnNGPWNUbyGz-JiEtZy6/view?usp=sharing). 6 | 7 | ## 🔨 Build Dataloader 8 | 9 | ### 🕶️ Pipeline Overview 10 | 11 | The figure below shows everything that happens after calling *build_dataloader* 12 | 13 |
14 | 15 |
16 | 17 | Below we show the specific meaning of the saved file: 18 | 19 | - `{name}_norm_truth_{length}_train.npy` - The saved data is used for training model, which has been normalized into [0, 1]. 20 | - `{name}_norm_truth_{length}_test.npy` - The saved data is used for inference, which has been normalized into [0, 1]. 21 | - `{name}_ground_truth_{length}_train.npy` - The saved data is used for training model, however, it is raw data that has not been normalized. 22 | - `{name}_ground_truth_{length}_test.npy` - The saved data is used for inference, however, it is raw data that has not been normalized. 23 | - `{name}_masking_{length}.npy` - The saved mask-seqences indicate the generation target of imputation or forecasting. 24 | 25 | > 🔔 Note that the generated time series `ddpm_fake_{name}.npy` or `ddpm_{mode}_{name}.npy` are also normalized into [0, 1]. You can restore them by adding following codes in **main.py**: 26 | ``` 27 | line 86: samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape) 28 | line 93: samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape) 29 | ``` 30 | 31 | ### 📝 Custom Dataset 32 | 33 | Real-world sequences (or any self-prepared sequences) need to be configured via the following: 34 | 35 | * Create and check the settings in your **.yaml file** and modify it such as *seq_length* and *feature_size* etc. if necessary. 36 | * Convert the real-world time series into **.csv file**, then put it to the repo like our template datasets. 37 | * Make sure that non-numeric rows and columns are not included in the training data, or you may need to modify codes in **./Utils/Data_utils/real_datasets.py**. 38 | - Remove the header if it exists: 39 | ``` 40 | line 132: df = pd.read_csv(filepath, header=0) 41 | # set `header=None` if it does not exsit. 42 | ``` 43 | - Delete rows and columns, here using the first column as an example: 44 | ``` 45 | line 133: df.drop(df.columns[0], axis=1, inplace=True) 46 | ``` 47 | 48 | > 🔔 Please set `use_ff=False` at line 54 in **./Models/interpretable_diffusion/gaussian_diffusion.py** if your temporal data is highly irregular. -------------------------------------------------------------------------------- /Experiments/metric_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Context-FID Score Presentation\n", 8 | "## Necessary packages and functions call\n", 9 | "\n", 10 | "- Context-FID score: A useful metric measures how well the the synthetic time series windows ”fit” into the local context of the time series" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "import torch\n", 21 | "import numpy as np\n", 22 | "import sys\n", 23 | "sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))\n", 24 | "from Utils.context_fid import Context_FID\n", 25 | "from Utils.metric_utils import display_scores\n", 26 | "from Utils.cross_correlation import CrossCorrelLoss" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Data Loading\n", 34 | "\n", 35 | "Load original dataset and preprocess the loaded data." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "iterations = 5\n", 45 | "ori_data = np.load('../toy_exp/samples/sine_ground_truth_24_train.npy')\n", 46 | "# ori_data = np.load('../OUTPUT/{dataset_name}/samples/{dataset_name}_norm_truth_{seq_length}_train.npy') # Uncomment the line if dataset other than Sine is used.\n", 47 | "fake_data = np.load('../toy_exp/ddpm_fake_sines.npy')" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## Context-FID Score\n", 55 | "\n", 56 | "- The Frechet Inception distance-like score is based on unsupervised time series embeddings. It is able to score the fit of the fixed length synthetic samples into their context of (often much longer) true time series.\n", 57 | "\n", 58 | "- The lowest scoring models correspond to the best performing models in downstream tasks" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Iter 0: context-fid = 0.007570160590888924 \n", 71 | "\n", 72 | "Iter 1: context-fid = 0.008166182104461794 \n", 73 | "\n", 74 | "Iter 2: context-fid = 0.007970870497297756 \n", 75 | "\n", 76 | "Iter 3: context-fid = 0.007697393798102688 \n", 77 | "\n", 78 | "Iter 4: context-fid = 0.008525671090148759 \n", 79 | "\n", 80 | "Final Score: 0.007986055616179984 ± 0.00047287486643304236\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "context_fid_score = []\n", 86 | "\n", 87 | "for i in range(iterations):\n", 88 | " context_fid = Context_FID(ori_data[:], fake_data[:ori_data.shape[0]])\n", 89 | " context_fid_score.append(context_fid)\n", 90 | " print(f'Iter {i}: ', 'context-fid =', context_fid, '\\n')\n", 91 | " \n", 92 | "display_scores(context_fid_score)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "## Correlational Score\n", 100 | "\n", 101 | "- The metric uses the absolute error of the auto-correlation estimator by real data and synthetic data as the metric to assess the temporal dependency.\n", 102 | "\n", 103 | "- For d > 1, it uses the l1-norm of the difference between cross correlation matrices." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "def random_choice(size, num_select=100):\n", 113 | " select_idx = np.random.randint(low=0, high=size, size=(num_select,))\n", 114 | " return select_idx" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 5, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Iter 0: cross-correlation = 0.019852216581699826 \n", 127 | "\n", 128 | "Iter 1: cross-correlation = 0.01816951370252664 \n", 129 | "\n", 130 | "Iter 2: cross-correlation = 0.022373661672091448 \n", 131 | "\n", 132 | "Iter 3: cross-correlation = 0.012407943886992933 \n", 133 | "\n", 134 | "Iter 4: cross-correlation = 0.010309792931556355 \n", 135 | "\n", 136 | "Final Score: 0.01662262575497344 ± 0.006316425881906014\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "x_real = torch.from_numpy(ori_data)\n", 142 | "x_fake = torch.from_numpy(fake_data)\n", 143 | "\n", 144 | "correlational_score = []\n", 145 | "size = int(x_real.shape[0] / iterations)\n", 146 | "\n", 147 | "for i in range(iterations):\n", 148 | " real_idx = random_choice(x_real.shape[0], size)\n", 149 | " fake_idx = random_choice(x_fake.shape[0], size)\n", 150 | " corr = CrossCorrelLoss(x_real[real_idx, :, :], name='CrossCorrelLoss')\n", 151 | " loss = corr.compute(x_fake[fake_idx, :, :])\n", 152 | " correlational_score.append(loss.item())\n", 153 | " print(f'Iter {i}: ', 'cross-correlation =', loss.item(), '\\n')\n", 154 | "\n", 155 | "display_scores(correlational_score)" 156 | ] 157 | } 158 | ], 159 | "metadata": { 160 | "kernelspec": { 161 | "display_name": "PT3.8", 162 | "language": "python", 163 | "name": "pt3.8" 164 | }, 165 | "language_info": { 166 | "codemirror_mode": { 167 | "name": "ipython", 168 | "version": 3 169 | }, 170 | "file_extension": ".py", 171 | "mimetype": "text/x-python", 172 | "name": "python", 173 | "nbconvert_exporter": "python", 174 | "pygments_lexer": "ipython3", 175 | "version": "3.8.13" 176 | } 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 2 180 | } 181 | -------------------------------------------------------------------------------- /Experiments/metric_tensorflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Metric Presentation and Visualization\n", 8 | "## Necessary packages and functions call\n", 9 | "\n", 10 | "- DDPM-TS: Interpretable Diffusion for Time Series Generation\n", 11 | "- Metrics: \n", 12 | " - discriminative_metrics\n", 13 | " - predictive_metrics\n", 14 | " - visualization" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 4, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import os\n", 24 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n", 25 | "\n", 26 | "import sys\n", 27 | "sys.path.append(os.path.join(os.path.dirname('__file__'), '../'))\n", 28 | "\n", 29 | "import warnings\n", 30 | "warnings.filterwarnings(\"ignore\")\n", 31 | "\n", 32 | "import numpy as np\n", 33 | "import tensorflow as tf\n", 34 | "gpus = tf.config.experimental.list_physical_devices('GPU')\n", 35 | "if gpus:\n", 36 | " try:\n", 37 | " for gpu in gpus:\n", 38 | " tf.config.experimental.set_memory_growth(gpu, True)\n", 39 | " except RuntimeError as e:\n", 40 | " print(e)\n", 41 | "\n", 42 | "from Utils.metric_utils import display_scores\n", 43 | "from Utils.discriminative_metric import discriminative_score_metrics\n", 44 | "from Utils.predictive_metric import predictive_score_metrics" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Data Loading\n", 52 | "\n", 53 | "Load original dataset and preprocess the loaded data." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 5, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "iterations = 5\n", 63 | "ori_data = np.load('../toy_exp/samples/sine_ground_truth_24_train.npy')\n", 64 | "# ori_data = np.load('../OUTPUT/{dataset_name}/samples/{dataset_name}_norm_truth_{seq_length}_train.npy') # Uncomment the line if dataset other than Sine is used.\n", 65 | "fake_data = np.load('../toy_exp/ddpm_fake_sines.npy')" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Evaluate the generated data\n", 73 | "\n", 74 | "### 1. Discriminative score\n", 75 | "\n", 76 | "To evaluate the classification accuracy between original and synthetic data using post-hoc RNN network. The output is | classification accuracy - 0.5 |.\n", 77 | "\n", 78 | "- metric_iteration: the number of iterations for metric computation." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 6, 84 | "metadata": { 85 | "scrolled": false 86 | }, 87 | "outputs": [ 88 | { 89 | "name": "stderr", 90 | "output_type": "stream", 91 | "text": [ 92 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00, 8.77it/s]\n" 93 | ] 94 | }, 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "Iter 0: 0.00649999999999995 , 0.5825 , 0.4305 \n", 100 | "\n" 101 | ] 102 | }, 103 | { 104 | "name": "stderr", 105 | "output_type": "stream", 106 | "text": [ 107 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00, 8.75it/s]\n" 108 | ] 109 | }, 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "Iter 1: 0.0034999999999999476 , 0.4425 , 0.5645 \n", 115 | "\n" 116 | ] 117 | }, 118 | { 119 | "name": "stderr", 120 | "output_type": "stream", 121 | "text": [ 122 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:32<00:00, 9.39it/s]\n" 123 | ] 124 | }, 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "Iter 2: 0.0007500000000000284 , 0.46 , 0.5415 \n", 130 | "\n" 131 | ] 132 | }, 133 | { 134 | "name": "stderr", 135 | "output_type": "stream", 136 | "text": [ 137 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:46<00:00, 8.82it/s]\n" 138 | ] 139 | }, 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "Iter 3: 0.02200000000000002 , 0.535 , 0.509 \n", 145 | "\n" 146 | ] 147 | }, 148 | { 149 | "name": "stderr", 150 | "output_type": "stream", 151 | "text": [ 152 | "training: 100%|████████████████████████████████████████████████████████████████████| 2000/2000 [03:48<00:00, 8.77it/s]\n" 153 | ] 154 | }, 155 | { 156 | "name": "stdout", 157 | "output_type": "stream", 158 | "text": [ 159 | "Iter 4: 0.007249999999999979 , 0.5365 , 0.478 \n", 160 | "\n", 161 | "sine:\n", 162 | "Final Score: 0.007999999999999985 ± 0.010231963047355045\n", 163 | "\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "discriminative_score = []\n", 169 | "\n", 170 | "for i in range(iterations):\n", 171 | " temp_disc, fake_acc, real_acc = discriminative_score_metrics(ori_data[:], fake_data[:ori_data.shape[0]])\n", 172 | " discriminative_score.append(temp_disc)\n", 173 | " print(f'Iter {i}: ', temp_disc, ',', fake_acc, ',', real_acc, '\\n')\n", 174 | " \n", 175 | "print('sine:')\n", 176 | "display_scores(discriminative_score)\n", 177 | "print()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "## Evaluate the generated data\n", 185 | "\n", 186 | "### 2. Predictive score\n", 187 | "\n", 188 | "To evaluate the prediction performance on train on synthetic, test on real setting. More specifically, we use Post-hoc RNN architecture to predict one-step ahead and report the performance in terms of MAE. \n", 189 | "\n", 190 | "The model learns to predict the last dimension with one more step." 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 7, 196 | "metadata": { 197 | "scrolled": false 198 | }, 199 | "outputs": [ 200 | { 201 | "name": "stderr", 202 | "output_type": "stream", 203 | "text": [ 204 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [04:59<00:00, 16.67it/s]\n" 205 | ] 206 | }, 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "0 epoch: 0.09314072894950631 \n", 212 | "\n" 213 | ] 214 | }, 215 | { 216 | "name": "stderr", 217 | "output_type": "stream", 218 | "text": [ 219 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [04:52<00:00, 17.08it/s]\n" 220 | ] 221 | }, 222 | { 223 | "name": "stdout", 224 | "output_type": "stream", 225 | "text": [ 226 | "1 epoch: 0.09352861018187178 \n", 227 | "\n" 228 | ] 229 | }, 230 | { 231 | "name": "stderr", 232 | "output_type": "stream", 233 | "text": [ 234 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [05:04<00:00, 16.41it/s]\n" 235 | ] 236 | }, 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "2 epoch: 0.09322281220011404 \n", 242 | "\n" 243 | ] 244 | }, 245 | { 246 | "name": "stderr", 247 | "output_type": "stream", 248 | "text": [ 249 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [04:56<00:00, 16.85it/s]\n" 250 | ] 251 | }, 252 | { 253 | "name": "stdout", 254 | "output_type": "stream", 255 | "text": [ 256 | "3 epoch: 0.09313142589665173 \n", 257 | "\n" 258 | ] 259 | }, 260 | { 261 | "name": "stderr", 262 | "output_type": "stream", 263 | "text": [ 264 | "training: 100%|████████████████████████████████████████████████████████████████████| 5000/5000 [05:01<00:00, 16.57it/s]\n" 265 | ] 266 | }, 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "4 epoch: 0.09342775384532898 \n", 272 | "\n", 273 | "sine:\n", 274 | "Final Score: 0.09329026621469458 ± 0.00022198749029314803\n", 275 | "\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "predictive_score = []\n", 281 | "for i in range(iterations):\n", 282 | " temp_pred = predictive_score_metrics(ori_data, fake_data[:ori_data.shape[0]])\n", 283 | " predictive_score.append(temp_pred)\n", 284 | " print(i, ' epoch: ', temp_pred, '\\n')\n", 285 | " \n", 286 | "print('sine:')\n", 287 | "display_scores(predictive_score)\n", 288 | "print()" 289 | ] 290 | } 291 | ], 292 | "metadata": { 293 | "kernelspec": { 294 | "display_name": "Tensorflow_3.8", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.8.13" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 2 313 | } 314 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 XXX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Models/interpretable_diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from torch import nn 6 | from Models.interpretable_diffusion.model_utils import LearnablePositionalEncoding, Conv_MLP,\ 7 | AdaLayerNorm, GELU2 8 | 9 | 10 | class GroupNorm32(nn.GroupNorm): 11 | def forward(self, x): 12 | return super().forward(x.float()).type(x.dtype) 13 | 14 | 15 | def normalization(channels): 16 | """ 17 | Make a standard normalization layer. 18 | 19 | :param channels: number of input channels. 20 | :return: an nn.Module for normalization. 21 | """ 22 | return GroupNorm32(8, channels) 23 | 24 | 25 | def conv_nd(dims, *args, **kwargs): 26 | """ 27 | Create a 1D, 2D, or 3D convolution module. 28 | """ 29 | if dims == 1: 30 | return nn.Conv1d(*args, **kwargs) 31 | elif dims == 2: 32 | return nn.Conv2d(*args, **kwargs) 33 | elif dims == 3: 34 | return nn.Conv3d(*args, **kwargs) 35 | raise ValueError(f"unsupported dimensions: {dims}") 36 | 37 | 38 | class QKVAttention(nn.Module): 39 | """ 40 | A module which performs QKV attention and splits in a different order. 41 | """ 42 | 43 | def __init__(self, n_heads): 44 | super().__init__() 45 | self.n_heads = n_heads 46 | 47 | def forward(self, qkv): 48 | """ 49 | Apply QKV attention. 50 | 51 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 52 | :return: an [N x (H * C) x T] tensor after attention. 53 | """ 54 | bs, width, length = qkv.shape 55 | assert width % (3 * self.n_heads) == 0 56 | ch = width // (3 * self.n_heads) 57 | q, k, v = qkv.chunk(3, dim=1) 58 | scale = 1 / math.sqrt(math.sqrt(ch)) 59 | weight = torch.einsum( 60 | "bct,bcs->bts", 61 | (q * scale).view(bs * self.n_heads, ch, length), 62 | (k * scale).view(bs * self.n_heads, ch, length), 63 | ) # More stable with f16 than dividing afterwards 64 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 65 | a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 66 | return a.reshape(bs, -1, length) 67 | 68 | 69 | 70 | class AttentionPool2d(nn.Module): 71 | """ 72 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 73 | """ 74 | 75 | def __init__( 76 | self, 77 | embed_dim: int, 78 | num_heads_channels: int, 79 | output_dim: int = None, 80 | ): 81 | super().__init__() 82 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 83 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 84 | self.num_heads = embed_dim // num_heads_channels 85 | self.attention = QKVAttention(self.num_heads) 86 | 87 | def forward(self, x): 88 | b, c, *_spatial = x.shape 89 | x = x.reshape(b, c, -1) # NC(HW) 90 | x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 91 | x = self.qkv_proj(x) 92 | x = self.attention(x) 93 | x = self.c_proj(x) 94 | return x[:, :, 0] 95 | 96 | 97 | class FullAttention(nn.Module): 98 | def __init__(self, 99 | n_embd, # the embed dim 100 | n_head, # the number of heads 101 | attn_pdrop=0.1, # attention dropout prob 102 | resid_pdrop=0.1, # residual attention dropout prob 103 | ): 104 | super().__init__() 105 | assert n_embd % n_head == 0 106 | # key, query, value projections for all heads 107 | self.key = nn.Linear(n_embd, n_embd) 108 | self.query = nn.Linear(n_embd, n_embd) 109 | self.value = nn.Linear(n_embd, n_embd) 110 | 111 | # regularization 112 | self.attn_drop = nn.Dropout(attn_pdrop) 113 | self.resid_drop = nn.Dropout(resid_pdrop) 114 | # output projection 115 | self.proj = nn.Linear(n_embd, n_embd) 116 | self.n_head = n_head 117 | 118 | def forward(self, x, mask=None): 119 | B, T, C = x.size() 120 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 121 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 122 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 123 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) 124 | 125 | att = F.softmax(att, dim=-1) # (B, nh, T, T) 126 | att = self.attn_drop(att) 127 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 128 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C) 129 | att = att.mean(dim=1, keepdim=False) # (B, T, T) 130 | 131 | # output projection 132 | y = self.resid_drop(self.proj(y)) 133 | return y, att 134 | 135 | 136 | class EncoderBlock(nn.Module): 137 | """ an unassuming Transformer block """ 138 | def __init__(self, 139 | n_embd=1024, 140 | n_head=16, 141 | attn_pdrop=0.1, 142 | resid_pdrop=0.1, 143 | mlp_hidden_times=4, 144 | activate='GELU' 145 | ): 146 | super().__init__() 147 | 148 | self.ln1 = AdaLayerNorm(n_embd) 149 | self.ln2 = nn.LayerNorm(n_embd) 150 | self.attn = FullAttention( 151 | n_embd=n_embd, 152 | n_head=n_head, 153 | attn_pdrop=attn_pdrop, 154 | resid_pdrop=resid_pdrop, 155 | ) 156 | 157 | assert activate in ['GELU', 'GELU2'] 158 | act = nn.GELU() if activate == 'GELU' else GELU2() 159 | 160 | self.mlp = nn.Sequential( 161 | nn.Linear(n_embd, mlp_hidden_times * n_embd), 162 | act, 163 | nn.Linear(mlp_hidden_times * n_embd, n_embd), 164 | nn.Dropout(resid_pdrop), 165 | ) 166 | 167 | def forward(self, x, timestep, mask=None, label_emb=None): 168 | a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask) 169 | x = x + a 170 | x = x + self.mlp(self.ln2(x)) # only one really use encoder_output 171 | return x, att 172 | 173 | 174 | class Encoder(nn.Module): 175 | def __init__( 176 | self, 177 | n_layer=14, 178 | n_embd=1024, 179 | n_head=16, 180 | attn_pdrop=0., 181 | resid_pdrop=0., 182 | mlp_hidden_times=4, 183 | block_activate='GELU', 184 | ): 185 | super().__init__() 186 | 187 | self.blocks = nn.Sequential(*[EncoderBlock( 188 | n_embd=n_embd, 189 | n_head=n_head, 190 | attn_pdrop=attn_pdrop, 191 | resid_pdrop=resid_pdrop, 192 | mlp_hidden_times=mlp_hidden_times, 193 | activate=block_activate, 194 | ) for _ in range(n_layer)]) 195 | 196 | def forward(self, input, t, padding_masks=None, label_emb=None): 197 | x = input 198 | for block_idx in range(len(self.blocks)): 199 | x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb) 200 | return x 201 | 202 | 203 | class Classifier(nn.Module): 204 | def __init__( 205 | self, 206 | feature_size, 207 | seq_length, 208 | num_classes=2, 209 | n_layer_enc=5, 210 | n_embd=1024, 211 | n_heads=16, 212 | attn_pdrop=0.1, 213 | resid_pdrop=0.1, 214 | mlp_hidden_times=4, 215 | block_activate='GELU', 216 | max_len=2048, 217 | num_head_channels=8, 218 | **kwargs 219 | ): 220 | super().__init__() 221 | self.emb = Conv_MLP(feature_size, n_embd, resid_pdrop=resid_pdrop) 222 | self.encoder = Encoder(n_layer_enc, n_embd, n_heads, attn_pdrop, resid_pdrop, mlp_hidden_times, block_activate) 223 | self.pos_enc = LearnablePositionalEncoding(n_embd, dropout=resid_pdrop, max_len=max_len) 224 | 225 | assert num_head_channels != -1 226 | self.out = nn.Sequential( 227 | normalization(seq_length), 228 | nn.SiLU(), 229 | AttentionPool2d( 230 | seq_length, num_head_channels, num_classes 231 | ), 232 | ) 233 | 234 | def forward(self, input, t, padding_masks=None): 235 | emb = self.emb(input) 236 | inp_enc = self.pos_enc(emb) 237 | output = self.encoder(inp_enc, t, padding_masks=padding_masks) 238 | 239 | return self.out(output) 240 | 241 | 242 | if __name__ == '__main__': 243 | device = torch.device('cuda:0') 244 | input = torch.randn(128, 64, 14).to(device) 245 | t = torch.randint(0, 1000, (128, ), device=device) 246 | 247 | def count_model_parameters(model): 248 | total_params = sum(p.numel() for p in model.parameters()) 249 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 250 | total_size_mb = total_params * 4 / (1024 * 1024) # assuming float32 251 | trainable_size_mb = trainable_params * 4 / (1024 * 1024) 252 | return {'Total Parameters': total_params, 'Trainable Parameters': trainable_params, 253 | 'Total Size (MB)': total_size_mb, 'Trainable Size (MB)': trainable_size_mb} 254 | 255 | model = Classifier( 256 | feature_size=14, 257 | n_layer_enc=3, 258 | seq_length=64, 259 | num_classes=2, 260 | n_embd=64, 261 | n_heads=4, 262 | attn_pdrop=0., 263 | resid_pdrop=0., 264 | mlp_hidden_times=4, 265 | block_activate='GELU', 266 | max_len=64, 267 | ).to(device) 268 | print(count_model_parameters(model)) 269 | output = model(input, t) 270 | print(output.shape) -------------------------------------------------------------------------------- /Models/interpretable_diffusion/model_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import scipy 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torch import nn, einsum 7 | from functools import partial 8 | from einops import rearrange, reduce 9 | from scipy.fftpack import next_fast_len 10 | 11 | 12 | def exists(x): 13 | """ 14 | Check if the input is not None. 15 | 16 | Args: 17 | x: The input to check. 18 | 19 | Returns: 20 | bool: True if the input is not None, False otherwise. 21 | """ 22 | return x is not None 23 | 24 | def default(val, d): 25 | """ 26 | Return the value if it exists, otherwise return the default value. 27 | 28 | Args: 29 | val: The value to check. 30 | d: The default value or a callable that returns the default value. 31 | 32 | Returns: 33 | The value if it exists, otherwise the default value. 34 | """ 35 | if exists(val): 36 | return val 37 | return d() if callable(d) else d 38 | 39 | def identity(t, *args, **kwargs): 40 | """ 41 | Return the input tensor unchanged. 42 | 43 | Args: 44 | t: The input tensor. 45 | *args: Additional arguments (unused). 46 | **kwargs: Additional keyword arguments (unused). 47 | 48 | Returns: 49 | The input tensor unchanged. 50 | """ 51 | return t 52 | 53 | def extract(a, t, x_shape): 54 | """ 55 | Extracts values from tensor `a` at indices specified by tensor `t` and reshapes the result. 56 | Args: 57 | a (torch.Tensor): The input tensor from which values are extracted. 58 | t (torch.Tensor): The tensor containing indices to extract from `a`. 59 | x_shape (tuple): The shape of the tensor `x` which determines the final shape of the output. 60 | Returns: 61 | torch.Tensor: A tensor containing the extracted values, reshaped to match the shape of `x` except for the first dimension. 62 | """ 63 | 64 | b, *_ = t.shape 65 | out = a.gather(-1, t) 66 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 67 | 68 | def cond_fn(x, t, classifier=None, y=None, classifier_scale=1.): 69 | """ 70 | Compute the gradient of the classifier's log probabilities with respect to the input. 71 | 72 | Args: 73 | classifier (nn.Module): The classifier model used to compute logits. 74 | x (torch.Tensor): The input tensor for which gradients are computed. 75 | t (torch.Tensor): The time step tensor. 76 | y (torch.Tensor, optional): The target labels tensor. Must not be None. 77 | classifier_scale (float, optional): Scaling factor for the gradients. Default is 1. 78 | 79 | Returns: 80 | torch.Tensor: The gradient of the selected log probabilities with respect to the input tensor, scaled by classifier_scale. 81 | """ 82 | assert y is not None 83 | with torch.enable_grad(): 84 | x_in = x.detach().requires_grad_(True) 85 | logits = classifier(x_in, t) 86 | log_probs = F.log_softmax(logits, dim=-1) 87 | selected = log_probs[range(len(logits)), y.view(-1)] 88 | return torch.autograd.grad(selected.sum(), x_in)[0] * classifier_scale 89 | 90 | # normalization functions 91 | 92 | def normalize_to_neg_one_to_one(x): 93 | return x * 2 - 1 94 | 95 | def unnormalize_to_zero_to_one(x): 96 | return (x + 1) * 0.5 97 | 98 | 99 | # sinusoidal positional embeds 100 | 101 | class SinusoidalPosEmb(nn.Module): 102 | """ 103 | Sinusoidal positional embedding module. 104 | 105 | This module generates sinusoidal positional embeddings for input tensors. 106 | The embeddings are computed using sine and cosine functions with different frequencies. 107 | 108 | Attributes: 109 | dim (int): The dimension of the positional embeddings. 110 | """ 111 | def __init__(self, dim): 112 | super().__init__() 113 | self.dim = dim 114 | 115 | def forward(self, x): 116 | device = x.device 117 | half_dim = self.dim // 2 118 | emb = math.log(10000) / (half_dim - 1) 119 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 120 | emb = x[:, None] * emb[None, :] 121 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 122 | return emb 123 | 124 | 125 | # learnable positional embeds 126 | 127 | class LearnablePositionalEncoding(nn.Module): 128 | """ 129 | Learnable positional encoding module. 130 | 131 | This module generates learnable positional embeddings for input tensors. 132 | The embeddings are learned during training and can adapt to the specific task. 133 | 134 | Attributes: 135 | d_model (int): The dimension of the positional embeddings. 136 | dropout (float): The dropout rate applied to the embeddings. 137 | max_len (int): The maximum length of the input sequences. 138 | """ 139 | def __init__(self, d_model, dropout=0.1, max_len=1024): 140 | super(LearnablePositionalEncoding, self).__init__() 141 | self.dropout = nn.Dropout(p=dropout) 142 | # Each position gets its own embedding 143 | # Since indices are always 0 ... max_len, we don't have to do a look-up 144 | self.pe = nn.Parameter(torch.empty(1, max_len, d_model)) # requires_grad automatically set to True 145 | nn.init.uniform_(self.pe, -0.02, 0.02) 146 | 147 | def forward(self, x): 148 | r"""Inputs of forward function 149 | Args: 150 | x: the sequence fed to the positional encoder model (required). 151 | Shape: 152 | x: [batch size, sequence length, embed dim] 153 | output: [batch size, sequence length, embed dim] 154 | """ 155 | # print(x.shape) 156 | x = x + self.pe 157 | return self.dropout(x) 158 | 159 | 160 | class moving_avg(nn.Module): 161 | """ 162 | Moving average block to highlight the trend of time series 163 | """ 164 | def __init__(self, kernel_size, stride): 165 | super(moving_avg, self).__init__() 166 | self.kernel_size = kernel_size 167 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 168 | 169 | def forward(self, x): 170 | # padding on the both ends of time series 171 | front = x[:, 0:1, :].repeat(1, self.kernel_size - 1-math.floor((self.kernel_size - 1) // 2), 1) 172 | end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1) 173 | x = torch.cat([front, x, end], dim=1) 174 | x = self.avg(x.permute(0, 2, 1)) 175 | x = x.permute(0, 2, 1) 176 | return x 177 | 178 | 179 | class series_decomp(nn.Module): 180 | """ 181 | Series decomposition block 182 | """ 183 | def __init__(self, kernel_size): 184 | super(series_decomp, self).__init__() 185 | self.moving_avg = moving_avg(kernel_size, stride=1) 186 | 187 | def forward(self, x): 188 | moving_mean = self.moving_avg(x) 189 | res = x - moving_mean 190 | return res, moving_mean 191 | 192 | 193 | class series_decomp_multi(nn.Module): 194 | """ 195 | Series decomposition block 196 | """ 197 | def __init__(self, kernel_size): 198 | super(series_decomp_multi, self).__init__() 199 | self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size] 200 | self.layer = torch.nn.Linear(1, len(kernel_size)) 201 | 202 | def forward(self, x): 203 | moving_mean=[] 204 | for func in self.moving_avg: 205 | moving_avg = func(x) 206 | moving_mean.append(moving_avg.unsqueeze(-1)) 207 | moving_mean=torch.cat(moving_mean,dim=-1) 208 | moving_mean = torch.sum(moving_mean*nn.Softmax(-1)(self.layer(x.unsqueeze(-1))),dim=-1) 209 | res = x - moving_mean 210 | return res, moving_mean 211 | 212 | 213 | class Transpose(nn.Module): 214 | """ Wrapper class of torch.transpose() for Sequential module. """ 215 | def __init__(self, shape: tuple): 216 | super(Transpose, self).__init__() 217 | self.shape = shape 218 | 219 | def forward(self, x): 220 | return x.transpose(*self.shape) 221 | 222 | 223 | class Conv_MLP(nn.Module): 224 | def __init__(self, in_dim, out_dim, resid_pdrop=0.): 225 | super().__init__() 226 | self.sequential = nn.Sequential( 227 | Transpose(shape=(1, 2)), 228 | nn.Conv1d(in_dim, out_dim, 3, stride=1, padding=1), 229 | nn.Dropout(p=resid_pdrop), 230 | ) 231 | 232 | def forward(self, x): 233 | return self.sequential(x).transpose(1, 2) 234 | 235 | 236 | class Transformer_MLP(nn.Module): 237 | def __init__(self, n_embd, mlp_hidden_times, act, resid_pdrop): 238 | super().__init__() 239 | self.sequential = nn.Sequential( 240 | nn.Conv1d(in_channels=n_embd, out_channels=int(mlp_hidden_times * n_embd), kernel_size=1, padding=0), 241 | act, 242 | nn.Conv1d(in_channels=int(mlp_hidden_times * n_embd), out_channels=int(mlp_hidden_times * n_embd), kernel_size=3, padding=1), 243 | act, 244 | nn.Conv1d(in_channels=int(mlp_hidden_times * n_embd), out_channels=n_embd, kernel_size=3, padding=1), 245 | nn.Dropout(p=resid_pdrop), 246 | ) 247 | 248 | def forward(self, x): 249 | return self.sequential(x) 250 | 251 | 252 | class GELU2(nn.Module): 253 | def __init__(self): 254 | super().__init__() 255 | def forward(self, x): 256 | return x * F.sigmoid(1.702 * x) 257 | 258 | 259 | class AdaLayerNorm(nn.Module): 260 | def __init__(self, n_embd): 261 | super().__init__() 262 | self.emb = SinusoidalPosEmb(n_embd) 263 | self.silu = nn.SiLU() 264 | self.linear = nn.Linear(n_embd, n_embd*2) 265 | self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) 266 | 267 | def forward(self, x, timestep, label_emb=None): 268 | emb = self.emb(timestep) 269 | if label_emb is not None: 270 | emb = emb + label_emb 271 | emb = self.linear(self.silu(emb)).unsqueeze(1) 272 | scale, shift = torch.chunk(emb, 2, dim=2) 273 | x = self.layernorm(x) * (1 + scale) + shift 274 | return x 275 | 276 | 277 | class AdaInsNorm(nn.Module): 278 | def __init__(self, n_embd): 279 | super().__init__() 280 | self.emb = SinusoidalPosEmb(n_embd) 281 | self.silu = nn.SiLU() 282 | self.linear = nn.Linear(n_embd, n_embd*2) 283 | self.instancenorm = nn.InstanceNorm1d(n_embd) 284 | 285 | def forward(self, x, timestep, label_emb=None): 286 | emb = self.emb(timestep) 287 | if label_emb is not None: 288 | emb = emb + label_emb 289 | emb = self.linear(self.silu(emb)).unsqueeze(1) 290 | scale, shift = torch.chunk(emb, 2, dim=2) 291 | x = self.instancenorm(x.transpose(-1, -2)).transpose(-1,-2) * (1 + scale) + shift 292 | return x -------------------------------------------------------------------------------- /Models/interpretable_diffusion/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from torch import nn 7 | from einops import rearrange, reduce, repeat 8 | from Models.interpretable_diffusion.model_utils import LearnablePositionalEncoding, Conv_MLP,\ 9 | AdaLayerNorm, Transpose, GELU2, series_decomp 10 | 11 | 12 | class TrendBlock(nn.Module): 13 | """ 14 | Model trend of time series using the polynomial regressor. 15 | """ 16 | def __init__(self, in_dim, out_dim, in_feat, out_feat, act): 17 | super(TrendBlock, self).__init__() 18 | trend_poly = 3 19 | self.trend = nn.Sequential( 20 | nn.Conv1d(in_channels=in_dim, out_channels=trend_poly, kernel_size=3, padding=1), 21 | act, 22 | Transpose(shape=(1, 2)), 23 | nn.Conv1d(in_feat, out_feat, 3, stride=1, padding=1) 24 | ) 25 | 26 | lin_space = torch.arange(1, out_dim + 1, 1) / (out_dim + 1) 27 | self.poly_space = torch.stack([lin_space ** float(p + 1) for p in range(trend_poly)], dim=0) 28 | 29 | def forward(self, input): 30 | b, c, h = input.shape 31 | x = self.trend(input).transpose(1, 2) 32 | trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) 33 | trend_vals = trend_vals.transpose(1, 2) 34 | return trend_vals 35 | 36 | 37 | class MovingBlock(nn.Module): 38 | """ 39 | Model trend of time series using the moving average. 40 | """ 41 | def __init__(self, out_dim): 42 | super(MovingBlock, self).__init__() 43 | size = max(min(int(out_dim / 4), 24), 4) 44 | self.decomp = series_decomp(size) 45 | 46 | def forward(self, input): 47 | b, c, h = input.shape 48 | x, trend_vals = self.decomp(input) 49 | return x, trend_vals 50 | 51 | 52 | class FourierLayer(nn.Module): 53 | """ 54 | Model seasonality of time series using the inverse DFT. 55 | """ 56 | def __init__(self, d_model, low_freq=1, factor=1): 57 | super().__init__() 58 | self.d_model = d_model 59 | self.factor = factor 60 | self.low_freq = low_freq 61 | 62 | def forward(self, x): 63 | """x: (b, t, d)""" 64 | b, t, d = x.shape 65 | x_freq = torch.fft.rfft(x, dim=1) 66 | 67 | if t % 2 == 0: 68 | x_freq = x_freq[:, self.low_freq:-1] 69 | f = torch.fft.rfftfreq(t)[self.low_freq:-1] 70 | else: 71 | x_freq = x_freq[:, self.low_freq:] 72 | f = torch.fft.rfftfreq(t)[self.low_freq:] 73 | 74 | x_freq, index_tuple = self.topk_freq(x_freq) 75 | f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device) 76 | f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device) 77 | return self.extrapolate(x_freq, f, t) 78 | 79 | def extrapolate(self, x_freq, f, t): 80 | x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) 81 | f = torch.cat([f, -f], dim=1) 82 | t = rearrange(torch.arange(t, dtype=torch.float), 83 | 't -> () () t ()').to(x_freq.device) 84 | 85 | amp = rearrange(x_freq.abs(), 'b f d -> b f () d') 86 | phase = rearrange(x_freq.angle(), 'b f d -> b f () d') 87 | x_time = amp * torch.cos(2 * math.pi * f * t + phase) 88 | return reduce(x_time, 'b f t d -> b t d', 'sum') 89 | 90 | def topk_freq(self, x_freq): 91 | length = x_freq.shape[1] 92 | top_k = int(self.factor * math.log(length)) 93 | values, indices = torch.topk(x_freq.abs(), top_k, dim=1, largest=True, sorted=True) 94 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)), indexing='ij') 95 | index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) 96 | x_freq = x_freq[index_tuple] 97 | return x_freq, index_tuple 98 | 99 | 100 | class SeasonBlock(nn.Module): 101 | """ 102 | Model seasonality of time series using the Fourier series. 103 | """ 104 | def __init__(self, in_dim, out_dim, factor=1): 105 | super(SeasonBlock, self).__init__() 106 | season_poly = factor * min(32, int(out_dim // 2)) 107 | self.season = nn.Conv1d(in_channels=in_dim, out_channels=season_poly, kernel_size=1, padding=0) 108 | fourier_space = torch.arange(0, out_dim, 1) / out_dim 109 | p1, p2 = (season_poly // 2, season_poly // 2) if season_poly % 2 == 0 \ 110 | else (season_poly // 2, season_poly // 2 + 1) 111 | s1 = torch.stack([torch.cos(2 * np.pi * p * fourier_space) for p in range(1, p1 + 1)], dim=0) 112 | s2 = torch.stack([torch.sin(2 * np.pi * p * fourier_space) for p in range(1, p2 + 1)], dim=0) 113 | self.poly_space = torch.cat([s1, s2]) 114 | 115 | def forward(self, input): 116 | b, c, h = input.shape 117 | x = self.season(input) 118 | season_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) 119 | season_vals = season_vals.transpose(1, 2) 120 | return season_vals 121 | 122 | 123 | class FullAttention(nn.Module): 124 | def __init__(self, 125 | n_embd, # the embed dim 126 | n_head, # the number of heads 127 | attn_pdrop=0.1, # attention dropout prob 128 | resid_pdrop=0.1, # residual attention dropout prob 129 | ): 130 | super().__init__() 131 | assert n_embd % n_head == 0 132 | # key, query, value projections for all heads 133 | self.key = nn.Linear(n_embd, n_embd) 134 | self.query = nn.Linear(n_embd, n_embd) 135 | self.value = nn.Linear(n_embd, n_embd) 136 | 137 | # regularization 138 | self.attn_drop = nn.Dropout(attn_pdrop) 139 | self.resid_drop = nn.Dropout(resid_pdrop) 140 | # output projection 141 | self.proj = nn.Linear(n_embd, n_embd) 142 | self.n_head = n_head 143 | 144 | def forward(self, x, mask=None): 145 | B, T, C = x.size() 146 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 147 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 148 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 149 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) 150 | 151 | att = F.softmax(att, dim=-1) # (B, nh, T, T) 152 | att = self.attn_drop(att) 153 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 154 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C) 155 | att = att.mean(dim=1, keepdim=False) # (B, T, T) 156 | 157 | # output projection 158 | y = self.resid_drop(self.proj(y)) 159 | return y, att 160 | 161 | 162 | class CrossAttention(nn.Module): 163 | def __init__(self, 164 | n_embd, # the embed dim 165 | condition_embd, # condition dim 166 | n_head, # the number of heads 167 | attn_pdrop=0.1, # attention dropout prob 168 | resid_pdrop=0.1, # residual attention dropout prob 169 | ): 170 | super().__init__() 171 | assert n_embd % n_head == 0 172 | # key, query, value projections for all heads 173 | self.key = nn.Linear(condition_embd, n_embd) 174 | self.query = nn.Linear(n_embd, n_embd) 175 | self.value = nn.Linear(condition_embd, n_embd) 176 | 177 | # regularization 178 | self.attn_drop = nn.Dropout(attn_pdrop) 179 | self.resid_drop = nn.Dropout(resid_pdrop) 180 | # output projection 181 | self.proj = nn.Linear(n_embd, n_embd) 182 | self.n_head = n_head 183 | 184 | def forward(self, x, encoder_output, mask=None): 185 | B, T, C = x.size() 186 | B, T_E, _ = encoder_output.size() 187 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 188 | k = self.key(encoder_output).view(B, T_E, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 189 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 190 | v = self.value(encoder_output).view(B, T_E, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 191 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) 192 | 193 | att = F.softmax(att, dim=-1) # (B, nh, T, T) 194 | att = self.attn_drop(att) 195 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 196 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side, (B, T, C) 197 | att = att.mean(dim=1, keepdim=False) # (B, T, T) 198 | 199 | # output projection 200 | y = self.resid_drop(self.proj(y)) 201 | return y, att 202 | 203 | 204 | class EncoderBlock(nn.Module): 205 | """ an unassuming Transformer block """ 206 | def __init__(self, 207 | n_embd=1024, 208 | n_head=16, 209 | attn_pdrop=0.1, 210 | resid_pdrop=0.1, 211 | mlp_hidden_times=4, 212 | activate='GELU' 213 | ): 214 | super().__init__() 215 | 216 | self.ln1 = AdaLayerNorm(n_embd) 217 | self.ln2 = nn.LayerNorm(n_embd) 218 | self.attn = FullAttention( 219 | n_embd=n_embd, 220 | n_head=n_head, 221 | attn_pdrop=attn_pdrop, 222 | resid_pdrop=resid_pdrop, 223 | ) 224 | 225 | assert activate in ['GELU', 'GELU2'] 226 | act = nn.GELU() if activate == 'GELU' else GELU2() 227 | 228 | self.mlp = nn.Sequential( 229 | nn.Linear(n_embd, mlp_hidden_times * n_embd), 230 | act, 231 | nn.Linear(mlp_hidden_times * n_embd, n_embd), 232 | nn.Dropout(resid_pdrop), 233 | ) 234 | 235 | def forward(self, x, timestep, mask=None, label_emb=None): 236 | a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask) 237 | x = x + a 238 | x = x + self.mlp(self.ln2(x)) # only one really use encoder_output 239 | return x, att 240 | 241 | 242 | class Encoder(nn.Module): 243 | def __init__( 244 | self, 245 | n_layer=14, 246 | n_embd=1024, 247 | n_head=16, 248 | attn_pdrop=0., 249 | resid_pdrop=0., 250 | mlp_hidden_times=4, 251 | block_activate='GELU', 252 | ): 253 | super().__init__() 254 | 255 | self.blocks = nn.Sequential(*[EncoderBlock( 256 | n_embd=n_embd, 257 | n_head=n_head, 258 | attn_pdrop=attn_pdrop, 259 | resid_pdrop=resid_pdrop, 260 | mlp_hidden_times=mlp_hidden_times, 261 | activate=block_activate, 262 | ) for _ in range(n_layer)]) 263 | 264 | def forward(self, input, t, padding_masks=None, label_emb=None): 265 | x = input 266 | for block_idx in range(len(self.blocks)): 267 | x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb) 268 | return x 269 | 270 | 271 | class DecoderBlock(nn.Module): 272 | """ an unassuming Transformer block """ 273 | def __init__(self, 274 | n_channel, 275 | n_feat, 276 | n_embd=1024, 277 | n_head=16, 278 | attn_pdrop=0.1, 279 | resid_pdrop=0.1, 280 | mlp_hidden_times=4, 281 | activate='GELU', 282 | condition_dim=1024, 283 | ): 284 | super().__init__() 285 | 286 | self.ln1 = AdaLayerNorm(n_embd) 287 | self.ln2 = nn.LayerNorm(n_embd) 288 | 289 | self.attn1 = FullAttention( 290 | n_embd=n_embd, 291 | n_head=n_head, 292 | attn_pdrop=attn_pdrop, 293 | resid_pdrop=resid_pdrop, 294 | ) 295 | self.attn2 = CrossAttention( 296 | n_embd=n_embd, 297 | condition_embd=condition_dim, 298 | n_head=n_head, 299 | attn_pdrop=attn_pdrop, 300 | resid_pdrop=resid_pdrop, 301 | ) 302 | 303 | self.ln1_1 = AdaLayerNorm(n_embd) 304 | 305 | assert activate in ['GELU', 'GELU2'] 306 | act = nn.GELU() if activate == 'GELU' else GELU2() 307 | 308 | self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act) 309 | # self.decomp = MovingBlock(n_channel) 310 | self.seasonal = FourierLayer(d_model=n_embd) 311 | # self.seasonal = SeasonBlock(n_channel, n_channel) 312 | 313 | self.mlp = nn.Sequential( 314 | nn.Linear(n_embd, mlp_hidden_times * n_embd), 315 | act, 316 | nn.Linear(mlp_hidden_times * n_embd, n_embd), 317 | nn.Dropout(resid_pdrop), 318 | ) 319 | 320 | self.proj = nn.Conv1d(n_channel, n_channel * 2, 1) 321 | self.linear = nn.Linear(n_embd, n_feat) 322 | 323 | def forward(self, x, encoder_output, timestep, mask=None, label_emb=None): 324 | a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask) 325 | x = x + a 326 | a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask) 327 | x = x + a 328 | x1, x2 = self.proj(x).chunk(2, dim=1) 329 | trend, season = self.trend(x1), self.seasonal(x2) 330 | x = x + self.mlp(self.ln2(x)) 331 | m = torch.mean(x, dim=1, keepdim=True) 332 | return x - m, self.linear(m), trend, season 333 | 334 | 335 | class Decoder(nn.Module): 336 | def __init__( 337 | self, 338 | n_channel, 339 | n_feat, 340 | n_embd=1024, 341 | n_head=16, 342 | n_layer=10, 343 | attn_pdrop=0.1, 344 | resid_pdrop=0.1, 345 | mlp_hidden_times=4, 346 | block_activate='GELU', 347 | condition_dim=512 348 | ): 349 | super().__init__() 350 | self.d_model = n_embd 351 | self.n_feat = n_feat 352 | self.blocks = nn.Sequential(*[DecoderBlock( 353 | n_feat=n_feat, 354 | n_channel=n_channel, 355 | n_embd=n_embd, 356 | n_head=n_head, 357 | attn_pdrop=attn_pdrop, 358 | resid_pdrop=resid_pdrop, 359 | mlp_hidden_times=mlp_hidden_times, 360 | activate=block_activate, 361 | condition_dim=condition_dim, 362 | ) for _ in range(n_layer)]) 363 | 364 | def forward(self, x, t, enc, padding_masks=None, label_emb=None): 365 | b, c, _ = x.shape 366 | # att_weights = [] 367 | mean = [] 368 | season = torch.zeros((b, c, self.d_model), device=x.device) 369 | trend = torch.zeros((b, c, self.n_feat), device=x.device) 370 | for block_idx in range(len(self.blocks)): 371 | x, residual_mean, residual_trend, residual_season = \ 372 | self.blocks[block_idx](x, enc, t, mask=padding_masks, label_emb=label_emb) 373 | season += residual_season 374 | trend += residual_trend 375 | mean.append(residual_mean) 376 | 377 | mean = torch.cat(mean, dim=1) 378 | return x, mean, trend, season 379 | 380 | 381 | class Transformer(nn.Module): 382 | def __init__( 383 | self, 384 | n_feat, 385 | n_channel, 386 | n_layer_enc=5, 387 | n_layer_dec=14, 388 | n_embd=1024, 389 | n_heads=16, 390 | attn_pdrop=0.1, 391 | resid_pdrop=0.1, 392 | mlp_hidden_times=4, 393 | block_activate='GELU', 394 | max_len=2048, 395 | conv_params=None, 396 | **kwargs 397 | ): 398 | super().__init__() 399 | self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop) 400 | self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop) 401 | 402 | if conv_params is None or conv_params[0] is None: 403 | if n_feat < 32 and n_channel < 64: 404 | kernel_size, padding = 1, 0 405 | else: 406 | kernel_size, padding = 5, 2 407 | else: 408 | kernel_size, padding = conv_params 409 | 410 | self.combine_s = nn.Conv1d(n_embd, n_feat, kernel_size=kernel_size, stride=1, padding=padding, 411 | padding_mode='circular', bias=False) 412 | self.combine_m = nn.Conv1d(n_layer_dec, 1, kernel_size=1, stride=1, padding=0, 413 | padding_mode='circular', bias=False) 414 | 415 | self.encoder = Encoder(n_layer_enc, n_embd, n_heads, attn_pdrop, resid_pdrop, mlp_hidden_times, block_activate) 416 | self.pos_enc = LearnablePositionalEncoding(n_embd, dropout=resid_pdrop, max_len=max_len) 417 | 418 | self.decoder = Decoder(n_channel, n_feat, n_embd, n_heads, n_layer_dec, attn_pdrop, resid_pdrop, mlp_hidden_times, 419 | block_activate, condition_dim=n_embd) 420 | self.pos_dec = LearnablePositionalEncoding(n_embd, dropout=resid_pdrop, max_len=max_len) 421 | 422 | def forward(self, input, t, padding_masks=None, return_res=False): 423 | emb = self.emb(input) 424 | inp_enc = self.pos_enc(emb) 425 | enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks) 426 | 427 | inp_dec = self.pos_dec(emb) 428 | output, mean, trend, season = self.decoder(inp_dec, t, enc_cond, padding_masks=padding_masks) 429 | 430 | res = self.inverse(output) 431 | res_m = torch.mean(res, dim=1, keepdim=True) 432 | season_error = self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m 433 | trend = self.combine_m(mean) + res_m + trend 434 | 435 | if return_res: 436 | return trend, self.combine_s(season.transpose(1, 2)).transpose(1, 2), res - res_m 437 | 438 | return trend, season_error 439 | 440 | 441 | if __name__ == '__main__': 442 | pass -------------------------------------------------------------------------------- /Models/ts2vec/models/dilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class SamePadConv(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): 8 | super().__init__() 9 | self.receptive_field = (kernel_size - 1) * dilation + 1 10 | padding = self.receptive_field // 2 11 | self.conv = nn.Conv1d( 12 | in_channels, out_channels, kernel_size, 13 | padding=padding, 14 | dilation=dilation, 15 | groups=groups 16 | ) 17 | self.remove = 1 if self.receptive_field % 2 == 0 else 0 18 | 19 | def forward(self, x): 20 | out = self.conv(x) 21 | if self.remove > 0: 22 | out = out[:, :, : -self.remove] 23 | return out 24 | 25 | class ConvBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): 27 | super().__init__() 28 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) 29 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) 30 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None 31 | 32 | def forward(self, x): 33 | residual = x if self.projector is None else self.projector(x) 34 | x = F.gelu(x) 35 | x = self.conv1(x) 36 | x = F.gelu(x) 37 | x = self.conv2(x) 38 | return x + residual 39 | 40 | class DilatedConvEncoder(nn.Module): 41 | def __init__(self, in_channels, channels, kernel_size): 42 | super().__init__() 43 | self.net = nn.Sequential(*[ 44 | ConvBlock( 45 | channels[i-1] if i > 0 else in_channels, 46 | channels[i], 47 | kernel_size=kernel_size, 48 | dilation=2**i, 49 | final=(i == len(channels)-1) 50 | ) 51 | for i in range(len(channels)) 52 | ]) 53 | 54 | def forward(self, x): 55 | return self.net(x) 56 | -------------------------------------------------------------------------------- /Models/ts2vec/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .dilated_conv import DilatedConvEncoder 6 | 7 | def generate_continuous_mask(B, T, n=5, l=0.1): 8 | res = torch.full((B, T), True, dtype=torch.bool) 9 | if isinstance(n, float): 10 | n = int(n * T) 11 | n = max(min(n, T // 2), 1) 12 | 13 | if isinstance(l, float): 14 | l = int(l * T) 15 | l = max(l, 1) 16 | 17 | for i in range(B): 18 | for _ in range(n): 19 | t = np.random.randint(T-l+1) 20 | res[i, t:t+l] = False 21 | return res 22 | 23 | def generate_binomial_mask(B, T, p=0.5): 24 | return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) 25 | 26 | class TSEncoder(nn.Module): 27 | def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'): 28 | super().__init__() 29 | self.input_dims = input_dims 30 | self.output_dims = output_dims 31 | self.hidden_dims = hidden_dims 32 | self.mask_mode = mask_mode 33 | self.input_fc = nn.Linear(input_dims, hidden_dims) 34 | self.feature_extractor = DilatedConvEncoder( 35 | hidden_dims, 36 | [hidden_dims] * depth + [output_dims], 37 | kernel_size=3 38 | ) 39 | self.repr_dropout = nn.Dropout(p=0.1) 40 | 41 | def forward(self, x, mask=None): # x: B x T x input_dims 42 | nan_mask = ~x.isnan().any(axis=-1) 43 | x[~nan_mask] = 0 44 | x = self.input_fc(x) # B x T x Ch 45 | 46 | # generate & apply mask 47 | if mask is None: 48 | if self.training: 49 | mask = self.mask_mode 50 | else: 51 | mask = 'all_true' 52 | 53 | if mask == 'binomial': 54 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) 55 | elif mask == 'continuous': 56 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) 57 | elif mask == 'all_true': 58 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 59 | elif mask == 'all_false': 60 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) 61 | elif mask == 'mask_last': 62 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 63 | mask[:, -1] = False 64 | 65 | mask &= nan_mask 66 | x[~mask] = 0 67 | 68 | # conv encoder 69 | x = x.transpose(1, 2) # B x Ch x T 70 | x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T 71 | x = x.transpose(1, 2) # B x T x Co 72 | 73 | return x 74 | -------------------------------------------------------------------------------- /Models/ts2vec/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0): 7 | loss = torch.tensor(0., device=z1.device) 8 | d = 0 9 | while z1.size(1) > 1: 10 | if alpha != 0: 11 | loss += alpha * instance_contrastive_loss(z1, z2) 12 | if d >= temporal_unit: 13 | if 1 - alpha != 0: 14 | loss += (1 - alpha) * temporal_contrastive_loss(z1, z2) 15 | d += 1 16 | z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2) 17 | z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2) 18 | if z1.size(1) == 1: 19 | if alpha != 0: 20 | loss += alpha * instance_contrastive_loss(z1, z2) 21 | d += 1 22 | return loss / d 23 | 24 | def instance_contrastive_loss(z1, z2): 25 | B, T = z1.size(0), z1.size(1) 26 | if B == 1: 27 | return z1.new_tensor(0.) 28 | z = torch.cat([z1, z2], dim=0) # 2B x T x C 29 | z = z.transpose(0, 1) # T x 2B x C 30 | sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B 31 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) 32 | logits += torch.triu(sim, diagonal=1)[:, :, 1:] 33 | logits = -F.log_softmax(logits, dim=-1) 34 | 35 | i = torch.arange(B, device=z1.device) 36 | loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2 37 | return loss 38 | 39 | def temporal_contrastive_loss(z1, z2): 40 | B, T = z1.size(0), z1.size(1) 41 | if T == 1: 42 | return z1.new_tensor(0.) 43 | z = torch.cat([z1, z2], dim=1) # B x 2T x C 44 | sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T 45 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) 46 | logits += torch.triu(sim, diagonal=1)[:, :, 1:] 47 | logits = -F.log_softmax(logits, dim=-1) 48 | 49 | t = torch.arange(T, device=z1.device) 50 | loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2 51 | return loss 52 | -------------------------------------------------------------------------------- /Models/ts2vec/ts2vec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import TensorDataset, DataLoader 4 | import numpy as np 5 | from Models.ts2vec.models.encoder import TSEncoder 6 | from Models.ts2vec.models.losses import hierarchical_contrastive_loss 7 | from Models.ts2vec.utils import take_per_row, split_with_nan, centerize_vary_length_series, torch_pad_nan 8 | 9 | 10 | class TS2Vec: 11 | '''The TS2Vec model''' 12 | 13 | def __init__( 14 | self, 15 | input_dims, 16 | output_dims=320, 17 | hidden_dims=64, 18 | depth=10, 19 | device='cuda', 20 | lr=0.001, 21 | batch_size=16, 22 | max_train_length=None, 23 | temporal_unit=0, 24 | after_iter_callback=None, 25 | after_epoch_callback=None 26 | ): 27 | ''' Initialize a TS2Vec model. 28 | 29 | Args: 30 | input_dims (int): The input dimension. For a univariate time series, this should be set to 1. 31 | output_dims (int): The representation dimension. 32 | hidden_dims (int): The hidden dimension of the encoder. 33 | depth (int): The number of hidden residual blocks in the encoder. 34 | device (int): The gpu used for training and inference. 35 | lr (int): The learning rate. 36 | batch_size (int): The batch size. 37 | max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than , it would be cropped into some sequences, each of which has a length less than . 38 | temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. 39 | after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration. 40 | after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch. 41 | ''' 42 | 43 | super().__init__() 44 | self.device = device 45 | self.lr = lr 46 | self.batch_size = batch_size 47 | self.max_train_length = max_train_length 48 | self.temporal_unit = temporal_unit 49 | 50 | self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth).to(self.device) 51 | self.net = torch.optim.swa_utils.AveragedModel(self._net) 52 | self.net.update_parameters(self._net) 53 | 54 | self.after_iter_callback = after_iter_callback 55 | self.after_epoch_callback = after_epoch_callback 56 | 57 | self.n_epochs = 0 58 | self.n_iters = 0 59 | 60 | def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): 61 | ''' Training the TS2Vec model. 62 | 63 | Args: 64 | train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. 65 | n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops. 66 | n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise. 67 | verbose (bool): Whether to print the training loss after each epoch. 68 | 69 | Returns: 70 | loss_log: a list containing the training losses on each epoch. 71 | ''' 72 | assert train_data.ndim == 3 73 | 74 | if n_iters is None and n_epochs is None: 75 | n_iters = 200 if train_data.size <= 100000 else 600 # default param for n_iters 76 | 77 | if self.max_train_length is not None: 78 | sections = train_data.shape[1] // self.max_train_length 79 | if sections >= 2: 80 | train_data = np.concatenate(split_with_nan(train_data, sections, axis=1), axis=0) 81 | 82 | temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0) 83 | if temporal_missing[0] or temporal_missing[-1]: 84 | train_data = centerize_vary_length_series(train_data) 85 | 86 | train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] 87 | 88 | train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) 89 | train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True, drop_last=True) 90 | 91 | optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) 92 | 93 | loss_log = [] 94 | 95 | while True: 96 | if n_epochs is not None and self.n_epochs >= n_epochs: 97 | break 98 | 99 | cum_loss = 0 100 | n_epoch_iters = 0 101 | 102 | interrupted = False 103 | for batch in train_loader: 104 | if n_iters is not None and self.n_iters >= n_iters: 105 | interrupted = True 106 | break 107 | 108 | x = batch[0] 109 | if self.max_train_length is not None and x.size(1) > self.max_train_length: 110 | window_offset = np.random.randint(x.size(1) - self.max_train_length + 1) 111 | x = x[:, window_offset : window_offset + self.max_train_length] 112 | x = x.to(self.device) 113 | 114 | ts_l = x.size(1) 115 | crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l+1) 116 | crop_left = np.random.randint(ts_l - crop_l + 1) 117 | crop_right = crop_left + crop_l 118 | crop_eleft = np.random.randint(crop_left + 1) 119 | crop_eright = np.random.randint(low=crop_right, high=ts_l + 1) 120 | crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0)) 121 | 122 | optimizer.zero_grad() 123 | 124 | out1 = self._net(take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)) 125 | out1 = out1[:, -crop_l:] 126 | 127 | out2 = self._net(take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)) 128 | out2 = out2[:, :crop_l] 129 | 130 | loss = hierarchical_contrastive_loss( 131 | out1, 132 | out2, 133 | temporal_unit=self.temporal_unit 134 | ) 135 | 136 | loss.backward() 137 | optimizer.step() 138 | self.net.update_parameters(self._net) 139 | 140 | cum_loss += loss.item() 141 | n_epoch_iters += 1 142 | 143 | self.n_iters += 1 144 | 145 | if self.after_iter_callback is not None: 146 | self.after_iter_callback(self, loss.item()) 147 | 148 | if interrupted: 149 | break 150 | 151 | cum_loss /= n_epoch_iters 152 | loss_log.append(cum_loss) 153 | if verbose: 154 | print(f"Epoch #{self.n_epochs}: loss={cum_loss}") 155 | self.n_epochs += 1 156 | 157 | if self.after_epoch_callback is not None: 158 | self.after_epoch_callback(self, cum_loss) 159 | 160 | return loss_log 161 | 162 | def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None): 163 | out = self.net(x.to(self.device, non_blocking=True), mask) 164 | if encoding_window == 'full_series': 165 | if slicing is not None: 166 | out = out[:, slicing] 167 | out = F.max_pool1d( 168 | out.transpose(1, 2), 169 | kernel_size = out.size(1), 170 | ).transpose(1, 2) 171 | 172 | elif isinstance(encoding_window, int): 173 | out = F.max_pool1d( 174 | out.transpose(1, 2), 175 | kernel_size = encoding_window, 176 | stride = 1, 177 | padding = encoding_window // 2 178 | ).transpose(1, 2) 179 | if encoding_window % 2 == 0: 180 | out = out[:, :-1] 181 | if slicing is not None: 182 | out = out[:, slicing] 183 | 184 | elif encoding_window == 'multiscale': 185 | p = 0 186 | reprs = [] 187 | while (1 << p) + 1 < out.size(1): 188 | t_out = F.max_pool1d( 189 | out.transpose(1, 2), 190 | kernel_size = (1 << (p + 1)) + 1, 191 | stride = 1, 192 | padding = 1 << p 193 | ).transpose(1, 2) 194 | if slicing is not None: 195 | t_out = t_out[:, slicing] 196 | reprs.append(t_out) 197 | p += 1 198 | out = torch.cat(reprs, dim=-1) 199 | 200 | else: 201 | if slicing is not None: 202 | out = out[:, slicing] 203 | 204 | return out.cpu() 205 | 206 | def encode(self, data, mask=None, encoding_window=None, casual=False, sliding_length=None, sliding_padding=0, batch_size=None): 207 | ''' Compute representations using the model. 208 | 209 | Args: 210 | data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. 211 | mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'. 212 | encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size. 213 | casual (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp. 214 | sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series. 215 | sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows. 216 | batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training. 217 | 218 | Returns: 219 | repr: The representations for data. 220 | ''' 221 | assert self.net is not None, 'please train or load a net first' 222 | assert data.ndim == 3 223 | if batch_size is None: 224 | batch_size = self.batch_size 225 | n_samples, ts_l, _ = data.shape 226 | 227 | org_training = self.net.training 228 | self.net.eval() 229 | 230 | dataset = TensorDataset(torch.from_numpy(data).to(torch.float)) 231 | loader = DataLoader(dataset, batch_size=batch_size) 232 | 233 | with torch.no_grad(): 234 | output = [] 235 | for batch in loader: 236 | x = batch[0] 237 | if sliding_length is not None: 238 | reprs = [] 239 | if n_samples < batch_size: 240 | calc_buffer = [] 241 | calc_buffer_l = 0 242 | for i in range(0, ts_l, sliding_length): 243 | l = i - sliding_padding 244 | r = i + sliding_length + (sliding_padding if not casual else 0) 245 | x_sliding = torch_pad_nan( 246 | x[:, max(l, 0) : min(r, ts_l)], 247 | left=-l if l<0 else 0, 248 | right=r-ts_l if r>ts_l else 0, 249 | dim=1 250 | ) 251 | if n_samples < batch_size: 252 | if calc_buffer_l + n_samples > batch_size: 253 | out = self._eval_with_pooling( 254 | torch.cat(calc_buffer, dim=0), 255 | mask, 256 | slicing=slice(sliding_padding, sliding_padding+sliding_length), 257 | encoding_window=encoding_window 258 | ) 259 | reprs += torch.split(out, n_samples) 260 | calc_buffer = [] 261 | calc_buffer_l = 0 262 | calc_buffer.append(x_sliding) 263 | calc_buffer_l += n_samples 264 | else: 265 | out = self._eval_with_pooling( 266 | x_sliding, 267 | mask, 268 | slicing=slice(sliding_padding, sliding_padding+sliding_length), 269 | encoding_window=encoding_window 270 | ) 271 | reprs.append(out) 272 | 273 | if n_samples < batch_size: 274 | if calc_buffer_l > 0: 275 | out = self._eval_with_pooling( 276 | torch.cat(calc_buffer, dim=0), 277 | mask, 278 | slicing=slice(sliding_padding, sliding_padding+sliding_length), 279 | encoding_window=encoding_window 280 | ) 281 | reprs += torch.split(out, n_samples) 282 | calc_buffer = [] 283 | calc_buffer_l = 0 284 | 285 | out = torch.cat(reprs, dim=1) 286 | if encoding_window == 'full_series': 287 | out = F.max_pool1d( 288 | out.transpose(1, 2).contiguous(), 289 | kernel_size = out.size(1), 290 | ).squeeze(1) 291 | else: 292 | out = self._eval_with_pooling(x, mask, encoding_window=encoding_window) 293 | if encoding_window == 'full_series': 294 | out = out.squeeze(1) 295 | 296 | output.append(out) 297 | 298 | output = torch.cat(output, dim=0) 299 | 300 | self.net.train(org_training) 301 | return output.numpy() 302 | 303 | def save(self, fn): 304 | ''' Save the model to a file. 305 | 306 | Args: 307 | fn (str): filename. 308 | ''' 309 | torch.save(self.net.state_dict(), fn) 310 | 311 | def load(self, fn): 312 | ''' Load the model from a file. 313 | 314 | Args: 315 | fn (str): filename. 316 | ''' 317 | state_dict = torch.load(fn, map_location=self.device) 318 | self.net.load_state_dict(state_dict) 319 | 320 | -------------------------------------------------------------------------------- /Models/ts2vec/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import torch 5 | import random 6 | from datetime import datetime 7 | 8 | def pkl_save(name, var): 9 | with open(name, 'wb') as f: 10 | pickle.dump(var, f) 11 | 12 | def pkl_load(name): 13 | with open(name, 'rb') as f: 14 | return pickle.load(f) 15 | 16 | def torch_pad_nan(arr, left=0, right=0, dim=0): 17 | if left > 0: 18 | padshape = list(arr.shape) 19 | padshape[dim] = left 20 | arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim) 21 | if right > 0: 22 | padshape = list(arr.shape) 23 | padshape[dim] = right 24 | arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim) 25 | return arr 26 | 27 | def pad_nan_to_target(array, target_length, axis=0, both_side=False): 28 | assert array.dtype in [np.float16, np.float32, np.float64] 29 | pad_size = target_length - array.shape[axis] 30 | if pad_size <= 0: 31 | return array 32 | npad = [(0, 0)] * array.ndim 33 | if both_side: 34 | npad[axis] = (pad_size // 2, pad_size - pad_size//2) 35 | else: 36 | npad[axis] = (0, pad_size) 37 | return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan) 38 | 39 | def split_with_nan(x, sections, axis=0): 40 | assert x.dtype in [np.float16, np.float32, np.float64] 41 | arrs = np.array_split(x, sections, axis=axis) 42 | target_length = arrs[0].shape[axis] 43 | for i in range(len(arrs)): 44 | arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis) 45 | return arrs 46 | 47 | def take_per_row(A, indx, num_elem): 48 | all_indx = indx[:,None] + np.arange(num_elem) 49 | return A[torch.arange(all_indx.shape[0])[:,None], all_indx] 50 | 51 | def centerize_vary_length_series(x): 52 | prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1) 53 | suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1) 54 | offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros 55 | rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]] 56 | offset[offset < 0] += x.shape[1] 57 | column_indices = column_indices - offset[:, np.newaxis] 58 | return x[rows, column_indices] 59 | 60 | def data_dropout(arr, p): 61 | B, T = arr.shape[0], arr.shape[1] 62 | mask = np.full(B*T, False, dtype=np.bool) 63 | ele_sel = np.random.choice( 64 | B*T, 65 | size=int(B*T*p), 66 | replace=False 67 | ) 68 | mask[ele_sel] = True 69 | res = arr.copy() 70 | res[mask.reshape(B, T)] = np.nan 71 | return res 72 | 73 | def name_with_datetime(prefix='default'): 74 | now = datetime.now() 75 | return prefix + '_' + now.strftime("%Y%m%d_%H%M%S") 76 | 77 | def init_dl_program( 78 | device_name, 79 | seed=None, 80 | use_cudnn=True, 81 | deterministic=False, 82 | benchmark=False, 83 | use_tf32=False, 84 | max_threads=None 85 | ): 86 | import torch 87 | if max_threads is not None: 88 | torch.set_num_threads(max_threads) # intraop 89 | if torch.get_num_interop_threads() != max_threads: 90 | torch.set_num_interop_threads(max_threads) # interop 91 | try: 92 | import mkl 93 | except: 94 | pass 95 | else: 96 | mkl.set_num_threads(max_threads) 97 | 98 | if seed is not None: 99 | random.seed(seed) 100 | seed += 1 101 | np.random.seed(seed) 102 | seed += 1 103 | torch.manual_seed(seed) 104 | 105 | if isinstance(device_name, (str, int)): 106 | device_name = [device_name] 107 | 108 | devices = [] 109 | for t in reversed(device_name): 110 | t_device = torch.device(t) 111 | devices.append(t_device) 112 | if t_device.type == 'cuda': 113 | assert torch.cuda.is_available() 114 | torch.cuda.set_device(t_device) 115 | if seed is not None: 116 | seed += 1 117 | torch.cuda.manual_seed(seed) 118 | devices.reverse() 119 | torch.backends.cudnn.enabled = use_cudnn 120 | torch.backends.cudnn.deterministic = deterministic 121 | torch.backends.cudnn.benchmark = benchmark 122 | 123 | if hasattr(torch.backends.cudnn, 'allow_tf32'): 124 | torch.backends.cudnn.allow_tf32 = use_tf32 125 | torch.backends.cuda.matmul.allow_tf32 = use_tf32 126 | 127 | return devices if len(devices) > 1 else devices[0] 128 | 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion-TS: Interpretable Diffusion for General Time Series Generation 2 | 3 | [![](https://img.shields.io/github/stars/Y-debug-sys/Diffusion-TS.svg)](https://github.com/Y-debug-sys/Diffusion-TS/stargazers) 4 | [![](https://img.shields.io/github/forks/Y-debug-sys/Diffusion-TS.svg)](https://github.com/Y-debug-sys/Diffusion-TS/network) 5 | [![](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/Y-debug-sys/Diffusion-TS/blob/main/LICENSE) 6 | 7 | 8 | 9 | > **Abstract:** Denoising diffusion probabilistic models (DDPMs) are becoming the leading paradigm for generative models. It has recently shown breakthroughs in audio synthesis, time series imputation and forecasting. In this paper, we propose Diffusion-TS, a novel diffusion-based framework that generates multivariate time series samples of high quality by using an encoder-decoder transformer with disentangled temporal representations, in which the decomposition technique guides Diffusion-TS to capture the semantic meaning of time series while transformers mine detailed sequential information from the noisy model input. Different from existing diffusion-based approaches, we train the model to directly reconstruct the sample instead of the noise in each diffusion step, combining a Fourier-based loss term. Diffusion-TS is expected to generate time series satisfying both interpretablity and realness. In addition, it is shown that the proposed Diffusion-TS can be easily extended to conditional generation tasks, such as forecasting and imputation, without any model changes. This also motivates us to further explore the performance of Diffusion-TS under irregular settings. Finally, through qualitative and quantitative experiments, results show that Diffusion-TS achieves the state-of-the-art results on various realistic analyses of time series. 10 | 11 | Diffusion-TS is a diffusion-based framework that generates general time series samples both conditionally and unconditionally. As shown in Figure 1, the framework contains two parts: a sequence encoder and an interpretable decoder which decomposes the time series into seasonal part and trend part. The trend part contains the polynomial regressor and extracted mean of each block output. For seasonal part, we reuse trigonometric representations based on Fourier series. Regarding training, sampling and more details, please refer to [our paper](https://openreview.net/pdf?id=4h1apFjO99) in ICLR 2024. 12 | 13 |

14 | 15 |
16 | Figure 1: Overall Architecture of Diffusion-TS. 17 |

18 | 19 | 🎤 **Update (2025/2/28)**: We have added an additional [experiment](https://github.com/Y-debug-sys/Diffusion-TS/blob/main/Experiments/eeg_multiple_classes.ipynb) for the EEG dataset and incorporated the Classifier Guidance mechanism. Now, Diffusion-TS supports multi-class generation. 20 | 21 | ## Dataset Preparation 22 | 23 | All the four real-world datasets (Stocks, ETTh1, Energy and fMRI) can be obtained from [Google Drive](https://drive.google.com/file/d/11DI22zKWtHjXMnNGPWNUbyGz-JiEtZy6/view?usp=sharing). Please download **dataset.zip**, then unzip and copy it to the folder `./Data` in our repository. EEG dataset can be downloaded from [here](https://drive.google.com/file/d/1IqwE0wbCT1orVdZpul2xFiNkGnYs4t89/view?usp=sharing) and should also be placed in the aforementioned `./Data/dataset` folder. 24 | 25 | ## Running the Code 26 | 27 | The code requires conda3 (or miniconda3), and one CUDA capable GPU. The instructions below guide you regarding running the codes in this repository. 28 | 29 | ### Environment & Libraries 30 | 31 | The full libraries list is provided as a `requirements.txt` in this repo. Please create a virtual environment with `conda` or `venv` and run 32 | 33 | ~~~bash 34 | (myenv) $ pip install -r requirements.txt 35 | ~~~ 36 | 37 | ### Training & Sampling 38 | 39 | For training, you can reproduce the experimental results of all benchmarks by runing 40 | 41 | ~~~bash 42 | (myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --train 43 | ~~~ 44 | 45 | **Note:** We also provided the corresponding `.yml` files (only stocks, sines, mujoco, etth, energy and fmri) under the folder `./Config` where all possible option can be altered. You may need to change some parameters in the model for different scenarios. For example, we use the whole data to train model for unconditional evaluation, then *training_ratio* is set to 1 by default. As for conditional generation, we need to divide data set thus it should be changed to a value < 1. 46 | 47 | While training, the script will save check points to the *results* folder after a fixed number of epochs. Once trained, please use the saved model for sampling by running 48 | 49 | #### Unconstrained 50 | ```bash 51 | (myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --sample 0 --milestone {checkpoint_number} 52 | ``` 53 | 54 | #### Imputation 55 | ```bash 56 | (myenv) $ python main.py --name {name} --config_file {config.yaml} --gpu 0 --sample 1 --milestone {checkpoint_number} --mode infill --missing_ratio {missing_ratio} 57 | ``` 58 | 59 | #### Forecasting 60 | ```bash 61 | (myenv) $ python main.py --name {dataset_name} --config_file {config.yaml} --gpu 0 --sample 1 --milestone {checkpoint_number} --mode predict --pred_len {pred_len} 62 | ``` 63 | 64 | 65 | ## Visualization and Evaluation 66 | 67 | After sampling, synthetic data and orginal data are stored in `.npy` file format under the *output* folder, which can be directly read to calculate quantitative metrics such as discriminative, predictive, correlational and context-FID score. You can also reproduce the visualization results using t-SNE or kernel plotting, and all of these evaluational codes can be found in the folder `./Utils`. Please refer to `.ipynb` tutorial files in this repo for more detailed implementations. 68 | 69 | **Note:** All the metrics can be found in the `./Experiments` folder. Additionally, by default, for datasets other than the Sine dataset (because it do not need normalization), their normalized forms are saved in `{...}_norm_truth.npy`. Therefore, when you run the Jupternotebook for dataset other than Sine, just uncomment and rewrite the corresponding code written at the beginning. 70 | 71 | ### Main Results 72 | 73 | #### Standard TS Generation 74 |

75 | Table 1: Results of 24-length Time-series Generation. 76 |
77 | 78 |

79 | 80 | #### Long-term TS Generation 81 |

82 | Table 2: Results of Long-term Time-series Generation. 83 |
84 | 85 |

86 | 87 | #### Conditional TS Generation 88 |

89 | 90 |
91 | Figure 2: Visualizations of Time-series Imputation and Forecasting. 92 |

93 | 94 | 95 | ## Authors 96 | 97 | * Paper Authors : Xinyu Yuan, Yan Qiao 98 | 99 | * Code Author : Xinyu Yuan 100 | 101 | * Contact : yxy5315@gmail.com 102 | 103 | 104 | ## Citation 105 | If you find this repo useful, please cite our paper via 106 | ```bibtex 107 | @inproceedings{yuan2024diffusionts, 108 | title={Diffusion-{TS}: Interpretable Diffusion for General Time Series Generation}, 109 | author={Xinyu Yuan and Yan Qiao}, 110 | booktitle={The Twelfth International Conference on Learning Representations}, 111 | year={2024}, 112 | url={https://openreview.net/forum?id=4h1apFjO99} 113 | } 114 | ``` 115 | 116 | 117 | ## Acknowledgement 118 | 119 | We appreciate the following github repos a lot for their valuable code base: 120 | 121 | https://github.com/lucidrains/denoising-diffusion-pytorch 122 | 123 | https://github.com/cientgu/VQ-Diffusion 124 | 125 | https://github.com/XiangLi1999/Diffusion-LM 126 | 127 | https://github.com/philipperemy/n-beats 128 | 129 | https://github.com/salesforce/ETSformer 130 | 131 | https://github.com/ermongroup/CSDI 132 | 133 | https://github.com/jsyoon0823/TimeGAN 134 | -------------------------------------------------------------------------------- /Utils/Data_utils/eeg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from scipy.io import arff 7 | from scipy import stats 8 | from copy import deepcopy 9 | from torch.utils.data import Dataset 10 | from Utils.masking_utils import noise_mask 11 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one 12 | from sklearn.preprocessing import MinMaxScaler 13 | 14 | 15 | class EEGDataset(Dataset): 16 | def __init__( 17 | self, 18 | data_root, 19 | window=64, 20 | save2npy=True, 21 | neg_one_to_one=True, 22 | period='train', 23 | output_dir='./OUTPUT' 24 | ): 25 | super(EEGDataset, self).__init__() 26 | assert period in ['train', 'test'], 'period must be train or test.' 27 | 28 | self.auto_norm, self.save2npy = neg_one_to_one, save2npy 29 | self.data_0, self.data_1, self.scaler = self.read_data(data_root, window) 30 | self.labels = np.zeros(self.data_0.shape[0] + self.data_1.shape[0]).astype(np.int64) 31 | self.labels[self.data_0.shape[0]:] = 1 32 | self.rawdata = np.vstack([self.data_0, self.data_1]) 33 | self.dir = os.path.join(output_dir, 'samples') 34 | os.makedirs(self.dir, exist_ok=True) 35 | 36 | self.window, self.period = window, period 37 | self.len, self.var_num = self.rawdata.shape[0], self.rawdata.shape[-1] 38 | 39 | self.samples = self.normalize(self.rawdata) 40 | 41 | # np.save(os.path.join(self.dir, 'eeg_ground_0_truth.npy'), self.data_0) 42 | # np.save(os.path.join(self.dir, 'eeg_ground_1_truth.npy'), self.data_1) 43 | 44 | self.sample_num = self.samples.shape[0] 45 | 46 | def read_data(self, filepath, length): 47 | """ 48 | Reads the data from the given filepath, removes outliers, classifies the data into two classes, 49 | and scales the data using MinMaxScaler. 50 | 51 | Args: 52 | filepath (str): Path to the .arff file containing the EEG data. 53 | length (int): Length of the window for classification. 54 | """ 55 | data = arff.loadarff(filepath) 56 | df = pd.DataFrame(data[0]) 57 | df['eyeDetection'] = df['eyeDetection'].astype('int') 58 | 59 | df = self.__OutlierRemoval__(df) 60 | df_0, df_1 = self.__Classify__(df, length=length) 61 | # df_0.to_csv('./EEG_Eye_State_0.csv', index=False) 62 | # df_1.to_csv('./EEG_Eye_State_1.csv', index=False) 63 | 64 | data_0 = df_0.values.reshape(df_0.shape[0], length, -1) 65 | data_1 = df_1.values.reshape(df_1.shape[0], length, -1) 66 | 67 | # print(f"Class 0: {data_0.shape}, Class 1: {data_1.shape}") 68 | 69 | data = np.vstack([data_0.reshape(-1, data_0.shape[-1]), data_1.reshape(-1, data_1.shape[-1])]) 70 | 71 | scaler = MinMaxScaler() 72 | scaler = scaler.fit(data) 73 | 74 | return data_0, data_1, scaler 75 | 76 | @staticmethod 77 | def __OutlierRemoval__(df): 78 | """ 79 | Removes outliers from the dataframe using z-score method and interpolates the missing values. 80 | 81 | Args: 82 | df (pd.DataFrame): Dataframe containing the EEG data. 83 | 84 | Returns: 85 | pd.DataFrame: Cleaned dataframe with outliers removed and missing values interpolated. 86 | """ 87 | temp_data_frame = deepcopy(df) 88 | clean_data_frame = deepcopy(df) 89 | for column in temp_data_frame.columns[:-1]: 90 | temp_data_frame[str(column)+'z_score'] = stats.zscore(temp_data_frame[column]) 91 | clean_data_frame[column] = temp_data_frame.loc[temp_data_frame[str(column)+'z_score'].abs()<=3][column] 92 | 93 | clean_data_frame.interpolate(method='linear', inplace=True) 94 | 95 | temp_data_frame = deepcopy(clean_data_frame) 96 | clean_data_frame_second = deepcopy(clean_data_frame) 97 | 98 | for column in temp_data_frame.columns[:-1]: 99 | temp_data_frame[str(column)+'z_score'] = stats.zscore(temp_data_frame[column]) 100 | clean_data_frame_second[column] = temp_data_frame.loc[temp_data_frame[str(column)+'z_score'].abs()<=3][column] 101 | 102 | clean_data_frame_second.interpolate(method='linear', inplace=True) 103 | return clean_data_frame 104 | 105 | @staticmethod 106 | def __Classify__(df, length=100): 107 | """ 108 | Classifies the data into two classes based on the eyeDetection column and creates signals for the two classes. 109 | 110 | Args: 111 | df (pd.DataFrame): Dataframe containing the EEG data. 112 | length (int): Length of the window for classification. 113 | 114 | Returns: 115 | pd.DataFrame: Dataframe containing the signals for class 0. 116 | pd.DataFrame: Dataframe containing the signals for class 1. 117 | """ 118 | # normalize the columns between -1 and 1 119 | mean, max, min = df.mean(), df.max(), df.min() 120 | df = 2*(df - mean) / (max - min) 121 | 122 | df['edge'] = df['eyeDetection'].diff() 123 | df['edge'][0] = 0.0 124 | 125 | starting = df['edge'][df['edge']==2] 126 | starting = starting.iloc[:-1] 127 | starting_time = starting.index.values 128 | ending = df['edge'][df['edge']==-2] 129 | ending_time = ending.index.values 130 | end_time_with_0 = np.insert(ending_time, 0, 0) 131 | 132 | signal_0 = [] 133 | singal_1 = [] 134 | 135 | # create signal 0 136 | for start, end in zip(end_time_with_0, starting_time): 137 | for i in range(start+50, end-length-50, 1): 138 | temp = [] 139 | for channel in df.columns[:-2]: 140 | temp.append(df[channel][i:i+length]) 141 | 142 | signal_0.append(np.hstack(temp)) 143 | 144 | # create signal 1 145 | for start, end in zip(starting_time, ending_time): 146 | for i in range(start+50, end-length-50, 1): 147 | temp = [] 148 | for channel in df.columns[:-2]: 149 | temp.append(df[channel][i:i+length]) 150 | 151 | singal_1.append(np.hstack(temp)) 152 | 153 | df_0 = pd.DataFrame(signal_0) 154 | df_1 = pd.DataFrame(singal_1) 155 | 156 | # min_samples = min(df_0.shape[0], df_1.shape[0]) - 1 157 | # # chop the data to the same length 158 | # df_0 = df_0.iloc[:min_samples, :] 159 | # df_1 = df_1.iloc[:min_samples, :] 160 | 161 | return df_0, df_1 162 | 163 | def __getitem__(self, ind): 164 | if self.period == 'test': 165 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 166 | y = self.labels[ind] # (1,) int 167 | return torch.from_numpy(x).float(), torch.tensor(y) 168 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 169 | return torch.from_numpy(x).float() 170 | 171 | def __len__(self): 172 | return self.sample_num 173 | 174 | def normalize(self, sq): 175 | d = self.__normalize(sq.reshape(-1, self.var_num)) 176 | data = d.reshape(-1, self.window, self.var_num) 177 | return data 178 | 179 | def __normalize(self, rawdata): 180 | data = self.scaler.transform(rawdata) 181 | if self.auto_norm: 182 | data = normalize_to_neg_one_to_one(data) 183 | return data 184 | 185 | def unnormalize(self, sq): 186 | d = self.__unnormalize(sq.reshape(-1, self.var_num)) 187 | return d.reshape(-1, self.window, self.var_num) 188 | 189 | def __unnormalize(self, data): 190 | if self.auto_norm: 191 | data = unnormalize_to_zero_to_one(data) 192 | x = data 193 | return self.scaler.inverse_transform(x) 194 | 195 | def shift_period(self, period): 196 | assert period in ['train', 'test'], 'period must be train or test.' 197 | self.period = period -------------------------------------------------------------------------------- /Utils/Data_utils/mujoco_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from torch.utils.data import Dataset 6 | from sklearn.preprocessing import MinMaxScaler 7 | 8 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one 9 | from Utils.masking_utils import noise_mask 10 | 11 | 12 | class MuJoCoDataset(Dataset): 13 | def __init__( 14 | self, 15 | window=128, 16 | num=30000, 17 | dim=12, 18 | save2npy=True, 19 | neg_one_to_one=True, 20 | seed=123, 21 | scalar=None, 22 | period='train', 23 | output_dir='./OUTPUT', 24 | predict_length=None, 25 | missing_ratio=None, 26 | style='separate', 27 | distribution='geometric', 28 | mean_mask_length=3 29 | ): 30 | super(MuJoCoDataset, self).__init__() 31 | assert period in ['train', 'test'], 'period must be train or test.' 32 | if period == 'train': 33 | assert ~(predict_length is not None or missing_ratio is not None), '' 34 | 35 | self.window, self.var_num = window, dim 36 | self.auto_norm = neg_one_to_one 37 | self.dir = os.path.join(output_dir, 'samples') 38 | os.makedirs(self.dir, exist_ok=True) 39 | self.pred_len, self.missing_ratio = predict_length, missing_ratio 40 | self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length 41 | 42 | self.rawdata, self.scaler = self._generate_random_trajectories(n_samples=num, seed=seed) 43 | if scalar is not None: 44 | self.scaler = scalar 45 | 46 | self.period, self.save2npy = period, save2npy 47 | self.samples = self.normalize(self.rawdata) 48 | self.sample_num = self.samples.shape[0] 49 | 50 | if period == 'test': 51 | if missing_ratio is not None: 52 | self.masking = self.mask_data(seed) 53 | elif predict_length is not None: 54 | masks = np.ones(self.samples.shape) 55 | masks[:, -predict_length:, :] = 0 56 | self.masking = masks.astype(bool) 57 | else: 58 | raise NotImplementedError() 59 | 60 | def _generate_random_trajectories(self, n_samples, seed=123): 61 | try: 62 | from dm_control import suite # noqa: F401 63 | except ImportError as e: 64 | raise Exception('Deepmind Control Suite is required to generate the dataset.') from e 65 | 66 | env = suite.load('hopper', 'stand') 67 | physics = env.physics 68 | 69 | # Store the state of the RNG to restore later. 70 | st0 = np.random.get_state() 71 | np.random.seed(seed) 72 | 73 | data = np.zeros((n_samples, self.window, self.var_num)) 74 | for i in range(n_samples): 75 | with physics.reset_context(): 76 | # x and z positions of the hopper. We want z > 0 for the hopper to stay above ground. 77 | physics.data.qpos[:2] = np.random.uniform(0, 0.5, size=2) 78 | physics.data.qpos[2:] = np.random.uniform(-2, 2, size=physics.data.qpos[2:].shape) 79 | physics.data.qvel[:] = np.random.uniform(-5, 5, size=physics.data.qvel.shape) 80 | 81 | for t in range(self.window): 82 | data[i, t, :self.var_num // 2] = physics.data.qpos 83 | data[i, t, self.var_num // 2:] = physics.data.qvel 84 | physics.step() 85 | 86 | # Restore RNG. 87 | np.random.set_state(st0) 88 | 89 | scaler = MinMaxScaler() 90 | scaler = scaler.fit(data.reshape(-1, self.var_num)) 91 | return data, scaler 92 | 93 | def normalize(self, sq): 94 | d = self.__normalize(sq.reshape(-1, self.var_num)) 95 | data = d.reshape(-1, self.window, self.var_num) 96 | if self.save2npy: 97 | np.save(os.path.join(self.dir, f"mujoco_ground_truth_{self.window}_{self.period}.npy"), sq) 98 | 99 | if self.auto_norm: 100 | np.save(os.path.join(self.dir, f"mujoco_norm_truth_{self.window}_{self.period}.npy"), unnormalize_to_zero_to_one(data)) 101 | else: 102 | np.save(os.path.join(self.dir, f"mujoco_norm_truth_{self.window}_{self.period}.npy"), data) 103 | 104 | return data 105 | 106 | def __normalize(self, rawdata): 107 | data = self.scaler.transform(rawdata) 108 | if self.auto_norm: 109 | data = normalize_to_neg_one_to_one(data) 110 | return data 111 | 112 | def unnormalize(self, sq): 113 | d = self.__unnormalize(sq.reshape(-1, self.var_num)) 114 | return d.reshape(-1, self.window, self.var_num) 115 | 116 | def __unnormalize(self, data): 117 | if self.auto_norm: 118 | data = unnormalize_to_zero_to_one(data) 119 | x = data 120 | return self.scaler.inverse_transform(x) 121 | 122 | def mask_data(self, seed=2023): 123 | masks = np.ones_like(self.samples) 124 | # Store the state of the RNG to restore later. 125 | st0 = np.random.get_state() 126 | np.random.seed(seed) 127 | 128 | for idx in range(self.samples.shape[0]): 129 | x = self.samples[idx, :, :] # (seq_length, feat_dim) array 130 | mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style, 131 | self.distribution) # (seq_length, feat_dim) boolean array 132 | masks[idx, :, :] = mask 133 | 134 | if self.save2npy: 135 | np.save(os.path.join(self.dir, f"mujoco_masking_{self.window}.npy"), masks) 136 | 137 | # Restore RNG. 138 | np.random.set_state(st0) 139 | return masks.astype(bool) 140 | 141 | def __getitem__(self, ind): 142 | if self.period == 'test': 143 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 144 | m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array 145 | return torch.from_numpy(x).float(), torch.from_numpy(m) 146 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 147 | return torch.from_numpy(x).float() 148 | 149 | def __len__(self): 150 | return self.sample_num 151 | -------------------------------------------------------------------------------- /Utils/Data_utils/real_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from scipy import io 7 | from sklearn.preprocessing import MinMaxScaler 8 | from torch.utils.data import Dataset 9 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one 10 | from Utils.masking_utils import noise_mask 11 | 12 | 13 | class CustomDataset(Dataset): 14 | def __init__( 15 | self, 16 | name, 17 | data_root, 18 | window=64, 19 | proportion=0.8, 20 | save2npy=True, 21 | neg_one_to_one=True, 22 | seed=123, 23 | period='train', 24 | output_dir='./OUTPUT', 25 | predict_length=None, 26 | missing_ratio=None, 27 | style='separate', 28 | distribution='geometric', 29 | mean_mask_length=3 30 | ): 31 | super(CustomDataset, self).__init__() 32 | assert period in ['train', 'test'], 'period must be train or test.' 33 | if period == 'train': 34 | assert ~(predict_length is not None or missing_ratio is not None), '' 35 | self.name, self.pred_len, self.missing_ratio = name, predict_length, missing_ratio 36 | self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length 37 | self.rawdata, self.scaler = self.read_data(data_root, self.name) 38 | self.dir = os.path.join(output_dir, 'samples') 39 | os.makedirs(self.dir, exist_ok=True) 40 | 41 | self.window, self.period = window, period 42 | self.len, self.var_num = self.rawdata.shape[0], self.rawdata.shape[-1] 43 | self.sample_num_total = max(self.len - self.window + 1, 0) 44 | self.save2npy = save2npy 45 | self.auto_norm = neg_one_to_one 46 | 47 | self.data = self.__normalize(self.rawdata) 48 | train, inference = self.__getsamples(self.data, proportion, seed) 49 | 50 | self.samples = train if period == 'train' else inference 51 | if period == 'test': 52 | if missing_ratio is not None: 53 | self.masking = self.mask_data(seed) 54 | elif predict_length is not None: 55 | masks = np.ones(self.samples.shape) 56 | masks[:, -predict_length:, :] = 0 57 | self.masking = masks.astype(bool) 58 | else: 59 | raise NotImplementedError() 60 | self.sample_num = self.samples.shape[0] 61 | 62 | def __getsamples(self, data, proportion, seed): 63 | x = np.zeros((self.sample_num_total, self.window, self.var_num)) 64 | for i in range(self.sample_num_total): 65 | start = i 66 | end = i + self.window 67 | x[i, :, :] = data[start:end, :] 68 | 69 | train_data, test_data = self.divide(x, proportion, seed) 70 | 71 | if self.save2npy: 72 | if 1 - proportion > 0: 73 | np.save(os.path.join(self.dir, f"{self.name}_ground_truth_{self.window}_test.npy"), self.unnormalize(test_data)) 74 | np.save(os.path.join(self.dir, f"{self.name}_ground_truth_{self.window}_train.npy"), self.unnormalize(train_data)) 75 | if self.auto_norm: 76 | if 1 - proportion > 0: 77 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_test.npy"), unnormalize_to_zero_to_one(test_data)) 78 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_train.npy"), unnormalize_to_zero_to_one(train_data)) 79 | else: 80 | if 1 - proportion > 0: 81 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_test.npy"), test_data) 82 | np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_train.npy"), train_data) 83 | 84 | return train_data, test_data 85 | 86 | def normalize(self, sq): 87 | d = sq.reshape(-1, self.var_num) 88 | d = self.scaler.transform(d) 89 | if self.auto_norm: 90 | d = normalize_to_neg_one_to_one(d) 91 | return d.reshape(-1, self.window, self.var_num) 92 | 93 | def unnormalize(self, sq): 94 | d = self.__unnormalize(sq.reshape(-1, self.var_num)) 95 | return d.reshape(-1, self.window, self.var_num) 96 | 97 | def __normalize(self, rawdata): 98 | data = self.scaler.transform(rawdata) 99 | if self.auto_norm: 100 | data = normalize_to_neg_one_to_one(data) 101 | return data 102 | 103 | def __unnormalize(self, data): 104 | if self.auto_norm: 105 | data = unnormalize_to_zero_to_one(data) 106 | x = data 107 | return self.scaler.inverse_transform(x) 108 | 109 | @staticmethod 110 | def divide(data, ratio, seed=2023): 111 | size = data.shape[0] 112 | # Store the state of the RNG to restore later. 113 | st0 = np.random.get_state() 114 | np.random.seed(seed) 115 | 116 | regular_train_num = int(np.ceil(size * ratio)) 117 | # id_rdm = np.random.permutation(size) 118 | id_rdm = np.arange(size) 119 | regular_train_id = id_rdm[:regular_train_num] 120 | irregular_train_id = id_rdm[regular_train_num:] 121 | 122 | regular_data = data[regular_train_id, :] 123 | irregular_data = data[irregular_train_id, :] 124 | 125 | # Restore RNG. 126 | np.random.set_state(st0) 127 | return regular_data, irregular_data 128 | 129 | @staticmethod 130 | def read_data(filepath, name=''): 131 | """Reads a single .csv 132 | """ 133 | df = pd.read_csv(filepath, header=0) 134 | if name == 'etth': 135 | df.drop(df.columns[0], axis=1, inplace=True) 136 | data = df.values 137 | scaler = MinMaxScaler() 138 | scaler = scaler.fit(data) 139 | return data, scaler 140 | 141 | def mask_data(self, seed=2023): 142 | masks = np.ones_like(self.samples) 143 | # Store the state of the RNG to restore later. 144 | st0 = np.random.get_state() 145 | np.random.seed(seed) 146 | 147 | for idx in range(self.samples.shape[0]): 148 | x = self.samples[idx, :, :] # (seq_length, feat_dim) array 149 | mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style, 150 | self.distribution) # (seq_length, feat_dim) boolean array 151 | masks[idx, :, :] = mask 152 | 153 | if self.save2npy: 154 | np.save(os.path.join(self.dir, f"{self.name}_masking_{self.window}.npy"), masks) 155 | 156 | # Restore RNG. 157 | np.random.set_state(st0) 158 | return masks.astype(bool) 159 | 160 | def __getitem__(self, ind): 161 | if self.period == 'test': 162 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 163 | m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array 164 | return torch.from_numpy(x).float(), torch.from_numpy(m) 165 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 166 | return torch.from_numpy(x).float() 167 | 168 | def __len__(self): 169 | return self.sample_num 170 | 171 | 172 | class fMRIDataset(CustomDataset): 173 | def __init__( 174 | self, 175 | proportion=1., 176 | **kwargs 177 | ): 178 | super().__init__(proportion=proportion, **kwargs) 179 | 180 | @staticmethod 181 | def read_data(filepath, name=''): 182 | """Reads a single .csv 183 | """ 184 | data = io.loadmat(filepath + '/sim4.mat')['ts'] 185 | scaler = MinMaxScaler() 186 | scaler = scaler.fit(data) 187 | return data, scaler 188 | -------------------------------------------------------------------------------- /Utils/Data_utils/sine_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from tqdm.auto import tqdm 6 | from torch.utils.data import Dataset 7 | 8 | from Models.interpretable_diffusion.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one 9 | from Utils.masking_utils import noise_mask 10 | 11 | 12 | class SineDataset(Dataset): 13 | def __init__( 14 | self, 15 | window=128, 16 | num=30000, 17 | dim=12, 18 | save2npy=True, 19 | neg_one_to_one=True, 20 | seed=123, 21 | period='train', 22 | output_dir='./OUTPUT', 23 | predict_length=None, 24 | missing_ratio=None, 25 | style='separate', 26 | distribution='geometric', 27 | mean_mask_length=3 28 | ): 29 | super(SineDataset, self).__init__() 30 | assert period in ['train', 'test'], 'period must be train or test.' 31 | if period == 'train': 32 | assert ~(predict_length is not None or missing_ratio is not None), '' 33 | 34 | self.pred_len, self.missing_ratio = predict_length, missing_ratio 35 | self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length 36 | 37 | self.dir = os.path.join(output_dir, 'samples') 38 | os.makedirs(self.dir, exist_ok=True) 39 | 40 | self.rawdata = self.sine_data_generation(no=num, seq_len=window, dim=dim, save2npy=save2npy, 41 | seed=seed, dir=self.dir, period=period) 42 | self.auto_norm = neg_one_to_one 43 | self.samples = self.normalize(self.rawdata) 44 | self.var_num = dim 45 | self.sample_num = self.samples.shape[0] 46 | self.window = window 47 | 48 | self.period, self.save2npy = period, save2npy 49 | if period == 'test': 50 | if missing_ratio is not None: 51 | self.masking = self.mask_data(seed) 52 | elif predict_length is not None: 53 | masks = np.ones(self.samples.shape) 54 | masks[:, -predict_length:, :] = 0 55 | self.masking = masks.astype(bool) 56 | else: 57 | raise NotImplementedError() 58 | 59 | def normalize(self, rawdata): 60 | if self.auto_norm: 61 | data = normalize_to_neg_one_to_one(rawdata) 62 | return data 63 | 64 | def unnormalize(self, data): 65 | if self.auto_norm: 66 | data = unnormalize_to_zero_to_one(data) 67 | return data 68 | 69 | @staticmethod 70 | def sine_data_generation(no, seq_len, dim, save2npy=True, seed=123, dir="./", period='train'): 71 | """Sine data generation. 72 | 73 | Args: 74 | - no: the number of samples 75 | - seq_len: sequence length of the time-series 76 | - dim: feature dimensions 77 | 78 | Returns: 79 | - data: generated data 80 | """ 81 | # Store the state of the RNG to restore later. 82 | st0 = np.random.get_state() 83 | np.random.seed(seed) 84 | 85 | # Initialize the output 86 | data = list() 87 | # Generate sine data 88 | for i in tqdm(range(0, no), total=no, desc="Sampling sine-dataset"): 89 | # Initialize each time-series 90 | temp = list() 91 | # For each feature 92 | for k in range(dim): 93 | # Randomly drawn frequency and phase 94 | freq = np.random.uniform(0, 0.1) 95 | phase = np.random.uniform(0, 0.1) 96 | 97 | # Generate sine signal based on the drawn frequency and phase 98 | temp_data = [np.sin(freq * j + phase) for j in range(seq_len)] 99 | temp.append(temp_data) 100 | 101 | # Align row/column 102 | temp = np.transpose(np.asarray(temp)) 103 | # Normalize to [0,1] 104 | temp = (temp + 1)*0.5 105 | # Stack the generated data 106 | data.append(temp) 107 | 108 | # Restore RNG. 109 | np.random.set_state(st0) 110 | data = np.array(data) 111 | if save2npy: 112 | np.save(os.path.join(dir, f"sine_ground_truth_{seq_len}_{period}.npy"), data) 113 | 114 | return data 115 | 116 | def mask_data(self, seed=2023): 117 | masks = np.ones_like(self.samples) 118 | # Store the state of the RNG to restore later. 119 | st0 = np.random.get_state() 120 | np.random.seed(seed) 121 | 122 | for idx in range(self.samples.shape[0]): 123 | x = self.samples[idx, :, :] # (seq_length, feat_dim) array 124 | mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style, 125 | self.distribution) # (seq_length, feat_dim) boolean array 126 | masks[idx, :, :] = mask 127 | 128 | if self.save2npy: 129 | np.save(os.path.join(self.dir, f"sine_masking_{self.window}.npy"), masks) 130 | 131 | # Restore RNG. 132 | np.random.set_state(st0) 133 | return masks.astype(bool) 134 | 135 | def __getitem__(self, ind): 136 | if self.period == 'test': 137 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 138 | m = self.masking[ind, :, :] # (seq_length, feat_dim) boolean array 139 | return torch.from_numpy(x).float(), torch.from_numpy(m) 140 | x = self.samples[ind, :, :] # (seq_length, feat_dim) array 141 | return torch.from_numpy(x).float() 142 | 143 | def __len__(self): 144 | return self.sample_num 145 | -------------------------------------------------------------------------------- /Utils/context_fid.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import numpy as np 3 | 4 | from Models.ts2vec.ts2vec import TS2Vec 5 | 6 | 7 | def calculate_fid(act1, act2): 8 | # calculate mean and covariance statistics 9 | mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) 10 | mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) 11 | # calculate sum squared difference between means 12 | ssdiff = np.sum((mu1 - mu2)**2.0) 13 | # calculate sqrt of product between cov 14 | covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2)) 15 | # check and correct imaginary numbers from sqrt 16 | if np.iscomplexobj(covmean): 17 | covmean = covmean.real 18 | # calculate score 19 | fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) 20 | return fid 21 | 22 | def Context_FID(ori_data, generated_data): 23 | model = TS2Vec(input_dims=ori_data.shape[-1], device=0, batch_size=8, lr=0.001, output_dims=320, 24 | max_train_length=3000) 25 | model.fit(ori_data, verbose=False) 26 | ori_represenation = model.encode(ori_data, encoding_window='full_series') 27 | gen_represenation = model.encode(generated_data, encoding_window='full_series') 28 | idx = np.random.permutation(ori_data.shape[0]) 29 | ori_represenation = ori_represenation[idx] 30 | gen_represenation = gen_represenation[idx] 31 | results = calculate_fid(ori_represenation, gen_represenation) 32 | return results -------------------------------------------------------------------------------- /Utils/cross_correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def cacf_torch(x, max_lag, dim=(0, 1)): 6 | def get_lower_triangular_indices(n): 7 | return [list(x) for x in torch.tril_indices(n, n)] 8 | 9 | ind = get_lower_triangular_indices(x.shape[2]) 10 | x = (x - x.mean(dim, keepdims=True)) / x.std(dim, keepdims=True) 11 | x_l = x[..., ind[0]] 12 | x_r = x[..., ind[1]] 13 | cacf_list = list() 14 | for i in range(max_lag): 15 | y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r 16 | cacf_i = torch.mean(y, (1)) 17 | cacf_list.append(cacf_i) 18 | cacf = torch.cat(cacf_list, 1) 19 | return cacf.reshape(cacf.shape[0], -1, len(ind[0])) 20 | 21 | 22 | class Loss(nn.Module): 23 | def __init__(self, name, reg=1.0, transform=lambda x: x, threshold=10., backward=False, norm_foo=lambda x: x): 24 | super(Loss, self).__init__() 25 | self.name = name 26 | self.reg = reg 27 | self.transform = transform 28 | self.threshold = threshold 29 | self.backward = backward 30 | self.norm_foo = norm_foo 31 | 32 | def forward(self, x_fake): 33 | self.loss_componentwise = self.compute(x_fake) 34 | return self.reg * self.loss_componentwise.mean() 35 | 36 | def compute(self, x_fake): 37 | raise NotImplementedError() 38 | 39 | @property 40 | def success(self): 41 | return torch.all(self.loss_componentwise <= self.threshold) 42 | 43 | 44 | class CrossCorrelLoss(Loss): 45 | def __init__(self, x_real, **kwargs): 46 | super(CrossCorrelLoss, self).__init__(norm_foo=lambda x: torch.abs(x).sum(0), **kwargs) 47 | self.cross_correl_real = cacf_torch(self.transform(x_real), 1).mean(0)[0] 48 | 49 | def compute(self, x_fake): 50 | cross_correl_fake = cacf_torch(self.transform(x_fake), 1).mean(0)[0] 51 | loss = self.norm_foo(cross_correl_fake - self.cross_correl_real.to(x_fake.device)) 52 | return loss / 10. -------------------------------------------------------------------------------- /Utils/discriminative_metric.py: -------------------------------------------------------------------------------- 1 | """Reimplement TimeGAN-pytorch Codebase. 2 | 3 | Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar, 4 | "Time-series Generative Adversarial Networks," 5 | Neural Information Processing Systems (NeurIPS), 2019. 6 | 7 | Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks 8 | 9 | Last updated Date: October 18th 2021 10 | Code author: Zhiwei Zhang (bitzzw@gmail.com) 11 | 12 | ----------------------------- 13 | 14 | predictive_metrics.py 15 | 16 | Note: Use post-hoc RNN to classify original data and synthetic data 17 | 18 | Output: discriminative score (np.abs(classification accuracy - 0.5)) 19 | """ 20 | 21 | # Necessary Packages 22 | import tensorflow as tf 23 | import tensorflow._api.v2.compat.v1 as tf1 24 | import numpy as np 25 | from sklearn.metrics import accuracy_score 26 | from Utils.metric_utils import train_test_divide, extract_time 27 | 28 | 29 | def batch_generator(data, time, batch_size): 30 | """Mini-batch generator. 31 | 32 | Args: 33 | - data: time-series data 34 | - time: time information 35 | - batch_size: the number of samples in each batch 36 | 37 | Returns: 38 | - X_mb: time-series data in each batch 39 | - T_mb: time information in each batch 40 | """ 41 | no = len(data) 42 | idx = np.random.permutation(no) 43 | train_idx = idx[:batch_size] 44 | 45 | X_mb = list(data[i] for i in train_idx) 46 | T_mb = list(time[i] for i in train_idx) 47 | 48 | return X_mb, T_mb 49 | 50 | 51 | def discriminative_score_metrics (ori_data, generated_data): 52 | """Use post-hoc RNN to classify original data and synthetic data 53 | 54 | Args: 55 | - ori_data: original data 56 | - generated_data: generated synthetic data 57 | 58 | Returns: 59 | - discriminative_score: np.abs(classification accuracy - 0.5) 60 | """ 61 | # Initialization on the Graph 62 | tf1.reset_default_graph() 63 | 64 | # Basic Parameters 65 | no, seq_len, dim = np.asarray(ori_data).shape 66 | 67 | # Set maximum sequence length and each sequence length 68 | ori_time, ori_max_seq_len = extract_time(ori_data) 69 | generated_time, generated_max_seq_len = extract_time(ori_data) 70 | max_seq_len = max([ori_max_seq_len, generated_max_seq_len]) 71 | 72 | ## Builde a post-hoc RNN discriminator network 73 | # Network parameters 74 | hidden_dim = int(dim/2) 75 | iterations = 2000 76 | batch_size = 128 77 | 78 | # Input place holders 79 | # Feature 80 | X = tf1.placeholder(tf.float32, [None, max_seq_len, dim], name = "myinput_x") 81 | X_hat = tf1.placeholder(tf.float32, [None, max_seq_len, dim], name = "myinput_x_hat") 82 | 83 | T = tf1.placeholder(tf.int32, [None], name = "myinput_t") 84 | T_hat = tf1.placeholder(tf.int32, [None], name = "myinput_t_hat") 85 | 86 | # discriminator function 87 | def discriminator (x, t): 88 | """Simple discriminator function. 89 | 90 | Args: 91 | - x: time-series data 92 | - t: time information 93 | 94 | Returns: 95 | - y_hat_logit: logits of the discriminator output 96 | - y_hat: discriminator output 97 | - d_vars: discriminator variables 98 | """ 99 | with tf1.variable_scope("discriminator", reuse = tf1.AUTO_REUSE) as vs: 100 | d_cell = tf1.nn.rnn_cell.GRUCell(num_units=hidden_dim, activation=tf.nn.tanh, name = 'd_cell') 101 | d_outputs, d_last_states = tf1.nn.dynamic_rnn(d_cell, x, dtype=tf.float32, sequence_length = t) 102 | # y_hat_logit = tf1.contrib.layers.fully_connected(d_last_states, 1, activation_fn=None) 103 | y_hat_logit = tf1.layers.dense(d_last_states, 1, activation=None) 104 | y_hat = tf.nn.sigmoid(y_hat_logit) 105 | d_vars = [v for v in tf1.all_variables() if v.name.startswith(vs.name)] 106 | 107 | return y_hat_logit, y_hat, d_vars 108 | 109 | y_logit_real, y_pred_real, d_vars = discriminator(X, T) 110 | y_logit_fake, y_pred_fake, _ = discriminator(X_hat, T_hat) 111 | 112 | # Loss for the discriminator 113 | d_loss_real = tf1.reduce_mean(tf1.nn.sigmoid_cross_entropy_with_logits(logits = y_logit_real, 114 | labels = tf1.ones_like(y_logit_real))) 115 | d_loss_fake = tf1.reduce_mean(tf1.nn.sigmoid_cross_entropy_with_logits(logits = y_logit_fake, 116 | labels = tf1.zeros_like(y_logit_fake))) 117 | d_loss = d_loss_real + d_loss_fake 118 | 119 | # optimizer 120 | d_solver = tf1.train.AdamOptimizer().minimize(d_loss, var_list = d_vars) 121 | 122 | ## Train the discriminator 123 | # Start session and initialize 124 | sess = tf1.Session() 125 | sess.run(tf1.global_variables_initializer()) 126 | 127 | # Train/test division for both original and generated data 128 | train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat = \ 129 | train_test_divide(ori_data, generated_data, ori_time, generated_time) 130 | 131 | from tqdm.auto import tqdm 132 | 133 | # Training step 134 | for itt in tqdm(range(iterations), desc='training', total=iterations): 135 | 136 | # Batch setting 137 | X_mb, T_mb = batch_generator(train_x, train_t, batch_size) 138 | X_hat_mb, T_hat_mb = batch_generator(train_x_hat, train_t_hat, batch_size) 139 | 140 | # Train discriminator 141 | _, step_d_loss = sess.run([d_solver, d_loss], 142 | feed_dict={X: X_mb, T: T_mb, X_hat: X_hat_mb, T_hat: T_hat_mb}) 143 | 144 | ## Test the performance on the testing set 145 | y_pred_real_curr, y_pred_fake_curr = sess.run([y_pred_real, y_pred_fake], 146 | feed_dict={X: test_x, T: test_t, X_hat: test_x_hat, T_hat: test_t_hat}) 147 | 148 | y_pred_final = np.squeeze(np.concatenate((y_pred_real_curr, y_pred_fake_curr), axis = 0)) 149 | y_label_final = np.concatenate((np.ones([len(y_pred_real_curr),]), np.zeros([len(y_pred_fake_curr),])), axis = 0) 150 | 151 | # Compute the accuracy 152 | acc = accuracy_score(y_label_final, (y_pred_final>0.5)) 153 | 154 | fake_acc = accuracy_score(np.zeros([len(y_pred_fake_curr),]), (y_pred_fake_curr>0.5)) 155 | real_acc = accuracy_score(np.ones([len(y_pred_fake_curr),]), (y_pred_real_curr>0.5)) 156 | # print("Fake Accuracy: ", fake_acc) 157 | # print("Real Accuracy: ", real_acc) 158 | 159 | discriminative_score = np.abs(0.5-acc) 160 | return discriminative_score, fake_acc, real_acc 161 | -------------------------------------------------------------------------------- /Utils/imputation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | 6 | from torch import nn 7 | 8 | 9 | def get_quantile(samples,q,dim=1): 10 | return torch.quantile(samples,q,dim=dim).cpu().numpy() 11 | 12 | def plot_sample(ori_data, gen_data, masks, sample_idx=0): 13 | plt.rcParams["font.size"] = 12 14 | fig, axes = plt.subplots(nrows=7, ncols=4, figsize=(12, 15)) 15 | sample_num, seq_len, feat_dim = ori_data.shape 16 | observed = ori_data * masks 17 | 18 | quantiles = [] 19 | quantiles.append(get_quantile(torch.from_numpy(gen_data), 0.5, dim=0) * (1 - masks) + observed) 20 | quantiles.append(get_quantile(torch.from_numpy(gen_data), 0.05, dim=0) * (1 - masks) + observed) 21 | quantiles.append(get_quantile(torch.from_numpy(gen_data), 0.95, dim=0) * (1 - masks) + observed) 22 | 23 | for feat_idx in range(feat_dim): 24 | row = feat_idx // 4 25 | col = feat_idx % 4 26 | 27 | df_x = pd.DataFrame({"x": np.arange(0, seq_len), "val": ori_data[sample_idx, :, feat_idx], 28 | "y": masks[sample_idx, :, feat_idx]}) 29 | df_x = df_x[df_x.y!=0] 30 | 31 | df_o = pd.DataFrame({"x": np.arange(0, seq_len), "val": ori_data[sample_idx, :, feat_idx], 32 | "y": (1 - masks)[sample_idx, :, feat_idx]}) 33 | df_o = df_o[df_o.y!=0] 34 | 35 | axes[row][col].plot(range(0, seq_len), quantiles[0][sample_idx, :, feat_idx], color='g', linestyle='solid', label='Diffusion-TS') 36 | axes[row][col].fill_between(range(0, seq_len), quantiles[1][sample_idx, :, feat_idx], 37 | quantiles[2][sample_idx, :, feat_idx], color='g', alpha=0.3) 38 | 39 | axes[row][col].plot(df_o.x, df_o.val, color='b', marker='o', linestyle='None') 40 | axes[row][col].plot(df_x.x, df_x.val, color='r', marker='x', linestyle='None') 41 | 42 | if col == 0: 43 | plt.setp(axes[row, 0], ylabel='value') 44 | if row == -1: 45 | plt.setp(axes[-1, col], xlabel='time') 46 | plt.tight_layout() 47 | plt.show() 48 | 49 | 50 | class MaskedLoss(nn.Module): 51 | """ Masked MSE Loss 52 | """ 53 | 54 | def __init__(self, reduction: str = 'mean', mode='mse'): 55 | 56 | super().__init__() 57 | 58 | self.reduction = reduction 59 | if mode == 'mse': 60 | self.loss = nn.MSELoss(reduction=self.reduction) 61 | else: 62 | self.loss = nn.L1Loss(reduction=self.reduction) 63 | 64 | def forward(self, 65 | y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: 66 | """Compute the loss between a target value and a prediction. 67 | 68 | Args: 69 | y_pred: Estimated values 70 | y_true: Target values 71 | mask: boolean tensor with 0s at places where values should be ignored and 1s where they should be considered 72 | 73 | Returns 74 | ------- 75 | if reduction == 'none': 76 | (num_active,) Loss for each active batch element as a tensor with gradient attached. 77 | if reduction == 'mean': 78 | scalar mean loss over batch as a tensor with gradient attached. 79 | """ 80 | 81 | # for this particular loss, one may also elementwise multiply y_pred and y_true with the inverted mask 82 | masked_pred = torch.masked_select(y_pred, mask) 83 | masked_true = torch.masked_select(y_true, mask) 84 | 85 | return self.loss(masked_pred, masked_true) 86 | 87 | def random_mask(observed_values, missing_ratio=0.1, seed=1984): 88 | observed_masks = ~np.isnan(observed_values) 89 | 90 | # randomly set some percentage as ground-truth 91 | masks = observed_masks.reshape(-1).copy() 92 | obs_indices = np.where(masks)[0].tolist() 93 | 94 | # Store the state of the RNG to restore later. 95 | st0 = np.random.get_state() 96 | np.random.seed(seed) 97 | 98 | miss_indices = np.random.choice( 99 | obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False 100 | ) 101 | 102 | # Restore RNG. 103 | np.random.set_state(st0) 104 | 105 | masks[miss_indices] = False 106 | gt_masks = masks.reshape(observed_masks.shape) 107 | 108 | observed_values = np.nan_to_num(observed_values) 109 | return torch.from_numpy(observed_values).float(), torch.from_numpy(observed_masks).float(),\ 110 | torch.from_numpy(gt_masks).float() -------------------------------------------------------------------------------- /Utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import json 5 | import torch 6 | import random 7 | import warnings 8 | import importlib 9 | import numpy as np 10 | 11 | 12 | def load_yaml_config(path): 13 | with open(path) as f: 14 | config = yaml.full_load(f) 15 | return config 16 | 17 | def save_config_to_yaml(config, path): 18 | assert path.endswith('.yaml') 19 | with open(path, 'w') as f: 20 | f.write(yaml.dump(config)) 21 | f.close() 22 | 23 | def save_dict_to_json(d, path, indent=None): 24 | json.dump(d, open(path, 'w'), indent=indent) 25 | 26 | def load_dict_from_json(path): 27 | return json.load(open(path, 'r')) 28 | 29 | def write_args(args, path): 30 | args_dict = dict((name, getattr(args, name)) for name in dir(args)if not name.startswith('_')) 31 | with open(path, 'a') as args_file: 32 | args_file.write('==> torch version: {}\n'.format(torch.__version__)) 33 | args_file.write('==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) 34 | args_file.write('==> Cmd:\n') 35 | args_file.write(str(sys.argv)) 36 | args_file.write('\n==> args:\n') 37 | for k, v in sorted(args_dict.items()): 38 | args_file.write(' %s: %s\n' % (str(k), str(v))) 39 | args_file.close() 40 | 41 | def seed_everything(seed, cudnn_deterministic=False): 42 | """ 43 | Function that sets seed for pseudo-random number generators in: 44 | pytorch, numpy, python.random 45 | 46 | Args: 47 | seed: the integer value seed for global random state 48 | """ 49 | if seed is not None: 50 | print(f"Global seed set to {seed}") 51 | random.seed(seed) 52 | np.random.seed(seed) 53 | torch.manual_seed(seed) 54 | torch.cuda.manual_seed_all(seed) 55 | torch.backends.cudnn.deterministic = False 56 | 57 | if cudnn_deterministic: 58 | torch.backends.cudnn.deterministic = True 59 | warnings.warn('You have chosen to seed training. ' 60 | 'This will turn on the CUDNN deterministic setting, ' 61 | 'which can slow down your training considerably! ' 62 | 'You may see unexpected behavior when restarting ' 63 | 'from checkpoints.') 64 | 65 | def merge_opts_to_config(config, opts): 66 | def modify_dict(c, nl, v): 67 | if len(nl) == 1: 68 | c[nl[0]] = type(c[nl[0]])(v) 69 | else: 70 | # print(nl) 71 | c[nl[0]] = modify_dict(c[nl[0]], nl[1:], v) 72 | return c 73 | 74 | if opts is not None and len(opts) > 0: 75 | assert len(opts) % 2 == 0, "each opts should be given by the name and values! The length shall be even number!" 76 | for i in range(len(opts) // 2): 77 | name = opts[2*i] 78 | value = opts[2*i+1] 79 | config = modify_dict(config, name.split('.'), value) 80 | return config 81 | 82 | def modify_config_for_debug(config): 83 | config['dataloader']['num_workers'] = 0 84 | config['dataloader']['batch_size'] = 1 85 | return config 86 | 87 | def get_model_parameters_info(model): 88 | # for mn, m in model.named_modules(): 89 | parameters = {'overall': {'trainable': 0, 'non_trainable': 0, 'total': 0}} 90 | for child_name, child_module in model.named_children(): 91 | parameters[child_name] = {'trainable': 0, 'non_trainable': 0} 92 | for pn, p in child_module.named_parameters(): 93 | if p.requires_grad: 94 | parameters[child_name]['trainable'] += p.numel() 95 | else: 96 | parameters[child_name]['non_trainable'] += p.numel() 97 | parameters[child_name]['total'] = parameters[child_name]['trainable'] + parameters[child_name]['non_trainable'] 98 | 99 | parameters['overall']['trainable'] += parameters[child_name]['trainable'] 100 | parameters['overall']['non_trainable'] += parameters[child_name]['non_trainable'] 101 | parameters['overall']['total'] += parameters[child_name]['total'] 102 | 103 | # format the numbers 104 | def format_number(num): 105 | K = 2**10 106 | M = 2**20 107 | G = 2**30 108 | if num > G: # K 109 | uint = 'G' 110 | num = round(float(num)/G, 2) 111 | elif num > M: 112 | uint = 'M' 113 | num = round(float(num)/M, 2) 114 | elif num > K: 115 | uint = 'K' 116 | num = round(float(num)/K, 2) 117 | else: 118 | uint = '' 119 | 120 | return '{}{}'.format(num, uint) 121 | 122 | def format_dict(d): 123 | for k, v in d.items(): 124 | if isinstance(v, dict): 125 | format_dict(v) 126 | else: 127 | d[k] = format_number(v) 128 | 129 | format_dict(parameters) 130 | return parameters 131 | 132 | def format_seconds(seconds): 133 | h = int(seconds // 3600) 134 | m = int(seconds // 60 - h * 60) 135 | s = int(seconds % 60) 136 | 137 | d = int(h // 24) 138 | h = h - d * 24 139 | 140 | if d == 0: 141 | if h == 0: 142 | if m == 0: 143 | ft = '{:02d}s'.format(s) 144 | else: 145 | ft = '{:02d}m:{:02d}s'.format(m, s) 146 | else: 147 | ft = '{:02d}h:{:02d}m:{:02d}s'.format(h, m, s) 148 | 149 | else: 150 | ft = '{:d}d:{:02d}h:{:02d}m:{:02d}s'.format(d, h, m, s) 151 | 152 | return ft 153 | 154 | def instantiate_from_config(config): 155 | if config is None: 156 | return None 157 | if not "target" in config: 158 | raise KeyError("Expected key `target` to instantiate.") 159 | module, cls = config["target"].rsplit(".", 1) 160 | cls = getattr(importlib.import_module(module, package=None), cls) 161 | return cls(**config.get("params", dict())) 162 | 163 | def class_from_string(class_name): 164 | module, cls = class_name.rsplit(".", 1) 165 | cls = getattr(importlib.import_module(module, package=None), cls) 166 | return cls 167 | 168 | def get_all_file(dir, end_with='.h5'): 169 | if isinstance(end_with, str): 170 | end_with = [end_with] 171 | filenames = [] 172 | for root, dirs, files in os.walk(dir): 173 | for f in files: 174 | for ew in end_with: 175 | if f.endswith(ew): 176 | filenames.append(os.path.join(root, f)) 177 | break 178 | return filenames 179 | 180 | def get_sub_dirs(dir, abs=True): 181 | sub_dirs = os.listdir(dir) 182 | if abs: 183 | sub_dirs = [os.path.join(dir, s) for s in sub_dirs] 184 | return sub_dirs 185 | 186 | def get_model_buffer(model): 187 | state_dict = model.state_dict() 188 | buffers_ = {} 189 | params_ = {n: p for n, p in model.named_parameters()} 190 | 191 | for k in state_dict: 192 | if k not in params_: 193 | buffers_[k] = state_dict[k] 194 | return buffers_ 195 | -------------------------------------------------------------------------------- /Utils/masking_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | def costume_collate(data, max_len=None, mask_compensation=False): 7 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create 8 | Args: 9 | data: len(batch_size) list of tuples (X, mask). 10 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length. 11 | - mask: boolean torch tensor of shape (seq_length, feat_dim); variable seq_length. 12 | max_len: global fixed sequence length. Used for architectures requiring fixed length input, 13 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s 14 | Returns: 15 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input) 16 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output) 17 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor 18 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values 19 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 ignore (padding) 20 | """ 21 | 22 | batch_size = len(data) 23 | features, masks = zip(*data) 24 | 25 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension) 26 | lengths = [X.shape[0] for X in features] # original sequence length for each time series 27 | if max_len is None: 28 | max_len = max(lengths) 29 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim) 30 | target_masks = torch.zeros_like( 31 | X, dtype=torch.bool) # (batch_size, padded_length, feat_dim) masks related to objective 32 | for i in range(batch_size): 33 | end = min(lengths[i], max_len) 34 | X[i, :end, :] = features[i][:end, :] 35 | target_masks[i, :end, :] = masks[i][:end, :] 36 | 37 | targets = X.clone() 38 | X = X * target_masks # mask input 39 | if mask_compensation: 40 | X = compensate_masking(X, target_masks) 41 | 42 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), 43 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep 44 | # target_masks = ~target_masks # inverse logic: 0 now means ignore, 1 means predict 45 | return X, targets, target_masks, padding_masks 46 | 47 | 48 | def compensate_masking(X, mask): 49 | """ 50 | Compensate feature vectors after masking values, in a way that the matrix product W @ X would not be affected on average. 51 | If p is the proportion of unmasked (active) elements, X' = X / p = X * feat_dim/num_active 52 | Args: 53 | X: (batch_size, seq_length, feat_dim) torch tensor 54 | mask: (batch_size, seq_length, feat_dim) torch tensor: 0s means mask and predict, 1s: unaffected (active) input 55 | Returns: 56 | (batch_size, seq_length, feat_dim) compensated features 57 | """ 58 | 59 | # number of unmasked elements of feature vector for each time step 60 | num_active = torch.sum(mask, dim=-1).unsqueeze(-1) # (batch_size, seq_length, 1) 61 | # to avoid division by 0, set the minimum to 1 62 | num_active = torch.max(num_active, torch.ones(num_active.shape, dtype=torch.int16)) # (batch_size, seq_length, 1) 63 | return X.shape[-1] * X / num_active 64 | 65 | 66 | def padding_mask(lengths, max_len=None): 67 | """ 68 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths, 69 | where 1 means keep element at this position (time step) 70 | """ 71 | batch_size = lengths.numel() 72 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types 73 | return (torch.arange(0, max_len, device=lengths.device) 74 | .type_as(lengths) 75 | .repeat(batch_size, 1) 76 | .lt(lengths.unsqueeze(1))) 77 | 78 | 79 | def noise_mask(X, masking_ratio, lm=3, mode='separate', distribution='geometric', exclude_feats=None): 80 | """ 81 | Creates a random boolean mask of the same shape as X, with 0s at places where a feature should be masked. 82 | Args: 83 | X: (seq_length, feat_dim) numpy array of features corresponding to a single sample 84 | masking_ratio: proportion of seq_length to be masked. At each time step, will also be the proportion of 85 | feat_dim that will be masked on average 86 | lm: average length of masking subsequences (streaks of 0s). Used only when `distribution` is 'geometric'. 87 | mode: whether each variable should be masked separately ('separate'), or all variables at a certain positions 88 | should be masked concurrently ('concurrent') 89 | distribution: whether each mask sequence element is sampled independently at random, or whether 90 | sampling follows a markov chain (and thus is stateful), resulting in geometric distributions of 91 | masked squences of a desired mean length `lm` 92 | exclude_feats: iterable of indices corresponding to features to be excluded from masking (i.e. to remain all 1s) 93 | 94 | Returns: 95 | boolean numpy array with the same shape as X, with 0s at places where a feature should be masked 96 | """ 97 | if exclude_feats is not None: 98 | exclude_feats = set(exclude_feats) 99 | 100 | if distribution == 'geometric': # stateful (Markov chain) 101 | if mode == 'separate': # each variable (feature) is independent 102 | mask = np.ones(X.shape, dtype=bool) 103 | for m in range(X.shape[1]): # feature dimension 104 | if exclude_feats is None or m not in exclude_feats: 105 | mask[:, m] = geom_noise_mask_single(X.shape[0], lm, masking_ratio) # time dimension 106 | else: # replicate across feature dimension (mask all variables at the same positions concurrently) 107 | mask = np.tile(np.expand_dims(geom_noise_mask_single(X.shape[0], lm, masking_ratio), 1), X.shape[1]) 108 | else: # each position is independent Bernoulli with p = 1 - masking_ratio 109 | if mode == 'separate': 110 | mask = np.random.choice(np.array([True, False]), size=X.shape, replace=True, 111 | p=(1 - masking_ratio, masking_ratio)) 112 | else: 113 | mask = np.tile(np.random.choice(np.array([True, False]), size=(X.shape[0], 1), replace=True, 114 | p=(1 - masking_ratio, masking_ratio)), X.shape[1]) 115 | 116 | return mask 117 | 118 | 119 | def geom_noise_mask_single(L, lm, masking_ratio): 120 | """ 121 | Randomly create a boolean mask of length `L`, consisting of subsequences of average length lm, masking with 0s a `masking_ratio` 122 | proportion of the sequence L. The length of masking subsequences and intervals follow a geometric distribution. 123 | Args: 124 | L: length of mask and sequence to be masked 125 | lm: average length of masking subsequences (streaks of 0s) 126 | masking_ratio: proportion of L to be masked 127 | 128 | Returns: 129 | (L,) boolean numpy array intended to mask ('drop') with 0s a sequence of length L 130 | """ 131 | keep_mask = np.ones(L, dtype=bool) 132 | p_m = 1 / lm # probability of each masking sequence stopping. parameter of geometric distribution. 133 | p_u = p_m * masking_ratio / (1 - masking_ratio) 134 | # probability of each unmasked sequence stopping. parameter of geometric distribution. 135 | p = [p_m, p_u] 136 | 137 | # Start in state 0 with masking_ratio probability 138 | state = int(np.random.rand() > masking_ratio) # state 0 means masking, 1 means not masking 139 | for i in range(L): 140 | keep_mask[i] = state # here it happens that state and masking value corresponding to state are identical 141 | if np.random.rand() < p[state]: 142 | state = 1 - state 143 | 144 | return keep_mask -------------------------------------------------------------------------------- /Utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | ## Necessary Packages 2 | import scipy.stats 3 | import numpy as np 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | 7 | from sklearn.manifold import TSNE 8 | from sklearn.decomposition import PCA 9 | 10 | 11 | def display_scores(results): 12 | mean = np.mean(results) 13 | sigma = scipy.stats.sem(results) 14 | sigma = sigma * scipy.stats.t.ppf((1 + 0.95) / 2., 5-1) 15 | # sigma = 1.96*(np.std(results)/np.sqrt(len(results))) 16 | print('Final Score: ', f'{mean} \xB1 {sigma}') 17 | 18 | 19 | def train_test_divide (data_x, data_x_hat, data_t, data_t_hat, train_rate=0.8): 20 | """Divide train and test data for both original and synthetic data. 21 | 22 | Args: 23 | - data_x: original data 24 | - data_x_hat: generated data 25 | - data_t: original time 26 | - data_t_hat: generated time 27 | - train_rate: ratio of training data from the original data 28 | """ 29 | # Divide train/test index (original data) 30 | no = len(data_x) 31 | idx = np.random.permutation(no) 32 | train_idx = idx[:int(no*train_rate)] 33 | test_idx = idx[int(no*train_rate):] 34 | 35 | train_x = [data_x[i] for i in train_idx] 36 | test_x = [data_x[i] for i in test_idx] 37 | train_t = [data_t[i] for i in train_idx] 38 | test_t = [data_t[i] for i in test_idx] 39 | 40 | # Divide train/test index (synthetic data) 41 | no = len(data_x_hat) 42 | idx = np.random.permutation(no) 43 | train_idx = idx[:int(no*train_rate)] 44 | test_idx = idx[int(no*train_rate):] 45 | 46 | train_x_hat = [data_x_hat[i] for i in train_idx] 47 | test_x_hat = [data_x_hat[i] for i in test_idx] 48 | train_t_hat = [data_t_hat[i] for i in train_idx] 49 | test_t_hat = [data_t_hat[i] for i in test_idx] 50 | 51 | return train_x, train_x_hat, test_x, test_x_hat, train_t, train_t_hat, test_t, test_t_hat 52 | 53 | 54 | def extract_time (data): 55 | """Returns Maximum sequence length and each sequence length. 56 | 57 | Args: 58 | - data: original data 59 | 60 | Returns: 61 | - time: extracted time information 62 | - max_seq_len: maximum sequence length 63 | """ 64 | time = list() 65 | max_seq_len = 0 66 | for i in range(len(data)): 67 | max_seq_len = max(max_seq_len, len(data[i][:,0])) 68 | time.append(len(data[i][:,0])) 69 | 70 | return time, max_seq_len 71 | 72 | 73 | def visualization(ori_data, generated_data, analysis, compare=3000): 74 | """Using PCA or tSNE for generated and original data visualization. 75 | 76 | Args: 77 | - ori_data: original data 78 | - generated_data: generated synthetic data 79 | - analysis: tsne or pca or kernel 80 | """ 81 | # Analysis sample size (for faster computation) 82 | anal_sample_no = min([compare, ori_data.shape[0]]) 83 | idx = np.random.permutation(ori_data.shape[0])[:anal_sample_no] 84 | 85 | # Data preprocessing 86 | # ori_data = np.asarray(ori_data) 87 | # generated_data = np.asarray(generated_data) 88 | 89 | ori_data = ori_data[idx] 90 | generated_data = generated_data[idx] 91 | 92 | no, seq_len, dim = ori_data.shape 93 | 94 | for i in range(anal_sample_no): 95 | if (i == 0): 96 | prep_data = np.reshape(np.mean(ori_data[0, :, :], 1), [1, seq_len]) 97 | prep_data_hat = np.reshape(np.mean(generated_data[0, :, :], 1), [1, seq_len]) 98 | else: 99 | prep_data = np.concatenate((prep_data, 100 | np.reshape(np.mean(ori_data[i, :, :], 1), [1, seq_len]))) 101 | prep_data_hat = np.concatenate((prep_data_hat, 102 | np.reshape(np.mean(generated_data[i, :, :], 1), [1, seq_len]))) 103 | 104 | # Visualization parameter 105 | colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)] 106 | 107 | if analysis == 'pca': 108 | # PCA Analysis 109 | pca = PCA(n_components=2) 110 | pca.fit(prep_data) 111 | pca_results = pca.transform(prep_data) 112 | pca_hat_results = pca.transform(prep_data_hat) 113 | 114 | # Plotting 115 | f, ax = plt.subplots(1) 116 | plt.scatter(pca_results[:, 0], pca_results[:, 1], 117 | c=colors[:anal_sample_no], alpha=0.2, label="Original") 118 | plt.scatter(pca_hat_results[:, 0], pca_hat_results[:, 1], 119 | c=colors[anal_sample_no:], alpha=0.2, label="Synthetic") 120 | 121 | ax.legend() 122 | plt.title('PCA plot') 123 | plt.xlabel('x-pca') 124 | plt.ylabel('y_pca') 125 | plt.show() 126 | 127 | elif analysis == 'tsne': 128 | 129 | # Do t-SNE Analysis together 130 | prep_data_final = np.concatenate((prep_data, prep_data_hat), axis=0) 131 | 132 | # TSNE anlaysis 133 | tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) 134 | tsne_results = tsne.fit_transform(prep_data_final) 135 | 136 | # Plotting 137 | f, ax = plt.subplots(1) 138 | 139 | plt.scatter(tsne_results[:anal_sample_no, 0], tsne_results[:anal_sample_no, 1], 140 | c=colors[:anal_sample_no], alpha=0.2, label="Original") 141 | plt.scatter(tsne_results[anal_sample_no:, 0], tsne_results[anal_sample_no:, 1], 142 | c=colors[anal_sample_no:], alpha=0.2, label="Synthetic") 143 | 144 | ax.legend() 145 | 146 | plt.title('t-SNE plot') 147 | plt.xlabel('x-tsne') 148 | plt.ylabel('y_tsne') 149 | plt.show() 150 | 151 | elif analysis == 'kernel': 152 | 153 | # Visualization parameter 154 | # colors = ["red" for i in range(anal_sample_no)] + ["blue" for i in range(anal_sample_no)] 155 | 156 | f, ax = plt.subplots(1) 157 | sns.distplot(prep_data, hist=False, kde=True, kde_kws={'linewidth': 5}, label='Original', color="red") 158 | sns.distplot(prep_data_hat, hist=False, kde=True, kde_kws={'linewidth': 5, 'linestyle':'--'}, label='Synthetic', color="blue") 159 | # Plot formatting 160 | 161 | # plt.legend(prop={'size': 22}) 162 | plt.legend() 163 | plt.xlabel('Data Value') 164 | plt.ylabel('Data Density Estimate') 165 | # plt.rcParams['pdf.fonttype'] = 42 166 | 167 | # plt.savefig(str(args.save_dir)+"/"+args.model1+"_histo.png", dpi=100,bbox_inches='tight') 168 | # plt.ylim((0, 12)) 169 | plt.show() 170 | plt.close() 171 | 172 | 173 | if __name__ == '__main__': 174 | pass -------------------------------------------------------------------------------- /Utils/predictive_metric.py: -------------------------------------------------------------------------------- 1 | """Reimplement TimeGAN-pytorch Codebase. 2 | 3 | Reference: Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar, 4 | "Time-series Generative Adversarial Networks," 5 | Neural Information Processing Systems (NeurIPS), 2019. 6 | 7 | Paper link: https://papers.nips.cc/paper/8789-time-series-generative-adversarial-networks 8 | 9 | Last updated Date: October 18th 2021 10 | Code author: Zhiwei Zhang (bitzzw@gmail.com) 11 | 12 | ----------------------------- 13 | 14 | predictive_metrics.py 15 | 16 | Note: Use Post-hoc RNN to predict one-step ahead (last feature) 17 | """ 18 | 19 | # Necessary Packages 20 | import tensorflow as tf 21 | import tensorflow._api.v2.compat.v1 as tf1 22 | tf.compat.v1.disable_eager_execution() 23 | import numpy as np 24 | from sklearn.metrics import mean_absolute_error 25 | from Utils.metric_utils import extract_time 26 | 27 | 28 | def predictive_score_metrics(ori_data, generated_data): 29 | """Report the performance of Post-hoc RNN one-step ahead prediction. 30 | 31 | Args: 32 | - ori_data: original data 33 | - generated_data: generated synthetic data 34 | 35 | Returns: 36 | - predictive_score: MAE of the predictions on the original data 37 | """ 38 | # Initialization on the Graph 39 | tf1.reset_default_graph() 40 | 41 | # Basic Parameters 42 | no, seq_len, dim = ori_data.shape 43 | 44 | # Set maximum sequence length and each sequence length 45 | ori_time, ori_max_seq_len = extract_time(ori_data) 46 | generated_time, generated_max_seq_len = extract_time(ori_data) 47 | max_seq_len = max([ori_max_seq_len, generated_max_seq_len]) 48 | # max_seq_len = 36 49 | 50 | ## Builde a post-hoc RNN predictive network 51 | # Network parameters 52 | hidden_dim = int(dim/2) 53 | iterations = 5000 54 | batch_size = 128 55 | 56 | # Input place holders 57 | X = tf1.placeholder(tf.float32, [None, max_seq_len-1, dim-1], name = "myinput_x") 58 | T = tf1.placeholder(tf.int32, [None], name = "myinput_t") 59 | Y = tf1.placeholder(tf.float32, [None, max_seq_len-1, 1], name = "myinput_y") 60 | 61 | # Predictor function 62 | def predictor (x, t): 63 | """Simple predictor function. 64 | 65 | Args: 66 | - x: time-series data 67 | - t: time information 68 | 69 | Returns: 70 | - y_hat: prediction 71 | - p_vars: predictor variables 72 | """ 73 | with tf1.variable_scope("predictor", reuse = tf1.AUTO_REUSE) as vs: 74 | p_cell = tf1.nn.rnn_cell.GRUCell(num_units=hidden_dim, activation=tf.nn.tanh, name = 'p_cell') 75 | p_outputs, p_last_states = tf1.nn.dynamic_rnn(p_cell, x, dtype=tf.float32, sequence_length = t) 76 | # y_hat_logit = tf.contrib.layers.fully_connected(p_outputs, 1, activation_fn=None) 77 | y_hat_logit = tf1.layers.dense(p_outputs, 1, activation=None) 78 | y_hat = tf.nn.sigmoid(y_hat_logit) 79 | p_vars = [v for v in tf1.all_variables() if v.name.startswith(vs.name)] 80 | 81 | return y_hat, p_vars 82 | 83 | y_pred, p_vars = predictor(X, T) 84 | # Loss for the predictor 85 | p_loss = tf1.losses.absolute_difference(Y, y_pred) 86 | # optimizer 87 | p_solver = tf1.train.AdamOptimizer().minimize(p_loss, var_list = p_vars) 88 | 89 | ## Training 90 | # Session start 91 | sess = tf1.Session() 92 | sess.run(tf1.global_variables_initializer()) 93 | 94 | from tqdm.auto import tqdm 95 | 96 | # Training using Synthetic dataset 97 | for itt in tqdm(range(iterations), desc='training', total=iterations): 98 | 99 | # Set mini-batch 100 | idx = np.random.permutation(len(generated_data)) 101 | train_idx = idx[:batch_size] 102 | 103 | X_mb = list(generated_data[i][:-1,:(dim-1)] for i in train_idx) 104 | T_mb = list(generated_time[i]-1 for i in train_idx) 105 | Y_mb = list(np.reshape(generated_data[i][1:,(dim-1)],[len(generated_data[i][1:,(dim-1)]),1]) for i in train_idx) 106 | 107 | # Train predictor 108 | _, step_p_loss = sess.run([p_solver, p_loss], feed_dict={X: X_mb, T: T_mb, Y: Y_mb}) 109 | 110 | ## Test the trained model on the original data 111 | idx = np.random.permutation(len(ori_data)) 112 | train_idx = idx[:no] 113 | 114 | # idx = np.random.permutation(len(generated_data)) 115 | # train_idx = idx[:batch_size] 116 | # X_mb = list(generated_data[i][:-1,:(dim-1)] for i in train_idx) 117 | # T_mb = list(generated_time[i]-1 for i in train_idx) 118 | # Y_mb = list(np.reshape(generated_data[i][1:,(dim-1)],[len(generated_data[i][1:,(dim-1)]),1]) for i in train_idx) 119 | 120 | X_mb = list(ori_data[i][:-1,:(dim-1)] for i in train_idx) 121 | T_mb = list(ori_time[i]-1 for i in train_idx) 122 | Y_mb = list(np.reshape(ori_data[i][1:,(dim-1)], [len(ori_data[i][1:,(dim-1)]),1]) for i in train_idx) 123 | 124 | # Prediction 125 | pred_Y_curr = sess.run(y_pred, feed_dict={X: X_mb, T: T_mb}) 126 | 127 | # Compute the performance in terms of MAE 128 | MAE_temp = 0 129 | for i in range(no): 130 | MAE_temp = MAE_temp + mean_absolute_error(Y_mb[i], pred_Y_curr[i,:,:]) 131 | 132 | predictive_score = MAE_temp / no 133 | 134 | return predictive_score 135 | -------------------------------------------------------------------------------- /engine/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | import torch 8 | from Utils.io_utils import write_args, save_config_to_yaml 9 | 10 | 11 | class Logger(object): 12 | def __init__(self, args): 13 | self.args = args 14 | self.save_dir = args.save_dir 15 | 16 | os.makedirs(self.save_dir, exist_ok=True) 17 | 18 | # save the args and config 19 | self.config_dir = os.path.join(self.save_dir, 'configs') 20 | os.makedirs(self.config_dir, exist_ok=True) 21 | file_name = os.path.join(self.config_dir, 'args.txt') 22 | write_args(args, file_name) 23 | 24 | log_dir = os.path.join(self.save_dir, 'logs') 25 | if not os.path.exists(log_dir): 26 | os.makedirs(log_dir, exist_ok=True) 27 | self.text_writer = open(os.path.join(log_dir, 'log.txt'), 'a') # 'w') 28 | if args.tensorboard: 29 | self.log_info('using tensorboard') 30 | self.tb_writer = torch.utils.tensorboard.SummaryWriter(log_dir=log_dir) # tensorboard.SummaryWriter(log_dir=log_dir) 31 | else: 32 | self.tb_writer = None 33 | 34 | def save_config(self, config): 35 | save_config_to_yaml(config, os.path.join(self.config_dir, 'config.yaml')) 36 | 37 | def log_info(self, info, check_primary=True): 38 | print(info) 39 | info = str(info) 40 | time_str = time.strftime('%Y-%m-%d-%H-%M') 41 | info = '{}: {}'.format(time_str, info) 42 | if not info.endswith('\n'): 43 | info += '\n' 44 | self.text_writer.write(info) 45 | self.text_writer.flush() 46 | 47 | def add_scalar(self, **kargs): 48 | """Log a scalar variable.""" 49 | if self.tb_writer is not None: 50 | self.tb_writer.add_scalar(**kargs) 51 | 52 | def add_scalars(self, **kargs): 53 | """Log a scalar variable.""" 54 | if self.tb_writer is not None: 55 | self.tb_writer.add_scalars(**kargs) 56 | 57 | def add_image(self, **kargs): 58 | """Log a scalar variable.""" 59 | if self.tb_writer is not None: 60 | self.tb_writer.add_image(**kargs) 61 | 62 | def add_images(self, **kargs): 63 | """Log a scalar variable.""" 64 | if self.tb_writer is not None: 65 | self.tb_writer.add_images(**kargs) 66 | 67 | def close(self): 68 | self.text_writer.close() 69 | self.tb_writer.close() 70 | 71 | -------------------------------------------------------------------------------- /engine/lr_sch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import inf 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class ReduceLROnPlateauWithWarmup(object): 7 | """Reduce learning rate when a metric has stopped improving. 8 | Models often benefit from reducing the learning rate by a factor 9 | of 2-10 once learning stagnates. This scheduler reads a metrics 10 | quantity and if no improvement is seen for a 'patience' number 11 | of epochs, the learning rate is reduced. 12 | 13 | Args: 14 | optimizer (Optimizer): Wrapped optimizer. 15 | mode (str): One of `min`, `max`. In `min` mode, lr will 16 | be reduced when the quantity monitored has stopped 17 | decreasing; in `max` mode it will be reduced when the 18 | quantity monitored has stopped increasing. Default: 'min'. 19 | factor (float): Factor by which the learning rate will be 20 | reduced. new_lr = lr * factor. Default: 0.1. 21 | patience (int): Number of epochs with no improvement after 22 | which learning rate will be reduced. For example, if 23 | `patience = 2`, then we will ignore the first 2 epochs 24 | with no improvement, and will only decrease the LR after the 25 | 3rd epoch if the loss still hasn't improved then. 26 | Default: 10. 27 | threshold (float): Threshold for measuring the new optimum, 28 | to only focus on significant changes. Default: 1e-4. 29 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 30 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 31 | mode or best * ( 1 - threshold ) in `min` mode. 32 | In `abs` mode, dynamic_threshold = best + threshold in 33 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 34 | cooldown (int): Number of epochs to wait before resuming 35 | normal operation after lr has been reduced. Default: 0. 36 | min_lr (float or list): A scalar or a list of scalars. A 37 | lower bound on the learning rate of all param groups 38 | or each group respectively. Default: 0. 39 | eps (float): Minimal decay applied to lr. If the difference 40 | between new and old lr is smaller than eps, the update is 41 | ignored. Default: 1e-8. 42 | verbose (bool): If ``True``, prints a message to stdout for 43 | each update. Default: ``False``. 44 | warmup_lr: float or None, the learning rate to be touched after warmup 45 | warmup: int, the number of steps to warmup 46 | """ 47 | 48 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 49 | threshold=1e-4, threshold_mode='rel', cooldown=0, 50 | min_lr=0, eps=1e-8, verbose=False, warmup_lr=None, 51 | warmup=0): 52 | 53 | if factor >= 1.0: 54 | raise ValueError('Factor should be < 1.0.') 55 | self.factor = factor 56 | 57 | # Attach optimizer 58 | if not isinstance(optimizer, Optimizer): 59 | raise TypeError('{} is not an Optimizer'.format( 60 | type(optimizer).__name__)) 61 | self.optimizer = optimizer 62 | 63 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 64 | if len(min_lr) != len(optimizer.param_groups): 65 | raise ValueError("expected {} min_lrs, got {}".format( 66 | len(optimizer.param_groups), len(min_lr))) 67 | self.min_lrs = list(min_lr) 68 | else: 69 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 70 | 71 | self.patience = patience 72 | self.verbose = verbose 73 | self.cooldown = cooldown 74 | self.cooldown_counter = 0 75 | self.mode = mode 76 | self.threshold = threshold 77 | self.threshold_mode = threshold_mode 78 | 79 | self.warmup_lr = warmup_lr 80 | self.warmup = warmup 81 | 82 | self.best = None 83 | self.num_bad_epochs = None 84 | self.mode_worse = None # the worse value for the chosen mode 85 | self.eps = eps 86 | self.last_epoch = 0 87 | self._init_is_better(mode=mode, threshold=threshold, 88 | threshold_mode=threshold_mode) 89 | self._reset() 90 | 91 | def _prepare_for_warmup(self): 92 | if self.warmup_lr is not None: 93 | if isinstance(self.warmup_lr, (list, tuple)): 94 | if len(self.warmup_lr) != len(self.optimizer.param_groups): 95 | raise ValueError("expected {} warmup_lrs, got {}".format( 96 | len(self.optimizer.param_groups), len(self.warmup_lr))) 97 | self.warmup_lrs = list(self.warmup_lr) 98 | else: 99 | self.warmup_lrs = [self.warmup_lr] * len(self.optimizer.param_groups) 100 | else: 101 | self.warmup_lrs = None 102 | if self.warmup > self.last_epoch: 103 | curr_lrs = [group['lr'] for group in self.optimizer.param_groups] 104 | self.warmup_lr_steps = [max(0, (self.warmup_lrs[i] - curr_lrs[i])/float(self.warmup)) for i in range(len(curr_lrs))] 105 | else: 106 | self.warmup_lr_steps = None 107 | 108 | def _reset(self): 109 | """Resets num_bad_epochs counter and cooldown counter.""" 110 | self.best = self.mode_worse 111 | self.cooldown_counter = 0 112 | self.num_bad_epochs = 0 113 | 114 | def step(self, metrics): 115 | # convert `metrics` to float, in case it's a zero-dim Tensor 116 | current = float(metrics) 117 | epoch = self.last_epoch + 1 118 | self.last_epoch = epoch 119 | 120 | if epoch <= self.warmup: 121 | self._increase_lr(epoch) 122 | else: 123 | if self.is_better(current, self.best): 124 | self.best = current 125 | self.num_bad_epochs = 0 126 | else: 127 | self.num_bad_epochs += 1 128 | 129 | if self.in_cooldown: 130 | self.cooldown_counter -= 1 131 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 132 | 133 | if self.num_bad_epochs > self.patience: 134 | self._reduce_lr(epoch) 135 | self.cooldown_counter = self.cooldown 136 | self.num_bad_epochs = 0 137 | 138 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 139 | 140 | def _reduce_lr(self, epoch): 141 | for i, param_group in enumerate(self.optimizer.param_groups): 142 | old_lr = float(param_group['lr']) 143 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 144 | if old_lr - new_lr > self.eps: 145 | param_group['lr'] = new_lr 146 | if self.verbose: 147 | print('Epoch {:5d}: reducing learning rate' 148 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 149 | 150 | def _increase_lr(self, epoch): 151 | # used for warmup 152 | for i, param_group in enumerate(self.optimizer.param_groups): 153 | old_lr = float(param_group['lr']) 154 | new_lr = max(old_lr + self.warmup_lr_steps[i], self.min_lrs[i]) 155 | param_group['lr'] = new_lr 156 | if self.verbose: 157 | print('Epoch {:5d}: increasing learning rate' 158 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 159 | 160 | @property 161 | def in_cooldown(self): 162 | return self.cooldown_counter > 0 163 | 164 | def is_better(self, a, best): 165 | if self.mode == 'min' and self.threshold_mode == 'rel': 166 | rel_epsilon = 1. - self.threshold 167 | return a < best * rel_epsilon 168 | 169 | elif self.mode == 'min' and self.threshold_mode == 'abs': 170 | return a < best - self.threshold 171 | 172 | elif self.mode == 'max' and self.threshold_mode == 'rel': 173 | rel_epsilon = self.threshold + 1. 174 | return a > best * rel_epsilon 175 | 176 | else: # mode == 'max' and epsilon_mode == 'abs': 177 | return a > best + self.threshold 178 | 179 | def _init_is_better(self, mode, threshold, threshold_mode): 180 | if mode not in {'min', 'max'}: 181 | raise ValueError('mode ' + mode + ' is unknown!') 182 | if threshold_mode not in {'rel', 'abs'}: 183 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') 184 | 185 | if mode == 'min': 186 | self.mode_worse = inf 187 | else: # mode == 'max': 188 | self.mode_worse = -inf 189 | 190 | self.mode = mode 191 | self.threshold = threshold 192 | self.threshold_mode = threshold_mode 193 | 194 | self._prepare_for_warmup() 195 | 196 | def state_dict(self): 197 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 198 | 199 | def load_state_dict(self, state_dict): 200 | self.__dict__.update(state_dict) 201 | self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) 202 | 203 | 204 | class CosineAnnealingLRWithWarmup(object): 205 | """ 206 | adjust lr: 207 | 208 | args: 209 | warmup_lr: float or None, the learning rate to be touched after warmup 210 | warmup: int, the number of steps to warmup 211 | """ 212 | 213 | def __init__(self, optimizer, T_max, last_epoch=-1, verbose=False, 214 | min_lr=0, warmup_lr=None, warmup=0): 215 | self.optimizer = optimizer 216 | self.T_max = T_max 217 | self.last_epoch = last_epoch 218 | self.verbose = verbose 219 | self.warmup_lr = warmup_lr 220 | self.warmup = warmup 221 | 222 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 223 | if len(min_lr) != len(optimizer.param_groups): 224 | raise ValueError("expected {} min_lrs, got {}".format( 225 | len(optimizer.param_groups), len(min_lr))) 226 | self.min_lrs = list(min_lr) 227 | else: 228 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 229 | self.max_lrs = [lr for lr in self.min_lrs] 230 | 231 | self._prepare_for_warmup() 232 | 233 | def step(self): 234 | epoch = self.last_epoch + 1 235 | self.last_epoch = epoch 236 | 237 | if epoch <= self.warmup: 238 | self._increase_lr(epoch) 239 | else: 240 | self._reduce_lr(epoch) 241 | 242 | def _reduce_lr(self, epoch): 243 | for i, param_group in enumerate(self.optimizer.param_groups): 244 | progress = float(epoch - self.warmup) / float(max(1, self.T_max - self.warmup)) 245 | factor = max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) 246 | old_lr = float(param_group['lr']) 247 | new_lr = max(self.max_lrs[i] * factor, self.min_lrs[i]) 248 | param_group['lr'] = new_lr 249 | if self.verbose: 250 | print('Epoch {:5d}: reducing learning rate' 251 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 252 | 253 | def _increase_lr(self, epoch): 254 | # used for warmup 255 | for i, param_group in enumerate(self.optimizer.param_groups): 256 | old_lr = float(param_group['lr']) 257 | new_lr = old_lr + self.warmup_lr_steps[i] 258 | param_group['lr'] = new_lr 259 | self.max_lrs[i] = max(self.max_lrs[i], new_lr) 260 | if self.verbose: 261 | print('Epoch {:5d}: increasing learning rate' 262 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 263 | 264 | def _prepare_for_warmup(self): 265 | if self.warmup_lr is not None: 266 | if isinstance(self.warmup_lr, (list, tuple)): 267 | if len(self.warmup_lr) != len(self.optimizer.param_groups): 268 | raise ValueError("expected {} warmup_lrs, got {}".format( 269 | len(self.optimizer.param_groups), len(self.warmup_lr))) 270 | self.warmup_lrs = list(self.warmup_lr) 271 | else: 272 | self.warmup_lrs = [self.warmup_lr] * len(self.optimizer.param_groups) 273 | else: 274 | self.warmup_lrs = None 275 | if self.warmup > self.last_epoch: 276 | curr_lrs = [group['lr'] for group in self.optimizer.param_groups] 277 | self.warmup_lr_steps = [max(0, (self.warmup_lrs[i] - curr_lrs[i])/float(self.warmup)) for i in range(len(curr_lrs))] 278 | else: 279 | self.warmup_lr_steps = None 280 | 281 | 282 | def state_dict(self): 283 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 284 | 285 | def load_state_dict(self, state_dict): 286 | self.__dict__.update(state_dict) 287 | self._prepare_for_warmup() -------------------------------------------------------------------------------- /engine/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | from pathlib import Path 9 | from tqdm.auto import tqdm 10 | from ema_pytorch import EMA 11 | from torch.optim import Adam 12 | from torch.nn.utils import clip_grad_norm_ 13 | from Utils.io_utils import instantiate_from_config, get_model_parameters_info 14 | 15 | 16 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 17 | 18 | def cycle(dl): 19 | while True: 20 | for data in dl: 21 | yield data 22 | 23 | 24 | class Trainer(object): 25 | def __init__(self, config, args, model, dataloader, logger=None): 26 | super().__init__() 27 | self.model = model 28 | self.device = self.model.betas.device 29 | self.train_num_steps = config['solver']['max_epochs'] 30 | self.gradient_accumulate_every = config['solver']['gradient_accumulate_every'] 31 | self.save_cycle = config['solver']['save_cycle'] 32 | self.dl = cycle(dataloader['dataloader']) 33 | self.dataloader = dataloader['dataloader'] 34 | self.step = 0 35 | self.milestone = 0 36 | self.args, self.config = args, config 37 | self.logger = logger 38 | 39 | self.results_folder = Path(config['solver']['results_folder'] + f'_{model.seq_length}') 40 | os.makedirs(self.results_folder, exist_ok=True) 41 | 42 | start_lr = config['solver'].get('base_lr', 1.0e-4) 43 | ema_decay = config['solver']['ema']['decay'] 44 | ema_update_every = config['solver']['ema']['update_interval'] 45 | 46 | self.opt = Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=start_lr, betas=[0.9, 0.96]) 47 | self.ema = EMA(self.model, beta=ema_decay, update_every=ema_update_every).to(self.device) 48 | 49 | sc_cfg = config['solver']['scheduler'] 50 | sc_cfg['params']['optimizer'] = self.opt 51 | self.sch = instantiate_from_config(sc_cfg) 52 | 53 | if self.logger is not None: 54 | self.logger.log_info(str(get_model_parameters_info(self.model))) 55 | self.log_frequency = 100 56 | 57 | def save(self, milestone, verbose=False): 58 | if self.logger is not None and verbose: 59 | self.logger.log_info('Save current model to {}'.format(str(self.results_folder / f'checkpoint-{milestone}.pt'))) 60 | data = { 61 | 'step': self.step, 62 | 'model': self.model.state_dict(), 63 | 'ema': self.ema.state_dict(), 64 | 'opt': self.opt.state_dict(), 65 | } 66 | torch.save(data, str(self.results_folder / f'checkpoint-{milestone}.pt')) 67 | 68 | def save_classifier(self, milestone, verbose=False): 69 | if self.logger is not None and verbose: 70 | self.logger.log_info('Save current classifer to {}'.format(str(self.results_folder / f'ckpt_classfier-{milestone}.pt'))) 71 | data = { 72 | 'step': self.step_classifier, 73 | 'classifier': self.classifier.state_dict() 74 | } 75 | torch.save(data, str(self.results_folder / f'ckpt_classfier-{milestone}.pt')) 76 | 77 | def load(self, milestone, verbose=False): 78 | if self.logger is not None and verbose: 79 | self.logger.log_info('Resume from {}'.format(str(self.results_folder / f'checkpoint-{milestone}.pt'))) 80 | device = self.device 81 | data = torch.load(str(self.results_folder / f'checkpoint-{milestone}.pt'), map_location=device) 82 | self.model.load_state_dict(data['model']) 83 | self.step = data['step'] 84 | self.opt.load_state_dict(data['opt']) 85 | self.ema.load_state_dict(data['ema']) 86 | self.milestone = milestone 87 | 88 | def load_classifier(self, milestone, verbose=False): 89 | if self.logger is not None and verbose: 90 | self.logger.log_info('Resume from {}'.format(str(self.results_folder / f'ckpt_classfier-{milestone}.pt'))) 91 | device = self.device 92 | data = torch.load(str(self.results_folder / f'ckpt_classfier-{milestone}.pt'), map_location=device) 93 | self.classifier.load_state_dict(data['classifier']) 94 | self.step_classifier = data['step'] 95 | self.milestone_classifier = milestone 96 | 97 | def train(self): 98 | device = self.device 99 | step = 0 100 | if self.logger is not None: 101 | tic = time.time() 102 | self.logger.log_info('{}: start training...'.format(self.args.name), check_primary=False) 103 | 104 | with tqdm(initial=step, total=self.train_num_steps) as pbar: 105 | while step < self.train_num_steps: 106 | total_loss = 0. 107 | for _ in range(self.gradient_accumulate_every): 108 | data = next(self.dl).to(device) 109 | loss = self.model(data, target=data) 110 | loss = loss / self.gradient_accumulate_every 111 | loss.backward() 112 | total_loss += loss.item() 113 | 114 | pbar.set_description(f'loss: {total_loss:.6f}') 115 | 116 | clip_grad_norm_(self.model.parameters(), 1.0) 117 | self.opt.step() 118 | self.sch.step(total_loss) 119 | self.opt.zero_grad() 120 | self.step += 1 121 | step += 1 122 | self.ema.update() 123 | 124 | with torch.no_grad(): 125 | if self.step != 0 and self.step % self.save_cycle == 0: 126 | self.milestone += 1 127 | self.save(self.milestone) 128 | # self.logger.log_info('saved in {}'.format(str(self.results_folder / f'checkpoint-{self.milestone}.pt'))) 129 | 130 | if self.logger is not None and self.step % self.log_frequency == 0: 131 | # info = '{}: train'.format(self.args.name) 132 | # info = info + ': Epoch {}/{}'.format(self.step, self.train_num_steps) 133 | # info += ' ||' 134 | # info += '' if loss_f == 'none' else ' Fourier Loss: {:.4f}'.format(loss_f.item()) 135 | # info += '' if loss_r == 'none' else ' Reglarization: {:.4f}'.format(loss_r.item()) 136 | # info += ' | Total Loss: {:.6f}'.format(total_loss) 137 | # self.logger.log_info(info) 138 | self.logger.add_scalar(tag='train/loss', scalar_value=total_loss, global_step=self.step) 139 | 140 | pbar.update(1) 141 | 142 | print('training complete') 143 | if self.logger is not None: 144 | self.logger.log_info('Training done, time: {:.2f}'.format(time.time() - tic)) 145 | 146 | def sample(self, num, size_every, shape=None, model_kwargs=None, cond_fn=None): 147 | if self.logger is not None: 148 | tic = time.time() 149 | self.logger.log_info('Begin to sample...') 150 | samples = np.empty([0, shape[0], shape[1]]) 151 | num_cycle = int(num // size_every) + 1 152 | 153 | for _ in range(num_cycle): 154 | sample = self.ema.ema_model.generate_mts(batch_size=size_every, model_kwargs=model_kwargs, cond_fn=cond_fn) 155 | samples = np.row_stack([samples, sample.detach().cpu().numpy()]) 156 | torch.cuda.empty_cache() 157 | 158 | if self.logger is not None: 159 | self.logger.log_info('Sampling done, time: {:.2f}'.format(time.time() - tic)) 160 | return samples 161 | 162 | def restore(self, raw_dataloader, shape=None, coef=1e-1, stepsize=1e-1, sampling_steps=50): 163 | if self.logger is not None: 164 | tic = time.time() 165 | self.logger.log_info('Begin to restore...') 166 | model_kwargs = {} 167 | model_kwargs['coef'] = coef 168 | model_kwargs['learning_rate'] = stepsize 169 | samples = np.empty([0, shape[0], shape[1]]) 170 | reals = np.empty([0, shape[0], shape[1]]) 171 | masks = np.empty([0, shape[0], shape[1]]) 172 | 173 | for idx, (x, t_m) in enumerate(raw_dataloader): 174 | x, t_m = x.to(self.device), t_m.to(self.device) 175 | if sampling_steps == self.model.num_timesteps: 176 | sample = self.ema.ema_model.sample_infill(shape=x.shape, target=x*t_m, partial_mask=t_m, 177 | model_kwargs=model_kwargs) 178 | else: 179 | sample = self.ema.ema_model.fast_sample_infill(shape=x.shape, target=x*t_m, partial_mask=t_m, model_kwargs=model_kwargs, 180 | sampling_timesteps=sampling_steps) 181 | 182 | samples = np.row_stack([samples, sample.detach().cpu().numpy()]) 183 | reals = np.row_stack([reals, x.detach().cpu().numpy()]) 184 | masks = np.row_stack([masks, t_m.detach().cpu().numpy()]) 185 | 186 | if self.logger is not None: 187 | self.logger.log_info('Imputation done, time: {:.2f}'.format(time.time() - tic)) 188 | return samples, reals, masks 189 | # return samples 190 | 191 | def forward_sample(self, x_start): 192 | b, c, h = x_start.shape 193 | noise = torch.randn_like(x_start, device=self.device) 194 | t = torch.randint(0, self.model.num_timesteps, (b,), device=self.device).long() 195 | x_t = self.model.q_sample(x_start=x_start, t=t, noise=noise).detach() 196 | return x_t, t 197 | 198 | def train_classfier(self, classifier): 199 | device = self.device 200 | step = 0 201 | self.milestone_classifier = 0 202 | self.step_classifier = 0 203 | dataloader = self.dataloader 204 | dataloader.dataset.shift_period('test') 205 | dataloader = cycle(dataloader) 206 | 207 | self.classifier = classifier 208 | self.opt_classifier = Adam(filter(lambda p: p.requires_grad, self.classifier.parameters()), lr=5.0e-4) 209 | 210 | if self.logger is not None: 211 | tic = time.time() 212 | self.logger.log_info('{}: start training classifier...'.format(self.args.name), check_primary=False) 213 | 214 | with tqdm(initial=step, total=self.train_num_steps) as pbar: 215 | while step < self.train_num_steps: 216 | total_loss = 0. 217 | for _ in range(self.gradient_accumulate_every): 218 | x, y = next(dataloader) 219 | x, y = x.to(device), y.to(device) 220 | x_t, t = self.forward_sample(x) 221 | logits = classifier(x_t, t) 222 | loss = F.cross_entropy(logits, y) 223 | loss = loss / self.gradient_accumulate_every 224 | loss.backward() 225 | total_loss += loss.item() 226 | 227 | pbar.set_description(f'loss: {total_loss:.6f}') 228 | 229 | self.opt_classifier.step() 230 | self.opt_classifier.zero_grad() 231 | self.step_classifier += 1 232 | step += 1 233 | 234 | with torch.no_grad(): 235 | if self.step_classifier != 0 and self.step_classifier % self.save_cycle == 0: 236 | self.milestone_classifier += 1 237 | self.save(self.milestone_classifier) 238 | 239 | if self.logger is not None and self.step_classifier % self.log_frequency == 0: 240 | self.logger.add_scalar(tag='train/loss', scalar_value=total_loss, global_step=self.step) 241 | 242 | pbar.update(1) 243 | 244 | print('training complete') 245 | if self.logger is not None: 246 | self.logger.log_info('Training done, time: {:.2f}'.format(time.time() - tic)) 247 | 248 | # return classifier 249 | 250 | -------------------------------------------------------------------------------- /figures/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig1.jpg -------------------------------------------------------------------------------- /figures/fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig2.jpg -------------------------------------------------------------------------------- /figures/fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig3.jpg -------------------------------------------------------------------------------- /figures/fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Y-debug-sys/Diffusion-TS/566307e6cf2d8095e58de4c6e3a6ae965b69b5b5/figures/fig4.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from engine.logger import Logger 7 | from engine.solver import Trainer 8 | from Data.build_dataloader import build_dataloader, build_dataloader_cond 9 | from Models.interpretable_diffusion.model_utils import unnormalize_to_zero_to_one 10 | from Utils.io_utils import load_yaml_config, seed_everything, merge_opts_to_config, instantiate_from_config 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='PyTorch Training Script') 15 | parser.add_argument('--name', type=str, default=None) 16 | 17 | parser.add_argument('--config_file', type=str, default=None, 18 | help='path of config file') 19 | parser.add_argument('--output', type=str, default='OUTPUT', 20 | help='directory to save the results') 21 | parser.add_argument('--tensorboard', action='store_true', 22 | help='use tensorboard for logging') 23 | 24 | # args for random 25 | 26 | parser.add_argument('--cudnn_deterministic', action='store_true', default=False, 27 | help='set cudnn.deterministic True') 28 | parser.add_argument('--seed', type=int, default=12345, 29 | help='seed for initializing training.') 30 | parser.add_argument('--gpu', type=int, default=None, 31 | help='GPU id to use. If given, only the specific gpu will be' 32 | ' used, and ddp will be disabled') 33 | 34 | # args for training 35 | parser.add_argument('--train', action='store_true', default=False, help='Train or Test.') 36 | parser.add_argument('--sample', type=int, default=0, 37 | choices=[0, 1], help='Condition or Uncondition.') 38 | parser.add_argument('--mode', type=str, default='infill', 39 | help='Infilling or Forecasting.') 40 | parser.add_argument('--milestone', type=int, default=10) 41 | 42 | parser.add_argument('--missing_ratio', type=float, default=0., help='Ratio of Missing Values.') 43 | parser.add_argument('--pred_len', type=int, default=0, help='Length of Predictions.') 44 | 45 | # args for modify config 46 | parser.add_argument('opts', help='Modify config options using the command-line', 47 | default=None, nargs=argparse.REMAINDER) 48 | 49 | args = parser.parse_args() 50 | args.save_dir = os.path.join(args.output, f'{args.name}') 51 | 52 | return args 53 | 54 | def main(): 55 | args = parse_args() 56 | 57 | if args.seed is not None: 58 | seed_everything(args.seed) 59 | 60 | if args.gpu is not None: 61 | torch.cuda.set_device(args.gpu) 62 | 63 | config = load_yaml_config(args.config_file) 64 | config = merge_opts_to_config(config, args.opts) 65 | 66 | logger = Logger(args) 67 | logger.save_config(config) 68 | 69 | model = instantiate_from_config(config['model']).cuda() 70 | if args.sample == 1 and args.mode in ['infill', 'predict']: 71 | test_dataloader_info = build_dataloader_cond(config, args) 72 | dataloader_info = build_dataloader(config, args) 73 | trainer = Trainer(config=config, args=args, model=model, dataloader=dataloader_info, logger=logger) 74 | 75 | if args.train: 76 | trainer.train() 77 | elif args.sample == 1 and args.mode in ['infill', 'predict']: 78 | trainer.load(args.milestone) 79 | dataloader, dataset = test_dataloader_info['dataloader'], test_dataloader_info['dataset'] 80 | coef = config['dataloader']['test_dataset']['coefficient'] 81 | stepsize = config['dataloader']['test_dataset']['step_size'] 82 | sampling_steps = config['dataloader']['test_dataset']['sampling_steps'] 83 | samples, *_ = trainer.restore(dataloader, [dataset.window, dataset.var_num], coef, stepsize, sampling_steps) 84 | if dataset.auto_norm: 85 | samples = unnormalize_to_zero_to_one(samples) 86 | # samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape) 87 | np.save(os.path.join(args.save_dir, f'ddpm_{args.mode}_{args.name}.npy'), samples) 88 | else: 89 | trainer.load(args.milestone) 90 | dataset = dataloader_info['dataset'] 91 | samples = trainer.sample(num=len(dataset), size_every=2001, shape=[dataset.window, dataset.var_num]) 92 | if dataset.auto_norm: 93 | samples = unnormalize_to_zero_to_one(samples) 94 | # samples = dataset.scaler.inverse_transform(samples.reshape(-1, samples.shape[-1])).reshape(samples.shape) 95 | np.save(os.path.join(args.save_dir, f'ddpm_fake_{args.name}.npy'), samples) 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | einops==0.6.0 3 | ema-pytorch==0.2.1 4 | matplotlib==3.6.0 5 | pandas==1.5.0 6 | scikit-learn==1.1.2 7 | scipy==1.8.1 8 | seaborn==0.12.2 9 | tqdm==4.64.1 10 | dm-control==1.0.12 11 | dm-env==1.6 12 | dm-tree==0.1.8 13 | mujoco==2.3.4 14 | gluonts==0.12.6 15 | pyyaml==6.0 16 | --------------------------------------------------------------------------------