├── .gitignore
├── DFINE.py
├── LICENSE.md
├── README.md
├── config_dfine.py
├── data
└── swiss_roll.pt
├── datasets.py
├── metrics.py
├── modules
├── LDM.py
└── MLP.py
├── nn.py
├── python_utils.py
├── requirements.txt
├── time_series_utils.py
├── trainers
├── BaseTrainer.py
└── TrainerDFINE.py
└── tutorial.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # unnecessary folders
10 | results*
11 | trainers/__pycache__
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 | *.pdf
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
--------------------------------------------------------------------------------
/DFINE.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | from modules.LDM import LDM
9 | from modules.MLP import MLP
10 | from nn import get_kernel_initializer_function, compute_mse, get_activation_function
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 |
16 | class DFINE(nn.Module):
17 | '''
18 | DFINE (Dynamical Flexible Inference for Nonlinear Embeddings) Model.
19 |
20 | DFINE is a novel neural network model of neural population activity with the ability to perform
21 | flexible inference while modeling the nonlinear latent manifold structure and linear temporal dynamics.
22 | To model neural population activity, two sets of latent factors are defined: the dynamic latent factors
23 | which characterize the linear temporal dynamics on a nonlinear manifold, and the manifold latent factors
24 | which describe this low-dimensional manifold that is embedded in the high-dimensional neural population activity space.
25 | These two separate sets of latent factors together enable all the above flexible inference properties
26 | by allowing for Kalman filtering on the manifold while also capturing embedding nonlinearities.
27 | Here are some mathematical notations used in this repository:
28 | - y: The high dimensional neural population activity, (num_seq, num_steps, dim_y). It must be Gaussian distributed, e.g., Gaussian-smoothed firing rates, or LFP, ECoG, EEG
29 | - a: The manifold latent factors, (num_seq, num_steps, dim_a).
30 | - x: The dynamic latent factors, (num_seq, num_steps, dim_x).
31 |
32 |
33 | * Please note that DFINE can perform learning and inference both for continuous data or trial-based data or segmented continuous data. In the case of continuous data,
34 | num_seq and batch_size can be set to 1, and we let the model be optimized from the long time-series (this is basically gradient descent and not batch-based gradient descent).
35 | In case of trial-based data, we can just pass the 3D tensor as the shape (num_seq, num_steps, dim_y) suggests. In case of segmented continuous data,
36 | num_seq can be the number of segments and DFINE provides both per-segment and concatenated inference at the end for the user's convenience. In the concatenated inference,
37 | the assumption is the concatenation of segments form a continuous time-series (single time-series with batch size of 1).
38 | '''
39 |
40 | def __init__(self, config):
41 | '''
42 | Initializer for an DFINE object. Note that DFINE is a subclass of torch.nn.Module.
43 |
44 | Parameters:
45 | ------------
46 |
47 | - config: yacs.config.CfgNode, yacs config which contains all hyperparameters required to create the DFINE model
48 | Please see config_dfine.py for the hyperparameters, their default values and definitions.
49 | '''
50 |
51 | super(DFINE, self).__init__()
52 |
53 | # Get the config and dimension parameters
54 | self.config = config
55 |
56 | # Set the seed, seed is by default set to a random integer, see config_dfine.py
57 | torch.manual_seed(self.config.seed)
58 |
59 | # Set the factor dimensions and loss scales
60 | self._set_dims_and_scales()
61 |
62 | # Initialize LDM parameters
63 | A, C, W_log_diag, R_log_diag, mu_0, Lambda_0 = self._init_ldm_parameters()
64 |
65 | # Initialize the LDM
66 | self.ldm = LDM(dim_x=self.dim_x, dim_a=self.dim_a,
67 | A=A, C=C,
68 | W_log_diag=W_log_diag, R_log_diag=R_log_diag,
69 | mu_0=mu_0, Lambda_0=Lambda_0,
70 | is_W_trainable=self.config.model.is_W_trainable,
71 | is_R_trainable=self.config.model.is_R_trainable)
72 |
73 | # Initialize encoder and decoder(s)
74 | self.encoder = self._get_MLP(input_dim=self.dim_y,
75 | output_dim=self.dim_a,
76 | layer_list=self.config.model.hidden_layer_list,
77 | activation_str=self.config.model.activation)
78 |
79 | self.decoder = self._get_MLP(input_dim=self.dim_a,
80 | output_dim=self.dim_y,
81 | layer_list=self.config.model.hidden_layer_list[::-1],
82 | activation_str=self.config.model.activation)
83 |
84 | # If asked to train supervised model, get behavior mapper
85 | if self.config.model.supervise_behv:
86 | self.mapper = self._get_MLP(input_dim=self.dim_a,
87 | output_dim=self.dim_behv,
88 | layer_list=self.config.model.hidden_layer_list_mapper,
89 | activation_str=self.config.model.activation_mapper)
90 |
91 |
92 | def _set_dims_and_scales(self):
93 | '''
94 | Sets the observation (y), manifold latent factor (a) and dynamic latent factor (x)
95 | (and behavior data dimension if supervised model is to be trained) dimensions,
96 | as well as behavior reconstruction loss and regularization loss scales from config.
97 | '''
98 |
99 | # Set the dimensions
100 | self.dim_y = self.config.model.dim_y
101 | self.dim_a = self.config.model.dim_a
102 | self.dim_x = self.config.model.dim_x
103 |
104 | if self.config.model.supervise_behv:
105 | self.dim_behv = len(self.config.model.which_behv_dims)
106 |
107 | # Set the loss scales for behavior component and for the regularization
108 | if self.config.model.supervise_behv:
109 | self.scale_behv_recons = self.config.loss.scale_behv_recons
110 | self.scale_l2 = self.config.loss.scale_l2
111 |
112 |
113 | def _get_MLP(self, input_dim, output_dim, layer_list, activation_str='tanh'):
114 | '''
115 | Creates an MLP object
116 |
117 | Parameters:
118 | ------------
119 | - input_dim: int, Dimensionality of the input to the MLP network
120 | - output_dim: int, Dimensionality of the output of the MLP network
121 | - layer_list: list, List of number of neurons in each hidden layer
122 | - activation_str: str, Activation function's name, 'tanh' by default
123 |
124 | Returns:
125 | ------------
126 | - mlp_network: an instance of MLP class with desired architecture
127 | '''
128 |
129 | activation_fn = get_activation_function(activation_str)
130 | kernel_initializer_fn = get_kernel_initializer_function(self.config.model.nn_kernel_initializer)
131 |
132 | mlp_network = MLP(input_dim=input_dim,
133 | output_dim=output_dim,
134 | layer_list=layer_list,
135 | activation_fn=activation_fn,
136 | kernel_initializer_fn=kernel_initializer_fn
137 | )
138 | return mlp_network
139 |
140 |
141 | def _init_ldm_parameters(self):
142 | '''
143 | Initializes the LDM Module parameters
144 |
145 | Returns:
146 | ------------
147 | - A: torch.Tensor, shape: (self.dim_x, self.dim_x), State transition matrix of LDM
148 | - C: torch.Tensor, shape: (self.dim_a, self.dim_x), Observation matrix of LDM
149 | - W_log_diag: torch.Tensor, shape: (self.dim_x, ), Log-diagonal of dynamics noise covariance matrix (W, therefore it is diagonal and PSD)
150 | - R_log_diag: torch.Tensor, shape: (self.dim_a, ), Log-diagonal of observation noise covariance matrix (R, therefore it is diagonal and PSD)
151 | - mu_0: torch.Tensor, shape: (self.dim_x, ), Dynamic latent factor prediction initial condition (x_{0|-1}) for Kalman filtering
152 | - Lambda_0: torch.Tensor, shape: (self.dim_x, self.dim_x), Dynamic latent factor estimate error covariance initial condition (P_{0|-1}) for Kalman filtering
153 |
154 | * We learn the log-diagonal of matrix W and R to satisfy the PSD constraint for cov matrices. Diagnoal W and R are used for the stability of learning
155 | similar to prior latent LDM works, see (Kao et al., Nature Communications, 2015) & (Abbaspourazad et al., IEEE TNSRE, 2019) for further info
156 | '''
157 |
158 | kernel_initializer_fn = get_kernel_initializer_function(self.config.model.ldm_kernel_initializer)
159 | A = kernel_initializer_fn(self.config.model.init_A_scale * torch.eye(self.dim_x, dtype=torch.float32))
160 | C = kernel_initializer_fn(self.config.model.init_C_scale * torch.randn(self.dim_a, self.dim_x, dtype=torch.float32))
161 |
162 | W_log_diag = torch.log(kernel_initializer_fn(torch.diag(self.config.model.init_W_scale * torch.eye(self.dim_x, dtype=torch.float32))))
163 | R_log_diag = torch.log(kernel_initializer_fn(torch.diag(self.config.model.init_R_scale * torch.eye(self.dim_a, dtype=torch.float32))))
164 |
165 | mu_0 = kernel_initializer_fn(torch.zeros(self.dim_x, dtype=torch.float32))
166 | Lambda_0 = kernel_initializer_fn(self.config.model.init_cov * torch.eye(self.dim_x, dtype=torch.float32))
167 |
168 | return A, C, W_log_diag, R_log_diag, mu_0, Lambda_0
169 |
170 |
171 | def forward(self, y, mask=None):
172 | '''
173 | Forward pass for DFINE Model
174 |
175 | Parameters:
176 | ------------
177 | - y: torch.Tensor, shape: (num_seq, num_steps, dim_y), High-dimensional neural observations
178 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether
179 | observations at each timestep exist (1) or are missing (0)
180 |
181 | Returns:
182 | ------------
183 | - model_vars: dict, Dictionary which contains learned parameters, inferrred latents, predictions and reconstructions. Keys are:
184 | - a_hat: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors.
185 | - a_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_a), Batch of predicted estimates of manifold latent factors (last index of the second dimension is removed)
186 | - a_filter: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of filtered estimates of manifold latent factors
187 | - a_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of smoothed estimates of manifold latent factors
188 | - x_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_x), Batch of predicted estimates of dynamic latent factors
189 | - x_filter: torch.Tensor, shape: (num_seq, num_steps, dim_x), Batch of filtered estimates of dynamic latent factors
190 | - x_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_x), Batch of smoothed estimates of dynamic latent factors
191 | - Lambda_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_x, dim_x), Batch of predicted estimates of dynamic latent factor estimation error covariance
192 | - Lambda_filter: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Batch of filtered estimates of dynamic latent factor estimation error covariance
193 | - Lambda_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Batch of smoothed estimates of dynamic latent factor estimation error covariance
194 | - y_hat: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of projected estimates of neural observations
195 | - y_pred: torch.Tensor, shape: (num_seq, num_steps-1, dim_y), Batch of predicted estimates of neural observations
196 | - y_filter: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of filtered estimates of neural observations
197 | - y_smooth: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of smoothed estimates of neural observations
198 | - A: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Repeated (tile) state transition matrix of LDM, same for each time-step in the 2nd axis
199 | - C: torch.Tensor, shape: (num_seq, num_steps, dim_y, dim_x), Repeated (tile) observation matrix of LDM, same for each time-step in the 2nd axis
200 | - behv_hat: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of reconstructed behavior. None if unsupervised model is trained
201 |
202 | * Terminology definition:
203 | projected: noisy estimations of manifold latent factors after nonlinear manifold embedding via encoder
204 | predicted: one-step ahead predicted estimations (t+1|t), the first and last time indices are (1|0) and (T|T-1)
205 | filtered: causal estimations (t|t)
206 | smoothed: non-causal estimations (t|T)
207 | '''
208 |
209 | # Get the dimensions from y
210 | num_seq, num_steps, _ = y.shape
211 |
212 | # Create the mask if it's None
213 | if mask is None:
214 | mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
215 |
216 | # Get the encoded low-dimensional manifold factors (project via nonlinear manifold embedding) -> the outputs are (num_seq * num_steps, dim_a)
217 | a_hat = self.encoder(y.view(-1, self.dim_y))
218 |
219 | # Reshape the manifold latent factors back into 3D structure (num_seq, num_steps, dim_a)
220 | a_hat = a_hat.view(-1, num_steps, self.dim_a)
221 |
222 | # Run LDM to infer filtered and smoothed dynamic latent factors
223 | x_pred, x_filter, x_smooth, Lambda_pred, Lambda_filter, Lambda_smooth = self.ldm(a=a_hat, mask=mask, do_smoothing=True)
224 | A = self.ldm.A.repeat(num_seq, num_steps, 1, 1)
225 | C = self.ldm.C.repeat(num_seq, num_steps, 1, 1)
226 | a_pred = (C @ x_pred.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, num_steps, dim_a, dim_x) x (num_seq, num_steps, dim_x, 1) --> (num_seq, num_steps, dim_a)
227 | a_filter = (C @ x_filter.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, num_steps, dim_a, dim_x) x (num_seq, num_steps, dim_x, 1) --> (num_seq, num_steps, dim_a)
228 | a_smooth = (C @ x_smooth.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, num_steps, dim_a, dim_x) x (num_seq, num_steps, dim_x, 1) --> (num_seq, num_steps, dim_a)
229 |
230 | # Remove the last timestep of predictions since it's T+1|T, which is not of our interest
231 | x_pred = x_pred[:, :-1, :]
232 | Lambda_pred = Lambda_pred[:, :-1, :, :]
233 | a_pred = a_pred[:, :-1, :]
234 |
235 | # Supervise a_seq or a_smooth to behavior if requested -> behv_hat shape: (num_seq, num_steps, dim_behv)
236 | if self.config.model.supervise_behv:
237 | if self.config.model.behv_from_smooth:
238 | behv_hat = self.mapper(a_smooth.view(-1, self.dim_a))
239 | else:
240 | behv_hat = self.mapper(a_hat.view(-1, self.dim_a))
241 | behv_hat = behv_hat.view(-1, num_steps, self.dim_behv)
242 | else:
243 | behv_hat = None
244 |
245 | # Get filtered and smoothed estimates of neural observations. To perform k-step-ahead prediction,
246 | # get_k_step_ahead_prediction(...) function should be called after the forward pass.
247 | y_hat = self.decoder(a_hat.view(-1, self.dim_a))
248 | y_pred = self.decoder(a_pred.reshape(-1, self.dim_a))
249 | y_filter = self.decoder(a_filter.view(-1, self.dim_a))
250 | y_smooth = self.decoder(a_smooth.view(-1, self.dim_a))
251 |
252 | y_hat = y_hat.view(num_seq, -1, self.dim_y)
253 | y_pred = y_pred.view(num_seq, -1, self.dim_y)
254 | y_filter = y_filter.view(num_seq, -1, self.dim_y)
255 | y_smooth = y_smooth.view(num_seq, -1, self.dim_y)
256 |
257 | # Dump inferrred latents, predictions and reconstructions to a dictionary
258 | model_vars = dict(a_hat=a_hat, a_pred=a_pred, a_filter=a_filter, a_smooth=a_smooth,
259 | x_pred=x_pred, x_filter=x_filter, x_smooth=x_smooth,
260 | Lambda_pred=Lambda_pred, Lambda_filter=Lambda_filter, Lambda_smooth=Lambda_smooth,
261 | y_hat=y_hat, y_pred=y_pred, y_filter=y_filter, y_smooth=y_smooth,
262 | A=A, C=C, behv_hat=behv_hat)
263 | return model_vars
264 |
265 |
266 | def get_k_step_ahead_prediction(self, model_vars, k):
267 | '''
268 | Performs k-step ahead prediction of manifold latent factors, dynamic latent factors and neural observations.
269 |
270 | Parameters:
271 | ------------
272 | - model_vars: dict, Dictionary returned after forward(...) call. See the definition of forward(...) function for information.
273 | - x_filter: torch.Tensor, shape: (num_seq, num_steps, dim_x), Batch of filtered estimates of dynamic latent factors
274 | - A: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x) or (dim_x, dim_x), State transition matrix of LDM
275 | - C: torch.Tensor, shape: (num_seq, num_steps, dim_y, dim_x) or (dim_y, dim_x), Observation matrix of LDM
276 | - k: int, Number of steps ahead for prediction
277 |
278 | Returns:
279 | ------------
280 | - y_pred_k: torch.Tensor, shape: (num_seq, num_steps-k, dim_y), Batch of predicted estimates of neural observations,
281 | the first index of the second dimension is y_{k|0}
282 | - a_pred_k: torch.Tensor, shape: (num_seq, num_steps-k, dim_a), Batch of predicted estimates of manifold latent factor,
283 | the first index of the second dimension is a_{k|0}
284 | - x_pred_k: torch.Tensor, shape: (num_seq, num_steps-k, dim_x), Batch of predicted estimates of dynamic latent factor,
285 | the first index of the second dimension is x_{k|0}
286 | '''
287 |
288 | # Check whether provided k value is valid or not
289 | if k <= 0 or not isinstance(k, int):
290 | assert False, 'Number of steps ahead prediction value is invalid or of wrong type, k must be a positive integer!'
291 |
292 | # Extract the required variables from model_vars dictionary
293 | x_filter = model_vars['x_filter']
294 | A = model_vars['A']
295 | C = model_vars['C']
296 |
297 | # Get the required dimensions
298 | num_seq, num_steps, _ = x_filter.shape
299 |
300 | # Check if shapes of A and C are 4D where first 2 dimensions are (number of trials/time segments) and (number of steps)
301 | if len(A.shape) == 2:
302 | A = A.repeat(num_seq, num_steps, 1, 1)
303 |
304 | if len(C.shape) == 2:
305 | C = C.repeat(num_seq, num_steps, 1, 1)
306 |
307 | # Here is where k-step ahead prediction is iteratively performed
308 | x_pred_k = x_filter[:, :-k, ...] # [x_k|0, x_{k+1}|1, ..., x_{T}|{T-k}]
309 | for i in range(1, k+1):
310 | if i != k:
311 | x_pred_k = (A[:, i:-(k-i), ...] @ x_pred_k.unsqueeze(dim=-1)).squeeze(dim=-1)
312 | else:
313 | x_pred_k = (A[:, i:, ...] @ x_pred_k.unsqueeze(dim=-1)).squeeze(dim=-1)
314 | a_pred_k = (C[:, k:, ...] @ x_pred_k.unsqueeze(dim=-1)).squeeze(dim=-1)
315 |
316 | # After obtaining k-step ahead predicted manifold latent factors, they're decoded to obtain k-step ahead predicted neural observations
317 | y_pred_k = self.decoder(a_pred_k.view(-1, self.dim_a))
318 |
319 | # Reshape mean and variance back to 3D structure after decoder (num_seq, num_steps, dim_y)
320 | y_pred_k = y_pred_k.reshape(num_seq, -1, self.dim_y)
321 |
322 | return y_pred_k, a_pred_k, x_pred_k
323 |
324 |
325 | def compute_loss(self, y, model_vars, mask=None, behv=None):
326 | '''
327 | Computes k-step ahead predicted MSE loss, regularization loss and behavior reconstruction loss
328 | if supervised model is being trained.
329 |
330 | Parameters:
331 | ------------
332 | - y: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of high-dimensional neural observations
333 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether
334 | observations at each timestep exists (1) or are missing (0)
335 | if None it will be set to ones.
336 | - model_vars: dict, Dictionary returned after forward(...) call. See the definition of forward(...) function for information.
337 | - behv: torch.tensor, shape: (num_seq, num_steps, dim_behv), Batch of behavior data
338 |
339 | Returns:
340 | ------------
341 | - loss: torch.Tensor, shape: (), Loss to optimize, which is sum of k-step-ahead MSE loss, L2 regularization loss and
342 | behavior reconstruction loss if model is supervised
343 | - loss_dict: dict, Dictionary which has all loss components to log on Tensorboard. Keys are (e.g. for config.loss.steps_ahead = [1, 2]):
344 | - steps_{k}_mse: torch.Tensor, shape: (), {k}-step ahead predicted masked MSE, k's are determined by config.loss.steps_ahead
345 | - model_loss: torch.Tensor, shape: (), Negative of sum of all steps_{k}_mse
346 | - behv_loss: torch.Tensor, shape: (), Behavior reconstruction loss, 0 if model is unsupervised
347 | - reg_loss: torch.Tensor, shape: (), L2 Regularization loss for DFINE encoder and decoder weights
348 | - total_loss: torch.Tensor, shape: (), Sum of model_loss, behv_loss and reg_loss
349 | '''
350 |
351 | # Create the mask if it's None
352 | if mask is None:
353 | mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
354 |
355 | # Dump individual loss values for logging or Tensorboard
356 | loss_dict = dict()
357 |
358 | # Iterate over multiple steps ahead
359 | k_steps_mse_sum = 0
360 | for _, k in enumerate(self.config.loss.steps_ahead):
361 | y_pred_k, _, _ = self.get_k_step_ahead_prediction(model_vars, k=k)
362 | mse_pred = compute_mse(y_flat=y[:, k:, :].reshape(-1, self.dim_y),
363 | y_hat_flat=y_pred_k.reshape(-1, self.dim_y),
364 | mask_flat=mask[:, k:, :].reshape(-1,))
365 | k_steps_mse_sum += mse_pred
366 | loss_dict[f'steps_{k}_mse'] = mse_pred
367 |
368 | model_loss = k_steps_mse_sum
369 | loss_dict['model_loss'] = model_loss
370 |
371 | # Get MSE loss for behavior reconstruction, 0 if we dont supervise our model with behavior data
372 | if self.config.model.supervise_behv:
373 | behv_mse = compute_mse(y_flat=behv[..., self.config.model.which_behv_dims].reshape(-1, self.dim_behv),
374 | y_hat_flat=model_vars['behv_hat'].reshape(-1, self.dim_behv),
375 | mask_flat=mask.reshape(-1,))
376 | behv_loss = self.scale_behv_recons * behv_mse
377 | else:
378 | behv_mse = torch.tensor(0, dtype=torch.float32, device=model_loss.device)
379 | behv_loss = torch.tensor(0, dtype=torch.float32, device=model_loss.device)
380 | loss_dict['behv_mse'] = behv_mse
381 | loss_dict['behv_loss'] = behv_loss
382 |
383 | # L2 regularization loss
384 | reg_loss = 0
385 | for name, param in self.named_parameters():
386 | if 'weight' in name:
387 | reg_loss = reg_loss + self.scale_l2 * torch.norm(param)
388 | loss_dict['reg_loss'] = reg_loss
389 |
390 | # Final loss is summation of model loss (sum of k-step ahead MSEs), behavior reconstruction loss and L2 regularization loss
391 | loss = model_loss + behv_loss + reg_loss
392 | loss_dict['total_loss'] = loss
393 | return loss, loss_dict
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | This software is Copyright © 2023 The University of Southern California. All Rights Reserved.
2 |
3 | Permission to use, copy, modify, and distribute this software and its documentation for educational, research
4 | and non-profit purposes, without fee, and without a written agreement is hereby granted, provided that the
5 | above copyright notice, this paragraph and the following three paragraphs appear in all copies.
6 |
7 | Permission to make commercial use of this software may be obtained by contacting:
8 | USC Stevens Center for Innovation
9 | University of Southern California
10 | 1150 S. Olive Street, Suite 2300
11 | Los Angeles, CA 90115, USA
12 |
13 | This software program and documentation are copyrighted by The University of Southern California. The software
14 | program and documentation are supplied "as is", without any accompanying services from USC. USC does not warrant
15 | that the operation of the program will be uninterrupted or error-free. The end-user understands that the program
16 | was developed for research purposes and is advised not to rely exclusively on the program for any reason.
17 |
18 | IN NO EVENT SHALL THE UNIVERSITY OF SOUTHERN CALIFORNIA BE LIABLE TO ANY PARTY FOR
19 | DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST
20 | PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF THE
21 | UNIVERSITY OF SOUTHERN CALIFORNIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
22 | DAMAGE. THE UNIVERSITY OF SOUTHERN CALIFORNIA SPECIFICALLY DISCLAIMS ANY
23 | WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED
25 | HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF SOUTHERN CALIFORNIA HAS NO
26 | OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
27 | MODIFICATIONS.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Torch DFINE
2 | PyTorch implementation of DFINE: Dynamical Flexible Inference for Nonlinear Embeddings
3 |
4 | DFINE is a neural network model of neural population activity that is developed to enable accurate and
5 | flexible inference, whether causally in real time, non-causally, or even in the presence of missing neural observations.
6 | Also, DFINE enables recursive and thus computationally efficient inference for real-time implementation.
7 | DFINE's capabilities are important for applications such as neurotechnologies and brain-computer interfaces.
8 |
9 | More information about the model and its training and inference methods can be found inside [tutorial.ipynb](tutorial.ipynb) and in our manuscript below.
10 |
11 | ## Publication
12 |
13 | Abbaspourazad, H.\*, Erturk, E.\*, Pesaran, B., & Shanechi, M. M. Dynamical flexible inference of nonlinear latent factors and structures in neural population activity. _Nature Biomedical Engineering_ (2023). https://www.nature.com/articles/s41551-023-01106-1
14 |
15 | Original preprint: https://www.biorxiv.org/content/10.1101/2023.03.13.532479v1
16 |
17 | ## Installation
18 | Torch DFINE requires Python version 3.8.* or 3.9.*. After the virtual environment with compatible Python version is set up,
19 | navigate to project folder and simply run the following command:
20 |
21 | ```
22 | pip install -r requirements.txt
23 | ```
24 |
25 | Then, navigate to the virtual environment's site-packages directory, create a file with .pth extension and copy the
26 | main project directory path (e.g. .../torchDFINE) into that .pth file. This will allow importing the desired modules by using subdirectories,
27 | for instance, TrainerDFINE class can be imported by ```from trainers.TrainerDFINE import TrainerDFINE```.
28 |
29 | ## DFINE Tutorial
30 | Please see [tutorial.ipynb](tutorial.ipynb) for further information and guidelines on DFINE's model, training, and inference.
31 |
32 | ## Licence
33 | Copyright (c) 2023 University of Southern California
34 | See full notice in [LICENSE.md](LICENSE.md)
35 | Hamidreza Abbaspourazad\*, Eray Erturk\* and Maryam M. Shanechi
36 | Shanechi Lab, University of Southern California
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/config_dfine.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | from python_utils import flatten_dict, unflatten_dict
9 |
10 | from yacs.config import CfgNode as CN
11 | import torch
12 |
13 |
14 | #### Initialization of default and recommended (except dimensions and hidden layer lists, set them suitable for data to fit) config
15 | _config = CN()
16 |
17 | ## Set device and seed
18 | _config.device = 'cpu'
19 | _config.seed = int(torch.randint(low=0, high=100000, size=(1,)))
20 |
21 | ## Dump model related settings
22 | _config.model = CN()
23 |
24 | # Hidden layer list where each element is the number of neurons for that hidden layer of DFINE encoder/decoder. Please use [20,20,20,20] for nonlinear manifold simulations.
25 | _config.model.hidden_layer_list = [32,32,32]
26 | # Activation function used in encoder and decoder layers
27 | _config.model.activation = 'tanh'
28 | # Dimensionality of neural observations
29 | _config.model.dim_y = 30
30 | # Dimensionality of manifold latent factor, a choice higher than dim_y (above) may lead to overfitting
31 | _config.model.dim_a = 16
32 | # Dimensionality of dynamic latent factor, it's recommended to set it same as dim_a (above), please see Extended Data Fig. 8
33 | _config.model.dim_x = 16
34 | # Initialization scale of LDM state transition matrix
35 | _config.model.init_A_scale = 1
36 | # Initialization scale of LDM observation matrix
37 | _config.model.init_C_scale = 1
38 | # Initialization scale of LDM process noise covariance matrix
39 | _config.model.init_W_scale = 0.5
40 | # Initialization scale of LDM observation noise covariance matrix
41 | _config.model.init_R_scale = 0.5
42 | # Initialization scale of dynamic latent factor estimation error covariance matrix
43 | _config.model.init_cov = 1
44 | # Boolean for whether process noise covariance matrix W is learnable or not
45 | _config.model.is_W_trainable = True
46 | # Boolean for whether observation noise covariance matrix R is learnable or not
47 | _config.model.is_R_trainable = True
48 | # Initialization type of LDM parameters, see nn.get_kernel_initializer_function for detailed definition and supported types
49 | _config.model.ldm_kernel_initializer = 'default'
50 | # Initialization type of DFINE encoder and decoder parameters, see nn.get_kernel_initializer_function for detailed definition and supported types
51 | _config.model.nn_kernel_initializer = 'xavier_normal'
52 | # Boolean for whether to learn a behavior-supervised model or not. It must be set to True if supervised model will be trained.
53 | _config.model.supervise_behv = False
54 | # Hidden layer list for the behavior mapper where each element is the number of neurons for that hidden layer of the mapper
55 | _config.model.hidden_layer_list_mapper = [20,20,20]
56 | # Activation function used in mapper layers
57 | _config.model.activation_mapper = 'tanh'
58 | # List of dimensions of behavior data to be decoded by mapper, check for any dimensionality mismatch
59 | _config.model.which_behv_dims = [0,1,2,3]
60 | # Boolean for whether to decode behavior from a_smooth
61 | _config.model.behv_from_smooth = True
62 | # Main save directory for DFINE results, plots and checkpoints
63 | _config.model.save_dir = 'D:/DATA/DFINE_results'
64 | # Number of steps to save DFINE checkpoints
65 | _config.model.save_steps = 10
66 |
67 | ## Dump loss related settings
68 | _config.loss = CN()
69 |
70 | # L2 regularization loss scale (we recommend a grid-search for the best value, i.e., a grid of [1e-4, 5e-4, 1e-3, 2e-3]). Please use 0 for nonlinear manifold simulations as it leads to a better performance.
71 | _config.loss.scale_l2 = 2e-3
72 | # List of number of steps ahead for which DFINE is optimized. For unsupervised and supervised versions, default values are [1,2,3,4] and [1,2], respectively.
73 | _config.loss.steps_ahead = [1,2,3,4]
74 | # If _config.model.supervise_behv is True, scale for MSE of behavior reconstruction (We recommend a grid-search for the best value. It should be set to a large value).
75 | _config.loss.scale_behv_recons = 20
76 |
77 | ## Dump training related settings
78 | _config.train = CN()
79 |
80 | # Batch size
81 | _config.train.batch_size = 32
82 | # Number of epochs for which DFINE is trained
83 | _config.train.num_epochs = 200
84 | # Number of steps to check validation data performance
85 | _config.train.valid_step = 1
86 | # Number of steps to save training/validation plots
87 | _config.train.plot_save_steps = 50
88 | # Number of steps to print training/validation logs
89 | _config.train.print_log_steps = 10
90 |
91 | ## Dump loading settings
92 | _config.load = CN()
93 |
94 | # Number of checkpoint to load
95 | _config.load.ckpt = -1
96 | # Boolean for whether to resume training from the epoch where checkpoint is saved
97 | _config.load.resume_train = False
98 |
99 | ## Dump learning rate related settings
100 | _config.lr = CN()
101 |
102 | # Learning rate scheduler type, options are explr (StepLR, purely exponential if explr.step_size == 1), cyclic (CyclicLR) or constantlr (constant learning rate, no scheduling)
103 | _config.lr.scheduler = 'explr'
104 | # Initial learning rate
105 | _config.lr.init = 0.02
106 |
107 | # Dump cyclic LR scheduler related settings, check https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CyclicLR.html for details
108 | _config.lr.cyclic = CN()
109 | # Minimum learning rate for cyclic LR scheduler
110 | _config.lr.cyclic.base_lr = 0.005
111 | # Maximum learning rate for cyclic LR scheduler
112 | _config.lr.cyclic.max_lr = 0.02
113 | # Envelope scale for exponential cyclic LR scheduler mode
114 | _config.lr.cyclic.gamma = 1
115 | # Mode for cyclic LR scheduler
116 | _config.lr.cyclic.mode = 'triangular'
117 | # Number of iterations in the increasing half of the cycle
118 | _config.lr.cyclic.step_size_up = 10
119 |
120 | # Dump exponential LR scheduler related settings, check https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html for details
121 | _config.lr.explr = CN()
122 | # Multiplicative factor of learning rate decay
123 | _config.lr.explr.gamma = 0.9
124 | # Steps to decay the learning rate, becomes purely exponential if step is 1
125 | _config.lr.explr.step_size = 15
126 |
127 | ## Dump optimizer related settings
128 | _config.optim = CN()
129 |
130 | # Epsilon for Adam optimizer
131 | _config.optim.eps = 1e-8
132 | # Gradient clipping norm
133 | _config.optim.grad_clip = 1
134 |
135 |
136 | def get_default_config():
137 | '''
138 | Creates the default config
139 |
140 | Returns:
141 | ------------
142 | - config: yacs.config.CfgNode, default DFINE config
143 | '''
144 |
145 | return _config.clone()
146 |
147 |
148 | def update_config(config, new_config):
149 | '''
150 | Updates the config
151 |
152 | Parameters:
153 | ------------
154 | - config: yacs.config.CfgNode or dict, Config to update
155 | - new_config: yacs.config.CfgNode or dict, Config with new settings and appropriate keys
156 |
157 | Returns:
158 | ------------
159 | - unflattened_config: yacs.config.CfgNode, Config with updated settings
160 | '''
161 |
162 | # Flatten both configs
163 | flat_config = flatten_dict(config)
164 | flat_new_config = flatten_dict(new_config)
165 |
166 | # Update and unflatten the config to return
167 | flat_config.update(flat_new_config)
168 | unflattened_config = CN(unflatten_dict(flat_config))
169 |
170 | return unflattened_config
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
--------------------------------------------------------------------------------
/data/swiss_roll.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShanechiLab/torchDFINE/48a981259a795c3d1725a348c81a1552093724ac/data/swiss_roll.pt
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | from torch.utils.data import Dataset
9 | import torch
10 |
11 |
12 | class DFINEDataset(Dataset):
13 | '''
14 | Dataset class for DFINE.
15 | '''
16 |
17 | def __init__(self, y, behv=None, mask=None):
18 | '''
19 | Initializer for DFINEDataset. Note that this is a subclass of torch.utils.data.Dataset. \
20 |
21 | Parameters:
22 | ------------
23 | - y: torch.Tensor, shape: (num_seq, num_steps, dim_y), High dimensional neural observations.
24 | - behv: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Behavior data. None by default.
25 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether
26 | observations at each timestep exists (1) or are missing (0).
27 | None by default.
28 | '''
29 |
30 | self.y = y
31 |
32 | # If behv is not provided, initialize it by zeros.
33 | if behv is None:
34 | self.behv = torch.zeros(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
35 | else:
36 | self.behv = behv
37 |
38 | # If mask is not provided, initialize it by ones.
39 | if mask is None:
40 | self.mask = torch.ones(y.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
41 | else:
42 | self.mask = mask
43 |
44 |
45 | def __len__(self):
46 | '''
47 | Returns the length of the dataset
48 | '''
49 |
50 | return self.y.shape[0]
51 |
52 |
53 | def __getitem__(self, idx):
54 | '''
55 | Returns a tuple of neural observations, behavior and mask segments
56 | '''
57 |
58 | return self.y[idx, :, :], self.behv[idx, :, :], self.mask[idx, :, :]
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | from torchmetrics import Metric
9 | import torch
10 |
11 |
12 | class Mean(Metric):
13 | '''
14 | Mean metric class to log batch-averaged metrics to Tensorboard.
15 | '''
16 |
17 | def __init__(self):
18 | '''
19 | Initializer for Mean metric. Note that this class is a subclass of torchmetrics.Metric.
20 | '''
21 |
22 | super().__init__(dist_sync_on_step=False)
23 |
24 | # Define total sum and number of samples that sum is computed over
25 | self.add_state("sum", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
26 | self.add_state("num_samples", default=torch.tensor(0, dtype=torch.float32), dist_reduce_fx="sum")
27 |
28 |
29 | def update(self, value, batch_size):
30 | '''
31 | Updates the total sum and number of samples
32 |
33 | Parameters:
34 | ------------
35 | - value: torch.Tensor, shape: (), Value to add to sum
36 | - batch_size: torch.Tensor, shape: (), Number of samples that 'value' is averaged over
37 | '''
38 |
39 | value = value.clone().detach()
40 | batch_size = torch.tensor(batch_size, dtype=torch.float32)
41 | self.sum += value.cpu() * batch_size
42 | self.num_samples += batch_size
43 |
44 |
45 | def reset(self):
46 | '''
47 | Resets the total sum and number of samples to 0
48 | '''
49 |
50 | self.sum = torch.tensor(0, dtype=torch.float32)
51 | self.num_samples = torch.tensor(0, dtype=torch.float32)
52 |
53 |
54 | def compute(self):
55 | '''
56 | Computes the mean metric.
57 |
58 | Returns:
59 | ------------
60 | - avg: Average value for the metric
61 | '''
62 |
63 | avg = self.sum / self.num_samples
64 | return avg
65 |
--------------------------------------------------------------------------------
/modules/LDM.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | class LDM(nn.Module):
13 | '''
14 | Linear Dynamical Model backbone for DFINE. This module is used for smoothing and filtering
15 | given a batch of trials/segments/time-series.
16 |
17 | LDM equations are as follows:
18 | x_{t+1} = Ax_{t} + w_{t}; cov(w_{t}) = W
19 | a_{t} = Cx_{t} + r_{t}; cov(r_{t}) = R
20 | '''
21 |
22 | def __init__(self, **kwargs):
23 | '''
24 | Initializer for an LDM object. Note that LDM is a subclass of torch.nn.Module.
25 |
26 | Parameters
27 | ------------
28 | - dim_x: int, Dimensionality of dynamic latent factors, default None
29 | - dim_a: int, Dimensionality of manifold latent factors, default None
30 | - is_W_trainable: bool, Whether dynamics noise covariance matrix (W) is learnt or not, default True
31 | - is_R_trainable: bool, Whether observation noise covariance matrix (R) is learnt or not, default True
32 | - A: torch.Tensor, shape: (self.dim_x, self.dim_x), State transition matrix of LDM, default identity
33 | - C: torch.Tensor, shape: (self.dim_a, self.dim_x), Observation matrix of LDM, default identity
34 | - mu_0: torch.Tensor, shape: (self.dim_x, ), Dynamic latent factor estimate initial condition (x_{0|-1}) for Kalman filtering, default zeros
35 | - Lambda_0: torch.Tensor, shape: (self.dim_x, self.dim_x), Dynamic latent factor estimate error covariance initial condition (P_{0|-1}) for Kalman Filtering, default identity
36 | - W_log_diag: torch.Tensor, shape: (self.dim_x, ), Log-diagonal of process noise covariance matrix (W, therefore it is diagonal and PSD), default ones
37 | - R_log_diag: torch.Tensor, shape: (self.dim_a, ), Log-diagonal of observation noise covariance matrix (R, therefore it is diagonal and PSD), default ones
38 | '''
39 |
40 | super(LDM, self).__init__()
41 |
42 | self.dim_x = kwargs.pop('dim_x', None)
43 | self.dim_a = kwargs.pop('dim_a', None)
44 |
45 | self.is_W_trainable = kwargs.pop('is_W_trainable', True)
46 | self.is_R_trainable = kwargs.pop('is_R_trainable', True)
47 |
48 | # Initializer for identity matrix, zeros matrix and ones matrix
49 | self.eye_init = lambda shape, dtype=torch.float32: torch.eye(*shape, dtype=dtype)
50 | self.zeros_init = lambda shape, dtype=torch.float32: torch.zeros(*shape, dtype=dtype)
51 | self.ones_init = lambda shape, dtype=torch.float32: torch.ones(*shape, dtype=dtype)
52 |
53 | # Get initial values for LDM parameters
54 | self.A = kwargs.pop('A', self.eye_init((self.dim_x, self.dim_x), dtype=torch.float32).unsqueeze(dim=0)).type(torch.FloatTensor)
55 | self.C = kwargs.pop('C', self.eye_init((self.dim_a, self.dim_x), dtype=torch.float32).unsqueeze(dim=0)).type(torch.FloatTensor)
56 |
57 | # Get KF initial conditions
58 | self.mu_0 = kwargs.pop('mu_0', self.zeros_init((self.dim_x, ), dtype=torch.float32)).type(torch.FloatTensor)
59 | self.Lambda_0 = kwargs.pop('Lambda_0', self.eye_init((self.dim_x, self.dim_x), dtype=torch.float32)).type(torch.FloatTensor)
60 |
61 | # Get initial process and observation noise parameters
62 | self.W_log_diag = kwargs.pop('W_log_diag', self.ones_init((self.dim_x, ), dtype=torch.float32)).type(torch.FloatTensor)
63 | self.R_log_diag = kwargs.pop('R_log_diag', self.ones_init((self.dim_a, ), dtype=torch.float32)).type(torch.FloatTensor)
64 |
65 | # Register trainable parameters to module
66 | self._register_params()
67 |
68 |
69 | def _register_params(self):
70 | '''
71 | Registers the learnable LDM parameters as nn.Parameters
72 | '''
73 |
74 | # Check if LDM matrix shapes are consistent
75 | self._check_matrix_shapes()
76 |
77 | # Register LDM parameters
78 | self.A = torch.nn.Parameter(self.A, requires_grad=True)
79 | self.C = torch.nn.Parameter(self.C, requires_grad=True)
80 |
81 | self.W_log_diag = torch.nn.Parameter(self.W_log_diag, requires_grad=self.is_W_trainable)
82 | self.R_log_diag = torch.nn.Parameter(self.R_log_diag, requires_grad=self.is_R_trainable)
83 |
84 | self.mu_0 = torch.nn.Parameter(self.mu_0, requires_grad=True)
85 | self.Lambda_0 = torch.nn.Parameter(self.Lambda_0, requires_grad=True)
86 |
87 |
88 | def _check_matrix_shapes(self):
89 | '''
90 | Checks whether LDM parameters have the correct shapes, which are defined above in the constructor
91 | '''
92 |
93 | # Check A matrix's shape
94 | if self.A.shape != (self.dim_x, self.dim_x):
95 | assert False, 'Shape of A matrix is not (dim_x, dim_x)!'
96 |
97 | # Check C matrix's shape
98 | if self.C.shape != (self.dim_a, self.dim_x):
99 | assert False, 'Shape of C matrix is not (dim_a, dim_x)!'
100 |
101 | # Check mu_0 matrix's shape
102 | if len(self.mu_0.shape) != 1:
103 | self.mu_0 = self.mu_0.view(-1, )
104 |
105 | if self.mu_0.shape != (self.dim_x, ):
106 | assert False, 'Shape of mu_0 matrix is not (dim_x, )!'
107 |
108 | # Check Lambda_0 matrix's shape
109 | if self.Lambda_0.shape != (self.dim_x, self.dim_x):
110 | assert False, 'Shape of Lambda_0 matrix is not (dim_x, dim_x)!'
111 |
112 | # Check W_log_diag matrix's shape
113 | if len(self.W_log_diag.shape) != 1:
114 | self.W_log_diag = self.W_log_diag.view(-1, )
115 |
116 | if self.W_log_diag.shape != (self.dim_x, ):
117 | assert False, 'Shape of W_log_diag matrix is not (dim_x, )!'
118 |
119 | # Check R_log_diag matrix's shape
120 | if len(self.R_log_diag.shape) != 1:
121 | self.R_log_diag = self.R_log_diag.view(-1, )
122 |
123 | if self.R_log_diag.shape != (self.dim_a, ):
124 | assert False, 'Shape of R_log_diag matrix is not (dim_x, )!'
125 |
126 |
127 | def _get_covariance_matrices(self):
128 | '''
129 | Get the process and observation noise covariance matrices from log-diagonals.
130 |
131 | Returns:
132 | ------------
133 | - W: torch.Tensor, shape: (self.dim_x, self.dim_x), Process noise covariance matrix
134 | - R: torch.Tensor, shape: (self.dim_a, self.dim_a), Observation noise covariance matrix
135 | '''
136 |
137 | W = torch.diag(torch.exp(self.W_log_diag))
138 | R = torch.diag(torch.exp(self.R_log_diag))
139 | return W, R
140 |
141 |
142 | def compute_forwards(self, a, mask=None):
143 | '''
144 | Performs the forward iteration of causal flexible Kalman filtering, given a batch of trials/segments/time-series
145 |
146 | Parameters:
147 | ------------
148 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step)
149 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether
150 | observations at each timestep exists (1) or are missing (0)
151 |
152 | Returns:
153 | ------------
154 | - mu_pred_all: torch.Tensor, shape: (num_steps, num_seq, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0}
155 | - mu_t_all: torch.Tensor, shape: (num_steps, num_seq, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0}
156 | - Lambda_pred_all: torch.Tensor, shape: (num_steps, num_seq, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0}
157 | - Lambda_t_all: torch.Tensor, shape: (num_steps, num_seq, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0}
158 | '''
159 |
160 | if mask is None:
161 | mask = torch.ones(a.shape[:-1], dtype=torch.float32)
162 |
163 | num_seq, num_steps, _ = a.shape
164 |
165 | # Make sure that mask is 3D (last axis is 1-dimensional)
166 | if len(mask.shape) != len(a.shape):
167 | mask = mask.unsqueeze(dim=-1) # (num_seq, num_steps, 1)
168 |
169 | # To make sure we do not accidentally use the real outputs in the steps with missing values, set them to a dummy value, e.g., 0.
170 | # The dummy values of observations at masked points are irrelevant because:
171 | # Kalman disregards the observations by setting Kalman Gain to 0 in K = torch.mul(K, mask[:, t, ...].unsqueeze(dim=1)) @ line 204
172 | a_masked = torch.mul(a, mask) # (num_seq, num_steps, dim_a) x (num_seq, num_steps, 1)
173 |
174 | # Initialize mu_0 and Lambda_0
175 | mu_0 = self.mu_0.unsqueeze(dim=0).repeat(num_seq, 1) # (num_seq, dim_x)
176 | Lambda_0 = self.Lambda_0.unsqueeze(dim=0).repeat(num_seq, 1, 1) # (num_seq, dim_x, dim_x)
177 |
178 | mu_pred = mu_0 # (num_seq, dim_x)
179 | Lambda_pred = Lambda_0 # (num_seq, dim_x, dim_x)
180 |
181 | # Create empty arrays for filtered and predicted estimates, NOTE: The last time-step of the prediction has T+1|T, which may not be of interest
182 | mu_pred_all = torch.zeros((num_steps, num_seq, self.dim_x), dtype=torch.float32, device=mu_0.device)
183 | mu_t_all = torch.zeros((num_steps, num_seq, self.dim_x), dtype=torch.float32, device=mu_0.device)
184 |
185 | # Create empty arrays for filtered and predicted error covariance, NOTE: The last time-step of the prediction has T+1|T, which may not be of interest
186 | Lambda_pred_all = torch.zeros((num_steps, num_seq, self.dim_x, self.dim_x), dtype=torch.float32, device=mu_0.device)
187 | Lambda_t_all = torch.zeros((num_steps, num_seq, self.dim_x, self.dim_x), dtype=torch.float32, device=mu_0.device)
188 |
189 | # Get covariance matrices
190 | W, R = self._get_covariance_matrices()
191 |
192 | for t in range(num_steps):
193 | # Tile C matrix for each time segment
194 | C_t = self.C.repeat(num_seq, 1, 1)
195 |
196 | # Obtain residual
197 | a_pred = (C_t @ mu_pred.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_a)
198 | r = a_masked[:, t, ...] - a_pred # (num_seq, dim_a)
199 |
200 | # Project system uncertainty into measurement space, get Kalman Gain
201 | S = C_t @ Lambda_pred @ torch.permute(C_t, (0, 2, 1)) + R # num_seq, dim_a, dim_a)
202 | S_inv = torch.inverse(S) # num_seq, dim_a, dim_a)
203 | K = Lambda_pred @ torch.permute(C_t, (0, 2, 1)) @ S_inv # (num_seq, dim_x, dim_a)
204 | K = torch.mul(K, mask[:, t, ...].unsqueeze(dim=1)) # (num_seq, dim_x, dim_a) x (num_seq, 1, 1)
205 |
206 | # Get current mu and Lambda
207 | mu_t = mu_pred + (K @ r.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_x)
208 | I_KC = torch.eye(self.dim_x, dtype=torch.float32, device=mu_0.device) - K @ C_t # (num_seq, dim_x, dim_x)
209 | Lambda_t = I_KC @ Lambda_pred # (num_seq, dim_x, dim_x)
210 |
211 | # Tile A matrix for each time segment
212 | A_t = self.A.repeat(num_seq, 1, 1) # (num_seq, dim_x, dim_x)
213 |
214 | # Prediction
215 | mu_pred = (A_t @ mu_t.unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_x, dim_x) x (num_seq, dim_x, 1) --> (num_seq, dim_x, 1) --> (num_seq, dim_x)
216 | Lambda_pred = A_t @ Lambda_t @ torch.permute(A_t, (0, 2, 1)) + W # (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) --> (num_seq, dim_x, dim_x)
217 |
218 | # Keep predictions and updates
219 | mu_pred_all[t, ...] = mu_pred
220 | mu_t_all[t, ...] = mu_t
221 |
222 | Lambda_pred_all[t, ...] = Lambda_pred
223 | Lambda_t_all[t, ...] = Lambda_t
224 |
225 | return mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all
226 |
227 |
228 | def filter(self, a, mask=None):
229 | '''
230 | Performs Kalman Filtering
231 |
232 | Parameters:
233 | ------------
234 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step)
235 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether
236 | observations at each timestep exists (1) or are missing (0)
237 |
238 | Returns:
239 | ------------
240 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0}
241 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0}
242 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0}
243 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0}
244 | '''
245 |
246 | # Run the forward iteration
247 | mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all = self.compute_forwards(a=a, mask=mask)
248 |
249 | # Swab num_seq and num_steps dimensions
250 | mu_pred_all = torch.permute(mu_pred_all, (1, 0, 2))
251 | mu_t_all = torch.permute(mu_t_all, (1, 0, 2))
252 | Lambda_pred_all = torch.permute(Lambda_pred_all, (1, 0, 2, 3))
253 | Lambda_t_all = torch.permute(Lambda_t_all, (1, 0, 2, 3))
254 |
255 | return mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all
256 |
257 |
258 | def compute_backwards(self, mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all):
259 | '''
260 | Performs backward iteration for Rauch-Tung-Striebel (RTS) Smoother
261 |
262 | Parameters:
263 | ------------
264 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0}
265 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0}
266 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0}
267 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0}
268 |
269 | Returns:
270 | ------------
271 | - mu_back_all: torch.Tensor, shape: (num_steps, num_seq, dim_x), Dynamic latent factor smoothed estimates (t|T) where first index of the second dimension has x_{0|T}
272 | - Lambda_back_all: torch.Tensor, shape: (num_steps, num_seq, dim_x, dim_x), Dynamic latent factor estimation error covariance smoothed estimates (t|T) where first index of the second dimension has P_{0|T}
273 | '''
274 |
275 | # Get number of steps and number of trials
276 | num_steps, num_seq, _ = mu_pred_all.shape
277 |
278 | # Create empty arrays for smoothed dynamic latent factors and error covariances
279 | mu_back_all = torch.zeros((num_steps, num_seq, self.dim_x), dtype=torch.float32, device=mu_pred_all.device) # (num_steps, num_seq, dim_x)
280 | Lambda_back_all = torch.zeros((num_steps, num_seq, self.dim_x, self.dim_x), dtype=torch.float32, device=mu_pred_all.device) # (num_steps, num_seq, dim_x, dim_x)
281 |
282 | # Last smoothed estimation is equivalent to the filtered estimation
283 | mu_back_all[-1, ...] = mu_t_all[-1, ...]
284 | Lambda_back_all[-1, ...] = Lambda_t_all[-1, ...]
285 |
286 | # Initialize iterable parameter
287 | mu_back = mu_t_all[-1, ...]
288 | Lambda_back = Lambda_back_all[-1, ...]
289 |
290 | for t in range(num_steps-2, -1, -1): # iterate loop over reverse time: T-2, T-3, ..., 0, where the last time-step is T-1
291 | A_t = self.A.repeat(num_seq, 1, 1)
292 | J_t = Lambda_t_all[t, ...] @ torch.permute(A_t, (0, 2, 1)) @ torch.inverse(Lambda_pred_all[t, ...]) # (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x) x (num_seq, dim_x, dim_x)
293 | mu_back = mu_t_all[t, ...] + (J_t @ (mu_back - mu_pred_all[t, ...]).unsqueeze(dim=-1)).squeeze(dim=-1) # (num_seq, dim_x) + (num_seq, dim_x, dim_x) x (num_seq, dim_x)
294 |
295 | Lambda_back = Lambda_t_all[t, ...] + J_t @ (Lambda_back - Lambda_pred_all[t, ...]) @ torch.permute(J_t, (0, 2, 1)) # (num_seq, dim_x, dim_x)
296 |
297 | mu_back_all[t, ...] = mu_back
298 | Lambda_back_all[t, ...] = Lambda_back
299 |
300 | return mu_back_all, Lambda_back_all
301 |
302 |
303 | def smooth(self, a, mask=None):
304 | '''
305 | Performs Rauch-Tung-Striebel (RTS) Smoothing
306 |
307 | Parameters:
308 | ------------
309 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step)
310 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether
311 | observations at each timestep exists (1) or are missing (0)
312 |
313 | Returns:
314 | ------------
315 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0}
316 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0}
317 | - mu_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor smoothed estimates (t|T) where first index of the second dimension has x_{0|T}
318 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0}
319 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0}
320 | - Lambda_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance smoothed estimates (t|T) where first index of the second dimension has P_{0|T}
321 | '''
322 |
323 | mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all = self.compute_forwards(a=a, mask=mask)
324 | mu_back_all, Lambda_back_all = self.compute_backwards(mu_pred_all=mu_pred_all,
325 | mu_t_all=mu_t_all,
326 | Lambda_pred_all=Lambda_pred_all,
327 | Lambda_t_all=Lambda_t_all)
328 |
329 | # Swab num_seq and num_steps dimensions
330 | mu_pred_all = torch.permute(mu_pred_all, (1, 0, 2))
331 | mu_t_all = torch.permute(mu_t_all, (1, 0, 2))
332 | mu_back_all = torch.permute(mu_back_all, (1, 0, 2))
333 |
334 | Lambda_pred_all = torch.permute(Lambda_pred_all, (1, 0, 2, 3))
335 | Lambda_t_all = torch.permute(Lambda_t_all, (1, 0, 2, 3))
336 | Lambda_back_all = torch.permute(Lambda_back_all, (1, 0, 2, 3))
337 |
338 | return mu_pred_all, mu_t_all, mu_back_all, Lambda_pred_all, Lambda_t_all, Lambda_back_all
339 |
340 |
341 | def forward(self, a, mask=None, do_smoothing=False):
342 | '''
343 | Forward pass function for LDM Module
344 |
345 | Parameters:
346 | ------------
347 | - a: torch.Tensor, shape: (num_seq, num_steps, dim_a), Batch of projected manifold latent factors (outputs of encoder; nonlinear manifold embedding step)
348 | - mask: torch.Tensor, shape: (num_seq, num_steps, 1), Mask input which shows whether
349 | observations at each timestep exists (1) or are missing (0)
350 | do_smoothing: bool, Whether to run RTS Smoothing or not
351 |
352 | Returns:
353 | ------------
354 | - mu_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor predictions (t+1|t) where first index of the second dimension has x_{1|0}
355 | - mu_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor filtered estimates (t|t) where first index of the second dimension has x_{0|0}
356 | - mu_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x), Dynamic latent factor smoothed estimates (t|T) where first index of the second dimension has x_{0|T}. Ones tensor if do_smoothing is False
357 | - Lambda_pred_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance predictions (t+1|t) where first index of the second dimension has P_{1|0}
358 | - Lambda_t_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance filtered estimates (t|t) where first index of the second dimension has P_{0|0}
359 | - Lambda_back_all: torch.Tensor, shape: (num_seq, num_steps, dim_x, dim_x), Dynamic latent factor estimation error covariance smoothed estimates (t|T) where first index of the second dimension has P_{0|T}. Ones tensor if do_smoothing is False
360 | '''
361 |
362 | if do_smoothing:
363 | mu_pred_all, mu_t_all, mu_back_all, Lambda_pred_all, Lambda_t_all, Lambda_back_all = self.smooth(a=a, mask=mask)
364 | else:
365 | mu_pred_all, mu_t_all, Lambda_pred_all, Lambda_t_all = self.filter(a=a, mask=mask)
366 | mu_back_all = torch.ones_like(mu_t_all, dtype=torch.float32, device=mu_t_all.device)
367 | Lambda_back_all = torch.ones_like(Lambda_t_all, dtype=torch.float32, device=Lambda_t_all.device)
368 |
369 | return mu_pred_all, mu_t_all, mu_back_all, Lambda_pred_all, Lambda_t_all, Lambda_back_all
370 |
371 |
372 |
373 |
374 |
375 |
--------------------------------------------------------------------------------
/modules/MLP.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | import torch.nn as nn
9 |
10 |
11 | class MLP(nn.Module):
12 | '''
13 | MLP Module for DFINE encoder and decoder in addition to the mapper to behavior for supervised DFINE.
14 | Encoder encodes the high-dimensional neural observations into low-dimensional manifold latent factors space
15 | and decoder decodes the manifold latent factors into high-dimensional neural observations.
16 | '''
17 |
18 | def __init__(self, **kwargs):
19 | '''
20 | Initializer for an Encoder/Decoder/Mapper object. Note that Encoder/Decoder/Mapper is a subclass of torch.nn.Module.
21 |
22 | Parameters
23 | ------------
24 | input_dim: int, Dimensionality of inputs to the MLP, default None
25 | output_dim: int, Dimensionality of outputs of the MLP , default None
26 | layer_list: list, List of number of neurons in each hidden layer, default None
27 | kernel_initializer_fn: torch.nn.init, Hidden layer weight initialization function, default nn.init.xavier_normal_
28 | activation_fn: torch.nn, Activation function of neurons, default nn.Tanh
29 | '''
30 |
31 | super(MLP, self).__init__()
32 |
33 | self.input_dim = kwargs.pop('input_dim', None)
34 | self.output_dim = kwargs.pop('output_dim', None)
35 | self.layer_list = kwargs.pop('layer_list', None)
36 | self.kernel_initializer_fn = kwargs.pop('kernel_initializer_fn', nn.init.xavier_normal_)
37 | self.activation_fn = kwargs.pop('activation_fn', nn.Tanh)
38 |
39 | # Create the ModuleList to stack the hidden layers
40 | self.layers = nn.ModuleList()
41 |
42 | # Create the hidden layers and initialize their weights based on desired initialization function
43 | current_dim = self.input_dim
44 | for i, dim in enumerate(self.layer_list):
45 | self.layers.append(nn.Linear(current_dim, dim))
46 | self.kernel_initializer_fn(self.layers[i].weight)
47 | current_dim = dim
48 |
49 | # Create output layer and initialize their weights based on desired initialization function
50 | self.out_layer = nn.Linear(current_dim, self.output_dim)
51 | self.kernel_initializer_fn(self.out_layer.weight)
52 |
53 |
54 | def forward(self, inp):
55 | '''
56 | Forward pass function for MLP Module
57 |
58 | Parameters:
59 | ------------
60 | inp: torch.Tensor, shape: (num_seq * num_steps, input_dim), Flattened batch of inputs
61 |
62 | Returns:
63 | ------------
64 | out: torch.Tensor, shape: (num_seq * num_steps, output_dim),Flattened batch of outputs
65 | '''
66 |
67 | # Push neural observations thru each hidden layer
68 | for layer in self.layers:
69 | inp = layer(inp)
70 | inp = self.activation_fn(inp)
71 |
72 | # Obtain the output
73 | out = self.out_layer(inp)
74 | return out
--------------------------------------------------------------------------------
/nn.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | def compute_mse(y_flat, y_hat_flat, mask_flat=None):
13 | '''
14 | Returns average Mean Square Error (MSE)
15 |
16 | Parameters:
17 | ------------
18 | - y_flat: torch.Tensor, shape: (num_samp, dim_y), True data to compute MSE of
19 | - y_hat_flat: torch.Tensor, shape: (num_samp, dim_y), Predicted/Reconstructed data to compute MSE of
20 | - mask_flat: torch.Tensor, shape: (num_samp, 1), Mask to compute MSE loss which shows whether
21 | observations at each timestep exists (1) or are missing (0)
22 |
23 | Returns:
24 | ------------
25 | - mse: torch.Tensor, Average MSE
26 | '''
27 |
28 | if mask_flat is None:
29 | mask_flat = torch.ones(y_flat.shape[:-1], dtype=torch.float32)
30 |
31 | # Make sure mask is 2D
32 | if len(mask_flat.shape) != len(y_flat.shape):
33 | mask_flat = mask_flat.unsqueeze(dim=-1)
34 |
35 | # Compute the MSEs and mask the timesteps where observations are missing
36 | mse = (y_flat - y_hat_flat) ** 2
37 | mse = torch.mul(mask_flat, mse)
38 |
39 | # Return the mean of the mse (over available observations)
40 | if mask_flat.shape[-1] != y_flat.shape[-1]: # which means shape of mask_flat is of dimension 1
41 | num_el = mask_flat.sum() * y_flat.shape[-1]
42 | else:
43 | num_el = mask_flat.sum()
44 |
45 | mse = mse.sum() / num_el
46 | return mse
47 |
48 |
49 | def get_activation_function(activation_str):
50 | '''
51 | Returns activation function given the activation function's name
52 |
53 | Parameters:
54 | ----------------------
55 | - activation_str: str, Activation function's name
56 |
57 | Returns:
58 | ----------------------
59 | - activation_fn: torch.nn, Activation function
60 | '''
61 |
62 | if activation_str.lower() == 'elu':
63 | return nn.ELU()
64 | elif activation_str.lower() == 'hardtanh':
65 | return nn.Hardtanh()
66 | elif activation_str.lower() == 'leakyrelu':
67 | return nn.LeakyReLU()
68 | elif activation_str.lower() == 'relu':
69 | return nn.ReLU()
70 | elif activation_str.lower() == 'rrelu':
71 | return nn.RReLU()
72 | elif activation_str.lower() == 'sigmoid':
73 | return nn.Sigmoid()
74 | elif activation_str.lower() == 'mish':
75 | return nn.Mish()
76 | elif activation_str.lower() == 'tanh':
77 | return nn.Tanh()
78 | elif activation_str.lower() == 'tanhshrink':
79 | return nn.Tanhshrink()
80 | elif activation_str.lower() == 'linear':
81 | return lambda x: x
82 |
83 | def get_kernel_initializer_function(kernel_initializer_str):
84 | '''
85 | Returns kernel initialization function given the kernel initialization function's name
86 |
87 | Parameters:
88 | ----------------------
89 | - kernel_initializer_str: str, Kernel initialization function's name
90 |
91 | Returns:
92 | ----------------------
93 | - kernel_initializer_fn: torch.nn.init, Kernel initialization function
94 | '''
95 |
96 | if kernel_initializer_str.lower() == 'uniform':
97 | return nn.init.uniform_
98 | elif kernel_initializer_str.lower() == 'normal':
99 | return nn.init.normal_
100 | elif kernel_initializer_str.lower() == 'xavier_uniform':
101 | return nn.init.xavier_uniform_
102 | elif kernel_initializer_str.lower() == 'xavier_normal':
103 | return nn.init.xavier_normal_
104 | elif kernel_initializer_str.lower() == 'kaiming_uniform':
105 | return nn.init.kaiming_uniform_
106 | elif kernel_initializer_str.lower() == 'kaiming_normal':
107 | return nn.init.kaiming_normal_
108 | elif kernel_initializer_str.lower() == 'orthogonal':
109 | return nn.init.orthogonal_
110 | elif kernel_initializer_str.lower() == 'default':
111 | return lambda x:x
--------------------------------------------------------------------------------
/python_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | import torch
9 | import numpy as np
10 |
11 |
12 | def carry_to_device(data, device, dtype=torch.float32):
13 | '''
14 | Carries dict/list of torch Tensors/numpy arrays to desired device recursively
15 |
16 | Parameters:
17 | ------------
18 | - data: torch.Tensor/np.ndarray/dict/list: Dictionary/list of torch Tensors/numpy arrays or torch Tensor/numpy array to be carried to desired device
19 | - device: str, Device name to carry the torch Tensors/numpy arrays to
20 | - dtype: torch.dtype, Data type for torch.Tensor to be returned, torch.float32 by default
21 |
22 | Returns:
23 | ------------
24 | - data: torch.Tensor/dict/list: Dictionary/list of torch.Tensors or torch Tensor carried to desired device
25 | '''
26 |
27 | if torch.is_tensor(data):
28 | return data.to(device)
29 |
30 | elif isinstance(data, np.ndarray):
31 | return torch.tensor(data, dtype=dtype).to(device)
32 |
33 | elif isinstance(data, dict):
34 | for key in data.keys():
35 | data[key] = carry_to_device(data[key], device)
36 | return data
37 |
38 | elif isinstance(data, list):
39 | for i, d in enumerate(data):
40 | data[i] = carry_to_device(d, device)
41 | return data
42 |
43 | else:
44 | return data
45 |
46 |
47 | def convert_to_tensor(x, dtype=torch.float32):
48 | '''
49 | Converts numpy.ndarray to torch.Tensor
50 |
51 | Parameters:
52 | ------------
53 | - x: np.ndarray, Numpy array to convert to torch.Tensor (if it's of type torch.Tensor already, it's returned without conversion)
54 | - dtype: torch.dtype, Data type for torch.Tensor to be returned, torch.float32 by default
55 |
56 | Returns:
57 | ------------
58 | - y: torch.Tensor, Converted tensor
59 | '''
60 |
61 | if isinstance(x, torch.Tensor):
62 | y = x
63 | elif isinstance(x, np.ndarray):
64 | y = torch.tensor(x, dtype=dtype) # use np.ndarray as middle step so that function works with tf tensors as well
65 | else:
66 | assert False, 'Only Numpy array can be converted to tensor'
67 | return y
68 |
69 |
70 | def flatten_dict(dictionary, level=[]):
71 | '''
72 | Flattens nested dictionary by putting '.' between nested keys, reference: https://stackoverflow.com/questions/6037503/python-unflatten-dict
73 |
74 | Parameters:
75 | ------------
76 | - dictionary: dict, Nested dictionary to be flattened
77 | - level: list, List of strings for recursion, initialized by empty list
78 |
79 | Returns:
80 | ------------
81 | - tmp_dict: dict, Flattened dictionary
82 | '''
83 |
84 | tmp_dict = {}
85 | for key, val in dictionary.items():
86 | if isinstance(val, dict):
87 | tmp_dict.update(flatten_dict(val, level + [key]))
88 | else:
89 | tmp_dict['.'.join(level + [key])] = val
90 | return tmp_dict
91 |
92 |
93 | def unflatten_dict(dictionary):
94 | '''
95 | Unflattens a flattened dictionary whose keys are joint string of nested keys separated by '.', reference: https://stackoverflow.com/questions/6037503/python-unflatten-dict
96 |
97 | Parameters:
98 | ------------
99 | - dictionary: dict, Flat dictionary to be unflattened
100 |
101 | Returns:
102 | ------------
103 | - resultDict: dict, Unflattened dictionary
104 | '''
105 |
106 | resultDict = dict()
107 | for key, value in dictionary.items():
108 | parts = key.split(".")
109 | d = resultDict
110 | for part in parts[:-1]:
111 | if part not in d:
112 | d[part] = dict()
113 | d = d[part]
114 | d[parts[-1]] = value
115 | return resultDict
116 |
117 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dill==0.3.4
2 | matplotlib==3.5.1
3 | numpy==1.19.5
4 | scipy==1.5.4
5 | --find-links https://download.pytorch.org/whl/torch_stable.html
6 | torch==1.11.0+cu113
7 | torchmetrics==0.8.0
8 | tqdm==4.62.3
9 | yacs==0.1.6
10 | tensorboard
11 | scikit-learn # for the tutorial
--------------------------------------------------------------------------------
/time_series_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | from python_utils import convert_to_tensor
9 | import torch
10 | from scipy.stats import pearsonr
11 |
12 |
13 | def get_nrmse_error(y, y_hat, version_calculation='modified'):
14 | '''
15 | Computes normalized root-mean-squared error between two 3D tensors. Note that this operation is not symmetric.
16 |
17 | Parameters:
18 | ------------
19 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with true observations
20 | - y_hat: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with reconstructed/estimated observations
21 | - version_calculation: str, Version to calculate the variance. If 'regular', variance of each sequence is computed separately,
22 | which may result in unstable nrmse value since some sequences may be constant or close to being constant and
23 | results in ~0 variance, so high/unreasonable nrmse. To prevent that, variance is computed across flattened sequence in
24 | 'modified' mode. 'modified' by default.
25 |
26 | Returns:
27 | ------------
28 | - normalized_error: torch.Tensor, shape: (dim_y,), Normalized root-mean-squared error for each data dimension
29 | - normalized_error_mean: torch.Tensor, shape: (), Average normalized root-mean-squared error for each data dimension
30 | '''
31 |
32 | # Check if dimensions are consistent
33 | assert y.shape == y_hat.shape, f'dimensions of y {y.shape} and y_hat {y_hat.shape} do not match'
34 | assert len(y.shape) == 3, 'mismatch in x dimension: x should be in the format of (num_seq, num_steps, dim_x)'
35 |
36 | y = convert_to_tensor(y).detach().cpu()
37 | y_hat = convert_to_tensor(y_hat).detach().cpu()
38 |
39 | # carry time to first dimension
40 | y = torch.permute(y, (1,0,2)) # (num_steps, num_seq, dim_x)
41 | y_hat = torch.permute(y_hat, (1,0,2)) # (num_steps, num_seq, dim_x)
42 |
43 | recons_error = torch.mean(torch.square(y - y_hat), dim=0)
44 |
45 | # way 1 to calculate variance
46 | if version_calculation == 'regular':
47 | var_y = torch.mean(torch.square(y - torch.mean(y, dim=0)), dim=0)
48 |
49 | # way 2 to calculate variance (sometime data in a batch is flat, it's more robust to calculate variance globally)
50 | elif version_calculation == 'modified':
51 | y_resh = torch.reshape(y, (-1, y.shape[2]))
52 | var_y = torch.mean(torch.square(y_resh - torch.mean(y_resh, dim=0)), dim=0)
53 | var_y = torch.tile(var_y.unsqueeze(dim=0), (y.shape[1], 1))
54 | normalized_error = torch.mean((torch.sqrt(recons_error) / torch.sqrt(var_y)), dim=0) # mean across batches
55 | normalized_error_mean = torch.mean(normalized_error)
56 |
57 | return normalized_error, normalized_error_mean
58 |
59 |
60 | def get_rmse_error(y, y_hat):
61 | '''
62 | Computes root-mean-squared error between two 3D tensors
63 |
64 | Parameters:
65 | ------------
66 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with true observations
67 | - y_hat: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y), Tensor with reconstructed/estimated observations
68 |
69 | Returns:
70 | ------------
71 | - rmse: torch.Tensor, shape: (dim_y,), Root-mean-squared error for each data dimension
72 | - rmse_mean: torch.Tensor, shape: (), Average root-mean-squared error for each data dimension
73 | '''
74 |
75 | # Check if dimensions are consistent
76 | assert y.shape == y_hat.shape, f'dimensions of y {y.shape} and y_hat {y_hat.shape} do not match'
77 |
78 | if len(y.shape) == 3:
79 | dim_y = y.shape[-1]
80 | y = y.reshape(-1, dim_y)
81 | y_hat = y_hat.reshape(-1, dim_y)
82 |
83 | y = convert_to_tensor(y).detach().cpu()
84 | y_hat = convert_to_tensor(y_hat).detach().cpu()
85 |
86 | rmse = torch.sqrt(torch.mean(torch.square(y-y_hat), dim=0))
87 | rmse_mean = torch.nanmean(rmse.nan_to_num(posinf=torch.nan, neginf=torch.nan)) # for stability purposes
88 | return rmse, rmse_mean
89 |
90 |
91 | def get_pearson_cc(y, y_hat):
92 | '''
93 | Computes Pearson correlation coefficient across two 2D (If 3D tensors are given, they're reshaped across 1st and 2nd dimensions) tensors across first (time) dimension.
94 |
95 | Parameters:
96 | ------------
97 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Tensor with true observations
98 | - y_hat: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Tensor with reconstructed/estimated observations
99 |
100 | Returns:
101 | ------------
102 | - ccs: torch.Tensor, shape: (dim_y,), Pearson correlation coefficients computed across first (time) dimension
103 | - ccs_mean: torch.Tensor, shape: (), Pearson correlation coefficients computed across first (time) dimension and averaged across data dimensions
104 | '''
105 |
106 | assert y.shape == y_hat.shape, f'dimensions of x {y.shape} and xhat {y_hat.shape} do not match'
107 |
108 | if len(y.shape) == 3:
109 | dim_y = y.shape[-1]
110 | y = y.reshape(-1, dim_y)
111 | y_hat = y_hat.reshape(-1, dim_y)
112 |
113 | y = convert_to_tensor(y).detach().cpu().numpy() # make sure every array/tensor has .numpy() function, pearsonr works on ndarrays
114 | y_hat = convert_to_tensor(y_hat).detach().cpu().numpy()
115 |
116 | ccs = []
117 | for dim in range(y.shape[-1]):
118 | cc, _ = pearsonr(y[:, dim], y_hat[:, dim])
119 | ccs.append(cc)
120 |
121 | ccs = torch.tensor(ccs, dtype=torch.float32)
122 | ccs_mean = torch.nanmean(ccs.nan_to_num(posinf=torch.nan, neginf=torch.nan))
123 | return ccs, ccs_mean
124 |
125 |
126 |
127 | def z_score_tensor(y, fit=True, **kwargs):
128 | '''
129 | Performs z-scoring fitting and transformation.
130 |
131 | Parameters:
132 | ------------
133 | - y: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Tensor/array to z-score
134 | (and if fit is True, to learn mean and standard deviation)
135 | - fit: bool, Whether to learn mean and standard deviation from y. If False, learnt 'mean' and 'std' should be provided as keyword arguments.
136 | - mean: torch.Tensor, shape: (), Mean to transform y. If fit is True, it's not necessary to provide since mean is going to be learnt. 0 by default.
137 | - std: torch.Tensor, shape: (), Standard deviation to transform y. If fit is True, it's not necessary to provide since std is going to be learnt. 1 by default.
138 |
139 | Returns:
140 | ------------
141 | y_z_scored: torch.Tensor/np.ndarray, shape: (num_seq, num_steps, dim_y) or (num_steps, dim_y), Z-scored tensor/array
142 | mean: torch.Tensor, shape: (), Learnt mean. If fit is True, it's the mean provided via keyword, or default
143 | mean: torch.Tensor/np.ndarray, Learnt standard deviation. If fit is True, it's the std provided via keyword, or default
144 | '''
145 |
146 | # Make sure that gradients are turned off
147 | with torch.no_grad():
148 | y = convert_to_tensor(y)
149 |
150 | y_resh = y.reshape(-1, y.shape[-1])
151 | if fit:
152 | mean = torch.mean(y_resh, dim=0)
153 | std = torch.std(y_resh, dim=0)
154 | else:
155 | mean = kwargs.pop('mean', 0)
156 | std = kwargs.pop('std', 1)
157 |
158 | # to prevent nan values
159 | std[std==0] = 1
160 |
161 | y_resh = (y_resh - mean) / std
162 | y_z_scored = y_resh.reshape(y.shape)
163 | return y_z_scored, mean, std
164 |
165 |
--------------------------------------------------------------------------------
/trainers/BaseTrainer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | import os
9 | import torch
10 | import datetime
11 | import logging
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 |
15 | class BaseTrainer:
16 | '''
17 | Base trainer class which is overwritten by TrainerDFINE.
18 | '''
19 |
20 | def __init__(self, config):
21 | '''
22 | Initializer of BaseTrainer.
23 |
24 | Parameters:
25 | ------------
26 | - config: yacs.config.CfgNode, yacs config which contains all hyperparameters required to create the DFINE model
27 | Please see config.py for the hyperparameters, their default values and definitions.
28 | '''
29 |
30 | self.config = config
31 |
32 | # Checkpoint and plot save directories, create directories if they don't exist
33 | self.ckpt_save_dir = os.path.join(self.config.model.save_dir, 'ckpts'); os.makedirs(self.ckpt_save_dir, exist_ok=True)
34 | self.plot_save_dir = os.path.join(self.config.model.save_dir, 'plots'); os.makedirs(self.plot_save_dir, exist_ok=True)
35 |
36 | # Training can be continued where it was left of, by default, training start epoch is 1
37 | self.start_epoch = 1
38 |
39 | # Tensorboard summary writer
40 | self.writer = SummaryWriter(log_dir=os.path.join(self.config.model.save_dir, 'summary'))
41 |
42 |
43 | def _save_config(self, config_save_name='config.yaml'):
44 | '''
45 | Saves the config inside config.model.save_dir
46 |
47 | Parameters:
48 | ------------
49 | - config_save_name: str, Fullfile name to save the config, 'config.yaml' by default
50 | '''
51 |
52 | config_save_path = os.path.join(self.config.model.save_dir, config_save_name)
53 | with open(config_save_path, 'w') as outfile:
54 | outfile.write(self.config.dump())
55 |
56 |
57 | def _get_optimizer(self, params):
58 | '''
59 | Creates the Adam optimizer with initial learning rate and epsilon specified inside config by config.lr.init and config.optim.eps, respectively
60 |
61 | Parameters:
62 | ------------
63 | - params: Parameters to be optimized by the optimizer
64 |
65 | Returns:
66 | ------------
67 | - optimizer: Adam optimizer with desired learning rate, epsilon to optimize parameters specified by params
68 | '''
69 |
70 | optimizer = torch.optim.Adam(params=params,
71 | lr=self.config.lr.init,
72 | eps=self.config.optim.eps)
73 | return optimizer
74 |
75 |
76 | def _get_lr_scheduler(self):
77 | '''
78 | Creates the learning rate scheduler based on scheduler type specified in config.lr.scheduler. Options are constrained by StepLR (explr), CyclicLR (cyclic) and LambdaLR (which is used as constantlr).
79 | '''
80 |
81 | if self.config.lr.scheduler.lower() == 'explr':
82 | scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, gamma=self.config.lr.explr.gamma, step_size=self.config.lr.explr.step_size)
83 | elif self.config.lr.scheduler.lower() == 'cyclic':
84 | scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=self.config.lr.cyclic.base_lr, max_lr=self.config.lr.cyclic.max_lr, mode=self.config.lr.cyclic.mode, gamma=self.config.lr.cyclic.gamma, step_size_up=self.config.lr.cyclic.step_size_up, cycle_momentum=False)
85 | elif self.config.lr.scheduler.lower() == 'constantlr':
86 | scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda x: 1)
87 | else:
88 | assert False, 'Only these learning rate schedulers are available: StepLR (explr), CyclicLR (cyclic) and LambdaLR (which is constantlr)!'
89 | return scheduler
90 |
91 |
92 | def _get_metrics(self):
93 | '''
94 | Empty function, overwritten function must return metric names as list metrics as nested dictionary. Keys are:
95 | - train: dict, Training Mean metrics
96 | - valid: dict, Validation Mean metrics
97 | '''
98 | pass
99 |
100 |
101 | def _reset_metrics(self, train_valid='train'):
102 | '''
103 | Resets the metrics
104 |
105 | Parameters:
106 | ------------
107 | - train_valid: str, Which metrics to reset, 'train' by default
108 | '''
109 |
110 | for _, metric in self.metrics[train_valid].items():
111 | metric.reset()
112 |
113 |
114 | def _update_metrics(self, loss_dict, batch_size, train_valid='train', verbose=True):
115 | '''
116 | Updates the metrics
117 |
118 | Parameters:
119 | ------------
120 | - loss_dict: dict, Dictionary with loss values to log in Tensorboard
121 | - batch_size: int, Number of trials for which the metrics are computed for
122 | - train_valid, str, Which metrics to update, 'train' by default
123 | - verbose: bool, Whether to print the warning if a key in metric_names doesn't exist in loss_dict
124 | '''
125 |
126 | for key in self.metric_names:
127 | if key not in loss_dict:
128 | if verbose:
129 | self.logger.warning(f'{key} does not exist in loss_dict, metric cannot be updated!')
130 | else:
131 | pass
132 | else:
133 | self.metrics[train_valid][key].update(loss_dict[key], batch_size)
134 |
135 |
136 | def _get_logger(self, prefix='dfine'):
137 | '''
138 | Creates the logger which is saved as .log file under config.model.save_dir
139 |
140 | Parameters:
141 | ------------
142 | - prefix: str, Prefix which is used as logger's name and .log file's name, 'dfine' by default
143 |
144 | Returns:
145 | ------------
146 | - logger: logging.Logger, Logger object to write logs into .log file
147 | '''
148 |
149 | os.makedirs(self.config.model.save_dir, exist_ok=True)
150 | date_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
151 | log_path = os.path.join(self.config.model.save_dir, f'{prefix}_{date_time}.log')
152 |
153 | # from: https://stackoverflow.com/a/56689445/16228104
154 | logger = logging.getLogger(f'{prefix.upper()} Logger')
155 | logger.setLevel(logging.DEBUG)
156 |
157 | # Remove old handlers from logger (since logger is static object) so that in several calls, it doesn't overwrite to previous log files
158 | handlers = logger.handlers[:]
159 | for handler in handlers:
160 | logger.removeHandler(handler)
161 | handler.close()
162 |
163 | # Create file handler which logs even debug messages
164 | fh = logging.FileHandler(log_path, mode='w')
165 | fh.setLevel(logging.DEBUG)
166 |
167 | # Create console handler with a higher log level
168 | ch = logging.StreamHandler()
169 | ch.setLevel(logging.DEBUG)
170 |
171 | # Create formatter and add it to the handlers
172 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', '%m/%d/%Y %I:%M:%S %p')
173 | ch.setFormatter(formatter)
174 | fh.setFormatter(formatter)
175 |
176 | # Add the handlers to logger
177 | logger.addHandler(ch)
178 | logger.addHandler(fh)
179 |
180 | return logger
181 |
182 |
183 | def _load_ckpt(self, model, optimizer, lr_scheduler=None):
184 | '''
185 | Loads the checkpoint whose number is specified in the config by config.load.ckpt
186 |
187 | Parameters:
188 | ------------
189 | - model: torch.nn.Module, Initialized DFINE model to load the parameters to
190 | - optimizer: torch.optim.Adam, Initialized Adam optimizer to load optimizer parameters to (loading is skipped if config.load.resume_train is False)
191 | - lr_scheduler: torch.optim.lr_scheduler, Initialized learning rate scheduler to load learning rate scheduler parameters to, None by default (loading is skipped if config.load.resume_train is False)
192 |
193 | Returns:
194 | ------------
195 | - model: torch.nn.Module, Loaded DFINE model
196 | - optimizer: torch.optim.Adam, Loaded Adam optimizer (if config.load.resume_train is True, otherwise, initialized optimizer is returned)
197 | - lr_scheduler: torch.optim.lr_scheduler, Loaded learning rate scheduler (if config.load.resume_train is True, otherwise, initialized learning rate scheduler is returned)
198 | '''
199 |
200 | self.logger.warning('Optimizer and LR scheduler can be loaded only in resume_train mode, else they are re-initialized')
201 | load_path = os.path.join(self.config.model.save_dir, 'ckpts', f'{self.config.load.ckpt}_ckpt.pth')
202 | self.logger.info(f'Loading model from: {load_path}...')
203 |
204 | # Load the checkpoint
205 | try:
206 | ckpt = torch.load(load_path)
207 | except:
208 | self.logger.error('Ckpt path does not exist!')
209 | assert False, ''
210 |
211 | # If config.load.resume_train is True, load optimizer and learning rate scheduler
212 | if self.config.load.resume_train:
213 | self.start_epoch = ckpt['epoch'] + 1 if isinstance(ckpt['epoch'], int) else 1
214 | try:
215 | optimizer.load_state_dict(ckpt['optimizer'])
216 | except:
217 | self.logger.error('Optimizer cannot be loaded!, check if optimizer type is consistent!')
218 | assert False, ''
219 |
220 | if lr_scheduler is not None:
221 | try:
222 | lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
223 | except:
224 | self.logger.error('LR scheduler cannot be loaded, check if scheduler type is consistent!')
225 | assert False, ''
226 |
227 | try:
228 | model.load_state_dict(ckpt['state_dict'])
229 | except:
230 | self.logger.error('Given architecture in config does not match the architecture of given checkpoint!')
231 | assert False, ''
232 |
233 | self.logger.info(f'Checkpoint succesfully loaded from {load_path}!')
234 | return model, optimizer, lr_scheduler
235 |
236 |
237 | def write_model_gradients(self, model, step, prefix='unclipped'):
238 | '''
239 | Logs the gradient norms to Tensorboard
240 |
241 | Parameters:
242 | ------------
243 | - model: torch.nn.Module, DFINE model whose gradients are to be logged
244 | - step: int, Step to log gradients for
245 | - prefix: str, Prefix for gradient norms to be logged, it can be 'clipped' or 'unclipped', 'unclipped' by default
246 | '''
247 |
248 | total_norm = 0
249 | for name, p in model.named_parameters():
250 | if p.grad is not None:
251 | grad_norm = p.grad.detach().data.cpu().norm(2)
252 | total_norm += grad_norm ** 2
253 | self.writer.add_scalar('grads/' + name + f"/{prefix}_grad", grad_norm, step)
254 |
255 | total_norm = total_norm ** 0.5
256 | self.writer.add_scalar(f'grads/total_grad_{prefix}_norm', total_norm, step)
257 |
258 |
259 | def _save_ckpt(self, epoch, model, optimizer, lr_scheduler=None):
260 | '''
261 | Saves the checkpoint under ckpt_save_dir (see __init__) with filename {epoch}_ckpt.pth
262 |
263 | Parameters:
264 | ------------
265 | - epoch: int, Epoch number for which the checkpoint is to be saved for
266 | - model: torch.nn.Module, DFINE model to be saved
267 | - optimizer: torch.optim.Adam, Adam optimizer to be saved
268 | - lr_scheduler: torch.optim.lr_scheduler, Learning rate scheduler to be saved
269 | '''
270 |
271 | save_path = os.path.join(self.ckpt_save_dir, f'{epoch}_ckpt.pth')
272 | if lr_scheduler is not None:
273 | torch.save({
274 | 'state_dict': model.state_dict(),
275 | 'optimizer': optimizer.state_dict(),
276 | 'lr_scheduler': lr_scheduler.state_dict(),
277 | 'epoch': epoch
278 | }, save_path)
279 | else:
280 | torch.save({
281 | 'state_dict': model.state_dict(),
282 | 'optimizer': optimizer.state_dict(),
283 | 'epoch': epoch
284 | }, save_path)
285 |
286 |
287 | def train_epoch(self, epoch, train_loader):
288 | '''
289 | Empty function, overwritten function performs single epoch training
290 |
291 | Parameters:
292 | ------------
293 | - epoch: int, Epoch number for which the training iterations are performed
294 | - train_loader: torch.utils.data.DataLoader, Training dataloader
295 | '''
296 |
297 | pass
298 |
299 |
300 | def valid_epoch(self, epoch, valid_loader):
301 | '''
302 | Empty function, overwritten function performs single epoch validation
303 |
304 | Parameters:
305 | ------------
306 | - epoch: int, Epoch number for which the training iterations are performed
307 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader
308 | '''
309 |
310 | pass
311 |
312 |
313 | def train(self, train_loader, valid_loader):
314 | '''
315 | Empty function, overwritten function performs DFINE training for number of epochs specified in config.train.num_epochs
316 |
317 | Parameters:
318 | ------------
319 | - train_loader: torch.utils.data.DataLoader, Training dataloader
320 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader
321 | '''
322 |
323 | pass
--------------------------------------------------------------------------------
/trainers/TrainerDFINE.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2023 University of Southern California
3 | See full notice in LICENSE.md
4 | Hamidreza Abbaspourazad*, Eray Erturk* and Maryam M. Shanechi
5 | Shanechi Lab, University of Southern California
6 | '''
7 |
8 | import os
9 | import torch
10 | from torch.nn.utils import clip_grad_norm_
11 | from tqdm import tqdm
12 | import numpy as np
13 | import timeit
14 | import matplotlib
15 | import matplotlib.pyplot as plt
16 | # matplotlib.use('Agg') # It disables interactive GUI backend. Commented by default but in case of config.train.plot_save_steps << config.train.num_epochs, please uncomment this line (otherwise, may throw matplotlib error).
17 |
18 | from trainers.BaseTrainer import BaseTrainer
19 | from DFINE import DFINE
20 | from python_utils import carry_to_device
21 | from time_series_utils import get_nrmse_error
22 | from metrics import Mean
23 |
24 | torch.set_printoptions(precision=3)
25 | np.set_printoptions(precision=3)
26 |
27 |
28 | class TrainerDFINE(BaseTrainer):
29 | '''
30 | Trainer class for DFINE model.
31 | '''
32 |
33 | def __init__(self, config):
34 | '''
35 | Initializer for a TrainerDFINE object. Note that TrainerDFINE is a subclass of trainers.BaseTrainer.
36 |
37 | Parameters
38 | ------------
39 | - config: yacs.config.CfgNode, yacs config which contains all hyperparameters required to create and train the DFINE model
40 | Please see config.py for the hyperparameters, their default values and definitions.
41 | '''
42 |
43 | super(TrainerDFINE, self).__init__(config)
44 |
45 | # Initialize training time statistics
46 | self.training_time = 0
47 | self.training_time_epochs = []
48 |
49 | # Initialize best validation losses
50 | self.best_val_loss = torch.inf
51 | if self.config.model.supervise_behv:
52 | self.best_val_behv_loss = torch.inf
53 |
54 | # Initialize logger
55 | self.logger = self._get_logger(prefix='dfine')
56 |
57 | # Set device
58 | self.device = 'cpu' if self.config.device == 'cpu' or not torch.cuda.is_available() else 'cuda:0'
59 | self.config.device = self.device # if cuda is asked in config but it's not available, config is also updated
60 |
61 | # Initialize the model, optimizer and learning rate scheduler
62 | self.dfine = DFINE(self.config);
63 | self.dfine.to(self.device) # carry the model to the desired device
64 | self.optimizer = self._get_optimizer(params=self.dfine.parameters())
65 | self.lr_scheduler = self._get_lr_scheduler()
66 |
67 | # Load ckpt if asked, model with best validation model loss can be loaded as well, which is saved with name 'best_loss_ckpt.pth'
68 | if (isinstance(self.config.load.ckpt, int) and self.config.load.ckpt > 1) or isinstance(self.config.load.ckpt, str):
69 | self.dfine, self.optimizer, self.lr_scheduler = self._load_ckpt(model=self.dfine,
70 | optimizer=self.optimizer,
71 | lr_scheduler=self.lr_scheduler)
72 |
73 | # Get the metrics
74 | self.metric_names, self.metrics = self._get_metrics()
75 |
76 | # Save the config
77 | self._save_config()
78 |
79 |
80 | def _get_metrics(self):
81 | '''
82 | Creates the metric names and nested metrics dictionary.
83 |
84 | Returns:
85 | ------------
86 | - metric_names: list, Metric names to log in Tensorboard, which are the keys of train/valid defined below
87 | - metrics_dictionary: dict, nested metrics dictionary. Keys (and metric_names) are (e.g. for config.loss.steps_ahead = [1,2]):
88 | - train:
89 | - steps_{k}_mse: metrics.Mean, Training {k}-step ahead predicted MSE
90 | - model_loss: metrics.Mean, Training negative sum of {k}-step ahead predicted MSEs (e.g. steps_1_mse + steps_2_mse)
91 | - reg_loss: metrics.Mean, L2 regularization loss for DFINE encoder and decoder weights
92 | - behv_mse: metrics.Mean, Exists if config.model.supervise_behv is True, Training behavior MSE
93 | - behv_loss: metrics.Mean, Exists if config.model.supervise_behv is True, Training behavior reconstruction loss
94 | - total_loss: metrics.Mean, Sum of training model_loss, reg_loss and behv_loss (if config.model.supervise_behv is True)
95 | - valid:
96 | - steps_{k}_mse: metrics.Mean, Validation {k}-step ahead predicted MSE
97 | - model_loss: metrics.Mean, Validation negative sum of {k}-step ahead predicted MSEs (e.g. steps_1_mse + steps_2_mse)
98 | - reg_loss: metrics.Mean, L2 regularization loss for DFINE encoder and decoder weights
99 | - behv_mse: metrics.Mean, Exists if config.model.supervise_behv is True, Validation behavior MSE
100 | - behv_loss: metrics.Mean, Exists if config.model.supervise_behv is True, Validation behavior reconstruction loss
101 | - total_loss: metrics.Mean, Sum of validation model_loss, reg_loss and behv_loss (if config.model.supervise_behv is True)
102 | '''
103 |
104 | metric_names = []
105 | for k in self.config.loss.steps_ahead:
106 | metric_names.append(f'steps_{k}_mse')
107 |
108 | if self.config.model.supervise_behv:
109 | metric_names.append('behv_mse')
110 | metric_names.append('behv_loss')
111 | metric_names.append('model_loss')
112 | metric_names.append('reg_loss')
113 | metric_names.append('total_loss')
114 |
115 | metrics = {}
116 | metrics['train'] = {}
117 | metrics['valid'] = {}
118 |
119 | for key in metric_names:
120 | metrics['train'][key] = Mean()
121 | metrics['valid'][key] = Mean()
122 |
123 | return metric_names, metrics
124 |
125 |
126 | def _get_log_str(self, epoch, train_valid='train'):
127 | '''
128 | Creates the logging/printing string of training/validation statistics at each epoch
129 |
130 | Parameters:
131 | ------------
132 | - epoch: int, Number of epoch to log the statistics for
133 | - train_valid: str, Training or validation prefix to log the statistics, 'train' by default
134 |
135 | Returns:
136 | ------------
137 | - log_str: str, Logging string
138 | '''
139 |
140 | log_str = f'Epoch {epoch}, {train_valid.upper()}\n'
141 |
142 | # Logging k-step ahead predicted MSEs
143 | for k in self.config.loss.steps_ahead:
144 | if k == 1:
145 | log_str += f"{k}_step_mse: {self.metrics[train_valid][f'steps_{k}_mse'].compute():.5f}\n"
146 | else:
147 | log_str += f"{k}_steps_mse: {self.metrics[train_valid][f'steps_{k}_mse'].compute():.5f}\n"
148 |
149 | # Logging L2 regularization loss and L2 scale
150 | log_str += f"reg_loss: {self.metrics[train_valid]['reg_loss'].compute():.5f}, scale_l2: {self.dfine.scale_l2:.5f}\n"
151 |
152 | # If model is behavior-supervised, log behavior reconstruction loss
153 | if self.config.model.supervise_behv:
154 | log_str += f"behv_loss: {self.metrics[train_valid]['behv_loss'].compute():.5f}, scale_behv_recons: {self.dfine.scale_behv_recons:.5f}\n"
155 |
156 | # Finally, log model_loss and total_loss to optimize
157 | log_str += f"model_loss: {self.metrics[train_valid]['model_loss'].compute():.5f}, total_loss: {self.metrics[train_valid]['total_loss'].compute():.5f}\n"
158 | return log_str
159 |
160 |
161 | def train_epoch(self, epoch, train_loader, verbose=True):
162 | '''
163 | Performs single epoch training over batches, logging to Tensorboard and plot generation
164 |
165 | Parameters:
166 | ------------
167 | - epoch: int, Number of epoch to perform training iteration
168 | - train_loader: torch.utils.data.DataLoader, Training dataloader
169 | '''
170 |
171 | # Take the model into training mode
172 | self.dfine.train()
173 |
174 | # Reset the metrics at the beginning of each epoch
175 | self._reset_metrics(train_valid='train')
176 |
177 | # Keep track of update step for logging the gradient norms
178 | step = (epoch - 1) * len(train_loader) + 1
179 |
180 | # Keep the time which training epoch starts
181 | start_time = timeit.default_timer()
182 |
183 | # Start iterating over batches
184 | with tqdm(train_loader, unit='batch') as tepoch:
185 | for _, batch in enumerate(tepoch):
186 | tepoch.set_description(f"Epoch {epoch}, TRAIN")
187 |
188 | # Carry data to device
189 | batch = carry_to_device(data=batch, device=self.device)
190 | y_batch, behv_batch, mask_batch = batch
191 |
192 | # Perform forward pass and compute loss
193 | model_vars = self.dfine(y=y_batch, mask=mask_batch)
194 | loss, loss_dict = self.dfine.compute_loss(y=y_batch,
195 | model_vars=model_vars,
196 | mask=mask_batch,
197 | behv=behv_batch)
198 |
199 | # Compute model gradients
200 | self.optimizer.zero_grad()
201 | loss.backward()
202 |
203 | # Log UNCLIPPED model gradients after gradient computations
204 | self.write_model_gradients(self.dfine, step=step, prefix='unclipped')
205 |
206 | # Skip gradient clipping for the first epoch
207 | if epoch > 1:
208 | clip_grad_norm_(self.dfine.parameters(), self.config.optim.grad_clip)
209 |
210 | # Log CLIPPED model gradients after gradient computations
211 | self.write_model_gradients(model=self.dfine, step=step, prefix='clipped')
212 |
213 | # Update model parameters
214 | self.optimizer.step()
215 |
216 | # Update metrics
217 | self._update_metrics(loss_dict=loss_dict,
218 | batch_size=y_batch.shape[0],
219 | train_valid='train',
220 | verbose=False)
221 |
222 | # Update the step
223 | step += 1
224 |
225 | # Get the runtime for the training epoch
226 | epoch_time = timeit.default_timer() - start_time
227 | self.training_time += epoch_time
228 | self.training_time_epochs.append(epoch_time)
229 |
230 | # Save model, optimizer and learning rate scheduler (we save the initial and the last model no matter what config.model.save_steps is)
231 | if epoch % self.config.model.save_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs:
232 | self._save_ckpt(epoch=epoch,
233 | model=self.dfine,
234 | optimizer=self.optimizer,
235 | lr_scheduler=self.lr_scheduler)
236 |
237 | # Write model summary
238 | self.write_summary(epoch, prefix='train')
239 |
240 | # Create and save plots from the last batch
241 | if epoch % self.config.train.plot_save_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs:
242 | self.create_plots(y_batch=y_batch,
243 | behv_batch=behv_batch,
244 | model_vars=model_vars,
245 | epoch=epoch,
246 | prefix='train')
247 |
248 | # Logging the training step information for last batch
249 | if verbose and (epoch % self.config.train.print_log_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs):
250 | log_str = self._get_log_str(epoch=epoch, train_valid='train')
251 | self.logger.info(log_str)
252 |
253 | # Update LR
254 | self.lr_scheduler.step()
255 |
256 |
257 | def valid_epoch(self, epoch, valid_loader, verbose=True):
258 | '''
259 | Performs single epoch validation over batches, logging to Tensorboard and plot generation
260 |
261 | Parameters:
262 | ------------
263 | - epoch: int, Number of epoch to perform validation
264 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader
265 | '''
266 |
267 | with torch.no_grad():
268 | # Take the model into evaluation mode
269 | self.dfine.eval()
270 |
271 | # Reset metrics at the beginning of each epoch
272 | self._reset_metrics(train_valid='valid')
273 |
274 | # Start iterating over the batches
275 | y_all, mask_all = [], []
276 | with tqdm(valid_loader, unit='batch') as tepoch:
277 | for _, batch in enumerate(tepoch):
278 | tepoch.set_description(f"Epoch {epoch}, VALID")
279 |
280 | # Carry data to device
281 | batch = carry_to_device(data=batch, device=self.device)
282 | y_batch, behv_batch, mask_batch = batch
283 | y_all.append(y_batch)
284 | mask_all.append(mask_batch)
285 |
286 | # Perform forward pass and compute loss
287 | model_vars = self.dfine(y=y_batch, mask=mask_batch)
288 | _, loss_dict = self.dfine.compute_loss(y=y_batch,
289 | model_vars=model_vars,
290 | mask=mask_batch,
291 | behv=behv_batch)
292 |
293 | # Update metrics
294 | self._update_metrics(loss_dict=loss_dict,
295 | batch_size=y_batch.shape[0],
296 | train_valid='valid',
297 | verbose=False)
298 |
299 | # Perform one-step-ahead prediction on the provided validation data, for evaluation
300 | y_all = torch.cat(y_all, dim=0)
301 | mask_all = torch.cat(mask_all, dim=0)
302 | model_vars_all = self.dfine(y=y_all, mask=mask_all)
303 | y_pred_all = model_vars_all['y_pred']
304 | _, one_step_ahead_nrmse = get_nrmse_error(y_all[:, 1:, :], y_pred_all)
305 | self.training_valid_one_step_nrmses.append(one_step_ahead_nrmse)
306 |
307 | # Write model summary
308 | self.write_summary(epoch, prefix='valid')
309 |
310 | # Save the best validation loss model (and best behavior reconstruction loss model if supervised)
311 | if self.metrics['valid']['model_loss'].compute() < self.best_val_loss:
312 | self.best_val_loss = self.metrics['valid']['model_loss'].compute()
313 | self._save_ckpt(epoch='best_loss',
314 | model=self.dfine,
315 | optimizer=self.optimizer,
316 | lr_scheduler=self.lr_scheduler)
317 |
318 | if self.config.model.supervise_behv:
319 | if self.metrics['valid']['behv_loss'].compute() < self.best_val_behv_loss:
320 | self.best_val_behv_loss = self.metrics['valid']['behv_loss'].compute()
321 | self._save_ckpt(epoch='best_behv_loss',
322 | model=self.dfine,
323 | optimizer=self.optimizer,
324 | lr_scheduler=self.lr_scheduler)
325 |
326 | # Create and save plots from last batch
327 | if epoch % self.config.train.plot_save_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs:
328 | self.create_plots(y_batch=y_batch,
329 | behv_batch=behv_batch,
330 | model_vars=model_vars,
331 | epoch=epoch,
332 | prefix='valid')
333 |
334 | if verbose and (epoch % self.config.train.print_log_steps == 0 or epoch == 1 or epoch == self.config.train.num_epochs):
335 | # Logging the validation step information for last batch
336 | log_str = self._get_log_str(epoch=epoch, train_valid='valid')
337 | self.logger.info(log_str)
338 |
339 |
340 | def train(self, train_loader, valid_loader=None):
341 | '''
342 | Performs full training of DFINE model
343 |
344 | Parameters:
345 | ------------
346 | - train_loader: torch.utils.data.DataLoader, Training dataloader
347 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader, None by default (if no valid_loader is provided, validation is skipped)
348 | '''
349 |
350 | # Bookkeeping the validation NRMSEs over the course of training
351 | self.training_valid_one_step_nrmses = []
352 |
353 | # Start iterating over the epochs
354 | for epoch in range(self.start_epoch, self.config.train.num_epochs + 1):
355 | # Perform validation with the initialized model
356 | if epoch == self.start_epoch:
357 | self.valid_epoch(epoch, valid_loader, verbose=False)
358 |
359 | # Perform training iteration over train_loader
360 | self.train_epoch(epoch, train_loader)
361 |
362 | # Perform validation over valid_loader if it's not None and we're at validation epoch
363 | if (epoch % self.config.train.valid_step == 0) and isinstance(valid_loader, torch.utils.data.dataloader.DataLoader):
364 | self.valid_epoch(epoch, valid_loader)
365 |
366 |
367 | def create_plots(self, y_batch, model_vars, behv_batch=None, mask_batch=None, epoch=1, trial_num=0, prefix='train'):
368 | '''
369 | Creates training/validation plots of neural reconstruction, manifold latent factors and dynamic latent factors
370 |
371 | Parameters:
372 | ------------
373 | - y_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), Batch of high-dimensional neural observations
374 | - model_vars: dict, Dictionary which contains inferrred latents, predictions and reconstructions. See DFINE.forward for further details.
375 | - epoch: int, Number of epoch for which to create plot
376 | - behv_batch: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of behavior, None by default
377 | - trial_num: int, Trial number to plot
378 | - prefix: str, Plotname prefix to save plots
379 | '''
380 |
381 | # Create the mask if it's None
382 | if mask_batch is None:
383 | mask_batch = torch.ones(y_batch.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
384 |
385 | # Generate and save reconstructed neural observation plot
386 | self.create_y_plot(y_batch=y_batch, y_hat_batch=model_vars['y_hat'], mask_batch=mask_batch, epoch=epoch, trial_num=trial_num, prefix=f'{prefix}', feat_name='y_hat')
387 | # Generate and save smoothed neural observation plot
388 | self.create_y_plot(y_batch=y_batch, y_hat_batch=model_vars['y_smooth'], mask_batch=mask_batch, epoch=epoch, trial_num=trial_num, prefix=f'{prefix}', feat_name='y_smooth')
389 |
390 | # Generate and save smoothed manifold latent factor plot
391 | self.create_k_step_ahead_plot(y_batch=y_batch, model_vars=model_vars, mask_batch=mask_batch, epoch=epoch, trial_num=trial_num, prefix=prefix)
392 |
393 | # Generate and save projected (encoder output directly) manifold latent factor plot
394 | self.create_latent_factor_plot(f=model_vars['a_hat'], epoch=epoch, trial_num=trial_num, prefix=prefix, feat_name='a_hat')
395 | # Generate and save smoothed manifold latent factor plot
396 | self.create_latent_factor_plot(f=model_vars['a_smooth'], epoch=epoch, trial_num=trial_num, prefix=prefix, feat_name='a_smooth')
397 | # Generate and save smoothed dynamic latent factor plot
398 | self.create_latent_factor_plot(f=model_vars['x_smooth'], epoch=epoch, trial_num=trial_num, prefix=prefix, feat_name='x_smooth')
399 | # Generate and save reconstructed behavior if model is behavior-supervised
400 | if self.config.model.supervise_behv and (behv_batch is not None):
401 | self.create_behv_recons_plot(behv_batch=behv_batch, behv_hat_batch=model_vars['behv_hat'], epoch=epoch, trial_num=trial_num, prefix=prefix)
402 |
403 | plt.close('all')
404 |
405 |
406 | def create_y_plot(self, y_batch, y_hat_batch, mask_batch=None, epoch=1, trial_num=0, prefix='train', feat_name='y_hat'):
407 | '''
408 | Creates true and estimated neural observation plots during training and validation
409 |
410 | Parameters:
411 | ------------
412 | - y_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), True high-dimensional neural observation
413 | - y_hat_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), Reconstructed high-dimensional neural observation, smoothed/filtered/reconstructed neural observation can be provided
414 | - mask_batch: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether
415 | observations at each timestep exists (1) or are missing (0)
416 | - epoch: int, Number of epoch for which to create the plot
417 | - trial_num:, int, Trial number in the batch to plot
418 | - prefix: str, Plotname prefix to save the plot
419 | - feat_name: str, Feature name of y_hat_batch (e.g. y_hat/y_smooth) used in plotname
420 | '''
421 |
422 | # Create the mask if it's None
423 | if mask_batch is None:
424 | mask_batch = torch.ones(y_batch.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
425 |
426 | # Detach tensors for plotting
427 | y_batch = y_batch.detach().cpu()
428 | y_hat_batch = y_hat_batch.detach().cpu()
429 | mask_batch = mask_batch.detach().cpu()
430 |
431 | # Mask y_batch and y_hat_batch
432 | num_seq, _, dim_y = y_batch.shape
433 | mask_bool_batch = mask_batch.type(torch.bool).tile(1, 1, self.dfine.dim_y)
434 | y_batch = y_batch[mask_bool_batch].reshape(num_seq, -1, self.dfine.dim_y)
435 | y_hat_batch = y_hat_batch[mask_bool_batch].reshape(num_seq, -1, self.dfine.dim_y)
436 |
437 | # Create the figure
438 | fig = plt.figure(figsize=(20,15))
439 | num_samples = y_batch.shape[1]
440 | color_index = range(num_samples)
441 | color_map = plt.cm.get_cmap('viridis')
442 |
443 | # Plot the true observations and noiseless observations (if it's provided)
444 | if dim_y >= 3:
445 | ax = fig.add_subplot(321, projection='3d')
446 | ax_m = ax.scatter(y_batch[trial_num, :, 0], y_batch[trial_num, :, 1], y_batch[trial_num, :, 2], c=color_index, vmin=0, vmax=num_samples, s=35, cmap=color_map, label='y_true')
447 |
448 | ax.set_xlabel('Dim 0')
449 | ax.set_ylabel('Dim 1')
450 | ax.set_zlabel('Dim 2')
451 | ax.set_title(f'True observations in 3d')
452 | ax.legend()
453 | fig.colorbar(ax_m)
454 |
455 | # Plot the reconstructed observation and noiseless observations (if it's provided)
456 | ax = fig.add_subplot(322, projection='3d')
457 | ax_m = ax.scatter(y_hat_batch[trial_num, :, 0], y_hat_batch[trial_num, :, 1], y_hat_batch[trial_num, :, 2], c=color_index, vmin=0, vmax=num_samples, s=35, cmap=color_map, label='y_hat')
458 | ax.set_title(f'Reconstructed observations in 3d')
459 | ax.set_xlabel('Dim 0')
460 | ax.set_ylabel('Dim 1')
461 | ax.set_zlabel('Dim 2')
462 | ax.legend()
463 | fig.colorbar(ax_m)
464 |
465 | # Plot Dim 0
466 | ax = fig.add_subplot(324)
467 | ax.plot(range(num_samples), y_batch[trial_num, :, 0], 'g', label='y_true')
468 | ax.plot(range(num_samples), y_hat_batch[trial_num, :, 0], 'b', label='y_hat')
469 | ax.set_title('Dim 0')
470 |
471 | if dim_y >= 2:
472 | # Plot Dim 1
473 | ax = fig.add_subplot(325)
474 | ax.plot(range(num_samples), y_batch[trial_num, :, 1], 'g', label='y_true')
475 | ax.plot(range(num_samples), y_hat_batch[trial_num, :, 1] ,'b', label='y_hat')
476 | ax.set_title('Dim 1')
477 |
478 | if dim_y >= 3:
479 | # Plot Dim 2
480 | ax = fig.add_subplot(326)
481 | ax.plot(range(num_samples), y_batch[trial_num, :, 2], 'g', label='y_true')
482 | ax.plot(range(num_samples), y_hat_batch[trial_num, :,2], 'b', label='y_hat')
483 | ax.set_title('Dim 2')
484 | ax.legend()
485 |
486 | # Save the plot under plot_save_dir
487 | plot_name = f'{prefix}_{feat_name}_{epoch}.png'
488 | plt.savefig(os.path.join(self.plot_save_dir, plot_name))
489 | plt.close('all')
490 |
491 |
492 | def create_k_step_ahead_plot(self, y_batch, model_vars, mask_batch=None, epoch=1, trial_num=0, prefix='train'):
493 | '''
494 | Creates true and k-step ahead predicted neural observation plots during training and validation
495 |
496 | Parameters:
497 | ------------
498 | - y_batch: torch.Tensor, shape: (num_seq, num_steps, dim_y), True high-dimensional neural observation
499 | - model_vars: dict, Dictionary which contains inferrred latents, predictions and reconstructions. See DFINE.forward for further details.
500 | - mask_batch: torch.Tensor, shape: (num_seq, num_steps, 1), Mask for manifold latent factors which shows whether
501 | observations at each timestep exists (1) or are missing (0)
502 | - epoch: int, Number of epoch for which to create the plot
503 | - trial_num:, int, Trial number in the batch to plot
504 | - prefix: str, Plotname prefix to save the plot
505 | '''
506 |
507 | num_total_steps = y_batch.shape[1]
508 |
509 | # Create the mask if it's None
510 | if mask_batch is None:
511 | mask_batch = torch.ones(y_batch.shape[:-1], dtype=torch.float32).unsqueeze(dim=-1)
512 |
513 | # Get the number of steps ahead for which DFINE is optimized and create the figure
514 | num_k = len(self.config.loss.steps_ahead)
515 | fig = plt.figure(figsize=(20, 20))
516 | fig_num = 1
517 |
518 | # Start iterating over steps ahead for plotting
519 | for k in self.config.loss.steps_ahead:
520 | # Get the k-step ahead prediction
521 | y_pred_k_batch, _, _ = self.dfine.get_k_step_ahead_prediction(model_vars, k)
522 |
523 | # Detach tensors for plotting and take timesteps from k to T (since we're plotting k-step ahead predictions)
524 | y_batch_k = y_batch[:, k:, ...].detach().cpu()
525 | mask_batch_k = mask_batch[:, k:, :].detach().cpu()
526 | y_pred_k_batch = y_pred_k_batch.detach().cpu()
527 |
528 | # Mask y and y_hat
529 | num_seq = y_batch.shape[0]
530 | mask_bool = mask_batch_k.type(torch.bool).tile(1, 1, self.dfine.dim_y)
531 | y_batch_k = y_batch_k[mask_bool].reshape(num_seq, -1, self.dfine.dim_y)
532 | y_pred_k_batch = y_pred_k_batch[mask_bool].reshape(num_seq, -1, self.dfine.dim_y)
533 |
534 | # Plot dimension 0
535 | ax = fig.add_subplot(num_k, 2, fig_num)
536 | ax.plot(range(k, num_total_steps), y_batch_k[trial_num, :, 0], 'g', label=f'{k}-step y_true')
537 | ax.plot(range(k, num_total_steps), y_pred_k_batch[trial_num, :, 0], 'b', label=f'{k}-step predicted y')
538 | ax.set_title(f'k={k} step ahead')
539 | ax.set_xlabel('Time')
540 | ax.set_ylabel('Dim 0')
541 | ax.legend()
542 | fig_num += 1
543 |
544 | # Plot first 3 dimensions of k-step prediction as 3D scatter plot (mostly not useful visualization unless the manifold is obvious in first 3 dimensions)
545 | color_index = range(y_batch_k.shape[1])
546 | color_map = plt.cm.get_cmap('viridis')
547 | ax = fig.add_subplot(num_k, 2, fig_num, projection='3d')
548 | ax_m = ax.scatter(y_pred_k_batch[trial_num, :, 0], y_pred_k_batch[trial_num, :, 1], y_pred_k_batch[trial_num, :, 2], c=color_index, vmin=0, vmax=y_batch.shape[1], s=35, cmap=color_map, label=f'{k}-step predicted y')
549 | ax.set_title(f'k={k} step ahead')
550 | ax.set_xlabel('Dim 0')
551 | ax.set_ylabel('Dim 1')
552 | ax.set_zlabel('Dim 2')
553 | ax.legend()
554 | fig.colorbar(ax_m)
555 | fig_num += 1
556 |
557 | # Save the plot under plot_save_dir
558 | plot_name = f'{prefix}_k_step_obs_{epoch}.png'
559 | plt.savefig(os.path.join(self.plot_save_dir, plot_name))
560 | plt.close('all')
561 |
562 |
563 | def create_latent_factor_plot(self, f, epoch=1, trial_num=0, prefix='train', feat_name='x_smooth'):
564 | '''
565 | Creates dynamic latent factor plots during training/validation
566 |
567 | Parameters:
568 | ------------
569 | - f: torch.Tensor, shape: (num_seq, num_steps, dim_x/dim_a), Batch of inferred dynamic/manifold latent factors, smoothed/filtered factors can be provided
570 | - epoch: int, Number of epoch for which to create dynamic latent factor plot
571 | - trial_num: int, Trial number to plot
572 | - prefix: str, Plotname prefix to save plots
573 | - feat_name: str, Feature name of y_hat_batch (e.g. y_hat/y_smooth) used in plotname
574 | '''
575 |
576 | # Detach the tensor for plotting
577 | f = f.detach().cpu()
578 |
579 | # From feat_name, get whether it's manifold or dynamic latent factors
580 | if feat_name[0].lower() == 'x':
581 | factor_name = 'Dynamic'
582 | else:
583 | factor_name = 'Manifold'
584 |
585 | # Create the figure and colormap
586 | fig = plt.figure(figsize=(10,8))
587 | _, num_steps, dim_f = f.shape
588 | color_index = range(num_steps)
589 | color_map = plt.cm.get_cmap('viridis')
590 |
591 | if dim_f > 2:
592 | # Scatter first 3 dimensions of dynamic latent factors
593 | ax = fig.add_subplot(221, projection='3d')
594 | ax_m = ax.scatter(f[trial_num, :, 0], f[trial_num, :, 1], f[trial_num, :, 2], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map)
595 | ax.set_xlabel('Dim 0')
596 | ax.set_ylabel('Dim 1')
597 | ax.set_zlabel('Dim 2')
598 | ax.set_title(f'{factor_name} latent factors in 3D')
599 | fig.colorbar(ax_m)
600 |
601 | # Scatter first 2 dimensions of dynamic latent factors, top view
602 | ax = fig.add_subplot(222)
603 | ax_m = ax.scatter(f[trial_num, :, 0], f[trial_num, :, 1], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map)
604 | ax.set_xlabel('Dim 0')
605 | ax.set_ylabel('Dim 1')
606 | ax.set_title(f'{factor_name} latent factors from top')
607 | fig.colorbar(ax_m)
608 |
609 | # Plot the first dimension of dynamic latent factors
610 | ax = fig.add_subplot(223)
611 | ax.plot(range(num_steps), f[trial_num, :, 0])
612 | ax.set_xlabel('Time')
613 | ax.set_ylabel('Dim 0')
614 |
615 | # Plot the second dimension of dynamic latent factors
616 | ax = fig.add_subplot(224)
617 | ax.plot(range(num_steps), f[trial_num, :, 1])
618 | ax.set_xlabel('Time')
619 | ax.set_ylabel('Dim 1')
620 |
621 | elif dim_f == 2:
622 | # Scatter first 2 dimensions of dynamic latent factors, top view
623 | ax = fig.add_subplot(221)
624 | ax_m = ax.scatter(f[trial_num, :, 0], f[trial_num, :, 1], c=color_index, vmin=0, vmax=num_steps, s=35, cmap=color_map)
625 | ax.set_xlabel('Dim 0')
626 | ax.set_ylabel('Dim 1')
627 | ax.set_title(f'{factor_name} latent factors from top')
628 | fig.colorbar(ax_m)
629 |
630 | # Plot the first dimension of dynamic latent factors
631 | ax = fig.add_subplot(222)
632 | ax.plot(range(num_steps), f[trial_num, :, 0])
633 | ax.set_xlabel('Time')
634 | ax.set_ylabel('Dim 0')
635 |
636 | # Plot the second dimension of dynamic latent factors
637 | ax = fig.add_subplot(223)
638 | ax.plot(range(num_steps), f[trial_num, :, 1])
639 | ax.set_xlabel('Time')
640 | ax.set_ylabel('Dim 1')
641 |
642 | else:
643 | # Plot the first dimension of dynamic latent factors
644 | ax = fig.add_subplot(111)
645 | ax.plot(range(num_steps), f[trial_num, :, 0])
646 | ax.set_xlabel('Time')
647 | ax.set_ylabel('Dim 0')
648 | fig.suptitle(f'{factor_name} latent factors info', fontsize=16)
649 |
650 | # Save the plot under plot_save_dir
651 | plot_name = f'{prefix}_{feat_name}_{epoch}.png'
652 | plt.savefig(os.path.join(self.plot_save_dir, plot_name))
653 | plt.close('all')
654 |
655 |
656 | def create_behv_recons_plot(self, behv_batch, behv_hat_batch, epoch=1, trial_num=0, prefix='train'):
657 | '''
658 | Creates behavior reconstruction plots during training/validation
659 |
660 | Parameters:
661 | ------------
662 | - behv_batch: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of true behavior
663 | - behv_hat_batch: torch.Tensor, shape: (num_seq, num_steps, dim_behv), Batch of reconstructed behavior
664 | - epoch: int, Number of epoch for which to create dynamic latent factor plot
665 | - trial_num: int, Trial number to plot
666 | - prefix: str, Plotname prefix to save plots
667 | '''
668 |
669 | # Create the figure and detach the tensors for plotting
670 | fig = plt.figure(figsize=(15,20))
671 | behv_batch = behv_batch.detach().cpu()
672 | behv_hat_batch = behv_hat_batch.detach().cpu()
673 |
674 | # Plot the desired behavior dimension
675 | for k_i, i in enumerate(self.config.model.which_behv_dims):
676 | ax = fig.add_subplot(self.dfine.dim_behv, 1, k_i+1)
677 | ax.plot(behv_batch[trial_num, :, i], label='True Behavior', color='green')
678 | ax.plot(behv_hat_batch[trial_num, :, k_i], label='Decoded Behavior', color='red')
679 | ax.set_xlabel(f'Time')
680 | ax.set_ylabel(f'Dim {i+1}')
681 | ax.legend()
682 |
683 | # Save the plot under plot_save_dir
684 | plot_name = f'{prefix}_behv_{epoch}.png'
685 | plt.savefig(os.path.join(self.plot_save_dir, plot_name))
686 | plt.close('all')
687 |
688 |
689 | def save_encoding_results(self, train_loader, valid_loader=None, do_full_inference=True, save_results=True):
690 | '''
691 | Performs inference, reconstruction and predictions for training data and validation data (if provided), and saves training and inference time statistics.
692 | Then, encoding results are saved under {config.model.save_dir}/encoding_results.pt.
693 |
694 | Parameters:
695 | ------------
696 | - train_loader: torch.utils.data.DataLoader, Training dataloader
697 | - valid_loader: torch.utils.data.DataLoader, Validation dataloader, None by default (if no valid_loader is provided, validation inference is skipped)
698 | - do_full_inference: bool, Whether to perform inference on flattened trials of batches of segments
699 | '''
700 | self.dfine.eval()
701 |
702 | with torch.no_grad():
703 | ############################################################################ BATCH INFERENCE ############################################################################
704 | # Create the keys for encoding results dictionary
705 | encoding_dict = {}
706 | encoding_dict['training_time'] = self.training_time
707 | encoding_dict['training_time_epochs'] = self.training_time_epochs
708 | encoding_dict['latent_inference_time'] = dict(train=0, valid=0)
709 |
710 | encoding_dict['x_pred'] = dict(train=[], valid=[])
711 | encoding_dict['x_filter'] = dict(train=[], valid=[])
712 | encoding_dict['x_smooth'] = dict(train=[], valid=[])
713 |
714 | encoding_dict['a_hat'] = dict(train=[], valid=[])
715 | encoding_dict['a_pred'] = dict(train=[], valid=[])
716 | encoding_dict['a_filter'] = dict(train=[], valid=[])
717 | encoding_dict['a_smooth'] = dict(train=[], valid=[])
718 |
719 | encoding_dict['mask'] = dict(train=[], valid=[])
720 |
721 | y_key_list = ['y', 'y_hat', 'y_filter', 'y_smooth', 'y_pred']
722 | for k in self.config.loss.steps_ahead:
723 | if k != 1:
724 | y_key_list.append(f'y_{k}_pred')
725 |
726 | for y_key in y_key_list:
727 | encoding_dict[y_key] = dict(train=[], valid=[])
728 |
729 | # If model is behavior-supervised, create the keys for behavior reconstruction
730 | if self.config.model.supervise_behv:
731 | encoding_dict['behv'] = dict(train=[], valid=[])
732 | encoding_dict['behv_hat'] = dict(train=[], valid=[])
733 |
734 | # Dump train_loader and valid_loader into a dictionary
735 | loaders = dict(train=train_loader, valid=valid_loader)
736 |
737 | # Start iterating over dataloaders
738 | for train_valid, loader in loaders.items():
739 | if isinstance(loader, torch.utils.data.dataloader.DataLoader):
740 | # If loader is not None, start iterating over the batches
741 | for _, batch in enumerate(loader):
742 | # Keep track of latent inference start time
743 | start_time = timeit.default_timer()
744 |
745 | batch = carry_to_device(batch, device=self.device)
746 | y_batch, behv_batch, mask_batch = batch
747 | model_vars = self.dfine(y=y_batch, mask=mask_batch)
748 |
749 | # Add to the latent inference time over the batches
750 | encoding_dict['latent_inference_time'][train_valid] += timeit.default_timer() - start_time
751 |
752 | # Append the inference variables to the empty lists created in the beginning
753 | encoding_dict['x_pred'][train_valid].append(model_vars['x_pred'].detach().cpu())
754 | encoding_dict['x_filter'][train_valid].append(model_vars['x_filter'].detach().cpu())
755 | encoding_dict['x_smooth'][train_valid].append(model_vars['x_smooth'].detach().cpu())
756 |
757 | encoding_dict['a_hat'][train_valid].append(model_vars['a_hat'].detach().cpu())
758 | encoding_dict['a_pred'][train_valid].append(model_vars['a_pred'].detach().cpu())
759 | encoding_dict['a_filter'][train_valid].append(model_vars['a_filter'].detach().cpu())
760 | encoding_dict['a_smooth'][train_valid].append(model_vars['a_smooth'].detach().cpu())
761 |
762 | encoding_dict['mask'][train_valid].append(mask_batch.detach().cpu())
763 | encoding_dict['y'][train_valid].append(y_batch.detach().cpu())
764 | encoding_dict['y_hat'][train_valid].append(model_vars['y_hat'].detach().cpu())
765 | encoding_dict['y_pred'][train_valid].append(model_vars['y_pred'].detach().cpu())
766 | encoding_dict['y_filter'][train_valid].append(model_vars['y_filter'].detach().cpu())
767 | encoding_dict['y_smooth'][train_valid].append(model_vars['y_smooth'].detach().cpu())
768 |
769 | for k in self.config.loss.steps_ahead:
770 | if k != 1:
771 | y_pred_k, _, _ = self.dfine.get_k_step_ahead_prediction(model_vars, k)
772 | encoding_dict[f'y_{k}_pred'][train_valid].append(y_pred_k)
773 |
774 | if self.config.model.supervise_behv:
775 | encoding_dict['behv'][train_valid].append(behv_batch.detach().cpu())
776 | encoding_dict['behv_hat'][train_valid].append(model_vars['behv_hat'].detach().cpu())
777 |
778 | # Convert lists to tensors
779 | encoding_dict['x_pred'][train_valid] = torch.cat(encoding_dict['x_pred'][train_valid], dim=0)
780 | encoding_dict['x_filter'][train_valid] = torch.cat(encoding_dict['x_filter'][train_valid], dim=0)
781 | encoding_dict['x_smooth'][train_valid] = torch.cat(encoding_dict['x_smooth'][train_valid], dim=0)
782 |
783 | encoding_dict['a_hat'][train_valid] = torch.cat(encoding_dict['a_hat'][train_valid], dim=0)
784 | encoding_dict['a_pred'][train_valid] = torch.cat(encoding_dict['a_pred'][train_valid], dim=0)
785 | encoding_dict['a_filter'][train_valid] = torch.cat(encoding_dict['a_filter'][train_valid], dim=0)
786 | encoding_dict['a_smooth'][train_valid] = torch.cat(encoding_dict['a_smooth'][train_valid], dim=0)
787 |
788 | encoding_dict['mask'][train_valid] = torch.cat(encoding_dict['mask'][train_valid], dim=0)
789 | for y_key in y_key_list:
790 | encoding_dict[y_key][train_valid] = torch.cat(encoding_dict[y_key][train_valid], dim=0)
791 |
792 | if self.config.model.supervise_behv:
793 | encoding_dict['behv'][train_valid] = torch.cat(encoding_dict['behv'][train_valid], dim=0)
794 | encoding_dict['behv_hat'][train_valid] = torch.cat(encoding_dict['behv_hat'][train_valid], dim=0)
795 |
796 | ############################################################################ FULL INFERENCE w/ FLATTENED SEQUENCE ############################################################################
797 | encoding_dict_full_inference = {}
798 |
799 | if do_full_inference:
800 | # Create the keys for encoding results dictionary
801 | encoding_dict_full_inference = {}
802 | encoding_dict_full_inference['latent_inference_time'] = dict(train=0, valid=0)
803 |
804 | encoding_dict_full_inference['x_pred'] = dict(train=[], valid=[])
805 | encoding_dict_full_inference['x_filter'] = dict(train=[], valid=[])
806 | encoding_dict_full_inference['x_smooth'] = dict(train=[], valid=[])
807 |
808 | encoding_dict_full_inference['a_hat'] = dict(train=[], valid=[])
809 | encoding_dict_full_inference['a_pred'] = dict(train=[], valid=[])
810 | encoding_dict_full_inference['a_filter'] = dict(train=[], valid=[])
811 | encoding_dict_full_inference['a_smooth'] = dict(train=[], valid=[])
812 |
813 | encoding_dict_full_inference['mask'] = dict(train=[], valid=[])
814 |
815 | for y_key in y_key_list:
816 | encoding_dict_full_inference[y_key] = dict(train=[], valid=[])
817 |
818 | # If model is behavior-supervised, create the keys for behavior reconstruction
819 | if self.config.model.supervise_behv:
820 | encoding_dict_full_inference['behv'] = dict(train=[], valid=[])
821 | encoding_dict_full_inference['behv_hat'] = dict(train=[], valid=[])
822 |
823 | # Dump variables to encoding_dict_full_inference
824 | for train_valid, loader in loaders.items():
825 | if isinstance(loader, torch.utils.data.dataloader.DataLoader):
826 | # Flatten the batches of neural observations, corresponding mask and behavior if model is supervised
827 | encoding_dict_full_inference['y'][train_valid] = encoding_dict['y'][train_valid].reshape(1, -1, self.dfine.dim_y)
828 | encoding_dict_full_inference['mask'][train_valid] = encoding_dict['mask'][train_valid].reshape(1, -1, 1)
829 |
830 | if self.config.model.supervise_behv:
831 | total_dim_behv = encoding_dict['behv'][train_valid].shape[-1]
832 | encoding_dict_full_inference['behv'][train_valid] = encoding_dict['behv'][train_valid].reshape(1, -1, total_dim_behv)
833 |
834 | # Keep track of latent inference start time
835 | start_time = timeit.default_timer()
836 | model_vars = self.dfine(y=encoding_dict_full_inference['y'][train_valid].to(self.device), mask=encoding_dict_full_inference['mask'][train_valid].to(self.device))
837 | encoding_dict_full_inference['latent_inference_time'][train_valid] += timeit.default_timer() - start_time
838 |
839 | # Append the inference variables to the empty lists created in the beginning
840 | encoding_dict_full_inference['x_pred'][train_valid] = model_vars['x_pred'].detach().cpu()
841 | encoding_dict_full_inference['x_filter'][train_valid] = model_vars['x_filter'].detach().cpu()
842 | encoding_dict_full_inference['x_smooth'][train_valid] = model_vars['x_smooth'].detach().cpu()
843 |
844 | encoding_dict_full_inference['a_hat'][train_valid] = model_vars['a_hat'].detach().cpu()
845 | encoding_dict_full_inference['a_pred'][train_valid] = model_vars['a_pred'].detach().cpu()
846 | encoding_dict_full_inference['a_filter'][train_valid] = model_vars['a_filter'].detach().cpu()
847 | encoding_dict_full_inference['a_smooth'][train_valid] = model_vars['a_smooth'].detach().cpu()
848 |
849 | encoding_dict_full_inference['y_hat'][train_valid] = model_vars['y_hat'].detach().cpu()
850 | encoding_dict_full_inference['y_pred'][train_valid] = model_vars['y_pred'].detach().cpu()
851 | encoding_dict_full_inference['y_filter'][train_valid] = model_vars['y_filter'].detach().cpu()
852 | encoding_dict_full_inference['y_smooth'][train_valid] = model_vars['y_smooth'].detach().cpu()
853 |
854 | for k in self.config.loss.steps_ahead:
855 | if k != 1:
856 | y_pred_k, _, _ = self.dfine.get_k_step_ahead_prediction(model_vars, k)
857 | encoding_dict_full_inference[f'y_{k}_pred'][train_valid] = y_pred_k
858 |
859 | if self.config.model.supervise_behv:
860 | encoding_dict_full_inference['behv_hat'][train_valid] = model_vars['behv_hat'].detach().cpu()
861 |
862 | # Dump batch and full inference encoding dictionaries into encoding_results
863 | encoding_results = dict(batch_inference=encoding_dict, full_inference=encoding_dict_full_inference)
864 |
865 | # Save encoding dictionary as .pt file
866 | if save_results:
867 | torch.save(encoding_results, os.path.join(self.config.model.save_dir, 'encoding_results.pt'))
868 |
869 | return encoding_results
870 |
871 |
872 | def write_summary(self, epoch, prefix='train'):
873 | '''
874 | Logs metrics to Tensorboard
875 |
876 | Parameters:
877 | ------------
878 | - epoch: int, Number of epoch for which to log metrics
879 | - prefix: str, Prefix to log metrics
880 | '''
881 |
882 | for key, val in self.metrics[prefix].items():
883 | self.writer.add_scalar(f'{prefix}/{key}', val.compute(), epoch)
884 |
885 | # Rest below is for logging scale values in the loss, will be same for all prefices, so log them only for 'train'
886 | if prefix != 'valid':
887 | self.writer.add_scalar(f'scale_l2', self.dfine.scale_l2, epoch)
888 | self.writer.add_scalar(f'learning_rate', self.lr_scheduler.get_last_lr()[0], epoch)
889 | if self.config.model.supervise_behv:
890 | self.writer.add_scalar(f'scale_behv_recons', self.dfine.scale_behv_recons, epoch)
891 |
892 |
--------------------------------------------------------------------------------