├── LICENSE.md ├── README.md ├── apps ├── data.py └── synthetic.py ├── krylov ├── arnoldi.py ├── cg.py ├── gmres.py ├── preconditioner.py └── test_krylov.py ├── neuralif ├── logger.py ├── loss.py ├── models.py └── utils.py ├── test.py └── train.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Paul Häusner 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural incomplete factorization 2 | 3 | This repository contains the code for learning incomplete factorization preconditioners directly from data by [Paul Häusner](https://paulhausner.github.io), Aleix Nieto Juscafresa, [Ozan Öktem](https://www.kth.se/profile/ozan), and [Jens Sjölund](https://jsjol.github.io/). 4 | 5 | ## Installation 6 | 7 | In order to run the training and testing, you need to install the following python dependencies: 8 | 9 | - pytorch 10 | - pytorch-geometric 11 | - scipy 12 | - networkx 13 | 14 | For validation and testing the following packages are required: 15 | 16 | - matplotlib 17 | - [numml](https://github.com/nicknytko/numml) (for efficient forward-backward substitution) 18 | - [ilupp](https://github.com/c-f-h/ilupp) (for baseline incomplete factorization preconditioners) 19 | 20 | ## Implementation 21 | 22 | The repository consists of several parts. In the `krylov` folder implementations for the conjugate gradient method and GRMES method are provided. Further, several preconditioner (Jacobi, ILU, IC) are implemented. 23 | 24 | The `neuralif` module contains the code for the learned preconditioner. The model.py file contains the different models that can be utilizes, loss.py implements several different loss functions. 25 | 26 | A synthetic dataset is provided in the folder `apps`. 27 | 28 | ## References 29 | 30 | If our code helps your research or work, please consider citing our paper. The following are BibTeX references: 31 | 32 | ``` 33 | @article{hausner2024neural, 34 | title={Neural incomplete factorization: learning preconditioners for the conjugate gradient method}, 35 | author={Paul H{\"a}usner and Ozan {\"O}ktem and Jens Sj{\"o}lund}, 36 | journal={Transactions on Machine Learning Research}, 37 | issn={2835-8856}, 38 | year={2024}, 39 | url={https://openreview.net/forum?id=FozLrZ3CI5} 40 | } 41 | 42 | @InProceedings{hausner2025learning, 43 | title={Learning incomplete factorization preconditioners for {GMRES}}, 44 | author={H{\"a}usner, Paul and Nieto Juscafresa, Aleix and Sj{\"o}lund, Jens}, 45 | booktitle={Proceedings of the 6th Northern Lights Deep Learning Conference (NLDL)}, 46 | pages={85--99}, 47 | year={2025}, 48 | volume={265}, 49 | series={Proceedings of Machine Learning Research}, 50 | publisher={PMLR}, 51 | } 52 | 53 | ``` 54 | 55 | Please feel free to reach out if you have any questions or comments. 56 | 57 | Contact: Paul Häusner, paul.hausner@it.uu.se 58 | -------------------------------------------------------------------------------- /apps/data.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.sparse import coo_matrix 6 | from torch_geometric.data import Data 7 | 8 | from torch_geometric.loader import DataLoader 9 | 10 | 11 | def matrix_to_graph_sparse(A, b): 12 | edge_index = torch.tensor(list(map(lambda x: [x[0], x[1]], zip(A.row, A.col))), dtype=torch.long) 13 | edge_features = torch.tensor(list(map(lambda x: [x], A.data)), dtype=torch.float) 14 | node_features = torch.tensor(list(map(lambda x: [x], b)), dtype=torch.float) 15 | 16 | # diag_elements = edge_index[:, 0] == edge_index[:, 1] 17 | # node_features = edge_features[diag_elements] 18 | # node_features = torch.cat((node_features, torch.tensor(list(map(lambda x: [x], b)), dtype=torch.float)), dim=1) 19 | 20 | # Embed the information into data object 21 | data = Data(x=node_features, edge_index=edge_index.t().contiguous(), edge_attr=edge_features) 22 | return data 23 | 24 | 25 | def matrix_to_graph(A, b): 26 | return matrix_to_graph_sparse(coo_matrix(A), b) 27 | 28 | 29 | def graph_to_matrix(data, normalize=False): 30 | A = torch.sparse_coo_tensor(data.edge_index, data.edge_attr[:, 0].squeeze(), requires_grad=False) 31 | b = data.x[:, 0].squeeze() 32 | 33 | if normalize: 34 | b = b / torch.linalg.norm(b) 35 | 36 | return A, b 37 | 38 | 39 | def get_dataloader(dataset, n=0, batch_size=1, spd=True, mode="train", size=None, graph=True): 40 | # Setup datasets 41 | 42 | if dataset == "random": 43 | data = FolderDataset(f"./data/Random/{mode}/", n, size=size, graph=graph) 44 | 45 | else: 46 | raise NotImplementedError("Dataset not implemented, Available: random") 47 | 48 | # Data Loaders 49 | if mode == "train": 50 | dataloader = DataLoader(data, batch_size=batch_size, shuffle=True) 51 | else: 52 | dataloader = DataLoader(data, batch_size=1, shuffle=False) 53 | 54 | return dataloader 55 | 56 | 57 | class FolderDataset(torch.utils.data.Dataset): 58 | def __init__(self, folder, n, graph=True, size=None) -> None: 59 | super().__init__() 60 | 61 | self.graph = True 62 | assert self.graph, "Graph keyword is depracated, only graph=True is supported." 63 | 64 | if n != 0: 65 | if self.graph: 66 | self.files = list(filter(lambda x: x.split("/")[-1].split('_')[0] == str(n), glob(folder+'*.pt'))) 67 | else: 68 | self.files = list(filter(lambda x: x.split("/")[-1].split('_')[0] == str(n), glob(folder+'*.npz'))) 69 | else: 70 | file_ending = "pt" if self.graph else "npz" 71 | self.files = list(glob(folder+f'*.{file_ending}')) 72 | 73 | if size is not None: 74 | assert len(self.files) >= size, f"Only {len(self.files)} files found in {folder} with n={n}" 75 | self.files = self.files[:size] 76 | 77 | if len(self.files) == 0: 78 | raise FileNotFoundError(f"No files found in {folder} with n={n}") 79 | 80 | def __len__(self): 81 | return len(self.files) 82 | 83 | def __getitem__(self, idx): 84 | if self.graph: 85 | g = torch.load(self.files[idx], weights_only=False) 86 | 87 | else: 88 | # deprecated... 89 | d = np.load(self.files[idx], allow_pickle=True) 90 | g = matrix_to_graph(d["A"], d["b"]) 91 | 92 | return g 93 | -------------------------------------------------------------------------------- /apps/synthetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import scipy 6 | from scipy.sparse import coo_matrix 7 | 8 | from data import matrix_to_graph 9 | 10 | 11 | def generate_sparse_random(n, alpha=1e-4, random_state=0, sol=False, ood=False): 12 | # We add to spd matricies since the sparsity is only enforced on the cholesky decomposition 13 | # generare a lower trinagular matrix 14 | # Random state 15 | rng = np.random.RandomState(random_state) 16 | 17 | if alpha is None: 18 | alpha = rng.uniform(1e-4, 1e-2) 19 | 20 | # this is 1% sparsity for n = 10 000 21 | sparsity = 10e-4 22 | 23 | # create out of distribution samples 24 | if ood: 25 | factor = rng.uniform(0.22, 2.2) 26 | sparsity = factor * sparsity 27 | 28 | nnz = int(sparsity * n ** 2) 29 | rows = [rng.randint(0, n) for _ in range(nnz)] 30 | cols = [rng.randint(0, n) for _ in range(nnz)] 31 | 32 | uniques = set(zip(rows, cols)) 33 | rows, cols = zip(*uniques) 34 | 35 | # generate values 36 | vals = np.array([rng.normal(0, 1) for _ in cols]) 37 | 38 | M = coo_matrix((vals, (rows, cols)), shape=(n, n)) 39 | I = scipy.sparse.identity(n) 40 | 41 | # create spd matrix 42 | A = (M @ M.T) + alpha * I 43 | print(f"Generated matrix with {100 * (A.nnz / n**2) :.2f}% non-zero elements: ({A.nnz} non-zeros)") 44 | 45 | # right hand side is uniform 46 | b = rng.uniform(0, 1, size=n) 47 | 48 | # We want a high-accuracy solution, so we use a direct sparse solver here. 49 | # only produce when in test mode 50 | if sol: 51 | # generate solution using dense method for accuracy reasons 52 | x, _ = scipy.sparse.linalg.cg(A, b) 53 | 54 | else: 55 | x = None 56 | 57 | return A, x, b 58 | 59 | 60 | def create_dataset(n, samples, alpha=1e-2, graph=True, rs=0, mode='train', solution=False): 61 | if mode != 'train': 62 | assert rs != 0, 'rs must be set for test and val to avoid overlap' 63 | 64 | print(f"Generating {samples} samples for the {mode} dataset.") 65 | 66 | for sam in range(samples): 67 | # generate solution only for val and test 68 | 69 | A, x, b = generate_sparse_random(n, random_state=(rs + sam), alpha=alpha, sol=solution, 70 | ood=(mode=="test_ood")) 71 | 72 | if graph: 73 | graph = matrix_to_graph(A, b) 74 | if x is not None: 75 | graph.s = torch.tensor(x, dtype=torch.float) 76 | graph.n = n 77 | torch.save(graph, f'./data/Random/{mode}/{n}_{sam}.pt') 78 | else: 79 | A = coo_matrix(A) 80 | np.savez(f'./data/Random/{mode}/{n}_{sam}.npz', A=A, b=b, x=x) 81 | 82 | 83 | if __name__ == '__main__': 84 | # create the folders and subfolders where the data is stored 85 | os.makedirs(f'./data/Random/train', exist_ok=True) 86 | os.makedirs(f'./data/Random/val', exist_ok=True) 87 | os.makedirs(f'./data/Random/test', exist_ok=True) 88 | 89 | # create 10k dataset 90 | n = 10_000 91 | alpha=10e-4 92 | 93 | create_dataset(n, 1000, alpha=alpha, mode='train', rs=0, graph=True, solution=True) 94 | create_dataset(n, 10, alpha=alpha, mode='val', rs=10000, graph=True) 95 | create_dataset(n, 100, alpha=alpha, mode='test', rs=103600, graph=True) 96 | -------------------------------------------------------------------------------- /krylov/arnoldi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def arnoldi(M, A, r0, m, tol=1e-12): 5 | """ 6 | This function computes an orthonormal basis 7 | 8 | V_m = {v_1,...,v_{m+1}} 9 | 10 | of K_{m+1}(A, r^{(0)}) = span{r^{(0)}, Ar^{(0)}, ..., A^{m}r^{(0)}}. 11 | 12 | Input parameters: 13 | ----------------- 14 | A: array_like 15 | An (n x n) array. 16 | 17 | b: array_like 18 | Initial vector of length n 19 | 20 | m: int 21 | One less than the dimension of the Krylov subspace. Must be > 0. 22 | 23 | r0: array_like 24 | Initial residual (length n) 25 | 26 | tol: 27 | Tolerance for convergence 28 | 29 | Output: 30 | ------- 31 | Q: numpy.array 32 | n x (m + 1) array, the columns are an orthonormal basis of the Krylov subspace. 33 | 34 | H: numpy.array 35 | An (m + 1) x m array. It is the matrix A on basis Q. It is upper Hessenberg. 36 | """ 37 | 38 | # Check inputs 39 | n = A.shape[0] 40 | d = r0.dtype 41 | 42 | # assert A.shape == (n, n) and b.shape == (n,) and r0.shape == (n,), "Matrix and vector dimensions don not match" 43 | # assert isinstance(m, int) and m >= 0, "m must be a positive integer" 44 | 45 | m = min(m, n) 46 | 47 | # Initialize matrices 48 | V = torch.zeros((n, m + 1), dtype=d) 49 | H = torch.zeros((m + 1, m), dtype=d) 50 | 51 | # Normalize input vector and use for Krylov vector 52 | beta = torch.linalg.norm(r0) 53 | V[:, 0] = r0 / beta 54 | 55 | for k in range(1, m + 1): 56 | # Generate a new candidate vector 57 | w = M(A @ V[:, k - 1]) # Note that here is different from arnoldi_one_iter as we iter over k from 1 to m. 58 | # In arnoldi_one_iter we have k as inputo to the function and we have V[:, k - 1] as k starts at 0. 59 | 60 | # Orthogonalization 61 | for j in range(k): 62 | H[j, k - 1] = V[:, j] @ w 63 | w -= H[j, k - 1] * V[:, j] 64 | 65 | H[k, k - 1] = torch.linalg.norm(w) 66 | 67 | # Check convergence 68 | if H[k, k - 1] <= tol: 69 | return V, H 70 | 71 | # Normalize and store the new basis vector 72 | V[:, k] = w / H[k, k - 1] 73 | 74 | return V, H 75 | 76 | 77 | def arnoldi_step(M, A, V, k, left=True, tol=1e-12): 78 | 79 | n = A.shape[0] # Dimension of the matrix 80 | d = A.dtype # Data type of the matrix 81 | 82 | # Initialize k + 2 nonzero elements of H along column k 83 | h_k = torch.zeros(k + 2, dtype=d) 84 | 85 | # Calculate the new vector in the Krylov subspace 86 | if left: 87 | v_new = M(A @ V[:, k]) 88 | else: 89 | v_new = A @ M(V[:, k]) 90 | # Calculate the first k elements of the kth Hessenberg column 91 | for j in range(k + 1): 92 | h_k[j] = v_new @ V[:, j] 93 | v_new -= h_k[j] * V[:, j] 94 | 95 | # Add the k+1 element 96 | h_k[k + 1] = torch.linalg.norm(v_new) 97 | 98 | # Early termination with exact solution 99 | if h_k[k + 1] <= tol: 100 | return h_k, None 101 | 102 | # Find the new orthogonal vector in the basis of the Krylov subspace 103 | v_new /= h_k[k + 1] 104 | 105 | return h_k, v_new 106 | -------------------------------------------------------------------------------- /krylov/cg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def stopping_criterion(A, rk, b): 5 | return torch.inner(rk, rk) / torch.inner(b, b) 6 | 7 | 8 | def conjugate_gradient(A, b, x0=None, x_true=None, rtol=1e-8, max_iter=100_000): 9 | x_hat = x0 if x0 is not None else torch.zeros_like(b) 10 | r = b - A@x_hat # residual 11 | p = r.clone() # search direction 12 | 13 | # Errors is a tuple of (error, residual) 14 | error_i = (x_hat - x_true) if x_true is not None else torch.zeros_like(b, requires_grad=False) 15 | res = stopping_criterion(A, r, b) 16 | errors = [(torch.inner(error_i, A@error_i), res)] 17 | 18 | for _ in range(max_iter): 19 | if res < rtol: 20 | break 21 | 22 | Ap = A@p 23 | r_norm = torch.inner(r, r) 24 | 25 | a = r_norm / torch.inner(Ap, p) # step length 26 | x_hat = x_hat + a * p 27 | r = r - a * Ap 28 | p = r + (torch.inner(r, r) / r_norm) * p 29 | 30 | error_i = (x_hat - x_true) if x_true is not None else torch.zeros_like(b, requires_grad=False) 31 | res = stopping_criterion(A, r, b) 32 | errors.append((torch.inner(error_i, A@error_i), res)) 33 | 34 | return errors, x_hat 35 | 36 | 37 | def preconditioned_conjugate_gradient(A, b, M=None, x0=None, x_true=None, rtol=1e-8, max_iter=100_000): 38 | # prec should be a function solving the linear equation system Mz=r one way or another 39 | # M is the preconditioner approximation of A^-1 or split approximation of MM^T=A 40 | # Saad, 2003 Algorithm 9.1 41 | 42 | if M is None: 43 | M = lambda x: x 44 | 45 | x_hat = x0 if x0 is not None else torch.zeros_like(b) 46 | 47 | rk = b - A@x_hat 48 | zk = M(rk) 49 | pk = zk.clone() 50 | 51 | # Errors is a tuple of (error, residual) 52 | error_i = (x_hat - x_true) if x_true is not None else torch.zeros_like(b, requires_grad=False) 53 | res = stopping_criterion(A, zk, b) 54 | errors = [(torch.inner(error_i, A@error_i), res)] 55 | 56 | for _ in range(max_iter): 57 | if res < rtol: 58 | break 59 | 60 | # precomputations 61 | Ap = A@pk 62 | rz = torch.inner(rk, zk) 63 | 64 | a = rz / torch.inner(Ap, pk) # step length 65 | x_hat = x_hat + a * pk 66 | rk = rk - a * Ap 67 | zk = M(rk) 68 | beta = torch.inner(rk, zk) / rz 69 | pk = zk + beta * pk 70 | 71 | error_i = (x_hat - x_true) if x_true is not None else torch.zeros_like(b, requires_grad=False) 72 | res = stopping_criterion(A, rk, b) 73 | errors.append((torch.inner(error_i, A@error_i), res)) 74 | 75 | return errors, x_hat 76 | -------------------------------------------------------------------------------- /krylov/gmres.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from krylov.arnoldi import arnoldi, arnoldi_step 4 | 5 | 6 | def back_substitution(A, b): 7 | """ 8 | Solve a linear system using back substitution. 9 | 10 | !!! Onyl for testing purposes, use torch.linalg.solve_triangular instead !!! 11 | 12 | Args: 13 | ---------- 14 | A: Coefficient matrix (must be upper triangular). 15 | b: Column vector of constants. 16 | 17 | Returns: 18 | -------- 19 | list: Solution vector. 20 | 21 | Raises: ValueError: If the matrix A is not square or if its dimensions are incompatible with the vector b. 22 | """ 23 | 24 | n = len(b) 25 | 26 | # Check if A is a square matrix 27 | if len(A) != n or any(len(row) != n for row in A): 28 | raise ValueError("Matrix A must be square.") 29 | 30 | # Check if dimensions of A and b are compatible 31 | if len(A) != len(b): 32 | raise ValueError("Dimensions of A and b are incompatible.") 33 | 34 | x = torch.zeros(n, dtype=b.dtype) 35 | 36 | for i in range(n - 1, -1, -1): 37 | x[i] = (b[i] - torch.sum(A[i, i+1:] * x[i+1:])) / A[i, i] 38 | 39 | return x 40 | 41 | 42 | def gmres(A, b, M=None, left=True, x0=None, x_true=None, atol=1e-8, rtol=None, max_iter=100_000, restart=None, plot=False): 43 | """ 44 | Restarted Generalized Minimal RESidual method for solving linear systems. 45 | 46 | Parameters: 47 | ----------- 48 | A : Coefficient matrix of the linear system. 49 | 50 | b : Right-hand side vector of the linear system. 51 | 52 | M : Preconditioner operator needs to allow M(x) to be computed for any vector x. 53 | 54 | x0 : Initial guess for the solution. 55 | 56 | k_max : Maximum number of iterations. When None we set k_max to the dimension of A. 57 | 58 | restart : Number of iterations before restart. If None, the method will not restart. 59 | 60 | rtol, atol : Tolerance for convergence. 61 | 62 | plot : If True, plot the convergence of the method (makes algorithm slower). 63 | 64 | Returns: 65 | -------- 66 | errors : Residual and error at each iteration. 67 | 68 | pk : Norm of the residual vector. 69 | """ 70 | 71 | n = A.shape[0] 72 | 73 | if max_iter is None or max_iter > n: 74 | max_iter = n 75 | 76 | if x0 is None: 77 | x0 = torch.zeros(n, dtype=b.dtype) 78 | 79 | if M is None: 80 | # identity preconditioner 81 | M = lambda x: x 82 | 83 | if left: 84 | r0 = M(b - A @ x0) 85 | else: 86 | r0 = b - A @ x0 87 | 88 | p0 = torch.linalg.norm(r0) 89 | 90 | pk = p0.clone() 91 | beta = p0.clone() 92 | 93 | def compute_solution(R, Q, V, beta, x0): 94 | # yk = back_substitution(R[:-1, :], beta*Q[0][:-1]) 95 | yk = torch.linalg.solve_triangular(R[:-1, :], beta * Q[0][:-1].reshape(-1, 1), upper=True) 96 | yk = yk.reshape(-1) 97 | if left: 98 | xk = x0 + V[:, :-1]@yk # Compute the new approximation x0 + V_{k}y 99 | else: 100 | xk = x0 + M(V[:, :-1]@yk) 101 | 102 | return xk 103 | 104 | error_i = (x0 - x_true) if x_true is not None else torch.zeros_like(b, requires_grad=False) 105 | errors = [(torch.linalg.norm(error_i), p0)] 106 | 107 | # Intialize the Arnoldi algorithm 108 | V_ = r0.clone().reshape(-1, 1) / beta 109 | H_ = torch.zeros((n+1, 1), dtype=b.dtype) 110 | 111 | k = 0 112 | for i in range(max_iter): 113 | 114 | # ARNOLDI ALGORITHM for krylov basis 115 | # Arnoldi algorithm to generate V_{k+1} and H_{K+1, K} 116 | # V, H = arnoldi(M, A, r0, k+1) 117 | 118 | h_new, v_new = arnoldi_step(M, A, V_, k, left=left) 119 | H_ = torch.cat((H_, torch.zeros((n + 1, 1))), axis=1) 120 | H_[:k+2, -1] = h_new 121 | 122 | if v_new is None: 123 | # found exact solution (this does not happen in practice) 124 | # ? not sure if we need to recompute the QR decomposition 125 | Q, R = torch.linalg.qr(H_[:k+2, 1:], mode='complete') # system of size m 126 | V_ = torch.cat((V_, torch.zeros(n, 1)), axis=1) 127 | errors.append((0, 0)) # logging reasons 128 | break 129 | 130 | V_ = torch.cat((V_, v_new.reshape(-1, 1)), axis=1) 131 | 132 | # QR DECOMPOSITION 133 | # TODO: can be achieved with rotation matrices afaik 134 | Q, R = torch.linalg.qr(H_[:k+2, 1:], mode='complete') # system of size m 135 | pk = torch.abs(beta*Q[0, k+1]) # Compute norm of residual vector 136 | 137 | k += 1 138 | 139 | # LOGGING 140 | if plot: 141 | xk = compute_solution(R, Q, V_, beta, x0) 142 | error_i = (xk - x_true) if x_true is not None else torch.zeros_like(b, requires_grad=False) 143 | errors.append((torch.norm(error_i), pk)) 144 | else: 145 | errors.append((errors[-1][0], pk)) 146 | 147 | # STOPPING CRITERIA 148 | if atol is not None and pk < atol: 149 | break 150 | if rtol is not None and pk < rtol*p0: 151 | break 152 | 153 | # RESTART (don't restart if we are in the last iteration) 154 | elif restart is not None and k == restart and i < max_iter - 1: 155 | 156 | # Compute current solution 157 | x0 = compute_solution(R, Q, V_, beta, x0) 158 | 159 | # Reset iterates 160 | r0 = M(b - A @ x0) 161 | p0 = torch.linalg.norm(r0) 162 | beta = p0.clone() 163 | pk = p0.clone() 164 | k = 0 165 | 166 | # Reset Arnoldi algorithm 167 | V_ = r0.clone().reshape(-1, 1) / beta 168 | H_ = torch.zeros((n+1, 1), dtype=b.dtype) 169 | 170 | xk = compute_solution(R, Q, V_, beta, x0) 171 | return errors, xk 172 | -------------------------------------------------------------------------------- /krylov/preconditioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import numml as nm 4 | 5 | import ilupp 6 | 7 | from neuralif.utils import torch_sparse_to_scipy, time_function 8 | from neuralif.models import NeuralIF 9 | 10 | 11 | class Preconditioner: 12 | def __init__(self, A, **kwargs): 13 | self.breakdown = False 14 | self.nnz = 0 15 | self.time = 0 16 | self.n = kwargs.get("n", 0) 17 | 18 | def timed_setup(self, A, **kwargs): 19 | start = time_function() 20 | self.setup(A, **kwargs) 21 | stop = time_function() 22 | self.time = stop - start 23 | 24 | def get_inverse(self): 25 | ones = torch.ones(self.n) 26 | offset = torch.zeros(1).to(torch.int64) 27 | 28 | I = torch.sparse.spdiags(ones, offset, (self.n, self.n)) 29 | I = I.to(torch.float64) 30 | 31 | return I 32 | 33 | def get_p_matrix(self): 34 | return self.get_inverse() 35 | 36 | def check_breakdown(self, P): 37 | if np.isnan(np.min(P)): 38 | self.breakdown = True 39 | 40 | def __call__(self, x): 41 | return x 42 | 43 | 44 | class JacobiPreconditioner(Preconditioner): 45 | def __init__(self, A, **kwargs): 46 | super().__init__(A, **kwargs) 47 | self.timed_setup(A) 48 | 49 | self.nnz = A.shape[0] 50 | 51 | def get_p_matrix(self): 52 | # we need to reinvert the matrix 53 | diag = 1 / self.P.values() 54 | offset = torch.zeros(1).to(torch.int64) 55 | 56 | Pinv = torch.sparse.spdiags(diag, offset, (self.n, self.n)) 57 | Pinv = Pinv.to(torch.float64) 58 | 59 | return Pinv 60 | 61 | def get_inverse(self): 62 | return self.P 63 | 64 | def setup(self, A): 65 | # We choose L = 1/D = diag(1/a11, 1/a22, ..., 1/ann) 66 | # data = 1 / torch.Tensor(torch.sqrt(A.diagonal())) 67 | data = 1 / torch.Tensor(A.diagonal()) 68 | indices = torch.vstack((torch.arange(A.shape[0]), torch.arange(A.shape[0]))) 69 | 70 | p = torch.sparse_coo_tensor(indices, data, size=A.shape) 71 | self.P = p.to(torch.float64).to_sparse_csr() 72 | 73 | def __call__(self, x): 74 | return self.P@x 75 | 76 | 77 | class ICholPreconditioner(Preconditioner): 78 | def __init__(self, A, **kwargs): 79 | super().__init__(A, **kwargs) 80 | 81 | self.timed_setup(A, **kwargs) 82 | self.nnz = self.L.nnz 83 | 84 | def setup(self, A, **kwargs): 85 | 86 | fill_in = kwargs.get("fill_in", 0.0) 87 | threshold = kwargs.get("threshold", 0.0) 88 | 89 | if fill_in == 0.0 and threshold == 0.0: 90 | icholprec = ilupp.ichol0(A.astype(np.float64).tocsr()) 91 | else: 92 | icholprec = ilupp.icholt(A.astype(np.float64).tocsr(), 93 | add_fill_in=fill_in, 94 | threshold=threshold) 95 | 96 | # icholprec = icholprec.astype(np.float32) 97 | self.check_breakdown(icholprec) 98 | 99 | # convert to nummel sparse format 100 | self.L = nm.sparse.SparseCSRTensor(icholprec) 101 | self.U = nm.sparse.SparseCSRTensor(icholprec.T) 102 | 103 | def get_p_matrix(self): 104 | return self.L@self.U 105 | 106 | def __call__(self, x): 107 | return fb_solve(self.L, self.U, x) 108 | 109 | 110 | class ILUPreconditioner(Preconditioner): 111 | def __init__(self, A, **kwargs): 112 | super().__init__(A, **kwargs) 113 | self.timed_setup(A, **kwargs) 114 | 115 | # don't count the diagonal twice in the process... 116 | self.nnz = self.L.nnz + self.U.nnz - A.shape[0] 117 | 118 | def get_inverse(self): 119 | L_inv = torch.inverse(self.L.to_dense()) 120 | U_inv = torch.inverse(self.U.to_dense()) 121 | 122 | return U_inv@L_inv 123 | 124 | def get_p_matrix(self): 125 | return self.L@self.U 126 | 127 | def setup(self, A, **kwargs): 128 | # compute ILU preconditioner using ilupp 129 | B = ilupp.ILU0Preconditioner(A.astype(np.float64).tocsr()) 130 | 131 | L, U = B.factors() 132 | 133 | # check breakdowns 134 | self.check_breakdown(L) 135 | self.check_breakdown(U) 136 | 137 | # convert to nummel sparse format 138 | self.L = nm.sparse.SparseCSRTensor(L) 139 | self.U = nm.sparse.SparseCSRTensor(U) 140 | 141 | def __call__(self, x): 142 | return fb_solve(self.L, self.U, x) 143 | 144 | 145 | class LearnedPreconditioner(Preconditioner): 146 | def __init__(self, data, model, **kwargs): 147 | super().__init__(data, **kwargs) 148 | 149 | self.model = model 150 | self.spd = isinstance(model, NeuralIF) 151 | 152 | self.timed_setup(data, **kwargs) 153 | 154 | if self.spd: 155 | self.nnz = self.L.nnz 156 | else: 157 | self.nnz = self.L.nnz + self.U.nnz - data.x.shape[0] 158 | 159 | def setup(self, data, **kwargs): 160 | L, U, _ = self.model(data) 161 | 162 | self.L = L.to("cpu").to(torch.float64) 163 | self.U = U.to("cpu").to(torch.float64) 164 | 165 | def get_inverse(self): 166 | L_inv = torch.inverse(self.L.to_dense()) 167 | U_inv = torch.inverse(self.U.to_dense()) 168 | 169 | return U_inv@L_inv 170 | 171 | def get_p_matrix(self): 172 | return self.L@self.U 173 | 174 | def __call__(self, x): 175 | return fb_solve(self.L, self.U, x, unit_upper=not self.spd) 176 | 177 | 178 | def fb_solve(L, U, r, unit_lower=False, unit_upper=False): 179 | y = L.solve_triangular(upper=False, unit=unit_lower, b=r) 180 | z = U.solve_triangular(upper=True, unit=unit_upper, b=y) 181 | return z 182 | 183 | 184 | def fb_solve_joint(LU, r): 185 | # Note: solve triangular ignores the values in lower/upper triangle 186 | y = LU.solve_triangular(upper=False, unit=False, b=r) 187 | z = LU.solve_triangular(upper=True, unit=False, b=y) 188 | return z 189 | 190 | 191 | # generate preconditioner 192 | def get_preconditioner(data, name, **kwargs): 193 | 194 | if name == "baseline" or name == "direct": 195 | return Preconditioner(None, n=data.x.shape[0] ,**kwargs) 196 | 197 | elif name == "learned": 198 | return LearnedPreconditioner(data, **kwargs) 199 | 200 | # convert to sparse matrix 201 | A = torch.sparse_coo_tensor(data.edge_index, data.edge_attr.squeeze(), 202 | dtype=torch.float64, requires_grad=False) 203 | A_s = torch_sparse_to_scipy(A) 204 | 205 | if name == "ic" or name == "ichol": 206 | return ICholPreconditioner(A_s, **kwargs) 207 | 208 | elif name == "ilu": 209 | return ILUPreconditioner(A_s, **kwargs) 210 | 211 | elif name == "jacobi": 212 | return JacobiPreconditioner(A_s, n=data.x.shape[0], **kwargs) 213 | 214 | else: 215 | raise NotImplementedError(f"Preconditioner {name} not implemented!") 216 | -------------------------------------------------------------------------------- /krylov/test_krylov.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from krylov.arnoldi import arnoldi, arnoldi_step 7 | from krylov.cg import conjugate_gradient, preconditioned_conjugate_gradient 8 | from krylov.gmres import gmres, back_substitution 9 | 10 | 11 | class TestKrylov(unittest.TestCase): 12 | 13 | def test_arnoldi(self): 14 | n = 100 15 | m = 4 16 | 17 | M = lambda x: x 18 | A = torch.rand(n, n, dtype=torch.float64) 19 | b = torch.rand(n, dtype=torch.float64) 20 | x0 = torch.zeros(n, dtype=torch.float64) 21 | 22 | r0 = M(b - A @ x0) 23 | 24 | V, H = arnoldi(M, A, r0, m) 25 | 26 | assert torch.linalg.cond(V) - 1 < 1e-5, f"Condition number of V is {torch.linalg.cond(V)}" 27 | assert torch.linalg.norm(V.T @ V - torch.eye(m + 1, dtype=torch.float64)) < 1e-5, "V is not orthogonal" 28 | 29 | for i in range(1, m + 2): 30 | cond_number = torch.linalg.cond(V[:,:i]) 31 | assert cond_number - 1 < 1e-5, f"Condition number of V[:,:{i}] is {cond_number}" 32 | 33 | # A@V[:,:-1] = V@H 34 | torch.testing.assert_close(A@V[:,:-1], V@H, rtol=1e-5, atol=1e-5) 35 | 36 | def test_arnoldi_step(self): 37 | n = 100 38 | m = 5 39 | 40 | M = lambda x: x 41 | A = torch.rand(n, n, dtype=torch.float64) 42 | b = torch.rand(n, dtype=torch.float64) 43 | x0 = torch.zeros(n, dtype=torch.float64) 44 | r0 = M(b - A @ x0) 45 | 46 | # classical approach 47 | Q, H = arnoldi(M, A, r0, m) 48 | 49 | # initalize the V matrix 50 | V = torch.zeros((n, 1), dtype=torch.float64) 51 | beta = np.linalg.norm(r0) 52 | 53 | V[:, 0] = r0 / beta 54 | R = torch.zeros((m + 1, 1), dtype=torch.float64) 55 | 56 | for i in range(m): 57 | h, v = arnoldi_step(M, A, V, i) 58 | 59 | V = torch.cat((V, torch.zeros((n, 1), dtype=torch.float64)), axis=1) 60 | V[:, i + 1] = v 61 | 62 | R = torch.cat((R, torch.zeros((m + 1, 1), dtype=torch.float64)), axis=1) 63 | R[:(i + 2), i] = h 64 | 65 | torch.testing.assert_close(Q, V, rtol=1e-5, atol=1e-5) 66 | torch.testing.assert_close(H, R[:, :-1], rtol=1e-5, atol=1e-5) 67 | 68 | def test_gmres(self): 69 | A, b = discretise_poisson(10) 70 | e, x_hat = gmres(A, b) 71 | 72 | x_direct = torch.linalg.inv(A.to_dense()) @ b 73 | torch.testing.assert_close(x_hat, x_direct, rtol=1e-5, atol=1e-5) 74 | 75 | def test_gmres_restart(self): 76 | A, b = discretise_poisson(10) 77 | _, x_hat = gmres(A, b, restart=10) 78 | _, x_hat2 = gmres(A, b) 79 | x_direct = torch.linalg.inv(A.to_dense()) @ b 80 | 81 | torch.testing.assert_close(x_hat, x_direct, rtol=1e-5, atol=1e-5) 82 | torch.testing.assert_close(x_hat, x_hat2, rtol=1e-5, atol=1e-5) 83 | 84 | def test_gmres_preconditioner(self): 85 | A, b = discretise_poisson(10) 86 | 87 | M = torch.diag(torch.rand(100, dtype=torch.float64)).to_sparse_coo() 88 | precond = lambda x: M@x 89 | 90 | e, x_hat = gmres(A, b, M=precond) 91 | x_direct = torch.linalg.inv(A.to_dense()) @ b 92 | 93 | torch.testing.assert_close(x_hat, x_direct, rtol=1e-5, atol=1e-5) 94 | 95 | def test_gmres_right(self): 96 | A, b = discretise_poisson(10) 97 | M = torch.diag(torch.rand(100, dtype=torch.float64)).to_sparse_coo() 98 | precond = lambda x: M@x 99 | 100 | e, x_hat = gmres(A, b, M=precond, left=False) 101 | 102 | x_direct = torch.linalg.inv(A.to_dense()) @ b 103 | torch.testing.assert_close(x_hat, x_direct, rtol=1e-5, atol=1e-5) 104 | 105 | def test_gmres_identity(self): 106 | A = torch.eye(100, dtype=torch.float64) 107 | b = torch.rand(100, dtype=torch.float64) 108 | 109 | e, x_hat = gmres(A, b) 110 | 111 | assert len(e) == 2, f"GMRES converged in {len(e) - 1} iterations" 112 | torch.testing.assert_close(x_hat, b, rtol=1e-5, atol=1e-5) 113 | 114 | def test_gmres_vs_cg(self): 115 | A, b = discretise_poisson(10) 116 | 117 | _, x_cg = conjugate_gradient(A, b) 118 | _, x_gmres = gmres(A, b) 119 | 120 | torch.testing.assert_close(x_cg, x_gmres, rtol=1e-5, atol=1e-5) 121 | 122 | def test_conjugate_gradient(self): 123 | A = torch.rand(100, 100, dtype=torch.float64) 124 | A = A @ A.T + 0.1 * torch.eye(100, dtype=torch.float64) 125 | x = torch.rand(100, dtype=torch.float64) 126 | 127 | # obtain rhs and normalize 128 | b = A @ x 129 | b = b / torch.linalg.norm(b) 130 | 131 | _, x_hat = conjugate_gradient(A, b, x, rtol=1e-8, max_iter=100_000) 132 | 133 | # check that the solution has a small residual 134 | res_norm = torch.linalg.norm(A @ x_hat - b) / torch.linalg.norm(b) 135 | 136 | assert res_norm < 1e-3, f"Residual norm is {res_norm}" 137 | 138 | def test_cg_preconditioner(self): 139 | A = torch.rand(100, 100, dtype=torch.float64) 140 | A = A @ A.T + 0.1 * torch.eye(100, dtype=torch.float64) 141 | 142 | M = torch.rand(100, 100, dtype=torch.float64) 143 | M = M @ M.T + 0.1 * torch.eye(100, dtype=torch.float64) 144 | precond = lambda x: M@x 145 | 146 | b = torch.rand(100, dtype=torch.float64) 147 | 148 | _, x_hat_1 = conjugate_gradient(A, b, x_true=None, rtol=1e-15) 149 | _, x_hat_2 = preconditioned_conjugate_gradient(A, b, M=precond, x_true=None, rtol=1e-15) 150 | 151 | torch.testing.assert_close(x_hat_1, x_hat_2, rtol=1e-6, atol=1e-6) 152 | 153 | def test_backsubsitution(self): 154 | 155 | # create data 156 | A = torch.rand(100, 100, dtype=torch.float64) 157 | U = torch.triu(A) 158 | b = torch.rand((100, 1), dtype=torch.float64) 159 | 160 | # solve the system using back substitution 161 | x = back_substitution(U, b) 162 | x_true = torch.linalg.solve_triangular(U, b, upper=True) 163 | 164 | torch.testing.assert_close(x, x_true.squeeze(), rtol=1e-5, atol=1e-5) 165 | 166 | 167 | def discretise_poisson(N): 168 | """Generate the matrix and rhs associated with the discrete Poisson operator.""" 169 | 170 | nelements = 5 * N**2 - 16 * N + 16 171 | 172 | row_ind = np.zeros(nelements, dtype=np.float64) 173 | col_ind = np.zeros(nelements, dtype=np.float64) 174 | data = np.zeros(nelements, dtype=np.float64) 175 | 176 | f = np.zeros(N * N, dtype=np.float64) 177 | 178 | count = 0 179 | for j in range(N): 180 | for i in range(N): 181 | if i == 0 or i == N - 1 or j == 0 or j == N - 1: 182 | row_ind[count] = col_ind[count] = j * N + i 183 | data[count] = 1 184 | f[j * N + i] = 0 185 | count += 1 186 | 187 | else: 188 | row_ind[count : count + 5] = j * N + i 189 | col_ind[count] = j * N + i 190 | col_ind[count + 1] = j * N + i + 1 191 | col_ind[count + 2] = j * N + i - 1 192 | col_ind[count + 3] = (j + 1) * N + i 193 | col_ind[count + 4] = (j - 1) * N + i 194 | 195 | data[count] = 4 * (N - 1)**2 196 | data[count + 1 : count + 5] = - (N - 1)**2 197 | f[j * N + i] = 1 198 | 199 | count += 5 200 | 201 | # create the sparse pytorch matrix 202 | idx = np.vstack((row_ind, col_ind)) 203 | A = torch.sparse_coo_tensor(idx, data, (N**2, N**2)).coalesce() 204 | b = torch.tensor(f) 205 | 206 | return A, b 207 | 208 | 209 | if __name__ == '__main__': 210 | unittest.main() 211 | -------------------------------------------------------------------------------- /neuralif/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dataclasses import dataclass, field 4 | from typing import List 5 | import torch 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import scipy.stats as st 9 | 10 | from neuralif.utils import kA_bound 11 | 12 | 13 | @dataclass 14 | class TestResults: 15 | method: str 16 | dataset: str 17 | folder: str 18 | 19 | # for learned distinguish between different models 20 | model_name: str = "" 21 | 22 | # general parameters 23 | seed: int = 0 24 | target: float = 1e-8 25 | solver: str = "cg" 26 | 27 | # store the results of the test evaluation 28 | n: List[int] = field(default_factory=list) 29 | # cond_pa: List[float] = field(default_factory=list) 30 | nnz_a: List[float] = field(default_factory=list) 31 | nnz_p: List[float] = field(default_factory=list) 32 | p_times: List[float] = field(default_factory=list) 33 | overhead: List[float] = field(default_factory=list) 34 | 35 | # store results from solver (cg or gmres) run 36 | solver_time: List[float] = field(default_factory=list) 37 | solver_iterations: List[float] = field(default_factory=list) 38 | solver_error: List[float] = field(default_factory=list) 39 | solver_residual: List[float] = field(default_factory=list) 40 | 41 | # more advanved loggings (not always set) 42 | distribution: List[torch.Tensor] = field(default_factory=list) 43 | loss1: List[float] = field(default_factory=list) 44 | loss2: List[float] = field(default_factory=list) 45 | 46 | def log(self, nnz_a, nnz_p, plot=False): 47 | 48 | self.nnz_a.append(nnz_a) 49 | self.nnz_p.append(nnz_p) 50 | 51 | if plot: 52 | self.plot_convergence() 53 | 54 | def log_solve(self, n, solver_time, solver_iterations, solver_error, solver_residual, p_time, overhead): 55 | self.n.append(n) 56 | self.solver_time.append(solver_time) 57 | self.solver_iterations.append(solver_iterations) 58 | self.solver_error.append(solver_error) 59 | self.solver_residual.append(solver_residual) 60 | self.p_times.append(p_time) 61 | self.overhead.append(overhead) 62 | 63 | def log_eigenval_dist(self, dist, plot=False): 64 | # eigenvalue of singular value dist :) 65 | 66 | self.distribution.append(dist.numpy()) 67 | 68 | if plot: 69 | self.plot_eigvals(dist) 70 | 71 | def log_loss(self, loss1, loss2, plot=False): 72 | self.loss1.append(loss1) 73 | self.loss2.append(loss2) 74 | 75 | if plot: 76 | self.plot_loss() 77 | 78 | def plot_convergence(self): 79 | 80 | # check convergence speed etc. 81 | error_0 = self.solver_error[-1][0] 82 | # errors = [fun(r[0]) for r in res] 83 | # residuals = [fun(r[1]) for r in res] 84 | 85 | if self.solver == "cg" and False: 86 | bounds = [error_0 * kA_bound(self.cond_pa[-1], k) for k in range(len(self.solver_residual[-1]))] 87 | else: 88 | bounds = None 89 | 90 | plt.plot(self.solver_error[-1], label="error ($|| x_i - x_* ||$)") 91 | plt.plot(self.solver_residual[-1], label="residual ($||r ||_2$)") 92 | 93 | if bounds is not None: 94 | plt.plot(bounds, "--", label="k(A)-bound") 95 | 96 | plt.plot([self.target for _ in self.solver_residual[-1]], ":") 97 | 98 | plt.grid(alpha=0.3) 99 | 100 | plt.yscale("log") 101 | plt.title(f"Convergence: {self.method} in {len(self.solver_residual[-1]) - 1} iterations") 102 | plt.xlabel("iteration") 103 | plt.ylabel("log10") 104 | plt.legend() 105 | 106 | sample = len(self.solver_time) 107 | plt.savefig(f"{self.folder}/convergence_{self.solver}_{self.method}_{sample}.pdf") 108 | plt.close() 109 | 110 | def plot_eigvals(self, dist, name=""): 111 | 112 | c = torch.max(dist) / torch.min(dist) 113 | 114 | # plt.rcParams["font.family"] = "Times New Roman" 115 | plt.rcParams["font.size"] = 14 116 | 117 | plt.grid(alpha=0.3) 118 | 119 | bins=20 120 | # bins=[0, 0.01, 0.1, 0.2,0.3,0.5,1,1.5,2,3,4,5] 121 | plt.hist(dist.tolist(), density=True, bins=bins, alpha=0.7) 122 | mn, mx = plt.xlim() 123 | plt.xlim(mn, mx) 124 | kde_xs = np.linspace(mn, mx, 300) 125 | kde = st.gaussian_kde(dist.tolist()) 126 | plt.plot(kde_xs, kde.pdf(kde_xs), "--", alpha=0.7) 127 | 128 | # plt.xscale("log") 129 | # plt.xlim(right=2) 130 | 131 | plt.title(f"$\kappa(A)=${c.item():.2e}") 132 | plt.ylabel("Frequency") 133 | plt.xlabel("$\lambda$") 134 | plt.savefig(f"{self.folder}/eigenvalues_{self.method}_{name}.png") 135 | plt.close() 136 | 137 | def plot_loss(self): 138 | i = len(self.solver_time) - 1 139 | 140 | # plt.rcParams["font.family"] = "Times New Roman" 141 | plt.rcParams["font.size"] = 14 142 | 143 | fig, axs = plt.subplots(1, 3, figsize=plt.figaspect(1/3)) 144 | # fig.suptitle(f"{self.method.upper()} Error: {self.loss[-1]:.2f}" + m) 145 | 146 | im1 = axs[0].imshow(torch.abs(self.A), interpolation='none', cmap='Blues') 147 | im1.set_clim(0, 1) 148 | axs[0].set_title("A") 149 | im2 = axs[1].imshow(torch.abs(self.L), interpolation='none', cmap='Blues') 150 | im2.set_clim(0, 1) 151 | axs[1].set_title("L") 152 | 153 | res = torch.abs(self.L@self.L.T - self.A) 154 | 155 | # show at points where A is non-zero 156 | # res = torch.where(torch.abs(self.A) > 0, res, torch.zeros_like(res)) 157 | 158 | im3 = axs[2].imshow(res, interpolation='none', cmap='Reds') 159 | im3.set_clim(0, 1) 160 | axs[2].set_title("L@L.T - A") 161 | 162 | # add colorbat 163 | fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8, wspace=0.4, hspace=0.1) 164 | cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8]) 165 | fig.colorbar(im3,cax=cb_ax) 166 | 167 | # share y-axis 168 | for ax in fig.get_axes(): 169 | ax.label_outer() 170 | 171 | # save as file 172 | plt.savefig(f"{self.folder}/chol_factorization_{self.method}_{i}.png") 173 | plt.close() 174 | 175 | def print_summary(self): 176 | for key, value in self.get_summary_dict().items(): 177 | print(f"{key}:\t{value}") 178 | print() 179 | 180 | def get_total_p_time(self): 181 | # needs to return an array with the total time required for the preconditioner 182 | # p_time + inv_time + overhead 183 | return [p + o for p, o in zip(self.p_times, self.overhead)] 184 | 185 | def get_summary_dict(self): 186 | # check where ch did not break down 187 | valid_samples = np.asarray(self.solver_iterations) > 0 188 | 189 | data = { 190 | f"time_{self.method}": np.mean(self.p_times, where=valid_samples), 191 | f"overhead_{self.method}": np.mean(self.overhead, where=valid_samples), 192 | f"{self.solver}_time_{self.method}": np.mean(self.solver_time, where=valid_samples), 193 | f"{self.solver}_iterations_{self.method}": np.mean(self.solver_iterations, where=valid_samples), 194 | f"total_time_{self.method}": np.mean(list(map(lambda x: x[0] + x[1], zip(self.get_total_p_time(), self.solver_time))), where=valid_samples), 195 | f"time-per-iter": np.sum(self.solver_time, where=valid_samples) / np.sum(self.solver_iterations, where=valid_samples), 196 | f"nnz_a_{self.method}": np.mean(self.nnz_a), 197 | f"nnz_p_{self.method}": np.mean(self.nnz_p), 198 | } 199 | 200 | # add information about failure runs... 201 | if np.sum(valid_samples) < len(self.solver_iterations): 202 | data = {**data, **{f"success_rate_{self.method}": np.sum(valid_samples) / len(self.solver_iterations)}} 203 | 204 | return data 205 | 206 | def save_results(self): 207 | fn = f"{self.folder}/test_{self.method}.npz" 208 | 209 | np.savez(fn, n=self.n, 210 | p_time=self.p_times, 211 | overhead_time=self.overhead, 212 | nnz_a=self.nnz_a, 213 | nnz_p=self.nnz_p, 214 | solver=self.solver, 215 | solver_time=self.solver_time, 216 | solver_iterations=self.solver_iterations, 217 | solver_error=np.asarray(self.solver_error, dtype="object"), 218 | solver_residual=np.asarray(self.solver_residual, dtype="object"), 219 | eig_distribution=np.asarray(self.distribution, dtype="object"), 220 | loss1=self.loss1, 221 | loss2=self.loss2) 222 | 223 | 224 | @dataclass 225 | class TrainResults: 226 | folder: str 227 | 228 | # training 229 | loss: List[float] = field(default_factory=list) 230 | grad_norm: List[float] = field(default_factory=list) 231 | times: List[float] = field(default_factory=list) 232 | 233 | # validation 234 | log_freq: int = 100 235 | val_loss: List[float] = field(default_factory=list) 236 | val_its: List[float] = field(default_factory=list) 237 | 238 | def log(self, loss, grad_norm, time): 239 | self.loss.append(loss) 240 | self.grad_norm.append(grad_norm) 241 | self.times.append(time) 242 | 243 | def log_val(self, val_loss, val_its): 244 | self.val_loss.append(val_loss) 245 | self.val_its.append(val_its) 246 | 247 | def save_results(self): 248 | fn = f"{self.folder}/training.npz" 249 | np.savez(fn, loss=self.loss, grad_norm=self.grad_norm, 250 | val_loss=self.val_loss, val_cond=self.val_its) 251 | 252 | 253 | def create_folder(folder=None): 254 | if folder is None: 255 | folder = f"./results/{os.path.basename(__file__).split('.')[0]}" 256 | 257 | if not os.path.exists(folder): 258 | os.makedirs(folder) 259 | 260 | return folder 261 | -------------------------------------------------------------------------------- /neuralif/loss.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | 4 | from apps.data import graph_to_matrix 5 | 6 | 7 | warnings.filterwarnings('ignore', '.*Sparse CSR tensor support is in beta state.*') 8 | 9 | 10 | def frobenius_loss(L, A, sparse=True): 11 | # * Cholesky decomposition style loss 12 | 13 | if type(L) is tuple: 14 | U = L[1] 15 | L = L[0] 16 | else: 17 | U = L.T 18 | 19 | if sparse: 20 | # Not directly supported in pyotrch: 21 | # https://github.com/pytorch/pytorch/issues/95169 22 | # https://github.com/rusty1s/pytorch_sparse/issues/45 23 | r = L@U - A 24 | return torch.norm(r) 25 | 26 | else: 27 | A = A.to_dense().squeeze() 28 | L = L.to_dense().squeeze() 29 | U = U.to_dense().squeeze() 30 | 31 | return torch.linalg.norm(L@U - A, ord="fro") 32 | 33 | 34 | def sketched_loss(L, A, c=None, normalized=False): 35 | # both cholesky and LU decomposition 36 | 37 | if type(L) is tuple: 38 | U = L[1] 39 | L = L[0] 40 | else: 41 | U = L.T 42 | 43 | eps = 1e-8 44 | 45 | z = torch.randn((A.shape[0], 1), device=L.device) 46 | 47 | # if normalized: 48 | # z = z / torch.linalg.norm(z) # z-vector also should have unit length 49 | 50 | est = L@(U@z) - A@z 51 | norm = torch.linalg.vector_norm(est, ord=2) # vector norm 52 | 53 | if normalized and c is None: 54 | norm = norm / torch.linalg.vector_norm(A@z, ord=2) 55 | elif normalized: 56 | norm = norm / (c + eps) 57 | 58 | return norm 59 | 60 | 61 | def supervised_loss(L, A, x): 62 | # Note: Ax = b 63 | 64 | if type(L) is tuple: 65 | U = L[1] 66 | L = L[0] 67 | else: 68 | U = L.T 69 | 70 | if x is None: 71 | # if x is None, we recompute the solution in every time step. 72 | with torch.no_grad(): 73 | b = torch.randn((A.shape[0], 1), device=L.device) 74 | x = torch.linalg.solve(A.to_dense(), b) 75 | else: 76 | b = A@x 77 | 78 | res = L@(U@x) - b 79 | return torch.linalg.vector_norm(res, ord=2) 80 | 81 | 82 | def dircet_min_loss(L, A, x): 83 | # get L and U factors 84 | if type(L) is tuple: 85 | U = L[1] 86 | L = L[0] 87 | else: 88 | U = L.T 89 | 90 | if x is None: 91 | # if x is None, we recompute the solution in every time step. 92 | with torch.no_grad(): 93 | b = torch.randn((A.shape[0], 1), device=L.device) 94 | x = torch.linalg.solve(A.to_dense(), b) 95 | else: 96 | b = A@x 97 | 98 | res = L@(U@x) 99 | return torch.linalg.vector_norm(res, ord=2) 100 | 101 | 102 | def combined_loss(L, A, x, w=1): 103 | # combined loss 104 | loss1 = sketched_loss(L, A) 105 | loss2 = supervised_loss(L, A, x) 106 | 107 | return w * loss1 + loss2 108 | 109 | 110 | def loss(output, data, config=None, **kwargs): 111 | 112 | # load the data 113 | with torch.no_grad(): 114 | A, b = graph_to_matrix(data) 115 | 116 | # compute loss 117 | if config is None: 118 | # this is the regular loss used to train NeuralIF 119 | l = sketched_loss(output, A, normalized=False) 120 | 121 | elif config == "normalized": 122 | l = sketched_loss(output, A, kwargs.get("c", None), normalized=True) 123 | 124 | elif config == "supervised": 125 | l = supervised_loss(output, A, data.s.squeeze()) 126 | 127 | elif config == "inverted": 128 | l = supervised_loss(output, A, None) 129 | 130 | elif config == "combined": 131 | l = combined_loss(output, A, data.s.squeeze()) 132 | 133 | elif config == "combined-supervised": 134 | l = combined_loss(output, A, None) 135 | 136 | elif config == "frobenius": 137 | l = frobenius_loss(output, A, sparse=False) 138 | 139 | else: 140 | raise ValueError("Invalid loss configuration") 141 | 142 | 143 | return l 144 | -------------------------------------------------------------------------------- /neuralif/models.py: -------------------------------------------------------------------------------- 1 | import numml.sparse as sp 2 | import torch 3 | import torch.nn as nn 4 | import torch_geometric 5 | import torch_geometric.nn as pyg 6 | from torch_geometric.nn import aggr 7 | from torch_geometric.utils import to_scipy_sparse_matrix 8 | from scipy.sparse import tril 9 | 10 | from neuralif.utils import TwoHop, gershgorin_norm 11 | 12 | 13 | ############################ 14 | # Layers # 15 | ############################ 16 | class GraphNet(nn.Module): 17 | # Follows roughly the outline of torch_geometric.nn.MessagePassing() 18 | # As shown in https://github.com/deepmind/graph_nets 19 | # Here is a helpful python implementation: 20 | # https://github.com/NVIDIA/GraphQSat/blob/main/gqsat/models.py 21 | # Also allows multirgaph GNN via edge_2_features 22 | def __init__(self, node_features, edge_features, global_features=0, hidden_size=0, 23 | aggregate="mean", activation="relu", skip_connection=False, edge_features_out=None): 24 | 25 | super().__init__() 26 | 27 | # different aggregation functions 28 | if aggregate == "sum": 29 | self.aggregate = aggr.SumAggregation() 30 | elif aggregate == "mean": 31 | self.aggregate = aggr.MeanAggregation() 32 | elif aggregate == "max": 33 | self.aggregate = aggr.MaxAggregation() 34 | elif aggregate == "softmax": 35 | self.aggregate = aggr.SoftmaxAggregation(learn=True) 36 | else: 37 | raise NotImplementedError(f"Aggregation '{aggregate}' not implemented") 38 | 39 | self.global_aggregate = aggr.MeanAggregation() 40 | 41 | add_edge_fs = 1 if skip_connection else 0 42 | edge_features_out = edge_features if edge_features_out is None else edge_features_out 43 | 44 | # Graph Net Blocks (see https://arxiv.org/pdf/1806.01261.pdf) 45 | self.edge_block = MLP([global_features + (edge_features + add_edge_fs) + (2 * node_features), 46 | hidden_size, 47 | edge_features_out], 48 | activation=activation) 49 | 50 | self.node_block = MLP([global_features + edge_features_out + node_features, 51 | hidden_size, 52 | node_features], 53 | activation=activation) 54 | 55 | # optional set of blocks for global GNN 56 | self.global_block = None 57 | if global_features > 0: 58 | self.global_block = MLP([edge_features_out + node_features + global_features, 59 | hidden_size, 60 | global_features], 61 | activation=activation) 62 | 63 | def forward(self, x, edge_index, edge_attr, g=None): 64 | row, col = edge_index 65 | 66 | if self.global_block is not None: 67 | assert g is not None, "Need global features for global block" 68 | 69 | # run the edge update and aggregate features 70 | edge_embedding = self.edge_block(torch.cat([torch.ones(x[row].shape[0], 1, device=x.device) * g, 71 | x[row], x[col], edge_attr], dim=1)) 72 | aggregation = self.aggregate(edge_embedding, row) 73 | 74 | 75 | agg_features = torch.cat([torch.ones(x.shape[0], 1, device=x.device) * g, x, aggregation], dim=1) 76 | node_embeddings = self.node_block(agg_features) 77 | 78 | # aggregate over all edges and nodes (always mean) 79 | mp_global_aggr = g 80 | edge_aggregation_global = self.global_aggregate(edge_embedding) 81 | node_aggregation_global = self.global_aggregate(node_embeddings) 82 | 83 | # compute the new global embedding 84 | # the old global feature is part of mp_global_aggr 85 | global_embeddings = self.global_block(torch.cat([node_aggregation_global, 86 | edge_aggregation_global, 87 | mp_global_aggr], dim=1)) 88 | 89 | return edge_embedding, node_embeddings, global_embeddings 90 | 91 | else: 92 | # update edge features and aggregate 93 | edge_embedding = self.edge_block(torch.cat([x[row], x[col], edge_attr], dim=1)) 94 | aggregation = self.aggregate(edge_embedding, row) 95 | agg_features = torch.cat([x, aggregation], dim=1) 96 | # update node features 97 | node_embeddings = self.node_block(agg_features) 98 | return edge_embedding, node_embeddings, None 99 | 100 | 101 | class MLP(nn.Module): 102 | def __init__(self, width, layer_norm=False, activation="relu", activate_final=False): 103 | super().__init__() 104 | width = list(filter(lambda x: x > 0, width)) 105 | assert len(width) >= 2, "Need at least one layer in the network!" 106 | 107 | lls = nn.ModuleList() 108 | for k in range(len(width)-1): 109 | lls.append(nn.Linear(width[k], width[k+1], bias=True)) 110 | if k != (len(width)-2) or activate_final: 111 | if activation == "relu": 112 | lls.append(nn.ReLU()) 113 | elif activation == "tanh": 114 | lls.append(nn.Tanh()) 115 | elif activation == "leakyrelu": 116 | lls.append(nn.LeakyReLU()) 117 | elif activation == "sigmoid": 118 | lls.append(nn.Sigmoid()) 119 | else: 120 | raise NotImplementedError(f"Activation '{activation}' not implemented") 121 | 122 | if layer_norm: 123 | lls.append(nn.LayerNorm(width[-1])) 124 | 125 | self.m = nn.Sequential(*lls) 126 | 127 | def forward(self, x): 128 | return self.m(x) 129 | 130 | 131 | class MP_Block(nn.Module): 132 | # L@L.T matrix multiplication graph layer 133 | # Aligns the computation of L@L.T - A with the learned updates 134 | def __init__(self, skip_connections, first, last, edge_features, node_features, global_features, hidden_size, **kwargs) -> None: 135 | super().__init__() 136 | 137 | # first and second aggregation 138 | if "aggregate" in kwargs and kwargs["aggregate"] is not None: 139 | aggr = kwargs["aggregate"] if len(kwargs["aggregate"]) == 2 else kwargs["aggregate"] * 2 140 | else: 141 | aggr = ["mean", "sum"] 142 | 143 | act = kwargs["activation"] if "activation" in kwargs else "relu" 144 | 145 | edge_features_in = 1 if first else edge_features 146 | edge_features_out = 1 if last else edge_features 147 | 148 | # We use 2 graph nets in order to operate on the upper and lower triangular parts of the matrix 149 | self.l1 = GraphNet(node_features=node_features, edge_features=edge_features_in, global_features=global_features, 150 | hidden_size=hidden_size, skip_connection=(not first and skip_connections), 151 | aggregate=aggr[0], activation=act, edge_features_out=edge_features) 152 | 153 | self.l2 = GraphNet(node_features=node_features, edge_features=edge_features, global_features=global_features, 154 | hidden_size=hidden_size, aggregate=aggr[1], activation=act, edge_features_out=edge_features_out) 155 | 156 | def forward(self, x, edge_index, edge_attr, global_features): 157 | edge_embedding, node_embeddings, global_features = self.l1(x, edge_index, edge_attr, g=global_features) 158 | 159 | # flip row and column indices 160 | edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0) 161 | edge_embedding, node_embeddings, global_features = self.l2(node_embeddings, edge_index, edge_embedding, g=global_features) 162 | 163 | return edge_embedding, node_embeddings, global_features 164 | 165 | 166 | ############################ 167 | # Networks # 168 | ############################ 169 | class NeuralPCG(nn.Module): 170 | def __init__(self, **kwargs): 171 | # NeuralPCG follows the Encoder-Process-Decoder architecture 172 | super().__init__() 173 | 174 | # Network hyper-parameters 175 | self._latent_size = kwargs["latent_size"] 176 | self._num_layers = 2 177 | self._message_passing_steps = kwargs["message_passing_steps"] 178 | 179 | # NeuralPCG uses constant number of features for input and output 180 | self._node_features = 1 181 | self._edge_features = 1 182 | 183 | # Pre-network transformations 184 | self.transforms = None 185 | 186 | # Encoder - Process - Decoder architecture 187 | self.encoder_nodes = MLP([self._node_features] + [self._latent_size] * self._num_layers) 188 | self.encoder_edges = MLP([self._edge_features] + [self._latent_size] * self._num_layers) 189 | 190 | # decoder do not have a layer norm 191 | self.decoder_edges = MLP([self._latent_size] * self._num_layers + [1]) 192 | 193 | # message passing layers 194 | self.message_passing = nn.ModuleList([GraphNet(self._latent_size, self._latent_size, 195 | hidden_size=self._latent_size, 196 | aggregate="mean") 197 | for _ in range(self._message_passing_steps)]) 198 | 199 | def forward(self, data): 200 | if self.transforms: 201 | data = self.transforms(data) 202 | 203 | x_nodes, x_edges, edge_index = data.x, data.edge_attr, data.edge_index 204 | 205 | # save diag elements for later 206 | diag_idx = edge_index[0] == edge_index[1] 207 | diag_values = data.edge_attr[diag_idx].clone() 208 | 209 | latent_edges = self.encoder_edges(x_edges) 210 | latent_nodes = self.encoder_nodes(x_nodes) 211 | 212 | for message_passing_layer in self.message_passing: 213 | latent_edges, latent_nodes, _ = message_passing_layer(latent_nodes, edge_index, latent_edges) 214 | 215 | # Convert to lower triangular part of a matrix 216 | decoded_edges = self.decoder_edges(latent_edges) 217 | 218 | return self.transform_output_matrix(diag_idx, diag_values, x_nodes, edge_index, decoded_edges) 219 | 220 | def transform_output_matrix(self, diag_idx, diag_vals, node_x, edge_index, edge_values): 221 | # set the diagonal elements 222 | # the diag element gets duplicated later so we need to divide by 2 223 | edge_values[diag_idx] = 0.5 * torch.sqrt(diag_vals) 224 | size = node_x.size()[0] 225 | 226 | if torch.is_inference_mode_enabled(): 227 | 228 | # use scipy to symmetrize output 229 | m = to_scipy_sparse_matrix(edge_index, edge_values) 230 | m = m + m.T 231 | m = tril(m) 232 | 233 | # efficient sparse numml format 234 | l = sp.SparseCSRTensor(m) 235 | u = sp.SparseCSRTensor(m.T) 236 | 237 | return l, u, None 238 | 239 | else: 240 | # symmetrize the output by stacking things! 241 | transpose_index = torch.stack([edge_index[1], edge_index[0]], dim=0) 242 | 243 | sym_value = torch.cat([edge_values, edge_values]) 244 | sym_index = torch.cat([edge_index, transpose_index], dim=1) 245 | sym_value = sym_value.squeeze() 246 | 247 | # return only lower triangular part 248 | m = sym_index[0] <= sym_index[1] 249 | 250 | t = torch.sparse_coo_tensor(sym_index[:, m], sym_value[m], 251 | size=(size, size)) 252 | t = t.coalesce() 253 | 254 | return t, None, None 255 | 256 | 257 | class PreCondNet(nn.Module): 258 | # BASELINE MODEL 259 | # No splitting of the matrix into lower and upper part for alignment 260 | # Used for the ablation study 261 | def __init__(self, **kwargs) -> None: 262 | super().__init__() 263 | 264 | self.global_features = kwargs["global_features"] 265 | self.latent_size = kwargs["latent_size"] 266 | # node features are augmented with local degree profile 267 | self.augment_node_features = kwargs["augment_nodes"] 268 | 269 | num_node_features = 8 if self.augment_node_features else 1 270 | message_passing_steps = kwargs["message_passing_steps"] 271 | 272 | self.skip_connections = kwargs["skip_connections"] 273 | 274 | # create the layers 275 | self.mps = torch.nn.ModuleList() 276 | for l in range(message_passing_steps): 277 | self.mps.append(GraphNet(num_node_features, 278 | edge_features=1, 279 | hidden_size=self.latent_size, 280 | skip_connection=(l > 0 and self.skip_connections))) 281 | 282 | def forward(self, data): 283 | 284 | if self.augment_node_features: 285 | data = augment_features(data) 286 | 287 | # get the input data 288 | edge_embedding = data.edge_attr 289 | node_embedding = data.x 290 | edge_index = data.edge_index 291 | 292 | # copy the input data (only edges of original matrix A) for skip connections 293 | a_edges = edge_embedding.clone() 294 | 295 | # compute the output of the network 296 | for i, layer in enumerate(self.mps): 297 | if i != 0 and self.skip_connections: 298 | edge_embedding = torch.cat([edge_embedding, a_edges], dim=1) 299 | 300 | edge_embedding, node_embedding, _ = layer(node_embedding, edge_index, edge_embedding) 301 | 302 | # transform the output into a matrix 303 | return self.transform_output_matrix(node_embedding, edge_index, edge_embedding) 304 | 305 | def transform_output_matrix(self, node_x, edge_index, edge_values): 306 | # force diagonal to be positive (via activation function) 307 | diag = edge_index[0] == edge_index[1] 308 | edge_values[diag] = torch.sqrt(torch.exp(edge_values[diag])) 309 | edge_values = edge_values.squeeze() 310 | 311 | size = node_x.size()[0] 312 | 313 | if torch.is_inference_mode_enabled(): 314 | # use scipy to symmetrize output 315 | m = to_scipy_sparse_matrix(edge_index, edge_values) 316 | m = m + m.T 317 | m = tril(m) 318 | 319 | # efficient sparse numml format 320 | l = sp.SparseCSRTensor(m) 321 | u = sp.SparseCSRTensor(m.T) 322 | 323 | # Return upper and lower matrix l and u 324 | return l, u, None 325 | 326 | else: 327 | # symmetrize the output 328 | # we basicially just stack the indices of the matrix and it's transpose 329 | # when coalesce the result, these results get summed up. 330 | transpose_index = torch.stack([edge_index[1], edge_index[0]], dim=0) 331 | 332 | sym_value = torch.cat([edge_values, edge_values]) 333 | sym_index = torch.cat([edge_index, transpose_index], dim=1) 334 | 335 | # find lower triangular indices 336 | m = sym_index[0] <= sym_index[1] 337 | 338 | # return only lower triangular part of the data 339 | t = torch.sparse_coo_tensor(sym_index[:, m], sym_value[m], 340 | size=(size, size)) 341 | 342 | # take care of duplicate values (to force the output to be symmetric) 343 | t = t.coalesce() 344 | 345 | return t, None, None 346 | 347 | 348 | class NeuralIF(nn.Module): 349 | # Neural Incomplete factorization 350 | def __init__(self, drop_tol=0, **kwargs) -> None: 351 | super().__init__() 352 | 353 | self.global_features = kwargs["global_features"] 354 | self.latent_size = kwargs["latent_size"] 355 | # node features are augmented with local degree profile 356 | self.augment_node_features = kwargs["augment_nodes"] 357 | 358 | num_node_features = 8 if self.augment_node_features else 1 359 | message_passing_steps = kwargs["message_passing_steps"] 360 | 361 | # edge feature representation in the latent layers 362 | edge_features = kwargs.get("edge_features", 1) 363 | 364 | self.skip_connections = kwargs["skip_connections"] 365 | 366 | self.mps = torch.nn.ModuleList() 367 | for l in range(message_passing_steps): 368 | # skip connections are added to all layers except the first one 369 | self.mps.append(MP_Block(skip_connections=self.skip_connections, 370 | first=l==0, 371 | last=l==(message_passing_steps-1), 372 | edge_features=edge_features, 373 | node_features=num_node_features, 374 | global_features=self.global_features, 375 | hidden_size=self.latent_size, 376 | activation=kwargs["activation"], 377 | aggregate=kwargs["aggregate"])) 378 | 379 | # node decodings 380 | self.node_decoder = MLP([num_node_features, self.latent_size, 1]) if kwargs["decode_nodes"] else None 381 | 382 | # diag-aggregation for normalization of rows 383 | self.normalize_diag = kwargs["normalize_diag"] if "normalize_diag" in kwargs else False 384 | self.diag_aggregate = aggr.SumAggregation() 385 | 386 | # normalization 387 | self.graph_norm = pyg.norm.GraphNorm(num_node_features) if ("graph_norm" in kwargs and kwargs["graph_norm"]) else None 388 | 389 | # drop tolerance and additional fill-ins and more sparsity 390 | self.tau = drop_tol 391 | self.two = kwargs.get("two_hop", False) 392 | 393 | def forward(self, data): 394 | # ! data could be batched here...(not implemented) 395 | 396 | if self.augment_node_features: 397 | data = augment_features(data, skip_rhs=True) 398 | 399 | # add additional edges to the data 400 | if self.two: 401 | data = TwoHop()(data) 402 | 403 | # * in principle it is possible to integrate reordering here. 404 | 405 | data = ToLowerTriangular()(data) 406 | 407 | # get the input data 408 | edge_embedding = data.edge_attr 409 | l_index = data.edge_index 410 | 411 | if self.graph_norm is not None: 412 | node_embedding = self.graph_norm(data.x, batch=data.batch) 413 | else: 414 | node_embedding = data.x 415 | 416 | # copy the input data (only edges of original matrix A) 417 | a_edges = edge_embedding.clone() 418 | 419 | if self.global_features > 0: 420 | global_features = torch.zeros((1, self.global_features), device=data.x.device, requires_grad=False) 421 | # feature ideas: nnz, 1-norm, inf-norm col/row var, min/max variability, avg distances to nnz 422 | else: 423 | global_features = None 424 | 425 | # compute the output of the network 426 | for i, layer in enumerate(self.mps): 427 | if i != 0 and self.skip_connections: 428 | edge_embedding = torch.cat([edge_embedding, a_edges], dim=1) 429 | 430 | edge_embedding, node_embedding, global_features = layer(node_embedding, l_index, edge_embedding, global_features) 431 | 432 | # transform the output into a matrix 433 | return self.transform_output_matrix(node_embedding, l_index, edge_embedding, a_edges) 434 | 435 | def transform_output_matrix(self, node_x, edge_index, edge_values, a_edges): 436 | # force diagonal to be positive 437 | diag = edge_index[0] == edge_index[1] 438 | 439 | # normalize diag such that it has zero residual 440 | if self.normalize_diag: 441 | # copy the diag of matrix A 442 | a_diag = a_edges[diag] 443 | 444 | # compute the row norm 445 | square_values = torch.pow(edge_values, 2) 446 | aggregated = self.diag_aggregate(square_values, edge_index[0]) 447 | 448 | # now, we renormalize the edge values such that they are the square root of the original value... 449 | edge_values = torch.sqrt(a_diag[edge_index[0]]) * edge_values / torch.sqrt(aggregated[edge_index[0]]) 450 | 451 | else: 452 | # otherwise, just take the edge values as they are... 453 | # but take the square root as it is numerically better 454 | # edge_values[diag] = torch.exp(edge_values[diag]) 455 | edge_values[diag] = torch.sqrt(torch.exp(edge_values[diag])) 456 | 457 | # node decoder 458 | node_output = self.node_decoder(node_x).squeeze() if self.node_decoder is not None else None 459 | 460 | # ! this if should only be activated when the model is in production!! 461 | if torch.is_inference_mode_enabled(): 462 | 463 | # we can decide to remove small elements during inference from the preconditioner matrix 464 | if self.tau != 0: 465 | small_value = (torch.abs(edge_values) <= self.tau).squeeze() 466 | 467 | # small value and not diagonal 468 | elems = torch.logical_and(small_value, torch.logical_not(diag)) 469 | 470 | # might be able to do this easily! 471 | edge_values[elems] = 0 472 | 473 | # remove zeros from the sparse representation 474 | filt = (edge_values != 0).squeeze() 475 | edge_values = edge_values[filt] 476 | edge_index = edge_index[:, filt] 477 | 478 | # ! this is the way to go!! 479 | # Doing pytorch -> scipy -> numml is a lot faster than pytorch -> numml on CPU 480 | # On GPU it is faster to go to pytorch -> numml -> CPU 481 | 482 | # convert to scipy sparse matrix 483 | # m = to_scipy_sparse_matrix(edge_index, matrix_values) 484 | m = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(), 485 | size=(node_x.size()[0], node_x.size()[0])) 486 | # type=torch.double) 487 | 488 | # produce L and U seperatly 489 | l = sp.SparseCSRTensor(m) 490 | u = sp.SparseCSRTensor(m.T) 491 | 492 | return l, u, node_output 493 | 494 | else: 495 | # For training and testing (computing regular losses for examples.) 496 | # does not need to be performance optimized! 497 | # use torch sparse directly 498 | t = torch.sparse_coo_tensor(edge_index, edge_values.squeeze(), 499 | size=(node_x.size()[0], node_x.size()[0])) 500 | 501 | # normalized l1 norm is best computed here! 502 | # l2_nn = torch.linalg.norm(edge_values, ord=2) 503 | l1_penalty = torch.sum(torch.abs(edge_values)) / len(edge_values) 504 | 505 | return t, l1_penalty, node_output 506 | 507 | 508 | class LearnedLU(nn.Module): 509 | 510 | def __init__(self, *args, **kwargs) -> None: 511 | super().__init__() 512 | 513 | self.global_features = kwargs["global_features"] 514 | self.augment_node_features = kwargs["augment_nodes"] 515 | 516 | num_node_features = 8 if self.augment_node_features else 1 517 | 518 | message_passing_steps = kwargs["message_passing_steps"] 519 | self.skip_connections = kwargs["skip_connections"] 520 | self.layers = nn.ModuleList() 521 | 522 | # use a smooth activation function for the diagonal during training 523 | self.smooth_activation = kwargs.get("smooth_activation", True) 524 | self.epsilon = kwargs.get("epsilon", 0.001) 525 | 526 | num_edge_features = 32 527 | hidden_size = 32 528 | 529 | for l in range(message_passing_steps): 530 | first_layer = l == 0 531 | last_layer = l == (message_passing_steps - 1) 532 | 533 | self.layers.append( 534 | GraphNet( 535 | skip_connection=(l != 0 and self.skip_connections), 536 | edge_features=2 if first_layer else num_edge_features, 537 | edge_features_out=1 if last_layer else num_edge_features, 538 | hidden_size=hidden_size, 539 | node_features=num_node_features, 540 | global_features=self.global_features 541 | ) 542 | ) 543 | 544 | def forward(self, data): 545 | a_edges = data.edge_attr.clone() 546 | 547 | if self.augment_node_features: 548 | data = augment_features(data) 549 | 550 | # add remaining self loops 551 | data.edge_index, data.edge_attr = torch_geometric.utils.add_remaining_self_loops(data.edge_index, data.edge_attr) 552 | 553 | edge_embedding = data.edge_attr 554 | node_embedding = data.x 555 | edge_index = data.edge_index 556 | 557 | 558 | # add positional encoding features 559 | row, col = data.edge_index 560 | lower_mask = row > col 561 | upper_mask = row < col 562 | additional_edge_feature = torch.zeros_like(a_edges) 563 | additional_edge_feature[lower_mask] = -1 564 | additional_edge_feature[upper_mask] = 1 565 | edge_embedding = torch.cat([edge_embedding, additional_edge_feature], dim=1) 566 | 567 | if self.global_features > 0: 568 | global_features = torch.zeros((1, self.global_features), device=data.x.device, requires_grad=False) 569 | else: 570 | global_features = None 571 | 572 | for i, layer in enumerate(self.layers): 573 | if i != 0 and self.skip_connections: 574 | edge_embedding = torch.cat([edge_embedding, a_edges], dim=1) 575 | 576 | edge_embedding, node_embedding, global_features = layer(node_embedding, edge_index, edge_embedding, global_features) 577 | 578 | return self.transform_output_matrix(a_edges, node_embedding, edge_index, edge_embedding) 579 | 580 | def transform_output_matrix(self, a_edges, node_x, edge_index, edge_values): 581 | """ 582 | Transform the output into L and U matrices. 583 | 584 | Parameters: 585 | a_edges (Tensor): Original edge attributes. 586 | node_x (Tensor): Node features. 587 | edge_index (Tensor): Edge indices. 588 | edge_values (Tensor): Edge values. 589 | tolerance (float): Tolerance for small values. 590 | 591 | Returns: 592 | tuple: Lower and upper matrices, and L1 norm. 593 | """ 594 | 595 | @torch.no_grad() 596 | def step_activation(x, eps=0.05): 597 | # activation function to enfore the diagonal to be non-zero 598 | # - replace small values with epsilon 599 | # - replace zeros with epsilon 600 | s = torch.where(torch.abs(x) > eps, x, torch.sign(x) * eps) 601 | return torch.where(s == 0, eps, s) 602 | 603 | def smooth_activation(x, eps=0.05): 604 | return x * (1 + torch.exp(-torch.abs((4 / eps) * x) + 2)) 605 | 606 | # create masks to split the edge values 607 | lower_mask = edge_index[0] >= edge_index[1] 608 | upper_mask = edge_index[0] <= edge_index[1] 609 | diag_mask = edge_index[0] == edge_index[1] 610 | 611 | # create values and indices for lower part 612 | lower_indices = edge_index[:, lower_mask] 613 | lower_values = edge_values[lower_mask][:, 0].squeeze() 614 | 615 | # create values and indices for upper part 616 | upper_indices = edge_index[:, upper_mask] 617 | upper_values = edge_values[upper_mask][:, -1].squeeze() 618 | 619 | # enforce diagonal to be unit valued for the upper part 620 | upper_values[diag_mask[upper_mask]] = 1 621 | 622 | # appy activation function to lower part 623 | if torch.is_inference_mode_enabled(): 624 | lower_values[diag_mask[lower_mask]] = step_activation(lower_values[diag_mask[lower_mask]], eps=self.epsilon) 625 | elif self.smooth_activation: 626 | lower_values[diag_mask[lower_mask]] = smooth_activation(lower_values[diag_mask[lower_mask]], eps=self.epsilon) 627 | 628 | # construct L and U matrix 629 | n = node_x.size()[0] 630 | 631 | # convert to lower and upper matrices 632 | lower_matrix = torch.sparse_coo_tensor(lower_indices, lower_values.squeeze(), size=(n, n)) 633 | upper_matrix = torch.sparse_coo_tensor(upper_indices, upper_values.squeeze(), size=(n, n)) 634 | 635 | if torch.is_inference_mode_enabled(): 636 | # convert to numml format 637 | l = sp.SparseCSRTensor(lower_matrix) 638 | u = sp.SparseCSRTensor(upper_matrix) 639 | 640 | return l, u, None 641 | 642 | else: 643 | # min diag element as a regularization term 644 | bound = torch.min(torch.abs(lower_values[diag_mask[lower_mask]])) 645 | 646 | return (lower_matrix, upper_matrix), bound, None 647 | 648 | 649 | ############################ 650 | # HELPERS # 651 | ############################ 652 | def augment_features(data, skip_rhs=False): 653 | # transform nodes to include more features 654 | 655 | if skip_rhs: 656 | # use instead notde position as an input feature! 657 | data.x = torch.arange(data.x.size()[0], device=data.x.device).unsqueeze(1) 658 | 659 | data = torch_geometric.transforms.LocalDegreeProfile()(data) 660 | 661 | # diagonal dominance and diagonal decay from the paper 662 | row, col = data.edge_index 663 | diag = (row == col) 664 | diag_elem = torch.abs(data.edge_attr[diag]) 665 | # remove diagonal elements by setting them to zero 666 | non_diag_elem = data.edge_attr.clone() 667 | non_diag_elem[diag] = 0 668 | 669 | row_sums = aggr.SumAggregation()(torch.abs(non_diag_elem), row) 670 | alpha = diag_elem / row_sums 671 | row_dominance_feature = alpha / (alpha + 1) 672 | row_dominance_feature = torch.nan_to_num(row_dominance_feature, nan=1.0) 673 | 674 | # compute diagonal decay features 675 | row_max = aggr.MaxAggregation()(torch.abs(non_diag_elem), row) 676 | alpha = diag_elem / row_max 677 | row_decay_feature = alpha / (alpha + 1) 678 | row_decay_feature = torch.nan_to_num(row_decay_feature, nan=1.0) 679 | 680 | data.x = torch.cat([data.x, row_dominance_feature, row_decay_feature], dim=1) 681 | 682 | return data 683 | 684 | 685 | class ToLowerTriangular(torch_geometric.transforms.BaseTransform): 686 | def __init__(self, inplace=False): 687 | self.inplace = inplace 688 | 689 | def __call__(self, data, order=None): 690 | if not self.inplace: 691 | data = data.clone() 692 | 693 | # TODO: if order is given use that one instead 694 | if order is not None: 695 | raise NotImplementedError("Custom ordering not yet implemented...") 696 | 697 | # transform the data into lower triag graph 698 | # this should be a data transformation (maybe?) 699 | rows, cols = data.edge_index[0], data.edge_index[1] 700 | fil = cols <= rows 701 | l_index = data.edge_index[:, fil] 702 | edge_embedding = data.edge_attr[fil] 703 | 704 | data.edge_index, data.edge_attr = l_index, edge_embedding 705 | return data 706 | -------------------------------------------------------------------------------- /neuralif/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | import numpy as np 6 | import torch 7 | import torch_geometric 8 | import scipy 9 | import scipy.sparse 10 | 11 | import time 12 | import os 13 | import psutil 14 | 15 | from torch_geometric.utils import coalesce, remove_self_loops, to_torch_coo_tensor, to_edge_index 16 | 17 | 18 | class TwoHop(torch_geometric.transforms.BaseTransform): 19 | 20 | def forward(self, data): 21 | assert data.edge_index is not None 22 | edge_index, edge_attr = data.edge_index, data.edge_attr 23 | num_nodes = data.num_nodes 24 | 25 | adj = to_torch_coo_tensor(edge_index, size=num_nodes) 26 | 27 | adj = adj @ adj 28 | 29 | edge_index2, _ = to_edge_index(adj) 30 | edge_index2, _ = remove_self_loops(edge_index2) 31 | 32 | edge_index = torch.cat([edge_index, edge_index2], dim=1) 33 | 34 | if edge_attr is not None: 35 | # We treat newly added edge features as "zero-features": 36 | edge_attr2 = edge_attr.new_zeros(edge_index2.size(1), 37 | *edge_attr.size()[1:]) 38 | edge_attr = torch.cat([edge_attr, edge_attr2], dim=0) 39 | 40 | data.edge_index, data.edge_attr = coalesce(edge_index, edge_attr, 41 | num_nodes) 42 | 43 | return data 44 | 45 | 46 | def gradient_clipping(model, clip=None): 47 | # track the gradient norm 48 | total_norm = 0.0 49 | for p in model.parameters(): 50 | if p.grad is not None: 51 | param_norm = p.grad.detach().data.norm(2) 52 | total_norm += param_norm.item() ** 2 53 | 54 | if clip is not None: 55 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 56 | 57 | return total_norm 58 | 59 | 60 | def save_dict_to_file(dictionary, filename): 61 | with open(filename, 'w') as file: 62 | file.write(json.dumps(dictionary)) 63 | 64 | 65 | def count_parameters(model): 66 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 67 | 68 | 69 | def num_non_zeros(P): 70 | return torch.linalg.norm(P.flatten(), ord=0) 71 | 72 | 73 | def frob_norm_sparse(data): 74 | return torch.pow(torch.sum(torch.pow(data, 2)), 0.5) 75 | 76 | 77 | def filter_small_values(A, threshold=1e-5): 78 | # only keep the values above threshold 79 | return torch.where(torch.abs(A) < threshold, torch.zeros_like(A), A) 80 | 81 | 82 | def plot_graph(data): 83 | # transofrm to networkx 84 | g = torch_geometric.utils.to_networkx(data, to_undirected=True) 85 | # remove the self loops for readability 86 | filtered_edges = list(filter(lambda x: x[0] != x[1], g.edges())) 87 | nx.draw(g, edgelist=filtered_edges) 88 | plt.show() 89 | 90 | 91 | def print_graph_statistics(data): 92 | print(data.validate()) 93 | print(data.is_directed()) 94 | print(data.num_nodes) 95 | 96 | 97 | def elapsed_since(start): 98 | return time.strftime("%H:%M:%S", time.gmtime(time.time() - start)) 99 | 100 | 101 | def get_process_memory(): 102 | process = psutil.Process(os.getpid()) 103 | mem_info = process.memory_info() 104 | return mem_info.rss 105 | 106 | 107 | def profile(func): 108 | def wrapper(*args, **kwargs): 109 | mem_before = get_process_memory() 110 | start = time.time() 111 | result = func(*args, **kwargs) 112 | elapsed_time = elapsed_since(start) 113 | mem_after = get_process_memory() 114 | print("{}: memory before: {:,}, after: {:,}, consumed: {:,}; exec time: {}".format( 115 | func.__name__, 116 | mem_before, mem_after, mem_after - mem_before, 117 | elapsed_time)) 118 | return result 119 | return wrapper 120 | 121 | 122 | def test_spd(A): 123 | # the matrix should be symmetric positive definite 124 | np.testing.assert_allclose(A, A.T, atol=1e-6) 125 | assert np.linalg.eigh(A)[0].min() > 0 126 | 127 | 128 | def kA_bound(cond, k): 129 | return 2 * ((torch.sqrt(cond) - 1) / (torch.sqrt(cond) + 1)) ** k 130 | 131 | 132 | def eigenval_distribution(P, A): 133 | if P == None: 134 | return torch.linalg.eigh(A)[0] 135 | else: 136 | return torch.linalg.eigh(P@A@P.T)[0] 137 | 138 | 139 | def condition_number(P, A, invert=False, split=True): 140 | if invert: 141 | if split: 142 | P = torch.linalg.solve_triangular(P, torch.eye(P.size()[0], device=P.device, requires_grad=False), upper=False) 143 | else: 144 | P = torch.linalg.inv(P) 145 | 146 | if split: 147 | # P.T@A@P is wrong! 148 | # Not sure what the difference is between P@A@P.T and P.T@P@A? 149 | return torch.linalg.cond(P@A@P.T) 150 | else: 151 | return torch.linalg.cond(P@A) 152 | 153 | 154 | def l1_output_norm(P): 155 | # normalized output norm 156 | return torch.sum(torch.abs(P)) / P.size()[0] 157 | 158 | 159 | def rademacher(n, m=1, device=None): 160 | if device == None: 161 | return torch.sign(torch.randn(n, m, requires_grad=False)) 162 | else: 163 | return torch.sign(torch.randn(n, m, device=device, requires_grad=False)) 164 | 165 | 166 | def torch_sparse_to_scipy(A): 167 | A = A.coalesce() 168 | d = A.values().squeeze().numpy() 169 | i, j = A.indices().numpy() 170 | A_s = scipy.sparse.coo_matrix((d, (i, j))) 171 | 172 | return A_s 173 | 174 | 175 | def gershgorin_norm(A, graph=False): 176 | 177 | if graph: 178 | row, col = A.edge_index 179 | agg = torch_geometric.nn.aggr.SumAggregation() 180 | 181 | row_sum = agg(torch.abs(A.edge_attr), row) 182 | col_sum = agg(torch.abs(A.edge_attr), col) 183 | 184 | else: 185 | # compute the normalization factor 186 | n = A.size()[0] 187 | 188 | # compute row and column sums 189 | row_sum = torch.sum(torch.abs(A.to_dense()), dim=1) 190 | col_sum = torch.sum(torch.abs(A.to_dense()), dim=0) 191 | 192 | gamma = torch.min(torch.max(row_sum), torch.max(col_sum)) 193 | return gamma 194 | 195 | 196 | time_function = lambda: time.perf_counter() 197 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import datetime 4 | 5 | import numpy as np 6 | import scipy 7 | import scipy.sparse 8 | import torch 9 | import json 10 | 11 | from krylov.cg import conjugate_gradient, preconditioned_conjugate_gradient 12 | from krylov.gmres import gmres 13 | from krylov.preconditioner import get_preconditioner 14 | 15 | from neuralif.models import NeuralIF, NeuralPCG, PreCondNet, LearnedLU 16 | from neuralif.utils import torch_sparse_to_scipy, time_function 17 | from neuralif.logger import TestResults 18 | 19 | from apps.data import matrix_to_graph_sparse, get_dataloader 20 | 21 | 22 | @torch.inference_mode() 23 | def test(model, test_loader, device, folder, save_results=False, dataset="random", solver="cg"): 24 | 25 | if save_results: 26 | os.makedirs(folder, exist_ok=False) 27 | 28 | print() 29 | print(f"Test:\t{len(test_loader.dataset)} samples") 30 | print(f"Solver:\t{solver} solver") 31 | print() 32 | 33 | # Two modes: either test baselines or the learned preconditioner 34 | if model is None: 35 | methods = ["baseline", "jacobi", "ilu"] 36 | else: 37 | assert solver in ["cg", "gmres"], "Data-driven method only works with CG or GMRES" 38 | methods = ["learned"] 39 | 40 | # using direct solver 41 | if solver == "direct": 42 | methods = ["direct"] 43 | 44 | for method in methods: 45 | print(f"Testing {method} preconditioner") 46 | 47 | test_results = TestResults(method, dataset, folder, 48 | model_name= f"\n{model.__class__.__name__}" if method == "learned" else "", 49 | target=1e-6, 50 | solver=solver) 51 | 52 | for sample, data in enumerate(test_loader): 53 | plot = save_results and sample == (len(test_loader.dataset) - 1) 54 | 55 | # Getting the preconditioners 56 | start = time_function() 57 | 58 | data = data.to(device) 59 | prec = get_preconditioner(data, method, model=model) 60 | 61 | # Get properties... 62 | p_time = prec.time 63 | breakdown = prec.breakdown 64 | nnzL = prec.nnz 65 | 66 | stop = time_function() 67 | 68 | A = torch.sparse_coo_tensor(data.edge_index, data.edge_attr.squeeze(), 69 | dtype=torch.float64, 70 | requires_grad=False).to("cpu").to_sparse_csr() 71 | 72 | b = data.x[:, 0].squeeze().to("cpu").to(torch.float64) 73 | b_norm = torch.linalg.norm(b) 74 | 75 | # we assume that b is unit norm wlog 76 | b = b / b_norm 77 | solution = data.s.to("cpu").to(torch.float64).squeeze() / b_norm if hasattr(data, "s") else None 78 | 79 | overhead = (stop - start) - (p_time) 80 | 81 | # RUN CONJUGATE GRADIENT 82 | start_solver = time_function() 83 | 84 | solver_settings = { 85 | "max_iter": 10_000, 86 | "x0": None 87 | } 88 | 89 | if breakdown: 90 | res = [] 91 | 92 | elif solver == "direct": 93 | 94 | # convert to sparse matrix (scipy) 95 | A_ = torch.sparse_coo_tensor(data.edge_index, data.edge_attr.squeeze(), 96 | dtype=torch.float64, requires_grad=False) 97 | 98 | # scipy sparse... 99 | A_s = torch_sparse_to_scipy(A_).tocsr() 100 | 101 | # override start time 102 | start_solver = time_function() 103 | 104 | dense = False 105 | 106 | if dense: 107 | _ = scipy.linalg.solve(A_.to_dense().numpy(), b.numpy(), assume_a='pos') 108 | else: 109 | _ = scipy.sparse.linalg.spsolve(A_s, b.numpy()) 110 | 111 | # dummy values... 112 | res = [(torch.Tensor([0]), torch.Tensor([0]))] * 2 113 | 114 | elif solver == "cg" and method == "baseline": 115 | # no preconditioner required when using baseline method 116 | res, _ = conjugate_gradient(A, b, x_true=solution, 117 | rtol=test_results.target, **solver_settings) 118 | 119 | elif solver == "cg": 120 | res, _ = preconditioned_conjugate_gradient(A, b, M=prec, x_true=solution, 121 | rtol=test_results.target, **solver_settings) 122 | 123 | elif solver == "gmres": 124 | 125 | res, _ = gmres(A, b, M=prec, x_true=solution, 126 | **solver_settings, plot=plot, 127 | atol=test_results.target, 128 | left=False) 129 | 130 | stop_solver = time_function() 131 | solver_time = (stop_solver - start_solver) 132 | 133 | # LOGGING 134 | test_results.log_solve(A.shape[0], solver_time, len(res) - 1, 135 | np.array([r[0].item() for r in res]), 136 | np.array([r[1].item() for r in res]), 137 | p_time, overhead) 138 | 139 | # ANALYSIS of the preconditioner and its effects! 140 | nnzA = A._nnz() 141 | 142 | test_results.log(nnzA, nnzL, plot=plot) 143 | 144 | svd = False 145 | if svd: 146 | # compute largest and smallest singular value 147 | Pinv = prec.get_inverse() 148 | APinv = A.to_dense() @ Pinv 149 | 150 | # compute the singular values of the preconditioned matrix 151 | S = torch.linalg.svdvals(APinv) 152 | 153 | # print the smallest and largest singular value 154 | test_results.log_eigenval_dist(S, plot=plot) 155 | 156 | # compute the loss of the preconditioner 157 | p = prec.get_p_matrix() 158 | loss1 = torch.linalg.norm(p.to_dense() - A.to_dense(), ord="fro") 159 | 160 | a_inv = torch.linalg.inv(A.to_dense()) 161 | loss2 = torch.linalg.norm(p.to_dense()@a_inv - torch.eye(a_inv.shape[0]), ord="fro") 162 | 163 | test_results.log_loss(loss1, loss2, plot=False) 164 | 165 | print(f"Smallest singular value: {S[-1]} | Largest singular value: {S[0]} | Condition number: {S[0] / S[-1]}") 166 | print(f"Loss Lmax: {loss1}\tLoss Lmin: {loss2}") 167 | print() 168 | 169 | if save_results: 170 | test_results.save_results() 171 | 172 | test_results.print_summary() 173 | 174 | 175 | def load_checkpoint(model, args, device): 176 | # load the saved weights of the model and the hyper-parameters 177 | checkpoint = args.checkpoint 178 | 179 | if checkpoint == "latest": 180 | # list all the directories in the results folder 181 | d = os.listdir("./results/") 182 | d.sort() 183 | 184 | config = None 185 | 186 | # find the latest checkpoint 187 | for i in range(len(d)): 188 | if os.path.isdir("./results/" + d[-i-1]): 189 | dir_contents = os.listdir("./results/" + d[-i-1]) 190 | 191 | # looking for a directory with both config and model weights 192 | if "config.json" in dir_contents and "final_model.pt" in dir_contents: 193 | # load the config.json file 194 | with open("./results/" + d[-i-1] + "/config.json") as f: 195 | config = json.load(f) 196 | 197 | if config["model"] != args.model: 198 | config = None 199 | continue 200 | 201 | if "best_model.pt" in dir_contents: 202 | checkpoint = "./results/" + d[-i-1] + "/best_model.pt" 203 | break 204 | else: 205 | checkpoint = "./results/" + d[-i-1] + "/final_model.pt" 206 | break 207 | if config is None: 208 | print("Checkpoint not found...") 209 | 210 | # neuralif has optional drop tolerance... 211 | if args.model == "neuralif": 212 | config["drop_tol"] = args.drop_tol 213 | 214 | # intialize model and hyper-parameters 215 | model = model(**config) 216 | print(f"load checkpoint: {checkpoint}") 217 | 218 | model.load_state_dict(torch.load(checkpoint, weights_only=False, map_location=torch.device(device))) 219 | 220 | elif checkpoint is not None: 221 | with open(checkpoint + "/config.json") as f: 222 | config = json.load(f) 223 | 224 | if args.model == "neuralif": 225 | config["drop_tol"] = args.drop_tol 226 | 227 | model = model(**config) 228 | print(f"load checkpoint: {checkpoint}") 229 | model.load_state_dict(torch.load(checkpoint + f"/{args.weights}.pt", 230 | map_location=torch.device(model.device))) 231 | 232 | else: 233 | model = model(**{"global_features": 0, "latent_size": 8, "augment_nodes": False, 234 | "message_passing_steps": 3, "skip_connections": True, "activation": "relu", 235 | "aggregate": None, "decode_nodes": False}) 236 | 237 | print("No checkpoint provided, using random weights") 238 | 239 | return model 240 | 241 | 242 | def warmup(model, device): 243 | # set testing parameters 244 | model.to(device) 245 | model.eval() 246 | 247 | # run model warmup 248 | test_size = 1_000 249 | matrix = scipy.sparse.coo_matrix((np.ones(test_size), (np.arange(test_size), np.arange(test_size)))) 250 | data = matrix_to_graph_sparse(matrix, torch.ones(test_size)) 251 | data.to(device) 252 | _ = model(data) 253 | 254 | print("Model warmup done...") 255 | 256 | 257 | # argument is the model to load and the dataset to evaluate on 258 | def argparser(): 259 | parser = argparse.ArgumentParser() 260 | 261 | parser.add_argument("--name", type=str, default=None) 262 | parser.add_argument("--device", type=int, required=False) 263 | 264 | # select data driven model to run 265 | parser.add_argument("--model", type=str, required=False, default="none") 266 | parser.add_argument("--checkpoint", type=str, required=False) 267 | parser.add_argument("--weights", type=str, required=False, default="model") 268 | parser.add_argument("--drop_tol", type=float, default=0) 269 | 270 | parser.add_argument("--solver", type=str, default="cg") 271 | 272 | # select dataset and subset 273 | parser.add_argument("--dataset", type=str, required=False, default="random") 274 | parser.add_argument("--subset", type=str, required=False, default="test") 275 | parser.add_argument("--n", type=int, required=False, default=0) 276 | parser.add_argument("--samples", type=int, required=False, default=None) 277 | 278 | # select if to save 279 | parser.add_argument("--save", action='store_true', default=False) 280 | 281 | return parser.parse_args() 282 | 283 | 284 | def main(): 285 | args = argparser() 286 | 287 | if args.device is not None: 288 | test_device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu") 289 | else: 290 | test_device = "cpu" 291 | 292 | if args.name is not None: 293 | folder = "results/" + args.name 294 | else: 295 | folder = folder = "results/" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 296 | 297 | print() 298 | print(f"Using device: {test_device}") 299 | # torch.set_num_threads(1) 300 | 301 | # Load the model 302 | if args.model == "nif" or args.model == "neuralif": 303 | print("Use model: NeuralIF") 304 | model = NeuralIF 305 | 306 | elif args.model == "lu" or args.model == "learnedlu": 307 | print("Use model: LU") 308 | model = LearnedLU 309 | 310 | assert args.solver == "gmres", "LU only supports GMRES solver" 311 | 312 | elif args.model == "neural_pcg" or args.model == "neuralpcg": 313 | print("Use model: NeuralPCG") 314 | model = NeuralPCG 315 | 316 | elif args.model == "precondnet": 317 | print("Use model: precondnet") 318 | model = PreCondNet 319 | 320 | elif args.model == "none": 321 | print("Running non-data-driven baselines") 322 | model = None 323 | 324 | else: 325 | raise NotImplementedError(f"Model {args.model} not available.") 326 | 327 | if model is not None: 328 | model = load_checkpoint(model, args, test_device) 329 | warmup(model, test_device) 330 | 331 | spd = args.solver == "cg" or args.solver == "direct" 332 | testdata_loader = get_dataloader(args.dataset, n=args.n, batch_size=1, mode=args.subset, 333 | size=args.samples, spd=spd, graph=True) 334 | 335 | # Evaluate the model 336 | test(model, testdata_loader, test_device, folder, 337 | save_results=args.save, dataset=args.dataset, solver=args.solver) 338 | 339 | 340 | if __name__ == "__main__": 341 | main() 342 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import argparse 4 | import pprint 5 | import torch 6 | import torch_geometric 7 | import time 8 | 9 | from apps.data import get_dataloader, graph_to_matrix 10 | 11 | from neuralif.utils import count_parameters, save_dict_to_file, condition_number, eigenval_distribution, gershgorin_norm 12 | from neuralif.logger import TrainResults, TestResults 13 | from neuralif.loss import loss 14 | from neuralif.models import NeuralPCG, NeuralIF, PreCondNet, LearnedLU 15 | 16 | from krylov.cg import preconditioned_conjugate_gradient 17 | from krylov.gmres import gmres 18 | from krylov.preconditioner import LearnedPreconditioner 19 | 20 | 21 | @torch.no_grad() 22 | def validate(model, validation_loader, solve=False, solver="cg", **kwargs): 23 | model.eval() 24 | 25 | acc_loss = 0.0 26 | num_loss = 0 27 | acc_solver_iters = 0.0 28 | 29 | for i, data in enumerate(validation_loader): 30 | data = data.to(device) 31 | 32 | # construct problem data 33 | A, b = graph_to_matrix(data) 34 | 35 | # run conjugate gradient method 36 | # this requires the learned preconditioner to be reasonably good! 37 | if solve: 38 | # run CG on CPU 39 | with torch.inference_mode(): 40 | preconditioner = LearnedPreconditioner(data, model) 41 | 42 | A = A.to("cpu").to(torch.float64) 43 | b = b.to("cpu").to(torch.float64) 44 | x_init = None 45 | 46 | solver_start = time.time() 47 | 48 | if solver == "cg": 49 | l, x_hat = preconditioned_conjugate_gradient(A.to("cpu"), b.to("cpu"), M=preconditioner, 50 | x0=x_init, rtol=1e-6, max_iter=1_000) 51 | elif solver == "gmres": 52 | l, x_hat = gmres(A, b, M=preconditioner, x0=x_init, atol=1e-6, max_iter=1_000, left=False) 53 | else: 54 | raise NotImplementedError("Solver not implemented choose between CG and GMRES!") 55 | 56 | solver_stop = time.time() 57 | 58 | # Measure preconditioning performance 59 | solver_time = (solver_stop - solver_start) 60 | acc_solver_iters += len(l) - 1 61 | 62 | else: 63 | output, _, _ = model(data) 64 | 65 | # Here, we compute the loss using the full forbenius norm (no estimator) 66 | # l = frobenius_loss(output, A) 67 | 68 | l = loss(data, output, config="frobenius") 69 | 70 | acc_loss += l.item() 71 | num_loss += 1 72 | 73 | if solve: 74 | # print(f"Smallest eigenvalue: {dist[0]}") 75 | print(f"Validation\t iterations:\t{acc_solver_iters / len(validation_loader):.2f}") 76 | return acc_solver_iters / len(validation_loader) 77 | 78 | else: 79 | print(f"Validation loss:\t{acc_loss / num_loss:.2f}") 80 | return acc_loss / len(validation_loader) 81 | 82 | 83 | def main(config): 84 | if config["save"]: 85 | os.makedirs(folder, exist_ok=True) 86 | save_dict_to_file(config, os.path.join(folder, "config.json")) 87 | 88 | # global seed-ish 89 | torch_geometric.seed_everything(config["seed"]) 90 | 91 | # args for the model 92 | model_args = {k: config[k] for k in ["latent_size", "message_passing_steps", "skip_connections", 93 | "augment_nodes", "global_features", "decode_nodes", 94 | "normalize_diag", "activation", "aggregate", "graph_norm", 95 | "two_hop", "edge_features", "normalize"] 96 | if k in config} 97 | 98 | # run the GMRES algorithm instead of CG (?) 99 | gmres = False 100 | 101 | # Create model 102 | if config["model"] == "neuralpcg": 103 | model = NeuralPCG(**model_args) 104 | 105 | elif config["model"] == "nif" or config["model"] == "neuralif" or config["model"] == "inf": 106 | model = NeuralIF(**model_args) 107 | 108 | elif config["model"] == "precondnet": 109 | model = PreCondNet(**model_args) 110 | 111 | elif config["model"] == "lu" or config["model"] == "learnedlu": 112 | gmres = True 113 | model = LearnedLU(**model_args) 114 | 115 | else: 116 | raise NotImplementedError 117 | 118 | model.to(device) 119 | 120 | print(f"Number params in model: {count_parameters(model)}") 121 | print() 122 | 123 | optimizer = torch.optim.AdamW(model.parameters()) 124 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20) 125 | 126 | # Setup datasets 127 | train_loader = get_dataloader(config["dataset"], config["n"], config["batch_size"], 128 | spd=not gmres, mode="train") 129 | 130 | validation_loader = get_dataloader(config["dataset"], config["n"], 1, spd=(not gmres), mode="val") 131 | 132 | best_val = float("inf") 133 | logger = TrainResults(folder) 134 | 135 | # todo: compile the model 136 | # compiled_model = torch.compile(model, mode="reduce-overhead") 137 | # model = torch_geometric.compile(model, mode="reduce-overhead") 138 | 139 | total_it = 0 140 | 141 | # Train loop 142 | for epoch in range(config["num_epochs"]): 143 | running_loss = 0.0 144 | grad_norm = 0.0 145 | 146 | start_epoch = time.perf_counter() 147 | 148 | for it, data in enumerate(train_loader): 149 | # increase iteration count 150 | total_it += 1 151 | 152 | # enable training mode 153 | model.train() 154 | 155 | start = time.perf_counter() 156 | data = data.to(device) 157 | 158 | output, reg, _ = model(data) 159 | l = loss(output, data, c=reg, config=config["loss"]) 160 | 161 | # if reg: 162 | # l = l + config["regularizer"] * reg 163 | 164 | l.backward() 165 | running_loss += l.item() 166 | 167 | # track the gradient norm 168 | if "gradient_clipping" in config and config["gradient_clipping"]: 169 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config["gradient_clipping"]) 170 | 171 | else: 172 | total_norm = 0.0 173 | 174 | for p in model.parameters(): 175 | if p.grad is not None: 176 | param_norm = p.grad.detach().data.norm(2) 177 | total_norm += param_norm.item() ** 2 178 | 179 | grad_norm = total_norm ** 0.5 / config["batch_size"] 180 | 181 | # update network parameters 182 | optimizer.step() 183 | optimizer.zero_grad() 184 | 185 | logger.log(l.item(), grad_norm, time.perf_counter() - start) 186 | 187 | # Do validation after 100 updates (to support big datasets) 188 | # convergence is expected to be pretty fast... 189 | if (total_it + 1) % 1000 == 0: 190 | 191 | # start with cg-checks after 5 iterations 192 | val_its = validate(model, validation_loader, solve=True, 193 | solver="gmres" if gmres else "cg") 194 | 195 | # use scheduler 196 | # if config["scheduler"]: 197 | # scheduler.step(val_loss) 198 | 199 | logger.log_val(None, val_its) 200 | 201 | # val_perf = val_cgits if val_cgits > 0 else val_loss 202 | val_perf = val_its 203 | 204 | if val_perf < best_val: 205 | if config["save"]: 206 | torch.save(model.state_dict(), f"{folder}/best_model.pt") 207 | best_val = val_perf 208 | 209 | epoch_time = time.perf_counter() - start_epoch 210 | 211 | # save model every epoch for analysis... 212 | if config["save"]: 213 | torch.save(model.state_dict(), f"{folder}/model_epoch{epoch+1}.pt") 214 | 215 | print(f"Epoch {epoch+1} \t loss: {1/len(train_loader) * running_loss} \t time: {epoch_time}") 216 | 217 | # save fully trained model 218 | if config["save"]: 219 | logger.save_results() 220 | torch.save(model.to(torch.float).state_dict(), f"{folder}/final_model.pt") 221 | 222 | # Test the model 223 | # wandb.run.summary["validation_chol"] = best_val 224 | print() 225 | print("Best validation loss:", best_val) 226 | 227 | 228 | def argparser(): 229 | parser = argparse.ArgumentParser() 230 | 231 | parser.add_argument("--name", type=str, default=None) 232 | parser.add_argument("--device", type=int, required=False) 233 | parser.add_argument("--save", action='store_true') 234 | 235 | # Training parameters 236 | parser.add_argument("--seed", type=int, default=42) 237 | parser.add_argument("--n", type=int, default=0) 238 | parser.add_argument("--batch_size", type=int, default=1) 239 | parser.add_argument("--num_epochs", type=int, default=100) 240 | parser.add_argument("--dataset", type=str, default="random") 241 | parser.add_argument("--loss", type=str, required=False) 242 | parser.add_argument("--gradient_clipping", type=float, default=1.0) 243 | 244 | parser.add_argument("--regularizer", type=float, default=0) 245 | parser.add_argument("--scheduler", action='store_true', default=False) 246 | 247 | # Model parameters 248 | parser.add_argument("--model", type=str, default="neuralif") 249 | 250 | parser.add_argument("--normalize", action='store_true', default=False) 251 | parser.add_argument("--latent_size", type=int, default=8) 252 | parser.add_argument("--message_passing_steps", type=int, default=3) 253 | parser.add_argument("--decode_nodes", action='store_true', default=False) 254 | parser.add_argument("--normalize_diag", action='store_true', default=False) 255 | parser.add_argument("--aggregate", nargs="*", type=str) 256 | parser.add_argument("--activation", type=str, default="relu") 257 | 258 | # NIF parameters 259 | parser.add_argument("--skip_connections", action='store_true', default=True) 260 | parser.add_argument("--augment_nodes", action='store_true') 261 | parser.add_argument("--global_features", type=int, default=0) 262 | parser.add_argument("--edge_features", type=int, default=1) 263 | parser.add_argument("--graph_norm", action='store_true') 264 | parser.add_argument("--two_hop", action='store_true') 265 | 266 | return parser.parse_args() 267 | 268 | 269 | if __name__ == "__main__": 270 | args = argparser() 271 | 272 | if args.device is None: 273 | device = "cpu" 274 | print("Warning!! Using cpu only training") 275 | print("If you have a GPU use that with the command --device {id}") 276 | print() 277 | else: 278 | device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu") 279 | 280 | if args.name is not None: 281 | folder = "results/" + args.name 282 | else: 283 | folder = folder = "results/" + datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 284 | 285 | print(f"Using device: {device}") 286 | print("Using config: ") 287 | pprint.pprint(vars(args)) 288 | print() 289 | 290 | # run experiments 291 | main(vars(args)) 292 | --------------------------------------------------------------------------------