├── .gitignore ├── DFINE.py ├── LICENSE.md ├── README.md ├── config_dfine.py ├── data └── swiss_roll.pt ├── datasets.py ├── metrics.py ├── modules ├── LDM.py └── MLP.py ├── nn.py ├── python_utils.py ├── requirements.txt ├── time_series_utils.py ├── trainers ├── BaseTrainer.py └── TrainerDFINE.py └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # unnecessary folders 10 | results* 11 | trainers/__pycache__ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | *.pdf 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /DFINE.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | from modules.LDM import LDM 9 | from modules.MLP import MLP 10 | from nn import get_kernel_initializer_function, compute_mse, get_activation_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class DFINE(nn.Module): 17 | ''' 18 | DFINE (Dynamical Flexible Inference for Nonlinear Embeddings) Model. 19 | 20 | DFINE is a novel neural network model of neural population activity with the ability to perform 21 | flexible inference while modeling the nonlinear latent manifold structure and linear temporal dynamics. 22 | To model neural population activity, two sets of latent factors are defined: the dynamic latent factors 23 | which characterize the linear temporal dynamics on a nonlinear manifold, and the manifold latent factors 24 | which describe this low-dimensional manifold that is embedded in the high-dimensional neural population activity space. 25 | These two separate sets of latent factors together enable all the above flexible inference properties 26 | by allowing for Kalman filtering on the manifold while also capturing embedding nonlinearities. 27 | Here are some mathematical notations used in this repository: 28 | - y: The high dimensional neural population activity, (num_seq, num_steps, dim_y). It must be Gaussian distributed, e.g., Gaussian-smoothed firing rates, or LFP, ECoG, EEG 29 | - a: The manifold latent factors, (num_seq, num_steps, dim_a). 30 | - x: The dynamic latent factors, (num_seq, num_steps, dim_x). 31 | 32 | 33 | * Please note that DFINE can perform learning and inference both for continuous data or trial-based data or segmented continuous data. In the case of continuous data, 34 | num_seq and batch_size can be set to 1, and we let the model be optimized from the long time-series (this is basically gradient descent and not batch-based gradient descent). 35 | In case of trial-based data, we can just pass the 3D tensor as the shape (num_seq, num_steps, dim_y) suggests. In case of segmented continuous data, 36 | num_seq can be the number of segments and DFINE provides both per-segment and concatenated inference at the end for the user's convenience. In the concatenated inference, 37 | the assumption is the concatenation of segments form a continuous time-series (single time-series with batch size of 1). 38 | ''' 39 | 40 | def __init__(self, config): 41 | ''' 42 | Initializer for an DFINE object. Note that DFINE is a subclass of torch.nn.Module. 43 | 44 | Parameters: 45 | ------------ 46 | 47 | - config: yacs.config.CfgNode, yacs config which contains all hyperparameters required to create the DFINE model 48 | Please see config_dfine.py for the hyperparameters, their default values and definitions. 49 | ''' 50 | 51 | super(DFINE, self).__init__() 52 | 53 | # Get the config and dimension parameters 54 | self.config = config 55 | 56 | # Set the seed, seed is by default set to a random integer, see config_dfine.py 57 | torch.manual_seed(self.config.seed) 58 | 59 | # Set the factor dimensions and loss scales 60 | self._set_dims_and_scales() 61 | 62 | # Initialize LDM parameters 63 | A, C, W_log_diag, R_log_diag, mu_0, Lambda_0 = self._init_ldm_parameters() 64 | 65 | # Initialize the LDM 66 | self.ldm = LDM(dim_x=self.dim_x, dim_a=self.dim_a, 67 | A=A, C=C, 68 | W_log_diag=W_log_diag, R_log_diag=R_log_diag, 69 | mu_0=mu_0, Lambda_0=Lambda_0, 70 | is_W_trainable=self.config.model.is_W_trainable, 71 | is_R_trainable=self.config.model.is_R_trainable) 72 | 73 | # Initialize encoder and decoder(s) 74 | self.encoder = self._get_MLP(input_dim=self.dim_y, 75 | output_dim=self.dim_a, 76 | layer_list=self.config.model.hidden_layer_list, 77 | activation_str=self.config.model.activation) 78 | 79 | self.decoder = self._get_MLP(input_dim=self.dim_a, 80 | output_dim=self.dim_y, 81 | layer_list=self.config.model.hidden_layer_list[::-1], 82 | activation_str=self.config.model.activation) 83 | 84 | # If asked to train supervised model, get behavior mapper 85 | if self.config.model.supervise_behv: 86 | self.mapper = self._get_MLP(input_dim=self.dim_a, 87 | output_dim=self.dim_behv, 88 | layer_list=self.config.model.hidden_layer_list_mapper, 89 | activation_str=self.config.model.activation_mapper) 90 | 91 | 92 | def _set_dims_and_scales(self): 93 | ''' 94 | Sets the observation (y), manifold latent factor (a) and dynamic latent factor (x) 95 | (and behavior data dimension if supervised model is to be trained) dimensions, 96 | as well as behavior reconstruction loss and regularization loss scales from config. 97 | ''' 98 | 99 | # Set the dimensions 100 | self.dim_y = self.config.model.dim_y 101 | self.dim_a = self.config.model.dim_a 102 | self.dim_x = self.config.model.dim_x 103 | 104 | if self.config.model.supervise_behv: 105 | self.dim_behv = len(self.config.model.which_behv_dims) 106 | 107 | # Set the loss scales for behavior component and for the regularization 108 | if self.config.model.supervise_behv: 109 | self.scale_behv_recons = self.config.loss.scale_behv_recons 110 | self.scale_l2 = self.config.loss.scale_l2 111 | 112 | 113 | def _get_MLP(self, input_dim, output_dim, layer_list, activation_str='tanh'): 114 | ''' 115 | Creates an MLP object 116 | 117 | Parameters: 118 | ------------ 119 | - input_dim: int, Dimensionality of the input to the MLP network 120 | - output_dim: int, Dimensionality of the output of the MLP network 121 | - layer_list: list, List of number of neurons in each hidden layer 122 | - activation_str: str, Activation function's name, 'tanh' by default 123 | 124 | Returns: 125 | ------------ 126 | - mlp_network: an instance of MLP class with desired architecture 127 | ''' 128 | 129 | activation_fn = get_activation_function(activation_str) 130 | kernel_initializer_fn = get_kernel_initializer_function(self.config.model.nn_kernel_initializer) 131 | 132 | mlp_network = MLP(input_dim=input_dim, 133 | output_dim=output_dim, 134 | layer_list=layer_list, 135 | activation_fn=activation_fn, 136 | kernel_initializer_fn=kernel_initializer_fn 137 | ) 138 | return mlp_network 139 | 140 | 141 | def _init_ldm_parameters(self): 142 | ''' 143 | Initializes the LDM Module parameters 144 | 145 | Returns: 146 | ------------ 147 | - A: torch.Tensor, shape: (self.dim_x, self.dim_x), State transition matrix of LDM 148 | - C: torch.Tensor, shape: (self.dim_a, self.dim_x), Observation matrix of LDM 149 | - W_log_diag: torch.Tensor, shape: (self.dim_x, ), Log-diagonal of dynamics noise covariance matrix (W, therefore it is diagonal and PSD) 150 | - R_log_diag: torch.Tensor, shape: (self.dim_a, ), Log-diagonal of observation noise covariance matrix (R, therefore it is diagonal and PSD) 151 | - mu_0: torch.Tensor, shape: (self.dim_x, ), Dynamic latent factor prediction initial condition (x_{0|-1}) for Kalman filtering 152 | - Lambda_0: torch.Tensor, shape: (self.dim_x, self.dim_x), Dynamic latent factor estimate error covariance initial condition (P_{0|-1}) for Kalman filtering 153 | 154 | * We learn the log-diagonal of matrix W and R to satisfy the PSD constraint for cov matrices. Diagnoal W and R are used for the stability of learning 155 | similar to prior latent LDM works, see (Kao et al., Nature Communications, 2015) & (Abbaspourazad et al., IEEE TNSRE, 2019) for further info 156 | ''' 157 | 158 | kernel_initializer_fn = get_kernel_initializer_function(self.config.model.ldm_kernel_initializer) 159 | A = kernel_initializer_fn(self.config.model.init_A_scale * torch.eye(self.dim_x, dtype=torch.float32)) 160 | C = kernel_initializer_fn(self.config.model.init_C_scale * torch.randn(self.dim_a, self.dim_x, dtype=torch.float32)) 161 | 162 | W_log_diag = torch.log(kernel_initializer_fn(torch.diag(self.config.model.init_W_scale * torch.eye(self.dim_x, dtype=torch.float32)))) 163 | R_log_diag = torch.log(kernel_initializer_fn(torch.diag(self.config.model.init_R_scale * torch.eye(self.dim_a, dtype=torch.float32)))) 164 | 165 | mu_0 = kernel_initializer_fn(torch.zeros(self.dim_x, dtype=torch.float32)) 166 | Lambda_0 = kernel_initializer_fn(self.config.model.init_cov * torch.eye(self.dim_x, dtype=torch.float32)) 167 | 168 | return A, C, W_log_diag, R_log_diag, mu_0, Lambda_0 169 | 170 | 171 | def forward(self, y, mask=None): 172 | ''' 173 | Forward pass for DFINE Model 174 | 175 | Parameters: 176 | ------------ 177 | - y: torch.Tensor, shape: (num_seq, num_steps, dim_y), High-dimensional neural observations 178 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether 179 | observations at each timestep exist (1) or are missing (0) 180 | 181 | Returns: 182 | ------------ 183 | - model_vars: dict, Dictionary which contains learned parameters, inferrred latents, predictions and reconstructions. Keys are: 184 | - a_hat: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors. 185 | - a_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_a), Batch of predicted estimates of manifold latent factors (last index of the second dimension is removed) 186 | - a_filter: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of filtered estimates of manifold latent factors 187 | - a_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of smoothed estimates of manifold latent factors 188 | - x_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_x), Batch of predicted estimates of dynamic latent factors 189 | - x_filter: torch.Tensor, shape: (num_seq, num_steps, dim_x), Batch of filtered estimates of dynamic latent factors 190 | - x_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_x), Batch of smoothed estimates of dynamic latent factors 191 | - Lambda_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_x, dim_x), Batch of predicted estimates of dynamic latent factor estimation error covariance 192 | - Lambda_filter: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Batch of filtered estimates of dynamic latent factor estimation error covariance 193 | - Lambda_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Batch of smoothed estimates of dynamic latent factor estimation error covariance 194 | - y_hat: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of projected estimates of neural observations 195 | - y_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_y), Batch of predicted estimates of neural observations 196 | - y_filter: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of filtered estimates of neural observations 197 | - y_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of smoothed estimates of neural observations 198 | - A: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Repeated (tile) state transition matrix of LDM, same for each time-step in the 2nd axis 199 | - C: torch.Tensor, shape: (num_seq, num_steps, dim_y, dim_x), Repeated (tile) observation matrix of LDM, same for each time-step in the 2nd axis 200 | - behv_hat: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of reconstructed behavior. None if unsupervised model is trained 201 | 202 | * Terminology definition: 203 | projected: noisy estimations of manifold latent factors after nonlinear manifold embedding via encoder 204 | predicted: one-step ahead predicted estimations (t+1|t), the first and last time indices are (1|0) and (T|T-1) 205 | filtered: causal estimations (t|t) 206 | smoothed: non-causal estimations (t|T) 207 | ''' 208 | 209 | # Get the dimensions from y 210 | num_seq, num_steps, _ = y.shape 211 | 212 | # Create the mask if it's None 213 | if mask is None: 214 | mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 215 | 216 | # Get the encoded low-dimensional manifold factors (project via nonlinear manifold embedding) -> the outputs are (num_seq * num_steps, dim_a) 217 | a_hat = self.encoder(y.view(-1, self.dim_y)) 218 | 219 | # Reshape the manifold latent factors back into 3D structure (num_seq, num_steps, dim_a) 220 | a_hat = a_hat.view(-1, num_steps, self.dim_a) 221 | 222 | # Run LDM to infer filtered and smoothed dynamic latent factors 223 | x_pred, x_filter, x_smooth, Lambda_pred, Lambda_filter, Lambda_smooth = self.ldm(a=a_hat, mask=mask, do_smoothing=True) 224 | A = self.ldm.A.repeat(num_seq, num_steps, 1, 1) 225 | C = self.ldm.C.repeat(num_seq, num_steps, 1, 1) 226 | a_pred = (C @ x_pred.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, num_steps, dim_a, dim_x) x (num_seq, num_steps, dim_x, 1) --> (num_seq, num_steps, dim_a) 227 | a_filter = (C @ x_filter.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, num_steps, dim_a, dim_x) x (num_seq, num_steps, dim_x, 1) --> (num_seq, num_steps, dim_a) 228 | a_smooth = (C @ x_smooth.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, num_steps, dim_a, dim_x) x (num_seq, num_steps, dim_x, 1) --> (num_seq, num_steps, dim_a) 229 | 230 | # Remove the last timestep of predictions since it's T+1|T, which is not of our interest 231 | x_pred = x_pred[:, :-1, :] 232 | Lambda_pred = Lambda_pred[:, :-1, :, :] 233 | a_pred = a_pred[:, :-1, :] 234 | 235 | # Supervise a_seq or a_smooth to behavior if requested -> behv_hat shape: (num_seq, num_steps, dim_behv) 236 | if self.config.model.supervise_behv: 237 | if self.config.model.behv_from_smooth: 238 | behv_hat = self.mapper(a_smooth.view(-1, self.dim_a)) 239 | else: 240 | behv_hat = self.mapper(a_hat.view(-1, self.dim_a)) 241 | behv_hat = behv_hat.view(-1, num_steps, self.dim_behv) 242 | else: 243 | behv_hat = None 244 | 245 | # Get filtered and smoothed estimates of neural observations. To perform k-step-ahead prediction, 246 | # get_k_step_ahead_prediction(...) function should be called after the forward pass. 247 | y_hat = self.decoder(a_hat.view(-1, self.dim_a)) 248 | y_pred = self.decoder(a_pred.reshape(-1, self.dim_a)) 249 | y_filter = self.decoder(a_filter.view(-1, self.dim_a)) 250 | y_smooth = self.decoder(a_smooth.view(-1, self.dim_a)) 251 | 252 | y_hat = y_hat.view(num_seq, -1, self.dim_y) 253 | y_pred = y_pred.view(num_seq, -1, self.dim_y) 254 | y_filter = y_filter.view(num_seq, -1, self.dim_y) 255 | y_smooth = y_smooth.view(num_seq, -1, self.dim_y) 256 | 257 | # Dump inferrred latents, predictions and reconstructions to a dictionary 258 | model_vars = dict(a_hat=a_hat, a_pred=a_pred, a_filter=a_filter, a_smooth=a_smooth, 259 | x_pred=x_pred, x_filter=x_filter, x_smooth=x_smooth, 260 | Lambda_pred=Lambda_pred, Lambda_filter=Lambda_filter, Lambda_smooth=Lambda_smooth, 261 | y_hat=y_hat, y_pred=y_pred, y_filter=y_filter, y_smooth=y_smooth, 262 | A=A, C=C, behv_hat=behv_hat) 263 | return model_vars 264 | 265 | 266 | def get_k_step_ahead_prediction(self, model_vars, k): 267 | ''' 268 | Performs k-step ahead prediction of manifold latent factors, dynamic latent factors and neural observations. 269 | 270 | Parameters: 271 | ------------ 272 | - model_vars: dict, Dictionary returned after forward(...) call. See the definition of forward(...) function for information. 273 | - x_filter: torch.Tensor, shape: (num_seq, num_steps, dim_x), Batch of filtered estimates of dynamic latent factors 274 | - A: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x) or (dim_x, dim_x), State transition matrix of LDM 275 | - C: torch.Tensor, shape: (num_seq, num_steps, dim_y, dim_x) or (dim_y, dim_x), Observation matrix of LDM 276 | - k: int, Number of steps ahead for prediction 277 | 278 | Returns: 279 | ------------ 280 | - y_pred_k: torch.Tensor, shape: (num_seq, num_steps-k, dim_y), Batch of predicted estimates of neural observations, 281 | the first index of the second dimension is y_{k|0} 282 | - a_pred_k: torch.Tensor, shape: (num_seq, num_steps-k, dim_a), Batch of predicted estimates of manifold latent factor, 283 | the first index of the second dimension is a_{k|0} 284 | - x_pred_k: torch.Tensor, shape: (num_seq, num_steps-k, dim_x), Batch of predicted estimates of dynamic latent factor, 285 | the first index of the second dimension is x_{k|0} 286 | ''' 287 | 288 | # Check whether provided k value is valid or not 289 | if k <= 0 or not isinstance(k, int): 290 | assert False, 'Number of steps ahead prediction value is invalid or of wrong type, k must be a positive integer!' 291 | 292 | # Extract the required variables from model_vars dictionary 293 | x_filter = model_vars['x_filter'] 294 | A = model_vars['A'] 295 | C = model_vars['C'] 296 | 297 | # Get the required dimensions 298 | num_seq, num_steps, _ = x_filter.shape 299 | 300 | # Check if shapes of A and C are 4D where first 2 dimensions are (number of trials/time segments) and (number of steps) 301 | if len(A.shape) == 2: 302 | A = A.repeat(num_seq, num_steps, 1, 1) 303 | 304 | if len(C.shape) == 2: 305 | C = C.repeat(num_seq, num_steps, 1, 1) 306 | 307 | # Here is where k-step ahead prediction is iteratively performed 308 | x_pred_k = x_filter[:, :-k, ...] # [x_k|0, x_{k+1}|1, ..., x_{T}|{T-k}] 309 | for i in range(1, k+1): 310 | if i != k: 311 | x_pred_k = (A[:, i:-(k-i), ...] @ x_pred_k.unsqueeze(dim=-1)).squeeze(dim=-1) 312 | else: 313 | x_pred_k = (A[:, i:, ...] @ x_pred_k.unsqueeze(dim=-1)).squeeze(dim=-1) 314 | a_pred_k = (C[:, k:, ...] @ x_pred_k.unsqueeze(dim=-1)).squeeze(dim=-1) 315 | 316 | # After obtaining k-step ahead predicted manifold latent factors, they're decoded to obtain k-step ahead predicted neural observations 317 | y_pred_k = self.decoder(a_pred_k.view(-1, self.dim_a)) 318 | 319 | # Reshape mean and variance back to 3D structure after decoder (num_seq, num_steps, dim_y) 320 | y_pred_k = y_pred_k.reshape(num_seq, -1, self.dim_y) 321 | 322 | return y_pred_k, a_pred_k, x_pred_k 323 | 324 | 325 | def compute_loss(self, y, model_vars, mask=None, behv=None): 326 | ''' 327 | Computes k-step ahead predicted MSE loss, regularization loss and behavior reconstruction loss 328 | if supervised model is being trained. 329 | 330 | Parameters: 331 | ------------ 332 | - y: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of high-dimensional neural observations 333 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether 334 | observations at each timestep exists (1) or are missing (0) 335 | if None it will be set to ones. 336 | - model_vars: dict, Dictionary returned after forward(...) call. See the definition of forward(...) function for information. 337 | - behv: torch.tensor, shape: (num_seq, num_steps, dim_behv), Batch of behavior data 338 | 339 | Returns: 340 | ------------ 341 | - loss: torch.Tensor, shape: (), Loss to optimize, which is sum of k-step-ahead MSE loss, L2 regularization loss and 342 | behavior reconstruction loss if model is supervised 343 | - loss_dict: dict, Dictionary which has all loss components to log on Tensorboard. Keys are (e.g. for config.loss.steps_ahead = [1, 2]): 344 | - steps_{k}_mse: torch.Tensor, shape: (), {k}-step ahead predicted masked MSE, k's are determined by config.loss.steps_ahead 345 | - model_loss: torch.Tensor, shape: (), Negative of sum of all steps_{k}_mse 346 | - behv_loss: torch.Tensor, shape: (), Behavior reconstruction loss, 0 if model is unsupervised 347 | - reg_loss: torch.Tensor, shape: (), L2 Regularization loss for DFINE encoder and decoder weights 348 | - total_loss: torch.Tensor, shape: (), Sum of model_loss, behv_loss and reg_loss 349 | ''' 350 | 351 | # Create the mask if it's None 352 | if mask is None: 353 | mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 354 | 355 | # Dump individual loss values for logging or Tensorboard 356 | loss_dict = dict() 357 | 358 | # Iterate over multiple steps ahead 359 | k_steps_mse_sum = 0 360 | for _, k in enumerate(self.config.loss.steps_ahead): 361 | y_pred_k, _, _ = self.get_k_step_ahead_prediction(model_vars, k=k) 362 | mse_pred = compute_mse(y_flat=y[:, k:, :].reshape(-1, self.dim_y), 363 | y_hat_flat=y_pred_k.reshape(-1, self.dim_y), 364 | mask_flat=mask[:, k:, :].reshape(-1,)) 365 | k_steps_mse_sum += mse_pred 366 | loss_dict[f'steps_{k}_mse'] = mse_pred 367 | 368 | model_loss = k_steps_mse_sum 369 | loss_dict['model_loss'] = model_loss 370 | 371 | # Get MSE loss for behavior reconstruction, 0 if we dont supervise our model with behavior data 372 | if self.config.model.supervise_behv: 373 | behv_mse = compute_mse(y_flat=behv[..., self.config.model.which_behv_dims].reshape(-1, self.dim_behv), 374 | y_hat_flat=model_vars['behv_hat'].reshape(-1, self.dim_behv), 375 | mask_flat=mask.reshape(-1,)) 376 | behv_loss = self.scale_behv_recons * behv_mse 377 | else: 378 | behv_mse = torch.tensor(0, dtype=torch.float32, device=model_loss.device) 379 | behv_loss = torch.tensor(0, dtype=torch.float32, device=model_loss.device) 380 | loss_dict['behv_mse'] = behv_mse 381 | loss_dict['behv_loss'] = behv_loss 382 | 383 | # L2 regularization loss 384 | reg_loss = 0 385 | for name, param in self.named_parameters(): 386 | if 'weight' in name: 387 | reg_loss = reg_loss + self.scale_l2 * torch.norm(param) 388 | loss_dict['reg_loss'] = reg_loss 389 | 390 | # Final loss is summation of model loss (sum of k-step ahead MSEs), behavior reconstruction loss and L2 regularization loss 391 | loss = model_loss + behv_loss + reg_loss 392 | loss_dict['total_loss'] = loss 393 | return loss, loss_dict 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | This software is Copyright © 2023 The University of Southern California. All Rights Reserved. 2 | 3 | Permission to use, copy, modify, and distribute this software and its documentation for educational, research 4 | and non-profit purposes, without fee, and without a written agreement is hereby granted, provided that the 5 | above copyright notice, this paragraph and the following three paragraphs appear in all copies. 6 | 7 | Permission to make commercial use of this software may be obtained by contacting: 8 | USC Stevens Center for Innovation 9 | University of Southern California 10 | 1150 S. Olive Street, Suite 2300 11 | Los Angeles, CA 90115, USA 12 | 13 | This software program and documentation are copyrighted by The University of Southern California. The software 14 | program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant 15 | that the operation of the program will be uninterrupted or error-free. The end-user understands that the program 16 | was developed for research purposes and is advised not to rely exclusively on the program for any reason. 17 | 18 | IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR 19 | DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST 20 | PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE 21 | UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 22 | DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY 23 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 24 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED 25 | HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO 26 | OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR 27 | MODIFICATIONS. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch DFINE 2 | PyTorch implementation of DFINE: Dynamical Flexible Inference for Nonlinear Embeddings 3 | 4 | DFINE is a neural network model of neural population activity that is developed to enable accurate and 5 | flexible inference, whether causally in real time, non-causally, or even in the presence of missing neural observations. 6 | Also, DFINE enables recursive and thus computationally efficient inference for real-time implementation. 7 | DFINE's capabilities are important for applications such as neurotechnologies and brain-computer interfaces. 8 | 9 | More information about the model and its training and inference methods can be found inside [tutorial.ipynb](tutorial.ipynb) and in our manuscript below. 10 | 11 | ## Publication 12 | 13 | Abbaspourazad, H.\*, Erturk, E.\*, Pesaran, B., & Shanechi, M. M. Dynamical flexible inference of nonlinear latent factors and structures in neural population activity. _Nature Biomedical Engineering_ (2023). https://www.nature.com/articles/s41551-023-01106-1 14 | 15 | Original preprint: https://www.biorxiv.org/content/10.1101/2023.03.13.532479v1 16 | 17 | ## Installation 18 | Torch DFINE requires Python version 3.8.* or 3.9.*. After the virtual environment with compatible Python version is set up, 19 | navigate to project folder and simply run the following command: 20 | 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | Then, navigate to the virtual environment's site-packages directory, create a file with .pth extension and copy the 26 | main project directory path (e.g. .../torchDFINE) into that .pth file. This will allow importing the desired modules by using subdirectories, 27 | for instance, TrainerDFINE class can be imported by ```from trainers.TrainerDFINE import TrainerDFINE```. 28 | 29 | ## DFINE Tutorial 30 | Please see [tutorial.ipynb](tutorial.ipynb) for further information and guidelines on DFINE's model, training, and inference. 31 | 32 | ## Licence 33 | Copyright (c) 2023 University of Southern California
34 | See full notice in [LICENSE.md](LICENSE.md)
35 | Hamidreza Abbaspourazad\*, Eray Erturk\* and Maryam M. Shanechi
36 | Shanechi Lab, University of Southern California 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /config_dfine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | from python_utils import flatten_dict, unflatten_dict 9 | 10 | from yacs.config import CfgNode as CN 11 | import torch 12 | 13 | 14 | #### Initialization of default and recommended (except dimensions and hidden layer lists, set them suitable for data to fit) config 15 | _config = CN() 16 | 17 | ## Set device and seed 18 | _config.device = 'cpu' 19 | _config.seed = int(torch.randint(low=0, high=100000, size=(1,))) 20 | 21 | ## Dump model related settings 22 | _config.model = CN() 23 | 24 | # Hidden layer list where each element is the number of neurons for that hidden layer of DFINE encoder/decoder. Please use [20,20,20,20] for nonlinear manifold simulations. 25 | _config.model.hidden_layer_list = [32,32,32] 26 | # Activation function used in encoder and decoder layers 27 | _config.model.activation = 'tanh' 28 | # Dimensionality of neural observations 29 | _config.model.dim_y = 30 30 | # Dimensionality of manifold latent factor, a choice higher than dim_y (above) may lead to overfitting 31 | _config.model.dim_a = 16 32 | # Dimensionality of dynamic latent factor, it's recommended to set it same as dim_a (above), please see Extended Data Fig. 8 33 | _config.model.dim_x = 16 34 | # Initialization scale of LDM state transition matrix 35 | _config.model.init_A_scale = 1 36 | # Initialization scale of LDM observation matrix 37 | _config.model.init_C_scale = 1 38 | # Initialization scale of LDM process noise covariance matrix 39 | _config.model.init_W_scale = 0.5 40 | # Initialization scale of LDM observation noise covariance matrix 41 | _config.model.init_R_scale = 0.5 42 | # Initialization scale of dynamic latent factor estimation error covariance matrix 43 | _config.model.init_cov = 1 44 | # Boolean for whether process noise covariance matrix W is learnable or not 45 | _config.model.is_W_trainable = True 46 | # Boolean for whether observation noise covariance matrix R is learnable or not 47 | _config.model.is_R_trainable = True 48 | # Initialization type of LDM parameters, see nn.get_kernel_initializer_function for detailed definition and supported types 49 | _config.model.ldm_kernel_initializer = 'default' 50 | # Initialization type of DFINE encoder and decoder parameters, see nn.get_kernel_initializer_function for detailed definition and supported types 51 | _config.model.nn_kernel_initializer = 'xavier_normal' 52 | # Boolean for whether to learn a behavior-supervised model or not. It must be set to True if supervised model will be trained. 53 | _config.model.supervise_behv = False 54 | # Hidden layer list for the behavior mapper where each element is the number of neurons for that hidden layer of the mapper 55 | _config.model.hidden_layer_list_mapper = [20,20,20] 56 | # Activation function used in mapper layers 57 | _config.model.activation_mapper = 'tanh' 58 | # List of dimensions of behavior data to be decoded by mapper, check for any dimensionality mismatch 59 | _config.model.which_behv_dims = [0,1,2,3] 60 | # Boolean for whether to decode behavior from a_smooth 61 | _config.model.behv_from_smooth = True 62 | # Main save directory for DFINE results, plots and checkpoints 63 | _config.model.save_dir = 'D:/DATA/DFINE_results' 64 | # Number of steps to save DFINE checkpoints 65 | _config.model.save_steps = 10 66 | 67 | ## Dump loss related settings 68 | _config.loss = CN() 69 | 70 | # L2 regularization loss scale (we recommend a grid-search for the best value, i.e., a grid of [1e-4, 5e-4, 1e-3, 2e-3]). Please use 0 for nonlinear manifold simulations as it leads to a better performance. 71 | _config.loss.scale_l2 = 2e-3 72 | # List of number of steps ahead for which DFINE is optimized. For unsupervised and supervised versions, default values are [1,2,3,4] and [1,2], respectively. 73 | _config.loss.steps_ahead = [1,2,3,4] 74 | # If _config.model.supervise_behv is True, scale for MSE of behavior reconstruction (We recommend a grid-search for the best value. It should be set to a large value). 75 | _config.loss.scale_behv_recons = 20 76 | 77 | ## Dump training related settings 78 | _config.train = CN() 79 | 80 | # Batch size 81 | _config.train.batch_size = 32 82 | # Number of epochs for which DFINE is trained 83 | _config.train.num_epochs = 200 84 | # Number of steps to check validation data performance 85 | _config.train.valid_step = 1 86 | # Number of steps to save training/validation plots 87 | _config.train.plot_save_steps = 50 88 | # Number of steps to print training/validation logs 89 | _config.train.print_log_steps = 10 90 | 91 | ## Dump loading settings 92 | _config.load = CN() 93 | 94 | # Number of checkpoint to load 95 | _config.load.ckpt = -1 96 | # Boolean for whether to resume training from the epoch where checkpoint is saved 97 | _config.load.resume_train = False 98 | 99 | ## Dump learning rate related settings 100 | _config.lr = CN() 101 | 102 | # Learning rate scheduler type, options are explr (StepLR, purely exponential if explr.step_size == 1), cyclic (CyclicLR) or constantlr (constant learning rate, no scheduling) 103 | _config.lr.scheduler = 'explr' 104 | # Initial learning rate 105 | _config.lr.init = 0.02 106 | 107 | # Dump cyclic LR scheduler related settings, check https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html for details 108 | _config.lr.cyclic = CN() 109 | # Minimum learning rate for cyclic LR scheduler 110 | _config.lr.cyclic.base_lr = 0.005 111 | # Maximum learning rate for cyclic LR scheduler 112 | _config.lr.cyclic.max_lr = 0.02 113 | # Envelope scale for exponential cyclic LR scheduler mode 114 | _config.lr.cyclic.gamma = 1 115 | # Mode for cyclic LR scheduler 116 | _config.lr.cyclic.mode = 'triangular' 117 | # Number of iterations in the increasing half of the cycle 118 | _config.lr.cyclic.step_size_up = 10 119 | 120 | # Dump exponential LR scheduler related settings, check https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html for details 121 | _config.lr.explr = CN() 122 | # Multiplicative factor of learning rate decay 123 | _config.lr.explr.gamma = 0.9 124 | # Steps to decay the learning rate, becomes purely exponential if step is 1 125 | _config.lr.explr.step_size = 15 126 | 127 | ## Dump optimizer related settings 128 | _config.optim = CN() 129 | 130 | # Epsilon for Adam optimizer 131 | _config.optim.eps = 1e-8 132 | # Gradient clipping norm 133 | _config.optim.grad_clip = 1 134 | 135 | 136 | def get_default_config(): 137 | ''' 138 | Creates the default config 139 | 140 | Returns: 141 | ------------ 142 | - config: yacs.config.CfgNode, default DFINE config 143 | ''' 144 | 145 | return _config.clone() 146 | 147 | 148 | def update_config(config, new_config): 149 | ''' 150 | Updates the config 151 | 152 | Parameters: 153 | ------------ 154 | - config: yacs.config.CfgNode or dict, Config to update 155 | - new_config: yacs.config.CfgNode or dict, Config with new settings and appropriate keys 156 | 157 | Returns: 158 | ------------ 159 | - unflattened_config: yacs.config.CfgNode, Config with updated settings 160 | ''' 161 | 162 | # Flatten both configs 163 | flat_config = flatten_dict(config) 164 | flat_new_config = flatten_dict(new_config) 165 | 166 | # Update and unflatten the config to return 167 | flat_config.update(flat_new_config) 168 | unflattened_config = CN(unflatten_dict(flat_config)) 169 | 170 | return unflattened_config 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /data/swiss_roll.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanechiLab/torchDFINE/48a981259a795c3d1725a348c81a1552093724ac/data/swiss_roll.pt -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | from torch.utils.data import Dataset 9 | import torch 10 | 11 | 12 | class DFINEDataset(Dataset): 13 | ''' 14 | Dataset class for DFINE. 15 | ''' 16 | 17 | def __init__(self, y, behv=None, mask=None): 18 | ''' 19 | Initializer for DFINEDataset. Note that this is a subclass of torch.utils.data.Dataset. \ 20 | 21 | Parameters: 22 | ------------ 23 | - y: torch.Tensor, shape: (num_seq, num_steps, dim_y), High dimensional neural observations. 24 | - behv: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Behavior data. None by default. 25 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether 26 | observations at each timestep exists (1) or are missing (0). 27 | None by default. 28 | ''' 29 | 30 | self.y = y 31 | 32 | # If behv is not provided, initialize it by zeros. 33 | if behv is None: 34 | self.behv = torch.zeros(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 35 | else: 36 | self.behv = behv 37 | 38 | # If mask is not provided, initialize it by ones. 39 | if mask is None: 40 | self.mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 41 | else: 42 | self.mask = mask 43 | 44 | 45 | def __len__(self): 46 | ''' 47 | Returns the length of the dataset 48 | ''' 49 | 50 | return self.y.shape[0] 51 | 52 | 53 | def __getitem__(self, idx): 54 | ''' 55 | Returns a tuple of neural observations, behavior and mask segments 56 | ''' 57 | 58 | return self.y[idx, :, :], self.behv[idx, :, :], self.mask[idx, :, :] -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | from torchmetrics import Metric 9 | import torch 10 | 11 | 12 | class Mean(Metric): 13 | ''' 14 | Mean metric class to log batch-averaged metrics to Tensorboard. 15 | ''' 16 | 17 | def __init__(self): 18 | ''' 19 | Initializer for Mean metric. Note that this class is a subclass of torchmetrics.Metric. 20 | ''' 21 | 22 | super().__init__(dist_sync_on_step=False) 23 | 24 | # Define total sum and number of samples that sum is computed over 25 | self.add_state("sum", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum") 26 | self.add_state("num_samples", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum") 27 | 28 | 29 | def update(self, value, batch_size): 30 | ''' 31 | Updates the total sum and number of samples 32 | 33 | Parameters: 34 | ------------ 35 | - value: torch.Tensor, shape: (), Value to add to sum 36 | - batch_size: torch.Tensor, shape: (), Number of samples that 'value' is averaged over 37 | ''' 38 | 39 | value = value.clone().detach() 40 | batch_size = torch.tensor(batch_size, dtype=torch.float32) 41 | self.sum += value.cpu() * batch_size 42 | self.num_samples += batch_size 43 | 44 | 45 | def reset(self): 46 | ''' 47 | Resets the total sum and number of samples to 0 48 | ''' 49 | 50 | self.sum = torch.tensor(0, dtype=torch.float32) 51 | self.num_samples = torch.tensor(0, dtype=torch.float32) 52 | 53 | 54 | def compute(self): 55 | ''' 56 | Computes the mean metric. 57 | 58 | Returns: 59 | ------------ 60 | - avg: Average value for the metric 61 | ''' 62 | 63 | avg = self.sum / self.num_samples 64 | return avg 65 | -------------------------------------------------------------------------------- /modules/LDM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class LDM(nn.Module): 13 | ''' 14 | Linear Dynamical Model backbone for DFINE. This module is used for smoothing and filtering 15 | given a batch of trials/segments/time-series. 16 | 17 | LDM equations are as follows: 18 | x_{t+1} = Ax_{t} + w_{t}; cov(w_{t}) = W 19 | a_{t} = Cx_{t} + r_{t}; cov(r_{t}) = R 20 | ''' 21 | 22 | def __init__(self, **kwargs): 23 | ''' 24 | Initializer for an LDM object. Note that LDM is a subclass of torch.nn.Module. 25 | 26 | Parameters 27 | ------------ 28 | - dim_x: int, Dimensionality of dynamic latent factors, default None 29 | - dim_a: int, Dimensionality of manifold latent factors, default None 30 | - is_W_trainable: bool, Whether dynamics noise covariance matrix (W) is learnt or not, default True 31 | - is_R_trainable: bool, Whether observation noise covariance matrix (R) is learnt or not, default True 32 | - A: torch.Tensor, shape: (self.dim_x, self.dim_x), State transition matrix of LDM, default identity 33 | - C: torch.Tensor, shape: (self.dim_a, self.dim_x), Observation matrix of LDM, default identity 34 | - mu_0: torch.Tensor, shape: (self.dim_x, ), Dynamic latent factor estimate initial condition (x_{0|-1}) for Kalman filtering, default zeros 35 | - Lambda_0: torch.Tensor, shape: (self.dim_x, self.dim_x), Dynamic latent factor estimate error covariance initial condition (P_{0|-1}) for Kalman Filtering, default identity 36 | - W_log_diag: torch.Tensor, shape: (self.dim_x, ), Log-diagonal of process noise covariance matrix (W, therefore it is diagonal and PSD), default ones 37 | - R_log_diag: torch.Tensor, shape: (self.dim_a, ), Log-diagonal of observation noise covariance matrix (R, therefore it is diagonal and PSD), default ones 38 | ''' 39 | 40 | super(LDM, self).__init__() 41 | 42 | self.dim_x = kwargs.pop('dim_x', None) 43 | self.dim_a = kwargs.pop('dim_a', None) 44 | 45 | self.is_W_trainable = kwargs.pop('is_W_trainable', True) 46 | self.is_R_trainable = kwargs.pop('is_R_trainable', True) 47 | 48 | # Initializer for identity matrix, zeros matrix and ones matrix 49 | self.eye_init = lambda shape, dtype=torch.float32: torch.eye(*shape, dtype=dtype) 50 | self.zeros_init = lambda shape, dtype=torch.float32: torch.zeros(*shape, dtype=dtype) 51 | self.ones_init = lambda shape, dtype=torch.float32: torch.ones(*shape, dtype=dtype) 52 | 53 | # Get initial values for LDM parameters 54 | self.A = kwargs.pop('A', self.eye_init((self.dim_x, self.dim_x), dtype=torch.float32).unsqueeze(dim=0)).type(torch.FloatTensor) 55 | self.C = kwargs.pop('C', self.eye_init((self.dim_a, self.dim_x), dtype=torch.float32).unsqueeze(dim=0)).type(torch.FloatTensor) 56 | 57 | # Get KF initial conditions 58 | self.mu_0 = kwargs.pop('mu_0', self.zeros_init((self.dim_x, ), dtype=torch.float32)).type(torch.FloatTensor) 59 | self.Lambda_0 = kwargs.pop('Lambda_0', self.eye_init((self.dim_x, self.dim_x), dtype=torch.float32)).type(torch.FloatTensor) 60 | 61 | # Get initial process and observation noise parameters 62 | self.W_log_diag = kwargs.pop('W_log_diag', self.ones_init((self.dim_x, ), dtype=torch.float32)).type(torch.FloatTensor) 63 | self.R_log_diag = kwargs.pop('R_log_diag', self.ones_init((self.dim_a, ), dtype=torch.float32)).type(torch.FloatTensor) 64 | 65 | # Register trainable parameters to module 66 | self._register_params() 67 | 68 | 69 | def _register_params(self): 70 | ''' 71 | Registers the learnable LDM parameters as nn.Parameters 72 | ''' 73 | 74 | # Check if LDM matrix shapes are consistent 75 | self._check_matrix_shapes() 76 | 77 | # Register LDM parameters 78 | self.A = torch.nn.Parameter(self.A, requires_grad=True) 79 | self.C = torch.nn.Parameter(self.C, requires_grad=True) 80 | 81 | self.W_log_diag = torch.nn.Parameter(self.W_log_diag, requires_grad=self.is_W_trainable) 82 | self.R_log_diag = torch.nn.Parameter(self.R_log_diag, requires_grad=self.is_R_trainable) 83 | 84 | self.mu_0 = torch.nn.Parameter(self.mu_0, requires_grad=True) 85 | self.Lambda_0 = torch.nn.Parameter(self.Lambda_0, requires_grad=True) 86 | 87 | 88 | def _check_matrix_shapes(self): 89 | ''' 90 | Checks whether LDM parameters have the correct shapes, which are defined above in the constructor 91 | ''' 92 | 93 | # Check A matrix's shape 94 | if self.A.shape != (self.dim_x, self.dim_x): 95 | assert False, 'Shape of A matrix is not (dim_x, dim_x)!' 96 | 97 | # Check C matrix's shape 98 | if self.C.shape != (self.dim_a, self.dim_x): 99 | assert False, 'Shape of C matrix is not (dim_a, dim_x)!' 100 | 101 | # Check mu_0 matrix's shape 102 | if len(self.mu_0.shape) != 1: 103 | self.mu_0 = self.mu_0.view(-1, ) 104 | 105 | if self.mu_0.shape != (self.dim_x, ): 106 | assert False, 'Shape of mu_0 matrix is not (dim_x, )!' 107 | 108 | # Check Lambda_0 matrix's shape 109 | if self.Lambda_0.shape != (self.dim_x, self.dim_x): 110 | assert False, 'Shape of Lambda_0 matrix is not (dim_x, dim_x)!' 111 | 112 | # Check W_log_diag matrix's shape 113 | if len(self.W_log_diag.shape) != 1: 114 | self.W_log_diag = self.W_log_diag.view(-1, ) 115 | 116 | if self.W_log_diag.shape != (self.dim_x, ): 117 | assert False, 'Shape of W_log_diag matrix is not (dim_x, )!' 118 | 119 | # Check R_log_diag matrix's shape 120 | if len(self.R_log_diag.shape) != 1: 121 | self.R_log_diag = self.R_log_diag.view(-1, ) 122 | 123 | if self.R_log_diag.shape != (self.dim_a, ): 124 | assert False, 'Shape of R_log_diag matrix is not (dim_x, )!' 125 | 126 | 127 | def _get_covariance_matrices(self): 128 | ''' 129 | Get the process and observation noise covariance matrices from log-diagonals. 130 | 131 | Returns: 132 | ------------ 133 | - W: torch.Tensor, shape: (self.dim_x, self.dim_x), Process noise covariance matrix 134 | - R: torch.Tensor, shape: (self.dim_a, self.dim_a), Observation noise covariance matrix 135 | ''' 136 | 137 | W = torch.diag(torch.exp(self.W_log_diag)) 138 | R = torch.diag(torch.exp(self.R_log_diag)) 139 | return W, R 140 | 141 | 142 | def compute_forwards(self, a, mask=None): 143 | ''' 144 | Performs the forward iteration of causal flexible Kalman filtering, given a batch of trials/segments/time-series 145 | 146 | Parameters: 147 | ------------ 148 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step) 149 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether 150 | observations at each timestep exists (1) or are missing (0) 151 | 152 | Returns: 153 | ------------ 154 | - mu_pred_all: torch.Tensor, shape: (num_steps, num_seq, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0} 155 | - mu_t_all: torch.Tensor, shape: (num_steps, num_seq, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0} 156 | - Lambda_pred_all: torch.Tensor, shape: (num_steps, num_seq, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0} 157 | - Lambda_t_all: torch.Tensor, shape: (num_steps, num_seq, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0} 158 | ''' 159 | 160 | if mask is None: 161 | mask = torch.ones(a.shape[:-1], dtype=torch.float32) 162 | 163 | num_seq, num_steps, _ = a.shape 164 | 165 | # Make sure that mask is 3D (last axis is 1-dimensional) 166 | if len(mask.shape) != len(a.shape): 167 | mask = mask.unsqueeze(dim=-1) # (num_seq, num_steps, 1) 168 | 169 | # To make sure we do not accidentally use the real outputs in the steps with missing values, set them to a dummy value, e.g., 0. 170 | # The dummy values of observations at masked points are irrelevant because: 171 | # Kalman disregards the observations by setting Kalman Gain to 0 in K = torch.mul(K, mask[:, t, ...].unsqueeze(dim=1)) @ line 204 172 | a_masked = torch.mul(a, mask) # (num_seq, num_steps, dim_a) x (num_seq, num_steps, 1) 173 | 174 | # Initialize mu_0 and Lambda_0 175 | mu_0 = self.mu_0.unsqueeze(dim=0).repeat(num_seq, 1) # (num_seq, dim_x) 176 | Lambda_0 = self.Lambda_0.unsqueeze(dim=0).repeat(num_seq, 1, 1) # (num_seq, dim_x, dim_x) 177 | 178 | mu_pred = mu_0 # (num_seq, dim_x) 179 | Lambda_pred = Lambda_0 # (num_seq, dim_x, dim_x) 180 | 181 | # Create empty arrays for filtered and predicted estimates, NOTE: The last time-step of the prediction has T+1|T, which may not be of interest 182 | mu_pred_all = torch.zeros((num_steps, num_seq, self.dim_x), dtype=torch.float32, device=mu_0.device) 183 | mu_t_all = torch.zeros((num_steps, num_seq, self.dim_x), dtype=torch.float32, device=mu_0.device) 184 | 185 | # Create empty arrays for filtered and predicted error covariance, NOTE: The last time-step of the prediction has T+1|T, which may not be of interest 186 | Lambda_pred_all = torch.zeros((num_steps, num_seq, self.dim_x, self.dim_x), dtype=torch.float32, device=mu_0.device) 187 | Lambda_t_all = torch.zeros((num_steps, num_seq, self.dim_x, self.dim_x), dtype=torch.float32, device=mu_0.device) 188 | 189 | # Get covariance matrices 190 | W, R = self._get_covariance_matrices() 191 | 192 | for t in range(num_steps): 193 | # Tile C matrix for each time segment 194 | C_t = self.C.repeat(num_seq, 1, 1) 195 | 196 | # Obtain residual 197 | a_pred = (C_t @ mu_pred.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_a) 198 | r = a_masked[:, t, ...] - a_pred # (num_seq, dim_a) 199 | 200 | # Project system uncertainty into measurement space, get Kalman Gain 201 | S = C_t @ Lambda_pred @ torch.permute(C_t, (0, 2, 1)) + R # num_seq, dim_a, dim_a) 202 | S_inv = torch.inverse(S) # num_seq, dim_a, dim_a) 203 | K = Lambda_pred @ torch.permute(C_t, (0, 2, 1)) @ S_inv # (num_seq, dim_x, dim_a) 204 | K = torch.mul(K, mask[:, t, ...].unsqueeze(dim=1)) # (num_seq, dim_x, dim_a) x (num_seq, 1, 1) 205 | 206 | # Get current mu and Lambda 207 | mu_t = mu_pred + (K @ r.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_x) 208 | I_KC = torch.eye(self.dim_x, dtype=torch.float32, device=mu_0.device) - K @ C_t # (num_seq, dim_x, dim_x) 209 | Lambda_t = I_KC @ Lambda_pred # (num_seq, dim_x, dim_x) 210 | 211 | # Tile A matrix for each time segment 212 | A_t = self.A.repeat(num_seq, 1, 1) # (num_seq, dim_x, dim_x) 213 | 214 | # Prediction 215 | mu_pred = (A_t @ mu_t.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_x, dim_x) x (num_seq, dim_x, 1) --> (num_seq, dim_x, 1) --> (num_seq, dim_x) 216 | Lambda_pred = A_t @ Lambda_t @ torch.permute(A_t, (0, 2, 1)) + W # (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) --> (num_seq, dim_x, dim_x) 217 | 218 | # Keep predictions and updates 219 | mu_pred_all[t, ...] = mu_pred 220 | mu_t_all[t, ...] = mu_t 221 | 222 | Lambda_pred_all[t, ...] = Lambda_pred 223 | Lambda_t_all[t, ...] = Lambda_t 224 | 225 | return mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all 226 | 227 | 228 | def filter(self, a, mask=None): 229 | ''' 230 | Performs Kalman Filtering 231 | 232 | Parameters: 233 | ------------ 234 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step) 235 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether 236 | observations at each timestep exists (1) or are missing (0) 237 | 238 | Returns: 239 | ------------ 240 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0} 241 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0} 242 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0} 243 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0} 244 | ''' 245 | 246 | # Run the forward iteration 247 | mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all = self.compute_forwards(a=a, mask=mask) 248 | 249 | # Swab num_seq and num_steps dimensions 250 | mu_pred_all = torch.permute(mu_pred_all, (1, 0, 2)) 251 | mu_t_all = torch.permute(mu_t_all, (1, 0, 2)) 252 | Lambda_pred_all = torch.permute(Lambda_pred_all, (1, 0, 2, 3)) 253 | Lambda_t_all = torch.permute(Lambda_t_all, (1, 0, 2, 3)) 254 | 255 | return mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all 256 | 257 | 258 | def compute_backwards(self, mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all): 259 | ''' 260 | Performs backward iteration for Rauch-Tung-Striebel (RTS) Smoother 261 | 262 | Parameters: 263 | ------------ 264 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0} 265 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0} 266 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0} 267 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0} 268 | 269 | Returns: 270 | ------------ 271 | - mu_back_all: torch.Tensor, shape: (num_steps, num_seq, dim_x), Dynamic latent factor smoothed estimates (t|T) where first index of the second dimension has x_{0|T} 272 | - Lambda_back_all: torch.Tensor, shape: (num_steps, num_seq, dim_x, dim_x), Dynamic latent factor estimation error covariance smoothed estimates (t|T) where first index of the second dimension has P_{0|T} 273 | ''' 274 | 275 | # Get number of steps and number of trials 276 | num_steps, num_seq, _ = mu_pred_all.shape 277 | 278 | # Create empty arrays for smoothed dynamic latent factors and error covariances 279 | mu_back_all = torch.zeros((num_steps, num_seq, self.dim_x), dtype=torch.float32, device=mu_pred_all.device) # (num_steps, num_seq, dim_x) 280 | Lambda_back_all = torch.zeros((num_steps, num_seq, self.dim_x, self.dim_x), dtype=torch.float32, device=mu_pred_all.device) # (num_steps, num_seq, dim_x, dim_x) 281 | 282 | # Last smoothed estimation is equivalent to the filtered estimation 283 | mu_back_all[-1, ...] = mu_t_all[-1, ...] 284 | Lambda_back_all[-1, ...] = Lambda_t_all[-1, ...] 285 | 286 | # Initialize iterable parameter 287 | mu_back = mu_t_all[-1, ...] 288 | Lambda_back = Lambda_back_all[-1, ...] 289 | 290 | for t in range(num_steps-2, -1, -1): # iterate loop over reverse time: T-2, T-3, ..., 0, where the last time-step is T-1 291 | A_t = self.A.repeat(num_seq, 1, 1) 292 | J_t = Lambda_t_all[t, ...] @ torch.permute(A_t, (0, 2, 1)) @ torch.inverse(Lambda_pred_all[t, ...]) # (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) 293 | mu_back = mu_t_all[t, ...] + (J_t @ (mu_back - mu_pred_all[t, ...]).unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_x) + (num_seq, dim_x, dim_x) x (num_seq, dim_x) 294 | 295 | Lambda_back = Lambda_t_all[t, ...] + J_t @ (Lambda_back - Lambda_pred_all[t, ...]) @ torch.permute(J_t, (0, 2, 1)) # (num_seq, dim_x, dim_x) 296 | 297 | mu_back_all[t, ...] = mu_back 298 | Lambda_back_all[t, ...] = Lambda_back 299 | 300 | return mu_back_all, Lambda_back_all 301 | 302 | 303 | def smooth(self, a, mask=None): 304 | ''' 305 | Performs Rauch-Tung-Striebel (RTS) Smoothing 306 | 307 | Parameters: 308 | ------------ 309 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step) 310 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether 311 | observations at each timestep exists (1) or are missing (0) 312 | 313 | Returns: 314 | ------------ 315 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0} 316 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0} 317 | - mu_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor smoothed estimates (t|T) where first index of the second dimension has x_{0|T} 318 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0} 319 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0} 320 | - Lambda_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance smoothed estimates (t|T) where first index of the second dimension has P_{0|T} 321 | ''' 322 | 323 | mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all = self.compute_forwards(a=a, mask=mask) 324 | mu_back_all, Lambda_back_all = self.compute_backwards(mu_pred_all=mu_pred_all, 325 | mu_t_all=mu_t_all, 326 | Lambda_pred_all=Lambda_pred_all, 327 | Lambda_t_all=Lambda_t_all) 328 | 329 | # Swab num_seq and num_steps dimensions 330 | mu_pred_all = torch.permute(mu_pred_all, (1, 0, 2)) 331 | mu_t_all = torch.permute(mu_t_all, (1, 0, 2)) 332 | mu_back_all = torch.permute(mu_back_all, (1, 0, 2)) 333 | 334 | Lambda_pred_all = torch.permute(Lambda_pred_all, (1, 0, 2, 3)) 335 | Lambda_t_all = torch.permute(Lambda_t_all, (1, 0, 2, 3)) 336 | Lambda_back_all = torch.permute(Lambda_back_all, (1, 0, 2, 3)) 337 | 338 | return mu_pred_all, mu_t_all, mu_back_all, Lambda_pred_all, Lambda_t_all, Lambda_back_all 339 | 340 | 341 | def forward(self, a, mask=None, do_smoothing=False): 342 | ''' 343 | Forward pass function for LDM Module 344 | 345 | Parameters: 346 | ------------ 347 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step) 348 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether 349 | observations at each timestep exists (1) or are missing (0) 350 | do_smoothing: bool, Whether to run RTS Smoothing or not 351 | 352 | Returns: 353 | ------------ 354 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0} 355 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0} 356 | - mu_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor smoothed estimates (t|T) where first index of the second dimension has x_{0|T}. Ones tensor if do_smoothing is False 357 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0} 358 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0} 359 | - Lambda_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance smoothed estimates (t|T) where first index of the second dimension has P_{0|T}. Ones tensor if do_smoothing is False 360 | ''' 361 | 362 | if do_smoothing: 363 | mu_pred_all, mu_t_all, mu_back_all, Lambda_pred_all, Lambda_t_all, Lambda_back_all = self.smooth(a=a, mask=mask) 364 | else: 365 | mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all = self.filter(a=a, mask=mask) 366 | mu_back_all = torch.ones_like(mu_t_all, dtype=torch.float32, device=mu_t_all.device) 367 | Lambda_back_all = torch.ones_like(Lambda_t_all, dtype=torch.float32, device=Lambda_t_all.device) 368 | 369 | return mu_pred_all, mu_t_all, mu_back_all, Lambda_pred_all, Lambda_t_all, Lambda_back_all 370 | 371 | 372 | 373 | 374 | 375 | -------------------------------------------------------------------------------- /modules/MLP.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class MLP(nn.Module): 12 | ''' 13 | MLP Module for DFINE encoder and decoder in addition to the mapper to behavior for supervised DFINE. 14 | Encoder encodes the high-dimensional neural observations into low-dimensional manifold latent factors space 15 | and decoder decodes the manifold latent factors into high-dimensional neural observations. 16 | ''' 17 | 18 | def __init__(self, **kwargs): 19 | ''' 20 | Initializer for an Encoder/Decoder/Mapper object. Note that Encoder/Decoder/Mapper is a subclass of torch.nn.Module. 21 | 22 | Parameters 23 | ------------ 24 | input_dim: int, Dimensionality of inputs to the MLP, default None 25 | output_dim: int, Dimensionality of outputs of the MLP , default None 26 | layer_list: list, List of number of neurons in each hidden layer, default None 27 | kernel_initializer_fn: torch.nn.init, Hidden layer weight initialization function, default nn.init.xavier_normal_ 28 | activation_fn: torch.nn, Activation function of neurons, default nn.Tanh 29 | ''' 30 | 31 | super(MLP, self).__init__() 32 | 33 | self.input_dim = kwargs.pop('input_dim', None) 34 | self.output_dim = kwargs.pop('output_dim', None) 35 | self.layer_list = kwargs.pop('layer_list', None) 36 | self.kernel_initializer_fn = kwargs.pop('kernel_initializer_fn', nn.init.xavier_normal_) 37 | self.activation_fn = kwargs.pop('activation_fn', nn.Tanh) 38 | 39 | # Create the ModuleList to stack the hidden layers 40 | self.layers = nn.ModuleList() 41 | 42 | # Create the hidden layers and initialize their weights based on desired initialization function 43 | current_dim = self.input_dim 44 | for i, dim in enumerate(self.layer_list): 45 | self.layers.append(nn.Linear(current_dim, dim)) 46 | self.kernel_initializer_fn(self.layers[i].weight) 47 | current_dim = dim 48 | 49 | # Create output layer and initialize their weights based on desired initialization function 50 | self.out_layer = nn.Linear(current_dim, self.output_dim) 51 | self.kernel_initializer_fn(self.out_layer.weight) 52 | 53 | 54 | def forward(self, inp): 55 | ''' 56 | Forward pass function for MLP Module 57 | 58 | Parameters: 59 | ------------ 60 | inp: torch.Tensor, shape: (num_seq * num_steps, input_dim), Flattened batch of inputs 61 | 62 | Returns: 63 | ------------ 64 | out: torch.Tensor, shape: (num_seq * num_steps, output_dim),Flattened batch of outputs 65 | ''' 66 | 67 | # Push neural observations thru each hidden layer 68 | for layer in self.layers: 69 | inp = layer(inp) 70 | inp = self.activation_fn(inp) 71 | 72 | # Obtain the output 73 | out = self.out_layer(inp) 74 | return out -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def compute_mse(y_flat, y_hat_flat, mask_flat=None): 13 | ''' 14 | Returns average Mean Square Error (MSE) 15 | 16 | Parameters: 17 | ------------ 18 | - y_flat: torch.Tensor, shape: (num_samp, dim_y), True data to compute MSE of 19 | - y_hat_flat: torch.Tensor, shape: (num_samp, dim_y), Predicted/Reconstructed data to compute MSE of 20 | - mask_flat: torch.Tensor, shape: (num_samp, 1), Mask to compute MSE loss which shows whether 21 | observations at each timestep exists (1) or are missing (0) 22 | 23 | Returns: 24 | ------------ 25 | - mse: torch.Tensor, Average MSE 26 | ''' 27 | 28 | if mask_flat is None: 29 | mask_flat = torch.ones(y_flat.shape[:-1], dtype=torch.float32) 30 | 31 | # Make sure mask is 2D 32 | if len(mask_flat.shape) != len(y_flat.shape): 33 | mask_flat = mask_flat.unsqueeze(dim=-1) 34 | 35 | # Compute the MSEs and mask the timesteps where observations are missing 36 | mse = (y_flat - y_hat_flat) ** 2 37 | mse = torch.mul(mask_flat, mse) 38 | 39 | # Return the mean of the mse (over available observations) 40 | if mask_flat.shape[-1] != y_flat.shape[-1]: # which means shape of mask_flat is of dimension 1 41 | num_el = mask_flat.sum() * y_flat.shape[-1] 42 | else: 43 | num_el = mask_flat.sum() 44 | 45 | mse = mse.sum() / num_el 46 | return mse 47 | 48 | 49 | def get_activation_function(activation_str): 50 | ''' 51 | Returns activation function given the activation function's name 52 | 53 | Parameters: 54 | ---------------------- 55 | - activation_str: str, Activation function's name 56 | 57 | Returns: 58 | ---------------------- 59 | - activation_fn: torch.nn, Activation function 60 | ''' 61 | 62 | if activation_str.lower() == 'elu': 63 | return nn.ELU() 64 | elif activation_str.lower() == 'hardtanh': 65 | return nn.Hardtanh() 66 | elif activation_str.lower() == 'leakyrelu': 67 | return nn.LeakyReLU() 68 | elif activation_str.lower() == 'relu': 69 | return nn.ReLU() 70 | elif activation_str.lower() == 'rrelu': 71 | return nn.RReLU() 72 | elif activation_str.lower() == 'sigmoid': 73 | return nn.Sigmoid() 74 | elif activation_str.lower() == 'mish': 75 | return nn.Mish() 76 | elif activation_str.lower() == 'tanh': 77 | return nn.Tanh() 78 | elif activation_str.lower() == 'tanhshrink': 79 | return nn.Tanhshrink() 80 | elif activation_str.lower() == 'linear': 81 | return lambda x: x 82 | 83 | def get_kernel_initializer_function(kernel_initializer_str): 84 | ''' 85 | Returns kernel initialization function given the kernel initialization function's name 86 | 87 | Parameters: 88 | ---------------------- 89 | - kernel_initializer_str: str, Kernel initialization function's name 90 | 91 | Returns: 92 | ---------------------- 93 | - kernel_initializer_fn: torch.nn.init, Kernel initialization function 94 | ''' 95 | 96 | if kernel_initializer_str.lower() == 'uniform': 97 | return nn.init.uniform_ 98 | elif kernel_initializer_str.lower() == 'normal': 99 | return nn.init.normal_ 100 | elif kernel_initializer_str.lower() == 'xavier_uniform': 101 | return nn.init.xavier_uniform_ 102 | elif kernel_initializer_str.lower() == 'xavier_normal': 103 | return nn.init.xavier_normal_ 104 | elif kernel_initializer_str.lower() == 'kaiming_uniform': 105 | return nn.init.kaiming_uniform_ 106 | elif kernel_initializer_str.lower() == 'kaiming_normal': 107 | return nn.init.kaiming_normal_ 108 | elif kernel_initializer_str.lower() == 'orthogonal': 109 | return nn.init.orthogonal_ 110 | elif kernel_initializer_str.lower() == 'default': 111 | return lambda x:x -------------------------------------------------------------------------------- /python_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def carry_to_device(data, device, dtype=torch.float32): 13 | ''' 14 | Carries dict/list of torch Tensors/numpy arrays to desired device recursively 15 | 16 | Parameters: 17 | ------------ 18 | - data: torch.Tensor/np.ndarray/dict/list: Dictionary/list of torch Tensors/numpy arrays or torch Tensor/numpy array to be carried to desired device 19 | - device: str, Device name to carry the torch Tensors/numpy arrays to 20 | - dtype: torch.dtype, Data type for torch.Tensor to be returned, torch.float32 by default 21 | 22 | Returns: 23 | ------------ 24 | - data: torch.Tensor/dict/list: Dictionary/list of torch.Tensors or torch Tensor carried to desired device 25 | ''' 26 | 27 | if torch.is_tensor(data): 28 | return data.to(device) 29 | 30 | elif isinstance(data, np.ndarray): 31 | return torch.tensor(data, dtype=dtype).to(device) 32 | 33 | elif isinstance(data, dict): 34 | for key in data.keys(): 35 | data[key] = carry_to_device(data[key], device) 36 | return data 37 | 38 | elif isinstance(data, list): 39 | for i, d in enumerate(data): 40 | data[i] = carry_to_device(d, device) 41 | return data 42 | 43 | else: 44 | return data 45 | 46 | 47 | def convert_to_tensor(x, dtype=torch.float32): 48 | ''' 49 | Converts numpy.ndarray to torch.Tensor 50 | 51 | Parameters: 52 | ------------ 53 | - x: np.ndarray, Numpy array to convert to torch.Tensor (if it's of type torch.Tensor already, it's returned without conversion) 54 | - dtype: torch.dtype, Data type for torch.Tensor to be returned, torch.float32 by default 55 | 56 | Returns: 57 | ------------ 58 | - y: torch.Tensor, Converted tensor 59 | ''' 60 | 61 | if isinstance(x, torch.Tensor): 62 | y = x 63 | elif isinstance(x, np.ndarray): 64 | y = torch.tensor(x, dtype=dtype) # use np.ndarray as middle step so that function works with tf tensors as well 65 | else: 66 | assert False, 'Only Numpy array can be converted to tensor' 67 | return y 68 | 69 | 70 | def flatten_dict(dictionary, level=[]): 71 | ''' 72 | Flattens nested dictionary by putting '.' between nested keys, reference: https://stackoverflow.com/questions/6037503/python-unflatten-dict 73 | 74 | Parameters: 75 | ------------ 76 | - dictionary: dict, Nested dictionary to be flattened 77 | - level: list, List of strings for recursion, initialized by empty list 78 | 79 | Returns: 80 | ------------ 81 | - tmp_dict: dict, Flattened dictionary 82 | ''' 83 | 84 | tmp_dict = {} 85 | for key, val in dictionary.items(): 86 | if isinstance(val, dict): 87 | tmp_dict.update(flatten_dict(val, level + [key])) 88 | else: 89 | tmp_dict['.'.join(level + [key])] = val 90 | return tmp_dict 91 | 92 | 93 | def unflatten_dict(dictionary): 94 | ''' 95 | Unflattens a flattened dictionary whose keys are joint string of nested keys separated by '.', reference: https://stackoverflow.com/questions/6037503/python-unflatten-dict 96 | 97 | Parameters: 98 | ------------ 99 | - dictionary: dict, Flat dictionary to be unflattened 100 | 101 | Returns: 102 | ------------ 103 | - resultDict: dict, Unflattened dictionary 104 | ''' 105 | 106 | resultDict = dict() 107 | for key, value in dictionary.items(): 108 | parts = key.split(".") 109 | d = resultDict 110 | for part in parts[:-1]: 111 | if part not in d: 112 | d[part] = dict() 113 | d = d[part] 114 | d[parts[-1]] = value 115 | return resultDict 116 | 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dill==0.3.4 2 | matplotlib==3.5.1 3 | numpy==1.19.5 4 | scipy==1.5.4 5 | --find-links https://download.pytorch.org/whl/torch_stable.html 6 | torch==1.11.0+cu113 7 | torchmetrics==0.8.0 8 | tqdm==4.62.3 9 | yacs==0.1.6 10 | tensorboard 11 | scikit-learn # for the tutorial -------------------------------------------------------------------------------- /time_series_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | from python_utils import convert_to_tensor 9 | import torch 10 | from scipy.stats import pearsonr 11 | 12 | 13 | def get_nrmse_error(y, y_hat, version_calculation='modified'): 14 | ''' 15 | Computes normalized root-mean-squared error between two 3D tensors. Note that this operation is not symmetric. 16 | 17 | Parameters: 18 | ------------ 19 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with true observations 20 | - y_hat: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with reconstructed/estimated observations 21 | - version_calculation: str, Version to calculate the variance. If 'regular', variance of each sequence is computed separately, 22 | which may result in unstable nrmse value since some sequences may be constant or close to being constant and 23 | results in ~0 variance, so high/unreasonable nrmse. To prevent that, variance is computed across flattened sequence in 24 | 'modified' mode. 'modified' by default. 25 | 26 | Returns: 27 | ------------ 28 | - normalized_error: torch.Tensor, shape: (dim_y,), Normalized root-mean-squared error for each data dimension 29 | - normalized_error_mean: torch.Tensor, shape: (), Average normalized root-mean-squared error for each data dimension 30 | ''' 31 | 32 | # Check if dimensions are consistent 33 | assert y.shape == y_hat.shape, f'dimensions of y {y.shape} and y_hat {y_hat.shape} do not match' 34 | assert len(y.shape) == 3, 'mismatch in x dimension: x should be in the format of (num_seq, num_steps, dim_x)' 35 | 36 | y = convert_to_tensor(y).detach().cpu() 37 | y_hat = convert_to_tensor(y_hat).detach().cpu() 38 | 39 | # carry time to first dimension 40 | y = torch.permute(y, (1,0,2)) # (num_steps, num_seq, dim_x) 41 | y_hat = torch.permute(y_hat, (1,0,2)) # (num_steps, num_seq, dim_x) 42 | 43 | recons_error = torch.mean(torch.square(y - y_hat), dim=0) 44 | 45 | # way 1 to calculate variance 46 | if version_calculation == 'regular': 47 | var_y = torch.mean(torch.square(y - torch.mean(y, dim=0)), dim=0) 48 | 49 | # way 2 to calculate variance (sometime data in a batch is flat, it's more robust to calculate variance globally) 50 | elif version_calculation == 'modified': 51 | y_resh = torch.reshape(y, (-1, y.shape[2])) 52 | var_y = torch.mean(torch.square(y_resh - torch.mean(y_resh, dim=0)), dim=0) 53 | var_y = torch.tile(var_y.unsqueeze(dim=0), (y.shape[1], 1)) 54 | normalized_error = torch.mean((torch.sqrt(recons_error) / torch.sqrt(var_y)), dim=0) # mean across batches 55 | normalized_error_mean = torch.mean(normalized_error) 56 | 57 | return normalized_error, normalized_error_mean 58 | 59 | 60 | def get_rmse_error(y, y_hat): 61 | ''' 62 | Computes root-mean-squared error between two 3D tensors 63 | 64 | Parameters: 65 | ------------ 66 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with true observations 67 | - y_hat: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with reconstructed/estimated observations 68 | 69 | Returns: 70 | ------------ 71 | - rmse: torch.Tensor, shape: (dim_y,), Root-mean-squared error for each data dimension 72 | - rmse_mean: torch.Tensor, shape: (), Average root-mean-squared error for each data dimension 73 | ''' 74 | 75 | # Check if dimensions are consistent 76 | assert y.shape == y_hat.shape, f'dimensions of y {y.shape} and y_hat {y_hat.shape} do not match' 77 | 78 | if len(y.shape) == 3: 79 | dim_y = y.shape[-1] 80 | y = y.reshape(-1, dim_y) 81 | y_hat = y_hat.reshape(-1, dim_y) 82 | 83 | y = convert_to_tensor(y).detach().cpu() 84 | y_hat = convert_to_tensor(y_hat).detach().cpu() 85 | 86 | rmse = torch.sqrt(torch.mean(torch.square(y-y_hat), dim=0)) 87 | rmse_mean = torch.nanmean(rmse.nan_to_num(posinf=torch.nan, neginf=torch.nan)) # for stability purposes 88 | return rmse, rmse_mean 89 | 90 | 91 | def get_pearson_cc(y, y_hat): 92 | ''' 93 | Computes Pearson correlation coefficient across two 2D (If 3D tensors are given, they're reshaped across 1st and 2nd dimensions) tensors across first (time) dimension. 94 | 95 | Parameters: 96 | ------------ 97 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Tensor with true observations 98 | - y_hat: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Tensor with reconstructed/estimated observations 99 | 100 | Returns: 101 | ------------ 102 | - ccs: torch.Tensor, shape: (dim_y,), Pearson correlation coefficients computed across first (time) dimension 103 | - ccs_mean: torch.Tensor, shape: (), Pearson correlation coefficients computed across first (time) dimension and averaged across data dimensions 104 | ''' 105 | 106 | assert y.shape == y_hat.shape, f'dimensions of x {y.shape} and xhat {y_hat.shape} do not match' 107 | 108 | if len(y.shape) == 3: 109 | dim_y = y.shape[-1] 110 | y = y.reshape(-1, dim_y) 111 | y_hat = y_hat.reshape(-1, dim_y) 112 | 113 | y = convert_to_tensor(y).detach().cpu().numpy() # make sure every array/tensor has .numpy() function, pearsonr works on ndarrays 114 | y_hat = convert_to_tensor(y_hat).detach().cpu().numpy() 115 | 116 | ccs = [] 117 | for dim in range(y.shape[-1]): 118 | cc, _ = pearsonr(y[:, dim], y_hat[:, dim]) 119 | ccs.append(cc) 120 | 121 | ccs = torch.tensor(ccs, dtype=torch.float32) 122 | ccs_mean = torch.nanmean(ccs.nan_to_num(posinf=torch.nan, neginf=torch.nan)) 123 | return ccs, ccs_mean 124 | 125 | 126 | 127 | def z_score_tensor(y, fit=True, **kwargs): 128 | ''' 129 | Performs z-scoring fitting and transformation. 130 | 131 | Parameters: 132 | ------------ 133 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Tensor/array to z-score 134 | (and if fit is True, to learn mean and standard deviation) 135 | - fit: bool, Whether to learn mean and standard deviation from y. If False, learnt 'mean' and 'std' should be provided as keyword arguments. 136 | - mean: torch.Tensor, shape: (), Mean to transform y. If fit is True, it's not necessary to provide since mean is going to be learnt. 0 by default. 137 | - std: torch.Tensor, shape: (), Standard deviation to transform y. If fit is True, it's not necessary to provide since std is going to be learnt. 1 by default. 138 | 139 | Returns: 140 | ------------ 141 | y_z_scored: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Z-scored tensor/array 142 | mean: torch.Tensor, shape: (), Learnt mean. If fit is True, it's the mean provided via keyword, or default 143 | mean: torch.Tensor/np.ndarray, Learnt standard deviation. If fit is True, it's the std provided via keyword, or default 144 | ''' 145 | 146 | # Make sure that gradients are turned off 147 | with torch.no_grad(): 148 | y = convert_to_tensor(y) 149 | 150 | y_resh = y.reshape(-1, y.shape[-1]) 151 | if fit: 152 | mean = torch.mean(y_resh, dim=0) 153 | std = torch.std(y_resh, dim=0) 154 | else: 155 | mean = kwargs.pop('mean', 0) 156 | std = kwargs.pop('std', 1) 157 | 158 | # to prevent nan values 159 | std[std==0] = 1 160 | 161 | y_resh = (y_resh - mean) / std 162 | y_z_scored = y_resh.reshape(y.shape) 163 | return y_z_scored, mean, std 164 | 165 | -------------------------------------------------------------------------------- /trainers/BaseTrainer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | import os 9 | import torch 10 | import datetime 11 | import logging 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | 15 | class BaseTrainer: 16 | ''' 17 | Base trainer class which is overwritten by TrainerDFINE. 18 | ''' 19 | 20 | def __init__(self, config): 21 | ''' 22 | Initializer of BaseTrainer. 23 | 24 | Parameters: 25 | ------------ 26 | - config: yacs.config.CfgNode, yacs config which contains all hyperparameters required to create the DFINE model 27 | Please see config.py for the hyperparameters, their default values and definitions. 28 | ''' 29 | 30 | self.config = config 31 | 32 | # Checkpoint and plot save directories, create directories if they don't exist 33 | self.ckpt_save_dir = os.path.join(self.config.model.save_dir, 'ckpts'); os.makedirs(self.ckpt_save_dir, exist_ok=True) 34 | self.plot_save_dir = os.path.join(self.config.model.save_dir, 'plots'); os.makedirs(self.plot_save_dir, exist_ok=True) 35 | 36 | # Training can be continued where it was left of, by default, training start epoch is 1 37 | self.start_epoch = 1 38 | 39 | # Tensorboard summary writer 40 | self.writer = SummaryWriter(log_dir=os.path.join(self.config.model.save_dir, 'summary')) 41 | 42 | 43 | def _save_config(self, config_save_name='config.yaml'): 44 | ''' 45 | Saves the config inside config.model.save_dir 46 | 47 | Parameters: 48 | ------------ 49 | - config_save_name: str, Fullfile name to save the config, 'config.yaml' by default 50 | ''' 51 | 52 | config_save_path = os.path.join(self.config.model.save_dir, config_save_name) 53 | with open(config_save_path, 'w') as outfile: 54 | outfile.write(self.config.dump()) 55 | 56 | 57 | def _get_optimizer(self, params): 58 | ''' 59 | Creates the Adam optimizer with initial learning rate and epsilon specified inside config by config.lr.init and config.optim.eps, respectively 60 | 61 | Parameters: 62 | ------------ 63 | - params: Parameters to be optimized by the optimizer 64 | 65 | Returns: 66 | ------------ 67 | - optimizer: Adam optimizer with desired learning rate, epsilon to optimize parameters specified by params 68 | ''' 69 | 70 | optimizer = torch.optim.Adam(params=params, 71 | lr=self.config.lr.init, 72 | eps=self.config.optim.eps) 73 | return optimizer 74 | 75 | 76 | def _get_lr_scheduler(self): 77 | ''' 78 | Creates the learning rate scheduler based on scheduler type specified in config.lr.scheduler. Options are constrained by StepLR (explr), CyclicLR (cyclic) and LambdaLR (which is used as constantlr). 79 | ''' 80 | 81 | if self.config.lr.scheduler.lower() == 'explr': 82 | scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, gamma=self.config.lr.explr.gamma, step_size=self.config.lr.explr.step_size) 83 | elif self.config.lr.scheduler.lower() == 'cyclic': 84 | scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.config.lr.cyclic.base_lr, max_lr=self.config.lr.cyclic.max_lr, mode=self.config.lr.cyclic.mode, gamma=self.config.lr.cyclic.gamma, step_size_up=self.config.lr.cyclic.step_size_up, cycle_momentum=False) 85 | elif self.config.lr.scheduler.lower() == 'constantlr': 86 | scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1) 87 | else: 88 | assert False, 'Only these learning rate schedulers are available: StepLR (explr), CyclicLR (cyclic) and LambdaLR (which is constantlr)!' 89 | return scheduler 90 | 91 | 92 | def _get_metrics(self): 93 | ''' 94 | Empty function, overwritten function must return metric names as list metrics as nested dictionary. Keys are: 95 | - train: dict, Training Mean metrics 96 | - valid: dict, Validation Mean metrics 97 | ''' 98 | pass 99 | 100 | 101 | def _reset_metrics(self, train_valid='train'): 102 | ''' 103 | Resets the metrics 104 | 105 | Parameters: 106 | ------------ 107 | - train_valid: str, Which metrics to reset, 'train' by default 108 | ''' 109 | 110 | for _, metric in self.metrics[train_valid].items(): 111 | metric.reset() 112 | 113 | 114 | def _update_metrics(self, loss_dict, batch_size, train_valid='train', verbose=True): 115 | ''' 116 | Updates the metrics 117 | 118 | Parameters: 119 | ------------ 120 | - loss_dict: dict, Dictionary with loss values to log in Tensorboard 121 | - batch_size: int, Number of trials for which the metrics are computed for 122 | - train_valid, str, Which metrics to update, 'train' by default 123 | - verbose: bool, Whether to print the warning if a key in metric_names doesn't exist in loss_dict 124 | ''' 125 | 126 | for key in self.metric_names: 127 | if key not in loss_dict: 128 | if verbose: 129 | self.logger.warning(f'{key} does not exist in loss_dict, metric cannot be updated!') 130 | else: 131 | pass 132 | else: 133 | self.metrics[train_valid][key].update(loss_dict[key], batch_size) 134 | 135 | 136 | def _get_logger(self, prefix='dfine'): 137 | ''' 138 | Creates the logger which is saved as .log file under config.model.save_dir 139 | 140 | Parameters: 141 | ------------ 142 | - prefix: str, Prefix which is used as logger's name and .log file's name, 'dfine' by default 143 | 144 | Returns: 145 | ------------ 146 | - logger: logging.Logger, Logger object to write logs into .log file 147 | ''' 148 | 149 | os.makedirs(self.config.model.save_dir, exist_ok=True) 150 | date_time = datetime.datetime.now().strftime("%m-%d_%H-%M") 151 | log_path = os.path.join(self.config.model.save_dir, f'{prefix}_{date_time}.log') 152 | 153 | # from: https://stackoverflow.com/a/56689445/16228104 154 | logger = logging.getLogger(f'{prefix.upper()} Logger') 155 | logger.setLevel(logging.DEBUG) 156 | 157 | # Remove old handlers from logger (since logger is static object) so that in several calls, it doesn't overwrite to previous log files 158 | handlers = logger.handlers[:] 159 | for handler in handlers: 160 | logger.removeHandler(handler) 161 | handler.close() 162 | 163 | # Create file handler which logs even debug messages 164 | fh = logging.FileHandler(log_path, mode='w') 165 | fh.setLevel(logging.DEBUG) 166 | 167 | # Create console handler with a higher log level 168 | ch = logging.StreamHandler() 169 | ch.setLevel(logging.DEBUG) 170 | 171 | # Create formatter and add it to the handlers 172 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', '%m/%d/%Y %I:%M:%S %p') 173 | ch.setFormatter(formatter) 174 | fh.setFormatter(formatter) 175 | 176 | # Add the handlers to logger 177 | logger.addHandler(ch) 178 | logger.addHandler(fh) 179 | 180 | return logger 181 | 182 | 183 | def _load_ckpt(self, model, optimizer, lr_scheduler=None): 184 | ''' 185 | Loads the checkpoint whose number is specified in the config by config.load.ckpt 186 | 187 | Parameters: 188 | ------------ 189 | - model: torch.nn.Module, Initialized DFINE model to load the parameters to 190 | - optimizer: torch.optim.Adam, Initialized Adam optimizer to load optimizer parameters to (loading is skipped if config.load.resume_train is False) 191 | - lr_scheduler: torch.optim.lr_scheduler, Initialized learning rate scheduler to load learning rate scheduler parameters to, None by default (loading is skipped if config.load.resume_train is False) 192 | 193 | Returns: 194 | ------------ 195 | - model: torch.nn.Module, Loaded DFINE model 196 | - optimizer: torch.optim.Adam, Loaded Adam optimizer (if config.load.resume_train is True, otherwise, initialized optimizer is returned) 197 | - lr_scheduler: torch.optim.lr_scheduler, Loaded learning rate scheduler (if config.load.resume_train is True, otherwise, initialized learning rate scheduler is returned) 198 | ''' 199 | 200 | self.logger.warning('Optimizer and LR scheduler can be loaded only in resume_train mode, else they are re-initialized') 201 | load_path = os.path.join(self.config.model.save_dir, 'ckpts', f'{self.config.load.ckpt}_ckpt.pth') 202 | self.logger.info(f'Loading model from: {load_path}...') 203 | 204 | # Load the checkpoint 205 | try: 206 | ckpt = torch.load(load_path) 207 | except: 208 | self.logger.error('Ckpt path does not exist!') 209 | assert False, '' 210 | 211 | # If config.load.resume_train is True, load optimizer and learning rate scheduler 212 | if self.config.load.resume_train: 213 | self.start_epoch = ckpt['epoch'] + 1 if isinstance(ckpt['epoch'], int) else 1 214 | try: 215 | optimizer.load_state_dict(ckpt['optimizer']) 216 | except: 217 | self.logger.error('Optimizer cannot be loaded!, check if optimizer type is consistent!') 218 | assert False, '' 219 | 220 | if lr_scheduler is not None: 221 | try: 222 | lr_scheduler.load_state_dict(ckpt['lr_scheduler']) 223 | except: 224 | self.logger.error('LR scheduler cannot be loaded, check if scheduler type is consistent!') 225 | assert False, '' 226 | 227 | try: 228 | model.load_state_dict(ckpt['state_dict']) 229 | except: 230 | self.logger.error('Given architecture in config does not match the architecture of given checkpoint!') 231 | assert False, '' 232 | 233 | self.logger.info(f'Checkpoint succesfully loaded from {load_path}!') 234 | return model, optimizer, lr_scheduler 235 | 236 | 237 | def write_model_gradients(self, model, step, prefix='unclipped'): 238 | ''' 239 | Logs the gradient norms to Tensorboard 240 | 241 | Parameters: 242 | ------------ 243 | - model: torch.nn.Module, DFINE model whose gradients are to be logged 244 | - step: int, Step to log gradients for 245 | - prefix: str, Prefix for gradient norms to be logged, it can be 'clipped' or 'unclipped', 'unclipped' by default 246 | ''' 247 | 248 | total_norm = 0 249 | for name, p in model.named_parameters(): 250 | if p.grad is not None: 251 | grad_norm = p.grad.detach().data.cpu().norm(2) 252 | total_norm += grad_norm ** 2 253 | self.writer.add_scalar('grads/' + name + f"/{prefix}_grad", grad_norm, step) 254 | 255 | total_norm = total_norm ** 0.5 256 | self.writer.add_scalar(f'grads/total_grad_{prefix}_norm', total_norm, step) 257 | 258 | 259 | def _save_ckpt(self, epoch, model, optimizer, lr_scheduler=None): 260 | ''' 261 | Saves the checkpoint under ckpt_save_dir (see __init__) with filename {epoch}_ckpt.pth 262 | 263 | Parameters: 264 | ------------ 265 | - epoch: int, Epoch number for which the checkpoint is to be saved for 266 | - model: torch.nn.Module, DFINE model to be saved 267 | - optimizer: torch.optim.Adam, Adam optimizer to be saved 268 | - lr_scheduler: torch.optim.lr_scheduler, Learning rate scheduler to be saved 269 | ''' 270 | 271 | save_path = os.path.join(self.ckpt_save_dir, f'{epoch}_ckpt.pth') 272 | if lr_scheduler is not None: 273 | torch.save({ 274 | 'state_dict': model.state_dict(), 275 | 'optimizer': optimizer.state_dict(), 276 | 'lr_scheduler': lr_scheduler.state_dict(), 277 | 'epoch': epoch 278 | }, save_path) 279 | else: 280 | torch.save({ 281 | 'state_dict': model.state_dict(), 282 | 'optimizer': optimizer.state_dict(), 283 | 'epoch': epoch 284 | }, save_path) 285 | 286 | 287 | def train_epoch(self, epoch, train_loader): 288 | ''' 289 | Empty function, overwritten function performs single epoch training 290 | 291 | Parameters: 292 | ------------ 293 | - epoch: int, Epoch number for which the training iterations are performed 294 | - train_loader: torch.utils.data.DataLoader, Training dataloader 295 | ''' 296 | 297 | pass 298 | 299 | 300 | def valid_epoch(self, epoch, valid_loader): 301 | ''' 302 | Empty function, overwritten function performs single epoch validation 303 | 304 | Parameters: 305 | ------------ 306 | - epoch: int, Epoch number for which the training iterations are performed 307 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader 308 | ''' 309 | 310 | pass 311 | 312 | 313 | def train(self, train_loader, valid_loader): 314 | ''' 315 | Empty function, overwritten function performs DFINE training for number of epochs specified in config.train.num_epochs 316 | 317 | Parameters: 318 | ------------ 319 | - train_loader: torch.utils.data.DataLoader, Training dataloader 320 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader 321 | ''' 322 | 323 | pass -------------------------------------------------------------------------------- /trainers/TrainerDFINE.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2023 University of Southern California 3 | See full notice in LICENSE.md 4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi 5 | Shanechi Lab, University of Southern California 6 | ''' 7 | 8 | import os 9 | import torch 10 | from torch.nn.utils import clip_grad_norm_ 11 | from tqdm import tqdm 12 | import numpy as np 13 | import timeit 14 | import matplotlib 15 | import matplotlib.pyplot as plt 16 | # matplotlib.use('Agg') # It disables interactive GUI backend. Commented by default but in case of config.train.plot_save_steps << config.train.num_epochs, please uncomment this line (otherwise, may throw matplotlib error). 17 | 18 | from trainers.BaseTrainer import BaseTrainer 19 | from DFINE import DFINE 20 | from python_utils import carry_to_device 21 | from time_series_utils import get_nrmse_error 22 | from metrics import Mean 23 | 24 | torch.set_printoptions(precision=3) 25 | np.set_printoptions(precision=3) 26 | 27 | 28 | class TrainerDFINE(BaseTrainer): 29 | ''' 30 | Trainer class for DFINE model. 31 | ''' 32 | 33 | def __init__(self, config): 34 | ''' 35 | Initializer for a TrainerDFINE object. Note that TrainerDFINE is a subclass of trainers.BaseTrainer. 36 | 37 | Parameters 38 | ------------ 39 | - config: yacs.config.CfgNode, yacs config which contains all hyperparameters required to create and train the DFINE model 40 | Please see config.py for the hyperparameters, their default values and definitions. 41 | ''' 42 | 43 | super(TrainerDFINE, self).__init__(config) 44 | 45 | # Initialize training time statistics 46 | self.training_time = 0 47 | self.training_time_epochs = [] 48 | 49 | # Initialize best validation losses 50 | self.best_val_loss = torch.inf 51 | if self.config.model.supervise_behv: 52 | self.best_val_behv_loss = torch.inf 53 | 54 | # Initialize logger 55 | self.logger = self._get_logger(prefix='dfine') 56 | 57 | # Set device 58 | self.device = 'cpu' if self.config.device == 'cpu' or not torch.cuda.is_available() else 'cuda:0' 59 | self.config.device = self.device # if cuda is asked in config but it's not available, config is also updated 60 | 61 | # Initialize the model, optimizer and learning rate scheduler 62 | self.dfine = DFINE(self.config); 63 | self.dfine.to(self.device) # carry the model to the desired device 64 | self.optimizer = self._get_optimizer(params=self.dfine.parameters()) 65 | self.lr_scheduler = self._get_lr_scheduler() 66 | 67 | # Load ckpt if asked, model with best validation model loss can be loaded as well, which is saved with name 'best_loss_ckpt.pth' 68 | if (isinstance(self.config.load.ckpt, int) and self.config.load.ckpt > 1) or isinstance(self.config.load.ckpt, str): 69 | self.dfine, self.optimizer, self.lr_scheduler = self._load_ckpt(model=self.dfine, 70 | optimizer=self.optimizer, 71 | lr_scheduler=self.lr_scheduler) 72 | 73 | # Get the metrics 74 | self.metric_names, self.metrics = self._get_metrics() 75 | 76 | # Save the config 77 | self._save_config() 78 | 79 | 80 | def _get_metrics(self): 81 | ''' 82 | Creates the metric names and nested metrics dictionary. 83 | 84 | Returns: 85 | ------------ 86 | - metric_names: list, Metric names to log in Tensorboard, which are the keys of train/valid defined below 87 | - metrics_dictionary: dict, nested metrics dictionary. Keys (and metric_names) are (e.g. for config.loss.steps_ahead = [1,2]): 88 | - train: 89 | - steps_{k}_mse: metrics.Mean, Training {k}-step ahead predicted MSE 90 | - model_loss: metrics.Mean, Training negative sum of {k}-step ahead predicted MSEs (e.g. steps_1_mse + steps_2_mse) 91 | - reg_loss: metrics.Mean, L2 regularization loss for DFINE encoder and decoder weights 92 | - behv_mse: metrics.Mean, Exists if config.model.supervise_behv is True, Training behavior MSE 93 | - behv_loss: metrics.Mean, Exists if config.model.supervise_behv is True, Training behavior reconstruction loss 94 | - total_loss: metrics.Mean, Sum of training model_loss, reg_loss and behv_loss (if config.model.supervise_behv is True) 95 | - valid: 96 | - steps_{k}_mse: metrics.Mean, Validation {k}-step ahead predicted MSE 97 | - model_loss: metrics.Mean, Validation negative sum of {k}-step ahead predicted MSEs (e.g. steps_1_mse + steps_2_mse) 98 | - reg_loss: metrics.Mean, L2 regularization loss for DFINE encoder and decoder weights 99 | - behv_mse: metrics.Mean, Exists if config.model.supervise_behv is True, Validation behavior MSE 100 | - behv_loss: metrics.Mean, Exists if config.model.supervise_behv is True, Validation behavior reconstruction loss 101 | - total_loss: metrics.Mean, Sum of validation model_loss, reg_loss and behv_loss (if config.model.supervise_behv is True) 102 | ''' 103 | 104 | metric_names = [] 105 | for k in self.config.loss.steps_ahead: 106 | metric_names.append(f'steps_{k}_mse') 107 | 108 | if self.config.model.supervise_behv: 109 | metric_names.append('behv_mse') 110 | metric_names.append('behv_loss') 111 | metric_names.append('model_loss') 112 | metric_names.append('reg_loss') 113 | metric_names.append('total_loss') 114 | 115 | metrics = {} 116 | metrics['train'] = {} 117 | metrics['valid'] = {} 118 | 119 | for key in metric_names: 120 | metrics['train'][key] = Mean() 121 | metrics['valid'][key] = Mean() 122 | 123 | return metric_names, metrics 124 | 125 | 126 | def _get_log_str(self, epoch, train_valid='train'): 127 | ''' 128 | Creates the logging/printing string of training/validation statistics at each epoch 129 | 130 | Parameters: 131 | ------------ 132 | - epoch: int, Number of epoch to log the statistics for 133 | - train_valid: str, Training or validation prefix to log the statistics, 'train' by default 134 | 135 | Returns: 136 | ------------ 137 | - log_str: str, Logging string 138 | ''' 139 | 140 | log_str = f'Epoch {epoch}, {train_valid.upper()}\n' 141 | 142 | # Logging k-step ahead predicted MSEs 143 | for k in self.config.loss.steps_ahead: 144 | if k == 1: 145 | log_str += f"{k}_step_mse: {self.metrics[train_valid][f'steps_{k}_mse'].compute():.5f}\n" 146 | else: 147 | log_str += f"{k}_steps_mse: {self.metrics[train_valid][f'steps_{k}_mse'].compute():.5f}\n" 148 | 149 | # Logging L2 regularization loss and L2 scale 150 | log_str += f"reg_loss: {self.metrics[train_valid]['reg_loss'].compute():.5f}, scale_l2: {self.dfine.scale_l2:.5f}\n" 151 | 152 | # If model is behavior-supervised, log behavior reconstruction loss 153 | if self.config.model.supervise_behv: 154 | log_str += f"behv_loss: {self.metrics[train_valid]['behv_loss'].compute():.5f}, scale_behv_recons: {self.dfine.scale_behv_recons:.5f}\n" 155 | 156 | # Finally, log model_loss and total_loss to optimize 157 | log_str += f"model_loss: {self.metrics[train_valid]['model_loss'].compute():.5f}, total_loss: {self.metrics[train_valid]['total_loss'].compute():.5f}\n" 158 | return log_str 159 | 160 | 161 | def train_epoch(self, epoch, train_loader, verbose=True): 162 | ''' 163 | Performs single epoch training over batches, logging to Tensorboard and plot generation 164 | 165 | Parameters: 166 | ------------ 167 | - epoch: int, Number of epoch to perform training iteration 168 | - train_loader: torch.utils.data.DataLoader, Training dataloader 169 | ''' 170 | 171 | # Take the model into training mode 172 | self.dfine.train() 173 | 174 | # Reset the metrics at the beginning of each epoch 175 | self._reset_metrics(train_valid='train') 176 | 177 | # Keep track of update step for logging the gradient norms 178 | step = (epoch - 1) * len(train_loader) + 1 179 | 180 | # Keep the time which training epoch starts 181 | start_time = timeit.default_timer() 182 | 183 | # Start iterating over batches 184 | with tqdm(train_loader, unit='batch') as tepoch: 185 | for _, batch in enumerate(tepoch): 186 | tepoch.set_description(f"Epoch {epoch}, TRAIN") 187 | 188 | # Carry data to device 189 | batch = carry_to_device(data=batch, device=self.device) 190 | y_batch, behv_batch, mask_batch = batch 191 | 192 | # Perform forward pass and compute loss 193 | model_vars = self.dfine(y=y_batch, mask=mask_batch) 194 | loss, loss_dict = self.dfine.compute_loss(y=y_batch, 195 | model_vars=model_vars, 196 | mask=mask_batch, 197 | behv=behv_batch) 198 | 199 | # Compute model gradients 200 | self.optimizer.zero_grad() 201 | loss.backward() 202 | 203 | # Log UNCLIPPED model gradients after gradient computations 204 | self.write_model_gradients(self.dfine, step=step, prefix='unclipped') 205 | 206 | # Skip gradient clipping for the first epoch 207 | if epoch > 1: 208 | clip_grad_norm_(self.dfine.parameters(), self.config.optim.grad_clip) 209 | 210 | # Log CLIPPED model gradients after gradient computations 211 | self.write_model_gradients(model=self.dfine, step=step, prefix='clipped') 212 | 213 | # Update model parameters 214 | self.optimizer.step() 215 | 216 | # Update metrics 217 | self._update_metrics(loss_dict=loss_dict, 218 | batch_size=y_batch.shape[0], 219 | train_valid='train', 220 | verbose=False) 221 | 222 | # Update the step 223 | step += 1 224 | 225 | # Get the runtime for the training epoch 226 | epoch_time = timeit.default_timer() - start_time 227 | self.training_time += epoch_time 228 | self.training_time_epochs.append(epoch_time) 229 | 230 | # Save model, optimizer and learning rate scheduler (we save the initial and the last model no matter what config.model.save_steps is) 231 | if epoch % self.config.model.save_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs: 232 | self._save_ckpt(epoch=epoch, 233 | model=self.dfine, 234 | optimizer=self.optimizer, 235 | lr_scheduler=self.lr_scheduler) 236 | 237 | # Write model summary 238 | self.write_summary(epoch, prefix='train') 239 | 240 | # Create and save plots from the last batch 241 | if epoch % self.config.train.plot_save_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs: 242 | self.create_plots(y_batch=y_batch, 243 | behv_batch=behv_batch, 244 | model_vars=model_vars, 245 | epoch=epoch, 246 | prefix='train') 247 | 248 | # Logging the training step information for last batch 249 | if verbose and (epoch % self.config.train.print_log_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs): 250 | log_str = self._get_log_str(epoch=epoch, train_valid='train') 251 | self.logger.info(log_str) 252 | 253 | # Update LR 254 | self.lr_scheduler.step() 255 | 256 | 257 | def valid_epoch(self, epoch, valid_loader, verbose=True): 258 | ''' 259 | Performs single epoch validation over batches, logging to Tensorboard and plot generation 260 | 261 | Parameters: 262 | ------------ 263 | - epoch: int, Number of epoch to perform validation 264 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader 265 | ''' 266 | 267 | with torch.no_grad(): 268 | # Take the model into evaluation mode 269 | self.dfine.eval() 270 | 271 | # Reset metrics at the beginning of each epoch 272 | self._reset_metrics(train_valid='valid') 273 | 274 | # Start iterating over the batches 275 | y_all, mask_all = [], [] 276 | with tqdm(valid_loader, unit='batch') as tepoch: 277 | for _, batch in enumerate(tepoch): 278 | tepoch.set_description(f"Epoch {epoch}, VALID") 279 | 280 | # Carry data to device 281 | batch = carry_to_device(data=batch, device=self.device) 282 | y_batch, behv_batch, mask_batch = batch 283 | y_all.append(y_batch) 284 | mask_all.append(mask_batch) 285 | 286 | # Perform forward pass and compute loss 287 | model_vars = self.dfine(y=y_batch, mask=mask_batch) 288 | _, loss_dict = self.dfine.compute_loss(y=y_batch, 289 | model_vars=model_vars, 290 | mask=mask_batch, 291 | behv=behv_batch) 292 | 293 | # Update metrics 294 | self._update_metrics(loss_dict=loss_dict, 295 | batch_size=y_batch.shape[0], 296 | train_valid='valid', 297 | verbose=False) 298 | 299 | # Perform one-step-ahead prediction on the provided validation data, for evaluation 300 | y_all = torch.cat(y_all, dim=0) 301 | mask_all = torch.cat(mask_all, dim=0) 302 | model_vars_all = self.dfine(y=y_all, mask=mask_all) 303 | y_pred_all = model_vars_all['y_pred'] 304 | _, one_step_ahead_nrmse = get_nrmse_error(y_all[:, 1:, :], y_pred_all) 305 | self.training_valid_one_step_nrmses.append(one_step_ahead_nrmse) 306 | 307 | # Write model summary 308 | self.write_summary(epoch, prefix='valid') 309 | 310 | # Save the best validation loss model (and best behavior reconstruction loss model if supervised) 311 | if self.metrics['valid']['model_loss'].compute() < self.best_val_loss: 312 | self.best_val_loss = self.metrics['valid']['model_loss'].compute() 313 | self._save_ckpt(epoch='best_loss', 314 | model=self.dfine, 315 | optimizer=self.optimizer, 316 | lr_scheduler=self.lr_scheduler) 317 | 318 | if self.config.model.supervise_behv: 319 | if self.metrics['valid']['behv_loss'].compute() < self.best_val_behv_loss: 320 | self.best_val_behv_loss = self.metrics['valid']['behv_loss'].compute() 321 | self._save_ckpt(epoch='best_behv_loss', 322 | model=self.dfine, 323 | optimizer=self.optimizer, 324 | lr_scheduler=self.lr_scheduler) 325 | 326 | # Create and save plots from last batch 327 | if epoch % self.config.train.plot_save_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs: 328 | self.create_plots(y_batch=y_batch, 329 | behv_batch=behv_batch, 330 | model_vars=model_vars, 331 | epoch=epoch, 332 | prefix='valid') 333 | 334 | if verbose and (epoch % self.config.train.print_log_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs): 335 | # Logging the validation step information for last batch 336 | log_str = self._get_log_str(epoch=epoch, train_valid='valid') 337 | self.logger.info(log_str) 338 | 339 | 340 | def train(self, train_loader, valid_loader=None): 341 | ''' 342 | Performs full training of DFINE model 343 | 344 | Parameters: 345 | ------------ 346 | - train_loader: torch.utils.data.DataLoader, Training dataloader 347 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader, None by default (if no valid_loader is provided, validation is skipped) 348 | ''' 349 | 350 | # Bookkeeping the validation NRMSEs over the course of training 351 | self.training_valid_one_step_nrmses = [] 352 | 353 | # Start iterating over the epochs 354 | for epoch in range(self.start_epoch, self.config.train.num_epochs + 1): 355 | # Perform validation with the initialized model 356 | if epoch == self.start_epoch: 357 | self.valid_epoch(epoch, valid_loader, verbose=False) 358 | 359 | # Perform training iteration over train_loader 360 | self.train_epoch(epoch, train_loader) 361 | 362 | # Perform validation over valid_loader if it's not None and we're at validation epoch 363 | if (epoch % self.config.train.valid_step == 0) and isinstance(valid_loader, torch.utils.data.dataloader.DataLoader): 364 | self.valid_epoch(epoch, valid_loader) 365 | 366 | 367 | def create_plots(self, y_batch, model_vars, behv_batch=None, mask_batch=None, epoch=1, trial_num=0, prefix='train'): 368 | ''' 369 | Creates training/validation plots of neural reconstruction, manifold latent factors and dynamic latent factors 370 | 371 | Parameters: 372 | ------------ 373 | - y_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of high-dimensional neural observations 374 | - model_vars: dict, Dictionary which contains inferrred latents, predictions and reconstructions. See DFINE.forward for further details. 375 | - epoch: int, Number of epoch for which to create plot 376 | - behv_batch: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of behavior, None by default 377 | - trial_num: int, Trial number to plot 378 | - prefix: str, Plotname prefix to save plots 379 | ''' 380 | 381 | # Create the mask if it's None 382 | if mask_batch is None: 383 | mask_batch = torch.ones(y_batch.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 384 | 385 | # Generate and save reconstructed neural observation plot 386 | self.create_y_plot(y_batch=y_batch, y_hat_batch=model_vars['y_hat'], mask_batch=mask_batch, epoch=epoch, trial_num=trial_num, prefix=f'{prefix}', feat_name='y_hat') 387 | # Generate and save smoothed neural observation plot 388 | self.create_y_plot(y_batch=y_batch, y_hat_batch=model_vars['y_smooth'], mask_batch=mask_batch, epoch=epoch, trial_num=trial_num, prefix=f'{prefix}', feat_name='y_smooth') 389 | 390 | # Generate and save smoothed manifold latent factor plot 391 | self.create_k_step_ahead_plot(y_batch=y_batch, model_vars=model_vars, mask_batch=mask_batch, epoch=epoch, trial_num=trial_num, prefix=prefix) 392 | 393 | # Generate and save projected (encoder output directly) manifold latent factor plot 394 | self.create_latent_factor_plot(f=model_vars['a_hat'], epoch=epoch, trial_num=trial_num, prefix=prefix, feat_name='a_hat') 395 | # Generate and save smoothed manifold latent factor plot 396 | self.create_latent_factor_plot(f=model_vars['a_smooth'], epoch=epoch, trial_num=trial_num, prefix=prefix, feat_name='a_smooth') 397 | # Generate and save smoothed dynamic latent factor plot 398 | self.create_latent_factor_plot(f=model_vars['x_smooth'], epoch=epoch, trial_num=trial_num, prefix=prefix, feat_name='x_smooth') 399 | # Generate and save reconstructed behavior if model is behavior-supervised 400 | if self.config.model.supervise_behv and (behv_batch is not None): 401 | self.create_behv_recons_plot(behv_batch=behv_batch, behv_hat_batch=model_vars['behv_hat'], epoch=epoch, trial_num=trial_num, prefix=prefix) 402 | 403 | plt.close('all') 404 | 405 | 406 | def create_y_plot(self, y_batch, y_hat_batch, mask_batch=None, epoch=1, trial_num=0, prefix='train', feat_name='y_hat'): 407 | ''' 408 | Creates true and estimated neural observation plots during training and validation 409 | 410 | Parameters: 411 | ------------ 412 | - y_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), True high-dimensional neural observation 413 | - y_hat_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), Reconstructed high-dimensional neural observation, smoothed/filtered/reconstructed neural observation can be provided 414 | - mask_batch: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether 415 | observations at each timestep exists (1) or are missing (0) 416 | - epoch: int, Number of epoch for which to create the plot 417 | - trial_num:, int, Trial number in the batch to plot 418 | - prefix: str, Plotname prefix to save the plot 419 | - feat_name: str, Feature name of y_hat_batch (e.g. y_hat/y_smooth) used in plotname 420 | ''' 421 | 422 | # Create the mask if it's None 423 | if mask_batch is None: 424 | mask_batch = torch.ones(y_batch.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 425 | 426 | # Detach tensors for plotting 427 | y_batch = y_batch.detach().cpu() 428 | y_hat_batch = y_hat_batch.detach().cpu() 429 | mask_batch = mask_batch.detach().cpu() 430 | 431 | # Mask y_batch and y_hat_batch 432 | num_seq, _, dim_y = y_batch.shape 433 | mask_bool_batch = mask_batch.type(torch.bool).tile(1, 1, self.dfine.dim_y) 434 | y_batch = y_batch[mask_bool_batch].reshape(num_seq, -1, self.dfine.dim_y) 435 | y_hat_batch = y_hat_batch[mask_bool_batch].reshape(num_seq, -1, self.dfine.dim_y) 436 | 437 | # Create the figure 438 | fig = plt.figure(figsize=(20,15)) 439 | num_samples = y_batch.shape[1] 440 | color_index = range(num_samples) 441 | color_map = plt.cm.get_cmap('viridis') 442 | 443 | # Plot the true observations and noiseless observations (if it's provided) 444 | if dim_y >= 3: 445 | ax = fig.add_subplot(321, projection='3d') 446 | ax_m = ax.scatter(y_batch[trial_num, :, 0], y_batch[trial_num, :, 1], y_batch[trial_num, :, 2], c=color_index, vmin=0, vmax=num_samples, s=35, cmap=color_map, label='y_true') 447 | 448 | ax.set_xlabel('Dim 0') 449 | ax.set_ylabel('Dim 1') 450 | ax.set_zlabel('Dim 2') 451 | ax.set_title(f'True observations in 3d') 452 | ax.legend() 453 | fig.colorbar(ax_m) 454 | 455 | # Plot the reconstructed observation and noiseless observations (if it's provided) 456 | ax = fig.add_subplot(322, projection='3d') 457 | ax_m = ax.scatter(y_hat_batch[trial_num, :, 0], y_hat_batch[trial_num, :, 1], y_hat_batch[trial_num, :, 2], c=color_index, vmin=0, vmax=num_samples, s=35, cmap=color_map, label='y_hat') 458 | ax.set_title(f'Reconstructed observations in 3d') 459 | ax.set_xlabel('Dim 0') 460 | ax.set_ylabel('Dim 1') 461 | ax.set_zlabel('Dim 2') 462 | ax.legend() 463 | fig.colorbar(ax_m) 464 | 465 | # Plot Dim 0 466 | ax = fig.add_subplot(324) 467 | ax.plot(range(num_samples), y_batch[trial_num, :, 0], 'g', label='y_true') 468 | ax.plot(range(num_samples), y_hat_batch[trial_num, :, 0], 'b', label='y_hat') 469 | ax.set_title('Dim 0') 470 | 471 | if dim_y >= 2: 472 | # Plot Dim 1 473 | ax = fig.add_subplot(325) 474 | ax.plot(range(num_samples), y_batch[trial_num, :, 1], 'g', label='y_true') 475 | ax.plot(range(num_samples), y_hat_batch[trial_num, :, 1] ,'b', label='y_hat') 476 | ax.set_title('Dim 1') 477 | 478 | if dim_y >= 3: 479 | # Plot Dim 2 480 | ax = fig.add_subplot(326) 481 | ax.plot(range(num_samples), y_batch[trial_num, :, 2], 'g', label='y_true') 482 | ax.plot(range(num_samples), y_hat_batch[trial_num, :,2], 'b', label='y_hat') 483 | ax.set_title('Dim 2') 484 | ax.legend() 485 | 486 | # Save the plot under plot_save_dir 487 | plot_name = f'{prefix}_{feat_name}_{epoch}.png' 488 | plt.savefig(os.path.join(self.plot_save_dir, plot_name)) 489 | plt.close('all') 490 | 491 | 492 | def create_k_step_ahead_plot(self, y_batch, model_vars, mask_batch=None, epoch=1, trial_num=0, prefix='train'): 493 | ''' 494 | Creates true and k-step ahead predicted neural observation plots during training and validation 495 | 496 | Parameters: 497 | ------------ 498 | - y_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), True high-dimensional neural observation 499 | - model_vars: dict, Dictionary which contains inferrred latents, predictions and reconstructions. See DFINE.forward for further details. 500 | - mask_batch: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether 501 | observations at each timestep exists (1) or are missing (0) 502 | - epoch: int, Number of epoch for which to create the plot 503 | - trial_num:, int, Trial number in the batch to plot 504 | - prefix: str, Plotname prefix to save the plot 505 | ''' 506 | 507 | num_total_steps = y_batch.shape[1] 508 | 509 | # Create the mask if it's None 510 | if mask_batch is None: 511 | mask_batch = torch.ones(y_batch.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1) 512 | 513 | # Get the number of steps ahead for which DFINE is optimized and create the figure 514 | num_k = len(self.config.loss.steps_ahead) 515 | fig = plt.figure(figsize=(20, 20)) 516 | fig_num = 1 517 | 518 | # Start iterating over steps ahead for plotting 519 | for k in self.config.loss.steps_ahead: 520 | # Get the k-step ahead prediction 521 | y_pred_k_batch, _, _ = self.dfine.get_k_step_ahead_prediction(model_vars, k) 522 | 523 | # Detach tensors for plotting and take timesteps from k to T (since we're plotting k-step ahead predictions) 524 | y_batch_k = y_batch[:, k:, ...].detach().cpu() 525 | mask_batch_k = mask_batch[:, k:, :].detach().cpu() 526 | y_pred_k_batch = y_pred_k_batch.detach().cpu() 527 | 528 | # Mask y and y_hat 529 | num_seq = y_batch.shape[0] 530 | mask_bool = mask_batch_k.type(torch.bool).tile(1, 1, self.dfine.dim_y) 531 | y_batch_k = y_batch_k[mask_bool].reshape(num_seq, -1, self.dfine.dim_y) 532 | y_pred_k_batch = y_pred_k_batch[mask_bool].reshape(num_seq, -1, self.dfine.dim_y) 533 | 534 | # Plot dimension 0 535 | ax = fig.add_subplot(num_k, 2, fig_num) 536 | ax.plot(range(k, num_total_steps), y_batch_k[trial_num, :, 0], 'g', label=f'{k}-step y_true') 537 | ax.plot(range(k, num_total_steps), y_pred_k_batch[trial_num, :, 0], 'b', label=f'{k}-step predicted y') 538 | ax.set_title(f'k={k} step ahead') 539 | ax.set_xlabel('Time') 540 | ax.set_ylabel('Dim 0') 541 | ax.legend() 542 | fig_num += 1 543 | 544 | # Plot first 3 dimensions of k-step prediction as 3D scatter plot (mostly not useful visualization unless the manifold is obvious in first 3 dimensions) 545 | color_index = range(y_batch_k.shape[1]) 546 | color_map = plt.cm.get_cmap('viridis') 547 | ax = fig.add_subplot(num_k, 2, fig_num, projection='3d') 548 | ax_m = ax.scatter(y_pred_k_batch[trial_num, :, 0], y_pred_k_batch[trial_num, :, 1], y_pred_k_batch[trial_num, :, 2], c=color_index, vmin=0, vmax=y_batch.shape[1], s=35, cmap=color_map, label=f'{k}-step predicted y') 549 | ax.set_title(f'k={k} step ahead') 550 | ax.set_xlabel('Dim 0') 551 | ax.set_ylabel('Dim 1') 552 | ax.set_zlabel('Dim 2') 553 | ax.legend() 554 | fig.colorbar(ax_m) 555 | fig_num += 1 556 | 557 | # Save the plot under plot_save_dir 558 | plot_name = f'{prefix}_k_step_obs_{epoch}.png' 559 | plt.savefig(os.path.join(self.plot_save_dir, plot_name)) 560 | plt.close('all') 561 | 562 | 563 | def create_latent_factor_plot(self, f, epoch=1, trial_num=0, prefix='train', feat_name='x_smooth'): 564 | ''' 565 | Creates dynamic latent factor plots during training/validation 566 | 567 | Parameters: 568 | ------------ 569 | - f: torch.Tensor, shape: (num_seq, num_steps, dim_x/dim_a), Batch of inferred dynamic/manifold latent factors, smoothed/filtered factors can be provided 570 | - epoch: int, Number of epoch for which to create dynamic latent factor plot 571 | - trial_num: int, Trial number to plot 572 | - prefix: str, Plotname prefix to save plots 573 | - feat_name: str, Feature name of y_hat_batch (e.g. y_hat/y_smooth) used in plotname 574 | ''' 575 | 576 | # Detach the tensor for plotting 577 | f = f.detach().cpu() 578 | 579 | # From feat_name, get whether it's manifold or dynamic latent factors 580 | if feat_name[0].lower() == 'x': 581 | factor_name = 'Dynamic' 582 | else: 583 | factor_name = 'Manifold' 584 | 585 | # Create the figure and colormap 586 | fig = plt.figure(figsize=(10,8)) 587 | _, num_steps, dim_f = f.shape 588 | color_index = range(num_steps) 589 | color_map = plt.cm.get_cmap('viridis') 590 | 591 | if dim_f > 2: 592 | # Scatter first 3 dimensions of dynamic latent factors 593 | ax = fig.add_subplot(221, projection='3d') 594 | ax_m = ax.scatter(f[trial_num, :, 0], f[trial_num, :, 1], f[trial_num, :, 2], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map) 595 | ax.set_xlabel('Dim 0') 596 | ax.set_ylabel('Dim 1') 597 | ax.set_zlabel('Dim 2') 598 | ax.set_title(f'{factor_name} latent factors in 3D') 599 | fig.colorbar(ax_m) 600 | 601 | # Scatter first 2 dimensions of dynamic latent factors, top view 602 | ax = fig.add_subplot(222) 603 | ax_m = ax.scatter(f[trial_num, :, 0], f[trial_num, :, 1], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map) 604 | ax.set_xlabel('Dim 0') 605 | ax.set_ylabel('Dim 1') 606 | ax.set_title(f'{factor_name} latent factors from top') 607 | fig.colorbar(ax_m) 608 | 609 | # Plot the first dimension of dynamic latent factors 610 | ax = fig.add_subplot(223) 611 | ax.plot(range(num_steps), f[trial_num, :, 0]) 612 | ax.set_xlabel('Time') 613 | ax.set_ylabel('Dim 0') 614 | 615 | # Plot the second dimension of dynamic latent factors 616 | ax = fig.add_subplot(224) 617 | ax.plot(range(num_steps), f[trial_num, :, 1]) 618 | ax.set_xlabel('Time') 619 | ax.set_ylabel('Dim 1') 620 | 621 | elif dim_f == 2: 622 | # Scatter first 2 dimensions of dynamic latent factors, top view 623 | ax = fig.add_subplot(221) 624 | ax_m = ax.scatter(f[trial_num, :, 0], f[trial_num, :, 1], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map) 625 | ax.set_xlabel('Dim 0') 626 | ax.set_ylabel('Dim 1') 627 | ax.set_title(f'{factor_name} latent factors from top') 628 | fig.colorbar(ax_m) 629 | 630 | # Plot the first dimension of dynamic latent factors 631 | ax = fig.add_subplot(222) 632 | ax.plot(range(num_steps), f[trial_num, :, 0]) 633 | ax.set_xlabel('Time') 634 | ax.set_ylabel('Dim 0') 635 | 636 | # Plot the second dimension of dynamic latent factors 637 | ax = fig.add_subplot(223) 638 | ax.plot(range(num_steps), f[trial_num, :, 1]) 639 | ax.set_xlabel('Time') 640 | ax.set_ylabel('Dim 1') 641 | 642 | else: 643 | # Plot the first dimension of dynamic latent factors 644 | ax = fig.add_subplot(111) 645 | ax.plot(range(num_steps), f[trial_num, :, 0]) 646 | ax.set_xlabel('Time') 647 | ax.set_ylabel('Dim 0') 648 | fig.suptitle(f'{factor_name} latent factors info', fontsize=16) 649 | 650 | # Save the plot under plot_save_dir 651 | plot_name = f'{prefix}_{feat_name}_{epoch}.png' 652 | plt.savefig(os.path.join(self.plot_save_dir, plot_name)) 653 | plt.close('all') 654 | 655 | 656 | def create_behv_recons_plot(self, behv_batch, behv_hat_batch, epoch=1, trial_num=0, prefix='train'): 657 | ''' 658 | Creates behavior reconstruction plots during training/validation 659 | 660 | Parameters: 661 | ------------ 662 | - behv_batch: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of true behavior 663 | - behv_hat_batch: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of reconstructed behavior 664 | - epoch: int, Number of epoch for which to create dynamic latent factor plot 665 | - trial_num: int, Trial number to plot 666 | - prefix: str, Plotname prefix to save plots 667 | ''' 668 | 669 | # Create the figure and detach the tensors for plotting 670 | fig = plt.figure(figsize=(15,20)) 671 | behv_batch = behv_batch.detach().cpu() 672 | behv_hat_batch = behv_hat_batch.detach().cpu() 673 | 674 | # Plot the desired behavior dimension 675 | for k_i, i in enumerate(self.config.model.which_behv_dims): 676 | ax = fig.add_subplot(self.dfine.dim_behv, 1, k_i+1) 677 | ax.plot(behv_batch[trial_num, :, i], label='True Behavior', color='green') 678 | ax.plot(behv_hat_batch[trial_num, :, k_i], label='Decoded Behavior', color='red') 679 | ax.set_xlabel(f'Time') 680 | ax.set_ylabel(f'Dim {i+1}') 681 | ax.legend() 682 | 683 | # Save the plot under plot_save_dir 684 | plot_name = f'{prefix}_behv_{epoch}.png' 685 | plt.savefig(os.path.join(self.plot_save_dir, plot_name)) 686 | plt.close('all') 687 | 688 | 689 | def save_encoding_results(self, train_loader, valid_loader=None, do_full_inference=True, save_results=True): 690 | ''' 691 | Performs inference, reconstruction and predictions for training data and validation data (if provided), and saves training and inference time statistics. 692 | Then, encoding results are saved under {config.model.save_dir}/encoding_results.pt. 693 | 694 | Parameters: 695 | ------------ 696 | - train_loader: torch.utils.data.DataLoader, Training dataloader 697 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader, None by default (if no valid_loader is provided, validation inference is skipped) 698 | - do_full_inference: bool, Whether to perform inference on flattened trials of batches of segments 699 | ''' 700 | self.dfine.eval() 701 | 702 | with torch.no_grad(): 703 | ############################################################################ BATCH INFERENCE ############################################################################ 704 | # Create the keys for encoding results dictionary 705 | encoding_dict = {} 706 | encoding_dict['training_time'] = self.training_time 707 | encoding_dict['training_time_epochs'] = self.training_time_epochs 708 | encoding_dict['latent_inference_time'] = dict(train=0, valid=0) 709 | 710 | encoding_dict['x_pred'] = dict(train=[], valid=[]) 711 | encoding_dict['x_filter'] = dict(train=[], valid=[]) 712 | encoding_dict['x_smooth'] = dict(train=[], valid=[]) 713 | 714 | encoding_dict['a_hat'] = dict(train=[], valid=[]) 715 | encoding_dict['a_pred'] = dict(train=[], valid=[]) 716 | encoding_dict['a_filter'] = dict(train=[], valid=[]) 717 | encoding_dict['a_smooth'] = dict(train=[], valid=[]) 718 | 719 | encoding_dict['mask'] = dict(train=[], valid=[]) 720 | 721 | y_key_list = ['y', 'y_hat', 'y_filter', 'y_smooth', 'y_pred'] 722 | for k in self.config.loss.steps_ahead: 723 | if k != 1: 724 | y_key_list.append(f'y_{k}_pred') 725 | 726 | for y_key in y_key_list: 727 | encoding_dict[y_key] = dict(train=[], valid=[]) 728 | 729 | # If model is behavior-supervised, create the keys for behavior reconstruction 730 | if self.config.model.supervise_behv: 731 | encoding_dict['behv'] = dict(train=[], valid=[]) 732 | encoding_dict['behv_hat'] = dict(train=[], valid=[]) 733 | 734 | # Dump train_loader and valid_loader into a dictionary 735 | loaders = dict(train=train_loader, valid=valid_loader) 736 | 737 | # Start iterating over dataloaders 738 | for train_valid, loader in loaders.items(): 739 | if isinstance(loader, torch.utils.data.dataloader.DataLoader): 740 | # If loader is not None, start iterating over the batches 741 | for _, batch in enumerate(loader): 742 | # Keep track of latent inference start time 743 | start_time = timeit.default_timer() 744 | 745 | batch = carry_to_device(batch, device=self.device) 746 | y_batch, behv_batch, mask_batch = batch 747 | model_vars = self.dfine(y=y_batch, mask=mask_batch) 748 | 749 | # Add to the latent inference time over the batches 750 | encoding_dict['latent_inference_time'][train_valid] += timeit.default_timer() - start_time 751 | 752 | # Append the inference variables to the empty lists created in the beginning 753 | encoding_dict['x_pred'][train_valid].append(model_vars['x_pred'].detach().cpu()) 754 | encoding_dict['x_filter'][train_valid].append(model_vars['x_filter'].detach().cpu()) 755 | encoding_dict['x_smooth'][train_valid].append(model_vars['x_smooth'].detach().cpu()) 756 | 757 | encoding_dict['a_hat'][train_valid].append(model_vars['a_hat'].detach().cpu()) 758 | encoding_dict['a_pred'][train_valid].append(model_vars['a_pred'].detach().cpu()) 759 | encoding_dict['a_filter'][train_valid].append(model_vars['a_filter'].detach().cpu()) 760 | encoding_dict['a_smooth'][train_valid].append(model_vars['a_smooth'].detach().cpu()) 761 | 762 | encoding_dict['mask'][train_valid].append(mask_batch.detach().cpu()) 763 | encoding_dict['y'][train_valid].append(y_batch.detach().cpu()) 764 | encoding_dict['y_hat'][train_valid].append(model_vars['y_hat'].detach().cpu()) 765 | encoding_dict['y_pred'][train_valid].append(model_vars['y_pred'].detach().cpu()) 766 | encoding_dict['y_filter'][train_valid].append(model_vars['y_filter'].detach().cpu()) 767 | encoding_dict['y_smooth'][train_valid].append(model_vars['y_smooth'].detach().cpu()) 768 | 769 | for k in self.config.loss.steps_ahead: 770 | if k != 1: 771 | y_pred_k, _, _ = self.dfine.get_k_step_ahead_prediction(model_vars, k) 772 | encoding_dict[f'y_{k}_pred'][train_valid].append(y_pred_k) 773 | 774 | if self.config.model.supervise_behv: 775 | encoding_dict['behv'][train_valid].append(behv_batch.detach().cpu()) 776 | encoding_dict['behv_hat'][train_valid].append(model_vars['behv_hat'].detach().cpu()) 777 | 778 | # Convert lists to tensors 779 | encoding_dict['x_pred'][train_valid] = torch.cat(encoding_dict['x_pred'][train_valid], dim=0) 780 | encoding_dict['x_filter'][train_valid] = torch.cat(encoding_dict['x_filter'][train_valid], dim=0) 781 | encoding_dict['x_smooth'][train_valid] = torch.cat(encoding_dict['x_smooth'][train_valid], dim=0) 782 | 783 | encoding_dict['a_hat'][train_valid] = torch.cat(encoding_dict['a_hat'][train_valid], dim=0) 784 | encoding_dict['a_pred'][train_valid] = torch.cat(encoding_dict['a_pred'][train_valid], dim=0) 785 | encoding_dict['a_filter'][train_valid] = torch.cat(encoding_dict['a_filter'][train_valid], dim=0) 786 | encoding_dict['a_smooth'][train_valid] = torch.cat(encoding_dict['a_smooth'][train_valid], dim=0) 787 | 788 | encoding_dict['mask'][train_valid] = torch.cat(encoding_dict['mask'][train_valid], dim=0) 789 | for y_key in y_key_list: 790 | encoding_dict[y_key][train_valid] = torch.cat(encoding_dict[y_key][train_valid], dim=0) 791 | 792 | if self.config.model.supervise_behv: 793 | encoding_dict['behv'][train_valid] = torch.cat(encoding_dict['behv'][train_valid], dim=0) 794 | encoding_dict['behv_hat'][train_valid] = torch.cat(encoding_dict['behv_hat'][train_valid], dim=0) 795 | 796 | ############################################################################ FULL INFERENCE w/ FLATTENED SEQUENCE ############################################################################ 797 | encoding_dict_full_inference = {} 798 | 799 | if do_full_inference: 800 | # Create the keys for encoding results dictionary 801 | encoding_dict_full_inference = {} 802 | encoding_dict_full_inference['latent_inference_time'] = dict(train=0, valid=0) 803 | 804 | encoding_dict_full_inference['x_pred'] = dict(train=[], valid=[]) 805 | encoding_dict_full_inference['x_filter'] = dict(train=[], valid=[]) 806 | encoding_dict_full_inference['x_smooth'] = dict(train=[], valid=[]) 807 | 808 | encoding_dict_full_inference['a_hat'] = dict(train=[], valid=[]) 809 | encoding_dict_full_inference['a_pred'] = dict(train=[], valid=[]) 810 | encoding_dict_full_inference['a_filter'] = dict(train=[], valid=[]) 811 | encoding_dict_full_inference['a_smooth'] = dict(train=[], valid=[]) 812 | 813 | encoding_dict_full_inference['mask'] = dict(train=[], valid=[]) 814 | 815 | for y_key in y_key_list: 816 | encoding_dict_full_inference[y_key] = dict(train=[], valid=[]) 817 | 818 | # If model is behavior-supervised, create the keys for behavior reconstruction 819 | if self.config.model.supervise_behv: 820 | encoding_dict_full_inference['behv'] = dict(train=[], valid=[]) 821 | encoding_dict_full_inference['behv_hat'] = dict(train=[], valid=[]) 822 | 823 | # Dump variables to encoding_dict_full_inference 824 | for train_valid, loader in loaders.items(): 825 | if isinstance(loader, torch.utils.data.dataloader.DataLoader): 826 | # Flatten the batches of neural observations, corresponding mask and behavior if model is supervised 827 | encoding_dict_full_inference['y'][train_valid] = encoding_dict['y'][train_valid].reshape(1, -1, self.dfine.dim_y) 828 | encoding_dict_full_inference['mask'][train_valid] = encoding_dict['mask'][train_valid].reshape(1, -1, 1) 829 | 830 | if self.config.model.supervise_behv: 831 | total_dim_behv = encoding_dict['behv'][train_valid].shape[-1] 832 | encoding_dict_full_inference['behv'][train_valid] = encoding_dict['behv'][train_valid].reshape(1, -1, total_dim_behv) 833 | 834 | # Keep track of latent inference start time 835 | start_time = timeit.default_timer() 836 | model_vars = self.dfine(y=encoding_dict_full_inference['y'][train_valid].to(self.device), mask=encoding_dict_full_inference['mask'][train_valid].to(self.device)) 837 | encoding_dict_full_inference['latent_inference_time'][train_valid] += timeit.default_timer() - start_time 838 | 839 | # Append the inference variables to the empty lists created in the beginning 840 | encoding_dict_full_inference['x_pred'][train_valid] = model_vars['x_pred'].detach().cpu() 841 | encoding_dict_full_inference['x_filter'][train_valid] = model_vars['x_filter'].detach().cpu() 842 | encoding_dict_full_inference['x_smooth'][train_valid] = model_vars['x_smooth'].detach().cpu() 843 | 844 | encoding_dict_full_inference['a_hat'][train_valid] = model_vars['a_hat'].detach().cpu() 845 | encoding_dict_full_inference['a_pred'][train_valid] = model_vars['a_pred'].detach().cpu() 846 | encoding_dict_full_inference['a_filter'][train_valid] = model_vars['a_filter'].detach().cpu() 847 | encoding_dict_full_inference['a_smooth'][train_valid] = model_vars['a_smooth'].detach().cpu() 848 | 849 | encoding_dict_full_inference['y_hat'][train_valid] = model_vars['y_hat'].detach().cpu() 850 | encoding_dict_full_inference['y_pred'][train_valid] = model_vars['y_pred'].detach().cpu() 851 | encoding_dict_full_inference['y_filter'][train_valid] = model_vars['y_filter'].detach().cpu() 852 | encoding_dict_full_inference['y_smooth'][train_valid] = model_vars['y_smooth'].detach().cpu() 853 | 854 | for k in self.config.loss.steps_ahead: 855 | if k != 1: 856 | y_pred_k, _, _ = self.dfine.get_k_step_ahead_prediction(model_vars, k) 857 | encoding_dict_full_inference[f'y_{k}_pred'][train_valid] = y_pred_k 858 | 859 | if self.config.model.supervise_behv: 860 | encoding_dict_full_inference['behv_hat'][train_valid] = model_vars['behv_hat'].detach().cpu() 861 | 862 | # Dump batch and full inference encoding dictionaries into encoding_results 863 | encoding_results = dict(batch_inference=encoding_dict, full_inference=encoding_dict_full_inference) 864 | 865 | # Save encoding dictionary as .pt file 866 | if save_results: 867 | torch.save(encoding_results, os.path.join(self.config.model.save_dir, 'encoding_results.pt')) 868 | 869 | return encoding_results 870 | 871 | 872 | def write_summary(self, epoch, prefix='train'): 873 | ''' 874 | Logs metrics to Tensorboard 875 | 876 | Parameters: 877 | ------------ 878 | - epoch: int, Number of epoch for which to log metrics 879 | - prefix: str, Prefix to log metrics 880 | ''' 881 | 882 | for key, val in self.metrics[prefix].items(): 883 | self.writer.add_scalar(f'{prefix}/{key}', val.compute(), epoch) 884 | 885 | # Rest below is for logging scale values in the loss, will be same for all prefices, so log them only for 'train' 886 | if prefix != 'valid': 887 | self.writer.add_scalar(f'scale_l2', self.dfine.scale_l2, epoch) 888 | self.writer.add_scalar(f'learning_rate', self.lr_scheduler.get_last_lr()[0], epoch) 889 | if self.config.model.supervise_behv: 890 | self.writer.add_scalar(f'scale_behv_recons', self.dfine.scale_behv_recons, epoch) 891 | 892 | --------------------------------------------------------------------------------