├── README.md ├── .gitignore ├── 00_mdof_osa ├── mdof_solutions.py ├── aux_funcs.py └── mdof_osa_pinn.py ├── auxUtils.py └── _mdofPINN ├── pinnUtils.py └── pinnModels.py /README.md: -------------------------------------------------------------------------------- 1 | # Physics-informed neural networks (PINNs) for dynamical systems 2 | 3 | This repo includes a series of physics-informed neural networks for various dynamic systems. Please feel free to flag any bugs, or any possible improvements. 4 | 5 | ### Required python packages 6 | ``` 7 | python >= 3.10 8 | numpy >2.x 9 | scipy 10 | matplotlib 11 | dynasim >= 1.4 12 | pytorch >= 2.0 13 | tqdm 14 | pyro 15 | seaborn 16 | arviz 17 | ``` 18 | 19 | ### Citation 20 | 21 | This code is provided for the paper https://arxiv.org/abs/2410.01340 22 | ``` 23 | @article{haywood2024response, 24 | title={Response Estimation and System Identification of Dynamical Systems via Physics-Informed Neural Networks}, 25 | author={Haywood-Alexander, Marcus and Arcieri, Giacomo and Kamariotis, Antonios and Chatzi, Eleni}, 26 | journal={arXiv preprint arXiv:2410.01340}, 27 | year={2024} 28 | } 29 | ``` 30 | 31 | # General PINN Definition 32 | 33 | ## Artificial Neural Networks 34 | For regression problems, the aim of an ANN is to map from an $n$-dimensional input, $\mathbf{x}$, to $k$-dimensional output, $\mathbf{y}$. 35 | For $N$ hidden layers of a neural network, each of the $m$ nodes is passed through an activation function $\sigma$. 36 | ```math 37 | \mathcal{N}_{\mathbf{y}}(\mathbf{x};\mathbf{W}, \mathbf{B}) := \sigma(\mathbf{w}^l x^{l-1} + \mathbf{b}^l), \quad \mathrm{for}\; l = 2,...,N 38 | ``` 39 | where $\mathbf{W}=\{\mathbf{w}^1,...,\mathbf{w}^N\}$ and $\mathbf{B}=\{\mathbf{b}^1,...,\mathbf{b}^N\}$ are the weights and biases of the network, respectively. 40 | These then form the hyperparameters of the networks $\mathbf{\Theta} = \{\mathbf{W},\mathbf{B}\}$. 41 | 42 | With target output data $\mathbf{y}^*$ from the domain of observations $\mathbf{x}^*\in\Omega_0$, the "optimal" parameters are commonly determined using a simple mean-squared-error objective function, 43 | ```math 44 | L_{obs}(\mathbf{x}^*;\mathbf{\Theta}) = \langle \mathbf{y}^* - \mathcal{N}_{\mathbf{y}}(\mathbf{x}^*;\mathbf{\Theta}) \rangle _{\Omega_{0}}, \qquad 45 | \langle \bullet \rangle _{\Omega_{\kappa}} = \frac{1}{N_{\kappa}}\sum_{x\in\Omega_{\kappa}}\left|\left|\bullet\right|\right|^2 46 | ``` 47 | 48 | ## Physics-Informed Neural Network 49 | 50 | If the physics of the system is known (or estimated) in the form of ordinary or partial differential equations, then this can be embedded into the objective function over which the NN parameters are optimised. 51 | Given a general form of the PDE, 52 | ```math 53 | \mathcal{F}(\mathbf{y},\mathbf{x};\theta) = 0 54 | ``` 55 | for some nonlinear operator $\mathcal{F}$ acting on $\mathbf{y}(\mathbf{x})$, where $\theta$ are parameters of the equation. For example, the wave equation (here in it's general form), 56 | ```math 57 | \frac{\partial^2 u}{\partial t^2} = c_1^2 \frac{\partial^2}{\partial x_1^2} + c_2^2 \frac{\partial^2}{\partial x_2^2} + ... + c_n^2 \frac{\partial^2}{\partial x_n^2} 58 | ``` 59 | where $\mathbf{y}=\{u_1,u_2,...,u_n\}$, $\mathbf{x}=\{x_1,x_2,...,x_n\}$, and $\theta = \{c_1,c_2,...,c_n\}$. 60 | 61 | When predicting the output from a neural network, we can also create an estimate of the nonlinear operator, $\mathcal{F}(\mathcal{N}_{\mathbf{y}},\mathbf{x};\theta)$. 62 | This can then be directly used as an additional objective function to be minimised, as when this value equals zero, the PDE is satisfied. 63 | Given the domain of collocation points, $\mathbf{x}_p \in \Omega_p$, this term is defined as, 64 | ```math 65 | L_{pde}(\mathbf{x}_c;\mathbf{\Theta},\theta) = \langle \mathcal{F}(\mathcal{N}_\mathbf{y_p},\mathbf{x}_p;\theta) \rangle _{\Omega_p}, \qquad \mathcal{N}_\mathbf{y_p} = \mathcal{N}_\mathbf{y}(\mathbf{x}_p;\mathbf{\Theta}) 66 | ``` 67 | Then, we can combine the the observation objective function with the pde objective function, and minimise this, 68 | 69 | ```math 70 | L = L_{obs} + \Lambda L_{pde} 71 | ``` 72 | where $\Lambda$ is a normalisation parameter required to posit the objective function values in the same magnitude to aid optimisation. In this work, often a combination of the input normalisation parameters are used to set the value of $\Lambda$. 73 | 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .gitattributes 2 | 3 | **/__pycache__/ 4 | 5 | **/checkpoints/ 6 | 7 | **/results/ 8 | 9 | **/.DS_Store 10 | .DS_Store 11 | .vscode/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | **.pyc 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/#use-with-ide 123 | .pdm.toml 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # PyCharm 169 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 170 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 171 | # and can be added to the global gitignore or merged into this file. For a more nuclear 172 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 173 | #.idea/ -------------------------------------------------------------------------------- /00_mdof_osa/mdof_solutions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from math import pi 4 | import scipy 5 | from typing import Union, Tuple 6 | 7 | Tensor = Union[torch.Tensor, np.ndarray] 8 | TensorFloat = Union[torch.Tensor, float] 9 | 10 | class nonlinearity: 11 | 12 | def __init__(self, dofs, gk_exp = None, gc_exp = None): 13 | 14 | self.dofs = dofs 15 | self.gk_exp = gk_exp 16 | self.gc_exp = gc_exp 17 | 18 | def Kn_func(self, kn_): 19 | 20 | Kn = torch.diag(kn_) - torch.diag(kn_[1:], 1) 21 | return Kn 22 | 23 | def gk_func(self, x, xdot): 24 | if self.gk_exp is not None: 25 | return torch.sign(x) * torch.abs(x) ** self.gk_exp 26 | else: 27 | return torch.zeros_like(x) 28 | 29 | def Cn_func(self, cn_): 30 | 31 | Cn = torch.diag(cn_) - torch.diag(cn_[1:], 1) 32 | return Cn 33 | 34 | def gc_func(self, x, xdot): 35 | if type(self.gc_exp) == float: 36 | return torch.sign(xdot) * torch.abs(xdot) ** self.gc_exp 37 | elif self.gc_exp == 'vdp': 38 | return (x**2 - 1) * xdot 39 | else: 40 | return torch.zeros_like(xdot) 41 | 42 | def mat_func(self, kn_, cn_, invM): 43 | 44 | Kn = self.Kn_func(kn_) 45 | Cn = self.Cn_func(cn_) 46 | 47 | return torch.cat(( 48 | torch.zeros((self.dofs, 2*self.dofs)), 49 | torch.cat((-invM @ Kn, -invM @ Cn), dim=1) 50 | ), dim=0) 51 | 52 | def gz_func(self, z): 53 | x_ = z[:self.dofs, :] - torch.cat((torch.zeros((1, z.shape[1])), z[:self.dofs-1, :]), dim=0) 54 | xdot_ = z[self.dofs:, :] - torch.cat((torch.zeros((1, z.shape[1])), z[self.dofs:-1, :]), dim=0) 55 | return torch.cat((self.gk_func(x_, xdot_), self.gc_func(x_, xdot_)), dim=0) 56 | 57 | def gen_ndof_cantilever(m_: TensorFloat, c_: TensorFloat, k_: TensorFloat, ndof: int = None, return_numpy: bool = False, connected_damping: bool = True) -> Tuple[Tensor, Tensor, Tensor]: 58 | if torch.is_tensor(m_): 59 | ndof = m_.shape[0] 60 | else: 61 | m_ = m_ * torch.ones((ndof)) 62 | c_ = c_ * torch.ones((ndof)) 63 | k_ = k_ * torch.ones((ndof)) 64 | M = torch.zeros((ndof,ndof), dtype=torch.float32) 65 | C = torch.zeros((ndof,ndof), dtype=torch.float32) 66 | K = torch.zeros((ndof,ndof), dtype=torch.float32) 67 | for i in range(ndof): 68 | M[i,i] = m_[i] 69 | for i in range(ndof-1): 70 | if connected_damping: 71 | C[i,i] = c_[i] + c_[i+1] 72 | C[i,i+1] = -c_[i+1] 73 | else: 74 | C[i,i] = c_[i] 75 | K[i,i] = k_[i] + k_[i+1] 76 | K[i,i+1] = -k_[i+1] 77 | C[-1,-1] = c_[-1] 78 | K[-1,-1] = k_[-1] 79 | C = torch.triu(C) + torch.triu(C, 1).T 80 | K = torch.triu(K) + torch.triu(K, 1).T 81 | if return_numpy: 82 | return M.numpy(), C.numpy(), K.numpy() 83 | else: 84 | return M, C, K 85 | 86 | def add_noise(x: np.ndarray, db: Tuple[float, None] = None, SNR: Tuple[float, None] = None, seed: int = 43810) -> np.ndarray: 87 | 88 | ns = x.shape[0] 89 | nd = x.shape[1] 90 | x_noisy = np.zeros_like(x) 91 | 92 | match [db, SNR]: 93 | case [float(), None]: 94 | noise_amp = 10.0 ** (db / 10.0) 95 | for i in range(nd): 96 | np.random.seed(seed + i) 97 | noise_x = np.random.normal(loc=0.0, scale=np.sqrt(noise_amp), size=ns) 98 | x_noisy[:,i] = x[:,i] + noise_x 99 | case [None, float()]: 100 | for i in range(nd): 101 | np.random.seed(seed + i) 102 | P_sig = 10 * np.log10(np.mean(x[:, i]**2)) 103 | P_noise = P_sig - SNR 104 | noise_amp = 10 ** (P_noise / 10.0) 105 | noise_x = np.random.normal(loc=0.0, scale=np.sqrt(noise_amp), size=ns) 106 | x_noisy[:,i] = x[:,i] + noise_x 107 | case [float(), float()]: 108 | raise Exception("Over specified, please select either db or SNR") 109 | case [None, None]: 110 | raise Exception("No noise level specified") 111 | return x_noisy 112 | 113 | 114 | -------------------------------------------------------------------------------- /00_mdof_osa/aux_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import copy 4 | 5 | def dropout(dropouts, *data_): 6 | data_dropped = [None] * len(data_) 7 | for i, data in enumerate(data_): 8 | data_dropped[i] = data.clone() 9 | for j in dropouts: 10 | data_dropped[i][:,j] = torch.zeros_like(data[:,j]) 11 | 12 | return data_dropped 13 | 14 | class test_parser: 15 | 16 | pass 17 | 18 | class sparse_recov_parser(test_parser): 19 | 20 | def __init__(self, sparsity_type, nonlin_type, error_type, snr, dofs, p_obs_drop): 21 | 22 | self.nonlin_type = nonlin_type 23 | 24 | self.sparsity_type = sparsity_type 25 | a = np.arange(dofs) 26 | match sparsity_type: 27 | case 'domain_interpolation': 28 | # self.dropouts = [1, 3] 29 | step = round(100/(100-p_obs_drop)) 30 | # self.dropouts = a[a%step==1].tolist() 31 | dropout_dropouts = a[1::step].tolist() 32 | self.dropouts = np.delete(a, dropout_dropouts).tolist() 33 | testid1 = 'inter' 34 | case 'domain_extension': 35 | # self.dropouts = [0, 1] 36 | self.dropouts = a[aint(dofs * (100 - p_obs_drop)/100)].tolist() 36 | testid1 = 'exten' 37 | case list(): 38 | self.dropouts = a[sparsity_type].tolist() 39 | p_obs_drop = 100 * len(self.dropouts) / dofs 40 | testid1 = 'custom' 41 | 42 | self.error_type = error_type 43 | match error_type: 44 | case 'No error': 45 | testid4 = 'nonerr' 46 | case 'Value error': 47 | testid4 = 'valerr' 48 | case 'Linear model': 49 | testid4 = 'linmod' 50 | case 'Force missing': 51 | testid4 = 'frcmia' 52 | 53 | testid2 = f'{int(p_obs_drop):d}dr' 54 | testid3 = f'{dofs:d}dof' 55 | 56 | # self.nonlin_type = nonlin_type 57 | # match self.nonlin_type: 58 | # case 'exponent_damping': 59 | # testid2 = 'exd' 60 | # case 'vanDerPol_damping': 61 | # testid2 = 'vdp' 62 | # case 'duffing_stiffness': 63 | # testid2 = 'duf' 64 | self.snr = snr 65 | testid5 = f'{int(snr):d}snr' 66 | 67 | self.test_id = f'sr_{testid1}_{testid2}_{testid3}_{testid4}_{testid5}' 68 | 69 | self.date_id = date.today().strftime("%Y%m%d") 70 | 71 | self.file_id = f'{self.date_id}__{self.test_id}' 72 | 73 | 74 | def prescribe_params(self, true_params, min_error = 0.1, max_error = 0.25): 75 | 76 | prescr_params = copy.deepcopy(true_params) 77 | gt_params = copy.deepcopy(true_params) 78 | 79 | if self.error_type == 'Value error': 80 | params_error = ['c_', 'cn_', 'k_', 'kn_'] 81 | for param_key in params_error: 82 | # randomly add or a subtract an error on the value between the min and max error 83 | for i in range(prescr_params[param_key].shape[0]): 84 | np.random.seed(43810+i) 85 | gt_params[param_key][i] += gt_params[param_key][i] * np.random.uniform(low = min_error, high = max_error) * [-1,1][np.random.randint(0, 2)] 86 | elif self.error_type == 'Linear model': 87 | prescr_params['kn_'] = np.zeros_like(true_params['kn_']) 88 | prescr_params['cn_'] = np.zeros_like(true_params['cn_']) 89 | 90 | return prescr_params, gt_params 91 | 92 | def pinn_param_dict(self, prescr_params): 93 | 94 | param_dict = {} 95 | n_dof = prescr_params['m_'].shape[0] 96 | for i in range(n_dof): 97 | for param in ['m_', 'c_', 'k_']: 98 | param_dict.update({f"{param}{i}" : {"type" : "constant", "value" : torch.tensor(prescr_params[param][i], dtype=torch.float32)}}) 99 | if self.error_type != 'Linear model': 100 | match self.nonlin_type: 101 | case 'vanDerPol_damping' | 'exponent_damping': 102 | for i in range(n_dof): 103 | param_dict.update({f"cn_{i}" : {"type" : "constant", "value" : torch.tensor(prescr_params['cn_'][i], dtype=torch.float32)}}) 104 | param_dict.update({f"kn_{i}" : {"type" : "constant", "value" : torch.tensor(prescr_params['kn_'][i], dtype=torch.float32)}}) 105 | case 'duffing_stiffness': 106 | for i in range(n_dof): 107 | param_dict.update({f"kn_{i}" : {"type" : "constant", "value" : torch.tensor(prescr_params['kn_'][i], dtype=torch.float32)}}) 108 | param_dict.update({f"cn_{i}" : {"type" : "constant", "value" : torch.tensor(prescr_params['cn_'][i], dtype=torch.float32)}}) 109 | else: 110 | param_dict.update({f"cn_{i}" : {"type" : "constant", "value" : torch.tensor(0.0, dtype=torch.float32)} for i in range(n_dof)}) 111 | param_dict.update({f"kn_{i}" : {"type" : "constant", "value" : torch.tensor(0.0, dtype=torch.float32)} for i in range(n_dof)}) 112 | 113 | return param_dict 114 | 115 | 116 | class param_est_parser(test_parser): 117 | 118 | def __init__(self, system_type, nonlin_type, n_dof, force_loc, snr, num_time_samps, num_repeats): 119 | 120 | self.system_type = system_type 121 | self.nonlin_type = nonlin_type 122 | self.n_dof = n_dof 123 | self.force_loc = force_loc 124 | self.snr = snr 125 | 126 | match system_type: 127 | case 'first_nonlin': 128 | testid1 = 'firstnln' 129 | case 'inter_nonlin': 130 | testid1 = 'internln' 131 | case 'fully_nonlin': 132 | if nonlin_type == 'vanDerPol_damping': 133 | testid1 = 'vandpd' 134 | else: 135 | testid1 = 'fullnln' 136 | 137 | testid2 = f'{n_dof:d}dof' 138 | 139 | if force_loc == -1: 140 | testid3 = 'fn' 141 | elif force_loc == 0: 142 | testid3 = 'f1' 143 | 144 | self.test_id = f'sr_{testid1}_{testid2}_{testid3}_snr{int(snr):d}' 145 | 146 | self.date_id = date.today().strftime("%Y%m%d") 147 | 148 | self.file_id = f'{self.date_id}__{self.test_id}' 149 | 150 | def prescribe_params(self, cn, kn, dofs): 151 | 152 | match self.system_type: 153 | case 'first_nonlin': 154 | cn_ = np.zeros((dofs)) 155 | cn_[0] = cn 156 | kn_ = np.zeros((dofs)) 157 | kn_[0] = kn 158 | case 'inter_nonlin': 159 | cn_ = np.zeros((dofs)) 160 | for i in range(0, cn_.shape[0], 2): 161 | cn_[i] = cn 162 | kn_ = np.zeros((dofs)) 163 | for i in range(0, kn_.shape[0], 2): 164 | kn_[i] = kn 165 | case 'fully_nonlin': 166 | cn_ = cn * np.ones((dofs)) 167 | kn_ = kn * np.ones((dofs)) 168 | return cn_, kn_ 169 | 170 | def pinn_param_dict(self, m_, c_, k_, cn_, kn_): 171 | 172 | param_dict = { 173 | "m_" : { 174 | "type" : "constant", 175 | "value" : torch.tensor(m_, dtype=torch.float32) 176 | }, 177 | "c_" : { 178 | "type" : "variable", 179 | "value" : torch.ones(c_.shape[0], dtype=torch.float32) 180 | }, 181 | "k_" : { 182 | "type" : "variable", 183 | "value" : torch.ones(k_.shape[0], dtype=torch.float32) 184 | }, 185 | } 186 | match self.system_type: 187 | case 'first_nonlin': 188 | param_dict['cn_'] = { 189 | 'type' : 'variable', 190 | 'value' : torch.tensor(0.0, dtype=torch.float32) 191 | } 192 | param_dict['kn_'] = { 193 | 'type' : 'variable', 194 | 'value' : torch.tensor(1.0, dtype=torch.float32) 195 | } 196 | case 'inter_nonlin' | 'fully_nonlin': 197 | if self.nonlin_type == 'duffing_stiffness': 198 | param_dict['cn_'] = { 199 | 'type' : 'variable', 200 | 'value' : torch.zeros(self.n_dof, dtype=torch.float32) 201 | } 202 | param_dict['kn_'] = { 203 | 'type' : 'variable', 204 | 'value' : torch.ones(self.n_dof, dtype=torch.float32) 205 | } 206 | elif self.nonlin_type == 'vanDerPol_damping' | 'exponent_damping': 207 | param_dict['cn_'] = { 208 | 'type' : 'variable', 209 | 'value' : torch.ones(self.n_dof, dtype=torch.float32) 210 | } 211 | param_dict['kn_'] = { 212 | 'type' : 'variable', 213 | 'value' : torch.zeros(self.n_dof, dtype=torch.float32) 214 | } 215 | 216 | return param_dict 217 | 218 | def pinn_explc_dict(self, m_, c_, k_, cn_, kn_): 219 | 220 | param_dict = { 221 | "m_" : { 222 | "type" : "constant", 223 | "value" : torch.tensor(m_, dtype=torch.float32) 224 | }, 225 | "c_" : { 226 | "type" : "constant", 227 | "value" : torch.tensor(c_, dtype=torch.float32) 228 | }, 229 | "k_" : { 230 | "type" : "constant", 231 | "value" : torch.tensor(k_, dtype=torch.float32) 232 | }, 233 | } 234 | match self.system_type: 235 | case 'first_nonlin': 236 | param_dict['cn_'] = { 237 | 'type' : 'variable', 238 | 'value' : torch.tensor(0.0, dtype=torch.float32) 239 | } 240 | param_dict['kn_'] = { 241 | 'type' : 'variable', 242 | 'value' : torch.tensor(1.0, dtype=torch.float32) 243 | } 244 | case 'inter_nonlin' | 'fully_nonlin': 245 | param_dict['cn_'] = { 246 | 'type' : 'constant', 247 | 'value' : torch.tensor(cn_, dtype=torch.float32) 248 | } 249 | param_dict['kn_'] = { 250 | 'type' : 'constant', 251 | 'value' : torch.tensor(kn_, dtype=torch.float32) 252 | } 253 | 254 | return param_dict 255 | 256 | 257 | class state_param_parser(test_parser): 258 | 259 | def __init__(self, sparsity_type, nonlin_type, error_type, snr, dofs, p_obs_drop): 260 | 261 | self.nonlin_type = nonlin_type 262 | 263 | self.sparsity_type = sparsity_type 264 | a = np.arange(dofs) 265 | match sparsity_type: 266 | case 'domain_interpolation': 267 | if p_obs_drop == 40.0: 268 | self.dropouts = np.array([1, 3, 6, 8, 11, 13, 16, 18]) 269 | self.dropouts = np.delete(self.dropouts, np.argwhere(self.dropouts>dofs)).tolist() 270 | else: 271 | # self.dropouts = [1, 3] 272 | step = round(100/(100-p_obs_drop)) 273 | # self.dropouts = a[a%step==1].tolist() 274 | dropout_dropouts = a[1::step].tolist() 275 | self.dropouts = np.delete(a, dropout_dropouts).tolist() 276 | testid1 = 'inter' 277 | case 'domain_extension': 278 | # self.dropouts = [0, 1] 279 | self.dropouts = a[a np.ndarray: 370 | 371 | ns = x.shape[0] 372 | nd = x.shape[1] 373 | x_noisy = np.zeros_like(x) 374 | 375 | match [db, SNR]: 376 | case [float(), None]: 377 | noise_amp = 10.0 ** (db / 10.0) 378 | for i in range(nd): 379 | np.random.seed(seed + i) 380 | noise_x = np.random.normal(loc=0.0, scale=np.sqrt(noise_amp), size=ns) 381 | x_noisy[:,i] = x[:,i] + noise_x 382 | case [None, float()]: 383 | P_sig_ = 10 * np.log10(np.mean(np.mean(x**2, axis=1), axis=0)) 384 | P_noise_ = P_sig_ - SNR 385 | noise_amp_ = 10 ** (P_noise_ / 10.0) 386 | for i in range(nd): 387 | np.random.seed(seed + i) 388 | if np.mean(x[:, i]**2) == 0: 389 | # noise_x = np.random.normal(loc=0.0, scale=np.sqrt(noise_amp_), size=ns) 390 | noise_x = np.zeros(ns) 391 | else: 392 | P_sig = 10 * np.log10(np.mean(x[:, i]**2)) 393 | P_noise = P_sig - SNR 394 | noise_amp = 10 ** (P_noise / 10.0) 395 | noise_x = np.random.normal(loc=0.0, scale=np.sqrt(noise_amp), size=ns) 396 | x_noisy[:,i] = x[:,i] + noise_x 397 | case [float(), float()]: 398 | raise Exception("Over specified, please select either db or SNR") 399 | case [None, None]: 400 | raise Exception("No noise level specified") 401 | return x_noisy 402 | 403 | -------------------------------------------------------------------------------- /00_mdof_osa/mdof_osa_pinn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import matplotlib.pyplot as plt 6 | 7 | from scipy.stats import qmc 8 | 9 | from tqdm import tqdm 10 | from tqdm.auto import tqdm as tqdma 11 | from IPython import display 12 | 13 | from typing import Tuple, Union 14 | Tensor = Union[torch.Tensor, np.ndarray] 15 | 16 | def max_mag_data(data: Tensor, axis: int = None) -> Tensor: 17 | """ 18 | Compute the maximum magnitude of data along the specified axis. 19 | """ 20 | if torch.is_tensor(data): 21 | if axis is None: 22 | data_max = torch.max(torch.max(torch.abs(data))) 23 | else: 24 | data_max = torch.max(torch.abs(data),dim=axis)[0] 25 | else: 26 | data_max = np.max(np.abs(data),axis=axis) 27 | return data_max 28 | 29 | def range_data(data: Tensor, axis: int = None) -> Tensor: 30 | """ 31 | Compute the range of data along the specified axis. 32 | """ 33 | if torch.is_tensor(data): 34 | if axis is None: 35 | data_range = torch.max(torch.max(data)) - torch.min(torch.min(data)) 36 | else: 37 | data_range = torch.max(data,dim=axis)[0] - torch.min(data,dim=axis)[0] 38 | else: 39 | data_range = np.max(data, axis=axis) - np.min(data, axis=axis) 40 | return data_range 41 | 42 | def normalise(data: Tensor, norm_type: str = "var", norm_dir: str = "all") -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 43 | """ 44 | Normalize data based on the specified normalization type and direction. 45 | """ 46 | if norm_type=="var": 47 | if len(data.shape)>1 and norm_dir=="axis": 48 | mean = data.mean(axis=0) 49 | std = data.std(axis=0) 50 | else: 51 | mean = data.mean() 52 | std = data.std() 53 | data_norm = (data-mean)/std 54 | return data_norm, (mean, std) 55 | elif norm_type=="range": 56 | if len(data.shape)>1 and norm_dir=="axis": 57 | dmax = range_data(data,axis=0) 58 | else: 59 | dmax = range_data(data) 60 | data_norm = data/dmax 61 | return data_norm, dmax 62 | elif norm_type=="max": 63 | if len(data.shape)>1 and norm_dir=="axis": 64 | dmax = max_mag_data(data,axis=0) 65 | else: 66 | dmax = max_mag_data(data) 67 | data_norm = data/dmax 68 | return data_norm, dmax 69 | 70 | def nonlin_state_transform(z: torch.Tensor) -> torch.Tensor: 71 | n_dof = int(z.shape[0]/2) 72 | return (z[:n_dof,:] - torch.cat((torch.zeros(1, z.shape[1]), z[:n_dof-1, :]), dim=0))**3 73 | 74 | 75 | class osa_pinn_mdof(nn.Module): 76 | 77 | def __init__(self, config): 78 | super().__init__() 79 | self.n_input = config["n_input"] 80 | self.n_output = config["n_output"] 81 | self.n_hidden = config["n_hidden"] 82 | self.n_layers = config["n_layers"] 83 | self.n_dof = config["n_dof"] 84 | self.activation = nn.Tanh 85 | self.device = config["device"] 86 | 87 | self.build_net() 88 | 89 | self.configure(config) 90 | 91 | def build_net(self): 92 | self.net = nn.Sequential( 93 | nn.Sequential(*[nn.Linear(self.n_input, self.n_hidden), self.activation()]), 94 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 95 | nn.Linear(self.n_hidden, self.n_output) 96 | ) 97 | return 0 98 | 99 | def build_ed_net(self): 100 | self.ed_net = nn.Sequential( 101 | nn.Sequential(*[nn.Linear(1, self.n_hidden), self.activation()]), 102 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 103 | nn.Linear(self.n_hidden, self.n_output) 104 | ) 105 | return self.ed_net 106 | 107 | def forward(self, x0, v0, t, f0=None): 108 | if f0 is None: 109 | x = torch.cat((x0, v0, t.view(-1,1)), dim=1) 110 | else: 111 | x = torch.cat((x0, v0, f0, t.view(-1,1)), dim=1) 112 | y = self.net(x) 113 | return y 114 | 115 | def configure(self, config): 116 | 117 | self.config = config 118 | 119 | self.nct = config["nct"] # number of time collocation points 120 | self.col_dt = config["col_dt"] # delta-t for time collocation points 121 | 122 | self.set_phys_params() 123 | self.set_norm_params() 124 | 125 | def set_phys_params(self): 126 | config = self.config 127 | self.param_attrs = {} 128 | for param_name, param_dict in config["phys_params"].items(): 129 | self.param_attrs[param_name] = param_dict["type"] 130 | if param_dict["type"] == "constant": 131 | setattr(self,param_name,param_dict["value"]) 132 | elif param_dict["type"] == "variable": 133 | self.register_parameter(param_name, nn.Parameter(torch.ones(self.n_dof))) 134 | if hasattr(self,"M") and hasattr(self,"C") and hasattr(self,"K"): 135 | self.A = torch.cat(( 136 | torch.cat((torch.zeros((self.n_dof,self.n_dof)),torch.eye(self.n_dof)), dim=1), 137 | torch.cat((-torch.linalg.inv(self.M)@self.K, -torch.linalg.inv(self.M)@self.C), dim=1) 138 | ), dim=0) 139 | elif hasattr(self,"M"): 140 | self.m_ = torch.diag(self.M) # takes diagonal from mass matrix if set as constant 141 | 142 | if hasattr(self,"M"): 143 | self.H = torch.cat((torch.zeros((self.n_dof,self.n_dof)),torch.linalg.inv(self.M)), dim=0) 144 | if hasattr(self,"Kn") and config["nonlinearity"]=="cubic": 145 | self.An = torch.cat(( 146 | torch.zeros((self.n_dof,self.n_dof)), 147 | -torch.linalg.inv(self.M)@self.Kn 148 | ), dim=0) 149 | 150 | def set_norm_params(self): 151 | config = self.config 152 | self.alpha_t = config["alphas"]["t"] 153 | self.alpha_x = config["alphas"]["x"] 154 | self.alpha_v = config["alphas"]["v"] 155 | self.alpha_z = torch.cat((self.alpha_x*torch.ones(self.n_dof,1), self.alpha_v*torch.ones(self.n_dof,1)), dim=0).float().to(self.device) 156 | self.alpha_f = config["alphas"]["f"] 157 | for param_name, param_dict in config["phys_params"].items(): 158 | if param_dict["type"] == "variable": 159 | setattr(self,"alpha_"+param_name[:-1], config["alphas"][param_name[:-1]]) 160 | 161 | def set_aux_funcs(self, nonlin_func): 162 | self.nonlin_func = nonlin_func 163 | 164 | def set_switches(self, lambdas: dict) -> None: 165 | switches = {} 166 | for key, value in lambdas: 167 | switches[key] = value>0.0 168 | self.switches = switches 169 | 170 | def set_colls_and_obs(self, t_data, x_data, v_data, f_data=None): 171 | 172 | # _data -> [samples, dof] 173 | n_obs = x_data.shape[0]-1 174 | 175 | # Observation set (uses displacement one data point ahead) 176 | self.x_obs = x_data[:-1,:] # initial displacement input 177 | self.v_obs = v_data[:-1,:] # initial velocity input 178 | self.t_obs = torch.zeros((n_obs,1)) 179 | for i in range(n_obs): 180 | self.t_obs[i] = t_data[i+1] - t_data[i] # time at end of horizon (window) 181 | if f_data is not None: 182 | self.f_obs = f_data[:-1,:] # force input 183 | self.z_obs = torch.cat((x_data[1:,:], v_data[1:,:]), dim=1).requires_grad_() # displacement at end of window (output) 184 | 185 | # Collocation set (sets a copy of the x0, v0 for a vector of time over the time horizon) 186 | x_col = torch.zeros((n_obs*self.nct,self.n_dof)) 187 | v_col = torch.zeros((n_obs*self.nct,self.n_dof)) 188 | t_col = torch.zeros((n_obs*self.nct,1)) 189 | f_col = torch.zeros((n_obs*self.nct,self.n_dof)) 190 | t_pred = torch.zeros((n_obs*self.nct,1)) 191 | 192 | for i in range(n_obs): 193 | for j in range(self.n_dof): 194 | x_col[self.nct*i:self.nct*(i+1),j] = x_data[i,j].item()*torch.ones(self.nct) 195 | v_col[self.nct*i:self.nct*(i+1),j] = v_data[i,j].item()*torch.ones(self.nct) 196 | if f_data is not None: 197 | f_col[self.nct*i:self.nct*(i+1),j] = f_data[i,j].item()*torch.ones(self.nct) 198 | t_col[self.nct*i:self.nct*(i+1),0] = torch.linspace(0, t_data[i+1].item()-t_data[i].item(), self.nct) 199 | 200 | # generates a vector of the time for the predicted output, by simply adding the total window onto the current time in the data 201 | t_pred[self.nct*i:self.nct*(i+1),0] = t_data[i] + torch.linspace(0, t_data[i+1].item()-t_data[i].item(), self.nct) 202 | 203 | self.x_col = x_col.requires_grad_() 204 | self.v_col = v_col.requires_grad_() 205 | self.t_col = t_col.requires_grad_() 206 | if f_data is not None: 207 | self.f_col = f_col.requires_grad_() 208 | 209 | self.ic_ids = torch.argwhere(t_col[:,0]==torch.tensor(0.0)).view(-1) 210 | 211 | return t_pred 212 | 213 | def loss_func(self, obs_data: torch.Tensor, col_data: torch.Tensor, lambdas: dict, ic_ids: Tuple[np.ndarray, None] = None) -> Tuple[torch.Tensor, list, dict]: 214 | 215 | z_obs = obs_data[:, :2*self.n_dof] 216 | x0_obs = obs_data[:, 2*self.n_dof : 3*self.n_dof] 217 | v0_obs = obs_data[:, 3*self.n_dof : 4*self.n_dof] 218 | f0_obs = obs_data[:, 4*self.n_dof : 5*self.n_dof] 219 | t_obs = obs_data[:, -1] 220 | 221 | if self.switches['obs']: 222 | # generate prediction at observation points 223 | if f_obs is None: 224 | zh_obs_hat = self.forward(x0_obs, v0_obs, t_obs) 225 | else: 226 | zh_obs_hat = self.forward(x0_obs, v0_obs, t_obs, f0_obs) 227 | R_obs = torch.sqrt(torch.sum((zh_obs_hat - z_obs)**2, dim=1)) 228 | 229 | x0_col = col_data[..., : self.n_dof].reshape(-1, self.n_dof) 230 | v0_col = col_data[..., self.n_dof : 2*self.n_dof].reshape(-1, self.n_dof) 231 | f0_col = col_data[..., 2*self.n_dof : 3*self.n_dof].reshape(-1, self.n_dof) 232 | t_col = col_data[..., -1].reshape(-1, 1) 233 | 234 | if self.switches['ode']: 235 | # generate prediction over prediction horizon (collocation domain) 236 | if f_col is None: 237 | zp_col_hat = self.forward(x0_col, v0_col, t_col) 238 | else: 239 | zp_col_hat = self.forward(x0_col, v0_col, t_col, f0_col) 240 | 241 | # retrieve derivatives 242 | dzdt = torch.zeros_like(zp_col_hat) 243 | for i in range(zp_col_hat.shape[1]): 244 | dzdt[:, i] = torch.autograd.grad(zp_col_hat[:, i], t_col, torch.ones_like(zp_col_hat[:, i]), create_graph=True)[0][:,0] # ∂_t-hat N_z-hat 245 | 246 | # retrieve physical parameters 247 | if hasattr(self,"A"): 248 | M, C, K = self.M, self.C, self.K 249 | A = self.A 250 | else: 251 | params = {} 252 | for param_name, param_dict in self.config["phys_params"].items(): 253 | if param_dict["type"] == "constant": 254 | params[param_name] = param_dict["value"] 255 | else: 256 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-1]) 257 | M, C, K = self.param_func(params["m_"],params["c_"],params["k_"]) 258 | invM = torch.diag(1/torch.diag(M)) 259 | A = torch.cat(( 260 | torch.cat((torch.zeros((self.n_dof, self.n_dof)), torch.eye(self.n_dof)), dim=1), 261 | torch.cat((-invM @ K, -invM @ C), dim=1) 262 | ), dim=0).requires_grad_() 263 | if f_col is not None: 264 | if hasattr(self,"H"): 265 | H = self.H 266 | else: 267 | H = torch.cat((torch.zeros((self.n_dof, self.n_dof)), invM), dim=0) 268 | 269 | if self.nonlinearity is not None: 270 | An = self.nonlinearity.mat_func(params['kn_'], params['cn_'], invM) 271 | 272 | # calculate ode residual 273 | match [self.nonlinearity, f_col]: 274 | case None, None: 275 | R_ = (self.alpha_z / self.alpha_t) * dzdt.T - A @ (self.alpha_z * zp_col_hat.T) 276 | R_ode = R_[self.n_dof:, :].T 277 | case [_, None]: 278 | gz = self.nonlinearity.gz_func(self.alpha_z*zp_col_hat.T) 279 | R_ = (self.alpha_z / self.alpha_t) * dzdt.T - A @ (self.alpha_z * zp_col_hat.T) - An @ gz 280 | R_ode = R_[self.n_dof:, :].T 281 | case [None, torch.Tensor()]: 282 | R_ = (self.alpha_z / self.alpha_t) * dzdt.T - A @ (self.alpha_z * zp_col_hat.T) - H @ (self.alpha_f * f_col.T) 283 | R_ode = R_[self.n_dof:, :].T 284 | case [_, torch.Tensor()]: 285 | gz = self.nonlinearity.gz_func(self.alpha_z * zp_col_hat.T) 286 | R_ = (self.alpha_z / self.alpha_t) * dzdt.T - A @ (self.alpha_z * zp_col_hat.T) - An @ gz - H @ (self.alpha_f * f_col.T) 287 | R_ode = R_[self.n_dof:, :].T 288 | 289 | if self.switches['cc']: 290 | # continuity condition residual 291 | R_cc = R_[:self.n_dof,:].T 292 | else: 293 | R_cc = torch.zeros((2, 2)) 294 | 295 | if self.switches['ic']: 296 | if ic_ids is None: 297 | raise Exception("Initial condition switch is on but no indexes were given") 298 | else: 299 | # initial condition residual 300 | R_ic = self.alpha_z * z0_col[ic_ids, :] - self.alpha_z * zp_col_hat[ic_ids, :] 301 | else: 302 | R_ic = torch.zeros((2, 2)) 303 | 304 | L_obs = lambdas['obs'] * torch.mean(R_obs**2) 305 | L_ic = lambdas['ic'] * torch.sum(torch.mean(R_ic**2, dim=0), dim=0) 306 | L_cc = lambdas['cc'] * torch.sum(torch.mean(R_cc**2, dim=0), dim=0) 307 | L_ode = lambdas['ode'] * torch.sum(torch.mean(R_ode**2, dim=0), dim=0) 308 | 309 | loss = L_obs + L_ic + L_cc + L_ode 310 | 311 | if math.isnan(loss): 312 | raise Exception("Loss is NaN, upsi") 313 | 314 | return loss, [L_obs, L_ic, L_cc, L_ode] 315 | 316 | def predict(self, pred_data): 317 | z0_pred = pred_data['z0_hat'] 318 | t_pred = pred_data["t_hat"] 319 | f_pred = pred_data["f_hat"] 320 | 321 | if f_pred is None: 322 | xp = self.forward(z0_pred, t_pred) 323 | else: 324 | xp = self.forward(z0_pred, t_pred, f_pred) 325 | return xp 326 | 327 | # if self.param_discovery: 328 | # xp_ed = self.ed_net(self.t_ed_col) 329 | # return xp, xp_ed, self.t_ed_col 330 | # else: 331 | # return xp 332 | 333 | 334 | class osa_mdof_dataset(torch.utils.data.Dataset): 335 | 336 | def __init__(self, t_data, x_data, v_data, f_data = None, data_config = None, device = torch.device("cpu")): 337 | 338 | n_dof = x_data.shape[1] 339 | if data_config is dict: 340 | self.subsample = data_config['nct'] # number to subsample 341 | self.nct = data_config['nct'] # number of collocation points 342 | else: 343 | self.subsample = 8 344 | self.nct = 4 345 | nct = self.nct 346 | if x_data.shape[1] != v_data.shape[1]: 347 | raise Exception("Dimension mismatch for data, please check DOFs dimension of data") 348 | 349 | # normalise data based on range 350 | t_data, alpha_t = normalise(t_data, "range") 351 | x_data, alpha_x = normalise(x_data, "range", "all") 352 | v_data, alpha_v = normalise(v_data, "range", "all") 353 | if f_data is not None: 354 | f_data, alpha_f = normalise(f_data, "range", "all") 355 | 356 | # create dataset 357 | nt = t_data.shape[0] 358 | sobol_sampler = qmc.Sobol(d=1, seed=43810) 359 | samples = sobol_sampler.random_base2(m=int(np.log2(nt/32))) 360 | # Scale samples to the desired range. This example assumes you want integers from 0 to nt-1. 361 | sub_ind = np.sort((samples * nt).astype(int), axis=0).squeeze() 362 | # sub_ind = np.sort(qmc.Sobol(d=1, seed=43810).integers(l_bounds=nt, n=int(nt/self.subsample)), axis=0).squeeze() 363 | t_data_sub = t_data[sub_ind] 364 | x_data_sub = x_data[sub_ind, :] 365 | v_data_sub = v_data[sub_ind, :] 366 | f_data_sub = f_data[sub_ind, :] 367 | n_obs = x_data_sub.shape[0] - 1 368 | 369 | # observation set (uses state one data point ahead) 370 | x0_obs = x_data_sub[:-1, :] # initial displacement input 371 | v0_obs = v_data_sub[:-1, :] # initial velocity input 372 | t_obs = t_data_sub[:-1, :] # time location in signal for observation (for plotting) 373 | dt_obs = torch.zeros((n_obs, 1)) # delta t for input to network 374 | for i in range(n_obs): 375 | dt_obs[i] = t_data_sub[i+1] - t_data_sub[i] 376 | if f_data is not None: 377 | f0_obs = f_data_sub[:-1, :] # force input 378 | z_obs = torch.cat((x_data_sub[1:,:], v_data_sub[1:,:]), dim=1).requires_grad_() # displacement at end of window (output) 379 | 380 | # collocation set (sets a copy of x0, v0 for a vector of time over the time horizon) 381 | x0_col = torch.zeros((n_obs, nct, n_dof)) 382 | v0_col = torch.zeros((n_obs, nct, n_dof)) 383 | dt_col = torch.zeros((n_obs, nct, 1)) 384 | t_col = torch.zeros((n_obs, nct, 1)) 385 | if f_data is not None: 386 | f0_col = torch.zeros((n_obs, nct, n_dof)) 387 | 388 | for i in range(n_obs): 389 | x0_col[i, :, :] = x_data_sub[i, :]*torch.ones((nct, n_dof)) 390 | v0_col[i, :, :] = v_data_sub[i, :]*torch.ones((nct, n_dof)) 391 | dt_col[i, :, 0] = torch.linspace(0, t_data_sub[i+1].item() - t_data_sub[i].item(), nct) 392 | if f_data is not None: 393 | f0_col[i, :, :] = f_data_sub[i, :]*torch.ones((nct, n_dof)) 394 | 395 | # generates a vector of the time for the predicted output, by simply adding the total window onto the current time in the data 396 | t_col[i, :, 0] = t_data_sub[i] + torch.linspace(0, t_data_sub[i+1].item() - t_data_sub[i].item(), nct) 397 | 398 | if f_data is not None: 399 | # concatenate into one large dataset 400 | data = torch.cat((x_data, v_data, f_data, t_data), dim=1) 401 | obs_data = torch.cat((z_obs, x0_obs, v0_obs, f0_obs, dt_obs, t_obs), dim=1) 402 | col_data = torch.cat((x0_col, v0_col, f0_col, dt_col, t_col), dim=2) 403 | self.alphas = { 404 | "x" : alpha_x, 405 | "v" : alpha_v, 406 | "f" : alpha_f, 407 | "t" : alpha_t 408 | } 409 | else: 410 | # concatenate into one large dataset 411 | data = torch.cat((x_data, v_data, t_data), dim=1) 412 | obs_data = torch.cat((z_obs, t_obs, x0_obs, v0_obs), dim=1) 413 | col_data = torch.cat((t_col, x0_col, v0_col), dim=2) 414 | self.alphas = { 415 | "x" : alpha_x, 416 | "v" : alpha_v, 417 | "t" : alpha_t 418 | } 419 | 420 | self.ground_truth = data.to(device) 421 | self.obs_data = obs_data.to(device) 422 | self.col_data = col_data.to(device) 423 | 424 | def __getitem__(self, index: int) -> np.ndarray: 425 | return self.obs_data[index, ...], self.col_data[index, ...] 426 | 427 | def get_original(self, index: int) -> np.ndarray: 428 | return self.ground_truth[index] 429 | 430 | def __len__(self) -> int: 431 | return self.obs_data.shape[0] 432 | 433 | def __repr__(self) -> str: 434 | return self.__class__.__name__ 435 | 436 | 437 | 438 | class bbnn(nn.Module): 439 | 440 | def __init__(self, N_INPUT, N_OUTPUT, N_HIDDEN, N_LAYERS): 441 | super().__init__() 442 | self.n_input = N_INPUT 443 | self.n_output = N_OUTPUT 444 | self.n_hidden = N_HIDDEN 445 | self.n_layers = N_LAYERS 446 | self.activation = nn.Tanh 447 | 448 | self.build_net() 449 | 450 | def build_net(self): 451 | self.net = nn.Sequential( 452 | nn.Sequential(*[nn.Linear(self.n_input, self.n_hidden), self.activation()]), 453 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 454 | nn.Linear(self.n_hidden, self.n_output) 455 | ) 456 | return self.net 457 | 458 | def forward(self, x): 459 | x = self.net(x) 460 | return x 461 | 462 | def predict(self, tp): 463 | yp = self.forward(tp) 464 | return yp 465 | 466 | def loss_func(self, x_obs, y_obs): 467 | yp_obs = self.forward(x_obs) 468 | loss = torch.mean((yp_obs - y_obs)**2) 469 | return loss 470 | 471 | class ParamClipper(object): 472 | 473 | def __init__(self, frequency=5): 474 | self.frequency = frequency 475 | 476 | def __call__(self, module): 477 | if hasattr(module, 'phys_params'): 478 | params = module.phys_params.data 479 | params = params.clamp(0,1) 480 | module.phys_params.data = params 481 | -------------------------------------------------------------------------------- /_mdofPINN/pinnUtils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Union, Tuple 4 | 5 | import matplotlib.pyplot as plt 6 | from IPython import display 7 | 8 | from tqdm import tqdm 9 | from tqdm.auto import tqdm as tqdma 10 | 11 | Tensor = Union[torch.Tensor, np.ndarray] 12 | TensorFloat = Union[torch.Tensor, float] 13 | 14 | class ParamClipper(object): 15 | 16 | def __init__(self, param_lims: dict=None): 17 | self.param_lims = param_lims 18 | 19 | def __call__(self, module): 20 | 21 | for i in range(module.n_dof): 22 | if hasattr(module, f'c_{i}'): 23 | params_c = getattr(module, f'c_{i}').data 24 | params_c = params_c.clamp(0, None) 25 | getattr(module, f'c_{i}').data = params_c 26 | if hasattr(module, f'k_{i}'): 27 | params_k = getattr(module, f'k_{i}').data 28 | params_k = params_k.clamp(0, None) 29 | getattr(module, f'k_{i}').data = params_k 30 | if hasattr(module, f'kn_{i}'): 31 | params_kn = getattr(module, f'kn_{i}').data 32 | params_kn = params_kn.clamp(0, None) 33 | getattr(module, f'kn_{i}').data = params_kn 34 | if hasattr(module, f'cn_{i}'): 35 | params_cn = getattr(module, f'cn_{i}').data 36 | params_cn = params_cn.clamp(0, None) 37 | getattr(module, f'cn_{i}').data = params_cn 38 | 39 | def dropout(dropouts, *data_): 40 | data_dropped = [None] * len(data_) 41 | for i, data in enumerate(data_): 42 | data_dropped[i] = data.clone() 43 | for j in dropouts: 44 | data_dropped[i][:,j] = torch.zeros_like(data[:,j]) 45 | 46 | return data_dropped 47 | 48 | def max_mag_data(data: Tensor, axis: int = None) -> Tensor: 49 | """ 50 | Compute the maximum magnitude of data along the specified axis. 51 | """ 52 | if torch.is_tensor(data): 53 | if axis is None: 54 | data_max = torch.max(torch.max(torch.abs(data))) 55 | else: 56 | data_max = torch.max(torch.abs(data),dim=axis)[0] 57 | else: 58 | data_max = np.max(np.abs(data),axis=axis) 59 | return data_max 60 | 61 | def range_data(data: Tensor, axis: Tuple[int, None] = None) -> Tensor: 62 | """ 63 | Compute the range of data along the specified axis. 64 | """ 65 | if torch.is_tensor(data): 66 | if axis is None: 67 | data_range = torch.max(torch.max(data)) - torch.min(torch.min(data)) 68 | else: 69 | data_range = torch.max(data,dim=axis)[0] - torch.min(data,dim=axis)[0] 70 | else: 71 | data_range = np.max(data, axis=axis) - np.min(data, axis=axis) 72 | return data_range 73 | 74 | def normalise(data: Tensor, norm_type: str = "var", norm_dir: str = "all") -> Tuple[Tensor, Tuple[Tensor, Tensor]]: 75 | """ 76 | Normalize data based on the specified normalization type and direction. 77 | """ 78 | if norm_type=="var": 79 | if len(data.shape)>1 and norm_dir=="axis": 80 | mean = data.mean(axis=0) 81 | std = data.std(axis=0) 82 | else: 83 | mean = data.mean() 84 | std = data.std() 85 | data_norm = (data-mean)/std 86 | return data_norm, (mean, std) 87 | elif norm_type=="range": 88 | if len(data.shape)>1 and norm_dir=="axis": 89 | dmax = range_data(data,axis=0) 90 | else: 91 | dmax = range_data(data) 92 | data_norm = data/dmax 93 | return data_norm, dmax 94 | elif norm_type=="max": 95 | if len(data.shape)>1 and norm_dir=="axis": 96 | dmax = max_mag_data(data,axis=0) 97 | else: 98 | dmax = max_mag_data(data) 99 | data_norm = data/dmax 100 | return data_norm, dmax 101 | 102 | def nonlin_state_transform(z: torch.Tensor) -> torch.Tensor: 103 | n_dof = int(z.shape[0]/2) 104 | return (z[:n_dof,:] - torch.cat((torch.zeros(1, z.shape[1]), z[:n_dof-1, :]), dim=0))**3 105 | 106 | 107 | class mdof_pinn_plotter: 108 | 109 | def __init__(self, n_dof, n_cols, figsize=(18,16)): 110 | 111 | # if n_dof > n_cols: 112 | # sub_rows = n_dof // 3 + int((n_dof%3)!=0) 113 | # sub_cols = n_cols 114 | # else: 115 | # sub_rows = 1 116 | # sub_cols = n_dof 117 | 118 | # mosaic_key = [[None]] * (sub_rows * 2) 119 | # for j in range(sub_rows): 120 | # mosaic_key[2 * j] = [f'dsp_dof_{3*j+d:d}' for d in range(3)] 121 | # mosaic_key[2 * j + 1] = [f'vel_dof_{3*j+d:d}' for d in range(3)] 122 | # mosaic_key.extend([['loss_plot'] * 3] * 5) 123 | 124 | mosaic_key = [[f'dsp_dof_{d:d}', f'vel_dof_{d:d}', f'acc_dof_{d:d}', f'frc_dof_{d:d}'] for d in range(n_dof)] 125 | mosaic_key.extend([['loss_plot'] * 4] * 5) 126 | 127 | self.fig, self.axs = plt.subplot_mosaic( 128 | mosaic_key, 129 | figsize=figsize, 130 | facecolor='w' 131 | ) 132 | 133 | def plot_joint_loss_hist(self, ax, loss_hist, pinn_type='normal'): 134 | n_epoch = len(loss_hist) 135 | indices = np.arange(1,n_epoch+1) 136 | if n_epoch > 20000: 137 | step = int(np.floor(n_epoch/10000)) 138 | loss_hist = loss_hist[::step,:] 139 | indices = indices[::step] 140 | if pinn_type == 'normal': 141 | labels = ["L_obs", "L_occ", "L_cc", "L_ode", "L"] 142 | else: 143 | labels = ["L_obs", "L_f", "L_cc", "L_ode", "L"] 144 | colors = ["tab:blue", "tab:purple", "tab:red", "tab:green", "black"] 145 | ax.cla() 146 | for i in range(len(labels)): 147 | ax.plot(indices, loss_hist[:,i], color=colors[i], label=labels[i]) 148 | ax.set_yscale('log') 149 | ax.legend() 150 | 151 | def sort_data(self, vec2sort: np.ndarray, *data_: tuple[np.ndarray,...]): 152 | sort_ids = np.argsort(vec2sort) 153 | sorted_data_ = [None] * len(data_) 154 | for i, data in enumerate(data_): 155 | sorted_data_[i] = np.zeros_like(data) 156 | if len(data.shape) > 1: 157 | for j in range(data.shape[1]): 158 | sorted_data_[i][:,j] = data[sort_ids,j].squeeze() 159 | else: 160 | sorted_data_[i] = data[sort_ids] 161 | if len(data_) > 1: 162 | return tuple(sorted_data_), sort_ids 163 | else: 164 | return sorted_data_[0], sort_ids 165 | 166 | def plot_result(self, axs_m, ground_truth, obs_data, prediction, alphas, n_dof, eq_pred = False): 167 | for ax in axs_m: 168 | axs_m[ax].cla() 169 | xL = np.amax(ground_truth["t"]) 170 | for dof in range(n_dof): 171 | axs_m[f'dsp_dof_{dof:d}'].plot(ground_truth["t"], ground_truth["x_hat"][:,dof], color="grey", linewidth=0.5, alpha=0.5, label="Exact solution") 172 | axs_m[f'dsp_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["x_hat"][:,dof]*alphas["x"].item(), color="tab:blue", linewidth=0.5, alpha=0.8, linestyle='--', label="Neural network prediction") 173 | yLx = np.amax(np.abs(ground_truth["x_hat"][:,dof])) 174 | axs_m[f'dsp_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 175 | axs_m[f'dsp_dof_{dof:d}'].set_ylim(-1.1*yLx, 1.1*yLx) 176 | 177 | axs_m[f'vel_dof_{dof:d}'].plot(ground_truth["t"], ground_truth["v_hat"][:,dof], color="grey", linewidth=0.5, alpha=0.5, label="Exact solution") 178 | axs_m[f'vel_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["v_hat"][:,dof]*alphas["v"].item(), color="tab:red", linewidth=0.5, alpha=0.8, linestyle='--', label="Neural network prediction") 179 | yLv = np.amax(np.abs(ground_truth["v_hat"][:,dof])) 180 | axs_m[f'vel_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 181 | axs_m[f'vel_dof_{dof:d}'].set_ylim(-1.1*yLv, 1.1*yLv) 182 | 183 | axs_m[f'acc_dof_{dof:d}'].plot(obs_data["t_hat"] * alphas["t"].item(), obs_data["a_hat"][:, dof] * alphas["a"].item(), color="tab:olive", linewidth=0.5, alpha=0.8, label='Training Data') 184 | axs_m[f'acc_dof_{dof:d}'].plot(ground_truth["t"], ground_truth["a_hat"][:,dof], color="grey", linewidth=0.5, alpha=0.5, label="Exact solution") 185 | axs_m[f'acc_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["a_hat"][:,dof]*alphas["a"].item(), color="tab:orange", linewidth=0.5, alpha=0.8, linestyle='--', label="Neural network prediction") 186 | yla = np.amax(np.abs(ground_truth["a_hat"][:,dof])) 187 | axs_m[f'acc_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 188 | axs_m[f'acc_dof_{dof:d}'].set_ylim(-1.1*yla, 1.1*yla) 189 | 190 | axs_m[f'frc_dof_{dof:d}'].plot(obs_data["t_hat"] * alphas["t"].item(), obs_data["f_hat"][:, dof] * alphas["f"].item(), color="tab:olive", linewidth=0.5, alpha=0.8, label='Training Data') 191 | axs_m[f'frc_dof_{dof:d}'].plot(ground_truth["t"], ground_truth["f_hat"][:,dof], color="grey", linewidth=0.5, alpha=0.5, label="Exact solution") 192 | axs_m[f'frc_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["f_hat"][:,dof]*alphas["f"].item(), color="tab:purple", linewidth=0.5, alpha=0.8, linestyle='--', label="Neural network prediction") 193 | if eq_pred: 194 | axs_m[f'frc_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["f_hat_eq"][:,dof]*alphas["f"].item(), color="tab:cyan", linewidth=0.5, alpha=0.8, linestyle='--', label="Equation prediction") 195 | ylf = np.amax(np.abs(ground_truth["f_hat"][:,dof])) 196 | axs_m[f'frc_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 197 | axs_m[f'frc_dof_{dof:d}'].set_ylim(-1.1*ylf, 1.1*ylf) 198 | 199 | def plot_train_update(self, ground_truth, obs_data, prediction, alphas, n_dof, loss_hist, eq_pred=False): 200 | 201 | self.plot_result(self.axs, ground_truth, obs_data, prediction, alphas, n_dof, eq_pred) 202 | self.plot_joint_loss_hist(self.axs['loss_plot'], np.array(loss_hist), pinn_type='spi' if eq_pred else 'normal') 203 | 204 | class mdof_pinn_trainer: 205 | 206 | def __init__(self, train_dataset, data_config, n_dof, device, train_loader, col_domain=True): 207 | 208 | self.train_dataset = train_dataset 209 | self.device = device 210 | self.data_config = data_config 211 | self.train_loader = train_loader 212 | self.n_dof = n_dof 213 | self.col_domain = col_domain 214 | 215 | self.num_obs_samps = len(train_dataset) * data_config['seq_len'] 216 | self.num_col_samps = len(train_dataset) * data_config['subsample'] * data_config['seq_len'] 217 | 218 | def print_params(self, mdof_model, gt_params) -> str: 219 | write_str = 'c : k : cn : kn : \n' 220 | for i in range(mdof_model.n_dof): 221 | for param in ['m_', 'c_', 'k_', 'kn_', 'cn_']: 222 | if mdof_model.config["phys_params"][f'{param}{i}']["type"] == 'variable': 223 | write_str += f'{param}{i}: {(getattr(mdof_model, f"{param}{i}").item())*(getattr(mdof_model, f"alpha_{param[:-1]}")):.3f} ' 224 | else: 225 | write_str += f'{param}{i}: {getattr(mdof_model, f"{param}{i}"):.3f} ' 226 | write_str += f'[{gt_params[param][i]:.2f}] ' 227 | write_str += '\n' 228 | return write_str 229 | 230 | def train(self, 231 | num_epochs, 232 | mdof_model, 233 | print_step, 234 | net_optimisers, 235 | plotter, 236 | ground_truth, 237 | pinn_config, 238 | param_optimiser = None, 239 | loss_hist = None, 240 | print_params = False, 241 | param_clipper = None, 242 | schedulers = None, 243 | profile = False, 244 | ): 245 | 246 | self.prediction = { 247 | "t_hat" : None, 248 | "x_hat" : None, 249 | "v_hat" : None, 250 | "a_hat" : None, 251 | "f_hat" : None 252 | } 253 | 254 | self.obs_data_dict = { 255 | "t_hat" : None, 256 | "a_hat" : None, 257 | "f_hat" : None 258 | } 259 | 260 | if 'acc_obs_method' in pinn_config.keys(): 261 | acc_obs_method = pinn_config['acc_obs_method'] 262 | else: 263 | acc_obs_method = 'obs_model' 264 | 265 | epoch = 0 266 | if loss_hist is None: 267 | self.loss_hist = [] 268 | else: 269 | self.loss_hist = loss_hist 270 | progress_bar = tqdm(total=num_epochs) 271 | 272 | if profile: 273 | profile.start() 274 | try: 275 | while epoch < num_epochs: 276 | 277 | write_string = '' 278 | write_string += f'Epoch {epoch:d}\n' 279 | phase_loss = 0. 280 | losses = [0.0] * 4 281 | mdof_model.train() 282 | for i, (obs_data, col_data) in enumerate(self.train_loader): 283 | if profile: 284 | profile.step() 285 | 286 | ### parse data 287 | acc_obs = obs_data[..., :self.n_dof].float().to(self.device).requires_grad_() # [sample, sequence, dof] 288 | f_obs = obs_data[..., self.n_dof:2*self.n_dof].float().to(self.device).requires_grad_() # [sample, sequence, dof] 289 | time_obs_ = [obs_data[:, nq, -1].reshape(-1, 1).to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] 290 | 291 | ### parse data into lists 292 | # acc_obs = [obs_data[:, nq, :self.n_dof].float().to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] # [samples, n_dof] * seq_len 293 | # f_obs = [obs_data[:, nq, self.n_dof:2*self.n_dof].float().to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] 294 | # time_obs_ = [obs_data[:, nq, -1].reshape(-1, 1).to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] 295 | 296 | # unroll collocation data 297 | if self.col_domain: 298 | time_col_ = [col_data[:, :, nq, -1].reshape(-1, 1).to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] 299 | force_col_ = [col_data[:, :, nq, self.n_dof:-1].reshape(-1, self.n_dof).float().to(self.device) for nq in range(self.data_config['seq_len'])] 300 | else: 301 | time_col_ = None 302 | force_col_ = None 303 | 304 | ### Calculate loss and backpropagate 305 | for optim in net_optimisers: 306 | optim.zero_grad() 307 | if param_optimiser is not None: 308 | param_optimiser.zero_grad() 309 | loss, losses_i, _ = mdof_model.loss_func(time_obs_, acc_obs, f_obs, time_col_, force_col_, pinn_config['lambds'], pinn_config['dropouts'], acc_obs_method) 310 | phase_loss += loss.item() 311 | losses = [losses[j] + loss_i for j, loss_i in enumerate(losses_i)] 312 | loss.backward() 313 | for optim in net_optimisers: 314 | optim.step() 315 | if param_optimiser is not None: 316 | param_optimiser.step() 317 | if param_clipper is not None: 318 | mdof_model.apply(param_clipper) 319 | 320 | self.loss_hist.append([loss_it.item() for loss_it in losses] + [phase_loss]) 321 | write_string += f'\tLoss {loss:.4e}\n' 322 | if schedulers is not None: 323 | for scheduler in schedulers: 324 | scheduler.step() 325 | write_string += f'\tLearning rate: {schedulers[0].get_last_lr()}\n' 326 | 327 | if (epoch + 1) % print_step == 0: 328 | 329 | mdof_model.eval() 330 | 331 | t_obs = np.zeros((self.num_obs_samps, 1)) 332 | a_obs = np.zeros((self.num_obs_samps, self.n_dof)) 333 | f_obs = np.zeros((self.num_obs_samps, self.n_dof)) 334 | 335 | t_pred = np.zeros((self.num_col_samps, 1)) 336 | z_pred = np.zeros((self.num_col_samps, 2*self.n_dof)) 337 | a_pred = np.zeros((self.num_col_samps, self.n_dof)) 338 | f_pred = np.zeros((self.num_col_samps, self.n_dof)) 339 | if mdof_model.pinn_type == 'spi': 340 | f_pred_eq = np.zeros((self.num_col_samps, self.n_dof)) 341 | 342 | for i, (obs_data, col_data) in enumerate(self.train_loader): 343 | 344 | inpoint_o = i * self.data_config['batch_size'] * self.data_config['seq_len'] 345 | outpoint_o = (i+1) * self.data_config['batch_size'] * self.data_config['seq_len'] 346 | t_obs[inpoint_o:outpoint_o] = obs_data[..., -1].cpu().reshape(-1,1) 347 | a_obs[inpoint_o:outpoint_o] = obs_data[..., :self.n_dof].cpu().reshape(-1, self.n_dof) 348 | f_obs[inpoint_o:outpoint_o] = obs_data[..., self.n_dof:2*self.n_dof].cpu().reshape(-1, self.n_dof) 349 | 350 | inpoint_ = i * self.data_config['batch_size'] * self.data_config['subsample'] * self.data_config['seq_len'] 351 | outpoint_ = (i + 1) * self.data_config['batch_size'] * self.data_config['subsample'] * self.data_config['seq_len'] 352 | if self.col_domain: 353 | t_pred_list = [obs_data[:, nq, -1].reshape(-1, 1).to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] 354 | else: 355 | t_pred_list = [col_data[:, :, nq, -1].reshape(-1, 1).to(self.device).requires_grad_() for nq in range(self.data_config['seq_len'])] 356 | if mdof_model.pinn_type == 'spi': 357 | z_pred_, f_pred_, a_pred_, t_pred_, f_pred_eq_ = mdof_model.predict(t_pred_list) 358 | f_pred_eq[inpoint_:outpoint_, :] = f_pred_eq_.detach().cpu().reshape(-1, self.n_dof).numpy() 359 | else: 360 | z_pred_, f_pred_, a_pred_, t_pred_ = mdof_model.predict(t_pred_list) 361 | 362 | t_pred[inpoint_:outpoint_] = t_pred_.detach().cpu().numpy().reshape(-1,1) 363 | z_pred[inpoint_:outpoint_, :] = z_pred_.detach().cpu().reshape(-1, 2*self.n_dof).numpy() 364 | a_pred[inpoint_:outpoint_, :] = a_pred_.detach().cpu().reshape(-1, self.n_dof).numpy() 365 | f_pred[inpoint_:outpoint_, :] = f_pred_.detach().cpu().reshape(-1, self.n_dof).numpy() 366 | 367 | (a_obs, f_obs, t_obs), _ = plotter.sort_data(t_obs[:,0], a_obs, f_obs, t_obs) 368 | if mdof_model.pinn_type == 'spi': 369 | (z_pred, f_pred, a_pred, t_pred, f_pred_eq), _ = plotter.sort_data(t_pred[:,0], z_pred, f_pred, a_pred, t_pred, f_pred_eq) 370 | self.prediction["f_hat_eq"] = f_pred_eq 371 | else: 372 | (z_pred, f_pred, a_pred, t_pred), _ = plotter.sort_data(t_pred[:,0], z_pred, f_pred, a_pred, t_pred) 373 | 374 | self.prediction['t_hat'] = t_pred 375 | self.prediction["x_hat"] = z_pred[:, :self.n_dof] 376 | self.prediction["v_hat"] = z_pred[:, self.n_dof:] 377 | self.prediction["a_hat"] = a_pred 378 | self.prediction["f_hat"] = f_pred 379 | 380 | self.obs_data_dict['t_hat'] = t_obs 381 | self.obs_data_dict['a_hat'] = a_obs 382 | self.obs_data_dict['f_hat'] = f_obs 383 | 384 | if mdof_model.pinn_type == 'spi': 385 | plotter.plot_train_update(ground_truth, self.obs_data_dict, self.prediction, pinn_config['alphas'], self.n_dof, self.loss_hist, eq_pred=True) 386 | else: 387 | plotter.plot_train_update(ground_truth, self.obs_data_dict, self.prediction, pinn_config['alphas'], self.n_dof, self.loss_hist) 388 | 389 | display.clear_output(wait=True) 390 | display.display(plt.gcf()) 391 | if print_params: 392 | write_string += self.print_params(mdof_model, ground_truth['params']) 393 | tqdma.write(write_string) 394 | epoch += 1 395 | progress_bar.update(1) 396 | except KeyboardInterrupt: 397 | progress_bar.close() 398 | 399 | if profile: 400 | profile.stop() 401 | 402 | display.clear_output() 403 | 404 | print(write_string) 405 | 406 | class mdof_stoch_pinn_plotter: 407 | 408 | def __init__(self, n_dof, n_cols, figsize=(18,16), plot_force=False): 409 | 410 | self.plot_force = plot_force 411 | 412 | if n_dof > n_cols: 413 | sub_rows = n_dof // n_cols + int((n_dof%n_cols)!=0) 414 | else: 415 | sub_rows = 1 416 | 417 | if plot_force: 418 | mosaic_key = [[None]] * (sub_rows * 3) 419 | for j in range(sub_rows): 420 | mosaic_key[3 * j] = [f'dsp_dof_{n_cols*j+d:d}' for d in range(n_cols)] 421 | mosaic_key[3 * j + 1] = [f'vel_dof_{n_cols*j+d:d}' for d in range(n_cols)] 422 | mosaic_key[3 * j + 2] = [f'frc_dof_{n_cols*j+d:d}' for d in range(n_cols)] 423 | else: 424 | mosaic_key = [[None]] * (sub_rows * 2) 425 | for j in range(sub_rows): 426 | mosaic_key[2 * j] = [f'dsp_dof_{n_cols*j+d:d}' for d in range(n_cols)] 427 | mosaic_key[2 * j + 1] = [f'vel_dof_{n_cols*j+d:d}' for d in range(n_cols)] 428 | mosaic_key.extend([['loss_plot'] * n_cols] * sub_rows) 429 | 430 | self.fig, self.axs = plt.subplot_mosaic( 431 | mosaic_key, 432 | figsize=figsize, 433 | facecolor='w' 434 | ) 435 | 436 | def plot_joint_loss_hist(self, ax, loss_hist): 437 | n_epoch = len(loss_hist) 438 | indices = np.arange(1,n_epoch+1) 439 | if n_epoch > 20000: 440 | step = int(np.floor(n_epoch/10000)) 441 | loss_hist = loss_hist[::step,:] 442 | indices = indices[::step] 443 | labels = ["L_obs", "L_cc", "L_ode", "L_nc", "L"] 444 | colors = ["tab:blue", "tab:red", "tab:green", "tab:purple", "black"] 445 | ax.cla() 446 | for i in range(len(labels)): 447 | ax.plot(indices, loss_hist[:,i], color=colors[i], label=labels[i]) 448 | ax.set_yscale('symlog') 449 | # if np.amin(loss_hist) < 1e-3: 450 | ax.set_ylim(-1e5, -1e3) 451 | ax.legend() 452 | 453 | def sort_data(self, vec2sort: np.ndarray, *data_: tuple[np.ndarray,...]): 454 | sort_ids = np.argsort(vec2sort) 455 | sorted_data_ = [None] * len(data_) 456 | for i, data in enumerate(data_): 457 | sorted_data_[i] = np.zeros_like(data) 458 | if len(data.shape) > 1: 459 | for j in range(data.shape[1]): 460 | sorted_data_[i][:,j] = data[sort_ids,j].squeeze() 461 | else: 462 | sorted_data_[i] = data[sort_ids] 463 | if len(data_) > 1: 464 | return tuple(sorted_data_), sort_ids 465 | else: 466 | return sorted_data_[0], sort_ids 467 | 468 | def plot_result(self, axs_m, ground_truth, obs_data, prediction, alphas, n_dof): 469 | for ax in axs_m: 470 | axs_m[ax].cla() 471 | xL = np.amax(ground_truth["t"]) 472 | for dof in range(n_dof): 473 | # displacement 474 | axs_m[f'dsp_dof_{dof:d}'].plot(obs_data["t_hat"] * alphas["t"].item(), obs_data["x_hat"][:, dof] * alphas["x"].item(), color="tab:olive", linewidth=0.5, alpha=0.8, label='Training Data') 475 | axs_m[f'dsp_dof_{dof:d}'].plot(ground_truth["t"], ground_truth["x_hat"][:,dof], color="grey", linewidth=0.5, alpha=0.5, label="Exact solution") 476 | axs_m[f'dsp_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["x_hat"][:,dof]*alphas["x"].item(), color="tab:blue", linewidth=0.5, alpha=0.8, linestyle='--', label="Neural network prediction") 477 | axs_m[f'dsp_dof_{dof:d}'].fill_between((prediction["t_hat"]*alphas["t"].item()).squeeze(), (prediction["x_hat"][:,dof]-2*prediction['sigma_x'])*alphas["x"].item(), (prediction["x_hat"][:,dof]+2*prediction['sigma_x'])*alphas["x"].item(), alpha=0.25, color="tab:blue", label=r"$2\sigma$ Range") 478 | yLx = np.amax(np.abs(ground_truth["x_hat"][:,dof])) 479 | axs_m[f'dsp_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 480 | axs_m[f'dsp_dof_{dof:d}'].set_ylim(-1.1*yLx, 1.1*yLx) 481 | 482 | # velocity 483 | axs_m[f'vel_dof_{dof:d}'].plot(obs_data["t_hat"] * alphas["t"].item(), obs_data["v_hat"][:, dof] * alphas["v"].item(), color="tab:olive", linewidth=0.5, alpha=0.8, label='Training Data') 484 | axs_m[f'vel_dof_{dof:d}'].plot(ground_truth["t"], ground_truth["v_hat"][:,dof], color="grey", linewidth=0.5, alpha=0.5, label="Exact solution") 485 | axs_m[f'vel_dof_{dof:d}'].plot(prediction["t_hat"] * alphas["t"].item(), prediction["v_hat"][:,dof]*alphas["v"].item(), color="tab:red", linewidth=0.5, alpha=0.8, linestyle='--', label="Neural network prediction") 486 | axs_m[f'vel_dof_{dof:d}'].fill_between((prediction["t_hat"]*alphas["t"].item()).squeeze(), (prediction["v_hat"][:,dof]-2*prediction['sigma_v'])*alphas["v"].item(), (prediction["v_hat"][:,dof]+2*prediction['sigma_v'])*alphas["v"].item(), alpha=0.25, color="tab:blue", label=r"$2\sigma$ Range") 487 | yLv = np.amax(np.abs(ground_truth["v_hat"][:,dof])) 488 | axs_m[f'vel_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 489 | axs_m[f'vel_dof_{dof:d}'].set_ylim(-1.1*yLv, 1.1*yLv) 490 | 491 | # force 492 | if self.plot_force: 493 | axs_m[f'frc_dof_{dof:d}'].plot( 494 | prediction["t_hat"] * alphas["t"].item(), 495 | obs_data["f_hat"][:, dof] * alphas["f"].item(), 496 | color="tab:olive", linewidth=0.5, alpha=0.8, label='Observation Data' 497 | ) 498 | axs_m[f'frc_dof_{dof:d}'].plot( 499 | ground_truth["t"], 500 | ground_truth["f_hat"][:,dof], 501 | color="grey", linewidth=0.5, alpha=0.5, label="Exact solution" 502 | ) 503 | axs_m[f'frc_dof_{dof:d}'].plot( 504 | prediction["t_hat"] * alphas["t"].item(), 505 | prediction["f_hat"][:,dof], 506 | color="tab:green", linewidth=0.5, alpha=0.8, linestyle='--', label="Prediction" 507 | ) 508 | # axs_m[f'frc_dof_{dof:d}'].fill_between( 509 | # (prediction["t_hat"] * alphas["t"].item()).squeeze(), 510 | # (prediction["f_hat"][:,dof]) - (2*prediction['sigma_f'])*alphas["f"].item(), 511 | # (prediction["f_hat"][:,dof]) + (2*prediction['sigma_f'])*alphas["f"].item(), 512 | # alpha=0.25, color="tab:blue", label=r"$2\sigma$ Range" 513 | # ) 514 | # yLf = np.amax(np.abs(ground_truth["f_hat"][:,dof])) 515 | yLf = np.amax(np.abs(prediction["f_hat"][:,dof])) 516 | axs_m[f'frc_dof_{dof:d}'].set_xlim(-0.05*xL, 1.05*xL) 517 | axs_m[f'frc_dof_{dof:d}'].set_ylim(-1.1*yLf, 1.1*yLf) 518 | 519 | 520 | def plot_train_update(self, ground_truth, obs_data, prediction, alphas, n_dof, loss_hist): 521 | 522 | self.plot_result(self.axs, ground_truth, obs_data, prediction, alphas, n_dof) 523 | self.plot_joint_loss_hist(self.axs['loss_plot'], np.array(loss_hist)) 524 | 525 | 526 | class mdof_stoch_pinn_trainer: 527 | 528 | def __init__(self, train_dataset, data_config, n_dof, device, train_loader): 529 | 530 | self.train_dataset = train_dataset 531 | self.device = device 532 | self.data_config = data_config 533 | self.train_loader = train_loader 534 | self.n_dof = n_dof 535 | 536 | self.num_obs_samps = len(train_dataset) * data_config['seq_len'] * data_config['num_repeats'] 537 | self.num_col_samps = len(train_dataset) * data_config['subsample'] * data_config['seq_len'] 538 | 539 | def train(self, num_epochs, mdof_model, print_step, optimisers, plotter, ground_truth, pinn_config, loss_hist=None, print_params=False): 540 | 541 | net_optimisers = optimisers['nets'] 542 | noise_optimiser = optimisers['noise'] 543 | 544 | t_obs = np.zeros((self.num_obs_samps, 1)) 545 | z_obs = np.zeros((self.num_obs_samps, 2*self.n_dof)) 546 | 547 | t_pred = np.zeros((self.num_col_samps, 1)) 548 | z_pred = np.zeros((self.num_col_samps, 2*self.n_dof)) 549 | f_pred = np.zeros((self.num_col_samps, self.n_dof)) 550 | 551 | self.prediction = { 552 | "t_hat" : None, 553 | "x_hat" : None, 554 | "v_hat" : None, 555 | "f_hat" : None 556 | } 557 | 558 | self.obs_data_dict = { 559 | "t_hat" : None, 560 | "x_hat" : None, 561 | "v_hat" : None 562 | } 563 | 564 | epoch = 0 565 | if loss_hist is None: 566 | self.loss_hist = [] 567 | else: 568 | self.loss_hist = loss_hist 569 | if 'progress_bar' in globals(): 570 | del progress_bar # noqa: F821 571 | display.clear_output() 572 | progress_bar = tqdm(total=num_epochs) 573 | 574 | try: 575 | while epoch < num_epochs: 576 | 577 | write_string = '' 578 | write_string += f'Epoch {epoch:d}\n' 579 | phase_loss = 0. 580 | losses = [0.0] * 4 581 | for i, (obs_data, col_data) in enumerate(self.train_loader): 582 | 583 | # parse observation domain data 584 | time_obs_ = [obs_data[:, :, nq, -1].reshape(-1, 1).requires_grad_() for nq in range(self.data_config['seq_len'])] 585 | state_obs = [obs_data[:, :, nq, :2*self.n_dof].reshape(-1, 2*self.n_dof).float().to(self.device) for nq in range(self.data_config['seq_len'])] 586 | state_obs = torch.cat([state_obs[nq].unsqueeze(1) for nq in range(self.data_config['seq_len'])], dim=1) 587 | 588 | # parse collocation domain data 589 | time_col_ = [col_data[:, :, nq, -1].reshape(-1, 1).requires_grad_() for nq in range(self.data_config['seq_len'])] 590 | force_col_ = [col_data[:, :, nq, 2*self.n_dof:-1].reshape(-1, self.n_dof).float().to(self.device) for nq in range(self.data_config['seq_len'])] 591 | 592 | # network_optimizer.zero_grad() 593 | for optim in net_optimisers: 594 | optim.zero_grad() 595 | noise_optimiser.zero_grad() 596 | loss, losses_i, _ = mdof_model.loss_func(time_obs_, state_obs, time_col_, force_col_, pinn_config['lambds'], pinn_config['dropouts']) 597 | phase_loss += loss.item() 598 | losses = [losses[j] + loss_i for j, loss_i in enumerate(losses_i)] 599 | loss.backward() 600 | # network_optimizer.step() 601 | for optim in net_optimisers: 602 | optim.step() 603 | noise_optimiser.step() 604 | 605 | self.loss_hist.append([loss_it.item() for loss_it in losses] + [phase_loss]) 606 | write_string += f'\tLoss {loss.item():.4e}\n' 607 | write_string += f'Obs loss: {losses[0].item():.4e}, CC loss: {losses[1].item():.4e}, Ode loss: {losses[2].item():.4e}, Nc loss: {losses[3].item():.4e}\n' 608 | 609 | if (epoch + 1) % print_step == 0: 610 | 611 | obs_step = self.data_config['batch_size'] * self.data_config['seq_len'] * self.data_config['num_repeats'] 612 | col_step = self.data_config['batch_size'] * self.data_config['subsample'] * self.data_config['seq_len'] 613 | 614 | for i, (obs_data, col_data) in enumerate(self.train_loader): 615 | 616 | inpoint_o = i * obs_step 617 | outpoint_o = (i+1) * obs_step 618 | t_obs[inpoint_o:outpoint_o] = obs_data[..., -1].cpu().reshape(-1,1) 619 | z_obs[inpoint_o:outpoint_o] = obs_data[..., :2*self.n_dof].cpu().reshape(-1, 2*self.n_dof) 620 | 621 | inpoint_ = i * col_step 622 | outpoint_ = (i + 1) * col_step 623 | t_col_ = [col_data[:, :, nq, -1].reshape(-1, 1).requires_grad_() for nq in range(self.data_config['seq_len'])] 624 | force_col_ = [col_data[:, :, nq, 2*self.n_dof:-1].reshape(-1, self.n_dof).float().to(self.device) for nq in range(self.data_config['seq_len'])] 625 | # z_pred_, f_pred_, t_pred_, _ = mdof_model.predict(t_col_) 626 | z_pred_, f_pred_, t_pred_, _, force_col = mdof_model.predict(t_col_, f_col=force_col_) 627 | 628 | t_pred[inpoint_:outpoint_] = t_pred_.detach().cpu().numpy().reshape(-1,1) 629 | z_pred[inpoint_:outpoint_, :] = z_pred_.detach().cpu().reshape(-1, 2*self.n_dof).numpy() 630 | f_pred[inpoint_:outpoint_, :] = f_pred_.detach().cpu().reshape(-1, self.n_dof).numpy() 631 | 632 | (z_obs, t_obs), _ = plotter.sort_data(t_obs[:,0], z_obs, t_obs) 633 | (z_pred, f_pred, t_pred), _ = plotter.sort_data(t_pred[:,0], z_pred, f_pred, t_pred) 634 | 635 | self.prediction['t_hat'] = t_pred 636 | self.prediction["x_hat"] = z_pred[:, :self.n_dof] 637 | self.prediction["v_hat"] = z_pred[:, self.n_dof:] 638 | self.prediction["f_hat"] = f_pred 639 | # self.prediction["sigma_z"] = mdof_model.sigma_z.detach().item() 640 | self.prediction["sigma_x"] = mdof_model.sigma_x.detach().item() 641 | self.prediction["sigma_v"] = mdof_model.sigma_v.detach().item() 642 | self.prediction["sigma_f"] = mdof_model.sigma_f.detach().item() 643 | 644 | self.obs_data_dict['t_hat'] = t_obs 645 | self.obs_data_dict['x_hat'] = z_obs[:, :self.n_dof] 646 | self.obs_data_dict['v_hat'] = z_obs[:, self.n_dof:] 647 | self.obs_data_dict['f_hat'] = force_col.detach() 648 | 649 | plotter.plot_train_update(ground_truth, self.obs_data_dict, self.prediction, pinn_config['alphas'], self.n_dof, self.loss_hist) 650 | 651 | display.clear_output(wait=True) 652 | display.display(plt.gcf()) 653 | 654 | # write_string += f'State noise: {self.prediction["sigma_z"]:.4e}\n' 655 | write_string += f'Displ noise: {self.prediction["sigma_x"]:.4e}\n' 656 | write_string += f'Veloc noise: {self.prediction["sigma_v"]:.4e}\n' 657 | write_string += f'Force noise: {self.prediction["sigma_f"]:.4e}\n' 658 | 659 | if print_params: 660 | write_string += 'c : k : cn : kn : \n' 661 | for j in range(self.n_dof): 662 | write_string = '%d : ' % (j+1) 663 | for param in ['c_','k_','cn_','kn_']: 664 | if pinn_config['phys_params'][param]['type']=='constant': 665 | if len(getattr(mdof_model, param).shape) == 0: 666 | write_string += '%.4f ' % getattr(mdof_model, param) 667 | else: 668 | write_string += '%.4f ' % getattr(mdof_model, param)[j] 669 | elif pinn_config['phys_params'][param]['type']=='variable': 670 | if len(getattr(mdof_model, param).shape) == 0: 671 | write_string += '%.4f ' % (getattr(mdof_model, param)*pinn_config['alphas'][param[:-1]]) 672 | else: 673 | write_string += '%.4f ' % (getattr(mdof_model, param)[j]*pinn_config['alphas'][param[:-1]]) 674 | write_string += '[%.4f] ' % ground_truth['params'][param][j] 675 | write_string += (write_string + '\n') 676 | tqdma.write(write_string) 677 | epoch += 1 678 | progress_bar.update(1) 679 | except KeyboardInterrupt: 680 | progress_bar.close() 681 | 682 | display.clear_output() 683 | 684 | print(write_string) 685 | -------------------------------------------------------------------------------- /_mdofPINN/pinnModels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from pinnUtils import normalise 6 | 7 | from typing import Tuple, Union, Optional 8 | Tensor = Union[torch.Tensor, np.ndarray] 9 | TensorFloat = Union[torch.Tensor, float] 10 | 11 | class nonlinearity: 12 | 13 | def __init__(self, dofs, gk_exp = None, gc_exp = None): 14 | 15 | self.dofs = dofs 16 | self.gk_exp = gk_exp 17 | self.gc_exp = gc_exp 18 | 19 | def Kn_func(self, kn_): 20 | 21 | Kn = torch.diag(kn_) - torch.diag(kn_[1:], 1) 22 | return Kn 23 | 24 | def gk_func(self, x, xdot): 25 | if self.gk_exp is not None: 26 | return torch.sign(x) * torch.abs(x) ** self.gk_exp 27 | else: 28 | return torch.zeros_like(x) 29 | 30 | def Cn_func(self, cn_): 31 | 32 | Cn = torch.diag(cn_) - torch.diag(cn_[1:], 1) 33 | return Cn 34 | 35 | def gc_func(self, x, xdot): 36 | if isinstance(self.gc_exp, float): 37 | return torch.sign(xdot) * torch.abs(xdot) ** self.gc_exp 38 | elif self.gc_exp == 'vdp': 39 | return (x**2 - 1) * xdot 40 | else: 41 | return torch.zeros_like(xdot) 42 | 43 | def mat_func(self, kn_, cn_, invM): 44 | 45 | Kn = self.Kn_func(kn_) 46 | Cn = self.Cn_func(cn_) 47 | 48 | return torch.cat(( 49 | torch.zeros((self.dofs, 2*self.dofs)), 50 | torch.cat((-invM @ Kn, -invM @ Cn), dim=1) 51 | ), dim=0) 52 | 53 | def gz_func(self, z): 54 | if len(z.shape) == 3: 55 | dofs = int(z.shape[1]/2) 56 | x_ = z[:, :dofs, :] - torch.cat((torch.zeros((z.shape[0], 1, z.shape[2])), z[:, :dofs-1, :]), dim=1) 57 | xdot_ = z[:, dofs:, :] - torch.cat((torch.zeros((z.shape[0], 1, z.shape[2])), z[:, dofs:-1, :]), dim=1) 58 | return torch.cat((self.gk_func(x_, xdot_), self.gc_func(x_, xdot_)), dim=1) 59 | else: 60 | dofs = int(z.shape[0]/2) 61 | x_ = z[:dofs, :] - torch.cat((torch.zeros((1, z.shape[1])), z[:dofs-1, :]), dim=0) 62 | xdot_ = z[dofs:, :] - torch.cat((torch.zeros((1, z.shape[1])), z[dofs:-1, :]), dim=0) 63 | return torch.cat((self.gk_func(x_, xdot_), self.gc_func(x_, xdot_)), dim=0) 64 | 65 | def gen_ndof_cantilever(m_: TensorFloat, c_: TensorFloat, k_: TensorFloat, ndof: int = None, return_numpy: bool = False, connected_damping: bool = True) -> Tuple[Tensor, Tensor, Tensor]: 66 | if torch.is_tensor(m_): 67 | ndof = m_.shape[0] 68 | else: 69 | m_ = m_ * torch.ones((ndof)) 70 | c_ = c_ * torch.ones((ndof)) 71 | k_ = k_ * torch.ones((ndof)) 72 | M = torch.zeros((ndof,ndof), dtype=torch.float32) 73 | C = torch.zeros((ndof,ndof), dtype=torch.float32) 74 | K = torch.zeros((ndof,ndof), dtype=torch.float32) 75 | for i in range(ndof): 76 | M[i,i] = m_[i] 77 | for i in range(ndof-1): 78 | if connected_damping: 79 | C[i,i] = c_[i] + c_[i+1] 80 | C[i,i+1] = -c_[i+1] 81 | else: 82 | C[i,i] = c_[i] 83 | K[i,i] = k_[i] + k_[i+1] 84 | K[i,i+1] = -k_[i+1] 85 | C[-1,-1] = c_[-1] 86 | K[-1,-1] = k_[-1] 87 | C = torch.triu(C) + torch.triu(C, 1).T 88 | K = torch.triu(K) + torch.triu(K, 1).T 89 | if return_numpy: 90 | return M.numpy(), C.numpy(), K.numpy() 91 | else: 92 | return M, C, K 93 | 94 | class mdof_pinn_model(nn.Module): 95 | 96 | def __init__(self, config: dict): 97 | super().__init__() 98 | self.n_input = config["n_input"] 99 | self.n_output = config["n_output"] 100 | self.n_hidden = config["n_hidden"] 101 | self.n_layers = config["n_layers"] 102 | self.seq_len = config["seq_len"] 103 | self.n_dof = config["n_dof"] 104 | if 'activation' in config.keys(): 105 | self.activation = getattr(nn, config["activation"]) 106 | else: 107 | # self.activation = nn.Tanh 108 | self.activation = nn.SiLU 109 | if 'net_split' in config.keys(): 110 | self.net_split = config["net_split"] 111 | else: 112 | self.net_split = False 113 | self.device = config["device"] 114 | 115 | self.build_nets() 116 | self.config = config 117 | 118 | self.pinn_type = 'normal' 119 | 120 | def gather_params(self): 121 | self.net_params_list = [] 122 | if self.net_split: 123 | for net in self.nets: 124 | for net_ in net: self.net_params_list.append(net_.parameters()) 125 | else: 126 | for net in self.nets: 127 | self.net_params_list.append(net.parameters()) 128 | 129 | def build_nets(self): 130 | if self.net_split: 131 | nets = [[None] * 2] * self.seq_len 132 | for net_n in range(self.seq_len): 133 | nets[net_n][0] = nn.Sequential( 134 | nn.Sequential(*[nn.Linear(self.n_input, self.n_hidden), self.activation()]), 135 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 136 | nn.Linear(self.n_hidden, self.n_dof) 137 | ).to(self.device) 138 | nets[net_n][1] = nn.Sequential( 139 | nn.Sequential(*[nn.Linear(self.n_input, self.n_hidden), self.activation()]), 140 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 141 | nn.Linear(self.n_hidden, self.n_dof) 142 | ).to(self.device) 143 | self.nets = tuple(nets) 144 | self.network_parameters = [] 145 | for net in self.nets: 146 | for net_ in net: self.network_parameters += list(net_.parameters()) 147 | pass 148 | else: 149 | nets = [None] * self.seq_len 150 | for net_n in range(self.seq_len): 151 | nets[net_n] = nn.Sequential( 152 | nn.Sequential(*[nn.Linear(self.n_input, self.n_hidden), self.activation()]), 153 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 154 | nn.Linear(self.n_hidden, self.n_output) 155 | ).to(self.device) 156 | self.nets = tuple(nets) 157 | # self.nets = nn.ModuleList(self.nets) 158 | self.network_parameters = [] 159 | for net in self.nets: self.network_parameters += list(net.parameters()) 160 | 161 | def build_net(self) -> int: 162 | self.net = nn.Sequential( 163 | nn.Sequential(*[nn.Linear(self.n_input * self.seq_len, self.n_hidden * self.seq_len), self.activation()]), 164 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden * self.seq_len, self.n_hidden * self.seq_len), self.activation()]) for _ in range(self.n_layers-1)]), 165 | nn.Linear(self.n_hidden * self.seq_len, self.n_output * self.seq_len), 166 | nn.Unflatten(dim=1, unflattened_size = (self.seq_len, self.n_output)) 167 | ) 168 | return 0 169 | 170 | def forward(self, x: torch.Tensor) -> torch.Tensor: 171 | """ 172 | Forward pass through the neural network. 173 | 174 | Args: 175 | x [(torch.Tensor)]: List of inputs to networks [samples x n_input] * seq_len 176 | 177 | Returns: 178 | torch.Tensor: Output tensor. 179 | """ 180 | if self.net_split: 181 | y = [] 182 | for nq in range(self.seq_len): 183 | z1 = self.nets[nq][0](x[nq]) 184 | z2 = self.nets[nq][1](x[nq]) 185 | y.append(torch.cat((z1, z2), dim=1)) 186 | else: 187 | y = [torch.zeros((x[0].shape[0], self.n_output)).to(self.device)] * self.seq_len 188 | for nq in range(self.seq_len): 189 | y[nq] = self.nets[nq](x[nq]) 190 | return tuple(y) # [samples x n_output] * seq_len 191 | 192 | def forward_new(self, x: torch.Tensor) -> torch.Tensor: 193 | """ 194 | Forward pass through the neural network. 195 | 196 | Args: 197 | x (torch.Tensor): Input to network [seq_len x samples x n_input] 198 | 199 | Returns: 200 | torch.Tensor: Output tensor. 201 | """ 202 | y = torch.zeros((self.seq_len, x.shape[1], self.n_output)).to(self.device) 203 | for nq in range(self.seq_len): 204 | y[nq] = self.nets[nq](x[nq].unsqueeze(1)) 205 | return y # [seq_len x samples x n_output] 206 | 207 | def configure(self, param_func, nonlin_func) -> None: 208 | """ 209 | Configures neural network 210 | 211 | Args: 212 | config (dict): Configuration parameters 213 | """ 214 | 215 | self.param_func = param_func 216 | self.nonlinearity = nonlin_func 217 | 218 | self.set_phys_params() 219 | self.set_norm_params() 220 | 221 | self.gather_params() 222 | self.set_switches(self.config['lambds']) 223 | 224 | def set_phys_params(self) -> None: 225 | """ 226 | Set physical parameters of model, and adds them as either constants or parameters for optimisation 227 | """ 228 | config = self.config 229 | self.param_attrs = {} 230 | self.system_parameters = [] 231 | 232 | #TODO: implement sparse matrices to see if it speeds up computation 233 | 234 | if all (param["type"] == 'constant' for param in config["phys_params"].values()): 235 | params = {} 236 | for param_name, param_dict in self.config["phys_params"].items(): 237 | params[param_name] = param_dict["value"] 238 | m_vec, c_vec, k_vec, kn_vec, cn_vec = self.param_parser(params) 239 | M, C, K = self.param_func(m_vec, c_vec, k_vec) 240 | self.M = M 241 | self.C = C 242 | self.K = K 243 | invM = torch.diag(1/torch.diag(M)) 244 | # state matrices 245 | self.A = torch.cat(( 246 | torch.cat((torch.zeros((self.n_dof,self.n_dof)),torch.eye(self.n_dof)), dim=1), 247 | torch.cat((-invM @ K, -invM @ C), dim=1) 248 | ), dim=0) 249 | self.H = torch.cat((torch.zeros((self.n_dof, self.n_dof)), invM), dim=0) 250 | if self.nonlinearity is not None: 251 | self.An = self.nonlinearity.mat_func(kn_vec, cn_vec, invM) 252 | # observation matrices 253 | self.B = torch.cat((-invM @ K, -invM @ C), dim=1) 254 | self.D = invM 255 | if self.nonlinearity is not None: 256 | self.Bn = self.nonlinearity.mat_func(kn_vec, cn_vec, invM)[self.n_dof:, :] 257 | 258 | for param_name, param_dict in config["phys_params"].items(): 259 | self.param_attrs[param_name] = param_dict["type"] 260 | if param_dict["type"] == "constant": 261 | setattr(self,param_name,param_dict["value"]) 262 | elif param_dict["type"] == "variable": 263 | self.register_parameter(param_name, nn.Parameter(param_dict["value"])) 264 | self.system_parameters.append(getattr(self, param_name)) 265 | if hasattr(self, "M") and hasattr(self, "C") and hasattr(self, "K"): 266 | self.A = torch.cat(( 267 | torch.cat((torch.zeros((self.n_dof,self.n_dof)),torch.eye(self.n_dof)), dim=1), 268 | torch.cat((-torch.linalg.inv(self.M)@self.K, -torch.linalg.inv(self.M)@self.C), dim=1) 269 | ), dim=0) 270 | elif hasattr(self,"M"): 271 | self.m_ = torch.diag(self.M) # takes diagonal from mass matrix if set as constant 272 | 273 | obs_dropouts = config["dropouts"] 274 | obs_keep = [j for j in range(self.n_dof) if j not in obs_dropouts] 275 | self.Sa = torch.diag(torch.tensor(obs_keep)) 276 | 277 | def set_norm_params(self) -> None: 278 | """ 279 | Set normalisation parameters of the model 280 | """ 281 | config = self.config 282 | 283 | # signal value norms 284 | self.alpha_t = config["alphas"]["t"].clone().detach().to(self.device) 285 | self.alpha_x = config["alphas"]["x"].clone().detach().to(self.device) 286 | self.alpha_v = config["alphas"]["v"].clone().detach().to(self.device) 287 | self.alpha_z = torch.cat((self.alpha_x.item()*torch.ones(self.n_dof,1), self.alpha_v.item()*torch.ones(self.n_dof,1)), dim=0).float().to(self.device) 288 | self.alpha_a = config["alphas"]["a"].clone().detach().to(self.device) 289 | self.alpha_f = config["alphas"]["f"].clone().detach().to(self.device) 290 | 291 | # system parameter norms 292 | self.alpha_c = config["alphas"]["c"].clone().detach().to(self.device) 293 | self.alpha_k = config["alphas"]["k"].clone().detach().to(self.device) 294 | self.alpha_kn = config["alphas"]["kn"].clone().detach().to(self.device) 295 | self.alpha_cn = config["alphas"]["cn"].clone().detach().to(self.device) 296 | # for param_name, param_dict in config["phys_params"].items(): 297 | # if param_dict["type"] == "variable": 298 | # setattr(self,"alpha_"+param_name[:-2], config["alphas"][param_name[:-2]]) 299 | # else: 300 | # setattr(self, "alpha_"+param_name[:-2], 1.0) 301 | 302 | def set_aux_funcs(self, param_func, nonlin_func): 303 | self.param_func = param_func 304 | self.nonlinearity = nonlin_func 305 | 306 | def set_switches(self, lambdas: dict) -> None: 307 | """ 308 | Sets switches for residual/loss calculation to improve performance of unecessary calculation 309 | Args: 310 | lambdas (dict): dictionary of lambda weighting parameters 311 | """ 312 | switches = {} 313 | for key, value in lambdas.items(): 314 | switches[key] = value>0.0 315 | self.switches = switches 316 | 317 | def param_parser(self, params) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 318 | """ 319 | Parses physical parameters into matrices 320 | 321 | Args: 322 | params (dict): dictionary of physical parameters 323 | 324 | Returns: 325 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: mass, damping, stiffness, nonlinear stiffnes and nonlinear damping vectors 326 | """ 327 | 328 | m_vec = torch.zeros(self.n_dof) 329 | c_vec = torch.zeros(self.n_dof) 330 | k_vec = torch.zeros(self.n_dof) 331 | kn_vec = torch.zeros(self.n_dof) 332 | cn_vec = torch.zeros(self.n_dof) 333 | 334 | for i in range(self.n_dof): 335 | m_vec[i] = params[f'm_{i}'] 336 | c_vec[i] = params[f'c_{i}'] 337 | k_vec[i] = params[f'k_{i}'] 338 | kn_vec[i] = params[f'kn_{i}'] 339 | cn_vec[i] = params[f'cn_{i}'] 340 | 341 | return m_vec, c_vec, k_vec, kn_vec, cn_vec 342 | 343 | def retrieve_state_matrices(self, theta_s: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 344 | if theta_s is None: 345 | if hasattr(self,"A"): 346 | M, C, K = self.M, self.C, self.K 347 | A = self.A 348 | else: 349 | params = {} 350 | for param_name, param_dict in self.config["phys_params"].items(): 351 | if param_dict["type"] == "constant": 352 | params[param_name] = param_dict["value"] 353 | else: 354 | if param_name[-2] == "_": 355 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-2]) 356 | else: 357 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-3]) 358 | m_vec, c_vec, k_vec, kn_vec, cn_vec = self.param_parser(params) 359 | M, C, K = self.param_func(m_vec, c_vec, k_vec) 360 | invM = torch.diag(1/torch.diag(M)) 361 | A = torch.cat(( 362 | torch.cat((torch.zeros((self.n_dof, self.n_dof)), torch.eye(self.n_dof)), dim=1), 363 | torch.cat((-invM @ K, -invM @ C), dim=1) 364 | ), dim=0).requires_grad_() 365 | else: 366 | if hasattr(self,"A"): 367 | M, C, K = self.M, self.C, self.K 368 | A = self.A 369 | else: 370 | M, C, K = self.param_func(self.m_, theta_s[:self.n_dof]*self.alpha_c, theta_s[self.n_dof:2*self.n_dof]*self.alpha_k) 371 | invM = torch.diag(1/torch.diag(M)) 372 | A = torch.cat(( 373 | torch.cat((torch.zeros((self.n_dof, self.n_dof)), torch.eye(self.n_dof)), dim=1), 374 | torch.cat((-invM @ K, -invM @ C), dim=1) 375 | ), dim=0).requires_grad_() 376 | if hasattr(self,"H"): 377 | H = self.H 378 | else: 379 | H = torch.cat((torch.zeros((self.n_dof, self.n_dof)), invM), dim=0) 380 | 381 | # nonlinear parameters 382 | if self.nonlinearity is not None: 383 | if hasattr(self,"An"): 384 | An = self.An 385 | else: 386 | An = self.nonlinearity.mat_func(kn_vec, cn_vec, invM) 387 | else: 388 | An = torch.zeros((self.n_dof, self.n_dof)) 389 | 390 | return A, H, An 391 | 392 | def retrieve_obs_matrices(self, theta_s: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 393 | if theta_s is None: 394 | if hasattr(self,"B"): 395 | B = self.B 396 | else: 397 | params = {} 398 | for param_name, param_dict in self.config["phys_params"].items(): 399 | if param_dict["type"] == "constant": 400 | params[param_name] = param_dict["value"] 401 | else: 402 | if param_name[-2] == "_": 403 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-2]) 404 | else: 405 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-3]) 406 | m_vec, c_vec, k_vec, kn_vec, cn_vec = self.param_parser(params) 407 | M, C, K = self.param_func(m_vec, c_vec, k_vec) 408 | invM = torch.diag(1/torch.diag(M)) 409 | B = torch.cat((-invM @ K, -invM @ C), dim=1).requires_grad_() 410 | else: 411 | if hasattr(self,"B"): 412 | B = self.B 413 | else: 414 | M, C, K = self.param_func(self.m_, theta_s[:self.n_dof]*self.alpha_c, theta_s[self.n_dof:2*self.n_dof]*self.alpha_k) 415 | invM = torch.diag(1/torch.diag(M)) 416 | B = torch.cat((-invM @ K, -invM @ C), dim=1).requires_grad_() 417 | if hasattr(self,"D"): 418 | D = self.D 419 | else: 420 | D = invM 421 | if self.nonlinearity is not None: 422 | if hasattr(self,"Bn"): 423 | Bn = self.Bn 424 | else: 425 | Bn = self.nonlinearity.mat_func(kn_vec, cn_vec, invM)[self.n_dof:, :] 426 | else: 427 | Bn = torch.zeros((self.n_dof, 2 * self.n_dof)) 428 | 429 | return B, D, Bn 430 | 431 | def loss_func( 432 | self, 433 | t_obs_list: list[torch.Tensor], 434 | acc_obs: torch.Tensor, 435 | f_obs: Tuple[torch.Tensor, None], 436 | t_col_list: Tuple[list[torch.Tensor], None], # when None use same domain as observations 437 | f_col: Tuple[torch.Tensor, None], # when None use same domain as observations 438 | lambdas: dict, 439 | obs_dropouts: Optional[list] = False, 440 | acc_obs_method: str = 'obs_model' 441 | ) -> Tuple[torch.Tensor, list, dict]: 442 | """ 443 | Calculates residuals for loss functions 444 | 445 | Args: 446 | t_obs_list [(torch.Tensor)]: time values for observation domain [[sample x 1]] * sequence 447 | acc_obs (torch.Tensor): observations of state [sample x sequence x dof] 448 | f_obs (torch.Tensor): observations of force [sample x sequence x dof] 449 | t_col [(torch.Tensor)]: time values for collocation domain [[sample x 1]] * sequence 450 | f_col (torch.Tensor): measurements of force in collocation domain [sample x sequence x dof] 451 | lambdas (dict): dictionary of loss weighting parameters 452 | 453 | Returns: 454 | loss: total loss 455 | losses: list of individual losses 456 | residuals: dictionary of residuals 457 | 458 | """ 459 | 460 | if self.switches['obs'] or self.switches['occ']: 461 | # generate prediction at observation points 462 | zp_obs_hat = self.forward(t_obs_list) # list of predicted states [samples x state_dim] * seq_len 463 | 464 | # retrieve system matrices 465 | B, D, Bn = self.retrieve_obs_matrices() 466 | 467 | if self.switches['obs']: 468 | # calculate residuals 469 | #TODO: use Sa matrix for dropouts 470 | if obs_dropouts: 471 | idx_obs = [j for j in range(self.n_dof) if j not in obs_dropouts] # indices of observed DOFs 472 | R_obs = torch.zeros((t_obs_list[0].shape[0])).to(self.device) # initialise observation residuals [samples] 473 | for nq in range(self.seq_len): 474 | B_ = B[idx_obs, :] # reduced observation matrix 475 | Bn_ = Bn[idx_obs, :] # reduced nonlinear state matrix 476 | D_ = D[idx_obs, :][:, idx_obs] # reduced force matrix 477 | acc_obs_ = self.alpha_a * acc_obs[:, nq, idx_obs].T # acceleration observations [n_dof_obs x samples] 478 | f_obs_ = self.alpha_f * f_obs[:, nq, idx_obs].T if f_obs is not None else None # force observations [n_dof_obs x samples] 479 | alpha_z_ = self.alpha_z 480 | if self.nonlinearity is None: 481 | if f_obs is None: 482 | R_obs_seq = (B_ @ (alpha_z_ * zp_obs_hat[nq].T) - acc_obs_).T 483 | else: 484 | R_obs_seq = (B_ @ (alpha_z_ * zp_obs_hat[nq].T) + D_ @ f_obs_ - acc_obs_).T 485 | else: 486 | gz = self.nonlinearity.gz_func(alpha_z_ * zp_obs_hat[nq].T) 487 | if f_obs is None: 488 | R_obs_seq = (B_ @ (alpha_z_ * zp_obs_hat[nq].T) + Bn_ @ gz - acc_obs_).T 489 | else: 490 | R_obs_seq = (B_ @ (alpha_z_ * zp_obs_hat[nq].T) + Bn_ @ gz + D_ @ f_obs_ - acc_obs_).T 491 | R_obs += torch.sqrt(torch.sum(R_obs_seq**2, dim=1)) 492 | else: 493 | R_obs = torch.zeros((t_obs_list[0].shape[0])).to(self.device) 494 | for nq in range(self.seq_len): 495 | acc_obs_ = self.alpha_a * acc_obs[:, nq, :].T 496 | f_obs_ = self.alpha_f * f_obs[:, nq, :].T 497 | alpha_z_ = self.alpha_z 498 | match [self.nonlinearity, f_obs]: 499 | case None, None: 500 | R_obs_seq = (B @ (alpha_z_ * zp_obs_hat[nq].T) - acc_obs_).T 501 | R_obs += torch.sqrt(torch.sum(R_obs_seq**2, dim=1)) 502 | case [_, None]: 503 | R_obs_seq = (B @ (alpha_z_ * zp_obs_hat[nq].T) + Bn @ self.nonlinearity.gz_func(alpha_z_ * zp_obs_hat[nq].T) - acc_obs_).T 504 | R_obs += torch.sqrt(torch.sum(R_obs_seq**2, dim=1)) 505 | case [None, torch.Tensor()]: 506 | R_obs_seq = (B @ (alpha_z_ * zp_obs_hat[nq].T) + D @ f_obs_ - acc_obs_).T 507 | R_obs += torch.sqrt(torch.sum(R_obs_seq**2, dim=1)) 508 | case [_, torch.Tensor()]: 509 | R_obs_seq = (B @ (alpha_z_ * zp_obs_hat[nq].T) + Bn @ self.nonlinearity.gz_func(alpha_z_ * zp_obs_hat[nq].T) + D @ f_obs_ - acc_obs_).T 510 | R_obs += torch.sqrt(torch.sum(R_obs_seq**2, dim=1)) 511 | else: 512 | R_obs = torch.zeros((2)) 513 | 514 | # if acc_obs_method in ['deriv_continuity', 'both']: 515 | if self.switches['occ']: 516 | if obs_dropouts: 517 | idx_keep = [j for j in range(self.n_dof) if j not in obs_dropouts] # indices to keep in terms of DOFs 518 | idx_keep_2 = idx_keep + [j+self.n_dof for j in idx_keep] # indices to keep in terms of state vector 519 | R_occ = torch.zeros((t_obs_list[0].shape[0])).to(self.device) 520 | dzdt_list_obs = [torch.zeros((t_obs_list[0].shape[0], len(idx_keep))).to(self.device) for _ in range(self.seq_len)] 521 | for nq in range(self.seq_len): 522 | acc_obs_ = self.alpha_a * acc_obs[:, nq, idx_keep] # acceleration observations [n_dof x No] 523 | for i, idx in enumerate(idx_keep): 524 | dzdt_list_obs[nq][:, i] = torch.autograd.grad(zp_obs_hat[nq][:, idx+self.n_dof], t_obs_list[nq], torch.ones_like(zp_obs_hat[nq][:, i]), create_graph=True)[0][:, 0] # first derivative of velocity 525 | R_occ += torch.sqrt(torch.sum(((self.alpha_z[-1]/self.alpha_t) * dzdt_list_obs[nq] - acc_obs_)**2, dim=1)) 526 | else: 527 | R_occ = torch.zeros((t_obs_list[0].shape[0])).to(self.device) 528 | dzdt_list_obs = [torch.zeros((t_obs_list[0].shape[0], 2*self.n_dof)).to(self.device) for _ in range(self.seq_len)] 529 | for nq in range(self.seq_len): 530 | acc_obs_ = self.alpha_a * acc_obs[:, nq, :].T 531 | for i in range(self.n_dof): 532 | dzdt_list_obs[nq][:, i+self.n_dof] = torch.autograd.grad(zp_obs_hat[nq][:, i], t_obs_list[nq], torch.ones_like(zp_obs_hat[nq][:, i]), create_graph=True)[0][:, 0] 533 | R_occ += torch.sqrt(torch.sum(((self.alpha_z[-1]/self.alpha_t) * dzdt_list_obs[nq] - acc_obs_)**2, dim=1)) 534 | else: 535 | R_occ = torch.zeros((2)) 536 | 537 | # force contribution checker 538 | # t_pred_stack = torch.cat(t_obs_list, dim=0).detach() 539 | # zp_pred_stack = self.alpha_z * torch.cat(zp_obs_hat, dim=0).detach().T 540 | # f_obs_stack = self.alpha_f * torch.cat([f_obs[:, nq, :] for nq in range(self.seq_len)], dim=0).detach().T 541 | # xx = zp_pred_stack[:self.n_dof, :] 542 | # vv = zp_pred_stack[self.n_dof:, :] 543 | # zn = nonlin_state_transform(zp_pred_stack) 544 | # C, K, Kn, = self.C, self.K, -self.Bn[:, :self.n_dof] * 10 545 | # lin_spring_contr = K @ xx 546 | # lin_damp_contr = C @ vv 547 | # nonlin_spring_contr = Kn @ zn[:self.n_dof, :] 548 | # acc_check = D @ (-lin_spring_contr - lin_damp_contr - nonlin_spring_contr + f_obs_stack) 549 | 550 | # match acc_obs_method: 551 | # case 'obs_model': 552 | # R_obs = R_obs_obs 553 | # case 'deriv_continuity': 554 | # R_obs = R_obs_dcc 555 | # case 'both': 556 | # R_obs = R_obs_obs + R_obs_dcc 557 | 558 | if self.switches['ode'] or self.switches['cc']: 559 | 560 | # generate or retrieve prediction over collocation domain 561 | dzdt_list_col = [torch.zeros((t_obs_list[0].shape[0], 2*self.n_dof)).to(self.device) for _ in range(self.seq_len)] 562 | 563 | if t_col_list is None: 564 | zp_col_hat = zp_obs_hat 565 | t_col_list = t_obs_list 566 | # if acc_obs_method in ['deriv_continuity', 'both']: 567 | if self.switches['occ']:# have already calculated velocity derivatives in observation loss 568 | if obs_dropouts: # there are some dropouts so recover only the derivatives that were not dropped out 569 | for nq in range(self.seq_len): 570 | dzdt_list_col[nq][:, idx_keep] = dzdt_list_obs[nq] 571 | for i in obs_dropouts: 572 | dzdt_list_col[nq][:, i + self.n_dof] = torch.autograd.grad(zp_obs_hat[nq][:, i+self.n_dof], t_obs_list[nq], torch.ones_like(zp_obs_hat[nq][:, i]), create_graph=True)[0][:, 0] 573 | else: # no dropouts so retrieve all derivatives of velocity 574 | for nq in range(self.seq_len): 575 | dzdt_list_col[nq][self.n_dof:] = dzdt_list_obs[nq] 576 | for nq in range(self.seq_len): # generate derivatives of displacement 577 | for i in range(self.n_dof): 578 | dzdt_list_col[nq][:, i] = torch.autograd.grad(zp_obs_hat[nq][:, i], t_obs_list[nq], torch.ones_like(zp_obs_hat[nq][:, i]), create_graph=True)[0][:, 0] 579 | else: 580 | # generate all derivatives 581 | dzdt_list_col = [torch.zeros((t_obs_list[0].shape[0], 2*self.n_dof)) for _ in range(self.seq_len)] 582 | for nq in range(self.seq_len): 583 | for i in range(2*self.n_dof): 584 | dzdt_list_col[nq][:, i] = torch.autograd.grad(zp_obs_hat[nq][:, i], t_obs_list[nq], torch.ones_like(zp_obs_hat[nq][:, i]), create_graph=True)[0][:, 0] 585 | if f_obs is not None: 586 | f_col = f_obs 587 | else: # collocation domain is different to observation domain, so must generate all derivatives 588 | zp_col_hat = self.forward(t_col_list) 589 | # generate derivatives 590 | dzdt_list_col = [torch.zeros((t_col_list[0].shape[0], 2*self.n_dof)) for _ in range(self.seq_len)] 591 | for nq in range(self.seq_len): 592 | for i in range(2*self.n_dof): 593 | dzdt_list_col[nq][:, i] = torch.autograd.grad(zp_col_hat[nq][:, i], t_col_list[nq], torch.ones_like(zp_col_hat[nq][:, i]), create_graph=True)[0][:, 0] 594 | 595 | # retrieve physical parameters 596 | A, H, An = self.retrieve_state_matrices() 597 | 598 | R_ode = torch.zeros((t_col_list[0].shape[0])).to(self.device) 599 | R_cc = torch.zeros((t_col_list[0].shape[0])).to(self.device) 600 | for nq in range(self.seq_len): 601 | f_col_ = self.alpha_f * f_col[:, nq, :].T 602 | dzdt_ = dzdt_list_col[nq] 603 | match [self.nonlinearity, f_col]: 604 | case None, None: 605 | R_ = (self.alpha_z / self.alpha_t) * dzdt_.T - A @ (self.alpha_z * zp_col_hat[nq].T) 606 | case [_, None]: 607 | gz = self.nonlinearity.gz_func(self.alpha_z*zp_col_hat[nq].T) 608 | R_ = (self.alpha_z / self.alpha_t)*dzdt_.T - A @ (self.alpha_z * zp_col_hat[nq].T) - An @ gz 609 | case [None, torch.Tensor()]: 610 | R_ = (self.alpha_z / self.alpha_t) * dzdt_.T - A @ (self.alpha_z * zp_col_hat[nq].T) - H @ (f_col_) 611 | case [_, torch.Tensor()]: 612 | gz = self.nonlinearity.gz_func(self.alpha_z * zp_col_hat[nq].T) 613 | R_ = (self.alpha_z / self.alpha_t) * dzdt_.T - A @ (self.alpha_z * zp_col_hat[nq].T) - An @ gz - H @ (f_col_) 614 | R_cc += torch.sqrt(torch.sum((R_[:self.n_dof, :])**2, dim=0)) / self.seq_len 615 | R_ode += torch.sqrt(torch.sum((R_[self.n_dof:, :])**2, dim=0)) / self.seq_len 616 | 617 | # continuity condition residual 618 | # R_cc = R_[:self.n_dof, :].T 619 | else: 620 | R_ode = torch.zeros((2)) 621 | R_cc = torch.zeros((2)) 622 | 623 | residuals = { 624 | "R_obs" : R_obs, 625 | "R_occ" : R_occ, 626 | "R_cc" : R_cc, 627 | "R_ode" : R_ode 628 | } 629 | 630 | L_obs = lambdas['obs'] * torch.mean(R_obs**2) 631 | L_occ = lambdas['occ'] * torch.mean(R_occ**2) 632 | L_cc = lambdas['cc'] * torch.mean(R_cc**2) 633 | L_ode = lambdas['ode'] * torch.mean(R_ode**2) 634 | 635 | loss = L_obs + L_occ + L_cc + L_ode 636 | 637 | if math.isnan(loss): 638 | raise Exception("Loss is NaN, upsi") 639 | 640 | return loss, [L_obs, L_occ, L_cc, L_ode], residuals 641 | 642 | def predict(self, t_pred_list, theta_s=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 643 | """ 644 | Predict state values 645 | 646 | Arguments: 647 | t_pred: [samples] * seq_len 648 | theta_s: [n_dof] * 3 649 | """ 650 | zp_list = self.forward(t_pred_list) 651 | # retrieve derivatives 652 | dzdt_list = [torch.zeros((t_pred_list[0].shape[0], 2*self.n_dof)) for _ in range(self.seq_len)] 653 | for nq in range(self.seq_len): 654 | for i in range(2*self.n_dof): 655 | dzdt_list[nq][:, i] = torch.autograd.grad(zp_list[nq][:, i], t_pred_list[nq], torch.ones_like(zp_list[nq][:, i]), create_graph=True)[0][:, 0] 656 | 657 | # # reshape 658 | t_pred_flat, sort_ids = torch.sort(torch.cat(t_pred_list, dim=0).reshape(-1)) 659 | zp_flat = torch.cat(zp_list, dim=0)[sort_ids, :] 660 | dzdt_ = torch.cat(dzdt_list, dim=0) 661 | dzdt = dzdt_[sort_ids, :] 662 | 663 | # retrieve physical parameters 664 | A, H, An = self.retrieve_state_matrices(theta_s) 665 | 666 | match self.nonlinearity: 667 | case None: 668 | Hf_pred = (self.alpha_z / self.alpha_t) * dzdt.T - A @ (self.alpha_z * zp_flat.T) 669 | case _: 670 | gz = self.nonlinearity.gz_func(self.alpha_z * zp_flat.T) 671 | Hf_pred = (self.alpha_z / self.alpha_t) * dzdt.T - A @ (self.alpha_z * zp_flat.T) - An @ gz 672 | M = torch.diag(1/torch.diag(H[self.n_dof:, :])) 673 | f_pred = M @ Hf_pred[self.n_dof:, :] / self.alpha_f 674 | 675 | #TODO: check acceleration prediction with derivative of state instead using observer model 676 | B, D, Bn = self.retrieve_obs_matrices(theta_s) 677 | match self.nonlinearity: 678 | case None: 679 | a_pred = (B @ (self.alpha_z * zp_flat.T) + D @ (self.alpha_f * f_pred)) / self.alpha_a 680 | case _: 681 | a_pred = (B @ (self.alpha_z * zp_flat.T) + Bn @ gz + D @ (self.alpha_f * f_pred)) / self.alpha_a 682 | 683 | return zp_flat, f_pred.T, a_pred.T, t_pred_flat 684 | 685 | 686 | class mdof_stoch_pinn(nn.Module): 687 | 688 | def __init__(self, config: dict): 689 | super().__init__() 690 | self.n_input = config["n_input"] 691 | self.n_output = config["n_output"] 692 | self.n_hidden = config["n_hidden"] 693 | self.n_layers = config["n_layers"] 694 | self.seq_len = config["seq_len"] 695 | self.n_dof = config["n_dof"] 696 | if 'activation' in config.keys(): 697 | self.activation = getattr(nn, config["activation"]) 698 | else: 699 | # self.activation = nn.Tanh 700 | self.activation = nn.SiLU 701 | self.activation = nn.ELU 702 | self.device = config["device"] 703 | self.t_pi = torch.tensor(math.pi) 704 | 705 | self.build_nets() 706 | 707 | self.configure(config) 708 | self.gather_params() 709 | self.set_switches(config['lambds']) 710 | 711 | def gather_params(self): 712 | self.net_params_list = [] 713 | for net in self.nets: 714 | self.net_params_list.append(net.parameters()) 715 | 716 | def build_nets(self) -> int: 717 | nets = [None] * self.seq_len 718 | for net_n in range(self.seq_len): 719 | nets[net_n] = nn.Sequential( 720 | nn.Sequential(*[nn.Linear(self.n_input, self.n_hidden), self.activation()]), 721 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden, self.n_hidden), self.activation()]) for _ in range(self.n_layers-1)]), 722 | nn.Linear(self.n_hidden, self.n_output) 723 | ) 724 | nets[net_n].to(self.device) 725 | self.nets = tuple(nets) 726 | self.network_parameters = [] 727 | for net in self.nets: 728 | self.network_parameters += list(net.parameters()) 729 | return 0 730 | 731 | def build_net(self) -> int: 732 | self.net = nn.Sequential( 733 | nn.Sequential(*[nn.Linear(self.n_input * self.seq_len, self.n_hidden * self.seq_len), self.activation()]), 734 | nn.Sequential(*[nn.Sequential(*[nn.Linear(self.n_hidden * self.seq_len, self.n_hidden * self.seq_len), self.activation()]) for _ in range(self.n_layers-1)]), 735 | nn.Linear(self.n_hidden * self.seq_len, self.n_output * self.seq_len), 736 | nn.Unflatten(dim=1, unflattened_size = (self.seq_len, self.n_output)) 737 | ) 738 | return 0 739 | 740 | def forward(self, x: torch.Tensor) -> torch.Tensor: 741 | """ 742 | Forward pass through the neural network. 743 | 744 | Args: 745 | x (torch.Tensor): Input to network 746 | 747 | Returns: 748 | torch.Tensor: Output tensor. 749 | """ 750 | # y = self.net(x) 751 | y = [torch.zeros((x[0].shape[0], self.n_output), device=self.device)] * self.seq_len 752 | # y = torch.zeros((self.seq_len, x[0].shape[0], self.n_output), device=self.device) 753 | for nq in range(self.seq_len): 754 | y[nq] = self.nets[nq](x[nq]) 755 | return tuple(y) 756 | 757 | def configure(self, config: dict) -> None: 758 | """ 759 | Configures neural network 760 | 761 | Args: 762 | config (dict): Configuration parameters 763 | """ 764 | 765 | self.config = config 766 | 767 | self.nonlinearity = False 768 | # self.param_func = config["param_func"] 769 | 770 | self.set_phys_params() 771 | self.set_norm_params() 772 | self.set_noise_params() 773 | 774 | def set_phys_params(self) -> None: 775 | """ 776 | Set physical parameters of model, and adds them as either constants or parameters for optimisation 777 | """ 778 | config = self.config 779 | self.param_attrs = {} 780 | self.system_parameters = [] 781 | for param_name, param_dict in config["phys_params"].items(): 782 | self.param_attrs[param_name] = param_dict["type"] 783 | if param_dict["type"] == "constant": 784 | setattr(self,param_name,param_dict["value"]) 785 | elif param_dict["type"] == "variable": 786 | self.register_parameter(param_name, nn.Parameter(param_dict["value"])) 787 | self.system_parameters.append(getattr(self, param_name)) 788 | if hasattr(self, "M") and hasattr(self, "C") and hasattr(self, "K"): 789 | self.A = torch.cat(( 790 | torch.cat((torch.zeros((self.n_dof,self.n_dof)), torch.eye(self.n_dof)), dim=1), 791 | torch.cat((-torch.linalg.inv(self.M)@self.K, -torch.linalg.inv(self.M)@self.C), dim=1) 792 | ), dim=0) 793 | elif hasattr(self,"M"): 794 | self.m_ = torch.diag(self.M) # takes diagonal from mass matrix if set as constant 795 | 796 | def set_noise_params(self) -> None: 797 | """ 798 | Set noise parameters in likelihood equations 799 | """ 800 | config = self.config 801 | self.noise_parameters = [] 802 | for param_name, param_val in config["noise_params"].items(): 803 | self.register_parameter(param_name, nn.Parameter(param_val)) 804 | self.noise_parameters.append(getattr(self, param_name)) 805 | 806 | def set_norm_params(self) -> None: 807 | """ 808 | Set normalisation parameters of the model 809 | """ 810 | config = self.config 811 | self.alpha_t = config["alphas"]["t"].clone().detach().to(self.device) 812 | self.alpha_x = config["alphas"]["x"].clone().detach().to(self.device) 813 | self.alpha_v = config["alphas"]["v"].clone().detach().to(self.device) 814 | self.alpha_z = torch.cat((config["alphas"]["x"]*torch.ones(self.n_dof,1), config["alphas"]["v"]*torch.ones(self.n_dof,1)), dim=0).float().to(self.device) 815 | self.alpha_f = config["alphas"]["f"].clone().detach().to(self.device) 816 | for param_name, param_dict in config["phys_params"].items(): 817 | if param_dict["type"] == "variable": 818 | setattr(self,"alpha_"+param_name[:-1],config["alphas"][param_name[:-1]]) 819 | else: 820 | setattr(self, "alpha_"+param_name[:-1], 1.0) 821 | 822 | def set_aux_funcs(self, param_func, nonlin_func): 823 | self.param_func = param_func 824 | self.nonlinearity = nonlin_func 825 | 826 | def set_switches(self, lambdas: dict) -> None: 827 | """ 828 | Sets switches for residual/loss calculation to improve performance of unecessary calculation 829 | Args: 830 | lambdas (dict): dictionary of lambda weighting parameters 831 | """ 832 | switches = {} 833 | for key, value in lambdas.items(): 834 | switches[key] = value>0.0 835 | if self.seq_len == 1: 836 | switches['ncc'] = 0.0 837 | self.switches = switches 838 | 839 | def loss_func(self, t_obs: torch.Tensor, z_obs: torch.Tensor, t_col: torch.Tensor, f_col: torch.Tensor, lambdas: dict, obs_dropouts: Tuple[list, bool]=False) -> Tuple[torch.Tensor, list, dict]: 840 | """ 841 | Calculates residuals for loss functions 842 | 843 | Args: 844 | t_obs list(torch.Tensor): time values for observation domain [No x 1] x seq_len 845 | z_obs (torch.Tensor): observations of state [No x seq_len x 2*ndof] 846 | t_col list(torch.Tensor): time values for collocation domain [Nc x 1] x seq_len 847 | f_col list(torch.Tensor): measurements of force in collocation domain [Nc x ndof] x seq_len 848 | lambdas (dict): dictionary of loss weighting parameters 849 | 850 | Returns: 851 | loss: total loss 852 | losses: list of individual losses 853 | residuals: dictionary of residuals 854 | 855 | """ 856 | 857 | if self.switches['obs']: 858 | # generate prediction at observation points 859 | zp_obs_hat = self.forward(t_obs) 860 | if obs_dropouts: 861 | idx = [] # ids for displacements that should be included 862 | idx2 = [] # ids for velocities that should be included 863 | for j in range(self.n_dof): 864 | if j not in obs_dropouts: 865 | idx.append(j) 866 | idx2.append(j+self.n_dof) 867 | idx.extend(idx2) # all ids in states that should be included 868 | n_obs_sq = t_obs[0].shape[0] # number of observations per sequence 869 | n_obs_state = 2*(self.n_dof-len(obs_dropouts)) # number of observations in the states 870 | R_obs = torch.zeros((n_obs_sq*self.seq_len, n_obs_state)) # empty matrix for residuals [n_samps, 2n_dof] 871 | for nq in range(self.seq_len): 872 | # R_obs += torch.sqrt(torch.sum((zp_obs_hat[nq][:, idx].reshape(-1, 2*(self.n_dof-len(obs_dropouts))) - z_obs[:, nq, idx].reshape(-1, 2*(self.n_dof-len(obs_dropouts))))**2, dim=1)) 873 | R_obs[nq*n_obs_sq:(nq+1)*n_obs_sq, :] = zp_obs_hat[nq][:, idx].reshape(-1, n_obs_state) - z_obs[:, nq, idx].reshape(-1, n_obs_state) 874 | else: 875 | # R_obs = torch.zeros((t_obs.shape[0])) 876 | # for nq in range(self.seq_len): 877 | # R_obs += torch.sqrt(torch.sum((zp_obs_hat[nq] - z_obs[:, nq, :])**2, dim=1)) 878 | # R_obs = torch.zeros((n_obs_sq*self.seq_len)) 879 | n_obs_sq = t_obs[0].shape[0] # number of observations per sequence 880 | R_obs = torch.zeros((n_obs_sq*self.seq_len, 2*self.n_dof)) 881 | for nq in range(self.seq_len): 882 | # R_obs[nq*n_obs_sq:(nq+1)*n_obs_sq] = torch.sqrt(torch.sum((zp_obs_hat[nq] - z_obs[:, nq, :])**2, dim=1)) 883 | R_obs[nq*n_obs_sq:(nq+1)*n_obs_sq, :] = zp_obs_hat[nq] - z_obs[:, nq, :] 884 | 885 | if self.switches['ncc']: 886 | net_in_ids = [torch.argmin(t_obs[n]) for n in range(self.seq_len)] 887 | net_out_ids = [torch.argmax(t_obs[n]) for n in range(self.seq_len)] 888 | R_ncc = torch.zeros((self.seq_len - 1, 1)) 889 | 890 | for nq in range(self.seq_len-1): 891 | R_ncc[nq] = torch.sqrt(torch.sum((zp_obs_hat[nq][net_out_ids[nq], :] - zp_obs_hat[nq + 1][net_in_ids[nq + 1], :])**2, dim=0)) 892 | else: 893 | R_ncc = torch.ones((self.n_dof, 1)) 894 | 895 | if self.switches['ode'] or self.switches['cc']: 896 | # generate prediction over collocation domain 897 | zp_col_hat_ = self.forward(t_col) 898 | 899 | # retrieve derivatives 900 | dxdt_list = [torch.zeros((t_col[0].shape[0], 2*self.n_dof)) for _ in range(self.seq_len)] 901 | for nq in range(self.seq_len): 902 | for i in range(2*self.n_dof): 903 | dxdt_list[nq][:, i] = torch.autograd.grad(zp_col_hat_[nq][:, i], t_col[nq], torch.ones_like(zp_col_hat_[nq][:, i]), create_graph=True)[0][:, 0] 904 | 905 | # reshape 906 | t_pred_flat, sort_ids = torch.sort(torch.cat(t_col, dim=0).reshape(-1)) 907 | zp_col_hat = torch.cat(zp_col_hat_, dim=0)[sort_ids, :] 908 | f_col = torch.cat(f_col, dim=0)[sort_ids, :] 909 | dxdt_ = torch.cat(dxdt_list, dim=0) 910 | dxdt = dxdt_[sort_ids, :] 911 | 912 | # retrieve physical parameters 913 | # linear parameters 914 | if hasattr(self,"A"): 915 | M, C, K = self.M, self.C, self.K 916 | A = self.A 917 | else: 918 | params = {} 919 | for param_name, param_dict in self.config["phys_params"].items(): 920 | if param_dict["type"] == "constant": 921 | params[param_name] = param_dict["value"] 922 | else: 923 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-1]) 924 | M, C, K = self.param_func(params["m_"], params["c_"], params["k_"]) 925 | invM = torch.diag(1/torch.diag(M)) 926 | A = torch.cat(( 927 | torch.cat((torch.zeros((self.n_dof, self.n_dof)), torch.eye(self.n_dof)), dim=1), 928 | torch.cat((-invM @ K, -invM @ C), dim=1) 929 | ), dim=0).requires_grad_() 930 | if f_col is not None: 931 | if hasattr(self,"H"): 932 | H = self.H 933 | else: 934 | H = torch.cat((torch.zeros((self.n_dof, self.n_dof)), invM), dim=0) 935 | 936 | # nonlinear parameters 937 | if self.nonlinearity is not None: 938 | An = self.nonlinearity.mat_func(params['kn_'], params['cn_'], invM) 939 | 940 | if self.switches['ode'] or self.switches['cc']: 941 | match [self.nonlinearity, f_col]: 942 | case None, None: 943 | R_ = (self.alpha_z / self.alpha_t) * dxdt.T - A @ (self.alpha_z * zp_col_hat.T) 944 | case [_, None]: 945 | gz = self.nonlinearity.gz_func(self.alpha_z*zp_col_hat.T) 946 | R_ = (self.alpha_z / self.alpha_t)*dxdt.T - A @ (self.alpha_z * zp_col_hat.T) - An @ gz 947 | case [None, torch.Tensor()]: 948 | R_ = (self.alpha_z / self.alpha_t) * dxdt.T - A @ (self.alpha_z * zp_col_hat.T) - H @ (self.alpha_f * f_col.T) 949 | case [_, torch.Tensor()]: 950 | gz = self.nonlinearity.gz_func(self.alpha_z * zp_col_hat.T) 951 | R_ = (self.alpha_z / self.alpha_t) * dxdt.T - A @ (self.alpha_z * zp_col_hat.T) - An @ gz - H @ (self.alpha_f * f_col.T) 952 | R_ode = R_[self.n_dof:, :].T / self.alpha_f 953 | else: 954 | R_ode = torch.zeros((self.n_dof, 1)) 955 | 956 | if self.switches['cc']: 957 | # continuity condition residual 958 | R_cc = R_[:self.n_dof, :].T / self.alpha_v 959 | else: 960 | R_cc = torch.zeros((self.n_dof, 1)) 961 | 962 | residuals = { 963 | "R_obs" : R_obs, 964 | "R_ncc" : R_ncc, 965 | "R_cc" : R_cc, 966 | "R_ode" : R_ode 967 | } 968 | 969 | # likelihoods 970 | N_o = R_obs.shape[0] 971 | # Sigma = torch.diag(torch.cat((self.sigma_x * torch.ones(self.n_dof), self.sigma_v * torch.ones(self.n_dof)))) 972 | if self.switches['obs']: 973 | ## same sigma for all states 974 | # log_likeli_obs = -N_o * torch.log(self.sigma_z) - (N_o/2) * torch.log(2*self.t_pi) - 0.5 * torch.sum((R_obs**2/self.sigma_z**2)) 975 | 976 | ## looping with separate sigma_x and sigma_v 977 | sigmas = [self.sigma_x, self.sigma_v] 978 | log_likeli_obs = torch.tensor(0.0) 979 | for d in range(2): 980 | if obs_dropouts: 981 | log_likeli_obs += (- 0.5 * N_o * torch.log(2*self.t_pi) - N_o * torch.log(sigmas[d]) - 0.5 * torch.sum(torch.sum(R_obs[:, int(d*n_obs_state/2):int((d+1)*n_obs_state/2)]**2, dim=1)/sigmas[d]**2, dim=0)) 982 | else: 983 | log_likeli_obs += (- 0.5 * N_o * torch.log(2*self.t_pi) - N_o * torch.log(sigmas[d]) - 0.5 * torch.sum(torch.sum(R_obs[:, d*self.n_dof:(d+1)*self.n_dof]**2, dim=1)/sigmas[d]**2, dim=0)) 984 | L_obs = lambdas['obs'] * - log_likeli_obs 985 | 986 | ## full multivariate calc 987 | # dist_log_term = torch.tensor(0., dtype=torch.float32) 988 | # for i in range(R_obs.shape[0]): 989 | # dist_log_term += R_obs[i,:].T @ torch.linalg.inv(Sigma) @ R_obs[i,:] 990 | # log_likeli_obs = -(N_o/2) * (2 * self.n_dof * torch.log(2*self.t_pi) + torch.log(torch.linalg.det(Sigma))) - 0.5 * dist_log_term 991 | else: 992 | log_likeli_obs = torch.tensor(0.0) 993 | L_obs = torch.tensor(0.0) 994 | 995 | N_c = R_ode.shape[0] 996 | log_likeli_ode = torch.tensor(0.0) 997 | if self.switches['ode']: 998 | log_likeli_ode = -0.5 * N_c * torch.log(2*self.t_pi) - N_c * torch.log(self.sigma_f) - 0.5 * torch.sum(torch.sum(R_ode**2, dim=1)/self.sigma_f**2, dim=0) 999 | L_ode = lambdas['ode'] * - log_likeli_ode 1000 | else: 1001 | log_likeli_ode = torch.tensor(0.0) 1002 | L_ode = torch.tensor(0.0) 1003 | 1004 | if self.switches['cc']: 1005 | log_likeli_cc = - 0.5 * N_c * torch.log(2*self.t_pi) - N_c * torch.log(self.sigma_v) - 0.5 * torch.sum(torch.sum(R_cc**2, dim=1)/self.sigma_v**2, dim=0) 1006 | L_cc = lambdas['cc'] * - log_likeli_cc 1007 | else: 1008 | L_cc = torch.tensor(0.0) 1009 | 1010 | L_ncc = lambdas['ncc'] * torch.mean(R_ncc*2, dim=0) 1011 | 1012 | loss = L_obs + L_ode + L_cc + L_ncc 1013 | 1014 | if math.isnan(loss): 1015 | raise Exception("Nan again for some bloody reason") 1016 | 1017 | return loss, [L_obs, L_cc, L_ode, L_ncc], residuals 1018 | 1019 | def predict(self, t_pred, theta_s=None, f_col=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 1020 | """ 1021 | Predict state values 1022 | """ 1023 | zp_ = self.forward(t_pred) 1024 | # tp_sort = [None] * self.seq_len 1025 | # zp_sort = [None] * self.seq_len 1026 | # fc_sort = [None] * self.seq_len 1027 | # retrieve derivatives 1028 | dxdt_list = [torch.zeros((t_pred[0].shape[0], 2*self.n_dof)) for _ in range(self.seq_len)] 1029 | for nq in range(self.seq_len): 1030 | for i in range(2*self.n_dof): 1031 | dxdt_list[nq][:, i] = torch.autograd.grad(zp_[nq][:, i], t_pred[nq], torch.ones_like(zp_[nq][:, i]), create_graph=True)[0][:, 0] 1032 | # tp_sort[nq], sort_ids = torch.sort(t_pred[nq].reshape(-1)) 1033 | # zp_sort[nq] = zp_[nq][sort_ids, :] 1034 | # if f_col is not None: 1035 | # fc_sort[nq] = f_col[nq][sort_ids, :] 1036 | # dxdt_list[nq] = dxdt_list[nq][sort_ids, :] 1037 | 1038 | # reshape 1039 | # t_pred_flat = torch.cat(tp_sort, dim=0) 1040 | t_pred_flat, sort_ids = torch.sort(torch.cat(t_pred, dim=0).reshape(-1)) 1041 | zp_flat = torch.cat(zp_, dim=0)[sort_ids, :] 1042 | if f_col is not None: 1043 | f_col_ = torch.cat(f_col, dim=0)[sort_ids, :] 1044 | dxdt = torch.cat(dxdt_list, dim=0)[sort_ids, :] 1045 | 1046 | # retrieve physical parameters 1047 | if hasattr(self,"A") and (theta_s is None): 1048 | M, C, K = self.M, self.C, self.K 1049 | A = self.A 1050 | else: 1051 | if theta_s is None: 1052 | params = {} 1053 | for param_name, param_dict in self.config["phys_params"].items(): 1054 | if param_dict["type"] == "constant": 1055 | params[param_name] = param_dict["value"] 1056 | else: 1057 | # params[param_name] = self.param_transforms[param_name](getattr(self,param_name)) 1058 | params[param_name] = getattr(self,param_name)*getattr(self,"alpha_"+param_name[:-1]) 1059 | M, C, K = self.param_func(params["m_"], params["c_"], params["k_"]) 1060 | else: 1061 | M, C, K = self.param_func(self.m_, theta_s[:self.n_dof]*self.alpha_c, theta_s[self.n_dof:2*self.n_dof]*self.alpha_k) 1062 | invM = torch.diag(1/torch.diag(M)) 1063 | A = torch.cat(( 1064 | torch.cat((torch.zeros((self.n_dof, self.n_dof), device=self.device), torch.eye(self.n_dof, device=self.device)), dim=1), 1065 | torch.cat((-invM @ K, -invM @ C), dim=1) 1066 | ), dim=0).requires_grad_() 1067 | 1068 | # nonlinear parameters 1069 | if self.nonlinearity is not None: 1070 | if theta_s is None: 1071 | An = self.nonlinearity.mat_func(params['kn_'], params['cn_'], invM) 1072 | else: 1073 | kn_ = torch.tensor(theta_s[2*self.n_dof:3*self.n_dof], dtype=torch.float32) 1074 | cn_ = torch.zeros_like(kn_) 1075 | An = self.nonlinearity.mat_func(kn_, cn_, invM) 1076 | 1077 | match self.nonlinearity: 1078 | case None: 1079 | Hf_pred = (self.alpha_z / self.alpha_t) * dxdt.T - A @ (self.alpha_z * zp_flat.T) 1080 | case _: 1081 | gz = self.nonlinearity.gz_func(self.alpha_z * zp_flat.T) 1082 | Hf_pred = (self.alpha_z / self.alpha_t) * dxdt.T - A @ (self.alpha_z * zp_flat.T) - An @ gz 1083 | f_pred = M @ Hf_pred[self.n_dof:, :] 1084 | 1085 | if f_col is None: 1086 | return zp_flat, f_pred.T, t_pred_flat, dxdt 1087 | else: 1088 | return zp_flat, f_pred.T, t_pred_flat, dxdt, f_col_ 1089 | 1090 | def locked_force_pred(self, theta_s, z_pred, dzdt_pred): 1091 | 1092 | # retrieve physical parameters 1093 | M, C, K = self.param_func(self.m_, theta_s[:self.n_dof]*self.alpha_c, theta_s[self.n_dof:2*self.n_dof]*self.alpha_k) 1094 | invM = torch.diag(1/torch.diag(M)) 1095 | # M1, C1, K1 = self.param_func(self.m_, torch.ones(self.n_dof), 15.0*torch.ones(self.n_dof)) 1096 | A = torch.cat(( 1097 | torch.cat((torch.zeros((self.n_dof, self.n_dof), device=self.device), torch.eye(self.n_dof, device=self.device)), dim=1), 1098 | torch.cat((-invM @ K, -invM@C), dim=1) 1099 | ), dim=0).requires_grad_() 1100 | # A1 = torch.cat(( 1101 | # torch.cat((torch.zeros((self.n_dof, self.n_dof), device=self.device), torch.eye(self.n_dof, device=self.device)), dim=1), 1102 | # torch.cat((-invM @ K1, -invM@C1), dim=1) 1103 | # ), dim=0).requires_grad_() 1104 | 1105 | # nonlinear parameters 1106 | if self.nonlinearity is not None: 1107 | if self.nonlinearity.gk_exp is not None: 1108 | kn__ = theta_s[2*self.n_dof:3*self.n_dof] 1109 | cn__ = torch.zeros(self.n_dof) 1110 | elif self.nonlinearity.gc_exp is not None: 1111 | kn__ = torch.zeros(self.n_dof) 1112 | cn__ = theta_s[2*self.n_dof:3*self.n_dof] 1113 | # cn__1 = 0.75 * torch.ones(self.n_dof) / self.alpha_cn 1114 | #cn__[0] = theta_s[-2] 1115 | 1116 | An = self.nonlinearity.mat_func(self.alpha_kn * kn__, self.alpha_cn * cn__, invM) 1117 | # An1 = self.nonlinearity.mat_func(self.alpha_kn * kn__, self.alpha_cn * cn__1, invM) 1118 | 1119 | match self.nonlinearity: 1120 | case None: 1121 | Hf_pred = (self.alpha_z / self.alpha_t) * dzdt_pred.T - A @ (self.alpha_z * z_pred.T) 1122 | case _: 1123 | gz = self.nonlinearity.gz_func(self.alpha_z * z_pred.T) 1124 | Hf_pred = (self.alpha_z / self.alpha_t) * dzdt_pred.T - A @ (self.alpha_z * z_pred.T) - An @ gz 1125 | # Hf_pred1 = (self.alpha_z / self.alpha_t) * dzdt_pred.T - A1 @ (self.alpha_z * z_pred.T) - An1 @ gz 1126 | f_pred = M @ Hf_pred[self.n_dof:, :] 1127 | # f_pred1 = M @ Hf_pred1[self.n_dof:, :] 1128 | self.f_pred = f_pred.T 1129 | 1130 | return f_pred.T / self.alpha_f 1131 | 1132 | def phys_log_likeli(self, theta_s): 1133 | 1134 | z_pred = self.z_pred 1135 | dzdt_pred = self.dzdt_pred 1136 | f_col = self.f_col 1137 | 1138 | # retrieve physical parameters 1139 | M, C, K = self.param_func(self.m_, theta_s[:self.n_dof]*self.alpha_c, theta_s[self.n_dof:2*self.n_dof]*self.alpha_k) 1140 | invM = torch.diag(1/torch.diag(M)) 1141 | A = torch.cat(( 1142 | torch.cat((torch.zeros((self.n_dof, self.n_dof), device=self.device), torch.eye(self.n_dof, device=self.device)), dim=1), 1143 | torch.cat((-invM @ K, -invM@C), dim=1) 1144 | ), dim=0).requires_grad_() 1145 | 1146 | # nonlinear parameters 1147 | if self.nonlinearity is not None: 1148 | # An = self.nonlinearity.mat_func(theta_s[2*self.n_dof:3*self.n_dof], theta_s[3*self.n_dof:4*self.n_dof], invM) 1149 | kn__ = torch.tensor(theta_s[2*self.n_dof:3*self.n_dof], dtype=torch.float32) 1150 | cn__ = torch.zeros(self.n_dof) 1151 | An = self.nonlinearity.mat_func(self.alpha_kn * kn__, self.alpha_cn * cn__, invM) 1152 | 1153 | match self.nonlinearity: 1154 | case None: 1155 | Hf_pred = (self.alpha_z / self.alpha_t) * dzdt_pred.T - A @ (self.alpha_z * z_pred.T) 1156 | case _: 1157 | gz = self.nonlinearity.gz_func(self.alpha_z * z_pred.T) 1158 | Hf_pred = (self.alpha_z / self.alpha_t) * dzdt_pred.T - A @ (self.alpha_z * z_pred.T) - An @ gz 1159 | f_pred = M @ Hf_pred[self.n_dof:, :] 1160 | 1161 | sigma_f = theta_s[-1] 1162 | 1163 | N_c = f_pred.shape[0] 1164 | f_res = torch.sqrt(torch.sum((f_col * self.alpha_f - f_pred.T)**2, dim=1)) 1165 | log_likeli_ode = - N_c * torch.log(sigma_f) - (N_c/2) * torch.log(2*self.t_pi) - 0.5 * torch.sum((f_res**2/(sigma_f**2))) 1166 | return log_likeli_ode 1167 | 1168 | 1169 | class mdof_dataset(torch.utils.data.Dataset): 1170 | 1171 | def __init__(self, t_data, acc_data, f_data = None, data_config = None, device = torch.device("cpu"), force_drop = False): 1172 | 1173 | if data_config is not None: 1174 | self.subsample = data_config['subsample'] # subsample simulates lower sample rate 1175 | self.seq_len = data_config['seq_len'] # subsample simulates lower sample rate 1176 | else: 1177 | self.subsample = 1 1178 | self.seq_len = 1 1179 | n_dof = acc_data.shape[1] 1180 | if acc_data.shape[1] != f_data.shape[1]: 1181 | raise Exception("Dimension mismatch for data, please check DOFs dimension of data") 1182 | 1183 | # normalise data based on range 1184 | t_data, alpha_t = normalise(t_data, "range") 1185 | acc_data, alpha_a = normalise(acc_data, "range", "all") 1186 | if f_data is not None: 1187 | if not force_drop: 1188 | f_data, alpha_f = normalise(f_data, "range", "all") 1189 | else: 1190 | f_data, alpha_f = torch.zeros_like(f_data), torch.tensor(0.0) 1191 | 1192 | # concatenate into one large dataset 1193 | data = torch.cat((acc_data, f_data, t_data), dim=1) 1194 | self.alphas = { 1195 | "a" : alpha_a, 1196 | "f" : alpha_f, 1197 | "t" : alpha_t 1198 | } 1199 | 1200 | # reshape to batches and sequences 1201 | # ndof for acc, 1 ndof for force, 1 for time 1202 | self.ground_truth = data.to(device) 1203 | col_data = data[:(data.shape[0] // (self.seq_len * self.subsample)) * (self.subsample * self.seq_len)] # cut off excess data 1204 | 1205 | # create observation data from subsample 1206 | obs_data = col_data[::self.subsample] 1207 | # self.obs_data = obs_data.reshape((-1, self.seq_len, 3 * n_dof + 1)).to(device) 1208 | self.obs_data = obs_data.T.reshape((2*n_dof+1, self.seq_len, -1)).permute(2, 1, 0).to(device) 1209 | 1210 | # create collocation data 1211 | # self.col_data = col_data.reshape((-1, self.seq_len, self.subsample, 3 * n_dof + 1)).to(device) 1212 | self.col_data = col_data.T.reshape((2*n_dof+1, self.seq_len, self.subsample, -1)).permute(3, 2, 1, 0).to(device) 1213 | else: 1214 | # concatenate into one large dataset 1215 | data = torch.cat((acc_data, t_data), dim=1) 1216 | self.alphas = { 1217 | "a" : alpha_a, 1218 | "t" : alpha_t 1219 | } 1220 | 1221 | # reshape to number of batches 1222 | # 2 ndof for state, 1 for time 1223 | self.ground_truth = data.to(device) 1224 | col_data = data[:(data.shape[0] // (self.seq_len * self.subsample)) * (self.seq_len * self.subsample)] # cut off excess data 1225 | 1226 | # create obervation data from subsample 1227 | obs_data = col_data[::self.subsample] 1228 | # self.data = data.reshape((-1, self.seq_len, 2 * n_dof + 1)).to(device) 1229 | self.obs_data = obs_data.T.reshape((n_dof+1, self.seq_len, -1)).T.to(device) 1230 | 1231 | # self.col_data = col_data.reshape((-1, self.subsample, self.seq_len, 2 * n_dof + 1)).to(device) 1232 | self.col_data = col_data.T.reshape((n_dof+1, self.seq_len, self.subsample, -1)).T.to(device) 1233 | 1234 | def __getitem__(self, index: int) -> np.ndarray: 1235 | return self.obs_data[index, ...], self.col_data[index, ...] 1236 | 1237 | def get_original(self, index: int) -> np.ndarray: 1238 | return self.ground_truth[index] 1239 | 1240 | def __len__(self) -> int: 1241 | return self.obs_data.shape[0] 1242 | 1243 | def __repr__(self) -> str: 1244 | return self.__class__.__name__ 1245 | 1246 | 1247 | class mdof_stoch_dataset(torch.utils.data.Dataset): 1248 | 1249 | def __init__(self, t_data, x_data, v_data, f_data = None, snr = 50.0, num_repeats = 1, data_config = None, device = torch.device("cpu")): 1250 | 1251 | self.num_repeats = num_repeats 1252 | 1253 | if data_config is not None: 1254 | self.subsample = data_config['subsample'] # subsample simulates lower sample rate 1255 | self.seq_len = data_config['seq_len'] # subsample simulates lower sample rate 1256 | else: 1257 | self.subsample = 1 1258 | self.seq_len = 1 1259 | n_dof = x_data.shape[1] 1260 | self.n_dof = n_dof 1261 | if x_data.shape[1] != v_data.shape[1]: 1262 | raise Exception("Dimension mismatch for data, please check DOFs dimension of data") 1263 | 1264 | # add noise and collate to shapes [num_samps, dof, num_repeats] 1265 | t_noisy = torch.cat([t_data.unsqueeze(2) for _ in range(num_repeats)], dim=2) 1266 | xx_noisy_list = [torch.tensor(add_noise(x_data, SNR=snr, seed = 42 + j)).unsqueeze(2) for j in range(num_repeats)] 1267 | vv_noisy_list = [torch.tensor(add_noise(v_data, SNR=snr, seed = 8 + j)).unsqueeze(2) for j in range(num_repeats)] 1268 | xx_noisy = torch.cat(xx_noisy_list, dim=2) 1269 | vv_noisy = torch.cat(vv_noisy_list, dim=2) 1270 | if f_data is not None: 1271 | ff_noisy_list = [torch.tensor(add_noise(f_data, SNR=snr, seed = 16 + j)).unsqueeze(2) for j in range(num_repeats)] 1272 | ff_noisy = torch.cat(ff_noisy_list, dim=2) 1273 | 1274 | # normalise data based on range 1275 | t_obs, alpha_t = normalise(t_noisy, "range", "all") 1276 | x_obs, alpha_x = normalise(xx_noisy, "range", "all") 1277 | v_obs, alpha_v = normalise(vv_noisy, "range", "all") 1278 | if f_data is not None: 1279 | f_obs, alpha_f = normalise(ff_noisy, "range", "all") 1280 | self.alphas = { 1281 | "x" : alpha_x, 1282 | "v" : alpha_v, 1283 | "f" : alpha_f, 1284 | "t" : alpha_t 1285 | } 1286 | 1287 | # dimension - 2 ndof for state, 1 ndof for force, 1 for time 1288 | 1289 | # create observation set from noisy data 1290 | # concatenate and subsample into observation dataset [num_samps, dimension, num_repeats] 1291 | obs_data_ = torch.cat((x_obs, v_obs, f_obs, t_obs), dim=1)[::self.subsample] 1292 | # cutoff excess data before reshape 1293 | obs_data_ = obs_data_[:(obs_data_.shape[0] // self.seq_len) * self.seq_len] 1294 | # reshape to [num_samps_per_seq, num_repeats, seq_len, dimension] 1295 | obs_data = torch.zeros((obs_data_.shape[0] // self.seq_len, num_repeats, self.seq_len, 3*n_dof+1)) 1296 | for i in range(self.num_repeats): 1297 | obs_data[:, i, :, :] = obs_data_[..., i].T.reshape((3*n_dof+1, self.seq_len, -1)).T 1298 | self.obs_data = obs_data.to(device) 1299 | 1300 | # create collocation data 1301 | # concatenate into [num_samps, dimension] 1302 | col_data_ = torch.tensor(np.concatenate((x_data/alpha_x, v_data/alpha_v, f_data/alpha_f, t_data/alpha_t), axis=1), dtype=torch.float32) 1303 | # set ground truth 1304 | self.ground_truth = col_data_.to(device) 1305 | # cutoff excess data before reshape 1306 | col_data = col_data_[:(col_data_.shape[0] // (self.seq_len)) * (self.seq_len)] # cut off excess data 1307 | # reshape to [num_samps_per_seq, subsample, seq_len, dimension] 1308 | self.col_data = col_data.T.reshape((3*n_dof+1, self.seq_len, self.subsample, -1)).T.to(device) 1309 | 1310 | else: 1311 | pass # sort this out once you have figured what you want to do 1312 | 1313 | raise NotImplementedError("Lack of force data not implemented yet") 1314 | --------------------------------------------------------------------------------