├── .gitignore ├── README.md ├── experiments ├── README.md ├── convergence │ ├── convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb │ ├── convergence_mnist_64workers_1000ms_seed1337_largeffn.ipynb │ ├── convergence_mnist_fail01_64workers_1000ms_seed1338_dmoe1024x4_cpu.ipynb │ ├── dmoe_emulator.py │ └── faulty_dmoe_emulator.py └── throughput │ ├── baseline_throughput.py │ ├── layers.py │ ├── rpc_throughput.py │ ├── throughput_client.py │ └── throughput_server.py ├── lib ├── __init__.py ├── client │ ├── __init__.py │ ├── gating_function.py │ └── remote_expert.py ├── network │ └── __init__.py ├── runtime │ ├── __init__.py │ ├── expert_backend.py │ └── task_pool.py ├── server │ ├── __init__.py │ ├── connection_handler.py │ └── network_handler.py └── utils │ ├── __init__.py │ ├── connection.py │ ├── data.py │ ├── nested.py │ ├── proto.py │ ├── serializer.py │ ├── shared_arrays.py │ ├── shared_future.py │ └── threading.py ├── requirements.txt ├── scheme.png └── scheme_pad.png /.gitignore: -------------------------------------------------------------------------------- 1 | # node and NPM 2 | npm-debug.log 3 | node_modules 4 | 5 | # swap files 6 | *~ 7 | *.swp 8 | 9 | examples/data/* 10 | examples/runs/* 11 | examples/.ipynb_checkpoints/* 12 | 13 | env.sh 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | bin/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg/ 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Mr Developer 53 | .mr.developer.cfg 54 | .project 55 | .pydevproject 56 | .idea 57 | .ipynb_checkpoints 58 | 59 | # Rope 60 | .ropeproject 61 | 62 | # Django stuff: 63 | *.log 64 | *.pot 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | docs/tmp* 69 | 70 | # OS X garbage 71 | .DS_Store 72 | 73 | # Debian things 74 | debian/reproducible-experiment-platform 75 | debian/files 76 | *.substvars 77 | *.debhelper.log 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Learning@home: Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts 2 | ![img](./scheme.png) 3 | 4 | PyTorch original implementation of ["Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts"](https://arxiv.org/abs/2002.04013) (NeurIPS 2020). 5 | 6 | __TL;DR:__ Learning@home is an approach for training large (up to multi-terabyte) neural networks on hardware provided by volunteers with unreliable and slow connection. 7 | 8 | __This repository__ contains a snapshot of Learning@home that was used to conduct initial experiments. While this snapshot implements the main functionality of Learning@home, it should be treated as a testbed to reproduce our experiments, __not__ as a finished library (see limitations below). To see an updated implementation designed for practical use, please refer to the [hivemind](https://github.com/learning-at-home/hivemind) project. 9 | 10 | 11 | ## What do I need to run it? 12 | * One or several computers, each equipped with at least one GPU 13 | * Each computer should have at least two open ports (if not, consider ssh port forwarding) 14 | * Some popular Linux x64 distribution 15 | * Tested on Ubuntu16.04, should work fine on any popular linux64 and even MacOS; 16 | * Running on Windows natively is not supported, please use vm or docker; 17 | 18 | ## How do I run it? 19 | 1. Clone or download this repo. `cd` to its root directory. 20 | 2. Create a working python enviromnent. [Anaconda](https://www.anaconda.com/) works fine. 21 | 3. Install packages from `requirements.txt` 22 | 4. Follow the instructions in the next section 23 | 24 | ## Running the experiments 25 | 26 | ### Throughput 27 | 28 | All three scripts are contained in the folder `throughput` and are ready for customized benchmark runs. 29 | 30 | To run the baseline with parameters from the paper, use 31 | 32 | ```python baseline_throughput.py --batches-for-latency 5 --batches-for-throughput 10 --batch-size 4 --throughput-runs 5 --linspace-points 10 --block-type transformer --layers-per-gpu 56 --gpus 0 1 2 3``` 33 | 34 | For testing Learning@home throughput under latency, first start the server for each GPU you have with 35 | 36 | ```python throughput_server.py -a 16 -p PORT_NUMBER --block_type BLOCK_TYPE --gpu GPU_NUMBER``` 37 | 38 | and then run a multiple-trainer client with commands like 39 | 40 | ```python throughput_client.py -j 64 --batches-for-latency 5 --batches-for-throughput 2 --throughput-runs 5 --linspace-points 10 --layers-per-gpu 56 --block-type ffn --hosts HOSTAME1:PORT_NUMBER1 HOSTAME2:PORT_NUMBER2”``` 41 | 42 | ```python throughput_client.py -j 64 --batches-for-latency 5 --batches-for-throughput 2 --throughput-runs 5 --linspace-points 10 --layers-per-gpu 56 --block-type transformer --max-ping 0.2 --hosts HOSTAME1:PORT_NUMBER1 HOSTAME2:PORT_NUMBER2 --batch-size 4``` 43 | 44 | ### Convergence 45 | This experiment can be conducted both in a distributed setting and with an emulator. We recommend using the emulator to make results hardware-agnostic and reduce variance due to CPU and network interference from other processes. 46 | 47 | You can find notebooks for [large FFN](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_largeffn.ipynb), [DMoE with 64 experts](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb), [DMoE with 4096 experts](./experiments/convergence/convergence_mnist_fail01_64workers_1000ms_seed1338_dmoe1024x4_cpu.ipynb) in [`./experiments/convergence`](./experiments/convergence). 48 | 49 | Below we include the full grid of parameters used to conduct convergence experiments: 50 | 51 | | `setup` | `notebook` |`experts_per_layer` | `num_trainers` | `batch_size` | `delay_ms` | 52 | |---|---|---|---|---|---| 53 | | `100ms large ffn` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_largeffn.ipynb)|`-`|`64`| `4` |`100`| 54 | | `100ms 64 experts` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb)|`16`|`16`| `4`|`100`| 55 | | `100ms 256 experts` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb)|`64`|`64`|`4`|`100`| 56 | | `100ms 4096 experts` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb)|`1024`|`64`|`8`|`100`| 57 | | `1000ms large ffn` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_largeffn.ipynb)|`-`|`64`| `4` |`1000`| 58 | | `1000ms 64 experts` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb)|`16`|`16`| `4`|`1000`| 59 | | `1000ms 256 experts` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb)|`64`|`64`|`4`|`1000`| 60 | | `1000ms 4096 experts` | [click](./experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb)|`1024`|`64`|`8`|`1000`| 61 | | `10% failure 64 experts` | [click](./experiments/convergence/convergence_mnist_fail01_64workers_1000ms_seed1338_dmoe1024x4_cpu.ipynb)|`16`|`16`| `4`|`1000`| 62 | | `10% failure 256 experts` | [click](./experiments/convergence/convergence_mnist_fail01_64workers_1000ms_seed1338_dmoe1024x4_cpu.ipynb)|`64`|`64`|`4`|`1000`| 63 | | `10% failure 4096 experts` | [click](./experiments/convergence/convergence_mnist_fail01_64workers_1000ms_seed1338_dmoe1024x4_cpu.ipynb)|`1024`|`64`|`8`|`1000`| 64 | 65 | You can reproduce the curves in Figure 4 by opening the associated notebook, setting parameters as described in the table and iterating through random seeds 1337-1341 (including both borders). 66 | 67 | Please note that these experiments can take up a lot of GPU memory due to storing "stale" gradients. With 16 trainers, the code should fit well into consumer GPU. For 4096 experts, we bypassed the memory limit by running on CPU. 68 | 69 | ### Gating function over DHT 70 | We also provide a reference implementation of DMoE gating function over Kademlia DHT via `lib.GatingFunction`. 71 | 72 | In order to test our implementation, you need to do two things: 73 | 74 | First, set up DHT with at least one server process: 75 | ```python 76 | import torch 77 | import lib 78 | 79 | # initial kademlia node 80 | node_zero = lib.TesseractNetwork(port=ROOT_PORT, start=True) 81 | 82 | 83 | # create experts. Warning: expert uids must be unique 84 | experts = {} 85 | for expert_uid in expert_uids: 86 | expert = torch.jit.script(NetworkBlock(1024)) 87 | expert_backend = lib.ExpertBackend( 88 | name=expert_uid, expert=expert, opt=torch.optim.Adam(expert.parameters(), amsgrad=True), 89 | args_schema=(lib.BatchTensorProto(1024),), outputs_schema=lib.BatchTensorProto(1024), 90 | max_batch_size=2048, pool_size=8) 91 | experts[expert_uid] = expert_backend 92 | 93 | # set up server(s) 94 | runtime = lib.TesseractServer(lib.TesseractNetwork(('127.0.0.1', ROOT_PORT), port=SOME_OTHER_PORT, start=rue), 95 | experts, port=PORTS[0], conn_handler_processes=64, 96 | sender_threads=1, device=torch.device('cuda'), 97 | start=True) 98 | # after creating node_zero you can create additional TesseractServer instances in separate processes 99 | ``` 100 | 101 | Second, create a client process and connect to any DHT node: 102 | ```python 103 | import torch 104 | import lib 105 | 106 | # create one or several backends with expert uids following the "expert.[0-32).[0-32)" pattern 107 | # all backends must have TesseractNetwork active 108 | 109 | network = lib.TesseractNetwork(('127.0.0.1', ROOT_PORT), port=SOME_NEW_PORT, start=True) 110 | dmoe = lib.GatingFunction(in_features=1024, grid_size=[32, 32], k_best=4, network=network, uid_prefix='expert') 111 | 112 | average_out = dmoe(torch.randn(32, 1024)) 113 | average_out.sum().backward() 114 | ``` 115 | 116 | 117 | 118 | ## Learning@home quick tour 119 | 120 | __Trainer process:__ 121 | * __`RemoteExpert`__(`lib/client/remote_expert.py`) behaves like a pytorch module with autograd support but actually sends request to a remote runtime. 122 | * __`GatingFunction`__(`lib/client/gating_function.py`) finds best experts for a given input and either returns them as `RemoteExpert` or applies them right away. 123 | 124 | __Runtime process:__ 125 | * __`TesseractRuntime`__ (`lib/runtime/__init__.py`) aggregates batches and performs inference/training of experts according to their priority. 126 | * __`TesseractServer`__ (`lib/server/__init__.py`) wraps runtime and periodically uploads experts into DHT. 127 | 128 | __DHT:__ 129 | * __`TesseractNetwork`__(`lib/network/__init__.py`) is a node of Kademlia-based DHT that stores metadata used by trainer and runtime. 130 | 131 | ## Limitations 132 | As stated above, this implementation is a testbed for experiments, not a feature-complete library. More specifically: 133 | 134 | * After finding best experts across DHT, a client still connects to these experts via hostname/port. Updated version connects to experts via DHT, allowing users to host servers with no public hostname or under NAT. 135 | * Runtime processes do not handle errors. In the updated version, any errors on server are reported to the client. 136 | * This implementation uses basic Kademlia protocol. Updated version modifies Kademlia to speed up searching for alive experts. 137 | 138 | An updated version of the library is available at https://github.com/learning-at-home/hivemind. 139 | 140 | ## References 141 | [Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts](https://arxiv.org/abs/2002.04013) (Max Ryabinin and Anton Gusev, NeurIPS 2020). 142 | ``` 143 | @misc{ryabinin2020crowdsourced, 144 | title={Towards Crowdsourced Training of Large Neural Networks using Decentralized Mixture-of-Experts}, 145 | author={Max Ryabinin and Anton Gusev}, 146 | year={2020}, 147 | eprint={2002.04013}, 148 | archivePrefix={arXiv}, 149 | primaryClass={cs.DC} 150 | } 151 | ``` -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # Running the expertiments 2 | 3 | This folder contains scripts and notebooks necessary for reproducing the results reported in the paper. 4 | To run them, please refer to the corresponding subsections of this guide. 5 | 6 | ### Throughput 7 | 8 | All three scripts are contained in the folder `throughput` and are ready for customized benchmark runs. 9 | 10 | To run the baseline with parameters from the paper, use 11 | 12 | ```python baseline_throughput.py --batches-for-latency 5 --batches-for-throughput 10 --batch-size 4 --throughput-runs 5 --linspace-points 10 --block-type transformer --layers-per-gpu 56 --gpus 0 1 2 3``` 13 | 14 | For testing Learning@home throughput under latency, first start the server for each GPU you have with 15 | 16 | ```python throughput_server.py -a 16 -p PORT_NUMBER --block_type BLOCK_TYPE --gpu GPU_NUMBER``` 17 | 18 | and then run a multiple-trainer client with commands like 19 | 20 | ```python throughput_client.py -j 64 --batches-for-latency 5 --batches-for-throughput 2 --throughput-runs 5 --linspace-points 10 --layers-per-gpu 56 --block-type ffn --hosts HOSTAME1:PORT_NUMBER1 HOSTAME2:PORT_NUMBER2”``` 21 | 22 | ```python throughput_client.py -j 64 --batches-for-latency 5 --batches-for-throughput 2 --throughput-runs 5 --linspace-points 10 --layers-per-gpu 56 --block-type transformer --max-ping 0.2 --hosts HOSTAME1:PORT_NUMBER1 HOSTAME2:PORT_NUMBER2 --batch-size 4``` 23 | 24 | ### Convergence 25 | This experiment can be conducted both in a distributed setting and with an emulator. We recommend using the emulator to make results hardware-agnostic and reduce variance due to CPU and network interference from other processes. 26 | 27 | You can find notebooks for [DMoE with 64 experts](./convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb) and [large FFN](./convergence/convergence_mnist_64workers_1000ms_seed1337_largeffn.ipynb) in [`./convergence`](./convergence). 28 | 29 | In order to reproduce our results, one should run these notebooks with 5 different random seeds aggregate statistics saved in the last cell. Please note that these experiments can take up a lot of GPU memory due to storing "stale" gradients. With 16 workers, the code should fit well into consumer GPU. For 64 workers, we bypassed the memory limit by sending gradients to `.cpu()` at the cost of longer experiment time. 30 | 31 | ### Gating function over DHT 32 | We also provide a reference implementation of DMoE gating function over Kademlia DHT via `lib.GatingFunction`. 33 | 34 | In order to test our implementation, you need to do two things: 35 | 36 | First, set up DHT with at least one server process: 37 | ```python 38 | import torch 39 | import lib 40 | 41 | # initial kademlia node 42 | node_zero = lib.TesseractNetwork(port=ROOT_PORT, start=True) 43 | 44 | 45 | # create experts. Warning: expert uids must be unique 46 | experts = {} 47 | for expert_uid in expert_uids: 48 | expert = torch.jit.script(NetworkBlock(1024)) 49 | expert_backend = lib.ExpertBackend( 50 | name=expert_uid, expert=expert, opt=torch.optim.Adam(expert.parameters(), amsgrad=True), 51 | args_schema=(lib.BatchTensorProto(1024),), outputs_schema=lib.BatchTensorProto(1024), 52 | max_batch_size=2048, pool_size=8) 53 | experts[expert_uid] = expert_backend 54 | 55 | # set up server(s) 56 | runtime = lib.TesseractServer(lib.TesseractNetwork(('127.0.0.1', ROOT_PORT), port=SOME_OTHER_PORT, start=rue), 57 | experts, port=PORTS[0], conn_handler_processes=64, 58 | sender_threads=1, device=torch.device('cuda'), 59 | start=True) 60 | # after creating node_zero you can create additional TesseractServer instances in separate processes 61 | ``` 62 | 63 | Second, create a client process and connect to any DHT node: 64 | ```python 65 | import torch 66 | import lib 67 | 68 | # create one or several backends with expert uids following the "expert.[0-32).[0-32)" pattern 69 | # all backends must have TesseractNetwork active 70 | 71 | network = lib.TesseractNetwork(('127.0.0.1', ROOT_PORT), port=SOME_NEW_PORT, start=True) 72 | dmoe = lib.GatingFunction(in_features=1024, grid_size=[32, 32], k_best=4, network=network, uid_prefix='expert') 73 | 74 | average_out = dmoe(torch.randn(32, 1024)) 75 | average_out.sum().backward() 76 | ``` 77 | -------------------------------------------------------------------------------- /experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_dmoe64x4.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import time\n", 10 | "import random\n", 11 | "import threading\n", 12 | "from functools import partial\n", 13 | "\n", 14 | "import numpy as np\n", 15 | "\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "from torchvision import datasets, transforms\n", 20 | "from tqdm import tqdm\n", 21 | "\n", 22 | "from IPython.display import clear_output\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "%matplotlib inline\n", 25 | "\n", 26 | "from dmoe_emulator import EmulatedDMoE, get_non_expert_params" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "seed = 1337\n", 36 | "torch.manual_seed(seed)\n", 37 | "np.random.seed(seed)\n", 38 | "random.seed(seed)\n", 39 | "\n", 40 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 41 | "layer_dim = 512\n", 42 | "num_blocks = 4\n", 43 | "num_experts = 64\n", 44 | "num_active_experts = 4\n", 45 | "\n", 46 | "batch_size = 4\n", 47 | "num_trainers = 64\n", 48 | "\n", 49 | "delay_ms = 1000\n", 50 | "\n", 51 | "eval_interval = 1024\n", 52 | "total_steps = eval_interval * 20\n", 53 | "update_every_steps = 10\n", 54 | "\n", 55 | "in_features = 28 ** 2\n", 56 | "num_classes = 10" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stderr", 66 | "output_type": "stream", 67 | "text": [ 68 | "\r", 69 | " 0%| | 0/20480 [00:00" 130 | ] 131 | }, 132 | "metadata": { 133 | "needs_background": "light" 134 | }, 135 | "output_type": "display_data" 136 | }, 137 | { 138 | "name": "stderr", 139 | "output_type": "stream", 140 | "text": [ 141 | "#20541\tloss=0.030536\tdelay=41: 20541it [1:21:21, 7.18it/s]" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "def trainer_thread_method():\n", 147 | " \"\"\" train model on batches, emulate network latency \"\"\"\n", 148 | " train_loader = torch.utils.data.DataLoader(\n", 149 | " datasets.MNIST('../data', train=True, download=True,\n", 150 | " transform=transforms.Compose([\n", 151 | " transforms.ToTensor(),\n", 152 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 153 | " transforms.Lambda(lambda x: x.view(-1))\n", 154 | " ])),\n", 155 | " batch_size=batch_size, shuffle=True,\n", 156 | " )\n", 157 | "\n", 158 | " while True:\n", 159 | " for xb, yb in train_loader:\n", 160 | " xb, yb = xb.to(device), yb.to(device)\n", 161 | "\n", 162 | " with lock_model:\n", 163 | " model.train(True)\n", 164 | " initial_step_index = len(train_history)\n", 165 | " logits = model(xb)\n", 166 | " loss = F.cross_entropy(logits, yb)\n", 167 | " \n", 168 | " opt.zero_grad()\n", 169 | " loss.backward()\n", 170 | " grads = [param.grad.clone() if param.grad is not None else None\n", 171 | " for param in non_expert_params]\n", 172 | "\n", 173 | " emulate_latency()\n", 174 | "\n", 175 | " with lock_model:\n", 176 | " model.train(True)\n", 177 | " opt.zero_grad()\n", 178 | " for param, grad in zip(non_expert_params, grads):\n", 179 | " param.grad[...] = grad\n", 180 | " opt.step()\n", 181 | " train_history.append(dict(\n", 182 | " loss=loss.item(), \n", 183 | " delay_steps=len(train_history) - initial_step_index,\n", 184 | " ))\n", 185 | " progress.desc = f'#{len(train_history)}\\tloss={loss.item():4f}\\tdelay={train_history[-1][\"delay_steps\"]}'\n", 186 | " progress.update(1)\n", 187 | " \n", 188 | " if len(train_history) % eval_interval == 0 or len(train_history) >= total_steps:\n", 189 | " need_to_eval.set()\n", 190 | " if len(train_history) >= total_steps:\n", 191 | " return\n", 192 | " \n", 193 | "\n", 194 | "def evaluate():\n", 195 | " test_loader = torch.utils.data.DataLoader(\n", 196 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 197 | " transforms.ToTensor(),\n", 198 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 199 | " transforms.Lambda(lambda x: x.view(-1))\n", 200 | " ])),\n", 201 | " batch_size=batch_size, num_workers=4, pin_memory=True,\n", 202 | " )\n", 203 | " with lock_model, torch.no_grad():\n", 204 | " model.train(False)\n", 205 | " loss_numerator = acc_numerator = denominator = 0.0\n", 206 | " for xb, yb in test_loader:\n", 207 | " xb, yb = xb.to(device), yb.to(device)\n", 208 | " logits = model(xb)\n", 209 | " loss_numerator += F.cross_entropy(logits, yb).item() * len(yb)\n", 210 | " acc_numerator += (logits.argmax(-1).to(yb.dtype) == yb).to(torch.float32).sum()\n", 211 | " denominator += len(yb)\n", 212 | " return dict(loss=loss_numerator / denominator,\n", 213 | " acc=acc_numerator / denominator,\n", 214 | " num_updates=len(train_history))\n", 215 | " \n", 216 | "\n", 217 | "# finally, run training\n", 218 | "trainers = [threading.Thread(target=trainer_thread_method) for i in range(num_trainers)]\n", 219 | "for trainer in trainers:\n", 220 | " trainer.start()\n", 221 | " \n", 222 | "while len(train_history) < total_steps:\n", 223 | " need_to_eval.wait(), need_to_eval.clear()\n", 224 | " val_metrics = evaluate()\n", 225 | " val_history.append(val_metrics)\n", 226 | " \n", 227 | " clear_output(True)\n", 228 | " plt.figure(figsize=[12, 6])\n", 229 | " plt.subplot(1, 2, 1); plt.title('train loss'); plt.ylim(0, 10)\n", 230 | " plt.plot([info['loss'] for info in train_history])\n", 231 | " \n", 232 | " \n", 233 | " plt.subplot(1, 2, 2); plt.title('val accuracy'); plt.grid()\n", 234 | " plt.plot(*zip(*((info['num_updates'], info['acc']) for info in val_history)))\n", 235 | " plt.show()\n", 236 | "\n", 237 | "for trainer in trainers:\n", 238 | " trainer.join()" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 5, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "logs/delay1000ms_dmoe4outof64experts_seed1337.pkl\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "import os\n", 256 | "import pickle\n", 257 | "!mkdir -p logs\n", 258 | "\n", 259 | "num_files = len(os.listdir('logs'))\n", 260 | "fname = f'logs/delay{delay_ms}ms_dmoe{num_active_experts}outof{num_experts}experts_seed{seed}.pkl'\n", 261 | "print(fname)\n", 262 | "with open(fname, 'wb') as f_out:\n", 263 | " pickle.dump(dict(train_history=train_history, val_history=val_history), f_out)" 264 | ] 265 | } 266 | ], 267 | "metadata": { 268 | "kernelspec": { 269 | "display_name": "py38", 270 | "language": "python", 271 | "name": "py38" 272 | }, 273 | "language_info": { 274 | "codemirror_mode": { 275 | "name": "ipython", 276 | "version": 3 277 | }, 278 | "file_extension": ".py", 279 | "mimetype": "text/x-python", 280 | "name": "python", 281 | "nbconvert_exporter": "python", 282 | "pygments_lexer": "ipython3", 283 | "version": "3.8.1" 284 | } 285 | }, 286 | "nbformat": 4, 287 | "nbformat_minor": 2 288 | } 289 | -------------------------------------------------------------------------------- /experiments/convergence/convergence_mnist_64workers_1000ms_seed1337_largeffn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import time\n", 10 | "import random\n", 11 | "import threading\n", 12 | "from functools import partial\n", 13 | "\n", 14 | "import numpy as np\n", 15 | "\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "from torchvision import datasets, transforms\n", 20 | "from tqdm import tqdm\n", 21 | "\n", 22 | "from IPython.display import clear_output\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "%matplotlib inline\n", 25 | "\n", 26 | "from dmoe_emulator import EmulatedDMoE, get_non_expert_params" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "seed = 1337\n", 36 | "torch.manual_seed(seed)\n", 37 | "np.random.seed(seed)\n", 38 | "random.seed(seed)\n", 39 | "\n", 40 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 41 | "layer_dim = 1024\n", 42 | "num_blocks = 4\n", 43 | "\n", 44 | "batch_size = 4\n", 45 | "num_trainers = 64\n", 46 | "\n", 47 | "delay_ms = 1000\n", 48 | "\n", 49 | "eval_interval = 1024\n", 50 | "total_steps = eval_interval * 20\n", 51 | "update_every_steps = 10\n", 52 | "\n", 53 | "in_features = 28 ** 2\n", 54 | "num_classes = 10" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [ 62 | { 63 | "name": "stderr", 64 | "output_type": "stream", 65 | "text": [ 66 | "\r", 67 | " 0%| | 0/20480 [00:00" 118 | ] 119 | }, 120 | "metadata": { 121 | "needs_background": "light" 122 | }, 123 | "output_type": "display_data" 124 | }, 125 | { 126 | "name": "stderr", 127 | "output_type": "stream", 128 | "text": [ 129 | "#20540\tloss=0.051598\tdelay=50: 20540it [15:24, 19.05it/s]" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "def trainer_thread_method():\n", 135 | " \"\"\" train model on batches, emulate network latency \"\"\"\n", 136 | " train_loader = torch.utils.data.DataLoader(\n", 137 | " datasets.MNIST('../data', train=True, download=True,\n", 138 | " transform=transforms.Compose([\n", 139 | " transforms.ToTensor(),\n", 140 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 141 | " transforms.Lambda(lambda x: x.view(-1))\n", 142 | " ])),\n", 143 | " batch_size=batch_size, shuffle=True,\n", 144 | " )\n", 145 | "\n", 146 | " while True:\n", 147 | " for xb, yb in train_loader:\n", 148 | " xb, yb = xb.to(device), yb.to(device)\n", 149 | "\n", 150 | " with lock_model:\n", 151 | " model.train(True)\n", 152 | " initial_step_index = len(train_history)\n", 153 | " logits = model(xb)\n", 154 | " loss = F.cross_entropy(logits, yb)\n", 155 | " \n", 156 | " opt.zero_grad()\n", 157 | " loss.backward()\n", 158 | " grads = [param.grad.clone() if param.grad is not None else None\n", 159 | " for param in non_expert_params]\n", 160 | "\n", 161 | " emulate_latency()\n", 162 | "\n", 163 | " with lock_model:\n", 164 | " model.train(True)\n", 165 | " opt.zero_grad()\n", 166 | " for param, grad in zip(non_expert_params, grads):\n", 167 | " param.grad[...] = grad\n", 168 | " opt.step()\n", 169 | " train_history.append(dict(\n", 170 | " loss=loss.item(), \n", 171 | " delay_steps=len(train_history) - initial_step_index,\n", 172 | " ))\n", 173 | " progress.desc = f'#{len(train_history)}\\tloss={loss.item():4f}\\tdelay={train_history[-1][\"delay_steps\"]}'\n", 174 | " progress.update(1)\n", 175 | " \n", 176 | " if len(train_history) % eval_interval == 0 or len(train_history) >= total_steps:\n", 177 | " need_to_eval.set()\n", 178 | " if len(train_history) >= total_steps:\n", 179 | " return\n", 180 | " \n", 181 | "\n", 182 | "def evaluate():\n", 183 | " test_loader = torch.utils.data.DataLoader(\n", 184 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 185 | " transforms.ToTensor(),\n", 186 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 187 | " transforms.Lambda(lambda x: x.view(-1))\n", 188 | " ])),\n", 189 | " batch_size=batch_size, num_workers=4, pin_memory=True,\n", 190 | " )\n", 191 | " with lock_model, torch.no_grad():\n", 192 | " model.train(False)\n", 193 | " loss_numerator = acc_numerator = denominator = 0.0\n", 194 | " for xb, yb in test_loader:\n", 195 | " xb, yb = xb.to(device), yb.to(device)\n", 196 | " logits = model(xb)\n", 197 | " loss_numerator += F.cross_entropy(logits, yb).item() * len(yb)\n", 198 | " acc_numerator += (logits.argmax(-1).to(yb.dtype) == yb).to(torch.float32).sum()\n", 199 | " denominator += len(yb)\n", 200 | " return dict(loss=loss_numerator / denominator,\n", 201 | " acc=acc_numerator / denominator,\n", 202 | " num_updates=len(train_history))\n", 203 | " \n", 204 | "\n", 205 | "# finally, run training\n", 206 | "trainers = [threading.Thread(target=trainer_thread_method) for i in range(num_trainers)]\n", 207 | "for trainer in trainers:\n", 208 | " trainer.start()\n", 209 | " \n", 210 | "while len(train_history) < total_steps:\n", 211 | " need_to_eval.wait(), need_to_eval.clear()\n", 212 | " val_metrics = evaluate()\n", 213 | " val_history.append(val_metrics)\n", 214 | " \n", 215 | " clear_output(True)\n", 216 | " plt.figure(figsize=[12, 6])\n", 217 | " plt.subplot(1, 2, 1); plt.title('train loss'); plt.ylim(0, 10)\n", 218 | " plt.plot([info['loss'] for info in train_history])\n", 219 | " \n", 220 | " \n", 221 | " plt.subplot(1, 2, 2); plt.title('val accuracy'); plt.grid()\n", 222 | " plt.plot(*zip(*((info['num_updates'], info['acc']) for info in val_history)))\n", 223 | " plt.show()\n", 224 | "\n", 225 | "for trainer in trainers:\n", 226 | " trainer.join()" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 5, 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "name": "stdout", 236 | "output_type": "stream", 237 | "text": [ 238 | "logs/delay1000ms_ffn1024_seed1337.pkl\n" 239 | ] 240 | }, 241 | { 242 | "name": "stderr", 243 | "output_type": "stream", 244 | "text": [ 245 | "\r", 246 | "#20543\tloss=0.018364\tdelay=72: 20543it [15:40, 19.05it/s]" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "import os\n", 252 | "import pickle\n", 253 | "!mkdir -p logs\n", 254 | "\n", 255 | "num_files = len(os.listdir('logs'))\n", 256 | "fname = f'logs/delay{delay_ms}ms_ffn{layer_dim}_seed{seed}.pkl'\n", 257 | "print(fname)\n", 258 | "with open(fname, 'wb') as f_out:\n", 259 | " pickle.dump(dict(train_history=train_history, val_history=val_history), f_out)" 260 | ] 261 | } 262 | ], 263 | "metadata": { 264 | "kernelspec": { 265 | "display_name": "py38", 266 | "language": "python", 267 | "name": "py38" 268 | }, 269 | "language_info": { 270 | "codemirror_mode": { 271 | "name": "ipython", 272 | "version": 3 273 | }, 274 | "file_extension": ".py", 275 | "mimetype": "text/x-python", 276 | "name": "python", 277 | "nbconvert_exporter": "python", 278 | "pygments_lexer": "ipython3", 279 | "version": "3.8.1" 280 | } 281 | }, 282 | "nbformat": 4, 283 | "nbformat_minor": 2 284 | } 285 | -------------------------------------------------------------------------------- /experiments/convergence/convergence_mnist_fail01_64workers_1000ms_seed1338_dmoe1024x4_cpu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "env: CUDA_VISIBLE_DEVICES=\n", 13 | "env: OMP_NUM_THREADS=48\n", 14 | "env: MKL_NUM_THREADS=48\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "%env CUDA_VISIBLE_DEVICES=\n", 20 | "%env OMP_NUM_THREADS=48\n", 21 | "%env MKL_NUM_THREADS=48\n", 22 | "import time\n", 23 | "import random\n", 24 | "import threading\n", 25 | "from functools import partial\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "\n", 29 | "import torch\n", 30 | "import torch.nn as nn\n", 31 | "import torch.nn.functional as F\n", 32 | "from torchvision import datasets, transforms\n", 33 | "from tqdm import tqdm\n", 34 | "\n", 35 | "from IPython.display import clear_output\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "%matplotlib inline\n", 38 | "\n", 39 | "from faulty_dmoe_emulator import EmulatedFaultyDMoE, get_non_expert_params" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "seed = 1338\n", 49 | "torch.manual_seed(seed)\n", 50 | "np.random.seed(seed)\n", 51 | "random.seed(seed)\n", 52 | "\n", 53 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 54 | "layer_dim = 512\n", 55 | "num_blocks = 4\n", 56 | "num_experts = 1024\n", 57 | "num_active_experts = 4\n", 58 | "\n", 59 | "batch_size = 8\n", 60 | "num_trainers = 64\n", 61 | "failure_rate = 0.1\n", 62 | "\n", 63 | "delay_ms = 1000\n", 64 | "\n", 65 | "eval_interval = 1024\n", 66 | "total_steps = eval_interval * 35\n", 67 | "update_every_steps = 10\n", 68 | "\n", 69 | "in_features = 28 ** 2\n", 70 | "num_classes = 10" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stderr", 80 | "output_type": "stream", 81 | "text": [ 82 | "\r", 83 | " 0%| | 0/35840 [00:00" 145 | ] 146 | }, 147 | "metadata": { 148 | "needs_background": "light" 149 | }, 150 | "output_type": "display_data" 151 | }, 152 | { 153 | "name": "stderr", 154 | "output_type": "stream", 155 | "text": [ 156 | "#35903\tloss=0.003334\tdelay=33: 35903it [17:10:38, 8.16it/s]" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "def trainer_thread_method():\n", 162 | " \"\"\" train model on batches, emulate network latency \"\"\"\n", 163 | " train_loader = torch.utils.data.DataLoader(\n", 164 | " datasets.MNIST('../data', train=True, download=True,\n", 165 | " transform=transforms.Compose([\n", 166 | " transforms.ToTensor(),\n", 167 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 168 | " transforms.Lambda(lambda x: x.view(-1))\n", 169 | " ])),\n", 170 | " batch_size=batch_size, shuffle=True,\n", 171 | " )\n", 172 | "\n", 173 | " while True:\n", 174 | " for xb, yb in train_loader:\n", 175 | " xb, yb = xb.to(device), yb.to(device)\n", 176 | "\n", 177 | " with lock_model:\n", 178 | " model.train(True)\n", 179 | " initial_step_index = len(train_history)\n", 180 | " logits = model(xb)\n", 181 | " loss = F.cross_entropy(logits, yb)\n", 182 | " \n", 183 | " opt.zero_grad()\n", 184 | " loss.backward()\n", 185 | " grads = [param.grad.clone() if param.grad is not None else None\n", 186 | " for param in non_expert_params]\n", 187 | "\n", 188 | " emulate_latency()\n", 189 | "\n", 190 | " with lock_model:\n", 191 | " model.train(True)\n", 192 | " opt.zero_grad()\n", 193 | " for param, grad in zip(non_expert_params, grads):\n", 194 | " param.grad[...] = grad\n", 195 | " opt.step()\n", 196 | " train_history.append(dict(\n", 197 | " loss=loss.item(), \n", 198 | " delay_steps=len(train_history) - initial_step_index,\n", 199 | " ))\n", 200 | " progress.desc = f'#{len(train_history)}\\tloss={loss.item():4f}\\tdelay={train_history[-1][\"delay_steps\"]}'\n", 201 | " progress.update(1)\n", 202 | " \n", 203 | " if len(train_history) % eval_interval == 0 or len(train_history) >= total_steps:\n", 204 | " need_to_eval.set()\n", 205 | " if len(train_history) >= total_steps:\n", 206 | " return\n", 207 | " \n", 208 | "\n", 209 | "def evaluate():\n", 210 | " test_loader = torch.utils.data.DataLoader(\n", 211 | " datasets.MNIST('../data', train=False, transform=transforms.Compose([\n", 212 | " transforms.ToTensor(),\n", 213 | " transforms.Normalize((0.1307,), (0.3081,)),\n", 214 | " transforms.Lambda(lambda x: x.view(-1))\n", 215 | " ])),\n", 216 | " batch_size=batch_size, num_workers=0, pin_memory=True,\n", 217 | " )\n", 218 | " with lock_model, torch.no_grad():\n", 219 | " model.train(False)\n", 220 | " loss_numerator = acc_numerator = denominator = 0.0\n", 221 | " for xb, yb in test_loader:\n", 222 | " xb, yb = xb.to(device), yb.to(device)\n", 223 | " logits = model(xb)\n", 224 | " loss_numerator += F.cross_entropy(logits, yb).item() * len(yb)\n", 225 | " acc_numerator += (logits.argmax(-1).to(yb.dtype) == yb).to(torch.float32).sum()\n", 226 | " denominator += len(yb)\n", 227 | " return dict(loss=loss_numerator / denominator,\n", 228 | " acc=acc_numerator / denominator,\n", 229 | " num_updates=len(train_history))\n", 230 | " \n", 231 | "\n", 232 | "# finally, run training\n", 233 | "trainers = [threading.Thread(target=trainer_thread_method) for i in range(num_trainers)]\n", 234 | "for trainer in trainers:\n", 235 | " trainer.start()\n", 236 | " \n", 237 | "while len(train_history) < total_steps:\n", 238 | " need_to_eval.wait(), need_to_eval.clear()\n", 239 | " val_metrics = evaluate()\n", 240 | " val_history.append(val_metrics)\n", 241 | " \n", 242 | " clear_output(True)\n", 243 | " plt.figure(figsize=[12, 6])\n", 244 | " plt.subplot(1, 2, 1); plt.title('train loss'); plt.ylim(0, 10)\n", 245 | " plt.plot([info['loss'] for info in train_history])\n", 246 | " \n", 247 | " \n", 248 | " plt.subplot(1, 2, 2); plt.title('val accuracy'); plt.grid()\n", 249 | " plt.plot(*zip(*((info['num_updates'], info['acc']) for info in val_history)))\n", 250 | " plt.show()\n", 251 | "\n", 252 | "for trainer in trainers:\n", 253 | " trainer.join()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 5, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "logs/gpu_delay1000ms_failrate0.1_dmoe4outof1024experts_seed1338.pkl\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "import os\n", 271 | "import pickle\n", 272 | "os.system('mkdir -p logs')\n", 273 | "\n", 274 | "num_files = len(os.listdir('logs'))\n", 275 | "fname = f'logs/gpu_delay{delay_ms}ms_failrate0.1_dmoe{num_active_experts}outof{num_experts}experts_seed{seed}.pkl'\n", 276 | "print(fname)\n", 277 | "with open(fname, 'wb') as f_out:\n", 278 | " pickle.dump(dict(train_history=train_history, val_history=val_history), f_out)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [] 287 | } 288 | ], 289 | "metadata": { 290 | "kernelspec": { 291 | "display_name": "py38", 292 | "language": "python", 293 | "name": "py38" 294 | }, 295 | "language_info": { 296 | "codemirror_mode": { 297 | "name": "ipython", 298 | "version": 3 299 | }, 300 | "file_extension": ".py", 301 | "mimetype": "text/x-python", 302 | "name": "python", 303 | "nbconvert_exporter": "python", 304 | "pygments_lexer": "ipython3", 305 | "version": "3.8.1" 306 | } 307 | }, 308 | "nbformat": 4, 309 | "nbformat_minor": 2 310 | } 311 | -------------------------------------------------------------------------------- /experiments/convergence/dmoe_emulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EmulatedDMoE(nn.Module): 7 | def __init__(self, in_features, num_experts, num_active, update_every_inputs, update_every_steps, 8 | Expert, Optimizer): 9 | """ 10 | A local mixture of experts module that emulates the behavior of DMoE from Learning@home 11 | The emulation concerns two main aspects: 12 | * Individual experts are updated not by the trainer but automatically after accumulating enough gradients 13 | * Forward/backward pass takes 14 | 15 | 16 | Warning: experts should NOT be optimized manually! Use get_non_expert_parameters(model) for optimizer 17 | 18 | :param in_features: input dimension 19 | :param num_experts: total number of experts 20 | :param num_active: only *this many* experts with highest score participate in a computation 21 | :param update_every_inputs: automatically triggers an update for an expert after it was used 22 | on *this many* inputs; this counter resets after every update 23 | :param update_every_steps: automatic update happens at most *this many* steps after expert processed 24 | its first input; this counter resets after every update 25 | :param Expert: callable(dim)-> nn.Module that receives and returns vectors of this dimension 26 | :param Optimizer: callable(Sequence[nn.Parameter]) -> torch.optim.Optimizer 27 | """ 28 | super().__init__() 29 | self.gating_pre_normalize = nn.LayerNorm(in_features) 30 | self.expert_keys = nn.Parameter(torch.randn(in_features, num_experts)) 31 | 32 | self.experts = nn.ModuleList([Expert(in_features) for _ in range(num_experts)]) 33 | self.expert_optimizers = {expert: Optimizer(expert.parameters()) for expert in self.experts} 34 | self.register_buffer('expert_inputs_since_update', torch.zeros(num_experts, dtype=torch.int64)) 35 | self.register_buffer('expert_steps_since_first_input', torch.zeros(num_experts, dtype=torch.int64)) 36 | 37 | self.num_active = num_active 38 | self.update_every_inputs = update_every_inputs 39 | self.update_every_steps = update_every_steps 40 | 41 | def forward(self, input): 42 | assert len(input.shape) == 2 43 | batch_size = len(input) 44 | if self.training: 45 | self.maybe_update_experts() 46 | 47 | gating_logits = self.gating_pre_normalize(input) @ F.normalize(self.expert_keys, dim=-1) 48 | chosen_ids = torch.argsort(gating_logits, dim=-1, descending=True)[..., :self.num_active] 49 | 50 | outputs = [] 51 | for i in range(batch_size): 52 | chosen_experts = [self.experts[chosen_id] for chosen_id in chosen_ids[i]] 53 | 54 | weights = F.softmax(gating_logits[i][chosen_ids[i]], dim=-1) 55 | expert_outputs = torch.stack([expert(input[i]) for expert in chosen_experts], dim=-1) 56 | output = expert_outputs @ weights 57 | outputs.append(output) 58 | 59 | outputs = torch.stack(outputs, dim=0) 60 | 61 | 62 | # update expert usage counts 63 | if self.training: 64 | self.expert_inputs_since_update.scatter_add_( 65 | 0, chosen_ids.reshape(-1), 66 | torch.ones_like(chosen_ids, device=self.expert_inputs_since_update.device).view(-1)) 67 | self.expert_steps_since_first_input += (self.expert_inputs_since_update > 0).to(torch.int64) 68 | return outputs 69 | 70 | def maybe_update_experts(self): 71 | for i, expert in enumerate(self.experts): 72 | if self.expert_inputs_since_update[i] >= self.update_every_inputs or \ 73 | self.expert_steps_since_first_input[i] >= self.update_every_steps: 74 | self.expert_optimizers[expert].step() 75 | self.expert_optimizers[expert].zero_grad() 76 | self.expert_inputs_since_update[i] = 0 77 | self.expert_steps_since_first_input[i] = 0 78 | 79 | def get_non_expert_params(model): 80 | expert_params = set(param for module in model.modules() if isinstance(module, EmulatedDMoE) 81 | for param in module.parameters()) 82 | return [param for param in model.parameters() if param not in expert_params] 83 | 84 | -------------------------------------------------------------------------------- /experiments/convergence/faulty_dmoe_emulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EmulatedFaultyDMoE(nn.Module): 7 | def __init__(self, in_features, num_experts, num_active, update_every_inputs, update_every_steps, failure_rate, 8 | Expert, Optimizer): 9 | """ 10 | A local mixture of experts module that emulates the behavior of DMoE from Learning@home 11 | The emulation concerns two main aspects: 12 | * Individual experts are updated not by the trainer but automatically after accumulating enough gradients 13 | * Forward/backward pass takes 14 | 15 | 16 | Warning: experts should NOT be optimized manually! Use get_non_expert_parameters(model) for optimizer 17 | 18 | :param in_features: input dimension 19 | :param num_experts: total number of experts 20 | :param num_active: only *this many* experts with highest score participate in a computation 21 | :param update_every_inputs: automatically triggers an update for an expert after it was used 22 | on *this many* inputs; this counter resets after every update 23 | :param update_every_steps: automatic update happens at most *this many* steps after expert processed 24 | its first input; this counter resets after every update 25 | :param Expert: callable(dim)-> nn.Module that receives and returns vectors of this dimension 26 | :param Optimizer: callable(Sequence[nn.Parameter]) -> torch.optim.Optimizer 27 | """ 28 | super().__init__() 29 | self.gating_pre_normalize = nn.LayerNorm(in_features) 30 | self.expert_keys = nn.Parameter(torch.randn(in_features, num_experts)) 31 | 32 | self.experts = nn.ModuleList([Expert(in_features) for _ in range(num_experts)]) 33 | self.expert_optimizers = {expert: Optimizer(expert.parameters()) for expert in self.experts} 34 | self.register_buffer('expert_inputs_since_update', torch.zeros(num_experts, dtype=torch.int64)) 35 | self.register_buffer('expert_steps_since_first_input', torch.zeros(num_experts, dtype=torch.int64)) 36 | 37 | self.num_active = num_active 38 | self.update_every_inputs = update_every_inputs 39 | self.update_every_steps = update_every_steps 40 | self.failure_rate = failure_rate 41 | 42 | def forward(self, input): 43 | assert len(input.shape) == 2 44 | batch_size = len(input) 45 | if self.training: 46 | self.maybe_update_experts() 47 | 48 | gating_logits = self.gating_pre_normalize(input) @ F.normalize(self.expert_keys, dim=-1) 49 | if self.failure_rate != 0: 50 | gating_logits = torch.where(torch.rand_like(gating_logits) < self.failure_rate, 51 | gating_logits - float('inf'), gating_logits) 52 | chosen_ids = torch.argsort(gating_logits, dim=-1, descending=True)[..., :self.num_active] 53 | 54 | outputs = [] 55 | for i in range(batch_size): 56 | chosen_experts = [self.experts[chosen_id] for chosen_id in chosen_ids[i]] 57 | 58 | weights = F.softmax(gating_logits[i][chosen_ids[i]], dim=-1) 59 | expert_outputs = torch.stack([expert(input[i]) for expert in chosen_experts], dim=-1) 60 | output = expert_outputs @ weights 61 | outputs.append(output) 62 | 63 | outputs = torch.stack(outputs, dim=0) 64 | 65 | 66 | # update expert usage counts 67 | if self.training: 68 | self.expert_inputs_since_update.scatter_add_( 69 | 0, chosen_ids.reshape(-1), 70 | torch.ones_like(chosen_ids, device=self.expert_inputs_since_update.device).view(-1)) 71 | self.expert_steps_since_first_input += (self.expert_inputs_since_update > 0).to(torch.int64) 72 | return outputs 73 | 74 | def maybe_update_experts(self): 75 | for i, expert in enumerate(self.experts): 76 | if self.expert_inputs_since_update[i] >= self.update_every_inputs or \ 77 | self.expert_steps_since_first_input[i] >= self.update_every_steps: 78 | self.expert_optimizers[expert].step() 79 | self.expert_optimizers[expert].zero_grad() 80 | self.expert_inputs_since_update[i] = 0 81 | self.expert_steps_since_first_input[i] = 0 82 | 83 | def get_non_expert_params(model): 84 | expert_params = set(param for module in model.modules() if isinstance(module, EmulatedFaultyDMoE) 85 | for param in module.parameters()) 86 | return [param for param in model.parameters() if param not in expert_params] 87 | 88 | -------------------------------------------------------------------------------- /experiments/throughput/baseline_throughput.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | from itertools import chain, repeat 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from layers import name_to_block, name_to_input 11 | 12 | 13 | class ModelParallelNetwork(nn.Module): 14 | def __init__(self, hid_dim, block_factory, gpus, layers_per_gpu): 15 | super().__init__() 16 | 17 | self.gpus = gpus 18 | self.blocks = nn.ModuleList( 19 | [nn.Sequential(*(torch.jit.script(block_factory(hid_dim)) for _ in range(layers_per_gpu))).to(device) for device in gpus] 20 | ) 21 | 22 | def forward(self, x, ping=None): 23 | for device, layer_list in zip(self.gpus, self.blocks): 24 | x = x.to(device, non_blocking=True) 25 | x = layer_list(x) 26 | return x 27 | 28 | 29 | class DummyCrowdsourcedNetwork(nn.Module): 30 | def __init__(self, hid_dim, block_factory, gpus, layers_per_gpu): 31 | super().__init__() 32 | self.gpu_devices = list(chain.from_iterable(repeat(gpus, layers_per_gpu))) 33 | self.layers = nn.ModuleList( 34 | [torch.jit.script(block_factory(hid_dim)).to(device) for device in self.gpu_devices] 35 | ) 36 | 37 | def forward(self, x, ping): 38 | for device, layer in zip(self.gpu_devices, self.layers): 39 | x = x.to(device, non_blocking=True) 40 | x = layer(x) 41 | # emulate network lag 42 | time.sleep(ping * np.random.weibull(1)) 43 | return x 44 | 45 | 46 | def measure_perf(model_class, batches_for_latency, batches_for_throughput, throughput_runs, ping, input_factory, batch_size, hid_dim, 47 | **kwargs): 48 | m = model_class(hid_dim=hid_dim, **kwargs) 49 | time_per_batch = [] 50 | z = input_factory(batch_size, hid_dim).pin_memory() 51 | out_buf = input_factory(batch_size, hid_dim).pin_memory() 52 | with torch.no_grad(): 53 | # latency: avg time to obtain a result per single processed batch 54 | for _ in range(batches_for_latency + 1): 55 | start = time.time() 56 | output = m(z, ping=ping) 57 | out_buf.copy_(output, non_blocking=True) 58 | torch.cuda.synchronize() 59 | time_per_batch.append(time.time() - start) 60 | # throughput: examples/sec when results are asynchronous 61 | throughputs = [] 62 | for run in range(throughput_runs): 63 | start = time.time() 64 | for _ in range(batches_for_throughput): 65 | output = m(z, ping=ping) 66 | out_buf.copy_(output, non_blocking=True) 67 | torch.cuda.synchronize() 68 | throughputs.append(batch_size * batches_for_throughput / (time.time() - start)) 69 | return np.mean(time_per_batch[1:]), np.std(time_per_batch[1:], ddof=1), np.mean(throughputs), np.std(throughputs, ddof=1) 70 | 71 | 72 | def main(args): 73 | np.random.seed(0) 74 | torch.manual_seed(0) 75 | measure_func = partial(measure_perf, batches_for_latency=args.batches_for_latency, batches_for_throughput=args.batches_for_throughput, 76 | throughput_runs=args.throughput_runs, gpus=args.gpus, layers_per_gpu=args.layers_per_gpu, hid_dim=args.hid_dim, 77 | block_factory=name_to_block[args.block_type], batch_size=args.batch_size, 78 | input_factory=name_to_input[args.block_type]) 79 | 80 | avg_latency, std_latency, avg_throughput, std_throughput = measure_func(ModelParallelNetwork, ping=0) 81 | print(f'ModelParallel (fast, ping=0.00):\t{avg_latency:.2f}±{std_latency:.2f}\t{avg_throughput:.2f}±{std_throughput:.2f}') 82 | 83 | for ping in np.linspace(0, args.max_ping, args.linspace_points): 84 | avg_latency, std_latency, avg_throughput, std_throughput = measure_func(DummyCrowdsourcedNetwork, ping=ping) 85 | print(f'ModelParallel (slow, ping={ping:.2f}):\t{avg_latency:.2f}±{std_latency:.2f}\t{avg_throughput:.2f}±{std_throughput:.2f}') 86 | 87 | 88 | if __name__ == '__main__': 89 | parser = ArgumentParser() 90 | parser.add_argument('--hid-dim', type=int, default=1024) 91 | parser.add_argument('--batches-for-latency', type=int, default=10) 92 | parser.add_argument('--batches-for-throughput', type=int, default=100) 93 | parser.add_argument('--batch-size', type=int, default=2048) 94 | parser.add_argument('--throughput-runs', type=int, default=10) 95 | parser.add_argument('--max-ping', type=float, default=0.2) 96 | parser.add_argument('--linspace-points', type=int, default=10) 97 | parser.add_argument('--gpus', type=int, nargs='+', required=True) 98 | parser.add_argument('--layers-per-gpu', type=int, default=56) 99 | parser.add_argument('--block-type', choices=name_to_block.keys(), required=True) 100 | args = parser.parse_args() 101 | main(args) 102 | -------------------------------------------------------------------------------- /experiments/throughput/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | 5 | class FeedforwardBlock(nn.Module): 6 | def __init__(self, hid_dim): 7 | super().__init__() 8 | self.layers = nn.Sequential( 9 | nn.Linear(hid_dim, 4 * hid_dim), 10 | nn.LayerNorm(4 * hid_dim), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(4 * hid_dim, 4 * hid_dim), 13 | nn.LayerNorm(4 * hid_dim), 14 | nn.ReLU(inplace=True), 15 | nn.Linear(4 * hid_dim, hid_dim), 16 | ) 17 | 18 | def forward(self, x): 19 | return x + self.layers(x) 20 | 21 | 22 | class TransformerEncoderLayer(nn.Module): 23 | """ 24 | A slight modification of torch.nn.TransformerEncoderLayer which allows for torch.jit scripting 25 | """ 26 | 27 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): 28 | super().__init__() 29 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 30 | # Implementation of Feedforward model 31 | self.linear1 = nn.Linear(d_model, dim_feedforward) 32 | self.dropout = nn.Dropout(dropout) 33 | self.linear2 = nn.Linear(dim_feedforward, d_model) 34 | 35 | self.norm1 = nn.LayerNorm(d_model) 36 | self.norm2 = nn.LayerNorm(d_model) 37 | self.dropout1 = nn.Dropout(dropout) 38 | self.dropout2 = nn.Dropout(dropout) 39 | 40 | self.activation = torch.nn.GELU() 41 | 42 | def forward(self, src): 43 | src.transpose_(0, 1) 44 | src2 = self.self_attn(src, src, src)[0] 45 | src = src + self.dropout1(src2) 46 | src = self.norm1(src) 47 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 48 | src = src + self.dropout2(src2) 49 | src = self.norm2(src) 50 | src.transpose_(0, 1) 51 | return src 52 | 53 | 54 | name_to_block = {'ffn': lambda hid_dim: FeedforwardBlock(hid_dim), 55 | 'transformer': lambda hid_dim: TransformerEncoderLayer(hid_dim, nhead=16)} 56 | name_to_input = {'ffn': lambda batch_size, hid_dim: torch.empty((batch_size, hid_dim)), 57 | 'transformer': lambda batch_size, hid_dim: torch.empty((batch_size, 512, hid_dim))} 58 | -------------------------------------------------------------------------------- /experiments/throughput/rpc_throughput.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from argparse import ArgumentParser 4 | from functools import partial 5 | from itertools import chain, repeat 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed.rpc as rpc 10 | import torch.nn as nn 11 | from layers import name_to_block, name_to_input 12 | from torch.distributed.rpc import rpc_sync 13 | from tqdm import trange 14 | 15 | 16 | class BlockWorker(nn.Module): 17 | def __init__(self, hid_dim, block_type): 18 | super().__init__() 19 | self.block = name_to_block[block_type](hid_dim).cuda() 20 | 21 | def forward(self, x): 22 | return self.block(x.cuda()).cpu() 23 | 24 | 25 | def _call_method(method, rref, *args, **kwargs): 26 | return method(rref.local_value(), *args, **kwargs) 27 | 28 | 29 | def _remote_method(method, rref, *args, **kwargs): 30 | args = [method, rref] + list(args) 31 | return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs) 32 | 33 | 34 | class ModelParallelRPC(nn.Module): 35 | def __init__(self, hid_dim, block_type, workers, layers_per_gpu): 36 | super().__init__() 37 | self.workers = list(chain.from_iterable(repeat(workers, layers_per_gpu))) 38 | self.layer_rrefs = [rpc.remote(worker, BlockWorker, args=(hid_dim, block_type)) for worker in self.workers] 39 | 40 | def forward(self, x): 41 | for layer_rref in self.layer_rrefs: 42 | x = _remote_method(BlockWorker.forward, layer_rref, x) 43 | return x 44 | 45 | 46 | def measure_perf(model_class, batches_for_latency, batches_for_throughput, throughput_runs, input_factory, 47 | batch_size, hid_dim, 48 | **kwargs): 49 | m = model_class(hid_dim=hid_dim, **kwargs) 50 | time_per_batch = [] 51 | z = input_factory(batch_size, hid_dim) 52 | with torch.no_grad(): 53 | # throughput: examples/sec when results are asynchronous 54 | throughputs = [] 55 | for run in trange(throughput_runs): 56 | start = time.time() 57 | for _ in range(batches_for_throughput): 58 | output = m(z) 59 | throughputs.append(batch_size * batches_for_throughput / (time.time() - start)) 60 | print(throughputs) 61 | return np.mean(throughputs), np.std(throughputs, ddof=1) 62 | 63 | 64 | def main(args): 65 | np.random.seed(0) 66 | torch.manual_seed(0) 67 | measure_func = partial(measure_perf, batches_for_latency=args.batches_for_latency, 68 | batches_for_throughput=args.batches_for_throughput, 69 | throughput_runs=args.throughput_runs, 70 | layers_per_gpu=args.layers_per_gpu, 71 | hid_dim=args.hid_dim, 72 | block_type=args.block_type, batch_size=args.batch_size, 73 | input_factory=name_to_input[args.block_type], 74 | workers=[f'worker{rank}' for rank in range(1, args.world_size)]) 75 | 76 | avg_throughput, std_throughput = measure_func(ModelParallelRPC) 77 | print(f'ModelParallel:\t{avg_throughput:.2f}±{std_throughput:.2f}') 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = ArgumentParser() 82 | parser.add_argument('--hid-dim', type=int, default=1024) 83 | parser.add_argument('--batches-for-latency', type=int, default=10) 84 | parser.add_argument('--batches-for-throughput', type=int, default=100) 85 | parser.add_argument('--batch-size', type=int, default=2048) 86 | parser.add_argument('--throughput-runs', type=int, default=10) 87 | parser.add_argument('--rank', type=int, required=True) 88 | parser.add_argument('--world-size', type=int, required=True) 89 | parser.add_argument('--layers-per-gpu', type=int, default=56) 90 | parser.add_argument('--block-type', choices=name_to_block.keys(), required=True) 91 | args = parser.parse_args() 92 | rpc.init_rpc(f"worker{args.rank}", rank=args.rank, world_size=args.world_size) 93 | if args.rank == 0: 94 | main(args) 95 | rpc.shutdown() 96 | -------------------------------------------------------------------------------- /experiments/throughput/throughput_client.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | from itertools import chain 5 | from multiprocessing import Pool 6 | from time import time, sleep 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | from layers import name_to_block, name_to_input 13 | 14 | sys.path.append('../../') 15 | import lib.client 16 | 17 | 18 | class ExpertsWithLatency(nn.Module): 19 | def __init__(self, experts): 20 | super().__init__() 21 | self.experts = nn.Sequential(*experts) 22 | 23 | def forward(self, x, ping): 24 | for layer in self.experts: 25 | x = layer(x) 26 | sleep(ping * np.random.weibull(1)) 27 | return x 28 | 29 | 30 | @torch.no_grad() 31 | def measure_perf(ping, model, x, num_batches): 32 | latencies = [] 33 | for batch in range(num_batches + 1): 34 | start = time() 35 | output = model(x, ping=ping) 36 | latencies.append(time() - start) 37 | return latencies[1:] 38 | 39 | 40 | def main(args): 41 | np.random.seed(0) 42 | torch.manual_seed(0) 43 | remote_experts = [] 44 | for address in args.hosts: 45 | host, port = address.split(':') 46 | experts_at_host = [lib.RemoteExpert(f'expert{i}', host=host, port=int(port)) for i in range(args.layers_per_gpu)] 47 | remote_experts.append(experts_at_host) 48 | experts = list(chain.from_iterable(zip(*remote_experts))) 49 | experts_with_delay = ExpertsWithLatency(experts) 50 | 51 | input = name_to_input[args.block_type](args.batch_size, args.hid_dim) 52 | measure_func = partial(measure_perf, model=experts_with_delay, x=input, num_batches=args.batches_for_throughput) 53 | 54 | with Pool(args.jobs) as p: 55 | for ping in np.linspace(0, args.max_ping, args.linspace_points): 56 | latencies = measure_perf(ping, experts_with_delay, input, args.batches_for_latency) 57 | 58 | throughputs = [] 59 | for run in range(args.throughput_runs): 60 | processing_start = time() 61 | results = p.map(measure_func, [ping for _ in range(args.jobs)]) 62 | # assume that all processes were working synchronously, then number of processed examples per second 63 | # needs to be summed by number of processes 64 | throughput = args.jobs * args.batch_size * (args.batches_for_throughput + 1) / (time() - processing_start) 65 | throughputs.append(throughput) 66 | avg_latency, std_latency = np.mean(latencies), np.std(latencies, ddof=1) 67 | avg_throughput, std_throughput = np.mean(throughputs), np.std(throughputs, ddof=1) 68 | print(f'ModelParallel (ours, ping={ping:.2f}):\t{avg_latency:.2f}±{std_latency:.2f}\t{avg_throughput:.2f}±{std_throughput:.2f}') 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = ArgumentParser() 73 | parser.add_argument('-j', '--jobs', type=int, required=True) 74 | parser.add_argument('--hosts', nargs='+', required=True) 75 | parser.add_argument('--hid-dim', type=int, default=1024) 76 | parser.add_argument('--batches-for-latency', type=int, default=10) 77 | parser.add_argument('--batches-for-throughput', type=int, default=100) 78 | parser.add_argument('--throughput-runs', type=int, default=10) 79 | parser.add_argument('--batch-size', type=int, default=2048) 80 | parser.add_argument('--linspace-points', type=int, default=10) 81 | parser.add_argument('--layers-per-gpu', type=int, default=56) 82 | parser.add_argument('--block-type', choices=name_to_block.keys(), required=True) 83 | parser.add_argument('--max-ping', type=float, default=0.2) 84 | args = parser.parse_args() 85 | main(args) 86 | -------------------------------------------------------------------------------- /experiments/throughput/throughput_server.py: -------------------------------------------------------------------------------- 1 | import multiprocessing.managers 2 | import sys 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | 7 | from layers import name_to_block 8 | 9 | sys.path.append('../../') 10 | import lib 11 | 12 | 13 | def main(args): 14 | inp_shape = (args.hid_dim,) if args.block_type == 'ffn' else (512, args.hid_dim) 15 | with multiprocessing.managers.SharedMemoryManager() as shm_manager, multiprocessing.Manager() as array_headers_manager: 16 | try: 17 | array_headers = array_headers_manager.dict() 18 | experts = {} 19 | for i in range(args.layers_per_gpu): 20 | expert = torch.jit.script(name_to_block[args.block_type](args.hid_dim)) 21 | experts[f'expert{i}'] = lib.ExpertBackend(name=f'expert{i}', 22 | expert=expert, opt=torch.optim.Adam(expert.parameters()), 23 | args_schema=(lib.BatchTensorProto(*inp_shape),), 24 | outputs_schema=lib.BatchTensorProto(*inp_shape), 25 | max_batch_size=args.max_batch_size, 26 | shm_manager=shm_manager, array_headers=array_headers, 27 | pool_size=8) 28 | 29 | lib.TesseractServer(None, experts, port=args.port, conn_handler_processes=args.handler_processes, 30 | sender_threads=4, device=torch.device('cuda', args.gpu), 31 | start=True) 32 | except KeyboardInterrupt: 33 | print('Finishing') 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = ArgumentParser() 38 | parser.add_argument('-a', '--handler-processes', type=int, default=256) 39 | parser.add_argument('-p', '--port', type=int, required=True) 40 | parser.add_argument('--hid-dim', type=int, default=1024) 41 | parser.add_argument('--max-batch-size', type=int, default=2048) 42 | 43 | parser.add_argument('--gpu', type=int, required=True) 44 | parser.add_argument('--layers-per-gpu', type=int, default=56) 45 | parser.add_argument('--block-type', choices=name_to_block.keys(), required=True) 46 | args = parser.parse_args() 47 | main(args) 48 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .client import * 3 | from .server import * 4 | from .network import * 5 | -------------------------------------------------------------------------------- /lib/client/__init__.py: -------------------------------------------------------------------------------- 1 | from .gating_function import GatingFunction 2 | from .remote_expert import RemoteExpert 3 | -------------------------------------------------------------------------------- /lib/client/gating_function.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import multiprocessing.pool 3 | from functools import partial 4 | from typing import Tuple, List, Dict, Any 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .remote_expert import RemoteExpert 11 | from ..utils import nested_map, check_numpy, run_and_await_k 12 | 13 | 14 | class GatingFunction(nn.Module): 15 | def __init__(self, *, in_features, grid_size: Tuple[int], network, num_workers=None, 16 | k_best, k_min=1, timeout_after_k_min=1.0, uid_prefix='', expert_padding=None): 17 | super().__init__() 18 | self.network, self.grid_size = network, grid_size 19 | self.uid_prefix, self.expert_padding = uid_prefix, expert_padding 20 | self.k_best, self.k_min, self.timeout_after_k_min = k_best, k_min, timeout_after_k_min 21 | 22 | self.thread_pool = mp.pool.ThreadPool(num_workers or k_best * 2) 23 | self.proj = nn.Linear(in_features, sum(grid_size)) # jointly predict logits for all grid dimensions 24 | 25 | def forward(self, input: torch.Tensor, *args, **kwargs) -> Tuple[List[List[RemoteExpert]], torch.Tensor]: 26 | """ 27 | Choose k best experts with beam search, then call chosen experts and average their outputs. 28 | :param batch: named tensors, each tensor has 0-th axis dedicated to batch (aka batch-first 29 | :return: averaged predictions of all experts that delivered on time 30 | """ 31 | assert len(input.shape) == 2 32 | 33 | # 1. compute scores and find most appropriate experts with beam search 34 | grid_scores = self.proj(input).split_with_sizes(self.grid_size, dim=-1) 35 | batch_experts = self.beam_search(grid_scores, self.k_best) 36 | # ^-- List[batch_size] of List[RemoteExpert] chosen for every input in batch 37 | 38 | # 2.1 call chosen experts (run them in background to save time) 39 | batch_outputs_async = [ 40 | self.thread_pool.apply_async(self._run_experts, 41 | args=[chosen_experts, input[i: i + 1], *(tensor[i: i + 1] for tensor in args)], 42 | kwds={key: tensor[i: i + 1] for key, tensor in kwargs.items()}) 43 | for i, chosen_experts in enumerate(batch_experts) 44 | ] 45 | 46 | # 2.2 compute *differentiable* logits for each expert 47 | batch_expert_logits = self._score_experts(grid_scores, batch_experts) 48 | # ^-- List[batch_size] of Dict[RemoteExpert, logit] before softmax for each active expert 49 | 50 | batch_outputs = [] 51 | for output_async, expert_logits in zip(batch_outputs_async, batch_expert_logits): 52 | expert_outputs: Dict[RemoteExpert, Any] = output_async.get() 53 | flat_experts, flat_outputs = zip(*expert_outputs.items()) 54 | 55 | # 3.1. normalize logits over only those experts that DID return output 56 | flat_logits = torch.stack([expert_logits[expert] for expert in flat_experts]) 57 | flat_weights = torch.softmax(flat_logits, dim=-1) 58 | 59 | # 3.2. average each output across experts 60 | average_outputs = nested_map( 61 | lambda *tensors: sum(x * weight for x, weight in zip(tensors, flat_weights)), *flat_outputs) 62 | 63 | batch_outputs.append(average_outputs) 64 | 65 | # 4. concatenate mixture outputs from individual experts 66 | return nested_map(lambda *tensors: torch.cat(tensors, dim=0), *batch_outputs) 67 | 68 | def beam_search(self, grid_scores: List[torch.Tensor], k_best: int, **kwargs) -> List[List[RemoteExpert]]: 69 | """ 70 | Find and return k best experts in the grid using (exact) beam search of the product space 71 | :param grid_scores: scores predicted for each dimension in the grid, 72 | :type grid_scores: a sequence of tensors of shape[batch_size, self.grid_size[i]] 73 | :param k_best: how many of the top experts participate in the computation 74 | :param kwargs: extra keyword parameters passed to self.network.first_k_active 75 | :returns: a list of *batch_size* lists that contain chosen experts for one sample 76 | each inner list contains RemoteExpert instances for *up to* k_best experts 77 | """ 78 | assert len(grid_scores) == len(self.grid_size) 79 | assert all(len(dim_scores.shape) == 2 for dim_scores in grid_scores) 80 | batch_size = len(grid_scores[0]) 81 | beam = np.array([[self.uid_prefix]] * batch_size, dtype=object) # [batch_size, up_to_beam_size] 82 | scores = np.zeros([batch_size, 1], dtype=np.float64) 83 | 84 | delimeters = np.array(self.network.UID_DELIMETER)[None, None, None] # pre-compute numpy array for fast concat 85 | 86 | for dim_index, dim_scores in enumerate(grid_scores): 87 | dim_scores = check_numpy(dim_scores) 88 | assert dim_scores.shape[-1] == self.grid_size[dim_index] 89 | 90 | # create all possible successsors from current beam 91 | dim_indices = np.arange(dim_scores.shape[1]).astype(str) 92 | new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :] 93 | new_candidates = new_candidates.reshape([batch_size, -1]) 94 | 95 | new_scores = scores[:, :, None] + dim_scores[:, None, :] 96 | new_scores = new_scores.reshape([batch_size, -1]) 97 | 98 | # select k best candidates according to scores but only those that are still active 99 | new_order = np.argsort(- new_scores, axis=-1) 100 | top_alive_lookups = [ 101 | self.thread_pool.apply_async(self.network.first_k_active, args=(cands[order], k_best), kwds=kwargs) 102 | for cands, order in zip(new_candidates, new_order)] 103 | 104 | batch_cand_to_score = [ 105 | dict(zip(cands, cand_scores)) for cands, cand_scores in zip(new_candidates, new_scores)] 106 | 107 | top_alive_prefixes = [result.get() for result in top_alive_lookups] 108 | top_alive_scores = [list(map(cand_to_score.get, top_cands)) 109 | for cand_to_score, top_cands in zip(batch_cand_to_score, top_alive_prefixes)] 110 | 111 | # pad up to beam size 112 | beam = np.array([row + [self.expert_padding] * (k_best - len(row)) 113 | for row in top_alive_prefixes], dtype='object') 114 | scores = np.array([row + [-float('inf')] * (k_best - len(row)) 115 | for row in top_alive_scores], dtype='float32') 116 | 117 | unique_experts = self.network.get_experts(list(set( 118 | uid for row in beam for uid in row if uid != self.expert_padding))) 119 | unique_experts_by_uid = {expert.uid: expert for expert in unique_experts if expert != self.expert_padding} 120 | 121 | return [ 122 | [unique_experts_by_uid[uid] for uid in row if uid in unique_experts_by_uid] 123 | for row in beam] 124 | 125 | def _run_experts(self, experts: List[RemoteExpert], *args, **kwargs) -> Dict[RemoteExpert, torch.Tensor]: 126 | outputs = run_and_await_k([partial(expert, *args, **kwargs) for expert in experts], 127 | k=self.k_min, timeout_after_k=self.timeout_after_k_min) 128 | return {expert: output for expert, output in zip(experts, outputs) 129 | if not isinstance(output, BaseException)} 130 | 131 | def _score_experts(self, grid_scores: List[torch.Tensor], 132 | experts: List[List[RemoteExpert]]) -> List[Dict[RemoteExpert, torch.Tensor]]: 133 | flat_experts = [expert for row in experts for expert in row] 134 | flat_batch_indices = torch.tensor([i for i, row in enumerate(experts) 135 | for uid in range(len(row))]) 136 | 137 | grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64) 138 | for i, expert in enumerate(flat_experts): 139 | expert_indices = expert.uid[len(self.uid_prefix) + len(self.network.UID_DELIMETER):] 140 | expert_indices = list(map(int, expert_indices.split(self.network.UID_DELIMETER))) 141 | grid_indices[i] = expert_indices 142 | 143 | scores_per_dim = [ 144 | dim_scores[flat_batch_indices, dim_indices] if len(flat_batch_indices) else torch.zeros(0) 145 | for dim_scores, dim_indices in zip(grid_scores, grid_indices.T)] 146 | flat_scores = torch.sum(torch.stack(scores_per_dim, dim=0), dim=0) 147 | 148 | output_dicts = [dict() for _ in range(len(experts))] 149 | for batch_i, expert, score in zip(check_numpy(flat_batch_indices), 150 | flat_experts, flat_scores): 151 | output_dicts[batch_i][expert] = score 152 | 153 | return output_dicts 154 | -------------------------------------------------------------------------------- /lib/client/remote_expert.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..utils import nested_flatten, DUMMY, PytorchSerializer, nested_pack, nested_compare, Connection 7 | 8 | 9 | class RemoteExpert(nn.Module): 10 | """ 11 | A simple module that runs forward/backward of an expert hosted on a remote machine. 12 | Works seamlessly with pytorch autograd. (this is essentially a simple RPC function) 13 | 14 | Warning: RemoteExpert currently assumes that you provide it with correct input shapes. 15 | Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime. 16 | 17 | :param uid: unique expert identifier 18 | :param host: hostname where TesseractServer operates 19 | :param port: port to which TesseractServer listens 20 | """ 21 | 22 | def __init__(self, uid, host='127.0.0.1', port=8080): 23 | super().__init__() 24 | self.uid, self.host, self.port = uid, host, port 25 | self._info = None 26 | 27 | def forward(self, *args, **kwargs): 28 | assert len(kwargs) == len(self.info['keyword_names']), f"Keyword args should be {self.info['keyword_names']}" 29 | kwargs = {key: kwargs[key] for key in self.info['keyword_names']} 30 | # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors 31 | 32 | forward_inputs = (args, kwargs) 33 | 34 | if not nested_compare(forward_inputs, self.info['forward_schema']): 35 | raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?") 36 | 37 | flat_outputs = _RemoteModuleCall.apply(DUMMY, self.uid, self.host, self.port, *nested_flatten(forward_inputs)) 38 | # Note: we send DUMMY to prevent torch from excluding expert from backward if no other inputs require grad 39 | return nested_pack(flat_outputs, structure=self.info['outputs_schema']) 40 | 41 | @property 42 | def info(self): 43 | if self._info is None: 44 | connection = Connection.create(self.host, self.port) 45 | connection.send_raw('info', PytorchSerializer.dumps(self.uid)) 46 | self._info = PytorchSerializer.loads(connection.recv_message()[1]) 47 | return self._info 48 | 49 | def extra_repr(self): 50 | return f"uid={self.uid}, host={self.host}, port={self.port}" 51 | 52 | 53 | class _RemoteModuleCall(torch.autograd.Function): 54 | """ Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead. """ 55 | 56 | @staticmethod 57 | def forward(ctx, dummy: torch.Tensor, 58 | uid: str, host: str, port: int, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: 59 | # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] 60 | inputs = tuple(map(torch.Tensor.detach, inputs)) # detach to avoid pickling the computation graph 61 | ctx.uid, ctx.host, ctx.port = uid, host, port 62 | ctx.save_for_backward(*inputs) 63 | 64 | connection = Connection.create(ctx.host, ctx.port) 65 | connection.send_raw('fwd_', PytorchSerializer.dumps((ctx.uid, inputs))) 66 | rtype, msg = connection.recv_message() 67 | return tuple(PytorchSerializer.loads(msg)) # flattened expert outputs 68 | 69 | @staticmethod 70 | def backward(ctx, *grad_outputs) -> Tuple[Optional[torch.Tensor], ...]: 71 | connection = Connection.create(ctx.host, ctx.port) 72 | payload = tuple(nested_flatten((ctx.saved_tensors, grad_outputs))) 73 | connection.send_raw('bwd_', PytorchSerializer.dumps((ctx.uid, payload))) 74 | rtype, msg = connection.recv_message() 75 | grad_inputs = PytorchSerializer.loads(msg) 76 | return (DUMMY, None, None, None, *grad_inputs) 77 | -------------------------------------------------------------------------------- /lib/network/__init__.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | import multiprocessing as mp 4 | from typing import Tuple, List, Optional 5 | 6 | from kademlia.network import Server 7 | 8 | from .. import run_in_background, repeated, SharedFuture, PickleSerializer, RemoteExpert 9 | 10 | 11 | class TesseractNetwork(mp.Process): 12 | UID_DELIMETER = '.' # splits expert uids over this delimeter 13 | HEARTBEAT_EXPIRATION = 120 # expert is inactive iff it fails to post timestamp for *this many seconds* 14 | make_key = "{}::{}".format 15 | 16 | def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False): 17 | super().__init__() 18 | self.port, self.initial_peers = port, initial_peers 19 | self._pipe, self.pipe = mp.Pipe(duplex=False) 20 | self.server = Server() 21 | if start: 22 | self.start() 23 | 24 | def run(self) -> None: 25 | loop = asyncio.new_event_loop() 26 | asyncio.set_event_loop(loop) 27 | loop.run_until_complete(self.server.listen(self.port)) 28 | loop.run_until_complete(self.server.bootstrap(self.initial_peers)) 29 | run_in_background(repeated(loop.run_forever)) 30 | 31 | while True: 32 | method, args, kwargs = self._pipe.recv() 33 | getattr(self, method)(*args, **kwargs) 34 | 35 | def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]: 36 | """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]""" 37 | future, _future = SharedFuture.make_pair() 38 | self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future))) 39 | return future.result() 40 | 41 | def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture): 42 | loop = asyncio.get_event_loop() 43 | lookup_futures = [asyncio.run_coroutine_threadsafe( 44 | self.server.get(self.make_key('expert', uid)), loop) for uid in uids] 45 | current_time = datetime.datetime.now() 46 | 47 | experts = [None] * len(uids) 48 | for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)): 49 | if lookup.result() is not None: 50 | (host, port), timestamp = PickleSerializer.loads(lookup.result()) 51 | if (current_time - timestamp).total_seconds() <= heartbeat_expiration: 52 | experts[i] = RemoteExpert(uid=uid, host=host, port=port) 53 | 54 | future.set_result(experts) 55 | 56 | def declare_experts(self, uids: List[str], addr, port, wait_timeout=0): 57 | """ 58 | Make experts available to DHT; update timestamps if already available 59 | :param uids: a list of expert ids to update 60 | :param addr: hostname that can be used to call this expert 61 | :param port: port that can be used to call this expert 62 | :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish 63 | """ 64 | done_event = mp.Event() if wait_timeout else None 65 | self.pipe.send(('_declare_experts', [], dict(uids=uids, addr=addr, port=port, done_event=done_event))) 66 | if done_event is not None: 67 | done_event.wait(wait_timeout) 68 | 69 | def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]): 70 | loop = asyncio.get_event_loop() 71 | timestamp = datetime.datetime.now() 72 | expert_metadata = PickleSerializer.dumps(((addr, port), timestamp)) 73 | prefix_metadata = PickleSerializer.dumps(timestamp) 74 | 75 | unique_prefixes = set() 76 | 77 | for uid in uids: 78 | asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop) 79 | uid_parts = uid.split(self.UID_DELIMETER) 80 | unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))]) 81 | 82 | for prefix in unique_prefixes: 83 | asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop) 84 | 85 | if done_event is not None: 86 | done_event.set() 87 | 88 | def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None): 89 | """ 90 | Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search 91 | :param prefixes: a list of uid prefixes ordered from highest to lowest priority 92 | :param k: return at most *this many* active prefixes 93 | :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago 94 | :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k 95 | :returns: a list of at most :k: prefixes that have at least one active expert each; 96 | """ 97 | future, _future = SharedFuture.make_pair() 98 | self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration, 99 | max_prefetch=max_prefetch or k, future=_future))) 100 | return future.result() 101 | 102 | def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture): 103 | loop = asyncio.get_event_loop() 104 | lookup_prefetch = [asyncio.run_coroutine_threadsafe( 105 | self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]] 106 | current_time = datetime.datetime.now() 107 | 108 | active_prefixes = [] 109 | 110 | for i, prefix in enumerate(prefixes): 111 | lookup = lookup_prefetch[i] 112 | 113 | if lookup.result() is not None: 114 | timestamp = PickleSerializer.loads(lookup.result()) 115 | if (current_time - timestamp).total_seconds() <= heartbeat_expiration: 116 | active_prefixes.append(prefix) 117 | if len(active_prefixes) >= k: 118 | future.set_result(active_prefixes) 119 | return 120 | 121 | # pre-dispatch the next request in line 122 | if len(lookup_prefetch) < len(prefixes): 123 | lookup_prefetch.append( 124 | asyncio.run_coroutine_threadsafe(self.server.get( 125 | self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop) 126 | ) 127 | 128 | # could not find enough active prefixes; return what we can 129 | future.set_result(active_prefixes) 130 | -------------------------------------------------------------------------------- /lib/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from itertools import chain 3 | from selectors import DefaultSelector, EVENT_READ 4 | from typing import Dict, List 5 | 6 | import torch 7 | import tqdm 8 | from prefetch_generator import BackgroundGenerator 9 | 10 | from .expert_backend import ExpertBackend 11 | from .task_pool import TaskPool, TaskPoolBase 12 | from ..utils import check_numpy 13 | 14 | 15 | class TesseractRuntime: 16 | def __init__(self, expert_backends: Dict[str, ExpertBackend], prefetch_batches=64, sender_threads: int = 1, 17 | device: torch.device = None): 18 | """ 19 | A group of processes that process tasks for multiple experts on a shared device 20 | :param expert_backends: a dict [expert uid -> ExpertBackend] 21 | :param prefetch_batches: generate up to this many batches in advance 22 | :param start: start process at the end of __init__ 23 | """ 24 | super().__init__() 25 | self.expert_backends = expert_backends 26 | self.pools = tuple(chain(*(expert.get_pools() for expert in expert_backends.values()))) 27 | self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads 28 | 29 | def main(self): 30 | progress = tqdm.tqdm(bar_format='{desc}, {rate_fmt}') 31 | for pool in self.pools: 32 | if not pool.is_alive(): 33 | pool.start() 34 | if self.device is not None: 35 | for expert_backend in self.expert_backends.values(): 36 | expert_backend.to(self.device) 37 | 38 | with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool, DefaultSelector() as sel: 39 | for pool in self.pools: 40 | sel.register(pool.batch_receiver, EVENT_READ, pool) 41 | try: 42 | for pool, batch_index, batch in BackgroundGenerator( 43 | self.iterate_minibatches_from_pools(selector=sel), self.prefetch_batches): 44 | outputs = pool.process_func(*batch) 45 | progress.update(len(outputs[0])) 46 | progress.desc = f'{pool.uid=} {len(outputs[0])=}' 47 | output_sender_pool.apply_async(self.send_outputs_to_pool, args=(pool, batch_index, outputs)) 48 | except KeyboardInterrupt: 49 | print('Runtime caught KeyboardInterrupt, exiting') 50 | for pool in self.pools: 51 | pool.join() 52 | 53 | def iterate_minibatches_from_pools(self, selector, timeout=None): 54 | """ 55 | Chooses pool according to priority, then copies exposed batch and frees the buffer 56 | """ 57 | while True: 58 | try: 59 | # wait until at least one batch_receiver becomes available 60 | ready_objects = selector.select() 61 | ready_pools = (key.data for (key, events) in ready_objects) 62 | pool = max(ready_pools, key=lambda pool: pool.priority) 63 | 64 | batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device) 65 | yield pool, batch_index, batch_tensors 66 | except KeyboardInterrupt: 67 | break 68 | 69 | def send_outputs_to_pool(self, pool: TaskPool, batch_index: int, outputs: List[torch.Tensor]): 70 | return pool.send_outputs_from_runtime(batch_index, [check_numpy(output) for output in outputs]) 71 | -------------------------------------------------------------------------------- /lib/runtime/expert_backend.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Any, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .task_pool import TaskPool 7 | from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProto, DUMMY_BATCH_SIZE, nested_map 8 | 9 | 10 | class ExpertBackend(nn.Module): 11 | def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *, 12 | args_schema: Tuple[BatchTensorProto, ...] = None, 13 | kwargs_schema: Dict[str, BatchTensorProto] = None, 14 | outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None, 15 | **kwargs): 16 | """ 17 | ExpertBackend implements how a given expert processes tasks. 18 | By default, there are two tasks: 19 | * forward receives inputs and produces outputs 20 | * backward receives gradients w.r.t. outputs, computes gradients w.r.t. inputs and trains the expert 21 | 22 | All incoming tasks are grouped by type (forward/backward) and sent into the corresponding pool, 23 | where tasks are grouped into minibatches and prepared for processing on device; 24 | The results are dispatched to task authors with SharedFuture.set_result. 25 | 26 | :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations: 27 | * Experts must always receive the same set of *args and **kwargs and produce output tensors of same type 28 | * All *args, **kwargs and outputs must be *tensors* where 0-th dimension represents to batch size 29 | * We recommend using experts that are ~invariant to the order in which they process batches 30 | 31 | :param opt: torch optimizer to be applied on every backward call 32 | :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto 33 | :param kwargs_schema: description of keyword arguments to expert.forward, dict of BatchTensorProto 34 | :param outputs_schema: description of outputs from expert.forward, nested structure of BatchTensorProto 35 | :param kwargs: extra parameters to be forwarded into TaskPool.__init__ 36 | """ 37 | super().__init__() 38 | self.expert, self.opt, self.name = expert, opt, name 39 | 40 | self.args_schema = args_schema = tuple(args_schema or ()) 41 | self.kwargs_schema = kwargs_schema = dict(kwargs_schema or {}) 42 | assert len(args_schema) or len(kwargs_schema), "expert must receive at least one positional or keyword input." \ 43 | " Did you forget to provide args_schema/kwargs_schema?" 44 | 45 | if outputs_schema is None: 46 | # run expert once to get outputs schema 47 | dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema) 48 | dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()} 49 | dummy_outputs = self.expert(*dummy_args, **dummy_kwargs) 50 | outputs_schema = nested_map(BatchTensorProto.from_tensor, dummy_outputs) 51 | 52 | self.forward_schema = (self.args_schema, self.kwargs_schema) 53 | self.outputs_schema = outputs_schema 54 | self.forward_pool = TaskPool( 55 | self.forward, inputs_schema=tuple(nested_flatten(self.forward_schema)), 56 | outputs_schema=tuple(nested_flatten(self.outputs_schema)), uid=f'{self.name}_forward', **kwargs) 57 | 58 | self.backward_schema = (self.forward_schema, self.outputs_schema) # original inputs and grad w.r.t. outputs 59 | self.backward_pool = TaskPool( 60 | self.backward, inputs_schema=tuple(nested_flatten(self.backward_schema)), 61 | outputs_schema=tuple(nested_flatten(self.forward_schema)), # return grads w.r.t. inputs with same schema 62 | uid=f'{self.name}_backward', **kwargs) 63 | 64 | def forward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: 65 | args, kwargs = nested_pack(inputs, structure=self.forward_schema) 66 | 67 | with torch.no_grad(): 68 | outputs = self.expert(*args, **kwargs) 69 | 70 | # Note: TaskPool requires function to accept and return a **list** of values, we pack/unpack it on client side 71 | return tuple(nested_flatten(outputs)) 72 | 73 | def backward(self, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]: 74 | (args, kwargs), grad_outputs = nested_pack(inputs, structure=self.backward_schema) 75 | 76 | with torch.enable_grad(): 77 | args = [tensor.detach().requires_grad_(True) for tensor in args] 78 | kwargs = {input_key: tensor.detach().requires_grad_(True) for input_key, tensor in kwargs.items()} 79 | 80 | outputs = self.expert(*args, **kwargs) 81 | assert nested_compare(outputs, grad_outputs), "outputs and grad_outputs must have the same structure" 82 | 83 | outputs_flat = tuple(nested_flatten(outputs)) 84 | 85 | grad_outputs_flat = tuple(map( 86 | lambda grad, out: grad.to(device=out.device, dtype=out.dtype, non_blocking=True), 87 | nested_flatten(grad_outputs), outputs_flat)) 88 | torch.autograd.backward(outputs_flat, grad_tensors=grad_outputs_flat, 89 | create_graph=False, retain_graph=False) 90 | self.apply_gradients() 91 | 92 | return tuple(x.grad if isinstance(x.grad, torch.Tensor) else torch.zeros_like(x) 93 | for x in nested_flatten((args, kwargs))) 94 | 95 | def apply_gradients(self) -> None: 96 | self.opt.step() 97 | self.opt.zero_grad() 98 | 99 | def get_pools(self) -> Sequence[TaskPool]: 100 | return self.forward_pool, self.backward_pool 101 | 102 | def get_info(self) -> Dict[str, Any]: 103 | return dict(forward_schema=self.forward_schema, outputs_schema=self.outputs_schema, 104 | keyword_names=tuple(self.kwargs_schema.keys())) 105 | -------------------------------------------------------------------------------- /lib/runtime/task_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | Task pool is responsible for receiving tasks and grouping them together for processing (but not processing itself) 3 | """ 4 | import ctypes 5 | import multiprocessing as mp 6 | import os 7 | import threading 8 | import uuid 9 | from collections import namedtuple 10 | from concurrent.futures import Future 11 | from queue import Empty 12 | from typing import List, Tuple, Dict, Any 13 | 14 | import numpy as np 15 | import torch 16 | 17 | from ..utils import SharedFuture, SharedArrays, BatchTensorProto, SharedArray, check_numpy, time 18 | 19 | Task = namedtuple("Task", ("future", "args")) 20 | 21 | 22 | class TaskPoolBase(mp.Process): 23 | """ A pool that accepts tasks and forms batches for parallel processing, interacts with TesseractRuntime """ 24 | 25 | def __init__(self, process_func: callable): 26 | super().__init__() 27 | self.process_func = process_func 28 | self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool 29 | 30 | def run(self): 31 | raise NotImplementedError() 32 | 33 | def submit_task(self, *args: torch.Tensor) -> Future: 34 | raise NotImplementedError() 35 | 36 | def form_batch(self, *args, **kwargs) -> List[Task]: 37 | raise NotImplementedError() 38 | 39 | def iterate_minibatches(self, *args, **kwargs): 40 | while True: 41 | yield self.form_batch(*args, **kwargs) 42 | 43 | @property 44 | def priority(self): 45 | return self._priority.value 46 | 47 | @priority.setter 48 | def priority(self, value): 49 | self._priority.value = float(value) 50 | 51 | @property 52 | def empty(self): 53 | raise NotImplementedError() 54 | 55 | 56 | class TaskPool(TaskPoolBase): 57 | 58 | def __init__(self, process_func: callable, 59 | inputs_schema: Tuple[BatchTensorProto, ...], outputs_schema: Tuple[BatchTensorProto, ...], 60 | max_batch_size: int, min_batch_size=1, timeout=None, pool_size=None, prefetch_batches=1, uid=None, 61 | shm_manager=None, array_headers=None, start=False): 62 | """ 63 | Naive implementation of task pool that forms batch from earliest submitted tasks 64 | :param process_func: function to be applied to every formed batch; called by TesseractRuntime 65 | Note: process_func should accept only *args Tensors and return a list of output Tensors 66 | :param inputs_schema: description of arguments to process_func, list of BatchTensorProto, positional only 67 | :param outputs_schema: description of outputs from process_func, list of BatchTensorProto, must return a list 68 | :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs) 69 | :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more 70 | :param timeout: wait for a subsequent task for at most this many seconds 71 | :param pool_size: store at most this many unprocessed tasks in a queue 72 | :param prefetch_batches: prepare up to this many *batches* in background for faster off-loading to runtime 73 | :param uid: pool identifier used for shared array allocation 74 | :param start: if True, start automatically at the end of __init__ 75 | """ 76 | 77 | super().__init__(process_func) 78 | self.min_batch_size, self.max_batch_size, self.timeout = min_batch_size, max_batch_size, timeout 79 | self.inputs_schema, self.outputs_schema = list(inputs_schema), list(outputs_schema) 80 | self.uid = uid or uuid.uuid4() 81 | self.prefetch_batches = prefetch_batches 82 | 83 | # interaction with ConnectionHandlers 84 | self.tasks = mp.Queue(maxsize=pool_size or 0) 85 | self.undispatched_task_timestamps = mp.SimpleQueue() 86 | 87 | # interaction with TesseractRuntime 88 | self.shared_arrays = SharedArrays(array_headers=array_headers, shm_manager=shm_manager) 89 | 90 | self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False) # send/recv array names that contain batch inputs 91 | self.batch_received = mp.Event() # runtime can notify pool that it can send next batch 92 | 93 | self.outputs_receiver, self.outputs_sender = mp.Pipe(duplex=False) # send/recv array names that contain outputs 94 | self.outputs_received = mp.Event() # pool can notify runtime that it can send next outputs 95 | 96 | if start: 97 | self.start() 98 | 99 | def submit_task(self, *args: torch.Tensor) -> Future: 100 | future1, future2 = SharedFuture.make_pair() 101 | self.tasks.put(Task(future1, args)) 102 | self.undispatched_task_timestamps.put(time.time()) 103 | return future2 104 | 105 | def form_batch(self) -> List[Task]: 106 | batch_tasks = [] 107 | total_size = 0 108 | 109 | while total_size < self.max_batch_size: 110 | if total_size >= self.min_batch_size and self.tasks.empty(): 111 | break # timeout reached, returning incomplete batch 112 | 113 | try: 114 | task = self.tasks.get(timeout=self.timeout) 115 | except Empty: 116 | exc = TimeoutError(f"Timeout reached but batch doesn't contain >={self.min_batch_size} elements yet.") 117 | for task in batch_tasks: 118 | task.future.set_exception(exc) 119 | raise exc 120 | 121 | if task.future.set_running_or_notify_cancel(): 122 | batch_tasks.append(task) 123 | total_size += self.get_task_size(task) 124 | 125 | return batch_tasks 126 | 127 | def run(self, *args, status_timeout=0.1, **kwargs): 128 | print(f'Starting pool, {os.getpid()=}') 129 | pending_batches = {} # Dict[batch uuid, List[SharedFuture]] for each batch currently in runtime 130 | self.batch_received.set(), self.outputs_received.set() # initial state: no batches/outputs pending 131 | output_thread = threading.Thread(target=self._pool_output_loop, args=[pending_batches]) 132 | try: 133 | output_thread.start() 134 | self._pool_input_loop(pending_batches, *args, **kwargs) 135 | except KeyboardInterrupt: 136 | print('Pool caught KeyboardInterrupt, exiting') 137 | finally: 138 | output_thread.join() 139 | self.shared_arrays.shm_manager.shutdown() 140 | 141 | def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwargs): 142 | """ Thread method that continually forms batches and sends them to runtime """ 143 | prev_num_tasks = 0 # number of tasks currently in shared buffer 144 | batch_index = max(pending_batches.keys(), default=0) 145 | batch_iterator = self.iterate_minibatches(*args, **kwargs) 146 | 147 | while True: 148 | self.batch_received.wait() # wait for runtime to receive (copy) previous batch 149 | 150 | # SIDE-EFFECT - compute pool priority from timestamp of earliest undispatched task 151 | # assumes that tasks are processed in the same order as they are created 152 | for skip_i in range(prev_num_tasks): 153 | finished_task_timestamp = self.undispatched_task_timestamps.get() # earlier timestamp = higher priority 154 | if skip_i == prev_num_tasks - 1: 155 | self.priority = finished_task_timestamp 156 | 157 | batch_tasks = next(batch_iterator) 158 | # save batch futures, _output_loop will deliver on them later 159 | pending_batches[batch_index] = batch_tasks 160 | 161 | # find or create shared arrays for current batch size 162 | batch_size = sum(map(self.get_task_size, batch_tasks)) 163 | shared_keys, shared_buffers = zip( 164 | *self.get_or_create_buffers(batch_size, self.inputs_schema, name='inputs')) 165 | 166 | self.batch_received.clear() # sending next batch... 167 | for i, buffer in enumerate(shared_buffers): 168 | np.concatenate([task.args[i] for task in batch_tasks], out=buffer) # assemble batch from tasks 169 | 170 | self.batch_sender.send((batch_index, shared_keys)) # send input keys, trigger runtime to receive batch 171 | batch_index += 1 172 | prev_num_tasks = len(batch_tasks) 173 | 174 | def _pool_output_loop(self, pending_batches: Dict[Any, List[Task]]): 175 | """ Thread method that continually receives results from runtime and dispatches them to task Futures """ 176 | 177 | while True: 178 | try: 179 | batch_index, output_keys = self.outputs_receiver.recv() 180 | batch_outputs = [self.shared_arrays[key].copy() for key in output_keys] 181 | self.outputs_received.set() # runtime can now send next output 182 | 183 | # split batch into partitions for individual tasks 184 | batch_tasks = pending_batches.pop(batch_index) 185 | task_sizes = [self.get_task_size(task) for task in batch_tasks] 186 | task_sections = np.cumsum(task_sizes)[:-1] # index in batch where task begins, for all tasks expert first 187 | outputs_per_task = zip(*(np.split(array, task_sections) for array in batch_outputs)) 188 | 189 | # dispatch results to futures 190 | for task, task_outputs in zip(batch_tasks, outputs_per_task): 191 | task.future.set_result(tuple( 192 | proto.convert_array_to_tensor(array) for proto, array in zip(self.outputs_schema, task_outputs) 193 | )) 194 | except KeyboardInterrupt: 195 | break 196 | 197 | @property 198 | def empty(self): 199 | return not self.batch_receiver.poll() 200 | 201 | def load_batch_to_runtime(self, timeout=None, device=None) -> Tuple[Any, List[torch.Tensor]]: 202 | """ receive next batch of numpy arrays """ 203 | if not self.batch_receiver.poll(timeout): 204 | raise TimeoutError() 205 | 206 | batch_index, input_keys = self.batch_receiver.recv() 207 | batch_inputs = [self.shared_arrays[key].copy() for key in input_keys] 208 | self.batch_received.set() # pool can now prepare next batch 209 | batch_inputs = [tensor_proto.convert_array_to_tensor(array).to(device, non_blocking=True) 210 | for array, tensor_proto in zip(batch_inputs, self.inputs_schema)] 211 | return batch_index, batch_inputs 212 | 213 | def send_outputs_from_runtime(self, batch_index: int, batch_outputs: List[np.ndarray]): 214 | """ send results for a processed batch, previously loaded through receive_batch """ 215 | batch_size = len(batch_outputs[0]) 216 | shared_keys, shared_buffers = zip(*self.get_or_create_buffers(batch_size, self.outputs_schema, name='outputs')) 217 | self.outputs_received.wait(), self.outputs_received.clear() # wait for pool to receive (copy) previous outputs 218 | 219 | for output, buffer in zip(batch_outputs, shared_buffers): 220 | np.copyto(dst=buffer, src=output) 221 | 222 | self.outputs_sender.send((batch_index, shared_keys)) 223 | 224 | def get_task_size(self, task: Task) -> int: 225 | """ compute task processing complexity (used for batching); defaults to batch size """ 226 | return len(task.args[0]) if task.args else 1 227 | 228 | def get_or_create_buffers(self, batch_size, schema: List[BatchTensorProto], name: str = '') -> List[SharedArray]: 229 | """ get or create a shared arrays for inputs and outputs with a given batch dimension """ 230 | 231 | for i, proto in enumerate(schema): 232 | key = f"pool_{self.uid}__batchsize_{batch_size}__{name}_{i}" 233 | if key in self.shared_arrays: 234 | arr = self.shared_arrays[key] 235 | else: 236 | self.shared_arrays[key] = arr = SharedArray.from_array(check_numpy(proto.make_empty(batch_size))) 237 | yield key, arr 238 | -------------------------------------------------------------------------------- /lib/server/__init__.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout 4 | from typing import Dict 5 | 6 | from .connection_handler import handle_connection 7 | from .network_handler import NetworkHandlerThread 8 | from ..network import TesseractNetwork 9 | from ..runtime import TesseractRuntime, ExpertBackend 10 | 11 | 12 | class TesseractServer: 13 | def __init__(self, network: TesseractNetwork, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1', 14 | port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False, 15 | **kwargs): 16 | self.network, self.experts, self.update_period = network, expert_backends, update_period 17 | self.addr, self.port = addr, port 18 | self.conn_handlers = conn_handler_processes 19 | self.runtime = TesseractRuntime(self.experts, **kwargs) 20 | 21 | if start: 22 | self.start() 23 | 24 | def start(self): 25 | if self.network: 26 | if not self.network.is_alive(): 27 | self.network.start() 28 | 29 | network_thread = NetworkHandlerThread(experts=self.experts, network=self.network, 30 | addr=self.addr, port=self.port, update_period=self.update_period) 31 | network_thread.start() 32 | 33 | processes = self.spawn_connection_handlers() 34 | try: 35 | self.runtime.main() 36 | finally: 37 | for process in processes: 38 | process.join() 39 | if self.network: 40 | network_thread.join() 41 | 42 | def spawn_connection_handlers(self): 43 | sock = socket(AF_INET, SOCK_STREAM) 44 | sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) 45 | sock.bind(('', self.port)) 46 | sock.listen() 47 | sock.settimeout(self.update_period) 48 | 49 | processes = [mp.Process(target=socket_loop, args=(sock, self.experts)) for _ in range(self.conn_handlers)] 50 | for process in processes: 51 | process.start() 52 | return processes 53 | 54 | 55 | def socket_loop(sock, experts): 56 | """ catch connections, send tasks to processing, respond with results """ 57 | print(f'Spawned connection handler pid={os.getpid()}') 58 | while True: 59 | try: 60 | handle_connection(sock.accept(), experts) 61 | except KeyboardInterrupt as e: 62 | print(f'Socket loop has caught {type(e)}, exiting') 63 | break 64 | except (timeout, BrokenPipeError, ConnectionResetError, NotImplementedError): 65 | continue 66 | -------------------------------------------------------------------------------- /lib/server/connection_handler.py: -------------------------------------------------------------------------------- 1 | from socket import socket 2 | from typing import Tuple, Dict 3 | 4 | from .. import PytorchSerializer, Connection 5 | from ..runtime.expert_backend import ExpertBackend 6 | 7 | 8 | def handle_connection(connection_tuple: Tuple[socket, str], experts: Dict[str, ExpertBackend]): 9 | with Connection(*connection_tuple) as connection: 10 | try: 11 | header = connection.recv_header() 12 | payload = PytorchSerializer.loads(connection.recv_raw()) 13 | 14 | if header == 'fwd_': 15 | uid, inputs = payload 16 | response = experts[uid].forward_pool.submit_task(*inputs).result() 17 | elif header == 'bwd_': 18 | uid, inputs_and_grad_outputs = payload 19 | response = experts[uid].backward_pool.submit_task(*inputs_and_grad_outputs).result() 20 | elif header == 'info': 21 | uid = payload 22 | response = experts[uid].get_info() 23 | else: 24 | raise NotImplementedError(f"Unknown header: {header}") 25 | 26 | connection.send_raw('rest', PytorchSerializer.dumps(response)) 27 | except RuntimeError: 28 | # socket connection broken 29 | pass 30 | -------------------------------------------------------------------------------- /lib/server/network_handler.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | from ..network import TesseractNetwork 5 | 6 | 7 | class NetworkHandlerThread(threading.Thread): 8 | def __init__(self, experts, network: TesseractNetwork, 9 | update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080): 10 | super(NetworkHandlerThread, self).__init__() 11 | self.port = port 12 | self.addr = addr 13 | self.experts = experts 14 | self.network = network 15 | self.update_period = update_period 16 | 17 | def run(self) -> None: 18 | while True: 19 | self.network.declare_experts(self.experts.keys(), self.addr, self.port) 20 | time.sleep(self.update_period) 21 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .connection import * 2 | from .data import * 3 | from .nested import * 4 | from .proto import * 5 | from .serializer import * 6 | from .shared_arrays import * 7 | from .shared_future import * 8 | from .threading import * 9 | -------------------------------------------------------------------------------- /lib/utils/connection.py: -------------------------------------------------------------------------------- 1 | from contextlib import AbstractContextManager 2 | from socket import socket 3 | from typing import Tuple 4 | 5 | 6 | class Connection(AbstractContextManager): 7 | header_size = 4 # number of characters in all headers 8 | payload_length_size = 8 # number of bytes used to encode payload length 9 | 10 | __slots__ = ('conn', 'addr') 11 | 12 | def __init__(self, conn: socket, addr: Tuple[str, int]): 13 | self.conn, self.addr = conn, addr 14 | 15 | @staticmethod 16 | def create(host: str, port: int): 17 | sock = socket() 18 | addr = (host, port) 19 | sock.connect(addr) 20 | return Connection(sock, addr) 21 | 22 | def send_raw(self, header: str, content: bytes): 23 | self.conn.send(header.encode()) 24 | self.conn.send(len(content).to_bytes(self.payload_length_size, byteorder='big')) 25 | 26 | total_sent = 0 27 | while total_sent < len(content): 28 | sent = self.conn.send(content[total_sent:]) 29 | if sent == 0: 30 | raise RuntimeError("socket connection broken") 31 | total_sent = total_sent + sent 32 | 33 | def recv_header(self) -> str: 34 | return self.conn.recv(self.header_size).decode() 35 | 36 | def recv_raw(self, max_package: int = 2048) -> bytes: 37 | length = int.from_bytes(self.conn.recv(self.payload_length_size), byteorder='big') 38 | chunks = [] 39 | bytes_recd = 0 40 | while bytes_recd < length: 41 | chunk = self.conn.recv(min(length - bytes_recd, max_package)) 42 | if chunk == b'': 43 | raise RuntimeError("socket connection broken") 44 | chunks.append(chunk) 45 | bytes_recd = bytes_recd + len(chunk) 46 | ret = b''.join(chunks) 47 | assert len(ret) == length 48 | return ret 49 | 50 | def recv_message(self) -> Tuple[str, bytes]: 51 | return self.recv_header(), self.recv_raw() 52 | 53 | def __exit__(self, *exc_info): 54 | self.conn.close() 55 | -------------------------------------------------------------------------------- /lib/utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def check_numpy(x): 6 | """ Makes sure x is a numpy array """ 7 | if isinstance(x, torch.Tensor): 8 | return x.detach().cpu().numpy() 9 | else: 10 | return np.asarray(x) 11 | 12 | 13 | DUMMY = torch.empty(0, requires_grad=True) 14 | -------------------------------------------------------------------------------- /lib/utils/nested.py: -------------------------------------------------------------------------------- 1 | """ utility functions that help you process nested dicts, tuples, lists and namedtuples """ 2 | 3 | 4 | def nested_compare(t, u): 5 | """ 6 | Return whether nested structure of t1 and t2 matches. 7 | """ 8 | if isinstance(t, (list, tuple)): 9 | if not isinstance(u, type(t)): 10 | return False 11 | if len(t) != len(u): 12 | return False 13 | for a, b in zip(t, u): 14 | if not nested_compare(a, b): 15 | return False 16 | return True 17 | 18 | if isinstance(t, dict): 19 | if not isinstance(u, dict): 20 | return False 21 | if set(t.keys()) != set(u.keys()): 22 | return False 23 | for k in t: 24 | if not nested_compare(t[k], u[k]): 25 | return False 26 | return True 27 | 28 | else: 29 | return True 30 | 31 | 32 | def nested_flatten(t): 33 | """ 34 | Turn nested list/tuple/dict into a flat iterator. 35 | """ 36 | if isinstance(t, (list, tuple)): 37 | for x in t: 38 | yield from nested_flatten(x) 39 | elif isinstance(t, dict): 40 | for k, v in sorted(t.items()): 41 | yield from nested_flatten(v) 42 | else: 43 | yield t 44 | 45 | 46 | def nested_pack(flat, structure): 47 | """ 48 | Restore nested structure from flattened state 49 | :param flat: result of nested_flatten 50 | :param structure: used as example when recovering structure 51 | :returns: nested structure like :structure: filled with elements of :flat: 52 | """ 53 | return _nested_pack(iter(flat), structure) 54 | 55 | 56 | def _nested_pack(flat_iter, structure): 57 | if is_namedtuple(structure): 58 | return type(structure)(*[ 59 | _nested_pack(flat_iter, x) 60 | for x in structure] 61 | ) 62 | elif isinstance(structure, (list, tuple)): 63 | return type(structure)( 64 | _nested_pack(flat_iter, x) 65 | for x in structure 66 | ) 67 | elif isinstance(structure, dict): 68 | return { 69 | k: _nested_pack(flat_iter, v) 70 | for k, v in sorted(structure.items()) 71 | } 72 | else: 73 | return next(flat_iter) 74 | 75 | 76 | def is_namedtuple(x): 77 | """Checks if x is a namedtuple instance. Taken from https://stackoverflow.com/a/2166841 .""" 78 | t = type(x) 79 | b = t.__bases__ 80 | if len(b) != 1 or b[0] != tuple: return False 81 | f = getattr(t, '_fields', None) 82 | if not isinstance(f, tuple): return False 83 | return all(type(n) == str for n in f) 84 | 85 | 86 | def nested_map(fn, *t): 87 | # Check arguments. 88 | if not t: 89 | raise ValueError('Expected 2+ arguments, got 1') 90 | for i in range(1, len(t)): 91 | if not nested_compare(t[0], t[i]): 92 | msg = 'Nested structure of %r and %r differs' 93 | raise ValueError(msg % (t[0], t[i])) 94 | 95 | # Map. 96 | flat = map(nested_flatten, t) 97 | return nested_pack(map(fn, *flat), t[0]) 98 | -------------------------------------------------------------------------------- /lib/utils/proto.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | 3 | import numpy as np 4 | import torch 5 | 6 | DUMMY_BATCH_SIZE = 3 # used for dummy runs only 7 | 8 | 9 | @dataclass(init=True, repr=True, frozen=True) 10 | class ProtoBase: 11 | pass 12 | 13 | 14 | @dataclass(init=True, repr=True, frozen=True) 15 | class ArrayProto(ProtoBase): 16 | shape: tuple 17 | dtype: np.dtype 18 | strides: tuple = None 19 | order: str = 'C' 20 | 21 | @classmethod 22 | def from_array(cls, arr: np.ndarray): 23 | return cls(arr.shape, arr.dtype, strides=arr.strides, order='CF'[np.isfortran(arr)]) 24 | 25 | def make_empty(self, **kwargs): 26 | properties = asdict(self) 27 | properties.update(kwargs) 28 | return np.ndarray(**properties) 29 | 30 | def make_from_buffer(self, buffer, offset=0): 31 | return np.ndarray(self.shape, self.dtype, buffer, offset, 32 | strides=self.strides, order=self.order) 33 | 34 | @property 35 | def nbytes(self): 36 | return np.dtype(self.dtype).itemsize * np.prod(self.shape) 37 | 38 | 39 | @dataclass(init=True, repr=True, frozen=True) 40 | class TensorProto(ProtoBase): 41 | size: tuple 42 | dtype: torch.dtype = None 43 | layout: torch.layout = torch.strided 44 | device: torch.device = None 45 | requires_grad: bool = False 46 | pin_memory: bool = False 47 | 48 | @property 49 | def shape(self): 50 | return self.size 51 | 52 | @classmethod 53 | def from_tensor(cls, tensor: torch.Tensor): 54 | return cls(tensor.shape, tensor.dtype, tensor.layout, tensor.device, tensor.requires_grad, tensor.is_pinned()) 55 | 56 | def make_empty(self, **kwargs): 57 | properties = asdict(self) 58 | properties.update(kwargs) 59 | return torch.empty(**properties) 60 | 61 | def convert_array_to_tensor(self, array: np.ndarray): 62 | tensor = torch.as_tensor(array, dtype=self.dtype, device=self.device) 63 | tensor = tensor.requires_grad_(self.requires_grad).to(self.device, non_blocking=True) 64 | return tensor.pin_memory() if self.pin_memory else tensor 65 | 66 | 67 | @dataclass(init=True, repr=True, frozen=True) 68 | class BatchTensorProto(TensorProto): 69 | """ torch Tensor with a variable 0-th dimension, used to describe batched data """ 70 | 71 | def __init__(self, *instance_size, **kwargs): # compatibility: allow initializing with *size 72 | if len(instance_size) == 1 and isinstance(instance_size[0], (list, tuple, torch.Size)): 73 | instance_size = instance_size[0] # we were given size as the only parameter instead of *parameters 74 | super().__init__((None, *instance_size), **kwargs) 75 | 76 | @classmethod 77 | def from_tensor(cls, tensor: torch.Tensor): 78 | return cls(*tensor.shape[1:], dtype=tensor.dtype, layout=tensor.layout, 79 | device=tensor.device, requires_grad=tensor.requires_grad, pin_memory=tensor.is_pinned()) 80 | 81 | def make_empty(self, batch_size, **kwargs): 82 | assert self.shape[0] is None, "Make sure 0-th dimension is not specified (set to None)" 83 | return super().make_empty(size=(batch_size, *self.shape[1:]), **kwargs) 84 | -------------------------------------------------------------------------------- /lib/utils/serializer.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import joblib 4 | import torch 5 | from six import BytesIO 6 | 7 | 8 | class JoblibSerializer: 9 | 10 | @staticmethod 11 | def dumps(obj) -> bytes: 12 | s = BytesIO() 13 | joblib.dump(obj, s) 14 | return s.getvalue() 15 | 16 | @staticmethod 17 | def loads(buf: bytes): 18 | return joblib.load(BytesIO(buf)) 19 | 20 | 21 | class PickleSerializer: 22 | @staticmethod 23 | def dumps(obj) -> bytes: 24 | return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) 25 | 26 | @staticmethod 27 | def loads(buf: bytes): 28 | return pickle.loads(buf) 29 | 30 | 31 | class PytorchSerializer: 32 | 33 | @staticmethod 34 | def dumps(obj) -> bytes: 35 | s = BytesIO() 36 | torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL) 37 | return s.getvalue() 38 | 39 | @staticmethod 40 | def loads(buf: bytes): 41 | return torch.load(BytesIO(buf)) 42 | -------------------------------------------------------------------------------- /lib/utils/shared_arrays.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dictionary of numpy arrays stored in shared memory. Multiprocessing-friendly 3 | """ 4 | import multiprocessing as mp 5 | import multiprocessing.shared_memory 6 | 7 | import numpy as np 8 | 9 | from .proto import ArrayProto 10 | 11 | 12 | class SharedArrays: 13 | def __init__(self, array_headers=None, shm_manager=None): 14 | """ 15 | A dict-like collection of numpy arrays shared between processes using multiprocessing.shared_memory 16 | :param array_headers: a shared dictionary of { key -> (ArrayProto, SharedMemory.name) } 17 | if not specified, creates a new shared dictionary using new multiprocessing.Manager() 18 | :param shm_manager: shared memory manager to be used when allocating new shared arrays 19 | if not specified, creates a new one 20 | """ 21 | assert array_headers is None or isinstance(array_headers, mp.managers.DictProxy) 22 | assert shm_manager is None or isinstance(shm_manager, mp.managers.SharedMemoryManager) 23 | if array_headers is not None: 24 | self.array_headers = array_headers 25 | else: 26 | self.array_headers_manager = mp.Manager() 27 | self.array_headers = self.array_headers_manager.dict() 28 | if shm_manager is None: 29 | shm_manager = mp.managers.SharedMemoryManager() 30 | shm_manager.start() 31 | self.shm_manager = shm_manager 32 | 33 | def fork(self): 34 | """ create a linked instance of SharedArrays that uses the same data and shm_manager """ 35 | return SharedArrays(self.array_headers, self.shm_manager) 36 | 37 | def __getitem__(self, key): 38 | proto, shmem_name = self.array_headers[key] 39 | return SharedArray(proto, mp.shared_memory.SharedMemory(name=shmem_name)) 40 | 41 | def __contains__(self, key): 42 | return key in self.array_headers 43 | 44 | def __setitem__(self, key, arr): 45 | if not isinstance(arr, SharedArray): 46 | raise ValueError("setitem only works with SharedArray values. For normal arrays, use:\n" 47 | "arr_shared = SharedArrays.create_array(key, ArrayProto.from_array(arr))\n" 48 | "arr_shared[...] = arr # note that arr not shared itself, but copied into a SharedArray") 49 | self.array_headers[key] = (ArrayProto.from_array(arr), arr.shared_memory.name) 50 | 51 | def __delitem__(self, key): 52 | del self.array_headers[key] 53 | 54 | def __repr__(self): 55 | return repr({key: self[key] for key in self.keys()}) 56 | 57 | def __len__(self): 58 | return len(self.array_headers) 59 | 60 | def keys(self): 61 | return self.array_headers.keys() 62 | 63 | def create_array(self, key, proto: ArrayProto): 64 | """ Create and return a shared array under the specified key. if key already exists, overwrite """ 65 | self[key] = shared_array = SharedArray(proto, self.shm_manager.SharedMemory(size=proto.nbytes)) 66 | return shared_array 67 | 68 | 69 | class SharedArray(np.ndarray): 70 | """ 71 | A subclass of numpy array that stores SharedMemory as an attribute; 72 | Use this class to prevent SharedMemory buffer from accidentally getting deallocated 73 | Details on subclassing numpy: https://docs.scipy.org/doc/numpy/user/basics.subclassing.html 74 | #simple-example-adding-an-extra-attribute-to-ndarray 75 | """ 76 | 77 | def __new__(subtype, proto: ArrayProto, shared_memory: mp.shared_memory.SharedMemory, offset=0): 78 | obj = super(SharedArray, subtype).__new__( 79 | subtype, proto.shape, proto.dtype, shared_memory.buf, offset, proto.strides, proto.order) 80 | obj.shared_memory = shared_memory 81 | return obj 82 | 83 | def __array_finalize__(self, obj): 84 | # make sure that shared memory is passed along to tensors that share its data, e.g. arr[::2] 85 | if obj is None: return # explicit creation: do nothing 86 | self.shared_memory = getattr(obj, 'shared_memory', None) 87 | 88 | def __array_wrap__(self, out_arr, context=None): 89 | return np.asarray(out_arr) # after out-of-place operation we no longer need to store sharedmemory 90 | 91 | @classmethod 92 | def from_array(cls, arr: np.ndarray, shared_memory: mp.shared_memory.SharedMemory = None): 93 | """ Create SharedArray from a regular numpy array (out-of-place) """ 94 | proto = ArrayProto.from_array(arr) 95 | shared_memory = shared_memory or mp.shared_memory.SharedMemory(create=True, size=proto.nbytes) 96 | proto.make_from_buffer(shared_memory.buf)[...] = arr 97 | return cls(proto, shared_memory) 98 | 99 | def __repr__(self): 100 | return super().__repr__() + '; shared_memory={}'.format(self.shared_memory) 101 | -------------------------------------------------------------------------------- /lib/utils/shared_future.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import multiprocessing.connection 3 | from concurrent.futures import Future, CancelledError 4 | from warnings import warn 5 | 6 | 7 | class SharedFuture(Future): 8 | """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """ 9 | STATES = 'pending', 'running', 'cancelled', 'finished', 'exception' 10 | STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES 11 | 12 | def __init__(self, connection: mp.connection.Connection): 13 | """ manually create MPFuture. Please use MPFuture.make_pair instead """ 14 | self.connection = connection 15 | self.state = self.STATE_PENDING 16 | self._result = None 17 | self._exception = None 18 | 19 | @classmethod 20 | def make_pair(cls): 21 | """ Create a pair of linked futures to be used in two processes """ 22 | connection1, connection2 = mp.Pipe() 23 | return cls(connection1), cls(connection2) 24 | 25 | def _recv(self, timeout): 26 | if self.state in (self.STATE_PENDING, self.STATE_RUNNING): 27 | if not self.connection.poll(timeout): 28 | raise TimeoutError() 29 | try: 30 | status, payload = self.connection.recv() 31 | except BrokenPipeError as e: 32 | status, payload = self.STATE_EXCEPTION, e 33 | 34 | assert status in self.STATES 35 | self.state = status 36 | 37 | if status == self.STATE_FINISHED: 38 | self._result = payload 39 | elif status == self.STATE_EXCEPTION: 40 | self._exception = payload 41 | elif status in (self.STATE_RUNNING, self.STATE_CANCELLED): 42 | pass # only update self.state 43 | else: 44 | raise ValueError("Result status should not be self.STATE_PENDING") 45 | 46 | def set_result(self, result): 47 | try: 48 | self.state, self._result = self.STATE_FINISHED, result 49 | self.connection.send((self.STATE_FINISHED, result)) 50 | return True 51 | except BrokenPipeError: 52 | return False 53 | 54 | def set_exception(self, exception: BaseException): 55 | try: 56 | self.state, self._exception = self.STATE_EXCEPTION, exception 57 | self.connection.send((self.STATE_EXCEPTION, exception)) 58 | return True 59 | except BrokenPipeError: 60 | return False 61 | 62 | def set_running_or_notify_cancel(self): 63 | return True 64 | 65 | def cancel(self): 66 | raise NotImplementedError() 67 | 68 | def result(self, timeout=None): 69 | self._recv(timeout) 70 | if self.state == self.STATE_FINISHED: 71 | return self._result 72 | elif self.state == self.STATE_EXCEPTION: 73 | raise self._exception 74 | else: 75 | assert self.state == self.STATE_CANCELLED 76 | raise CancelledError() 77 | 78 | def exception(self, timeout=None): 79 | self._recv(timeout) 80 | return self._exception 81 | 82 | def done(self): 83 | return self.state in (self.STATE_FINISHED, self.STATE_EXCEPTION, self.STATE_CANCELLED) 84 | 85 | def running(self): 86 | return self.state == self.STATE_RUNNING 87 | 88 | def cancelled(self): 89 | warn("cancelled not implemented") 90 | return False 91 | 92 | def add_done_callback(self, callback): 93 | raise NotImplementedError() 94 | 95 | def __repr__(self): 96 | try: 97 | self._recv(timeout=0) 98 | except TimeoutError: 99 | pass 100 | if self.state == self.STATE_FINISHED: 101 | return "".format(id(self), type(self._result)) 102 | elif self.state == self.STATE_EXCEPTION: 103 | return "".format(id(self), type(self._exception)) 104 | else: 105 | return "".format(id(self), self.state) 106 | -------------------------------------------------------------------------------- /lib/utils/threading.py: -------------------------------------------------------------------------------- 1 | import time 2 | from concurrent.futures import Future, TimeoutError 3 | from itertools import count 4 | from threading import Thread, Event, Lock 5 | 6 | 7 | def run_in_background(func: callable, *args, **kwargs): 8 | """ run f(*args, **kwargs) in background and return Future for its outputs """ 9 | future = Future() 10 | 11 | def _run(): 12 | try: 13 | future.set_result(func(*args, **kwargs)) 14 | except Exception as e: 15 | future.set_exception(e) 16 | 17 | Thread(target=_run).start() 18 | return future 19 | 20 | 21 | def repeated(func: callable, n_times=None): 22 | """ A function that runs a :func: forever or for a specified number of times; use with run_run_in_background """ 23 | 24 | def repeat(): 25 | for i in count(): 26 | if n_times is not None and i > n_times: 27 | break 28 | func() 29 | 30 | return repeat 31 | 32 | 33 | def add_event_callback(event: Event, callback, timeout=None): 34 | """ Add callback that will be executed asynchronously when event is set """ 35 | return Thread(target=lambda: (event.wait(timeout), callback())).start() 36 | 37 | 38 | class CountdownEvent(Event): 39 | def __init__(self, count_to: int, initial=0): 40 | """ An event that must be incremented :count_to: times before it is considered set """ 41 | super().__init__() 42 | self.value = initial 43 | self.count_to = count_to 44 | self.lock = Lock() 45 | self.increment(by=0) # trigger set/unset depending on initial value 46 | 47 | def increment(self, by=1): 48 | with self.lock: 49 | self.value += by 50 | if self.value >= self.count_to: 51 | super().set() 52 | else: 53 | super().clear() 54 | return self.value 55 | 56 | def clear(self): 57 | return self.increment(by=-self.value) 58 | 59 | 60 | def await_first(*events: Event, k=1, timeout=None): 61 | """ 62 | wait until first k (default=1) events are set, return True if event was set fast 63 | # Note: after k successes we manually *set* all events to avoid memory leak. 64 | """ 65 | events_done = CountdownEvent(count_to=k) 66 | for event in events: 67 | add_event_callback(event, callback=events_done.increment, timeout=timeout) 68 | 69 | if events_done.wait(timeout=timeout): 70 | [event.set() for event in events] 71 | return True 72 | else: 73 | raise TimeoutError() 74 | 75 | 76 | def run_and_await_k(jobs: callable, k, timeout_after_k=0, timeout_total=None): 77 | """ 78 | Runs all :jobs: asynchronously, awaits for at least k of them to finish 79 | :param jobs: functions to call 80 | :param k: how many functions should finish 81 | :param timeout_after_k: after reaching k finished jobs, wait for this long before cancelling 82 | :param timeout_total: if specified, terminate cancel jobs after this many seconds 83 | :returns: a list of either results or exceptions for each job 84 | """ 85 | assert k <= len(jobs) 86 | start_time = time.time() 87 | min_successful_jobs = CountdownEvent(count_to=k) 88 | max_failed_jobs = CountdownEvent(count_to=len(jobs) - k + 1) 89 | 90 | def _run_and_increment(run_job: callable): 91 | try: 92 | result = run_job() 93 | min_successful_jobs.increment() 94 | return result 95 | except Exception as e: 96 | max_failed_jobs.increment() 97 | return e 98 | 99 | def _run_and_await(run_job: callable): 100 | # call function asynchronously. Increment counter after finished 101 | future = run_in_background(_run_and_increment, run_job) 102 | 103 | try: # await for success counter to reach k OR for fail counter to reach n - k + 1 104 | await_first(min_successful_jobs, max_failed_jobs, 105 | timeout=None if timeout_total is None else timeout_total - time.time() + start_time) 106 | except TimeoutError as e: # counter didn't reach k jobs in timeout_total 107 | return future.result() if future.done() else e 108 | 109 | try: # await for subsequent jobs if asked to 110 | return future.result(timeout=timeout_after_k) 111 | except TimeoutError as e: 112 | future.cancel() 113 | return e 114 | 115 | except Exception as e: # job failed with exception. Ignore it. 116 | return e 117 | 118 | results = [run_in_background(_run_and_await, f) for f in jobs] 119 | results = [result.result() for result in results] 120 | if min_successful_jobs.is_set(): 121 | return results 122 | elif max_failed_jobs.is_set(): 123 | raise ValueError("Could not get enough results: too many jobs failed.") 124 | else: 125 | raise TimeoutError("Could not get enough results: reached timeout_total.") 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | joblib 3 | numpy 4 | requests 5 | tqdm 6 | kademlia 7 | prefetch_generator -------------------------------------------------------------------------------- /scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mryab/learning-at-home/bfcaaf9df4b43f6f400c1581d26ce85ae02661a4/scheme.png -------------------------------------------------------------------------------- /scheme_pad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mryab/learning-at-home/bfcaaf9df4b43f6f400c1581d26ce85ae02661a4/scheme_pad.png --------------------------------------------------------------------------------