├── README.md ├── example.ipynb ├── setup.py └── tsdiff ├── __init__.py ├── csdi ├── config │ └── base.yaml ├── dataset_physio.py ├── dataset_pm25.py ├── diff_models.py ├── download.py ├── exe_physio.py ├── exe_pm25.py ├── experiment.py ├── experiment.yaml ├── main_model.py └── utils.py ├── data ├── __init__.py └── generate.py ├── diffusion ├── __init__.py ├── beta_scheduler.py ├── continuous_diffusion.py ├── discrete_diffusion.py └── noise.py ├── forecasting ├── __init__.py ├── experiment.py ├── experiment.yaml ├── models │ ├── __init__.py │ ├── score_estimator.py │ ├── score_network.py │ └── time_grad_network.py ├── train.py └── train_deepvar.py ├── neural_process ├── __init__.py ├── experiment.py ├── experiment.yaml └── train.py ├── synthetic ├── __init__.py ├── data.py ├── diffusion_model.py ├── discriminator_experiment.py ├── discriminator_experiment.yaml ├── experiment.py ├── experiment.yaml ├── nf_model.py ├── ode_model.py ├── sde_model.py └── train.py ├── test ├── __init__.py ├── test_beta_scheduler.py ├── test_ddpm.py └── test_sde_diffusion.py └── utils ├── __init__.py ├── dotdict.py ├── epsilon_theta.py ├── exception.py ├── feedforward.py ├── positional_encoding.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Modeling Temporal Data as Continuous Functions with Stochastic Process Diffusion 2 | 3 | Marin Biloš, Kashif Rasul, Anderson Schneider, Yuriy Nevmyvaka, Stephan Günnemann 4 | 5 | International Conference on Machine Learning (ICML), 2023 6 | 7 | Abstract: Temporal data such as time series can be viewed as discretized measurements of the underlying function. To build a generative model for such data we have to model the stochastic process that governs it. We propose a solution by defining the denoising diffusion model in the function space which also allows us to naturally handle irregularly-sampled observations. The forward process gradually adds noise to functions, preserving their continuity, while the learned reverse process removes the noise and returns functions as new samples. To this end, we define suitable noise sources and introduce novel denoising and score-matching models. We show how our method can be used for multivariate probabilistic forecasting and imputation, and how our model can be interpreted as a neural process. 8 | 9 | ## Simple example 10 | 11 | You can find a runnable example of DSPD-GP model in a [self-contained notebook](example.ipynb). For full experiments follow the instructions below. 12 | 13 | ## Installation 14 | 15 | Run to install local package `tsdiff` and other required packages: 16 | ```sh 17 | pip install -e . 18 | ``` 19 | 20 | Run tests: 21 | ```sh 22 | pytest 23 | ``` 24 | 25 | ## Probabilistic modeling 26 | 27 | Generate synthetic data: 28 | ```sh 29 | python -m tsdiff.data.generate 30 | ``` 31 | 32 | Run example: 33 | ```sh 34 | python -m tsdiff.synthetic.train --seed 1 --dataset lorenz --diffusion GaussianDiffusion --model rnn 35 | ``` 36 | Other options (see [file](tsdiff/synthetic/train.py)): 37 | ``` 38 | python -m tsdiff.synthetic.train 39 | --dataset [cir|lorenz|ou|predator_prey|sine|sink] 40 | --diffusion [GaussianDiffusion|OUDiffusion|GPDiffusion|ContinuousGaussianDiffusion|ContinuousOUDiffusion|ContinuousGPDiffusion] 41 | --model [feedforward|rnn|transformer] 42 | --gp_sigma [float] 43 | --ou_theta [float] 44 | ``` 45 | 46 | See the files [experiment.py](tsdiff/synthetic/experiment.py) and [experiment.yaml](tsdiff/synthetic/experiment.yaml) for replicating the experiments, as well as [discriminator_experiment.py](tsdiff/synthetic/discriminator_experiment.py) and [discriminator_experiment.yaml](tsdiff/synthetic/discriminator_experiment.yaml) for discriminator experiment. 47 | 48 | ## Forecasting 49 | 50 | Example: 51 | ```sh 52 | python -m tsdiff.forecasting.train --seed 1 --dataset electricity_nips --network timegrad_rnn --noise ou --epochs 100 53 | ``` 54 | Other options can be found in [train.py](tsdiff/forecasting/train.py). See [experiment.py](tsdiff/forecasting/experiment.py) and [experiment.yaml](tsdiff/forecasting/experiment.yaml) for reproducing the experiments. 55 | 56 | ## Neural process 57 | 58 | Example: 59 | ```sh 60 | python -m tsdiff.neural_process.train 61 | ``` 62 | Use [experiment.py](tsdiff/neural_process/experiment.py) and [experiment.yaml](tsdiff/neural_process/experiment.yaml) to reproduce the experiments. 63 | 64 | 65 | ## Imputation experiment 66 | 67 | The setup and the code is very similar to the official implementation of CSDI [[arxiv](https://arxiv.org/abs/2107.03502), [github](https://github.com/ermongroup/CSDI)]. 68 | 69 | In our version, we keep the actual times of the observations, instead of rounding them up to the nearest hour. 70 | 71 | Download data: 72 | ```sh 73 | python -m tsdiff.csdi.download physio 74 | ``` 75 | 76 | Running the code: 77 | ```sh 78 | # original paper 79 | python -m tsdiff.csdi.exe_physio 80 | 81 | # our paper 82 | python -m tsdiff.csdi.exe_physio --gp_noise 83 | ``` 84 | Other parameters include `--is_unconditional`, `--nsample` (e.g., 100), `--testmissingratio` (e.g., 0.5). 85 | 86 | Replicating the experiments: `seml tsdiff_csdi add tsdiff/csdi/experiment.yaml start`. File [experiment.yaml](tsdiff/csdi/experiment.yaml) contains the hyperparameter configuration. File [experiment.py](tsdiff/csdi/experiment.py) runs the experiments. 87 | 88 | 89 | ## Citation 90 | 91 | ``` 92 | @inproceedings{bilos2022diffusion, 93 | title={Modeling Temporal Data as Continuous Functions with Stochastic Process Diffusion}, 94 | author={Bilo{\v{s}}, Marin and Rasul, Kashif and Schneider, Anderson and Nevmyvaka, Yuriy and G{\"u}nnemann, Stephan}, 95 | booktitle={International Conference on Machine Learning (ICML)}, 96 | year={2023}, 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | install_requires = [ 4 | 'numpy>=1.21.5', 5 | 'pytest>=6.2.4', 6 | 'scipy>=1.7.1', 7 | 'torch>=1.12.1', 8 | 'pytorch-lightning==1.6.0', 9 | 'torchdiffeq==0.2.3', 10 | 'torchsde==0.2.5', 11 | 'matplotlib>=3.4.3', 12 | 'seaborn==0.11.1', 13 | 'pytorchts==0.6.0', 14 | 'gluonts==0.9.*', 15 | 'wget==3.2', 16 | ] 17 | 18 | with open('README.md', 'r') as f: 19 | long_description = f.read() 20 | 21 | setup(name='tsdiff', 22 | version='0.1.0', 23 | description='Time series diffusion', 24 | long_description=long_description, 25 | long_description_content_type='text/markdown', 26 | url='', 27 | author='Marin Bilos', 28 | author_email='marin.bilos@morganstaley.com', # also: marin.bilos@tum.de 29 | packages=find_packages(), 30 | install_requires=install_requires, 31 | python_requires='>=3.7', 32 | zip_safe=False, 33 | ) 34 | -------------------------------------------------------------------------------- /tsdiff/__init__.py: -------------------------------------------------------------------------------- 1 | from . import diffusion 2 | from . import utils 3 | -------------------------------------------------------------------------------- /tsdiff/csdi/config/base.yaml: -------------------------------------------------------------------------------- 1 | 2 | #type: args 3 | 4 | train: 5 | epochs: 200 6 | batch_size: 16 7 | lr: 1.0e-3 8 | 9 | diffusion: 10 | layers: 4 11 | channels: 64 12 | nheads: 8 13 | diffusion_embedding_dim: 128 14 | beta_start: 0.0001 15 | beta_end: 0.5 16 | num_steps: 50 17 | schedule: "quad" 18 | 19 | model: 20 | is_unconditional: 0 21 | timeemb: 128 22 | featureemb: 16 23 | target_strategy: "random" 24 | -------------------------------------------------------------------------------- /tsdiff/csdi/dataset_physio.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import os 4 | import re 5 | import numpy as np 6 | import pandas as pd 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | from pathlib import Path 10 | DATA_DIR = Path(__file__) / 'data' 11 | 12 | 13 | # 35 attributes which contains enough non-values 14 | attributes = ['DiasABP', 'HR', 'Na', 'Lactate', 'NIDiasABP', 'PaO2', 'WBC', 'pH', 'Albumin', 'ALT', 'Glucose', 'SaO2', 15 | 'Temp', 'AST', 'Bilirubin', 'HCO3', 'BUN', 'RespRate', 'Mg', 'HCT', 'SysABP', 'FiO2', 'K', 'GCS', 16 | 'Cholesterol', 'NISysABP', 'TroponinT', 'MAP', 'TroponinI', 'PaCO2', 'Platelets', 'Urine', 'NIMAP', 17 | 'Creatinine', 'ALP'] 18 | 19 | 20 | def extract_hour(x): 21 | h, _ = map(int, x.split(":")) 22 | return h 23 | 24 | 25 | def parse_data(x): 26 | # extract the last value for each attribute 27 | x = x.set_index("Parameter").to_dict()["Value"] 28 | 29 | values = [] 30 | 31 | for attr in attributes: 32 | if x.__contains__(attr): 33 | values.append(x[attr]) 34 | else: 35 | values.append(np.nan) 36 | return values 37 | 38 | 39 | def parse_id(id_, missing_ratio=0.1): 40 | data = pd.read_csv(DATA_DIR / "physio/set-a/{}.txt".format(id_)) 41 | # set hour 42 | extract_relative_minutes = lambda x: (int(x[0]) * 60 + int(x[1])) / 60 / 48 43 | data['relativeTime'] = data['Time'].str.split(':').apply(extract_relative_minutes) 44 | 45 | data["Time"] = data["Time"].apply(lambda x: extract_hour(x)) 46 | 47 | 48 | # create data for 48 hours x 35 attributes 49 | observed_values = [] 50 | observed_times = [] 51 | for h in range(48): 52 | observed_values.append(parse_data(data[data["Time"] == h])) 53 | times = data[data['Time'] == h]['relativeTime'] 54 | observed_times.append(times.iloc[-1] if len(times) > 0 else np.nan) 55 | observed_values = np.array(observed_values) 56 | observed_times = np.array(observed_times) 57 | observed_masks = ~np.isnan(observed_values) 58 | 59 | # randomly set some percentage as ground-truth 60 | masks = observed_masks.reshape(-1).copy() 61 | obs_indices = np.where(masks)[0].tolist() 62 | miss_indices = np.random.choice( 63 | obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False 64 | ) 65 | masks[miss_indices] = False 66 | gt_masks = masks.reshape(observed_masks.shape) 67 | 68 | observed_values = np.nan_to_num(observed_values) 69 | observed_times = np.nan_to_num(observed_times) 70 | observed_masks = observed_masks.astype("float32") 71 | gt_masks = gt_masks.astype("float32") 72 | 73 | return observed_values, observed_masks, gt_masks, observed_times 74 | 75 | 76 | def get_idlist(): 77 | patient_id = [] 78 | for filename in os.listdir(DATA_DIR / "physio/set-a"): 79 | match = re.search("\d{6}", filename) 80 | if match: 81 | patient_id.append(match.group()) 82 | patient_id = np.sort(patient_id) 83 | return patient_id 84 | 85 | 86 | class Physio_Dataset(Dataset): 87 | def __init__(self, eval_length=48, use_index_list=None, missing_ratio=0.0, seed=0): 88 | self.eval_length = eval_length 89 | np.random.seed(seed) # seed for ground truth choice 90 | 91 | self.observed_values = [] 92 | self.observed_masks = [] 93 | self.gt_masks = [] 94 | self.observed_times = [] 95 | path = ( 96 | DATA_DIR / ("physio_missing" + str(missing_ratio) + "_seed" + str(seed) + ".pk") 97 | ) 98 | 99 | if os.path.isfile(path) == False: # if datasetfile is none, create 100 | idlist = get_idlist() 101 | for id_ in idlist: 102 | try: 103 | observed_values, observed_masks, gt_masks, observed_times = parse_id( 104 | id_, missing_ratio 105 | ) 106 | self.observed_values.append(observed_values) 107 | self.observed_masks.append(observed_masks) 108 | self.gt_masks.append(gt_masks) 109 | self.observed_times.append(observed_times) 110 | except Exception as e: 111 | print(id_, e) 112 | continue 113 | self.observed_values = np.array(self.observed_values) 114 | self.observed_masks = np.array(self.observed_masks) 115 | self.gt_masks = np.array(self.gt_masks) 116 | self.observed_times = np.array(self.observed_times) 117 | 118 | # calc mean and std and normalize values 119 | # (it is the same normalization as Cao et al. (2018) (https://github.com/caow13/BRITS)) 120 | tmp_values = self.observed_values.reshape(-1, 35) 121 | tmp_masks = self.observed_masks.reshape(-1, 35) 122 | mean = np.zeros(35) 123 | std = np.zeros(35) 124 | for k in range(35): 125 | c_data = tmp_values[:, k][tmp_masks[:, k] == 1] 126 | mean[k] = c_data.mean() 127 | std[k] = c_data.std() 128 | self.observed_values = ( 129 | (self.observed_values - mean) / std * self.observed_masks 130 | ) 131 | 132 | with open(path, "wb") as f: 133 | pickle.dump( 134 | [self.observed_values, self.observed_masks, self.gt_masks, self.observed_times], f 135 | ) 136 | else: # load datasetfile 137 | with open(path, "rb") as f: 138 | self.observed_values, self.observed_masks, self.gt_masks, self.observed_times = pickle.load( 139 | f 140 | ) 141 | if use_index_list is None: 142 | self.use_index_list = np.arange(len(self.observed_values)) 143 | else: 144 | self.use_index_list = use_index_list 145 | 146 | def __getitem__(self, org_index): 147 | index = self.use_index_list[org_index] 148 | s = { 149 | "observed_data": self.observed_values[index], 150 | "observed_mask": self.observed_masks[index], 151 | "gt_mask": self.gt_masks[index], 152 | "timepoints": np.arange(self.eval_length), 153 | "times": self.observed_times[index], 154 | } 155 | return s 156 | 157 | def __len__(self): 158 | return len(self.use_index_list) 159 | 160 | 161 | def get_dataloader(seed=1, nfold=None, batch_size=16, missing_ratio=0.1): 162 | 163 | # only to obtain total length of dataset 164 | dataset = Physio_Dataset(missing_ratio=missing_ratio, seed=seed) 165 | indlist = np.arange(len(dataset)) 166 | 167 | np.random.seed(seed) 168 | np.random.shuffle(indlist) 169 | 170 | # 5-fold test 171 | start = (int)(nfold * 0.2 * len(dataset)) 172 | end = (int)((nfold + 1) * 0.2 * len(dataset)) 173 | test_index = indlist[start:end] 174 | remain_index = np.delete(indlist, np.arange(start, end)) 175 | 176 | np.random.seed(seed) 177 | np.random.shuffle(remain_index) 178 | num_train = (int)(len(dataset) * 0.7) 179 | train_index = remain_index[:num_train] 180 | valid_index = remain_index[num_train:] 181 | 182 | dataset = Physio_Dataset( 183 | use_index_list=train_index, missing_ratio=missing_ratio, seed=seed 184 | ) 185 | train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=1) 186 | valid_dataset = Physio_Dataset( 187 | use_index_list=valid_index, missing_ratio=missing_ratio, seed=seed 188 | ) 189 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=0) 190 | test_dataset = Physio_Dataset( 191 | use_index_list=test_index, missing_ratio=missing_ratio, seed=seed 192 | ) 193 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=0) 194 | return train_loader, valid_loader, test_loader 195 | -------------------------------------------------------------------------------- /tsdiff/csdi/dataset_pm25.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from torch.utils.data import DataLoader, Dataset 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class PM25_Dataset(Dataset): 9 | def __init__(self, eval_length=36, target_dim=36, mode="train", validindex=0): 10 | self.eval_length = eval_length 11 | self.target_dim = target_dim 12 | 13 | path = "./data/pm25/pm25_meanstd.pk" 14 | with open(path, "rb") as f: 15 | self.train_mean, self.train_std = pickle.load(f) 16 | if mode == "train": 17 | month_list = [1, 2, 4, 5, 7, 8, 10, 11] 18 | # 1st,4th,7th,10th months are excluded from histmask (since the months are used for creating missing patterns in test dataset) 19 | flag_for_histmask = [0, 1, 0, 1, 0, 1, 0, 1] 20 | month_list.pop(validindex) 21 | flag_for_histmask.pop(validindex) 22 | elif mode == "valid": 23 | month_list = [1, 2, 4, 5, 7, 8, 10, 11] 24 | month_list = month_list[validindex : validindex + 1] 25 | elif mode == "test": 26 | month_list = [3, 6, 9, 12] 27 | self.month_list = month_list 28 | 29 | # create data for batch 30 | self.observed_data = [] # values (separated into each month) 31 | self.observed_mask = [] # masks (separated into each month) 32 | self.gt_mask = [] # ground-truth masks (separated into each month) 33 | self.index_month = [] # indicate month 34 | self.position_in_month = [] # indicate the start position in month (length is the same as index_month) 35 | self.valid_for_histmask = [] # whether the sample is used for histmask 36 | self.use_index = [] # to separate train/valid/test 37 | self.cut_length = [] # excluded from evaluation targets 38 | 39 | df = pd.read_csv( 40 | "./data/pm25/Code/STMVL/SampleData/pm25_ground.txt", 41 | index_col="datetime", 42 | parse_dates=True, 43 | ) 44 | df_gt = pd.read_csv( 45 | "./data/pm25/Code/STMVL/SampleData/pm25_missing.txt", 46 | index_col="datetime", 47 | parse_dates=True, 48 | ) 49 | for i in range(len(month_list)): 50 | current_df = df[df.index.month == month_list[i]] 51 | current_df_gt = df_gt[df_gt.index.month == month_list[i]] 52 | current_length = len(current_df) - eval_length + 1 53 | 54 | last_index = len(self.index_month) 55 | self.index_month += np.array([i] * current_length).tolist() 56 | self.position_in_month += np.arange(current_length).tolist() 57 | if mode == "train": 58 | self.valid_for_histmask += np.array( 59 | [flag_for_histmask[i]] * current_length 60 | ).tolist() 61 | 62 | # mask values for observed indices are 1 63 | c_mask = 1 - current_df.isnull().values 64 | c_gt_mask = 1 - current_df_gt.isnull().values 65 | c_data = ( 66 | (current_df.fillna(0).values - self.train_mean) / self.train_std 67 | ) * c_mask 68 | 69 | self.observed_mask.append(c_mask) 70 | self.gt_mask.append(c_gt_mask) 71 | self.observed_data.append(c_data) 72 | 73 | if mode == "test": 74 | n_sample = len(current_df) // eval_length 75 | # interval size is eval_length (missing values are imputed only once) 76 | c_index = np.arange( 77 | last_index, last_index + eval_length * n_sample, eval_length 78 | ) 79 | self.use_index += c_index.tolist() 80 | self.cut_length += [0] * len(c_index) 81 | if len(current_df) % eval_length != 0: # avoid double-count for the last time-series 82 | self.use_index += [len(self.index_month) - 1] 83 | self.cut_length += [eval_length - len(current_df) % eval_length] 84 | 85 | if mode != "test": 86 | self.use_index = np.arange(len(self.index_month)) 87 | self.cut_length = [0] * len(self.use_index) 88 | 89 | # masks for 1st,4th,7th,10th months are used for creating missing patterns in test data, 90 | # so these months are excluded from histmask to avoid leakage 91 | if mode == "train": 92 | ind = -1 93 | self.index_month_histmask = [] 94 | self.position_in_month_histmask = [] 95 | 96 | for i in range(len(self.index_month)): 97 | while True: 98 | ind += 1 99 | if ind == len(self.index_month): 100 | ind = 0 101 | if self.valid_for_histmask[ind] == 1: 102 | self.index_month_histmask.append(self.index_month[ind]) 103 | self.position_in_month_histmask.append( 104 | self.position_in_month[ind] 105 | ) 106 | break 107 | else: # dummy (histmask is only used for training) 108 | self.index_month_histmask = self.index_month 109 | self.position_in_month_histmask = self.position_in_month 110 | 111 | def __getitem__(self, org_index): 112 | index = self.use_index[org_index] 113 | c_month = self.index_month[index] 114 | c_index = self.position_in_month[index] 115 | hist_month = self.index_month_histmask[index] 116 | hist_index = self.position_in_month_histmask[index] 117 | s = { 118 | "observed_data": self.observed_data[c_month][ 119 | c_index : c_index + self.eval_length 120 | ], 121 | "observed_mask": self.observed_mask[c_month][ 122 | c_index : c_index + self.eval_length 123 | ], 124 | "gt_mask": self.gt_mask[c_month][ 125 | c_index : c_index + self.eval_length 126 | ], 127 | "hist_mask": self.observed_mask[hist_month][ 128 | hist_index : hist_index + self.eval_length 129 | ], 130 | "timepoints": np.arange(self.eval_length), 131 | "cut_length": self.cut_length[org_index], 132 | } 133 | 134 | return s 135 | 136 | def __len__(self): 137 | return len(self.use_index) 138 | 139 | 140 | def get_dataloader(batch_size, device, validindex=0): 141 | dataset = PM25_Dataset(mode="train", validindex=validindex) 142 | train_loader = DataLoader( 143 | dataset, batch_size=batch_size, num_workers=1, shuffle=True 144 | ) 145 | dataset_test = PM25_Dataset(mode="test", validindex=validindex) 146 | test_loader = DataLoader( 147 | dataset_test, batch_size=batch_size, num_workers=1, shuffle=False 148 | ) 149 | dataset_valid = PM25_Dataset(mode="valid", validindex=validindex) 150 | valid_loader = DataLoader( 151 | dataset_valid, batch_size=batch_size, num_workers=1, shuffle=False 152 | ) 153 | 154 | scaler = torch.from_numpy(dataset.train_std).to(device).float() 155 | mean_scaler = torch.from_numpy(dataset.train_mean).to(device).float() 156 | 157 | return train_loader, valid_loader, test_loader, scaler, mean_scaler 158 | -------------------------------------------------------------------------------- /tsdiff/csdi/diff_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def get_torch_trans(heads=8, layers=1, channels=64): 8 | encoder_layer = nn.TransformerEncoderLayer( 9 | d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu" 10 | ) 11 | return nn.TransformerEncoder(encoder_layer, num_layers=layers) 12 | 13 | 14 | def Conv1d_with_init(in_channels, out_channels, kernel_size): 15 | layer = nn.Conv1d(in_channels, out_channels, kernel_size) 16 | nn.init.kaiming_normal_(layer.weight) 17 | return layer 18 | 19 | 20 | class DiffusionEmbedding(nn.Module): 21 | def __init__(self, num_steps, embedding_dim=128, projection_dim=None): 22 | super().__init__() 23 | if projection_dim is None: 24 | projection_dim = embedding_dim 25 | self.register_buffer( 26 | "embedding", 27 | self._build_embedding(num_steps, embedding_dim / 2), 28 | persistent=False, 29 | ) 30 | self.projection1 = nn.Linear(embedding_dim, projection_dim) 31 | self.projection2 = nn.Linear(projection_dim, projection_dim) 32 | 33 | def forward(self, diffusion_step): 34 | x = self.embedding[diffusion_step] 35 | x = self.projection1(x) 36 | x = F.silu(x) 37 | x = self.projection2(x) 38 | x = F.silu(x) 39 | return x 40 | 41 | def _build_embedding(self, num_steps, dim=64): 42 | steps = torch.arange(num_steps).unsqueeze(1) # (T,1) 43 | frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(0) # (1,dim) 44 | table = steps * frequencies # (T,dim) 45 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) # (T,dim*2) 46 | return table 47 | 48 | 49 | class diff_CSDI(nn.Module): 50 | def __init__(self, config, inputdim=2): 51 | super().__init__() 52 | self.channels = config["channels"] 53 | 54 | self.diffusion_embedding = DiffusionEmbedding( 55 | num_steps=config["num_steps"], 56 | embedding_dim=config["diffusion_embedding_dim"], 57 | ) 58 | 59 | self.input_projection = Conv1d_with_init(inputdim, self.channels, 1) 60 | self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1) 61 | self.output_projection2 = Conv1d_with_init(self.channels, 1, 1) 62 | nn.init.zeros_(self.output_projection2.weight) 63 | 64 | self.residual_layers = nn.ModuleList( 65 | [ 66 | ResidualBlock( 67 | side_dim=config["side_dim"], 68 | channels=self.channels, 69 | diffusion_embedding_dim=config["diffusion_embedding_dim"], 70 | nheads=config["nheads"], 71 | ) 72 | for _ in range(config["layers"]) 73 | ] 74 | ) 75 | 76 | def forward(self, x, cond_info, diffusion_step): 77 | B, inputdim, K, L = x.shape 78 | 79 | x = x.reshape(B, inputdim, K * L) 80 | x = self.input_projection(x) 81 | x = F.relu(x) 82 | x = x.reshape(B, self.channels, K, L) 83 | 84 | diffusion_emb = self.diffusion_embedding(diffusion_step) 85 | 86 | skip = [] 87 | for layer in self.residual_layers: 88 | x, skip_connection = layer(x, cond_info, diffusion_emb) 89 | skip.append(skip_connection) 90 | 91 | x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) 92 | x = x.reshape(B, self.channels, K * L) 93 | x = self.output_projection1(x) # (B,channel,K*L) 94 | x = F.relu(x) 95 | x = self.output_projection2(x) # (B,1,K*L) 96 | x = x.reshape(B, K, L) 97 | return x 98 | 99 | 100 | class ResidualBlock(nn.Module): 101 | def __init__(self, side_dim, channels, diffusion_embedding_dim, nheads): 102 | super().__init__() 103 | self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels) 104 | self.cond_projection = Conv1d_with_init(side_dim, 2 * channels, 1) 105 | self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1) 106 | self.output_projection = Conv1d_with_init(channels, 2 * channels, 1) 107 | 108 | self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=channels) 109 | self.feature_layer = get_torch_trans(heads=nheads, layers=1, channels=channels) 110 | 111 | def forward_time(self, y, base_shape): 112 | B, channel, K, L = base_shape 113 | if L == 1: 114 | return y 115 | y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L) 116 | y = self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0) 117 | y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L) 118 | return y 119 | 120 | def forward_feature(self, y, base_shape): 121 | B, channel, K, L = base_shape 122 | if K == 1: 123 | return y 124 | y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K) 125 | y = self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0) 126 | y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L) 127 | return y 128 | 129 | def forward(self, x, cond_info, diffusion_emb): 130 | B, channel, K, L = x.shape 131 | base_shape = x.shape 132 | x = x.reshape(B, channel, K * L) 133 | 134 | diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(-1) # (B,channel,1) 135 | y = x + diffusion_emb 136 | 137 | y = self.forward_time(y, base_shape) 138 | y = self.forward_feature(y, base_shape) # (B,channel,K*L) 139 | y = self.mid_projection(y) # (B,2*channel,K*L) 140 | 141 | _, cond_dim, _, _ = cond_info.shape 142 | cond_info = cond_info.reshape(B, cond_dim, K * L) 143 | cond_info = self.cond_projection(cond_info) # (B,2*channel,K*L) 144 | y = y + cond_info 145 | 146 | gate, filter = torch.chunk(y, 2, dim=1) 147 | y = torch.sigmoid(gate) * torch.tanh(filter) # (B,channel,K*L) 148 | y = self.output_projection(y) 149 | 150 | residual, skip = torch.chunk(y, 2, dim=1) 151 | x = x.reshape(base_shape) 152 | residual = residual.reshape(base_shape) 153 | skip = skip.reshape(base_shape) 154 | return (x + residual) / math.sqrt(2.0), skip 155 | -------------------------------------------------------------------------------- /tsdiff/csdi/download.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | import zipfile 3 | import sys 4 | import os 5 | import wget 6 | import requests 7 | import pandas as pd 8 | import pickle 9 | 10 | os.makedirs("data/", exist_ok=True) 11 | if sys.argv[1] == "physio": 12 | url = "https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz?download" 13 | wget.download(url, out="data") 14 | with tarfile.open("data/set-a.tar.gz", "r:gz") as t: 15 | t.extractall(path="data/physio") 16 | 17 | elif sys.argv[1] == "pm25": 18 | url = "https://www.microsoft.com/en-us/research/wp-content/uploads/2016/06/STMVL-Release.zip" 19 | urlData = requests.get(url).content 20 | filename = "data/STMVL-Release.zip" 21 | with open(filename, mode="wb") as f: 22 | f.write(urlData) 23 | with zipfile.ZipFile(filename) as z: 24 | z.extractall("data/pm25") 25 | 26 | def create_normalizer_pm25(): 27 | df = pd.read_csv( 28 | "./data/pm25/Code/STMVL/SampleData/pm25_ground.txt", 29 | index_col="datetime", 30 | parse_dates=True, 31 | ) 32 | test_month = [3, 6, 9, 12] 33 | for i in test_month: 34 | df = df[df.index.month != i] 35 | mean = df.describe().loc["mean"].values 36 | std = df.describe().loc["std"].values 37 | path = "./data/pm25/pm25_meanstd.pk" 38 | with open(path, "wb") as f: 39 | pickle.dump([mean, std], f) 40 | create_normalizer_pm25() 41 | -------------------------------------------------------------------------------- /tsdiff/csdi/exe_physio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import datetime 4 | import json 5 | import yaml 6 | import os 7 | from pathlib import Path 8 | 9 | from tsdiff.csdi.main_model import CSDI_Physio 10 | from tsdiff.csdi.dataset_physio import get_dataloader 11 | from tsdiff.csdi.utils import train, evaluate 12 | 13 | parser = argparse.ArgumentParser(description="CSDI") 14 | parser.add_argument("--config", type=str, default="base.yaml") 15 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 16 | parser.add_argument("--seed", type=int, default=1) 17 | parser.add_argument("--testmissingratio", type=float, default=0.1) 18 | parser.add_argument( 19 | "--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])" 20 | ) 21 | parser.add_argument("--unconditional", action="store_true") 22 | parser.add_argument("--gp_noise", action="store_true") 23 | parser.add_argument("--modelfolder", type=str, default="") 24 | parser.add_argument("--nsample", type=int, default=100) 25 | 26 | args = parser.parse_args() 27 | print(args) 28 | 29 | path = Path(__file__).parents[0].resolve() / 'config' / args.config 30 | with open(path, "r") as f: 31 | config = yaml.safe_load(f) 32 | 33 | config["model"]["is_unconditional"] = args.unconditional 34 | config["model"]["test_missing_ratio"] = args.testmissingratio 35 | 36 | print(json.dumps(config, indent=4)) 37 | 38 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 39 | foldername = "./save/physio_fold" + str(args.nfold) + "_" + current_time + "/" 40 | print('model folder:', foldername) 41 | os.makedirs(foldername, exist_ok=True) 42 | with open(foldername + "config.json", "w") as f: 43 | json.dump(config, f, indent=4) 44 | 45 | train_loader, valid_loader, test_loader = get_dataloader( 46 | seed=args.seed, 47 | nfold=args.nfold, 48 | batch_size=config["train"]["batch_size"], 49 | missing_ratio=config["model"]["test_missing_ratio"], 50 | ) 51 | 52 | model = CSDI_Physio(config, args.device, gp_noise=args.gp_noise, gp_sigma=0.02).to(args.device) 53 | 54 | if args.modelfolder == "": 55 | train( 56 | model, 57 | config["train"], 58 | train_loader, 59 | valid_loader=valid_loader, 60 | foldername=foldername, 61 | ) 62 | else: 63 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth")) 64 | 65 | evaluate(model, test_loader, nsample=args.nsample, scaler=1, foldername=foldername) 66 | -------------------------------------------------------------------------------- /tsdiff/csdi/exe_pm25.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import datetime 4 | import json 5 | import yaml 6 | import os 7 | 8 | from dataset_pm25 import get_dataloader 9 | from main_model import CSDI_PM25 10 | from utils import train, evaluate 11 | 12 | parser = argparse.ArgumentParser(description="CSDI") 13 | parser.add_argument("--config", type=str, default="base.yaml") 14 | parser.add_argument('--device', default='cuda:0', help='Device for Attack') 15 | parser.add_argument("--modelfolder", type=str, default="") 16 | parser.add_argument( 17 | "--targetstrategy", type=str, default="mix", choices=["mix", "random", "historical"] 18 | ) 19 | parser.add_argument( 20 | "--validationindex", type=int, default=0, help="index of month used for validation (value:[0-7])" 21 | ) 22 | parser.add_argument("--nsample", type=int, default=100) 23 | parser.add_argument("--unconditional", action="store_true") 24 | 25 | args = parser.parse_args() 26 | print(args) 27 | 28 | path = "config/" + args.config 29 | with open(path, "r") as f: 30 | config = yaml.safe_load(f) 31 | 32 | config["model"]["is_unconditional"] = args.unconditional 33 | config["model"]["target_strategy"] = args.targetstrategy 34 | 35 | print(json.dumps(config, indent=4)) 36 | 37 | current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 38 | foldername = ( 39 | "./save/pm25_validationindex" + str(args.validationindex) + "_" + current_time + "/" 40 | ) 41 | 42 | print('model folder:', foldername) 43 | os.makedirs(foldername, exist_ok=True) 44 | with open(foldername + "config.json", "w") as f: 45 | json.dump(config, f, indent=4) 46 | 47 | train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader( 48 | config["train"]["batch_size"], device=args.device, validindex=args.validationindex 49 | ) 50 | model = CSDI_PM25(config, args.device).to(args.device) 51 | 52 | if args.modelfolder == "": 53 | train( 54 | model, 55 | config["train"], 56 | train_loader, 57 | valid_loader=valid_loader, 58 | foldername=foldername, 59 | ) 60 | else: 61 | model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth")) 62 | 63 | evaluate( 64 | model, 65 | test_loader, 66 | nsample=args.nsample, 67 | scaler=scaler, 68 | mean_scaler=mean_scaler, 69 | foldername=foldername, 70 | ) 71 | -------------------------------------------------------------------------------- /tsdiff/csdi/experiment.py: -------------------------------------------------------------------------------- 1 | from main_model import CSDI_Physio 2 | from dataset_physio import get_dataloader 3 | from utils import train, evaluate 4 | 5 | import seml 6 | from sacred import Experiment 7 | 8 | ex = Experiment() 9 | seml.setup_logger(ex) 10 | 11 | @ex.config 12 | def config(): 13 | overwrite = None 14 | db_collection = None 15 | if db_collection is not None: 16 | ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite)) 17 | 18 | @ex.automain 19 | def run( 20 | seed: int, 21 | lr: float, 22 | batch_size: int, 23 | epochs: int, 24 | nsample: int, 25 | testmissingratio: float, 26 | gp_noise: bool, 27 | is_unconditional: bool, 28 | timeemb: int, 29 | featureemb: int, 30 | target_strategy: str, 31 | num_steps: int, 32 | schedule: str, 33 | beta_start: float, 34 | beta_end: float, 35 | layers: int, 36 | channels: int, 37 | nheads: int, 38 | diffusion_embedding_dim: int, 39 | gp_sigma: float = None, 40 | device: str = 'cuda:0', 41 | ): 42 | 43 | config = dict( 44 | train=dict( 45 | lr=lr, 46 | batch_size=batch_size, 47 | epochs=epochs, 48 | ), 49 | model=dict( 50 | timeemb=timeemb, 51 | featureemb=featureemb, 52 | is_unconditional=is_unconditional, 53 | target_strategy=target_strategy, 54 | test_missing_ratio=testmissingratio, 55 | ), 56 | diffusion=dict( 57 | num_steps=num_steps, 58 | schedule=schedule, 59 | beta_start=beta_start, 60 | beta_end=beta_end, 61 | layers=layers, 62 | channels=channels, 63 | nheads=nheads, 64 | diffusion_embedding_dim=diffusion_embedding_dim, 65 | ), 66 | ) 67 | 68 | train_loader, valid_loader, test_loader = get_dataloader( 69 | seed=seed, 70 | nfold=0, 71 | batch_size=batch_size, 72 | missing_ratio=testmissingratio, 73 | ) 74 | 75 | model = CSDI_Physio(config, device, gp_noise=gp_noise, gp_sigma=gp_sigma).to(device) 76 | 77 | train( 78 | model, 79 | config["train"], 80 | train_loader, 81 | valid_loader=valid_loader, 82 | foldername='', 83 | ) 84 | 85 | results = evaluate(model, test_loader, nsample=nsample, scaler=1, foldername='') 86 | return results 87 | -------------------------------------------------------------------------------- /tsdiff/csdi/experiment.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | executable: csdi/experiment.py 3 | name: tsdiff_csdi 4 | output_dir: experiments/logs 5 | project_root_dir: ../ 6 | conda_environment: place_your_env_here 7 | 8 | slurm: 9 | experiments_per_job: 2 10 | sbatch_options: 11 | gres: gpu:1 # num GPUs 12 | mem: 16G # memory 13 | cpus-per-task: 2 # num cores 14 | time: 0-08:00 # max time, D-HH:MM 15 | partition: gpu_all 16 | 17 | fixed: 18 | epochs: 200 19 | batch_size: 16 20 | lr: 1e-3 21 | layers: 4 22 | channels: 64 23 | nheads: 8 24 | diffusion_embedding_dim: 128 25 | beta_start: 0.0001 26 | beta_end: 0.5 27 | num_steps: 50 28 | schedule: 'quad' 29 | is_unconditional: 0 30 | timeemb: 128 31 | featureemb: 16 32 | target_strategy: 'random' 33 | 34 | grid: 35 | seed: 36 | type: range 37 | min: 1 38 | max: 11 39 | step: 1 40 | 41 | testmissingratio: 42 | type: choice 43 | options: 44 | - 0.1 45 | - 0.5 46 | - 0.9 47 | 48 | nsample: 49 | type: choice 50 | options: 51 | - 20 52 | 53 | baseline: 54 | fixed: 55 | gp_noise: False 56 | 57 | our: 58 | fixed: 59 | gp_noise: True 60 | 61 | grid: 62 | gp_sigma: 63 | type: choice 64 | options: 65 | - 0.005 66 | - 0.01 67 | - 0.02 68 | - 0.05 69 | - 0.1 70 | -------------------------------------------------------------------------------- /tsdiff/csdi/main_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from tsdiff.csdi.diff_models import diff_CSDI 5 | 6 | from tsdiff.diffusion.noise import GaussianProcess 7 | 8 | 9 | class CSDI_base(nn.Module): 10 | def __init__(self, target_dim, config, device, gp_noise=False, gp_sigma=None): 11 | super().__init__() 12 | self.device = device 13 | self.target_dim = target_dim 14 | 15 | self.gp_noise = gp_noise 16 | if self.gp_noise: 17 | self.gp = GaussianProcess(target_dim, sigma=gp_sigma) 18 | 19 | self.emb_time_dim = config["model"]["timeemb"] 20 | self.emb_feature_dim = config["model"]["featureemb"] 21 | self.is_unconditional = config["model"]["is_unconditional"] 22 | self.target_strategy = config["model"]["target_strategy"] 23 | 24 | self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim 25 | if self.is_unconditional == False: 26 | self.emb_total_dim += 1 # for conditional mask 27 | self.embed_layer = nn.Embedding( 28 | num_embeddings=self.target_dim, embedding_dim=self.emb_feature_dim 29 | ) 30 | 31 | config_diff = config["diffusion"] 32 | config_diff["side_dim"] = self.emb_total_dim 33 | 34 | input_dim = 1 if self.is_unconditional == True else 2 35 | self.diffmodel = diff_CSDI(config_diff, input_dim) 36 | 37 | # parameters for diffusion models 38 | self.num_steps = config_diff["num_steps"] 39 | if config_diff["schedule"] == "quad": 40 | self.beta = np.linspace( 41 | config_diff["beta_start"] ** 0.5, config_diff["beta_end"] ** 0.5, self.num_steps 42 | ) ** 2 43 | elif config_diff["schedule"] == "linear": 44 | self.beta = np.linspace( 45 | config_diff["beta_start"], config_diff["beta_end"], self.num_steps 46 | ) 47 | 48 | self.alpha_hat = 1 - self.beta 49 | self.alpha = np.cumprod(self.alpha_hat) 50 | self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1) 51 | 52 | def time_embedding(self, pos, d_model=128): 53 | pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(self.device) 54 | position = pos.unsqueeze(2) 55 | div_term = 1 / torch.pow( 56 | 10, torch.arange(0, d_model, 2).to(self.device) / d_model 57 | ) 58 | pe[:, :, 0::2] = torch.sin(position * div_term) 59 | pe[:, :, 1::2] = torch.cos(position * div_term) 60 | return pe 61 | 62 | def get_randmask(self, observed_mask): 63 | rand_for_mask = torch.rand_like(observed_mask) * observed_mask 64 | rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) 65 | for i in range(len(observed_mask)): 66 | sample_ratio = np.random.rand() # missing ratio 67 | num_observed = observed_mask[i].sum().item() 68 | num_masked = round(num_observed * sample_ratio) 69 | rand_for_mask[i][rand_for_mask[i].topk(num_masked).indices] = -1 70 | cond_mask = (rand_for_mask > 0).reshape(observed_mask.shape).float() 71 | return cond_mask 72 | 73 | def get_hist_mask(self, observed_mask, for_pattern_mask=None): 74 | if for_pattern_mask is None: 75 | for_pattern_mask = observed_mask 76 | if self.target_strategy == "mix": 77 | rand_mask = self.get_randmask(observed_mask) 78 | 79 | cond_mask = observed_mask.clone() 80 | for i in range(len(cond_mask)): 81 | mask_choice = np.random.rand() 82 | if self.target_strategy == "mix" and mask_choice > 0.5: 83 | cond_mask[i] = rand_mask[i] 84 | else: # draw another sample for histmask (i-1 corresponds to another sample) 85 | cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1] 86 | return cond_mask 87 | 88 | def get_side_info(self, observed_times, cond_mask): 89 | B, K, L = cond_mask.shape 90 | 91 | time_embed = self.time_embedding(observed_times, self.emb_time_dim) # (B,L,emb) 92 | time_embed = time_embed.unsqueeze(2).expand(-1, -1, K, -1) 93 | feature_embed = self.embed_layer( 94 | torch.arange(self.target_dim).to(self.device) 95 | ) # (K,emb) 96 | feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1) 97 | 98 | side_info = torch.cat([time_embed, feature_embed], dim=-1) # (B,L,K,*) 99 | side_info = side_info.permute(0, 3, 2, 1) # (B,*,K,L) 100 | 101 | if self.is_unconditional == False: 102 | side_mask = cond_mask.unsqueeze(1) # (B,1,K,L) 103 | side_info = torch.cat([side_info, side_mask], dim=1) 104 | 105 | return side_info 106 | 107 | def calc_loss_valid( 108 | self, observed_data, cond_mask, observed_mask, observed_times, side_info, is_train 109 | ): 110 | loss_sum = 0 111 | for t in range(self.num_steps): # calculate loss for all t 112 | loss = self.calc_loss( 113 | observed_data, cond_mask, observed_mask, observed_times, side_info, is_train, set_t=t 114 | ) 115 | loss_sum += loss.detach() 116 | return loss_sum / self.num_steps 117 | 118 | def calc_loss( 119 | self, observed_data, cond_mask, observed_mask, observed_times, side_info, is_train, set_t=-1 120 | ): 121 | B, K, L = observed_data.shape 122 | if is_train != 1: # for validation 123 | t = (torch.ones(B) * set_t).long().to(self.device) 124 | else: 125 | t = torch.randint(0, self.num_steps, [B]).to(self.device) 126 | current_alpha = self.alpha_torch[t] # (B,1,1) 127 | noise = torch.randn_like(observed_data) 128 | 129 | if self.gp_noise: 130 | # Compute GP covariance matrix 131 | L = self.gp.covariance_cholesky(t=observed_times.unsqueeze(-1)) 132 | # Replace normal noise with GP noise 133 | noise = (L @ noise.transpose(-1, -2)).transpose(-1, -2) 134 | 135 | noisy_data = (current_alpha ** 0.5) * observed_data + (1.0 - current_alpha) ** 0.5 * noise 136 | 137 | total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask) 138 | 139 | predicted = self.diffmodel(total_input, side_info, t) # (B,K,L) 140 | 141 | target_mask = observed_mask - cond_mask 142 | residual = (noise - predicted) * target_mask 143 | num_eval = target_mask.sum() 144 | loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1) 145 | return loss 146 | 147 | def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): 148 | if self.is_unconditional == True: 149 | total_input = noisy_data.unsqueeze(1) # (B,1,K,L) 150 | else: 151 | cond_obs = (cond_mask * observed_data).unsqueeze(1) 152 | noisy_target = ((1 - cond_mask) * noisy_data).unsqueeze(1) 153 | total_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) 154 | 155 | return total_input 156 | 157 | def impute(self, observed_data, cond_mask, observed_times, side_info, n_samples): 158 | B, K, L = observed_data.shape 159 | 160 | imputed_samples = torch.zeros(B, n_samples, K, L).to(self.device) 161 | 162 | # Precompute GP covariance matrix 163 | if self.gp_noise: 164 | L = self.gp.covariance_cholesky(t=observed_times.unsqueeze(-1)) 165 | 166 | for i in range(n_samples): 167 | # generate noisy observation for unconditional model 168 | if self.is_unconditional == True: 169 | noisy_obs = observed_data 170 | noisy_cond_history = [] 171 | for t in range(self.num_steps): 172 | noise = torch.randn_like(noisy_obs) 173 | 174 | # Replace unit normal noise with GP noise 175 | if self.gp_noise: 176 | noise = (L @ noise.transpose(-1, -2)).transpose(-1, -2) 177 | 178 | noisy_obs = (self.alpha_hat[t] ** 0.5) * noisy_obs + self.beta[t] ** 0.5 * noise 179 | noisy_cond_history.append(noisy_obs * cond_mask) 180 | 181 | current_sample = torch.randn_like(observed_data) 182 | 183 | # Replace initial unit normal noise with GP noise 184 | if self.gp_noise: 185 | current_sample = (L @ current_sample.transpose(-1, -2)).transpose(-1, -2) 186 | 187 | for t in range(self.num_steps - 1, -1, -1): 188 | if self.is_unconditional == True: 189 | diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample 190 | diff_input = diff_input.unsqueeze(1) # (B,1,K,L) 191 | else: 192 | cond_obs = (cond_mask * observed_data).unsqueeze(1) 193 | noisy_target = ((1 - cond_mask) * current_sample).unsqueeze(1) 194 | diff_input = torch.cat([cond_obs, noisy_target], dim=1) # (B,2,K,L) 195 | predicted = self.diffmodel(diff_input, side_info, torch.tensor([t]).to(self.device)) 196 | 197 | coeff1 = 1 / self.alpha_hat[t] ** 0.5 198 | coeff2 = (1 - self.alpha_hat[t]) / (1 - self.alpha[t]) ** 0.5 199 | current_sample = coeff1 * (current_sample - coeff2 * predicted) 200 | 201 | if t > 0: 202 | noise = torch.randn_like(current_sample) 203 | 204 | # Replace unit normal noise with GP noise 205 | if self.gp_noise: 206 | noise = (L @ noise.transpose(-1, -2)).transpose(-1, -2) 207 | 208 | sigma = ((1.0 - self.alpha[t - 1]) / (1.0 - self.alpha[t]) * self.beta[t]) ** 0.5 209 | current_sample += sigma * noise 210 | 211 | imputed_samples[:, i] = current_sample.detach() 212 | return imputed_samples 213 | 214 | def forward(self, batch, is_train=1): 215 | ( 216 | observed_data, 217 | observed_mask, 218 | observed_times, 219 | gt_mask, 220 | for_pattern_mask, 221 | _, 222 | ) = self.process_data(batch) 223 | if is_train == 0: 224 | cond_mask = gt_mask 225 | elif self.target_strategy != "random": 226 | cond_mask = self.get_hist_mask( 227 | observed_mask, for_pattern_mask=for_pattern_mask 228 | ) 229 | else: 230 | cond_mask = self.get_randmask(observed_mask) 231 | 232 | side_info = self.get_side_info(observed_times, cond_mask) 233 | 234 | loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid 235 | 236 | return loss_func(observed_data, cond_mask, observed_mask, observed_times, side_info, is_train) 237 | 238 | def evaluate(self, batch, n_samples): 239 | ( 240 | observed_data, 241 | observed_mask, 242 | observed_times, 243 | gt_mask, 244 | _, 245 | cut_length, 246 | ) = self.process_data(batch) 247 | 248 | with torch.no_grad(): 249 | cond_mask = gt_mask 250 | target_mask = observed_mask - cond_mask 251 | 252 | side_info = self.get_side_info(observed_times, cond_mask) 253 | 254 | samples = self.impute(observed_data, cond_mask, observed_times, side_info, n_samples) 255 | 256 | for i in range(len(cut_length)): # to avoid double evaluation 257 | target_mask[i, ..., 0 : cut_length[i].item()] = 0 258 | return samples, observed_data, target_mask, observed_mask, observed_times 259 | 260 | 261 | class CSDI_PM25(CSDI_base): 262 | def __init__(self, config, device, target_dim=36): 263 | super(CSDI_PM25, self).__init__(target_dim, config, device) 264 | 265 | def process_data(self, batch): 266 | observed_data = batch["observed_data"].to(self.device).float() 267 | observed_mask = batch["observed_mask"].to(self.device).float() 268 | observed_tp = batch["timepoints"].to(self.device).float() 269 | gt_mask = batch["gt_mask"].to(self.device).float() 270 | cut_length = batch["cut_length"].to(self.device).long() 271 | for_pattern_mask = batch["hist_mask"].to(self.device).float() 272 | 273 | observed_data = observed_data.permute(0, 2, 1) 274 | observed_mask = observed_mask.permute(0, 2, 1) 275 | gt_mask = gt_mask.permute(0, 2, 1) 276 | for_pattern_mask = for_pattern_mask.permute(0, 2, 1) 277 | 278 | return ( 279 | observed_data, 280 | observed_mask, 281 | gt_mask, 282 | for_pattern_mask, 283 | cut_length, 284 | ) 285 | 286 | 287 | class CSDI_Physio(CSDI_base): 288 | def __init__(self, config, device, gp_noise=False, gp_sigma=None, target_dim=35): 289 | super(CSDI_Physio, self).__init__(target_dim, config, device, gp_noise, gp_sigma) 290 | 291 | def process_data(self, batch): 292 | observed_data = batch["observed_data"].to(self.device).float() 293 | observed_mask = batch["observed_mask"].to(self.device).float() 294 | observed_tp = batch["timepoints"].to(self.device).float() 295 | observed_times = batch["times"].to(self.device).float() 296 | gt_mask = batch["gt_mask"].to(self.device).float() 297 | 298 | observed_data = observed_data.permute(0, 2, 1) 299 | observed_mask = observed_mask.permute(0, 2, 1) 300 | gt_mask = gt_mask.permute(0, 2, 1) 301 | 302 | cut_length = torch.zeros(len(observed_data)).long().to(self.device) 303 | for_pattern_mask = observed_mask 304 | 305 | return ( 306 | observed_data, 307 | observed_mask, 308 | observed_times, 309 | gt_mask, 310 | for_pattern_mask, 311 | cut_length, 312 | ) 313 | -------------------------------------------------------------------------------- /tsdiff/csdi/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.optim import Adam 4 | from tqdm import tqdm 5 | import pickle 6 | 7 | 8 | def train( 9 | model, 10 | config, 11 | train_loader, 12 | valid_loader=None, 13 | valid_epoch_interval=5, 14 | foldername="", 15 | ): 16 | optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-6) 17 | if foldername != "": 18 | output_path = foldername + "/model.pth" 19 | 20 | p1 = int(0.75 * config["epochs"]) 21 | p2 = int(0.9 * config["epochs"]) 22 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 23 | optimizer, milestones=[p1, p2], gamma=0.1 24 | ) 25 | 26 | best_valid_loss = 1e10 27 | for epoch_no in range(config["epochs"]): 28 | avg_loss = 0 29 | model.train() 30 | with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it: 31 | for batch_no, train_batch in enumerate(it, start=1): 32 | optimizer.zero_grad() 33 | 34 | loss = model(train_batch) 35 | loss.backward() 36 | avg_loss += loss.item() 37 | optimizer.step() 38 | it.set_postfix( 39 | ordered_dict={ 40 | "avg_epoch_loss": avg_loss / batch_no, 41 | "epoch": epoch_no, 42 | }, 43 | refresh=False, 44 | ) 45 | lr_scheduler.step() 46 | if valid_loader is not None and (epoch_no + 1) % valid_epoch_interval == 0: 47 | model.eval() 48 | avg_loss_valid = 0 49 | with torch.no_grad(): 50 | with tqdm(valid_loader, mininterval=5.0, maxinterval=50.0) as it: 51 | for batch_no, valid_batch in enumerate(it, start=1): 52 | loss = model(valid_batch, is_train=0) 53 | avg_loss_valid += loss.item() 54 | it.set_postfix( 55 | ordered_dict={ 56 | "valid_avg_epoch_loss": avg_loss_valid / batch_no, 57 | "epoch": epoch_no, 58 | }, 59 | refresh=False, 60 | ) 61 | if best_valid_loss > avg_loss_valid: 62 | best_valid_loss = avg_loss_valid 63 | print( 64 | "\n best loss is updated to ", 65 | avg_loss_valid / batch_no, 66 | "at", 67 | epoch_no, 68 | ) 69 | 70 | if foldername != "": 71 | torch.save(model.state_dict(), output_path) 72 | 73 | 74 | def quantile_loss(target, forecast, q: float, eval_points) -> float: 75 | return 2 * torch.sum( 76 | torch.abs((forecast - target) * eval_points * ((target <= forecast) * 1.0 - q)) 77 | ) 78 | 79 | 80 | def calc_denominator(target, eval_points): 81 | return torch.sum(torch.abs(target * eval_points)) 82 | 83 | 84 | def calc_quantile_CRPS(target, forecast, eval_points, mean_scaler, scaler): 85 | target = target * scaler + mean_scaler 86 | forecast = forecast * scaler + mean_scaler 87 | 88 | quantiles = np.arange(0.05, 1.0, 0.05) 89 | denom = calc_denominator(target, eval_points) 90 | CRPS = 0 91 | for i in range(len(quantiles)): 92 | q_pred = [] 93 | for j in range(len(forecast)): 94 | q_pred.append(torch.quantile(forecast[j : j + 1], quantiles[i], dim=1)) 95 | q_pred = torch.cat(q_pred, 0) 96 | q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points) 97 | CRPS += q_loss / denom 98 | return CRPS.item() / len(quantiles) 99 | 100 | 101 | def evaluate(model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername=""): 102 | 103 | with torch.no_grad(): 104 | model.eval() 105 | mse_total = 0 106 | mae_total = 0 107 | evalpoints_total = 0 108 | 109 | all_target = [] 110 | all_observed_point = [] 111 | all_observed_time = [] 112 | all_evalpoint = [] 113 | all_generated_samples = [] 114 | with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it: 115 | for batch_no, test_batch in enumerate(it, start=1): 116 | output = model.evaluate(test_batch, nsample) 117 | 118 | samples, c_target, eval_points, observed_points, observed_time = output 119 | samples = samples.permute(0, 1, 3, 2) # (B,nsample,L,K) 120 | c_target = c_target.permute(0, 2, 1) # (B,L,K) 121 | eval_points = eval_points.permute(0, 2, 1) 122 | observed_points = observed_points.permute(0, 2, 1) 123 | 124 | samples_median = samples.median(dim=1) 125 | all_target.append(c_target) 126 | all_evalpoint.append(eval_points) 127 | all_observed_point.append(observed_points) 128 | all_observed_time.append(observed_time) 129 | all_generated_samples.append(samples) 130 | 131 | mse_current = ( 132 | ((samples_median.values - c_target) * eval_points) ** 2 133 | ) * (scaler ** 2) 134 | mae_current = ( 135 | torch.abs((samples_median.values - c_target) * eval_points) 136 | ) * scaler 137 | 138 | mse_total += mse_current.sum().item() 139 | mae_total += mae_current.sum().item() 140 | evalpoints_total += eval_points.sum().item() 141 | 142 | it.set_postfix( 143 | ordered_dict={ 144 | "rmse_total": np.sqrt(mse_total / evalpoints_total), 145 | "mae_total": mae_total / evalpoints_total, 146 | "batch_no": batch_no, 147 | }, 148 | refresh=True, 149 | ) 150 | 151 | all_target = torch.cat(all_target, dim=0) 152 | all_evalpoint = torch.cat(all_evalpoint, dim=0) 153 | all_observed_point = torch.cat(all_observed_point, dim=0) 154 | all_observed_time = torch.cat(all_observed_time, dim=0) 155 | all_generated_samples = torch.cat(all_generated_samples, dim=0) 156 | 157 | crps = calc_quantile_CRPS( 158 | all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler 159 | ) 160 | 161 | results = dict( 162 | rmse=np.sqrt(mse_total / evalpoints_total), 163 | mae=mae_total / evalpoints_total, 164 | crps=crps, 165 | ) 166 | 167 | print("RMSE:", results['rmse']) 168 | print("MAE:", results['mae']) 169 | print("CRPS:", results['crps']) 170 | 171 | return results 172 | -------------------------------------------------------------------------------- /tsdiff/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbilos/tsdiff/32b23f2b7f5ec4d68bc533dda8a0096086cbd0ab/tsdiff/data/__init__.py -------------------------------------------------------------------------------- /tsdiff/data/generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torchdiffeq import odeint 5 | from torchsde import sdeint 6 | from pathlib import Path 7 | 8 | DATA_DIR = Path(__file__).parents[2] / 'data/synthetic' 9 | DATA_DIR.mkdir(parents=True, exist_ok=True) 10 | 11 | def generate_OU(N=10_000, mu=0.02, theta=0.1, sigma=0.4, regular=True): 12 | class OU(nn.Module): 13 | noise_type = 'diagonal' 14 | sde_type = 'ito' 15 | 16 | def __init__(self, mu, theta, sigma): 17 | super().__init__() 18 | self.mu = mu 19 | self.theta = theta 20 | self.sigma = sigma 21 | 22 | # Drift 23 | def f(self, t, y): 24 | return self.mu * t - self.theta * y 25 | 26 | # Diffusion 27 | def g(self, t, y): 28 | return self.sigma * torch.ones_like(y).to(y) 29 | 30 | f = OU(mu, theta, sigma) 31 | if regular: 32 | t = torch.linspace(0, 63, 64) 33 | x0 = torch.randn(N, 1) 34 | 35 | with torch.no_grad(): 36 | x = sdeint(f, x0, t, dt=0.1).transpose(0, 1) 37 | 38 | t = t.view(1, -1, 1).expand_as(x[...,:1]) 39 | 40 | np.savez(DATA_DIR / 'ou.npz', t=t.numpy(), x=x.numpy()) 41 | else: 42 | np.savez(DATA_DIR / 'ou_irregular.npz', t=t.numpy(), x=x.numpy()) 43 | 44 | 45 | def generate_CIR(N=10_000, a=1, b=1.2, sigma=0.2, regular=True): 46 | class CIR(nn.Module): 47 | """ Cox-Ingersoll-Ross """ 48 | noise_type = 'diagonal' 49 | sde_type = 'ito' 50 | 51 | def __init__(self, a, b, sigma): 52 | super().__init__() 53 | self.a = a 54 | self.b = b 55 | self.sigma = sigma 56 | 57 | def f(self, t, y): 58 | return self.a * (self.b - y) 59 | 60 | def g(self, t, y): 61 | return self.sigma * y.sqrt() 62 | 63 | f = CIR(a, b, sigma) 64 | if regular: 65 | t = torch.linspace(0, 63, 64) 66 | x0 = torch.randn(N, 1).abs() 67 | 68 | with torch.no_grad(): 69 | x = sdeint(f, x0, t, dt=0.1).transpose(0, 1) 70 | 71 | t = t.view(1, -1, 1).expand_as(x[...,:1]) 72 | 73 | np.savez(DATA_DIR / 'cir.npz', t=t.numpy(), x=x.numpy()) 74 | else: 75 | np.savez(DATA_DIR / 'cir_irregular.npz', t=t.numpy(), x=x.numpy()) 76 | 77 | 78 | def generate_lorenz(N=10_000, rho=28, sigma=10, beta=2.667, regular=True): 79 | class Lorenz(nn.Module): 80 | def __init__(self, rho, sigma, beta): 81 | super().__init__() 82 | self.rho = rho 83 | self.sigma = sigma 84 | self.beta = beta 85 | 86 | def forward(self, t, inp): 87 | x, y, z = inp.chunk(3, dim=-1) 88 | dx = self.sigma * (y - x) 89 | dy = x * self.rho - y - x * z 90 | dz = x * y - self.beta * z 91 | d_inp = torch.cat([dx, dy, dz], -1) 92 | return d_inp 93 | 94 | f = Lorenz(rho, sigma, beta) 95 | if regular: 96 | t = torch.linspace(0, 2, 100) 97 | x0 = torch.randn(N, 3) * 10 98 | 99 | with torch.no_grad(): 100 | x = odeint(f, x0, t, method='dopri5').transpose(0, 1) 101 | t = t.view(1, -1, 1).expand_as(x[...,:1]) 102 | 103 | np.savez(DATA_DIR / 'lorenz.npz', t=t.numpy(), x=x.numpy()) 104 | else: 105 | np.savez(DATA_DIR / 'lorenz_irregular.npz', t=t.numpy(), x=x.numpy()) 106 | 107 | 108 | def generate_sine(N=10_000, regular=True): 109 | a = torch.rand(N, 1, 5) + 3 110 | b = torch.rand(N, 1, 5) * 0.5 111 | c = torch.rand(N, 1, 5) 112 | 113 | if regular: 114 | t = torch.linspace(0, 10, 100).view(1, -1, 1).repeat(N, 1, 1) 115 | x = (c * torch.sin(a * t + b)).sum(-1, keepdim=True) 116 | 117 | np.savez(DATA_DIR / 'sine.npz', t=t.numpy(), x=x.numpy()) 118 | else: 119 | np.savez(DATA_DIR / 'sine_irregular.npz', t=t.numpy(), x=x.numpy()) 120 | 121 | 122 | def generate_predator_prey(N=10_000, regular=True): 123 | class PredatorPrey(nn.Module): 124 | def forward(self, t, y): 125 | y1, y2 = y.chunk(2, dim=-1) 126 | dy = torch.cat([ 127 | 2/3 * y1 - 2/3 * y1 * y2, 128 | y1 * y2 - y2, 129 | ], -1) 130 | return dy 131 | 132 | f = PredatorPrey() 133 | if regular: 134 | t = torch.linspace(0, 10, 64) 135 | x0 = torch.rand(N, 2) 136 | 137 | with torch.no_grad(): 138 | x = odeint(f, x0, t, method='dopri5').transpose(0, 1) 139 | 140 | t = t.view(1, -1, 1).expand_as(x[...,:1]) 141 | 142 | np.savez(DATA_DIR / 'predator_prey.npz', t=t.numpy(), x=x.numpy()) 143 | else: 144 | np.savez(DATA_DIR / 'predator_prey_irregular.npz', t=t.numpy(), x=x.numpy()) 145 | 146 | def generate_sink(N=10_000, regular=True): 147 | class Sink(nn.Module): 148 | def forward(self, t, y): 149 | A = torch.Tensor([[-4, 10], [-3, 2]]).to(y) 150 | return y @ A 151 | 152 | f = Sink() 153 | if regular: 154 | t = torch.linspace(0, 3, 64) 155 | x0 = torch.rand(N, 2) 156 | 157 | with torch.no_grad(): 158 | x = odeint(f, x0, t, method='dopri5').transpose(0, 1) 159 | 160 | t = t.view(1, -1, 1).expand_as(x[...,:1]) 161 | 162 | np.savez(DATA_DIR / 'sink.npz', t=t.numpy(), x=x.numpy()) 163 | else: 164 | np.savez(DATA_DIR / 'sink_irregular.npz', t=t.numpy(), x=x.numpy()) 165 | 166 | 167 | 168 | if __name__ == '__main__': 169 | generate_OU() 170 | generate_CIR() 171 | generate_sine() 172 | generate_lorenz() 173 | generate_predator_prey() 174 | generate_sink() 175 | -------------------------------------------------------------------------------- /tsdiff/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .beta_scheduler import * 2 | from .discrete_diffusion import * 3 | from .continuous_diffusion import * 4 | -------------------------------------------------------------------------------- /tsdiff/diffusion/beta_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | def get_beta_scheduler(name: str) -> Callable: 7 | if name == 'linear': 8 | return BetaLinear 9 | 10 | def get_loss_weighting(name: str) -> Callable: 11 | if name == 'exponential': 12 | return exponential_loss_weighting 13 | 14 | class BetaLinear(nn.Module): 15 | """ 16 | Linear scheduling for beta. 17 | Input t is always from interval [0, 1]. 18 | 19 | Args: 20 | start: Lower bound (float) 21 | end: Upper bound (float) 22 | """ 23 | def __init__(self, start: float, end: float): 24 | super().__init__() 25 | self.start = start 26 | self.end = end 27 | 28 | def forward(self, t: Tensor) -> Tensor: 29 | return self.start * (1 - t) + self.end * t 30 | 31 | def integral(self, t: Tensor) -> Tensor: 32 | return 0.5 * (self.end - self.start) * t.square() + self.start * t 33 | 34 | 35 | def exponential_loss_weighting(beta_fn, i): 36 | return 1 - torch.exp(-beta_fn.integral(i)) 37 | -------------------------------------------------------------------------------- /tsdiff/diffusion/continuous_diffusion.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Optional, Union 2 | from torchtyping import TensorType 3 | from functools import partial 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.distributions as td 9 | 10 | from torchsde import sdeint 11 | from torchdiffeq import odeint 12 | 13 | from tsdiff.diffusion.noise import Normal, OrnsteinUhlenbeck, GaussianProcess 14 | 15 | 16 | class ContinuousDiffusion(nn.Module): 17 | """ 18 | Continuous diffusion using SDEs (https://arxiv.org/abs/2011.13456) 19 | 20 | Args: 21 | dim: Dimension of data 22 | beta_fn: Scheduler for noise levels 23 | t1: Final diffusion time 24 | noise_fn: Type of noise 25 | predict_gaussian_noise: Whether to approximate score with unit normal 26 | loss_weighting: Function returning loss weights given diffusion time 27 | """ 28 | def __init__( 29 | self, 30 | dim: int, 31 | beta_fn: Callable, 32 | t1: float = 1.0, 33 | noise_fn: Callable = None, 34 | loss_weighting: Callable = None, 35 | is_time_series: bool = False, 36 | predict_gaussian_noise: bool = True, 37 | **kwargs, 38 | ): 39 | super().__init__() 40 | self.dim = dim 41 | self.t1 = t1 42 | self.predict_gaussian_noise = predict_gaussian_noise 43 | self.is_time_series = is_time_series 44 | 45 | self.beta_fn = beta_fn 46 | self.noise = noise_fn 47 | self.loss_weighting = partial(loss_weighting or (lambda beta, i: 1), beta_fn) 48 | 49 | def forward( 50 | self, 51 | x: TensorType[..., 'dim'], 52 | i: TensorType[..., 1], 53 | _return_all: Optional[bool] = False, # For internal use only 54 | **kwargs, 55 | ) -> Tuple[TensorType[..., 'dim'], TensorType[..., 'dim']]: 56 | 57 | noise_gaussian = torch.randn_like(x) 58 | 59 | if self.is_time_series: 60 | cov = self.noise.covariance(**kwargs) 61 | L = torch.linalg.cholesky(cov) 62 | noise = L @ noise_gaussian 63 | else: 64 | noise = noise_gaussian 65 | 66 | beta_int = self.beta_fn.integral(i) 67 | 68 | mean = x * torch.exp(-beta_int / 2) 69 | std = (1 - torch.exp(-beta_int)).clamp(1e-5).sqrt() 70 | 71 | y = mean + std * noise 72 | 73 | if _return_all: 74 | return y, noise, mean, std, cov if self.is_time_series else None 75 | 76 | if self.predict_gaussian_noise: 77 | return y, noise_gaussian 78 | else: 79 | return y, noise 80 | 81 | def get_loss( 82 | self, 83 | model: Callable, 84 | x: TensorType[..., 'dim'], 85 | **kwargs, 86 | ) -> TensorType[..., 1]: 87 | 88 | i = torch.rand(x.shape[0], *(1,) * len(x.shape[1:])).expand_as(x[...,:1]).to(x) 89 | i = i * self.t1 90 | 91 | x_noisy, noise = self.forward(x, i, **kwargs) 92 | 93 | pred_noise = model(x_noisy, i=i, **kwargs) 94 | loss = self.loss_weighting(i) * (pred_noise - noise)**2 95 | 96 | return loss 97 | 98 | def _get_score(self, model, x, i, L=None, **kwargs): 99 | """ 100 | Returns score: ∇_xs log p(xs) 101 | """ 102 | if isinstance(i, float): 103 | i = torch.Tensor([i]).to(x) 104 | if i.shape[:-1] != x.shape[:-1]: 105 | i = i.view(*(1,) * len(x.shape)).expand_as(x[...,:1]) 106 | 107 | beta_int = self.beta_fn.integral(i) 108 | std = (1 - torch.exp(-beta_int)).clamp(1e-5).sqrt() 109 | 110 | noise = model(x, i=i, **kwargs) 111 | 112 | if L is not None: 113 | # We have to compute the score using -Sigma.inv() @ noise / std 114 | # assuming noise~N(0, Sigma). 115 | # If `predict_gaussian_noise=False`, compute (LL^T).inv() 116 | # Else, we can simplify (LL^T).inv() @ L @ noise 117 | # to (L^T).inv() @ noise, where noise~N(0, I). 118 | # So we anyways have to do (L^T).inv(), and sometimes L.inv() 119 | if not self.predict_gaussian_noise: 120 | noise = torch.linalg.solve_triangular(L, noise, upper=False) 121 | noise = torch.linalg.solve_triangular(L.transpose(-1, -2), noise, upper=True) 122 | 123 | score = -noise / std 124 | return score 125 | 126 | @torch.no_grad() 127 | def log_prob( 128 | self, 129 | model: Callable, 130 | x: Union[TensorType[..., 'dim'], TensorType[..., 'seq_len', 'dim']], 131 | num_samples: int = 1, 132 | **kwargs, 133 | ) -> TensorType[..., 1]: 134 | model.train() # Allows backprop through RNN 135 | self._e = torch.randn(num_samples, *x.shape).to(x) 136 | 137 | if self.is_time_series: 138 | cov = self.noise.covariance(**kwargs) 139 | L = torch.linalg.cholesky(cov) 140 | else: 141 | L = None 142 | 143 | def drift(i, state): 144 | y, _ = state 145 | with torch.set_grad_enabled(True): 146 | y = y.requires_grad_(True) 147 | score = self._get_score(model, y, i=i, L=L, **kwargs) 148 | if self.is_time_series: 149 | # Have to include `cov` since g(t) = "scalar" * L @ dW 150 | score = cov @ score 151 | dy = -0.5 * self.beta_fn(i) * (y + score) 152 | divergence = divergence_approx(dy, y, self._e, num_samples=num_samples) 153 | return dy, -divergence 154 | 155 | interval = torch.Tensor([0, self.t1]).to(x) 156 | 157 | # states = odeint(drift, (x, torch.zeros_like(x).to(x)), interval, rtol=1e-6, atol=1e-5) 158 | states = odeint(drift, (x, torch.zeros_like(x).to(x)), interval, 159 | method='rk4', options={'step_size': .01}) 160 | y, div = states[0][-1], states[1][-1] 161 | 162 | if self.is_time_series: 163 | p0 = td.Independent(torch.distributions.MultivariateNormal( 164 | torch.zeros_like(y).transpose(-1, -2), 165 | cov.unsqueeze(-3).repeat_interleave(self.dim, dim=-3), 166 | ), 1) 167 | log_prob = p0.log_prob(y.transpose(-1, -2)) - div.sum([-1, -2]) 168 | log_prob = log_prob / x.shape[-2] 169 | else: 170 | p0 = td.Independent(td.Normal(torch.zeros_like(y), torch.ones_like(y)), 1) 171 | log_prob = p0.log_prob(y) - div.sum(-1) 172 | 173 | return log_prob.unsqueeze(-1) 174 | 175 | @torch.no_grad() 176 | def sample( 177 | self, 178 | model: Callable, 179 | num_samples: int, 180 | device: str = None, 181 | use_ode: bool = True, 182 | **kwargs, 183 | ) -> TensorType['num_samples', 'dim']: 184 | if isinstance(num_samples, int): 185 | num_samples = (num_samples,) 186 | 187 | sampler = self.ode_sample if use_ode else self.sde_sample 188 | return sampler(model, num_samples, device, **kwargs) 189 | 190 | @torch.no_grad() 191 | def ode_sample( 192 | self, 193 | model: Callable, 194 | num_samples: int, 195 | device: str = None, 196 | **kwargs, 197 | ) -> TensorType['num_samples', 'dim']: 198 | if self.is_time_series: 199 | cov = self.noise.covariance(**kwargs) 200 | L = torch.linalg.cholesky(cov) 201 | else: 202 | L = None 203 | 204 | def drift(i, y): 205 | score = self._get_score(model, y, i=i, L=L, **kwargs) 206 | if self.is_time_series: 207 | # Have to include `cov` since g(t) = "scalar" * L @ dW 208 | score = cov @ score 209 | return -0.5 * self.beta_fn(i) * (y + score) 210 | 211 | x = self.noise(*num_samples, **kwargs).to(device) 212 | t = torch.Tensor([self.t1, 0]).to(device) 213 | y = odeint(drift, x, t, method='rk4', options={'step_size': .01})[1] 214 | # y = odeint(drift, x, t, rtol=1e-6, atol=1e-5)[1] 215 | 216 | return y 217 | 218 | @torch.no_grad() 219 | def sde_sample( 220 | self, 221 | model: Callable, 222 | num_samples: int, 223 | device: str = None, 224 | **kwargs, 225 | ) -> TensorType['num_samples', 'dim']: 226 | 227 | if self.is_time_series: 228 | cov = self.noise.covariance(**kwargs) 229 | L = torch.linalg.cholesky(cov) 230 | else: 231 | L = None 232 | 233 | is_time_series = self.is_time_series 234 | 235 | x = self.noise(*num_samples, **kwargs).to(device) 236 | shape = x.shape 237 | x = x.transpose(-2, -1).flatten(0, -2) 238 | 239 | class SDE(nn.Module): 240 | noise_type = 'general' if is_time_series else 'diagonal' 241 | sde_type = 'ito' 242 | 243 | def __init__(self, beta_fn, _get_score): 244 | super().__init__() 245 | self.beta_fn = beta_fn 246 | self._get_score = _get_score 247 | 248 | def f(self, i, inp): 249 | i = -i 250 | inp = inp.view(*shape) # Reshape back to original 251 | 252 | score = self._get_score(model, inp, i=i, L=L, **kwargs) 253 | if is_time_series: 254 | score = cov @ score 255 | 256 | dx = self.beta_fn(i) * (0.5 * inp + score) 257 | 258 | if is_time_series: 259 | return dx.transpose(-1, -2).flatten(0, -2) 260 | return dx.view(-1, shape[-1]) 261 | 262 | def g(self, i, inp): 263 | i = -i 264 | beta = -self.beta_fn(i).sqrt() 265 | 266 | if is_time_series: 267 | return (beta * L).repeat_interleave(shape[-1], dim=0) 268 | return beta.view(1, 1).repeat(np.prod(shape[:-1]), shape[-1]).to(device) 269 | 270 | sde = SDE(self.beta_fn, self._get_score) 271 | interval = torch.Tensor([-self.t1, 0]).to(device) # Time from -t1 to 0 272 | 273 | step_size = self.t1 / 100 274 | if not is_time_series: 275 | x = x.view(-1, shape[-1]) 276 | else: 277 | x = x.view(-1, shape[-2]) 278 | y = sdeint(sde, x, interval, dt=step_size)[-1] 279 | y = y.view(*shape) 280 | 281 | return y 282 | 283 | 284 | class ContinuousGaussianDiffusion(ContinuousDiffusion): 285 | """ Continuous diffusion using Gaussian noise """ 286 | def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise=None, **kwargs): 287 | super().__init__(dim, beta_fn, noise_fn=Normal(dim), predict_gaussian_noise=True, **kwargs) 288 | 289 | 290 | class ContinuousOUDiffusion(ContinuousDiffusion): 291 | """ Continuous diffusion using noise coming from an OU process """ 292 | def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise: bool = False, theta: float = 0.5, **kwargs): 293 | super().__init__( 294 | dim=dim, 295 | beta_fn=beta_fn, 296 | noise_fn=OrnsteinUhlenbeck(dim, theta=theta), 297 | predict_gaussian_noise=predict_gaussian_noise, 298 | is_time_series=True, 299 | **kwargs, 300 | ) 301 | 302 | 303 | class ContinuousGPDiffusion(ContinuousDiffusion): 304 | """ Continuous diffusion using noise coming from a Gaussian process """ 305 | def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise: bool = False, sigma: float = 0.1, **kwargs): 306 | super().__init__( 307 | dim=dim, 308 | beta_fn=beta_fn, 309 | noise_fn=GaussianProcess(dim, sigma=sigma), 310 | predict_gaussian_noise=predict_gaussian_noise, 311 | is_time_series=True, 312 | **kwargs, 313 | ) 314 | 315 | 316 | def divergence_approx(output, input, e, num_samples=1): 317 | out = 0 318 | for i in range(num_samples): 319 | out += torch.autograd.grad(output, input, e[i], create_graph=True)[0].detach() * e[i] 320 | return out / num_samples 321 | -------------------------------------------------------------------------------- /tsdiff/diffusion/discrete_diffusion.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Tuple, Union 2 | from torchtyping import TensorType 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributions as td 7 | 8 | from tsdiff.diffusion.noise import Normal, OrnsteinUhlenbeck, GaussianProcess 9 | 10 | 11 | class DiscreteDiffusion(nn.Module): 12 | """ 13 | Discrete diffusion (https://arxiv.org/abs/2006.11239) 14 | 15 | Args: 16 | dim: Dimension of data 17 | num_steps: Number of diffusion steps 18 | beta_fn: Scheduler for noise levels 19 | noise_fn: Type of noise 20 | parallel_elbo: Whether to compute ELBO in parallel or not 21 | """ 22 | def __init__( 23 | self, 24 | dim: int, 25 | num_steps: int, 26 | beta_fn: Callable, 27 | noise_fn: Callable, 28 | parallel_elbo: bool = False, 29 | is_time_series: bool = False, 30 | predict_gaussian_noise: bool = True, 31 | **kwargs, 32 | ): 33 | super().__init__() 34 | self.dim = dim 35 | self.num_steps = num_steps 36 | self.parallel_elbo = parallel_elbo 37 | self.is_time_series = is_time_series 38 | self.predict_gaussian_noise = predict_gaussian_noise 39 | 40 | self.betas = beta_fn(torch.linspace(0, 1, num_steps)) 41 | self.alphas = torch.cumprod(1 - self.betas, dim=0) 42 | 43 | self.noise = noise_fn 44 | 45 | def forward( 46 | self, 47 | x: TensorType[..., 'dim'], 48 | i: TensorType[..., 1], 49 | **kwargs, 50 | ) -> Tuple[TensorType[..., 'dim'], TensorType[..., 'dim']]: 51 | 52 | noise_gaussian = torch.randn_like(x) 53 | 54 | if self.is_time_series: 55 | cov = self.noise.covariance(**kwargs) 56 | L = torch.linalg.cholesky(cov) 57 | noise = L @ noise_gaussian 58 | else: 59 | noise = noise_gaussian 60 | 61 | alpha = self.alphas[i.long()].to(x) 62 | y = torch.sqrt(alpha) * x + torch.sqrt(1 - alpha) * noise 63 | 64 | if self.predict_gaussian_noise: 65 | return y, noise_gaussian 66 | else: 67 | return y, noise 68 | 69 | def get_loss( 70 | self, 71 | model: Callable, 72 | x: TensorType[..., 'dim'], 73 | **kwargs, 74 | ) -> TensorType[..., 'dim']: 75 | 76 | i = torch.randint(0, self.num_steps, size=(x.shape[0],)) 77 | i = i.view(-1, *(1,) * len(x.shape[1:])).expand_as(x[...,:1]).to(x) 78 | 79 | x_noisy, noise = self.forward(x, i, **kwargs) 80 | 81 | pred_noise = model(x_noisy, i=i, **kwargs) 82 | loss = (pred_noise - noise)**2 83 | 84 | return loss 85 | 86 | @torch.no_grad() 87 | def sample( 88 | self, 89 | model: Callable, 90 | num_samples: Union[int, Tuple], 91 | device: str = 'cpu', 92 | **kwargs, 93 | ) -> TensorType['*num_samples', 'dim']: 94 | if isinstance(num_samples, int): 95 | num_samples = (num_samples,) 96 | 97 | x = self.noise(*num_samples, **kwargs).to(device) 98 | 99 | if self.is_time_series and self.predict_gaussian_noise: 100 | cov = self.noise.covariance(**kwargs) 101 | L = torch.linalg.cholesky(cov) 102 | else: 103 | L = None 104 | 105 | for diff_step in reversed(range(0, self.num_steps)): 106 | alpha = self.alphas[diff_step] 107 | beta = self.betas[diff_step] 108 | 109 | # An alternative can be: 110 | # alpha_prev = self.alphas[diff_step - 1] 111 | # sigma = beta * (1 - alpha_prev) / (1 - alpha) 112 | sigma = beta 113 | 114 | if diff_step == 0: 115 | z = 0 116 | else: 117 | z = self.noise(*num_samples, **kwargs).to(device) 118 | 119 | i = torch.Tensor([diff_step]).expand_as(x[...,:1]).to(device) 120 | pred_noise = model(x, i=i, **kwargs) 121 | 122 | if L is not None: 123 | pred_noise = L @ pred_noise 124 | 125 | x = (x - beta * pred_noise / (1 - alpha).sqrt()) / (1 - beta).sqrt() + sigma.sqrt() * z 126 | 127 | return x 128 | 129 | @torch.no_grad() 130 | def log_prob( 131 | self, 132 | model: Callable, 133 | x: TensorType[..., 'dim'], 134 | num_samples: int = 1, 135 | **kwargs, 136 | ) -> TensorType[..., 1]: 137 | if self.is_time_series and self.predict_gaussian_noise: 138 | cov = self.noise.covariance(**kwargs) 139 | L = torch.linalg.cholesky(cov) 140 | else: 141 | L = None 142 | 143 | func = self._elbo_parallel if self.parallel_elbo else self._elbo_sequential 144 | return func(model, x, num_samples=num_samples, L=L, **kwargs) 145 | 146 | def _elbo_parallel( 147 | self, 148 | model: Callable, 149 | x: TensorType[..., 'dim'], 150 | L: TensorType[..., 'seq_len', 'seq_len'], 151 | num_samples: int = 1, 152 | **kwargs, 153 | ) -> TensorType[..., 1]: 154 | """ 155 | Computes ELBO over all diffusion steps in parallel, 156 | then averages over `num_samples` runs. 157 | If diffusion `num_steps` large (and `num_samples` small) 158 | it will be heavy on the GPU memory. 159 | 160 | Args: 161 | model: Denoising diffusion model 162 | x: Clean input data 163 | num_samples: How many times to compute ELBO, final 164 | result is averaged over all ELBO samples 165 | **kwargs: Can be time, latent etc. depending on a model 166 | """ 167 | elbo = 0 168 | 169 | i = expand_to_x(torch.arange(self.num_steps), x).expand(-1, *x[...,:1].shape).contiguous() 170 | alphas = expand_to_x(self.alphas, x) 171 | betas = expand_to_x(self.betas, x) 172 | 173 | xt, kwargs = expand_x_and_kwargs(x, kwargs, self.num_steps) 174 | 175 | for _ in range(num_samples): 176 | # Get diffused outputs 177 | xt, _ = self.forward(x, i, **kwargs) # [num_steps, ..., dim] 178 | 179 | # Output predicted noise 180 | epsilon = model(xt, i=i, **kwargs) 181 | 182 | if L is not None: 183 | epsilon = L @ epsilon 184 | 185 | # p(x_{t-1} | p_t) 186 | p_mu = get_p_mu(xt, betas, alphas, epsilon) 187 | px = td.Independent(td.Normal(p_mu[1:], betas[1:].sqrt()), 1) 188 | 189 | # p(x_0 | x_1) 190 | log_prob_x0_x1 = td.Independent(td.Normal(p_mu[0], betas[0].sqrt()), 1).log_prob(x) 191 | assert log_prob_x0_x1.shape == x.shape[:-1] 192 | 193 | # q(x_{t-1} | x_0, x_t), t > 1 194 | qx = get_qx(x.unsqueeze(0), xt[1:], alphas[1:], alphas[:-1], betas[1:]) 195 | 196 | # KL[q(x_{t-1} | p_t) || p(x_{t-1} | p_t)] 197 | kl_q_p = td.kl_divergence(qx, px).sum(0) 198 | assert kl_q_p.shape == x.shape[:-1] 199 | 200 | # ELBO 201 | elbo_contribution = (log_prob_x0_x1 - kl_q_p) / num_samples 202 | elbo += elbo_contribution 203 | 204 | elbo = reduce_elbo(elbo, x) 205 | return elbo 206 | 207 | def _elbo_sequential( 208 | self, 209 | model: Callable, 210 | x: TensorType[..., 'dim'], 211 | L: TensorType[..., 'seq_len', 'seq_len'], 212 | num_samples: int = 1, 213 | **kwargs, 214 | ) -> TensorType[..., 1]: 215 | """ 216 | Computes ELBO as a sum of diffusion steps - sequentially. 217 | 218 | Args: 219 | model: Denoising diffusion model 220 | x: Clean input data 221 | num_samples: How many times to compute ELBO, final 222 | result is averaged over all ELBO samples 223 | **kwargs: Can be time, latent etc. depending on a model 224 | """ 225 | elbo = 0 226 | 227 | x, kwargs = expand_x_and_kwargs(x, kwargs, num_samples) 228 | 229 | for i in range(self.num_steps): 230 | # Prepare variables 231 | beta = self.betas[i].to(x) 232 | alpha = self.alphas[i].to(x) 233 | step = torch.Tensor([i]).expand_as(x[...,:1]).to(x) 234 | 235 | # Diffuse and predict noise 236 | xt, _ = self.forward(x, i=step, **kwargs) 237 | epsilon = model(xt, i=step, **kwargs) 238 | 239 | if L is not None: 240 | epsilon = L @ epsilon 241 | 242 | assert xt.shape == x.shape == epsilon.shape 243 | 244 | # p(x_{t-1} | p_t) 245 | p_mu = get_p_mu(xt, beta, alpha, epsilon) 246 | px = td.Independent(td.Normal(p_mu, beta.sqrt()), 1) 247 | 248 | if i == 0: 249 | elbo = elbo + px.log_prob(x).mean(0) 250 | else: 251 | prev_alpha = self.alphas[i - 1] 252 | 253 | # q(x_{t-1} | x_0, x_t), t > 1 254 | qx = get_qx(x, xt, alpha, prev_alpha, beta) 255 | 256 | # KL[q(x_{t-1} | p_t) || p(x_{t-1} | p_t)] 257 | kl = td.kl_divergence(qx, px).mean(0) 258 | elbo = elbo - kl 259 | 260 | elbo = reduce_elbo(elbo, x) 261 | return elbo 262 | 263 | 264 | class GaussianDiffusion(DiscreteDiffusion): 265 | """ Discrete diffusion with Gaussian noise """ 266 | def __init__(self, dim: int, num_steps: int, beta_fn: Callable, **kwargs): 267 | super().__init__(dim, num_steps, beta_fn, noise_fn=Normal(dim), **kwargs) 268 | 269 | 270 | class OUDiffusion(DiscreteDiffusion): 271 | """ Discrete diffusion with noise coming from an OU process """ 272 | def __init__( 273 | self, 274 | dim: int, 275 | num_steps: int, 276 | beta_fn: Callable, 277 | predict_gaussian_noise: bool, 278 | theta: float = 0.5, 279 | **kwargs, 280 | ): 281 | super().__init__( 282 | dim=dim, 283 | num_steps=num_steps, 284 | beta_fn=beta_fn, 285 | noise_fn=OrnsteinUhlenbeck(dim, theta=theta), 286 | is_time_series=True, 287 | predict_gaussian_noise=predict_gaussian_noise, 288 | **kwargs, 289 | ) 290 | 291 | 292 | class GPDiffusion(DiscreteDiffusion): 293 | """ Discrete diffusion with noise coming from a Gaussian process """ 294 | def __init__( 295 | self, 296 | dim: int, 297 | num_steps: int, 298 | beta_fn: Callable, 299 | predict_gaussian_noise: bool, 300 | sigma: float = 0.1, 301 | **kwargs, 302 | ): 303 | super().__init__( 304 | dim=dim, 305 | num_steps=num_steps, 306 | beta_fn=beta_fn, 307 | noise_fn=GaussianProcess(dim, sigma=sigma), 308 | is_time_series=True, 309 | predict_gaussian_noise=predict_gaussian_noise, 310 | **kwargs, 311 | ) 312 | 313 | 314 | def expand_to_x(inputs, x): 315 | return inputs.view(-1, *(1,) * len(x.shape)).to(x) 316 | 317 | def expand_x_and_kwargs(x, kwargs, N): 318 | # Expand dimensions 319 | x = x.unsqueeze(0).repeat_interleave(N, dim=0) 320 | 321 | # A hacky solution to repeat dimensions in all kwargs (latent, t, etc.) 322 | for key, value in kwargs.items(): 323 | if torch.is_tensor(value): 324 | kwargs[key] = value.unsqueeze(0).repeat_interleave(N, dim=0) 325 | 326 | return x, kwargs 327 | 328 | def reduce_elbo( 329 | elbo: TensorType['batch', Any], 330 | x: TensorType[Any], 331 | ) -> TensorType['batch', 1]: 332 | # Reduce ELBO over all but batch dimension: (B, ...) -> (B,) 333 | elbo = elbo.view(elbo.shape[0], -1).sum(1) 334 | 335 | if len(x.shape) > 2: 336 | elbo = elbo / x.shape[-2] 337 | 338 | return elbo.unsqueeze(1) 339 | 340 | def get_p_mu(xt, beta, alpha, epsilon): 341 | mu = 1 / (1 - beta).sqrt() * (xt - beta / (1 - alpha).sqrt() * epsilon) 342 | return mu 343 | 344 | def get_qx(x, xt, alpha, prev_alpha, beta): 345 | q_mu_1 = torch.sqrt(prev_alpha) * beta / (1 - alpha) * x 346 | q_mu_2 = torch.sqrt(1 - beta) * (1 - prev_alpha) / (1 - alpha) * xt 347 | q_mu = q_mu_1 + q_mu_2 348 | 349 | q_sigma = beta * (1 - prev_alpha) / (1 - alpha) 350 | 351 | qx = td.Independent(td.Normal(q_mu, q_sigma.expand_as(q_mu).sqrt()), 1) 352 | return qx 353 | -------------------------------------------------------------------------------- /tsdiff/diffusion/noise.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from torchtyping import TensorType 3 | 4 | import numpy as np 5 | import scipy.fftpack 6 | from functools import lru_cache 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class Normal(nn.Module): 13 | def __init__(self, dim: int, **kwargs): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | def forward(self, *shape, **kwargs): 18 | return torch.randn(*shape, self.dim) 19 | 20 | def covariance(self, **kwargs): 21 | return torch.eye(self.dim) 22 | 23 | 24 | class Wiener(nn.Module): 25 | """ 26 | Wiener process / Brownian motion. 27 | """ 28 | def __init__(self, dim: int): 29 | super().__init__() 30 | self.dim = dim 31 | 32 | def forward( 33 | self, 34 | t: Union[TensorType['seq_len'], TensorType[..., 'seq_len', 1]], 35 | **kwargs, 36 | ) -> Union[TensorType['seq_len'], TensorType[..., 'seq_len', 'dim']]: 37 | one_dimensional = len(t.shape) == 1 38 | 39 | if one_dimensional: 40 | t = t.unsqueeze(-1) 41 | t = t.repeat_interleave(self.dim, dim=-1) 42 | 43 | dt = torch.diff(t, dim=-2, prepend=torch.zeros_like(t[...,:1,:]).to(t)) 44 | dw = torch.randn_like(dt) * dt.clamp(1e-5).sqrt() 45 | w = dw.cumsum(dim=-2) 46 | 47 | if one_dimensional and self.dim == 1: 48 | w = w.squeeze(-1) 49 | return w 50 | 51 | 52 | class OrnsteinUhlenbeck(nn.Module): 53 | """ 54 | Ornstein-Uhlenbeck process. 55 | 56 | Args: 57 | theta: Diffusion param, higher value = spikier (float) 58 | """ 59 | def __init__(self, dim: int, theta: float = 0.5): 60 | super().__init__() 61 | self.dim = dim 62 | self.theta = theta 63 | self.wiener = Wiener(dim) 64 | 65 | def forward( 66 | self, 67 | *args, 68 | t: TensorType[..., 'seq_len', 1], 69 | **kwargs, 70 | ) -> TensorType[..., 'seq_len', 'dim']: 71 | 72 | delta = torch.diff(t, dim=-2, prepend=torch.zeros_like(t[...,:1,:])) 73 | coeff = torch.exp(-self.theta * delta) 74 | 75 | sample = [] 76 | 77 | x = torch.randn(*t.shape[:-2], 1, self.dim).to(t) 78 | for i in range(coeff.shape[-2]): 79 | z = torch.randn(*t.shape[:-2], 1, self.dim).to(t) 80 | c = coeff[...,i,None,:] 81 | x = c * x + torch.sqrt(1 - c**2) * z 82 | sample.append(x) 83 | 84 | sample = torch.cat(sample, dim=-2) 85 | return sample 86 | 87 | def covariance( 88 | self, 89 | t: TensorType[..., 'seq_len', 1], 90 | diag_epsilon: float = 1e-4, 91 | **kwargs, 92 | ) -> TensorType[..., 'seq_len', 'seq_len']: 93 | t = t.squeeze(-1) 94 | diag = torch.eye(t.shape[-1]).to(t) * diag_epsilon 95 | cov = torch.exp(-(t.unsqueeze(-1) - t.unsqueeze(-2)).abs() * self.theta) 96 | return cov + diag 97 | 98 | def covariance_cholesky(self, t: TensorType[..., 'seq_len', 1]) -> TensorType[..., 'seq_len', 'seq_len']: 99 | return torch.linalg.cholesky(self.covariance(t)) 100 | 101 | def covariance_inverse(self, t: TensorType[..., 'seq_len', 1]) -> TensorType[..., 'seq_len', 'seq_len']: 102 | return torch.linalg.inv(self.covariance(t)) 103 | 104 | 105 | class GaussianProcess(nn.Module): 106 | """ 107 | Gaussian random field for one-dimensional (temporal) data. 108 | """ 109 | def __init__(self, dim: int, sigma: float = 0.1): 110 | super().__init__() 111 | self.dim = dim 112 | self.sigma = sigma 113 | 114 | def forward( 115 | self, 116 | *args, 117 | t: TensorType[..., 'N', 1], 118 | **kwargs, 119 | ) -> TensorType[..., 'N', 'dim']: 120 | # If N is very large this could become slow 121 | # In that case, consider using sparse GP 122 | L = self.covariance_cholesky(t) 123 | e = torch.randn(*t.shape[:-1], self.dim).to(t) 124 | return L @ e 125 | 126 | def covariance( 127 | self, 128 | t: TensorType[..., 'N', 1], 129 | diag_epsilon: float = 1e-4, 130 | **kwargs, 131 | ) -> TensorType[..., 'N', 'N']: 132 | if t.shape[-1] != 1 or len(t.shape) < 2: 133 | t = t.unsqueeze(-1) 134 | distance = t - t.transpose(-1, -2) 135 | diag = torch.eye(t.shape[-2]).to(t) * diag_epsilon 136 | return torch.exp(-torch.square(distance / self.sigma)) + diag 137 | 138 | def covariance_cholesky(self, t: TensorType[..., 'N', 1]) -> TensorType[..., 'N', 'N']: 139 | return torch.linalg.cholesky(self.covariance(t)) 140 | 141 | def covariance_inverse(self, t: TensorType[..., 'N', 1]) -> TensorType[..., 'N', 'N']: 142 | return torch.linalg.inv(self.covariance(t)) 143 | -------------------------------------------------------------------------------- /tsdiff/forecasting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbilos/tsdiff/32b23f2b7f5ec4d68bc533dda8a0096086cbd0ab/tsdiff/forecasting/__init__.py -------------------------------------------------------------------------------- /tsdiff/forecasting/experiment.py: -------------------------------------------------------------------------------- 1 | import seml 2 | from sacred import Experiment 3 | from tsdiff.forecasting.train import train 4 | 5 | ex = Experiment() 6 | seml.setup_logger(ex) 7 | 8 | @ex.config 9 | def config(): 10 | overwrite = None 11 | db_collection = None 12 | if db_collection is not None: 13 | ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite)) 14 | 15 | @ex.automain 16 | def run( 17 | seed: int, 18 | dataset: str, 19 | network: str, 20 | noise: bool, 21 | diffusion_steps: str, 22 | epochs: int, 23 | batch_size: int, 24 | learning_rate: float, 25 | num_cells: int, 26 | hidden_dim: int, 27 | residual_layers: int, 28 | ): 29 | results = train(**locals()) 30 | return results 31 | -------------------------------------------------------------------------------- /tsdiff/forecasting/experiment.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | executable: experiments/forecasting/experiment.py 3 | name: tsdiff_forecast 4 | output_dir: experiments/logs 5 | project_root_dir: ../../ 6 | conda_environment: place_your_env_here 7 | 8 | slurm: 9 | experiments_per_job: 1 10 | sbatch_options: 11 | gres: gpu:1 # num GPUs 12 | mem: 16G # memory 13 | cpus-per-task: 2 # num cores 14 | time: 0-08:00 # max time, D-HH:MM 15 | partition: gpu_all 16 | 17 | 18 | fixed: 19 | epochs: 10 20 | learning_rate: 1e-3 21 | batch_size: 64 22 | 23 | grid: 24 | seed: 25 | type: range 26 | min: 1 27 | max: 4 28 | step: 1 29 | 30 | dataset: 31 | type: choice 32 | options: 33 | - electricity_nips 34 | - solar_nips 35 | - traffic_nips 36 | - exchange_rate_nips 37 | 38 | diffusion_steps: 39 | type: choice 40 | options: 41 | - 100 42 | 43 | num_cells: 44 | type: choice 45 | options: 46 | - 40 47 | 48 | hidden_dim: 49 | type: choice 50 | options: 51 | - 100 52 | 53 | residual_layers: 54 | type: choice 55 | options: 56 | - 8 57 | 58 | old: 59 | grid: 60 | network: 61 | type: choice 62 | options: 63 | - timegrad 64 | - timegrad_old 65 | 66 | noise: 67 | type: choice 68 | options: 69 | - normal 70 | 71 | cnn: 72 | grid: 73 | network: 74 | type: choice 75 | options: 76 | - timegrad_cnn 77 | 78 | noise: 79 | type: choice 80 | options: 81 | - normal 82 | - ou 83 | - gp 84 | -------------------------------------------------------------------------------- /tsdiff/forecasting/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .score_estimator import * 2 | from .score_network import * 3 | from .time_grad_network import * 4 | -------------------------------------------------------------------------------- /tsdiff/forecasting/models/score_estimator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional 2 | 3 | import torch 4 | 5 | from gluonts.dataset.field_names import FieldName 6 | from gluonts.time_feature import TimeFeature 7 | from gluonts.torch.model.predictor import PyTorchPredictor 8 | from gluonts.torch.util import copy_parameters 9 | from gluonts.model.predictor import Predictor 10 | from gluonts.transform import ( 11 | Transformation, 12 | Chain, 13 | InstanceSplitter, 14 | ExpectedNumInstanceSampler, 15 | ValidationSplitSampler, 16 | TestSplitSampler, 17 | RenameFields, 18 | AsNumpyArray, 19 | ExpandDimArray, 20 | AddObservedValuesIndicator, 21 | AddTimeFeatures, 22 | VstackFeatures, 23 | SetFieldIfNotPresent, 24 | TargetDimIndicator, 25 | ) 26 | from gluonts.core.component import validated 27 | 28 | from pts.feature import ( 29 | fourier_time_features_from_frequency, 30 | lags_for_fourier_time_features_from_frequency, 31 | ) 32 | from pts.model import PyTorchEstimator 33 | from pts.model.utils import get_module_forward_input_names 34 | 35 | from tsdiff.utils import TrainerForecasting 36 | 37 | 38 | class ScoreEstimator(PyTorchEstimator): 39 | def __init__( 40 | self, 41 | training_net: Callable, 42 | prediction_net: Callable, 43 | noise: str, 44 | input_size: int, 45 | freq: str, 46 | prediction_length: int, 47 | target_dim: int, 48 | trainer: TrainerForecasting = TrainerForecasting(), 49 | context_length: Optional[int] = None, 50 | num_layers: int = 2, 51 | num_cells: int = 40, 52 | cell_type: str = "GRU", 53 | num_parallel_samples: int = 100, 54 | dropout_rate: float = 0.1, 55 | cardinality: List[int] = [1], 56 | embedding_dimension: int = 5, 57 | hidden_dim: int = 100, 58 | diff_steps: int = 100, 59 | loss_type: str = "l2", 60 | beta_end=0.1, 61 | beta_schedule="linear", 62 | residual_layers=8, 63 | residual_channels=8, 64 | dilation_cycle_length=2, 65 | scaling: bool = True, 66 | pick_incomplete: bool = False, 67 | lags_seq: Optional[List[int]] = None, 68 | time_features: Optional[List[TimeFeature]] = None, 69 | old: bool = False, 70 | time_feat_dim: int = 4, 71 | **kwargs, 72 | ) -> None: 73 | super().__init__(trainer=trainer, **kwargs) 74 | 75 | self.training_net = training_net 76 | self.prediction_net = prediction_net 77 | self.noise = noise 78 | 79 | self.old = old 80 | 81 | self.freq = freq 82 | self.context_length = context_length if context_length is not None else prediction_length 83 | 84 | self.input_size = input_size 85 | self.prediction_length = prediction_length 86 | self.target_dim = target_dim 87 | self.time_feat_dim = time_feat_dim 88 | self.num_layers = num_layers 89 | self.num_cells = num_cells 90 | self.cell_type = cell_type 91 | self.num_parallel_samples = num_parallel_samples 92 | self.dropout_rate = dropout_rate 93 | self.cardinality = cardinality 94 | self.embedding_dimension = embedding_dimension 95 | 96 | self.conditioning_length = hidden_dim 97 | self.diff_steps = diff_steps 98 | self.loss_type = loss_type 99 | self.beta_end = beta_end 100 | self.beta_schedule = beta_schedule 101 | self.residual_layers = residual_layers 102 | self.residual_channels = residual_channels 103 | self.dilation_cycle_length = dilation_cycle_length 104 | 105 | self.lags_seq = ( 106 | lags_seq 107 | if lags_seq is not None 108 | else lags_for_fourier_time_features_from_frequency(freq_str=freq) 109 | ) 110 | 111 | self.time_features = ( 112 | time_features 113 | if time_features is not None 114 | else fourier_time_features_from_frequency(self.freq) 115 | ) 116 | 117 | self.history_length = self.context_length + max(self.lags_seq) 118 | self.pick_incomplete = pick_incomplete 119 | self.scaling = scaling 120 | 121 | self.train_sampler = ExpectedNumInstanceSampler( 122 | num_instances=1.0, 123 | min_past=0 if pick_incomplete else self.history_length, 124 | min_future=prediction_length, 125 | ) 126 | 127 | self.validation_sampler = ValidationSplitSampler( 128 | min_past=0 if pick_incomplete else self.history_length, 129 | min_future=prediction_length, 130 | ) 131 | 132 | def create_transformation(self) -> Transformation: 133 | return Chain( 134 | [ 135 | AsNumpyArray( 136 | field=FieldName.TARGET, 137 | expected_ndim=2, 138 | ), 139 | # maps the target to (1, T) 140 | # if the target data is uni dimensional 141 | ExpandDimArray( 142 | field=FieldName.TARGET, 143 | axis=None, 144 | ), 145 | AddObservedValuesIndicator( 146 | target_field=FieldName.TARGET, 147 | output_field=FieldName.OBSERVED_VALUES, 148 | ), 149 | AddTimeFeatures( 150 | start_field=FieldName.START, 151 | target_field=FieldName.TARGET, 152 | output_field=FieldName.FEAT_TIME, 153 | time_features=self.time_features, 154 | pred_length=self.prediction_length, 155 | ), 156 | VstackFeatures( 157 | output_field=FieldName.FEAT_TIME, 158 | input_fields=[FieldName.FEAT_TIME], 159 | ), 160 | SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), 161 | TargetDimIndicator( 162 | field_name="target_dimension_indicator", 163 | target_field=FieldName.TARGET, 164 | ), 165 | AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), 166 | ] 167 | ) 168 | 169 | def create_instance_splitter(self, mode: str): 170 | assert mode in ["training", "validation", "test"] 171 | 172 | instance_sampler = { 173 | "training": self.train_sampler, 174 | "validation": self.validation_sampler, 175 | "test": TestSplitSampler(), 176 | }[mode] 177 | 178 | return InstanceSplitter( 179 | target_field=FieldName.TARGET, 180 | is_pad_field=FieldName.IS_PAD, 181 | start_field=FieldName.START, 182 | forecast_start_field=FieldName.FORECAST_START, 183 | instance_sampler=instance_sampler, 184 | past_length=self.history_length, 185 | future_length=self.prediction_length, 186 | time_series_fields=[ 187 | FieldName.FEAT_TIME, 188 | FieldName.OBSERVED_VALUES, 189 | ], 190 | ) + ( 191 | RenameFields( 192 | { 193 | f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", 194 | f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", 195 | } 196 | ) 197 | ) 198 | 199 | def create_training_network(self, device: torch.device): 200 | return self.training_net( 201 | noise=self.noise, 202 | input_size=self.input_size, 203 | target_dim=self.target_dim, 204 | num_layers=self.num_layers, 205 | num_cells=self.num_cells, 206 | cell_type=self.cell_type, 207 | history_length=self.history_length, 208 | context_length=self.context_length, 209 | prediction_length=self.prediction_length, 210 | dropout_rate=self.dropout_rate, 211 | cardinality=self.cardinality, 212 | embedding_dimension=self.embedding_dimension, 213 | diff_steps=self.diff_steps, 214 | loss_type=self.loss_type, 215 | beta_end=self.beta_end, 216 | beta_schedule=self.beta_schedule, 217 | residual_layers=self.residual_layers, 218 | residual_channels=self.residual_channels, 219 | dilation_cycle_length=self.dilation_cycle_length, 220 | lags_seq=self.lags_seq, 221 | scaling=self.scaling, 222 | conditioning_length=self.conditioning_length, 223 | time_feat_dim=self.time_feat_dim, 224 | ).to(device) 225 | 226 | def create_predictor( 227 | self, 228 | transformation: Transformation, 229 | trained_network: Any, 230 | device: torch.device, 231 | ) -> Predictor: 232 | prediction_network = self.prediction_net( 233 | noise=self.noise, 234 | input_size=self.input_size, 235 | target_dim=self.target_dim, 236 | num_layers=self.num_layers, 237 | num_cells=self.num_cells, 238 | cell_type=self.cell_type, 239 | history_length=self.history_length, 240 | context_length=self.context_length, 241 | prediction_length=self.prediction_length, 242 | dropout_rate=self.dropout_rate, 243 | cardinality=self.cardinality, 244 | embedding_dimension=self.embedding_dimension, 245 | diff_steps=self.diff_steps, 246 | loss_type=self.loss_type, 247 | beta_end=self.beta_end, 248 | beta_schedule=self.beta_schedule, 249 | residual_layers=self.residual_layers, 250 | residual_channels=self.residual_channels, 251 | dilation_cycle_length=self.dilation_cycle_length, 252 | lags_seq=self.lags_seq, 253 | scaling=self.scaling, 254 | conditioning_length=self.conditioning_length, 255 | num_parallel_samples=self.num_parallel_samples, 256 | time_feat_dim=self.time_feat_dim, 257 | ).to(device) 258 | 259 | copy_parameters(trained_network, prediction_network) 260 | input_names = get_module_forward_input_names(prediction_network) 261 | prediction_splitter = self.create_instance_splitter("test") 262 | 263 | return PyTorchPredictor( 264 | input_transform=transformation + prediction_splitter, 265 | input_names=input_names, 266 | prediction_net=prediction_network, 267 | batch_size=self.trainer.batch_size, 268 | freq=self.freq, 269 | prediction_length=self.prediction_length, 270 | device=device, 271 | ) 272 | -------------------------------------------------------------------------------- /tsdiff/forecasting/models/score_network.py: -------------------------------------------------------------------------------- 1 | from torchtyping import TensorType 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from pts.model import weighted_average 8 | from pts.modules import MeanScaler 9 | 10 | from tsdiff.diffusion import OUDiffusion, BetaLinear 11 | 12 | 13 | class ScoreTrainingNetwork(nn.Module): 14 | """ 15 | Score training network. 16 | 17 | Args: 18 | context_length: Size of history 19 | prediction_length: Size of prediction 20 | target_dim: Dimension of data 21 | time_feat_dim: Dimension of covariates 22 | conditioning_length: Hidden dimension 23 | beta_end: Final diffusion scale 24 | diff_steps: Number of diffusion steps 25 | residual_layers: Number of residual layers 26 | residual_channels: Number of residual channels 27 | dilation_cycle_length: Dilation cycle length 28 | """ 29 | def __init__( 30 | self, 31 | context_length: int, 32 | prediction_length: int, 33 | target_dim: int, 34 | time_feat_dim: int, 35 | conditioning_length: int, 36 | beta_end: float, 37 | diff_steps: int, 38 | residual_layers: int, 39 | residual_channels: int, 40 | dilation_cycle_length: int, 41 | **kwargs, 42 | ): 43 | super().__init__() 44 | self.context_length = context_length 45 | self.prediction_length = prediction_length 46 | 47 | # hidden_dim = conditioning_length 48 | # self.context_rnn = nn.GRU(target_dim + time_feat_dim, hidden_dim, num_layers=2, bidirectional=True, batch_first=True) 49 | 50 | self.diffusion = OUDiffusion(target_dim, BetaLinear(1e-4, beta_end), diff_steps) 51 | self.denoise_fn = DenoisingModel( 52 | dim=target_dim + time_feat_dim, 53 | residual_channels=residual_channels, 54 | latent_dim=conditioning_length, 55 | residual_hidden=conditioning_length, 56 | ) 57 | 58 | self.scaler = MeanScaler(keepdim=True) 59 | 60 | def forward( 61 | self, 62 | target_dimension_indicator: TensorType['batch', 'dim'], 63 | past_time_feat: TensorType['batch', 'history_length', 'feat_dim'], 64 | past_target_cdf: TensorType['batch', 'history_length', 'dim'], 65 | past_observed_values: TensorType['batch', 'history_length', 'dim'], 66 | past_is_pad: TensorType['batch', 'history_length'], 67 | future_time_feat: TensorType['batch', 'prediction_length', 'feat_dim'], 68 | future_target_cdf: TensorType['batch', 'prediction_length', 'dim'], 69 | future_observed_values: TensorType['batch', 'prediction_length', 'dim'], 70 | ) -> TensorType[()]: 71 | 72 | past_time_feat = past_time_feat[...,-self.context_length:,:] 73 | past_target_cdf = past_target_cdf[...,-self.context_length:,:] 74 | past_observed_values = past_observed_values[...,-self.context_length:,:] 75 | past_is_pad = past_is_pad[...,-self.context_length:] 76 | 77 | past_observed_values = torch.min(past_observed_values, 1 - past_is_pad.unsqueeze(-1)) 78 | _, scale = self.scaler(past_target_cdf, past_observed_values) 79 | 80 | history = past_target_cdf / scale 81 | target = future_target_cdf / scale 82 | 83 | t = torch.arange(self.prediction_length).view(1, -1, 1).repeat(target.shape[0], 1, 1).to(target) 84 | 85 | loss = self.diffusion.get_loss(self.denoise_fn, target, t=t, history=history, covariates=future_time_feat) 86 | 87 | loss_weights, _ = future_observed_values.min(dim=-1, keepdim=True) 88 | loss = weighted_average(loss, weights=loss_weights, dim=1) 89 | 90 | return loss.mean() 91 | 92 | 93 | class ScorePredictionNetwork(ScoreTrainingNetwork): 94 | def __init__(self, num_parallel_samples: int, **kwargs) -> None: 95 | super().__init__(**kwargs) 96 | self.num_parallel_samples = num_parallel_samples 97 | 98 | def forward( 99 | self, 100 | target_dimension_indicator: TensorType['batch', 'dim'], 101 | past_time_feat: TensorType['batch', 'history_length', 'feat_dim'], 102 | past_target_cdf: TensorType['batch', 'history_length', 'dim'], 103 | past_observed_values: TensorType['batch', 'history_length', 'dim'], 104 | past_is_pad: TensorType['batch', 'history_length'], 105 | future_time_feat: TensorType['batch', 'prediction_length', 'feat_dim'], 106 | ) -> TensorType['batch', 'num_samples', 'prediction_length', 'dim']: 107 | 108 | past_observed_values = torch.min(past_observed_values, 1 - past_is_pad.unsqueeze(-1)) 109 | 110 | rnn_states, scale = self.get_rnn_state( 111 | past_time_feat=past_time_feat, 112 | past_target_cdf=past_target_cdf, 113 | past_observed_values=past_observed_values, 114 | future_time_feat=future_time_feat, 115 | ) 116 | 117 | t = torch.arange(self.prediction_length).view(1, -1, 1) 118 | t = t.repeat(rnn_states.shape[0] * self.num_parallel_samples, 1, 1).to(rnn_states) 119 | 120 | rnn_states = rnn_states.repeat_interleave(self.num_parallel_samples, dim=0) 121 | 122 | samples = self.diffusion.sample(self.denoise_fn, t=t, latent=rnn_states) 123 | samples = samples.unflatten(0, (-1, self.num_parallel_samples)) * scale.unsqueeze(1) 124 | 125 | return samples 126 | -------------------------------------------------------------------------------- /tsdiff/forecasting/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | from copy import deepcopy 5 | 6 | from gluonts.dataset.multivariate_grouper import MultivariateGrouper 7 | from gluonts.dataset.repository.datasets import get_dataset 8 | from gluonts.evaluation.backtest import make_evaluation_predictions 9 | from gluonts.evaluation import MultivariateEvaluator 10 | 11 | from tsdiff.forecasting.models import ( 12 | ScoreEstimator, 13 | TimeGradTrainingNetwork_AutoregressiveOld, TimeGradPredictionNetwork_AutoregressiveOld, 14 | TimeGradTrainingNetwork_Autoregressive, TimeGradPredictionNetwork_Autoregressive, 15 | TimeGradTrainingNetwork_All, TimeGradPredictionNetwork_All, 16 | TimeGradTrainingNetwork_RNN, TimeGradPredictionNetwork_RNN, 17 | TimeGradTrainingNetwork_Transformer, TimeGradPredictionNetwork_Transformer, 18 | TimeGradTrainingNetwork_CNN, TimeGradPredictionNetwork_CNN, 19 | ) 20 | from tsdiff.utils import NotSupportedModelNoiseCombination, TrainerForecasting 21 | 22 | import warnings 23 | warnings.simplefilter(action='ignore', category=FutureWarning) 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | 27 | def energy_score(forecast, target): 28 | obs_dist = np.mean(np.linalg.norm((forecast - target), axis=-1)) 29 | pair_dist = np.mean( 30 | np.linalg.norm(forecast[:, np.newaxis, ...] - forecast, axis=-1) 31 | ) 32 | return obs_dist - pair_dist * 0.5 33 | 34 | def train( 35 | seed: int, 36 | dataset: str, 37 | network: str, 38 | noise: str, 39 | diffusion_steps: int, 40 | epochs: int, 41 | learning_rate: float, 42 | batch_size: int, 43 | num_cells: int, 44 | hidden_dim: int, 45 | residual_layers: int, 46 | ): 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | 50 | covariance_dim = 4 if dataset != 'exchange_rate_nips' else -4 51 | 52 | # Load data 53 | dataset = get_dataset(dataset, regenerate=False) 54 | 55 | target_dim = int(dataset.metadata.feat_static_cat[0].cardinality) 56 | 57 | train_grouper = MultivariateGrouper(max_target_dim=min(2000, target_dim)) 58 | test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test) / len(dataset.train)), max_target_dim=min(2000, target_dim)) 59 | dataset_train = train_grouper(dataset.train) 60 | dataset_test = test_grouper(dataset.test) 61 | 62 | val_window = 20 * dataset.metadata.prediction_length 63 | dataset_train = list(dataset_train) 64 | dataset_val = [] 65 | for i in range(len(dataset_train)): 66 | x = deepcopy(dataset_train[i]) 67 | x['target'] = x['target'][:,-val_window:] 68 | dataset_val.append(x) 69 | dataset_train[i]['target'] = dataset_train[i]['target'][:,:-val_window] 70 | 71 | # Load model 72 | if network == 'timegrad': 73 | if noise != 'normal': 74 | raise NotSupportedModelNoiseCombination 75 | training_net, prediction_net = TimeGradTrainingNetwork_Autoregressive, TimeGradPredictionNetwork_Autoregressive 76 | elif network == 'timegrad_old': 77 | if noise != 'normal': 78 | raise NotSupportedModelNoiseCombination 79 | training_net, prediction_net = TimeGradTrainingNetwork_AutoregressiveOld, TimeGradPredictionNetwork_AutoregressiveOld 80 | elif network == 'timegrad_all': 81 | training_net, prediction_net = TimeGradTrainingNetwork_All, TimeGradPredictionNetwork_All 82 | elif network == 'timegrad_rnn': 83 | training_net, prediction_net = TimeGradTrainingNetwork_RNN, TimeGradPredictionNetwork_RNN 84 | elif network == 'timegrad_transformer': 85 | training_net, prediction_net = TimeGradTrainingNetwork_Transformer, TimeGradPredictionNetwork_Transformer 86 | elif network == 'timegrad_cnn': 87 | training_net, prediction_net = TimeGradTrainingNetwork_CNN, TimeGradPredictionNetwork_CNN 88 | 89 | estimator = ScoreEstimator( 90 | training_net=training_net, 91 | prediction_net=prediction_net, 92 | noise=noise, 93 | target_dim=target_dim, 94 | prediction_length=dataset.metadata.prediction_length, 95 | context_length=dataset.metadata.prediction_length, 96 | cell_type='GRU', 97 | num_cells=num_cells, 98 | hidden_dim=hidden_dim, 99 | residual_layers=residual_layers, 100 | input_size=target_dim * 4 + covariance_dim, 101 | freq=dataset.metadata.freq, 102 | loss_type='l2', 103 | scaling=True, 104 | diff_steps=diffusion_steps, 105 | beta_end=20 / diffusion_steps, 106 | beta_schedule='linear', 107 | num_parallel_samples=100, 108 | pick_incomplete=True, 109 | trainer=TrainerForecasting( 110 | device=device, 111 | epochs=epochs, 112 | learning_rate=learning_rate, 113 | num_batches_per_epoch=100, 114 | batch_size=batch_size, 115 | patience=10, 116 | ), 117 | ) 118 | 119 | # Training 120 | predictor = estimator.train(dataset_train, dataset_val, num_workers=8) 121 | 122 | # Evaluation 123 | forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test, predictor=predictor, num_samples=100) 124 | forecasts = list(forecast_it) 125 | targets = list(ts_it) 126 | 127 | score = energy_score( 128 | forecast=np.array([x.samples for x in forecasts]), 129 | target=np.array([x[-dataset.metadata.prediction_length:] for x in targets])[:,None,...], 130 | ) 131 | 132 | evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], target_agg_funcs={'sum': np.sum}) 133 | agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test)) 134 | 135 | metrics = dict( 136 | CRPS=agg_metric['mean_wQuantileLoss'], 137 | ND=agg_metric['ND'], 138 | NRMSE=agg_metric['NRMSE'], 139 | CRPS_sum=agg_metric['m_sum_mean_wQuantileLoss'], 140 | ND_sum=agg_metric['m_sum_ND'], 141 | NRMSE_sum=agg_metric['m_sum_NRMSE'], 142 | energy_score=score, 143 | ) 144 | metrics = { k: float(v) for k,v in metrics.items() } 145 | 146 | return metrics 147 | 148 | 149 | if __name__ == '__main__': 150 | parser = argparse.ArgumentParser(description='Train forecasting model.') 151 | parser.add_argument('--seed', type=int, default=1) 152 | parser.add_argument('--dataset', type=str) 153 | parser.add_argument('--network', type=str, choices=[ 154 | 'timegrad', 'timegrad_old', 'timegrad_all', 'timegrad_rnn', 'timegrad_transformer', 'timegrad_cnn' 155 | ]) 156 | parser.add_argument('--noise', type=str, choices=['normal', 'ou', 'gp']) 157 | parser.add_argument('--diffusion_steps', type=int, default=100) 158 | parser.add_argument('--epochs', type=int, default=100) 159 | parser.add_argument('--learning_rate', type=int, default=1e-3) 160 | parser.add_argument('--batch_size', type=int, default=64) 161 | parser.add_argument('--num_cells', type=int, default=100) 162 | parser.add_argument('--hidden_dim', type=int, default=100) 163 | parser.add_argument('--residual_layers', type=int, default=8) 164 | args = parser.parse_args() 165 | 166 | metrics = train(**args.__dict__) 167 | 168 | for key, value in metrics.items(): 169 | print(f'{key}:\t{value:.4f}') 170 | 171 | # Example: 172 | # python -m tsdiff.forecasting.train --seed 1 --dataset electricity_nips --network timegrad_rnn --noise ou --epochs 100 173 | -------------------------------------------------------------------------------- /tsdiff/forecasting/train_deepvar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import argparse 4 | 5 | import torch 6 | from gluonts.dataset.multivariate_grouper import MultivariateGrouper 7 | from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset 8 | from pts.model.deepvar import DeepVAREstimator 9 | from pts import Trainer 10 | from gluonts.evaluation.backtest import make_evaluation_predictions 11 | from gluonts.evaluation import MultivariateEvaluator 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | def energy_score(forecast, target): 15 | obs_dist = np.mean(np.linalg.norm((forecast - target), axis=-1)) 16 | pair_dist = np.mean( 17 | np.linalg.norm(forecast[:, np.newaxis, ...] - forecast, axis=-1) 18 | ) 19 | return obs_dist - pair_dist * 0.5 20 | 21 | def train(dataset_name): 22 | covariance_dim = 4 if dataset_name != 'exchange_rate_nips' else -4 23 | 24 | dataset = get_dataset(dataset_name, regenerate=False) 25 | 26 | train_grouper = MultivariateGrouper(max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)) 27 | test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test)/len(dataset.train)), max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)) 28 | 29 | dataset_train = train_grouper(dataset.train) 30 | dataset_test = test_grouper(dataset.test) 31 | 32 | evaluator = MultivariateEvaluator(quantiles=(np.arange(20)/20.0)[1:], target_agg_funcs={'sum': np.sum}) 33 | 34 | target_dim = int(dataset.metadata.feat_static_cat[0].cardinality) 35 | 36 | estimator = DeepVAREstimator( 37 | input_size=target_dim * 4 + covariance_dim + 3, 38 | target_dim=target_dim, 39 | prediction_length=dataset.metadata.prediction_length, 40 | context_length=dataset.metadata.prediction_length*4, 41 | freq=dataset.metadata.freq, 42 | trainer=Trainer( 43 | device=device, 44 | epochs=40, 45 | learning_rate=1e-3, 46 | num_batches_per_epoch=100, 47 | batch_size=64, 48 | ) 49 | ) 50 | 51 | predictor = estimator.train(dataset_train) 52 | forecast_it, ts_it = make_evaluation_predictions(dataset=dataset_test, 53 | predictor=predictor, 54 | num_samples=100) 55 | forecasts = list(forecast_it) 56 | targets = list(ts_it) 57 | 58 | agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test)) 59 | 60 | score = energy_score( 61 | forecast=np.array([x.samples for x in forecasts]), 62 | target=np.array([x[-dataset.metadata.prediction_length:] for x in targets])[:,None,...], 63 | ) 64 | 65 | metrics = dict( 66 | CRPS=agg_metric['mean_wQuantileLoss'], 67 | ND=agg_metric['ND'], 68 | NRMSE=agg_metric['NRMSE'], 69 | CRPS_sum=agg_metric['m_sum_mean_wQuantileLoss'], 70 | ND_sum=agg_metric['m_sum_ND'], 71 | NRMSE_sum=agg_metric['m_sum_NRMSE'], 72 | energy_score=score, 73 | ) 74 | metrics = { k: float(v) for k,v in metrics.items() } 75 | 76 | return metrics 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser(description='Train forecasting model.') 80 | parser.add_argument('--dataset', type=str, choices=['exchange_rate_nips', 'electricity_nips', 'solar_nips', 'traffic_nips']) 81 | args = parser.parse_args() 82 | 83 | metrics = train(args.dataset) 84 | 85 | for key, value in metrics.items(): 86 | print(f'{key}:\t{value:.4f}') 87 | -------------------------------------------------------------------------------- /tsdiff/neural_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbilos/tsdiff/32b23f2b7f5ec4d68bc533dda8a0096086cbd0ab/tsdiff/neural_process/__init__.py -------------------------------------------------------------------------------- /tsdiff/neural_process/experiment.py: -------------------------------------------------------------------------------- 1 | import seml 2 | from sacred import Experiment 3 | from tsdiff.neural_process.train import train 4 | 5 | ex = Experiment() 6 | seml.setup_logger(ex) 7 | 8 | @ex.config 9 | def config(): 10 | overwrite = None 11 | db_collection = None 12 | if db_collection is not None: 13 | ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite)) 14 | 15 | @ex.automain 16 | def run(seed: int, use_gp: bool, param: float): 17 | _, result = train(seed=seed, gp=use_gp, param=param) 18 | return result 19 | -------------------------------------------------------------------------------- /tsdiff/neural_process/experiment.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | executable: experiments/neural_process/experiment.py 3 | name: tsdiff_np 4 | output_dir: experiments/logs 5 | project_root_dir: ../../ 6 | conda_environment: place_your_env_here 7 | 8 | slurm: 9 | experiments_per_job: 1 10 | sbatch_options: 11 | gres: gpu:1 # num GPUs 12 | mem: 16G # memory 13 | cpus-per-task: 2 # num cores 14 | time: 0-08:00 # max time, D-HH:MM 15 | partition: gpu_all 16 | 17 | 18 | grid: 19 | seed: 20 | type: range 21 | min: 1 22 | max: 6 23 | step: 1 24 | 25 | use_gp: 26 | type: choice 27 | options: 28 | - True 29 | - False 30 | 31 | param: 32 | type: choice 33 | options: 34 | - 0.002 35 | - 0.02 36 | - 0.2 37 | - 2 38 | -------------------------------------------------------------------------------- /tsdiff/neural_process/train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from functools import partial 4 | from scipy.stats import multivariate_normal 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader, Dataset 9 | from gluonts.evaluation.metrics import mse, quantile_loss 10 | 11 | from tsdiff.diffusion import OUDiffusion, GPDiffusion, BetaLinear 12 | from tsdiff.utils import PositionalEncoding 13 | 14 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 15 | torch.cuda.set_device(device) 16 | 17 | def quantile_loss(target, forecast, q): 18 | """ Adapted from gluonts """ 19 | return 2 * np.sum(np.abs((forecast - target) * ((target <= forecast) - q))) 20 | 21 | def q_mean_loss(target, forecast): 22 | q_loss = [] 23 | for q in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: 24 | forecast_quantile = np.quantile(forecast, q) 25 | q_loss.append(quantile_loss(target, forecast_quantile, q)) 26 | return np.mean(q_loss) / np.abs(target).sum() 27 | 28 | def radial_basis_kernel(x, y, sigma): 29 | dist = (y[None,:] - x[:,None])**2 30 | return np.exp(-dist / sigma) 31 | 32 | def generate_data( 33 | N, 34 | train_ratio=0.5, 35 | mean_num_points=8, 36 | min_num_points=5, 37 | max_num_points=50, 38 | min_t=0, 39 | max_t=1, 40 | sigma=0.05, 41 | seed=1, 42 | ): 43 | np.random.seed(seed) 44 | 45 | lengths = np.random.poisson(mean_num_points, size=(N,)) 46 | lengths = np.clip(lengths, min_num_points, max_num_points) 47 | 48 | x_time, x_values = [], [] 49 | y_time, y_values = [], [] 50 | 51 | q_loss = [] 52 | mse_error = [] 53 | 54 | for n in lengths: 55 | t = np.random.uniform(min_t, max_t, size=n) 56 | 57 | cov = radial_basis_kernel(t, t, sigma) + 1e-4 * np.eye(n) # Add to diagonal for stability 58 | L = np.linalg.cholesky(cov) 59 | 60 | points = L @ np.random.normal(size=(n,)) 61 | 62 | i = int(n * train_ratio) 63 | i = np.clip(i, 1, n - 1) 64 | 65 | x_time.append(t[:i]) 66 | y_time.append(t[i:]) 67 | x_values.append(points[:i]) 68 | y_values.append(points[i:]) 69 | 70 | # Evaluate the quantile loss 71 | t_ = t[i:] 72 | t = t[:i] 73 | 74 | kxx = radial_basis_kernel(t, t, sigma) 75 | kxx = kxx + 1e-2 * np.eye(len(kxx)) 76 | kxx_inv = np.linalg.inv(kxx) 77 | kyx = radial_basis_kernel(t_, t, sigma) 78 | kyy = radial_basis_kernel(t_, t_, sigma) 79 | 80 | mean = kyx @ kxx_inv @ points[:i] 81 | cov = kyy - kyx @ kxx_inv @ kyx.T 82 | # cov += 1e-3 * np.eye(len(cov)) 83 | 84 | dist = multivariate_normal(mean, cov) 85 | 86 | sample = dist.rvs(100) 87 | 88 | q_loss.append(q_mean_loss(points[i:], sample)) 89 | mse_error.append(mse(points[i:], sample.mean(0))) 90 | 91 | q_loss = np.mean(q_loss) 92 | mse_error = np.mean(mse_error) 93 | return x_values, x_time, y_values, y_time, q_loss, mse_error 94 | 95 | 96 | class NumpyDataset(Dataset): 97 | def __init__(self, x, tx, y, ty): 98 | super().__init__() 99 | assert len(x) == len(tx) == len(y) == len(ty) 100 | assert [len(a) == len(b) for a, b in zip(x, tx)] 101 | assert [len(a) == len(b) for a, b in zip(y, ty)] 102 | 103 | self.x = x 104 | self.tx = tx 105 | self.y = y 106 | self.ty = ty 107 | 108 | def __getitem__(self, ind): 109 | return self.x[ind], self.tx[ind], self.y[ind], self.ty[ind] 110 | 111 | def __len__(self): 112 | return len(self.x) 113 | 114 | def collate_fn(batch, device): 115 | x, tx, y, ty = list(zip(*batch)) 116 | max_x_len = max(map(len, x)) 117 | max_y_len = max(map(len, y)) 118 | 119 | def get_mask(arr, max_len): 120 | mask = np.array([np.concatenate([np.ones(len(x)), np.zeros(max_len - len(x))]) for x in arr]) 121 | return (1 - torch.Tensor(mask)).bool().to(device) 122 | x_pad = get_mask(x, max_x_len) 123 | y_pad = get_mask(y, max_y_len) 124 | 125 | pad = lambda arr, max_len: torch.Tensor(np.array([np.pad(s, (0, max_len - len(s))) for s in arr])).unsqueeze(-1).to(device) 126 | x = pad(x, max_x_len) 127 | y = pad(y, max_y_len) 128 | tx = pad(tx, max_x_len) 129 | ty = pad(ty, max_y_len) 130 | 131 | return x, y, tx, ty, x_pad, y_pad 132 | 133 | class Denoiser(nn.Module): 134 | def __init__(self, dim, hidden_dim): 135 | super().__init__() 136 | self.i_emb = PositionalEncoding(hidden_dim, max_value=100) 137 | self.t_emb = PositionalEncoding(hidden_dim, max_value=1) 138 | 139 | self.linear1 = nn.Linear(dim, hidden_dim) 140 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 141 | 142 | self.i_proj = nn.Sequential( 143 | nn.Linear(hidden_dim, hidden_dim), 144 | nn.Tanh(), 145 | nn.Linear(hidden_dim, hidden_dim), 146 | ) 147 | 148 | self.z_proj = nn.Sequential( 149 | nn.Linear(dim, hidden_dim), 150 | nn.Tanh(), 151 | nn.Linear(hidden_dim, hidden_dim), 152 | ) 153 | 154 | self.net = nn.Sequential( 155 | nn.Linear(hidden_dim, hidden_dim), 156 | nn.Tanh(), 157 | nn.Linear(hidden_dim, hidden_dim), 158 | nn.Tanh(), 159 | nn.Linear(hidden_dim, dim), 160 | ) 161 | 162 | self.sigma = nn.Parameter(torch.randn(1)) 163 | 164 | def forward(self, y, *, t, i, z, tz, z_pad): 165 | k = torch.exp(-torch.square(t - tz.transpose(-1, -2)) / self.sigma**2) 166 | 167 | z = self.z_proj(z) 168 | z = z * (1 - z_pad.float().unsqueeze(-1)) 169 | 170 | z = k @ z 171 | 172 | i = self.i_emb(i) 173 | i = self.i_proj(i) 174 | 175 | y = self.linear1(y) + z 176 | y = torch.relu(y) 177 | 178 | y = self.linear2(y) + i 179 | y = torch.relu(y) 180 | 181 | y = self.net(y) 182 | return y 183 | 184 | class NeuralDenoisingProcess(nn.Module): 185 | def __init__(self, dim, hidden_dim, gp, param): 186 | super().__init__() 187 | self.dim = dim 188 | self.hidden_dim = hidden_dim 189 | 190 | if gp: 191 | self.diffusion = GPDiffusion(dim, beta_fn=BetaLinear(1e-4, 0.2), num_steps=100, sigma=param) 192 | else: 193 | self.diffusion = OUDiffusion(dim, beta_fn=BetaLinear(1e-4, 0.2), num_steps=100, theta=param) 194 | self.denoise_fn = Denoiser(dim, hidden_dim) 195 | 196 | def forward(self, x, y, tx, ty, x_pad, y_pad): 197 | y = torch.cat([x, y], 1) 198 | ty = torch.cat([tx, ty], 1) 199 | y_pad = torch.cat([x_pad, y_pad], 1) 200 | 201 | loss = self.diffusion.get_loss(self.denoise_fn, y, t=ty, tz=tx, z=x, z_pad=x_pad) 202 | loss = loss.mean(-1) * (1 - y_pad.float()) 203 | return loss.mean() 204 | 205 | def sample(self, x, tx, x_pad, ty, num_samples=1): 206 | x = x.repeat_interleave(num_samples, dim=0) 207 | tx = tx.repeat_interleave(num_samples, dim=0) 208 | x_pad = x_pad.repeat_interleave(num_samples, dim=0) 209 | ty = ty.repeat_interleave(num_samples, dim=0) 210 | return self.diffusion.sample(self.denoise_fn, num_samples=ty.shape[:-1], device=tx.device, 211 | t=ty, tz=tx, z=x, z_pad=x_pad) 212 | 213 | def evaluate(model, testloader): 214 | num_samples = 100 215 | 216 | targets, forecasts = [], [] 217 | mse_err = [] 218 | q_loss = [] 219 | 220 | for _, batch in enumerate(testloader): 221 | x, y, tx, ty, x_pad, y_pad = batch 222 | 223 | y_pred = model.sample(x, tx, x_pad, ty, num_samples=num_samples) 224 | 225 | target, forecast = y.cpu().squeeze().numpy(), y_pred.cpu().squeeze().numpy() 226 | 227 | targets.append(target) 228 | forecasts.append(forecast) 229 | 230 | assert target.shape == forecast.mean(0).shape 231 | 232 | q_loss.append(q_mean_loss(target, forecast)) 233 | mse_err.append(mse(target, forecast.mean(0))) 234 | 235 | return np.mean(q_loss), np.mean(mse_err) 236 | 237 | def sample(model, testloader, gp, param): 238 | num_samples = 30 239 | T = 50 240 | 241 | loader_it = iter(testloader) 242 | 243 | samples = [] 244 | xs = [] 245 | txs = [] 246 | ys = [] 247 | tys = [] 248 | 249 | for _ in range(10): 250 | x, y, tx, ty, x_pad, y_pad = next(loader_it) 251 | t = torch.linspace(0, 1, T).view(1, -1, 1).to(x) 252 | y_samples = model.sample(x, tx, x_pad, t, num_samples=num_samples) 253 | 254 | samples.append(y_samples.detach().cpu().numpy()) 255 | xs.append(x.detach().cpu().numpy()) 256 | txs.append(tx.detach().cpu().numpy()) 257 | ys.append(y.detach().cpu().numpy()) 258 | tys.append(ty.detach().cpu().numpy()) 259 | 260 | root = '/tsdiff/experiments/neural_process/samples' 261 | filename = f'{root}/{"gp" if gp else "ou"}-{param}.pkl' 262 | 263 | with open(filename, 'wb') as f: 264 | data = dict(sample=samples, x=xs, tx=txs, y=ys, ty=tys) 265 | pickle.dump(data, f) 266 | 267 | def train( 268 | seed: int, 269 | gp: bool, 270 | param: float, 271 | epochs: int = 200, 272 | batch_size: int = 128, 273 | train_size: int = 800, 274 | test_size: int = 200, 275 | hidden_dim: int = 32, 276 | ): 277 | # Generate data 278 | x_values, x_time, y_values, y_time, train_q_loss, train_mse_error = generate_data(train_size, seed=seed) 279 | trainset = NumpyDataset(x_values, x_time, y_values, y_time) 280 | trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=partial(collate_fn, device=device), shuffle=True) 281 | 282 | x_values, x_time, y_values, y_time, test_q_loss, test_mse_error = generate_data(test_size, seed=seed + 10) 283 | testset = NumpyDataset(x_values, x_time, y_values, y_time) 284 | testloader = DataLoader(testset, batch_size=1, collate_fn=partial(collate_fn, device=device), shuffle=False) 285 | 286 | # Make model 287 | model = NeuralDenoisingProcess( 288 | dim=1, 289 | hidden_dim=hidden_dim, 290 | gp=gp, 291 | param=param, 292 | ).to(device) 293 | optim = torch.optim.Adam(model.parameters(), weight_decay=1e-5) 294 | 295 | # Train 296 | for epoch in range(epochs): 297 | for batch in trainloader: 298 | optim.zero_grad() 299 | loss = model(*batch) 300 | loss.backward() 301 | optim.step() 302 | if epoch % 10 == 0: 303 | print(f'Epoch {epoch}, loss: {loss:.4f}') 304 | 305 | # Evaluate 306 | q_loss, mse_error = evaluate(model, testloader) 307 | 308 | if seed == 1: 309 | sample(model, testloader, gp, param) 310 | 311 | result = dict( 312 | q_loss=q_loss, 313 | mse_error=mse_error, 314 | train_q_loss=train_q_loss, 315 | train_mse_error=train_mse_error, 316 | test_q_loss=test_q_loss, 317 | test_mse_error=test_mse_error, 318 | ) 319 | 320 | return model, result 321 | 322 | if __name__ == '__main__': 323 | model, result = train(seed=1, param=0.02, epochs=20, gp=True) 324 | print(result) 325 | -------------------------------------------------------------------------------- /tsdiff/synthetic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbilos/tsdiff/32b23f2b7f5ec4d68bc533dda8a0096086cbd0ab/tsdiff/synthetic/__init__.py -------------------------------------------------------------------------------- /tsdiff/synthetic/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import TensorDataset, DataLoader, random_split 4 | from pytorch_lightning import LightningDataModule 5 | from pathlib import Path 6 | 7 | class DataModule(LightningDataModule): 8 | def __init__(self, name, batch_size: int, test_batch_size: int = None): 9 | super().__init__() 10 | self.name = name 11 | self.batch_size = batch_size 12 | self.test_batch_size = test_batch_size or batch_size 13 | 14 | dataset = self._load_dataset() 15 | self.trainset, self.valset, self.testset = self._split_train_val_test(dataset) 16 | 17 | @property 18 | def dim(self): 19 | return self.trainset[0][0].shape[-1] 20 | 21 | @property 22 | def x_mean(self): 23 | return torch.cat([x[0] for x in self.trainset], 0).mean(0) 24 | 25 | @property 26 | def x_std(self): 27 | return torch.cat([x[0] for x in self.trainset], 0).std(0).clamp(1e-4) 28 | 29 | @property 30 | def t_max(self): 31 | return torch.cat([x[1] for x in self.trainset], 0).max().item() 32 | 33 | def train_dataloader(self): 34 | return DataLoader(self.trainset, batch_size=self.batch_size, shuffle=True) 35 | 36 | def val_dataloader(self): 37 | return DataLoader(self.valset, batch_size=self.test_batch_size, shuffle=False) 38 | 39 | def test_dataloader(self): 40 | return DataLoader(self.testset, batch_size=self.test_batch_size, shuffle=False) 41 | 42 | def _load_dataset(self): 43 | filepath = Path(__file__).parents[2] / f'data/synthetic/{self.name}.npz' 44 | data = np.load(filepath) 45 | dataset = TensorDataset(torch.Tensor(data['x']), torch.Tensor(data['t'])) 46 | return dataset 47 | 48 | def _split_train_val_test(self, dataset): 49 | train_len, val_len = int(0.6 * len(dataset)), int(0.2 * len(dataset)) 50 | return random_split(dataset, lengths=[train_len, val_len, len(dataset) - train_len - val_len]) 51 | -------------------------------------------------------------------------------- /tsdiff/synthetic/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pytorch_lightning import LightningModule 7 | from torchtyping import TensorType 8 | 9 | import tsdiff 10 | from tsdiff.diffusion.beta_scheduler import get_beta_scheduler, get_loss_weighting 11 | from tsdiff.utils import PositionalEncoding, FeedForward 12 | 13 | 14 | class FeedForwardModel(nn.Module): 15 | def __init__(self, dim, hidden_dim, max_i, num_layers=3, **kwargs): 16 | super().__init__() 17 | self.t_enc = PositionalEncoding(hidden_dim, max_value=1) 18 | self.i_enc = PositionalEncoding(hidden_dim, max_value=max_i) 19 | self.input_proj = nn.Linear(dim, hidden_dim) 20 | self.net = FeedForward(3 * hidden_dim, [hidden_dim] * num_layers, dim) 21 | 22 | def forward(self, x, *, t, i, **kwargs): 23 | t = self.t_enc(t) 24 | i = self.i_enc(i) 25 | x = self.input_proj(x) 26 | x = torch.cat([x, t, i], -1) 27 | return self.net(x) 28 | 29 | class RNNModel(nn.Module): 30 | def __init__(self, dim, hidden_dim, max_i, num_layers=2, bidirectional=True, **kwargs): 31 | super().__init__() 32 | self.hidden_dim = hidden_dim 33 | self.num_layers = num_layers 34 | self.directions = 2 if bidirectional else 1 35 | 36 | self.t_enc = PositionalEncoding(hidden_dim, max_value=1) 37 | self.i_enc = PositionalEncoding(hidden_dim, max_value=max_i) 38 | self.init_proj = FeedForward(hidden_dim, [], self.num_layers * self.directions * hidden_dim) 39 | self.input_proj = FeedForward(dim, [], hidden_dim) 40 | self.rnn = nn.GRU( 41 | 3 * hidden_dim, 42 | hidden_dim, 43 | num_layers=num_layers, 44 | bidirectional=bidirectional, 45 | batch_first=True, 46 | ) 47 | self.output_proj = FeedForward(self.directions * hidden_dim, [], dim) 48 | 49 | def forward( 50 | self, 51 | x: TensorType['B', 'L', 'D'], 52 | *, 53 | t: TensorType['B', 'L', 1], 54 | i: TensorType['B', 'L', 1], 55 | **kwargs, 56 | ) -> TensorType['B', 'L', 'D']: 57 | shape = x.shape 58 | 59 | t = self.t_enc(t.view(-1, shape[-2], 1)) 60 | i = self.i_enc(i.view(-1, shape[-2], 1)) 61 | x = self.input_proj(x.view(-1, *shape[-2:])) 62 | 63 | init = self.init_proj(i[:,0]) 64 | init = init.view(self.num_layers * self.directions, -1, self.hidden_dim) 65 | 66 | x = torch.cat([x, t, i], -1) 67 | 68 | y, _ = self.rnn(x, init) 69 | y = self.output_proj(y) 70 | y = y.view(*shape) 71 | 72 | return y 73 | 74 | 75 | class ResidualBlock(nn.Module): 76 | def __init__(self, dim, hidden_size, residual_channels, dilation, padding_mode): 77 | super().__init__() 78 | self.step_projection = nn.Linear(hidden_size, residual_channels) 79 | self.time_projection = nn.Linear(hidden_size, residual_channels) 80 | 81 | self.x_step_proj = nn.Sequential( 82 | nn.Conv2d(residual_channels, residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode), 83 | nn.LeakyReLU(0.4), 84 | ) 85 | self.x_time_proj = nn.Sequential( 86 | nn.Conv2d(residual_channels, residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode), 87 | nn.LeakyReLU(0.4), 88 | ) 89 | 90 | self.latent_projection = nn.Conv2d( 91 | 1, 2 * residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode, 92 | ) 93 | self.dilated_conv = nn.Conv2d( 94 | 1 * residual_channels, 95 | 2 * residual_channels, 96 | kernel_size=3, 97 | dilation=dilation, 98 | padding='same', 99 | padding_mode=padding_mode, 100 | ) 101 | self.output_projection = nn.Conv2d( 102 | residual_channels, 2 * residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode, 103 | ) 104 | 105 | def forward(self, x, t=None, i=None): 106 | i = self.step_projection(i).transpose(-1, -2).unsqueeze(-1) 107 | 108 | y = x + i 109 | y = y + self.x_step_proj(y) 110 | 111 | t = self.time_projection(t).transpose(-1, -2).unsqueeze(-1) 112 | y = y + self.x_time_proj(y + t) 113 | 114 | y = self.dilated_conv(y) 115 | 116 | gate, filter = y.chunk(2, dim=1) 117 | y = torch.sigmoid(gate) * torch.tanh(filter) 118 | 119 | y = self.output_projection(y) 120 | y = F.leaky_relu(y, 0.4) 121 | 122 | residual, skip = y.chunk(2, dim=1) 123 | return (x + residual) / math.sqrt(2), skip 124 | 125 | class CNNModel(nn.Module): 126 | def __init__(self, dim, hidden_dim, max_i, num_layers=8, residual_channels=8, padding_mode='circular'): 127 | super().__init__() 128 | 129 | self.input_projection = nn.Conv2d(1, residual_channels, kernel_size=1, padding='same', padding_mode=padding_mode) 130 | self.step_embedding = PositionalEncoding(hidden_dim, max_value=max_i) 131 | self.time_embedding = PositionalEncoding(hidden_dim, max_value=1) 132 | 133 | self.residual_layers = nn.ModuleList([ 134 | ResidualBlock(dim, hidden_dim, residual_channels, dilation=2**(i % 2), padding_mode=padding_mode) 135 | for i in range(num_layers) 136 | ]) 137 | 138 | self.skip_projection = nn.Conv2d( 139 | residual_channels, residual_channels, kernel_size=3, padding='same', padding_mode=padding_mode, 140 | ) 141 | self.output_projection = nn.Conv2d( 142 | residual_channels, 1, kernel_size=3, padding='same', padding_mode=padding_mode, 143 | ) 144 | 145 | self.time_proj = nn.Sequential( 146 | nn.Linear(1, hidden_dim), 147 | nn.LeakyReLU(0.4), 148 | nn.Linear(hidden_dim, hidden_dim), 149 | nn.LeakyReLU(0.4), 150 | ) 151 | 152 | def forward(self, x, t=None, i=None, **kwargs): 153 | shape = x.shape 154 | 155 | x = x.view(-1, *x.shape[-2:]) 156 | t = t.view(-1, *t.shape[-2:]) 157 | i = i.view(-1, *i.shape[-2:]) 158 | 159 | x = x.unsqueeze(1) 160 | x = self.input_projection(x) 161 | x = F.leaky_relu(x, 0.4) 162 | 163 | i = self.step_embedding(i) 164 | t = self.time_proj(t) 165 | 166 | skip_agg = 0 167 | for layer in self.residual_layers: 168 | x, skip = layer(x, t=t, i=i) 169 | skip_agg = skip_agg + skip 170 | 171 | x = skip_agg / math.sqrt(len(self.residual_layers)) 172 | x = self.skip_projection(x) 173 | x = F.leaky_relu(x, 0.4) 174 | x = self.output_projection(x).squeeze(1) 175 | 176 | x = x.view(*shape) 177 | return x 178 | 179 | 180 | class TransformerModel(nn.Module): 181 | def __init__(self, dim, hidden_dim, max_i, num_layers=8, 182 | num_ref_points=10, **kwargs): 183 | super().__init__() 184 | self.hidden_dim = hidden_dim 185 | self.num_layers = num_layers 186 | 187 | self.t_enc = PositionalEncoding(hidden_dim, max_value=1) 188 | self.i_enc = PositionalEncoding(hidden_dim, max_value=max_i) 189 | 190 | self.input_proj = FeedForward(dim, [], hidden_dim) 191 | 192 | self.proj = FeedForward(3 * hidden_dim, [], hidden_dim, final_activation=nn.ReLU()) 193 | 194 | self.enc_att = [] 195 | self.i_proj = [] 196 | for _ in range(num_layers): 197 | self.enc_att.append(nn.MultiheadAttention(hidden_dim, num_heads=1, batch_first=True)) 198 | self.i_proj.append(nn.Linear(3 * hidden_dim, hidden_dim)) 199 | self.enc_att = nn.ModuleList(self.enc_att) 200 | self.i_proj = nn.ModuleList(self.i_proj) 201 | 202 | self.output_proj = FeedForward(hidden_dim, [], dim) 203 | 204 | def forward( 205 | self, 206 | x: TensorType['B', 'L', 'D'], 207 | *, 208 | t: TensorType['B', 'L', 1], 209 | i: TensorType['B', 'L', 1], 210 | **kwargs, 211 | ) -> TensorType['B', 'L', 'D']: 212 | shape = x.shape 213 | 214 | x = x.view(-1, *shape[-2:]) 215 | t = t.view(-1, shape[-2], 1) 216 | i = i.view(-1, shape[-2], 1) 217 | 218 | x = self.input_proj(x) 219 | t = self.t_enc(t) 220 | i = self.i_enc(i) 221 | 222 | x = self.proj(torch.cat([x, t, i], -1)) 223 | 224 | for att_layer, i_proj in zip(self.enc_att, self.i_proj): 225 | y, _ = att_layer( 226 | query=x, 227 | key=x, 228 | value=x, 229 | ) 230 | x = x + torch.relu(y) 231 | 232 | x = self.output_proj(x) 233 | x = x.view(*shape) 234 | return x 235 | 236 | class DiffusionModule(LightningModule): 237 | def __init__( 238 | self, 239 | # Data params 240 | dim: int, 241 | data_mean: torch.Tensor = None, 242 | data_std: torch.Tensor = None, 243 | max_t: float = None, 244 | # Diffusion params 245 | diffusion: str = None, 246 | gp_sigma: float = None, 247 | ou_theta: float = None, 248 | discrete_num_steps: int = None, 249 | predict_gaussian_noise: bool = None, 250 | continuous_t1: float = None, 251 | beta_fn: str = None, 252 | beta_start: float = None, 253 | beta_end: float = None, 254 | loss_weighting: str = None, 255 | # NN params 256 | model: str = None, 257 | hidden_dim: int = None, 258 | # Training params 259 | learning_rate: float = None, 260 | weight_decay: float = None, 261 | **kwargs, 262 | ): 263 | super().__init__() 264 | 265 | self.save_hyperparameters() 266 | 267 | self.dim = dim 268 | self.data_mean = data_mean 269 | self.data_std = data_std 270 | self.max_t = max_t 271 | 272 | self.learning_rate = learning_rate 273 | self.weight_decay = weight_decay 274 | 275 | self.diffusion = getattr(tsdiff.diffusion, diffusion)( 276 | dim=dim, 277 | beta_fn=get_beta_scheduler(beta_fn)(beta_start, beta_end), 278 | sigma=gp_sigma, 279 | theta=ou_theta, 280 | num_steps=discrete_num_steps, 281 | predict_gaussian_noise=predict_gaussian_noise, 282 | t1=continuous_t1, 283 | loss_weighting=get_loss_weighting(loss_weighting), 284 | ) 285 | 286 | if model == 'rnn': 287 | model = RNNModel 288 | elif model == 'feedforward': 289 | model = FeedForwardModel 290 | elif model == 'cnn': 291 | model = CNNModel 292 | elif model == 'transformer': 293 | model = TransformerModel 294 | 295 | max_i = continuous_t1 if 'Continuous' in diffusion else discrete_num_steps 296 | 297 | self.model = model( 298 | dim=dim, 299 | hidden_dim=hidden_dim, 300 | max_i=max_i, 301 | ) 302 | 303 | def forward(self, batch, log_name=None): 304 | x, t = self._normalize_batch(batch) 305 | loss = self.diffusion.get_loss(self.model, x, t=t).mean() 306 | if log_name is not None: 307 | self.log(log_name, loss) 308 | return loss 309 | 310 | def training_step(self, batch, batch_idx): 311 | return self.forward(batch, 'train_loss') 312 | 313 | def validation_step(self, batch, batch_idx): 314 | return self.forward(batch, 'val_loss') 315 | 316 | def test_step(self, batch, batch_idx): 317 | x, t = self._normalize_batch(batch) 318 | log_prob = self.diffusion.log_prob(self.model, x, t=t, num_samples=5) 319 | log_prob = log_prob - self.data_std.log().sum() 320 | self.log('test_log_prob', log_prob.mean()) 321 | 322 | def configure_optimizers(self): 323 | optimizer = torch.optim.Adam( 324 | self.parameters(), 325 | lr=self.learning_rate, 326 | weight_decay=self.weight_decay, 327 | ) 328 | return optimizer 329 | 330 | def sample(self, t, **kwargs): 331 | t = t / self.max_t 332 | samples = self.diffusion.sample( 333 | self.model.to(t), 334 | num_samples=t.shape[:-1], 335 | t=t, 336 | device=t.device, 337 | **kwargs, 338 | ) 339 | 340 | return samples * self.data_std.to(t) + self.data_mean.to(t) 341 | 342 | def _normalize_batch(self, batch): 343 | x, t = batch 344 | x = (x - self.data_mean.to(x)) / self.data_std.to(x) 345 | t = t / self.max_t 346 | return x, t 347 | -------------------------------------------------------------------------------- /tsdiff/synthetic/discriminator_experiment.py: -------------------------------------------------------------------------------- 1 | import seml 2 | import numpy as np 3 | from sacred import Experiment 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import TensorDataset, DataLoader 10 | 11 | from pytorch_lightning import LightningModule, Trainer 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 13 | 14 | from tsdiff.synthetic.train import train 15 | 16 | DATA_DIR = Path(__file__).parents[2].resolve() / 'data/synthetic' 17 | SAMPLE_DIR = Path(__file__).parents[2].resolve() / 'data/samples' 18 | 19 | ex = Experiment() 20 | seml.setup_logger(ex) 21 | 22 | class Net(nn.Module): 23 | def __init__(self, dim, hidden_dim): 24 | super().__init__() 25 | self.emb = nn.Linear(dim, hidden_dim) 26 | self.transformer = nn.TransformerEncoder( 27 | nn.TransformerEncoderLayer(hidden_dim, nhead=4, batch_first=True), 28 | num_layers=4, 29 | ) 30 | 31 | self.proj = nn.Sequential( 32 | nn.Linear(hidden_dim, hidden_dim), 33 | nn.ReLU(), 34 | nn.Linear(hidden_dim, 1), 35 | nn.Sigmoid(), 36 | ) 37 | 38 | def forward(self, x): 39 | h = self.emb(x) 40 | h = self.transformer(h) 41 | h = h.mean(dim=1) 42 | return self.proj(h).squeeze(-1) 43 | 44 | class Model(LightningModule): 45 | def __init__(self, dim, hidden_dim, lr, weight_decay): 46 | super().__init__() 47 | self.dim = dim 48 | self.hidden_dim = hidden_dim 49 | self.lr = lr 50 | self.weight_decay = weight_decay 51 | self.save_hyperparameters() 52 | self.loss = nn.BCELoss(reduction='mean') 53 | self.net = Net(dim, hidden_dim) 54 | 55 | def forward(self, x): 56 | return self.net(x) 57 | 58 | def training_step(self, batch, batch_nb, log_name='train_loss'): 59 | x, y = batch 60 | loss = self.loss(input=self.forward(x), target=y) 61 | self.log(log_name, loss, prog_bar=True) 62 | return loss 63 | 64 | def validation_step(self, batch, batch_nb): 65 | return self.training_step(batch, batch_nb, log_name='val_loss') 66 | 67 | @torch.no_grad() 68 | def test_step(self, batch, batch_idx): 69 | x, y = batch 70 | y_pred = self(x) 71 | loss = self.loss(input=y_pred, target=y) 72 | accuracy = torch.sum((y_pred > 0.5).float() == y) / len(y) 73 | self.log("test_loss", loss) 74 | self.log("test_acc", accuracy) 75 | 76 | def configure_optimizers(self): 77 | return torch.optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.weight_decay) 78 | 79 | 80 | @ex.config 81 | def config(): 82 | overwrite = None 83 | db_collection = None 84 | if db_collection is not None: 85 | ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite)) 86 | 87 | @ex.automain 88 | def run( 89 | seed: int, 90 | dataset: str, 91 | model: str, 92 | diffusion: str, 93 | epochs: int, 94 | batch_size: int, 95 | gp_sigma: float = None, 96 | ou_theta: float = None, 97 | ): 98 | np.random.seed(seed) 99 | torch.manual_seed(seed) 100 | 101 | # DATA 102 | filename = SAMPLE_DIR / f'{dataset}-{diffusion}-{model}-{gp_sigma or ou_theta}.npy' 103 | synthetic = np.load(filename) 104 | 105 | filename = DATA_DIR / f'{dataset}.npz' 106 | data = np.load(filename)['x'][:len(synthetic)] 107 | 108 | dim = data.shape[-1] 109 | 110 | all_data = torch.Tensor(np.concatenate([synthetic, data], 0)) 111 | all_data = (all_data - all_data.mean()) / all_data.std().clamp(1e-4) 112 | all_labels = torch.cat([torch.ones(len(synthetic)), torch.zeros(len(data))], 0) 113 | 114 | ind = torch.randperm(len(all_data)) 115 | all_data = all_data[ind] 116 | all_labels = all_labels[ind] 117 | 118 | ind1, ind2 = int(0.6 * len(all_data)), int(0.8 * len(all_data)) 119 | 120 | trainloader = DataLoader(TensorDataset(all_data[:ind1], all_labels[:ind1]), batch_size=batch_size) 121 | valloader = DataLoader(TensorDataset(all_data[ind1:ind2], all_labels[ind1:ind2]), batch_size=batch_size) 122 | testloader = DataLoader(TensorDataset(all_data[ind2:], all_labels[ind2:]), batch_size=batch_size) 123 | 124 | # TRAINING 125 | model = Model(dim, hidden_dim=128, lr=1e-3, weight_decay=1e-5) 126 | 127 | checkpointing = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, filename='best-checkpoint') 128 | early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=20) 129 | trainer = Trainer( 130 | gpus=1, 131 | auto_select_gpus=True, 132 | max_epochs=epochs, 133 | log_every_n_steps=1, 134 | enable_checkpointing=True, 135 | callbacks=[early_stopping, checkpointing], 136 | ) 137 | trainer.fit(model, train_dataloaders=trainloader, val_dataloaders=testloader) 138 | 139 | # TESTING 140 | model = Model.load_from_checkpoint(checkpointing.best_model_path) 141 | metrics = trainer.test(model, testloader) 142 | 143 | return metrics[0] 144 | -------------------------------------------------------------------------------- /tsdiff/synthetic/discriminator_experiment.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | executable: experiments/synthetic/discriminator_experiment.py 3 | name: tsdiff_discriminator 4 | output_dir: experiments/logs 5 | project_root_dir: ../../ 6 | conda_environment: place_your_env_here 7 | 8 | slurm: 9 | experiments_per_job: 1 10 | sbatch_options: 11 | gres: gpu:1 # num GPUs 12 | mem: 16G # memory 13 | cpus-per-task: 2 # num cores 14 | time: 0-08:00 # max time, D-HH:MM 15 | partition: gpu_all 16 | 17 | 18 | fixed: 19 | epochs: 100 20 | batch_size: 256 21 | 22 | grid: 23 | seed: 24 | type: range 25 | min: 1 26 | max: 6 27 | step: 1 28 | 29 | dataset: 30 | type: choice 31 | options: 32 | - cir 33 | - lorenz 34 | - ou 35 | - predator_prey 36 | - sine 37 | - sink 38 | 39 | model: 40 | type: choice 41 | options: 42 | - feedforward 43 | - rnn 44 | 45 | gaussian: 46 | grid: 47 | diffusion: 48 | type: choice 49 | options: 50 | - GaussianDiffusion 51 | - ContinuousGaussianDiffusion 52 | 53 | gp: 54 | grid: 55 | diffusion: 56 | type: choice 57 | options: 58 | - GPDiffusion 59 | - ContinuousGPDiffusion 60 | 61 | gp_sigma: 62 | type: choice 63 | options: 64 | - 0.01 65 | - 0.1 66 | - 1 67 | 68 | ou: 69 | grid: 70 | diffusion: 71 | type: choice 72 | options: 73 | - OUDiffusion 74 | - ContinuousOUDiffusion 75 | 76 | ou_theta: 77 | type: choice 78 | options: 79 | - 5 80 | - 0.5 81 | - 0.05 82 | 83 | ode: 84 | fixed: 85 | model: ode 86 | diffusion: None 87 | 88 | nf: 89 | fixed: 90 | model: nf 91 | diffusion: None 92 | -------------------------------------------------------------------------------- /tsdiff/synthetic/experiment.py: -------------------------------------------------------------------------------- 1 | import seml 2 | from sacred import Experiment 3 | from tsdiff.synthetic.train import train 4 | 5 | ex = Experiment() 6 | seml.setup_logger(ex) 7 | 8 | @ex.config 9 | def config(): 10 | overwrite = None 11 | db_collection = None 12 | if db_collection is not None: 13 | ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite)) 14 | 15 | @ex.automain 16 | def run( 17 | seed: int, 18 | dataset: str, 19 | model: str, 20 | diffusion: str, 21 | epochs: int, 22 | learning_rate: float, 23 | batch_size: int, 24 | predict_gaussian_noise: bool, 25 | gp_sigma: float = None, 26 | ou_theta: float = None, 27 | ): 28 | results = train(**locals()) 29 | return results 30 | -------------------------------------------------------------------------------- /tsdiff/synthetic/experiment.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | executable: synthetic/experiment.py 3 | name: tsdiff_syntetic_final 4 | output_dir: ../logs 5 | project_root_dir: ../ 6 | conda_environment: place_your_env_here 7 | 8 | slurm: 9 | experiments_per_job: 2 10 | sbatch_options: 11 | gres: gpu:1 # num GPUs 12 | mem: 16G # memory 13 | cpus-per-task: 2 # num cores 14 | time: 0-08:00 # max time, D-HH:MM 15 | partition: gpu_all 16 | 17 | 18 | fixed: 19 | epochs: 1000 20 | learning_rate: 1e-3 21 | batch_size: 256 22 | predict_gaussian_noise: False 23 | 24 | grid: 25 | seed: 26 | type: range 27 | min: 1 28 | max: 6 29 | step: 1 30 | 31 | dataset: 32 | type: choice 33 | options: 34 | - cir 35 | - lorenz 36 | - ou 37 | - predator_prey 38 | - sine 39 | - sink 40 | 41 | model: 42 | type: choice 43 | options: 44 | - feedforward 45 | - rnn 46 | - transformer 47 | 48 | gaussian: 49 | grid: 50 | diffusion: 51 | type: choice 52 | options: 53 | - GaussianDiffusion 54 | - ContinuousGaussianDiffusion 55 | 56 | gp: 57 | grid: 58 | diffusion: 59 | type: choice 60 | options: 61 | - GPDiffusion 62 | - ContinuousGPDiffusion 63 | 64 | gp_sigma: 65 | type: choice 66 | options: 67 | - 0.01 68 | - 0.1 69 | - 1 70 | 71 | predict_gaussian_noise: 72 | type: choice 73 | options: 74 | - True 75 | - False 76 | 77 | ou: 78 | grid: 79 | diffusion: 80 | type: choice 81 | options: 82 | - OUDiffusion 83 | - ContinuousOUDiffusion 84 | 85 | ou_theta: 86 | type: choice 87 | options: 88 | - 5 89 | - 0.5 90 | - 0.05 91 | 92 | predict_gaussian_noise: 93 | type: choice 94 | options: 95 | - True 96 | - False 97 | 98 | # ode: 99 | # fixed: 100 | # model: ode 101 | # epochs: 100 102 | # batch_size: 16 103 | # diffusion: None 104 | 105 | # nf: 106 | # fixed: 107 | # model: nf 108 | # diffusion: None 109 | -------------------------------------------------------------------------------- /tsdiff/synthetic/nf_model.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torchtyping import TensorType 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributions as td 8 | from pytorch_lightning import LightningModule 9 | import stribor as st 10 | from stribor.flows.cumsum import diff 11 | 12 | import tsdiff 13 | from tsdiff.utils import PositionalEncoding 14 | 15 | 16 | class Cumsum(st.ElementwiseTransform): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, x: TensorType[..., 'dim'], **kwargs) -> TensorType[..., 'dim']: 21 | y = x.cumsum(dim=-2) 22 | return y 23 | 24 | def inverse(self, y: TensorType[..., 'dim'], **kwargs) -> TensorType[..., 'dim']: 25 | x = y.diff(dim=-2, prepend=torch.zeros_like(y[...,:1,:])) 26 | return x 27 | 28 | def log_det_jacobian( 29 | self, x: TensorType[..., 'dim'], y: TensorType[..., 'dim'], **kwargs, 30 | ) -> TensorType[..., 1]: 31 | return torch.zeros_like(x[...,:1]).to(x) 32 | 33 | def log_diag_jacobian( 34 | self, x: TensorType[..., 'dim'], y: TensorType[..., 'dim'], **kwargs, 35 | ) -> TensorType[..., 'dim']: 36 | return torch.zeros_like(x).to(x) 37 | 38 | class Diff(Cumsum): 39 | forward = Cumsum.inverse 40 | inverse = Cumsum.forward 41 | 42 | 43 | class Model(nn.Module): 44 | def __init__(self, dim, hidden_dim): 45 | super().__init__() 46 | 47 | transforms = [ 48 | Cumsum() 49 | ] 50 | 51 | for i in range(12): 52 | transforms.append( 53 | st.Coupling( 54 | st.Affine( 55 | dim=dim, 56 | latent_net=st.net.MLP(dim + hidden_dim, [hidden_dim] * 2, 2 * dim), 57 | ), 58 | mask='none' if dim == 1 else f'ordered_{i % 2}', 59 | ) 60 | ) 61 | 62 | base_dist = td.Independent(td.Normal(torch.zeros(dim).cuda(), torch.ones(dim).cuda()), 1) 63 | self.flow = st.NormalizingFlow(base_dist, transforms) 64 | self.t_enc = PositionalEncoding(hidden_dim - 1, max_value=1) 65 | 66 | def log_prob( 67 | self, 68 | x: TensorType['B', 'L', 'D'], 69 | t: TensorType['B', 'L', 1], 70 | **kwargs, 71 | ) -> TensorType['B', 'L', 'D']: 72 | t = torch.cat([t, self.t_enc(t)], -1) 73 | log_prob = self.flow.log_prob(x, latent=t) 74 | return log_prob 75 | 76 | def sample(self, t, **kwargs): 77 | t = torch.cat([t, self.t_enc(t)], -1) 78 | samples = self.flow.sample(num_samples=t.shape[:-1], latent=t) 79 | return samples 80 | 81 | class NFModule(LightningModule): 82 | def __init__( 83 | self, 84 | # Data params 85 | dim: int, 86 | data_mean: torch.Tensor = None, 87 | data_std: torch.Tensor = None, 88 | max_t: float = None, 89 | # NN params 90 | hidden_dim: int = None, 91 | # Training params 92 | learning_rate: float = None, 93 | weight_decay: float = None, 94 | **kwargs, 95 | ): 96 | super().__init__() 97 | self.save_hyperparameters() 98 | 99 | self.dim = dim 100 | self.data_mean = data_mean 101 | self.data_std = data_std 102 | self.max_t = max_t 103 | 104 | self.learning_rate = learning_rate 105 | self.weight_decay = weight_decay 106 | 107 | self.model = Model( 108 | dim=dim, 109 | hidden_dim=hidden_dim, 110 | ) 111 | 112 | def forward(self, batch, log_name=None): 113 | x, t = self._normalize_batch(batch) 114 | loss = -self.model.log_prob(x, t).mean() 115 | if log_name is not None: 116 | self.log(log_name, loss) 117 | return loss 118 | 119 | def training_step(self, batch, batch_idx): 120 | return self.forward(batch, 'train_loss') 121 | 122 | def validation_step(self, batch, batch_idx): 123 | return self.forward(batch, 'val_loss') 124 | 125 | def test_step(self, batch, batch_idx): 126 | x, t = self._normalize_batch(batch) 127 | log_prob = self.model.log_prob(x, t) / t.shape[-2] 128 | log_prob = log_prob - self.data_std.log().sum() 129 | self.log('test_log_prob', log_prob.mean()) 130 | 131 | def configure_optimizers(self): 132 | return torch.optim.Adam( 133 | self.parameters(), 134 | lr=self.learning_rate, 135 | weight_decay=self.weight_decay, 136 | ) 137 | 138 | @torch.no_grad() 139 | def sample(self, t, **kwargs): 140 | t = t / self.max_t 141 | self.model = self.model.to(t) 142 | samples = self.model.sample(t) 143 | return samples * self.data_std.to(t) + self.data_mean.to(t) 144 | 145 | def _normalize_batch(self, batch): 146 | x, t = batch 147 | x = (x - self.data_mean.to(x)) / self.data_std.to(x) 148 | t = t / self.max_t 149 | return x, t 150 | -------------------------------------------------------------------------------- /tsdiff/synthetic/ode_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as td 5 | from pytorch_lightning import LightningModule 6 | from torchdiffeq import odeint_adjoint as odeint 7 | 8 | 9 | 10 | class ODEFunc(nn.Module): 11 | def __init__(self, dim, hidden_dim): 12 | super().__init__() 13 | self.net = nn.Sequential( 14 | nn.Linear(hidden_dim + 1, hidden_dim), 15 | nn.Tanh(), 16 | nn.Linear(hidden_dim, hidden_dim), 17 | nn.Tanh(), 18 | ) 19 | 20 | def forward(self, t, state): 21 | y, diff = state 22 | y = torch.cat([t * diff, y], -1) 23 | dy = self.net(y) * diff 24 | return dy, torch.zeros_like(diff).to(dy) 25 | 26 | 27 | class Model(nn.Module): 28 | def __init__(self, dim, hidden_dim): 29 | super().__init__() 30 | self.enc_net = nn.Sequential( 31 | nn.Linear(dim, hidden_dim), 32 | nn.GRU(hidden_dim, 2 * hidden_dim, num_layers=2, batch_first=True), 33 | ) 34 | self.net = ODEFunc(dim, hidden_dim) 35 | self.proj = nn.Sequential( 36 | nn.Linear(hidden_dim, hidden_dim), 37 | nn.Tanh(), 38 | nn.Linear(hidden_dim, 2 * dim), 39 | ) 40 | 41 | def encoder(self, x): 42 | _, h = self.enc_net(x) 43 | h = h[-1].unsqueeze(-2).repeat_interleave(x.shape[-2], dim=-2) 44 | mu, sigma = h.chunk(2, dim=-1) 45 | sigma = F.softplus(0.1 + 0.9 * sigma) 46 | return mu, sigma 47 | 48 | def decoder(self, z, t): 49 | z = odeint(self.net, (z, t), torch.Tensor([0, 1]).to(t)) 50 | z = z[0][1] # first state, solution at t=1 (with reparam.) 51 | x_mu, x_sigma = self.proj(z).chunk(2, dim=-1) 52 | x_sigma = F.softplus(0.1 + 0.9 * x_sigma) 53 | return x_mu, x_sigma 54 | 55 | 56 | class ODEModule(LightningModule): 57 | def __init__( 58 | self, 59 | # Data params 60 | dim: int, 61 | data_mean: torch.Tensor = None, 62 | data_std: torch.Tensor = None, 63 | max_t: float = None, 64 | # NN params 65 | hidden_dim: int = None, 66 | # Training params 67 | learning_rate: float = None, 68 | weight_decay: float = None, 69 | **kwargs, 70 | ): 71 | super().__init__() 72 | self.dim = dim 73 | self.hidden_dim = hidden_dim 74 | self.learning_rate = learning_rate 75 | self.weight_decay = weight_decay 76 | self.data_mean = data_mean 77 | self.data_std = data_std 78 | self.max_t = max_t 79 | 80 | self.save_hyperparameters() 81 | 82 | self.model = Model(dim, hidden_dim) 83 | 84 | def get_loss(self, x, t): 85 | mu, sigma = self.model.encoder(x.flip(dims=[-2])) 86 | z = torch.randn_like(mu) * sigma + mu 87 | 88 | x_mu, x_sigma = self.model.decoder(z, t) 89 | 90 | px = td.Normal(x_mu, x_sigma) 91 | pz = td.Normal(mu, sigma) 92 | qz = td.Normal(torch.zeros_like(mu), torch.ones_like(sigma)) 93 | 94 | kl = td.kl_divergence(pz, qz) 95 | 96 | # loss = -px.log_prob(x) + kl.sum(-1, keepdim=True) 97 | loss = (x - x_mu)**2 + kl.sum(-1, keepdim=True) 98 | return loss 99 | 100 | def forward(self, batch, log_name=None): 101 | x, t = self._normalize_batch(batch) 102 | loss = self.get_loss(x, t).mean() 103 | if log_name is not None: 104 | self.log(log_name, loss) 105 | return loss 106 | 107 | def training_step(self, batch, batch_idx): 108 | return self.forward(batch, 'train_loss') 109 | 110 | def validation_step(self, batch, batch_idx): 111 | return self.forward(batch, 'val_loss') 112 | 113 | def test_step(self, batch, batch_idx): 114 | x, t = self._normalize_batch(batch) 115 | log_prob = -self.get_loss(x, t) / x.shape[-2] 116 | log_prob = log_prob - self.data_std.log().sum() 117 | self.log('test_log_prob', log_prob.mean()) 118 | 119 | def configure_optimizers(self): 120 | return torch.optim.Adam( 121 | self.parameters(), 122 | lr=self.learning_rate, 123 | weight_decay=self.weight_decay, 124 | ) 125 | 126 | @torch.no_grad() 127 | def sample(self, t, **kwargs): 128 | t = t / self.max_t 129 | z = torch.randn(*t.shape[:-2], 1, self.hidden_dim).to(t) 130 | z = z.repeat_interleave(t.shape[-2], dim=-2) 131 | samples, _ = self.model.decoder(z, t) 132 | return samples * self.data_std.to(t) + self.data_mean.to(t) 133 | 134 | def _normalize_batch(self, batch): 135 | x, t = batch 136 | x = (x - self.data_mean.to(x)) / self.data_std.to(x) 137 | t = t / self.max_t 138 | return x, t 139 | -------------------------------------------------------------------------------- /tsdiff/synthetic/sde_model.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from torchtyping import TensorType 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | from torch import distributions 8 | import torchsde 9 | from pytorch_lightning import LightningModule 10 | 11 | def _stable_division(a, b, epsilon=1e-7): 12 | b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign()) 13 | return a / b 14 | 15 | class Encoder(nn.Module): 16 | def __init__(self, input_size, hidden_size, output_size): 17 | super(Encoder, self).__init__() 18 | self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size) 19 | self.lin = nn.Linear(hidden_size, output_size) 20 | 21 | def forward(self, inp): 22 | out, _ = self.gru(inp) 23 | out = self.lin(out) 24 | return out 25 | 26 | 27 | class LatentSDE(nn.Module): 28 | sde_type = "ito" 29 | noise_type = "diagonal" 30 | 31 | def __init__(self, data_size, latent_size, context_size, hidden_size): 32 | super(LatentSDE, self).__init__() 33 | # Encoder. 34 | self.encoder = Encoder(input_size=data_size, hidden_size=hidden_size, output_size=context_size) 35 | self.qz0_net = nn.Linear(context_size, latent_size + latent_size) 36 | 37 | # Decoder. 38 | self.f_net = nn.Sequential( 39 | nn.Linear(latent_size + context_size, hidden_size), 40 | nn.Softplus(), 41 | nn.Linear(hidden_size, hidden_size), 42 | nn.Softplus(), 43 | nn.Linear(hidden_size, latent_size), 44 | ) 45 | self.h_net = nn.Sequential( 46 | nn.Linear(latent_size, hidden_size), 47 | nn.Softplus(), 48 | nn.Linear(hidden_size, hidden_size), 49 | nn.Softplus(), 50 | nn.Linear(hidden_size, latent_size), 51 | ) 52 | # This needs to be an element-wise function for the SDE to satisfy diagonal noise. 53 | self.g_nets = nn.ModuleList( 54 | [ 55 | nn.Sequential( 56 | nn.Linear(1, hidden_size), 57 | nn.Softplus(), 58 | nn.Linear(hidden_size, 1), 59 | nn.Sigmoid() 60 | ) 61 | for _ in range(latent_size) 62 | ] 63 | ) 64 | self.projector = nn.Linear(latent_size, data_size) 65 | 66 | self.pz0_mean = nn.Parameter(torch.zeros(1, latent_size)) 67 | self.pz0_logstd = nn.Parameter(torch.zeros(1, latent_size)) 68 | 69 | self._ctx = None 70 | 71 | def contextualize(self, ctx): 72 | self._ctx = ctx # A tuple of tensors of sizes (T,), (T, batch_size, d). 73 | 74 | def f(self, t, y): 75 | ts, ctx = self._ctx 76 | i = min(torch.searchsorted(ts, t, right=True), len(ts) - 1) 77 | return self.f_net(torch.cat((y, ctx[i]), dim=1)) 78 | 79 | def h(self, t, y): 80 | return self.h_net(y) 81 | 82 | def g(self, t, y): # Diagonal diffusion. 83 | y = torch.split(y, split_size_or_sections=1, dim=1) 84 | out = [g_net_i(y_i) for (g_net_i, y_i) in zip(self.g_nets, y)] 85 | return torch.cat(out, dim=1) 86 | 87 | def forward(self, xs, ts, noise_std, adjoint=False, method="euler"): 88 | # Contextualization is only needed for posterior inference. 89 | xs = xs.transpose(0, 1) 90 | 91 | ctx = self.encoder(torch.flip(xs, dims=(0,))) 92 | ctx = torch.flip(ctx, dims=(0,)) 93 | self.contextualize((ts, ctx)) 94 | 95 | qz0_mean, qz0_logstd = self.qz0_net(ctx[0]).chunk(chunks=2, dim=1) 96 | z0 = qz0_mean + qz0_logstd.exp() * torch.randn_like(qz0_mean) 97 | 98 | if adjoint: 99 | # Must use the argument `adjoint_params`, since `ctx` is not part of the input to `f`, `g`, and `h`. 100 | adjoint_params = ( 101 | (ctx,) + 102 | tuple(self.f_net.parameters()) + tuple(self.g_nets.parameters()) + tuple(self.h_net.parameters()) 103 | ) 104 | zs, log_ratio = torchsde.sdeint_adjoint( 105 | self, z0, ts, adjoint_params=adjoint_params, dt=1e-2, logqp=True, method=method) 106 | else: 107 | zs, log_ratio = torchsde.sdeint(self, z0, ts, dt=1e-2, logqp=True, method=method) 108 | 109 | _xs = self.projector(zs) 110 | xs_dist = torch.distributions.Normal(loc=_xs, scale=noise_std) 111 | 112 | log_pxs = xs_dist.log_prob(xs).sum(dim=(0, 2)).mean(dim=0) 113 | 114 | qz0 = torch.distributions.Normal(loc=qz0_mean, scale=qz0_logstd.exp()) 115 | pz0 = torch.distributions.Normal(loc=self.pz0_mean, scale=self.pz0_logstd.exp()) 116 | logqp0 = torch.distributions.kl_divergence(qz0, pz0).sum(dim=1).mean(dim=0) 117 | logqp_path = log_ratio.sum(dim=0).mean(dim=0) 118 | return log_pxs, logqp0 + logqp_path 119 | 120 | @torch.no_grad() 121 | def sample(self, batch_size, ts, bm=None): 122 | eps = torch.randn(size=(batch_size, *self.pz0_mean.shape[1:]), device=self.pz0_mean.device) 123 | z0 = self.pz0_mean + self.pz0_logstd.exp() * eps 124 | zs = torchsde.sdeint(self, z0, ts, names={'drift': 'h'}, dt=1e-3, bm=bm) 125 | # Most of the times in ML, we don't sample the observation noise for visualization purposes. 126 | _xs = self.projector(zs) 127 | return _xs 128 | 129 | class SDEModule(LightningModule): 130 | def __init__( 131 | self, 132 | # Data params 133 | dim: int, 134 | data_mean: torch.Tensor = None, 135 | data_std: torch.Tensor = None, 136 | max_t: float = None, 137 | # NN params 138 | hidden_dim: int = None, 139 | # Training params 140 | learning_rate: float = None, 141 | weight_decay: float = None, 142 | **kwargs, 143 | ): 144 | super().__init__() 145 | self.dim = dim 146 | self.hidden_dim = hidden_dim 147 | self.learning_rate = learning_rate 148 | self.weight_decay = weight_decay 149 | self.data_mean = data_mean 150 | self.data_std = data_std 151 | self.max_t = max_t 152 | 153 | self.save_hyperparameters() 154 | 155 | self.latent_sde = LatentSDE(dim, hidden_dim, hidden_dim, hidden_dim) 156 | 157 | def forward(self, batch, log_name=None): 158 | x, t = self._normalize_batch(batch) 159 | t = t[0,:,0] 160 | log_prob, kl = self.latent_sde(x, t, noise_std=0.01) 161 | loss = -log_prob + kl 162 | loss = loss.mean() 163 | if log_name is not None: 164 | self.log(log_name, loss) 165 | return loss 166 | 167 | def training_step(self, batch, batch_idx): 168 | return self.forward(batch, 'train_loss') 169 | 170 | def validation_step(self, batch, batch_idx): 171 | return self.forward(batch, 'val_loss') 172 | 173 | def test_step(self, batch, batch_idx): 174 | log_prob = -self(batch) / batch[0].shape[-2] 175 | log_prob = log_prob - self.data_std.log().sum() 176 | self.log('test_log_prob', log_prob.mean()) 177 | 178 | def configure_optimizers(self): 179 | return torch.optim.Adam( 180 | self.latent_sde.parameters(), 181 | lr=self.learning_rate, 182 | weight_decay=self.weight_decay, 183 | ) 184 | 185 | @torch.no_grad() 186 | def sample(self, t, **kwargs): 187 | t = t / self.max_t 188 | t = t.cpu() 189 | 190 | eps = torch.randn(t.shape[0], 1).to(t) 191 | bm = torchsde.BrownianInterval( 192 | t0=t[0,0,0], 193 | t1=t[0,-1,0], 194 | size=(t.shape[0], 1), 195 | device=t.device, 196 | levy_area_approximation='space-time', 197 | ) 198 | 199 | samples = self.latent_sde.sample_q(ts=t[0,:,0], batch_size=t.shape[0], eps=eps, bm=bm).squeeze() 200 | samples = samples.transpose(0, 1).unsqueeze(-1) 201 | return samples * self.data_std.to(t) + self.data_mean.to(t) 202 | 203 | def _normalize_batch(self, batch): 204 | x, t = batch 205 | x = (x - self.data_mean.to(x)) / self.data_std.to(x) 206 | t = t / self.max_t 207 | return x, t 208 | -------------------------------------------------------------------------------- /tsdiff/synthetic/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import argparse 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 10 | 11 | from tsdiff.synthetic.data import DataModule 12 | from tsdiff.synthetic.diffusion_model import DiffusionModule 13 | from tsdiff.synthetic.ode_model import ODEModule 14 | from tsdiff.synthetic.nf_model import NFModule 15 | from tsdiff.synthetic.sde_model import SDEModule 16 | 17 | warnings.simplefilter(action='ignore', category=(np.VisibleDeprecationWarning)) 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | SAMPLE_DIR = Path(__file__).parents[2].resolve() / 'data/samples' 21 | SAMPLE_DIR.mkdir(exist_ok=True, parents=True) 22 | 23 | def train( 24 | *, 25 | seed: int, 26 | dataset: str, 27 | diffusion: str, 28 | model: str, 29 | gp_sigma: float = None, 30 | ou_theta: float = None, 31 | beta_start: float = None, 32 | beta_end: float = None, 33 | batch_size: int = 256, 34 | hidden_dim: int = 128, 35 | predict_gaussian_noise: bool = True, 36 | beta_fn: str = 'linear', 37 | discrete_num_steps: int = 100, 38 | continuous_t1: float = 1, 39 | loss_weighting: str = 'exponential', 40 | learning_rate: float = 1e-3, 41 | weight_decay: float = 0, 42 | epochs: int = 100, 43 | patience: int = 20, 44 | return_model: bool = False, 45 | ): 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | 49 | # Load data 50 | datamodule = DataModule(dataset, batch_size=batch_size) 51 | 52 | if diffusion is not None: 53 | if 'Continuous' in diffusion: 54 | beta_start, beta_end = 0.1, 20 55 | else: 56 | beta_start, beta_end = 1e-4, 20 / discrete_num_steps 57 | 58 | if model == 'ode': 59 | Module = ODEModule 60 | elif model == 'nf': 61 | Module = NFModule 62 | elif model == 'sde': 63 | Module = SDEModule 64 | else: 65 | Module = DiffusionModule 66 | 67 | # Load model 68 | module = Module( 69 | dim=datamodule.dim, 70 | data_mean=datamodule.x_mean, 71 | data_std=datamodule.x_std, 72 | max_t=datamodule.t_max, 73 | diffusion=diffusion, 74 | model=model, 75 | predict_gaussian_noise=predict_gaussian_noise, 76 | gp_sigma=gp_sigma, 77 | ou_theta=ou_theta, 78 | beta_fn=beta_fn, 79 | discrete_num_steps=discrete_num_steps, 80 | beta_start=beta_start, 81 | beta_end=beta_end, 82 | continuous_t1=continuous_t1, 83 | loss_weighting=loss_weighting, 84 | hidden_dim=hidden_dim, 85 | learning_rate=learning_rate, 86 | weight_decay=weight_decay, 87 | ) 88 | 89 | # Train 90 | checkpointing = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, filename='best-checkpoint') 91 | early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=patience) 92 | trainer = Trainer( 93 | gpus=1, 94 | auto_select_gpus=True, 95 | max_epochs=epochs, 96 | log_every_n_steps=1, 97 | enable_checkpointing=True, 98 | callbacks=[early_stopping, checkpointing], 99 | ) 100 | 101 | trainer.fit(module, train_dataloaders=datamodule.train_dataloader(), val_dataloaders=datamodule.val_dataloader()) 102 | 103 | # Load best model 104 | module = Module.load_from_checkpoint(checkpointing.best_model_path) 105 | 106 | # Evaluation 107 | metrics = trainer.test(module, datamodule.test_dataloader()) 108 | 109 | # Generate samples 110 | if seed == 1: 111 | t = datamodule.trainset[:1000][1].to(device) 112 | samples = module.sample(t=t, use_ode=True) 113 | np.save(SAMPLE_DIR / f'{dataset}-{diffusion}-{model}-{gp_sigma or ou_theta}-{predict_gaussian_noise}', samples.detach().cpu().numpy()) 114 | 115 | if return_model: 116 | return module, datamodule, trainer, metrics 117 | 118 | return metrics[0] 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser(description='Train forecasting model.') 123 | parser.add_argument('--seed', type=int, default=1) 124 | parser.add_argument('--dataset', type=str, choices=['cir', 'lorenz', 'ou', 'predator_prey', 'sine', 'sink']) 125 | parser.add_argument('--diffusion', type=str, choices=[ 126 | 'GaussianDiffusion', 'OUDiffusion', 'GPDiffusion', 127 | 'ContinuousGaussianDiffusion', 'ContinuousOUDiffusion', 'ContinuousGPDiffusion', 128 | ]) 129 | parser.add_argument('--model', type=str, choices=['feedforward', 'rnn', 'cnn', 'ode', 'transformer']) 130 | parser.add_argument('--gp_sigma', type=float, default=0.1) 131 | parser.add_argument('--ou_theta', type=float, default=0.5) 132 | parser.add_argument('--epochs', type=int, default=100) 133 | args = parser.parse_args() 134 | 135 | metrics = train(**args.__dict__) 136 | 137 | for key, value in metrics.items(): 138 | print(f'{key}:\t{value:.4f}') 139 | -------------------------------------------------------------------------------- /tsdiff/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbilos/tsdiff/32b23f2b7f5ec4d68bc533dda8a0096086cbd0ab/tsdiff/test/__init__.py -------------------------------------------------------------------------------- /tsdiff/test/test_beta_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from tsdiff.diffusion.beta_scheduler import BetaLinear 5 | 6 | 7 | def test_linear(): 8 | torch.manual_seed(123) 9 | 10 | f = BetaLinear(start=6, end=-3) 11 | 12 | t = torch.linspace(0, 1, 100).view(1, -1, 1).repeat(10, 1, 4) 13 | t = t.requires_grad_(True) 14 | 15 | beta = f(t) 16 | 17 | # Check boundaries 18 | assert (beta[:,0] == 6).all() 19 | assert (beta[:,-1] == -3).all() 20 | 21 | # Check integral 22 | beta_int = f.integral(t) 23 | 24 | beta_int_derivative = torch.autograd.grad(beta_int.sum(), t)[0] 25 | assert torch.allclose(beta, beta_int_derivative, atol=1e-6) 26 | -------------------------------------------------------------------------------- /tsdiff/test/test_ddpm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from tsdiff.diffusion import GaussianDiffusion, OUDiffusion, GPDiffusion 5 | from tsdiff.diffusion.beta_scheduler import BetaLinear 6 | 7 | 8 | @pytest.mark.parametrize('input_shape', [(1,), (1, 1), (10, 20, 4), (4, 3, 7, 2)]) 9 | def test_shapes(input_shape): 10 | num_steps = 1000 11 | x = torch.randn(*input_shape) 12 | i = torch.randint_like(x, 0, num_steps) 13 | 14 | diffusion = GaussianDiffusion(dim=input_shape[-1], beta_fn=BetaLinear(1e-4, 0.02), num_steps=num_steps) 15 | 16 | y, noise = diffusion(x, i) 17 | assert not torch.isnan(y).any() and not torch.isnan(noise).any() 18 | assert y.shape == noise.shape == x.shape 19 | 20 | 21 | def test_alpha_vs_beta(): 22 | torch.manual_seed(123) 23 | 24 | N = 100_000 # Number of samples 25 | max_steps = 1000 26 | steps = 300 # Number of steps after which we calculate statistics 27 | 28 | x = torch.randn(N, 1).square() + 1 # Sample from distribution that is not normal 29 | i = torch.Tensor([steps]).repeat(N).unsqueeze(-1) 30 | 31 | diffusion = GaussianDiffusion(dim=1, beta_fn=BetaLinear(1e-4, 0.02), num_steps=max_steps) 32 | 33 | y, _ = diffusion(x, i) 34 | 35 | y_ = x.clone() 36 | for j in range(steps): 37 | y_ = torch.sqrt(1 - diffusion.betas[j]) * y_ + torch.sqrt(diffusion.betas[j]) * torch.randn_like(y_) 38 | 39 | assert torch.allclose(y.mean(), y_.mean(), atol=0.1) 40 | assert torch.allclose(y.std(), y_.std(), atol=0.1) 41 | 42 | 43 | @pytest.mark.parametrize('diffusion', [OUDiffusion, GPDiffusion]) 44 | @pytest.mark.parametrize('input_shape', [(1, 1), (1, 1, 1), (10, 20, 4), (4, 3, 7, 2)]) 45 | @pytest.mark.parametrize('predict_gaussian_noise', [True, False]) 46 | def test_time_diffusion_shapes(diffusion, input_shape, predict_gaussian_noise): 47 | num_steps = 1000 48 | x = torch.randn(*input_shape) 49 | t = torch.rand(*input_shape[:-1], 1) 50 | i = torch.randint_like(x, 0, num_steps) 51 | 52 | diffusion = diffusion( 53 | dim=input_shape[-1], 54 | beta_fn=BetaLinear(1e-4, 0.02), 55 | predict_gaussian_noise=predict_gaussian_noise, 56 | num_steps=num_steps, 57 | ) 58 | 59 | y, noise = diffusion(x, i=i, t=t) 60 | assert not torch.isnan(y).any() and not torch.isnan(noise).any() 61 | assert y.shape == noise.shape == x.shape 62 | 63 | 64 | @pytest.mark.parametrize('diffusion', [OUDiffusion, GPDiffusion]) 65 | @pytest.mark.parametrize('dim', [1, 3]) 66 | def test_time_diffusion_alpha_vs_beta(diffusion, dim): 67 | torch.manual_seed(123) 68 | 69 | N = 10_000 # Number of samples 70 | max_steps = 1000 71 | steps = 300 # Number of steps after which we calculate statistics 72 | 73 | x = torch.randn(N, 10, dim).square() + 1 # Sample from distribution that is not unit normal 74 | t = torch.linspace(0, 1, 10).view(1, -1, 1).repeat(N, 1, 1) 75 | i = torch.Tensor([steps]).view(1, 1, 1).repeat(N, 10, 1) 76 | 77 | diffusion = diffusion( 78 | dim=dim, 79 | beta_fn=BetaLinear(1e-4, 0.02), 80 | num_steps=max_steps, 81 | predict_gaussian_noise=True, 82 | ) 83 | 84 | y, _ = diffusion(x, t=t, i=i) 85 | 86 | y_ = x.clone() 87 | for j in range(steps): 88 | y_ = torch.sqrt(1 - diffusion.betas[j]) * y_ + torch.sqrt(diffusion.betas[j]) * diffusion.noise(t=t) 89 | 90 | assert torch.allclose(y.mean(0), y_.mean(0), atol=0.1) 91 | assert torch.allclose(y.std(0), y_.std(0), atol=0.1) 92 | 93 | 94 | @pytest.mark.parametrize('diffusion', [GaussianDiffusion, OUDiffusion, GPDiffusion]) 95 | @pytest.mark.parametrize('parallel_elbo', [True, False]) 96 | @pytest.mark.parametrize('predict_gaussian_noise', [True, False]) 97 | def test_latent_inputs(diffusion, parallel_elbo, predict_gaussian_noise): 98 | torch.manual_seed(123) 99 | 100 | max_steps = 100 101 | N, T, D, H = 32, 10, 4, 64 102 | 103 | x = torch.randn(N, T, D) 104 | t = torch.rand(N, T, 1).sort(dim=1)[0] 105 | latent = torch.randn(N, T, H) 106 | 107 | class Model(nn.Module): 108 | def __init__(self): 109 | super().__init__() 110 | self.net = nn.Sequential( 111 | nn.Linear(D + H + 2, H), 112 | nn.Tanh(), 113 | nn.Linear(H, D), 114 | ) 115 | 116 | def forward(self, x, t, i, latent): 117 | x = torch.cat([x, t, i, latent], -1) 118 | return self.net(x) 119 | 120 | model = Model() 121 | diffusion = diffusion( 122 | dim=D, 123 | beta_fn=BetaLinear(1e-4, 2 / max_steps * 10), 124 | num_steps=max_steps, 125 | predict_gaussian_noise=predict_gaussian_noise, 126 | parallel_elbo=parallel_elbo, 127 | ) 128 | 129 | samples = diffusion.sample(model, num_samples=(N, T), t=t, latent=latent) 130 | assert samples.shape == (N, T, D) 131 | assert not torch.any(torch.isnan(samples)) 132 | 133 | elbo = diffusion.log_prob(model, x, t=t, latent=latent, num_samples=30) 134 | assert elbo.shape == (N, 1) 135 | assert not torch.any(torch.isnan(elbo)) 136 | 137 | # Empirical result to catch bigger errors 138 | elbo = elbo.mean() 139 | assert elbo > -36 and elbo < -30 140 | 141 | loss = diffusion.get_loss(model, x, t=t, latent=latent) 142 | assert loss.shape == (N, T, D) 143 | assert not torch.any(torch.isnan(loss)) 144 | -------------------------------------------------------------------------------- /tsdiff/test/test_sde_diffusion.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | import torchsde 5 | 6 | from tsdiff.diffusion import ( 7 | GaussianDiffusion, 8 | OUDiffusion, 9 | ContinuousGaussianDiffusion, 10 | ContinuousGPDiffusion, 11 | ContinuousOUDiffusion, 12 | ) 13 | from tsdiff.diffusion.beta_scheduler import BetaLinear 14 | from tsdiff.diffusion.noise import OrnsteinUhlenbeck 15 | 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | 18 | 19 | def test_sde_diffusion_as_ddpm(): 20 | num_steps = 1000 21 | step = 300 22 | N = 100_000 23 | beta_start, beta_end = 1e-4, 2 / num_steps * 10 24 | 25 | x = torch.randn(N, 1).square() + 1 26 | i = torch.Tensor([step]).repeat(N).unsqueeze(-1) 27 | 28 | beta_fn = BetaLinear(beta_start, beta_end) 29 | rescaled_beta_fn = BetaLinear(beta_start / num_steps, beta_end / num_steps) 30 | 31 | ddpm = GaussianDiffusion(dim=1, beta_fn=beta_fn, num_steps=num_steps, predict_gaussian_noise=False) 32 | sdediff = ContinuousGaussianDiffusion(dim=1, beta_fn=rescaled_beta_fn, predict_gaussian_noise=False) 33 | 34 | y1, _ = ddpm(x, i) 35 | y2, _ = sdediff(x, i) 36 | 37 | assert torch.allclose(y1.mean(), y2.mean(), atol=0.1) 38 | assert torch.allclose(y1.std(), y2.std(), atol=0.1) 39 | 40 | 41 | def test_sde_ou_diffusion_as_ddpm(): 42 | num_steps = 1000 43 | step = 300 44 | N = 100_000 45 | beta_start, beta_end = 1e-4, 2 / num_steps * 10 46 | 47 | x = torch.randn(N, 10, 1).square() + 1 48 | t = torch.linspace(0, 1, 10).view(1, -1, 1).repeat(N, 1, 1) 49 | i = torch.Tensor([step]).view(1, 1, 1).repeat(N, 10, 1) 50 | 51 | beta_fn = BetaLinear(beta_start, beta_end) 52 | rescaled_beta_fn = BetaLinear(beta_start / num_steps, beta_end / num_steps) 53 | 54 | ddpm = OUDiffusion(dim=1, beta_fn=beta_fn, num_steps=num_steps, predict_gaussian_noise=False) 55 | sdediff = ContinuousOUDiffusion(dim=1, beta_fn=rescaled_beta_fn, t1=num_steps, predict_gaussian_noise=False) 56 | 57 | y1, _ = ddpm(x, t=t, i=i) 58 | y2, _ = sdediff(x, t=t, i=i) 59 | 60 | assert torch.allclose(y1.mean(0), y2.mean(0), atol=0.1) 61 | assert torch.allclose(y1.std(0), y2.std(0), atol=0.1) 62 | 63 | 64 | def test_sde_diffusion_score(): 65 | torch.manual_seed(123) 66 | 67 | N = 10_000 68 | 69 | x = torch.randn(1).repeat(N).unsqueeze(-1).square() + 1 70 | i = torch.Tensor([0.3]).repeat(N).unsqueeze(-1) 71 | 72 | sdediff = ContinuousGaussianDiffusion(dim=1, beta_fn=BetaLinear(0, 10)) 73 | 74 | y, noise, mean, std, _ = sdediff(x, i, _return_all=True, predict_gaussian_noise=False) 75 | score = -noise / std 76 | y = y.requires_grad_(True) 77 | 78 | mean = mean.unique() 79 | std = std.unique() 80 | 81 | assert len(mean) == 1 and len(std) == 1 82 | assert torch.allclose(y.mean(), mean, atol=0.05) 83 | assert torch.allclose(y.std(), std, atol=0.05) 84 | 85 | dist = torch.distributions.MultivariateNormal(mean, std.unsqueeze(-1).square()) 86 | log_prob = dist.log_prob(y) 87 | 88 | true_score = torch.autograd.grad(log_prob.sum(), y)[0] 89 | 90 | assert torch.allclose(true_score, score, atol=1e-4) 91 | 92 | 93 | @pytest.mark.parametrize('Diffusion', [ContinuousGPDiffusion, ContinuousOUDiffusion]) 94 | @pytest.mark.parametrize('predict_gaussian_noise', [True, False]) 95 | def test_sde_time_diffusion_score(Diffusion, predict_gaussian_noise): 96 | np.random.seed(132) 97 | torch.manual_seed(123) 98 | 99 | N = 10_000 100 | T = 10 101 | t1 = 1 102 | 103 | x = torch.randn(1, T, 1).repeat(N, 1, 1).square().to(device) + 1 104 | t = torch.linspace(0, 1, T).view(1, -1, 1).repeat(N, 1, 1).to(device) 105 | 106 | class SDE(torch.nn.Module): 107 | sde_type = 'ito' 108 | noise_type = 'general' 109 | 110 | def __init__(self, beta_fn, cov): 111 | super().__init__() 112 | self.beta_fn = beta_fn 113 | self.L = torch.linalg.cholesky(cov) 114 | 115 | def f(self, t, x): 116 | return -0.5 * self.beta_fn(t) * x 117 | 118 | def g(self, t, x): 119 | return self.beta_fn(t).sqrt() * self.L 120 | 121 | 122 | for ratio in [0.01, 0.3, 0.9]: 123 | i = torch.Tensor([ratio * t1]).view(1, 1, 1).repeat(N, T, 1).to(device) 124 | 125 | beta_fn = BetaLinear(0.1, 10) 126 | sdediff = Diffusion( 127 | dim=1, 128 | t1=t1, 129 | beta_fn=beta_fn, 130 | predict_gaussian_noise=predict_gaussian_noise, 131 | ).to(device) 132 | 133 | y, diff_noise, diff_mean, diff_std, diff_cov = sdediff(x, t=t, i=i, _return_all=True) 134 | L = torch.linalg.cholesky(diff_cov[0]) 135 | 136 | y = y.requires_grad_(True) 137 | 138 | assert len(diff_mean.unique()) == 10 139 | assert torch.all(diff_cov[0] == diff_cov[1]) 140 | 141 | diff_mean = diff_mean[0,:,0] 142 | diff_cov_tilde = (diff_cov * diff_std**2)[0] 143 | 144 | def model(*args, noise=None, **kwargs): 145 | # "Fake" model that perfectly predicts the noise 146 | # In case `predict_gaussian_noise=True`, undo covariance 147 | if predict_gaussian_noise: 148 | noise = torch.linalg.inv(L) @ noise 149 | return noise 150 | 151 | model_score = sdediff._get_score(model, x, i=i, t=t, L=L, noise=diff_noise) 152 | 153 | # Statistics from diffusion function vs. empirical covariance of diffused values 154 | empirical_diff_mean = y.mean([0, 2]) 155 | empirical_diff_cov = torch.cov(y.squeeze(-1).T) 156 | assert torch.allclose(empirical_diff_mean, diff_mean, atol=0.05) 157 | assert torch.allclose(empirical_diff_cov, diff_cov_tilde, atol=0.05) 158 | 159 | # Empirical score vs. score from diffusion function 160 | diff_dist = torch.distributions.MultivariateNormal(diff_mean, diff_cov_tilde) 161 | log_prob = diff_dist.log_prob(y.squeeze(-1)) 162 | empirical_diff_score = torch.autograd.grad(log_prob.sum(), y)[0] 163 | assert torch.allclose(empirical_diff_score, model_score, atol=0.05, rtol=0.05) 164 | 165 | # True theoretical covariance vs. calculated in diffusion function 166 | time_cov = sdediff.noise.covariance(t) 167 | assert torch.allclose(time_cov[0], diff_cov, atol=0.05) 168 | true_cov = (1 - torch.exp(-beta_fn.integral(i[0][0]))) * time_cov[0] 169 | assert torch.allclose(true_cov, diff_cov_tilde, atol=0.05) 170 | 171 | # SDE sequential computation vs. direct covariance computation 172 | sde = SDE(beta_fn=beta_fn, cov=time_cov) 173 | times = torch.Tensor([0, ratio * t1]).to(y) 174 | with torch.no_grad(): 175 | true_y = torchsde.sdeint(sde, x.squeeze(-1), times, dt=5e-4)[-1] 176 | sde_cov = torch.cov(true_y.T) 177 | assert torch.allclose(true_cov, sde_cov, atol=0.05) 178 | -------------------------------------------------------------------------------- /tsdiff/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dotdict import * 2 | from .exception import * 3 | from .feedforward import * 4 | from .positional_encoding import * 5 | from .trainer import * 6 | from .epsilon_theta import * 7 | -------------------------------------------------------------------------------- /tsdiff/utils/dotdict.py: -------------------------------------------------------------------------------- 1 | class dotdict(dict): 2 | """ Dot notation access to dict attributes """ 3 | __getattr__ = dict.get 4 | __setattr__ = dict.__setitem__ 5 | __delattr__ = dict.__delitem__ 6 | -------------------------------------------------------------------------------- /tsdiff/utils/epsilon_theta.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DiffusionEmbedding(nn.Module): 9 | def __init__(self, dim, proj_dim, max_steps=500): 10 | super().__init__() 11 | self.register_buffer( 12 | "embedding", self._build_embedding(dim, max_steps), persistent=False 13 | ) 14 | self.projection1 = nn.Linear(dim * 2, proj_dim) 15 | self.projection2 = nn.Linear(proj_dim, proj_dim) 16 | 17 | def forward(self, diffusion_step): 18 | x = self.embedding[diffusion_step] 19 | x = self.projection1(x) 20 | x = F.silu(x) 21 | x = self.projection2(x) 22 | x = F.silu(x) 23 | return x 24 | 25 | def _build_embedding(self, dim, max_steps): 26 | steps = torch.arange(max_steps).unsqueeze(1) # [T,1] 27 | dims = torch.arange(dim).unsqueeze(0) # [1,dim] 28 | table = steps * 10.0 ** (dims * 4.0 / dim) # [T,dim] 29 | table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) 30 | return table 31 | 32 | 33 | class ResidualBlock(nn.Module): 34 | def __init__(self, hidden_size, residual_channels, dilation): 35 | super().__init__() 36 | self.dilated_conv = nn.Conv1d( 37 | residual_channels, 38 | 2 * residual_channels, 39 | 3, 40 | padding=dilation, 41 | dilation=dilation, 42 | padding_mode="circular", 43 | ) 44 | self.diffusion_projection = nn.Linear(hidden_size, residual_channels) 45 | self.conditioner_projection = nn.Conv1d( 46 | 1, 2 * residual_channels, 1, padding=2, padding_mode="circular" 47 | ) 48 | self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) 49 | 50 | nn.init.kaiming_normal_(self.conditioner_projection.weight) 51 | nn.init.kaiming_normal_(self.output_projection.weight) 52 | 53 | def forward(self, x, conditioner, diffusion_step): 54 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 55 | conditioner = self.conditioner_projection(conditioner) 56 | 57 | y = x + diffusion_step 58 | y = self.dilated_conv(y) + conditioner 59 | 60 | gate, filter = torch.chunk(y, 2, dim=1) 61 | y = torch.sigmoid(gate) * torch.tanh(filter) 62 | 63 | y = self.output_projection(y) 64 | y = F.leaky_relu(y, 0.4) 65 | residual, skip = torch.chunk(y, 2, dim=1) 66 | return (x + residual) / math.sqrt(2.0), skip 67 | 68 | 69 | class CondUpsampler(nn.Module): 70 | def __init__(self, cond_length, target_dim): 71 | super().__init__() 72 | self.linear1 = nn.Linear(cond_length, target_dim // 2) 73 | self.linear2 = nn.Linear(target_dim // 2, target_dim) 74 | 75 | def forward(self, x): 76 | x = self.linear1(x) 77 | x = F.leaky_relu(x, 0.4) 78 | x = self.linear2(x) 79 | x = F.leaky_relu(x, 0.4) 80 | return x 81 | 82 | 83 | class EpsilonTheta(nn.Module): 84 | def __init__( 85 | self, 86 | target_dim, 87 | cond_length, 88 | time_emb_dim=16, 89 | residual_layers=8, 90 | residual_channels=8, 91 | dilation_cycle_length=2, 92 | residual_hidden=64, 93 | ): 94 | super().__init__() 95 | self.input_projection = nn.Conv1d( 96 | 1, residual_channels, 1, padding=2, padding_mode="circular" 97 | ) 98 | self.diffusion_embedding = DiffusionEmbedding( 99 | time_emb_dim, proj_dim=residual_hidden 100 | ) 101 | self.cond_upsampler = CondUpsampler( 102 | target_dim=target_dim, cond_length=cond_length 103 | ) 104 | self.residual_layers = nn.ModuleList( 105 | [ 106 | ResidualBlock( 107 | residual_channels=residual_channels, 108 | dilation=2 ** (i % dilation_cycle_length), 109 | hidden_size=residual_hidden, 110 | ) 111 | for i in range(residual_layers) 112 | ] 113 | ) 114 | self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3) 115 | self.output_projection = nn.Conv1d(residual_channels, 1, 3) 116 | 117 | nn.init.kaiming_normal_(self.input_projection.weight) 118 | nn.init.kaiming_normal_(self.skip_projection.weight) 119 | nn.init.zeros_(self.output_projection.weight) 120 | 121 | def forward(self, inputs, time, cond): 122 | x = self.input_projection(inputs) 123 | x = F.leaky_relu(x, 0.4) 124 | 125 | diffusion_step = self.diffusion_embedding(time) 126 | cond_up = self.cond_upsampler(cond) 127 | skip = [] 128 | for layer in self.residual_layers: 129 | x, skip_connection = layer(x, cond_up, diffusion_step) 130 | skip.append(skip_connection) 131 | 132 | x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) 133 | x = self.skip_projection(x) 134 | x = F.leaky_relu(x, 0.4) 135 | x = self.output_projection(x) 136 | return x 137 | -------------------------------------------------------------------------------- /tsdiff/utils/exception.py: -------------------------------------------------------------------------------- 1 | class NotSupportedModelNoiseCombination(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /tsdiff/utils/feedforward.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable 2 | from torchtyping import TensorType 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class FeedForward(nn.Module): 8 | def __init__( 9 | self, 10 | in_dim: int, 11 | hidden_dims: List[int], 12 | out_dim: int, 13 | activation: Callable = nn.ReLU(), 14 | final_activation: Callable = None, 15 | ): 16 | super().__init__() 17 | 18 | hidden_dims = hidden_dims[:] 19 | hidden_dims.append(out_dim) 20 | 21 | layers = [nn.Linear(in_dim, hidden_dims[0])] 22 | 23 | for i in range(len(hidden_dims) - 1): 24 | layers.append(activation) 25 | layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) 26 | 27 | if final_activation is not None: 28 | layers.append(final_activation) 29 | 30 | self.net = nn.Sequential(*layers) 31 | 32 | def forward(self, x: TensorType[..., 'in_dim']) -> TensorType[..., 'out_dim']: 33 | return self.net(x) 34 | -------------------------------------------------------------------------------- /tsdiff/utils/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torchtyping import TensorType 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__(self, dim: int, max_value: float): 9 | super().__init__() 10 | self.max_value = max_value 11 | 12 | linear_dim = dim // 2 13 | periodic_dim = dim - linear_dim 14 | 15 | self.scale = torch.exp(-2 * torch.arange(0, periodic_dim).float() * math.log(self.max_value) / periodic_dim) 16 | self.shift = torch.zeros(periodic_dim) 17 | self.shift[::2] = 0.5 * math.pi 18 | 19 | self.linear_proj = nn.Linear(1, linear_dim) 20 | 21 | def forward(self, t: TensorType[..., 'length', 1]) -> TensorType[..., 'length', 'dim']: 22 | periodic = torch.sin(t * self.scale.to(t) + self.shift.to(t)) 23 | linear = self.linear_proj(t / self.max_value) 24 | 25 | return torch.cat([linear, periodic], -1) 26 | -------------------------------------------------------------------------------- /tsdiff/utils/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from copy import deepcopy 4 | from tqdm.auto import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim import Adam 9 | from torch.optim.lr_scheduler import OneCycleLR 10 | from torch.utils.data import DataLoader 11 | 12 | from gluonts.core.component import validated 13 | from pts import Trainer 14 | 15 | 16 | class TrainerForecasting(Trainer): 17 | @validated() 18 | def __init__( 19 | self, 20 | epochs: int = 100, 21 | batch_size: int = 32, 22 | num_batches_per_epoch: int = 50, 23 | learning_rate: float = 1e-3, 24 | weight_decay: float = 1e-6, 25 | maximum_learning_rate: float = 1e-2, 26 | clip_gradient: Optional[float] = None, 27 | patience: int = None, 28 | device: Optional[Union[torch.device, str]] = None, 29 | **kwargs, 30 | ) -> None: 31 | self.epochs = epochs 32 | self.batch_size = batch_size 33 | self.num_batches_per_epoch = num_batches_per_epoch 34 | self.learning_rate = learning_rate 35 | self.weight_decay = weight_decay 36 | self.maximum_learning_rate = maximum_learning_rate 37 | self.clip_gradient = clip_gradient 38 | self.patience = patience 39 | self.device = device 40 | 41 | def __call__( 42 | self, 43 | net: nn.Module, 44 | train_iter: DataLoader, 45 | validation_iter: Optional[DataLoader] = None, 46 | ) -> None: 47 | 48 | optimizer = Adam(net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 49 | 50 | lr_scheduler = OneCycleLR( 51 | optimizer, 52 | max_lr=self.maximum_learning_rate, 53 | steps_per_epoch=self.num_batches_per_epoch, 54 | epochs=self.epochs, 55 | ) 56 | 57 | # Early stopping setup 58 | best_loss = float('inf') 59 | waiting = 0 60 | best_net = deepcopy(net.state_dict()) 61 | 62 | # Training loop 63 | for epoch_no in range(self.epochs): 64 | # mark epoch start time 65 | cumm_epoch_loss = 0.0 66 | total = self.num_batches_per_epoch - 1 67 | 68 | # training loop 69 | with tqdm(train_iter, total=total) as it: 70 | for batch_no, data_entry in enumerate(it, start=1): 71 | 72 | optimizer.zero_grad() 73 | 74 | inputs = [v.to(self.device) for v in data_entry.values()] 75 | 76 | loss = net(*inputs) 77 | 78 | if isinstance(loss, (list, tuple)): 79 | loss = loss[0] 80 | 81 | loss.backward() 82 | 83 | if self.clip_gradient is not None: 84 | nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) 85 | 86 | optimizer.step() 87 | lr_scheduler.step() 88 | 89 | cumm_epoch_loss += loss.item() 90 | avg_epoch_loss = cumm_epoch_loss / batch_no 91 | it.set_postfix( 92 | { 93 | "epoch": f"{epoch_no + 1}/{self.epochs}", 94 | "avg_loss": avg_epoch_loss, 95 | }, 96 | refresh=False, 97 | ) 98 | 99 | if self.num_batches_per_epoch == batch_no: 100 | break 101 | it.close() 102 | 103 | # validation loop 104 | if validation_iter is not None: 105 | cumm_epoch_loss_val = 0.0 106 | with tqdm(validation_iter, total=total, colour="green") as it: 107 | 108 | for batch_no, data_entry in enumerate(it, start=1): 109 | inputs = [v.to(self.device) for v in data_entry.values()] 110 | with torch.no_grad(): 111 | output = net(*inputs) 112 | if isinstance(output, (list, tuple)): 113 | loss = output[0] 114 | else: 115 | loss = output 116 | 117 | cumm_epoch_loss_val += loss.item() 118 | avg_epoch_loss_val = cumm_epoch_loss_val / batch_no 119 | it.set_postfix( 120 | { 121 | "epoch": f"{epoch_no + 1}/{self.epochs}", 122 | "avg_loss": avg_epoch_loss, 123 | "avg_val_loss": avg_epoch_loss_val, 124 | }, 125 | refresh=False, 126 | ) 127 | 128 | if self.num_batches_per_epoch == batch_no: 129 | break 130 | it.close() 131 | 132 | # Early stopping logic 133 | if avg_epoch_loss_val < best_loss: 134 | best_loss = avg_epoch_loss_val 135 | best_net = deepcopy(net.state_dict()) 136 | waiting = 0 137 | elif waiting > self.patience: 138 | print(f'Early stopping at epoch {epoch_no}') 139 | break 140 | else: 141 | waiting += 1 142 | 143 | # mark epoch end time and log time cost of current epoch 144 | 145 | net.load_state_dict(best_net) 146 | 147 | # python -m tsdiff.train_forecasting --seed 1 --dataset electricity_nips --network timegrad_rnn --noise ou --epochs 100 148 | --------------------------------------------------------------------------------