├── .gitignore ├── .pre-commit-config.yaml ├── .python-version ├── LICENSE ├── README.md ├── experiments ├── __init__.py ├── batch_normal.py ├── batch_sizes.py ├── beautiful_mnist.py ├── cl_augment.py ├── cl_balance_labels.py ├── cl_digit_exp.py ├── cl_exp.py ├── cl_fashion_exp.py ├── continual_learning.py ├── continual_learning_fashion.py ├── long_run.py ├── low_precision.py ├── lr.py ├── lr_scale.py ├── lr_vendor16.py ├── market_depth.py ├── mnist_backprop_cmp.py ├── resnet18.py ├── scaling.py ├── unit_vector_mode.py └── utils.py ├── marketplace ├── __init__.py ├── continual_learning.py ├── nn.py ├── optimizers.py ├── random.py ├── training.py └── utils.py ├── plot ├── __init__.py ├── compare_with_all_at_once.py ├── plot_3d.py └── plot_continual_learning_fashion.py ├── pyproject.toml ├── tests ├── __init__.py ├── test_optimizers.py ├── test_random.py └── test_training.py └── uv.lock /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | #.idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | .idea 209 | *.safetensors 210 | runs 211 | mlruns 212 | mnist 213 | fashion_replay 214 | *.jsonl 215 | fashion_reply_article 216 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.11.6 5 | hooks: 6 | # Run the formatter. 7 | - id: ruff-format 8 | - repo: https://github.com/asottile/reorder_python_imports 9 | rev: v3.10.0 10 | hooks: 11 | - id: reorder-python-imports 12 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Launch Platform 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Marketplace 2 | Marketplace is a machine learning experiment aimed at training a model efficiently on a GPU without using backpropagation. 3 | The approach involves breaking down the layers of a machine learning model into smaller groups, running them with various parameter combinations. 4 | We select the best-performing parameter combination and mutate it with different parameter variants. 5 | To learn more about the concept, please refer to the articles: 6 | 7 | 1. [Marketplace: my first attempt at training without backprop on GPU efficiently](https://fangpenlin.com/posts/2025/08/18/marketplace-my-first-attempt-at-training-without-backprop-on-gpu-efficiently/) 8 | 2. [Marketplace V2 is all you need: A training algorithm on par with backprop that needs only forward pass](https://fangpenlin.com/posts/2025/09/02/marketplace-v2-is-all-you-need-a-training-algorithm-on-par-with-backprop/) 9 | 3. [Continual learning with the Marketplace algorithm: model learns new data through inference, not training](https://fangpenlin.com/posts/2025/09/09/continual-learning-with-marketplace-model-learns-new-data-with-mostly-inference/) 10 | 11 | For example, the [beautiful_mnist model](https://github.com/tinygrad/tinygrad/blob/c30a113b2a876cabaea1049601fea3a0b758c5b1/examples/beautiful_mnist.py) included in [Tinygrad](https://github.com/tinygrad/tinygrad)'s example folder can be broken down into three groups of layers: 12 | 13 | ```python 14 | from marketplace.training import Spec 15 | from marketplace.nn import Model 16 | from tinygrad import Tensor 17 | from tinygrad.nn import Conv2d 18 | from tinygrad.nn import InstanceNorm 19 | from tinygrad.nn import Linear 20 | 21 | [ 22 | Spec( 23 | model=Model( 24 | Conv2d(vendor_count, 1, 32, 5), 25 | Tensor.relu, 26 | Conv2d(vendor_count, 32, 32, 5), 27 | Tensor.relu, 28 | InstanceNorm(vendor_count, 32), 29 | Tensor.max_pool2d, 30 | ) 31 | ), 32 | Spec( 33 | model=Model( 34 | Conv2d(vendor_count, 32, 64, 3), 35 | Tensor.relu, 36 | Conv2d(vendor_count, 64, 64, 3), 37 | Tensor.relu, 38 | InstanceNorm(vendor_count, 64), 39 | Tensor.max_pool2d, 40 | lambda x: x.flatten(1), 41 | ), 42 | ), 43 | Spec( 44 | model=Model([Linear(vendor_count, 576, 10)]), 45 | ), 46 | ] 47 | ``` 48 | 49 | With that, we can run the model on GPU with different combinations of parameters. 50 | The following code is a simple example of how to run a forward pass of the model. 51 | 52 | ```python 53 | from tinygrad import Tensor 54 | from tinygrad import TinyJit 55 | from marketplace.training import forward 56 | from marketplace.optimizers import Optimizer 57 | 58 | @TinyJit 59 | def forward_step() -> tuple[Tensor, Tensor, Tensor]: 60 | samples = Tensor.randint(batch_size, high=X_train.shape[0]) 61 | x = X_train[samples] 62 | y = Y_train[samples] 63 | batch_logits, batch_paths = forward(marketplace, x) 64 | loss = Tensor.stack( 65 | *(logits.sparse_categorical_crossentropy(y) for logits in batch_logits), 66 | dim=0, 67 | ) 68 | best_loss, best_index = loss.topk(1, largest=False) 69 | best_index = best_index.squeeze(0) 70 | accuracy = ( 71 | (batch_logits[best_index].sigmoid().argmax(axis=1) == y).sum() / batch_size 72 | ) * 100 73 | return ( 74 | best_loss.realize(), 75 | accuracy.realize(), 76 | batch_paths[best_index].realize(), 77 | ) 78 | 79 | lr = Tensor(1e-1).contiguous().realize() 80 | optimizer = Optimizer( 81 | marketplace=marketplace, 82 | learning_rate=lr, 83 | ) 84 | best_loss, best_accuracy, best_path = forward_step() 85 | 86 | ``` 87 | 88 | Next, now we know the best parameters combination, we can mutate it with different variants of parameters. 89 | 90 | TODO: the following is outdated, needs to update 91 | 92 | ```python 93 | @TinyJit 94 | def mutate_step(best_path: Tensor): 95 | mutate( 96 | marketplace=marketplace, 97 | leading_path=best_path, 98 | jitter=lr, 99 | ) 100 | 101 | mutate_step(best_path) 102 | 103 | ``` 104 | 105 | That's it. 106 | We just trained a model without using backpropagation and relying on only the forward pass! 107 | By reepeating the process, we can train a model. 108 | Of course, this is still no match for the backprop training, but it's an interesting start. 109 | 110 | ## Experiments 111 | 112 | All of the experiments are in the `experiments` folder. 113 | To run the training, you can use the following command: 114 | 115 | ```bash 116 | CUDA=1 uv run python -m experiments.beautiful_mnist 117 | ``` 118 | 119 | It comes with some arguments to control the training, you can see them by running: 120 | 121 | ```bash 122 | uv run python -m experiments.beautiful_mnist --help 123 | ``` 124 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaunchPlatform/marketplace/6487c3ea4d5c2df208bc9ce10a627cef5e8eeace/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/batch_normal.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | from marketplace.multi_nn import MultiBatchNorm 9 | from marketplace.multi_nn import MultiInstanceNorm 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def main(): 15 | exp_id = ensure_experiment("Batch Normal vs Instance Normal") 16 | for batch_normal in [True, False]: 17 | with mlflow.start_run( 18 | run_name="batch-normal" if batch_normal else "instance-normal", 19 | experiment_id=exp_id, 20 | log_system_metrics=True, 21 | ): 22 | marketplace = make_marketplace( 23 | norm_cls=MultiBatchNorm if batch_normal else MultiInstanceNorm 24 | ) 25 | train( 26 | step_count=10_000, 27 | batch_size=512, 28 | initial_forward_pass=1, 29 | initial_lr=1e-3, 30 | lr_decay_rate=1e-4, 31 | marketplace=marketplace, 32 | ) 33 | 34 | 35 | if __name__ == "__main__": 36 | logging.basicConfig(level=logging.INFO) 37 | main() 38 | -------------------------------------------------------------------------------- /experiments/batch_sizes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | PYRAMID32_HALF_UPSTREAM_STRUCTURE = [ 12 | # layer 0 13 | (2, 0), 14 | # layer 1 15 | (4, 2), 16 | # layer 2 (N/A) 17 | (4, 4), 18 | # layer 3 19 | (8, 4), 20 | # layer 4 21 | (16, 8), 22 | # layer 5 (N/A) 23 | (16, 16), 24 | # layer 6 25 | (32, 16), 26 | ] 27 | 28 | 29 | def main(): 30 | exp_id = ensure_experiment("Batch Size V3 (Sticky Leader)") 31 | for batch_size, forward_pass in [ 32 | (32, 1), 33 | (32, 2), 34 | (32, 4), 35 | (32, 8), 36 | (32, 16), 37 | (64, 1), 38 | (64, 2), 39 | (64, 4), 40 | (64, 8), 41 | (128, 1), 42 | (128, 2), 43 | (128, 4), 44 | (256, 1), 45 | (256, 2), 46 | (512, 1), 47 | ]: 48 | with mlflow.start_run( 49 | run_name=f"batch-size-{batch_size}-fw-{forward_pass}", 50 | experiment_id=exp_id, 51 | log_system_metrics=True, 52 | ): 53 | marketplace = make_marketplace(PYRAMID32_HALF_UPSTREAM_STRUCTURE) 54 | train( 55 | step_count=10_000, 56 | batch_size=batch_size, 57 | initial_forward_pass=forward_pass, 58 | initial_lr=1e-3, 59 | lr_decay_rate=1e-4, 60 | sticky_leaders=True, 61 | marketplace=marketplace, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | logging.basicConfig(level=logging.INFO) 67 | main() 68 | -------------------------------------------------------------------------------- /experiments/beautiful_mnist.py: -------------------------------------------------------------------------------- 1 | # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 2 | import functools 3 | import logging 4 | import pathlib 5 | import sys 6 | import time 7 | import typing 8 | 9 | import click 10 | import mlflow 11 | from tinygrad import GlobalCounters 12 | from tinygrad import Tensor 13 | from tinygrad import TinyJit 14 | from tinygrad.helpers import getenv 15 | from tinygrad.helpers import trange 16 | from tinygrad.nn import Conv2d 17 | from tinygrad.nn import InstanceNorm 18 | from tinygrad.nn import Linear 19 | from tinygrad.nn.datasets import mnist 20 | 21 | from .utils import ensure_experiment 22 | from .utils import filter_classes 23 | from marketplace.nn import Model 24 | from marketplace.optimizers import Optimizer 25 | from marketplace.optimizers import UnitVectorMode 26 | from marketplace.training import forward 27 | from marketplace.training import Spec 28 | from marketplace.training import straight_forward 29 | from marketplace.utils import write_checkpoint 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | @functools.cache 35 | def load_data(): 36 | return mnist(fashion=getenv("FASHION")) 37 | 38 | 39 | def make_marketplace( 40 | structure: list[tuple[int, int]] | None = None, 41 | default_vendor_count: int = 8, 42 | ): 43 | if structure is None: 44 | structure = [ 45 | # layer 0 46 | (default_vendor_count, 0), 47 | # layer 1 48 | (default_vendor_count, 0), 49 | # layer 2 50 | (default_vendor_count, 0), 51 | ] 52 | return [ 53 | Spec( 54 | model=Model( 55 | Conv2d(1, 32, 5), 56 | Tensor.relu, 57 | Conv2d(32, 32, 5), 58 | Tensor.relu, 59 | InstanceNorm(32), 60 | Tensor.max_pool2d, 61 | ), 62 | vendor_count=structure[0][0], 63 | ), 64 | Spec( 65 | model=Model( 66 | Conv2d(32, 64, 3), 67 | Tensor.relu, 68 | Conv2d(64, 64, 3), 69 | Tensor.relu, 70 | InstanceNorm(64), 71 | Tensor.max_pool2d, 72 | lambda x: x.flatten(1), 73 | ), 74 | vendor_count=structure[1][0], 75 | upstream_sampling=structure[1][1], 76 | ), 77 | Spec( 78 | model=Model(Linear(576, 10)), 79 | vendor_count=structure[2][0], 80 | upstream_sampling=structure[2][1], 81 | ), 82 | ] 83 | 84 | 85 | def train( 86 | step_count: int, 87 | batch_size: int, 88 | initial_lr: float, 89 | lr_decay_rate: float, 90 | marketplace: list[Spec], 91 | meta_lr: float | None = None, 92 | probe_scale: float | None = None, 93 | marketplace_replica: int = 1, 94 | initial_forward_pass: int = 1, 95 | forward_pass_schedule: list[tuple[int, int]] | None = None, 96 | metrics_per_steps: int = 10, 97 | checkpoint_filepath: pathlib.Path | None = None, 98 | checkpoint_per_steps: int = 1000, 99 | only_classes: typing.Container[int] | None = None, 100 | exclude_missing_from_loss_func: bool = True, 101 | unit_vector_mode: UnitVectorMode = "per_spec", 102 | manual_seed: int | None = None, 103 | ): 104 | logger.info( 105 | "Running beautiful MNIST with step_count=%s, batch_size=%s, init_lr=%s, lr_decay=%s, meta_lr=%s, " 106 | "probe_scale=%s, marketplace_replica=%s, initial_forward_pass=%s, forward_pass_schedule=%s, " 107 | "metrics_per_steps=%s, checkpoint_filepath=%s, checkpoint_per_steps=%s, only_classes=%s, " 108 | "exclude_missing_from_loss_func=%s, unit_vector_mode=%s, manual_seed=%s", 109 | step_count, 110 | batch_size, 111 | initial_lr, 112 | lr_decay_rate, 113 | meta_lr, 114 | probe_scale, 115 | marketplace_replica, 116 | initial_forward_pass, 117 | metrics_per_steps, 118 | forward_pass_schedule, 119 | checkpoint_filepath, 120 | checkpoint_per_steps, 121 | only_classes, 122 | exclude_missing_from_loss_func, 123 | unit_vector_mode, 124 | manual_seed, 125 | ) 126 | 127 | mlflow.log_param("step_count", step_count) 128 | mlflow.log_param("batch_size", batch_size) 129 | mlflow.log_param("marketplace_replica", marketplace_replica) 130 | mlflow.log_param("initial_forward_pass", initial_forward_pass) 131 | mlflow.log_param("lr", initial_lr) 132 | mlflow.log_param("lr_decay_rate", lr_decay_rate) 133 | mlflow.log_param("meta_lr", meta_lr) 134 | mlflow.log_param("probe_scale", probe_scale) 135 | mlflow.log_param("forward_pass_schedule", forward_pass_schedule) 136 | mlflow.log_param("metrics_per_steps", metrics_per_steps) 137 | mlflow.log_param("checkpoint_per_steps", checkpoint_per_steps) 138 | mlflow.log_param("only_classes", only_classes) 139 | mlflow.log_param("exclude_missing_from_loss_func", exclude_missing_from_loss_func) 140 | mlflow.log_param("unit_vector_mode", unit_vector_mode) 141 | mlflow.log_param("manual_seed", manual_seed) 142 | 143 | if manual_seed is not None: 144 | Tensor.manual_seed(manual_seed) 145 | 146 | X_train, Y_train, X_test, Y_test = load_data() 147 | 148 | if only_classes is not None: 149 | X_train, Y_train = filter_classes(X_train, Y_train, only=only_classes) 150 | X_test, Y_test = filter_classes(X_test, Y_test, only=only_classes) 151 | 152 | lr = Tensor(initial_lr).contiguous().realize() 153 | optimizer = Optimizer( 154 | marketplace=marketplace, 155 | learning_rate=lr, 156 | probe_scale=(Tensor(probe_scale) if probe_scale is not None else None), 157 | meta_learning_rate=(Tensor(meta_lr) if meta_lr is not None else None), 158 | ) 159 | loss_func = lambda x, y: x.sparse_categorical_crossentropy(y) 160 | if only_classes is not None and exclude_missing_from_loss_func: 161 | excluded_classes = frozenset(range(10)) - frozenset(only_classes) 162 | if len(excluded_classes) != 1: 163 | raise ValueError("Currently only support one excluded class") 164 | loss_func = lambda x, y: x.sparse_categorical_crossentropy( 165 | y, ignore_index=list(excluded_classes)[0] 166 | ) 167 | 168 | @TinyJit 169 | def forward_step(samples: Tensor) -> tuple[Tensor, Tensor, Tensor]: 170 | x = X_train[samples] 171 | y = Y_train[samples] 172 | batch_logits, batch_paths = forward( 173 | marketplace=marketplace, 174 | vendors=optimizer.vendors, 175 | x=x, 176 | ) 177 | loss = Tensor.stack( 178 | *(loss_func(logits, y) for logits in batch_logits), 179 | dim=0, 180 | ) 181 | accuracy = Tensor.stack( 182 | *( 183 | ((logits.argmax(axis=1) == y).sum() / batch_size) * 100 184 | for logits in batch_logits 185 | ), 186 | dim=0, 187 | ) 188 | return ( 189 | loss.realize(), 190 | accuracy.realize(), 191 | batch_paths.realize(), 192 | ) 193 | 194 | @TinyJit 195 | def compute_direction_vectors( 196 | loss: Tensor, paths: Tensor 197 | ) -> list[dict[str, Tensor]]: 198 | direction_vectors = optimizer.compute_direction_vectors( 199 | loss=loss, 200 | paths=paths, 201 | ) 202 | # TODO: optional 203 | Tensor.realize(*optimizer.schedule_lr_scale_update(direction_vectors)) 204 | return [ 205 | {key: params.realize() for key, params in delta.items()} 206 | for delta in direction_vectors 207 | ] 208 | 209 | @TinyJit 210 | def lr_scale_optimize_step( 211 | direction_vectors: list[dict[str, Tensor]] | None, learning_rates: Tensor | None 212 | ): 213 | Tensor.realize( 214 | *optimizer.schedule_weight_update( 215 | direction_delta=direction_vectors, learning_rates=learning_rates 216 | ) 217 | ) 218 | Tensor.realize(*optimizer.schedule_seeds_update()) 219 | Tensor.realize(*optimizer.schedule_delta_update()) 220 | 221 | @TinyJit 222 | def optimize_step( 223 | samples: Tensor, loss: Tensor, paths: Tensor 224 | ) -> tuple[Tensor, Tensor]: 225 | direction_vectors = optimizer.compute_direction_vectors( 226 | loss=loss, 227 | paths=paths, 228 | unit_vector_mode=unit_vector_mode, 229 | ) 230 | Tensor.realize( 231 | *optimizer.schedule_weight_update( 232 | direction_delta=direction_vectors, 233 | ) 234 | ) 235 | Tensor.realize(*optimizer.schedule_seeds_update()) 236 | Tensor.realize(*optimizer.schedule_delta_update()) 237 | 238 | # let's run forward pass again to see accuracy and loss 239 | x = X_train[samples] 240 | y = Y_train[samples] 241 | logits = straight_forward(marketplace, x) 242 | loss = loss_func(logits, y) 243 | accuracy = ((logits.argmax(axis=1) == y).sum() / batch_size) * 100 244 | return loss.realize(), accuracy.realize() 245 | 246 | @TinyJit 247 | def get_test_acc() -> Tensor: 248 | return ( 249 | straight_forward(marketplace, X_test).argmax(axis=1) == Y_test 250 | ).mean() * 100 251 | 252 | i = 0 253 | test_acc = float("nan") 254 | current_forward_pass = initial_forward_pass 255 | for i in (t := trange(step_count)): 256 | GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing 257 | 258 | if forward_pass_schedule is not None: 259 | for threshold, forward_pass in reversed(forward_pass_schedule): 260 | if i >= threshold: 261 | current_forward_pass = forward_pass 262 | break 263 | 264 | start_time = time.perf_counter() 265 | 266 | sample_batches = Tensor.randint( 267 | current_forward_pass, batch_size, high=X_train.shape[0] 268 | ).realize() 269 | 270 | # direction probing forward pass 271 | loss, _, paths = forward_step(sample_batches[0]) 272 | 273 | if meta_lr is not None: 274 | # lr scaling forward pass 275 | direction_vectors = compute_direction_vectors(loss=loss, paths=paths) 276 | loss, accuracy, paths = forward_step(sample_batches[0]) 277 | best_loss, best_index = loss.topk(1, largest=False) 278 | best_index = best_index.squeeze(0) 279 | best_accuracy = accuracy[best_index] 280 | best_path = paths[best_index] 281 | best_lr = optimizer.get_learning_rates(best_path) 282 | lr_scale_optimize_step(direction_vectors, best_lr) 283 | else: 284 | best_loss, best_accuracy = optimize_step( 285 | samples=sample_batches[0], loss=loss, paths=paths 286 | ) 287 | 288 | end_time = time.perf_counter() 289 | run_time = end_time - start_time 290 | if meta_lr is not None: 291 | optimizer.meta_learning_rate.assign( 292 | optimizer.meta_learning_rate * (1 - lr_decay_rate) 293 | ).realize() 294 | else: 295 | lr.assign(lr * (1 - lr_decay_rate)).realize() 296 | gflops = GlobalCounters.global_ops * 1e-9 / run_time 297 | 298 | if i % metrics_per_steps == (metrics_per_steps - 1): 299 | test_acc = get_test_acc().item() 300 | mlflow.log_metric("training/loss", best_loss.item(), step=i) 301 | mlflow.log_metric("training/accuracy", best_accuracy.item(), step=i) 302 | mlflow.log_metric("training/forward_pass", current_forward_pass, step=i) 303 | mlflow.log_metric("training/lr", lr.item(), step=i) 304 | mlflow.log_metric("training/gflops", gflops, step=i) 305 | mlflow.log_metric("testing/accuracy", test_acc, step=i) 306 | if meta_lr is not None: 307 | mlflow.log_metric( 308 | "testing/meta_lr", optimizer.meta_learning_rate.item(), step=i 309 | ) 310 | for j, spec_lr in enumerate(best_lr): 311 | mlflow.log_metric( 312 | f"training/adaptive_lr_{j}", spec_lr.item(), step=i 313 | ) 314 | 315 | if checkpoint_filepath is not None and i % checkpoint_per_steps == ( 316 | checkpoint_per_steps - 1 317 | ): 318 | write_checkpoint( 319 | marketplace=marketplace, 320 | global_step=i, 321 | output_filepath=pathlib.Path(checkpoint_filepath), 322 | ) 323 | 324 | t.set_description( 325 | f"loss: {best_loss.item():6.2f}, fw: {current_forward_pass}, lr: {lr.item():.2e}, " 326 | f"acc: {best_accuracy.item():.2f}%, vacc: {test_acc:.2f}%, {gflops:9,.2f} GFLOPS" 327 | ) 328 | if i is not None and checkpoint_filepath is not None: 329 | write_checkpoint( 330 | marketplace=marketplace, 331 | global_step=i, 332 | output_filepath=pathlib.Path(checkpoint_filepath), 333 | ) 334 | 335 | 336 | @click.command("beautiful_mnist") 337 | @click.option("--step-count", type=int, default=10_000, help="How many steps to run") 338 | @click.option("--batch-size", type=int, default=512, help="Size of batch") 339 | @click.option( 340 | "--initial-lr", type=float, default=1e-1, help="Initial learning rate value" 341 | ) 342 | @click.option("--lr-decay", type=float, default=1e-5, help="Learning rate decay rate") 343 | @click.option( 344 | "--meta-lr", 345 | type=click.FloatRange(0.0, 1.0, max_open=True), 346 | help="Enable learning rate scaling mode with the given meta-learning rate", 347 | ) 348 | @click.option( 349 | "--forward-pass", 350 | type=int, 351 | default=1, 352 | help="How many forward pass to run (simulate distributed computing)", 353 | ) 354 | @click.option( 355 | "--marketplace-replica", 356 | type=int, 357 | default=1, 358 | help="How many marketplace replica to run (simulate distributed computing)", 359 | ) 360 | @click.option("--vendor-count", type=int, default=8, help="Vendor count") 361 | @click.option("--seed", type=int, help="Set the random seed") 362 | @click.option( 363 | "--probe-scale", 364 | type=float, 365 | default=0.1, 366 | help="The scale we use to apply on LR for making the reconciled delta direction", 367 | ) 368 | @click.option( 369 | "--unit-vector-mode", 370 | type=click.Choice(UnitVectorMode, case_sensitive=False), 371 | default=UnitVectorMode.per_spec.value, 372 | help="The unit vector mode to use", 373 | ) 374 | @click.option( 375 | "--checkpoint-filepath", 376 | type=click.Path(dir_okay=False, writable=True), 377 | help="Filepath of checkpoint to write to", 378 | ) 379 | @click.option( 380 | "--checkpoint-per-steps", 381 | type=int, 382 | default=100, 383 | help="For how many steps we should write a checkpoint", 384 | ) 385 | @click.option("--run-name", type=str, help="Set the run name") 386 | def main( 387 | step_count: int, 388 | batch_size: int, 389 | initial_lr: float, 390 | lr_decay: float, 391 | meta_lr: float | None, 392 | forward_pass: int, 393 | marketplace_replica: int, 394 | vendor_count: int, 395 | seed: int | None, 396 | probe_scale: float | None, 397 | unit_vector_mode: UnitVectorMode, 398 | checkpoint_filepath: str, 399 | checkpoint_per_steps: int, 400 | run_name: str | None, 401 | ): 402 | # ref: https://github.com/tinygrad/tinygrad/issues/8617 403 | # With complex huge compute graph, tinygrad runs into recursion too deep issue, let's bump it up 404 | NEW_RECURSION_LIMIT = 100_000 405 | logger.info("Current recursion limit is %s", sys.getrecursionlimit()) 406 | sys.setrecursionlimit(NEW_RECURSION_LIMIT) 407 | logger.info("Set recursion limit to %s", NEW_RECURSION_LIMIT) 408 | with mlflow.start_run( 409 | experiment_id=ensure_experiment("Marketplace V2"), 410 | run_name="beautiful-mnist" if run_name is None else run_name, 411 | ): 412 | mlflow.log_param("vendor_count", vendor_count) 413 | train( 414 | step_count=step_count, 415 | batch_size=batch_size, 416 | initial_lr=initial_lr, 417 | lr_decay_rate=lr_decay, 418 | meta_lr=meta_lr, 419 | initial_forward_pass=forward_pass, 420 | probe_scale=probe_scale if probe_scale else None, 421 | unit_vector_mode=unit_vector_mode, 422 | manual_seed=seed, 423 | marketplace=make_marketplace( 424 | default_vendor_count=vendor_count, 425 | ), 426 | marketplace_replica=marketplace_replica, 427 | checkpoint_filepath=( 428 | pathlib.Path(checkpoint_filepath) 429 | if checkpoint_filepath is not None 430 | else None 431 | ), 432 | checkpoint_per_steps=checkpoint_per_steps, 433 | ) 434 | 435 | 436 | if __name__ == "__main__": 437 | logging.basicConfig(level=logging.INFO) 438 | main() 439 | -------------------------------------------------------------------------------- /experiments/cl_augment.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import mlflow 5 | 6 | from .beautiful_mnist import make_marketplace 7 | from .beautiful_mnist import train 8 | from .continual_learning import learn 9 | from .utils import ensure_experiment 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def main(): 15 | exp_id = ensure_experiment("Continual Learning Augment") 16 | checkpoint_file = pathlib.Path("continual-learning.safetensors") 17 | if not checkpoint_file.exists(): 18 | train_vendor_count = 16 19 | with mlflow.start_run( 20 | run_name="base-model", 21 | experiment_id=exp_id, 22 | log_system_metrics=True, 23 | ): 24 | marketplace = make_marketplace(default_vendor_count=train_vendor_count) 25 | mlflow.log_param("vendor_count", train_vendor_count) 26 | train( 27 | step_count=2_000, 28 | batch_size=512, 29 | initial_lr=1e-1, 30 | lr_decay_rate=1e-5, 31 | probe_scale=1e-1, 32 | marketplace=marketplace, 33 | manual_seed=42, 34 | checkpoint_filepath=checkpoint_file, 35 | ) 36 | else: 37 | logger.info("Checkpoint file %s already exists, skip", checkpoint_file) 38 | learn_vendor_count = 4 39 | for augment_old, augment_new in [ 40 | (False, False), 41 | (False, True), 42 | (True, False), 43 | (True, True), 44 | ]: 45 | with mlflow.start_run( 46 | run_name=f"augment-old-{augment_old}-new-{augment_new}", 47 | experiment_id=exp_id, 48 | log_system_metrics=True, 49 | ): 50 | marketplace = make_marketplace(default_vendor_count=learn_vendor_count) 51 | mlflow.log_param("vendor_count", learn_vendor_count) 52 | learn( 53 | step_count=100_000, 54 | batch_size=256, 55 | target_new_classes=(3,), 56 | augment_old=augment_old, 57 | augment_new=augment_new, 58 | new_train_size=32, 59 | initial_lr=1e-2, 60 | lr_decay_rate=0, 61 | probe_scale=1.0, 62 | forward_pass=1, 63 | marketplace=marketplace, 64 | manual_seed=42, 65 | input_checkpoint_filepath=checkpoint_file, 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | logging.basicConfig(level=logging.INFO) 71 | main() 72 | -------------------------------------------------------------------------------- /experiments/cl_balance_labels.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import mlflow 5 | 6 | from .beautiful_mnist import make_marketplace 7 | from .beautiful_mnist import train 8 | from .continual_learning import learn 9 | from .utils import ensure_experiment 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def main(): 15 | exp_id = ensure_experiment("Continual Learning Balance Labels V2") 16 | checkpoint_file = pathlib.Path("continual-learning.safetensors") 17 | if not checkpoint_file.exists(): 18 | train_vendor_count = 16 19 | with mlflow.start_run( 20 | run_name="base-model", 21 | experiment_id=exp_id, 22 | log_system_metrics=True, 23 | ): 24 | marketplace = make_marketplace(default_vendor_count=train_vendor_count) 25 | mlflow.log_param("vendor_count", train_vendor_count) 26 | train( 27 | step_count=2_000, 28 | batch_size=512, 29 | initial_lr=1e-1, 30 | lr_decay_rate=1e-5, 31 | probe_scale=1e-1, 32 | marketplace=marketplace, 33 | manual_seed=42, 34 | checkpoint_filepath=checkpoint_file, 35 | ) 36 | else: 37 | logger.info("Checkpoint file %s already exists, skip", checkpoint_file) 38 | learn_vendor_count = 8 39 | for balance_labels in [True, False]: 40 | with mlflow.start_run( 41 | run_name=f"balance-labels-{balance_labels}", 42 | experiment_id=exp_id, 43 | log_system_metrics=True, 44 | ): 45 | marketplace = make_marketplace(default_vendor_count=learn_vendor_count) 46 | mlflow.log_param("vendor_count", learn_vendor_count) 47 | learn( 48 | step_count=10_000, 49 | batch_size=256, 50 | target_new_classes=(3,), 51 | balance_labels=balance_labels, 52 | new_train_size=32, 53 | initial_lr=1e-2, 54 | lr_decay_rate=0, 55 | probe_scale=1.0, 56 | forward_pass=1, 57 | marketplace=marketplace, 58 | manual_seed=42, 59 | input_checkpoint_filepath=checkpoint_file, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | logging.basicConfig(level=logging.INFO) 65 | main() 66 | -------------------------------------------------------------------------------- /experiments/cl_digit_exp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import mlflow 5 | 6 | from .beautiful_mnist import make_marketplace 7 | from .beautiful_mnist import train 8 | from .continual_learning import learn 9 | from .utils import ensure_experiment 10 | from marketplace.optimizers import UnitVectorMode 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def main(): 16 | exp_id = ensure_experiment("Continual Learning Digit - Article") 17 | for exclude_missing_from_loss_func in [False, True]: 18 | checkpoint_file = pathlib.Path( 19 | f"continual-learning-v3-exclude-9-exclude-loss-{exclude_missing_from_loss_func}.safetensors" 20 | ) 21 | if not checkpoint_file.exists(): 22 | train_vendor_count = 16 23 | with mlflow.start_run( 24 | run_name=f"base-model-exclude-loss-{exclude_missing_from_loss_func}", 25 | experiment_id=exp_id, 26 | log_system_metrics=True, 27 | ): 28 | marketplace = make_marketplace(default_vendor_count=train_vendor_count) 29 | mlflow.log_param("vendor_count", train_vendor_count) 30 | train( 31 | step_count=2_000, 32 | batch_size=512, 33 | initial_lr=3e-1, 34 | lr_decay_rate=1e-5, 35 | probe_scale=1e-2, 36 | marketplace=marketplace, 37 | manual_seed=42, 38 | unit_vector_mode=UnitVectorMode.whole, 39 | checkpoint_filepath=checkpoint_file, 40 | # exclude 9 41 | only_classes=tuple(range(9)), 42 | exclude_missing_from_loss_func=exclude_missing_from_loss_func, 43 | ) 44 | else: 45 | logger.info("Checkpoint file %s already exists, skip", checkpoint_file) 46 | 47 | learn_vendor_count = 4 48 | replay_file = pathlib.Path( 49 | f"digit-exclude-loss-{exclude_missing_from_loss_func}.jsonl" 50 | ) 51 | with ( 52 | mlflow.start_run( 53 | run_name=f"exclude-loss-{exclude_missing_from_loss_func}", 54 | experiment_id=exp_id, 55 | log_system_metrics=True, 56 | ), 57 | replay_file.open("wt") as fo, 58 | ): 59 | marketplace = make_marketplace( 60 | default_vendor_count=learn_vendor_count, 61 | ) 62 | mlflow.log_param("vendor_count", learn_vendor_count) 63 | learn( 64 | step_count=100_000, 65 | batch_size=256, 66 | initial_lr=1e-2, 67 | lr_decay_rate=0, 68 | probe_scale=1.0, 69 | forward_pass=1, 70 | marketplace=marketplace, 71 | manual_seed=42, 72 | replay_file=fo, 73 | target_new_classes=(9,), 74 | input_checkpoint_filepath=checkpoint_file, 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | logging.basicConfig(level=logging.INFO) 80 | main() 81 | -------------------------------------------------------------------------------- /experiments/cl_exp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import mlflow 5 | 6 | from .beautiful_mnist import make_marketplace 7 | from .beautiful_mnist import train 8 | from .continual_learning import learn 9 | from .utils import ensure_experiment 10 | from marketplace.optimizers import UnitVectorMode 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def main(): 16 | exp_id = ensure_experiment("Continual Learning V4") 17 | checkpoint_file = pathlib.Path( 18 | "continual-learning-v3-exclude-9-neutral.safetensors" 19 | ) 20 | if not checkpoint_file.exists(): 21 | train_vendor_count = 16 22 | with mlflow.start_run( 23 | run_name="base-model", 24 | experiment_id=exp_id, 25 | log_system_metrics=True, 26 | ): 27 | marketplace = make_marketplace(default_vendor_count=train_vendor_count) 28 | mlflow.log_param("vendor_count", train_vendor_count) 29 | train( 30 | step_count=2_000, 31 | batch_size=512, 32 | initial_lr=3e-1, 33 | lr_decay_rate=1e-5, 34 | probe_scale=1e-2, 35 | marketplace=marketplace, 36 | manual_seed=42, 37 | unit_vector_mode=UnitVectorMode.whole, 38 | checkpoint_filepath=checkpoint_file, 39 | # exclude 9 40 | only_classes=tuple(range(9)), 41 | ) 42 | else: 43 | logger.info("Checkpoint file %s already exists, skip", checkpoint_file) 44 | for learn_vendor_count in [8, 16]: 45 | for fw in [4, 8, 16]: 46 | for lr in [9e-3, 1e-2, 2e-2, 3e-2, 1e-1, 2e-1]: 47 | for probe_scale in [1, 0.5, 0.1]: 48 | with mlflow.start_run( 49 | run_name=f"learn-vendor-{learn_vendor_count}-lr-{lr:.1e}-fw-{fw}-probe-scale-{probe_scale}", 50 | experiment_id=exp_id, 51 | log_system_metrics=True, 52 | ): 53 | marketplace = make_marketplace( 54 | default_vendor_count=learn_vendor_count 55 | ) 56 | mlflow.log_param("vendor_count", learn_vendor_count) 57 | learn( 58 | step_count=10_000, 59 | batch_size=256, 60 | target_new_classes=(9,), 61 | initial_lr=lr, 62 | lr_decay_rate=0, 63 | probe_scale=probe_scale, 64 | forward_pass=fw, 65 | marketplace=marketplace, 66 | manual_seed=42, 67 | input_checkpoint_filepath=checkpoint_file, 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | logging.basicConfig(level=logging.INFO) 73 | main() 74 | -------------------------------------------------------------------------------- /experiments/cl_fashion_exp.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | 4 | import mlflow 5 | 6 | from .beautiful_mnist import make_marketplace 7 | from .beautiful_mnist import train 8 | from .continual_learning_fashion import learn 9 | from .utils import ensure_experiment 10 | from marketplace.optimizers import UnitVectorMode 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def main(): 16 | exp_id = ensure_experiment("Continual Learning Fashion - Article") 17 | checkpoint_file = pathlib.Path("continual-learning-fashion.safetensors") 18 | if not checkpoint_file.exists(): 19 | train_vendor_count = 16 20 | with mlflow.start_run( 21 | run_name="base-model", 22 | experiment_id=exp_id, 23 | log_system_metrics=True, 24 | ): 25 | marketplace = make_marketplace(default_vendor_count=train_vendor_count) 26 | mlflow.log_param("vendor_count", train_vendor_count) 27 | train( 28 | step_count=2_000, 29 | batch_size=512, 30 | initial_lr=3e-1, 31 | lr_decay_rate=1e-5, 32 | probe_scale=1e-2, 33 | marketplace=marketplace, 34 | manual_seed=42, 35 | unit_vector_mode=UnitVectorMode.whole, 36 | checkpoint_filepath=checkpoint_file, 37 | ) 38 | else: 39 | logger.info("Checkpoint file %s already exists, skip", checkpoint_file) 40 | 41 | learn_vendor_count = 4 42 | for augment_old in [False, True]: 43 | replay_file = pathlib.Path(f"augment-old-{augment_old}.jsonl") 44 | with ( 45 | mlflow.start_run( 46 | run_name=f"augment-old-{augment_old}", 47 | experiment_id=exp_id, 48 | log_system_metrics=True, 49 | ), 50 | replay_file.open("wt") as fo, 51 | ): 52 | marketplace = make_marketplace( 53 | default_vendor_count=learn_vendor_count, 54 | ) 55 | mlflow.log_param("vendor_count", learn_vendor_count) 56 | learn( 57 | step_count=100_000, 58 | batch_size=256, 59 | new_train_size=16, 60 | initial_lr=1e-2, 61 | lr_decay_rate=0, 62 | probe_scale=1.0, 63 | forward_pass=1, 64 | augment_old=augment_old, 65 | marketplace=marketplace, 66 | manual_seed=42, 67 | replay_file=fo, 68 | input_checkpoint_filepath=checkpoint_file, 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | logging.basicConfig(level=logging.INFO) 74 | main() 75 | -------------------------------------------------------------------------------- /experiments/continual_learning.py: -------------------------------------------------------------------------------- 1 | # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 2 | import json 3 | import logging 4 | import pathlib 5 | import sys 6 | import time 7 | import typing 8 | from contextlib import nullcontext 9 | 10 | import click 11 | import mlflow 12 | import numpy as np 13 | from PIL import Image 14 | from tinygrad import dtypes 15 | from tinygrad import GlobalCounters 16 | from tinygrad import Tensor 17 | from tinygrad import TinyJit 18 | from tinygrad.helpers import trange 19 | from tinygrad.nn import Conv2d 20 | from tinygrad.nn import InstanceNorm 21 | from tinygrad.nn import Linear 22 | from tinygrad.nn.datasets import mnist 23 | 24 | from .utils import ensure_experiment 25 | from marketplace.continual_learning import forward_with_paths 26 | from marketplace.nn import Model 27 | from marketplace.optimizers import Optimizer 28 | from marketplace.optimizers import UnitVectorMode 29 | from marketplace.training import Spec 30 | from marketplace.training import straight_forward 31 | from marketplace.utils import load_checkpoint 32 | from marketplace.utils import write_checkpoint 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | LABEL_COUNT = 10 37 | 38 | 39 | # stolen from tinygrad 40 | # ref: https://github.com/tinygrad/tinygrad/blob/c6c16b294616447238d5d19974bceca52c9f2a40/extra/augment.py#L11-L21 41 | def augment_img( 42 | X: np.typing.NDArray, rotate: float = 10, px: int = 3 43 | ) -> np.typing.NDArray: 44 | Xaug = np.zeros_like(X) 45 | for i in range(len(X)): 46 | im = Image.fromarray(X[i]) 47 | im = im.rotate(np.random.randint(-rotate, rotate), resample=Image.BICUBIC) 48 | w, h = X.shape[1:] 49 | # upper left, lower left, lower right, upper right 50 | quad = np.random.randint(-px, px, size=(8)) + np.array([0, 0, 0, h, w, h, w, 0]) 51 | im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC) 52 | Xaug[i] = im 53 | return Xaug 54 | 55 | 56 | def make_marketplace( 57 | structure: list[tuple[int, int]] | None = None, 58 | default_vendor_count: int = 8, 59 | ): 60 | if structure is None: 61 | structure = [ 62 | # layer 0 63 | (default_vendor_count, 0), 64 | # layer 1 65 | (default_vendor_count, 0), 66 | # layer 2 67 | (default_vendor_count, 0), 68 | ] 69 | return [ 70 | Spec( 71 | model=Model( 72 | Conv2d(1, 32, 5), 73 | Tensor.relu, 74 | Conv2d(32, 32, 5), 75 | Tensor.relu, 76 | InstanceNorm(32), 77 | Tensor.max_pool2d, 78 | ), 79 | vendor_count=structure[0][0], 80 | ), 81 | Spec( 82 | model=Model( 83 | Conv2d(32, 64, 3), 84 | Tensor.relu, 85 | Conv2d(64, 64, 3), 86 | Tensor.relu, 87 | InstanceNorm(64), 88 | Tensor.max_pool2d, 89 | lambda x: x.flatten(1), 90 | ), 91 | vendor_count=structure[1][0], 92 | upstream_sampling=structure[1][1], 93 | ), 94 | Spec( 95 | model=Model(Linear(576, LABEL_COUNT)), 96 | vendor_count=structure[2][0], 97 | upstream_sampling=structure[2][1], 98 | ), 99 | ] 100 | 101 | 102 | def learn( 103 | step_count: int, 104 | batch_size: int, 105 | initial_lr: float, 106 | lr_decay_rate: float, 107 | marketplace: list[Spec], 108 | target_new_classes: tuple[int] = (9,), 109 | probe_scale: float | None = None, 110 | forward_pass: int = 1, 111 | metrics_per_steps: int = 10, 112 | input_checkpoint_filepath: pathlib.Path | None = None, 113 | checkpoint_filepath: pathlib.Path | None = None, 114 | checkpoint_per_steps: int = 1000, 115 | replay_file: typing.TextIO | None = None, 116 | manual_seed: int | None = None, 117 | ): 118 | logger.info( 119 | "Running beautiful MNIST continual learning with step_count=%s, batch_size=%s, init_lr=%s, lr_decay=%s, " 120 | "target_new_classes=%s, probe_scale=%s, forward_pass=%s, metrics_per_steps=%s, input_checkpoint_filepath=%s, " 121 | "checkpoint_filepath=%s, checkpoint_per_steps=%s, manual_seed=%s", 122 | step_count, 123 | batch_size, 124 | initial_lr, 125 | lr_decay_rate, 126 | target_new_classes, 127 | probe_scale, 128 | forward_pass, 129 | metrics_per_steps, 130 | input_checkpoint_filepath, 131 | checkpoint_filepath, 132 | checkpoint_per_steps, 133 | manual_seed, 134 | ) 135 | 136 | mlflow.log_param("step_count", step_count) 137 | mlflow.log_param("batch_size", batch_size) 138 | mlflow.log_param("forward_pass", forward_pass) 139 | mlflow.log_param("lr", initial_lr) 140 | mlflow.log_param("lr_decay_rate", lr_decay_rate) 141 | mlflow.log_param("target_new_classes", target_new_classes) 142 | mlflow.log_param("probe_scale", probe_scale) 143 | mlflow.log_param("metrics_per_steps", metrics_per_steps) 144 | mlflow.log_param("checkpoint_per_steps", checkpoint_per_steps) 145 | mlflow.log_param("manual_seed", manual_seed) 146 | 147 | if input_checkpoint_filepath is not None: 148 | load_checkpoint( 149 | marketplace=marketplace, input_filepath=input_checkpoint_filepath 150 | ) 151 | 152 | if manual_seed is not None: 153 | Tensor.manual_seed(manual_seed) 154 | 155 | X_train, Y_train, X_test, Y_test = mnist() 156 | 157 | lr = Tensor(initial_lr).contiguous().realize() 158 | optimizer = Optimizer( 159 | marketplace=marketplace, 160 | learning_rate=lr, 161 | probe_scale=(Tensor(probe_scale) if probe_scale is not None else None), 162 | ) 163 | 164 | @TinyJit 165 | def forward_step(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor, Tensor]: 166 | batch_paths = Tensor.stack( 167 | *( 168 | Tensor.randint( 169 | batch_size, low=0, high=spec.vendor_count, dtype=dtypes.uint 170 | ) 171 | for spec in marketplace 172 | ), 173 | dim=1, 174 | ) 175 | 176 | logits = forward_with_paths( 177 | marketplace=marketplace, 178 | paths=batch_paths, 179 | x=x, 180 | deltas=[ctx.delta for ctx in optimizer.spec_context], 181 | ) 182 | loss = logits.sparse_categorical_crossentropy(y, reduction="none") 183 | correct = logits.argmax(axis=1) == y 184 | return ( 185 | loss.realize(), 186 | correct.realize(), 187 | batch_paths.realize(), 188 | ) 189 | 190 | @TinyJit 191 | def optimize_step(loss: Tensor, paths: Tensor): 192 | direction_vectors = optimizer.compute_direction_vectors( 193 | loss=loss, 194 | paths=paths, 195 | unit_vector_mode=UnitVectorMode.whole, 196 | ) 197 | Tensor.realize( 198 | *optimizer.schedule_weight_update( 199 | direction_delta=direction_vectors, 200 | ) 201 | ) 202 | Tensor.realize(*optimizer.schedule_seeds_update()) 203 | Tensor.realize(*optimizer.schedule_delta_update()) 204 | 205 | @TinyJit 206 | def get_test_acc() -> tuple[Tensor, Tensor]: 207 | predictions = straight_forward(marketplace, X_test).argmax(axis=1) == Y_test 208 | new_labels = Y_test == target_new_classes[0] 209 | for new_label in target_new_classes[1:]: 210 | new_labels |= Y_test == new_label 211 | old_labels = ~new_labels 212 | return ( 213 | # old labels accuracy 214 | (((predictions & old_labels).sum() / old_labels.sum()) * 100).realize(), 215 | # new labels accuracy 216 | (((predictions & new_labels).sum() / new_labels.sum()) * 100).realize(), 217 | ) 218 | 219 | i = 0 220 | old_test_acc = float("nan") 221 | new_test_acc = float("nan") 222 | for i in (t := trange(step_count)): 223 | GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing 224 | 225 | start_time = time.perf_counter() 226 | 227 | all_samples = [] 228 | all_correct = [] 229 | all_loss = [] 230 | all_paths = [] 231 | all_old_loss = [] 232 | all_old_accuracy = [] 233 | all_new_loss = [] 234 | all_new_accuracy = [] 235 | for _ in range(forward_pass): 236 | samples = Tensor.randint( 237 | batch_size, low=0, high=batch_size, dtype=dtypes.uint 238 | ) 239 | x = X_train[samples] 240 | y = Y_train[samples] 241 | 242 | loss, correct, paths = forward_step(x, y) 243 | all_loss.append(loss) 244 | all_paths.append(paths) 245 | 246 | y = y.numpy() 247 | loss = loss.numpy() 248 | correct = correct.numpy() 249 | samples = samples.numpy() 250 | 251 | old_mask = ~np.isin(y, target_new_classes) 252 | old_loss = loss[old_mask] 253 | old_accuracy = correct[old_mask] 254 | new_mask = ~old_mask 255 | new_loss = loss[new_mask] 256 | new_accuracy = correct[new_mask] 257 | 258 | all_samples.append(samples) 259 | all_correct.append(correct) 260 | all_old_loss.append(old_loss) 261 | all_old_accuracy.append(old_accuracy) 262 | all_new_loss.append(new_loss) 263 | all_new_accuracy.append(new_accuracy) 264 | 265 | optimize_step(Tensor.cat(*all_loss), Tensor.cat(*all_paths)) 266 | 267 | old_loss = np.concatenate(all_old_loss).mean() 268 | old_accuracy = np.concatenate(all_old_accuracy).mean() * 100 269 | new_loss = np.concatenate(all_new_loss).mean() 270 | new_accuracy = np.concatenate(all_new_accuracy).mean() * 100 271 | 272 | end_time = time.perf_counter() 273 | run_time = end_time - start_time 274 | lr.assign(lr * (1 - lr_decay_rate)).realize() 275 | gflops = GlobalCounters.global_ops * 1e-9 / run_time 276 | 277 | if i % metrics_per_steps == (metrics_per_steps - 1): 278 | old_test_acc, new_test_acc = get_test_acc() 279 | old_test_acc = old_test_acc.item() 280 | new_test_acc = new_test_acc.item() 281 | mlflow.log_metric("learning/old_loss", old_loss.item(), step=i) 282 | mlflow.log_metric("learning/old_accuracy", old_accuracy.item(), step=i) 283 | mlflow.log_metric("learning/new_loss", new_loss.item(), step=i) 284 | mlflow.log_metric("learning/new_accuracy", new_accuracy.item(), step=i) 285 | mlflow.log_metric("learning/lr", lr.item(), step=i) 286 | mlflow.log_metric("learning/gflops", gflops, step=i) 287 | mlflow.log_metric("testing/old_accuracy", old_test_acc, step=i) 288 | mlflow.log_metric("testing/new_accuracy", new_test_acc, step=i) 289 | if replay_file is not None: 290 | replay_file.write( 291 | json.dumps( 292 | dict( 293 | samples=np.concatenate(all_samples).tolist(), 294 | correct=np.concatenate(all_correct).tolist(), 295 | global_step=i, 296 | ) 297 | ) 298 | ) 299 | 300 | if checkpoint_filepath is not None and i % checkpoint_per_steps == ( 301 | checkpoint_per_steps - 1 302 | ): 303 | write_checkpoint( 304 | marketplace=marketplace, 305 | global_step=i, 306 | output_filepath=pathlib.Path(checkpoint_filepath), 307 | ) 308 | 309 | t.set_description( 310 | f"loss: {old_loss.item():6.2f}/{new_loss.item():6.2f}, rl: {lr.item():.2e}, " 311 | f"acc: {old_accuracy.item():.2f}%/{new_accuracy.item():.2f}%, " 312 | f"vacc: {old_test_acc:.2f}%/{new_test_acc:.2f}%, {gflops:9,.2f} GFLOPS" 313 | ) 314 | if i is not None and checkpoint_filepath is not None: 315 | write_checkpoint( 316 | marketplace=marketplace, 317 | global_step=i, 318 | output_filepath=pathlib.Path(checkpoint_filepath), 319 | ) 320 | 321 | 322 | @click.command() 323 | @click.option("--step-count", type=int, default=10_000, help="How many steps to run") 324 | @click.option("--batch-size", type=int, default=256, help="Size of batch") 325 | @click.option( 326 | "--initial-lr", type=float, default=1e-2, help="Initial learning rate value" 327 | ) 328 | @click.option("--lr-decay", type=float, default=0, help="Learning rate decay rate") 329 | @click.option("--vendor-count", type=int, default=4, help="Vendor count") 330 | @click.option( 331 | "--forward-pass", 332 | type=int, 333 | default=1, 334 | help="How many forward pass to run (simulate distributed computing)", 335 | ) 336 | @click.option("--seed", type=int, help="Set the random seed") 337 | @click.option( 338 | "--probe-scale", 339 | type=float, 340 | default=1, 341 | help="The scale we use to apply on LR for making the reconciled delta direction", 342 | ) 343 | @click.option( 344 | "--input-checkpoint-filepath", 345 | type=click.Path(dir_okay=False, readable=True, exists=True), 346 | default="continual-learning-v3-exclude-9-neutral.safetensors", 347 | help="Filepath of checkpoint to read from", 348 | ) 349 | @click.option( 350 | "--checkpoint-filepath", 351 | type=click.Path(dir_okay=False, writable=True), 352 | help="Filepath of checkpoint to write to", 353 | ) 354 | @click.option( 355 | "--checkpoint-per-steps", 356 | type=int, 357 | default=100, 358 | help="For how many steps we should write a checkpoint", 359 | ) 360 | @click.option( 361 | "--replay-file", 362 | type=click.Path(dir_okay=False, writable=True), 363 | help="Filepath of replay JSON file to write to", 364 | ) 365 | @click.option("--run-name", type=str, help="Set the run name") 366 | def main( 367 | step_count: int, 368 | batch_size: int, 369 | initial_lr: float, 370 | lr_decay: float, 371 | vendor_count: int, 372 | forward_pass: int, 373 | seed: int | None, 374 | probe_scale: float | None, 375 | input_checkpoint_filepath: str, 376 | checkpoint_filepath: str, 377 | checkpoint_per_steps: int, 378 | replay_file: str | None, 379 | run_name: str | None, 380 | ): 381 | # ref: https://github.com/tinygrad/tinygrad/issues/8617 382 | # With complex huge compute graph, tinygrad runs into recursion too deep issue, let's bump it up 383 | NEW_RECURSION_LIMIT = 100_000 384 | logger.info("Current recursion limit is %s", sys.getrecursionlimit()) 385 | sys.setrecursionlimit(NEW_RECURSION_LIMIT) 386 | logger.info("Set recursion limit to %s", NEW_RECURSION_LIMIT) 387 | if replay_file is not None: 388 | replay_filepath = pathlib.Path(replay_file) 389 | replay_file_ctx = replay_filepath.open("wt") 390 | else: 391 | replay_file_ctx = nullcontext() 392 | with ( 393 | mlflow.start_run( 394 | experiment_id=ensure_experiment("Continual Learning"), 395 | run_name="beautiful-mnist" if run_name is None else run_name, 396 | ), 397 | replay_file_ctx as replay_file, 398 | ): 399 | mlflow.log_param("vendor_count", vendor_count) 400 | learn( 401 | step_count=step_count, 402 | batch_size=batch_size, 403 | initial_lr=initial_lr, 404 | lr_decay_rate=lr_decay, 405 | probe_scale=probe_scale if probe_scale else None, 406 | forward_pass=forward_pass, 407 | manual_seed=seed, 408 | marketplace=make_marketplace( 409 | default_vendor_count=vendor_count, 410 | ), 411 | input_checkpoint_filepath=( 412 | pathlib.Path(input_checkpoint_filepath) 413 | if input_checkpoint_filepath is not None 414 | else None 415 | ), 416 | checkpoint_filepath=( 417 | pathlib.Path(checkpoint_filepath) 418 | if checkpoint_filepath is not None 419 | else None 420 | ), 421 | checkpoint_per_steps=checkpoint_per_steps, 422 | replay_file=replay_file, 423 | ) 424 | 425 | 426 | if __name__ == "__main__": 427 | logging.basicConfig(level=logging.INFO) 428 | main() 429 | -------------------------------------------------------------------------------- /experiments/continual_learning_fashion.py: -------------------------------------------------------------------------------- 1 | # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 2 | import json 3 | import logging 4 | import pathlib 5 | import sys 6 | import time 7 | import typing 8 | from contextlib import nullcontext 9 | 10 | import click 11 | import mlflow 12 | import numpy as np 13 | from PIL import Image 14 | from tinygrad import dtypes 15 | from tinygrad import GlobalCounters 16 | from tinygrad import Tensor 17 | from tinygrad import TinyJit 18 | from tinygrad.helpers import trange 19 | from tinygrad.nn import Conv2d 20 | from tinygrad.nn import InstanceNorm 21 | from tinygrad.nn import Linear 22 | from tinygrad.nn.datasets import mnist 23 | 24 | from .utils import ensure_experiment 25 | from marketplace.continual_learning import forward_with_paths 26 | from marketplace.nn import Model 27 | from marketplace.optimizers import Optimizer 28 | from marketplace.optimizers import UnitVectorMode 29 | from marketplace.training import Spec 30 | from marketplace.training import straight_forward 31 | from marketplace.utils import load_checkpoint 32 | from marketplace.utils import write_checkpoint 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | LABEL_COUNT = 10 37 | 38 | 39 | # stolen from tinygrad 40 | # ref: https://github.com/tinygrad/tinygrad/blob/c6c16b294616447238d5d19974bceca52c9f2a40/extra/augment.py#L11-L21 41 | def augment_img( 42 | X: np.typing.NDArray, rotate: float = 10, px: int = 3 43 | ) -> np.typing.NDArray: 44 | Xaug = np.zeros_like(X) 45 | for i in range(len(X)): 46 | im = Image.fromarray(X[i]) 47 | im = im.rotate(np.random.randint(-rotate, rotate), resample=Image.BICUBIC) 48 | w, h = X.shape[1:] 49 | # upper left, lower left, lower right, upper right 50 | quad = np.random.randint(-px, px, size=(8)) + np.array([0, 0, 0, h, w, h, w, 0]) 51 | im = im.transform((w, h), Image.QUAD, quad, resample=Image.BICUBIC) 52 | Xaug[i] = im 53 | return Xaug 54 | 55 | 56 | def make_marketplace( 57 | structure: list[tuple[int, int]] | None = None, 58 | default_vendor_count: int = 8, 59 | ): 60 | if structure is None: 61 | structure = [ 62 | # layer 0 63 | (default_vendor_count, 0), 64 | # layer 1 65 | (default_vendor_count, 0), 66 | # layer 2 67 | (default_vendor_count, 0), 68 | ] 69 | return [ 70 | Spec( 71 | model=Model( 72 | Conv2d(1, 32, 5), 73 | Tensor.relu, 74 | Conv2d(32, 32, 5), 75 | Tensor.relu, 76 | InstanceNorm(32), 77 | Tensor.max_pool2d, 78 | ), 79 | vendor_count=structure[0][0], 80 | ), 81 | Spec( 82 | model=Model( 83 | Conv2d(32, 64, 3), 84 | Tensor.relu, 85 | Conv2d(64, 64, 3), 86 | Tensor.relu, 87 | InstanceNorm(64), 88 | Tensor.max_pool2d, 89 | lambda x: x.flatten(1), 90 | ), 91 | vendor_count=structure[1][0], 92 | upstream_sampling=structure[1][1], 93 | ), 94 | Spec( 95 | model=Model(Linear(576, LABEL_COUNT)), 96 | vendor_count=structure[2][0], 97 | upstream_sampling=structure[2][1], 98 | ), 99 | ] 100 | 101 | 102 | def learn( 103 | step_count: int, 104 | batch_size: int, 105 | initial_lr: float, 106 | lr_decay_rate: float, 107 | marketplace: list[Spec], 108 | target_new_classes: tuple[int] = (3,), 109 | balance_labels: bool = True, 110 | augment_old: bool = False, 111 | augment_new: bool = False, 112 | new_train_size: int = 16, 113 | probe_scale: float | None = None, 114 | forward_pass: int = 1, 115 | metrics_per_steps: int = 10, 116 | input_checkpoint_filepath: pathlib.Path | None = None, 117 | checkpoint_filepath: pathlib.Path | None = None, 118 | checkpoint_per_steps: int = 1000, 119 | replay_file: typing.TextIO | None = None, 120 | manual_seed: int | None = None, 121 | ): 122 | logger.info( 123 | "Running beautiful MNIST continual learning with step_count=%s, batch_size=%s, init_lr=%s, lr_decay=%s, " 124 | "target_new_classes=%s, balance_labels=%s, augment_old=%s, augment_new=%s, new_train_size=%s, probe_scale=%s, " 125 | "forward_pass=%s, metrics_per_steps=%s, input_checkpoint_filepath=%s, checkpoint_filepath=%s, " 126 | "checkpoint_per_steps=%s, manual_seed=%s", 127 | step_count, 128 | batch_size, 129 | initial_lr, 130 | lr_decay_rate, 131 | target_new_classes, 132 | balance_labels, 133 | augment_old, 134 | augment_new, 135 | new_train_size, 136 | probe_scale, 137 | forward_pass, 138 | metrics_per_steps, 139 | input_checkpoint_filepath, 140 | checkpoint_filepath, 141 | checkpoint_per_steps, 142 | manual_seed, 143 | ) 144 | 145 | mlflow.log_param("step_count", step_count) 146 | mlflow.log_param("batch_size", batch_size) 147 | mlflow.log_param("forward_pass", forward_pass) 148 | mlflow.log_param("lr", initial_lr) 149 | mlflow.log_param("lr_decay_rate", lr_decay_rate) 150 | mlflow.log_param("target_new_classes", target_new_classes) 151 | mlflow.log_param("balance_labels", balance_labels) 152 | mlflow.log_param("augment_old", augment_old) 153 | mlflow.log_param("augment_new", augment_new) 154 | mlflow.log_param("new_train_size", new_train_size) 155 | mlflow.log_param("probe_scale", probe_scale) 156 | mlflow.log_param("metrics_per_steps", metrics_per_steps) 157 | mlflow.log_param("checkpoint_per_steps", checkpoint_per_steps) 158 | mlflow.log_param("manual_seed", manual_seed) 159 | 160 | if input_checkpoint_filepath is not None: 161 | load_checkpoint( 162 | marketplace=marketplace, input_filepath=input_checkpoint_filepath 163 | ) 164 | 165 | if manual_seed is not None: 166 | Tensor.manual_seed(manual_seed) 167 | 168 | X_train, Y_train, X_test, Y_test = mnist() 169 | new_X_train, new_Y_train, new_X_test, new_Y_test = mnist(fashion=True) 170 | 171 | if target_new_classes is not None: 172 | class_mask = np.isin(new_Y_train.numpy(), target_new_classes) 173 | target_new_X_train = Tensor(new_X_train.numpy()[class_mask]) 174 | target_new_Y_train = Tensor(new_Y_train.numpy()[class_mask]) 175 | class_mask = np.isin(new_Y_test.numpy(), target_new_classes) 176 | target_new_X_test = Tensor(new_X_test.numpy()[class_mask]) 177 | target_new_Y_test = Tensor(new_Y_test.numpy()[class_mask]) 178 | else: 179 | target_new_X_train = new_X_train 180 | target_new_Y_train = new_Y_train 181 | target_new_X_test = new_X_test 182 | target_new_Y_test = new_Y_test 183 | 184 | lr = Tensor(initial_lr).contiguous().realize() 185 | optimizer = Optimizer( 186 | marketplace=marketplace, 187 | learning_rate=lr, 188 | probe_scale=(Tensor(probe_scale) if probe_scale is not None else None), 189 | ) 190 | 191 | @TinyJit 192 | def forward_step( 193 | old_x: Tensor, 194 | old_y: Tensor, 195 | new_x: Tensor, 196 | new_y: Tensor, 197 | ) -> tuple[Tensor, Tensor, Tensor]: 198 | combined_x = Tensor.cat(old_x, new_x) 199 | combined_y = Tensor.cat(old_y, new_y) 200 | 201 | batch_paths = Tensor.stack( 202 | *( 203 | Tensor.randint( 204 | batch_size, low=0, high=spec.vendor_count, dtype=dtypes.uint 205 | ) 206 | for spec in marketplace 207 | ), 208 | dim=1, 209 | ) 210 | 211 | logits = forward_with_paths( 212 | marketplace=marketplace, 213 | paths=batch_paths, 214 | x=combined_x, 215 | deltas=[ctx.delta for ctx in optimizer.spec_context], 216 | ) 217 | loss = logits.sparse_categorical_crossentropy(combined_y, reduction="none") 218 | 219 | if balance_labels: 220 | # Notice: by adding the target classes from new dataset, we are changing the probability of each number 221 | # appearing. Not sure if it matters, but to make the model harder to blindly guess, we are balancing the 222 | # unbalanced labels by introducing the cross entropy weights. 223 | weights = np.repeat((1 / LABEL_COUNT) * len(old_x), LABEL_COUNT) 224 | weights[target_new_classes] += (1 / len(target_new_classes)) * len(new_x) 225 | weights = batch_size / weights 226 | max_weight = weights.max() 227 | weights = Tensor(weights / max_weight, dtype=dtypes.default_float) 228 | loss *= weights[combined_y] 229 | correct = logits.argmax(axis=1) == combined_y 230 | return ( 231 | loss.realize(), 232 | correct.realize(), 233 | batch_paths.realize(), 234 | ) 235 | 236 | @TinyJit 237 | def optimize_step(loss: Tensor, paths: Tensor): 238 | direction_vectors = optimizer.compute_direction_vectors( 239 | loss=loss, 240 | paths=paths, 241 | unit_vector_mode=UnitVectorMode.whole, 242 | ) 243 | Tensor.realize( 244 | *optimizer.schedule_weight_update( 245 | direction_delta=direction_vectors, 246 | ) 247 | ) 248 | Tensor.realize(*optimizer.schedule_seeds_update()) 249 | Tensor.realize(*optimizer.schedule_delta_update()) 250 | 251 | @TinyJit 252 | def get_test_acc() -> tuple[Tensor, Tensor]: 253 | old = ( 254 | straight_forward(marketplace, X_test).argmax(axis=1) == Y_test 255 | ).mean() * 100 256 | new = ( 257 | straight_forward(marketplace, target_new_X_test).argmax(axis=1) 258 | == target_new_Y_test 259 | ).mean() * 100 260 | return old.realize(), new.realize() 261 | 262 | i = 0 263 | old_test_acc = float("nan") 264 | new_test_acc = float("nan") 265 | for i in (t := trange(step_count)): 266 | GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing 267 | 268 | start_time = time.perf_counter() 269 | old_train_size = batch_size - new_train_size 270 | 271 | all_old_samples = [] 272 | all_new_samples = [] 273 | all_old_correct = [] 274 | all_new_correct = [] 275 | all_loss = [] 276 | all_paths = [] 277 | all_old_loss = [] 278 | all_old_accuracy = [] 279 | all_new_loss = [] 280 | all_new_accuracy = [] 281 | for _ in range(forward_pass): 282 | old_samples = Tensor.randint( 283 | old_train_size, low=0, high=old_train_size, dtype=dtypes.uint 284 | ) 285 | new_samples = Tensor.randint( 286 | new_train_size, low=0, high=new_train_size, dtype=dtypes.uint 287 | ) 288 | old_x = X_train[old_samples] 289 | old_y = Y_train[old_samples] 290 | new_x = target_new_X_train[new_samples] 291 | new_y = target_new_Y_train[new_samples] 292 | # TODO: a bit slow, ideally run with a background loader 293 | if augment_old: 294 | old_x = old_x.reshape(-1, 28, 28).numpy().astype(np.uint8) 295 | old_x = Tensor( 296 | augment_img(old_x).reshape(-1, 1, 28, 28), 297 | dtype=dtypes.default_float, 298 | ) 299 | if augment_new: 300 | new_x = new_x.reshape(-1, 28, 28).numpy().astype(np.uint8) 301 | new_x = Tensor( 302 | augment_img(new_x).reshape(-1, 1, 28, 28), 303 | dtype=dtypes.default_float, 304 | ) 305 | 306 | loss, correct, paths = forward_step( 307 | old_x=old_x, 308 | old_y=old_y, 309 | new_x=new_x, 310 | new_y=new_y, 311 | ) 312 | old_loss = loss[:old_train_size].mean() 313 | old_correct = correct[:old_train_size] 314 | old_accuracy = old_correct.mean() 315 | 316 | new_loss = loss[old_train_size:].mean() 317 | new_correct = correct[old_train_size:] 318 | new_accuracy = new_correct.mean() 319 | 320 | all_old_samples.append(old_samples.numpy()) 321 | all_new_samples.append(new_samples.numpy()) 322 | all_loss.append(loss) 323 | all_paths.append(paths) 324 | all_old_loss.append(old_loss.numpy()) 325 | all_old_correct.append(old_correct.numpy()) 326 | all_old_accuracy.append(old_accuracy.numpy()) 327 | all_new_loss.append(new_loss.numpy()) 328 | all_new_correct.append(new_correct.numpy()) 329 | all_new_accuracy.append(new_accuracy.numpy()) 330 | 331 | optimize_step(Tensor.cat(*all_loss), Tensor.cat(*all_paths)) 332 | 333 | old_loss = np.array(all_old_loss).mean() 334 | old_accuracy = np.array(all_old_accuracy).mean() * 100 335 | new_loss = np.array(all_new_loss).mean() 336 | new_accuracy = np.array(all_new_accuracy).mean() * 100 337 | 338 | end_time = time.perf_counter() 339 | run_time = end_time - start_time 340 | lr.assign(lr * (1 - lr_decay_rate)).realize() 341 | gflops = GlobalCounters.global_ops * 1e-9 / run_time 342 | 343 | if i % metrics_per_steps == (metrics_per_steps - 1): 344 | old_test_acc, new_test_acc = get_test_acc() 345 | old_test_acc = old_test_acc.item() 346 | new_test_acc = new_test_acc.item() 347 | mlflow.log_metric("learning/old_loss", old_loss.item(), step=i) 348 | mlflow.log_metric("learning/old_accuracy", old_accuracy.item(), step=i) 349 | mlflow.log_metric("learning/new_loss", new_loss.item(), step=i) 350 | mlflow.log_metric("learning/new_accuracy", new_accuracy.item(), step=i) 351 | mlflow.log_metric("learning/lr", lr.item(), step=i) 352 | mlflow.log_metric("learning/gflops", gflops, step=i) 353 | mlflow.log_metric("testing/old_accuracy", old_test_acc, step=i) 354 | mlflow.log_metric("testing/new_accuracy", new_test_acc, step=i) 355 | if replay_file is not None: 356 | replay_file.write( 357 | json.dumps( 358 | dict( 359 | old_samples=np.concatenate(all_old_samples).tolist(), 360 | old_correct=np.concatenate(all_old_correct).tolist(), 361 | old_loss=np.array(all_old_loss).tolist(), 362 | old_test_acc=old_test_acc, 363 | new_samples=np.concatenate(all_new_samples).tolist(), 364 | new_correct=np.concatenate(all_new_correct).tolist(), 365 | new_loss=np.array(all_new_loss).tolist(), 366 | new_test_acc=new_test_acc, 367 | global_step=i, 368 | ) 369 | ) 370 | + "\n" 371 | ) 372 | 373 | if checkpoint_filepath is not None and i % checkpoint_per_steps == ( 374 | checkpoint_per_steps - 1 375 | ): 376 | write_checkpoint( 377 | marketplace=marketplace, 378 | global_step=i, 379 | output_filepath=pathlib.Path(checkpoint_filepath), 380 | ) 381 | 382 | t.set_description( 383 | f"loss: {old_loss.item():6.2f}/{new_loss.item():6.2f}, rl: {lr.item():.2e}, " 384 | f"acc: {old_accuracy.item():.2f}%/{new_accuracy.item():.2f}%, " 385 | f"vacc: {old_test_acc:.2f}%/{new_test_acc:.2f}%, {gflops:9,.2f} GFLOPS" 386 | ) 387 | if i is not None and checkpoint_filepath is not None: 388 | write_checkpoint( 389 | marketplace=marketplace, 390 | global_step=i, 391 | output_filepath=pathlib.Path(checkpoint_filepath), 392 | ) 393 | 394 | 395 | @click.command() 396 | @click.option("--step-count", type=int, default=10_000, help="How many steps to run") 397 | @click.option("--batch-size", type=int, default=256, help="Size of batch") 398 | @click.option( 399 | "--initial-lr", type=float, default=1e-2, help="Initial learning rate value" 400 | ) 401 | @click.option("--lr-decay", type=float, default=0, help="Learning rate decay rate") 402 | @click.option("--vendor-count", type=int, default=4, help="Vendor count") 403 | @click.option( 404 | "--forward-pass", 405 | type=int, 406 | default=1, 407 | help="How many forward pass to run (simulate distributed computing)", 408 | ) 409 | @click.option("--seed", type=int, help="Set the random seed") 410 | @click.option( 411 | "--probe-scale", 412 | type=float, 413 | default=1, 414 | help="The scale we use to apply on LR for making the reconciled delta direction", 415 | ) 416 | @click.option( 417 | "--input-checkpoint-filepath", 418 | type=click.Path(dir_okay=False, readable=True, exists=True), 419 | default="continual-learning-fashion.safetensors", 420 | help="Filepath of checkpoint to read from", 421 | ) 422 | @click.option( 423 | "--checkpoint-filepath", 424 | type=click.Path(dir_okay=False, writable=True), 425 | help="Filepath of checkpoint to write to", 426 | ) 427 | @click.option( 428 | "--checkpoint-per-steps", 429 | type=int, 430 | default=100, 431 | help="For how many steps we should write a checkpoint", 432 | ) 433 | @click.option( 434 | "--replay-file", 435 | type=click.Path(dir_okay=False, writable=True), 436 | help="Filepath of replay JSON file to write to", 437 | ) 438 | @click.option("--run-name", type=str, help="Set the run name") 439 | def main( 440 | step_count: int, 441 | batch_size: int, 442 | initial_lr: float, 443 | lr_decay: float, 444 | vendor_count: int, 445 | forward_pass: int, 446 | seed: int | None, 447 | probe_scale: float | None, 448 | input_checkpoint_filepath: str, 449 | checkpoint_filepath: str, 450 | checkpoint_per_steps: int, 451 | replay_file: str | None, 452 | run_name: str | None, 453 | ): 454 | # ref: https://github.com/tinygrad/tinygrad/issues/8617 455 | # With complex huge compute graph, tinygrad runs into recursion too deep issue, let's bump it up 456 | NEW_RECURSION_LIMIT = 100_000 457 | logger.info("Current recursion limit is %s", sys.getrecursionlimit()) 458 | sys.setrecursionlimit(NEW_RECURSION_LIMIT) 459 | logger.info("Set recursion limit to %s", NEW_RECURSION_LIMIT) 460 | if replay_file is not None: 461 | replay_filepath = pathlib.Path(replay_file) 462 | replay_file_ctx = replay_filepath.open("wt") 463 | else: 464 | replay_file_ctx = nullcontext() 465 | with ( 466 | mlflow.start_run( 467 | experiment_id=ensure_experiment("Continual Learning - Fashion"), 468 | run_name="beautiful-mnist" if run_name is None else run_name, 469 | ), 470 | replay_file_ctx as replay_file, 471 | ): 472 | mlflow.log_param("vendor_count", vendor_count) 473 | learn( 474 | step_count=step_count, 475 | batch_size=batch_size, 476 | initial_lr=initial_lr, 477 | lr_decay_rate=lr_decay, 478 | probe_scale=probe_scale if probe_scale else None, 479 | forward_pass=forward_pass, 480 | manual_seed=seed, 481 | marketplace=make_marketplace( 482 | default_vendor_count=vendor_count, 483 | ), 484 | input_checkpoint_filepath=( 485 | pathlib.Path(input_checkpoint_filepath) 486 | if input_checkpoint_filepath is not None 487 | else None 488 | ), 489 | checkpoint_filepath=( 490 | pathlib.Path(checkpoint_filepath) 491 | if checkpoint_filepath is not None 492 | else None 493 | ), 494 | checkpoint_per_steps=checkpoint_per_steps, 495 | replay_file=replay_file, 496 | ) 497 | 498 | 499 | if __name__ == "__main__": 500 | logging.basicConfig(level=logging.INFO) 501 | main() 502 | -------------------------------------------------------------------------------- /experiments/long_run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | PYRAMID32_HALF_UPSTREAM_STRUCTURE = [ 12 | # layer 0 13 | (2, 0), 14 | # layer 1 15 | (4, 2), 16 | # layer 2 (N/A) 17 | (0, 0), 18 | # layer 3 19 | (8, 4), 20 | # layer 4 21 | (16, 8), 22 | # layer 5 (N/A) 23 | (0, 0), 24 | # layer 6 25 | (32, 16), 26 | ] 27 | 28 | 29 | def main(): 30 | exp_id = ensure_experiment("Long Run") 31 | with mlflow.start_run( 32 | run_name=f"long-run-v2", 33 | experiment_id=exp_id, 34 | description="Find out how learning rate and decay rate affects the training process", 35 | log_system_metrics=True, 36 | tags=dict(round="5"), 37 | ): 38 | marketplace = make_marketplace(PYRAMID32_HALF_UPSTREAM_STRUCTURE) 39 | train( 40 | step_count=100_000, 41 | batch_size=32, 42 | initial_lr=1e-3, 43 | lr_decay_rate=1e-3, 44 | marketplace=marketplace, 45 | ) 46 | 47 | 48 | if __name__ == "__main__": 49 | logging.basicConfig(level=logging.INFO) 50 | main() 51 | -------------------------------------------------------------------------------- /experiments/low_precision.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | from tinygrad import Context 5 | 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | PYRAMID64_HALF_UPSTREAM_STRUCTURE = [ 13 | # layer 0 14 | (4, 0), 15 | # layer 1 16 | (8, 4), 17 | # layer 2 (N/A) 18 | (0, 0), 19 | # layer 3 20 | (16, 8), 21 | # layer 4 22 | (32, 16), 23 | # layer 5 (N/A) 24 | (0, 0), 25 | # layer 6 26 | (64, 32), 27 | ] 28 | 29 | 30 | def main(): 31 | exp_id = ensure_experiment("Low Precision") 32 | for fp16 in [ 33 | # False, 34 | True, 35 | ]: 36 | with mlflow.start_run( 37 | run_name=f"fp-16-{fp16}", 38 | experiment_id=exp_id, 39 | description="Find out if low precision training make any difference", 40 | log_system_metrics=True, 41 | tags=dict(round="5"), 42 | ): 43 | ctx_values = {} 44 | # XXX: it seems like FLOAT16 can only be passed in by env var with tinygrad? 45 | # if fp16: 46 | # ctx_values["FLOAT16"] = 1 47 | with Context(**ctx_values): 48 | mlflow.log_param("fp16", fp16) 49 | train( 50 | step_count=10_000, 51 | batch_size=32, 52 | initial_lr=1e-3, 53 | lr_decay_rate=4.5e-4, 54 | mp_structure=PYRAMID64_HALF_UPSTREAM_STRUCTURE, 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | logging.basicConfig(level=logging.INFO) 60 | main() 61 | -------------------------------------------------------------------------------- /experiments/lr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | VENDOR_COUNT = 8 12 | 13 | 14 | def main(): 15 | exp_id = ensure_experiment("Param Attribution LR with Unit Vector Fixed") 16 | for probe_scale in map(lambda x: 0.1 + x * 0.025, range(0, 10, 2)): 17 | for lr in map(lambda x: 0.1 + x * 0.025, range(0, 10, 2)): 18 | for decay in [1e-5]: 19 | probe_str = f"{probe_scale:.1e}" if probe_scale is not None else "none" 20 | with mlflow.start_run( 21 | run_name=f"probe-scale-{probe_str}-lr-{lr:.1e}-decay-{decay:.1e}", 22 | experiment_id=exp_id, 23 | log_system_metrics=True, 24 | ): 25 | marketplace = make_marketplace(default_vendor_count=VENDOR_COUNT) 26 | mlflow.log_param("vendor_count", VENDOR_COUNT) 27 | train( 28 | step_count=1_000, 29 | batch_size=512, 30 | initial_lr=lr, 31 | lr_decay_rate=decay, 32 | probe_scale=probe_scale, 33 | marketplace=marketplace, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | logging.basicConfig(level=logging.INFO) 39 | main() 40 | -------------------------------------------------------------------------------- /experiments/lr_scale.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def main(): 13 | exp_id = ensure_experiment("LR Scale") 14 | for lr in [1e-2, 1e-3, 1e-4]: 15 | with mlflow.start_run( 16 | run_name=f"lr-scale-lr-{lr:.1e}-no-scale", 17 | experiment_id=exp_id, 18 | log_system_metrics=True, 19 | ): 20 | marketplace = make_marketplace(default_vendor_count=8) 21 | mlflow.log_param("vendor_count", 8) 22 | train( 23 | step_count=1_000, 24 | batch_size=512, 25 | initial_lr=lr, 26 | lr_decay_rate=1e-4, 27 | marketplace=marketplace, 28 | ) 29 | for lr_scale_start in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]: 30 | for lr_scale_end in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]: 31 | with mlflow.start_run( 32 | run_name=f"lr-scale-lr-{lr:.1e}-scale-{lr_scale_start}-{lr_scale_end}", 33 | experiment_id=exp_id, 34 | log_system_metrics=True, 35 | ): 36 | marketplace = make_marketplace(default_vendor_count=8) 37 | mlflow.log_param("vendor_count", 8) 38 | train( 39 | step_count=1_000, 40 | batch_size=512, 41 | initial_lr=lr, 42 | lr_decay_rate=1e-4, 43 | lr_scaling_range=(lr_scale_start, lr_scale_end), 44 | marketplace=marketplace, 45 | ) 46 | 47 | 48 | if __name__ == "__main__": 49 | logging.basicConfig(level=logging.INFO) 50 | main() 51 | -------------------------------------------------------------------------------- /experiments/lr_vendor16.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | VENDOR_COUNT = 16 12 | 13 | 14 | def main(): 15 | exp_id = ensure_experiment("Param Attribution LR Vendor 16") 16 | for probe_scale in map(lambda x: 0.1 + x * 0.025, range(10)): 17 | for lr in [0.06, 0.08, 0.2]: 18 | for decay in [1e-5]: 19 | probe_str = f"{probe_scale:.1e}" if probe_scale is not None else "none" 20 | with mlflow.start_run( 21 | run_name=f"probe-scale-{probe_str}-lr-{lr:.1e}-decay-{decay:.1e}", 22 | experiment_id=exp_id, 23 | log_system_metrics=True, 24 | ): 25 | marketplace = make_marketplace(default_vendor_count=VENDOR_COUNT) 26 | mlflow.log_param("vendor_count", VENDOR_COUNT) 27 | train( 28 | step_count=1_000, 29 | batch_size=512, 30 | initial_lr=lr, 31 | lr_decay_rate=decay, 32 | probe_scale=probe_scale, 33 | marketplace=marketplace, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | logging.basicConfig(level=logging.INFO) 39 | main() 40 | -------------------------------------------------------------------------------- /experiments/market_depth.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import typing 4 | 5 | import mlflow 6 | from tinygrad import Tensor 7 | 8 | from .beautiful_mnist import make_marketplace 9 | from .beautiful_mnist import train 10 | from .utils import ensure_experiment 11 | from marketplace.multi_nn import MultiConv2d 12 | from marketplace.multi_nn import MultiInstanceNorm 13 | from marketplace.multi_nn import MultiLinear 14 | from marketplace.multi_nn import MultiModel 15 | from marketplace.multi_nn import MultiModelBase 16 | from marketplace.training import Spec 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def make_marketplace_depth_1( 22 | vendor_count: int, 23 | norm_cls: typing.Type[MultiModelBase] = MultiInstanceNorm, 24 | ): 25 | return [ 26 | Spec( 27 | model=MultiModel( 28 | [ 29 | MultiConv2d(vendor_count, 1, 32, 5), 30 | Tensor.relu, 31 | MultiConv2d(vendor_count, 32, 32, 5), 32 | Tensor.relu, 33 | norm_cls(vendor_count, 32), 34 | Tensor.max_pool2d, 35 | MultiConv2d(vendor_count, 32, 64, 3), 36 | Tensor.relu, 37 | MultiConv2d(vendor_count, 64, 64, 3), 38 | Tensor.relu, 39 | norm_cls(vendor_count, 64), 40 | Tensor.max_pool2d, 41 | lambda x: x.flatten(1), 42 | MultiLinear(vendor_count, 576, 10), 43 | ] 44 | ), 45 | ), 46 | ] 47 | 48 | 49 | def main(): 50 | exp_id = ensure_experiment("Market Depth V2") 51 | for market_depth, vendor_count in [ 52 | (1, 8), 53 | (1, 16), 54 | (1, 32), 55 | (1, 64), 56 | (3, 8), 57 | (3, 16), 58 | ]: 59 | with mlflow.start_run( 60 | run_name=f"market-depth-{market_depth}-vendor-{vendor_count}", 61 | experiment_id=exp_id, 62 | description="Find out how market depth affects performance", 63 | log_system_metrics=True, 64 | ): 65 | if market_depth == 3: 66 | marketplace = make_marketplace(default_vendor_count=vendor_count) 67 | elif market_depth == 1: 68 | marketplace = make_marketplace_depth_1(vendor_count) 69 | else: 70 | raise ValueError(f"Unexpected depth {market_depth}") 71 | train( 72 | step_count=10_000, 73 | batch_size=512, 74 | initial_lr=1e-3, 75 | lr_decay_rate=1e-4, 76 | marketplace=marketplace, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | logging.basicConfig(level=logging.INFO) 82 | 83 | # Notice: to run this exp, you may need to set a bigger MAX_KERNEL_BUFFERS value 84 | # like export MAX_KERNEL_BUFFERS=100 85 | 86 | # ref: https://github.com/tinygrad/tinygrad/issues/8617 87 | # With complex huge compute graph, tinygrad runs into recursion too deep issue, let's bump it up 88 | NEW_RECURSION_LIMIT = 100_000 89 | logger.info("Current recursion limit is %s", sys.getrecursionlimit()) 90 | logger.info("Set recursion limit to %s", NEW_RECURSION_LIMIT) 91 | sys.setrecursionlimit(NEW_RECURSION_LIMIT) 92 | main() 93 | -------------------------------------------------------------------------------- /experiments/mnist_backprop_cmp.py: -------------------------------------------------------------------------------- 1 | # model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392 2 | import logging 3 | import os 4 | import time 5 | import typing 6 | from typing import Callable 7 | 8 | import mlflow 9 | from tinygrad import dtypes 10 | from tinygrad import GlobalCounters 11 | from tinygrad import nn 12 | from tinygrad import Tensor 13 | from tinygrad import TinyJit 14 | from tinygrad.device import is_dtype_supported 15 | from tinygrad.helpers import colored 16 | from tinygrad.helpers import getenv 17 | from tinygrad.helpers import trange 18 | from tinygrad.nn.datasets import mnist 19 | from tinygrad.nn.state import load_state_dict 20 | from tinygrad.nn.state import safe_load 21 | 22 | from experiments.beautiful_mnist import make_marketplace 23 | from experiments.beautiful_mnist import train 24 | from experiments.utils import ensure_experiment 25 | 26 | DEPTH_3_MODEL_STATE_KEY_MAP = { 27 | "spec.0.layers.0.bias": "layers.0.bias", 28 | "spec.0.layers.0.weight": "layers.0.weight", 29 | "spec.0.layers.2.bias": "layers.2.bias", 30 | "spec.0.layers.2.weight": "layers.2.weight", 31 | "spec.0.layers.4.bias": "layers.4.bias", 32 | "spec.0.layers.4.weight": "layers.4.weight", 33 | "spec.0.layers.4.running_mean": "layers.4.running_mean", 34 | "spec.0.layers.4.running_var": "layers.4.running_var", 35 | "spec.1.layers.0.bias": "layers.6.bias", 36 | "spec.1.layers.0.weight": "layers.6.weight", 37 | "spec.1.layers.2.bias": "layers.8.bias", 38 | "spec.1.layers.2.weight": "layers.8.weight", 39 | "spec.1.layers.4.bias": "layers.10.bias", 40 | "spec.1.layers.4.weight": "layers.10.weight", 41 | "spec.1.layers.4.running_mean": "layers.10.running_mean", 42 | "spec.1.layers.4.running_var": "layers.10.running_var", 43 | "spec.2.layers.0.bias": "layers.13.bias", 44 | "spec.2.layers.0.weight": "layers.13.weight", 45 | } 46 | BATCH_NORM_KEYS = ["layers.4", "layers.10"] 47 | logger = logging.getLogger(__name__) 48 | 49 | 50 | class Model: 51 | def __init__(self, norm_cls: typing.Callable = nn.InstanceNorm): 52 | self.layers: list[Callable[[Tensor], Tensor]] = [ 53 | nn.Conv2d(1, 32, 5), 54 | Tensor.relu, 55 | nn.Conv2d(32, 32, 5), 56 | Tensor.relu, 57 | norm_cls(32), 58 | Tensor.max_pool2d, 59 | nn.Conv2d(32, 64, 3), 60 | Tensor.relu, 61 | nn.Conv2d(64, 64, 3), 62 | Tensor.relu, 63 | norm_cls(64), 64 | Tensor.max_pool2d, 65 | lambda x: x.flatten(1), 66 | nn.Linear(576, 10), 67 | ] 68 | 69 | def __call__(self, x: Tensor) -> Tensor: 70 | return x.sequential(self.layers) 71 | 72 | 73 | def train_mnist( 74 | optimizer_type: str = "adam", 75 | lr: float = 1e-3, 76 | batch_size: int = 512, 77 | step_count: int = 1_000, 78 | ): 79 | X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION")) 80 | 81 | model = Model(norm_cls=nn.BatchNorm) 82 | 83 | model_weights_filepath = os.environ.get("MODEL_WEIGHTS") 84 | if model_weights_filepath is not None: 85 | logger.info("Loading model weights from %s", model_weights_filepath) 86 | model_state = safe_load(model_weights_filepath) 87 | converted_state = { 88 | DEPTH_3_MODEL_STATE_KEY_MAP[key]: value 89 | for key, value in model_state.items() 90 | if key in DEPTH_3_MODEL_STATE_KEY_MAP 91 | } 92 | for key in BATCH_NORM_KEYS: 93 | converted_state[f"{key}.num_batches_tracked"] = Tensor.zeros( 94 | 1, 95 | dtype="long" if is_dtype_supported(dtypes.long) else "int", 96 | requires_grad=False, 97 | ) 98 | load_state_dict(model, converted_state) 99 | logger.info("Model weight loaded") 100 | 101 | if optimizer_type == "adam": 102 | opt = nn.optim.Adam(nn.state.get_parameters(model)) 103 | mlflow.log_param("optimizer", "adam") 104 | elif optimizer_type == "muon": 105 | opt = nn.optim.Muon(nn.state.get_parameters(model), lr=lr) 106 | mlflow.log_param("lr", lr) 107 | elif optimizer_type == "sgd": 108 | opt = nn.optim.SGD( 109 | nn.state.get_parameters(model), 110 | lr=lr, 111 | ) 112 | mlflow.log_param("lr", lr) 113 | else: 114 | raise ValueError("Unexpected type") 115 | mlflow.log_param("optimizer", optimizer_type) 116 | 117 | @TinyJit 118 | @Tensor.train() 119 | def train_step() -> Tensor: 120 | opt.zero_grad() 121 | samples = Tensor.randint(batch_size, high=X_train.shape[0]) 122 | loss = ( 123 | model(X_train[samples]) 124 | .sparse_categorical_crossentropy(Y_train[samples]) 125 | .backward() 126 | ) 127 | opt.step() 128 | return loss 129 | 130 | @TinyJit 131 | def get_test_acc() -> Tensor: 132 | return (model(X_test).argmax(axis=1) == Y_test).mean() * 100 133 | 134 | test_acc = float("nan") 135 | for i in (t := trange(step_count)): 136 | GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing 137 | loss = train_step() 138 | start_time = time.perf_counter() 139 | run_time = time.perf_counter() - start_time 140 | gflops = GlobalCounters.global_ops * 1e-9 / run_time 141 | if i % 10 == 9: 142 | test_acc = get_test_acc().item() 143 | mlflow.log_metric("training/loss", loss.item(), step=i) 144 | mlflow.log_metric("training/gflops", gflops, step=i) 145 | mlflow.log_metric("testing/accuracy", test_acc, step=i) 146 | t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%") 147 | 148 | # verify eval acc 149 | if target := getenv("TARGET_EVAL_ACC_PCT", 0.0): 150 | if test_acc >= target and test_acc != 100.0: 151 | print(colored(f"{test_acc=} >= {target}", "green")) 152 | else: 153 | raise ValueError(colored(f"{test_acc=} < {target}", "red")) 154 | 155 | 156 | if __name__ == "__main__": 157 | step_count = 1_000 158 | exp_id = ensure_experiment("Backprop Comparison V4") 159 | for optimizer_type in ["adam", "muon", "sgd"]: 160 | if optimizer_type == "adam": 161 | with mlflow.start_run( 162 | run_name=f"backprop-adam", 163 | experiment_id=exp_id, 164 | log_system_metrics=True, 165 | ): 166 | train_mnist(optimizer_type=optimizer_type, step_count=step_count) 167 | else: 168 | for lr_base in [1e-2, 1e-3, 1e-4]: 169 | for lr in list(map(lambda x: x * 1e-3, range(1, 10))): 170 | with mlflow.start_run( 171 | run_name=f"backprop-{optimizer_type}-lr-{lr}", 172 | experiment_id=exp_id, 173 | log_system_metrics=True, 174 | ): 175 | train_mnist( 176 | optimizer_type=optimizer_type, lr=lr, step_count=step_count 177 | ) 178 | with mlflow.start_run( 179 | run_name="marketplace-v2", 180 | experiment_id=exp_id, 181 | log_system_metrics=True, 182 | ): 183 | marketplace = make_marketplace(default_vendor_count=16) 184 | train( 185 | step_count=step_count, 186 | batch_size=512, 187 | initial_lr=1e-1, 188 | lr_decay_rate=1e-5, 189 | probe_scale=1e-1, 190 | marketplace=marketplace, 191 | manual_seed=42, 192 | ) 193 | -------------------------------------------------------------------------------- /experiments/resnet18.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import logging 4 | import pathlib 5 | import random 6 | import time 7 | import typing 8 | 9 | import mlflow 10 | import numpy as np 11 | from PIL import Image 12 | from tinygrad import nn 13 | from tinygrad import Tensor 14 | from tinygrad import TinyJit 15 | from tinygrad.helpers import GlobalCounters 16 | from tinygrad.helpers import trange 17 | from tinyloader.loader import load 18 | from tinyloader.loader import load_with_workers 19 | from tinyloader.loader import Loader 20 | 21 | from .utils import ensure_experiment 22 | from marketplace.multi_nn import MultiConv2d 23 | from marketplace.multi_nn import MultiInstanceNorm 24 | from marketplace.multi_nn import MultiLinear 25 | from marketplace.multi_nn import MultiModel 26 | from marketplace.multi_nn import MultiModelBase 27 | from marketplace.training import forward 28 | from marketplace.training import forward_with_path 29 | from marketplace.training import mutate 30 | from marketplace.training import Spec 31 | 32 | 33 | def get_train_files(basedir: pathlib.Path) -> list[str]: 34 | if not (files := glob.glob(p := str(basedir / "train/*/*"))): 35 | raise FileNotFoundError(f"No training files in {p}") 36 | return files 37 | 38 | 39 | def get_val_files(basedir: pathlib.Path) -> list[str]: 40 | if not (files := glob.glob(p := str(basedir / "val/*/*"))): 41 | raise FileNotFoundError(f"No training files in {p}") 42 | return files 43 | 44 | 45 | def get_imagenet_categories(basedir: pathlib.Path) -> dict[str, int]: 46 | ci = json.load(open(basedir / "imagenet_class_index.json")) 47 | return {v[0]: int(k) for k, v in ci.items()} 48 | 49 | 50 | def center_crop(img: Image) -> Image: 51 | rescale = min(img.size) / 256 52 | crop_left = (img.width - 224 * rescale) / 2.0 53 | crop_top = (img.height - 224 * rescale) / 2.0 54 | img = img.resize( 55 | (224, 224), 56 | Image.BILINEAR, 57 | box=(crop_left, crop_top, crop_left + 224 * rescale, crop_top + 224 * rescale), 58 | ) 59 | return img 60 | 61 | 62 | class ImageLoader(Loader): 63 | def __init__(self, img_categories: dict[str, int]): 64 | super().__init__() 65 | self.img_categories = img_categories 66 | 67 | def make_request(self, item: pathlib.Path) -> typing.Any: 68 | return item 69 | 70 | def load(self, request: pathlib.Path) -> tuple[np.typing.NDArray, ...]: 71 | x = Image.open(request) 72 | x = center_crop(x) 73 | x = np.transpose(np.asarray(x), (2, 0, 1)) 74 | y = self.img_categories[request.parts[-2]] 75 | return x, np.array(y) 76 | 77 | def post_process( 78 | self, response: tuple[np.typing.NDArray, ...] 79 | ) -> tuple[Tensor, ...]: 80 | x, y = response 81 | x = Tensor(x.copy()).contiguous().realize() 82 | y = Tensor(y).realize() 83 | return x, y 84 | 85 | 86 | class BasicBlock(MultiModelBase): 87 | def __init__( 88 | self, vendor_count: int, in_channels: int, out_channels: int, stride: int = 1 89 | ): 90 | super().__init__() 91 | self.vendor_count = vendor_count 92 | self.conv1 = MultiConv2d( 93 | vendor_count, 94 | in_channels, 95 | out_channels, 96 | kernel_size=3, 97 | stride=stride, 98 | padding=1, 99 | bias=False, 100 | ) 101 | self.bn1 = MultiInstanceNorm(vendor_count, out_channels) 102 | self.conv2 = MultiConv2d( 103 | vendor_count, 104 | out_channels, 105 | out_channels, 106 | kernel_size=3, 107 | stride=1, 108 | padding=1, 109 | bias=False, 110 | ) 111 | self.bn2 = MultiInstanceNorm(vendor_count, out_channels) 112 | 113 | self.downsample = lambda i, x: x 114 | if stride != 1 or in_channels != out_channels: 115 | self.downsample = MultiModel( 116 | [ 117 | MultiConv2d( 118 | vendor_count, 119 | in_channels, 120 | out_channels, 121 | kernel_size=1, 122 | stride=stride, 123 | bias=False, 124 | ), 125 | MultiInstanceNorm(vendor_count, out_channels), 126 | ] 127 | ) 128 | 129 | def __call__(self, i: Tensor, x: Tensor) -> Tensor: 130 | out = self.conv1(i, x) 131 | out = self.bn1(i, out) 132 | out = out.relu() 133 | out = self.conv2(i, out) 134 | out = self.bn2(i, out) 135 | out += self.downsample(i, x) 136 | out = out.relu() 137 | return out 138 | 139 | 140 | def make_marketplace(num_classes: int = 100, default_vendor_count: int = 4): 141 | layer0_vendor_count = default_vendor_count 142 | layer1_upstream_sampling = 0 143 | layer1_vendor_count = default_vendor_count 144 | layer2_upstream_sampling = 0 145 | layer2_vendor_count = default_vendor_count 146 | return [ 147 | Spec( 148 | model=MultiModel( 149 | [ 150 | MultiConv2d( 151 | layer0_vendor_count, 152 | in_channels=3, 153 | out_channels=64, 154 | kernel_size=7, 155 | stride=2, 156 | padding=3, 157 | bias=False, 158 | ), 159 | MultiInstanceNorm(layer0_vendor_count, 64), 160 | Tensor.relu, 161 | lambda x: x.max_pool2d( 162 | kernel_size=3, 163 | stride=2, 164 | padding=1, 165 | ), 166 | BasicBlock( 167 | layer0_vendor_count, 168 | in_channels=64, 169 | out_channels=64, 170 | stride=1, 171 | ), 172 | BasicBlock( 173 | layer0_vendor_count, 174 | in_channels=64, 175 | out_channels=64, 176 | stride=1, 177 | ), 178 | ] 179 | ), 180 | upstream_sampling=2, 181 | ), 182 | # layer1 183 | Spec( 184 | model=MultiModel( 185 | [ 186 | BasicBlock( 187 | layer1_vendor_count, 188 | in_channels=64, 189 | out_channels=128, 190 | stride=2, 191 | ), 192 | BasicBlock( 193 | layer1_vendor_count, 194 | in_channels=128, 195 | out_channels=128, 196 | stride=1, 197 | ), 198 | BasicBlock( 199 | layer1_vendor_count, 200 | in_channels=128, 201 | out_channels=256, 202 | stride=2, 203 | ), 204 | BasicBlock( 205 | layer1_vendor_count, 206 | in_channels=256, 207 | out_channels=256, 208 | stride=1, 209 | ), 210 | ] 211 | ), 212 | upstream_sampling=layer1_upstream_sampling, 213 | ), 214 | # layer4 215 | Spec( 216 | model=MultiModel( 217 | [ 218 | BasicBlock( 219 | layer2_vendor_count, 220 | in_channels=256, 221 | out_channels=512, 222 | stride=2, 223 | ), 224 | BasicBlock( 225 | layer2_vendor_count, 226 | in_channels=512, 227 | out_channels=512, 228 | stride=1, 229 | ), 230 | lambda x: x.avg_pool2d(kernel_size=7), 231 | lambda x: x.flatten(1), 232 | MultiLinear(layer2_vendor_count, 512, num_classes), 233 | ] 234 | ), 235 | upstream_sampling=layer2_upstream_sampling, 236 | ), 237 | ] 238 | 239 | 240 | def train( 241 | dataset_dir: pathlib.Path, 242 | marketplace: list[Spec], 243 | step_count: int = 500_000, 244 | batch_size: int = 64, 245 | num_workers: int = 8, 246 | initial_lr: float = 1e-3, 247 | lr_decay_rate: float = 4.5e-4, 248 | ): 249 | train_files = get_train_files(dataset_dir) 250 | val_files = get_val_files(dataset_dir) 251 | img_categories = get_imagenet_categories(dataset_dir) 252 | loader = ImageLoader(img_categories=img_categories) 253 | 254 | lr = Tensor(initial_lr) 255 | 256 | @TinyJit 257 | @MultiModelBase.learn() 258 | def forward_step(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]: 259 | batch_logits, batch_paths = forward(marketplace, x) 260 | return Tensor.stack( 261 | *(logits.sparse_categorical_crossentropy(y) for logits in batch_logits), 262 | dim=0, 263 | ).realize(), batch_paths.realize() 264 | 265 | @TinyJit 266 | def mutate_step( 267 | combined_loss: Tensor, combined_paths: Tensor 268 | ) -> tuple[Tensor, Tensor]: 269 | min_loss, min_loss_index = combined_loss.topk(1, largest=False) 270 | min_path = combined_paths[min_loss_index].flatten() 271 | mutate( 272 | marketplace=marketplace, 273 | leading_path=min_path, 274 | jitter=lr, 275 | ) 276 | return min_loss.realize(), min_path.realize() 277 | 278 | @TinyJit 279 | def get_test_acc(path: Tensor, x: Tensor, y: Tensor) -> Tensor: 280 | return ( 281 | forward_with_path(marketplace, x, path).argmax(axis=1) == y 282 | ).mean() * 100 283 | 284 | # with load_with_workers( 285 | # loader, 286 | # list(map(pathlib.Path, train_files)), 287 | # num_worker=num_workers, 288 | # shared_memory_enabled=True, 289 | # ) as generator: 290 | test_acc = float("nan") 291 | current_forward_pass = 1 292 | 293 | shuffled_train_files = list(map(pathlib.Path, train_files)) 294 | random.shuffle(shuffled_train_files) 295 | 296 | shuffled_test_files = list(map(pathlib.Path, val_files)) 297 | random.shuffle(shuffled_test_files) 298 | 299 | consumed_count = 0 300 | generator = load(loader, shuffled_train_files) 301 | for i in (t := trange(step_count)): 302 | GlobalCounters.reset() 303 | 304 | start_time = time.perf_counter() 305 | 306 | all_loss = [] 307 | all_paths = [] 308 | for _ in range(current_forward_pass): 309 | x_batch = [] 310 | y_batch = [] 311 | for _ in range(batch_size): 312 | x, y = next(generator) 313 | x_batch.append(x) 314 | y_batch.append(y) 315 | x = Tensor.stack(x_batch, dim=0).realize() 316 | y = Tensor.stack(y_batch, dim=0).realize() 317 | 318 | batch_loss, batch_path = forward_step(x, y) 319 | all_loss.append(batch_loss) 320 | all_paths.append(batch_path) 321 | consumed_count += batch_size * current_forward_pass 322 | if len(shuffled_train_files) - consumed_count < ( 323 | batch_size * current_forward_pass 324 | ): 325 | random.shuffle(shuffled_test_files) 326 | generator = load(loader, shuffled_train_files) 327 | consumed_count = 0 328 | print("Out of training data, reload") 329 | 330 | combined_loss = Tensor.cat(*all_loss).realize() 331 | combined_paths = Tensor.cat(*all_paths).realize() 332 | 333 | loss, path = mutate_step( 334 | combined_loss=combined_loss, combined_paths=combined_paths 335 | ) 336 | 337 | end_time = time.perf_counter() 338 | run_time = end_time - start_time 339 | lr.replace(lr * (1 - lr_decay_rate)) 340 | gflops = GlobalCounters.global_ops * 1e-9 / run_time 341 | 342 | if i % 10 == (10 - 1): 343 | # TODO: optimize this 344 | test_generator = load(loader, shuffled_test_files) 345 | 346 | x_batch = [] 347 | y_batch = [] 348 | # XXX: well, this is not great, but let's some quick hack to make it works 349 | for _ in range(batch_size * 16): 350 | x, y = next(test_generator) 351 | x_batch.append(x) 352 | y_batch.append(y) 353 | 354 | x = Tensor.stack(x_batch, dim=0).realize() 355 | y = Tensor.stack(y_batch, dim=0).realize() 356 | test_acc = get_test_acc(path, x, y).item() 357 | 358 | mlflow.log_metric("training/loss", loss.item(), step=i) 359 | mlflow.log_metric("training/forward_pass", current_forward_pass, step=i) 360 | mlflow.log_metric("training/lr", lr.item(), step=i) 361 | mlflow.log_metric("training/gflops", gflops, step=i) 362 | mlflow.log_metric("testing/accuracy", test_acc, step=i) 363 | 364 | t.set_description( 365 | f"loss: {loss.item():6.2f}, fw: {current_forward_pass}, rl: {lr.item():e}, " 366 | f"acc: {test_acc:5.2f}%, {gflops:9,.2f} GFLOPS" 367 | ) 368 | 369 | 370 | def main(): 371 | exp_id = ensure_experiment("ResNet18") 372 | with mlflow.start_run(experiment_id=exp_id, run_name="resnet18"): 373 | mlflow.log_param("vendor_count", 4) 374 | mlflow.log_param("num_classes", 10) 375 | mlflow.log_param("dataset", "mnist") 376 | train( 377 | dataset_dir=pathlib.Path("mnist"), 378 | marketplace=make_marketplace(num_classes=10), 379 | ) 380 | 381 | 382 | if __name__ == "__main__": 383 | logging.basicConfig(level=logging.INFO) 384 | main() 385 | -------------------------------------------------------------------------------- /experiments/scaling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | import mlflow 5 | 6 | from .beautiful_mnist import make_marketplace 7 | from .beautiful_mnist import train 8 | from .utils import ensure_experiment 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def main(): 14 | exp_id = ensure_experiment("Scaling V5 with LR scaling") 15 | for marketplace_replica in [64]: 16 | for forward_pass in [64]: 17 | with mlflow.start_run( 18 | run_name=f"scaling-mr-{marketplace_replica}-fw-{forward_pass}", 19 | experiment_id=exp_id, 20 | log_system_metrics=True, 21 | ): 22 | marketplace = make_marketplace(default_vendor_count=8) 23 | train( 24 | step_count=500, 25 | batch_size=512, 26 | initial_lr=1e-2, 27 | lr_decay_rate=1e-4, 28 | initial_forward_pass=forward_pass, 29 | lr_scaling_range=0.1, 30 | marketplace=marketplace, 31 | marketplace_replica=marketplace_replica, 32 | # Make initial weights the same so that the exp is less noisy 33 | manual_seed=42, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | logging.basicConfig(level=logging.INFO) 39 | 40 | # ref: https://github.com/tinygrad/tinygrad/issues/8617 41 | # With complex huge compute graph, tinygrad runs into recursion too deep issue, let's bump it up 42 | NEW_RECURSION_LIMIT = 100_000 43 | logger.info("Current recursion limit is %s", sys.getrecursionlimit()) 44 | sys.setrecursionlimit(NEW_RECURSION_LIMIT) 45 | logger.info("Set recursion limit to %s", NEW_RECURSION_LIMIT) 46 | 47 | main() 48 | -------------------------------------------------------------------------------- /experiments/unit_vector_mode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import mlflow 4 | 5 | from .beautiful_mnist import make_marketplace 6 | from .beautiful_mnist import train 7 | from .utils import ensure_experiment 8 | from marketplace.optimizers import UnitVectorMode 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | VENDOR_COUNT = 8 13 | 14 | 15 | def main(): 16 | exp_id = ensure_experiment("Unit Vector Mode") 17 | for lr in [0.09, 0.1, 0.2, 0.3]: 18 | for mode in [UnitVectorMode.per_spec, UnitVectorMode.whole]: 19 | with mlflow.start_run( 20 | run_name=f"lr-{lr}-{mode.value}-scale-1e-3", 21 | experiment_id=exp_id, 22 | log_system_metrics=True, 23 | ): 24 | marketplace = make_marketplace(default_vendor_count=VENDOR_COUNT) 25 | mlflow.log_param("vendor_count", VENDOR_COUNT) 26 | train( 27 | step_count=1_000, 28 | batch_size=512, 29 | initial_lr=lr, 30 | lr_decay_rate=1e-5, 31 | probe_scale=0.001, 32 | unit_vector_mode=mode, 33 | marketplace=marketplace, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | logging.basicConfig(level=logging.INFO) 39 | main() 40 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | 4 | import mlflow 5 | import numpy as np 6 | from tinygrad import dtypes 7 | from tinygrad import Tensor 8 | from tinygrad.tensor import ReductionStr 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def ensure_experiment(name: str) -> str: 14 | try: 15 | experiment_id = mlflow.create_experiment( 16 | name=name, 17 | ) 18 | logger.info("Created experiment with name %s and id %s", name, experiment_id) 19 | return experiment_id 20 | except mlflow.exceptions.MlflowException as e: 21 | logger.info("Failed to create experiment with error: %s", e) 22 | # If experiment already exists, get its ID 23 | experiment = mlflow.get_experiment_by_name(name) 24 | experiment_id = experiment.experiment_id 25 | logger.info("Return existing experiment id %s for %s", experiment_id, name) 26 | return experiment_id 27 | 28 | 29 | def filter_classes( 30 | x: Tensor, y: Tensor, only: typing.Container 31 | ) -> tuple[Tensor, Tensor]: 32 | class_mask = np.isin(y.numpy(), only) 33 | return Tensor(x.numpy()[class_mask]), Tensor(y.numpy()[class_mask]) 34 | -------------------------------------------------------------------------------- /marketplace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaunchPlatform/marketplace/6487c3ea4d5c2df208bc9ce10a627cef5e8eeace/marketplace/__init__.py -------------------------------------------------------------------------------- /marketplace/continual_learning.py: -------------------------------------------------------------------------------- 1 | from tinygrad import Tensor 2 | 3 | from .optimizers import CachedDeltaVendor 4 | from .training import Spec 5 | 6 | 7 | def forward_with_paths( 8 | marketplace: list[Spec], 9 | x: Tensor, 10 | paths: Tensor, 11 | deltas: list[Tensor], 12 | ) -> Tensor: 13 | output = [] 14 | # TODO: this is extremely slow for Tinygrad JIT compiler, should find a better way to do it instead 15 | for xi, path in zip(x, paths): 16 | data = xi.unsqueeze(0) 17 | for spec, delta, idx in zip(marketplace, deltas, path): 18 | vendor = CachedDeltaVendor( 19 | model=spec.model, 20 | delta={key: params[idx] for key, params in delta.items()}, 21 | ) 22 | data = vendor(data) 23 | output.append(data[0]) 24 | return Tensor.stack(*output, dim=0) 25 | -------------------------------------------------------------------------------- /marketplace/nn.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from tinygrad import Tensor 4 | 5 | 6 | class Model: 7 | def __init__(self, *layers): 8 | self.layers: tuple[typing.Callable, ...] = layers 9 | 10 | def __call__(self, x: Tensor) -> Tensor: 11 | value = x 12 | for model in self.layers: 13 | value = model(value) 14 | return value 15 | -------------------------------------------------------------------------------- /marketplace/optimizers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import dataclasses 3 | import enum 4 | import typing 5 | 6 | from tinygrad import dtypes 7 | from tinygrad import Tensor 8 | from tinygrad.nn.state import get_parameters 9 | from tinygrad.nn.state import get_state_dict 10 | from tinygrad.nn.state import load_state_dict 11 | 12 | from .random import counter_advance_for 13 | from .random import RandomNumberGenerator 14 | from .training import Spec 15 | 16 | SEED_MAX = 2**64 17 | 18 | 19 | class UnitVectorMode(enum.Enum): 20 | whole = "whole" 21 | per_spec = "per_spec" 22 | 23 | 24 | @dataclasses.dataclass 25 | class SpecContext: 26 | seeds: Tensor 27 | delta: dict[str, Tensor] 28 | learning_rate: Tensor 29 | learning_rate_scales: Tensor | None = None 30 | 31 | 32 | class CachedDeltaVendor: 33 | def __init__(self, model: typing.Callable, delta: dict[str, Tensor]): 34 | self.model = model 35 | self.delta = delta 36 | 37 | def __call__(self, *args, **kwargs): 38 | vendored_model = copy.deepcopy(self.model) 39 | load_state_dict( 40 | vendored_model, 41 | state_dict={ 42 | key: param + self.delta[key] 43 | for key, param in get_state_dict(vendored_model).items() 44 | }, 45 | verbose=False, 46 | realize=False, 47 | ) 48 | return vendored_model(*args, **kwargs) 49 | 50 | 51 | class Optimizer: 52 | def __init__( 53 | self, 54 | marketplace: list[Spec], 55 | learning_rate: Tensor, 56 | # learning_rate_scale_range: Tensor | None = None, 57 | meta_learning_rate: Tensor | None = None, 58 | seeds: list[Tensor] | None = None, 59 | probe_scale: Tensor | None = None, 60 | make_rng: typing.Type[RandomNumberGenerator] = RandomNumberGenerator, 61 | ): 62 | self.marketplace = marketplace 63 | self.learning_rate = learning_rate 64 | self.meta_learning_rate = meta_learning_rate 65 | self.make_rng = make_rng 66 | self.probe_scale = probe_scale 67 | 68 | if seeds is not None: 69 | market_shape = tuple(spec.vendor_count for spec in marketplace) 70 | seeds_shape = tuple(len(vendor_seeds) for vendor_seeds in seeds) 71 | if seeds_shape != market_shape: 72 | raise ValueError( 73 | f"Provided seeds should the same shape {market_shape} as the depth of market but got {seeds_shape}" 74 | ) 75 | else: 76 | seeds = [ 77 | Tensor.cat( 78 | Tensor.zeros(1, dtype=dtypes.uint64), 79 | Tensor.randint( 80 | spec.vendor_count - 1, low=1, high=SEED_MAX, dtype=dtypes.uint64 81 | ), 82 | ) 83 | for spec in self.marketplace 84 | ] 85 | 86 | self.spec_context: list[SpecContext] = [ 87 | SpecContext( 88 | seeds=seeds[i].contiguous(), 89 | # allocate memory for delta 90 | delta={ 91 | key: Tensor.empty( 92 | spec.vendor_count, *params.shape, dtype=params.dtype 93 | ).contiguous() 94 | for key, params in get_state_dict(spec.model).items() 95 | }, 96 | learning_rate=( 97 | self.learning_rate.clone().contiguous() 98 | if self.meta_learning_rate is not None 99 | else self.learning_rate 100 | ), 101 | learning_rate_scales=( 102 | Tensor.zeros(spec.vendor_count).contiguous() 103 | if self.meta_learning_rate is not None 104 | else None 105 | ), 106 | ) 107 | for i, spec in enumerate(self.marketplace) 108 | ] 109 | 110 | Tensor.realize( 111 | *( 112 | # We need to realize all the parameters so that they are buffer instead of compute graph, otherwise the 113 | # update weights assign operation won't work. 114 | # ref: https://x.com/fangpenlin/status/1959405151455969607 115 | [ 116 | param.assign(param.contiguous()) 117 | for spec in self.marketplace 118 | for param in get_parameters(spec.model) 119 | ] 120 | # also realize seeds so that they are buffer 121 | + [ctx.seeds for ctx in self.spec_context] 122 | ) 123 | ) 124 | # Realize the delta, making them buffers 125 | Tensor.realize(*self.schedule_delta_update()) 126 | self.vendors = [ 127 | [ 128 | CachedDeltaVendor( 129 | model=spec.model, 130 | delta={key: params[i] for key, params in ctx.delta.items()}, 131 | ) 132 | for i in range(spec.vendor_count) 133 | ] 134 | for spec, ctx in zip(self.marketplace, self.spec_context) 135 | ] 136 | 137 | def get_seeds(self, path: Tensor) -> Tensor: 138 | return Tensor.cat( 139 | *( 140 | ctx.seeds[index].unsqueeze(0) 141 | for index, ctx in zip(path, self.spec_context) 142 | ), 143 | dim=0, 144 | ) 145 | 146 | def get_learning_rates(self, path: Tensor) -> Tensor: 147 | return Tensor.stack( 148 | *( 149 | ctx.learning_rate * (1 + ctx.learning_rate_scales[index]) 150 | for index, ctx in zip(path, self.spec_context) 151 | ), 152 | dim=0, 153 | ) 154 | 155 | def step( 156 | self, 157 | seeds: Tensor, 158 | learning_rates: Tensor | None = None, 159 | keep_leader: bool = True, 160 | ): 161 | Tensor.realize( 162 | *self.schedule_step( 163 | seeds, learning_rates=learning_rates, keep_leader=keep_leader 164 | ) 165 | ) 166 | 167 | def schedule_step( 168 | self, 169 | seeds: Tensor, 170 | learning_rates: Tensor | None = None, 171 | keep_leader: bool = True, 172 | ) -> list[Tensor]: 173 | return ( 174 | self.schedule_weight_update(seeds, learning_rates=learning_rates) 175 | + self.schedule_seeds_update(keep_leader) 176 | + self.schedule_delta_update() 177 | ) 178 | 179 | def schedule_weight_update( 180 | self, 181 | direction_delta: list[dict[str, Tensor]], 182 | learning_rates: Tensor | None = None, 183 | ) -> list[Tensor]: 184 | weight_updates = [] 185 | if learning_rates is None: 186 | learning_rates = self.learning_rate.expand(len(self.marketplace)) 187 | for spec, ctx, delta, lr in zip( 188 | self.marketplace, self.spec_context, direction_delta, learning_rates 189 | ): 190 | model_params = get_state_dict(spec.model) 191 | keys = sorted(list(model_params.keys())) 192 | effective_lr = ctx.learning_rate 193 | if self.meta_learning_rate is not None: 194 | weight_updates.append(ctx.learning_rate.assign(lr)) 195 | effective_lr = lr 196 | for key in keys: 197 | params = model_params[key] 198 | weight_updates.append(params.assign(params + delta[key] * effective_lr)) 199 | return weight_updates 200 | 201 | def schedule_seeds_update(self, keep_leader: bool = True): 202 | return [ 203 | ctx.seeds.assign( 204 | Tensor.cat( 205 | Tensor.zeros(1, dtype=dtypes.uint64), 206 | Tensor.randint( 207 | len(ctx.seeds) - 1, low=1, high=SEED_MAX, dtype=dtypes.uint64 208 | ), 209 | ) 210 | if keep_leader 211 | else Tensor.randint( 212 | *ctx.seeds.shape, low=1, high=SEED_MAX, dtype=dtypes.uint64 213 | ) 214 | ) 215 | for ctx in self.spec_context 216 | ] 217 | 218 | def schedule_delta_update(self) -> list[Tensor]: 219 | delta_updates = [] 220 | for ctx in self.spec_context: 221 | counter = 0 222 | keys = sorted(list(ctx.delta.keys())) 223 | for key in keys: 224 | params = ctx.delta[key] 225 | updated_params = Tensor.stack( 226 | *( 227 | self.make_delta( 228 | seed=seed, 229 | lr=( 230 | ctx.learning_rate 231 | if self.probe_scale is None 232 | else (ctx.learning_rate * self.probe_scale) 233 | ), 234 | counter=Tensor(counter, dtype=dtypes.uint), 235 | params=params[i], 236 | ) 237 | for i, seed in enumerate(ctx.seeds) 238 | ), 239 | dim=0, 240 | ) 241 | counter += counter_advance_for(params[0]) 242 | delta_updates.append(params.assign(updated_params)) 243 | return delta_updates 244 | 245 | def schedule_lr_scale_update(self, direction_vectors: Tensor) -> list[Tensor]: 246 | if self.meta_learning_rate is None: 247 | raise ValueError("LR scale not set") 248 | lr_updates = [] 249 | for ctx, vector in zip(self.spec_context, direction_vectors): 250 | # We use the final counter (after all params) for generating the lr delta. 251 | # TODO: extract this part? 252 | final_counter = 0 253 | keys = sorted(list(ctx.delta.keys())) 254 | for key in keys: 255 | params = ctx.delta[key] 256 | final_counter += counter_advance_for(params[0]) 257 | # Generate different LR to try out 258 | lr_updates.append( 259 | ctx.learning_rate_scales.assign( 260 | Tensor.stack( 261 | *( 262 | ( 263 | self.make_delta( 264 | seed=seed, 265 | counter=Tensor(final_counter, dtype=dtypes.uint), 266 | lr=self.meta_learning_rate, 267 | params=lr_scale, 268 | ) 269 | if i != 0 270 | # we always keep the original lr in the combinations, in case we cannot find any 271 | # improvement from scale, at least we are not making regression 272 | else Tensor.zeros_like(lr_scale) 273 | ) 274 | for i, (lr_scale, seed) in enumerate( 275 | zip(ctx.learning_rate_scales, ctx.seeds) 276 | ) 277 | ), 278 | dim=0, 279 | ) 280 | ) 281 | ) 282 | for key in keys: 283 | params = ctx.delta[key] 284 | updated_params = Tensor.stack( 285 | *( 286 | vector[key] * (ctx.learning_rate * (1 + lr_scale)) 287 | for i, lr_scale in enumerate(ctx.learning_rate_scales) 288 | ), 289 | dim=0, 290 | ) 291 | lr_updates.append(params.assign(updated_params)) 292 | return lr_updates 293 | 294 | def compute_direction_vectors( 295 | self, 296 | loss: Tensor, 297 | paths: Tensor, 298 | unit_vector_mode: UnitVectorMode = UnitVectorMode.per_spec, 299 | ) -> list[dict[str, Tensor]]: 300 | std, mean = loss.std_mean() 301 | std_loss = -((loss - mean) / std) 302 | reconciled_deltas = [] 303 | direction_vectors = [] 304 | vector_square_sum = [] 305 | for i, (spec, ctx) in enumerate(zip(self.marketplace, self.spec_context)): 306 | model_params = get_state_dict(spec.model) 307 | keys = sorted(list(model_params.keys())) 308 | counter = 0 309 | indexes = paths[:, i] 310 | reconciled_delta = {} 311 | for key in keys: 312 | reconciled_delta[key] = ( 313 | # Take all the delta and multiply their corresponding normalized loss, so that we can "reward" each 314 | # parameters in delta accordingly to compose a overall better direction. 315 | ctx.delta[key][indexes] 316 | * std_loss.reshape( 317 | len(std_loss), *((1,) * len(model_params[key].shape)) 318 | ) 319 | ).sum(axis=0) 320 | counter += counter_advance_for(model_params[key]) 321 | reconciled_deltas.append(reconciled_delta) 322 | # We treat all the parameters delta in this spec as a vector 323 | combined_vector = Tensor.cat( 324 | *[delta.flatten() for delta in reconciled_delta.values()] 325 | ) 326 | if unit_vector_mode == UnitVectorMode.whole: 327 | # add up the vector's element^2 328 | vector_square_sum.append(combined_vector.square().sum().unsqueeze(0)) 329 | elif unit_vector_mode == UnitVectorMode.per_spec: 330 | vector_len = combined_vector.square().sum().sqrt() 331 | direction_vectors.append( 332 | {key: delta / vector_len for key, delta in reconciled_delta.items()} 333 | ) 334 | else: 335 | raise ValueError(f"Unexpected unit vector mode {unit_vector_mode}") 336 | if unit_vector_mode == UnitVectorMode.per_spec: 337 | return direction_vectors 338 | vector_len = Tensor.cat(*vector_square_sum).sum().sqrt() 339 | return [ 340 | # make them a unit vector 341 | {key: delta / vector_len for key, delta in reconciled_delta.items()} 342 | for reconciled_delta in reconciled_deltas 343 | ] 344 | 345 | def make_delta( 346 | self, seed: Tensor, lr: Tensor, counter: Tensor, params: Tensor 347 | ) -> Tensor: 348 | return (seed != 0).where( 349 | self.make_rng(seed=seed, counter=counter).uniform_like( 350 | params, 351 | low=-lr, 352 | high=lr, 353 | ), 354 | Tensor.zeros_like(params), 355 | ) 356 | -------------------------------------------------------------------------------- /marketplace/random.py: -------------------------------------------------------------------------------- 1 | from tinygrad import Device 2 | from tinygrad import Tensor 3 | from tinygrad import UOp 4 | from tinygrad.dtype import DTypeLike 5 | from tinygrad.dtype import dtypes 6 | from tinygrad.dtype import to_dtype 7 | from tinygrad.helpers import all_int 8 | from tinygrad.helpers import argfix 9 | from tinygrad.helpers import ceildiv 10 | from tinygrad.helpers import prod 11 | 12 | 13 | # The original version is taking two unit32 as the key, but we want to use one uint64 as the key. To make the 14 | # compute graph simpler, let's change it a bit to use uint64 directly 15 | # ref: https://github.com/tinygrad/tinygrad/blob/b057a90d493664d37558eb6c5447bc5bd5c15009/tinygrad/tensor.py#L496-L500 16 | def _threefry_random_bits(key: Tensor, counts0: Tensor, counts1: Tensor) -> Tensor: 17 | x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) 18 | x = x._apply_uop(UOp.threefry, key._broadcast_to(x.shape)) 19 | counts0, counts1 = ( 20 | (x & 0xFFFFFFFF).cast(dtypes.uint32), 21 | ((x >> 32) & 0xFFFFFFFF).cast(dtypes.uint32), 22 | ) 23 | return counts0.cat(counts1) 24 | 25 | 26 | def counter_advance(*shape, dtype: DTypeLike | None = None) -> int: 27 | """Calculate counter advance for a given shape 28 | 29 | :param shape: Shape of tensor 30 | :param dtype: dtype of tensor 31 | :return: the number we will advance when generate random numbers in the given shape 32 | """ 33 | if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): 34 | raise ValueError(f"rand only supports float dtypes, got {dtype}") 35 | if (numel := prod(shape)) == 0: 36 | return 0 37 | return ceildiv(numel * dtype.itemsize, 4) 38 | 39 | 40 | def counter_advance_for(target: Tensor) -> int: 41 | return counter_advance(*target.shape, dtype=target.dtype) 42 | 43 | 44 | # we mostly follow the implementation of Tinygrad's `rand` function, but we use our own given seed value 45 | # ref: https://github.com/tinygrad/tinygrad/blob/b057a90d493664d37558eb6c5447bc5bd5c15009/tinygrad/tensor.py#L502-L549 46 | def rand( 47 | *shape, 48 | seed: Tensor, 49 | counter: Tensor = 0, 50 | device: str | None = None, 51 | dtype: DTypeLike | None = None, 52 | contiguous: bool = True, 53 | ) -> Tensor: 54 | if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): 55 | raise ValueError(f"rand only supports float dtypes, got {dtype}") 56 | if not all_int(shape := argfix(*shape)) or not all(s >= 0 for s in shape): 57 | raise ValueError(f"invalid input {shape=}") 58 | if device is not None and not isinstance(device, str): 59 | raise ValueError(f"rand only supports single device, got {device=}") 60 | device = Device.canonicalize(device) 61 | 62 | if seed.dtype != dtypes.uint64: 63 | raise ValueError("Seed dtype needs to be uint32") 64 | if seed.ndim != 0: 65 | raise ValueError("Seed must be a scalar") 66 | 67 | # if shape has 0, return zero tensor 68 | if (numel := prod(shape)) == 0: 69 | return Tensor.zeros(shape, device=device, dtype=dtype) 70 | 71 | # how many 4 bytes random bits sets we should generate 72 | num = ceildiv(numel * dtype.itemsize, 4) 73 | 74 | # increase counter 75 | counter.assign(counter + num).contiguous() 76 | bits_count = counter - num 77 | 78 | # threefry random bits 79 | counts0 = ( 80 | Tensor.arange( 81 | ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False 82 | ) 83 | + bits_count 84 | ) 85 | counts1 = counts0 + ceildiv(num, 2) 86 | bits = _threefry_random_bits(seed, counts0, counts1)[:num] 87 | 88 | # bitcast to uint with same number of bits 89 | _, nmant = dtypes.finfo(dtype) 90 | uint_dtype = { 91 | 1: dtypes.uint8, 92 | 2: dtypes.uint16, 93 | 4: dtypes.uint32, 94 | 8: dtypes.uint64, 95 | }[dtype.itemsize] 96 | bits = bits.bitcast(uint_dtype) 97 | # only randomize the mantissa bits and set the exponent to 1 98 | one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype) 99 | bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one) 100 | # bitcast back to the original dtype and reshape 101 | out = bits.bitcast(dtype)[:numel].sub(1).reshape(shape) 102 | return out.contiguous() if contiguous else out 103 | 104 | 105 | class RandomNumberGenerator: 106 | def __init__(self, seed: Tensor, counter: Tensor | None = None): 107 | if seed.dtype != dtypes.uint64: 108 | raise ValueError("Seed dtype needs to be uint64") 109 | self.seed = seed 110 | self.counter = counter 111 | if self.counter is None: 112 | self.counter = Tensor.zeros(dtype=dtypes.uint).contiguous().realize() 113 | 114 | def rand( 115 | self, 116 | *shape, 117 | device: str | None = None, 118 | dtype: DTypeLike | None = None, 119 | contiguous: bool = True, 120 | ) -> Tensor: 121 | return rand( 122 | *shape, 123 | seed=self.seed, 124 | counter=self.counter, 125 | device=device, 126 | dtype=dtype, 127 | contiguous=contiguous, 128 | ) 129 | 130 | def uniform( 131 | self, *shape, low=0.0, high=1.0, dtype: DTypeLike | None = None 132 | ) -> Tensor: 133 | return ((high - low) * self.rand(*shape)).cast( 134 | dtype or dtypes.default_float 135 | ) + low 136 | 137 | def uniform_like(self, target: Tensor, low=0.0, high=1.0): 138 | return self.uniform(*target.shape, low=low, high=high, dtype=target.dtype) 139 | -------------------------------------------------------------------------------- /marketplace/training.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing 3 | 4 | from tinygrad import Tensor 5 | 6 | 7 | @dataclasses.dataclass 8 | class Spec: 9 | model: typing.Callable 10 | vendor_count: int 11 | upstream_sampling: int = 0 12 | 13 | 14 | def produce( 15 | spec: Spec, 16 | x: Tensor, 17 | vendors: list[typing.Callable], 18 | paths: Tensor | None = None, 19 | ) -> tuple[Tensor, Tensor]: 20 | """Produce various of output for the given model and its vendors with upstream sampling 21 | 22 | :param spec: spec of marketplace 23 | :param x: raw input data or intermediate products from the previous layer 24 | :param vendors: vendors for decorating a model 25 | :param paths: accumulated paths so far from the previous layers 26 | :return: (output_data, paths) 27 | """ 28 | if paths is None: 29 | # this is the first spec for taking in the raw input, let's feed data to all of them 30 | # TODO: use RANGIFY feature when it's ready to make JIT's job much easier 31 | output_data = Tensor.stack( 32 | *(vendor(x) for vendor in vendors), 33 | dim=0, 34 | ) 35 | paths = Tensor.arange(len(vendors)).unsqueeze(1) 36 | return output_data, paths 37 | if x.size(0) != paths.size(0): 38 | raise ValueError( 39 | "Provided input data's first dimension doesn't match with the paths' first dimension" 40 | ) 41 | 42 | if spec.upstream_sampling == 0: 43 | # when upstream sampling is zero, it means we sample the full input 44 | upstream_sampling = x.shape[0] 45 | input_indexes = Tensor.arange(x.shape[0]).expand(spec.vendor_count, -1) 46 | else: 47 | upstream_sampling = spec.upstream_sampling 48 | input_count = paths.size(0) 49 | # TODO: use RANGIFY? 50 | input_indexes = Tensor.stack( 51 | *( 52 | Tensor.randperm(input_count)[:upstream_sampling] 53 | for _ in range(spec.vendor_count) 54 | ), 55 | dim=0, 56 | ) 57 | 58 | input_data = x[input_indexes] 59 | # merge different batches for the same vendor into one. 60 | merged_batches = input_data.reshape(input_data.shape[0], -1, *input_data.shape[3:]) 61 | if len(merged_batches) != len(vendors): 62 | raise ValueError( 63 | f"Unexpected size of merged batches {len(merged_batches)} and vendors {len(vendors)}" 64 | ) 65 | 66 | output_data = Tensor.stack( 67 | *(vendor(merged) for vendor, merged in zip(vendors, merged_batches)), 68 | dim=0, 69 | ) 70 | # breaking down merged batches back to individual batches 71 | output_data = output_data.reshape(-1, input_data.shape[2], *output_data.shape[2:]) 72 | 73 | prev_paths = paths[input_indexes].flatten(0, 1) 74 | current_paths = ( 75 | Tensor.arange(len(vendors)) 76 | .unsqueeze(1) 77 | .repeat(1, upstream_sampling) 78 | .flatten() 79 | .unsqueeze(1) 80 | ) 81 | merged_paths = prev_paths.cat(current_paths, dim=1) 82 | return output_data, merged_paths 83 | 84 | 85 | def forward( 86 | marketplace: list[Spec], 87 | x: Tensor, 88 | vendors: list[list[typing.Callable]], 89 | initial_paths: Tensor | None = None, 90 | ) -> tuple[Tensor, Tensor]: 91 | data = x 92 | acc_paths = initial_paths 93 | for spec, spec_vendors in zip(marketplace, vendors): 94 | data, acc_paths = produce( 95 | spec=spec, 96 | x=data, 97 | vendors=spec_vendors, 98 | paths=acc_paths, 99 | ) 100 | return data, acc_paths 101 | 102 | 103 | def straight_forward(specs: list[Spec], x: Tensor) -> Tensor: 104 | data = x 105 | for spec in specs: 106 | data = spec.model(data) 107 | return data 108 | -------------------------------------------------------------------------------- /marketplace/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | import pathlib 4 | 5 | from tinygrad import Tensor 6 | from tinygrad.nn.state import get_state_dict 7 | from tinygrad.nn.state import load_state_dict 8 | from tinygrad.nn.state import safe_load 9 | from tinygrad.nn.state import safe_save 10 | 11 | from .training import Spec 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def write_checkpoint( 17 | marketplace: list[Spec], 18 | global_step: int, 19 | output_filepath: pathlib.Path, 20 | ): 21 | logger.info( 22 | "Writing checkpoint with global_step %s to %s", global_step, output_filepath 23 | ) 24 | parameters = dict( 25 | itertools.chain.from_iterable( 26 | [ 27 | (f"spec.{i}.{key}", weights) 28 | for key, weights in get_state_dict(spec.model).items() 29 | ] 30 | for i, spec in enumerate(marketplace) 31 | ) 32 | ) 33 | checkpoint_tmp_filepath = output_filepath.with_suffix(".tmp") 34 | safe_save( 35 | parameters | dict(global_step=Tensor(global_step)), str(checkpoint_tmp_filepath) 36 | ) 37 | checkpoint_tmp_filepath.rename(output_filepath) 38 | logger.info( 39 | "Wrote checkpoint with global_step %s to %s", global_step, output_filepath 40 | ) 41 | 42 | 43 | def load_checkpoint( 44 | marketplace: list[Spec], 45 | input_filepath: pathlib.Path, 46 | ): 47 | logger.info("Loading checkpoint from %s", input_filepath) 48 | state = safe_load(input_filepath) 49 | 50 | for i, spec in enumerate(marketplace): 51 | prefix = f"spec.{i}." 52 | spec_params = { 53 | key.removeprefix(prefix): params 54 | for key, params in state.items() 55 | if key.startswith(prefix) 56 | } 57 | load_state_dict(spec.model, spec_params) 58 | 59 | global_step = state.pop("global_step", None) 60 | if global_step is not None: 61 | global_step = global_step.item() 62 | logger.info( 63 | "Loaded checkpoint with global_step %s from %s", global_step, input_filepath 64 | ) 65 | -------------------------------------------------------------------------------- /plot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaunchPlatform/marketplace/6487c3ea4d5c2df208bc9ce10a627cef5e8eeace/plot/__init__.py -------------------------------------------------------------------------------- /plot/compare_with_all_at_once.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | # Define unit cost functions 6 | def unit_cost_all_at_once(N, C): 7 | """Calculate unit cost for all-at-once approach: U_A = sum(C(i))""" 8 | return sum(C(i) for i in range(N)) 9 | 10 | 11 | def unit_cost_marketplace(N, M, C): 12 | """Calculate unit cost for Marketplace approach: U_M = sum(C(i) / M^(N-i-1))""" 13 | return sum(C(i) / (M ** (N - i - 1)) for i in range(N)) 14 | 15 | 16 | # Define parameters 17 | M_values = [4, 8, 16, 32, 64] # Number of vendors per layer 18 | N_values = np.arange(1, 11) # Number of layers from 1 to 10 19 | C = lambda i: 1 # Constant computation cost per layer 20 | 21 | # Initialize lists for plotting 22 | unit_costs_a = [] 23 | unit_costs_m = {M: [] for M in M_values} 24 | ratios = {M: [] for M in M_values} 25 | 26 | # Calculate unit costs for each N and M 27 | for N in N_values: 28 | # All-at-once unit cost 29 | ua = unit_cost_all_at_once(N, C) 30 | unit_costs_a.append(ua) 31 | 32 | # Marketplace unit cost for each M 33 | for M in M_values: 34 | um = unit_cost_marketplace(N, M, C) 35 | unit_costs_m[M].append(um) 36 | ratios[M].append(ua / um if um != 0 else float("inf")) 37 | 38 | # Create plots 39 | plt.style.use("seaborn-v0_8") 40 | 41 | # Plot 1: Unit Cost vs. Number of Layers 42 | plt.figure(figsize=(10, 6)) 43 | plt.plot(N_values, unit_costs_a, label="All-at-Once (U_A)", marker="o", linewidth=2) 44 | for M in M_values: 45 | plt.plot( 46 | N_values, unit_costs_m[M], label=f"Marketplace (M={M})", marker="s", linewidth=2 47 | ) 48 | plt.xlabel("Number of Layers (N)") 49 | plt.ylabel("Unit Cost") 50 | plt.title("Unit Cost Comparison: All-at-Once vs. Marketplace") 51 | plt.legend() 52 | plt.grid(True) 53 | plt.yscale("log") # Log scale to better visualize differences 54 | plt.savefig("unit_cost_vs_N.png") 55 | plt.close() 56 | 57 | # Plot 2: Unit Cost vs. Number of Vendors for fixed N 58 | fixed_N = [3, 5, 7] # Different N values to compare 59 | M_range = np.arange(2, 11) 60 | plt.figure(figsize=(10, 6)) 61 | for N in fixed_N: 62 | ua = unit_cost_all_at_once(N, C) 63 | um_values = [unit_cost_marketplace(N, M, C) for M in M_range] 64 | plt.plot(M_range, um_values, label=f"Marketplace (N={N})", marker="s", linewidth=2) 65 | plt.axhline(y=ua, linestyle="--", label=f"All-at-Once (N={N})", alpha=0.7) 66 | plt.xlabel("Number of Vendors per Layer (M)") 67 | plt.ylabel("Unit Cost") 68 | plt.title("Unit Cost vs. Number of Vendors for Fixed N") 69 | plt.legend() 70 | plt.grid(True) 71 | plt.yscale("log") # Log scale for clarity 72 | plt.savefig("unit_cost_vs_M.png") 73 | plt.close() 74 | 75 | # Plot 3: Ratio of Unit Costs vs. Number of Layers 76 | plt.figure(figsize=(10, 6)) 77 | for M in M_values: 78 | plt.plot(N_values, ratios[M], label=f"M={M}", marker="^", linewidth=2) 79 | plt.xlabel("Number of Layers (N)") 80 | plt.ylabel("Efficiency Ratio (U_A / U_M)") 81 | plt.title("Efficiency Ratio: All-at-Once vs. Marketplace") 82 | plt.legend() 83 | plt.grid(True) 84 | plt.yscale("log") # Log scale to show exponential growth 85 | plt.savefig("efficiency_ratio_vs_N.png") 86 | plt.close() 87 | -------------------------------------------------------------------------------- /plot/plot_3d.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from mpl_toolkits.mplot3d import Axes3D 4 | 5 | plt.rcParams["text.usetex"] = True 6 | 7 | 8 | # Define unit cost functions 9 | def unit_cost_all_at_once(N, C): 10 | """Calculate unit cost for all-at-once approach: U_A = sum(C(i))""" 11 | return sum(C(i) for i in range(N)) 12 | 13 | 14 | def unit_cost_marketplace(N, M, C): 15 | """Calculate unit cost for Marketplace approach: U_M = sum(C(i) / M^(N-i-1))""" 16 | return sum(C(i) / (M ** (N - i - 1)) for i in range(N)) 17 | 18 | 19 | # Define parameters 20 | N_values = np.arange(1, 11) # Number of layers from 1 to 10 21 | M_values = np.arange(2, 11) # Number of vendors from 2 to 10 22 | C = lambda i: 1 # Constant computation cost per layer 23 | 24 | # Create meshgrid for N and M 25 | N_grid, M_grid = np.meshgrid(N_values, M_values) 26 | 27 | # Calculate unit costs 28 | U_M = np.zeros_like(N_grid, dtype=float) 29 | U_A = np.zeros_like(N_grid, dtype=float) 30 | for i in range(N_grid.shape[0]): 31 | for j in range(N_grid.shape[1]): 32 | N = int(N_grid[i, j]) 33 | M = int(M_grid[i, j]) 34 | U_M[i, j] = unit_cost_marketplace(N, M, C) 35 | U_A[i, j] = unit_cost_all_at_once(N, C) 36 | 37 | # Create 3D plot 38 | fig = plt.figure(figsize=(12, 8)) 39 | ax = fig.add_subplot(111, projection="3d") 40 | 41 | # Plot Marketplace unit cost surface 42 | surf1 = ax.plot_surface( 43 | N_grid, M_grid, U_M, cmap="viridis", alpha=0.7, label="Marketplace ($$U_M$$)" 44 | ) 45 | # Plot All-at-Once unit cost surface 46 | surf2 = ax.plot_surface( 47 | N_grid, M_grid, U_A, cmap="magma", alpha=0.7, label="All-at-Once ($$U_A$$)" 48 | ) 49 | 50 | # Add labels and title 51 | ax.set_xlabel("Depth of Market ($N$)") 52 | ax.set_ylabel("Number of Vendors ($M$)") 53 | ax.set_zlabel("Unit Cost") 54 | ax.set_title("Unit Costs: Marketplace vs. All-at-Once") 55 | 56 | # Add a color bar for each surface 57 | fig.colorbar(surf1, ax=ax, shrink=0.5, aspect=5, label="$U_M$") 58 | fig.colorbar(surf2, ax=ax, shrink=0.5, aspect=5, label="$U_A$") 59 | 60 | # Set z-axis to log scale for better visualization 61 | ax.set_zscale("log") 62 | 63 | # Save the plot 64 | plt.savefig("3d-unit-cost-comparison.png") 65 | plt.close() 66 | -------------------------------------------------------------------------------- /plot/plot_continual_learning_fashion.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import multiprocessing 4 | import pathlib 5 | import signal 6 | 7 | import click 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from matplotlib import ticker 11 | from tinygrad.nn.datasets import mnist 12 | 13 | # Simulated image data: list of (image, is_correct) pairs, None for empty cells 14 | logger = logging.getLogger(__name__) 15 | 16 | # Save the original default for potential restoration 17 | original_font_size = plt.rcParams["font.size"] 18 | # Scale up font size (e.g., 1.5x the default, which is usually 10pt) 19 | scale_factor = 3.8 20 | plt.rcParams["font.size"] = original_font_size * scale_factor 21 | 22 | X_train, _, _, _ = mnist() 23 | X_train = X_train.numpy() 24 | 25 | target_new_classes = (3,) 26 | new_X_train, new_Y_train, _, _ = mnist(fashion=True) 27 | class_mask = np.isin(new_Y_train.numpy(), target_new_classes) 28 | target_new_X_train = new_X_train.numpy()[class_mask] 29 | target_new_Y_train = new_Y_train.numpy()[class_mask] 30 | 31 | 32 | def plot_frame( 33 | old_images: np.typing.NDArray, 34 | old_correct: np.typing.NDArray, 35 | old_learning_accuracy: np.typing.NDArray, 36 | old_validation_accuracy: np.typing.NDArray, 37 | old_loss: np.typing.NDArray, 38 | new_images: np.typing.NDArray, 39 | new_correct: np.typing.NDArray, 40 | new_learning_accuracy: np.typing.NDArray, 41 | new_validation_accuracy: np.typing.NDArray, 42 | new_loss: np.typing.NDArray, 43 | steps: np.typing.NDArray, 44 | output_file: pathlib.Path, 45 | dpi: int = 50, 46 | ): 47 | images_top = list( 48 | zip( 49 | old_images, 50 | old_correct, 51 | ) 52 | ) 53 | images_bottom = list( 54 | zip( 55 | new_images, 56 | new_correct, 57 | ) 58 | ) 59 | 60 | # Set up figure with gridspec for images and charts 61 | fig = plt.figure(figsize=(32, 32)) 62 | fig.suptitle( 63 | f"Marketplace Continual Learning +Fashion CLS 3 - Step {steps[-1] + 1}", 64 | fontsize=48, 65 | ) 66 | gs = fig.add_gridspec(2, 2, width_ratios=[1, 1], hspace=0.2, wspace=0.2) 67 | 68 | # Image grid axes 69 | ax_top = fig.add_subplot(gs[0, 0]) 70 | ax_top.set_title( 71 | f"Old Data ({old_correct.sum()}/{len(old_correct)}, acc={old_learning_accuracy[-1]:.2f}%)" 72 | ) 73 | ax_bottom = fig.add_subplot(gs[1, 0]) 74 | ax_bottom.set_title( 75 | f"New Data ({new_correct.sum()}/{len(new_correct)}, acc={new_learning_accuracy[-1]:.2f}%)" 76 | ) 77 | 78 | # Chart axes 79 | ax_acc_top = fig.add_subplot(gs[0, 1]) 80 | ax_loss_top = ax_acc_top.twinx() 81 | ax_acc_bottom = fig.add_subplot(gs[1, 1]) 82 | ax_loss_bottom = ax_acc_bottom.twinx() 83 | 84 | # Function to plot a single grid 85 | def plot_grid(ax, images, grid_size=16): 86 | for i in range(grid_size): 87 | for j in range(grid_size): 88 | idx = i * grid_size + j 89 | if idx >= len(images): 90 | sub_ax = ax.inset_axes( 91 | [j / grid_size, 1 - (i + 1) / grid_size, 0.05, 0.05] 92 | ) 93 | sub_ax.axis("off") 94 | continue 95 | if images[idx] is not None and images[idx][0] is not None: 96 | img, is_correct = images[idx] 97 | sub_ax = ax.inset_axes( 98 | [j / grid_size, 1 - (i + 1) / grid_size, 0.05, 0.05] 99 | ) 100 | sub_ax.imshow(img, cmap="gray") 101 | border_color = "green" if is_correct else "red" 102 | for spine in sub_ax.spines.values(): 103 | spine.set_edgecolor(border_color) 104 | spine.set_linewidth(4) 105 | sub_ax.set_xticks([]) 106 | sub_ax.set_yticks([]) 107 | 108 | # Plot image grids 109 | plot_grid(ax_top, images_top) 110 | plot_grid(ax_bottom, images_bottom) 111 | 112 | # Plot accuracy and loss for top grid 113 | ax_acc_top.plot( 114 | steps, old_learning_accuracy, label="Learning Accuracy", color="blue" 115 | ) 116 | ax_acc_top.plot( 117 | steps, old_validation_accuracy, label="Validation Accuracy", color="green" 118 | ) 119 | ax_loss_top.plot(steps, old_loss, label="Loss", color="red", linestyle="--") 120 | ax_acc_top.set_title( 121 | f"Old Data (vacc={old_validation_accuracy[-1]:.2f}, loss={old_loss[-1]:.2f})" 122 | ) 123 | ax_acc_top.set_xlabel("Steps") 124 | ax_acc_top.set_ylabel("Accuracy", color="blue") 125 | ax_acc_top.set_ylim(0, 100) 126 | ax_loss_top.set_ylabel("Loss", color="red") 127 | ax_loss_top.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.3f")) 128 | ax_acc_top.tick_params(axis="y", colors="blue") 129 | ax_loss_top.tick_params(axis="y", colors="red") 130 | ax_acc_top.legend(loc="upper left") 131 | ax_loss_top.legend(loc="upper right") 132 | 133 | # Plot accuracy and loss for bottom grid 134 | ax_acc_bottom.plot( 135 | steps, new_learning_accuracy, label="Learning Accuracy", color="blue" 136 | ) 137 | ax_acc_bottom.plot( 138 | steps, new_validation_accuracy, label="Validation Accuracy", color="green" 139 | ) 140 | ax_loss_bottom.plot(steps, new_loss, label="Loss", color="red", linestyle="--") 141 | ax_acc_bottom.set_title( 142 | f"New Data (vacc={new_validation_accuracy[-1]:.2f}, loss={new_loss[-1]:.2f})" 143 | ) 144 | ax_acc_bottom.set_xlabel("Steps") 145 | ax_acc_bottom.set_ylabel("Accuracy", color="blue") 146 | ax_acc_bottom.set_ylim(0, 100) 147 | ax_loss_bottom.set_ylabel("Loss", color="red") 148 | ax_loss_bottom.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.3f")) 149 | ax_acc_bottom.tick_params(axis="y", colors="blue") 150 | ax_loss_bottom.tick_params(axis="y", colors="red") 151 | ax_acc_bottom.legend(loc="upper left") 152 | ax_loss_bottom.legend(loc="upper right") 153 | 154 | # Remove ticks from image grid axes 155 | ax_top.set_xticks([]) 156 | ax_top.set_yticks([]) 157 | ax_bottom.set_xticks([]) 158 | ax_bottom.set_yticks([]) 159 | 160 | plt.subplots_adjust(left=0.025, right=0.95, bottom=0.05, top=0.925) 161 | # plt.tight_layout() 162 | 163 | tmp_file = output_file.with_suffix(".tmp.png") 164 | plt.savefig(tmp_file, dpi=dpi, bbox_inches=None) 165 | tmp_file.rename(output_file) 166 | 167 | 168 | # ref: https://stackoverflow.com/a/6191991 169 | def init_worker(): 170 | signal.signal(signal.SIGINT, signal.SIG_IGN) 171 | 172 | 173 | def make_frame(kwargs): 174 | plot_frame(**kwargs) 175 | return kwargs["output_file"] 176 | 177 | 178 | @click.command() 179 | @click.argument( 180 | "INPUT_FILE", type=click.Path(dir_okay=False, exists=True, readable=True) 181 | ) 182 | @click.argument( 183 | "OUTPUT_FOLDER", 184 | type=click.Path(dir_okay=True, file_okay=False, exists=True, writable=True), 185 | ) 186 | @click.option("--limit", type=int) 187 | def main(input_file: str, output_folder: str, limit: int | None): 188 | steps = [] 189 | 190 | old_samples = [] 191 | old_correct = [] 192 | old_learning_accuracy = [] 193 | old_validation_accuracy = [] 194 | old_loss = [] 195 | 196 | new_samples = [] 197 | new_correct = [] 198 | new_learning_accuracy = [] 199 | new_validation_accuracy = [] 200 | new_loss = [] 201 | 202 | with open(input_file) as replay_file: 203 | for line in replay_file.readlines(): 204 | data = json.loads(line) 205 | 206 | steps.append(data["global_step"]) 207 | 208 | old_learning_accuracy.append(np.array(data["old_correct"]).mean() * 100) 209 | old_validation_accuracy.append(data["old_test_acc"]) 210 | old_loss.append(data["old_loss"]) 211 | old_correct.append(data["old_correct"]) 212 | old_samples.append(data["old_samples"]) 213 | 214 | new_learning_accuracy.append(np.array(data["new_correct"]).mean() * 100) 215 | new_validation_accuracy.append(data["new_test_acc"]) 216 | new_loss.append(data["new_loss"]) 217 | new_correct.append(data["new_correct"]) 218 | new_samples.append(data["new_samples"]) 219 | 220 | old_learning_accuracy = np.array(old_learning_accuracy) 221 | old_validation_accuracy = np.array(old_validation_accuracy) 222 | old_loss = np.array(old_loss).mean(axis=1) 223 | old_correct = np.array(old_correct) 224 | old_samples = np.array(old_samples) 225 | 226 | new_learning_accuracy = np.array(new_learning_accuracy) 227 | new_validation_accuracy = np.array(new_validation_accuracy) 228 | new_loss = np.array(new_loss).mean(axis=1) 229 | new_correct = np.array(new_correct) 230 | new_samples = np.array(new_samples) 231 | 232 | steps = np.array(steps) 233 | 234 | def prepare_kwargs(item: tuple[int, int]): 235 | i, step = item 236 | count = i + 1 237 | output_file = pathlib.Path(output_folder) / f"{i}.png" 238 | logger.info("Writing %s (step %s) to %s", i, step, output_file) 239 | return dict( 240 | old_images=X_train[old_samples[i]].reshape(-1, 28, 28), 241 | old_correct=old_correct[i], 242 | old_learning_accuracy=old_learning_accuracy[:count], 243 | old_validation_accuracy=old_validation_accuracy[:count], 244 | old_loss=old_loss[:count], 245 | new_images=target_new_X_train[new_samples[i]].reshape(-1, 28, 28), 246 | new_correct=new_correct[i], 247 | new_learning_accuracy=new_learning_accuracy[:count], 248 | new_validation_accuracy=new_validation_accuracy[:count], 249 | new_loss=new_loss[:count], 250 | steps=steps[:count], 251 | output_file=output_file, 252 | ) 253 | 254 | limited_steps = steps 255 | if limit is not None: 256 | limited_steps = steps[:limit] 257 | with multiprocessing.Pool(16, init_worker) as pool: 258 | for output_file in pool.imap( 259 | make_frame, 260 | filter( 261 | lambda x: not x["output_file"].exists(), 262 | map(prepare_kwargs, enumerate(limited_steps)), 263 | ), 264 | ): 265 | logger.info("Wrote to %s", output_file) 266 | 267 | 268 | if __name__ == "__main__": 269 | logging.basicConfig(level=logging.INFO) 270 | main() 271 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "marketplace" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "click>=8.2.1", 9 | "mlflow>=3.2.0", 10 | "numpy>=2.3.2", 11 | "pillow>=11.3.0", 12 | "psutil>=7.0.0", 13 | "tinygrad>=0.11.0", 14 | "tinyloader>=0.1.3", 15 | ] 16 | 17 | [dependency-groups] 18 | dev = [ 19 | "capstone>=5.0.6", 20 | "matplotlib>=3.10.5", 21 | "pytest>=8.4.1", 22 | "seaborn>=0.13.2", 23 | "tensorboard>=2.20.0", 24 | "tensorboardx>=2.6.4", 25 | ] 26 | 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaunchPlatform/marketplace/6487c3ea4d5c2df208bc9ce10a627cef5e8eeace/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_optimizers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tinygrad import dtypes 3 | from tinygrad import Tensor 4 | from tinygrad.nn.state import get_state_dict 5 | 6 | from marketplace.nn import Model 7 | from marketplace.optimizers import CachedDeltaVendor 8 | from marketplace.optimizers import Optimizer 9 | from marketplace.optimizers import SEED_MAX 10 | from marketplace.training import Spec 11 | 12 | 13 | class Multiply: 14 | def __init__(self, number: float): 15 | self.number = Tensor(number).contiguous().realize() 16 | 17 | def __call__(self, x: Tensor) -> Tensor: 18 | return x * self.number 19 | 20 | 21 | class Add: 22 | def __init__(self, number: float): 23 | self.number = Tensor(number).contiguous().realize() 24 | 25 | def __call__(self, x: Tensor) -> Tensor: 26 | return x + self.number 27 | 28 | 29 | @pytest.fixture 30 | def optimizer() -> Optimizer: 31 | return Optimizer( 32 | marketplace=[ 33 | Spec( 34 | model=Model( 35 | Multiply(3.0), 36 | Add(11.0), 37 | ), 38 | vendor_count=4, 39 | ), 40 | Spec( 41 | model=Model( 42 | Multiply(6.0), 43 | Add(3.0), 44 | Multiply(23.0), 45 | ), 46 | vendor_count=4, 47 | ), 48 | Spec(model=Multiply(5.0), vendor_count=2), 49 | ], 50 | learning_rate=Tensor(2.0).contiguous(), 51 | seeds=[ 52 | Tensor([0, 1, 2, 3], dtype=dtypes.uint64).contiguous(), 53 | Tensor([0, 1, 2, 3], dtype=dtypes.uint64).contiguous(), 54 | Tensor([0, 1], dtype=dtypes.uint64).contiguous(), 55 | ], 56 | ) 57 | 58 | 59 | def test_cached_delta_vendor(): 60 | model = Model( 61 | Multiply(3.0), 62 | Add(7.0), 63 | ) 64 | vendor = CachedDeltaVendor( 65 | model=model, 66 | delta={ 67 | "layers.0.number": Tensor(5.0), 68 | "layers.1.number": Tensor(2.0), 69 | }, 70 | ) 71 | x = Tensor(4) 72 | assert model(x).item() == (x.item() * 3) + 7 73 | assert vendor(x).item() == (x.item() * (3 + 5)) + (7 + 2) 74 | # ensure that we didn't change the weights of original model 75 | assert model(x).item() == (x.item() * 3) + 7 76 | 77 | model.layers[0].number.assign(4.0) 78 | assert model(x).item() == (x.item() * 4) + 7 79 | assert vendor(x).item() == (x.item() * (4 + 5)) + (7 + 2) 80 | 81 | 82 | def test_optimizer(optimizer: Optimizer): 83 | assert len(optimizer.spec_context) == len(optimizer.marketplace) 84 | # assert len(optimizer.vendors) == len(optimizer.marketplace) 85 | 86 | 87 | def test_optimizer_schedule_delta_update(optimizer: Optimizer): 88 | init_seeds = [ctx.seeds.tolist() for ctx in optimizer.spec_context] 89 | initial_deltas = [ 90 | {key: params.tolist() for key, params in ctx.delta.items()} 91 | for ctx in optimizer.spec_context 92 | ] 93 | for _ in range(10): 94 | Tensor.realize(*optimizer.schedule_delta_update()) 95 | new_delta = [ 96 | {key: params.tolist() for key, params in ctx.delta.items()} 97 | for ctx in optimizer.spec_context 98 | ] 99 | assert initial_deltas == new_delta 100 | for _ in range(5): 101 | for ctx in optimizer.spec_context: 102 | ctx.seeds.assign( 103 | Tensor.randint( 104 | *ctx.seeds.shape, low=0, high=SEED_MAX, dtype=dtypes.uint64 105 | ) 106 | ).realize() 107 | assert [ctx.seeds.tolist() for ctx in optimizer.spec_context] != init_seeds 108 | last_delta = None 109 | for _ in range(10): 110 | Tensor.realize(*optimizer.schedule_delta_update()) 111 | new_delta = [ 112 | {key: params.tolist() for key, params in ctx.delta.items()} 113 | for ctx in optimizer.spec_context 114 | ] 115 | assert new_delta != initial_deltas 116 | if last_delta is not None: 117 | assert new_delta == last_delta 118 | last_delta = new_delta 119 | 120 | 121 | def test_optimizer_schedule_weight_update(optimizer: Optimizer): 122 | initial_deltas = [ 123 | {key: params.numpy() for key, params in ctx.delta.items()} 124 | for ctx in optimizer.spec_context 125 | ] 126 | initial_weights = [ 127 | {key: params.numpy() for key, params in get_state_dict(spec.model).items()} 128 | for spec in optimizer.marketplace 129 | ] 130 | 131 | # Update with zero seeds, nothing should change 132 | Tensor.realize( 133 | *optimizer.schedule_weight_update( 134 | Tensor.zeros(len(optimizer.marketplace), dtype=dtypes.uint64) 135 | ) 136 | ) 137 | assert initial_weights == [ 138 | {key: params.numpy() for key, params in get_state_dict(spec.model).items()} 139 | for spec in optimizer.marketplace 140 | ] 141 | 142 | # Now the weight should change, but the second should remain the ame 143 | path = Tensor([1, 0, 1], dtype=dtypes.uint) 144 | seeds = optimizer.get_seeds(path) 145 | # the second seed should be zero 146 | assert seeds[1].item() == 0 147 | Tensor.realize(*optimizer.schedule_weight_update(seeds)) 148 | new_weights = [ 149 | {key: params.numpy() for key, params in get_state_dict(spec.model).items()} 150 | for spec in optimizer.marketplace 151 | ] 152 | assert new_weights[0] != initial_weights[0] 153 | assert new_weights[1] == initial_weights[1] 154 | assert new_weights[2] != initial_weights[2] 155 | 156 | for i in range(3): 157 | assert { 158 | key: init_params + initial_deltas[i][key][path[i].item()] 159 | for key, init_params in initial_weights[i].items() 160 | } == new_weights[i] 161 | -------------------------------------------------------------------------------- /tests/test_random.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from tinygrad import dtypes 3 | from tinygrad import Tensor 4 | from tinygrad.dtype import DTypeLike 5 | from tinygrad.helpers import ceildiv 6 | from tinygrad.helpers import prod 7 | 8 | from marketplace.random import counter_advance 9 | from marketplace.random import rand 10 | from marketplace.random import RandomNumberGenerator 11 | 12 | 13 | @pytest.fixture 14 | def rng() -> RandomNumberGenerator: 15 | return RandomNumberGenerator( 16 | seed=Tensor(0, dtype=dtypes.uint64), counter=Tensor(0, dtype=dtypes.uint) 17 | ) 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "shape, seed, counter, expected", 22 | [ 23 | ( 24 | (), 25 | Tensor(123456, dtype=dtypes.uint64), 26 | 0, 27 | 0.7353423833847046, 28 | ), 29 | ( 30 | (6,), 31 | Tensor(123456, dtype=dtypes.uint64), 32 | 0, 33 | [ 34 | 0.5452135801315308, 35 | 0.28107452392578125, 36 | 0.4398590326309204, 37 | 0.6165577173233032, 38 | 0.04700958728790283, 39 | 0.5229370594024658, 40 | ], 41 | ), 42 | ( 43 | (6,), 44 | Tensor(123456, dtype=dtypes.uint64), 45 | 2, 46 | [ 47 | 0.4398590326309204, 48 | 0.8822280168533325, 49 | 0.35901951789855957, 50 | 0.5229370594024658, 51 | 0.39503049850463867, 52 | 0.4783148765563965, 53 | ], 54 | ), 55 | ( 56 | (3, 2), 57 | Tensor(123456, dtype=dtypes.uint64), 58 | 0, 59 | [ 60 | [0.5452135801315308, 0.28107452392578125], 61 | [0.4398590326309204, 0.6165577173233032], 62 | [0.04700958728790283, 0.5229370594024658], 63 | ], 64 | ), 65 | ], 66 | ) 67 | def test_rand( 68 | shape: tuple[int, ...], seed: Tensor, counter: int, expected: list | float 69 | ): 70 | counter_val = Tensor(counter) 71 | nums = rand(*shape, seed=seed, counter=counter_val) 72 | if isinstance(expected, list): 73 | assert nums.tolist() == expected 74 | else: 75 | assert nums.item() == expected 76 | assert ( 77 | counter_val.item() == ceildiv(prod(shape) * dtypes.float.itemsize, 4) + counter 78 | ) 79 | 80 | 81 | def test_rng_rand(rng: RandomNumberGenerator): 82 | random_numbers = rng.rand(512, 768).realize() 83 | assert random_numbers.min().item() >= 0.0 84 | assert random_numbers.max().item() < 1.0 85 | assert random_numbers.mean().item() == pytest.approx(0.5, rel=1e-03) 86 | 87 | 88 | def test_rng_uniform(rng: RandomNumberGenerator): 89 | random_numbers = rng.uniform(512, 768, low=0, high=10).realize() 90 | assert random_numbers.min().item() >= 0.0 91 | assert random_numbers.max().item() < 10.0 92 | assert random_numbers.mean().item() == pytest.approx(5, rel=1e-03) 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "shape, dtype, expected", 97 | [ 98 | ((0,), dtypes.float, 0), 99 | ((1, 2, 3), dtypes.float, 6), 100 | ((1, 2, 3), dtypes.float16, 3), 101 | ((5,), dtypes.float16, 3), 102 | ], 103 | ) 104 | def test_counter_advance( 105 | shape: tuple[int, ...], dtype: DTypeLike | None, expected: int 106 | ): 107 | assert counter_advance(*shape, dtype=dtype) == expected 108 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | from tinygrad import Tensor 5 | 6 | from marketplace.training import produce 7 | from marketplace.training import Spec 8 | 9 | 10 | class Multiply: 11 | def __init__(self, value: float): 12 | self.weight = Tensor(value).contiguous().realize() 13 | 14 | def __call__(self, x: Tensor) -> Tensor: 15 | return x * self.weight 16 | 17 | 18 | class MultiplySum: 19 | def __init__(self, value: float): 20 | self.weight = Tensor(value).contiguous().realize() 21 | 22 | def __call__(self, x: Tensor): 23 | return x.sum(axis=1) * self.weight 24 | 25 | 26 | def realize(x: Tensor) -> list: 27 | return x.tolist() 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "spec, vendors, x, expected", 32 | [ 33 | ( 34 | Spec( 35 | model=lambda: None, 36 | vendor_count=3, 37 | ), 38 | [Multiply(v) for v in [0.0, 1.0, 2.0]], 39 | Tensor([1.0, 2.0, 3.0]), 40 | ( 41 | Tensor( 42 | [ 43 | [0.0, 0.0, 0.0], 44 | [1.0, 2.0, 3.0], 45 | [2.0, 4.0, 6.0], 46 | ] 47 | ), 48 | Tensor([[0], [1], [2]]), 49 | ), 50 | ), 51 | ], 52 | ) 53 | def test_produce_with_input_data( 54 | spec: Spec, 55 | vendors: list[typing.Callable], 56 | x: Tensor, 57 | expected: tuple[Tensor, Tensor], 58 | ): 59 | assert list(map(realize, produce(spec=spec, vendors=vendors, x=x))) == list( 60 | map(realize, expected) 61 | ) 62 | 63 | 64 | @pytest.mark.parametrize( 65 | "spec, vendors, x, paths", 66 | [ 67 | ( 68 | Spec( 69 | model=lambda: None, 70 | vendor_count=3, 71 | upstream_sampling=2, 72 | ), 73 | [Multiply(v) for v in [1.0, 3.0, 5.0]], 74 | Tensor( 75 | [ 76 | [1.0, 2.0, 3.0], 77 | [2.0, 3.0, 4.0], 78 | [4.0, 5.0, 6.0], 79 | ] 80 | ), 81 | Tensor([[0], [1], [2]]), 82 | ), 83 | ( 84 | Spec( 85 | model=lambda: None, 86 | vendor_count=3, 87 | upstream_sampling=2, 88 | ), 89 | [MultiplySum(v) for v in [1.0, 3.0, 5.0]], 90 | Tensor( 91 | [ 92 | [[1.0, 2.0, 3.0]], 93 | [[2.0, 3.0, 4.0]], 94 | [[4.0, 5.0, 6.0]], 95 | ] 96 | ), 97 | Tensor([[0], [1], [2]]), 98 | ), 99 | ( 100 | Spec( 101 | model=lambda: None, 102 | vendor_count=3, 103 | upstream_sampling=0, 104 | ), 105 | [MultiplySum(v) for v in [1.0, 3.0, 5.0]], 106 | Tensor( 107 | [ 108 | [[1.0, 2.0, 3.0]], 109 | [[2.0, 3.0, 4.0]], 110 | [[4.0, 5.0, 6.0]], 111 | ] 112 | ), 113 | Tensor([[0], [1], [2]]), 114 | ), 115 | ], 116 | ) 117 | def test_produce(spec: Spec, vendors: list[typing.Callable], x: Tensor, paths: Tensor): 118 | output, out_paths = produce(spec=spec, vendors=vendors, x=x, paths=paths) 119 | assert all(v >= 0 and v < len(x) for v in out_paths[:, :1].flatten().tolist()) 120 | assert ( 121 | out_paths[:, 1:].tolist() 122 | == ( 123 | Tensor.arange(spec.vendor_count) 124 | .unsqueeze(1) 125 | .repeat(1, spec.upstream_sampling if spec.upstream_sampling > 0 else len(x)) 126 | .flatten() 127 | .unsqueeze(1) 128 | ).tolist() 129 | ) 130 | expected_output = [vendors[j.item()](x[i]).tolist() for i, j in out_paths] 131 | assert output.tolist() == expected_output 132 | --------------------------------------------------------------------------------