├── README.md ├── dLDS_discrete ├── LICENSE.txt ├── README.md ├── __pycache__ │ ├── create_params.cpython-310.pyc │ └── main_functions.cpython-310.pyc ├── build │ └── lib │ │ └── dlds_discrete │ │ ├── __init__.py │ │ ├── create_params.py │ │ ├── main_functions.py │ │ ├── test │ │ ├── __init__.py │ │ └── main_functions.py │ │ └── train_discrete_model_example.py ├── dLDS_discrete.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── dLDS_discrete2.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── dLDS_discrete_2022.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── dist │ ├── dLDS_discrete-0.0.1-py3-none-any.whl │ ├── dLDS_discrete-0.0.1.tar.gz │ ├── dLDS_discrete-0.0.11-py3-none-any.whl │ ├── dLDS_discrete-0.0.11.tar.gz │ ├── dLDS_discrete-0.0.2-py3-none-any.whl │ ├── dLDS_discrete-0.0.2.tar.gz │ ├── dLDS_discrete-0.0.3-py3-none-any.whl │ ├── dLDS_discrete-0.0.3.tar.gz │ ├── dLDS_discrete-0.0.4-py3-none-any.whl │ ├── dLDS_discrete-0.0.4.tar.gz │ ├── dLDS_discrete-0.0.5-py3-none-any.whl │ ├── dLDS_discrete-0.0.5.tar.gz │ ├── dLDS_discrete-0.0.6-py3-none-any.whl │ ├── dLDS_discrete-0.0.6.tar.gz │ ├── dLDS_discrete-0.0.7-py3-none-any.whl │ ├── dLDS_discrete-0.0.7.tar.gz │ ├── dLDS_discrete-0.0.8-py3-none-any.whl │ ├── dLDS_discrete-0.0.8.tar.gz │ ├── dLDS_discrete-0.0.9-py3-none-any.whl │ ├── dLDS_discrete-0.0.9.tar.gz │ ├── dLDS_discrete-0.0.91-py3-none-any.whl │ ├── dLDS_discrete-0.0.91.tar.gz │ ├── dLDS_discrete2-0.0.2-py3-none-any.whl │ ├── dLDS_discrete2-0.0.2.tar.gz │ ├── dLDS_discrete_2022-0.1.1-py3-none-any.whl │ └── dLDS_discrete_2022-0.1.1.tar.gz ├── dlds_discrete │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── main_functions.cpython-310.pyc │ ├── create_params.py │ ├── main_functions.py │ ├── test │ │ ├── __init__.py │ │ └── main_functions.py │ └── train_discrete_model_example.py ├── pyproject.toml └── setup.py ├── discrete_dlds_visualization.ipynb ├── discrete_results_notebook ├── FHN_with_reg_0_3_spgl1.npy ├── Worm1_WT_Stim.mat ├── Worm1_WT_Stim_Zhat_discretestates.mat ├── c_elegans_dlds_results.npy ├── fhn_dyn_non_reg.npy ├── fhn_reg_effect │ ├── fhn_5sub0_01regspgl1_iters10.npy │ ├── fhn_5sub0_05regspgl1_iters10.npy │ ├── fhn_5sub0_1regspgl1_iters10.npy │ ├── fhn_5sub0_4regspgl1_iters10.npy │ ├── fhn_5sub0_6regspgl1_iters10.npy │ ├── fhn_5sub0_7regspgl1_iters10.npy │ └── fhn_5sub0regspgl1_iters10.npy ├── lorenz_5sub0_55regspgl1_iters10.npy ├── lorenz_5sub0reg_results.pkl ├── lorenz_reg_effect │ ├── lorenz_5sub0_01regspgl1_iters10.npy │ ├── lorenz_5sub0_05regspgl1_iters10.npy │ ├── lorenz_5sub0_1regspgl1_iters10.npy │ ├── lorenz_5sub0_4regspgl1_iters10.npy │ ├── lorenz_5sub0_6regspgl1_iters10.npy │ ├── lorenz_5sub0_7regspgl1_iters10.npy │ └── lorenz_5sub0regspgl1_iters10.npy ├── multifhn_2sub0reg.npy ├── multilorenz_3sub0reg.npy ├── num_sub_fhn │ ├── fhn_2sub0_3regspgl1_iters10.npy │ ├── fhn_3sub0_3regspgl1_iters10.npy │ ├── fhn_4sub0_3regspgl1_iters10.npy │ ├── fhn_5sub0_3regspgl1_iters10.npy │ ├── fhn_6sub0_3regspgl1_iters10.npy │ └── fhn_7sub0_3regspgl1_iters10.npy └── num_sub_lorenz │ ├── lorenz_2sub0_55regspgl1_iters10.npy │ ├── lorenz_3sub0_55regspgl1_iters10.npy │ ├── lorenz_4sub0_55regspgl1_iters10.npy │ ├── lorenz_5sub0_55regspgl1_iters10.npy │ ├── lorenz_6sub0_55regspgl1_iters10.npy │ └── lorenz_7sub0_55regspgl1_iters10.npy └── paper figures ├── DynamicsLearningModel.png ├── FHN_dLDS_vs_rslds.png ├── c elegans.PNG ├── lorenz dLDS vs rSLDS.png └── lorenz_new_f1.png /README.md: -------------------------------------------------------------------------------- 1 | _The discrete part of the dLDS model described at:_ 2 | *** NOW AT JMLR *** 3 | **Noga Mudrik, Yenho Chen, Eva Yezerets, Christopher Rozell, Adam Charles. "Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics". 2024. JMLR [link](https://www.jmlr.org/papers/volume25/23-0777/23-0777.pdf)** 4 | 5 | Learning interpretable representations of neural dynamics at a population level is a crucial first step to understanding how neural activity patterns over time relate to perception and behavior. Models of neural dynamics often focus on either low-dimensional projections of neural activity, or on learning dynamical systems that explicitly relate to the neural state over time. We discuss how these two approaches are interrelated by considering dynamical systems as representative of flows on a low-dimensional manifold. Building on this concept, we propose a new decomposed dynamical system model that represents complex nonstationary and nonlinear dynamics of time-series data as a sparse combination of simpler, more interpretable components. The decomposed nature of the dynamics generalizes over previous switched approaches and enables modeling of overlapping and non-stationary drifts in the dynamics. We further present a dictionary learning- driven approach to model fitting, where we leverage recent results in tracking sparse vectors over time. We demonstrate that our model can learn efficient representations and smoothly transition between dynamical modes in both continuous-time and discrete-time examples. We show results on low-dimensional linear and nonlinear attractors to demonstrate that decomposed systems can well approximate nonlinear dynamics. Additionally, we apply our model to C. elegans data, illustrating a diversity of dynamics that is obscured when classified into discrete states. 6 | 7 | 8 | # Outine: 9 | ## A) Installation Instructions 10 | ## B) Main Model Figues 11 | ## C) Package and functions description 12 | 13 | ================================================================= 14 | # A) Installation and Code Instructions 15 | Our discrete model can also be pip-installed using the dlds_discrete package, as described at https://pypi.org/project/dLDS-discrete-2022/ 16 | 1. Make sure you have _os_, _pickle_, and _itertools_ installed in your python directory 17 | 2. In the cmd, write: _!pip install dLDS-discrete-2022_ 18 | 2. Import the package: _import dlds_discrete_ 19 | 3. Import all functions in the main_functions script: _from dlds_discrete.main_functions import *_ 20 | 4. Call the desired function, as described below (in section (C)) 21 | 22 | 23 | If you prefer to use the GitHub code, the main code is located at: _\dLDS-Discrete-Python-Model\dLDS_discrete\dlds_discrete_, where: 24 | 1) **'main_functions'** = Python script with the model and plotting functions. 25 | 2) **'create_params'** = Python script to define default parameters 26 | 3) **'train_discrete_model_example'** - an example for how to train our model 27 | 28 | The Python notebook at the home directory (**discrete_dlds_visualization.ipynb**) may help you understand how to use the code. 29 | 30 | ================================================================= 31 | # B) Main Model Figures 32 | 33 | ### Model descrition 34 | ![image](https://user-images.githubusercontent.com/90283200/171279434-f27ec55e-e34c-46c1-bb9a-7efb5b3c018c.png) 35 | 36 | ### Lorenz - visualization of coefficients smooth temporal evoulution 37 | ![image](https://user-images.githubusercontent.com/90283200/171278476-3bb4fa4b-935e-4334-851e-4a9ede69fa3c.png) 38 | 39 | 40 | ### Fitzhugh Nagumo model - dLDS vs rSLDS 41 | ![image](https://user-images.githubusercontent.com/90283200/171278183-0c1866b8-34a5-4a4c-9d26-9cb4f5ee7764.png) 42 | 43 | 44 | ### Lorenz - dLDS vs rSLDS 45 | ![image](https://user-images.githubusercontent.com/90283200/171278344-f0585304-cee0-499c-af1c-62aaa67028ec.png) 46 | 47 | ### _C. elegans_ 48 | ![image](https://user-images.githubusercontent.com/90283200/171279482-fb59ffa1-8755-475a-a97b-c161afc615a1.png) 49 | 50 | 51 | ================================================================= 52 | # C) Package and functions description 53 | 54 | ## Main Useful Functions: 55 | 56 | ### 1. create_dynamics: 57 | _create sample dynamics_ 58 | 59 | 60 | 61 | **create_dynamics**_(type_dyn = 'cyl', max_time = 1000, dt = 0.01, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2})_ 62 | 63 | #### Detailed Description: 64 | Create ground truth dynamics. 65 | Inputs: 66 | type_dyn = Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 67 | max_time = integer. Number of time points for the dynamics. Relevant only if data is empty; 68 | dt = time interval for the dynamics. 69 | params_ex = dictionary of parameters for the dynamics. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 70 | 71 | 72 | Outputs: 73 | dynamics: k X T matrix of the dynamics 74 | 75 | 76 | 77 | 78 | ### 2. train_model_include_D: 79 | _main function to train the model._ 80 | 81 | **train_model_include_D**_(max_time = 500, dt = 0.1, dynamics_type = 'cyl',num_subdyns = 3, 82 | error_reco = np.inf, data = [], step_f = 30, GD_decay = 0.85, max_error = 1e-3, 83 | max_iter = 3000, F = [], coefficients = [], params= {'update_c_type':'inv','reg_term':0,'smooth_term':0}, 84 | epsilon_error_change = 10**(-5), D = [], x_former =[], latent_dim = None, include_D = False,step_D = 30, reg1=0,reg_f =0 , 85 | max_data_reco = 1e-3, sigma_mix_f = 0.1, action_along_time = 'median', to_print = True, seed = 0, seed_f = 0, 86 | normalize_eig = True, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}, 87 | init_distant_F = False,max_corr = 0.1, decaying_reg = 1, other_params_c={}, include_last_up = False)_ 88 | 89 | #### Detailed Description: 90 | max_time = Number of time points for the dynamics. Relevant only if data is empty; 91 | dt = time interval for the dynamics 92 | dynamics_type = type of the dynamics. Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 93 | num_subdyns = number of sub-dynamics 94 | error_reco = intial error for the reconstruction (do not touch) 95 | data = if one wants to use a pre define groud-truth dynamics. If not empty - it overwrites max_time, dt, and dynamics_type 96 | step_f = initial step size for GD on the sub-dynamics 97 | GD_decay = Gradient descent decay rate 98 | max_error = Threshold for the model error. If the model arrives at a lower reconstruction error - the training ends. 99 | max_iter = # of max. iterations for training the model 100 | F = pre-defined sub-dynamics. Keep empty if random. 101 | coefficients = pre-defined coefficients. Keep empty if random. 102 | params = dictionary that includes info about the regularization and coefficients solver. e.g. {'update_c_type':'inv','reg_term':0,'smooth_term':0} 103 | epsilon_error_change = check if the sub-dynamics do not change by at least epsilon_error_change, for at least 5 last iterations. Otherwise - add noise to f 104 | D = pre-defined D matrix (keep empty if D = I) 105 | latent_dim = If D != I, it is the pre-defined latent dynamics. 106 | include_D = If True -> D !=I; If False -> D = I 107 | step_D = GD step for updating D, only if include_D is true 108 | reg1 = if include_D is true -> L1 regularization on D 109 | reg_f = if include_D is true -> Frobenius norm regularization on D 110 | max_data_reco = if include_D is true -> threshold for the error on the reconstruction of the data (continue training if the error (y - Dx)^2 > max_data_reco) 111 | sigma_mix_f = std of noise added to mix f 112 | action_along_time = the function to take on the error over time. Can be 'median' or 'mean' 113 | to_print = to print error value while training? (boolean) 114 | seed = random seed 115 | seed_f = random seed for initializing f 116 | normalize_eig = whether to normalize each sub-dynamic by dividing by the highest abs eval 117 | params_ex = parameters related to the creation of the ground truth dynamics. e.g. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2} 118 | init_distant_F = when initializing F -> make sure that the correlation between each pair of {f}_i does not exeed a threshold 119 | max_corr = max correlation between each pair of initial sub-dyns (relevant only if init_distant_F is True) 120 | decaying_reg = decaying factor for the l1 regularization on the coefficients. If 1 - there is no decay. (should be a scalar in (0,1]) 121 | other_params_c = additional parameters for the update step of c 122 | include_last_up = add another update step of the coefficients at the end 123 | 124 | * example call (for Lorenz, w. 3 operators): train_model_include_D(10, 0.01, 'lorenz', 3, GD_decay = 0.99) 125 | 126 | 127 | 128 | ### 3. create_reco: 129 | _create the dynamics reconstruction using the operators and coefficients obtained by dLDS (F, c)._ 130 | 131 | 132 | **create_reco**_(latent_dyn,coefficients, F, type_find = 'median',min_far =10, smooth_coeffs = False, 133 | smoothing_params = {'wind':5})_ 134 | #### Detailed Description: 135 | This function creates the reconstruction 136 | Inputs: 137 | latent_dyn = the ground truth latent dynamics 138 | coefficients = the operators coefficients ({$c(t)_i}) 139 | F = a list of transport operators (a list with M transport operators, 140 | each is a square matrix, kXk, where k is the latent dynamics 141 | dimension ) 142 | type_find = 'median' 143 | min_far = 10 144 | smooth_coeffs= False 145 | smoothing_params = {'wind':5} 146 | 147 | Outputs: 148 | cur_reco = dLDS reconstruction of the latent dynamics 149 | 150 | 151 | 152 | ### 4. visualize_dyn: 153 | _visualization of the dynamics, with various coloring options_ 154 | 155 | 156 | **visualize_dyn**_(dyn,ax = [], params_plot = {},turn_off_back = False, marker_size = 10, include_line = False, 157 | color_sig = [],cmap = 'cool', return_fig = False, color_by_dominant = False, coefficients =[], 158 | figsize = (5,5),colorbar = False, colors = [],vmin = None,vmax = None, color_mix = False, alpha = 0.4, 159 | colors_dyns = np.array(['r','g','b','yellow']), add_text = 't ', text_points = [],fontsize_times = 18, 160 | marker = "o",delta_text = 0.5, color_for_0 =None, legend = [],fig = [],return_mappable = False)_ 161 | #### Detailed Description: 162 | Inputs: 163 | dyn = dynamics to plot. Should be a np.array with size k X T 164 | ax = the subplot to plot in. (optional). If empty list -> the function will create a subplot 165 | params_plot = additional parameters for the plotting (optional). Can include plotting-related keys like xlabel, ylabel, title, etc. 166 | turn_off_back= disable backgroud of the plot? (optional). Boolean 167 | marker_size = marker size of the plot (optional). Integer 168 | include_line = add a curve to the plot (in addition to the scatter plot). Boolean 169 | color_sig = the color signal. 170 | If empty and color_by_dominant is true - color by the dominant dynamics. 171 | If empty and not color_by_dominant - color by time. 172 | cmap = color map 173 | colors = if not empty -> pre-defined colors for the different sub-dynamics. 174 | If empty -> colors are according to the cmap. 175 | color_mix = relevant only if color_by_dominant is True. In this case the colors need to be in the form of [r,g,b] 176 | Output: 177 | h (only if return_fig) -> returns the figure 178 | 179 | -------------------------------------------------------------------------------- /dLDS_discrete/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /dLDS_discrete/README.md: -------------------------------------------------------------------------------- 1 | **The discrete model described in:** _Noga Mudrik*, Yenho Chen*, Eva Yezerets, Christopher Rozell, Adam Charles. "Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics". 2022_ 2 | 3 | 4 | Learning interpretable representations of neural dynamics at a population level is 5 | a crucial first step to understanding how neural activity patterns over time relate 6 | to perception and behavior. Models of neural dynamics often focus on either 7 | low-dimensional projections of neural activity, or on learning dynamical systems 8 | that explicitly relate to the neural state over time. We discuss how these two 9 | approaches are interrelated by considering dynamical systems as representative of 10 | flows on a low-dimensional manifold. Building on this concept, we propose a new 11 | decomposed dynamical system model that represents complex nonstationary and 12 | nonlinear dynamics of time-series data as a sparse combination of simpler, more 13 | interpretable components. The decomposed nature of the dynamics generalizes 14 | over previous switched approaches and enables modeling of overlapping and 15 | non-stationary drifts in the dynamics. We further present a dictionary learning- 16 | driven approach to model fitting, where we leverage recent results in tracking sparse 17 | vectors over time. We demonstrate that our model can learn efficient representations 18 | and smoothly transition between dynamical modes in both continuous-time and 19 | discrete-time examples. We show results on low-dimensional linear and nonlinear 20 | attractors to demonstrate that decomposed systems can well approximate nonlinear 21 | dynamics. Additionally, we apply our model to C. elegans data, illustrating a 22 | diversity of dynamics that is obscured when classified into discrete states. 23 | 24 | # Installation Instructions: 25 | 1. (if itertools not installed): sudo pip3 install more-itertools [in the cmd] 26 | 2. (if pickle not installed): pip install pickle-mixin [in the cmd] 27 | 3. !pip install dLDS-discrete [in the cmd] 28 | 4. from dlds_discrete import main_functions [in Python console] 29 | 5. from dlds_discrete.main_functions import * [in Python console] 30 | 6. Use any function from the ones described below 31 | 32 | 33 | 34 | ## Main Useful Functions: 35 | 36 | ### 1. create_dynamics: 37 | _create sample dynamics_ 38 | 39 | 40 | 41 | **create_dynamics**_(type_dyn = 'cyl', max_time = 1000, dt = 0.01, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2})_ 42 | 43 | #### Detailed Description: 44 | Create ground truth dynamics. 45 | Inputs: 46 | type_dyn = Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 47 | max_time = integer. Number of time points for the dynamics. Relevant only if data is empty; 48 | dt = time interval for the dynamics. 49 | params_ex = dictionary of parameters for the dynamics. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 50 | 51 | 52 | Outputs: 53 | dynamics: k X T matrix of the dynamics 54 | 55 | 56 | 57 | 58 | ### 2. train_model_include_D: 59 | _main function to train the model._ 60 | 61 | **train_model_include_D**_(max_time = 500, dt = 0.1, dynamics_type = 'cyl',num_subdyns = 3, 62 | error_reco = np.inf, data = [], step_f = 30, GD_decay = 0.85, max_error = 1e-3, 63 | max_iter = 3000, F = [], coefficients = [], params= {'update_c_type':'inv','reg_term':0,'smooth_term':0}, 64 | epsilon_error_change = 10**(-5), D = [], x_former =[], latent_dim = None, include_D = False,step_D = 30, reg1=0,reg_f =0 , 65 | max_data_reco = 1e-3, sigma_mix_f = 0.1, action_along_time = 'median', to_print = True, seed = 0, seed_f = 0, 66 | normalize_eig = True, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}, 67 | init_distant_F = False,max_corr = 0.1, decaying_reg = 1, other_params_c={}, include_last_up = False)_ 68 | 69 | #### Detailed Description: 70 | max_time = Number of time points for the dynamics. Relevant only if data is empty; 71 | dt = time interval for the dynamics 72 | dynamics_type = type of the dynamics. Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 73 | num_subdyns = number of sub-dynamics 74 | error_reco = intial error for the reconstruction (do not touch) 75 | data = if one wants to use a pre define groud-truth dynamics. If not empty - it overwrites max_time, dt, and dynamics_type 76 | step_f = initial step size for GD on the sub-dynamics 77 | GD_decay = Gradient descent decay rate 78 | max_error = Threshold for the model error. If the model arrives at a lower reconstruction error - the training ends. 79 | max_iter = # of max. iterations for training the model 80 | F = pre-defined sub-dynamics. Keep empty if random. 81 | coefficients = pre-defined coefficients. Keep empty if random. 82 | params = dictionary that includes info about the regularization and coefficients solver. e.g. {'update_c_type':'inv','reg_term':0,'smooth_term':0} 83 | epsilon_error_change = check if the sub-dynamics do not change by at least epsilon_error_change, for at least 5 last iterations. Otherwise - add noise to f 84 | D = pre-defined D matrix (keep empty if D = I) 85 | latent_dim = If D != I, it is the pre-defined latent dynamics. 86 | include_D = If True -> D !=I; If False -> D = I 87 | step_D = GD step for updating D, only if include_D is true 88 | reg1 = if include_D is true -> L1 regularization on D 89 | reg_f = if include_D is true -> Frobenius norm regularization on D 90 | max_data_reco = if include_D is true -> threshold for the error on the reconstruction of the data (continue training if the error (y - Dx)^2 > max_data_reco) 91 | sigma_mix_f = std of noise added to mix f 92 | action_along_time = the function to take on the error over time. Can be 'median' or 'mean' 93 | to_print = to print error value while training? (boolean) 94 | seed = random seed 95 | seed_f = random seed for initializing f 96 | normalize_eig = whether to normalize each sub-dynamic by dividing by the highest abs eval 97 | params_ex = parameters related to the creation of the ground truth dynamics. e.g. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2} 98 | init_distant_F = when initializing F -> make sure that the correlation between each pair of {f}_i does not exeed a threshold 99 | max_corr = max correlation between each pair of initial sub-dyns (relevant only if init_distant_F is True) 100 | decaying_reg = decaying factor for the l1 regularization on the coefficients. If 1 - there is no decay. (should be a scalar in (0,1]) 101 | other_params_c = additional parameters for the update step of c 102 | include_last_up = add another update step of the coefficients at the end 103 | 104 | * example call (for Lorenz, w. 3 operators): train_model_include_D(10, 0.01, 'lorenz', 3, GD_decay = 0.99) 105 | 106 | 107 | 108 | ### 3. create_reco: 109 | _create the dynamics reconstruction using the operators and coefficients obtained by dLDS (F, c)._ 110 | 111 | 112 | **create_reco**_(latent_dyn,coefficients, F, type_find = 'median',min_far =10, smooth_coeffs = False, 113 | smoothing_params = {'wind':5})_ 114 | #### Detailed Description: 115 | This function creates the reconstruction 116 | Inputs: 117 | latent_dyn = the ground truth latent dynamics 118 | coefficients = the operators coefficients ({$c(t)_i}) 119 | F = a list of transport operators (a list with M transport operators, 120 | each is a square matrix, kXk, where k is the latent dynamics 121 | dimension ) 122 | type_find = 'median' 123 | min_far = 10 124 | smooth_coeffs= False 125 | smoothing_params = {'wind':5} 126 | 127 | Outputs: 128 | cur_reco = dLDS reconstruction of the latent dynamics 129 | 130 | 131 | 132 | ### 4. visualize_dyn: 133 | _visualization of the dynamics, with various coloring options_ 134 | 135 | 136 | **visualize_dyn**_(dyn,ax = [], params_plot = {},turn_off_back = False, marker_size = 10, include_line = False, 137 | color_sig = [],cmap = 'cool', return_fig = False, color_by_dominant = False, coefficients =[], 138 | figsize = (5,5),colorbar = False, colors = [],vmin = None,vmax = None, color_mix = False, alpha = 0.4, 139 | colors_dyns = np.array(['r','g','b','yellow']), add_text = 't ', text_points = [],fontsize_times = 18, 140 | marker = "o",delta_text = 0.5, color_for_0 =None, legend = [],fig = [],return_mappable = False)_ 141 | #### Detailed Description: 142 | Inputs: 143 | dyn = dynamics to plot. Should be a np.array with size k X T 144 | ax = the subplot to plot in. (optional). If empty list -> the function will create a subplot 145 | params_plot = additional parameters for the plotting (optional). Can include plotting-related keys like xlabel, ylabel, title, etc. 146 | turn_off_back= disable backgroud of the plot? (optional). Boolean 147 | marker_size = marker size of the plot (optional). Integer 148 | include_line = add a curve to the plot (in addition to the scatter plot). Boolean 149 | color_sig = the color signal. 150 | If empty and color_by_dominant is true - color by the dominant dynamics. 151 | If empty and not color_by_dominant - color by time. 152 | cmap = color map 153 | colors = if not empty -> pre-defined colors for the different sub-dynamics. 154 | If empty -> colors are according to the cmap. 155 | color_mix = relevant only if color_by_dominant is True. In this case the colors need to be in the form of [r,g,b] 156 | Output: 157 | h (only if return_fig) -> returns the figure 158 | 159 | -------------------------------------------------------------------------------- /dLDS_discrete/__pycache__/create_params.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/__pycache__/create_params.cpython-310.pyc -------------------------------------------------------------------------------- /dLDS_discrete/__pycache__/main_functions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/__pycache__/main_functions.cpython-310.pyc -------------------------------------------------------------------------------- /dLDS_discrete/build/lib/dlds_discrete/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 30 08:49:16 2022 4 | 5 | @author: noga mudrik 6 | """ 7 | 8 | from main_functions import * 9 | -------------------------------------------------------------------------------- /dLDS_discrete/build/lib/dlds_discrete/create_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics 3 | @code author: noga mudrik 4 | """ 5 | 6 | 7 | #%% Call Main Function 8 | 9 | 10 | from main_functions import * 11 | 12 | update_c_type = 'inv' # can be also lasso or smooth 13 | step_f = 30 14 | step_f_original = step_f 15 | num_subdyns = 3; 16 | dynamic_type ='cyl' 17 | max_time = 50 18 | dt = 0.1 19 | noise = 0 20 | speed_change = 0 21 | is_D_I = True 22 | reg_term = 0.01 23 | max_iter = 6000 24 | num_iter = max_iter 25 | GD_decay = 0.99 26 | max_error = 1e-8 27 | max_data_reco = max_error 28 | acumulated_error = False 29 | error_order_initial = 2 30 | error_step_add = 120 31 | error_step_max = 20 32 | error_step_max_display = error_step_max -2 #+ 2 33 | error_step_add = 120 34 | epsilon_error_change= 10**(-8) 35 | sigma_mix_f = 0.1 36 | weights_orders = np.linspace(1,2**error_step_max,error_step_max)[::-1] 37 | grad_vec = np.linspace(0.99,1.01,error_step_max) 38 | 39 | action_along_time = 'median' #can be mean 40 | 41 | latent_dyn = create_dynamics(type_dyn = dynamic_type, max_time = max_time, dt = dt) 42 | F = [init_mat((latent_dyn.shape[0], latent_dyn.shape[0]),normalize=True) for i in range(num_subdyns)] 43 | coefficients = init_mat((num_subdyns,latent_dyn.shape[1]-1)) 44 | include_D = False 45 | 46 | error_reco = np.inf 47 | error_reco_all = np.inf*np.ones((1,error_step_max)) 48 | cur_reco = create_reco(latent_dyn=latent_dyn, coefficients= coefficients, F=F) 49 | error_reco_array = np.inf*np.ones((1,max(error_step_max_display,error_step_max))) 50 | error_reco_array_med = np.inf*np.ones((1,max(error_step_max_display,error_step_max))) 51 | error_order = error_step_max 52 | 53 | seed_f = 0 54 | same_c = False 55 | data = [] 56 | 57 | 58 | 59 | dyn_radius = 5 60 | dyn_num_cyls = 5 61 | dyn_bias = 0 62 | addition_save = [] 63 | to_print = True 64 | name_auto = False 65 | to_save_without_ask = False 66 | ylim_small = [-15,15] 67 | 68 | seed_f = 0 69 | data =[] 70 | 71 | reg_term =0 72 | smooth_term = 0 73 | noise_max = 1 74 | max_error = 1e-8 75 | include_D = False 76 | exp_power = 0.1 77 | 78 | start_from_c = False 79 | 80 | start_sparse_c = False 81 | init_distant_F = False 82 | max_corr = 0.1 83 | decaying_reg = 0.999 84 | normalize_eig = True 85 | params_ex = {'radius':dyn_radius , 'num_cyls': dyn_num_cyls, 'bias':dyn_bias,'exp_power':exp_power} 86 | 87 | bias_term = False 88 | center_dynamics = False 89 | bias_out = False 90 | noise_interval = 0.1 91 | 92 | 93 | 94 | params_update_c = {'reg_term': reg_term, 'update_c_type':update_c_type,'smooth_term':smooth_term} 95 | 96 | 97 | 98 | 99 | width_des = 0.6 100 | t_break = 280 101 | factor_power = 0.3 102 | quarter_initial = 'low' 103 | smooth_window = 3 104 | colors = [[1,0.1,0],[0.1,0,1], [0,1,0]] 105 | s_scatter = 200 106 | start_run = 30 107 | plot_movemean = False 108 | max_time_plot = 500 109 | colors_dyn = ['r','g','b'] 110 | n_samples_range = np.arange(5,106,20) 111 | num_noises = 5 112 | max_dyn =5 113 | 114 | init_distant_F = True 115 | step_f = 50 116 | epsilon_error_change= 10**(-8) 117 | 118 | 119 | action_along_time = 'median' #can also be 'mean' 120 | 121 | error_reco = np.inf 122 | 123 | error_reco_array = [] -------------------------------------------------------------------------------- /dLDS_discrete/build/lib/dlds_discrete/test/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 30 08:49:16 2022 4 | 5 | @author: noga mudrik 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /dLDS_discrete/build/lib/dlds_discrete/train_discrete_model_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics 3 | @code author: noga mudrik 4 | """ 5 | 6 | #%% Imports: 7 | 8 | from importlib import reload 9 | 10 | import main_functions 11 | main_functions = reload(main_functions) 12 | from main_functions import * 13 | from datetime import date 14 | 15 | 16 | """ 17 | Parameters 18 | """ 19 | exec(open('create_params.py').read()) 20 | 21 | addi_save = date.today().strftime('%d%m%y') # For saving 22 | if 'addition_save' not in locals(): addition_save = [] 23 | 24 | update_c_types = ['inv'] #['spgl1'] 25 | num_iters = [10] 26 | max_iter = 6000 27 | is_D_I = True 28 | 29 | 30 | """ 31 | Parameters to choose 32 | """ 33 | dt = float(input('dt (rec for Lorenz 0.01, rec for FHN 0.2)')) 34 | max_time = float(input('max time (rec for Lorenz 10, rec for FHN 200)')) 35 | dynamic_type = input('dynamic type (e.g. lorenz, FHN)') 36 | addi_name = input('additional name id') 37 | num_subdyns = [int(input('num_dyns (m)'))] 38 | include_last_up = str2bool(input('include last up? (for FHN reg)')) 39 | reg_vals_new = [float(input('reg_val_input (tau)'))] 40 | addition_save.append(addi_save) 41 | latent_dyn = create_dynamics(type_dyn = dynamic_type, max_time = max_time, dt = dt) 42 | include_D = False 43 | to_load = False 44 | 45 | 46 | 47 | name_auto = True 48 | normalize_eig = True 49 | to_print = False 50 | seed_f = 0 51 | dt_range = np.linspace(0.001, 1, 20) 52 | exp_power = 0.1 53 | 54 | """ 55 | Runnining over the parameters 56 | """ 57 | for num_iter in num_iters: 58 | for reg_term in reg_vals_new: 59 | for update_c_type in update_c_types : 60 | for num_subs in num_subdyns: 61 | to_save_without_ask = True 62 | sigma_mix_f = 0.1 63 | F = [init_mat((latent_dyn.shape[0], latent_dyn.shape[0]),normalize=True) for i in range(num_subs)] 64 | coefficients = init_mat((num_subs,latent_dyn.shape[1]-1)) 65 | save_name = '%s_%gsub%greg%s_iters%s'%(dynamic_type,num_subs,reg_term, update_c_type, str(num_iter)) 66 | data = latent_dyn 67 | 68 | params_update_c = {'reg_term': reg_term, 'update_c_type':update_c_type,'smooth_term' :smooth_term, 'num_iters': num_iter, 'threshkind':'soft'} 69 | 70 | if to_save_without_ask: to_save = True 71 | else: to_save = str2bool(input('To save?')) 72 | 73 | coefficients, F, latent_dyn, error_reco_array, D = train_model_include_D(max_time , dt , dynamic_type, num_subdyns = num_subs, 74 | data = data, step_f = step_f, GD_decay = GD_decay, 75 | max_error = max_error, max_iter = max_iter, 76 | include_D = include_D, seed_f = seed_f, 77 | normalize_eig = normalize_eig, 78 | to_print = to_print, params = params_update_c ) 79 | if to_save: 80 | if name_auto: pass 81 | else: save_name = input('save_name') 82 | save_dict = {'F':F, 'coefficients':coefficients, 'latent_dyn': latent_dyn, 'max_time': max_time, 'dt':dt,'dyn_type':dynamic_type, 83 | 'error_reco_array' :error_reco_array, 'D':D} 84 | save_file_dynamics(save_name, ['main_folder_results', dynamic_type, 'clean%s'%addi_name,update_c_type ]+addition_save, save_dict ) 85 | save_file_dynamics(save_name, ['main_folder_results' ,dynamic_type, 'clean%s'%addi_name,update_c_type ]+addition_save, [], type_save = '.pkl' ) 86 | 87 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: dLDS-discrete 3 | Version: 0.0.91 4 | Summary: dLDS discrete model package 5 | Author: noga mudrik 6 | Author-email: 7 | Classifier: Programming Language :: Python :: 3 8 | Classifier: License :: OSI Approved :: MIT License 9 | Classifier: Operating System :: OS Independent 10 | Requires-Python: >=3.8 11 | Description-Content-Type: text/markdown 12 | License-File: LICENSE.txt 13 | 14 | **The discrete model described in:** _Noga Mudrik*, Yenho Chen*, Eva Yezerets, Christopher Rozell, Adam Charles. "Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics". 2022_ 15 | 16 | 17 | Learning interpretable representations of neural dynamics at a population level is 18 | a crucial first step to understanding how neural activity patterns over time relate 19 | to perception and behavior. Models of neural dynamics often focus on either 20 | low-dimensional projections of neural activity, or on learning dynamical systems 21 | that explicitly relate to the neural state over time. We discuss how these two 22 | approaches are interrelated by considering dynamical systems as representative of 23 | flows on a low-dimensional manifold. Building on this concept, we propose a new 24 | decomposed dynamical system model that represents complex nonstationary and 25 | nonlinear dynamics of time-series data as a sparse combination of simpler, more 26 | interpretable components. The decomposed nature of the dynamics generalizes 27 | over previous switched approaches and enables modeling of overlapping and 28 | non-stationary drifts in the dynamics. We further present a dictionary learning- 29 | driven approach to model fitting, where we leverage recent results in tracking sparse 30 | vectors over time. We demonstrate that our model can learn efficient representations 31 | and smoothly transition between dynamical modes in both continuous-time and 32 | discrete-time examples. We show results on low-dimensional linear and nonlinear 33 | attractors to demonstrate that decomposed systems can well approximate nonlinear 34 | dynamics. Additionally, we apply our model to C. elegans data, illustrating a 35 | diversity of dynamics that is obscured when classified into discrete states. 36 | 37 | # Installation Instructions: 38 | 1. (if itertools not installed): sudo pip3 install more-itertools [in the cmd] 39 | 2. (if pickle not installed): pip install pickle-mixin [in the cmd] 40 | 3. !pip install dLDS-discrete [in the cmd] 41 | 4. from dlds_discrete import main_functions [in Python console] 42 | 5. from main_functions import * [in Python console] 43 | 44 | 45 | 46 | ## Main Useful Functions: 47 | 48 | ### 1. create_dynamics: 49 | _create sample dynamics_ 50 | 51 | 52 | 53 | **create_dynamics**_(type_dyn = 'cyl', max_time = 1000, dt = 0.01, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2})_ 54 | 55 | #### Detailed Description: 56 | Create ground truth dynamics. 57 | Inputs: 58 | type_dyn = Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 59 | max_time = integer. Number of time points for the dynamics. Relevant only if data is empty; 60 | dt = time interval for the dynamics. 61 | params_ex = dictionary of parameters for the dynamics. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 62 | 63 | 64 | Outputs: 65 | dynamics: k X T matrix of the dynamics 66 | 67 | 68 | 69 | 70 | ### 2. train_model_include_D: 71 | _main function to train the model._ 72 | 73 | **train_model_include_D**_(max_time = 500, dt = 0.1, dynamics_type = 'cyl',num_subdyns = 3, 74 | error_reco = np.inf, data = [], step_f = 30, GD_decay = 0.85, max_error = 1e-3, 75 | max_iter = 3000, F = [], coefficients = [], params= {'update_c_type':'inv','reg_term':0,'smooth_term':0}, 76 | epsilon_error_change = 10**(-5), D = [], x_former =[], latent_dim = None, include_D = False,step_D = 30, reg1=0,reg_f =0 , 77 | max_data_reco = 1e-3, sigma_mix_f = 0.1, action_along_time = 'median', to_print = True, seed = 0, seed_f = 0, 78 | normalize_eig = True, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}, 79 | init_distant_F = False,max_corr = 0.1, decaying_reg = 1, other_params_c={}, include_last_up = False)_ 80 | 81 | #### Detailed Description: 82 | max_time = Number of time points for the dynamics. Relevant only if data is empty; 83 | dt = time interval for the dynamics 84 | dynamics_type = type of the dynamics. Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 85 | num_subdyns = number of sub-dynamics 86 | error_reco = intial error for the reconstruction (do not touch) 87 | data = if one wants to use a pre define groud-truth dynamics. If not empty - it overwrites max_time, dt, and dynamics_type 88 | step_f = initial step size for GD on the sub-dynamics 89 | GD_decay = Gradient descent decay rate 90 | max_error = Threshold for the model error. If the model arrives at a lower reconstruction error - the training ends. 91 | max_iter = # of max. iterations for training the model 92 | F = pre-defined sub-dynamics. Keep empty if random. 93 | coefficients = pre-defined coefficients. Keep empty if random. 94 | params = dictionary that includes info about the regularization and coefficients solver. e.g. {'update_c_type':'inv','reg_term':0,'smooth_term':0} 95 | epsilon_error_change = check if the sub-dynamics do not change by at least epsilon_error_change, for at least 5 last iterations. Otherwise - add noise to f 96 | D = pre-defined D matrix (keep empty if D = I) 97 | latent_dim = If D != I, it is the pre-defined latent dynamics. 98 | include_D = If True -> D !=I; If False -> D = I 99 | step_D = GD step for updating D, only if include_D is true 100 | reg1 = if include_D is true -> L1 regularization on D 101 | reg_f = if include_D is true -> Frobenius norm regularization on D 102 | max_data_reco = if include_D is true -> threshold for the error on the reconstruction of the data (continue training if the error (y - Dx)^2 > max_data_reco) 103 | sigma_mix_f = std of noise added to mix f 104 | action_along_time = the function to take on the error over time. Can be 'median' or 'mean' 105 | to_print = to print error value while training? (boolean) 106 | seed = random seed 107 | seed_f = random seed for initializing f 108 | normalize_eig = whether to normalize each sub-dynamic by dividing by the highest abs eval 109 | params_ex = parameters related to the creation of the ground truth dynamics. e.g. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2} 110 | init_distant_F = when initializing F -> make sure that the correlation between each pair of {f}_i does not exeed a threshold 111 | max_corr = max correlation between each pair of initial sub-dyns (relevant only if init_distant_F is True) 112 | decaying_reg = decaying factor for the l1 regularization on the coefficients. If 1 - there is no decay. (should be a scalar in (0,1]) 113 | other_params_c = additional parameters for the update step of c 114 | include_last_up = add another update step of the coefficients at the end 115 | 116 | * example call (for Lorenz, w. 3 operators): train_model_include_D(10, 0.01, 'lorenz', 3, GD_decay = 0.99) 117 | 118 | 119 | 120 | ### 3. create_reco: 121 | _create the dynamics reconstruction using the operators and coefficients obtained by dLDS (F, c)._ 122 | 123 | 124 | **create_reco**_(latent_dyn,coefficients, F, type_find = 'median',min_far =10, smooth_coeffs = False, 125 | smoothing_params = {'wind':5})_ 126 | #### Detailed Description: 127 | This function creates the reconstruction 128 | Inputs: 129 | latent_dyn = the ground truth latent dynamics 130 | coefficients = the operators coefficients ({$c(t)_i}) 131 | F = a list of transport operators (a list with M transport operators, 132 | each is a square matrix, kXk, where k is the latent dynamics 133 | dimension ) 134 | type_find = 'median' 135 | min_far = 10 136 | smooth_coeffs= False 137 | smoothing_params = {'wind':5} 138 | 139 | Outputs: 140 | cur_reco = dLDS reconstruction of the latent dynamics 141 | 142 | 143 | 144 | ### 4. visualize_dyn: 145 | _visualization of a dynamics, with various coloring options_ 146 | 147 | 148 | **visualize_dyn**_(dyn,ax = [], params_plot = {},turn_off_back = False, marker_size = 10, include_line = False, 149 | color_sig = [],cmap = 'cool', return_fig = False, color_by_dominant = False, coefficients =[], 150 | figsize = (5,5),colorbar = False, colors = [],vmin = None,vmax = None, color_mix = False, alpha = 0.4, 151 | colors_dyns = np.array(['r','g','b','yellow']), add_text = 't ', text_points = [],fontsize_times = 18, 152 | marker = "o",delta_text = 0.5, color_for_0 =None, legend = [],fig = [],return_mappable = False)_ 153 | #### Detailed Description: 154 | Inputs: 155 | dyn = dynamics to plot. Should be a np.array with size k X T 156 | ax = the subplot to plot in. (optional). If empty list -> the function will create a subplot 157 | params_plot = additional parameters for the plotting (optional). Can include plotting-related keys like xlabel, ylabel, title, etc. 158 | turn_off_back= disable backgroud of the plot? (optional). Boolean 159 | marker_size = marker size of the plot (optional). Integer 160 | include_line = add a curve to the plot (in addition to the scatter plot). Boolean 161 | color_sig = the color signal. 162 | If empty and color_by_dominant is true - color by the dominant dynamics. 163 | If empty and not color_by_dominant - color by time. 164 | cmap = color map 165 | colors = if not empty -> pre-defined colors for the different sub-dynamics. 166 | If empty -> colors are according to the cmap. 167 | color_mix = relevant only if color_by_dominant is True. In this case the colors need to be in the form of [r,g,b] 168 | Output: 169 | h (only if return_fig) -> returns the figure 170 | 171 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE.txt 2 | README.md 3 | pyproject.toml 4 | setup.py 5 | dLDS_discrete.egg-info/PKG-INFO 6 | dLDS_discrete.egg-info/SOURCES.txt 7 | dLDS_discrete.egg-info/dependency_links.txt 8 | dLDS_discrete.egg-info/requires.txt 9 | dLDS_discrete.egg-info/top_level.txt 10 | dlds_discrete/__init__.py 11 | dlds_discrete/create_params.py 12 | dlds_discrete/main_functions.py 13 | dlds_discrete/train_discrete_model_example.py 14 | dlds_discrete/test/__init__.py 15 | dlds_discrete/test/main_functions.py -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | scipy 5 | pandas 6 | webcolors 7 | seaborn 8 | colormap 9 | sklearn 10 | pylops 11 | dill 12 | mat73 13 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | dlds_discrete 2 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete2.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: dLDS-discrete2 3 | Version: 0.0.2 4 | Summary: dLDS discrete model package 5 | Author: noga mudrik 6 | Author-email: 7 | Classifier: Programming Language :: Python :: 3 8 | Classifier: License :: OSI Approved :: MIT License 9 | Classifier: Operating System :: OS Independent 10 | Requires-Python: >=3.6 11 | Description-Content-Type: text/markdown 12 | License-File: LICENSE.txt 13 | 14 | **DISCRETE MODEL VISUALZATIONS** 15 | 16 | 17 | **The discrete model described in:** _Noga Mudrik*, Yenho Chen*, Eva Yezerets, Christopher Rozell, Adam Charles. "Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics". 2022_ 18 | 19 | 20 | Learning interpretable representations of neural dynamics at a population level is 21 | a crucial first step to understanding how neural activity patterns over time relate 22 | to perception and behavior. Models of neural dynamics often focus on either 23 | low-dimensional projections of neural activity, or on learning dynamical systems 24 | that explicitly relate to the neural state over time. We discuss how these two 25 | approaches are interrelated by considering dynamical systems as representative of 26 | flows on a low-dimensional manifold. Building on this concept, we propose a new 27 | decomposed dynamical system model that represents complex nonstationary and 28 | nonlinear dynamics of time-series data as a sparse combination of simpler, more 29 | interpretable components. The decomposed nature of the dynamics generalizes 30 | over previous switched approaches and enables modeling of overlapping and 31 | non-stationary drifts in the dynamics. We further present a dictionary learning- 32 | driven approach to model fitting, where we leverage recent results in tracking sparse 33 | vectors over time. We demonstrate that our model can learn efficient representations 34 | and smoothly transition between dynamical modes in both continuous-time and 35 | discrete-time examples. We show results on low-dimensional linear and nonlinear 36 | attractors to demonstrate that decomposed systems can well approximate nonlinear 37 | dynamics. Additionally, we apply our model to C. elegans data, illustrating a 38 | diversity of dynamics that is obscured when classified into discrete states. 39 | 40 | 41 | ## Main Useful Functions: 42 | 43 | ### 1. create_dynamics: 44 | _create sample dynamics_ 45 | 46 | 47 | 48 | **create_dynamics**_(type_dyn = 'cyl', max_time = 1000, dt = 0.01, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2})_ 49 | 50 | #### Detailed Description: 51 | Create ground truth dynamics. 52 | Inputs: 53 | type_dyn = Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 54 | max_time = integer. Number of time points for the dynamics. Relevant only if data is empty; 55 | dt = time interval for the dynamics. 56 | params_ex = dictionary of parameters for the dynamics. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 57 | 58 | 59 | Outputs: 60 | dynamics: k X T matrix of the dynamics 61 | 62 | 63 | 64 | 65 | ### 2. train_model_include_D: 66 | _main function to train the model._ 67 | 68 | **train_model_include_D**_(max_time = 500, dt = 0.1, dynamics_type = 'cyl',num_subdyns = 3, 69 | error_reco = np.inf, data = [], step_f = 30, GD_decay = 0.85, max_error = 1e-3, 70 | max_iter = 3000, F = [], coefficients = [], params= {'update_c_type':'inv','reg_term':0,'smooth_term':0}, 71 | epsilon_error_change = 10**(-5), D = [], x_former =[], latent_dim = None, include_D = False,step_D = 30, reg1=0,reg_f =0 , 72 | max_data_reco = 1e-3, sigma_mix_f = 0.1, action_along_time = 'median', to_print = True, seed = 0, seed_f = 0, 73 | normalize_eig = True, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}, 74 | init_distant_F = False,max_corr = 0.1, decaying_reg = 1, other_params_c={}, include_last_up = False)_ 75 | 76 | #### Detailed Description: 77 | max_time = Number of time points for the dynamics. Relevant only if data is empty; 78 | dt = time interval for the dynamics 79 | dynamics_type = type of the dynamics. Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 80 | num_subdyns = number of sub-dynamics 81 | error_reco = intial error for the reconstruction (do not touch) 82 | data = if one wants to use a pre define groud-truth dynamics. If not empty - it overwrites max_time, dt, and dynamics_type 83 | step_f = initial step size for GD on the sub-dynamics 84 | GD_decay = Gradient descent decay rate 85 | max_error = Threshold for the model error. If the model arrives at a lower reconstruction error - the training ends. 86 | max_iter = # of max. iterations for training the model 87 | F = pre-defined sub-dynamics. Keep empty if random. 88 | coefficients = pre-defined coefficients. Keep empty if random. 89 | params = dictionary that includes info about the regularization and coefficients solver. e.g. {'update_c_type':'inv','reg_term':0,'smooth_term':0} 90 | epsilon_error_change = check if the sub-dynamics do not change by at least epsilon_error_change, for at least 5 last iterations. Otherwise - add noise to f 91 | D = pre-defined D matrix (keep empty if D = I) 92 | latent_dim = If D != I, it is the pre-defined latent dynamics. 93 | include_D = If True -> D !=I; If False -> D = I 94 | step_D = GD step for updating D, only if include_D is true 95 | reg1 = if include_D is true -> L1 regularization on D 96 | reg_f = if include_D is true -> Frobenius norm regularization on D 97 | max_data_reco = if include_D is true -> threshold for the error on the reconstruction of the data (continue training if the error (y - Dx)^2 > max_data_reco) 98 | sigma_mix_f = std of noise added to mix f 99 | action_along_time = the function to take on the error over time. Can be 'median' or 'mean' 100 | to_print = to print error value while training? (boolean) 101 | seed = random seed 102 | seed_f = random seed for initializing f 103 | normalize_eig = whether to normalize each sub-dynamic by dividing by the highest abs eval 104 | params_ex = parameters related to the creation of the ground truth dynamics. e.g. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2} 105 | init_distant_F = when initializing F -> make sure that the correlation between each pair of {f}_i does not exeed a threshold 106 | max_corr = max correlation between each pair of initial sub-dyns (relevant only if init_distant_F is True) 107 | decaying_reg = decaying factor for the l1 regularization on the coefficients. If 1 - there is no decay. (should be a scalar in (0,1]) 108 | other_params_c = additional parameters for the update step of c 109 | include_last_up = add another update step of the coefficients at the end 110 | 111 | * example call (for Lorenz, w. 3 operators): train_model_include_D(10, 0.01, 'lorenz', 3, GD_decay = 0.99) 112 | 113 | 114 | 115 | ### 3. create_reco: 116 | _create the dynamics reconstruction using the operators and coefficients obtained by dLDS (F, c)._ 117 | 118 | 119 | **create_reco**_(latent_dyn,coefficients, F, type_find = 'median',min_far =10, smooth_coeffs = False, 120 | smoothing_params = {'wind':5})_ 121 | #### Detailed Description: 122 | This function creates the reconstruction 123 | Inputs: 124 | latent_dyn = the ground truth latent dynamics 125 | coefficients = the operators coefficients ({$c(t)_i}) 126 | F = a list of transport operators (a list with M transport operators, 127 | each is a square matrix, kXk, where k is the latent dynamics 128 | dimension ) 129 | type_find = 'median' 130 | min_far = 10 131 | smooth_coeffs= False 132 | smoothing_params = {'wind':5} 133 | 134 | Outputs: 135 | cur_reco = dLDS reconstruction of the latent dynamics 136 | 137 | 138 | 139 | ### 4. visualize_dyn: 140 | _visualization of a dynamics, with various coloring options_ 141 | 142 | 143 | **visualize_dyn**_(dyn,ax = [], params_plot = {},turn_off_back = False, marker_size = 10, include_line = False, 144 | color_sig = [],cmap = 'cool', return_fig = False, color_by_dominant = False, coefficients =[], 145 | figsize = (5,5),colorbar = False, colors = [],vmin = None,vmax = None, color_mix = False, alpha = 0.4, 146 | colors_dyns = np.array(['r','g','b','yellow']), add_text = 't ', text_points = [],fontsize_times = 18, 147 | marker = "o",delta_text = 0.5, color_for_0 =None, legend = [],fig = [],return_mappable = False)_ 148 | #### Detailed Description: 149 | Inputs: 150 | dyn = dynamics to plot. Should be a np.array with size k X T 151 | ax = the subplot to plot in. (optional). If empty list -> the function will create a subplot 152 | params_plot = additional parameters for the plotting (optional). Can include plotting-related keys like xlabel, ylabel, title, etc. 153 | turn_off_back= disable backgroud of the plot? (optional). Boolean 154 | marker_size = marker size of the plot (optional). Integer 155 | include_line = add a curve to the plot (in addition to the scatter plot). Boolean 156 | color_sig = the color signal. 157 | If empty and color_by_dominant is true - color by the dominant dynamics. 158 | If empty and not color_by_dominant - color by time. 159 | cmap = color map 160 | colors = if not empty -> pre-defined colors for the different sub-dynamics. 161 | If empty -> colors are according to the cmap. 162 | color_mix = relevant only if color_by_dominant is True. In this case the colors need to be in the form of [r,g,b] 163 | Output: 164 | h (only if return_fig) -> returns the figure 165 | 166 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete2.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE.txt 2 | README.md 3 | pyproject.toml 4 | setup.py 5 | dLDS_discrete2.egg-info/PKG-INFO 6 | dLDS_discrete2.egg-info/SOURCES.txt 7 | dLDS_discrete2.egg-info/dependency_links.txt 8 | dLDS_discrete2.egg-info/top_level.txt 9 | dlds_discrete/__init__.py 10 | dlds_discrete/create_params.py 11 | dlds_discrete/main_functions.py 12 | dlds_discrete/train_discrete_model_example.py 13 | dlds_discrete/test/__init__.py 14 | dlds_discrete/test/main_functions.py -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete2.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete2.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | dlds_discrete 2 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete_2022.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: dLDS-discrete-2022 3 | Version: 0.1.1 4 | Summary: dLDS discrete model package 5 | Author: noga mudrik 6 | Author-email: 7 | Classifier: Programming Language :: Python :: 3 8 | Classifier: License :: OSI Approved :: MIT License 9 | Classifier: Operating System :: OS Independent 10 | Requires-Python: >=3.8 11 | Description-Content-Type: text/markdown 12 | License-File: LICENSE.txt 13 | 14 | **The discrete model described in:** _Noga Mudrik*, Yenho Chen*, Eva Yezerets, Christopher Rozell, Adam Charles. "Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics". 2022_ 15 | 16 | 17 | Learning interpretable representations of neural dynamics at a population level is 18 | a crucial first step to understanding how neural activity patterns over time relate 19 | to perception and behavior. Models of neural dynamics often focus on either 20 | low-dimensional projections of neural activity, or on learning dynamical systems 21 | that explicitly relate to the neural state over time. We discuss how these two 22 | approaches are interrelated by considering dynamical systems as representative of 23 | flows on a low-dimensional manifold. Building on this concept, we propose a new 24 | decomposed dynamical system model that represents complex nonstationary and 25 | nonlinear dynamics of time-series data as a sparse combination of simpler, more 26 | interpretable components. The decomposed nature of the dynamics generalizes 27 | over previous switched approaches and enables modeling of overlapping and 28 | non-stationary drifts in the dynamics. We further present a dictionary learning- 29 | driven approach to model fitting, where we leverage recent results in tracking sparse 30 | vectors over time. We demonstrate that our model can learn efficient representations 31 | and smoothly transition between dynamical modes in both continuous-time and 32 | discrete-time examples. We show results on low-dimensional linear and nonlinear 33 | attractors to demonstrate that decomposed systems can well approximate nonlinear 34 | dynamics. Additionally, we apply our model to C. elegans data, illustrating a 35 | diversity of dynamics that is obscured when classified into discrete states. 36 | 37 | # Installation Instructions: 38 | 1. (if itertools not installed): sudo pip3 install more-itertools [in the cmd] 39 | 2. (if pickle not installed): pip install pickle-mixin [in the cmd] 40 | 3. !pip install dLDS-discrete [in the cmd] 41 | 4. from dlds_discrete import main_functions [in Python console] 42 | 5. from dlds_discrete.main_functions import * [in Python console] 43 | 6. Use any function from the ones described below 44 | 45 | 46 | 47 | ## Main Useful Functions: 48 | 49 | ### 1. create_dynamics: 50 | _create sample dynamics_ 51 | 52 | 53 | 54 | **create_dynamics**_(type_dyn = 'cyl', max_time = 1000, dt = 0.01, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2})_ 55 | 56 | #### Detailed Description: 57 | Create ground truth dynamics. 58 | Inputs: 59 | type_dyn = Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 60 | max_time = integer. Number of time points for the dynamics. Relevant only if data is empty; 61 | dt = time interval for the dynamics. 62 | params_ex = dictionary of parameters for the dynamics. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 63 | 64 | 65 | Outputs: 66 | dynamics: k X T matrix of the dynamics 67 | 68 | 69 | 70 | 71 | ### 2. train_model_include_D: 72 | _main function to train the model._ 73 | 74 | **train_model_include_D**_(max_time = 500, dt = 0.1, dynamics_type = 'cyl',num_subdyns = 3, 75 | error_reco = np.inf, data = [], step_f = 30, GD_decay = 0.85, max_error = 1e-3, 76 | max_iter = 3000, F = [], coefficients = [], params= {'update_c_type':'inv','reg_term':0,'smooth_term':0}, 77 | epsilon_error_change = 10**(-5), D = [], x_former =[], latent_dim = None, include_D = False,step_D = 30, reg1=0,reg_f =0 , 78 | max_data_reco = 1e-3, sigma_mix_f = 0.1, action_along_time = 'median', to_print = True, seed = 0, seed_f = 0, 79 | normalize_eig = True, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}, 80 | init_distant_F = False,max_corr = 0.1, decaying_reg = 1, other_params_c={}, include_last_up = False)_ 81 | 82 | #### Detailed Description: 83 | max_time = Number of time points for the dynamics. Relevant only if data is empty; 84 | dt = time interval for the dynamics 85 | dynamics_type = type of the dynamics. Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 86 | num_subdyns = number of sub-dynamics 87 | error_reco = intial error for the reconstruction (do not touch) 88 | data = if one wants to use a pre define groud-truth dynamics. If not empty - it overwrites max_time, dt, and dynamics_type 89 | step_f = initial step size for GD on the sub-dynamics 90 | GD_decay = Gradient descent decay rate 91 | max_error = Threshold for the model error. If the model arrives at a lower reconstruction error - the training ends. 92 | max_iter = # of max. iterations for training the model 93 | F = pre-defined sub-dynamics. Keep empty if random. 94 | coefficients = pre-defined coefficients. Keep empty if random. 95 | params = dictionary that includes info about the regularization and coefficients solver. e.g. {'update_c_type':'inv','reg_term':0,'smooth_term':0} 96 | epsilon_error_change = check if the sub-dynamics do not change by at least epsilon_error_change, for at least 5 last iterations. Otherwise - add noise to f 97 | D = pre-defined D matrix (keep empty if D = I) 98 | latent_dim = If D != I, it is the pre-defined latent dynamics. 99 | include_D = If True -> D !=I; If False -> D = I 100 | step_D = GD step for updating D, only if include_D is true 101 | reg1 = if include_D is true -> L1 regularization on D 102 | reg_f = if include_D is true -> Frobenius norm regularization on D 103 | max_data_reco = if include_D is true -> threshold for the error on the reconstruction of the data (continue training if the error (y - Dx)^2 > max_data_reco) 104 | sigma_mix_f = std of noise added to mix f 105 | action_along_time = the function to take on the error over time. Can be 'median' or 'mean' 106 | to_print = to print error value while training? (boolean) 107 | seed = random seed 108 | seed_f = random seed for initializing f 109 | normalize_eig = whether to normalize each sub-dynamic by dividing by the highest abs eval 110 | params_ex = parameters related to the creation of the ground truth dynamics. e.g. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2} 111 | init_distant_F = when initializing F -> make sure that the correlation between each pair of {f}_i does not exeed a threshold 112 | max_corr = max correlation between each pair of initial sub-dyns (relevant only if init_distant_F is True) 113 | decaying_reg = decaying factor for the l1 regularization on the coefficients. If 1 - there is no decay. (should be a scalar in (0,1]) 114 | other_params_c = additional parameters for the update step of c 115 | include_last_up = add another update step of the coefficients at the end 116 | 117 | * example call (for Lorenz, w. 3 operators): train_model_include_D(10, 0.01, 'lorenz', 3, GD_decay = 0.99) 118 | 119 | 120 | 121 | ### 3. create_reco: 122 | _create the dynamics reconstruction using the operators and coefficients obtained by dLDS (F, c)._ 123 | 124 | 125 | **create_reco**_(latent_dyn,coefficients, F, type_find = 'median',min_far =10, smooth_coeffs = False, 126 | smoothing_params = {'wind':5})_ 127 | #### Detailed Description: 128 | This function creates the reconstruction 129 | Inputs: 130 | latent_dyn = the ground truth latent dynamics 131 | coefficients = the operators coefficients ({$c(t)_i}) 132 | F = a list of transport operators (a list with M transport operators, 133 | each is a square matrix, kXk, where k is the latent dynamics 134 | dimension ) 135 | type_find = 'median' 136 | min_far = 10 137 | smooth_coeffs= False 138 | smoothing_params = {'wind':5} 139 | 140 | Outputs: 141 | cur_reco = dLDS reconstruction of the latent dynamics 142 | 143 | 144 | 145 | ### 4. visualize_dyn: 146 | _visualization of a dynamics, with various coloring options_ 147 | 148 | 149 | **visualize_dyn**_(dyn,ax = [], params_plot = {},turn_off_back = False, marker_size = 10, include_line = False, 150 | color_sig = [],cmap = 'cool', return_fig = False, color_by_dominant = False, coefficients =[], 151 | figsize = (5,5),colorbar = False, colors = [],vmin = None,vmax = None, color_mix = False, alpha = 0.4, 152 | colors_dyns = np.array(['r','g','b','yellow']), add_text = 't ', text_points = [],fontsize_times = 18, 153 | marker = "o",delta_text = 0.5, color_for_0 =None, legend = [],fig = [],return_mappable = False)_ 154 | #### Detailed Description: 155 | Inputs: 156 | dyn = dynamics to plot. Should be a np.array with size k X T 157 | ax = the subplot to plot in. (optional). If empty list -> the function will create a subplot 158 | params_plot = additional parameters for the plotting (optional). Can include plotting-related keys like xlabel, ylabel, title, etc. 159 | turn_off_back= disable backgroud of the plot? (optional). Boolean 160 | marker_size = marker size of the plot (optional). Integer 161 | include_line = add a curve to the plot (in addition to the scatter plot). Boolean 162 | color_sig = the color signal. 163 | If empty and color_by_dominant is true - color by the dominant dynamics. 164 | If empty and not color_by_dominant - color by time. 165 | cmap = color map 166 | colors = if not empty -> pre-defined colors for the different sub-dynamics. 167 | If empty -> colors are according to the cmap. 168 | color_mix = relevant only if color_by_dominant is True. In this case the colors need to be in the form of [r,g,b] 169 | Output: 170 | h (only if return_fig) -> returns the figure 171 | 172 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete_2022.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE.txt 2 | README.md 3 | pyproject.toml 4 | setup.py 5 | dLDS_discrete_2022.egg-info/PKG-INFO 6 | dLDS_discrete_2022.egg-info/SOURCES.txt 7 | dLDS_discrete_2022.egg-info/dependency_links.txt 8 | dLDS_discrete_2022.egg-info/requires.txt 9 | dLDS_discrete_2022.egg-info/top_level.txt 10 | dlds_discrete/__init__.py 11 | dlds_discrete/create_params.py 12 | dlds_discrete/main_functions.py 13 | dlds_discrete/train_discrete_model_example.py 14 | dlds_discrete/test/__init__.py 15 | dlds_discrete/test/main_functions.py -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete_2022.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete_2022.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | scipy 4 | scipy 5 | pandas 6 | webcolors 7 | seaborn 8 | colormap 9 | sklearn 10 | pylops 11 | dill 12 | mat73 13 | easydev 14 | -------------------------------------------------------------------------------- /dLDS_discrete/dLDS_discrete_2022.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | dlds_discrete 2 | -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.1-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.1-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.1.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.11-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.11-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.11.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.11.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.2-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.2-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.2.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.3-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.3-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.3.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.3.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.4-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.4-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.4.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.4.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.5-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.5-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.5.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.5.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.6-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.6-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.6.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.6.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.7-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.7-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.7.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.7.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.8-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.8-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.8.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.8.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.9-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.9-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.9.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.9.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.91-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.91-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete-0.0.91.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete-0.0.91.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete2-0.0.2-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete2-0.0.2-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete2-0.0.2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete2-0.0.2.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete_2022-0.1.1-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete_2022-0.1.1-py3-none-any.whl -------------------------------------------------------------------------------- /dLDS_discrete/dist/dLDS_discrete_2022-0.1.1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dist/dLDS_discrete_2022-0.1.1.tar.gz -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 30 08:49:16 2022 4 | 5 | @author: noga mudrik 6 | """ 7 | import os 8 | 9 | #os.chdir('.dlds_discrete') 10 | #from main_functions import * 11 | -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dlds_discrete/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/__pycache__/main_functions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/dLDS_discrete/dlds_discrete/__pycache__/main_functions.cpython-310.pyc -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/create_params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics 3 | @code author: noga mudrik 4 | """ 5 | 6 | 7 | #%% Call Main Function 8 | 9 | 10 | from main_functions import * 11 | 12 | update_c_type = 'inv' # can be also lasso or smooth 13 | step_f = 30 14 | step_f_original = step_f 15 | num_subdyns = 3; 16 | dynamic_type ='cyl' 17 | max_time = 50 18 | dt = 0.1 19 | noise = 0 20 | speed_change = 0 21 | is_D_I = True 22 | reg_term = 0.01 23 | max_iter = 6000 24 | num_iter = max_iter 25 | GD_decay = 0.99 26 | max_error = 1e-8 27 | max_data_reco = max_error 28 | acumulated_error = False 29 | error_order_initial = 2 30 | error_step_add = 120 31 | error_step_max = 20 32 | error_step_max_display = error_step_max -2 #+ 2 33 | error_step_add = 120 34 | epsilon_error_change= 10**(-8) 35 | sigma_mix_f = 0.1 36 | weights_orders = np.linspace(1,2**error_step_max,error_step_max)[::-1] 37 | grad_vec = np.linspace(0.99,1.01,error_step_max) 38 | 39 | action_along_time = 'median' #can be mean 40 | 41 | latent_dyn = create_dynamics(type_dyn = dynamic_type, max_time = max_time, dt = dt) 42 | F = [init_mat((latent_dyn.shape[0], latent_dyn.shape[0]),normalize=True) for i in range(num_subdyns)] 43 | coefficients = init_mat((num_subdyns,latent_dyn.shape[1]-1)) 44 | include_D = False 45 | 46 | error_reco = np.inf 47 | error_reco_all = np.inf*np.ones((1,error_step_max)) 48 | cur_reco = create_reco(latent_dyn=latent_dyn, coefficients= coefficients, F=F) 49 | error_reco_array = np.inf*np.ones((1,max(error_step_max_display,error_step_max))) 50 | error_reco_array_med = np.inf*np.ones((1,max(error_step_max_display,error_step_max))) 51 | error_order = error_step_max 52 | 53 | seed_f = 0 54 | same_c = False 55 | data = [] 56 | 57 | 58 | 59 | dyn_radius = 5 60 | dyn_num_cyls = 5 61 | dyn_bias = 0 62 | addition_save = [] 63 | to_print = True 64 | name_auto = False 65 | to_save_without_ask = False 66 | ylim_small = [-15,15] 67 | 68 | seed_f = 0 69 | data =[] 70 | 71 | reg_term =0 72 | smooth_term = 0 73 | noise_max = 1 74 | max_error = 1e-8 75 | include_D = False 76 | exp_power = 0.1 77 | 78 | start_from_c = False 79 | 80 | start_sparse_c = False 81 | init_distant_F = False 82 | max_corr = 0.1 83 | decaying_reg = 0.999 84 | normalize_eig = True 85 | params_ex = {'radius':dyn_radius , 'num_cyls': dyn_num_cyls, 'bias':dyn_bias,'exp_power':exp_power} 86 | 87 | bias_term = False 88 | center_dynamics = False 89 | bias_out = False 90 | noise_interval = 0.1 91 | 92 | 93 | 94 | params_update_c = {'reg_term': reg_term, 'update_c_type':update_c_type,'smooth_term':smooth_term} 95 | 96 | 97 | 98 | 99 | width_des = 0.6 100 | t_break = 280 101 | factor_power = 0.3 102 | quarter_initial = 'low' 103 | smooth_window = 3 104 | colors = [[1,0.1,0],[0.1,0,1], [0,1,0]] 105 | s_scatter = 200 106 | start_run = 30 107 | plot_movemean = False 108 | max_time_plot = 500 109 | colors_dyn = ['r','g','b'] 110 | n_samples_range = np.arange(5,106,20) 111 | num_noises = 5 112 | max_dyn =5 113 | 114 | init_distant_F = True 115 | step_f = 50 116 | epsilon_error_change= 10**(-8) 117 | 118 | 119 | action_along_time = 'median' #can also be 'mean' 120 | 121 | error_reco = np.inf 122 | 123 | error_reco_array = [] -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/main_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics 3 | @code author: noga mudrik 4 | """ 5 | 6 | """ 7 | Imports 8 | """ 9 | 10 | # simaple imports 11 | import matplotlib 12 | import numpy as np 13 | from scipy import linalg 14 | import pandas as pd 15 | #import random 16 | 17 | # Plotting imports 18 | from webcolors import name_to_rgb 19 | import matplotlib.pyplot as plt 20 | import itertools 21 | import seaborn as sns 22 | from colormap import rgb2hex 23 | 24 | # Linear algebra imports 25 | from numpy.linalg import matrix_power 26 | from scipy.linalg import expm 27 | from sklearn import linear_model 28 | try: 29 | import pylops 30 | except: 31 | print('itertools was not uploaded') 32 | import itertools 33 | 34 | # os and files loading imports 35 | import os 36 | import dill 37 | import mat73 38 | import warnings 39 | import pickle 40 | sep = os.sep 41 | 42 | 43 | #%% FHN model 44 | 45 | 46 | def create_FHN(dt = 0.01, max_t = 100, I_ext = 0.5, b = 0.7, a = 0.8 , tau = 20, v0 = -0.5, w0 = 0, 47 | params = {'exp_power' : 0.9, 'change_speed': False}): 48 | time_points = np.arange(0, max_t, dt) 49 | if params['change_speed']: 50 | time_points = time_points**params['exp_power'] 51 | 52 | w_full = [] 53 | v_full = [] 54 | v = v0 55 | w = w0 56 | for t in time_points: 57 | v, w = cal_next_FHN(v,w, dt , max_t , I_ext , b, a , tau) 58 | v_full.append(v) 59 | w_full.append(w) 60 | return v_full, w_full 61 | 62 | 63 | 64 | def cal_next_FHN(v,w, dt = 0.01, max_t = 300, I_ext = 0.5, b = 0.7, a = 0.8 , tau = 20) : 65 | v_next = v + dt*(v - (v**3)/3 - w + I_ext) 66 | w_next = w + dt/tau*(v + a - b*w) 67 | return v_next, w_next 68 | 69 | #%% Lorenz attractor dynamics definition 70 | 71 | def lorenz(x, y, z, s=10, r=25, b=2.667): 72 | """ 73 | Inputs: 74 | x, y, z: a point of interest in three dimensional space 75 | s, r, b: parameters defining the lorenz attractor 76 | Outputs: 77 | x_dot, y_dot, z_dot: values of the lorenz attractor's partial 78 | derivatives at the point x, y, z 79 | """ 80 | x_dot = s*(y - x) 81 | y_dot = r*x - y - x*z 82 | z_dot = x*y - b*z 83 | return x_dot, y_dot, z_dot 84 | 85 | def create_lorenz_mat(t = [], initial_conds = (0., 1., 1.05) , txy = []): 86 | """ 87 | Create the lorenz dynamics 88 | """ 89 | if len(t) == 0: t = np.arange(0,1000,0.01) 90 | if len(txy) == 0: txy = t 91 | 92 | xs = np.zeros(len(t)-1) 93 | ys = np.zeros(len(t)-1) 94 | zs = np.zeros(len(t)-1) 95 | 96 | # Set initial values 97 | xs[0], ys[0], zs[0] = initial_conds 98 | 99 | 100 | for i in range(len(t[:-2])): 101 | dt_z = t[i+1] - t[i] 102 | dt_xy = txy[i+1] - txy[i] 103 | x_dot, y_dot, z_dot = lorenz(xs[i], ys[i], zs[i]) 104 | xs[i + 1] = xs[i] + (x_dot * dt_xy) 105 | ys[i + 1] = ys[i] + (y_dot * dt_xy) 106 | zs[i + 1] = zs[i] + (z_dot * dt_z) 107 | return xs, ys, zs 108 | 109 | def load_mat_file(mat_name , mat_path = '',sep = sep): 110 | """ 111 | Function to load mat files. Useful for uploading the c. elegans data. 112 | Example: 113 | load_mat_file('WT_Stim.mat') 114 | """ 115 | data_dict = mat73.loadmat(mat_path+sep+mat_name) 116 | return data_dict 117 | 118 | def create_dynamics(type_dyn = 'cyl', max_time = 1000, dt = 0.01, 119 | params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 120 | """ 121 | Create ground truth dynamics. 122 | Inputs: 123 | type_dyn = Can be 'cyl', 'lorenz','FHN', 'multi_cyl', 'torus', 'circ2d', 'spiral' 124 | max_time = integer. Number of time points for the dynamics. Relevant only if data is empty; 125 | dt = time interval for the dynamics. 126 | params_ex = dictionary of parameters for the dynamics. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}): 127 | 128 | 129 | Outputs: 130 | dynamics: k X T matrix of the dynamics 131 | 132 | """ 133 | t = np.arange(0, max_time, dt) 134 | if type_dyn == 'cyl': 135 | x = params_ex['radius']*np.sin(t) 136 | y = params_ex['radius']*np.cos(t) 137 | z = t + params_ex['bias'] 138 | 139 | 140 | dynamics = np.vstack([x.flatten(),y.flatten(),z.flatten()]) 141 | elif type_dyn == 'spiral': 142 | x = t*np.sin(t) 143 | y = t*np.cos(t) 144 | z = t 145 | 146 | dynamics = np.vstack([x.flatten(),y.flatten(),z.flatten()]) 147 | elif type_dyn == 'lorenz': 148 | txy = t 149 | 150 | x,y,z = create_lorenz_mat(t, txy = txy) 151 | dynamics = np.vstack([x.flatten(),y.flatten(),z.flatten()]) 152 | elif type_dyn == 'torus': 153 | R=5; r=1; 154 | u=np.arange(0,max_time,dt); 155 | v=np.arange(0,max_time,dt); 156 | [u,v]=np.meshgrid(u,v); 157 | x=(R+r*np.cos(v)) @ np.cos(u); 158 | y=(R+r*np.cos(v)) @ np.sin(u); 159 | z=r*np.sin(v); 160 | dynamics = np.vstack([x.flatten(),y.flatten(),z.flatten()]) 161 | elif type_dyn == 'circ2d': 162 | x = params_ex['radius']*np.sin(t) 163 | y = params_ex['radius']*np.cos(t) 164 | dynamics = np.vstack([x.flatten(),y.flatten()]) 165 | elif type_dyn == 'multi_cyl': 166 | dynamics0 = create_dynamics('cyl',max_time = 50,params_ex = params_ex) 167 | list_dyns = [] 168 | for dyn_num in range(params_ex['num_cyls']): 169 | np.random.seed(dyn_num) 170 | random_trans = np.random.rand(dynamics0.shape[0],dynamics0.shape[0])-0.5 171 | transformed_dyn = random_trans @ dynamics0 172 | list_dyns.append(transformed_dyn) 173 | dynamics = np.hstack(list_dyns) 174 | elif type_dyn == 'c_elegans': 175 | mat_c_elegans = load_mat_file('WT_NoStim.mat','E:\CoDyS-Python-rep-\other_models') # 176 | dynamics = mat_c_elegans['WT_NoStim']['traces'].T 177 | elif type_dyn == 'lorenz_2d': 178 | txy = t 179 | 180 | x,y,z = create_lorenz_mat(t, txy = txy) 181 | dynamics = np.vstack([x.flatten(),z.flatten()]) 182 | elif type_dyn.lower() == 'fhn': 183 | v_full, w_full = create_FHN(dt = dt, max_t = max_time, I_ext = 0.5, 184 | b = 0.7, a = 0.8 , tau = 20, v0 = -0.5, w0 = 0, 185 | params = {'exp_power' : params_ex['exp_power'], 'change_speed': False}) 186 | 187 | dynamics = np.vstack([v_full, w_full]) 188 | return dynamics 189 | 190 | 191 | 192 | #%% Basic Model Functions 193 | #%% Main Model Training 194 | def train_model_include_D(max_time = 500, dt = 0.1, dynamics_type = 'cyl',num_subdyns = 3, 195 | error_reco = np.inf, data = [], step_f = 30, GD_decay = 0.85, max_error = 1e-3, 196 | max_iter = 3000, F = [], coefficients = [], params= {'update_c_type':'inv','reg_term':0,'smooth_term':0}, 197 | epsilon_error_change = 10**(-5), D = [], x_former =[], latent_dim = None, include_D = False,step_D = 30, reg1=0,reg_f =0 , 198 | max_data_reco = 1e-3, sigma_mix_f = 0.1, action_along_time = 'median', to_print = True, seed = 0, seed_f = 0, 199 | normalize_eig = True, params_ex = {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2}, 200 | init_distant_F = False,max_corr = 0.1, decaying_reg = 1, other_params_c={}, include_last_up = False): 201 | 202 | """ 203 | This is the main function to train the model! 204 | Inputs: 205 | max_time = Number of time points for the dynamics. Relevant only if data is empty; 206 | dt = time interval for the dynamics 207 | dynamics_type = type of the dynamics. Can be 'cyl', 'lorenz', 'multi_cyl', 'torus', 'circ2d', 'spiral' 208 | num_subdyns = number of sub-dynamics 209 | error_reco = intial error for the reconstruction (do not touch) 210 | data = if one wants to use a pre define groud-truth dynamics. If not empty - it overwrites max_time, dt, and dynamics_type 211 | step_f = initial step size for GD on the sub-dynamics 212 | GD_decay = Gradient descent decay rate 213 | max_error = Threshold for the model error. If the model arrives at a lower reconstruction error - the training ends. 214 | max_iter = # of max. iterations for training the model 215 | F = pre-defined sub-dynamics. Keep empty if random. 216 | coefficients = pre-defined coefficients. Keep empty if random. 217 | params = dictionary that includes info about the regularization and coefficients solver. e.g. {'update_c_type':'inv','reg_term':0,'smooth_term':0} 218 | epsilon_error_change = check if the sub-dynamics do not change by at least epsilon_error_change, for at least 5 last iterations. Otherwise - add noise to f 219 | D = pre-defined D matrix (keep empty if D = I) 220 | latent_dim = If D != I, it is the pre-defined latent dynamics. 221 | include_D = If True -> D !=I; If False -> D = I 222 | step_D = GD step for updating D, only if include_D is true 223 | reg1 = if include_D is true -> L1 regularization on D 224 | reg_f = if include_D is true -> Frobenius norm regularization on D 225 | max_data_reco = if include_D is true -> threshold for the error on the reconstruction of the data (continue training if the error (y - Dx)^2 > max_data_reco) 226 | sigma_mix_f = std of noise added to mix f 227 | action_along_time = the function to take on the error over time. Can be 'median' or 'mean' 228 | to_print = to print error value while training? (boolean) 229 | seed = random seed 230 | seed_f = random seed for initializing f 231 | normalize_eig = whether to normalize each sub-dynamic by dividing by the highest abs eval 232 | params_ex = parameters related to the creation of the ground truth dynamics. e.g. {'radius':1, 'num_cyls': 5, 'bias':0,'exp_power':0.2} 233 | init_distant_F = when initializing F -> make sure that the correlation between each pair of {f}_i does not exeed a threshold 234 | max_corr = max correlation between each pair of initial sub-dyns (relevant only if init_distant_F is True) 235 | decaying_reg = decaying factor for the l1 regularization on the coefficients. If 1 - there is no decay. (should be a scalar in (0,1]) 236 | other_params_c = additional parameters for the update step of c 237 | include_last_up = add another update step of the coefficients at the end 238 | """ 239 | # if D != I 240 | if not include_D and len(data) > 1: latent_dyn = data 241 | 242 | step_f_original = step_f 243 | 244 | # Define data and number of dyns 245 | if len(data) == 0 : 246 | data = create_dynamics(type_dyn = dynamics_type, max_time = max_time, dt = dt, params_ex = params_ex) 247 | if not include_D: latent_dyn = data 248 | else: 249 | if isinstance(data, np.ndarray) and len(data) > 1: 250 | if not include_D: latent_dyn = data 251 | else: 252 | if len(data) == 1: 253 | data = data[0] 254 | if not include_D: latent_dyn = data 255 | else: 256 | raise ValueError('The parameter "data" is invalid') 257 | 258 | # Define # of time points 259 | n_times = data.shape[1] 260 | 261 | # Default value for the latent dimension if D != I 262 | if include_D and np.isnan(latent_dim): 263 | latent_dim = int(np.max([data.shape[0] / 5,3])) 264 | else: latent_dim = data.shape[0]; 265 | 266 | if include_D: # If model needs to study D 267 | if len(D) == 0: D = init_mat(size_mat = (data.shape[0], latent_dim) , dist_type ='sparse', init_params={'k':4}) 268 | elif D.shape[0] != data.shape[0]: raise ValueError('# of rows in D should be = # rows in the data ') 269 | 270 | else: 271 | latent_dyn = data 272 | 273 | if len(F) == 0: 274 | F = [init_mat((latent_dim, latent_dim),normalize=True,r_seed = seed_f+i) for i in range(num_subdyns)] 275 | # Check that initial F's are far enough from each other 276 | if init_distant_F: 277 | F = check_F_dist_init(F, max_corr = max_corr) 278 | 279 | """ 280 | Initialize Coeffs 281 | """ 282 | if len(coefficients) == 0: 283 | coefficients = init_mat((num_subdyns,n_times-1)) 284 | if len(params) == 0: params = {'update_c_type':'inv','reg_term':0,'smooth_term':0} 285 | 286 | if not include_D: 287 | cur_reco = create_reco(latent_dyn=latent_dyn, coefficients= coefficients, F=F) 288 | 289 | 290 | data_reco_error = np.inf 291 | 292 | 293 | counter = 1 294 | 295 | error_reco_array = [] 296 | 297 | while data_reco_error > max_data_reco and (counter < max_iter): 298 | 299 | ### Store Iteration Results 300 | 301 | 302 | """ 303 | Update x 304 | """ 305 | 306 | if include_D: 307 | latent_dyn = update_X(D, data,random_state=seed) 308 | 309 | 310 | """ 311 | Decay reg 312 | """ 313 | if params['update_c_type'] == 'lasso': 314 | params['reg_term'] = params['reg_term']*decaying_reg 315 | 316 | """ 317 | Update coefficients 318 | """ 319 | if counter != 1: coefficients = update_c(F,latent_dyn, params,random_state=seed,other_params=other_params_c) 320 | 321 | 322 | 323 | """ 324 | Update D 325 | """ 326 | 327 | if include_D: 328 | one_dyn: D = update_D(D, step_D, latent_dyn, data, reg1,reg_f) 329 | 330 | 331 | 332 | """ 333 | Update F 334 | """ 335 | 336 | F = update_f_all(latent_dyn,F,coefficients,step_f,normalize=False, action_along_time= action_along_time, normalize_eig = normalize_eig ) 337 | 338 | step_f *= GD_decay 339 | 340 | if include_D: 341 | data_reco_error = np.mean((data - D @ latent_dyn)**2) 342 | mid_reco = create_reco(latent_dyn, coefficients, F) 343 | error_reco = np.mean((latent_dyn -mid_reco)**2) 344 | 345 | error_reco_array.append(error_reco) 346 | 347 | if np.mean(np.abs(np.diff(error_reco_array[-5:]))) < epsilon_error_change: 348 | F = [f_i + sigma_mix_f*np.random.randn(f_i.shape[0],f_i.shape[1]) for f_i in F] 349 | print('mixed F') 350 | 351 | if to_print: 352 | print('Error = %s'%str(error_reco) ) 353 | if include_D: print('Error reco y = %s'%str(data_reco_error)) 354 | 355 | counter += 1 356 | if counter == max_iter: print('Arrived to max iter') 357 | 358 | 359 | # Post training adjustments 360 | if include_last_up: 361 | coefficients = update_c(F, latent_dyn,params, {'reg_term': 0, 'update_c_type':'inv','smooth_term' :0, 'num_iters': 10, 'threshkind':'soft'}) 362 | else: 363 | coefficients = update_c(F, latent_dyn, params,other_params=other_params_c) 364 | 365 | print(error_reco_array) 366 | if not include_D: 367 | D = []; 368 | return coefficients, F, latent_dyn, error_reco_array, D 369 | 370 | 371 | def update_D(former_D, step_D , x, y, reg1 = 0, reg_f= 0) : 372 | """ 373 | Update the matrix D by applying GD. Relevant just in case where D != I 374 | """ 375 | 376 | if reg1 == 0 and reg_f ==0: 377 | D = y @ linalg.pinv(x) 378 | else: 379 | basic_error = -2*(y - former_D @ x ) @ x.T 380 | if reg1 != 0: reg1_error = np.sum(np.sign(former_D)) 381 | else: reg1_error = 0 382 | if reg_f != 0: reg_f_error = 2*former_D 383 | reg_f_error = 0 384 | D = former_D - step_D *(basic_error + reg1*reg1_error + reg_f* reg_f_error) 385 | return D 386 | 387 | def update_X(D, data, reg1 = 0, former_x = [], random_state = 0, other_params ={}): 388 | """ 389 | Update the latent dynamics. Relevant just in case where D != I 390 | """ 391 | if reg1 == 0 : 392 | x = linalg.pinv(D) @ data 393 | else: 394 | clf = linear_model.Lasso(alpha=reg1,random_state=random_state, **other_params) 395 | clf.fit(D,data) 396 | x = np.array(clf.coef_) 397 | return x 398 | 399 | def check_F_dist_init(F, max_corr = 0.1): 400 | """ 401 | This function aims to validate that the matrices in F are far enough from each other 402 | """ 403 | combs = list(itertools.combinations(np.arange(len(F)),2)) 404 | corr_bool = [spec_corr(F[comb_s[0]],F[comb_s[1]]) > max_corr for comb_s in combs] 405 | counter= 100 406 | while (corr_bool == False).any(): 407 | counter +=1 408 | for comb_num,comb in enumerate(combs): 409 | if spec_corr(F[comb[0]],F[comb[1]]) > max_corr: 410 | fi_new = init_mat(np.shape(F[0]),dist_type = 'norm',r_seed = counter) 411 | F[comb[0]] = fi_new 412 | return F 413 | 414 | def spec_corr(v1,v2): 415 | """ 416 | absolute value of correlation 417 | """ 418 | corr = np.corrcoef(v1[:],v2[:]) 419 | return np.abs(corr[0,1]) 420 | 421 | 422 | def init_mat(size_mat, r_seed = 0, dist_type = 'norm', init_params = {'loc':0,'scale':1}, normalize = False): 423 | """ 424 | This is an initialization function to initialize matrices like G_i and c. 425 | Inputs: 426 | size_mat = 2-element tuple or list, describing the shape of the mat 427 | r_seed = random seed (should be integer) 428 | dist_type = distribution type for initialization; can be 'norm' (normal dist), 'uni' (uniform dist),'inti', 'sprase', 'regional' 429 | init_params = a dictionary with params for initialization. The keys depends on 'dist_type'. 430 | keys for norm -> ['loc','scale'] 431 | keys for inti and uni -> ['low','high'] 432 | keys for sparse -> ['k'] -> number of non-zeros in each row 433 | keys for regional -> ['k'] -> repeats of the sub-dynamics allocations 434 | normalize = whether to normalize the matrix 435 | Output: 436 | the random matrix with size 'size_mat' 437 | """ 438 | np.random.seed(r_seed) 439 | random.seed(r_seed) 440 | if dist_type == 'norm': 441 | rand_mat = np.random.normal(loc=init_params['loc'],scale = init_params['scale'], size= size_mat) 442 | elif dist_type == 'uni': 443 | if 'high' not in init_params.keys() or 'low' not in init_params.keys(): 444 | raise KeyError('Initialization did not work since low or high boundries were not set') 445 | rand_mat = np.random.uniform(init_params['low'],init_params['high'], size= size_mat) 446 | elif dist_type == 'inti': 447 | if 'high' not in init_params.keys() or 'low' not in init_params.keys(): 448 | raise KeyError('Initialization did not work since low or high boundries were not set') 449 | rand_mat = np.random.randint(init_params['low'],init_params['high'], size= size_mat) 450 | elif dist_type == 'sparse': 451 | if 'k' not in init_params.keys(): 452 | raise KeyError('Initialization did not work since k was not set') 453 | 454 | k=init_params['k'] 455 | b1 = [random.sample(list(np.arange(size_mat[0])),np.random.randint(1,np.min([size_mat[0],k]))) for i in range(size_mat[1])] 456 | b2 = [[i]*len(el) for i,el in enumerate(b1)] 457 | rand_mat = np.zeros((size_mat[0], size_mat[1])) 458 | rand_mat[np.hstack(b1), np.hstack(b2)] = 1 459 | elif dist_type == 'regional': 460 | if 'k' not in init_params.keys(): 461 | raise KeyError('Initialization did not work since k was not set for regional initialization') 462 | 463 | k=init_params['k'] 464 | splits = [len(split) for split in np.split(np.arange(size_mat[1]),k)] 465 | cur_repeats = [np.repeat(np.eye(size_mat[0]), int(np.ceil(split_len/size_mat[0])),axis = 1) for split_len in splits] 466 | cur_repeats = np.hstack(cur_repeats)[:size_mat[1]] 467 | 468 | rand_mat = cur_repeats 469 | else: 470 | raise NameError('Unknown dist type!') 471 | if normalize: 472 | rand_mat = norm_mat(rand_mat) 473 | return rand_mat 474 | 475 | 476 | def norm_mat(mat, type_norm = 'evals', to_norm = True): 477 | """ 478 | This function comes to norm matrices. 479 | Inputs: 480 | mat = the matrix to norm 481 | type_norm = what type of normalization to apply. Can be: 482 | - 'evals' - normalize by dividing by the max eigen-value 483 | - 'max' - divide by the maximum abs value in the matrix 484 | - 'exp' - normalization using matrix exponential (matrix exponential) 485 | to_norm = whether to norm or not to. 486 | Output: 487 | the normalized matrix 488 | """ 489 | if to_norm: 490 | if type_norm == 'evals': 491 | eigenvalues, _ = linalg.eig(mat) 492 | mat = mat / np.max(np.abs(eigenvalues)) 493 | elif type_norm == 'max': 494 | mat = mat / np.max(np.abs(mat)) 495 | elif type_norm == 'exp': 496 | mat = np.exp(-np.trace(mat))*expm(mat) 497 | return mat 498 | 499 | 500 | def update_c(F, latent_dyn, 501 | params_update_c = {'update_c_type':'inv','reg_term':0,'smooth_term':0, 'to_norm_fx' : False},clear_dyn = [], 502 | direction = 'c2n',other_params = {'warm_start':False},random_state=0 , skip_error = False, cofficients = []): 503 | """ 504 | The function comes to update the coefficients of the sub-dynamics, {c_i}, by solving the inverse or solving lasso. 505 | Inputs: 506 | F = list of sub-dynamics. Should be a list of k X k arrays. 507 | latent_dyn = latent_dynamics (dynamics dimensions X time) 508 | params_update_c = dictionary with keys: 509 | update_c_type = options: 510 | - 'inv' (least squares) 511 | - 'lasso' (sklearn lasso) 512 | - 'fista' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.FISTA.html) 513 | - 'omp' (https://pylops.readthedocs.io/en/latest/gallery/plot_ista.html#sphx-glr-gallery-plot-ista-py) 514 | - 'ista' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.ISTA.html) 515 | - 'IRLS' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.IRLS.html) 516 | - 'spgl1' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.SPGL1.html) 517 | 518 | 519 | - . Refers to the way the coefficients should be claculated (inv -> no l1 regularization) 520 | reg_term = scalar between 0 to 1, describe the reg. term on the cofficients 521 | smooth_term = scalar between 0 to 1, describe the smooth term on the cofficients (c_t - c_(t-1)) 522 | direction = can be c2n (clean to noise) OR n2c (noise to clean) 523 | other_params = additional parameters for the lasso solver (optional) 524 | random_state = random state for reproducability (optional) 525 | skip_error = whether to skip an error when solving the inverse for c (optional) 526 | cofficients = needed only if smooth_term > 0. This is the reference coefficients matrix to apply the constraint (c_hat_t - c_(t-1)) on. 527 | 528 | Outputs: 529 | coefficients matrix (k X T), type = np.array 530 | 531 | example: 532 | coeffs = update_c(np.random.rand(2,2), np.random.rand(2,15),{}) 533 | """ 534 | 535 | if isinstance(latent_dyn,list): 536 | if len(latent_dyn) == 1: several_dyns = False 537 | else: several_dyns = True 538 | else: 539 | several_dyns = False 540 | if several_dyns: 541 | n_times = latent_dyn[0].shape[1]-1 542 | else: 543 | n_times = latent_dyn.shape[1]-1 544 | 545 | 546 | params_update_c = {**{'update_c_type':'inv', 'smooth_term': 0, 'reg_term':0},**params_update_c} 547 | if len(clear_dyn) == 0: 548 | clear_dyn = latent_dyn 549 | if direction == 'n2c': 550 | latent_dyn, clear_dyn =clear_dyn, latent_dyn 551 | if isinstance(F,np.ndarray): F = [F] 552 | coeffs_list = [] 553 | 554 | 555 | for time_point in np.arange(n_times): 556 | if not several_dyns: 557 | cur_dyn = clear_dyn[:,time_point] 558 | next_dyn = latent_dyn[:,time_point+1] 559 | total_next_dyn = next_dyn 560 | f_x_mat = [] 561 | for f_i in F: 562 | f_x_mat.append(f_i @ cur_dyn) 563 | stacked_fx = np.vstack(f_x_mat).T 564 | stacked_fx[stacked_fx> 10**8] = 10**8 565 | else: 566 | total_next_dyn = [] 567 | for dyn_num in range(len(latent_dyn)): 568 | cur_dyn = clear_dyn[dyn_num][:,time_point] 569 | next_dyn = latent_dyn[dyn_num][:,time_point+1] 570 | total_next_dyn.extend(next_dyn.flatten().tolist()) 571 | f_x_mat = [] 572 | for f_num,f_i in enumerate(F): 573 | f_x_mat.append(f_i @ cur_dyn) 574 | if dyn_num == 0: 575 | stacked_fx = np.vstack(f_x_mat).T 576 | else: 577 | stacked_fx = np.vstack([stacked_fx, np.vstack(f_x_mat).T ]) 578 | stacked_fx[stacked_fx> 10**8] = 10**8 579 | 580 | total_next_dyn = np.reshape(np.array(total_next_dyn), (-1,1)) 581 | if len(F) == 1: stacked_fx = np.reshape(stacked_fx,[-1,1]) 582 | if params_update_c['smooth_term'] > 0 and time_point > 0 : 583 | if len(cofficients) == 0: 584 | warnings.warn("Warning: you called the smoothing option without defining coefficients") 585 | if params_update_c['smooth_term'] > 0 and time_point > 0 and len(cofficients) > 0 : 586 | c_former = cofficients[:,time_point-1].reshape((-1,1)) 587 | total_next_dyn_full = np.hstack([total_next_dyn, np.sqrt(params_update_c['smooth_term'])*c_former]) 588 | stacked_fx_full = np.hstack([stacked_fx, np.sqrt(params_update_c['smooth_term'])*np.eye(len(stacked_fx))]) 589 | else: 590 | total_next_dyn_full = total_next_dyn 591 | stacked_fx_full = stacked_fx 592 | 593 | if params_update_c['update_c_type'] == 'inv' or (params_update_c['reg_term'] == 0 and params_update_c['smooth_term'] == 0): 594 | try: 595 | coeffs =linalg.pinv(stacked_fx_full) @ total_next_dyn_full.reshape((-1,1)) 596 | except: 597 | if not skip_error: 598 | raise NameError('A problem in taking the inverse of fx when looking for the model coefficients') 599 | else: 600 | return np.nan*np.ones((len(F), latent_dyn.shape[1])) 601 | elif params_update_c['update_c_type'] == 'lasso' : 602 | 603 | clf = linear_model.Lasso(alpha=params_update_c['reg_term'],random_state=random_state, **other_params) 604 | clf.fit(stacked_fx_full,total_next_dyn_full.T ) 605 | coeffs = np.array(clf.coef_) 606 | 607 | elif params_update_c['update_c_type'].lower() == 'fista' : 608 | Aop = pylops.MatrixMult(stacked_fx_full) 609 | #print('fista') 610 | if 'threshkind' not in params_update_c: params_update_c['threshkind'] ='soft' 611 | 612 | coeffs = pylops.optimization.sparsity.FISTA(Aop, total_next_dyn_full.flatten(), niter=params_update_c['num_iters'],eps = params_update_c['reg_term'] , threshkind = params_update_c.get('threshkind') )[0] 613 | 614 | elif params_update_c['update_c_type'].lower() == 'ista' : 615 | #print('ista') 616 | 617 | if 'threshkind' not in params_update_c: params_update_c['threshkind'] ='soft' 618 | Aop = pylops.MatrixMult(stacked_fx_full) 619 | coeffs = pylops.optimization.sparsity.ISTA(Aop, total_next_dyn_full.flatten(), niter=params_update_c['num_iters'] , 620 | eps = params_update_c['reg_term'],threshkind = params_update_c.get('threshkind'))[0] 621 | 622 | 623 | 624 | elif params_update_c['update_c_type'].lower() == 'omp' : 625 | #print('omp') 626 | Aop = pylops.MatrixMult(stacked_fx_full) 627 | coeffs = pylops.optimization.sparsity.OMP(Aop, total_next_dyn_full.flatten(), niter_outer=params_update_c['num_iters'], sigma=params_update_c['reg_term'])[0] 628 | 629 | 630 | elif params_update_c['update_c_type'].lower() == 'spgl1' : 631 | #print('spgl1') 632 | Aop = pylops.MatrixMult(stacked_fx_full) 633 | coeffs = pylops.optimization.sparsity.SPGL1(Aop, total_next_dyn_full.flatten(),iter_lim = params_update_c['num_iters'], 634 | tau = params_update_c['reg_term'])[0] 635 | 636 | 637 | elif params_update_c['update_c_type'].lower() == 'irls' : 638 | #print('irls') 639 | Aop = pylops.MatrixMult(stacked_fx_full) 640 | 641 | coeffs = pylops.optimization.sparsity.IRLS(Aop, total_next_dyn_full.flatten(), nouter=50, espI = params_update_c['reg_term'])[0] 642 | 643 | 644 | else: 645 | 646 | 647 | raise NameError('Unknown update c type') 648 | coeffs_list.append(coeffs.flatten()) 649 | coeffs_final = np.vstack(coeffs_list) 650 | 651 | return coeffs_final.T 652 | 653 | 654 | def create_next(latent_dyn, coefficients, F,time_point): 655 | """ 656 | This function evaluate the dynamics at t+1 given the value of the dynamics at time t, the sub-dynamics, and other model parameters 657 | Inputs: 658 | latent_dyn = the latent dynamics (can be either ground truth or estimated). [k X T] 659 | coefficients = the sub-dynamics coefficients (used by the model) 660 | F = a list of np.arrays, each np.array is a sub-dynamic with size kXk 661 | time_point = current time point 662 | order = how many time points in the future we want to estimate 663 | Outputs: 664 | k X 1 np.array describing the dynamics at time_point+1 665 | 666 | """ 667 | if isinstance(F[0],list): 668 | F = [np.array(f_i) for f_i in F] 669 | 670 | if latent_dyn.shape[1] > 1: 671 | cur_A = np.dstack([coefficients[i,time_point]*f_i @ latent_dyn[:, time_point] for i,f_i in enumerate(F)]).sum(2).T 672 | else: 673 | cur_A = np.dstack([coefficients[i,time_point]*f_i @ latent_dyn for i,f_i in enumerate(F)]).sum(2).T 674 | return cur_A 675 | 676 | def create_ci_fi_xt(latent_dyn,F,coefficients, cumulative = False, mute_infs = 10**50, 677 | max_inf = 10**60): 678 | 679 | """ 680 | An intermediate step for the reconstruction - 681 | Specifically - It calculated the error that should be taken in the GD step for updating f: 682 | f - eta * output_of(create_ci_fi_xt) 683 | output: 684 | 3d array of the gradient step (unweighted): [k X k X time] 685 | """ 686 | 687 | if max_inf <= mute_infs: 688 | raise ValueError('max_inf should be higher than mute-infs') 689 | curse_dynamics = latent_dyn 690 | 691 | all_grads = [] 692 | for time_point in np.arange(latent_dyn.shape[1]-1): 693 | if cumulative: 694 | if time_point > 0: 695 | previous_A = cur_A 696 | else: 697 | previous_A = curse_dynamics[:,0] 698 | cur_A = create_next(np.reshape(previous_A,[-1,1]), coefficients, F,time_point) 699 | else: 700 | cur_A = create_next(curse_dynamics, coefficients, F,time_point) 701 | next_A = latent_dyn[:,time_point+1] 702 | 703 | """ 704 | The actual step 705 | """ 706 | 707 | if cumulative: 708 | gradient_val = (next_A - cur_A) @ previous_A.T 709 | else: 710 | gradient_val = (next_A - cur_A) @ curse_dynamics[:, time_point].T 711 | all_grads.append(gradient_val) 712 | return np.dstack(all_grads) 713 | 714 | 715 | def update_f_all(latent_dyn,F,coefficients,step_f, normalize = False, acumulated_error = False, 716 | action_along_time = 'mean', weights_power = 1.2, normalize_eig = True): 717 | 718 | """ 719 | Update all the sub-dynamics {f_i} using GD 720 | """ 721 | 722 | if action_along_time == 'mean': 723 | 724 | all_grads = create_ci_fi_xt(latent_dyn,F,coefficients) 725 | new_f_s = [norm_mat(f_i-2*step_f*norm_mat(np.mean(all_grads[:,:,:]*np.reshape(coefficients[i,:], [1,1,-1]), 2),to_norm = normalize),to_norm = normalize_eig ) for i,f_i in enumerate(F)] 726 | elif action_along_time == 'median': 727 | all_grads = create_ci_fi_xt(latent_dyn,F,coefficients) 728 | 729 | new_f_s = [norm_mat(f_i-2*step_f*norm_mat(np.median(all_grads[:,:,:]*np.reshape(coefficients[i,:], [1,1,-1]), 2),to_norm = normalize),to_norm = normalize_eig ) for i,f_i in enumerate(F)] 730 | 731 | else: 732 | raise NameError('Unknown action along time. Should be mean or median') 733 | for f_num in range(len(new_f_s)): 734 | rand_mat = np.random.rand(new_f_s[f_num].shape[0],new_f_s[f_num].shape[1]) 735 | new_f_s[f_num][np.isnan(new_f_s[f_num])] = rand_mat[np.isnan(new_f_s[f_num])] .flatten() 736 | 737 | return new_f_s 738 | 739 | 740 | 741 | 742 | #%% Plotting functions 743 | def add_bar_dynamics(coefficients_n, ax_all_all = [],min_max_points = [10,100,200,300,400,500], 744 | colors = np.array(['r','g','b','yellow']), centralize = False): 745 | if isinstance(ax_all_all, list) and len(ax_all_all) == 0: 746 | fig, ax_all_all = plt.subplots(1,len(min_max_points), figsize = (8*len(min_max_points), 7)) 747 | 748 | max_bar = np.max(np.abs(coefficients_n[:,min_max_points])) 749 | for pair_num,val in enumerate(min_max_points): 750 | ax_all = ax_all_all[pair_num] 751 | 752 | 753 | ax_all.bar(np.arange(coefficients_n.shape[0]),coefficients_n[:,val], 754 | color = np.array(colors)[:coefficients_n.shape[0]], 755 | alpha = 0.3) 756 | # ax_all.set_title('t = %s'%str(val), fontsize = 40, fontweight = 'bold') 757 | ax_all.get_xaxis().set_ticks([]) #for ax in ax_all] 758 | ax_all.get_yaxis().set_ticks([]) #for ax in ax_all] 759 | ax_all.spines['top'].set_visible(False) 760 | 761 | ax_all.spines['right'].set_visible(False) 762 | ax_all.spines['bottom'].set_visible(False) 763 | ax_all.spines['left'].set_visible(False) 764 | ax_all.axhline(0, ls = '-',alpha = 0.5, color = 'black', lw = 6) 765 | ax_all.set_ylim([-max_bar,max_bar]) 766 | 767 | def plot_sub_effect(sub_dyn, rec_rad_all = 5, colors = ['r','g','b','m'], alpha = 0.8, ax = [], 768 | n_points = 100, figsize = (10,10), params_labels = {'title':'sub-dyn effect'}, lw = 4): 769 | params_labels = {**{'zlabel':None}, **params_labels} 770 | if isinstance(ax,list) and len(ax) == 0: 771 | fig, ax = plt.subplots(figsize = figsize) 772 | if len(colors) == 1: colors = [colors]*4 773 | if not isinstance(rec_rad_all,list): rec_rad_all = [rec_rad_all] 774 | ax.axhline(0, alpha = 0.1, color = 'black', ls = 'dotted') 775 | ax.axvline(0, alpha = 0.1, color = 'black', ls = 'dotted') 776 | for rec_rad in rec_rad_all: 777 | ax.plot([-rec_rad, rec_rad],[rec_rad,rec_rad],alpha = alpha**2, color = colors[0], ls ='--',lw=lw) 778 | ax.plot([-rec_rad, rec_rad],[-rec_rad,-rec_rad],alpha = alpha**2, color = colors[1], ls = '--',lw=lw) 779 | ax.plot([rec_rad, rec_rad],[-rec_rad,rec_rad],alpha = alpha**2, color = colors[2], ls = '--',lw=lw) 780 | ax.plot([-rec_rad,-rec_rad], [ -rec_rad, rec_rad],alpha = alpha**2, color = colors[3], ls = '--',lw=lw) 781 | 782 | 783 | if not (sub_dyn == 0).all(): 784 | sub_dyn = norm_mat(sub_dyn, type_norm = 'evals') 785 | effect_up = sub_dyn @ np.vstack([np.linspace(-rec_rad, rec_rad, n_points), [rec_rad]*n_points]) 786 | effect_down = sub_dyn @ np.vstack([np.linspace(-rec_rad, rec_rad, n_points), [-rec_rad]*n_points]) 787 | effect_right = sub_dyn @ np.vstack([[rec_rad]*n_points,np.linspace(-rec_rad, rec_rad, n_points)]) 788 | effect_left = sub_dyn @ np.vstack([[-rec_rad]*n_points,np.linspace(-rec_rad, rec_rad, n_points)]) 789 | ax.plot(effect_up[0,:],effect_up[1,:],alpha = alpha, color = colors[0],lw=lw) 790 | ax.plot(effect_down[0,:],effect_down[1,:],alpha = alpha, color = colors[1],lw=lw) 791 | ax.plot(effect_right[0,:],effect_right[1,:],alpha = alpha, color = colors[2],lw=lw) 792 | ax.plot(effect_left[0,:],effect_left[1,:],alpha = alpha, color = colors[3],lw=lw) 793 | # Up 794 | add_arrow(ax, [0,rec_rad], [np.mean(effect_up[0,:]),np.mean(effect_up[1,:])],arrowprops = {'facecolor' :colors[0]}) 795 | add_arrow(ax, [0,-rec_rad], [np.mean(effect_down[0,:]),np.mean(effect_down[1,:])],arrowprops = {'facecolor' :colors[1]}) 796 | add_arrow(ax, [rec_rad,0], [np.mean(effect_right[0,:]),np.mean(effect_right[1,:])],arrowprops = {'facecolor' :colors[2]}) 797 | add_arrow(ax, [-rec_rad,0], [np.mean(effect_left[0,:]),np.mean(effect_left[1,:])],arrowprops = {'facecolor' :colors[3]}) 798 | add_labels(ax, **params_labels) 799 | 800 | 801 | def add_dummy_sub_legend(ax, colors,lenf, label_base = 'f'): 802 | dummy_lines = [] 803 | for i,color in enumerate(colors[:lenf]): 804 | dummy_lines.append(ax.plot([],[],c = color, label = '%s %s'%(label_base, str(i)))[0]) 805 | ax.set_title('Dynamics colored by mix of colors of the dominant dynamics') 806 | legend = ax.legend([dummy_lines[i] for i in range(len(dummy_lines))], ['f %s'%str(i) for i in range(len(colors))], loc = 'upper left') 807 | ax.legend() 808 | 809 | def plot_subs_effects_2d(F, colors =[['r','maroon','darkred','coral'],['forestgreen','limegreen','darkgreen','springgreen']] , alpha = 0.7 , rec_rad_all = 5, 810 | n_points = 100, params_labels = {'title':'sub-dyn effect'}, lw = 4, evec_colors = ['r','g'], include_dyn = False, loc_leg = 'upper left', 811 | axs = [], fig = []): 812 | if include_dyn: 813 | fig, axs = plt.subplots(len(F), 3, figsize = (35,8*len(F)),sharey='col', sharex = 'col') 814 | else: 815 | if isinstance(axs,list) and len(axs) == 0: 816 | fig, axs = plt.subplots(len(F), 2, figsize = (30,8*len(F)),sharey='col', sharex = 'col') 817 | 818 | if isinstance(colors[0], list): 819 | [plot_sub_effect(f_i, rec_rad_all , colors[i] , alpha, axs[i,1], n_points, params_labels = {'title':'f %s effect'%str(i+1)}, lw = lw) for i,f_i in enumerate(F)] 820 | 821 | else: 822 | [plot_sub_effect(f_i, rec_rad_all , colors , alpha, axs[i,1], n_points, params_labels = {'title':'f %s effect'%str(i+1)}, lw = lw) for i,f_i in enumerate(F)] 823 | 824 | [plot_evals_evecs(axs[i,0], f_i, evec_colors[i] , alpha) for i,f_i in enumerate(F)] 825 | dummy_lines = [] 826 | dummy_lines.append(axs[0,0].plot([],[], c="black", ls = '--', lw = lw)[0]) 827 | dummy_lines.append(axs[0,0].plot([],[], c="black", ls = '-', lw = lw)[0]) 828 | 829 | legend = axs[0,1].legend([dummy_lines[i] for i in [0,1]], ['Original', 'after sub-dynamic transform'], loc = loc_leg ) 830 | axs[0,1].add_artist(legend) 831 | if include_dyn: 832 | [quiver_plot(sub_dyn = f, ax = axs[i,2], chosen_color = evec_colors[i], type_plot='streamplot',cons_color =True, xlabel = 'dv',ylabel = 'dw') for i,f in enumerate(F)] 833 | [axs[i,2].set_title('f %s'%str(i+1), fontsize = 18) for i in range(len(F))] 834 | 835 | fig.subplots_adjust(wspace = 0.4, hspace = 0.4) 836 | 837 | 838 | def plot_subs(F, axs = [],params_F_plot = {'cmap':'PiYG'}, include_sup = True,annot = True): 839 | """ 840 | This function plots heatmaps of the sub-dynamics 841 | """ 842 | params_F_plot = {**{'cmap':'PiYG'},**params_F_plot} 843 | if isinstance(axs,list): 844 | if len(axs) == 0: 845 | fig, axs = plt.subplots(1,len(F), sharex = True,sharey = True) 846 | 847 | [sns.heatmap(f_i, ax = axs[i],annot=annot, **params_F_plot) for i,f_i in enumerate(F)] 848 | [ax.set_title('f#%g'%i) for i,ax in enumerate(axs)] 849 | if include_sup: plt.suptitle('Sub-Dynamics') 850 | plt.subplots_adjust(hspace = 0.5,wspace = 0.5) 851 | 852 | 853 | def plot_evals_evecs(ax, sub_dyn, colors =['r','g','b','m'] , alpha = 0.7, title ='evals'): 854 | eigenvalues, eigenvectors = linalg.eig(sub_dyn) 855 | for eval_num, eigenval in enumerate(eigenvalues): 856 | #add_arrow(ax, [0,0], [eigenvectors[0,eval_num], eigenvectors[1,eval_num]], arrowprops = {'facecolor' :colors[eval_num]}) 857 | ax.scatter( np.real(eigenval),np.imag(eigenval), alpha = alpha, color = colors, s = 300) #[eval_num]) 858 | ax.set_xlabel('Real') 859 | ax.set_ylabel('Imag') 860 | ax.axhline(0, alpha = 0.1, color = 'black', ls = 'dotted') 861 | ax.axvline(0, alpha = 0.1, color = 'black', ls = 'dotted') 862 | ax.set_title('evals') 863 | 864 | 865 | def plot_3d_color_scatter(latent_dyn,coefficients, ax = [], figsize = (15,10), delta = 0.4, colors = []): 866 | 867 | if latent_dyn.shape[0] != 3: 868 | print('Dynamics is not 3d') 869 | pass 870 | else: 871 | if len(colors) == 0: 872 | colors = ['r','g','b'] 873 | if isinstance(ax,list) and len(ax) == 0: 874 | fig, ax = plt.subplots(figsize = figsize, subplot_kw={'projection':'3d'}) 875 | for row in range(coefficients.shape[0]): 876 | coefficients_row = coefficients[row] 877 | coefficients_row[coefficients_row == 0] = 0.01 878 | 879 | ax.scatter(latent_dyn[0,:]+delta*row,latent_dyn[1,:]+delta*row,latent_dyn[2,:]+delta, s = coefficients_row**0.3, c = colors[row]) 880 | ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 881 | ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 882 | ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 883 | ax.grid(False) 884 | 885 | 886 | def add_arrow(ax, start, end,arrowprops = {'facecolor' : 'black', 'width':1, 'alpha' :0.2} ): 887 | arrowprops = {**{'facecolor' : 'black', 'width':1.5, 'alpha' :0.2, 'edgecolor':'none'}, **arrowprops} 888 | ax.annotate('',ha = 'center', va = 'bottom', xytext = start,xy =end, 889 | arrowprops = arrowprops) 890 | 891 | 892 | def rgb_to_hex(rgb_vec): 893 | r = rgb_vec[0]; g = rgb_vec[1]; b = rgb_vec[2] 894 | return rgb2hex(int(255*r), int(255*g), int(255*b)) 895 | 896 | def remove_edges(ax): 897 | ax.spines['top'].set_visible(False) 898 | ax.spines['right'].set_visible(False) 899 | ax.spines['bottom'].set_visible(False) 900 | ax.spines['left'].set_visible(False) 901 | ax.get_xaxis().set_ticks([]) 902 | ax.get_yaxis().set_ticks([]) 903 | 904 | def quiver_plot(sub_dyn = [], xmin = -5, xmax = 5, ymin = -5, ymax = 5, ax = [], chosen_color = 'red', 905 | alpha = 0.4, w = 0.02, type_plot = 'quiver', zmin = -5, zmax = 5, cons_color = False, 906 | return_artist = False,xlabel = 'x',ylabel = 'y',quiver_3d = False,inter=2): 907 | """ 908 | type_plot - can be either quiver or streamplot 909 | """ 910 | 911 | if len(sub_dyn) == 0: 912 | sub_dyn = np.array([[0,-1],[1,0]]) 913 | 914 | 915 | if ymin >= ymax: 916 | raise ValueError('ymin should be < ymax') 917 | elif xmin >=xmax: 918 | raise ValueError('xmin should be < xmax') 919 | else: 920 | 921 | if not quiver_3d: 922 | if isinstance(ax,list) and len(ax) == 0: 923 | fig, ax = plt.subplots() 924 | X, Y = np.meshgrid(np.arange(xmin, xmax), np.arange(ymin,ymax)) 925 | 926 | new_mat = sub_dyn - np.eye(len(sub_dyn)) 927 | 928 | U = new_mat[0,:] @ np.vstack([X.flatten(), Y.flatten()]) 929 | V = new_mat[1,:] @ np.vstack([X.flatten(), Y.flatten()]) 930 | 931 | if type_plot == 'quiver': 932 | h = ax.quiver(X,Y,U,V, color = chosen_color, alpha = alpha, width = w) 933 | elif type_plot == 'streamplot': 934 | 935 | 936 | x = np.linspace(xmin,xmax,100) 937 | y = np.linspace(ymin,ymax,100) 938 | X, Y = np.meshgrid(x, y) 939 | new_mat = sub_dyn - np.eye(len(sub_dyn)) 940 | U = new_mat[0,:] @ np.vstack([X.flatten(), Y.flatten()]) 941 | V = new_mat[1,:] @ np.vstack([X.flatten(), Y.flatten()]) 942 | 943 | 944 | if cons_color: 945 | 946 | if len(chosen_color[:]) == 3 and isinstance(chosen_color, (list,np.ndarray)): 947 | color_stream = rgb_to_hex(chosen_color) 948 | elif isinstance(chosen_color, str) and chosen_color[0] != '#': 949 | color_stream = list(name_to_rgb(chosen_color)) 950 | else: 951 | color_stream = chosen_color 952 | 953 | else: 954 | new_mat_color = np.abs(new_mat @ np.vstack([x.flatten(), y.flatten()])) 955 | color_stream = new_mat_color.T @ new_mat_color 956 | try: 957 | h = ax.streamplot(np.linspace(xmin,xmax,100),np.linspace(ymin,ymax,100),U.reshape(X.shape),V.reshape(Y.shape), color = color_stream) #chosen_color 958 | except: 959 | h = ax.streamplot(np.linspace(xmin,xmax,100),np.linspace(ymin,ymax,100),U.reshape(X.shape),V.reshape(Y.shape), color = chosen_color) #chosen_color 960 | else: 961 | raise NameError('Wrong plot name') 962 | else: 963 | if isinstance(ax,list) and len(ax) == 0: 964 | fig, ax = plt.subplots(subplot_kw={'projection':'3d'}) 965 | X, Y , Z = np.meshgrid(np.arange(xmin, xmax,inter), np.arange(ymin,ymax,inter), np.arange(zmin,zmax,inter)) 966 | 967 | new_mat = sub_dyn - np.eye(len(sub_dyn)) 968 | U = np.zeros(X.shape); V = np.zeros(X.shape); W = np.zeros(X.shape); 969 | 970 | for xloc in np.arange(X.shape[0]): 971 | for yloc in np.arange(X.shape[1]): 972 | for zloc in np.arange(X.shape[2]): 973 | U[xloc,yloc,zloc] = new_mat[0,:] @ np.array([X[xloc,yloc,zloc] ,Y[xloc,yloc,zloc] ,Z[xloc,yloc,zloc] ]).reshape((-1,1)) 974 | V[xloc,yloc,zloc] = new_mat[1,:] @ np.array([X[xloc,yloc,zloc] ,Y[xloc,yloc,zloc] ,Z[xloc,yloc,zloc] ]).reshape((-1,1)) 975 | W[xloc,yloc,zloc] = new_mat[2,:] @ np.array([X[xloc,yloc,zloc] ,Y[xloc,yloc,zloc] ,Z[xloc,yloc,zloc] ]).reshape((-1,1)) 976 | 977 | if type_plot == 'quiver': 978 | h = ax.quiver(X,Y,Z,U,V,W, color = chosen_color, alpha = alpha,lw = 1.5, length=0.8, normalize=True,arrow_length_ratio=0.5)#, width = w 979 | ax.grid(False) 980 | elif type_plot == 'streamplot': 981 | raise NameError('streamplot is not accepted for the 3d case') 982 | 983 | else: 984 | raise NameError('Wront plot name') 985 | if quiver_3d: zlabel ='z' 986 | else: zlabel = None 987 | 988 | add_labels(ax, zlabel = zlabel, xlabel = xlabel, ylabel = ylabel) 989 | if return_artist: return h 990 | 991 | 992 | def visualize_dyn(dyn,ax = [], params_plot = {},turn_off_back = False, marker_size = 10, include_line = False, 993 | color_sig = [],cmap = 'cool', return_fig = False, color_by_dominant = False, coefficients =[], 994 | figsize = (5,5),colorbar = False, colors = [],vmin = None,vmax = None, color_mix = False, alpha = 0.4, 995 | colors_dyns = np.array(['r','g','b','yellow']), add_text = 't ', text_points = [],fontsize_times = 18, 996 | marker = "o",delta_text = 0.5, color_for_0 =None, legend = [],fig = [],return_mappable = False): 997 | """ 998 | Plot the multi-dimensional dynamics 999 | Inputs: 1000 | dyn = dynamics to plot. Should be a np.array with size k X T 1001 | ax = the subplot to plot in (optional) 1002 | params_plot = additional parameters for the plotting (optional). Can include plotting-related keys like xlabel, ylabel, title, etc. 1003 | turn_off_back= disable backgroud of the plot? (optional). Boolean 1004 | marker_size = marker size of the plot (optional). Integer 1005 | include_line = add a curve to the plot (in addition to the scatter plot). Boolean 1006 | color_sig = the color signal. if empty and color_by_dominant - color by the dominant dynamics. If empty and not color_by_dominant - color by time. 1007 | cmap = cmap 1008 | colors = if not empty -> pre-defined colors for the different sub-dynamics. Otherwise - colors are according to the cmap. 1009 | color_mix = relevant only if color_by_dominant. In this case the colors need to be in the form of [r,g,b] 1010 | Output: 1011 | (only if return_fig) -> returns the figure 1012 | 1013 | """ 1014 | if not isinstance(color_sig,list) and not isinstance(color_sig,np.ndarray): color_sig = [color_sig] 1015 | 1016 | 1017 | if isinstance(ax,list) and len(ax) == 0: 1018 | if dyn.shape[0] == 3: 1019 | fig, ax = plt.subplots(figsize = figsize, subplot_kw={'projection':'3d'}) 1020 | else: 1021 | fig, ax = plt.subplots(figsize = figsize) 1022 | 1023 | 1024 | 1025 | if include_line: 1026 | if dyn.shape[0] == 3: 1027 | ax.plot(dyn[0,:], dyn[1,:], dyn[2,:],alpha = 0.2) 1028 | else: 1029 | ax.plot(dyn[0,:], dyn[1,:], alpha = 0.2) 1030 | if len(legend) > 0: 1031 | [ax.scatter([],[], c = colors_dyns[i], label = legend[i], s = 10) for i in np.arange(len(legend))] 1032 | ax.legend() 1033 | # Create color sig 1034 | if len(color_sig) == 0: 1035 | color_sig = np.arange(dyn.shape[1]) 1036 | if color_by_dominant and (coefficients.shape[1] == dyn.shape[1]-1 or coefficients.shape[1] == dyn.shape[1]): 1037 | if color_mix: 1038 | if len(colors) == 0 or not np.shape(colors)[0] == 3: raise ValueError('colors mat should have 3 rows') 1039 | else: 1040 | 1041 | color_sig = ((np.array(colors)[:,:coefficients.shape[0]] @ np.abs(coefficients)) / np.max(np.abs(coefficients).sum(0).reshape((1,-1)))).T 1042 | color_sig[np.isnan(color_sig) ] = 0.1 1043 | dyn = dyn[:,:-1] 1044 | else: 1045 | 1046 | color_sig_tmp = find_dominant_dyn(coefficients) 1047 | if len(colors_dyns) > 0: 1048 | color_sig = colors_dyns[color_sig_tmp] 1049 | elif len(color_sig) == 0: 1050 | color_sig=color_sig_tmp 1051 | else: 1052 | color_sig=np.array(color_sig)[color_sig_tmp] 1053 | if len(color_sig.flatten()) < dyn.shape[1]: dyn = dyn[:,:len(color_sig.flatten())] 1054 | if color_for_0: 1055 | 1056 | color_sig[np.sum(coefficients,0) == 0] = color_for_0 1057 | 1058 | 1059 | if dyn.shape[0] > 2: 1060 | if len(colors) == 0: 1061 | h = ax.scatter(dyn[0,:], dyn[1,:], dyn[2,:], marker = marker, s = marker_size,c= color_sig,cmap = cmap, alpha = alpha, 1062 | vmin = vmin, vmax = vmax) 1063 | else: 1064 | h = ax.scatter(dyn[0,:], dyn[1,:], dyn[2,:], marker =marker, s = marker_size,c= color_sig, alpha = alpha) 1065 | else: 1066 | dyn = np.array(dyn) 1067 | 1068 | if len(colors) == 0: 1069 | h = ax.scatter(dyn[0,:], dyn[1,:], marker = marker, s = marker_size,c= color_sig,cmap = cmap, alpha = alpha, 1070 | vmin = vmin, vmax = vmax) 1071 | else: 1072 | h = ax.scatter(dyn[0,:], dyn[1,:], marker = marker, s = marker_size,c= color_sig, alpha = alpha) 1073 | 1074 | params_plot['zlabel'] = None 1075 | if len(params_plot) > 0: 1076 | if dyn.shape[0] == 3: 1077 | if 'xlabel' in params_plot.keys(): 1078 | add_labels(ax, xlabel=params_plot.get('xlabel'), ylabel=params_plot.get('ylabel'), zlabel=params_plot.get('zlabel'), title=params_plot.get('title'), 1079 | xlim = params_plot.get('xlim'), ylim =params_plot.get('ylim'), zlim =params_plot.get('zlim')) 1080 | elif 'zlabel' in params_plot.keys(): 1081 | add_labels(ax, zlabel=params_plot.get('zlabel'), title=params_plot.get('title'), 1082 | xlim = params_plot.get('xlim'), ylim =params_plot.get('ylim'), zlim =params_plot.get('zlim')) 1083 | else: 1084 | add_labels(ax, title=params_plot.get('title'), 1085 | xlim = params_plot.get('xlim'), ylim =params_plot.get('ylim'), zlim =params_plot.get('zlim')) 1086 | else: 1087 | if 'xlabel' in params_plot.keys(): 1088 | add_labels(ax, xlabel=params_plot.get('xlabel'), ylabel=params_plot.get('ylabel'), zlabel=None, title=params_plot.get('title'), 1089 | xlim = params_plot.get('xlim'), ylim =params_plot.get('ylim'), zlim =None) 1090 | elif 'zlabel' in params_plot.keys(): 1091 | add_labels(ax, zlabel=None, title=params_plot.get('title'), 1092 | xlim = params_plot.get('xlim'), ylim =params_plot.get('ylim'), zlim =None) 1093 | else: 1094 | add_labels(ax, title=params_plot.get('title'), 1095 | xlim = params_plot.get('xlim'), ylim =params_plot.get('ylim'), zlim =None,zlabel = None); 1096 | if len(text_points) > 0: 1097 | 1098 | if dyn.shape[0] == 3: 1099 | [ax.text(dyn[0,t]+delta_text,dyn[1,t]+delta_text,dyn[2,t]+delta_text, '%s = %s'%(add_text, str(t)), fontsize =fontsize_times, fontweight = 'bold') for t in text_points] 1100 | else: 1101 | [ax.text(dyn[0,t]+delta_text,dyn[1,t]+delta_text, '%s = %s'%(add_text, str(t)), fontsize =fontsize_times, fontweight = 'bold') for t in text_points] 1102 | 1103 | remove_edges(ax) 1104 | ax.set_axis_off() 1105 | if colorbar: 1106 | fig.colorbar(h, cax=ax, position = 'top') 1107 | if return_mappable: 1108 | return h 1109 | 1110 | 1111 | 1112 | #%% Helper Functions and Post-Analysis Functions 1113 | def str2bool(str_to_change): 1114 | """ 1115 | Transform 'true' or 'yes' to True boolean variable 1116 | Example: 1117 | str2bool('true') - > True 1118 | """ 1119 | if isinstance(str_to_change, str): 1120 | str_to_change = (str_to_change.lower() == 'true') or (str_to_change.lower() == 'yes') 1121 | return str_to_change 1122 | 1123 | def norm_over_time(coefficients, type_norm = 'normal'): 1124 | if type_norm == 'normal': 1125 | coefficients_norm = (coefficients - np.mean(coefficients,1).reshape((-1,1)))/np.std(coefficients, 1).reshape((-1,1)) 1126 | return coefficients_norm 1127 | 1128 | def norm_coeffs(coefficients, type_norm, same_width = True,width_des = 0.7,factor_power = 0.9, min_width = 0.01): 1129 | """ 1130 | type_norm can be: 'sum_abs', 'norm','abs' 1131 | """ 1132 | if type_norm == 'norm': 1133 | coefficients_n = norm_over_time(np.abs(coefficients), type_norm = 'normal') 1134 | coefficients_n = coefficients_n - np.min(coefficients_n,1).reshape((-1,1)) 1135 | 1136 | elif type_norm == 'sum_abs': 1137 | coefficients[np.abs(coefficients) < min_width] = min_width 1138 | coefficients_n = np.abs(coefficients) / np.sum(np.abs(coefficients),1).reshape((-1,1)) 1139 | elif type_norm == 'abs': 1140 | coefficients[np.abs(coefficients) < min_width] = min_width 1141 | coefficients_n = np.abs(coefficients) 1142 | elif type_norm == 'no_norm': 1143 | coefficients_n = coefficients 1144 | else: 1145 | raise NameError('Invalid type_norm value') 1146 | 1147 | 1148 | coefficients_n[coefficients_n < min_width] = min_width 1149 | if same_width: coefficients_n = width_des*(np.abs(coefficients_n)**factor_power) / np.sum(np.abs(coefficients_n)**factor_power,axis = 0) 1150 | else: coefficients_n = np.abs(coefficients_n) / np.sum(np.abs(coefficients_n),axis = 0) 1151 | coefficients_n[coefficients_n < min_width] = min_width 1152 | return coefficients_n 1153 | 1154 | 1155 | def movmfunc(func, mat, window = 3, direction = 0): 1156 | """ 1157 | moving window with applying the function func on the matrix 'mat' towrads the direction 'direction' 1158 | """ 1159 | if len(mat.shape) == 1: 1160 | mat = mat.reshape((-1,1)) 1161 | direction = 0 1162 | addition = int(np.ceil((window-1)/2)) 1163 | if direction == 0: 1164 | mat_wrap = np.vstack([np.nan*np.ones((addition,np.shape(mat)[1])), mat, np.nan*np.ones((addition,np.shape(mat)[1]))]) 1165 | movefunc_res = np.vstack([func(mat_wrap[i-addition:i+addition,:],axis = direction) for i in range(addition, np.shape(mat_wrap)[0]-addition)]) 1166 | elif direction == 1: 1167 | mat_wrap = np.hstack([np.nan*np.ones((np.shape(mat)[0],addition)), mat, np.nan*np.ones((np.shape(mat)[0],addition))]) 1168 | movefunc_res = np.vstack([func(mat_wrap[:,i-addition:i+addition],axis = direction) for i in range(addition, np.shape(mat_wrap)[1]-addition)]).T 1169 | return movefunc_res 1170 | 1171 | def create_reco(latent_dyn,coefficients, F, type_find = 'median',min_far =10, smooth_coeffs = False, 1172 | smoothing_params = {'wind':5}): 1173 | """ 1174 | This function creates the reconstruction 1175 | Inputs: 1176 | latent_dyn = the ground truth latent dynamics 1177 | coefficients = the operators coefficients (c(t)_i) 1178 | F = a list of transport operators (a list with M transport operators, 1179 | each is a square matrix, kXk, where k is the latent dynamics 1180 | dimension ) 1181 | type_find = 'median' 1182 | min_far = 10 1183 | smooth_coeffs= False 1184 | smoothing_params = {'wind':5} 1185 | 1186 | Outputs: 1187 | cur_reco = dLDS reconstruction of the latent dynamics 1188 | 1189 | """ 1190 | if smooth_coeffs: 1191 | coefficients = movmfunc(np.nanmedian, coefficients, window = smoothing_params['wind'], direction = 1) 1192 | 1193 | 1194 | cur_reco = np.hstack([create_next(latent_dyn, coefficients, F,time_point) for time_point in range(latent_dyn.shape[1]-1)]) 1195 | cur_reco = np.hstack([latent_dyn[:,0].reshape((-1,1)),cur_reco]) 1196 | 1197 | 1198 | return cur_reco 1199 | 1200 | 1201 | def add_labels(ax, xlabel='X', ylabel='Y', zlabel='Z', title='', xlim = None, ylim = None, zlim = None,xticklabels = np.array([None]), 1202 | yticklabels = np.array([None] ), xticks = [], yticks = [], legend = [], ylabel_params = {},zlabel_params = {}, xlabel_params = {}, title_params = {}): 1203 | """ 1204 | This function add labels, titles, limits, etc. to figures; 1205 | Inputs: 1206 | ax = the subplot to edit 1207 | xlabel = xlabel 1208 | ylabel = ylabel 1209 | zlabel = zlabel (if the figure is 2d please define zlabel = None) 1210 | etc. 1211 | """ 1212 | if xlabel != '' and xlabel != None: ax.set_xlabel(xlabel, **xlabel_params) 1213 | if ylabel != '' and ylabel != None:ax.set_ylabel(ylabel, **ylabel_params) 1214 | if zlabel != '' and zlabel != None:ax.set_zlabel(zlabel,**ylabel_params) 1215 | if title != '' and title != None: ax.set_title(title, **title_params) 1216 | if xlim != None: ax.set_xlim(xlim) 1217 | if ylim != None: ax.set_ylim(ylim) 1218 | if zlim != None: ax.set_zlim(zlim) 1219 | 1220 | if (np.array(xticklabels) != None).any(): 1221 | if len(xticks) == 0: xticks = np.arange(len(xticklabels)) 1222 | ax.set_xticks(xticks); 1223 | ax.set_xticklabels(xticklabels); 1224 | if (np.array(yticklabels) != None).any(): 1225 | if len(yticks) == 0: yticks = np.arange(len(yticklabels)) +0.5 1226 | ax.set_yticks(yticks); 1227 | ax.set_yticklabels(yticklabels); 1228 | if len(legend) > 0: ax.legend(legend) 1229 | 1230 | 1231 | 1232 | def find_dominant_dyn(coefficients): 1233 | """ 1234 | This function finds the # of the most dominant dynamics in each time point. It should be used when comparing to rsLDS 1235 | Input: 1236 | coefficients: np.array of kXT 1237 | Output: 1238 | an array with len T, containing the index of the most dominant sub-dynamic at each time point 1239 | """ 1240 | domi = np.argmax(np.abs(coefficients),0) 1241 | return domi 1242 | 1243 | 1244 | 1245 | 1246 | #%% Saving 1247 | 1248 | def check_save_name(save_name, invalid_signs = '!@#$%^&*.,:;', addi_path = [], sep=sep) : 1249 | """ 1250 | Check if the name is valid 1251 | """ 1252 | for invalid_sign in invalid_signs: save_name = save_name.replace(invalid_sign,'_') 1253 | if len(addi_path) == 0: return save_name 1254 | else: 1255 | path_name = sep.join(addi_path) 1256 | return path_name +sep + save_name 1257 | 1258 | def save_file_dynamics(save_name, folders_names,to_save =[], invalid_signs = '!@#$%^&*.,:;', sep = sep , type_save = '.npy'): 1259 | """ 1260 | Save dynamics & model results 1261 | """ 1262 | save_name = check_save_name(save_name, invalid_signs) 1263 | path_name = sep.join(folders_names) 1264 | if not os.path.exists(path_name): 1265 | os.makedirs(path_name) 1266 | if type_save == '.npy': 1267 | if not save_name.endswith('.npy'): save_name = save_name + '.npy' 1268 | np.save(path_name +sep + save_name, to_save) 1269 | elif type_save == '.pkl': 1270 | if not save_name.endswith('.pkl'): save_name = save_name + '.pkl' 1271 | dill.dump_session(path_name +sep + save_name) 1272 | 1273 | def saveLoad(opt,filename): 1274 | global calc 1275 | if opt == "save": 1276 | f = open(filename, 'wb') 1277 | pickle.dump(calc, f, 2) 1278 | f.close 1279 | 1280 | elif opt == "load": 1281 | f = open(filename, 'rb') 1282 | calc = pickle.load(f) 1283 | else: 1284 | print('Invalid saveLoad option') 1285 | 1286 | def load_vars(folders_names , save_name ,sep=sep , ending = '.pkl',full_name = False): 1287 | """ 1288 | Load results previously saved 1289 | Example: 1290 | load_vars('' , 'save_c.pkl' ,sep=sep , ending = '.pkl',full_name = False) 1291 | """ 1292 | if full_name: 1293 | dill.load_session(save_name) 1294 | else: 1295 | if len(folders_names) > 0: path_name = sep.join(folders_names) 1296 | else: path_name = '' 1297 | 1298 | if not save_name.endswith(ending): save_name = '%s%s'%(save_name , ending) 1299 | dill.load_session(path_name +sep +save_name) 1300 | 1301 | 1302 | 1303 | 1304 | def create_colors(len_colors, perm = [0,1,2]): 1305 | """ 1306 | Create a set of discrete colors with a one-directional order 1307 | Input: 1308 | len_colors = number of different colors needed 1309 | Output: 1310 | 3 X len_colors matrix decpiting the colors in the cols 1311 | """ 1312 | colors = np.vstack([np.linspace(0,1,len_colors),(1-np.linspace(0,1,len_colors))**2,1-np.linspace(0,1,len_colors)]) 1313 | colors = colors[perm, :] 1314 | return colors 1315 | 1316 | 1317 | 1318 | #%% Plot tricolor curve for Lorenz 1319 | 1320 | def min_dist(dotA1, dotA2, dotB1, dotB2, num_sects = 500): 1321 | x_lin = np.linspace(dotA1[0], dotA2[0]) 1322 | y_lin = np.linspace(dotA1[1], dotA2[1]) 1323 | x_lin_or = np.linspace(dotB1[0], dotB2[0]) 1324 | y_lin_or = np.linspace(dotB1[1], dotB2[1]) 1325 | dist_list = [] 1326 | for pairA_num, pairAx in enumerate(x_lin): 1327 | pairAy = y_lin[pairA_num] 1328 | for pairB_num, pairBx in enumerate(x_lin_or): 1329 | pairBy = y_lin_or[pairB_num] 1330 | dist = (pairAx - pairBx)**2 + (pairAy - pairBy)**2 1331 | dist_list.append(dist) 1332 | return dist_list 1333 | 1334 | 1335 | def find_perpendicular(d1, d2, perp_length = 1, prev_v = [], next_v = [], ref_point = [],choose_meth = 'intersection',initial_point = 'mid', 1336 | direction_initial = 'low', return_unchose = False, layer_num = 0): 1337 | """ 1338 | This function find the 2 point of the orthogonal vector to a vector defined by points d1,d2 1339 | d1 = first data point 1340 | d2 = second data point 1341 | perp_length = desired width 1342 | prev_v = previous value of v. Needed only if choose_meth == 'prev' 1343 | next_v = next value of v. Needed only if choose_meth == 'prev' 1344 | ref_point = reference point for the 'smooth' case, or for 2nd+ layers 1345 | choose_meth = 'intersection' (eliminate intersections) OR 'smooth' (smoothing with previous prediction) OR 'prev' (eliminate convexity) 1346 | direction_initial = to which direction take the first perp point 1347 | return_unchose = whether to return unchosen directions 1348 | 1349 | """ 1350 | # Check Input 1351 | if d2[0] == d1[0] and d2[1] == d1[1]: 1352 | raise ValueError('d1 and d2 are the same point') 1353 | 1354 | # Define start point for un-perp curve 1355 | if initial_point == 'mid': 1356 | perp_begin = (np.array(d1) + np.array(d2))/2 1357 | d1_perp = perp_begin 1358 | elif initial_point == 'end': d1_perp = d2 1359 | elif initial_point == 'start': d1_perp = d1 1360 | else: raise NameError('Unknown intial point') 1361 | 1362 | # If perpendicular direction is according to 'intersection' elimination 1363 | if choose_meth == 'intersection': 1364 | if len(prev_v) > 0: intersected_curve1 = prev_v 1365 | else: intersected_curve1 = d1 1366 | if len(next_v) > 0: intersected_curve2 = next_v 1367 | else: intersected_curve2 = d2 1368 | 1369 | # If a horizontal line 1370 | if d2[0] == d1[0]: d2_perp = np.array([d1_perp[0]+perp_length, d1_perp[1]]) 1371 | # If a vertical line 1372 | elif d2[1] == d1[1]: d2_perp = np.array([d1_perp[0], d1_perp[1]+perp_length]) 1373 | else: 1374 | m = (d2[1]-d1[1])/(d2[0]-d1[0]) 1375 | m_per = -1/m # Slope of perp curve 1376 | theta1 = np.arctan(m_per) 1377 | theta2 = theta1 + np.pi 1378 | 1379 | # if smoothing 1380 | if choose_meth == 'smooth' or choose_meth == 'intersection': 1381 | if len(ref_point) == 0: 1382 | smooth_val =[] 1383 | else: smooth_val = np.array(ref_point) 1384 | 1385 | # if by convexity 1386 | if choose_meth == 'prev': 1387 | if len(prev_v) > 0 and len(next_v) > 0: # both sides are provided 1388 | prev_mid_or = (np.array(prev_v) + np.array(next_v))/2 1389 | elif len(prev_v) > 0 and len(next_v) == 0: # only the previous side is provided 1390 | prev_mid_or = (np.array(prev_v) + np.array(d2))/2 1391 | elif len(next_v) > 0 and len(prev_v) == 0: # only the next side is provided 1392 | prev_mid_or = (np.array(d1) + np.array(next_v))/2 1393 | else: 1394 | raise ValueError('prev or next should be defined (to detect convexity)!') 1395 | 1396 | if choose_meth == 'prev': 1397 | prev_mid = prev_mid_or 1398 | elif choose_meth == 'smooth': 1399 | prev_mid = smooth_val 1400 | elif choose_meth == 'intersection': 1401 | prev_mid = smooth_val 1402 | 1403 | x_shift = perp_length * np.cos(theta1) 1404 | y_shift = perp_length * np.sin(theta1) 1405 | d2_perp1 = np.array([d1_perp[0] + x_shift, d1_perp[1]+ y_shift]) 1406 | 1407 | x_shift2 = perp_length * np.cos(theta2) 1408 | y_shift2 = perp_length * np.sin(theta2) 1409 | d2_perp2 = np.array([d1_perp[0] + x_shift2, d1_perp[1]+ y_shift2]) 1410 | options_last = [d2_perp1, d2_perp2] 1411 | 1412 | # Choose the option that goes outside 1413 | if len(prev_mid) > 0: 1414 | 1415 | 1416 | if len(ref_point) > 0 and layer_num > 0: # here ref point is a point of a different dynamics layer from which we want to take distance 1417 | dist1 = np.sum((smooth_val - d2_perp1)**2) 1418 | dist2 = np.sum((smooth_val - d2_perp2)**2) 1419 | max_opt = np.argmax([dist1, dist2]) 1420 | 1421 | elif choose_meth == 'intersection': 1422 | dist1 = np.min(min_dist(prev_mid, d2_perp1, intersected_curve1, intersected_curve2)) 1423 | dist2 = np.min(min_dist(prev_mid, d2_perp2, intersected_curve1, intersected_curve2)) 1424 | max_opt = np.argmax([dist1,dist2]) 1425 | 1426 | else: 1427 | dist1 = np.sum((prev_mid - d2_perp1)**2) 1428 | dist2 = np.sum((prev_mid - d2_perp2)**2) 1429 | max_opt = np.argmin([dist1,dist2]) 1430 | else: 1431 | 1432 | if len(ref_point) > 0 and layer_num >0: # here ref point is a point of a different dynamics layer from which we want to take distance 1433 | dist1 = np.sum((ref_point - d2_perp1)**2) 1434 | dist2 = np.sum((ref_point - d2_perp2)**2) 1435 | max_opt = np.argmax([dist1, dist2]) 1436 | 1437 | elif direction_initial == 'low': 1438 | max_opt = np.argmin([d2_perp1[1], d2_perp2[1]]) 1439 | elif direction_initial == 'high': 1440 | max_opt = np.argmax([d2_perp1[1], d2_perp2[1]]) 1441 | elif direction_initial == 'right' : 1442 | max_opt = np.argmax([d2_perp1[0], d2_perp2[0]]) 1443 | elif direction_initial == 'left': 1444 | max_opt = np.argmin([d2_perp1[0], d2_perp2[0]]) 1445 | 1446 | 1447 | else: 1448 | raise NameError('Invalid direction initial value') 1449 | 1450 | d2_perp = options_last[max_opt] # take the desired direction 1451 | if return_unchose: 1452 | d2_perp_unchose = options_last[np.abs(1 - max_opt)] 1453 | return d1_perp, d2_perp, d2_perp_unchose 1454 | return d1_perp, d2_perp 1455 | 1456 | 1457 | def find_lows_high(coeff_row, latent_dyn, choose_meth ='intersection',factor_power = 0.9, initial_point = 'start', 1458 | direction_initial = 'low', return_unchose = False, ref_point = [], layer_num = 0): 1459 | """ 1460 | Calculates the coordinates of the 'high' values of a specific kayer 1461 | """ 1462 | 1463 | if return_unchose: unchosen_highs = [] 1464 | ### Initialize 1465 | x_highs_y_highs = []; x_lows_y_lows = [] 1466 | if isinstance(ref_point, np.ndarray): 1467 | if len(ref_point.shape) > 1: 1468 | ref_shape_all = ref_point 1469 | else: 1470 | ref_shape_all = np.array([]) 1471 | else: 1472 | ref_shape_all = np.array([]) 1473 | # Iterate over time 1474 | for t_num in range(0,latent_dyn.shape[1]-2): 1475 | d1_coeff = latent_dyn[:,t_num] 1476 | d2_coeff = latent_dyn[:,t_num+1] 1477 | prev_v = latent_dyn[:,t_num-1] 1478 | next_v = latent_dyn[:,t_num+2] 1479 | c_len = (coeff_row[t_num] + coeff_row[t_num+1])/2 1480 | 1481 | if len(ref_shape_all) > 0 and ref_shape_all.shape[0] > t_num and layer_num > 0: # and ref_shape_all.shape[1] >1 1482 | ref_point = ref_shape_all[t_num,:] 1483 | 1484 | 1485 | if len(ref_point) > 0 and layer_num > 0 : #and t_num < 3 1486 | pass 1487 | 1488 | 1489 | elif t_num > 2 and (choose_meth == 'smooth' or choose_meth == 'intersection'): 1490 | ref_point = d2_perp 1491 | else: 1492 | ref_point = [] 1493 | 1494 | 1495 | if return_unchose: d1_perp, d2_perp, d2_perp_unchosen = find_perpendicular(d1_coeff, d2_coeff,c_len**factor_power, prev_v = prev_v, next_v=next_v,ref_point = ref_point , choose_meth = choose_meth, initial_point=initial_point, direction_initial =direction_initial, return_unchose = return_unchose,layer_num=layer_num)# c_len 1496 | else: d1_perp, d2_perp = find_perpendicular(d1_coeff, d2_coeff,c_len**factor_power, prev_v = prev_v, next_v=next_v,ref_point = ref_point , choose_meth = choose_meth, initial_point=initial_point, direction_initial= direction_initial, return_unchose = return_unchose,layer_num=layer_num)# c_len 1497 | # Add results to results lists 1498 | x_lows_y_lows.append([d1_perp[0],d1_perp[1]]) 1499 | x_highs_y_highs.append([d2_perp[0],d2_perp[1]]) 1500 | if return_unchose: unchosen_highs.append([d2_perp_unchosen[0],d2_perp_unchosen[1]]) 1501 | # return 1502 | if return_unchose: 1503 | return x_lows_y_lows, x_highs_y_highs, unchosen_highs 1504 | return x_lows_y_lows, x_highs_y_highs 1505 | 1506 | def spec_corr(v1,v2): 1507 | """ 1508 | absolute value of correlation 1509 | """ 1510 | corr = np.corrcoef(v1[:],v2[:]) 1511 | return np.abs(corr[0,1]) 1512 | 1513 | 1514 | def plot_multi_colors(store_dict,min_time_plot = 0,max_time_plot = -100, colors = ['green','red','blue'], ax = [], 1515 | fig = [], alpha = 0.99, smooth_window = 3, factor_power = 0.9, coefficients_n = [], to_scatter = False, 1516 | to_scatter_only_one = False ,choose_meth = 'intersection', title = ''): 1517 | """ 1518 | store_dict is a dictionary with the high estimation results. 1519 | example: 1520 | store_dict , coefficients_n = calculate_high_for_all(coefficients,choose_meth = 'intersection',width_des = width_des, latent_dyn = latent_dyn, direction_initial = direction_initial,factor_power = factor_power, return_unchose=True) 1521 | 1522 | """ 1523 | if len(colors) < len(store_dict): raise ValueError('Not enough colors were provided') 1524 | if isinstance(ax, list) and len(ax) == 0: fig, ax = plt.subplots(figsize = (20,20)) 1525 | for key_counter, (key,set_plot) in enumerate(store_dict.items()): 1526 | if key_counter == 0: 1527 | x_lows_y_lows = store_dict[key][0] 1528 | x_highs_y_highs = store_dict[key][1] 1529 | low_ref =[] 1530 | high_ref = [] 1531 | else: 1532 | low_ref = [np.array(x_highs_y_highs)[min_time_plot,0], np.array(x_highs_y_highs)[min_time_plot,1]] 1533 | high_ref = [np.array(x_highs_y_highs)[max_time_plot,0],np.array(x_highs_y_highs)[max_time_plot,1]] 1534 | if len(coefficients_n) > 0: 1535 | # Define the length of the last perp. 1536 | c_len = (coefficients_n[key,max_time_plot-1] + coefficients_n[key,max_time_plot])/2 1537 | # Create perp. in the last point 1538 | d1_p, d2_p =find_perpendicular([np.array(x_lows_y_lows)[max_time_plot-2,0],np.array(x_lows_y_lows)[max_time_plot-2,1]], 1539 | [np.array(x_lows_y_lows)[max_time_plot-1,0],np.array(x_lows_y_lows)[max_time_plot-1,1]], 1540 | perp_length = c_len**factor_power, 1541 | ref_point = high_ref, 1542 | choose_meth = 'intersection',initial_point = 'end') 1543 | # Define the length of the first perp. 1544 | c_len_start = (coefficients_n[key,min_time_plot-1] + coefficients_n[key,min_time_plot])/2 1545 | # Create perp. in the first point 1546 | d1_p_start =[np.array(x_highs_y_highs)[min_time_plot,0],np.array(x_highs_y_highs)[min_time_plot,1]] 1547 | 1548 | d2_p_start= [np.array(x_highs_y_highs)[min_time_plot+1,0],np.array(x_highs_y_highs)[min_time_plot+1,1]] 1549 | 1550 | x_lows_y_lows = store_dict[key][0] 1551 | x_highs_y_highs = store_dict[key][1] 1552 | 1553 | stack_x = np.hstack([np.array(x_lows_y_lows)[min_time_plot:max_time_plot,0].flatten(), np.array([d2_p[0]]), np.array(x_highs_y_highs)[max_time_plot-1:min_time_plot+1:-1,0].flatten(),np.array([d2_p_start[0]])]) 1554 | stack_y = np.hstack([np.array(x_lows_y_lows)[min_time_plot:max_time_plot,1].flatten(), np.array([d2_p[1]]),np.array(x_highs_y_highs)[max_time_plot-1:min_time_plot+1:-1,1].flatten(),np.array([d2_p_start[1]])]) 1555 | 1556 | else: 1557 | stack_x = np.hstack([np.array(x_lows_y_lows)[min_time_plot:max_time_plot,0].flatten(), np.array(x_highs_y_highs)[max_time_plot:min_time_plot:,0].flatten()]) 1558 | stack_y = np.hstack([np.array(x_lows_y_lows)[min_time_plot:max_time_plot,1].flatten(), np.array(x_highs_y_highs)[max_time_plot:min_time_plot:,1].flatten()]) 1559 | stack_x_smooth = stack_x 1560 | stack_y_smooth = stack_y 1561 | if key_counter !=0: 1562 | ax.fill(stack_x_smooth, stack_y_smooth, alpha=0.3, facecolor=colors[key_counter], edgecolor=None, zorder=1, snap = True)# 1563 | else: 1564 | ax.fill(stack_x_smooth, stack_y_smooth, alpha=alpha, facecolor=colors[key_counter], edgecolor=None, zorder=1, snap = True)# 1565 | 1566 | if to_scatter or (to_scatter_only_one and key == np.max(list(store_dict.keys()))): 1567 | 1568 | 1569 | ax.scatter(np.array(x_lows_y_lows)[min_time_plot:max_time_plot,0].flatten(), np.array(x_lows_y_lows)[min_time_plot:max_time_plot,1].flatten(), c = 'black', alpha = alpha, s = 45) 1570 | 1571 | remove_edges(ax) 1572 | if not title == '': 1573 | ax.set_title(title, fontsize = 20) 1574 | 1575 | 1576 | def calculate_high_for_all(coefficients, choose_meth = 'both', same_width = True,factor_power = 0.9, width_des = 0.7, 1577 | initial_point = 'start', latent_dyn = [], 1578 | direction_initial = 'low', return_unchose = False, type_norm = 'norm',min_width =0.01): 1579 | """ 1580 | Create the dictionary to store results 1581 | """ 1582 | if len(latent_dyn) == 0: raise ValueError('Empty latent dyn was provided') 1583 | 1584 | # Coeffs normalization 1585 | coefficients_n = norm_coeffs(coefficients, type_norm, same_width = same_width, width_des = width_des,factor_power =factor_power,min_width=min_width ) 1586 | 1587 | # Initialization 1588 | store_dict = {} 1589 | dyn_use = latent_dyn 1590 | ref_point = [] 1591 | 1592 | for row in range(coefficients_n.shape[0]): 1593 | 1594 | coeff_row = coefficients_n[row,:] 1595 | # Store the results for each layer 1596 | if return_unchose: 1597 | x_lows_y_lows, x_highs_y_highs,x_highs_y_highs2 = find_lows_high(coeff_row,dyn_use, choose_meth = choose_meth, factor_power=factor_power, 1598 | initial_point = initial_point,direction_initial = direction_initial, 1599 | return_unchose = return_unchose, ref_point = ref_point,layer_num = row ) 1600 | store_dict[row] = [x_lows_y_lows, x_highs_y_highs,x_highs_y_highs2] 1601 | else: 1602 | x_lows_y_lows, x_highs_y_highs = find_lows_high(coeff_row,dyn_use, choose_meth = choose_meth, factor_power=factor_power, 1603 | initial_point = initial_point, direction_initial = direction_initial , 1604 | return_unchose = return_unchose,ref_point = ref_point ,layer_num=row) 1605 | store_dict[row] = [x_lows_y_lows, x_highs_y_highs] 1606 | # Update the reference points 1607 | if initial_point == 'mid': 1608 | dyn_use = np.array(x_highs_y_highs).T 1609 | dyn_use = (dyn_use[:,1:] + dyn_use[:,:-1])/2 1610 | dyn_use = np.hstack([latent_dyn[:,:2], dyn_use, latent_dyn[:,-2:]]) 1611 | else: 1612 | dyn_use = np.array(x_highs_y_highs).T 1613 | 1614 | ref_point = np.array(x_lows_y_lows) 1615 | return store_dict, coefficients_n 1616 | 1617 | 1618 | 1619 | #%% Plot 2d axis of coeffs for fig 2 1620 | 1621 | def plot_3d_dyn_basis(F, coefficients, projection = [0,-1], ax = [], fig = [], time_emph = [], n_times = 5, 1622 | type_plot = 'quiver',range_p = 10,s=200, w = 0.05/3, alpha0 = 0.3, 1623 | time_emph_text = [10, 20, 30, 50,80, 100,200,300,400,500], turn_off_back = True, lim1 = np.nan, 1624 | ax_qui = [], 1625 | ax_base = [], to_title = True, loc_title = 'title', include_bar =True, axs_basis_colored = [], 1626 | colors_dyns = np.array(['r','g','b','yellow']) , plot_dyn_by_colorbase = False, remove_edges_ax = False, include_dynamics = False, 1627 | latent_dyn = [],fontsize_times = 16,delta_text = 0.1, delta_text_y = 0,delta_text_z = 0, 1628 | new_colors = True, include_quiver = True, base_narrow = True,colors = [],color_by_dom = False, 1629 | quiver_3d = False, s_all = 10,to_remove_edge = True, to_grid = False, cons_color = False): 1630 | """ 1631 | ax = subplot to plot coefficients over time 1632 | colors = should be a mat of k X 3 1633 | """ 1634 | if not F[0].shape[0] ==3: quiver_3d = False 1635 | if len(colors) ==0: 1636 | if color_by_dom: 1637 | color_sig_tmp = find_dominant_dyn(np.abs(coefficients)) 1638 | colors = colors_dyns[color_sig_tmp] 1639 | colors_base = np.zeros(coefficients.shape[1]) 1640 | 1641 | else: 1642 | colors_base = np.linspace(0,1,coefficients.shape[1]).reshape((-1,1)) 1643 | colors = np.hstack([colors_base, 1-colors_base, colors_base**2]) 1644 | 1645 | if isinstance(ax,list) and len(ax) == 0: 1646 | if len(F) == 3: fig, ax = plt.subplots(subplot_kw={'projection':'3d'}, figsize= (10,10)) 1647 | elif len(F) == 2: fig, ax = plt.subplots(figsize= (10,10)) 1648 | else: raise ValueError('Invalid dim for F') 1649 | if len(time_emph) == 0: 1650 | time_emph =np.linspace(0,coefficients.shape[1]-2, n_times+1)[1:].astype(int) 1651 | 1652 | if include_dynamics: 1653 | 1654 | if len(latent_dyn) == 0: raise ValueError('You should provide latent dyn as input if "include dynamics" it True') 1655 | if len(F[0]) == 3: 1656 | fig_dyn,ax_dyn = plt.subplots(figsize = (15,15),subplot_kw={'projection':'3d'}) 1657 | if new_colors: 1658 | 1659 | ax_dyn.scatter(latent_dyn[0,:len(colors_base)], latent_dyn[1,:len(colors_base)],latent_dyn[2,:len(colors_base)], color = colors,alpha = 0.3) 1660 | ax_dyn.scatter(latent_dyn[0,time_emph],latent_dyn[1,time_emph],latent_dyn[2,time_emph], c = 'black', s = 300) 1661 | else: 1662 | c_sig = np.arange(latent_dyn.shape[1]) 1663 | ax_dyn.scatter(latent_dyn[0,:], latent_dyn[1,:],latent_dyn[2,:], c = c_sig,alpha = 0.3, cmap = 'viridis', s = 100) 1664 | ax_dyn.scatter(latent_dyn[0,time_emph],latent_dyn[1,time_emph],latent_dyn[2,time_emph], c = c_sig[time_emph], s = 300, cmap = 'viridis') 1665 | [ax_dyn.text(latent_dyn[0,t] + delta_text,latent_dyn[1,t]+delta_text_y,latent_dyn[2,t]+delta_text_z, 't = %s'%str(t), fontsize =fontsize_times, fontweight = 'bold') for t in time_emph] 1666 | ax_dyn.set_axis_off() 1667 | 1668 | else: 1669 | fig_dyn,ax_dyn = plt.subplots(figsize = (10,10)) 1670 | if new_colors: 1671 | 1672 | 1673 | ax_dyn.scatter(latent_dyn[0,:len(colors_base)], latent_dyn[1,:len(colors_base)], color = colors,alpha = 0.3, s = 50) 1674 | ax_dyn.scatter(latent_dyn[0,time_emph],latent_dyn[1,time_emph], c = 'black', s = 200) 1675 | else: 1676 | c_sig = np.arange(latent_dyn.shape[1]) 1677 | ax_dyn.scatter(latent_dyn[0,:], latent_dyn[1,:], c = c_sig,alpha = 0.3, cmap = 'viridis', s = 100) 1678 | ax_dyn.scatter(latent_dyn[0,time_emph],latent_dyn[1,time_emph], c = c_sig[time_emph], s = 300, cmap = 'viridis') 1679 | [ax_dyn.text(latent_dyn[0,t] + delta_text,latent_dyn[1,t]+delta_text_y, 't = %s'%str(t), fontsize =fontsize_times, fontweight = 'bold') for t in time_emph] 1680 | remove_edges(ax_dyn) 1681 | 1682 | if len(F[0]) == 3: 1683 | if quiver_3d: 1684 | if type_plot == 'streamplot': 1685 | type_plot = 'quiver' 1686 | print('If quiver_3d then type_plot need to be quiver (currently is streamplot)') 1687 | if include_quiver: 1688 | if isinstance(ax_qui, list) and len(ax_qui)== 0: 1689 | fig_qui, ax_qui = plt.subplots(1,len(time_emph), figsize= (7*len(time_emph),5) ,subplot_kw={'projection':'3d'}) 1690 | if isinstance(ax_base, list) and len(ax_base)==0: 1691 | if base_narrow: 1692 | fig_base, ax_base = plt.subplots(len(F),1, figsize= (5,7*len(F)) ,subplot_kw={'projection':'3d'}) 1693 | else: 1694 | fig_base, ax_base = plt.subplots(1,len(F),figsize= (7*len(F),5 ),subplot_kw={'projection':'3d'}) 1695 | else: 1696 | 1697 | F = [f[:,projection] for f in F] 1698 | F = [f[projection, :] for f in F] 1699 | if include_quiver: 1700 | if isinstance(ax_qui, list) and len(ax_qui)== 0: fig_qui, ax_qui = plt.subplots(1,len(time_emph), figsize= (7*len(time_emph),5) ) 1701 | if isinstance(ax_base, list) and len(ax_base)==0: 1702 | if base_narrow: 1703 | fig_base, ax_base = plt.subplots(len(F),1, figsize= (5,7*len(F)) ) 1704 | else: 1705 | fig_base, ax_base = plt.subplots(1,len(F),figsize= (7*len(F),5 )) 1706 | 1707 | elif len(F[0]) == 2: 1708 | if isinstance(ax_qui, list) and len(ax_qui)==0: fig_qui, ax_qui = plt.subplots(1,len(time_emph), figsize= (7*len(time_emph),5)) 1709 | if isinstance(ax_base, list) and len(ax_base)==0: 1710 | if base_narrow: 1711 | fig_base, ax_base = plt.subplots(len(F), 1,figsize= (5,7*len(F) )) 1712 | else: 1713 | fig_base, ax_base = plt.subplots(1,len(F),figsize= (7*len(F),5 )) 1714 | if len(F[0]) == 3: 1715 | 1716 | cmap = matplotlib.cm.get_cmap('viridis') 1717 | if new_colors: 1718 | 1719 | ax.scatter(coefficients[0,:],coefficients[1,:],coefficients[2,:], c = colors, alpha = alpha0, s = s_all) 1720 | 1721 | ax.scatter(coefficients[0,time_emph],coefficients[1,time_emph],coefficients[2,time_emph], c = 'black', 1722 | s = s) 1723 | 1724 | [plot_reco_dyn(coefficients, F, time_point, type_plot = type_plot, range_p = range_p, color =colors[time_point] , 1725 | w = w, ax = ax_qui[i], quiver_3d = quiver_3d, cons_color = cons_color) for i, time_point in enumerate(time_emph)] 1726 | else: 1727 | 1728 | cmap = matplotlib.cm.get_cmap('viridis') 1729 | colors_base = np.arange(coefficients.shape[1]) 1730 | ax.scatter(coefficients[0,:],coefficients[1,:],coefficients[2,:], c = colors_base, alpha = alpha0, s = s_all) 1731 | ax.scatter(coefficients[0,time_emph],coefficients[1,time_emph],coefficients[2,time_emph], c = 'black', s = s, alpha = np.min([alpha0*2, 1])) 1732 | if include_quiver: 1733 | [plot_reco_dyn(coefficients, F, time_point, type_plot = type_plot, range_p = range_p, 1734 | color = cmap(time_point/colors.shape[0]) , 1735 | w = w, ax = ax_qui[i], quiver_3d = quiver_3d, cons_color = cons_color) for i, time_point in enumerate(time_emph)] 1736 | 1737 | if to_title and include_quiver: 1738 | if loc_title == 'title': 1739 | [ax_qui[i].set_title('t = ' + str(time_point), fontsize =fontsize_times*3 , fontweight = 'bold') for i, time_point in enumerate(time_emph)] 1740 | else: 1741 | [ax_qui[i].set_ylabel('t = ' + str(time_point), fontsize =fontsize_times, fontweight = 'bold') for i, time_point in enumerate(time_emph)] 1742 | 1743 | [ax.text(coefficients[0,time_point]+delta_text,coefficients[1,time_point]+delta_text_y,coefficients[2,time_point]+delta_text_z,'t = ' + str(time_point), fontsize =fontsize_times, fontweight = 'bold') for time_point in time_emph_text] 1744 | ax.set_xlabel('f1');ax.set_ylabel('f2');ax.set_zlabel('f3'); 1745 | 1746 | else: 1747 | 1748 | 1749 | cmap = matplotlib.cm.get_cmap('viridis') 1750 | if new_colors: 1751 | ax.scatter(coefficients[0,:],coefficients[1,:], c = colors, alpha = alpha0, s = s_all) 1752 | ax.scatter(coefficients[0,time_emph],coefficients[1,time_emph], c = colors[time_emph], s = s) 1753 | if include_quiver: 1754 | [plot_reco_dyn(coefficients, F, time_point, type_plot = type_plot, range_p = range_p, color =colors[time_point] , w = w, ax = ax_qui[i],cons_color=cons_color ) for i, time_point in enumerate(time_emph)] 1755 | else: 1756 | colors_base = np.arange(coefficients.shape[1]) 1757 | ax.scatter(coefficients[0,:],coefficients[1,:], c = colors_base, alpha = alpha0, s = s_all) 1758 | ax.scatter(coefficients[0,time_emph],coefficients[1,time_emph], c = colors_base[time_emph], s = s) 1759 | if include_quiver: 1760 | [plot_reco_dyn(coefficients, F, time_point, type_plot = type_plot, range_p = range_p, color = cmap(time_point/colors.shape[0]) , w = w, ax = ax_qui[i],cons_color= cons_color ) for i, time_point in enumerate(time_emph)] 1761 | if latent_dyn.shape[0] == 3: 1762 | [ax.text(coefficients[0,time_point]+delta_text,coefficients[1,time_point]+delta_text_y, coefficients[2,time_point]+delta_text_z,'t = ' + str(time_point),fontsize = fontsize_times ,fontweight = 'bold') for time_point in time_emph_text] 1763 | 1764 | else: 1765 | [ax.text(coefficients[0,time_point]+delta_text,coefficients[1,time_point]+delta_text_y,'t = ' + str(time_point),fontsize = fontsize_times ,fontweight = 'bold') for time_point in time_emph_text] 1766 | 1767 | ax.set_xlabel('f1');ax.set_ylabel('f2'); 1768 | if remove_edges_ax: remove_edges(ax) 1769 | 1770 | if to_title and include_quiver: 1771 | if loc_title == 'title': 1772 | [ax_qui[i].set_title('t = ' + str(time_point), fontsize = 30) for i, time_point in enumerate(time_emph)] 1773 | else: 1774 | [ax_qui[i].set_ylabel('t = ' + str(time_point), fontsize = 30) for i, time_point in enumerate(time_emph)] 1775 | if to_remove_edge: 1776 | if include_quiver: [remove_edges(ax_spec) for ax_spec in ax_qui] 1777 | [remove_edges(ax_spec) for ax_spec in ax_base] 1778 | ax.set_xticks([]) 1779 | ax.set_yticks([]) 1780 | if quiver_3d: 1781 | ax.set_zticks([]) 1782 | [quiver_plot(f,-range_p, range_p, -range_p, range_p, ax = ax_base[f_num],chosen_color = 'black', w = w, type_plot = type_plot,cons_color =cons_color,quiver_3d = quiver_3d ) for f_num, f in enumerate(F)] 1783 | [ax_base_spec.set_title('f %s'%str(i), fontsize = 16) for i, ax_base_spec in enumerate(ax_base)] 1784 | 1785 | 1786 | if turn_off_back and len(F) == 3: 1787 | ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 1788 | ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 1789 | ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 1790 | if not to_grid and len(F) == 3: 1791 | ax.grid(False) 1792 | ax.set_zticks([]) 1793 | ax.xaxis._axinfo['juggled'] = (0,0,0) 1794 | ax.yaxis._axinfo['juggled'] = (1,1,1) 1795 | ax.zaxis._axinfo['juggled'] = (2,2,2) 1796 | if not np.isnan(lim1): 1797 | ax.set_xlim([-lim1,lim1]) 1798 | ax.set_ylim([-lim1,lim1]) 1799 | 1800 | 1801 | if include_bar: 1802 | if base_narrow: 1803 | fig_all_all, ax_all_all = plt.subplots(len(time_emph),1, figsize = (6,len(time_emph)*7)) 1804 | else: 1805 | ax_all_all = [] 1806 | add_bar_dynamics(coefficients, ax_all_all = ax_all_all, min_max_points = time_emph, colors = colors_dyns, 1807 | centralize = True) 1808 | 1809 | if isinstance( axs_basis_colored ,list) and len( axs_basis_colored ) == 0: 1810 | if base_narrow: 1811 | if quiver_3d: fig_basis_colored , axs_basis_colored = plt.subplots( len(F),1,figsize = (5,6*len(F)),subplot_kw={'projection':'3d'}) 1812 | else: fig_basis_colored , axs_basis_colored = plt.subplots( len(F),1,figsize = (5,6*len(F))) 1813 | 1814 | else: 1815 | if quiver_3d: fig_basis_colored , axs_basis_colored = plt.subplots( 1, len(F), figsize = (6*len(F),5),subplot_kw={'projection':'3d'}) 1816 | else: fig_basis_colored , axs_basis_colored = plt.subplots( 1, len(F), figsize = (6*len(F),5)) 1817 | [quiver_plot(f,-range_p, range_p, -range_p, range_p, ax = axs_basis_colored[f_num],alpha = 0.7, chosen_color = colors_dyns[f_num], w = w, type_plot = type_plot, cons_color = cons_color, quiver_3d=quiver_3d ) for f_num, f in enumerate(F)] 1818 | [remove_edges(ax_spec) for ax_spec in axs_basis_colored] 1819 | if quiver_3d: [ax.set_zticks([]) for ax in axs_basis_colored] 1820 | 1821 | 1822 | 1823 | def plot_reco_dyn(coefficients, F, time_point, type_plot = 'quiver', range_p = 10, color = 'black', 1824 | w = 0.05/3, ax = [], cons_color= False, to_remove_edges = False, projection = [0,1], 1825 | return_artist = False, 1826 | xlabel = 'x',ylabel = 'y',quiver_3d = False): 1827 | if isinstance(ax,list) and len(ax) == 0: 1828 | 1829 | fig, ax = plt.subplots() 1830 | 1831 | if len(F) == 3: 1832 | merge_dyn_at_t_break = coefficients[0,time_point] * F[0]+coefficients[1,time_point] * F[1]+coefficients[2,time_point] * F[2] 1833 | 1834 | if not quiver_3d: 1835 | 1836 | merge_dyn_at_t_break = merge_dyn_at_t_break[:, projection] 1837 | merge_dyn_at_t_break = merge_dyn_at_t_break[projection,:] 1838 | 1839 | elif len(F) == 2: 1840 | merge_dyn_at_t_break = coefficients[0,time_point] * F[0]+coefficients[1,time_point] * F[1] 1841 | 1842 | art = quiver_plot(sub_dyn = merge_dyn_at_t_break, chosen_color = color, xmin = -range_p, 1843 | xmax = range_p, ymin= -range_p,ymax= range_p, ax = ax, w = w, type_plot=type_plot, 1844 | cons_color= cons_color, return_artist = return_artist, xlabel = xlabel, ylabel = ylabel, 1845 | quiver_3d = quiver_3d) 1846 | if to_remove_edges: remove_edges(ax) 1847 | if return_artist: 1848 | return art 1849 | 1850 | 1851 | 1852 | def plot_c_space(coefficients,latent_dyn = [], axs = [], fig = [], xlim = [-50,50], ylim = [-50,50], add_midline = True, d3 = True, cmap = 'winter', color_sig = [], 1853 | title = '', times_plot= [], cmap_f = []): 1854 | if len(times_plot) > 0 and isinstance(cmap_f, list): cmap_f = plt.cm.get_cmap(cmap) 1855 | if len(color_sig) == 0: color_sig = latent_dyn[0,:-1] 1856 | if isinstance(axs, list) and len(axs) == 0: 1857 | 1858 | if coefficients.shape[0] == 3: 1859 | fig, axs = plt.subplots(figsize = (15,15),subplot_kw={'projection':'3d'}) 1860 | d3 = True 1861 | h = axs.scatter(coefficients[0,:], coefficients[1,:],coefficients[2,:], c = color_sig, cmap = cmap) 1862 | if len(times_plot) > 0: 1863 | axs.scatter(coefficients[0,times_plot], coefficients[1,times_plot], coefficients[2,times_plot], c =cmap_f(color_sig[times_plot]/np.max(color_sig)),s = 500 ) 1864 | elif coefficients.shape[0] == 2: 1865 | if d3: 1866 | fig, axs = plt.subplots(figsize = (15,15),subplot_kw={'projection':'3d'}) 1867 | h = axs.scatter(coefficients[0,:], coefficients[1,:], np.arange(coefficients.shape[1]), c = color_sig, cmap = cmap) 1868 | if len(times_plot) > 0: 1869 | zax = np.arange(coefficients.shape[1]) 1870 | axs.scatter(coefficients[0,times_plot], coefficients[1,times_plot], zax[times_plot], c =cmap_f(color_sig[times_plot]/np.max(color_sig)),s = 500 ) 1871 | else: 1872 | fig, axs = plt.subplots(figsize = (15,15)) 1873 | h = axs.scatter(coefficients[0,:], coefficients[1,:], c = color_sig, cmap = cmap) 1874 | if len(times_plot) > 0: 1875 | axs.scatter(coefficients[0,times_plot], coefficients[1,times_plot], c =cmap_f(color_sig[times_plot]/np.max(color_sig)),s = 500 ) 1876 | else: 1877 | print('Invalid coefficients shape in axis 0') 1878 | 1879 | if len(xlim) > 0: axs.set_xlim(xlim) 1880 | if len(ylim) > 0: axs.set_ylim(ylim) 1881 | if not isinstance(fig, list): fig.colorbar(h) 1882 | 1883 | if d3: 1884 | axs.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 1885 | axs.grid(False) 1886 | axs.set_axis_off() 1887 | if add_midline: 1888 | if d3: 1889 | axs.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 1890 | axs.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 1891 | axs.plot([0,0],[np.min(coefficients[1,:]),np.max(coefficients[1,:])],[0,0], color = 'black', ls = '--', alpha = 0.3) 1892 | axs.plot([np.min(coefficients[0,:]),np.nanmax(coefficients[0,:])],[0,0],[0,0], color = 'black', ls = '--', alpha = 0.3) 1893 | axs.plot([0,0],[0,0],[0,1.3*coefficients.shape[1]],color = 'black', alpha = 0.3, ls = '--') 1894 | axs.view_init(elev=15, azim=30) 1895 | else: 1896 | axs.axhline(0, color = 'black', ls = '--', alpha = 0.3) 1897 | axs.axvline(0, color = 'black', ls = '--', alpha = 0.3) 1898 | if len(title) > 0: 1899 | axs.set_title(title) 1900 | 1901 | 1902 | 1903 | -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/test/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon May 30 08:49:16 2022 4 | 5 | @author: noga mudrik 6 | """ 7 | 8 | -------------------------------------------------------------------------------- /dLDS_discrete/dlds_discrete/train_discrete_model_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposed Linear Dynamical Systems (dLDS) for learning the latent components of neural dynamics 3 | @code author: noga mudrik 4 | """ 5 | 6 | #%% Imports: 7 | 8 | from importlib import reload 9 | 10 | import main_functions 11 | main_functions = reload(main_functions) 12 | from main_functions import * 13 | from datetime import date 14 | 15 | 16 | """ 17 | Parameters 18 | """ 19 | exec(open('create_params.py').read()) 20 | 21 | addi_save = date.today().strftime('%d%m%y') # For saving 22 | if 'addition_save' not in locals(): addition_save = [] 23 | 24 | update_c_types = ['inv'] #['spgl1'] 25 | num_iters = [10] 26 | max_iter = 6000 27 | is_D_I = True 28 | 29 | 30 | """ 31 | Parameters to choose 32 | """ 33 | dt = float(input('dt (rec for Lorenz 0.01, rec for FHN 0.2)')) 34 | max_time = float(input('max time (rec for Lorenz 10, rec for FHN 200)')) 35 | dynamic_type = input('dynamic type (e.g. lorenz, FHN)') 36 | addi_name = input('additional name id') 37 | num_subdyns = [int(input('num_dyns (m)'))] 38 | include_last_up = str2bool(input('include last up? (for FHN reg)')) 39 | reg_vals_new = [float(input('reg_val_input (tau)'))] 40 | addition_save.append(addi_save) 41 | latent_dyn = create_dynamics(type_dyn = dynamic_type, max_time = max_time, dt = dt) 42 | include_D = False 43 | to_load = False 44 | 45 | 46 | 47 | name_auto = True 48 | normalize_eig = True 49 | to_print = False 50 | seed_f = 0 51 | dt_range = np.linspace(0.001, 1, 20) 52 | exp_power = 0.1 53 | 54 | """ 55 | Runnining over the parameters 56 | """ 57 | for num_iter in num_iters: 58 | for reg_term in reg_vals_new: 59 | for update_c_type in update_c_types : 60 | for num_subs in num_subdyns: 61 | to_save_without_ask = True 62 | sigma_mix_f = 0.1 63 | F = [init_mat((latent_dyn.shape[0], latent_dyn.shape[0]),normalize=True) for i in range(num_subs)] 64 | coefficients = init_mat((num_subs,latent_dyn.shape[1]-1)) 65 | save_name = '%s_%gsub%greg%s_iters%s'%(dynamic_type,num_subs,reg_term, update_c_type, str(num_iter)) 66 | data = latent_dyn 67 | 68 | params_update_c = {'reg_term': reg_term, 'update_c_type':update_c_type,'smooth_term' :smooth_term, 'num_iters': num_iter, 'threshkind':'soft'} 69 | 70 | if to_save_without_ask: to_save = True 71 | else: to_save = str2bool(input('To save?')) 72 | 73 | coefficients, F, latent_dyn, error_reco_array, D = train_model_include_D(max_time , dt , dynamic_type, num_subdyns = num_subs, 74 | data = data, step_f = step_f, GD_decay = GD_decay, 75 | max_error = max_error, max_iter = max_iter, 76 | include_D = include_D, seed_f = seed_f, 77 | normalize_eig = normalize_eig, 78 | to_print = to_print, params = params_update_c ) 79 | if to_save: 80 | if name_auto: pass 81 | else: save_name = input('save_name') 82 | save_dict = {'F':F, 'coefficients':coefficients, 'latent_dyn': latent_dyn, 'max_time': max_time, 'dt':dt,'dyn_type':dynamic_type, 83 | 'error_reco_array' :error_reco_array, 'D':D} 84 | save_file_dynamics(save_name, ['main_folder_results', dynamic_type, 'clean%s'%addi_name,update_c_type ]+addition_save, save_dict ) 85 | save_file_dynamics(save_name, ['main_folder_results' ,dynamic_type, 'clean%s'%addi_name,update_c_type ]+addition_save, [], type_save = '.pkl' ) 86 | 87 | -------------------------------------------------------------------------------- /dLDS_discrete/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /dLDS_discrete/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | packages=setuptools.find_packages(), 8 | author="noga mudrik", 9 | 10 | name="dLDS_discrete_2022", 11 | version="0.1.01", 12 | 13 | author_email="", 14 | description="dLDS discrete model package", 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ], 23 | 24 | python_requires=">=3.8", 25 | install_requires = ['numpy', 'matplotlib','scipy','scipy','pandas','webcolors', 26 | 'seaborn','colormap','sklearn', 'pylops','dill','mat73', 'easydev'] 27 | ) 28 | 29 | -------------------------------------------------------------------------------- /discrete_results_notebook/FHN_with_reg_0_3_spgl1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/FHN_with_reg_0_3_spgl1.npy -------------------------------------------------------------------------------- /discrete_results_notebook/Worm1_WT_Stim.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/Worm1_WT_Stim.mat -------------------------------------------------------------------------------- /discrete_results_notebook/Worm1_WT_Stim_Zhat_discretestates.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/Worm1_WT_Stim_Zhat_discretestates.mat -------------------------------------------------------------------------------- /discrete_results_notebook/c_elegans_dlds_results.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/c_elegans_dlds_results.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_dyn_non_reg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_dyn_non_reg.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0_01regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0_01regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0_05regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0_05regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0_1regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0_1regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0_4regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0_4regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0_6regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0_6regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0_7regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0_7regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/fhn_reg_effect/fhn_5sub0regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/fhn_reg_effect/fhn_5sub0regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_5sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_5sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_5sub0reg_results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_5sub0reg_results.pkl -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_01regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_01regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_05regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_05regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_1regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_1regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_4regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_4regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_6regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_6regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_7regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0_7regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/lorenz_reg_effect/lorenz_5sub0regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/multifhn_2sub0reg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/multifhn_2sub0reg.npy -------------------------------------------------------------------------------- /discrete_results_notebook/multilorenz_3sub0reg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/multilorenz_3sub0reg.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_fhn/fhn_2sub0_3regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_fhn/fhn_2sub0_3regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_fhn/fhn_3sub0_3regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_fhn/fhn_3sub0_3regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_fhn/fhn_4sub0_3regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_fhn/fhn_4sub0_3regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_fhn/fhn_5sub0_3regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_fhn/fhn_5sub0_3regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_fhn/fhn_6sub0_3regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_fhn/fhn_6sub0_3regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_fhn/fhn_7sub0_3regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_fhn/fhn_7sub0_3regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_lorenz/lorenz_2sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_lorenz/lorenz_2sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_lorenz/lorenz_3sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_lorenz/lorenz_3sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_lorenz/lorenz_4sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_lorenz/lorenz_4sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_lorenz/lorenz_5sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_lorenz/lorenz_5sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_lorenz/lorenz_6sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_lorenz/lorenz_6sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /discrete_results_notebook/num_sub_lorenz/lorenz_7sub0_55regspgl1_iters10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/discrete_results_notebook/num_sub_lorenz/lorenz_7sub0_55regspgl1_iters10.npy -------------------------------------------------------------------------------- /paper figures/DynamicsLearningModel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/paper figures/DynamicsLearningModel.png -------------------------------------------------------------------------------- /paper figures/FHN_dLDS_vs_rslds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/paper figures/FHN_dLDS_vs_rslds.png -------------------------------------------------------------------------------- /paper figures/c elegans.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/paper figures/c elegans.PNG -------------------------------------------------------------------------------- /paper figures/lorenz dLDS vs rSLDS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/paper figures/lorenz dLDS vs rSLDS.png -------------------------------------------------------------------------------- /paper figures/lorenz_new_f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dLDS-Decomposed-Linear-Dynamics/dLDS-Discrete-Python-Model/dd0b6df8f8a23031dad4d36a5227c5a17786ab26/paper figures/lorenz_new_f1.png --------------------------------------------------------------------------------