├── .gitignore ├── .vscode └── launch.json ├── BNN └── JaxLightning_BNN.py ├── LICENSE ├── README.md ├── ScoreBasedGenerativeModelling ├── configs │ └── config.yaml ├── main.py ├── src │ ├── ScoreBased_Data.py │ ├── ScoreBased_Hyperparameters.py │ └── ScoreBased_Models.py └── submit.sh ├── assets ├── InANutshell.png ├── automaticoptimization.png ├── code.png ├── dataloader.png └── now_kiss.jpeg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | #Local files 10 | **/wandb/** 11 | **/lightning_logs/** 12 | **/data/** 13 | **/checkpoints/** 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | **/wandb/** 42 | **/lightning_logs/** 43 | **/data/** 44 | **/checkpoints/** 45 | **/outputs/** 46 | **/pcluster_logs/** 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | cover/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | .pybuilder/ 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | # For a library or package, you might want to ignore these files since the code is 100 | # intended to run in multiple environments; otherwise, check them in: 101 | # .python-version 102 | 103 | # pipenv 104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 107 | # install all needed dependencies. 108 | #Pipfile.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /BNN/JaxLightning_BNN.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | from torch import Tensor 3 | from torch.utils.data import DataLoader, TensorDataset, Dataset 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.trainer import Trainer 6 | from pl_bolts.datamodules import MNISTDataModule 7 | import matplotlib.pyplot as plt 8 | from typing import Tuple 9 | import einops 10 | import jax, jax.numpy as jnp 11 | import optax, equinox as eqx 12 | 13 | 14 | def numpy_collate(batch): 15 | if isinstance(batch[0], np.ndarray): 16 | return np.stack(batch) 17 | elif isinstance(batch[0], (tuple, list)): 18 | transposed = zip(*batch) 19 | return [numpy_collate(samples) for samples in transposed] 20 | else: 21 | return np.array(batch) 22 | 23 | 24 | class NumpyTensorDataset(Dataset): 25 | r"""Dataset wrapping tensors. 26 | 27 | Each sample will be retrieved by indexing tensors along the first dimension. 28 | 29 | Args: 30 | *tensors (Tensor): tensors that have the same size of the first dimension. 31 | """ 32 | 33 | tensors: Tuple[Tensor, ...] 34 | 35 | def __init__(self, *tensors: Tensor) -> None: 36 | """ 37 | PyTorch: tensor.size() -> tuple -> tensor.size(0) gives batch dim 38 | Numpy: nparray.size() is not indexable, so we use shape 39 | """ 40 | assert all( 41 | tensors[0].shape[0] == tensor.shape[0] for tensor in tensors 42 | ), "Size mismatch between tensors" 43 | self.tensors = tensors 44 | 45 | def __getitem__(self, index): 46 | return tuple(tensor[index] for tensor in self.tensors) 47 | 48 | def __len__(self): 49 | return self.tensors[0].shape[0] 50 | 51 | 52 | class RegressionDataModule(pl.LightningDataModule): 53 | num_samples = 200 54 | x_noise_std = 0.01 55 | y_noise_std = 0.2 56 | 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def setup(self, stage: str) -> None: 61 | x = jnp.linspace(-0.35, 0.55, self.num_samples) 62 | 63 | x_noise = np.random.normal(0.0, self.x_noise_std, size=x.shape) 64 | 65 | std = np.linspace(0, self.y_noise_std, self.num_samples) # * y_noise_std 66 | # print(std.shape) 67 | non_stationary_noise = np.random.normal(loc=0, scale=std) 68 | y_noise = non_stationary_noise 69 | 70 | y = ( 71 | x 72 | + 0.3 * jnp.sin(2 * jnp.pi * (x + x_noise)) 73 | + 0.3 * jnp.sin(4 * jnp.pi * (x + x_noise)) 74 | + y_noise 75 | ) 76 | 77 | x, y = x.reshape(-1, 1), y.reshape(-1, 1) 78 | 79 | x = (x - x.mean(axis=0)) / x.std(axis=0) 80 | y = (y - x.mean(axis=0)) / y.std(axis=0) 81 | 82 | self.X = np.array(x) 83 | self.Y = np.array(y) 84 | 85 | def train_dataloader(self): 86 | return DataLoader( 87 | NumpyTensorDataset(self.X, self.Y), 88 | shuffle=True, 89 | batch_size=16, 90 | collate_fn=numpy_collate, 91 | drop_last=False, 92 | ) 93 | 94 | def val_dataloader(self): 95 | return DataLoader( 96 | TensorDataset(self.X[:100], self.Y[:100]), 97 | shuffle=True, 98 | batch_size=32, 99 | collate_fn=numpy_collate, 100 | ) 101 | 102 | 103 | class BayesLinear(eqx.Module): 104 | weight_mu: jax.numpy.ndarray 105 | weight_rho: jax.numpy.ndarray 106 | bias_mu: jax.numpy.ndarray 107 | bias_rho: jax.numpy.ndarray 108 | 109 | def __init__(self, in_size, out_size, key): 110 | wkey, bkey = jax.random.split(key) 111 | self.weight_mu = ( 112 | jax.random.normal(wkey, (out_size, in_size)) / (out_size + in_size) ** 0.5 113 | ) 114 | self.weight_rho = jax.numpy.full(shape=self.weight_mu.shape, fill_value=-3) 115 | self.bias_mu = jax.random.normal(bkey, (out_size,)) * 0.1 116 | self.bias_rho = jax.numpy.full(shape=self.bias_mu.shape, fill_value=-3) 117 | 118 | def __call__(self, x, key): 119 | w_eps = jax.random.normal(key=key, shape=self.weight_mu.shape) 120 | b_eps = jax.random.normal(key=key, shape=self.bias_mu.shape) 121 | w = self.weight_mu + jax.nn.softplus(self.weight_rho) * w_eps 122 | b = self.bias_mu + jax.nn.softplus(self.bias_rho) * b_eps 123 | # print(x.shape, w.shape, b.shape) 124 | 125 | return x @ w.T + b 126 | 127 | def kl_div(self): 128 | weight_scale = jax.nn.softplus(self.weight_rho) 129 | kl_div = jnp.log(1.0) - jnp.log(weight_scale) 130 | kl_div += (weight_scale**2 + (self.weight_mu - 0) ** 2) / (2) - 0.5 131 | return kl_div.sum() 132 | 133 | 134 | class BNN(eqx.Module): 135 | bnn: eqx.nn.Sequential 136 | 137 | def __init__(self, key): 138 | hidden = 51 139 | self.bnn = eqx.nn.Sequential( 140 | [ 141 | BayesLinear(1, hidden, key), 142 | eqx.nn.Lambda(jax.nn.gelu), 143 | BayesLinear(hidden, hidden, key), 144 | eqx.nn.Lambda(jax.nn.gelu), 145 | BayesLinear(hidden, hidden, key), 146 | eqx.nn.Lambda(jax.nn.gelu), 147 | BayesLinear(hidden, hidden, key), 148 | eqx.nn.Lambda(jax.nn.gelu), 149 | BayesLinear(hidden, 1, key), 150 | ] 151 | ) 152 | 153 | def __call__(self, x, key): 154 | return self.bnn(x, key=key) 155 | 156 | def kl_div(self): 157 | kl_div = 0 158 | for layer in self.bnn.layers: 159 | if type(layer) == BayesLinear: 160 | kl_div += layer.kl_div() 161 | return kl_div 162 | 163 | 164 | class JaxLightning(pl.LightningModule): 165 | def __init__(self): 166 | super().__init__() 167 | self.automatic_optimization = False 168 | self.key = jax.random.PRNGKey(1) 169 | self.MC = 5 170 | 171 | self.key, subkey = jax.random.split(self.key) 172 | self.bnn = BNN(key=subkey) 173 | self.global_step_ = 0 174 | 175 | @property 176 | def global_step(self): 177 | """ 178 | self.global_step is an attribute without setter and is updated somewhere deep within Lightning 179 | Simply overwrite global_step as a property to access it normally in the LightningModule 180 | :return: 181 | """ 182 | return self.global_step_ 183 | 184 | def on_fit_start(self) -> None: 185 | self.num_data_samples = self.trainer.datamodule.num_samples 186 | self.viz_network("On Train Start") 187 | 188 | def training_step(self, batch): 189 | """Standard PyTorch Lightning training step ... but with Jax in it!""" 190 | self.global_step_ += 1 191 | data, target = ( 192 | jnp.array(batch[0].reshape(-1, *batch[0].shape[1:])), 193 | jnp.array(batch[1]), 194 | ) 195 | data = einops.repeat(data, "... -> b ...", b=self.MC) 196 | target = einops.repeat(target, "... -> b ...", b=self.MC) 197 | 198 | self.key, *subkeys = jax.random.split( 199 | self.key, num=self.MC + 1 200 | ) # creating new keys 201 | subkeys = jnp.stack(subkeys) 202 | """ 203 | Jax in the middle of a Lightning module 204 | Call static gradient method from same PL module 205 | Calls in turn cost function doing the jit compiled forward, backward and gradient update step 206 | """ 207 | loss, metrics, self.bnn, self.optim, self.opt_state = JaxLightning.make_step( 208 | self.bnn, 209 | data, 210 | target, 211 | self.num_data_samples, 212 | subkeys, 213 | self.optim, 214 | self.opt_state, 215 | ) 216 | """All the logging and perks you love about Lightining""" 217 | dict = { 218 | "Loss": loss.item(), 219 | "global_step": self.global_step, 220 | "current_epoch": self.current_epoch, 221 | } 222 | self.log_dict(dict, prog_bar=True) 223 | return dict 224 | 225 | @staticmethod 226 | @eqx.filter_value_and_grad(has_aux=True) 227 | def criterion(model, x, y, num_samples, keys): 228 | """Jit-able criterion function including forward pass""" 229 | model_vmap = jax.vmap( 230 | model, in_axes=(0, 0) 231 | ) # takes [MC, B, F] features and [MC] keys 232 | pred = model_vmap(x, keys) 233 | assert pred.ndim == 3 234 | std = jax.lax.stop_gradient(pred.std(axis=0)) 235 | mse = (y - pred.mean(axis=0)) ** 2 236 | nll = -(-0.5 * mse / std**2).sum() * num_samples 237 | kl = 1.0 * model.kl_div() 238 | return (nll + kl, {"nll": nll, "kl": kl, "std": std.mean(), "mse": mse.mean()}) 239 | 240 | @staticmethod 241 | @eqx.filter_jit 242 | def make_step(model, x, y, num_samples, keys, optim, opt_state): 243 | """Jit-able gradient and parameter update""" 244 | (loss, metrics), grads = JaxLightning.criterion(model, x, y, num_samples, keys) 245 | updates, opt_state = optim.update(grads, opt_state) 246 | model = eqx.apply_updates(model, updates) 247 | return loss, metrics, model, optim, opt_state 248 | 249 | def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None: 250 | """ 251 | In automatic_optimization=False mode, the global_step attribute is (for some reason) not automatically incremented, so we have to use our own should_stop criterion 252 | """ 253 | if self.global_step >= self.trainer.max_steps and self.trainer.max_steps > 0: 254 | self.trainer.should_stop = True 255 | 256 | def on_train_end(self) -> None: 257 | self.viz_network(title="On Fit End") 258 | 259 | def viz_network(self, title=""): 260 | MC = 201 261 | 262 | self.key, *subkeys = jax.random.split(self.key, num=MC + 1) # creating new keys 263 | subkeys = jnp.stack(subkeys) 264 | pred_model = jax.jit(jax.vmap(self.bnn, in_axes=(0, 0))) 265 | 266 | viz_input = jnp.linspace(-4, 4, 100).reshape(-1, 1) 267 | x_mc = einops.repeat(viz_input, "... -> b ...", b=MC) 268 | print(x_mc.shape, len(subkeys)) 269 | mc_pred = pred_model(x_mc, subkeys) 270 | 271 | fig = plt.figure(figsize=(10, 10)) 272 | for pred in mc_pred.squeeze(): 273 | plt.plot( 274 | viz_input, pred, color="red", alpha=1 / min(mc_pred.shape[0], 50.0) 275 | ) 276 | X, Y = self.trainer.datamodule.X, self.trainer.datamodule.Y 277 | plt.scatter(X, Y, alpha=1 / min(mc_pred.shape[0], 1.0), s=5) 278 | plt.ylim(-3, 3) 279 | plt.title(title) 280 | plt.show() 281 | 282 | def configure_optimizers(self): 283 | self.optim = optax.adam(0.001) 284 | self.opt_state = self.optim.init(eqx.filter(self.bnn, eqx.is_array)) 285 | 286 | 287 | print(jax.devices()) 288 | 289 | bnn = JaxLightning() 290 | dm = RegressionDataModule() 291 | trainer = Trainer(max_steps=4000) 292 | trainer.fit(bnn, dm) 293 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ludwigwinkler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JaxLightning 2 | PyTorch Lightning + Jax = nice 3 | 4 | # PyTorch Lightning 5 | 6 | The package has become the go-to standard of ML research for PyTorch. 7 | Above all, it removes all the boiler plate code that you usually have to write for kicking off a simple experiment. 8 | Additionally, you got amazing logging, general code structure, data management via LightningDataModules and other templates making quick iteration a breeze. 9 | 10 | # Jax 11 | 12 | Recent packages such as Equinox and Treex are at the top level very similar in structure and handling like PyTorch. 13 | This makes the code very readable and succinct. 14 | The biggest advantage of Jax is probably its clean functional programming (I've come around to that) and its speed. 15 | Vmap, derivatives in all directions and automatic accelerator management (no more tensor.to(deviceXYZ)) is also part of the gift box. 16 | Also, the explicit random keys remove an entire library of possible problems. 17 | 18 | You can find a speed comparison at the deep learning course of [UvA](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html). 19 | The take away of that PyTorch vs Jax comparison is that Jax excels at compiling forward and backward passes consisting of lots of SIMD (single instruction multiple data) instructions such as separate calls of small-ish kernel convolutions on locally independent data. 20 | The speed-up can be considerable at **2.5X-3.4X** but is expected to regress to PyTorchs performance to drop once the time spend on executing those calls increases relative to the function call itself. 21 | That means that for large batch sizes, single instructions on very large tensors will depend on the hardware and the XLA compiler can do little about that. 22 | 23 | # Can we get the best of both worlds? Yes, we can[1]. 24 | 25 | References: [1] 44th President of the United States Barack Obama 26 | 27 | You can run any Jax code and thus any Jax neural network package inside of PyTorch Lightning, be it written in Flax/Haiku/Equinox/Treex and optimized with the common optimization libraries. 28 | 29 | The necessary steps are: 30 | 31 | **1. Turn off via `automatic_optimization=False` ... we gotta do that stuff ourselves.** 32 | 33 | Dataloader 34 | 35 | **2. Load data strictly in 'numpy' mode by modifying the `collate_fn` in PyTorch.** 36 | 37 | Dataloader 38 | 39 | **3. Run the forward, backward and gradient update step within lightning with `@staticmethod` decorators.** 40 | 41 | Dataloader 42 | 43 | Done. Now we got the best of Jax with the best of Lightning. :) 44 | 45 | ### Tensors vs Arrays 46 | 47 | The main idea of combining the great and convenient code structure of PyTorch Lightning with the versatility of Jax is to restrict PyTorch Lightning to pure Numpy/Jax.Numpy until the data 'reaches' the Jax model. 48 | Therefore we can reuse almost all DataModules and DataSets and remove the single line, where data is cast to torch.Tensors. 49 | Thus the dataloader/datamodules etc restricted to Numpy/Jax.Numpy operations. 50 | 51 | ### Optimization 52 | 53 | Secondly, we can't use PyTorch Lightning automatic optimization which makes setting up experiments in PL so convenient. 54 | But at the same time Jax does automatic device placement and moving tensors to the correct devices. 55 | Thus by simply setting the class variable `automatic_optimization=False` we gain complete control over all gradient computations and gradient descent optimization and tell PL that we'll do our optimization on our own. 56 | 57 | Since Jax requires pure functions, all we have to do is make the forward step a `@staticmethod` without the `self` argument. 58 | Similarly, we can create a static gradient function in the same way. 59 | 60 | Question: What do we gain? 61 | 62 | Answer: We can jit-compile the entire forward and backward pass with **JAX** with a simple decorator inside the training setup of **Pytorch Lightning**. 63 | 64 | Thus PyTorch Lightning takes care of all the data set management, the logging, the tracking and the overall training loop structure with all the convenience PL is famous for, and Jax does the fast computing inside of PL. 65 | 66 | Everybody wins ... 67 | 68 | ### Examples 69 | 70 | There are two examples how I used PyTorch Lightning for my own Bayesian Neural Networks and Score Based Generative Modelling ([used code by Patrick Kidger](https://docs.kidger.site/equinox/examples/score_based_diffusion/)). 71 | 72 | ![](assets/now_kiss.jpeg) 73 | -------------------------------------------------------------------------------- /ScoreBasedGenerativeModelling/configs/config.yaml: -------------------------------------------------------------------------------- 1 | logging: online 2 | project: jax_score_based 3 | load_last_checkpoint: False 4 | show: True 5 | seed: 12345 6 | fast_dev_run: 0 7 | save_figs: False 8 | watch_model_training: False 9 | hparams_id: None 10 | experiment: mnist -------------------------------------------------------------------------------- /ScoreBasedGenerativeModelling/main.py: -------------------------------------------------------------------------------- 1 | import sys, argparse 2 | import pathlib 3 | from pathlib import Path 4 | import matplotlib.pyplot as plt 5 | import wandb 6 | import jax, jax.numpy as jnp, equinox as eqx, optax, jax.random as jr 7 | import diffrax as dfx 8 | import functools as ft 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.trainer import Trainer 11 | import torch, einops 12 | 13 | from pytorch_lightning.loggers import WandbLogger 14 | 15 | import hydra 16 | from omegaconf import DictConfig 17 | 18 | 19 | # filepath = Path().resolve() 20 | # phd_path = Path(str(filepath).split('PhD')[0] + 'PhD') 21 | # jax_path = Path(str(filepath).split('Jax')[0] + 'Jax') 22 | # score_path = Path(str(filepath).split('Jax')[0] + 'Jax/ScoreBasedGenerativeModelling') 23 | # sys.path.append(str(phd_path)) 24 | # sys.path.append(str(jax_path)) 25 | from src.ScoreBased_Data import MNISTDataModule 26 | from src.ScoreBased_Models import Mixer2d 27 | from src.ScoreBased_Hyperparameters import str2bool, process_hparams 28 | 29 | 30 | class JaxLightning(pl.LightningModule): 31 | def __init__(self, **kwargs): 32 | super().__init__() 33 | self.save_hyperparameters() 34 | self.automatic_optimization = False 35 | 36 | self.key = jax.random.PRNGKey(1) 37 | self.key, self.model_key, self.train_key, self.loader_key, self.sample_key = ( 38 | jax.random.split(self.key, 5) 39 | ) 40 | 41 | self.key, subkey = jax.random.split(self.key) 42 | self.model = Mixer2d( 43 | (1, 28, 28), 44 | patch_size=4, 45 | hidden_size=64, 46 | mix_patch_size=512, 47 | mix_hidden_size=512, 48 | num_blocks=4, 49 | t1=10, 50 | key=subkey, 51 | ) 52 | 53 | self.t1 = 10.0 54 | self.int_beta = lambda t: t 55 | self.weight = lambda t: 1 - jnp.exp(-self.int_beta(t)) 56 | self.dt0 = 0.1 57 | self.samples = 10 58 | 59 | self.global_step_ = 0 60 | 61 | self.configure_optimizers() 62 | 63 | def on_fit_start(self) -> None: 64 | pathlib.Path("checkpoints").mkdir(exist_ok=True) 65 | 66 | if self.hparams.load_last_checkpoint: 67 | try: 68 | self.model = eqx.tree_deserialise_leaves( 69 | f"checkpoints/ScoreBased/last.eqx", self.model 70 | ) 71 | print("Loaded weights") 72 | except: 73 | print("Didnt load weights") 74 | 75 | self.logger.log_image( 76 | key="Samples P1", 77 | images=[wandb.Image(self.sample(), caption="Samples P1")], 78 | ) 79 | 80 | # def on_fit_end(self): 81 | # pathlib.Path.mkdir(f"checkpoints/ScoreBased", parents=True, exist_ok=True) 82 | # eqx.tree_serialise_leaves(f"checkpoints/ScoreBased/last.eqx", self.model) 83 | 84 | def training_step(self, batch): 85 | data = batch[0] 86 | value, self.model, self.train_key, self.opt_state = JaxLightning.make_step( 87 | self.model, 88 | self.weight, 89 | self.int_beta, 90 | data, 91 | self.t1, 92 | self.train_key, 93 | self.opt_state, 94 | self.optim.update, 95 | ) 96 | dict_ = {"loss": torch.scalar_tensor(value.item())} 97 | self.log_dict(dict_, prog_bar=True) 98 | self.global_step_ += 1 99 | return dict_ 100 | 101 | def sample(self): 102 | self.sample_key, *sample_key = jr.split(self.sample_key, self.samples**2 + 1) 103 | sample_key = jnp.stack(sample_key) 104 | sample_fn = ft.partial( 105 | JaxLightning.single_sample_fn, 106 | self.model, 107 | self.int_beta, 108 | (1, 28, 28), 109 | self.dt0, 110 | self.t1, 111 | ) 112 | sample = jax.vmap(sample_fn)(sample_key) 113 | # sample = data_mean + data_std * sample 114 | # sample = jnp.clip(sample, data_min, data_max) 115 | sample = einops.rearrange( 116 | sample, "(n1 n2) 1 h w -> (n1 h) (n2 w)", n1=self.samples, n2=self.samples 117 | ) 118 | fig = plt.figure() 119 | plt.imshow(sample, cmap="Greys") 120 | plt.axis("off") 121 | plt.title(f"{self.global_step_}") 122 | plt.tight_layout() 123 | if self.hparams.show: 124 | plt.show() 125 | return fig 126 | 127 | def validation_step(self, batch): 128 | data = batch[0] 129 | value, self.model, self.train_key, self.opt_state = JaxLightning.make_step( 130 | self.model, 131 | self.weight, 132 | self.int_beta, 133 | data, 134 | self.t1, 135 | self.train_key, 136 | self.opt_state, 137 | self.optim.update, 138 | ) 139 | # dict_ = {"loss": torch.scalar_tensor(value.item())} 140 | self.log("Val_Loss", jnp.asarray(value).item(), prog_bar=True, batch_size=1) 141 | 142 | def on_validation_epoch_end(self) -> None: 143 | # pathlib.Path.mkdir(f"checkpoints/ScoreBased", parents=True, exist_ok=True) 144 | # eqx.tree_serialise_leaves(f"checkpoints/ScoreBased/last.eqx", self.model) 145 | self.logger.log_image( 146 | key="Samples P1", 147 | images=[wandb.Image(self.sample(), caption="Samples P1")], 148 | ) 149 | 150 | def configure_optimizers(self): 151 | self.optim = optax.adam(3e-4) 152 | self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_inexact_array)) 153 | 154 | @staticmethod 155 | @eqx.filter_jit 156 | def single_sample_fn(model, int_beta, data_shape, dt0, t1, key): 157 | """ 158 | Sampling a single trajectory starting from normal noise at t1 and recovering data distribution at t0 159 | :param model: 160 | :param int_beta: 161 | :param data_shape: 162 | :param dt0: 163 | :param t1: 164 | :param key: 165 | :return: 166 | """ 167 | 168 | def drift(t, y, args): 169 | """ 170 | compute time derivative of function dß(t)/dt 171 | Noising SDE: dx(t) = -1/2 ß(t) x(t) dt + ß(t)^1/2 dW_t -> μ(x(t)) = - 1/2 ß(t) x(t) dt and σ^2 = ß(t) 172 | Reverse SDE: μ(x(tau)) = 1/2 ß(t) x(t) + ß(t) ∇ log p 173 | :param t: 174 | :param y: 175 | :param args: 176 | :return: 177 | """ 178 | t = jnp.array(t) 179 | _, beta = jax.jvp(fun=int_beta, primals=(t,), tangents=(jnp.ones_like(t),)) 180 | return ( 181 | -0.5 * beta * (y + model(t, y)) 182 | ) # negative because we use -dt0 when solving 183 | 184 | term = dfx.ODETerm(drift) 185 | solver = dfx.Tsit5() 186 | t0 = 0 187 | y1 = jr.normal( 188 | key, data_shape 189 | ) # noise at t1, from which integrate backwards to data distribution 190 | # reverse time, solve from t1 to t0 191 | sol = dfx.diffeqsolve( 192 | terms=term, 193 | solver=solver, 194 | t0=t1, 195 | t1=t0, 196 | dt0=-dt0, 197 | y0=y1, 198 | # adjoint=dfx.NoAdjoint(), 199 | ) 200 | return sol.ys[0] 201 | 202 | @staticmethod 203 | def single_loss_fn(model, weight, int_beta, data, t, key): 204 | """ 205 | OU process provides analytical mean and variance 206 | int_beta(t) = ß = θ 207 | E[X_t] = μ + exp[-θ t] ( X_0 - μ) w/ μ=0 gives =X_0 * exp[ - θ t ] 208 | V[X_t] = σ^2/(2θ) ( 1 - exp(-2 θ t) ) w/ σ^2=ß=θ gives = 1 - exp(-2 ß t) 209 | :param model: 210 | :param weight: 211 | :param int_beta: 212 | :param data: 213 | :param t: 214 | :param key: 215 | :return: 216 | """ 217 | mean = data * jnp.exp(-0.5 * int_beta(t)) # analytical mean of OU process 218 | var = jnp.maximum( 219 | 1 - jnp.exp(-int_beta(t)), 1e-5 220 | ) # analytical variance of OU process 221 | std = jnp.sqrt(var) 222 | noise = jr.normal(key, data.shape) 223 | y = mean + std * noise 224 | pred = model(t, y) 225 | return weight(t) * jnp.mean((pred + noise / std) ** 2) # loss 226 | 227 | @staticmethod 228 | def batch_loss_fn(model, weight, int_beta, data, t1, key): 229 | batch_size = data.shape[0] 230 | tkey, losskey = jr.split(key) 231 | losskey = jr.split(losskey, batch_size) 232 | """ 233 | Low-discrepancy sampling over t to reduce variance 234 | by sampling very evenly by sampling uniformly and independently from (t1-t0)/batch_size bins 235 | t = [U(0,1), U(1,2), U(2,3), ...] 236 | """ 237 | t = jr.uniform(tkey, (batch_size,), minval=0, maxval=t1 / batch_size) 238 | t = t + (t1 / batch_size) * jnp.arange(batch_size) 239 | """ Fixing the first three arguments of single_loss_fn, leaving data, t and key as input """ 240 | loss_fn = ft.partial(JaxLightning.single_loss_fn, model, weight, int_beta) 241 | loss_fn = jax.vmap(loss_fn) 242 | return jnp.mean(loss_fn(data, t, losskey)) 243 | 244 | @staticmethod 245 | @eqx.filter_jit 246 | def make_step(model, weight, int_beta, data, t1, key, opt_state, opt_update): 247 | loss_fn = eqx.filter_value_and_grad(JaxLightning.batch_loss_fn) 248 | loss, grads = loss_fn(model, weight, int_beta, data, t1, key) 249 | updates, opt_state = opt_update(grads, opt_state) 250 | model = eqx.apply_updates(model, updates) 251 | key = jr.split(key, 1)[0] 252 | return loss, model, key, opt_state 253 | 254 | 255 | # hparams = argparse.ArgumentParser() 256 | # logging = [0, 1, 1 if torch.cuda.is_available() else 0][-1] 257 | # hparams = JaxLightning.args( 258 | # hparams, 259 | # logging=["disabled", "online"][logging], 260 | # project="jax_score_based", 261 | # load_last_checkpoint=True, 262 | # show=True, 263 | # seed=[12345, 2345, 98765][1], 264 | # fast_dev_run=0, 265 | # save_figs=[False, True][0], 266 | # watch_model_training=[False, True][0], 267 | # ) 268 | 269 | # temp_args, _ = hparams.parse_known_args() 270 | 271 | 272 | @hydra.main( 273 | version_base=None, 274 | config_name="config", 275 | config_path="configs", 276 | ) 277 | def main(cfg: DictConfig) -> None: 278 | print(cfg) 279 | 280 | hparams = process_hparams(cfg, print_hparams=True) 281 | 282 | logger = WandbLogger( 283 | project=hparams.project, name=hparams.experiment, mode=hparams.logging 284 | ) 285 | 286 | dm = MNISTDataModule() 287 | dm.prepare_data() 288 | dm.setup() 289 | 290 | # exit() 291 | 292 | scorebased = JaxLightning(**hparams) 293 | 294 | trainer = Trainer( 295 | max_steps=1_000_000, 296 | accelerator="cpu", 297 | logger=logger, 298 | check_val_every_n_epoch=10, 299 | ) 300 | trainer.fit(scorebased, dm) 301 | 302 | 303 | if __name__ == "__main__": 304 | main() 305 | -------------------------------------------------------------------------------- /ScoreBasedGenerativeModelling/src/ScoreBased_Data.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pathlib import Path 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torchvision 6 | import torchvision.datasets 7 | from torch.utils.data import DataLoader 8 | 9 | filepath = Path().resolve() 10 | phd_path = Path(str(filepath).split("PhD")[0] + "PhD") 11 | jax_path = Path(str(filepath).split("Jax")[0] + "Jax") 12 | score_path = Path(str(filepath).split("Jax")[0] + "Jax/ScoreBasedGenerativeModelling") 13 | sys.path.append(str(phd_path)) 14 | 15 | 16 | def numpy_collate(batch): 17 | if isinstance(batch[0], np.ndarray): 18 | return np.stack(batch) 19 | elif isinstance(batch[0], (tuple, list)): 20 | transposed = zip(*batch) 21 | return [numpy_collate(samples) for samples in transposed] 22 | else: 23 | return np.array(batch) 24 | 25 | 26 | class MNIST(torchvision.datasets.MNIST): 27 | """ 28 | Numpy version of MNIST 29 | """ 30 | 31 | def __init__(self, **kwargs): 32 | super().__init__(**kwargs) 33 | 34 | def __getitem__(self, index: int): 35 | """ 36 | Args: 37 | index (int): Index 38 | 39 | Returns: 40 | tuple: (image, target) where target is index of the target class. 41 | """ 42 | 43 | img, target = self.data[index], int(self.targets[index]) 44 | 45 | img = img.numpy() 46 | 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | if self.target_transform is not None: 51 | target = self.target_transform(target) 52 | 53 | return img, target 54 | 55 | 56 | class MNISTDataModule(pl.LightningDataModule): 57 | def __init__(self): 58 | super().__init__() 59 | self.transform = torchvision.transforms.Compose( 60 | [ 61 | torchvision.transforms.Lambda(lambda x: x / 255), 62 | torchvision.transforms.Lambda(lambda x: x.reshape(1, 28, 28)), 63 | ] 64 | ) 65 | 66 | def prepare_data(self) -> None: 67 | if not os.path.isdir(path := score_path / "data"): 68 | torchvision.datasets.MNIST(root=path, download=True) 69 | 70 | def setup(self, stage=None): 71 | data_path = score_path / "data" 72 | 73 | self.train_data = MNIST(root=data_path, train=True, transform=self.transform) 74 | self.val_data = MNIST(root=data_path, train=False, transform=self.transform) 75 | 76 | def train_dataloader(self): 77 | return DataLoader( 78 | self.train_data, batch_size=64, shuffle=True, collate_fn=numpy_collate 79 | ) 80 | 81 | def val_dataloader(self): 82 | return DataLoader( 83 | self.val_data, batch_size=64, shuffle=True, collate_fn=numpy_collate 84 | ) 85 | -------------------------------------------------------------------------------- /ScoreBasedGenerativeModelling/src/ScoreBased_Hyperparameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import hashlib 4 | import numbers 5 | import math 6 | import os 7 | import sys 8 | import time 9 | from pathlib import Path 10 | from omegaconf import OmegaConf 11 | 12 | from pytorch_lightning.loggers import WandbLogger 13 | 14 | import wandb 15 | 16 | 17 | def str2bool(v): 18 | if isinstance(v, bool): 19 | return v 20 | elif type(v) == str: 21 | if v.lower() in ("yes", "true", "t", "y", "1"): 22 | return True 23 | elif v.lower() in ("no", "false", "f", "n", "0"): 24 | return False 25 | elif isinstance(v, numbers.Number): 26 | assert v in [0, 1] 27 | if v == 1: 28 | return True 29 | if v == 0: 30 | return False 31 | else: 32 | raise argparse.ArgumentTypeError(f"Invalid Value: {type(v)}") 33 | 34 | 35 | def process_hparams(hparams, print_hparams=True): 36 | # hparams = hparams.parse_args() 37 | 38 | if hparams.logging == "online": 39 | hparams.show = False 40 | 41 | """Create HParam ID for saving and loading checkpoints""" 42 | hashable_config = copy.deepcopy(hparams) 43 | id = hashlib.sha1( 44 | repr(sorted(hashable_config.__dict__.items())).encode() 45 | ).hexdigest() # int -> abs(int) -> str(abs(int))) 46 | hparams.hparams_id = id 47 | 48 | if print_hparams: 49 | print(OmegaConf.to_container(hparams, resolve=True)) 50 | # [print(f"\t {key}: {value}") for key, value in sorted(hparams.__dict__.items())] 51 | 52 | return hparams 53 | -------------------------------------------------------------------------------- /ScoreBasedGenerativeModelling/src/ScoreBased_Models.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import equinox as eqx 3 | import jax 4 | from flax import linen as fnn 5 | from jax import numpy as jnp, random as jr 6 | 7 | 8 | class Encoder(fnn.Module): 9 | features: int = 64 10 | training: bool = True 11 | 12 | @fnn.compact 13 | def __call__(self, x): 14 | z1 = fnn.Conv(self.features, kernel_size=(3, 3))(x) 15 | z1 = fnn.relu(z1) 16 | z1 = fnn.Conv(self.features, kernel_size=(3, 3))(z1) 17 | z1 = fnn.BatchNorm(use_running_average=not self.training)(z1) 18 | z1 = fnn.relu(z1) 19 | z1_pool = fnn.max_pool(z1, window_shape=(2, 2), strides=(2, 2)) 20 | 21 | z2 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z1_pool) 22 | z2 = fnn.relu(z2) 23 | z2 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z2) 24 | z2 = fnn.BatchNorm(use_running_average=not self.training)(z2) 25 | z2 = fnn.relu(z2) 26 | z2_pool = fnn.max_pool(z2, window_shape=(2, 2), strides=(2, 2)) 27 | 28 | z3 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z2_pool) 29 | z3 = fnn.relu(z3) 30 | z3 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z3) 31 | z3 = fnn.BatchNorm(use_running_average=not self.training)(z3) 32 | z3 = fnn.relu(z3) 33 | z3_pool = fnn.max_pool(z3, window_shape=(2, 2), strides=(2, 2)) 34 | 35 | z4 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z3_pool) 36 | z4 = fnn.relu(z4) 37 | z4 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z4) 38 | z4 = fnn.BatchNorm(use_running_average=not self.training)(z4) 39 | z4 = fnn.relu(z4) 40 | z4_dropout = fnn.Dropout(0.5, deterministic=False)(z4) 41 | z4_pool = fnn.max_pool(z4_dropout, window_shape=(2, 2), strides=(2, 2)) 42 | 43 | z5 = fnn.Conv(self.features * 16, kernel_size=(3, 3))(z4_pool) 44 | z5 = fnn.relu(z5) 45 | z5 = fnn.Conv(self.features * 16, kernel_size=(3, 3))(z5) 46 | z5 = fnn.BatchNorm(use_running_average=not self.training)(z5) 47 | z5 = fnn.relu(z5) 48 | z5_dropout = fnn.Dropout(0.5, deterministic=False)(z5) 49 | 50 | return z1, z2, z3, z4_dropout, z5_dropout 51 | 52 | 53 | class Decoder(fnn.Module): 54 | features: int = 64 55 | training: bool = True 56 | 57 | @fnn.compact 58 | def __call__(self, z1, z2, z3, z4_dropout, z5_dropout): 59 | z6_up = jax.image.resize(z5_dropout, 60 | shape=(z5_dropout.shape[0], z5_dropout.shape[1] * 2, z5_dropout.shape[2] * 2, 61 | z5_dropout.shape[3]), 62 | method='nearest') 63 | z6 = fnn.Conv(self.features * 8, kernel_size=(2, 2))(z6_up) 64 | z6 = fnn.relu(z6) 65 | z6 = jnp.concatenate([z4_dropout, z6], axis=3) 66 | z6 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z6) 67 | z6 = fnn.relu(z6) 68 | z6 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z6) 69 | z6 = fnn.BatchNorm(use_running_average=not self.training)(z6) 70 | z6 = fnn.relu(z6) 71 | 72 | z7_up = jax.image.resize(z6, 73 | shape=(z6.shape[0], z6.shape[1] * 2, z6.shape[2] * 2, z6.shape[3]), 74 | method='nearest') 75 | z7 = fnn.Conv(self.features * 4, kernel_size=(2, 2))(z7_up) 76 | z7 = fnn.relu(z7) 77 | z7 = jnp.concatenate([z3, z7], axis=3) 78 | z7 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z7) 79 | z7 = fnn.relu(z7) 80 | z7 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z7) 81 | z7 = fnn.BatchNorm(use_running_average=not self.training)(z7) 82 | z7 = fnn.relu(z7) 83 | 84 | z8_up = jax.image.resize(z7, 85 | shape=(z7.shape[0], z7.shape[1] * 2, z7.shape[2] * 2, z7.shape[3]), 86 | method='nearest') 87 | z8 = fnn.Conv(self.features * 2, kernel_size=(2, 2))(z8_up) 88 | z8 = fnn.relu(z8) 89 | z8 = jnp.concatenate([z2, z8], axis=3) 90 | z8 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z8) 91 | z8 = fnn.relu(z8) 92 | z8 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z8) 93 | z8 = fnn.BatchNorm(use_running_average=not self.training)(z8) 94 | z8 = fnn.relu(z8) 95 | 96 | z9_up = jax.image.resize(z8, 97 | shape=(z8.shape[0], z8.shape[1] * 2, z8.shape[2] * 2, z8.shape[3]), 98 | method='nearest') 99 | z9 = fnn.Conv(self.features, kernel_size=(2, 2))(z9_up) 100 | z9 = fnn.relu(z9) 101 | z9 = jnp.concatenate([z1, z9], axis=3) 102 | z9 = fnn.Conv(self.features, kernel_size=(3, 3))(z9) 103 | z9 = fnn.relu(z9) 104 | z9 = fnn.Conv(self.features, kernel_size=(3, 3))(z9) 105 | z9 = fnn.BatchNorm(use_running_average=not self.training)(z9) 106 | z9 = fnn.relu(z9) 107 | 108 | y = fnn.Conv(1, kernel_size=(1, 1))(z9) 109 | y = fnn.sigmoid(y) 110 | 111 | return y 112 | 113 | 114 | class UNet(fnn.Module): 115 | features: int = 64 116 | training: bool = True 117 | 118 | @fnn.compact 119 | def __call__(self, x): 120 | z1, z2, z3, z4_dropout, z5_dropout = Encoder(self.training)(x) 121 | y = Decoder(self.training)(z1, z2, z3, z4_dropout, z5_dropout) 122 | 123 | return y 124 | 125 | 126 | class MixerBlock(eqx.Module): 127 | patch_mixer: eqx.nn.MLP 128 | hidden_mixer: eqx.nn.MLP 129 | norm1: eqx.nn.LayerNorm 130 | norm2: eqx.nn.LayerNorm 131 | 132 | def __init__(self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, key): 133 | tkey, ckey = jr.split(key, 2) 134 | self.patch_mixer = eqx.nn.MLP(num_patches, num_patches, mix_patch_size, depth=1, key=tkey) 135 | self.hidden_mixer = eqx.nn.MLP(hidden_size, hidden_size, mix_hidden_size, depth=1, key=ckey) 136 | self.norm1 = eqx.nn.LayerNorm((hidden_size, num_patches)) 137 | self.norm2 = eqx.nn.LayerNorm((num_patches, hidden_size)) 138 | 139 | def __call__(self, y): 140 | y = y + jax.vmap(self.patch_mixer)(self.norm1(y)) 141 | y = einops.rearrange(y, "c p -> p c") 142 | y = y + jax.vmap(self.hidden_mixer)(self.norm2(y)) 143 | y = einops.rearrange(y, "p c -> c p") 144 | return y 145 | 146 | 147 | class Mixer2d(eqx.Module): 148 | conv_in: eqx.nn.Conv2d 149 | conv_out: eqx.nn.ConvTranspose2d 150 | blocks: list 151 | norm: eqx.nn.LayerNorm 152 | t1: float 153 | 154 | def __init__(self, img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, t1, *, key, ): 155 | input_size, height, width = img_size 156 | assert (height % patch_size) == 0 157 | assert (width % patch_size) == 0 158 | num_patches = (height // patch_size) * (width // patch_size) 159 | inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks) 160 | 161 | self.conv_in = eqx.nn.Conv2d(input_size + 1, hidden_size, patch_size, stride=patch_size, key=inkey) 162 | self.conv_out = eqx.nn.ConvTranspose2d(hidden_size, input_size, patch_size, stride=patch_size, key=outkey) 163 | self.blocks = [MixerBlock(num_patches, hidden_size, mix_patch_size, mix_hidden_size, key=bkey) for bkey in 164 | bkeys] 165 | self.norm = eqx.nn.LayerNorm((hidden_size, num_patches)) 166 | self.t1 = t1 167 | 168 | def __call__(self, t, y): 169 | t = t / self.t1 170 | _, height, width = y.shape 171 | t = einops.repeat(t, "-> 1 h w", h=height, w=width) 172 | y = jnp.concatenate([y, t]) 173 | y = self.conv_in(y) 174 | _, patch_height, patch_width = y.shape 175 | y = einops.rearrange(y, "c h w -> c (h w)") 176 | for block in self.blocks: 177 | y = block(y) 178 | y = self.norm(y) 179 | y = einops.rearrange(y, "c (h w) -> c h w", h=patch_height, w=patch_width) 180 | return self.conv_out(y) 181 | -------------------------------------------------------------------------------- /ScoreBasedGenerativeModelling/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J train 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --gres=gpu:1 # per node 6 | #SBATCH -p gpu2 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --output=/homefs/home/winklep4/Work/JaxLightning/pcluster_logs/%A/%A.%a-output.txt 9 | #SBATCH --error=/homefs/home/winklep4/Work/JaxLightning/pcluster_logs/%A/%A.%a-error.txt 10 | ### SBATCH --time 1-00:00:00 11 | #SBATCH --signal=SIGUSR1@90 12 | 13 | . ~/.bashrc 14 | nvidia-smi 15 | 16 | eval "$(micromamba shell hook --shell bash)" 17 | micromamba activate jax 18 | export WANDB__SERVICE_WAIT=300 19 | export HYDRA_FULL_ERROR=1 20 | 21 | wandb login --host https://genentech.wandb.io 22 | wandb login --relogin 23 | wandb artifact cache cleanup 50G 24 | 25 | echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" 26 | echo "${PWD}" 27 | 28 | ### Set the wandb agent command ------------------------------- 29 | # Set the wandb agent command 30 | WANDB_AGENT_CMD=" 31 | wandb agent ludwig-winkler/protein-correction_testing/3hngzz7v 32 | " 33 | 34 | # Check if an argument is provided 35 | if [ -n "$1" ]; then 36 | WANDB_AGENT_CMD="$WANDB_AGENT_CMD --count $1" 37 | fi 38 | 39 | # Run the wandb agent command 40 | $WANDB_AGENT_CMD 41 | 42 | ### Set the wandb agent command ------------------------------- 43 | 44 | python main.py -------------------------------------------------------------------------------- /assets/InANutshell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/InANutshell.png -------------------------------------------------------------------------------- /assets/automaticoptimization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/automaticoptimization.png -------------------------------------------------------------------------------- /assets/code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/code.png -------------------------------------------------------------------------------- /assets/dataloader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/dataloader.png -------------------------------------------------------------------------------- /assets/now_kiss.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/now_kiss.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohappyeyeballs==2.4.3 3 | aiohttp==3.10.10 4 | aiosignal==1.3.1 5 | antlr4-python3-runtime==4.9.3 6 | attrs==24.2.0 7 | certifi==2024.8.30 8 | charset-normalizer==3.4.0 9 | chex==0.1.87 10 | click==8.1.7 11 | contourpy==1.3.0 12 | cycler==0.12.1 13 | diffrax==0.6.0 14 | docker-pycreds==0.4.0 15 | einops==0.8.0 16 | equinox==0.11.8 17 | etils==1.10.0 18 | filelock==3.16.1 19 | flax==0.10.1 20 | fonttools==4.54.1 21 | frozenlist==1.5.0 22 | fsspec==2024.10.0 23 | gitdb==4.0.11 24 | GitPython==3.1.43 25 | humanize==4.11.0 26 | hydra-core==1.3.2 27 | idna==3.10 28 | importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1726082825846/work 29 | importlib_resources==6.4.5 30 | iniconfig==2.0.0 31 | jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1729336343576/work 32 | jaxlib==0.4.34 33 | jaxtyping==0.2.34 34 | Jinja2==3.1.4 35 | kiwisolver==1.4.7 36 | lightning==2.4.0 37 | lightning-utilities==0.11.8 38 | lineax==0.0.7 39 | markdown-it-py==3.0.0 40 | MarkupSafe==3.0.2 41 | matplotlib==3.9.2 42 | mdurl==0.1.2 43 | ml-dtypes @ file:///home/conda/feedstock_root/build_artifacts/ml_dtypes_1726376268746/work 44 | mpmath==1.3.0 45 | msgpack==1.1.0 46 | multidict==6.1.0 47 | nest-asyncio==1.6.0 48 | networkx==3.4.2 49 | numpy==1.26.4 50 | nvidia-cublas-cu12==12.4.5.8 51 | nvidia-cuda-cupti-cu12==12.4.127 52 | nvidia-cuda-nvrtc-cu12==12.4.127 53 | nvidia-cuda-runtime-cu12==12.4.127 54 | nvidia-cudnn-cu12==9.1.0.70 55 | nvidia-cufft-cu12==11.2.1.3 56 | nvidia-curand-cu12==10.3.5.147 57 | nvidia-cusolver-cu12==11.6.1.9 58 | nvidia-cusparse-cu12==12.3.1.170 59 | nvidia-nccl-cu12==2.21.5 60 | nvidia-nvjitlink-cu12==12.4.127 61 | nvidia-nvtx-cu12==12.4.127 62 | omegaconf==2.3.0 63 | opt_einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1727392354687/work 64 | optax==0.2.3 65 | optimistix==0.0.9 66 | orbax-checkpoint==0.8.0 67 | packaging==24.1 68 | pillow==11.0.0 69 | platformdirs==4.3.6 70 | propcache==0.2.0 71 | protobuf==5.28.3 72 | psutil==6.1.0 73 | Pygments==2.18.0 74 | pyparsing==3.2.0 75 | python-dateutil==2.9.0.post0 76 | pytorch-lightning==2.4.0 77 | PyYAML==6.0.2 78 | requests==2.32.3 79 | rich==13.9.4 80 | scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1729480690820/work/dist/scipy-1.14.1-cp311-cp311-linux_x86_64.whl#sha256=1d6b83e78fb8e82dc54a9de4e73aed118775fa22ccdc582469171f891aa3e912 81 | sentry-sdk==2.17.0 82 | setproctitle==1.3.3 83 | six==1.16.0 84 | smmap==5.0.1 85 | sympy==1.13.1 86 | tensorstore==0.1.67 87 | toolz==1.0.0 88 | torch==2.5.1 89 | torchmetrics==1.5.1 90 | torchvision==0.20.1 91 | tqdm==4.66.6 92 | triton==3.1.0 93 | typeguard==2.13.3 94 | typing_extensions==4.12.2 95 | urllib3==2.2.3 96 | wandb==0.18.5 97 | yarl==1.17.1 98 | zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1726248574750/work 99 | --------------------------------------------------------------------------------