├── .gitignore
├── .vscode
└── launch.json
├── BNN
└── JaxLightning_BNN.py
├── LICENSE
├── README.md
├── ScoreBasedGenerativeModelling
├── configs
│ └── config.yaml
├── main.py
├── src
│ ├── ScoreBased_Data.py
│ ├── ScoreBased_Hyperparameters.py
│ └── ScoreBased_Models.py
└── submit.sh
├── assets
├── InANutshell.png
├── automaticoptimization.png
├── code.png
├── dataloader.png
└── now_kiss.jpeg
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | #Local files
10 | **/wandb/**
11 | **/lightning_logs/**
12 | **/data/**
13 | **/checkpoints/**
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | **/wandb/**
42 | **/lightning_logs/**
43 | **/data/**
44 | **/checkpoints/**
45 | **/outputs/**
46 | **/pcluster_logs/**
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 | cover/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 | db.sqlite3
75 | db.sqlite3-journal
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | .pybuilder/
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | # For a library or package, you might want to ignore these files since the code is
100 | # intended to run in multiple environments; otherwise, check them in:
101 | # .python-version
102 |
103 | # pipenv
104 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
105 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
106 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
107 | # install all needed dependencies.
108 | #Pipfile.lock
109 |
110 | # poetry
111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
112 | # This is especially recommended for binary packages to ensure reproducibility, and is more
113 | # commonly ignored for libraries.
114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
115 | #poetry.lock
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python Debugger: Current File",
9 | "type": "debugpy",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal"
13 | }
14 | ]
15 | }
--------------------------------------------------------------------------------
/BNN/JaxLightning_BNN.py:
--------------------------------------------------------------------------------
1 | import torch, numpy as np
2 | from torch import Tensor
3 | from torch.utils.data import DataLoader, TensorDataset, Dataset
4 | import pytorch_lightning as pl
5 | from pytorch_lightning.trainer import Trainer
6 | from pl_bolts.datamodules import MNISTDataModule
7 | import matplotlib.pyplot as plt
8 | from typing import Tuple
9 | import einops
10 | import jax, jax.numpy as jnp
11 | import optax, equinox as eqx
12 |
13 |
14 | def numpy_collate(batch):
15 | if isinstance(batch[0], np.ndarray):
16 | return np.stack(batch)
17 | elif isinstance(batch[0], (tuple, list)):
18 | transposed = zip(*batch)
19 | return [numpy_collate(samples) for samples in transposed]
20 | else:
21 | return np.array(batch)
22 |
23 |
24 | class NumpyTensorDataset(Dataset):
25 | r"""Dataset wrapping tensors.
26 |
27 | Each sample will be retrieved by indexing tensors along the first dimension.
28 |
29 | Args:
30 | *tensors (Tensor): tensors that have the same size of the first dimension.
31 | """
32 |
33 | tensors: Tuple[Tensor, ...]
34 |
35 | def __init__(self, *tensors: Tensor) -> None:
36 | """
37 | PyTorch: tensor.size() -> tuple -> tensor.size(0) gives batch dim
38 | Numpy: nparray.size() is not indexable, so we use shape
39 | """
40 | assert all(
41 | tensors[0].shape[0] == tensor.shape[0] for tensor in tensors
42 | ), "Size mismatch between tensors"
43 | self.tensors = tensors
44 |
45 | def __getitem__(self, index):
46 | return tuple(tensor[index] for tensor in self.tensors)
47 |
48 | def __len__(self):
49 | return self.tensors[0].shape[0]
50 |
51 |
52 | class RegressionDataModule(pl.LightningDataModule):
53 | num_samples = 200
54 | x_noise_std = 0.01
55 | y_noise_std = 0.2
56 |
57 | def __init__(self):
58 | super().__init__()
59 |
60 | def setup(self, stage: str) -> None:
61 | x = jnp.linspace(-0.35, 0.55, self.num_samples)
62 |
63 | x_noise = np.random.normal(0.0, self.x_noise_std, size=x.shape)
64 |
65 | std = np.linspace(0, self.y_noise_std, self.num_samples) # * y_noise_std
66 | # print(std.shape)
67 | non_stationary_noise = np.random.normal(loc=0, scale=std)
68 | y_noise = non_stationary_noise
69 |
70 | y = (
71 | x
72 | + 0.3 * jnp.sin(2 * jnp.pi * (x + x_noise))
73 | + 0.3 * jnp.sin(4 * jnp.pi * (x + x_noise))
74 | + y_noise
75 | )
76 |
77 | x, y = x.reshape(-1, 1), y.reshape(-1, 1)
78 |
79 | x = (x - x.mean(axis=0)) / x.std(axis=0)
80 | y = (y - x.mean(axis=0)) / y.std(axis=0)
81 |
82 | self.X = np.array(x)
83 | self.Y = np.array(y)
84 |
85 | def train_dataloader(self):
86 | return DataLoader(
87 | NumpyTensorDataset(self.X, self.Y),
88 | shuffle=True,
89 | batch_size=16,
90 | collate_fn=numpy_collate,
91 | drop_last=False,
92 | )
93 |
94 | def val_dataloader(self):
95 | return DataLoader(
96 | TensorDataset(self.X[:100], self.Y[:100]),
97 | shuffle=True,
98 | batch_size=32,
99 | collate_fn=numpy_collate,
100 | )
101 |
102 |
103 | class BayesLinear(eqx.Module):
104 | weight_mu: jax.numpy.ndarray
105 | weight_rho: jax.numpy.ndarray
106 | bias_mu: jax.numpy.ndarray
107 | bias_rho: jax.numpy.ndarray
108 |
109 | def __init__(self, in_size, out_size, key):
110 | wkey, bkey = jax.random.split(key)
111 | self.weight_mu = (
112 | jax.random.normal(wkey, (out_size, in_size)) / (out_size + in_size) ** 0.5
113 | )
114 | self.weight_rho = jax.numpy.full(shape=self.weight_mu.shape, fill_value=-3)
115 | self.bias_mu = jax.random.normal(bkey, (out_size,)) * 0.1
116 | self.bias_rho = jax.numpy.full(shape=self.bias_mu.shape, fill_value=-3)
117 |
118 | def __call__(self, x, key):
119 | w_eps = jax.random.normal(key=key, shape=self.weight_mu.shape)
120 | b_eps = jax.random.normal(key=key, shape=self.bias_mu.shape)
121 | w = self.weight_mu + jax.nn.softplus(self.weight_rho) * w_eps
122 | b = self.bias_mu + jax.nn.softplus(self.bias_rho) * b_eps
123 | # print(x.shape, w.shape, b.shape)
124 |
125 | return x @ w.T + b
126 |
127 | def kl_div(self):
128 | weight_scale = jax.nn.softplus(self.weight_rho)
129 | kl_div = jnp.log(1.0) - jnp.log(weight_scale)
130 | kl_div += (weight_scale**2 + (self.weight_mu - 0) ** 2) / (2) - 0.5
131 | return kl_div.sum()
132 |
133 |
134 | class BNN(eqx.Module):
135 | bnn: eqx.nn.Sequential
136 |
137 | def __init__(self, key):
138 | hidden = 51
139 | self.bnn = eqx.nn.Sequential(
140 | [
141 | BayesLinear(1, hidden, key),
142 | eqx.nn.Lambda(jax.nn.gelu),
143 | BayesLinear(hidden, hidden, key),
144 | eqx.nn.Lambda(jax.nn.gelu),
145 | BayesLinear(hidden, hidden, key),
146 | eqx.nn.Lambda(jax.nn.gelu),
147 | BayesLinear(hidden, hidden, key),
148 | eqx.nn.Lambda(jax.nn.gelu),
149 | BayesLinear(hidden, 1, key),
150 | ]
151 | )
152 |
153 | def __call__(self, x, key):
154 | return self.bnn(x, key=key)
155 |
156 | def kl_div(self):
157 | kl_div = 0
158 | for layer in self.bnn.layers:
159 | if type(layer) == BayesLinear:
160 | kl_div += layer.kl_div()
161 | return kl_div
162 |
163 |
164 | class JaxLightning(pl.LightningModule):
165 | def __init__(self):
166 | super().__init__()
167 | self.automatic_optimization = False
168 | self.key = jax.random.PRNGKey(1)
169 | self.MC = 5
170 |
171 | self.key, subkey = jax.random.split(self.key)
172 | self.bnn = BNN(key=subkey)
173 | self.global_step_ = 0
174 |
175 | @property
176 | def global_step(self):
177 | """
178 | self.global_step is an attribute without setter and is updated somewhere deep within Lightning
179 | Simply overwrite global_step as a property to access it normally in the LightningModule
180 | :return:
181 | """
182 | return self.global_step_
183 |
184 | def on_fit_start(self) -> None:
185 | self.num_data_samples = self.trainer.datamodule.num_samples
186 | self.viz_network("On Train Start")
187 |
188 | def training_step(self, batch):
189 | """Standard PyTorch Lightning training step ... but with Jax in it!"""
190 | self.global_step_ += 1
191 | data, target = (
192 | jnp.array(batch[0].reshape(-1, *batch[0].shape[1:])),
193 | jnp.array(batch[1]),
194 | )
195 | data = einops.repeat(data, "... -> b ...", b=self.MC)
196 | target = einops.repeat(target, "... -> b ...", b=self.MC)
197 |
198 | self.key, *subkeys = jax.random.split(
199 | self.key, num=self.MC + 1
200 | ) # creating new keys
201 | subkeys = jnp.stack(subkeys)
202 | """
203 | Jax in the middle of a Lightning module
204 | Call static gradient method from same PL module
205 | Calls in turn cost function doing the jit compiled forward, backward and gradient update step
206 | """
207 | loss, metrics, self.bnn, self.optim, self.opt_state = JaxLightning.make_step(
208 | self.bnn,
209 | data,
210 | target,
211 | self.num_data_samples,
212 | subkeys,
213 | self.optim,
214 | self.opt_state,
215 | )
216 | """All the logging and perks you love about Lightining"""
217 | dict = {
218 | "Loss": loss.item(),
219 | "global_step": self.global_step,
220 | "current_epoch": self.current_epoch,
221 | }
222 | self.log_dict(dict, prog_bar=True)
223 | return dict
224 |
225 | @staticmethod
226 | @eqx.filter_value_and_grad(has_aux=True)
227 | def criterion(model, x, y, num_samples, keys):
228 | """Jit-able criterion function including forward pass"""
229 | model_vmap = jax.vmap(
230 | model, in_axes=(0, 0)
231 | ) # takes [MC, B, F] features and [MC] keys
232 | pred = model_vmap(x, keys)
233 | assert pred.ndim == 3
234 | std = jax.lax.stop_gradient(pred.std(axis=0))
235 | mse = (y - pred.mean(axis=0)) ** 2
236 | nll = -(-0.5 * mse / std**2).sum() * num_samples
237 | kl = 1.0 * model.kl_div()
238 | return (nll + kl, {"nll": nll, "kl": kl, "std": std.mean(), "mse": mse.mean()})
239 |
240 | @staticmethod
241 | @eqx.filter_jit
242 | def make_step(model, x, y, num_samples, keys, optim, opt_state):
243 | """Jit-able gradient and parameter update"""
244 | (loss, metrics), grads = JaxLightning.criterion(model, x, y, num_samples, keys)
245 | updates, opt_state = optim.update(grads, opt_state)
246 | model = eqx.apply_updates(model, updates)
247 | return loss, metrics, model, optim, opt_state
248 |
249 | def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
250 | """
251 | In automatic_optimization=False mode, the global_step attribute is (for some reason) not automatically incremented, so we have to use our own should_stop criterion
252 | """
253 | if self.global_step >= self.trainer.max_steps and self.trainer.max_steps > 0:
254 | self.trainer.should_stop = True
255 |
256 | def on_train_end(self) -> None:
257 | self.viz_network(title="On Fit End")
258 |
259 | def viz_network(self, title=""):
260 | MC = 201
261 |
262 | self.key, *subkeys = jax.random.split(self.key, num=MC + 1) # creating new keys
263 | subkeys = jnp.stack(subkeys)
264 | pred_model = jax.jit(jax.vmap(self.bnn, in_axes=(0, 0)))
265 |
266 | viz_input = jnp.linspace(-4, 4, 100).reshape(-1, 1)
267 | x_mc = einops.repeat(viz_input, "... -> b ...", b=MC)
268 | print(x_mc.shape, len(subkeys))
269 | mc_pred = pred_model(x_mc, subkeys)
270 |
271 | fig = plt.figure(figsize=(10, 10))
272 | for pred in mc_pred.squeeze():
273 | plt.plot(
274 | viz_input, pred, color="red", alpha=1 / min(mc_pred.shape[0], 50.0)
275 | )
276 | X, Y = self.trainer.datamodule.X, self.trainer.datamodule.Y
277 | plt.scatter(X, Y, alpha=1 / min(mc_pred.shape[0], 1.0), s=5)
278 | plt.ylim(-3, 3)
279 | plt.title(title)
280 | plt.show()
281 |
282 | def configure_optimizers(self):
283 | self.optim = optax.adam(0.001)
284 | self.opt_state = self.optim.init(eqx.filter(self.bnn, eqx.is_array))
285 |
286 |
287 | print(jax.devices())
288 |
289 | bnn = JaxLightning()
290 | dm = RegressionDataModule()
291 | trainer = Trainer(max_steps=4000)
292 | trainer.fit(bnn, dm)
293 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ludwigwinkler
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # JaxLightning
2 | PyTorch Lightning + Jax = nice
3 |
4 | # PyTorch Lightning
5 |
6 | The package has become the go-to standard of ML research for PyTorch.
7 | Above all, it removes all the boiler plate code that you usually have to write for kicking off a simple experiment.
8 | Additionally, you got amazing logging, general code structure, data management via LightningDataModules and other templates making quick iteration a breeze.
9 |
10 | # Jax
11 |
12 | Recent packages such as Equinox and Treex are at the top level very similar in structure and handling like PyTorch.
13 | This makes the code very readable and succinct.
14 | The biggest advantage of Jax is probably its clean functional programming (I've come around to that) and its speed.
15 | Vmap, derivatives in all directions and automatic accelerator management (no more tensor.to(deviceXYZ)) is also part of the gift box.
16 | Also, the explicit random keys remove an entire library of possible problems.
17 |
18 | You can find a speed comparison at the deep learning course of [UvA](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html).
19 | The take away of that PyTorch vs Jax comparison is that Jax excels at compiling forward and backward passes consisting of lots of SIMD (single instruction multiple data) instructions such as separate calls of small-ish kernel convolutions on locally independent data.
20 | The speed-up can be considerable at **2.5X-3.4X** but is expected to regress to PyTorchs performance to drop once the time spend on executing those calls increases relative to the function call itself.
21 | That means that for large batch sizes, single instructions on very large tensors will depend on the hardware and the XLA compiler can do little about that.
22 |
23 | # Can we get the best of both worlds? Yes, we can[1].
24 |
25 | References: [1] 44th President of the United States Barack Obama
26 |
27 | You can run any Jax code and thus any Jax neural network package inside of PyTorch Lightning, be it written in Flax/Haiku/Equinox/Treex and optimized with the common optimization libraries.
28 |
29 | The necessary steps are:
30 |
31 | **1. Turn off via `automatic_optimization=False` ... we gotta do that stuff ourselves.**
32 |
33 |
34 |
35 | **2. Load data strictly in 'numpy' mode by modifying the `collate_fn` in PyTorch.**
36 |
37 |
38 |
39 | **3. Run the forward, backward and gradient update step within lightning with `@staticmethod` decorators.**
40 |
41 |
42 |
43 | Done. Now we got the best of Jax with the best of Lightning. :)
44 |
45 | ### Tensors vs Arrays
46 |
47 | The main idea of combining the great and convenient code structure of PyTorch Lightning with the versatility of Jax is to restrict PyTorch Lightning to pure Numpy/Jax.Numpy until the data 'reaches' the Jax model.
48 | Therefore we can reuse almost all DataModules and DataSets and remove the single line, where data is cast to torch.Tensors.
49 | Thus the dataloader/datamodules etc restricted to Numpy/Jax.Numpy operations.
50 |
51 | ### Optimization
52 |
53 | Secondly, we can't use PyTorch Lightning automatic optimization which makes setting up experiments in PL so convenient.
54 | But at the same time Jax does automatic device placement and moving tensors to the correct devices.
55 | Thus by simply setting the class variable `automatic_optimization=False` we gain complete control over all gradient computations and gradient descent optimization and tell PL that we'll do our optimization on our own.
56 |
57 | Since Jax requires pure functions, all we have to do is make the forward step a `@staticmethod` without the `self` argument.
58 | Similarly, we can create a static gradient function in the same way.
59 |
60 | Question: What do we gain?
61 |
62 | Answer: We can jit-compile the entire forward and backward pass with **JAX** with a simple decorator inside the training setup of **Pytorch Lightning**.
63 |
64 | Thus PyTorch Lightning takes care of all the data set management, the logging, the tracking and the overall training loop structure with all the convenience PL is famous for, and Jax does the fast computing inside of PL.
65 |
66 | Everybody wins ...
67 |
68 | ### Examples
69 |
70 | There are two examples how I used PyTorch Lightning for my own Bayesian Neural Networks and Score Based Generative Modelling ([used code by Patrick Kidger](https://docs.kidger.site/equinox/examples/score_based_diffusion/)).
71 |
72 | 
73 |
--------------------------------------------------------------------------------
/ScoreBasedGenerativeModelling/configs/config.yaml:
--------------------------------------------------------------------------------
1 | logging: online
2 | project: jax_score_based
3 | load_last_checkpoint: False
4 | show: True
5 | seed: 12345
6 | fast_dev_run: 0
7 | save_figs: False
8 | watch_model_training: False
9 | hparams_id: None
10 | experiment: mnist
--------------------------------------------------------------------------------
/ScoreBasedGenerativeModelling/main.py:
--------------------------------------------------------------------------------
1 | import sys, argparse
2 | import pathlib
3 | from pathlib import Path
4 | import matplotlib.pyplot as plt
5 | import wandb
6 | import jax, jax.numpy as jnp, equinox as eqx, optax, jax.random as jr
7 | import diffrax as dfx
8 | import functools as ft
9 | import pytorch_lightning as pl
10 | from pytorch_lightning.trainer import Trainer
11 | import torch, einops
12 |
13 | from pytorch_lightning.loggers import WandbLogger
14 |
15 | import hydra
16 | from omegaconf import DictConfig
17 |
18 |
19 | # filepath = Path().resolve()
20 | # phd_path = Path(str(filepath).split('PhD')[0] + 'PhD')
21 | # jax_path = Path(str(filepath).split('Jax')[0] + 'Jax')
22 | # score_path = Path(str(filepath).split('Jax')[0] + 'Jax/ScoreBasedGenerativeModelling')
23 | # sys.path.append(str(phd_path))
24 | # sys.path.append(str(jax_path))
25 | from src.ScoreBased_Data import MNISTDataModule
26 | from src.ScoreBased_Models import Mixer2d
27 | from src.ScoreBased_Hyperparameters import str2bool, process_hparams
28 |
29 |
30 | class JaxLightning(pl.LightningModule):
31 | def __init__(self, **kwargs):
32 | super().__init__()
33 | self.save_hyperparameters()
34 | self.automatic_optimization = False
35 |
36 | self.key = jax.random.PRNGKey(1)
37 | self.key, self.model_key, self.train_key, self.loader_key, self.sample_key = (
38 | jax.random.split(self.key, 5)
39 | )
40 |
41 | self.key, subkey = jax.random.split(self.key)
42 | self.model = Mixer2d(
43 | (1, 28, 28),
44 | patch_size=4,
45 | hidden_size=64,
46 | mix_patch_size=512,
47 | mix_hidden_size=512,
48 | num_blocks=4,
49 | t1=10,
50 | key=subkey,
51 | )
52 |
53 | self.t1 = 10.0
54 | self.int_beta = lambda t: t
55 | self.weight = lambda t: 1 - jnp.exp(-self.int_beta(t))
56 | self.dt0 = 0.1
57 | self.samples = 10
58 |
59 | self.global_step_ = 0
60 |
61 | self.configure_optimizers()
62 |
63 | def on_fit_start(self) -> None:
64 | pathlib.Path("checkpoints").mkdir(exist_ok=True)
65 |
66 | if self.hparams.load_last_checkpoint:
67 | try:
68 | self.model = eqx.tree_deserialise_leaves(
69 | f"checkpoints/ScoreBased/last.eqx", self.model
70 | )
71 | print("Loaded weights")
72 | except:
73 | print("Didnt load weights")
74 |
75 | self.logger.log_image(
76 | key="Samples P1",
77 | images=[wandb.Image(self.sample(), caption="Samples P1")],
78 | )
79 |
80 | # def on_fit_end(self):
81 | # pathlib.Path.mkdir(f"checkpoints/ScoreBased", parents=True, exist_ok=True)
82 | # eqx.tree_serialise_leaves(f"checkpoints/ScoreBased/last.eqx", self.model)
83 |
84 | def training_step(self, batch):
85 | data = batch[0]
86 | value, self.model, self.train_key, self.opt_state = JaxLightning.make_step(
87 | self.model,
88 | self.weight,
89 | self.int_beta,
90 | data,
91 | self.t1,
92 | self.train_key,
93 | self.opt_state,
94 | self.optim.update,
95 | )
96 | dict_ = {"loss": torch.scalar_tensor(value.item())}
97 | self.log_dict(dict_, prog_bar=True)
98 | self.global_step_ += 1
99 | return dict_
100 |
101 | def sample(self):
102 | self.sample_key, *sample_key = jr.split(self.sample_key, self.samples**2 + 1)
103 | sample_key = jnp.stack(sample_key)
104 | sample_fn = ft.partial(
105 | JaxLightning.single_sample_fn,
106 | self.model,
107 | self.int_beta,
108 | (1, 28, 28),
109 | self.dt0,
110 | self.t1,
111 | )
112 | sample = jax.vmap(sample_fn)(sample_key)
113 | # sample = data_mean + data_std * sample
114 | # sample = jnp.clip(sample, data_min, data_max)
115 | sample = einops.rearrange(
116 | sample, "(n1 n2) 1 h w -> (n1 h) (n2 w)", n1=self.samples, n2=self.samples
117 | )
118 | fig = plt.figure()
119 | plt.imshow(sample, cmap="Greys")
120 | plt.axis("off")
121 | plt.title(f"{self.global_step_}")
122 | plt.tight_layout()
123 | if self.hparams.show:
124 | plt.show()
125 | return fig
126 |
127 | def validation_step(self, batch):
128 | data = batch[0]
129 | value, self.model, self.train_key, self.opt_state = JaxLightning.make_step(
130 | self.model,
131 | self.weight,
132 | self.int_beta,
133 | data,
134 | self.t1,
135 | self.train_key,
136 | self.opt_state,
137 | self.optim.update,
138 | )
139 | # dict_ = {"loss": torch.scalar_tensor(value.item())}
140 | self.log("Val_Loss", jnp.asarray(value).item(), prog_bar=True, batch_size=1)
141 |
142 | def on_validation_epoch_end(self) -> None:
143 | # pathlib.Path.mkdir(f"checkpoints/ScoreBased", parents=True, exist_ok=True)
144 | # eqx.tree_serialise_leaves(f"checkpoints/ScoreBased/last.eqx", self.model)
145 | self.logger.log_image(
146 | key="Samples P1",
147 | images=[wandb.Image(self.sample(), caption="Samples P1")],
148 | )
149 |
150 | def configure_optimizers(self):
151 | self.optim = optax.adam(3e-4)
152 | self.opt_state = self.optim.init(eqx.filter(self.model, eqx.is_inexact_array))
153 |
154 | @staticmethod
155 | @eqx.filter_jit
156 | def single_sample_fn(model, int_beta, data_shape, dt0, t1, key):
157 | """
158 | Sampling a single trajectory starting from normal noise at t1 and recovering data distribution at t0
159 | :param model:
160 | :param int_beta:
161 | :param data_shape:
162 | :param dt0:
163 | :param t1:
164 | :param key:
165 | :return:
166 | """
167 |
168 | def drift(t, y, args):
169 | """
170 | compute time derivative of function dß(t)/dt
171 | Noising SDE: dx(t) = -1/2 ß(t) x(t) dt + ß(t)^1/2 dW_t -> μ(x(t)) = - 1/2 ß(t) x(t) dt and σ^2 = ß(t)
172 | Reverse SDE: μ(x(tau)) = 1/2 ß(t) x(t) + ß(t) ∇ log p
173 | :param t:
174 | :param y:
175 | :param args:
176 | :return:
177 | """
178 | t = jnp.array(t)
179 | _, beta = jax.jvp(fun=int_beta, primals=(t,), tangents=(jnp.ones_like(t),))
180 | return (
181 | -0.5 * beta * (y + model(t, y))
182 | ) # negative because we use -dt0 when solving
183 |
184 | term = dfx.ODETerm(drift)
185 | solver = dfx.Tsit5()
186 | t0 = 0
187 | y1 = jr.normal(
188 | key, data_shape
189 | ) # noise at t1, from which integrate backwards to data distribution
190 | # reverse time, solve from t1 to t0
191 | sol = dfx.diffeqsolve(
192 | terms=term,
193 | solver=solver,
194 | t0=t1,
195 | t1=t0,
196 | dt0=-dt0,
197 | y0=y1,
198 | # adjoint=dfx.NoAdjoint(),
199 | )
200 | return sol.ys[0]
201 |
202 | @staticmethod
203 | def single_loss_fn(model, weight, int_beta, data, t, key):
204 | """
205 | OU process provides analytical mean and variance
206 | int_beta(t) = ß = θ
207 | E[X_t] = μ + exp[-θ t] ( X_0 - μ) w/ μ=0 gives =X_0 * exp[ - θ t ]
208 | V[X_t] = σ^2/(2θ) ( 1 - exp(-2 θ t) ) w/ σ^2=ß=θ gives = 1 - exp(-2 ß t)
209 | :param model:
210 | :param weight:
211 | :param int_beta:
212 | :param data:
213 | :param t:
214 | :param key:
215 | :return:
216 | """
217 | mean = data * jnp.exp(-0.5 * int_beta(t)) # analytical mean of OU process
218 | var = jnp.maximum(
219 | 1 - jnp.exp(-int_beta(t)), 1e-5
220 | ) # analytical variance of OU process
221 | std = jnp.sqrt(var)
222 | noise = jr.normal(key, data.shape)
223 | y = mean + std * noise
224 | pred = model(t, y)
225 | return weight(t) * jnp.mean((pred + noise / std) ** 2) # loss
226 |
227 | @staticmethod
228 | def batch_loss_fn(model, weight, int_beta, data, t1, key):
229 | batch_size = data.shape[0]
230 | tkey, losskey = jr.split(key)
231 | losskey = jr.split(losskey, batch_size)
232 | """
233 | Low-discrepancy sampling over t to reduce variance
234 | by sampling very evenly by sampling uniformly and independently from (t1-t0)/batch_size bins
235 | t = [U(0,1), U(1,2), U(2,3), ...]
236 | """
237 | t = jr.uniform(tkey, (batch_size,), minval=0, maxval=t1 / batch_size)
238 | t = t + (t1 / batch_size) * jnp.arange(batch_size)
239 | """ Fixing the first three arguments of single_loss_fn, leaving data, t and key as input """
240 | loss_fn = ft.partial(JaxLightning.single_loss_fn, model, weight, int_beta)
241 | loss_fn = jax.vmap(loss_fn)
242 | return jnp.mean(loss_fn(data, t, losskey))
243 |
244 | @staticmethod
245 | @eqx.filter_jit
246 | def make_step(model, weight, int_beta, data, t1, key, opt_state, opt_update):
247 | loss_fn = eqx.filter_value_and_grad(JaxLightning.batch_loss_fn)
248 | loss, grads = loss_fn(model, weight, int_beta, data, t1, key)
249 | updates, opt_state = opt_update(grads, opt_state)
250 | model = eqx.apply_updates(model, updates)
251 | key = jr.split(key, 1)[0]
252 | return loss, model, key, opt_state
253 |
254 |
255 | # hparams = argparse.ArgumentParser()
256 | # logging = [0, 1, 1 if torch.cuda.is_available() else 0][-1]
257 | # hparams = JaxLightning.args(
258 | # hparams,
259 | # logging=["disabled", "online"][logging],
260 | # project="jax_score_based",
261 | # load_last_checkpoint=True,
262 | # show=True,
263 | # seed=[12345, 2345, 98765][1],
264 | # fast_dev_run=0,
265 | # save_figs=[False, True][0],
266 | # watch_model_training=[False, True][0],
267 | # )
268 |
269 | # temp_args, _ = hparams.parse_known_args()
270 |
271 |
272 | @hydra.main(
273 | version_base=None,
274 | config_name="config",
275 | config_path="configs",
276 | )
277 | def main(cfg: DictConfig) -> None:
278 | print(cfg)
279 |
280 | hparams = process_hparams(cfg, print_hparams=True)
281 |
282 | logger = WandbLogger(
283 | project=hparams.project, name=hparams.experiment, mode=hparams.logging
284 | )
285 |
286 | dm = MNISTDataModule()
287 | dm.prepare_data()
288 | dm.setup()
289 |
290 | # exit()
291 |
292 | scorebased = JaxLightning(**hparams)
293 |
294 | trainer = Trainer(
295 | max_steps=1_000_000,
296 | accelerator="cpu",
297 | logger=logger,
298 | check_val_every_n_epoch=10,
299 | )
300 | trainer.fit(scorebased, dm)
301 |
302 |
303 | if __name__ == "__main__":
304 | main()
305 |
--------------------------------------------------------------------------------
/ScoreBasedGenerativeModelling/src/ScoreBased_Data.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from pathlib import Path
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torchvision
6 | import torchvision.datasets
7 | from torch.utils.data import DataLoader
8 |
9 | filepath = Path().resolve()
10 | phd_path = Path(str(filepath).split("PhD")[0] + "PhD")
11 | jax_path = Path(str(filepath).split("Jax")[0] + "Jax")
12 | score_path = Path(str(filepath).split("Jax")[0] + "Jax/ScoreBasedGenerativeModelling")
13 | sys.path.append(str(phd_path))
14 |
15 |
16 | def numpy_collate(batch):
17 | if isinstance(batch[0], np.ndarray):
18 | return np.stack(batch)
19 | elif isinstance(batch[0], (tuple, list)):
20 | transposed = zip(*batch)
21 | return [numpy_collate(samples) for samples in transposed]
22 | else:
23 | return np.array(batch)
24 |
25 |
26 | class MNIST(torchvision.datasets.MNIST):
27 | """
28 | Numpy version of MNIST
29 | """
30 |
31 | def __init__(self, **kwargs):
32 | super().__init__(**kwargs)
33 |
34 | def __getitem__(self, index: int):
35 | """
36 | Args:
37 | index (int): Index
38 |
39 | Returns:
40 | tuple: (image, target) where target is index of the target class.
41 | """
42 |
43 | img, target = self.data[index], int(self.targets[index])
44 |
45 | img = img.numpy()
46 |
47 | if self.transform is not None:
48 | img = self.transform(img)
49 |
50 | if self.target_transform is not None:
51 | target = self.target_transform(target)
52 |
53 | return img, target
54 |
55 |
56 | class MNISTDataModule(pl.LightningDataModule):
57 | def __init__(self):
58 | super().__init__()
59 | self.transform = torchvision.transforms.Compose(
60 | [
61 | torchvision.transforms.Lambda(lambda x: x / 255),
62 | torchvision.transforms.Lambda(lambda x: x.reshape(1, 28, 28)),
63 | ]
64 | )
65 |
66 | def prepare_data(self) -> None:
67 | if not os.path.isdir(path := score_path / "data"):
68 | torchvision.datasets.MNIST(root=path, download=True)
69 |
70 | def setup(self, stage=None):
71 | data_path = score_path / "data"
72 |
73 | self.train_data = MNIST(root=data_path, train=True, transform=self.transform)
74 | self.val_data = MNIST(root=data_path, train=False, transform=self.transform)
75 |
76 | def train_dataloader(self):
77 | return DataLoader(
78 | self.train_data, batch_size=64, shuffle=True, collate_fn=numpy_collate
79 | )
80 |
81 | def val_dataloader(self):
82 | return DataLoader(
83 | self.val_data, batch_size=64, shuffle=True, collate_fn=numpy_collate
84 | )
85 |
--------------------------------------------------------------------------------
/ScoreBasedGenerativeModelling/src/ScoreBased_Hyperparameters.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import hashlib
4 | import numbers
5 | import math
6 | import os
7 | import sys
8 | import time
9 | from pathlib import Path
10 | from omegaconf import OmegaConf
11 |
12 | from pytorch_lightning.loggers import WandbLogger
13 |
14 | import wandb
15 |
16 |
17 | def str2bool(v):
18 | if isinstance(v, bool):
19 | return v
20 | elif type(v) == str:
21 | if v.lower() in ("yes", "true", "t", "y", "1"):
22 | return True
23 | elif v.lower() in ("no", "false", "f", "n", "0"):
24 | return False
25 | elif isinstance(v, numbers.Number):
26 | assert v in [0, 1]
27 | if v == 1:
28 | return True
29 | if v == 0:
30 | return False
31 | else:
32 | raise argparse.ArgumentTypeError(f"Invalid Value: {type(v)}")
33 |
34 |
35 | def process_hparams(hparams, print_hparams=True):
36 | # hparams = hparams.parse_args()
37 |
38 | if hparams.logging == "online":
39 | hparams.show = False
40 |
41 | """Create HParam ID for saving and loading checkpoints"""
42 | hashable_config = copy.deepcopy(hparams)
43 | id = hashlib.sha1(
44 | repr(sorted(hashable_config.__dict__.items())).encode()
45 | ).hexdigest() # int -> abs(int) -> str(abs(int)))
46 | hparams.hparams_id = id
47 |
48 | if print_hparams:
49 | print(OmegaConf.to_container(hparams, resolve=True))
50 | # [print(f"\t {key}: {value}") for key, value in sorted(hparams.__dict__.items())]
51 |
52 | return hparams
53 |
--------------------------------------------------------------------------------
/ScoreBasedGenerativeModelling/src/ScoreBased_Models.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import equinox as eqx
3 | import jax
4 | from flax import linen as fnn
5 | from jax import numpy as jnp, random as jr
6 |
7 |
8 | class Encoder(fnn.Module):
9 | features: int = 64
10 | training: bool = True
11 |
12 | @fnn.compact
13 | def __call__(self, x):
14 | z1 = fnn.Conv(self.features, kernel_size=(3, 3))(x)
15 | z1 = fnn.relu(z1)
16 | z1 = fnn.Conv(self.features, kernel_size=(3, 3))(z1)
17 | z1 = fnn.BatchNorm(use_running_average=not self.training)(z1)
18 | z1 = fnn.relu(z1)
19 | z1_pool = fnn.max_pool(z1, window_shape=(2, 2), strides=(2, 2))
20 |
21 | z2 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z1_pool)
22 | z2 = fnn.relu(z2)
23 | z2 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z2)
24 | z2 = fnn.BatchNorm(use_running_average=not self.training)(z2)
25 | z2 = fnn.relu(z2)
26 | z2_pool = fnn.max_pool(z2, window_shape=(2, 2), strides=(2, 2))
27 |
28 | z3 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z2_pool)
29 | z3 = fnn.relu(z3)
30 | z3 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z3)
31 | z3 = fnn.BatchNorm(use_running_average=not self.training)(z3)
32 | z3 = fnn.relu(z3)
33 | z3_pool = fnn.max_pool(z3, window_shape=(2, 2), strides=(2, 2))
34 |
35 | z4 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z3_pool)
36 | z4 = fnn.relu(z4)
37 | z4 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z4)
38 | z4 = fnn.BatchNorm(use_running_average=not self.training)(z4)
39 | z4 = fnn.relu(z4)
40 | z4_dropout = fnn.Dropout(0.5, deterministic=False)(z4)
41 | z4_pool = fnn.max_pool(z4_dropout, window_shape=(2, 2), strides=(2, 2))
42 |
43 | z5 = fnn.Conv(self.features * 16, kernel_size=(3, 3))(z4_pool)
44 | z5 = fnn.relu(z5)
45 | z5 = fnn.Conv(self.features * 16, kernel_size=(3, 3))(z5)
46 | z5 = fnn.BatchNorm(use_running_average=not self.training)(z5)
47 | z5 = fnn.relu(z5)
48 | z5_dropout = fnn.Dropout(0.5, deterministic=False)(z5)
49 |
50 | return z1, z2, z3, z4_dropout, z5_dropout
51 |
52 |
53 | class Decoder(fnn.Module):
54 | features: int = 64
55 | training: bool = True
56 |
57 | @fnn.compact
58 | def __call__(self, z1, z2, z3, z4_dropout, z5_dropout):
59 | z6_up = jax.image.resize(z5_dropout,
60 | shape=(z5_dropout.shape[0], z5_dropout.shape[1] * 2, z5_dropout.shape[2] * 2,
61 | z5_dropout.shape[3]),
62 | method='nearest')
63 | z6 = fnn.Conv(self.features * 8, kernel_size=(2, 2))(z6_up)
64 | z6 = fnn.relu(z6)
65 | z6 = jnp.concatenate([z4_dropout, z6], axis=3)
66 | z6 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z6)
67 | z6 = fnn.relu(z6)
68 | z6 = fnn.Conv(self.features * 8, kernel_size=(3, 3))(z6)
69 | z6 = fnn.BatchNorm(use_running_average=not self.training)(z6)
70 | z6 = fnn.relu(z6)
71 |
72 | z7_up = jax.image.resize(z6,
73 | shape=(z6.shape[0], z6.shape[1] * 2, z6.shape[2] * 2, z6.shape[3]),
74 | method='nearest')
75 | z7 = fnn.Conv(self.features * 4, kernel_size=(2, 2))(z7_up)
76 | z7 = fnn.relu(z7)
77 | z7 = jnp.concatenate([z3, z7], axis=3)
78 | z7 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z7)
79 | z7 = fnn.relu(z7)
80 | z7 = fnn.Conv(self.features * 4, kernel_size=(3, 3))(z7)
81 | z7 = fnn.BatchNorm(use_running_average=not self.training)(z7)
82 | z7 = fnn.relu(z7)
83 |
84 | z8_up = jax.image.resize(z7,
85 | shape=(z7.shape[0], z7.shape[1] * 2, z7.shape[2] * 2, z7.shape[3]),
86 | method='nearest')
87 | z8 = fnn.Conv(self.features * 2, kernel_size=(2, 2))(z8_up)
88 | z8 = fnn.relu(z8)
89 | z8 = jnp.concatenate([z2, z8], axis=3)
90 | z8 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z8)
91 | z8 = fnn.relu(z8)
92 | z8 = fnn.Conv(self.features * 2, kernel_size=(3, 3))(z8)
93 | z8 = fnn.BatchNorm(use_running_average=not self.training)(z8)
94 | z8 = fnn.relu(z8)
95 |
96 | z9_up = jax.image.resize(z8,
97 | shape=(z8.shape[0], z8.shape[1] * 2, z8.shape[2] * 2, z8.shape[3]),
98 | method='nearest')
99 | z9 = fnn.Conv(self.features, kernel_size=(2, 2))(z9_up)
100 | z9 = fnn.relu(z9)
101 | z9 = jnp.concatenate([z1, z9], axis=3)
102 | z9 = fnn.Conv(self.features, kernel_size=(3, 3))(z9)
103 | z9 = fnn.relu(z9)
104 | z9 = fnn.Conv(self.features, kernel_size=(3, 3))(z9)
105 | z9 = fnn.BatchNorm(use_running_average=not self.training)(z9)
106 | z9 = fnn.relu(z9)
107 |
108 | y = fnn.Conv(1, kernel_size=(1, 1))(z9)
109 | y = fnn.sigmoid(y)
110 |
111 | return y
112 |
113 |
114 | class UNet(fnn.Module):
115 | features: int = 64
116 | training: bool = True
117 |
118 | @fnn.compact
119 | def __call__(self, x):
120 | z1, z2, z3, z4_dropout, z5_dropout = Encoder(self.training)(x)
121 | y = Decoder(self.training)(z1, z2, z3, z4_dropout, z5_dropout)
122 |
123 | return y
124 |
125 |
126 | class MixerBlock(eqx.Module):
127 | patch_mixer: eqx.nn.MLP
128 | hidden_mixer: eqx.nn.MLP
129 | norm1: eqx.nn.LayerNorm
130 | norm2: eqx.nn.LayerNorm
131 |
132 | def __init__(self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, key):
133 | tkey, ckey = jr.split(key, 2)
134 | self.patch_mixer = eqx.nn.MLP(num_patches, num_patches, mix_patch_size, depth=1, key=tkey)
135 | self.hidden_mixer = eqx.nn.MLP(hidden_size, hidden_size, mix_hidden_size, depth=1, key=ckey)
136 | self.norm1 = eqx.nn.LayerNorm((hidden_size, num_patches))
137 | self.norm2 = eqx.nn.LayerNorm((num_patches, hidden_size))
138 |
139 | def __call__(self, y):
140 | y = y + jax.vmap(self.patch_mixer)(self.norm1(y))
141 | y = einops.rearrange(y, "c p -> p c")
142 | y = y + jax.vmap(self.hidden_mixer)(self.norm2(y))
143 | y = einops.rearrange(y, "p c -> c p")
144 | return y
145 |
146 |
147 | class Mixer2d(eqx.Module):
148 | conv_in: eqx.nn.Conv2d
149 | conv_out: eqx.nn.ConvTranspose2d
150 | blocks: list
151 | norm: eqx.nn.LayerNorm
152 | t1: float
153 |
154 | def __init__(self, img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, t1, *, key, ):
155 | input_size, height, width = img_size
156 | assert (height % patch_size) == 0
157 | assert (width % patch_size) == 0
158 | num_patches = (height // patch_size) * (width // patch_size)
159 | inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks)
160 |
161 | self.conv_in = eqx.nn.Conv2d(input_size + 1, hidden_size, patch_size, stride=patch_size, key=inkey)
162 | self.conv_out = eqx.nn.ConvTranspose2d(hidden_size, input_size, patch_size, stride=patch_size, key=outkey)
163 | self.blocks = [MixerBlock(num_patches, hidden_size, mix_patch_size, mix_hidden_size, key=bkey) for bkey in
164 | bkeys]
165 | self.norm = eqx.nn.LayerNorm((hidden_size, num_patches))
166 | self.t1 = t1
167 |
168 | def __call__(self, t, y):
169 | t = t / self.t1
170 | _, height, width = y.shape
171 | t = einops.repeat(t, "-> 1 h w", h=height, w=width)
172 | y = jnp.concatenate([y, t])
173 | y = self.conv_in(y)
174 | _, patch_height, patch_width = y.shape
175 | y = einops.rearrange(y, "c h w -> c (h w)")
176 | for block in self.blocks:
177 | y = block(y)
178 | y = self.norm(y)
179 | y = einops.rearrange(y, "c (h w) -> c h w", h=patch_height, w=patch_width)
180 | return self.conv_out(y)
181 |
--------------------------------------------------------------------------------
/ScoreBasedGenerativeModelling/submit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -J train
3 | #SBATCH --nodes=1
4 | #SBATCH --ntasks-per-node=1
5 | #SBATCH --gres=gpu:1 # per node
6 | #SBATCH -p gpu2
7 | #SBATCH --cpus-per-task=8
8 | #SBATCH --output=/homefs/home/winklep4/Work/JaxLightning/pcluster_logs/%A/%A.%a-output.txt
9 | #SBATCH --error=/homefs/home/winklep4/Work/JaxLightning/pcluster_logs/%A/%A.%a-error.txt
10 | ### SBATCH --time 1-00:00:00
11 | #SBATCH --signal=SIGUSR1@90
12 |
13 | . ~/.bashrc
14 | nvidia-smi
15 |
16 | eval "$(micromamba shell hook --shell bash)"
17 | micromamba activate jax
18 | export WANDB__SERVICE_WAIT=300
19 | export HYDRA_FULL_ERROR=1
20 |
21 | wandb login --host https://genentech.wandb.io
22 | wandb login --relogin
23 | wandb artifact cache cleanup 50G
24 |
25 | echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
26 | echo "${PWD}"
27 |
28 | ### Set the wandb agent command -------------------------------
29 | # Set the wandb agent command
30 | WANDB_AGENT_CMD="
31 | wandb agent ludwig-winkler/protein-correction_testing/3hngzz7v
32 | "
33 |
34 | # Check if an argument is provided
35 | if [ -n "$1" ]; then
36 | WANDB_AGENT_CMD="$WANDB_AGENT_CMD --count $1"
37 | fi
38 |
39 | # Run the wandb agent command
40 | $WANDB_AGENT_CMD
41 |
42 | ### Set the wandb agent command -------------------------------
43 |
44 | python main.py
--------------------------------------------------------------------------------
/assets/InANutshell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/InANutshell.png
--------------------------------------------------------------------------------
/assets/automaticoptimization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/automaticoptimization.png
--------------------------------------------------------------------------------
/assets/code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/code.png
--------------------------------------------------------------------------------
/assets/dataloader.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/dataloader.png
--------------------------------------------------------------------------------
/assets/now_kiss.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ludwigwinkler/JaxLightning/19fdd24822e252ea13e8b44546a0243aa38d305a/assets/now_kiss.jpeg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | aiohappyeyeballs==2.4.3
3 | aiohttp==3.10.10
4 | aiosignal==1.3.1
5 | antlr4-python3-runtime==4.9.3
6 | attrs==24.2.0
7 | certifi==2024.8.30
8 | charset-normalizer==3.4.0
9 | chex==0.1.87
10 | click==8.1.7
11 | contourpy==1.3.0
12 | cycler==0.12.1
13 | diffrax==0.6.0
14 | docker-pycreds==0.4.0
15 | einops==0.8.0
16 | equinox==0.11.8
17 | etils==1.10.0
18 | filelock==3.16.1
19 | flax==0.10.1
20 | fonttools==4.54.1
21 | frozenlist==1.5.0
22 | fsspec==2024.10.0
23 | gitdb==4.0.11
24 | GitPython==3.1.43
25 | humanize==4.11.0
26 | hydra-core==1.3.2
27 | idna==3.10
28 | importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1726082825846/work
29 | importlib_resources==6.4.5
30 | iniconfig==2.0.0
31 | jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1729336343576/work
32 | jaxlib==0.4.34
33 | jaxtyping==0.2.34
34 | Jinja2==3.1.4
35 | kiwisolver==1.4.7
36 | lightning==2.4.0
37 | lightning-utilities==0.11.8
38 | lineax==0.0.7
39 | markdown-it-py==3.0.0
40 | MarkupSafe==3.0.2
41 | matplotlib==3.9.2
42 | mdurl==0.1.2
43 | ml-dtypes @ file:///home/conda/feedstock_root/build_artifacts/ml_dtypes_1726376268746/work
44 | mpmath==1.3.0
45 | msgpack==1.1.0
46 | multidict==6.1.0
47 | nest-asyncio==1.6.0
48 | networkx==3.4.2
49 | numpy==1.26.4
50 | nvidia-cublas-cu12==12.4.5.8
51 | nvidia-cuda-cupti-cu12==12.4.127
52 | nvidia-cuda-nvrtc-cu12==12.4.127
53 | nvidia-cuda-runtime-cu12==12.4.127
54 | nvidia-cudnn-cu12==9.1.0.70
55 | nvidia-cufft-cu12==11.2.1.3
56 | nvidia-curand-cu12==10.3.5.147
57 | nvidia-cusolver-cu12==11.6.1.9
58 | nvidia-cusparse-cu12==12.3.1.170
59 | nvidia-nccl-cu12==2.21.5
60 | nvidia-nvjitlink-cu12==12.4.127
61 | nvidia-nvtx-cu12==12.4.127
62 | omegaconf==2.3.0
63 | opt_einsum @ file:///home/conda/feedstock_root/build_artifacts/opt_einsum_1727392354687/work
64 | optax==0.2.3
65 | optimistix==0.0.9
66 | orbax-checkpoint==0.8.0
67 | packaging==24.1
68 | pillow==11.0.0
69 | platformdirs==4.3.6
70 | propcache==0.2.0
71 | protobuf==5.28.3
72 | psutil==6.1.0
73 | Pygments==2.18.0
74 | pyparsing==3.2.0
75 | python-dateutil==2.9.0.post0
76 | pytorch-lightning==2.4.0
77 | PyYAML==6.0.2
78 | requests==2.32.3
79 | rich==13.9.4
80 | scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1729480690820/work/dist/scipy-1.14.1-cp311-cp311-linux_x86_64.whl#sha256=1d6b83e78fb8e82dc54a9de4e73aed118775fa22ccdc582469171f891aa3e912
81 | sentry-sdk==2.17.0
82 | setproctitle==1.3.3
83 | six==1.16.0
84 | smmap==5.0.1
85 | sympy==1.13.1
86 | tensorstore==0.1.67
87 | toolz==1.0.0
88 | torch==2.5.1
89 | torchmetrics==1.5.1
90 | torchvision==0.20.1
91 | tqdm==4.66.6
92 | triton==3.1.0
93 | typeguard==2.13.3
94 | typing_extensions==4.12.2
95 | urllib3==2.2.3
96 | wandb==0.18.5
97 | yarl==1.17.1
98 | zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1726248574750/work
99 |
--------------------------------------------------------------------------------