├── .gitignore ├── LICENSE ├── README.md ├── data ├── air_subsampled_train_perc_10.npz ├── air_subsampled_train_perc_100.npz ├── air_subsampled_train_perc_2.npz ├── air_subsampled_train_perc_20.npz ├── air_subsampled_train_perc_5.npz ├── energy_subsampled_train_perc_10.npz ├── energy_subsampled_train_perc_100.npz ├── energy_subsampled_train_perc_2.npz ├── energy_subsampled_train_perc_20.npz ├── energy_subsampled_train_perc_5.npz ├── sine_subsampled_train_perc_10.npz ├── sine_subsampled_train_perc_100.npz ├── sine_subsampled_train_perc_2.npz ├── sine_subsampled_train_perc_20.npz ├── sine_subsampled_train_perc_5.npz ├── stockv_subsampled_train_perc_10.npz ├── stockv_subsampled_train_perc_100.npz ├── stockv_subsampled_train_perc_2.npz ├── stockv_subsampled_train_perc_20.npz └── stockv_subsampled_train_perc_5.npz ├── outputs ├── gen_data │ └── .gitignore ├── models │ └── .gitignore └── tsne │ └── .gitignore ├── requirements.txt └── src ├── __init__.py ├── compare_plot.ipynb ├── config ├── __init__.py └── hyperparameters.yaml ├── data_utils.py ├── paths.py ├── vae ├── __init__.py ├── timevae.py ├── vae_base.py ├── vae_conv_model.py ├── vae_dense_model.py └── vae_utils.py ├── vae_pipeline.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # Environments 8 | .env 9 | .venv 10 | env/ 11 | venv/ 12 | ENV/ 13 | env.bak/ 14 | venv.bak/ 15 | 16 | # Jupyter Notebook 17 | **/.ipynb_checkpoints/ 18 | 19 | tmp/ 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yunzhe Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of TimeVAE 2 | 3 | This repository provides an unofficial PyTorch implementation of the TimeVAE model for generating synthetic time-series data, along with two baseline models: a dense VAE and a convolutional VAE. The file structures and usage closely follow the original TensorFlow implementation to ensure consistency and ease of use. 4 | 5 | Original Tensorflow Repo: [TimeVAE for Synthetic Timeseries Data Generation](https://github.com/abudesai/timeVAE) 6 | 7 | ## Paper Reference 8 | 9 | For a detailed explanation of the methodology, please refer to the original paper: [TIMEVAE: A VARIATIONAL AUTO-ENCODER FOR MULTIVARIATE TIME SERIES GENERATION](https://arxiv.org/abs/2111.08095). 10 | 11 | ## Comparison 12 | 13 | The PyTorch model was trained and evaluated using the provided dataset and the default hyperparameters as in the TensorFlow implementation, each for 1000 epochs, achieving similar convergence (see Figure 4 from the original paper). The plotting script can be found in [src/compare_plot.ipynb](https://github.com/wangyz1999/timeVAE-pytorch/blob/main/src/compare_plot.ipynb). 14 | 15 | ![TSNE TIMEVAE](https://github.com/user-attachments/assets/887a776c-7df6-46f4-9a16-301eb6021967) 16 | 17 | 18 | ## Installation 19 | 20 | Create a virtual environment and install dependencies: 21 | 22 | ```bash 23 | python -m venv venv 24 | source venv/bin/activate # On Windows use `venv\Scripts\activate` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Usage 29 | 30 | 1. **Prepare Data**: Save your data as a numpy array with shape `(n_samples, n_timesteps, n_features)` in the `./data/` folder in `.npz` format. The filename without the extension will be used as the dataset name (e.g., `my_data.npz` will be referred to as `my_data`). Alternatively, use one of the existing datasets provided in the `./data/` folder. 31 | 32 | 2. **Configure Pipeline**: 33 | 34 | - Update the dataset name and model type in `./src/vae_pipeline.py`: 35 | ```python 36 | dataset = "my_data" # Your dataset name 37 | model_name = "timeVAE" # Choose between vae_dense, vae_conv, or timeVAE 38 | ``` 39 | - Set hyperparameters in `./src/config/hyperparameters.yaml`. Key hyperparameters include `latent_dim`, `hidden_layer_sizes`, `reconstruction_wt`, and `batch_size`. 40 | 41 | 3. **Run the Script**: 42 | 43 | ```bash 44 | python src/vae_pipeline.py 45 | ``` 46 | 47 | 4. **Outputs**: 48 | - Trained models are saved in `./outputs/models//`. 49 | - Generated synthetic data is saved in `./outputs/gen_data//` in `.npz` format. 50 | - t-SNE plots are saved in `./outputs/tsne//` in `.png` format. 51 | 52 | ## Hyperparameters 53 | 54 | The four key hyperparameters for the VAE models are: 55 | 56 | - `latent_dim`: Number of latent dimensions (default: 8). 57 | - `hidden_layer_sizes`: Number of hidden units or filters (default: [50, 100, 200]). 58 | - `reconstruction_wt`: Weight for the reconstruction loss (default: 3.0). 59 | - `batch_size`: Training batch size (default: 16). 60 | 61 | For `timeVAE`: 62 | 63 | - `trend_poly`: Degree of polynomial trend component (default: 0). 64 | - `custom_seas`: Custom seasonalities as a list of tuples (default: null). 65 | - `use_residual_conn`: Use residual connection (default: true). 66 | 67 | > The default settings for the timeVAE model set it to operate as the base model without interpretable components. 68 | 69 | ## License 70 | 71 | This project is licensed under the MIT License. See the `LICENSE` file for details. 72 | -------------------------------------------------------------------------------- /data/air_subsampled_train_perc_10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/air_subsampled_train_perc_10.npz -------------------------------------------------------------------------------- /data/air_subsampled_train_perc_100.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/air_subsampled_train_perc_100.npz -------------------------------------------------------------------------------- /data/air_subsampled_train_perc_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/air_subsampled_train_perc_2.npz -------------------------------------------------------------------------------- /data/air_subsampled_train_perc_20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/air_subsampled_train_perc_20.npz -------------------------------------------------------------------------------- /data/air_subsampled_train_perc_5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/air_subsampled_train_perc_5.npz -------------------------------------------------------------------------------- /data/energy_subsampled_train_perc_10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/energy_subsampled_train_perc_10.npz -------------------------------------------------------------------------------- /data/energy_subsampled_train_perc_100.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/energy_subsampled_train_perc_100.npz -------------------------------------------------------------------------------- /data/energy_subsampled_train_perc_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/energy_subsampled_train_perc_2.npz -------------------------------------------------------------------------------- /data/energy_subsampled_train_perc_20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/energy_subsampled_train_perc_20.npz -------------------------------------------------------------------------------- /data/energy_subsampled_train_perc_5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/energy_subsampled_train_perc_5.npz -------------------------------------------------------------------------------- /data/sine_subsampled_train_perc_10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/sine_subsampled_train_perc_10.npz -------------------------------------------------------------------------------- /data/sine_subsampled_train_perc_100.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/sine_subsampled_train_perc_100.npz -------------------------------------------------------------------------------- /data/sine_subsampled_train_perc_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/sine_subsampled_train_perc_2.npz -------------------------------------------------------------------------------- /data/sine_subsampled_train_perc_20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/sine_subsampled_train_perc_20.npz -------------------------------------------------------------------------------- /data/sine_subsampled_train_perc_5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/sine_subsampled_train_perc_5.npz -------------------------------------------------------------------------------- /data/stockv_subsampled_train_perc_10.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/stockv_subsampled_train_perc_10.npz -------------------------------------------------------------------------------- /data/stockv_subsampled_train_perc_100.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/stockv_subsampled_train_perc_100.npz -------------------------------------------------------------------------------- /data/stockv_subsampled_train_perc_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/stockv_subsampled_train_perc_2.npz -------------------------------------------------------------------------------- /data/stockv_subsampled_train_perc_20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/stockv_subsampled_train_perc_20.npz -------------------------------------------------------------------------------- /data/stockv_subsampled_train_perc_5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/data/stockv_subsampled_train_perc_5.npz -------------------------------------------------------------------------------- /outputs/gen_data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /outputs/models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /outputs/tsne/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu124 2 | 3 | torch==2.4.0 4 | pandas==1.3.4 5 | numpy==1.26.4 6 | scikit-learn==1.5.0 7 | matplotlib==3.7.2 8 | pyyaml==6.0.1 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/src/__init__.py -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/src/config/__init__.py -------------------------------------------------------------------------------- /src/config/hyperparameters.yaml: -------------------------------------------------------------------------------- 1 | timeVAE: 2 | latent_dim: 8 3 | hidden_layer_sizes: 4 | - 50 5 | - 100 6 | - 200 7 | reconstruction_wt: 3.0 8 | batch_size: 16 9 | use_residual_conn: true 10 | trend_poly: 0 11 | custom_seas: null 12 | 13 | vae_dense: 14 | latent_dim: 8 15 | hidden_layer_sizes: 16 | - 50 17 | - 100 18 | - 200 19 | reconstruction_wt: 3.0 20 | batch_size: 16 21 | 22 | vae_conv: 23 | latent_dim: 8 24 | hidden_layer_sizes: 25 | - 50 26 | - 100 27 | - 200 28 | reconstruction_wt: 3.0 29 | batch_size: 16 30 | -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import yaml 5 | import numpy as np 6 | 7 | SCALER_FNAME = "scaler.pkl" 8 | 9 | 10 | def load_yaml_file(file_path): 11 | with open(file_path, "r", encoding="utf-8") as file: 12 | loaded = yaml.safe_load(file) 13 | return loaded 14 | 15 | 16 | def load_data(data_dir: str, dataset: str) -> np.ndarray: 17 | """ 18 | Load data from a dataset located in a directory. 19 | 20 | Args: 21 | data_dir (str): The directory where the dataset is located. 22 | dataset (str): The name of the dataset file (without the .npz extension). 23 | 24 | Returns: 25 | np.ndarray: The loaded dataset. 26 | """ 27 | return get_npz_data(os.path.join(data_dir, f"{dataset}.npz")) 28 | 29 | 30 | def save_data(data: np.ndarray, output_file: str) -> None: 31 | """ 32 | Save data to a .npz file. 33 | 34 | Args: 35 | data (np.ndarray): The data to save. 36 | output_file (str): The path to the .npz file to save the data to. 37 | 38 | Returns: 39 | None 40 | """ 41 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 42 | np.savez_compressed(output_file, data=data) 43 | 44 | 45 | def get_npz_data(input_file: str) -> np.ndarray: 46 | """ 47 | Load data from a .npz file. 48 | 49 | Args: 50 | input_file (str): The path to the .npz file. 51 | 52 | Returns: 53 | np.ndarray: The data array extracted from the .npz file. 54 | """ 55 | loaded = np.load(input_file) 56 | return loaded["data"] 57 | 58 | 59 | def split_data( 60 | data: np.ndarray, valid_perc: float, shuffle: bool = True, seed: int = 123 61 | ) -> tuple[np.ndarray, np.ndarray]: 62 | """ 63 | Split the data into training and validation sets. 64 | 65 | Args: 66 | data (np.ndarray): The dataset to split. 67 | valid_perc (float): The percentage of data to use for validation. 68 | shuffle (bool, optional): Whether to shuffle the data before splitting. 69 | Defaults to True. 70 | seed (int, optional): The random seed to use for shuffling the data. 71 | Defaults to 123. 72 | 73 | Returns: 74 | tuple[np.ndarray, np.ndarray]: A tuple containing the training data and 75 | validation data arrays. 76 | """ 77 | N = data.shape[0] 78 | N_train = int(N * (1 - valid_perc)) 79 | 80 | if shuffle: 81 | np.random.seed(seed) 82 | data = data.copy() 83 | np.random.shuffle(data) 84 | 85 | train_data = data[:N_train] 86 | valid_data = data[N_train:] 87 | return train_data, valid_data 88 | 89 | 90 | class MinMaxScaler: 91 | """Min Max normalizer. 92 | Args: 93 | - data: original data 94 | 95 | Returns: 96 | - norm_data: normalized data 97 | """ 98 | 99 | def fit_transform(self, data): 100 | self.fit(data) 101 | scaled_data = self.transform(data) 102 | return scaled_data 103 | 104 | def fit(self, data): 105 | self.mini = np.min(data, 0) 106 | self.range = np.max(data, 0) - self.mini 107 | return self 108 | 109 | def transform(self, data): 110 | numerator = data - self.mini 111 | scaled_data = numerator / (self.range + 1e-7) 112 | return scaled_data 113 | 114 | def inverse_transform(self, data): 115 | data *= self.range 116 | data += self.mini 117 | return data 118 | 119 | 120 | def inverse_transform_data(data, scaler): 121 | return scaler.inverse_transform(data.copy()) 122 | 123 | 124 | def scale_data(train_data, valid_data): 125 | scaler = MinMaxScaler() 126 | scaled_train_data = scaler.fit_transform(train_data) 127 | scaled_valid_data = scaler.transform(valid_data) 128 | return scaled_train_data, scaled_valid_data, scaler 129 | 130 | 131 | def save_scaler(scaler: MinMaxScaler, dir_path: str) -> None: 132 | """ 133 | Save a MinMaxScaler to a file. 134 | 135 | Args: 136 | scaler (MinMaxScaler): The scaler to save. 137 | dir_path (str): The path to the directory where the scaler will be saved. 138 | 139 | Returns: 140 | None 141 | """ 142 | os.makedirs(dir_path, exist_ok=True) 143 | scaler_fpath = os.path.join(dir_path, SCALER_FNAME) 144 | with open(scaler_fpath, "wb") as file: 145 | pickle.dump(scaler, file) 146 | 147 | 148 | def load_scaler(dir_path: str) -> MinMaxScaler: 149 | """ 150 | Load a MinMaxScaler from a file. 151 | 152 | Args: 153 | dir_path (str): The path to the file from which the scaler will be loaded. 154 | 155 | Returns: 156 | MinMaxScaler: The loaded scaler. 157 | """ 158 | scaler_fpath = os.path.join(dir_path, SCALER_FNAME) 159 | with open(scaler_fpath, "rb") as file: 160 | scaler = pickle.load(file) 161 | return scaler 162 | -------------------------------------------------------------------------------- /src/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | 5 | DATASETS_DIR = os.path.join(ROOT_DIR, "data") 6 | 7 | OUTPUTS_DIR = os.path.join(ROOT_DIR, "outputs") 8 | 9 | GEN_DATA_DIR = os.path.join(OUTPUTS_DIR, "gen_data") 10 | MODELS_DIR = os.path.join(OUTPUTS_DIR, "models") 11 | TSNE_DIR = os.path.join(OUTPUTS_DIR, "tsne") 12 | 13 | SRC_DIR = os.path.join(ROOT_DIR, "src") 14 | 15 | CONFIG_DIR = os.path.join(SRC_DIR, "config") 16 | CFG_FILE_PATH = os.path.join(CONFIG_DIR, "config.yaml") 17 | HYPERPARAMETERS_FILE_PATH = os.path.join(CONFIG_DIR, "hyperparameters.yaml") 18 | 19 | 20 | # MODEL ARTIFACTS 21 | SCALER_FILE_PATH = os.path.join(MODELS_DIR, "scaler.pkl") 22 | -------------------------------------------------------------------------------- /src/vae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangyz1999/timeVAE-pytorch/7f8193051924d89fa16ecb6e8cdfff7622e950da/src/vae/__init__.py -------------------------------------------------------------------------------- /src/vae/timevae.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import joblib 7 | 8 | from vae.vae_base import BaseVariationalAutoencoder, Sampling 9 | 10 | 11 | class TrendLayer(nn.Module): 12 | def __init__(self, seq_len, feat_dim, latent_dim, trend_poly): 13 | super(TrendLayer, self).__init__() 14 | self.seq_len = seq_len 15 | self.feat_dim = feat_dim 16 | self.latent_dim = latent_dim 17 | self.trend_poly = trend_poly 18 | self.trend_dense1 = nn.Linear(self.latent_dim, self.feat_dim * self.trend_poly) 19 | self.trend_dense2 = nn.Linear(self.feat_dim * self.trend_poly, self.feat_dim * self.trend_poly) 20 | 21 | def forward(self, z): 22 | trend_params = F.relu(self.trend_dense1(z)) 23 | trend_params = self.trend_dense2(trend_params) 24 | trend_params = trend_params.view(-1, self.feat_dim, self.trend_poly) 25 | 26 | lin_space = torch.arange(0, float(self.seq_len), 1, device=z.device) / self.seq_len 27 | poly_space = torch.stack([lin_space ** float(p + 1) for p in range(self.trend_poly)], dim=0) 28 | 29 | trend_vals = torch.matmul(trend_params, poly_space) 30 | trend_vals = trend_vals.permute(0, 2, 1) 31 | return trend_vals 32 | 33 | 34 | class SeasonalLayer(nn.Module): 35 | def __init__(self, seq_len, feat_dim, latent_dim, custom_seas): 36 | super(SeasonalLayer, self).__init__() 37 | self.seq_len = seq_len 38 | self.feat_dim = feat_dim 39 | self.custom_seas = custom_seas 40 | 41 | self.dense_layers = nn.ModuleList([ 42 | nn.Linear(latent_dim, feat_dim * num_seasons) 43 | for num_seasons, len_per_season in custom_seas 44 | ]) 45 | 46 | 47 | def _get_season_indexes_over_seq(self, num_seasons, len_per_season): 48 | season_indexes = torch.arange(num_seasons).unsqueeze(1) + torch.zeros( 49 | (num_seasons, len_per_season), dtype=torch.int32 50 | ) 51 | season_indexes = season_indexes.view(-1) 52 | season_indexes = season_indexes.repeat(self.seq_len // len_per_season + 1)[: self.seq_len] 53 | return season_indexes 54 | 55 | def forward(self, z): 56 | N = z.shape[0] 57 | ones_tensor = torch.ones((N, self.feat_dim, self.seq_len), dtype=torch.int32, device=z.device) 58 | 59 | all_seas_vals = [] 60 | for i, (num_seasons, len_per_season) in enumerate(self.custom_seas): 61 | season_params = self.dense_layers[i](z) 62 | season_params = season_params.view(-1, self.feat_dim, num_seasons) 63 | 64 | season_indexes_over_time = self._get_season_indexes_over_seq( 65 | num_seasons, len_per_season 66 | ).to(z.device) 67 | 68 | dim2_idxes = ones_tensor * season_indexes_over_time.view(1, 1, -1) 69 | season_vals = torch.gather(season_params, 2, dim2_idxes) 70 | 71 | all_seas_vals.append(season_vals) 72 | 73 | all_seas_vals = torch.stack(all_seas_vals, dim=-1) 74 | all_seas_vals = torch.sum(all_seas_vals, dim=-1) 75 | all_seas_vals = all_seas_vals.permute(0, 2, 1) 76 | 77 | return all_seas_vals 78 | 79 | def compute_output_shape(self, input_shape): 80 | return (input_shape[0], self.seq_len, self.feat_dim) 81 | 82 | 83 | class LevelModel(nn.Module): 84 | def __init__(self, latent_dim, feat_dim, seq_len): 85 | super(LevelModel, self).__init__() 86 | self.latent_dim = latent_dim 87 | self.feat_dim = feat_dim 88 | self.seq_len = seq_len 89 | self.level_dense1 = nn.Linear(self.latent_dim, self.feat_dim) 90 | self.level_dense2 = nn.Linear(self.feat_dim, self.feat_dim) 91 | self.relu = nn.ReLU() 92 | 93 | def forward(self, z): 94 | level_params = self.relu(self.level_dense1(z)) 95 | level_params = self.level_dense2(level_params) 96 | level_params = level_params.view(-1, 1, self.feat_dim) 97 | 98 | ones_tensor = torch.ones((1, self.seq_len, 1), dtype=torch.float32, device=z.device) 99 | level_vals = level_params * ones_tensor 100 | return level_vals 101 | 102 | 103 | class ResidualConnection(nn.Module): 104 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim, encoder_last_dense_dim): 105 | super(ResidualConnection, self).__init__() 106 | self.seq_len = seq_len 107 | self.feat_dim = feat_dim 108 | self.hidden_layer_sizes = hidden_layer_sizes 109 | 110 | self.dense = nn.Linear(latent_dim, encoder_last_dense_dim) 111 | self.deconv_layers = nn.ModuleList() 112 | in_channels = hidden_layer_sizes[-1] 113 | 114 | for i, num_filters in enumerate(reversed(hidden_layer_sizes[:-1])): 115 | self.deconv_layers.append( 116 | nn.ConvTranspose1d(in_channels, num_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 117 | ) 118 | in_channels = num_filters 119 | 120 | self.deconv_layers.append( 121 | nn.ConvTranspose1d(in_channels, feat_dim, kernel_size=3, stride=2, padding=1, output_padding=1) 122 | ) 123 | 124 | L_in = encoder_last_dense_dim // hidden_layer_sizes[-1] 125 | for i in range(len(hidden_layer_sizes)): 126 | L_in = (L_in - 1) * 2 - 2 * 1 + 3 + 1 127 | L_final = L_in 128 | 129 | self.final_dense = nn.Linear(feat_dim * L_final, seq_len * feat_dim) 130 | 131 | def forward(self, z): 132 | batch_size = z.size(0) 133 | x = F.relu(self.dense(z)) 134 | x = x.view(batch_size, -1, self.hidden_layer_sizes[-1]) 135 | x = x.transpose(1, 2) 136 | 137 | for deconv in self.deconv_layers[:-1]: 138 | x = F.relu(deconv(x)) 139 | x = F.relu(self.deconv_layers[-1](x)) 140 | 141 | x = x.flatten(1) 142 | x = self.final_dense(x) 143 | residuals = x.view(-1, self.seq_len, self.feat_dim) 144 | return residuals 145 | 146 | 147 | class TimeVAEEncoder(nn.Module): 148 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim): 149 | super(TimeVAEEncoder, self).__init__() 150 | self.seq_len = seq_len 151 | self.feat_dim = feat_dim 152 | self.latent_dim = latent_dim 153 | self.hidden_layer_sizes = hidden_layer_sizes 154 | self.layers = [] 155 | self.layers.append(nn.Conv1d(feat_dim, hidden_layer_sizes[0], kernel_size=3, stride=2, padding=1)) 156 | self.layers.append(nn.ReLU()) 157 | 158 | for i, num_filters in enumerate(hidden_layer_sizes[1:]): 159 | self.layers.append(nn.Conv1d(hidden_layer_sizes[i], num_filters, kernel_size=3, stride=2, padding=1)) 160 | self.layers.append(nn.ReLU()) 161 | 162 | self.layers.append(nn.Flatten()) 163 | 164 | self.encoder_last_dense_dim = self._get_last_dense_dim(seq_len, feat_dim, hidden_layer_sizes) 165 | 166 | self.encoder = nn.Sequential(*self.layers) 167 | self.z_mean = nn.Linear(self.encoder_last_dense_dim, latent_dim) 168 | self.z_log_var = nn.Linear(self.encoder_last_dense_dim, latent_dim) 169 | 170 | def forward(self, x): 171 | x = x.transpose(1, 2) 172 | x = self.encoder(x) 173 | z_mean = self.z_mean(x) 174 | z_log_var = self.z_log_var(x) 175 | z = Sampling()([z_mean, z_log_var]) 176 | return z_mean, z_log_var, z 177 | 178 | def _get_last_dense_dim(self, seq_len, feat_dim, hidden_layer_sizes): 179 | with torch.no_grad(): 180 | x = torch.randn(1, feat_dim, seq_len) 181 | for conv in self.layers: 182 | x = conv(x) 183 | return x.numel() 184 | 185 | class TimeVAEDecoder(nn.Module): 186 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim, trend_poly=0, custom_seas=None, use_residual_conn=True, encoder_last_dense_dim=None): 187 | super(TimeVAEDecoder, self).__init__() 188 | self.seq_len = seq_len 189 | self.feat_dim = feat_dim 190 | self.hidden_layer_sizes = hidden_layer_sizes 191 | self.latent_dim = latent_dim 192 | self.trend_poly = trend_poly 193 | self.custom_seas = custom_seas 194 | self.use_residual_conn = use_residual_conn 195 | self.encoder_last_dense_dim = encoder_last_dense_dim 196 | self.level_model = LevelModel(self.latent_dim, self.feat_dim, self.seq_len) 197 | 198 | if use_residual_conn: 199 | self.residual_conn = ResidualConnection(seq_len, feat_dim, hidden_layer_sizes, latent_dim, encoder_last_dense_dim) 200 | 201 | def forward(self, z): 202 | outputs = self.level_model(z) 203 | if self.trend_poly is not None and self.trend_poly > 0: 204 | trend_vals = TrendLayer(self.seq_len, self.feat_dim, self.latent_dim, self.trend_poly)(z) 205 | outputs += trend_vals 206 | 207 | # custom seasons 208 | if self.custom_seas is not None and len(self.custom_seas) > 0: 209 | cust_seas_vals = SeasonalLayer(self.seq_len, self.feat_dim, self.latent_dim, self.custom_seas)(z) 210 | outputs += cust_seas_vals 211 | 212 | if self.use_residual_conn: 213 | residuals = self.residual_conn(z) 214 | outputs += residuals 215 | 216 | return outputs 217 | 218 | 219 | class TimeVAE(BaseVariationalAutoencoder): 220 | model_name = "TimeVAE" 221 | 222 | def __init__( 223 | self, 224 | hidden_layer_sizes=None, 225 | trend_poly=0, 226 | custom_seas=None, 227 | use_residual_conn=True, 228 | **kwargs, 229 | ): 230 | super(TimeVAE, self).__init__(**kwargs) 231 | 232 | if hidden_layer_sizes is None: 233 | hidden_layer_sizes = [50, 100, 200] 234 | 235 | self.hidden_layer_sizes = hidden_layer_sizes 236 | self.trend_poly = trend_poly 237 | self.custom_seas = custom_seas 238 | self.use_residual_conn = use_residual_conn 239 | 240 | self.encoder = self._get_encoder() 241 | self.decoder = self._get_decoder() 242 | 243 | for layer in self.modules(): 244 | if isinstance(layer, nn.Linear): 245 | nn.init.xavier_uniform_(layer.weight) 246 | if layer.bias is not None: 247 | nn.init.zeros_(layer.bias) 248 | 249 | def _get_encoder(self): 250 | return TimeVAEEncoder(self.seq_len, self.feat_dim, self.hidden_layer_sizes, self.latent_dim) 251 | 252 | def _get_decoder(self): 253 | return TimeVAEDecoder(self.seq_len, self.feat_dim, self.hidden_layer_sizes, self.latent_dim, self.trend_poly, self.custom_seas, self.use_residual_conn, self.encoder.encoder_last_dense_dim) 254 | 255 | def save(self, model_dir: str): 256 | os.makedirs(model_dir, exist_ok=True) 257 | torch.save(self.state_dict(), os.path.join(model_dir, f"{self.model_name}_weights.pth")) 258 | 259 | if self.custom_seas is not None: 260 | self.custom_seas = [(int(num_seasons), int(len_per_season)) for num_seasons, len_per_season in self.custom_seas] 261 | 262 | dict_params = { 263 | "seq_len": self.seq_len, 264 | "feat_dim": self.feat_dim, 265 | "latent_dim": self.latent_dim, 266 | "reconstruction_wt": self.reconstruction_wt, 267 | "hidden_layer_sizes": list(self.hidden_layer_sizes), 268 | "trend_poly": self.trend_poly, 269 | "custom_seas": self.custom_seas, 270 | "use_residual_conn": self.use_residual_conn, 271 | } 272 | params_file = os.path.join(model_dir, f"{self.model_name}_parameters.pkl") 273 | joblib.dump(dict_params, params_file) 274 | 275 | @classmethod 276 | def load(cls, model_dir: str) -> "TimeVAE": 277 | params_file = os.path.join(model_dir, f"{cls.model_name}_parameters.pkl") 278 | dict_params = joblib.load(params_file) 279 | vae_model = TimeVAE(**dict_params) 280 | vae_model.load_state_dict(torch.load(os.path.join(model_dir, f"{cls.model_name}_weights.pth"))) 281 | return vae_model -------------------------------------------------------------------------------- /src/vae/vae_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader, TensorDataset 8 | import joblib 9 | 10 | 11 | class Sampling(nn.Module): 12 | def forward(self, inputs): 13 | z_mean, z_log_var = inputs 14 | batch = z_mean.size(0) 15 | dim = z_mean.size(1) 16 | epsilon = torch.randn(batch, dim).to(z_mean.device) 17 | return z_mean + torch.exp(0.5 * z_log_var) * epsilon 18 | 19 | class BaseVariationalAutoencoder(nn.Module, ABC): 20 | model_name = None 21 | 22 | def __init__( 23 | self, 24 | seq_len, 25 | feat_dim, 26 | latent_dim, 27 | reconstruction_wt=3.0, 28 | batch_size=16, 29 | **kwargs 30 | ): 31 | super(BaseVariationalAutoencoder, self).__init__() 32 | self.seq_len = seq_len 33 | self.feat_dim = feat_dim 34 | self.latent_dim = latent_dim 35 | self.reconstruction_wt = reconstruction_wt 36 | self.batch_size = batch_size 37 | self.encoder = None 38 | self.decoder = None 39 | self.sampling = Sampling() 40 | 41 | def fit_on_data(self, train_data, max_epochs=1000, verbose=0): 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | self.to(device) 44 | 45 | train_tensor = torch.FloatTensor(train_data).to(device) 46 | train_dataset = TensorDataset(train_tensor) 47 | train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 48 | 49 | optimizer = optim.Adam(self.parameters()) 50 | 51 | for epoch in range(max_epochs): 52 | self.train() 53 | total_loss = 0 54 | reconstruction_loss = 0 55 | kl_loss = 0 56 | 57 | for batch in train_loader: 58 | X = batch[0] 59 | optimizer.zero_grad() 60 | 61 | z_mean, z_log_var, z = self.encoder(X) 62 | reconstruction = self.decoder(z) 63 | 64 | loss, recon_loss, kl = self.loss_function(X, reconstruction, z_mean, z_log_var) 65 | 66 | # Normalize the loss by the batch size 67 | loss = loss / X.size(0) 68 | recon_loss = recon_loss / X.size(0) 69 | kl = kl / X.size(0) 70 | 71 | loss.backward() 72 | optimizer.step() 73 | 74 | total_loss += loss.item() 75 | reconstruction_loss += recon_loss.item() 76 | kl_loss += kl.item() 77 | 78 | if verbose: 79 | print(f"Epoch {epoch + 1}/{max_epochs} | Total loss: {total_loss / len(train_loader):.4f} | " 80 | f"Recon loss: {reconstruction_loss / len(train_loader):.4f} | " 81 | f"KL loss: {kl_loss / len(train_loader):.4f}") 82 | 83 | def forward(self, X): 84 | z_mean, z_log_var, z = self.encoder(X) 85 | x_decoded = self.decoder(z_mean) 86 | return x_decoded 87 | 88 | def predict(self, X): 89 | self.eval() 90 | with torch.no_grad(): 91 | X = torch.FloatTensor(X).to(next(self.parameters()).device) 92 | z_mean, z_log_var, z = self.encoder(X) 93 | x_decoded = self.decoder(z_mean) 94 | return x_decoded.cpu().detach().numpy() 95 | 96 | def get_num_trainable_variables(self): 97 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 98 | 99 | def get_prior_samples(self, num_samples): 100 | device = next(self.parameters()).device 101 | Z = torch.randn(num_samples, self.latent_dim).to(device) 102 | samples = self.decoder(Z) 103 | return samples.cpu().detach().numpy() 104 | 105 | def get_prior_samples_given_Z(self, Z): 106 | Z = torch.FloatTensor(Z).to(next(self.parameters()).device) 107 | samples = self.decoder(Z) 108 | return samples.cpu().detach().numpy() 109 | 110 | @abstractmethod 111 | def _get_encoder(self, **kwargs): 112 | raise NotImplementedError 113 | 114 | @abstractmethod 115 | def _get_decoder(self, **kwargs): 116 | raise NotImplementedError 117 | 118 | def _get_reconstruction_loss(self, X, X_recons): 119 | def get_reconst_loss_by_axis(X, X_recons, dim): 120 | x_r = torch.mean(X, dim=dim) 121 | x_c_r = torch.mean(X_recons, dim=dim) 122 | err = torch.pow(x_r - x_c_r, 2) 123 | loss = torch.sum(err) 124 | return loss 125 | 126 | err = torch.pow(X - X_recons, 2) 127 | reconst_loss = torch.sum(err) 128 | 129 | reconst_loss += get_reconst_loss_by_axis(X, X_recons, dim=2) # by time axis 130 | # reconst_loss += get_reconst_loss_by_axis(X, X_recons, dim=1) # by feature axis 131 | 132 | return reconst_loss 133 | 134 | def loss_function(self, X, X_recons, z_mean, z_log_var): 135 | reconstruction_loss = self._get_reconstruction_loss(X, X_recons) 136 | kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) 137 | total_loss = self.reconstruction_wt * reconstruction_loss + kl_loss 138 | return total_loss, reconstruction_loss, kl_loss 139 | 140 | def save_weights(self, model_dir): 141 | if self.model_name is None: 142 | raise ValueError("Model name not set.") 143 | os.makedirs(model_dir, exist_ok=True) 144 | torch.save(self.encoder.state_dict(), os.path.join(model_dir, f"{self.model_name}_encoder_wts.pth")) 145 | torch.save(self.decoder.state_dict(), os.path.join(model_dir, f"{self.model_name}_decoder_wts.pth")) 146 | 147 | def load_weights(self, model_dir): 148 | self.encoder.load_state_dict(torch.load(os.path.join(model_dir, f"{self.model_name}_encoder_wts.pth"))) 149 | self.decoder.load_state_dict(torch.load(os.path.join(model_dir, f"{self.model_name}_decoder_wts.pth"))) 150 | 151 | def save(self, model_dir): 152 | os.makedirs(model_dir, exist_ok=True) 153 | self.save_weights(model_dir) 154 | dict_params = { 155 | "seq_len": self.seq_len, 156 | "feat_dim": self.feat_dim, 157 | "latent_dim": self.latent_dim, 158 | "reconstruction_wt": self.reconstruction_wt, 159 | "hidden_layer_sizes": list(self.hidden_layer_sizes) if hasattr(self, 'hidden_layer_sizes') else None, 160 | } 161 | params_file = os.path.join(model_dir, f"{self.model_name}_parameters.pkl") 162 | joblib.dump(dict_params, params_file) 163 | 164 | if __name__ == "__main__": 165 | pass -------------------------------------------------------------------------------- /src/vae/vae_conv_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import joblib 7 | 8 | from vae.vae_base import BaseVariationalAutoencoder, Sampling 9 | 10 | 11 | class ConvEncoder(nn.Module): 12 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim): 13 | super(ConvEncoder, self).__init__() 14 | self.conv_layers = nn.ModuleList() 15 | in_channels = feat_dim 16 | for i, num_filters in enumerate(hidden_layer_sizes): 17 | self.conv_layers.append( 18 | nn.Conv1d(in_channels, num_filters, kernel_size=3, stride=2, padding=1) 19 | ) 20 | in_channels = num_filters 21 | 22 | self.encoder_last_dense_dim = self._get_last_dense_dim(seq_len, feat_dim, hidden_layer_sizes) 23 | self.z_mean = nn.Linear(self.encoder_last_dense_dim, latent_dim) 24 | self.z_log_var = nn.Linear(self.encoder_last_dense_dim, latent_dim) 25 | self.sampling = Sampling() 26 | 27 | def _get_last_dense_dim(self, seq_len, feat_dim, hidden_layer_sizes): 28 | with torch.no_grad(): 29 | x = torch.randn(1, feat_dim, seq_len) 30 | for conv in self.conv_layers: 31 | x = conv(x) 32 | return x.numel() 33 | 34 | def forward(self, x): 35 | x = x.transpose(1, 2) 36 | for conv in self.conv_layers: 37 | x = F.relu(conv(x)) 38 | x = x.flatten(1) 39 | z_mean = self.z_mean(x) 40 | z_log_var = self.z_log_var(x) 41 | z = self.sampling((z_mean, z_log_var)) 42 | return z_mean, z_log_var, z 43 | 44 | 45 | class ConvDecoder(nn.Module): 46 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim, encoder_last_dense_dim): 47 | super(ConvDecoder, self).__init__() 48 | 49 | self.seq_len = seq_len 50 | self.feat_dim = feat_dim 51 | self.hidden_layer_sizes = hidden_layer_sizes 52 | 53 | self.dense = nn.Linear(latent_dim, encoder_last_dense_dim) 54 | self.deconv_layers = nn.ModuleList() 55 | in_channels = hidden_layer_sizes[-1] 56 | 57 | for i, num_filters in enumerate(reversed(hidden_layer_sizes[:-1])): 58 | self.deconv_layers.append( 59 | nn.ConvTranspose1d(in_channels, num_filters, kernel_size=3, stride=2, padding=1, output_padding=1) 60 | ) 61 | in_channels = num_filters 62 | 63 | self.deconv_layers.append( 64 | nn.ConvTranspose1d(in_channels, feat_dim, kernel_size=3, stride=2, padding=1, output_padding=1) 65 | ) 66 | 67 | L_in = encoder_last_dense_dim // hidden_layer_sizes[-1] 68 | for i in range(len(hidden_layer_sizes)): 69 | L_in = (L_in - 1) * 2 - 2 * 1 + 3 + 1 70 | L_final = L_in 71 | 72 | self.final_dense = nn.Linear(feat_dim * L_final, seq_len * feat_dim) 73 | 74 | def forward(self, z): 75 | batch_size = z.size(0) 76 | x = F.relu(self.dense(z)) 77 | x = x.view(batch_size, -1, self.hidden_layer_sizes[-1]) 78 | x = x.transpose(1, 2) 79 | 80 | for deconv in self.deconv_layers[:-1]: 81 | x = F.relu(deconv(x)) 82 | x = F.relu(self.deconv_layers[-1](x)) 83 | 84 | x = x.flatten(1) 85 | x = self.final_dense(x) 86 | x = x.view(-1, self.seq_len, self.feat_dim) 87 | return x 88 | 89 | class VariationalAutoencoderConv(BaseVariationalAutoencoder): 90 | model_name = "VAE_Conv" 91 | 92 | def __init__(self, hidden_layer_sizes=None, **kwargs): 93 | super(VariationalAutoencoderConv, self).__init__(**kwargs) 94 | if hidden_layer_sizes is None: 95 | hidden_layer_sizes = [50, 100, 200] 96 | 97 | self.hidden_layer_sizes = hidden_layer_sizes 98 | self.encoder = self._get_encoder() 99 | self.decoder = self._get_decoder() 100 | 101 | for layer in self.modules(): 102 | if isinstance(layer, nn.Linear): 103 | nn.init.xavier_uniform_(layer.weight) 104 | if layer.bias is not None: 105 | nn.init.zeros_(layer.bias) 106 | 107 | def _get_encoder(self): 108 | return ConvEncoder(self.seq_len, self.feat_dim, self.hidden_layer_sizes, self.latent_dim) 109 | 110 | def _get_decoder(self): 111 | return ConvDecoder(self.seq_len, self.feat_dim, self.hidden_layer_sizes, self.latent_dim, self.encoder.encoder_last_dense_dim) 112 | 113 | @classmethod 114 | def load(cls, model_dir): 115 | params_file = os.path.join(model_dir, f"{cls.model_name}_parameters.pkl") 116 | dict_params = joblib.load(params_file) 117 | vae_model = VariationalAutoencoderConv(**dict_params) 118 | vae_model.load_weights(model_dir) 119 | return vae_model -------------------------------------------------------------------------------- /src/vae/vae_dense_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import joblib 6 | 7 | from vae.vae_base import BaseVariationalAutoencoder, Sampling 8 | 9 | class DenseEncoder(nn.Module): 10 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim): 11 | super(DenseEncoder, self).__init__() 12 | input_size = seq_len * feat_dim 13 | 14 | encoder_layers = [] 15 | encoder_layers.append(nn.Flatten()) 16 | 17 | for M_out in hidden_layer_sizes: 18 | encoder_layers.append(nn.Linear(input_size, M_out)) 19 | encoder_layers.append(nn.ReLU()) 20 | input_size = M_out 21 | 22 | self.encoder = nn.Sequential(*encoder_layers) 23 | self.z_mean = nn.Linear(input_size, latent_dim) 24 | self.z_log_var = nn.Linear(input_size, latent_dim) 25 | self.sampling = Sampling() 26 | 27 | def forward(self, x): 28 | x = self.encoder(x) 29 | z_mean = self.z_mean(x) 30 | z_log_var = self.z_log_var(x) 31 | z = self.sampling((z_mean, z_log_var)) 32 | return z_mean, z_log_var, z 33 | 34 | class DenseDecoder(nn.Module): 35 | def __init__(self, seq_len, feat_dim, hidden_layer_sizes, latent_dim): 36 | super(DenseDecoder, self).__init__() 37 | decoder_layers = [] 38 | input_size = latent_dim 39 | self.seq_len = seq_len 40 | self.feat_dim = feat_dim 41 | 42 | for M_out in hidden_layer_sizes: 43 | decoder_layers.append(nn.Linear(input_size, M_out)) 44 | decoder_layers.append(nn.ReLU()) 45 | input_size = M_out 46 | 47 | decoder_layers.append(nn.Linear(input_size, seq_len * feat_dim)) 48 | self.decoder = nn.Sequential(*decoder_layers) 49 | 50 | def forward(self, z): 51 | decoder_output = self.decoder(z) 52 | reshaped_output = decoder_output.view(-1, self.seq_len, self.feat_dim) 53 | return reshaped_output 54 | 55 | class VariationalAutoencoderDense(BaseVariationalAutoencoder): 56 | model_name = "VAE_Dense" 57 | 58 | def __init__(self, hidden_layer_sizes, **kwargs): 59 | super(VariationalAutoencoderDense, self).__init__(**kwargs) 60 | 61 | if hidden_layer_sizes is None: 62 | hidden_layer_sizes = [50, 100, 200] 63 | 64 | self.hidden_layer_sizes = hidden_layer_sizes 65 | 66 | self.encoder = self._get_encoder() 67 | self.decoder = self._get_decoder() 68 | 69 | for layer in self.modules(): 70 | if isinstance(layer, nn.Linear): 71 | nn.init.xavier_uniform_(layer.weight) 72 | if layer.bias is not None: 73 | nn.init.zeros_(layer.bias) 74 | 75 | def _get_encoder(self): 76 | return DenseEncoder(self.seq_len, self.feat_dim, self.hidden_layer_sizes, self.latent_dim) 77 | 78 | def _get_decoder(self): 79 | return DenseDecoder(self.seq_len, self.feat_dim, list(reversed(self.hidden_layer_sizes)), self.latent_dim) 80 | 81 | @classmethod 82 | def load(cls, model_dir: str) -> "VariationalAutoencoderDense": 83 | params_file = os.path.join(model_dir, f"{cls.model_name}_parameters.pkl") 84 | dict_params = joblib.load(params_file) 85 | vae_model = VariationalAutoencoderDense(**dict_params) 86 | vae_model.load_state_dict(torch.load(os.path.join(model_dir, f"{cls.model_name}_weights.pth"))) 87 | return vae_model 88 | 89 | def save_weights(self, model_dir): 90 | torch.save(self.state_dict(), os.path.join(model_dir, f"{self.model_name}_weights.pth")) 91 | -------------------------------------------------------------------------------- /src/vae/vae_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Union, List, Optional 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from vae.vae_dense_model import VariationalAutoencoderDense as VAE_Dense 8 | from vae.vae_conv_model import VariationalAutoencoderConv as VAE_Conv 9 | from vae.timevae import TimeVAE 10 | 11 | 12 | def set_seeds(seed: int = 111) -> None: 13 | """ 14 | Set seeds for reproducibility. 15 | 16 | Args: 17 | seed (int): The seed value to set. 18 | """ 19 | # Set the seed for PyTorch 20 | torch.manual_seed(seed) 21 | 22 | # Set the seed for NumPy 23 | np.random.seed(seed) 24 | 25 | # Set the seed for Python built-in random module 26 | random.seed(seed) 27 | 28 | 29 | def instantiate_vae_model( 30 | vae_type: str, sequence_length: int, feature_dim: int, batch_size: int, **kwargs 31 | ) -> Union[VAE_Dense, VAE_Conv, TimeVAE]: 32 | """ 33 | Instantiate a Variational Autoencoder (VAE) model based on the specified type. 34 | 35 | Args: 36 | vae_type (str): The type of VAE model to instantiate. 37 | One of ('vae_dense', 'vae_conv', 'timeVAE'). 38 | sequence_length (int): The sequence length. 39 | feature_dim (int): The feature dimension. 40 | batch_size (int): Batch size for training. 41 | 42 | Returns: 43 | Union[VAE_Dense, VAE_Conv, TimeVAE]: The instantiated VAE model. 44 | 45 | Raises: 46 | ValueError: If an unrecognized VAE type is provided. 47 | """ 48 | set_seeds(seed=123) 49 | 50 | if vae_type == "vae_dense": 51 | vae = VAE_Dense( 52 | seq_len=sequence_length, 53 | feat_dim=feature_dim, 54 | batch_size=batch_size, 55 | **kwargs, 56 | ) 57 | elif vae_type == "vae_conv": 58 | vae = VAE_Conv( 59 | seq_len=sequence_length, 60 | feat_dim=feature_dim, 61 | batch_size=batch_size, 62 | **kwargs, 63 | ) 64 | elif vae_type == "timeVAE": 65 | vae = TimeVAE( 66 | seq_len=sequence_length, 67 | feat_dim=feature_dim, 68 | batch_size=batch_size, 69 | **kwargs, 70 | ) 71 | else: 72 | raise ValueError( 73 | f"Unrecognized model type [{vae_type}]. " 74 | "Please choose from vae_dense, vae_conv, timeVAE." 75 | ) 76 | 77 | return vae 78 | 79 | 80 | def train_vae(vae, train_data, max_epochs, verbose=0): 81 | """ 82 | Train a VAE model. 83 | 84 | Args: 85 | vae (Union[VAE_Dense, VAE_Conv, TimeVAE]): The VAE model to train. 86 | train_data (np.ndarray): The training data which must be of shape 87 | [num_samples, window_len, feature_dim]. 88 | max_epochs (int, optional): The maximum number of epochs to train 89 | the model. 90 | Defaults to 100. 91 | verbose (int, optional): Verbose arg for keras model.fit() 92 | """ 93 | vae.fit_on_data(train_data, max_epochs, verbose) 94 | 95 | 96 | def save_vae_model(vae, dir_path: str) -> None: 97 | """ 98 | Save the weights of a VAE model. 99 | 100 | Args: 101 | vae (Union[VAE_Dense, VAE_Conv, TimeVAE]): The VAE model to save. 102 | dir_path (str): The directory to save the model weights. 103 | """ 104 | vae.save(dir_path) 105 | 106 | 107 | def load_vae_model(vae_type: str, dir_path: str) -> Union[VAE_Dense, VAE_Conv, TimeVAE]: 108 | """ 109 | Load a VAE model from the specified directory. 110 | 111 | Args: 112 | vae_type (str): The type of VAE model to load. 113 | One of ('vae_dense', 'vae_conv', 'timeVAE'). 114 | dir_path (str): The directory containing the model weights. 115 | 116 | Returns: 117 | Union[VAE_Dense, VAE_Conv, TimeVAE]: The loaded VAE model. 118 | """ 119 | if vae_type == "vae_dense": 120 | vae = VAE_Dense.load(dir_path) 121 | elif vae_type == "vae_conv": 122 | vae = VAE_Conv.load(dir_path) 123 | elif vae_type == "timeVAE": 124 | vae = TimeVAE.load(dir_path) 125 | else: 126 | raise ValueError( 127 | f"Unrecognized model type [{vae_type}]. " 128 | "Please choose from vae_dense, vae_conv, timeVAE." 129 | ) 130 | 131 | return vae 132 | 133 | 134 | def get_posterior_samples(vae, data): 135 | """ 136 | Get posterior samples from the VAE model. 137 | 138 | Args: 139 | vae (Union[VAE_Dense, VAE_Conv, TimeVAE]): The trained VAE model. 140 | data (np.ndarray): The data to generate posterior samples from. 141 | 142 | Returns: 143 | np.ndarray: The posterior samples. 144 | """ 145 | return vae.predict(data) 146 | 147 | 148 | def get_prior_samples(vae, num_samples: int): 149 | """ 150 | Get prior samples from the VAE model. 151 | 152 | Args: 153 | vae (Union[VAE_Dense, VAE_Conv, TimeVAE]): The trained VAE model. 154 | num_samples (int): The number of samples to generate. 155 | 156 | Returns: 157 | np.ndarray: The prior samples. 158 | """ 159 | return vae.get_prior_samples(num_samples=num_samples) 160 | -------------------------------------------------------------------------------- /src/vae_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import paths 6 | from data_utils import ( 7 | load_yaml_file, 8 | load_data, 9 | split_data, 10 | scale_data, 11 | inverse_transform_data, 12 | save_scaler, 13 | save_data, 14 | ) 15 | from vae.vae_utils import ( 16 | instantiate_vae_model, 17 | train_vae, 18 | save_vae_model, 19 | get_posterior_samples, 20 | get_prior_samples, 21 | load_vae_model, 22 | ) 23 | from visualize import plot_samples, plot_latent_space_samples, visualize_and_save_tsne 24 | 25 | 26 | def run_vae_pipeline(dataset_name: str, vae_type: str): 27 | # ---------------------------------------------------------------------------------- 28 | # Load data, perform train/valid split, scale data 29 | 30 | # read data 31 | data = load_data(data_dir=paths.DATASETS_DIR, dataset=dataset_name) 32 | 33 | # split data into train/valid splits 34 | train_data, valid_data = split_data(data, valid_perc=0.1, shuffle=True) 35 | 36 | # scale data 37 | scaled_train_data, scaled_valid_data, scaler = scale_data(train_data, valid_data) 38 | 39 | # ---------------------------------------------------------------------------------- 40 | # Instantiate and train the VAE Model 41 | 42 | # load hyperparameters from yaml file 43 | hyperparameters = load_yaml_file(paths.HYPERPARAMETERS_FILE_PATH)[vae_type] 44 | 45 | # instantiate the model 46 | _, sequence_length, feature_dim = scaled_train_data.shape 47 | vae_model = instantiate_vae_model( 48 | vae_type=vae_type, 49 | sequence_length=sequence_length, 50 | feature_dim=feature_dim, 51 | **hyperparameters, 52 | ) 53 | 54 | # train vae 55 | train_vae( 56 | vae=vae_model, 57 | train_data=scaled_train_data, 58 | max_epochs=200, 59 | verbose=1, 60 | ) 61 | 62 | # ---------------------------------------------------------------------------------- 63 | # Save scaler and model 64 | model_save_dir = os.path.join(paths.MODELS_DIR, dataset_name) 65 | # save scaler 66 | save_scaler(scaler=scaler, dir_path=model_save_dir) 67 | # Save vae 68 | save_vae_model(vae=vae_model, dir_path=model_save_dir) 69 | 70 | # ---------------------------------------------------------------------------------- 71 | # Visualize posterior samples 72 | x_decoded = get_posterior_samples(vae_model, scaled_train_data) 73 | plot_samples( 74 | samples1=scaled_train_data, 75 | samples1_name="Original Train", 76 | samples2=x_decoded, 77 | samples2_name="Reconstructed Train", 78 | num_samples=5, 79 | ) 80 | # ---------------------------------------------------------------------------------- 81 | # Generate prior samples, visualize and save them 82 | 83 | # Generate prior samples 84 | prior_samples = get_prior_samples(vae_model, num_samples=train_data.shape[0]) 85 | # Plot prior samples 86 | plot_samples( 87 | samples1=prior_samples, 88 | samples1_name="Prior Samples", 89 | num_samples=5, 90 | ) 91 | 92 | # visualize t-sne of original and prior samples 93 | visualize_and_save_tsne( 94 | samples1=scaled_train_data, 95 | samples1_name="Original", 96 | samples2=prior_samples, 97 | samples2_name="Generated (Prior)", 98 | scenario_name=f"Model-{vae_type} Dataset-{dataset_name}", 99 | save_dir=os.path.join(paths.TSNE_DIR, dataset_name), 100 | max_samples=2000, 101 | ) 102 | 103 | # inverse transformer samples to original scale and save to dir 104 | inverse_scaled_prior_samples = inverse_transform_data(prior_samples, scaler) 105 | save_data( 106 | data=inverse_scaled_prior_samples, 107 | output_file=os.path.join( 108 | os.path.join(paths.GEN_DATA_DIR, dataset_name), 109 | f"{vae_type}_{dataset_name}_prior_samples.npz", 110 | ), 111 | ) 112 | 113 | # ---------------------------------------------------------------------------------- 114 | # If latent_dim == 2, plot latent space 115 | if hyperparameters["latent_dim"] == 2: 116 | plot_latent_space_samples(vae=vae_model, n=8, figsize=(15, 15)) 117 | 118 | # ---------------------------------------------------------------------------------- 119 | # later.... load model 120 | loaded_model = load_vae_model(vae_type, model_save_dir).to(next(vae_model.parameters()).device) 121 | 122 | # Verify that loaded model produces same posterior samples 123 | new_x_decoded = loaded_model.predict(scaled_train_data) 124 | print( 125 | "Preds from orig and loaded models equal: ", 126 | np.allclose(x_decoded, new_x_decoded, atol=1e-5), 127 | ) 128 | 129 | # ---------------------------------------------------------------------------------- 130 | 131 | 132 | if __name__ == "__main__": 133 | # check `/data/` for available datasets 134 | dataset = "sine_subsampled_train_perc_20" 135 | 136 | # models: vae_dense, vae_conv, timeVAE 137 | model_name = "vae_conv" 138 | 139 | run_vae_pipeline(dataset, model_name) 140 | -------------------------------------------------------------------------------- /src/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd, numpy as np 6 | from sklearn.manifold import TSNE 7 | 8 | TITLE_FONT_SIZE = 16 9 | 10 | 11 | def plot_samples( 12 | samples1: np.ndarray, 13 | samples1_name: str, 14 | samples2: Optional[np.ndarray] = None, 15 | samples2_name: Optional[str] = None, 16 | num_samples: int = 5, 17 | ) -> None: 18 | """ 19 | Plot one or two sets of samples. 20 | 21 | Args: 22 | samples1 (np.ndarray): The first set of samples to plot. 23 | samples1_name (str): The name for the first set of samples in the plot title. 24 | samples2 (Optional[np.ndarray]): The second set of samples to plot. 25 | Defaults to None. 26 | samples2_name (Optional[str]): The name for the second set of samples in the 27 | plot title. 28 | Defaults to None. 29 | num_samples (int, optional): The number of samples to plot. 30 | Defaults to 5. 31 | 32 | Returns: 33 | None 34 | """ 35 | if samples2 is not None: 36 | fig, axs = plt.subplots(num_samples, 2, figsize=(10, 6)) 37 | else: 38 | fig, axs = plt.subplots(num_samples, 1, figsize=(6, 8)) 39 | 40 | for i in range(num_samples): 41 | rnd_idx1 = np.random.choice(len(samples1)) 42 | sample1 = samples1[rnd_idx1] 43 | 44 | if samples2 is not None: 45 | rnd_idx2 = np.random.choice(len(samples2)) 46 | sample2 = samples2[rnd_idx2] 47 | 48 | axs[i, 0].plot(sample1) 49 | axs[i, 0].set_title(samples1_name) 50 | 51 | axs[i, 1].plot(sample2) 52 | axs[i, 1].set_title(samples2_name) 53 | else: 54 | axs[i].plot(sample1) 55 | axs[i].set_title(samples1_name) 56 | 57 | if samples2 is not None: 58 | fig.suptitle(f"{samples1_name} vs {samples2_name}", fontsize=TITLE_FONT_SIZE) 59 | else: 60 | fig.suptitle(samples1_name, fontsize=TITLE_FONT_SIZE) 61 | 62 | fig.tight_layout() 63 | plt.show() 64 | 65 | 66 | def plot_latent_space_samples(vae, n: int, figsize: tuple) -> None: 67 | """ 68 | Plot samples from a 2D latent space. 69 | 70 | Args: 71 | vae: The VAE model with a method to generate samples from latent space. 72 | n (int): Number of points in each dimension of the grid. 73 | figsize (tuple): Figure size for the plot. 74 | """ 75 | scale = 3.0 76 | grid_x = np.linspace(-scale, scale, n) 77 | grid_y = np.linspace(-scale, scale, n)[::-1] 78 | grid_size = len(grid_x) 79 | 80 | # Generate the latent space grid 81 | Z2 = np.array([[x, y] for x in grid_x for y in grid_y]) 82 | 83 | # Generate samples from the VAE given the latent space coordinates 84 | X_recon = vae.get_prior_samples_given_Z(Z2) 85 | X_recon = np.squeeze(X_recon) 86 | 87 | fig, axs = plt.subplots(grid_size, grid_size, figsize=figsize) 88 | 89 | # Plot each generated sample 90 | for k, (i, yi) in enumerate(enumerate(grid_y)): 91 | for j, xi in enumerate(grid_x): 92 | axs[i, j].plot(X_recon[k]) 93 | axs[i, j].set_title(f"z1={np.round(xi, 2)}; z2={np.round(yi, 2)}") 94 | k += 1 95 | 96 | fig.suptitle("Generated Samples From 2D Embedded Space", fontsize=TITLE_FONT_SIZE) 97 | fig.tight_layout() 98 | plt.show() 99 | 100 | 101 | def avg_over_dim(data: np.ndarray, axis: int) -> np.ndarray: 102 | """ 103 | Average over the feature dimension of the data. 104 | 105 | Args: 106 | data (np.ndarray): The data to average over. 107 | axis (int): Axis to average over. 108 | 109 | Returns: 110 | np.ndarray: The data averaged over the feature dimension. 111 | """ 112 | return np.mean(data, axis=axis) 113 | 114 | 115 | def visualize_and_save_tsne( 116 | samples1: np.ndarray, 117 | samples1_name: str, 118 | samples2: np.ndarray, 119 | samples2_name: str, 120 | scenario_name: str, 121 | save_dir: str, 122 | max_samples: int = 1000, 123 | ) -> None: 124 | """ 125 | Visualize the t-SNE of two sets of samples and save to file. 126 | 127 | Args: 128 | samples1 (np.ndarray): The first set of samples to plot. 129 | samples1_name (str): The name for the first set of samples in the plot title. 130 | samples2 (np.ndarray): The second set of samples to plot. 131 | samples2_name (str): The name for the second set of samples in the 132 | plot title. 133 | scenario_name (str): The scenario name for the given samples. 134 | save_dir (str): Dir path to which to save the file. 135 | max_samples (int): Maximum number of samples to use in the plot. Samples should 136 | be limited because t-SNE is O(n^2). 137 | """ 138 | if samples1.shape != samples2.shape: 139 | raise ValueError( 140 | "Given pairs of samples dont match in shapes. Cannot create t-SNE.\n" 141 | f"sample1 shape: {samples1.shape}; sample2 shape: {samples2.shape}" 142 | ) 143 | 144 | samples1_2d = avg_over_dim(samples1, axis=2) 145 | samples2_2d = avg_over_dim(samples2, axis=2) 146 | 147 | # num of samples used in the t-SNE plot 148 | used_samples = min(samples1_2d.shape[0], max_samples) 149 | 150 | # Combine the original and generated samples 151 | combined_samples = np.vstack( 152 | [samples1_2d[:used_samples], samples2_2d[:used_samples]] 153 | ) 154 | 155 | # Compute the t-SNE of the combined samples 156 | tsne = TSNE(n_components=2, perplexity=40, n_iter=300, random_state=42) 157 | tsne_samples = tsne.fit_transform(combined_samples) 158 | 159 | # Create a DataFrame for the t-SNE samples 160 | tsne_df = pd.DataFrame( 161 | { 162 | "tsne_1": tsne_samples[:, 0], 163 | "tsne_2": tsne_samples[:, 1], 164 | "sample_type": [samples1_name] * used_samples 165 | + [samples2_name] * used_samples, 166 | } 167 | ) 168 | 169 | # Plot the t-SNE samples 170 | plt.figure(figsize=(8, 8)) 171 | for sample_type, color in zip([samples1_name, samples2_name], ["red", "blue"]): 172 | if sample_type is not None: 173 | indices = tsne_df["sample_type"] == sample_type 174 | plt.scatter( 175 | tsne_df.loc[indices, "tsne_1"], 176 | tsne_df.loc[indices, "tsne_2"], 177 | label=sample_type, 178 | color=color, 179 | alpha=0.5, 180 | s=100, 181 | ) 182 | 183 | plt.title(f"t-SNE for {scenario_name}") 184 | plt.legend() 185 | 186 | # Save the plot to a file 187 | os.makedirs(save_dir, exist_ok=True) 188 | plt.savefig(os.path.join(save_dir, f"{scenario_name}.png")) 189 | 190 | plt.show() 191 | --------------------------------------------------------------------------------