├── .gitignore ├── LICENSE ├── README.md ├── basic_utils ├── __init__.py ├── dist_util.py ├── logger.py └── sampler.py ├── config ├── __init__.py ├── base.py └── train.py ├── data ├── __init__.py └── dataset.py ├── models └── __init__.py ├── requirements.txt ├── run_train.sh ├── train.py └── utils ├── __init__.py ├── initialization.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Manual ### 2 | # Add extra ignored paths on your own 3 | 4 | ### WandB and Torch ### 5 | wandb/ 6 | **/wandb/ 7 | *.pth 8 | *.pt 9 | 10 | ### IDE ### 11 | .vscode 12 | .idea/ 13 | *.iml 14 | modules.xml 15 | *.ipr 16 | 17 | ### Python Cache ### 18 | *.log 19 | *.pyc 20 | __pycache__/ 21 | 22 | ### Python Virtual Environments ### 23 | .env 24 | .venv 25 | env/ 26 | venv/ 27 | ENV/ 28 | env.bak/ 29 | venv.bak/ 30 | 31 | ### OS Related ### 32 | *~ 33 | .fuse_hidden* 34 | .directory 35 | .Trash-* 36 | .nfs* 37 | .DS_Store 38 | .AppleDouble 39 | .LSOverride 40 | Icon 41 | ._* 42 | .DocumentRevisions-V100 43 | .fseventsd 44 | .Spotlight-V100 45 | .TemporaryItems 46 | .Trashes 47 | .VolumeIcon.icns 48 | .com.apple.timemachine.donotpresent 49 | .AppleDB 50 | .AppleDesktop 51 | Network Trash Folder 52 | Temporary Items 53 | .apdisk 54 | Thumbs.db 55 | Thumbs.db:encryptable 56 | ehthumbs.db 57 | ehthumbs_vista.db 58 | *.stackdump 59 | [Dd]esktop.ini 60 | $RECYCLE.BIN/ 61 | *.lnk 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 kdha0727 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch pipeline with `torch.distributed` 2 | 3 | `utils.trainer.TrainLoop` will run training loop - compatible with torch.distributed.run 4 | 5 | ## All you need to do 6 | 7 | * Complete `config/train.py`'s `TrainSettings`(or `YourSettings`) class. 8 | * this setting class is compatible with argparse and json. 9 | * Complete `data/__init__.py`'s `load_data_from_args` function. 10 | * Complete `model` package. 11 | * Complete `utils/initialization.py`'s `create_model_from_config` function. 12 | * Complete some method of `utils/trainer.py`'s `TrainLoop` class. 13 | * `log_loss_dict` method: logging function of loss values dict. 14 | * `compute_losses` method: calculate `losses` from `micro_batch` and TrainLoop vars 15 | * `backward_from_losses` method: make single `loss` from `losses`, and run `loss.backward()` 16 | * `__init__` method: add your extra values to TrainLoop vars if needed. 17 | * Complete `train.py` to make sense with all code signatures you modified. 18 | * Modify setting json file, after copying default train settings with command, 19 | ``` 20 | python3 -c "from config import TrainSettings as T; print(T().json(indent=2))" >> train_config.json 21 | ``` 22 | 23 |
24 | View simplest train.py script example: 25 | 26 | ```python 27 | from torch.distributed.elastic.multiprocessing.errors import record 28 | 29 | 30 | def main(): 31 | 32 | import os 33 | import torch 34 | from basic_utils import dist_util 35 | 36 | if os.getenv("LOCAL_RANK", None) and not dist_util.is_initialized(): 37 | dist_util.setup_dist() 38 | with dist_util.with_dist_cleanup(): 39 | main() 40 | return 41 | rank = dist_util.get_rank() 42 | dist_util.barrier() 43 | 44 | class Model(torch.nn.Module): 45 | 46 | def __init__(self): 47 | super().__init__() 48 | self.param = torch.nn.Parameter(torch.ones(1)) 49 | 50 | def forward(self, x, target=None): 51 | output = self.param * x 52 | if target is not None: 53 | return (target - output) ** 2 54 | return output 55 | 56 | model = Model() 57 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-1) 58 | 59 | model.to(dist_util.dev()) 60 | dist_util.barrier() 61 | 62 | if dist_util.is_initialized(): 63 | ddp_kwargs = dict( 64 | broadcast_buffers=False, 65 | bucket_cap_mb=128, 66 | find_unused_parameters=False, 67 | ) 68 | if torch.cuda.is_available(): 69 | ddp_kwargs.update(device_ids=[dist_util.dev()], output_device=dist_util.dev()) 70 | ddp_model = torch.nn.parallel.DistributedDataParallel(model, **ddp_kwargs) 71 | else: 72 | ddp_model = model 73 | 74 | dist_util.sequential_print("rank", rank, "param :", model.param.data.item()) 75 | dist_util.print_master_node() 76 | 77 | data = torch.ones(1, device=dist_util.dev()) * (dist_util.get_rank() + 1) 78 | target = torch.ones(1, device=dist_util.dev()) * 0.5 * (dist_util.get_rank() + 1) 79 | 80 | with ddp_model.no_sync() if dist_util.is_initialized() else dist_util.dummy_context(): 81 | loss = ddp_model(data, target) 82 | dist_util.sequential_print("rank", rank, "loss :", loss.item()) 83 | dist_util.print_master_node() 84 | 85 | loss.backward() 86 | dist_util.sequential_print("rank", rank, "grad :", model.param.grad.item()) 87 | dist_util.print_master_node() 88 | 89 | loss = ddp_model(data, target) 90 | dist_util.sequential_print("rank", rank, "loss :", loss.item()) 91 | dist_util.print_master_node() 92 | 93 | loss.backward() 94 | dist_util.sequential_print("rank", rank, "sync_grad :", model.param.grad.item()) 95 | dist_util.print_master_node() 96 | 97 | optimizer.step() 98 | dist_util.sequential_print("rank", rank, "updated_param :", model.param.data.item()) 99 | dist_util.barrier() 100 | 101 | 102 | if __name__ == "__main__": 103 | record(main)() 104 | 105 | ``` 106 | 107 | Execute it with... 108 | 109 | ```bash 110 | torchrun --nproc_per_node gpu train.py 111 | ``` 112 | 113 | Or without distributed training... 114 | 115 | ```bash 116 | python3 train.py 117 | ``` 118 | 119 |
120 | 121 | ## How to run 122 | 123 | after completion, you can run train script with 124 | 125 | ```bash 126 | torchrun --nproc_per_node gpu train.py --config_json train_config.json 127 | ``` 128 | 129 | ## Citations 130 | 131 | ```bibtex 132 | @inproceedings{gong2022diffuseq, 133 | author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng}, 134 | booktitle = {International Conference on Learning Representations, ICLR}, 135 | title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models}, 136 | year = 2023 137 | } 138 | ``` 139 | -------------------------------------------------------------------------------- /basic_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/basic_utils/__init__.py -------------------------------------------------------------------------------- /basic_utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 kdha0727 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy 4 | # of this software and associated documentation files (the "Software"), to deal 5 | # in the Software without restriction, including without limitation the rights 6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | # copies of the Software, and to permit persons to whom the Software is 8 | # furnished to do so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | """ 21 | Helpers for distributed training. 22 | This module's function is compatible even though script is not running in torch.distributed.run environment. 23 | Write code as though you are using torch.distributed.run - if you directly run scripts, it works! 24 | """ 25 | __author__ = "https://github.com/kdha0727" 26 | 27 | import io 28 | import os 29 | import contextlib 30 | import functools 31 | from typing import overload, Sequence, Optional, Union, ContextManager, Any 32 | 33 | import torch 34 | import torch.distributed as dist 35 | from torch.nn.parallel.distributed import DistributedDataParallel 36 | from torch.cuda import is_available as _cuda_available 37 | 38 | 39 | RANK: int = 0 40 | WORLD_SIZE: int = 1 41 | 42 | 43 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 44 | # Setup Tools # 45 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 46 | 47 | def is_initialized() -> bool: 48 | # if pytorch isn't compiled with c10d, is_initialized is omitted from namespace. 49 | # this function wraps 50 | """ 51 | Returns c10d (distributed) runtime is initialized. 52 | """ 53 | return dist.is_available() and getattr(dist, "is_initialized", lambda: False)() 54 | 55 | 56 | @overload 57 | def setup_dist(temp_dir: str, rank: int, world_size: int) -> None: ... 58 | 59 | 60 | @overload 61 | def setup_dist() -> None: ... 62 | 63 | 64 | def setup_dist(*args): 65 | """ 66 | Set up a distributed process group. 67 | Usage 68 | 1. setup_dist(temp_dir, rank, world_size) 69 | : if you want to init by file, call this function with three args (temp_dir, rank, world_size). 70 | 2. setup_dist() 71 | : if you want to init by env (by torchrun), call this function without args. 72 | """ 73 | if is_initialized(): 74 | return 75 | 76 | backend = "gloo" if (not _cuda_available()) or (os.name == "nt") else "nccl" 77 | 78 | if len(args) == 3: 79 | temp_dir, rank, world_size = args 80 | assert os.path.isdir(temp_dir), f"temp_dir {temp_dir} is not a directory" 81 | assert isinstance(rank, int), f"rank {rank} must be int" 82 | assert isinstance(world_size, int), f"world_size {world_size} must be int" 83 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 84 | if os.name == 'nt': 85 | init_method = 'file:///' + init_file.replace('\\', '/') 86 | else: 87 | init_method = f'file://{init_file}' 88 | dist.init_process_group(backend=backend, init_method=init_method, rank=rank, world_size=world_size) 89 | elif len(args) == 0: 90 | assert os.getenv("LOCAL_RANK", None) is not None, "environ LOCAL_RANK is not set" 91 | dist.init_process_group(backend=backend, init_method="env://") 92 | rank = dist.get_rank() 93 | world_size = dist.get_world_size() 94 | else: 95 | raise TypeError("setup_dist() takes 0 or 3 arguments") 96 | 97 | global RANK, WORLD_SIZE 98 | RANK = rank 99 | WORLD_SIZE = world_size 100 | 101 | if _cuda_available(): 102 | torch.cuda.set_device(dev()) 103 | torch.cuda.empty_cache() 104 | 105 | 106 | def cleanup_dist(): 107 | """ 108 | Clean up a distributed process group. 109 | """ 110 | if is_initialized(): 111 | dist.destroy_process_group() 112 | 113 | 114 | @contextlib.contextmanager 115 | def with_dist_cleanup(): 116 | """ 117 | Context Manager or Decorator version of cleanup_dist(). 118 | """ 119 | yield 120 | cleanup_dist() 121 | 122 | 123 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 124 | # General Tools # 125 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 126 | 127 | @functools.lru_cache(maxsize=None) 128 | def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: 129 | """ 130 | Wrapper of torch.distributed.get_rank. 131 | Get the rank of current process. 132 | """ 133 | if group is not None and is_initialized(): 134 | return dist.get_rank(group=group) 135 | return RANK 136 | 137 | 138 | @functools.lru_cache(maxsize=None) 139 | def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int: 140 | """ 141 | Wrapper of torch.distributed.get_world_size. 142 | Get the world size of current process. 143 | """ 144 | if group is not None and is_initialized(): 145 | return dist.get_world_size(group=group) 146 | return WORLD_SIZE 147 | 148 | 149 | def barrier(*args: ..., **kwargs: ...) -> None: 150 | """ 151 | Wrapper for torch.distributed.barrier. 152 | Synchronizes all processes. 153 | see `torch.distributed.distributed_c10d.barrier.__doc__` for more information. 154 | """ 155 | if is_initialized(): 156 | return dist.barrier(*args, **kwargs) 157 | 158 | 159 | @contextlib.contextmanager 160 | def synchronized() -> ContextManager[None]: 161 | """ 162 | context manager version of barrier() function. 163 | """ 164 | barrier() 165 | yield 166 | barrier() 167 | 168 | 169 | @functools.lru_cache(maxsize=None) 170 | def dev(group: Optional[dist.ProcessGroup] = None) -> torch.device: 171 | """ 172 | Get the device to use for torch.distributed. 173 | """ 174 | if _cuda_available(): 175 | return torch.device("cuda:{}".format(get_rank(group))) 176 | return torch.device("cpu") 177 | 178 | 179 | try: 180 | def load_state_dict(local_or_remote_path: Union[str, os.PathLike], **kwargs: ...) -> Any: 181 | """ 182 | Load a PyTorch file. 183 | """ 184 | with bf.BlobFile(local_or_remote_path, "rb") as f: 185 | data = f.read() 186 | return torch.load(io.BytesIO(data), **kwargs) 187 | import blobfile as bf # NOQA: F401 188 | except ImportError: 189 | def load_state_dict(local_or_remote_path: Union[str, os.PathLike], **kwargs: ...) -> Any: 190 | """ 191 | Load a PyTorch file. 192 | """ 193 | return torch.load(local_or_remote_path, **kwargs) 194 | 195 | 196 | def broadcast( 197 | tensor: Sequence[torch.Tensor], 198 | src: Optional[int] = 0, 199 | group: Optional[dist.ProcessGroup] = None, 200 | async_op: bool = False 201 | ) -> None: 202 | """ 203 | Synchronize a Tensor across ranks from {src} rank. (default=0) 204 | :param tensor: torch.Tensor. 205 | :param src: source rank to sync params from. default is 0. 206 | :param group: 207 | :param async_op: 208 | """ 209 | if not is_initialized(): 210 | return 211 | with torch.no_grad(): 212 | dist.broadcast(tensor, src, group=group, async_op=async_op) 213 | 214 | 215 | def sync_params( 216 | params: Sequence[torch.Tensor], 217 | src: Optional[int] = 0, 218 | group: Optional[dist.ProcessGroup] = None, 219 | async_op: bool = False 220 | ) -> None: 221 | """ 222 | Synchronize a sequence of Tensors across ranks from {src} rank. (default=0) 223 | :param params: Sequence of torch.Tensor. 224 | :param src: source rank to sync params from. default is 0. 225 | :param group: 226 | :param async_op: 227 | """ 228 | if not is_initialized(): 229 | return 230 | for p in params: 231 | broadcast(p, src, group=group, async_op=async_op) 232 | 233 | 234 | def sequential_print(*args: ..., **kwargs: ...) -> None: 235 | """ 236 | Print argument sequentially by rank order. 237 | Arguments are passed to print function. 238 | """ 239 | rank = get_rank() 240 | for i in range(get_world_size()): 241 | if i == rank: 242 | print(*args, **kwargs) 243 | barrier() 244 | 245 | 246 | def print_master_node(*args: ..., **kwargs: ...) -> None: 247 | """ 248 | Print argument only on master node. 249 | Arguments are passed to print function. 250 | """ 251 | if get_rank() == 0: 252 | print(*args, **kwargs) 253 | barrier() 254 | 255 | 256 | @contextlib.contextmanager 257 | def dummy_context() -> ContextManager[None]: 258 | """ 259 | Dummy context manager. 260 | """ 261 | yield 262 | 263 | 264 | @contextlib.contextmanager 265 | def no_sync(*modules: DistributedDataParallel) -> ContextManager[None]: 266 | """ 267 | Context Manager or Decorator of multiple modules' no_sync(). 268 | Usage 269 | with no_sync(ddp_module_1, ddp_module_2, ddp_module_3, ddp_module_n): 270 | ... # your code 271 | """ 272 | if not modules: 273 | yield 274 | else: 275 | with contextlib.ExitStack() as stk: 276 | for module in modules: 277 | if not isinstance(module, DistributedDataParallel): 278 | raise TypeError( 279 | "arg {!r} should be instance of DistributedDataParallel, got {}" 280 | .format(module, type(module)) 281 | ) 282 | stk.enter_context(module.no_sync()) 283 | yield 284 | 285 | 286 | try: 287 | # patch some os module functions - since file io uses os.sync 288 | import nt # NOQA 289 | os.sync = nt.sync = lambda: None # signature: () -> None 290 | except ImportError: 291 | pass 292 | -------------------------------------------------------------------------------- /basic_utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | from abc import abstractmethod, ABC 7 | import os 8 | import sys 9 | import json 10 | import time 11 | import datetime 12 | import tempfile 13 | import warnings 14 | from collections import defaultdict 15 | from contextlib import contextmanager 16 | import wandb 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(ABC): 27 | @abstractmethod 28 | def writekvs(self, kvs): 29 | raise NotImplementedError 30 | 31 | 32 | class SeqWriter(ABC): 33 | @abstractmethod 34 | def writeseq(self, seq): 35 | raise NotImplementedError 36 | 37 | 38 | class HumanOutputFormat(KVWriter, SeqWriter): 39 | def __init__(self, filename_or_file): 40 | if isinstance(filename_or_file, str): 41 | self.file = open(filename_or_file, "wt") 42 | self.own_file = True 43 | else: 44 | assert hasattr(filename_or_file, "read"), ( 45 | "expected file or str, got %s" % filename_or_file 46 | ) 47 | self.file = filename_or_file 48 | self.own_file = False 49 | 50 | def writekvs(self, kvs): 51 | # Create strings for printing 52 | key2str = {} 53 | for (key, val) in sorted(kvs.items()): 54 | if hasattr(val, "__float__"): 55 | valstr = "%-8.3g" % val 56 | else: 57 | valstr = str(val) 58 | key2str[self._truncate(key)] = self._truncate(valstr) 59 | 60 | # Find max widths 61 | if len(key2str) == 0: 62 | print("WARNING: tried to write empty key-value dict") 63 | return 64 | else: 65 | keywidth = max(map(len, key2str.keys())) 66 | valwidth = max(map(len, key2str.values())) 67 | 68 | # Write out the data 69 | dashes = "-" * (keywidth + valwidth + 7) 70 | lines = [dashes] 71 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 72 | lines.append( 73 | "| %s%s | %s%s |" 74 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 75 | ) 76 | lines.append(dashes) 77 | self.file.write("\n".join(lines) + "\n") 78 | 79 | # Flush the output to the file 80 | self.file.flush() 81 | 82 | @staticmethod 83 | def _truncate(s): 84 | maxlen = 30 85 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 86 | 87 | def writeseq(self, seq): 88 | seq = list(seq) 89 | for (i, elem) in enumerate(seq): 90 | self.file.write(elem) 91 | if i < len(seq) - 1: # add space unless this is the last one 92 | self.file.write(" ") 93 | self.file.write("\n") 94 | self.file.flush() 95 | 96 | def close(self): 97 | if self.own_file: 98 | self.file.close() 99 | 100 | 101 | class JSONOutputFormat(KVWriter): 102 | def __init__(self, filename): 103 | self.file = open(filename, "wt") 104 | 105 | def writekvs(self, kvs): 106 | for k, v in sorted(kvs.items()): 107 | if hasattr(v, "dtype"): 108 | kvs[k] = float(v) 109 | self.file.write(json.dumps(kvs) + "\n") 110 | self.file.flush() 111 | 112 | def close(self): 113 | self.file.close() 114 | 115 | 116 | class CSVOutputFormat(KVWriter): 117 | def __init__(self, filename): 118 | self.file = open(filename, "w+t") 119 | self.keys = [] 120 | self.sep = "," 121 | 122 | def writekvs(self, kvs): 123 | # Add our current row to the history 124 | extra_keys = list(kvs.keys() - self.keys) 125 | extra_keys.sort() 126 | if extra_keys: 127 | self.keys.extend(extra_keys) 128 | self.file.seek(0) 129 | lines = self.file.readlines() 130 | self.file.seek(0) 131 | for (i, k) in enumerate(self.keys): 132 | if i > 0: 133 | self.file.write(",") 134 | self.file.write(k) 135 | self.file.write("\n") 136 | for line in lines[1:]: 137 | self.file.write(line[:-1]) 138 | self.file.write(self.sep * len(extra_keys)) 139 | self.file.write("\n") 140 | for (i, k) in enumerate(self.keys): 141 | if i > 0: 142 | self.file.write(",") 143 | v = kvs.get(k) 144 | if v is not None: 145 | self.file.write(str(v)) 146 | self.file.write("\n") 147 | self.file.flush() 148 | 149 | def close(self): 150 | self.file.close() 151 | 152 | 153 | class TensorBoardOutputFormat(KVWriter): 154 | """ 155 | Dumps key/value pairs into TensorBoard's numeric format. 156 | """ 157 | 158 | def __init__(self, dir): 159 | os.makedirs(dir, exist_ok=True) 160 | self.dir = dir 161 | self.step = 1 162 | prefix = "events" 163 | path = os.path.join(os.path.abspath(dir), prefix) 164 | import tensorflow as tf 165 | from tensorflow.python import pywrap_tensorflow 166 | from tensorflow.core.util import event_pb2 167 | from tensorflow.python.util import compat 168 | 169 | self.tf = tf 170 | self.event_pb2 = event_pb2 171 | self.pywrap_tensorflow = pywrap_tensorflow 172 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 173 | 174 | def writekvs(self, kvs): 175 | def summary_val(k, v): 176 | kwargs = {"tag": k, "simple_value": float(v)} 177 | return self.tf.Summary.Value(**kwargs) 178 | 179 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 180 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 181 | event.step = ( 182 | self.step 183 | ) # is there any reason why you'd want to specify the step? 184 | self.writer.WriteEvent(event) 185 | self.writer.Flush() 186 | self.step += 1 187 | 188 | def close(self): 189 | if self.writer: 190 | self.writer.Close() 191 | self.writer = None 192 | 193 | 194 | def make_output_format(format, ev_dir, log_suffix=""): 195 | os.makedirs(ev_dir, exist_ok=True) 196 | if format == "stdout": 197 | return HumanOutputFormat(sys.stdout) 198 | elif format == "log": 199 | return HumanOutputFormat(os.path.join(ev_dir, "log%s.txt" % log_suffix)) 200 | elif format == "json": 201 | return JSONOutputFormat(os.path.join(ev_dir, "progress%s.json" % log_suffix)) 202 | elif format == "csv": 203 | return CSVOutputFormat(os.path.join(ev_dir, "progress%s.csv" % log_suffix)) 204 | elif format == "tensorboard": 205 | return TensorBoardOutputFormat(os.path.join(ev_dir, "tb%s" % log_suffix)) 206 | else: 207 | raise ValueError("Unknown format specified: %s" % (format,)) 208 | 209 | 210 | # ================================================================ 211 | # API 212 | # ================================================================ 213 | 214 | 215 | def logkv(key, val): 216 | """ 217 | Log a value of some diagnostic 218 | Call this once for each diagnostic quantity, each iteration 219 | If called many times, last value will be used. 220 | """ 221 | get_current().logkv(key, val) 222 | 223 | 224 | def logkv_mean(key, val): 225 | """ 226 | The same as logkv(), but if called many times, values averaged. 227 | """ 228 | get_current().logkv_mean(key, val) 229 | 230 | 231 | def logkvs(d): 232 | """ 233 | Log a dictionary of key-value pairs 234 | """ 235 | for (k, v) in d.items(): 236 | logkv(k, v) 237 | 238 | 239 | def dumpkvs(): 240 | """ 241 | Write all of the diagnostics from the current iteration 242 | """ 243 | return get_current().dumpkvs() 244 | 245 | 246 | def getkvs(): 247 | return get_current().name2val 248 | 249 | 250 | def log(*args, level=INFO): 251 | """ 252 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 253 | """ 254 | get_current().log(*args, level=level) 255 | 256 | 257 | def debug(*args): 258 | log(*args, level=DEBUG) 259 | 260 | 261 | def info(*args): 262 | log(*args, level=INFO) 263 | 264 | 265 | def warn(*args): 266 | log(*args, level=WARN) 267 | 268 | 269 | def error(*args): 270 | log(*args, level=ERROR) 271 | 272 | 273 | def set_level(level): 274 | """ 275 | Set logging threshold on current logger. 276 | """ 277 | get_current().set_level(level) 278 | 279 | 280 | def set_comm(comm): 281 | get_current().set_comm(comm) 282 | 283 | 284 | def get_dir(): 285 | """ 286 | Get directory that log files are being written to. 287 | will be None if there is no output directory (i.e., if you didn't call start) 288 | """ 289 | return get_current().get_dir() 290 | 291 | 292 | record_tabular = logkv 293 | dump_tabular = dumpkvs 294 | 295 | 296 | @contextmanager 297 | def profile_kv(scopename): 298 | logkey = "wait_" + scopename 299 | tstart = time.time() 300 | try: 301 | yield 302 | finally: 303 | get_current().name2val[logkey] += time.time() - tstart 304 | 305 | 306 | def profile(n): 307 | """ 308 | Usage: 309 | @profile("my_func") 310 | def my_func(): code 311 | """ 312 | 313 | def decorator_with_name(func): 314 | def func_wrapper(*args, **kwargs): 315 | with profile_kv(n): 316 | return func(*args, **kwargs) 317 | 318 | return func_wrapper 319 | 320 | return decorator_with_name 321 | 322 | 323 | # ================================================================ 324 | # Backend 325 | # ================================================================ 326 | 327 | 328 | def get_current(): 329 | if Logger.CURRENT is None: 330 | _configure_default_logger() 331 | 332 | return Logger.CURRENT 333 | 334 | 335 | class Logger(object): 336 | DEFAULT = None # A logger with no output files. (See right below class definition) 337 | # So that you can still log to the terminal without setting up any output files 338 | CURRENT = None # Current logger being used by the free functions above 339 | 340 | def __init__(self, dir, output_formats, comm=None): 341 | self.name2val = defaultdict(float) # values this iteration 342 | self.name2cnt = defaultdict(int) 343 | self.level = INFO 344 | self.dir = dir 345 | self.output_formats = output_formats 346 | self.comm = comm 347 | 348 | # Logging API, forwarded 349 | # ---------------------------------------- 350 | def logkv(self, key, val): 351 | self.name2val[key] = val 352 | 353 | def logkv_mean(self, key, val): 354 | oldval, cnt = self.name2val[key], self.name2cnt[key] 355 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 356 | self.name2cnt[key] = cnt + 1 357 | 358 | def dumpkvs(self): 359 | if self.comm is None: 360 | d = self.name2val 361 | else: 362 | d = mpi_weighted_mean( 363 | self.comm, 364 | { 365 | name: (val, self.name2cnt.get(name, 1)) 366 | for (name, val) in self.name2val.items() 367 | }, 368 | ) 369 | if self.comm.rank != 0: 370 | d["dummy"] = 1 # so we don't get a warning about empty dict 371 | # LISA 372 | out = d.copy() # Return the dict for unit testing purposes 373 | if int(os.environ['LOCAL_RANK']) == 0: 374 | wandb.log({**d}) 375 | for fmt in self.output_formats: 376 | if isinstance(fmt, KVWriter): 377 | fmt.writekvs(d) 378 | self.name2val.clear() 379 | self.name2cnt.clear() 380 | return out 381 | 382 | def log(self, *args, level=INFO): 383 | if self.level <= level: 384 | self._do_log(args) 385 | 386 | # Configuration 387 | # ---------------------------------------- 388 | def set_level(self, level): 389 | self.level = level 390 | 391 | def set_comm(self, comm): 392 | self.comm = comm 393 | 394 | def get_dir(self): 395 | return self.dir 396 | 397 | def close(self): 398 | for fmt in self.output_formats: 399 | fmt.close() 400 | 401 | # Misc 402 | # ---------------------------------------- 403 | def _do_log(self, args): 404 | for fmt in self.output_formats: 405 | if isinstance(fmt, SeqWriter): 406 | fmt.writeseq(map(str, args)) 407 | 408 | 409 | def get_rank_without_mpi_import(): 410 | # check environment variables here instead of importing mpi4py 411 | # to avoid calling MPI_Init() when this module is imported 412 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 413 | if varname in os.environ: 414 | return int(os.environ[varname]) 415 | return 0 416 | 417 | 418 | def mpi_weighted_mean(comm, local_name2valcount): 419 | """ 420 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 421 | Perform a weighted average over dicts that are each on a different node 422 | Input: local_name2valcount: dict mapping key -> (value, count) 423 | Returns: key -> mean 424 | """ 425 | all_name2valcount = comm.gather(local_name2valcount) 426 | if comm.rank == 0: 427 | name2sum = defaultdict(float) 428 | name2count = defaultdict(float) 429 | for n2vc in all_name2valcount: 430 | for (name, (val, count)) in n2vc.items(): 431 | try: 432 | val = float(val) 433 | except ValueError: 434 | if comm.rank == 0: 435 | warnings.warn( 436 | "WARNING: tried to compute mean on non-float {}={}".format( 437 | name, val 438 | ) 439 | ) 440 | else: 441 | name2sum[name] += val * count 442 | name2count[name] += count 443 | return {name: name2sum[name] / name2count[name] for name in name2sum} 444 | else: 445 | return {} 446 | 447 | 448 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 449 | """ 450 | If comm is provided, average all numerical stats across that comm 451 | """ 452 | if dir is None: 453 | dir = os.getenv("OPENAI_LOGDIR") 454 | if dir is None: 455 | dir = os.path.join( 456 | tempfile.gettempdir(), 457 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 458 | ) 459 | assert isinstance(dir, str) 460 | dir = os.path.expanduser(dir) 461 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 462 | 463 | rank = get_rank_without_mpi_import() 464 | if rank > 0: 465 | log_suffix = log_suffix + "-rank%03i" % rank 466 | 467 | if format_strs is None: 468 | if rank == 0: 469 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 470 | else: 471 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 472 | format_strs = filter(None, format_strs) 473 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 474 | 475 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 476 | if output_formats: 477 | log("Logging to %s" % dir) 478 | 479 | 480 | def _configure_default_logger(): 481 | configure() 482 | Logger.DEFAULT = Logger.CURRENT 483 | 484 | 485 | def reset(): 486 | if Logger.CURRENT is not Logger.DEFAULT: 487 | Logger.CURRENT.close() 488 | Logger.CURRENT = Logger.DEFAULT 489 | log("Reset logger") 490 | 491 | 492 | @contextmanager 493 | def scoped_configure(dir=None, format_strs=None, comm=None): 494 | prevlogger = Logger.CURRENT 495 | configure(dir=dir, format_strs=format_strs, comm=comm) 496 | try: 497 | yield 498 | finally: 499 | Logger.CURRENT.close() 500 | Logger.CURRENT = prevlogger 501 | -------------------------------------------------------------------------------- /basic_utils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class InfiniteSampler(torch.utils.data.Sampler): 5 | def __init__( 6 | self, 7 | dataset, num_replicas: int = 1, 8 | rank: int = None, shuffle: bool = True, 9 | seed: int = 0, window_size: float = 0.5 10 | ) -> None: 11 | from . import dist_util 12 | assert len(dataset) > 0 13 | assert num_replicas > 0 14 | assert 0 <= rank < num_replicas 15 | assert 0 <= window_size <= 1 16 | super().__init__(dataset) 17 | self.dataset = dataset 18 | self.num_replicas = num_replicas if num_replicas is not None else dist_util.get_world_size() 19 | self.rank = rank if rank is not None else dist_util.get_rank() 20 | self.shuffle = shuffle 21 | self.seed = seed 22 | self.window_size = window_size 23 | 24 | def __iter__(self): 25 | length = len(self.dataset) 26 | if self.shuffle: 27 | g = torch.Generator() 28 | g.manual_seed(self.seed + self.epoch) 29 | order = torch.randperm(length, generator=g).tolist() # type: ignore[arg-type] 30 | window = round(length * self.window_size) 31 | else: 32 | g = None 33 | order = list(range(length)) # type: ignore[arg-type] 34 | window = 0 35 | 36 | idx = 0 37 | while True: 38 | i = idx % length 39 | if idx % self.num_replicas == self.rank: 40 | yield order[i] 41 | if window >= 2: 42 | j = (i - torch.randint(window, size=(), generator=g)) % length 43 | order[i], order[j] = order[j], order[i] 44 | idx += 1 45 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/config/__init__.py -------------------------------------------------------------------------------- /config/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pre-configured base setting class, type (choice), and item with description. 3 | Compatible with ArgumentParser. 4 | """ 5 | from typing import Literal 6 | try: 7 | from typing import get_args 8 | except ImportError: 9 | from typing_extensions import get_args 10 | from argparse import ArgumentParser as Ap, ArgumentDefaultsHelpFormatter as Df 11 | from pydantic import BaseModel, Field, validator 12 | from pydantic.validators import bool_validator 13 | 14 | 15 | class ArgparseCompatibleBaseModel(BaseModel): 16 | 17 | class Config: # Override this to allow extra kwargs 18 | extra = "forbid" 19 | 20 | @classmethod 21 | def from_argparse(cls, namespace, __top=True): 22 | if not isinstance(namespace, dict): 23 | namespace = vars(namespace) 24 | kwargs = {} 25 | for name, field in cls.__fields__.items(): 26 | if isinstance(field.type_, type) and issubclass(field.type_, BaseModel): 27 | kwargs[name] = ArgparseCompatibleBaseModel.from_argparse.__func__(field.type_, namespace, __top=False) # NOQA 28 | else: 29 | kwargs[name] = namespace.pop(name) 30 | assert not (__top and namespace), str(namespace) 31 | return cls(**kwargs) 32 | 33 | @classmethod 34 | def to_argparse(cls, parser_or_group=None): 35 | if parser_or_group is None: 36 | parser_or_group = Ap(formatter_class=Df) 37 | for name, field in cls.__fields__.items(): 38 | if isinstance(field.type_, type) and issubclass(field.type_, BaseModel): 39 | group = parser_or_group.add_argument_group(name) 40 | ArgparseCompatibleBaseModel.to_argparse.__func__(field.type_, group) # NOQA 41 | continue 42 | kw = dict(dest=name, type=field.type_, default=field.default, 43 | help=field.field_info.description, required=field.required) 44 | if getattr(field.type_, '__origin__', None) is Literal: 45 | choices = tuple(get_args(field.outer_type_)) 46 | s = "def {name}(arg):\n for ch in __CHOICES__:\n" \ 47 | " if str(ch) == arg:\n return ch\n raise ValueError" \ 48 | .format(name=name) 49 | n = {"__CHOICES__": choices, "__name__": name} 50 | exec(s, n) 51 | kw.update(type=n[name], choices=choices, metavar="{"+", ".join(map(str, choices))+"}") 52 | elif isinstance(field.type_, type) and issubclass(field.type_, bool): 53 | kw.update(type=bool_validator, metavar="{true, false}") 54 | parser_or_group.add_argument("--" + name, **kw) 55 | return parser_or_group 56 | 57 | @classmethod 58 | def from_argv(cls, argv=None): 59 | return cls.from_argparse(cls.to_argparse().parse_args(argv)) 60 | 61 | 62 | S = Setting = ArgparseCompatibleBaseModel # Alias 63 | 64 | 65 | def choice(*args): 66 | return Literal.__getitem__(args) 67 | 68 | 69 | C = Choice = choice # Alias 70 | 71 | 72 | def item(default, description=None): 73 | return Field(default, description=description) 74 | 75 | 76 | _ = Item = item # Alias 77 | 78 | 79 | Validator = validator # Alias 80 | 81 | 82 | __all__ = ( 83 | 'ArgparseCompatibleBaseModel', 'Setting', 'S', 84 | 'choice', 'Choice', 'C', 85 | 'item', 'Item', '_', 86 | 'validator', 'Validator', 87 | ) 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | import yaml 93 | 94 | class Config1(S): 95 | a: int = _(1, description='this is a') 96 | b: int = _(2, description='this is b') 97 | 98 | class Config2(S): 99 | c: C('choice1', 'choice2') = _('choice2', description='this is c') 100 | d: bool = _(True, description='this is d') 101 | 102 | class Config(S): 103 | conf1: Config1 = Config1() 104 | conf2: Config2 = Config2() 105 | 106 | Config.to_argparse().print_help() 107 | yaml.dump(Config.from_argv().dict(), __import__('sys').stdout) 108 | -------------------------------------------------------------------------------- /config/train.py: -------------------------------------------------------------------------------- 1 | from typing import final 2 | from argparse import ArgumentParser as Ap, ArgumentDefaultsHelpFormatter as Df 3 | from .base import S, Choice, Item as _ 4 | 5 | 6 | class GeneralSettings(S): 7 | lr: float \ 8 | = _(1e-4, "Learning Rate") 9 | batch_size: int \ 10 | = _(2048, "Batch size of running step and optimizing") 11 | microbatch: int \ 12 | = _(64, "Batch size for forward and backward") 13 | learning_steps: int \ 14 | = _(320000, "Steps for whole iteration") 15 | log_interval: int \ 16 | = _(20, "Steps per log") 17 | save_interval: int \ 18 | = _(2000, "Steps per save") 19 | eval_interval: int \ 20 | = _(1000, "Steps per eval") 21 | ema_rate: str \ 22 | = _("0.5,0.9,0.99", "EMA rate. separate rates by comma(',').") 23 | seed: int \ 24 | = _(102, "Seed for train or test.") 25 | resume_checkpoint: str \ 26 | = _("", "Checkpoint path(.pt) to resume training") 27 | checkpoint_path: str \ 28 | = _("", "! This will be automatically updated while training !") 29 | gradient_clipping: float \ 30 | = _(0., "Gradient clipping (>0), default: 0 (no clipping). ") 31 | weight_decay: float \ 32 | = _(0., "Weight decay.") 33 | 34 | 35 | class DataSettings(S): 36 | dataset: str \ 37 | = _("dataset", "Name of dataset.") 38 | data_dir: str \ 39 | = _("datasets/dataset", "Path for dataset to be saved.") 40 | data_loader_workers: int \ 41 | = _(2, "num_workers for DataLoader.") 42 | 43 | 44 | class YourSettings(S): 45 | # TODO: add extra settings on your own 46 | pass 47 | 48 | 49 | @final 50 | class TrainSettings( 51 | YourSettings, 52 | # TODO: inherit setting classes "reversely", due to python mro. 53 | DataSettings, 54 | GeneralSettings 55 | ): 56 | 57 | @classmethod 58 | def to_argparse(cls, parser_or_group=None, add_json=False): 59 | if not add_json: 60 | return super(TrainSettings, cls).to_argparse(parser_or_group) 61 | if parser_or_group is None: 62 | parser_or_group = Ap(formatter_class=Df) 63 | setting_group = parser_or_group.add_argument_group(title="settings") 64 | setting_group.add_mutually_exclusive_group().add_argument( 65 | "--config_json", type=str, required=False, 66 | help="You can alter arguments all below by config_json file.") 67 | super(TrainSettings, cls).to_argparse(setting_group.add_mutually_exclusive_group()) 68 | return parser_or_group 69 | 70 | @classmethod 71 | def from_argparse(cls, namespace, __top=True): 72 | if getattr(namespace, "config_json", None): 73 | return cls.parse_file(namespace.config_json) 74 | else: 75 | if hasattr(namespace, "config_json"): 76 | delattr(namespace, "config_json") 77 | return super(TrainSettings, cls).from_argparse(namespace) 78 | 79 | 80 | __all__ = ('TrainSettings', ) 81 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | def load_data_from_args( 2 | split, 3 | data_dir, 4 | # TODO: add dataset args on your own 5 | batch_size: int, 6 | deterministic: bool = False, 7 | loop: bool = True, 8 | seed: ... = 0, 9 | window_size: float = 0.5, 10 | num_loader_proc: int = 1, 11 | ): 12 | from basic_utils import dist_util 13 | from basic_utils.sampler import InfiniteSampler 14 | from torch.utils.data.distributed import DistributedSampler 15 | from torch.utils.data import DataLoader 16 | from .dataset import CustomDataset 17 | 18 | rank = dist_util.get_rank() 19 | num_replicas = dist_util.get_world_size() 20 | 21 | # TODO: add your data-loading function from split & data_dir & your arguments 22 | dataset_kwargs = dict(data_dir=data_dir, split=split) # TODO 23 | dataset = CustomDataset(**dataset_kwargs) 24 | 25 | sampler_kwargs = dict( 26 | num_replicas=num_replicas, 27 | rank=rank, 28 | shuffle=not deterministic, 29 | seed=hash(seed) 30 | ) 31 | if loop: 32 | sampler = InfiniteSampler(dataset, **sampler_kwargs, window_size=window_size) 33 | else: 34 | sampler = DistributedSampler(dataset, **sampler_kwargs, drop_last=False) 35 | 36 | return DataLoader( 37 | dataset, 38 | batch_size=batch_size // num_replicas, 39 | sampler=sampler, 40 | num_workers=num_loader_proc, 41 | persistent_workers=num_loader_proc > 0, 42 | # pin_memory=True, 43 | ) 44 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # Add your dataset configuration 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class CustomDataset(Dataset): 6 | # TODO: implement this class 7 | 8 | def __init__(self, *args, **kwargs): 9 | raise NotImplementedError 10 | 11 | def __getitem__(self, item): 12 | raise NotImplementedError 13 | 14 | def __len__(self): 15 | raise NotImplementedError 16 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/models/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/torch_stable.html 2 | torch>=1.9.0 3 | -f https://download.pytorch.org/whl/torch_stable.html 4 | torchvision>=0.10.0 5 | 6 | blobfile 7 | pydantic 8 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node gpu train.py --config_json train_config.json 2 | # python -m torch.distributed.launch --use_env --nproc_per_node 4 train.py --config_json train_config.json # <=1.8 3 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.distributed.elastic.multiprocessing.errors import record 3 | from config.train import TrainSettings 4 | 5 | 6 | def parse_args(): 7 | parser = TrainSettings.to_argparse(add_json=True) 8 | return parser.parse_args() 9 | 10 | 11 | def train(rank, args): 12 | 13 | # Import dependencies 14 | import time 15 | import json 16 | 17 | # Import everything 18 | from data import load_data_from_args 19 | from basic_utils import dist_util, logger 20 | from utils.initialization import create_model_from_config, seed_all 21 | from utils.trainer import TrainLoop 22 | 23 | dist_util.barrier() # Sync 24 | 25 | # Set checkpoint path 26 | folder_name = "model_checkpoints/" 27 | if not os.path.isdir(folder_name) and rank == 0: 28 | os.mkdir(folder_name) 29 | if not args.checkpoint_path: 30 | model_file = f"Run_{args.dataset}_lr{args.lr}" \ 31 | f"_seed{args.seed}_{time.strftime('%Y%m%d-%H:%M:%S')}" # TODO: add your naming rule by args 32 | args.checkpoint_path = os.path.join(folder_name, model_file) 33 | if not os.path.isdir(args.checkpoint_path) and rank == 0: 34 | os.mkdir(args.checkpoint_path) 35 | 36 | # Configure log and seed 37 | logger.configure(dir=args.checkpoint_path, format_strs=["log", "csv"] + (["stdout"] if rank == 0 else [])) 38 | seed_all(args.seed) 39 | 40 | # Prepare dataloader 41 | logger.log("### Creating data loader...") 42 | dist_util.barrier() # Sync 43 | data = load_data_from_args( 44 | split='train', 45 | data_dir=args.data_dir, 46 | batch_size=args.batch_size, 47 | # TODO: add args on your own 48 | deterministic=False, 49 | loop=True, 50 | seed=args.seed, 51 | num_loader_proc=args.data_loader_workers, 52 | ) 53 | data_valid = load_data_from_args( 54 | split='valid', 55 | data_dir=args.data_dir, 56 | batch_size=args.batch_size, 57 | # TODO: add args on your own 58 | deterministic=True, 59 | loop=True, 60 | seed=args.seed, 61 | num_loader_proc=args.data_loader_workers, 62 | ) 63 | dist_util.barrier() # Sync 64 | 65 | # Initialize model 66 | logger.log("### Creating model...") 67 | model = create_model_from_config(**args.dict()) 68 | 69 | # Load model to each node's device 70 | model.to(dist_util.dev()) 71 | dist_util.barrier() # Sync 72 | 73 | # Count and log total params 74 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 75 | logger.log(f'### The parameter count is {pytorch_total_params}') 76 | 77 | # Save training args 78 | training_args_path = f'{args.checkpoint_path}/training_args.json' 79 | if not os.path.exists(training_args_path): 80 | logger.log(f'### Saving the hyperparameters to {args.checkpoint_path}/training_args.json') 81 | if dist_util.get_rank() == 0: 82 | with open(training_args_path, 'w') as fp: 83 | json.dump(args.dict(), fp, indent=2) 84 | 85 | # Init wandb 86 | if dist_util.get_rank() == 0: 87 | # TODO: Uncomment and customize your wandb setting on your own, or just use environ. 88 | import wandb 89 | wandb.init( 90 | mode=os.getenv("WANDB_MODE", "online"), 91 | # entity=os.getenv("WANDB_ENTITY", ""), 92 | # project=os.getenv("WANDB_PROJECT", ""), 93 | ) 94 | wandb.config.update(args.__dict__, allow_val_change=True) 95 | dist_util.barrier() # Sync last 96 | 97 | # Run train loop 98 | logger.log("### Training...") 99 | train_loop = TrainLoop( 100 | model=model, 101 | # TODO: add your argument 102 | data=data, 103 | batch_size=args.batch_size, 104 | microbatch=args.microbatch, 105 | lr=args.lr, 106 | ema_rate=args.ema_rate, 107 | log_interval=args.log_interval, 108 | save_interval=args.save_interval, 109 | resume_checkpoint=args.resume_checkpoint, 110 | weight_decay=args.weight_decay, 111 | learning_steps=args.learning_steps, 112 | checkpoint_path=args.checkpoint_path, 113 | gradient_clipping=args.gradient_clipping, 114 | eval_data=data_valid, 115 | eval_interval=args.eval_interval, 116 | eval_callbacks=[] 117 | ) 118 | train_loop.run_loop() 119 | 120 | 121 | @record 122 | def main(namespace): 123 | 124 | # Create config from parsed argument namespace 125 | args: TrainSettings = TrainSettings.from_argparse(namespace) 126 | 127 | # Import dist_util 128 | from basic_utils import dist_util 129 | 130 | # Setup distributed 131 | if "LOCAL_RANK" in os.environ: 132 | dist_util.setup_dist() 133 | rank = dist_util.get_rank() 134 | 135 | # Run training 136 | try: 137 | train(rank, args) 138 | finally: 139 | dist_util.cleanup_dist() 140 | 141 | 142 | if __name__ == "__main__": 143 | main(parse_args()) 144 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/utils/__init__.py -------------------------------------------------------------------------------- /utils/initialization.py: -------------------------------------------------------------------------------- 1 | def seed_all(seed, deterministic=False): 2 | import random 3 | import numpy as np 4 | import torch 5 | from basic_utils.dist_util import get_rank 6 | if deterministic: 7 | seed = int(seed) 8 | torch.backends.cudnn.deterministic = True # NOQA 9 | torch.backends.cudnn.benchmark = False # NOQA 10 | else: 11 | seed = int(seed) + get_rank() # Make seed differ by node rank 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) # contains torch.cuda.manual_seed_all 15 | # TODO: Add your seeder if needed 16 | 17 | 18 | def create_model_from_config( 19 | *, 20 | argument1, 21 | argument2, 22 | argument3, 23 | **_ 24 | ): 25 | # TODO: Implement this function 26 | model = ... 27 | return model 28 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer copied and adapted from: 3 | https://github.com/Shark-NLP/DiffuSeq/blob/8bfafcbb26df218073b8117234afb9de9dfcbec9/train_util.py 4 | """ 5 | import copy 6 | import os 7 | 8 | import blobfile as bf 9 | import math 10 | import torch 11 | from torch.nn.parallel.distributed import DistributedDataParallel 12 | from torch.optim import AdamW 13 | 14 | from basic_utils import dist_util, logger 15 | 16 | 17 | class TrainLoop: 18 | 19 | def log_loss_dict(self, mode, losses, *args, **kwargs): # mode: train or eval 20 | # TODO: Add your loss logging function with logger.logkv_mean 21 | raise NotImplementedError 22 | 23 | def compute_losses(self, micro_batch): 24 | # TODO: Add your loss function with logger.logkv_mean 25 | raise NotImplementedError 26 | 27 | @staticmethod 28 | def backward_from_losses(losses): 29 | # TODO: Add your loss-reducing function 30 | loss = ... 31 | loss.backward() 32 | 33 | @classmethod 34 | def get_batch_length(cls, batch): 35 | if isinstance(batch, torch.Tensor): 36 | return batch.shape[0] 37 | elif isinstance(batch, dict): 38 | return cls.get_batch_length(batch[list(batch.keys())[0]]) 39 | elif isinstance(batch, list): 40 | return cls.get_batch_length(batch[0]) 41 | else: 42 | # TODO: Add your custom function if needed 43 | raise TypeError("Unsupported batch type: {}".format(type(batch).__name__)) 44 | 45 | def __init__( 46 | self, 47 | *, 48 | model, 49 | # TODO: Add your arguments 50 | data, 51 | batch_size, 52 | microbatch, 53 | lr, 54 | ema_rate, 55 | log_interval, 56 | save_interval, 57 | resume_checkpoint, 58 | weight_decay=0.0, 59 | learning_steps=0, 60 | checkpoint_path='', 61 | gradient_clipping=-1., 62 | eval_data=None, 63 | eval_interval=-1, 64 | eval_callbacks=(), # e.g. plotting utils for wandb 65 | ): 66 | self.model = model 67 | self.data = data 68 | self.eval_data = eval_data 69 | self.batch_size = batch_size 70 | self.microbatch = microbatch if microbatch > 0 else batch_size 71 | self.lr = float(lr) 72 | self.ema_rate = ( 73 | [ema_rate] 74 | if isinstance(ema_rate, float) 75 | else [float(x) for x in ema_rate.split(",")] 76 | ) 77 | self.log_interval = log_interval 78 | self.eval_interval = eval_interval 79 | self.save_interval = save_interval 80 | self.resume_checkpoint = resume_checkpoint 81 | self.weight_decay = weight_decay 82 | self.learning_steps = learning_steps 83 | self.gradient_clipping = gradient_clipping 84 | 85 | # TODO: Add your arguments 86 | 87 | self.step = 0 88 | self.resume_step = 0 89 | self.global_batch = self.batch_size * dist_util.get_world_size() 90 | 91 | self.model_params = list(self.model.parameters()) 92 | self.master_params = self.model_params 93 | self.eval_callbacks = list(eval_callbacks) 94 | 95 | self.checkpoint_path = checkpoint_path # DEBUG ** 96 | 97 | self._load_and_sync_parameters() 98 | 99 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) 100 | if self.resume_step: 101 | # Model was resumed, either due to a restart or a checkpoint 102 | # being specified at the command line. 103 | self._load_optimizer_state() 104 | # frac_done = (self.step + self.resume_step) / self.learning_steps 105 | # lr = self.lr * (1 - frac_done) 106 | # self.opt = AdamW(self.master_params, lr=lr, weight_decay=self.weight_decay) 107 | self.ema_params = [ 108 | self._load_ema_parameters(rate) for rate in self.ema_rate 109 | ] 110 | else: 111 | self.ema_params = [ 112 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) 113 | ] 114 | 115 | if dist_util.is_initialized(): 116 | self.use_ddp = True 117 | print(dist_util.dev()) 118 | ddp_kwargs = dict( 119 | broadcast_buffers=False, 120 | bucket_cap_mb=128, 121 | find_unused_parameters=False, 122 | ) 123 | if not torch.cuda.is_available(): 124 | ddp_kwargs.update(device_ids=[dist_util.dev()]) 125 | ddp_kwargs.update(output_device=dist_util.dev()) 126 | self.ddp_model = DistributedDataParallel(self.model, **ddp_kwargs) 127 | else: 128 | self.use_ddp = False 129 | self.ddp_model = self.model 130 | 131 | # In checkpoint-loading process, sometimes GPU 0 is used and allocated. 132 | # After all process is done, free GPU 0 by function below: 133 | torch.cuda.empty_cache() 134 | 135 | # # # # # # # # # # # # # # # Implemented Functions # # # # # # # # # # # # # # # 136 | 137 | def _load_and_sync_parameters(self): 138 | resume_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint 139 | if not resume_checkpoint: 140 | return 141 | 142 | self.resume_step = self.parse_resume_step_from_filename(resume_checkpoint) 143 | if dist_util.get_rank() == 0: 144 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 145 | self.model.load_state_dict( 146 | dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()) 147 | ) 148 | 149 | dist_util.sync_params(self.model.parameters()) 150 | 151 | def _load_ema_parameters(self, rate): 152 | ema_params = copy.deepcopy(self.master_params) 153 | main_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint 154 | if not main_checkpoint: 155 | return 156 | ema_checkpoint = self.find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 157 | if ema_checkpoint: 158 | if dist_util.get_rank() == 0: 159 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 160 | state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev()) 161 | ema_params = self._state_dict_to_master_params(state_dict) 162 | 163 | dist_util.sync_params(ema_params) 164 | return ema_params 165 | 166 | def _load_optimizer_state(self): 167 | main_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint 168 | if not main_checkpoint: 169 | return 170 | opt_checkpoint = self.find_opt_checkpoint(main_checkpoint, self.resume_step) 171 | if bf.exists(opt_checkpoint): 172 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 173 | state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) 174 | self.opt.load_state_dict(state_dict) 175 | 176 | def run_loop(self): 177 | while ( 178 | not self.learning_steps 179 | or self.step + self.resume_step < self.learning_steps 180 | ): 181 | batch = next(self.data) 182 | self.run_step(batch) 183 | if self.step % self.log_interval == 0: 184 | logger.dumpkvs() 185 | if self.eval_data is not None and self.step % self.eval_interval == 0: 186 | cond_eval = next(self.eval_data) 187 | self.forward_only(cond_eval) 188 | print('eval on validation set') 189 | logger.dumpkvs() 190 | if dist_util.get_rank() == 0: 191 | for callback in self.eval_callbacks: 192 | callback(self) 193 | if self.step > 0 and self.step % self.save_interval == 0: 194 | self.save() 195 | self.step += 1 196 | if (self.step - 1) % self.save_interval != 0: 197 | self.save() 198 | 199 | def run_step(self, batch): 200 | self.forward_backward(batch) 201 | self.optimize() 202 | self.log_step() 203 | 204 | def _zero_grad(self): 205 | for param in self.model_params: 206 | if param.grad is not None: 207 | param.grad.detach_() 208 | param.grad.zero_() 209 | 210 | def _common_forward(self, total_batch, start_index): 211 | micro_batch = { 212 | k: v[start_index: start_index + self.microbatch].to(dist_util.dev()) 213 | for k, v in total_batch.items() 214 | } 215 | last_batch = (start_index + self.microbatch) >= self.get_batch_length(total_batch) 216 | 217 | if last_batch or not self.use_ddp: 218 | losses = self.compute_losses(micro_batch) 219 | else: 220 | with self.ddp_model.no_sync(): 221 | losses = self.compute_losses(micro_batch) 222 | return losses 223 | 224 | @torch.no_grad() 225 | def forward_only(self, batch): 226 | self._zero_grad() 227 | for i in range(0, self.get_batch_length(batch), self.microbatch): 228 | losses = self._common_forward(batch, i) 229 | self.log_loss_dict(mode="eval", losses=losses) 230 | 231 | def forward_backward(self, batch): 232 | self._zero_grad() 233 | for i in range(0, self.get_batch_length(batch), self.microbatch): 234 | losses = self._common_forward(batch, i) 235 | self.log_loss_dict(mode="eval", losses=losses) 236 | self.backward_from_losses(losses) 237 | 238 | def optimize(self): 239 | if self.gradient_clipping > 0: 240 | self.grad_clip() 241 | self._log_grad_norm() 242 | self._anneal_lr() 243 | self.opt.step() 244 | for rate, params in zip(self.ema_rate, self.ema_params): 245 | update_ema(params, self.master_params, rate=rate) 246 | 247 | def grad_clip(self): 248 | max_grad_norm = self.gradient_clipping 249 | if hasattr(self.opt, "clip_grad_norm"): 250 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 251 | self.opt.clip_grad_norm(max_grad_norm) 252 | else: 253 | torch.nn.utils.clip_grad_norm_( 254 | self.model.parameters(), 255 | max_grad_norm, 256 | ) 257 | 258 | def _anneal_lr(self): 259 | if not self.learning_steps: 260 | return 261 | frac_done = (self.step + self.resume_step) / self.learning_steps 262 | lr = self.lr * (1 - frac_done) 263 | for param_group in self.opt.param_groups: 264 | param_group["lr"] = lr 265 | 266 | def _log_grad_norm(self): 267 | sqsum = 0.0 268 | # cnt = 0 269 | for p in self.master_params: 270 | if p.grad is not None: 271 | sqsum += (p.grad ** 2).sum().item() 272 | logger.logkv_mean("grad_norm", math.sqrt(sqsum)) 273 | 274 | def log_step(self): 275 | logger.logkv("step", self.step + self.resume_step) 276 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 277 | 278 | def save(self): 279 | self._save_checkpoint(0, self.master_params) 280 | for r, p in zip(self.ema_rate, self.ema_params): 281 | self._save_checkpoint(r, p) 282 | self._save_opt() 283 | dist_util.barrier() 284 | 285 | def _save_checkpoint(self, rate, params): 286 | state_dict = self._master_params_to_state_dict(params) 287 | if dist_util.get_rank() == 0: 288 | logger.log(f"saving model {rate}...") 289 | if not rate: 290 | filename = f"model_{(self.step+self.resume_step):06d}.pt" 291 | else: 292 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 293 | print('writing to', bf.join(self.checkpoint_path, filename)) 294 | with bf.BlobFile(bf.join(self.checkpoint_path, filename), "wb") as f: 295 | torch.save(state_dict, f) 296 | 297 | def _save_opt(self): 298 | if dist_util.get_rank() == 0: 299 | logger.log(f"saving optimizer...") 300 | filename = f"opt_{(self.step+self.resume_step):06d}.pt" 301 | print('writing to', bf.join(self.checkpoint_path, filename)) 302 | with bf.BlobFile(bf.join(self.checkpoint_path, filename), "wb") as f: 303 | torch.save(self.opt.state_dict(), f) 304 | 305 | def _master_params_to_state_dict(self, master_params, key=None): 306 | state_dict = self.model.state_dict() 307 | for i, (name, _value) in enumerate(self.model.named_parameters()): 308 | assert name in state_dict 309 | if key is not None and key == name: 310 | return master_params[i] 311 | state_dict[name] = master_params[i] 312 | if key is not None: 313 | raise KeyError(key) 314 | return state_dict 315 | 316 | def _state_dict_to_master_params(self, state_dict): 317 | params = [state_dict[name] for name, _ in self.model.named_parameters()] 318 | return params 319 | 320 | @staticmethod 321 | def parse_resume_step_from_filename(filename): 322 | """ 323 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 324 | checkpoint's number of steps. 325 | """ 326 | filename: str = os.path.basename(filename) 327 | assert filename.startswith('model') and filename[-3:] == '.pt', "Invalid model name" 328 | return int(filename[-9:-3]) 329 | 330 | @staticmethod 331 | def find_resume_checkpoint(): 332 | log_dir = logger.get_current().dir 333 | model_weights = sorted(filter(lambda s: s.endswith(".pt") and s.startswith("model"), os.listdir(log_dir))) 334 | if model_weights: 335 | path = os.path.join(log_dir, model_weights[-1]) 336 | return path 337 | 338 | @staticmethod 339 | def find_ema_checkpoint(main_checkpoint, step, rate): 340 | if not main_checkpoint: 341 | return None 342 | filename = f"ema_{rate}_{step:06d}.pt" 343 | path = bf.join(bf.dirname(main_checkpoint), filename) 344 | if bf.exists(path): 345 | return path 346 | return None 347 | 348 | @staticmethod 349 | def find_opt_checkpoint(main_checkpoint, step): 350 | if not main_checkpoint: 351 | return None 352 | filename = f"opt_{step:06d}.pt" 353 | path = bf.join(bf.dirname(main_checkpoint), filename) 354 | if bf.exists(path): 355 | return path 356 | return None 357 | 358 | __call__ = run_loop 359 | 360 | 361 | def update_ema(target_params, source_params, rate=0.99): 362 | """ 363 | Update target parameters to be closer to those of source parameters using 364 | an exponential moving average. 365 | 366 | :param target_params: the target parameter sequence. 367 | :param source_params: the source parameter sequence. 368 | :param rate: the EMA rate (closer to 1 means slower). 369 | """ 370 | for trg, src in zip(target_params, source_params): 371 | trg.detach().mul_(rate).add_(src, alpha=1 - rate) 372 | --------------------------------------------------------------------------------