├── .gitignore
├── LICENSE
├── README.md
├── basic_utils
├── __init__.py
├── dist_util.py
├── logger.py
└── sampler.py
├── config
├── __init__.py
├── base.py
└── train.py
├── data
├── __init__.py
└── dataset.py
├── models
└── __init__.py
├── requirements.txt
├── run_train.sh
├── train.py
└── utils
├── __init__.py
├── initialization.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Manual ###
2 | # Add extra ignored paths on your own
3 |
4 | ### WandB and Torch ###
5 | wandb/
6 | **/wandb/
7 | *.pth
8 | *.pt
9 |
10 | ### IDE ###
11 | .vscode
12 | .idea/
13 | *.iml
14 | modules.xml
15 | *.ipr
16 |
17 | ### Python Cache ###
18 | *.log
19 | *.pyc
20 | __pycache__/
21 |
22 | ### Python Virtual Environments ###
23 | .env
24 | .venv
25 | env/
26 | venv/
27 | ENV/
28 | env.bak/
29 | venv.bak/
30 |
31 | ### OS Related ###
32 | *~
33 | .fuse_hidden*
34 | .directory
35 | .Trash-*
36 | .nfs*
37 | .DS_Store
38 | .AppleDouble
39 | .LSOverride
40 | Icon
41 | ._*
42 | .DocumentRevisions-V100
43 | .fseventsd
44 | .Spotlight-V100
45 | .TemporaryItems
46 | .Trashes
47 | .VolumeIcon.icns
48 | .com.apple.timemachine.donotpresent
49 | .AppleDB
50 | .AppleDesktop
51 | Network Trash Folder
52 | Temporary Items
53 | .apdisk
54 | Thumbs.db
55 | Thumbs.db:encryptable
56 | ehthumbs.db
57 | ehthumbs_vista.db
58 | *.stackdump
59 | [Dd]esktop.ini
60 | $RECYCLE.BIN/
61 | *.lnk
62 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 kdha0727
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch pipeline with `torch.distributed`
2 |
3 | `utils.trainer.TrainLoop` will run training loop - compatible with torch.distributed.run
4 |
5 | ## All you need to do
6 |
7 | * Complete `config/train.py`'s `TrainSettings`(or `YourSettings`) class.
8 | * this setting class is compatible with argparse and json.
9 | * Complete `data/__init__.py`'s `load_data_from_args` function.
10 | * Complete `model` package.
11 | * Complete `utils/initialization.py`'s `create_model_from_config` function.
12 | * Complete some method of `utils/trainer.py`'s `TrainLoop` class.
13 | * `log_loss_dict` method: logging function of loss values dict.
14 | * `compute_losses` method: calculate `losses` from `micro_batch` and TrainLoop vars
15 | * `backward_from_losses` method: make single `loss` from `losses`, and run `loss.backward()`
16 | * `__init__` method: add your extra values to TrainLoop vars if needed.
17 | * Complete `train.py` to make sense with all code signatures you modified.
18 | * Modify setting json file, after copying default train settings with command,
19 | ```
20 | python3 -c "from config import TrainSettings as T; print(T().json(indent=2))" >> train_config.json
21 | ```
22 |
23 |
24 | View simplest train.py script example:
25 |
26 | ```python
27 | from torch.distributed.elastic.multiprocessing.errors import record
28 |
29 |
30 | def main():
31 |
32 | import os
33 | import torch
34 | from basic_utils import dist_util
35 |
36 | if os.getenv("LOCAL_RANK", None) and not dist_util.is_initialized():
37 | dist_util.setup_dist()
38 | with dist_util.with_dist_cleanup():
39 | main()
40 | return
41 | rank = dist_util.get_rank()
42 | dist_util.barrier()
43 |
44 | class Model(torch.nn.Module):
45 |
46 | def __init__(self):
47 | super().__init__()
48 | self.param = torch.nn.Parameter(torch.ones(1))
49 |
50 | def forward(self, x, target=None):
51 | output = self.param * x
52 | if target is not None:
53 | return (target - output) ** 2
54 | return output
55 |
56 | model = Model()
57 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
58 |
59 | model.to(dist_util.dev())
60 | dist_util.barrier()
61 |
62 | if dist_util.is_initialized():
63 | ddp_kwargs = dict(
64 | broadcast_buffers=False,
65 | bucket_cap_mb=128,
66 | find_unused_parameters=False,
67 | )
68 | if torch.cuda.is_available():
69 | ddp_kwargs.update(device_ids=[dist_util.dev()], output_device=dist_util.dev())
70 | ddp_model = torch.nn.parallel.DistributedDataParallel(model, **ddp_kwargs)
71 | else:
72 | ddp_model = model
73 |
74 | dist_util.sequential_print("rank", rank, "param :", model.param.data.item())
75 | dist_util.print_master_node()
76 |
77 | data = torch.ones(1, device=dist_util.dev()) * (dist_util.get_rank() + 1)
78 | target = torch.ones(1, device=dist_util.dev()) * 0.5 * (dist_util.get_rank() + 1)
79 |
80 | with ddp_model.no_sync() if dist_util.is_initialized() else dist_util.dummy_context():
81 | loss = ddp_model(data, target)
82 | dist_util.sequential_print("rank", rank, "loss :", loss.item())
83 | dist_util.print_master_node()
84 |
85 | loss.backward()
86 | dist_util.sequential_print("rank", rank, "grad :", model.param.grad.item())
87 | dist_util.print_master_node()
88 |
89 | loss = ddp_model(data, target)
90 | dist_util.sequential_print("rank", rank, "loss :", loss.item())
91 | dist_util.print_master_node()
92 |
93 | loss.backward()
94 | dist_util.sequential_print("rank", rank, "sync_grad :", model.param.grad.item())
95 | dist_util.print_master_node()
96 |
97 | optimizer.step()
98 | dist_util.sequential_print("rank", rank, "updated_param :", model.param.data.item())
99 | dist_util.barrier()
100 |
101 |
102 | if __name__ == "__main__":
103 | record(main)()
104 |
105 | ```
106 |
107 | Execute it with...
108 |
109 | ```bash
110 | torchrun --nproc_per_node gpu train.py
111 | ```
112 |
113 | Or without distributed training...
114 |
115 | ```bash
116 | python3 train.py
117 | ```
118 |
119 |
120 |
121 | ## How to run
122 |
123 | after completion, you can run train script with
124 |
125 | ```bash
126 | torchrun --nproc_per_node gpu train.py --config_json train_config.json
127 | ```
128 |
129 | ## Citations
130 |
131 | ```bibtex
132 | @inproceedings{gong2022diffuseq,
133 | author = {Gong, Shansan and Li, Mukai and Feng, Jiangtao and Wu, Zhiyong and Kong, Lingpeng},
134 | booktitle = {International Conference on Learning Representations, ICLR},
135 | title = {{DiffuSeq}: Sequence to Sequence Text Generation with Diffusion Models},
136 | year = 2023
137 | }
138 | ```
139 |
--------------------------------------------------------------------------------
/basic_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/basic_utils/__init__.py
--------------------------------------------------------------------------------
/basic_utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 kdha0727
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy
4 | # of this software and associated documentation files (the "Software"), to deal
5 | # in the Software without restriction, including without limitation the rights
6 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | # copies of the Software, and to permit persons to whom the Software is
8 | # furnished to do so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 | """
21 | Helpers for distributed training.
22 | This module's function is compatible even though script is not running in torch.distributed.run environment.
23 | Write code as though you are using torch.distributed.run - if you directly run scripts, it works!
24 | """
25 | __author__ = "https://github.com/kdha0727"
26 |
27 | import io
28 | import os
29 | import contextlib
30 | import functools
31 | from typing import overload, Sequence, Optional, Union, ContextManager, Any
32 |
33 | import torch
34 | import torch.distributed as dist
35 | from torch.nn.parallel.distributed import DistributedDataParallel
36 | from torch.cuda import is_available as _cuda_available
37 |
38 |
39 | RANK: int = 0
40 | WORLD_SIZE: int = 1
41 |
42 |
43 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
44 | # Setup Tools #
45 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
46 |
47 | def is_initialized() -> bool:
48 | # if pytorch isn't compiled with c10d, is_initialized is omitted from namespace.
49 | # this function wraps
50 | """
51 | Returns c10d (distributed) runtime is initialized.
52 | """
53 | return dist.is_available() and getattr(dist, "is_initialized", lambda: False)()
54 |
55 |
56 | @overload
57 | def setup_dist(temp_dir: str, rank: int, world_size: int) -> None: ...
58 |
59 |
60 | @overload
61 | def setup_dist() -> None: ...
62 |
63 |
64 | def setup_dist(*args):
65 | """
66 | Set up a distributed process group.
67 | Usage
68 | 1. setup_dist(temp_dir, rank, world_size)
69 | : if you want to init by file, call this function with three args (temp_dir, rank, world_size).
70 | 2. setup_dist()
71 | : if you want to init by env (by torchrun), call this function without args.
72 | """
73 | if is_initialized():
74 | return
75 |
76 | backend = "gloo" if (not _cuda_available()) or (os.name == "nt") else "nccl"
77 |
78 | if len(args) == 3:
79 | temp_dir, rank, world_size = args
80 | assert os.path.isdir(temp_dir), f"temp_dir {temp_dir} is not a directory"
81 | assert isinstance(rank, int), f"rank {rank} must be int"
82 | assert isinstance(world_size, int), f"world_size {world_size} must be int"
83 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
84 | if os.name == 'nt':
85 | init_method = 'file:///' + init_file.replace('\\', '/')
86 | else:
87 | init_method = f'file://{init_file}'
88 | dist.init_process_group(backend=backend, init_method=init_method, rank=rank, world_size=world_size)
89 | elif len(args) == 0:
90 | assert os.getenv("LOCAL_RANK", None) is not None, "environ LOCAL_RANK is not set"
91 | dist.init_process_group(backend=backend, init_method="env://")
92 | rank = dist.get_rank()
93 | world_size = dist.get_world_size()
94 | else:
95 | raise TypeError("setup_dist() takes 0 or 3 arguments")
96 |
97 | global RANK, WORLD_SIZE
98 | RANK = rank
99 | WORLD_SIZE = world_size
100 |
101 | if _cuda_available():
102 | torch.cuda.set_device(dev())
103 | torch.cuda.empty_cache()
104 |
105 |
106 | def cleanup_dist():
107 | """
108 | Clean up a distributed process group.
109 | """
110 | if is_initialized():
111 | dist.destroy_process_group()
112 |
113 |
114 | @contextlib.contextmanager
115 | def with_dist_cleanup():
116 | """
117 | Context Manager or Decorator version of cleanup_dist().
118 | """
119 | yield
120 | cleanup_dist()
121 |
122 |
123 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
124 | # General Tools #
125 | # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
126 |
127 | @functools.lru_cache(maxsize=None)
128 | def get_rank(group: Optional[dist.ProcessGroup] = None) -> int:
129 | """
130 | Wrapper of torch.distributed.get_rank.
131 | Get the rank of current process.
132 | """
133 | if group is not None and is_initialized():
134 | return dist.get_rank(group=group)
135 | return RANK
136 |
137 |
138 | @functools.lru_cache(maxsize=None)
139 | def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
140 | """
141 | Wrapper of torch.distributed.get_world_size.
142 | Get the world size of current process.
143 | """
144 | if group is not None and is_initialized():
145 | return dist.get_world_size(group=group)
146 | return WORLD_SIZE
147 |
148 |
149 | def barrier(*args: ..., **kwargs: ...) -> None:
150 | """
151 | Wrapper for torch.distributed.barrier.
152 | Synchronizes all processes.
153 | see `torch.distributed.distributed_c10d.barrier.__doc__` for more information.
154 | """
155 | if is_initialized():
156 | return dist.barrier(*args, **kwargs)
157 |
158 |
159 | @contextlib.contextmanager
160 | def synchronized() -> ContextManager[None]:
161 | """
162 | context manager version of barrier() function.
163 | """
164 | barrier()
165 | yield
166 | barrier()
167 |
168 |
169 | @functools.lru_cache(maxsize=None)
170 | def dev(group: Optional[dist.ProcessGroup] = None) -> torch.device:
171 | """
172 | Get the device to use for torch.distributed.
173 | """
174 | if _cuda_available():
175 | return torch.device("cuda:{}".format(get_rank(group)))
176 | return torch.device("cpu")
177 |
178 |
179 | try:
180 | def load_state_dict(local_or_remote_path: Union[str, os.PathLike], **kwargs: ...) -> Any:
181 | """
182 | Load a PyTorch file.
183 | """
184 | with bf.BlobFile(local_or_remote_path, "rb") as f:
185 | data = f.read()
186 | return torch.load(io.BytesIO(data), **kwargs)
187 | import blobfile as bf # NOQA: F401
188 | except ImportError:
189 | def load_state_dict(local_or_remote_path: Union[str, os.PathLike], **kwargs: ...) -> Any:
190 | """
191 | Load a PyTorch file.
192 | """
193 | return torch.load(local_or_remote_path, **kwargs)
194 |
195 |
196 | def broadcast(
197 | tensor: Sequence[torch.Tensor],
198 | src: Optional[int] = 0,
199 | group: Optional[dist.ProcessGroup] = None,
200 | async_op: bool = False
201 | ) -> None:
202 | """
203 | Synchronize a Tensor across ranks from {src} rank. (default=0)
204 | :param tensor: torch.Tensor.
205 | :param src: source rank to sync params from. default is 0.
206 | :param group:
207 | :param async_op:
208 | """
209 | if not is_initialized():
210 | return
211 | with torch.no_grad():
212 | dist.broadcast(tensor, src, group=group, async_op=async_op)
213 |
214 |
215 | def sync_params(
216 | params: Sequence[torch.Tensor],
217 | src: Optional[int] = 0,
218 | group: Optional[dist.ProcessGroup] = None,
219 | async_op: bool = False
220 | ) -> None:
221 | """
222 | Synchronize a sequence of Tensors across ranks from {src} rank. (default=0)
223 | :param params: Sequence of torch.Tensor.
224 | :param src: source rank to sync params from. default is 0.
225 | :param group:
226 | :param async_op:
227 | """
228 | if not is_initialized():
229 | return
230 | for p in params:
231 | broadcast(p, src, group=group, async_op=async_op)
232 |
233 |
234 | def sequential_print(*args: ..., **kwargs: ...) -> None:
235 | """
236 | Print argument sequentially by rank order.
237 | Arguments are passed to print function.
238 | """
239 | rank = get_rank()
240 | for i in range(get_world_size()):
241 | if i == rank:
242 | print(*args, **kwargs)
243 | barrier()
244 |
245 |
246 | def print_master_node(*args: ..., **kwargs: ...) -> None:
247 | """
248 | Print argument only on master node.
249 | Arguments are passed to print function.
250 | """
251 | if get_rank() == 0:
252 | print(*args, **kwargs)
253 | barrier()
254 |
255 |
256 | @contextlib.contextmanager
257 | def dummy_context() -> ContextManager[None]:
258 | """
259 | Dummy context manager.
260 | """
261 | yield
262 |
263 |
264 | @contextlib.contextmanager
265 | def no_sync(*modules: DistributedDataParallel) -> ContextManager[None]:
266 | """
267 | Context Manager or Decorator of multiple modules' no_sync().
268 | Usage
269 | with no_sync(ddp_module_1, ddp_module_2, ddp_module_3, ddp_module_n):
270 | ... # your code
271 | """
272 | if not modules:
273 | yield
274 | else:
275 | with contextlib.ExitStack() as stk:
276 | for module in modules:
277 | if not isinstance(module, DistributedDataParallel):
278 | raise TypeError(
279 | "arg {!r} should be instance of DistributedDataParallel, got {}"
280 | .format(module, type(module))
281 | )
282 | stk.enter_context(module.no_sync())
283 | yield
284 |
285 |
286 | try:
287 | # patch some os module functions - since file io uses os.sync
288 | import nt # NOQA
289 | os.sync = nt.sync = lambda: None # signature: () -> None
290 | except ImportError:
291 | pass
292 |
--------------------------------------------------------------------------------
/basic_utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | from abc import abstractmethod, ABC
7 | import os
8 | import sys
9 | import json
10 | import time
11 | import datetime
12 | import tempfile
13 | import warnings
14 | from collections import defaultdict
15 | from contextlib import contextmanager
16 | import wandb
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(ABC):
27 | @abstractmethod
28 | def writekvs(self, kvs):
29 | raise NotImplementedError
30 |
31 |
32 | class SeqWriter(ABC):
33 | @abstractmethod
34 | def writeseq(self, seq):
35 | raise NotImplementedError
36 |
37 |
38 | class HumanOutputFormat(KVWriter, SeqWriter):
39 | def __init__(self, filename_or_file):
40 | if isinstance(filename_or_file, str):
41 | self.file = open(filename_or_file, "wt")
42 | self.own_file = True
43 | else:
44 | assert hasattr(filename_or_file, "read"), (
45 | "expected file or str, got %s" % filename_or_file
46 | )
47 | self.file = filename_or_file
48 | self.own_file = False
49 |
50 | def writekvs(self, kvs):
51 | # Create strings for printing
52 | key2str = {}
53 | for (key, val) in sorted(kvs.items()):
54 | if hasattr(val, "__float__"):
55 | valstr = "%-8.3g" % val
56 | else:
57 | valstr = str(val)
58 | key2str[self._truncate(key)] = self._truncate(valstr)
59 |
60 | # Find max widths
61 | if len(key2str) == 0:
62 | print("WARNING: tried to write empty key-value dict")
63 | return
64 | else:
65 | keywidth = max(map(len, key2str.keys()))
66 | valwidth = max(map(len, key2str.values()))
67 |
68 | # Write out the data
69 | dashes = "-" * (keywidth + valwidth + 7)
70 | lines = [dashes]
71 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
72 | lines.append(
73 | "| %s%s | %s%s |"
74 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
75 | )
76 | lines.append(dashes)
77 | self.file.write("\n".join(lines) + "\n")
78 |
79 | # Flush the output to the file
80 | self.file.flush()
81 |
82 | @staticmethod
83 | def _truncate(s):
84 | maxlen = 30
85 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
86 |
87 | def writeseq(self, seq):
88 | seq = list(seq)
89 | for (i, elem) in enumerate(seq):
90 | self.file.write(elem)
91 | if i < len(seq) - 1: # add space unless this is the last one
92 | self.file.write(" ")
93 | self.file.write("\n")
94 | self.file.flush()
95 |
96 | def close(self):
97 | if self.own_file:
98 | self.file.close()
99 |
100 |
101 | class JSONOutputFormat(KVWriter):
102 | def __init__(self, filename):
103 | self.file = open(filename, "wt")
104 |
105 | def writekvs(self, kvs):
106 | for k, v in sorted(kvs.items()):
107 | if hasattr(v, "dtype"):
108 | kvs[k] = float(v)
109 | self.file.write(json.dumps(kvs) + "\n")
110 | self.file.flush()
111 |
112 | def close(self):
113 | self.file.close()
114 |
115 |
116 | class CSVOutputFormat(KVWriter):
117 | def __init__(self, filename):
118 | self.file = open(filename, "w+t")
119 | self.keys = []
120 | self.sep = ","
121 |
122 | def writekvs(self, kvs):
123 | # Add our current row to the history
124 | extra_keys = list(kvs.keys() - self.keys)
125 | extra_keys.sort()
126 | if extra_keys:
127 | self.keys.extend(extra_keys)
128 | self.file.seek(0)
129 | lines = self.file.readlines()
130 | self.file.seek(0)
131 | for (i, k) in enumerate(self.keys):
132 | if i > 0:
133 | self.file.write(",")
134 | self.file.write(k)
135 | self.file.write("\n")
136 | for line in lines[1:]:
137 | self.file.write(line[:-1])
138 | self.file.write(self.sep * len(extra_keys))
139 | self.file.write("\n")
140 | for (i, k) in enumerate(self.keys):
141 | if i > 0:
142 | self.file.write(",")
143 | v = kvs.get(k)
144 | if v is not None:
145 | self.file.write(str(v))
146 | self.file.write("\n")
147 | self.file.flush()
148 |
149 | def close(self):
150 | self.file.close()
151 |
152 |
153 | class TensorBoardOutputFormat(KVWriter):
154 | """
155 | Dumps key/value pairs into TensorBoard's numeric format.
156 | """
157 |
158 | def __init__(self, dir):
159 | os.makedirs(dir, exist_ok=True)
160 | self.dir = dir
161 | self.step = 1
162 | prefix = "events"
163 | path = os.path.join(os.path.abspath(dir), prefix)
164 | import tensorflow as tf
165 | from tensorflow.python import pywrap_tensorflow
166 | from tensorflow.core.util import event_pb2
167 | from tensorflow.python.util import compat
168 |
169 | self.tf = tf
170 | self.event_pb2 = event_pb2
171 | self.pywrap_tensorflow = pywrap_tensorflow
172 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
173 |
174 | def writekvs(self, kvs):
175 | def summary_val(k, v):
176 | kwargs = {"tag": k, "simple_value": float(v)}
177 | return self.tf.Summary.Value(**kwargs)
178 |
179 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
180 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
181 | event.step = (
182 | self.step
183 | ) # is there any reason why you'd want to specify the step?
184 | self.writer.WriteEvent(event)
185 | self.writer.Flush()
186 | self.step += 1
187 |
188 | def close(self):
189 | if self.writer:
190 | self.writer.Close()
191 | self.writer = None
192 |
193 |
194 | def make_output_format(format, ev_dir, log_suffix=""):
195 | os.makedirs(ev_dir, exist_ok=True)
196 | if format == "stdout":
197 | return HumanOutputFormat(sys.stdout)
198 | elif format == "log":
199 | return HumanOutputFormat(os.path.join(ev_dir, "log%s.txt" % log_suffix))
200 | elif format == "json":
201 | return JSONOutputFormat(os.path.join(ev_dir, "progress%s.json" % log_suffix))
202 | elif format == "csv":
203 | return CSVOutputFormat(os.path.join(ev_dir, "progress%s.csv" % log_suffix))
204 | elif format == "tensorboard":
205 | return TensorBoardOutputFormat(os.path.join(ev_dir, "tb%s" % log_suffix))
206 | else:
207 | raise ValueError("Unknown format specified: %s" % (format,))
208 |
209 |
210 | # ================================================================
211 | # API
212 | # ================================================================
213 |
214 |
215 | def logkv(key, val):
216 | """
217 | Log a value of some diagnostic
218 | Call this once for each diagnostic quantity, each iteration
219 | If called many times, last value will be used.
220 | """
221 | get_current().logkv(key, val)
222 |
223 |
224 | def logkv_mean(key, val):
225 | """
226 | The same as logkv(), but if called many times, values averaged.
227 | """
228 | get_current().logkv_mean(key, val)
229 |
230 |
231 | def logkvs(d):
232 | """
233 | Log a dictionary of key-value pairs
234 | """
235 | for (k, v) in d.items():
236 | logkv(k, v)
237 |
238 |
239 | def dumpkvs():
240 | """
241 | Write all of the diagnostics from the current iteration
242 | """
243 | return get_current().dumpkvs()
244 |
245 |
246 | def getkvs():
247 | return get_current().name2val
248 |
249 |
250 | def log(*args, level=INFO):
251 | """
252 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
253 | """
254 | get_current().log(*args, level=level)
255 |
256 |
257 | def debug(*args):
258 | log(*args, level=DEBUG)
259 |
260 |
261 | def info(*args):
262 | log(*args, level=INFO)
263 |
264 |
265 | def warn(*args):
266 | log(*args, level=WARN)
267 |
268 |
269 | def error(*args):
270 | log(*args, level=ERROR)
271 |
272 |
273 | def set_level(level):
274 | """
275 | Set logging threshold on current logger.
276 | """
277 | get_current().set_level(level)
278 |
279 |
280 | def set_comm(comm):
281 | get_current().set_comm(comm)
282 |
283 |
284 | def get_dir():
285 | """
286 | Get directory that log files are being written to.
287 | will be None if there is no output directory (i.e., if you didn't call start)
288 | """
289 | return get_current().get_dir()
290 |
291 |
292 | record_tabular = logkv
293 | dump_tabular = dumpkvs
294 |
295 |
296 | @contextmanager
297 | def profile_kv(scopename):
298 | logkey = "wait_" + scopename
299 | tstart = time.time()
300 | try:
301 | yield
302 | finally:
303 | get_current().name2val[logkey] += time.time() - tstart
304 |
305 |
306 | def profile(n):
307 | """
308 | Usage:
309 | @profile("my_func")
310 | def my_func(): code
311 | """
312 |
313 | def decorator_with_name(func):
314 | def func_wrapper(*args, **kwargs):
315 | with profile_kv(n):
316 | return func(*args, **kwargs)
317 |
318 | return func_wrapper
319 |
320 | return decorator_with_name
321 |
322 |
323 | # ================================================================
324 | # Backend
325 | # ================================================================
326 |
327 |
328 | def get_current():
329 | if Logger.CURRENT is None:
330 | _configure_default_logger()
331 |
332 | return Logger.CURRENT
333 |
334 |
335 | class Logger(object):
336 | DEFAULT = None # A logger with no output files. (See right below class definition)
337 | # So that you can still log to the terminal without setting up any output files
338 | CURRENT = None # Current logger being used by the free functions above
339 |
340 | def __init__(self, dir, output_formats, comm=None):
341 | self.name2val = defaultdict(float) # values this iteration
342 | self.name2cnt = defaultdict(int)
343 | self.level = INFO
344 | self.dir = dir
345 | self.output_formats = output_formats
346 | self.comm = comm
347 |
348 | # Logging API, forwarded
349 | # ----------------------------------------
350 | def logkv(self, key, val):
351 | self.name2val[key] = val
352 |
353 | def logkv_mean(self, key, val):
354 | oldval, cnt = self.name2val[key], self.name2cnt[key]
355 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
356 | self.name2cnt[key] = cnt + 1
357 |
358 | def dumpkvs(self):
359 | if self.comm is None:
360 | d = self.name2val
361 | else:
362 | d = mpi_weighted_mean(
363 | self.comm,
364 | {
365 | name: (val, self.name2cnt.get(name, 1))
366 | for (name, val) in self.name2val.items()
367 | },
368 | )
369 | if self.comm.rank != 0:
370 | d["dummy"] = 1 # so we don't get a warning about empty dict
371 | # LISA
372 | out = d.copy() # Return the dict for unit testing purposes
373 | if int(os.environ['LOCAL_RANK']) == 0:
374 | wandb.log({**d})
375 | for fmt in self.output_formats:
376 | if isinstance(fmt, KVWriter):
377 | fmt.writekvs(d)
378 | self.name2val.clear()
379 | self.name2cnt.clear()
380 | return out
381 |
382 | def log(self, *args, level=INFO):
383 | if self.level <= level:
384 | self._do_log(args)
385 |
386 | # Configuration
387 | # ----------------------------------------
388 | def set_level(self, level):
389 | self.level = level
390 |
391 | def set_comm(self, comm):
392 | self.comm = comm
393 |
394 | def get_dir(self):
395 | return self.dir
396 |
397 | def close(self):
398 | for fmt in self.output_formats:
399 | fmt.close()
400 |
401 | # Misc
402 | # ----------------------------------------
403 | def _do_log(self, args):
404 | for fmt in self.output_formats:
405 | if isinstance(fmt, SeqWriter):
406 | fmt.writeseq(map(str, args))
407 |
408 |
409 | def get_rank_without_mpi_import():
410 | # check environment variables here instead of importing mpi4py
411 | # to avoid calling MPI_Init() when this module is imported
412 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
413 | if varname in os.environ:
414 | return int(os.environ[varname])
415 | return 0
416 |
417 |
418 | def mpi_weighted_mean(comm, local_name2valcount):
419 | """
420 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
421 | Perform a weighted average over dicts that are each on a different node
422 | Input: local_name2valcount: dict mapping key -> (value, count)
423 | Returns: key -> mean
424 | """
425 | all_name2valcount = comm.gather(local_name2valcount)
426 | if comm.rank == 0:
427 | name2sum = defaultdict(float)
428 | name2count = defaultdict(float)
429 | for n2vc in all_name2valcount:
430 | for (name, (val, count)) in n2vc.items():
431 | try:
432 | val = float(val)
433 | except ValueError:
434 | if comm.rank == 0:
435 | warnings.warn(
436 | "WARNING: tried to compute mean on non-float {}={}".format(
437 | name, val
438 | )
439 | )
440 | else:
441 | name2sum[name] += val * count
442 | name2count[name] += count
443 | return {name: name2sum[name] / name2count[name] for name in name2sum}
444 | else:
445 | return {}
446 |
447 |
448 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
449 | """
450 | If comm is provided, average all numerical stats across that comm
451 | """
452 | if dir is None:
453 | dir = os.getenv("OPENAI_LOGDIR")
454 | if dir is None:
455 | dir = os.path.join(
456 | tempfile.gettempdir(),
457 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
458 | )
459 | assert isinstance(dir, str)
460 | dir = os.path.expanduser(dir)
461 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
462 |
463 | rank = get_rank_without_mpi_import()
464 | if rank > 0:
465 | log_suffix = log_suffix + "-rank%03i" % rank
466 |
467 | if format_strs is None:
468 | if rank == 0:
469 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
470 | else:
471 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
472 | format_strs = filter(None, format_strs)
473 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
474 |
475 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
476 | if output_formats:
477 | log("Logging to %s" % dir)
478 |
479 |
480 | def _configure_default_logger():
481 | configure()
482 | Logger.DEFAULT = Logger.CURRENT
483 |
484 |
485 | def reset():
486 | if Logger.CURRENT is not Logger.DEFAULT:
487 | Logger.CURRENT.close()
488 | Logger.CURRENT = Logger.DEFAULT
489 | log("Reset logger")
490 |
491 |
492 | @contextmanager
493 | def scoped_configure(dir=None, format_strs=None, comm=None):
494 | prevlogger = Logger.CURRENT
495 | configure(dir=dir, format_strs=format_strs, comm=comm)
496 | try:
497 | yield
498 | finally:
499 | Logger.CURRENT.close()
500 | Logger.CURRENT = prevlogger
501 |
--------------------------------------------------------------------------------
/basic_utils/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class InfiniteSampler(torch.utils.data.Sampler):
5 | def __init__(
6 | self,
7 | dataset, num_replicas: int = 1,
8 | rank: int = None, shuffle: bool = True,
9 | seed: int = 0, window_size: float = 0.5
10 | ) -> None:
11 | from . import dist_util
12 | assert len(dataset) > 0
13 | assert num_replicas > 0
14 | assert 0 <= rank < num_replicas
15 | assert 0 <= window_size <= 1
16 | super().__init__(dataset)
17 | self.dataset = dataset
18 | self.num_replicas = num_replicas if num_replicas is not None else dist_util.get_world_size()
19 | self.rank = rank if rank is not None else dist_util.get_rank()
20 | self.shuffle = shuffle
21 | self.seed = seed
22 | self.window_size = window_size
23 |
24 | def __iter__(self):
25 | length = len(self.dataset)
26 | if self.shuffle:
27 | g = torch.Generator()
28 | g.manual_seed(self.seed + self.epoch)
29 | order = torch.randperm(length, generator=g).tolist() # type: ignore[arg-type]
30 | window = round(length * self.window_size)
31 | else:
32 | g = None
33 | order = list(range(length)) # type: ignore[arg-type]
34 | window = 0
35 |
36 | idx = 0
37 | while True:
38 | i = idx % length
39 | if idx % self.num_replicas == self.rank:
40 | yield order[i]
41 | if window >= 2:
42 | j = (i - torch.randint(window, size=(), generator=g)) % length
43 | order[i], order[j] = order[j], order[i]
44 | idx += 1
45 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/config/__init__.py
--------------------------------------------------------------------------------
/config/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Pre-configured base setting class, type (choice), and item with description.
3 | Compatible with ArgumentParser.
4 | """
5 | from typing import Literal
6 | try:
7 | from typing import get_args
8 | except ImportError:
9 | from typing_extensions import get_args
10 | from argparse import ArgumentParser as Ap, ArgumentDefaultsHelpFormatter as Df
11 | from pydantic import BaseModel, Field, validator
12 | from pydantic.validators import bool_validator
13 |
14 |
15 | class ArgparseCompatibleBaseModel(BaseModel):
16 |
17 | class Config: # Override this to allow extra kwargs
18 | extra = "forbid"
19 |
20 | @classmethod
21 | def from_argparse(cls, namespace, __top=True):
22 | if not isinstance(namespace, dict):
23 | namespace = vars(namespace)
24 | kwargs = {}
25 | for name, field in cls.__fields__.items():
26 | if isinstance(field.type_, type) and issubclass(field.type_, BaseModel):
27 | kwargs[name] = ArgparseCompatibleBaseModel.from_argparse.__func__(field.type_, namespace, __top=False) # NOQA
28 | else:
29 | kwargs[name] = namespace.pop(name)
30 | assert not (__top and namespace), str(namespace)
31 | return cls(**kwargs)
32 |
33 | @classmethod
34 | def to_argparse(cls, parser_or_group=None):
35 | if parser_or_group is None:
36 | parser_or_group = Ap(formatter_class=Df)
37 | for name, field in cls.__fields__.items():
38 | if isinstance(field.type_, type) and issubclass(field.type_, BaseModel):
39 | group = parser_or_group.add_argument_group(name)
40 | ArgparseCompatibleBaseModel.to_argparse.__func__(field.type_, group) # NOQA
41 | continue
42 | kw = dict(dest=name, type=field.type_, default=field.default,
43 | help=field.field_info.description, required=field.required)
44 | if getattr(field.type_, '__origin__', None) is Literal:
45 | choices = tuple(get_args(field.outer_type_))
46 | s = "def {name}(arg):\n for ch in __CHOICES__:\n" \
47 | " if str(ch) == arg:\n return ch\n raise ValueError" \
48 | .format(name=name)
49 | n = {"__CHOICES__": choices, "__name__": name}
50 | exec(s, n)
51 | kw.update(type=n[name], choices=choices, metavar="{"+", ".join(map(str, choices))+"}")
52 | elif isinstance(field.type_, type) and issubclass(field.type_, bool):
53 | kw.update(type=bool_validator, metavar="{true, false}")
54 | parser_or_group.add_argument("--" + name, **kw)
55 | return parser_or_group
56 |
57 | @classmethod
58 | def from_argv(cls, argv=None):
59 | return cls.from_argparse(cls.to_argparse().parse_args(argv))
60 |
61 |
62 | S = Setting = ArgparseCompatibleBaseModel # Alias
63 |
64 |
65 | def choice(*args):
66 | return Literal.__getitem__(args)
67 |
68 |
69 | C = Choice = choice # Alias
70 |
71 |
72 | def item(default, description=None):
73 | return Field(default, description=description)
74 |
75 |
76 | _ = Item = item # Alias
77 |
78 |
79 | Validator = validator # Alias
80 |
81 |
82 | __all__ = (
83 | 'ArgparseCompatibleBaseModel', 'Setting', 'S',
84 | 'choice', 'Choice', 'C',
85 | 'item', 'Item', '_',
86 | 'validator', 'Validator',
87 | )
88 |
89 |
90 | if __name__ == '__main__':
91 |
92 | import yaml
93 |
94 | class Config1(S):
95 | a: int = _(1, description='this is a')
96 | b: int = _(2, description='this is b')
97 |
98 | class Config2(S):
99 | c: C('choice1', 'choice2') = _('choice2', description='this is c')
100 | d: bool = _(True, description='this is d')
101 |
102 | class Config(S):
103 | conf1: Config1 = Config1()
104 | conf2: Config2 = Config2()
105 |
106 | Config.to_argparse().print_help()
107 | yaml.dump(Config.from_argv().dict(), __import__('sys').stdout)
108 |
--------------------------------------------------------------------------------
/config/train.py:
--------------------------------------------------------------------------------
1 | from typing import final
2 | from argparse import ArgumentParser as Ap, ArgumentDefaultsHelpFormatter as Df
3 | from .base import S, Choice, Item as _
4 |
5 |
6 | class GeneralSettings(S):
7 | lr: float \
8 | = _(1e-4, "Learning Rate")
9 | batch_size: int \
10 | = _(2048, "Batch size of running step and optimizing")
11 | microbatch: int \
12 | = _(64, "Batch size for forward and backward")
13 | learning_steps: int \
14 | = _(320000, "Steps for whole iteration")
15 | log_interval: int \
16 | = _(20, "Steps per log")
17 | save_interval: int \
18 | = _(2000, "Steps per save")
19 | eval_interval: int \
20 | = _(1000, "Steps per eval")
21 | ema_rate: str \
22 | = _("0.5,0.9,0.99", "EMA rate. separate rates by comma(',').")
23 | seed: int \
24 | = _(102, "Seed for train or test.")
25 | resume_checkpoint: str \
26 | = _("", "Checkpoint path(.pt) to resume training")
27 | checkpoint_path: str \
28 | = _("", "! This will be automatically updated while training !")
29 | gradient_clipping: float \
30 | = _(0., "Gradient clipping (>0), default: 0 (no clipping). ")
31 | weight_decay: float \
32 | = _(0., "Weight decay.")
33 |
34 |
35 | class DataSettings(S):
36 | dataset: str \
37 | = _("dataset", "Name of dataset.")
38 | data_dir: str \
39 | = _("datasets/dataset", "Path for dataset to be saved.")
40 | data_loader_workers: int \
41 | = _(2, "num_workers for DataLoader.")
42 |
43 |
44 | class YourSettings(S):
45 | # TODO: add extra settings on your own
46 | pass
47 |
48 |
49 | @final
50 | class TrainSettings(
51 | YourSettings,
52 | # TODO: inherit setting classes "reversely", due to python mro.
53 | DataSettings,
54 | GeneralSettings
55 | ):
56 |
57 | @classmethod
58 | def to_argparse(cls, parser_or_group=None, add_json=False):
59 | if not add_json:
60 | return super(TrainSettings, cls).to_argparse(parser_or_group)
61 | if parser_or_group is None:
62 | parser_or_group = Ap(formatter_class=Df)
63 | setting_group = parser_or_group.add_argument_group(title="settings")
64 | setting_group.add_mutually_exclusive_group().add_argument(
65 | "--config_json", type=str, required=False,
66 | help="You can alter arguments all below by config_json file.")
67 | super(TrainSettings, cls).to_argparse(setting_group.add_mutually_exclusive_group())
68 | return parser_or_group
69 |
70 | @classmethod
71 | def from_argparse(cls, namespace, __top=True):
72 | if getattr(namespace, "config_json", None):
73 | return cls.parse_file(namespace.config_json)
74 | else:
75 | if hasattr(namespace, "config_json"):
76 | delattr(namespace, "config_json")
77 | return super(TrainSettings, cls).from_argparse(namespace)
78 |
79 |
80 | __all__ = ('TrainSettings', )
81 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | def load_data_from_args(
2 | split,
3 | data_dir,
4 | # TODO: add dataset args on your own
5 | batch_size: int,
6 | deterministic: bool = False,
7 | loop: bool = True,
8 | seed: ... = 0,
9 | window_size: float = 0.5,
10 | num_loader_proc: int = 1,
11 | ):
12 | from basic_utils import dist_util
13 | from basic_utils.sampler import InfiniteSampler
14 | from torch.utils.data.distributed import DistributedSampler
15 | from torch.utils.data import DataLoader
16 | from .dataset import CustomDataset
17 |
18 | rank = dist_util.get_rank()
19 | num_replicas = dist_util.get_world_size()
20 |
21 | # TODO: add your data-loading function from split & data_dir & your arguments
22 | dataset_kwargs = dict(data_dir=data_dir, split=split) # TODO
23 | dataset = CustomDataset(**dataset_kwargs)
24 |
25 | sampler_kwargs = dict(
26 | num_replicas=num_replicas,
27 | rank=rank,
28 | shuffle=not deterministic,
29 | seed=hash(seed)
30 | )
31 | if loop:
32 | sampler = InfiniteSampler(dataset, **sampler_kwargs, window_size=window_size)
33 | else:
34 | sampler = DistributedSampler(dataset, **sampler_kwargs, drop_last=False)
35 |
36 | return DataLoader(
37 | dataset,
38 | batch_size=batch_size // num_replicas,
39 | sampler=sampler,
40 | num_workers=num_loader_proc,
41 | persistent_workers=num_loader_proc > 0,
42 | # pin_memory=True,
43 | )
44 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Add your dataset configuration
2 | from torch.utils.data import Dataset
3 |
4 |
5 | class CustomDataset(Dataset):
6 | # TODO: implement this class
7 |
8 | def __init__(self, *args, **kwargs):
9 | raise NotImplementedError
10 |
11 | def __getitem__(self, item):
12 | raise NotImplementedError
13 |
14 | def __len__(self):
15 | raise NotImplementedError
16 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/models/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -f https://download.pytorch.org/whl/torch_stable.html
2 | torch>=1.9.0
3 | -f https://download.pytorch.org/whl/torch_stable.html
4 | torchvision>=0.10.0
5 |
6 | blobfile
7 | pydantic
8 |
--------------------------------------------------------------------------------
/run_train.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node gpu train.py --config_json train_config.json
2 | # python -m torch.distributed.launch --use_env --nproc_per_node 4 train.py --config_json train_config.json # <=1.8
3 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.distributed.elastic.multiprocessing.errors import record
3 | from config.train import TrainSettings
4 |
5 |
6 | def parse_args():
7 | parser = TrainSettings.to_argparse(add_json=True)
8 | return parser.parse_args()
9 |
10 |
11 | def train(rank, args):
12 |
13 | # Import dependencies
14 | import time
15 | import json
16 |
17 | # Import everything
18 | from data import load_data_from_args
19 | from basic_utils import dist_util, logger
20 | from utils.initialization import create_model_from_config, seed_all
21 | from utils.trainer import TrainLoop
22 |
23 | dist_util.barrier() # Sync
24 |
25 | # Set checkpoint path
26 | folder_name = "model_checkpoints/"
27 | if not os.path.isdir(folder_name) and rank == 0:
28 | os.mkdir(folder_name)
29 | if not args.checkpoint_path:
30 | model_file = f"Run_{args.dataset}_lr{args.lr}" \
31 | f"_seed{args.seed}_{time.strftime('%Y%m%d-%H:%M:%S')}" # TODO: add your naming rule by args
32 | args.checkpoint_path = os.path.join(folder_name, model_file)
33 | if not os.path.isdir(args.checkpoint_path) and rank == 0:
34 | os.mkdir(args.checkpoint_path)
35 |
36 | # Configure log and seed
37 | logger.configure(dir=args.checkpoint_path, format_strs=["log", "csv"] + (["stdout"] if rank == 0 else []))
38 | seed_all(args.seed)
39 |
40 | # Prepare dataloader
41 | logger.log("### Creating data loader...")
42 | dist_util.barrier() # Sync
43 | data = load_data_from_args(
44 | split='train',
45 | data_dir=args.data_dir,
46 | batch_size=args.batch_size,
47 | # TODO: add args on your own
48 | deterministic=False,
49 | loop=True,
50 | seed=args.seed,
51 | num_loader_proc=args.data_loader_workers,
52 | )
53 | data_valid = load_data_from_args(
54 | split='valid',
55 | data_dir=args.data_dir,
56 | batch_size=args.batch_size,
57 | # TODO: add args on your own
58 | deterministic=True,
59 | loop=True,
60 | seed=args.seed,
61 | num_loader_proc=args.data_loader_workers,
62 | )
63 | dist_util.barrier() # Sync
64 |
65 | # Initialize model
66 | logger.log("### Creating model...")
67 | model = create_model_from_config(**args.dict())
68 |
69 | # Load model to each node's device
70 | model.to(dist_util.dev())
71 | dist_util.barrier() # Sync
72 |
73 | # Count and log total params
74 | pytorch_total_params = sum(p.numel() for p in model.parameters())
75 | logger.log(f'### The parameter count is {pytorch_total_params}')
76 |
77 | # Save training args
78 | training_args_path = f'{args.checkpoint_path}/training_args.json'
79 | if not os.path.exists(training_args_path):
80 | logger.log(f'### Saving the hyperparameters to {args.checkpoint_path}/training_args.json')
81 | if dist_util.get_rank() == 0:
82 | with open(training_args_path, 'w') as fp:
83 | json.dump(args.dict(), fp, indent=2)
84 |
85 | # Init wandb
86 | if dist_util.get_rank() == 0:
87 | # TODO: Uncomment and customize your wandb setting on your own, or just use environ.
88 | import wandb
89 | wandb.init(
90 | mode=os.getenv("WANDB_MODE", "online"),
91 | # entity=os.getenv("WANDB_ENTITY", ""),
92 | # project=os.getenv("WANDB_PROJECT", ""),
93 | )
94 | wandb.config.update(args.__dict__, allow_val_change=True)
95 | dist_util.barrier() # Sync last
96 |
97 | # Run train loop
98 | logger.log("### Training...")
99 | train_loop = TrainLoop(
100 | model=model,
101 | # TODO: add your argument
102 | data=data,
103 | batch_size=args.batch_size,
104 | microbatch=args.microbatch,
105 | lr=args.lr,
106 | ema_rate=args.ema_rate,
107 | log_interval=args.log_interval,
108 | save_interval=args.save_interval,
109 | resume_checkpoint=args.resume_checkpoint,
110 | weight_decay=args.weight_decay,
111 | learning_steps=args.learning_steps,
112 | checkpoint_path=args.checkpoint_path,
113 | gradient_clipping=args.gradient_clipping,
114 | eval_data=data_valid,
115 | eval_interval=args.eval_interval,
116 | eval_callbacks=[]
117 | )
118 | train_loop.run_loop()
119 |
120 |
121 | @record
122 | def main(namespace):
123 |
124 | # Create config from parsed argument namespace
125 | args: TrainSettings = TrainSettings.from_argparse(namespace)
126 |
127 | # Import dist_util
128 | from basic_utils import dist_util
129 |
130 | # Setup distributed
131 | if "LOCAL_RANK" in os.environ:
132 | dist_util.setup_dist()
133 | rank = dist_util.get_rank()
134 |
135 | # Run training
136 | try:
137 | train(rank, args)
138 | finally:
139 | dist_util.cleanup_dist()
140 |
141 |
142 | if __name__ == "__main__":
143 | main(parse_args())
144 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdha0727/distributed-pipeline/ed11369eebbd2f59655ad47affec436ed0b64284/utils/__init__.py
--------------------------------------------------------------------------------
/utils/initialization.py:
--------------------------------------------------------------------------------
1 | def seed_all(seed, deterministic=False):
2 | import random
3 | import numpy as np
4 | import torch
5 | from basic_utils.dist_util import get_rank
6 | if deterministic:
7 | seed = int(seed)
8 | torch.backends.cudnn.deterministic = True # NOQA
9 | torch.backends.cudnn.benchmark = False # NOQA
10 | else:
11 | seed = int(seed) + get_rank() # Make seed differ by node rank
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed) # contains torch.cuda.manual_seed_all
15 | # TODO: Add your seeder if needed
16 |
17 |
18 | def create_model_from_config(
19 | *,
20 | argument1,
21 | argument2,
22 | argument3,
23 | **_
24 | ):
25 | # TODO: Implement this function
26 | model = ...
27 | return model
28 |
--------------------------------------------------------------------------------
/utils/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Trainer copied and adapted from:
3 | https://github.com/Shark-NLP/DiffuSeq/blob/8bfafcbb26df218073b8117234afb9de9dfcbec9/train_util.py
4 | """
5 | import copy
6 | import os
7 |
8 | import blobfile as bf
9 | import math
10 | import torch
11 | from torch.nn.parallel.distributed import DistributedDataParallel
12 | from torch.optim import AdamW
13 |
14 | from basic_utils import dist_util, logger
15 |
16 |
17 | class TrainLoop:
18 |
19 | def log_loss_dict(self, mode, losses, *args, **kwargs): # mode: train or eval
20 | # TODO: Add your loss logging function with logger.logkv_mean
21 | raise NotImplementedError
22 |
23 | def compute_losses(self, micro_batch):
24 | # TODO: Add your loss function with logger.logkv_mean
25 | raise NotImplementedError
26 |
27 | @staticmethod
28 | def backward_from_losses(losses):
29 | # TODO: Add your loss-reducing function
30 | loss = ...
31 | loss.backward()
32 |
33 | @classmethod
34 | def get_batch_length(cls, batch):
35 | if isinstance(batch, torch.Tensor):
36 | return batch.shape[0]
37 | elif isinstance(batch, dict):
38 | return cls.get_batch_length(batch[list(batch.keys())[0]])
39 | elif isinstance(batch, list):
40 | return cls.get_batch_length(batch[0])
41 | else:
42 | # TODO: Add your custom function if needed
43 | raise TypeError("Unsupported batch type: {}".format(type(batch).__name__))
44 |
45 | def __init__(
46 | self,
47 | *,
48 | model,
49 | # TODO: Add your arguments
50 | data,
51 | batch_size,
52 | microbatch,
53 | lr,
54 | ema_rate,
55 | log_interval,
56 | save_interval,
57 | resume_checkpoint,
58 | weight_decay=0.0,
59 | learning_steps=0,
60 | checkpoint_path='',
61 | gradient_clipping=-1.,
62 | eval_data=None,
63 | eval_interval=-1,
64 | eval_callbacks=(), # e.g. plotting utils for wandb
65 | ):
66 | self.model = model
67 | self.data = data
68 | self.eval_data = eval_data
69 | self.batch_size = batch_size
70 | self.microbatch = microbatch if microbatch > 0 else batch_size
71 | self.lr = float(lr)
72 | self.ema_rate = (
73 | [ema_rate]
74 | if isinstance(ema_rate, float)
75 | else [float(x) for x in ema_rate.split(",")]
76 | )
77 | self.log_interval = log_interval
78 | self.eval_interval = eval_interval
79 | self.save_interval = save_interval
80 | self.resume_checkpoint = resume_checkpoint
81 | self.weight_decay = weight_decay
82 | self.learning_steps = learning_steps
83 | self.gradient_clipping = gradient_clipping
84 |
85 | # TODO: Add your arguments
86 |
87 | self.step = 0
88 | self.resume_step = 0
89 | self.global_batch = self.batch_size * dist_util.get_world_size()
90 |
91 | self.model_params = list(self.model.parameters())
92 | self.master_params = self.model_params
93 | self.eval_callbacks = list(eval_callbacks)
94 |
95 | self.checkpoint_path = checkpoint_path # DEBUG **
96 |
97 | self._load_and_sync_parameters()
98 |
99 | self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
100 | if self.resume_step:
101 | # Model was resumed, either due to a restart or a checkpoint
102 | # being specified at the command line.
103 | self._load_optimizer_state()
104 | # frac_done = (self.step + self.resume_step) / self.learning_steps
105 | # lr = self.lr * (1 - frac_done)
106 | # self.opt = AdamW(self.master_params, lr=lr, weight_decay=self.weight_decay)
107 | self.ema_params = [
108 | self._load_ema_parameters(rate) for rate in self.ema_rate
109 | ]
110 | else:
111 | self.ema_params = [
112 | copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
113 | ]
114 |
115 | if dist_util.is_initialized():
116 | self.use_ddp = True
117 | print(dist_util.dev())
118 | ddp_kwargs = dict(
119 | broadcast_buffers=False,
120 | bucket_cap_mb=128,
121 | find_unused_parameters=False,
122 | )
123 | if not torch.cuda.is_available():
124 | ddp_kwargs.update(device_ids=[dist_util.dev()])
125 | ddp_kwargs.update(output_device=dist_util.dev())
126 | self.ddp_model = DistributedDataParallel(self.model, **ddp_kwargs)
127 | else:
128 | self.use_ddp = False
129 | self.ddp_model = self.model
130 |
131 | # In checkpoint-loading process, sometimes GPU 0 is used and allocated.
132 | # After all process is done, free GPU 0 by function below:
133 | torch.cuda.empty_cache()
134 |
135 | # # # # # # # # # # # # # # # Implemented Functions # # # # # # # # # # # # # # #
136 |
137 | def _load_and_sync_parameters(self):
138 | resume_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint
139 | if not resume_checkpoint:
140 | return
141 |
142 | self.resume_step = self.parse_resume_step_from_filename(resume_checkpoint)
143 | if dist_util.get_rank() == 0:
144 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
145 | self.model.load_state_dict(
146 | dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev())
147 | )
148 |
149 | dist_util.sync_params(self.model.parameters())
150 |
151 | def _load_ema_parameters(self, rate):
152 | ema_params = copy.deepcopy(self.master_params)
153 | main_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint
154 | if not main_checkpoint:
155 | return
156 | ema_checkpoint = self.find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
157 | if ema_checkpoint:
158 | if dist_util.get_rank() == 0:
159 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
160 | state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev())
161 | ema_params = self._state_dict_to_master_params(state_dict)
162 |
163 | dist_util.sync_params(ema_params)
164 | return ema_params
165 |
166 | def _load_optimizer_state(self):
167 | main_checkpoint = self.find_resume_checkpoint() or self.resume_checkpoint
168 | if not main_checkpoint:
169 | return
170 | opt_checkpoint = self.find_opt_checkpoint(main_checkpoint, self.resume_step)
171 | if bf.exists(opt_checkpoint):
172 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
173 | state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
174 | self.opt.load_state_dict(state_dict)
175 |
176 | def run_loop(self):
177 | while (
178 | not self.learning_steps
179 | or self.step + self.resume_step < self.learning_steps
180 | ):
181 | batch = next(self.data)
182 | self.run_step(batch)
183 | if self.step % self.log_interval == 0:
184 | logger.dumpkvs()
185 | if self.eval_data is not None and self.step % self.eval_interval == 0:
186 | cond_eval = next(self.eval_data)
187 | self.forward_only(cond_eval)
188 | print('eval on validation set')
189 | logger.dumpkvs()
190 | if dist_util.get_rank() == 0:
191 | for callback in self.eval_callbacks:
192 | callback(self)
193 | if self.step > 0 and self.step % self.save_interval == 0:
194 | self.save()
195 | self.step += 1
196 | if (self.step - 1) % self.save_interval != 0:
197 | self.save()
198 |
199 | def run_step(self, batch):
200 | self.forward_backward(batch)
201 | self.optimize()
202 | self.log_step()
203 |
204 | def _zero_grad(self):
205 | for param in self.model_params:
206 | if param.grad is not None:
207 | param.grad.detach_()
208 | param.grad.zero_()
209 |
210 | def _common_forward(self, total_batch, start_index):
211 | micro_batch = {
212 | k: v[start_index: start_index + self.microbatch].to(dist_util.dev())
213 | for k, v in total_batch.items()
214 | }
215 | last_batch = (start_index + self.microbatch) >= self.get_batch_length(total_batch)
216 |
217 | if last_batch or not self.use_ddp:
218 | losses = self.compute_losses(micro_batch)
219 | else:
220 | with self.ddp_model.no_sync():
221 | losses = self.compute_losses(micro_batch)
222 | return losses
223 |
224 | @torch.no_grad()
225 | def forward_only(self, batch):
226 | self._zero_grad()
227 | for i in range(0, self.get_batch_length(batch), self.microbatch):
228 | losses = self._common_forward(batch, i)
229 | self.log_loss_dict(mode="eval", losses=losses)
230 |
231 | def forward_backward(self, batch):
232 | self._zero_grad()
233 | for i in range(0, self.get_batch_length(batch), self.microbatch):
234 | losses = self._common_forward(batch, i)
235 | self.log_loss_dict(mode="eval", losses=losses)
236 | self.backward_from_losses(losses)
237 |
238 | def optimize(self):
239 | if self.gradient_clipping > 0:
240 | self.grad_clip()
241 | self._log_grad_norm()
242 | self._anneal_lr()
243 | self.opt.step()
244 | for rate, params in zip(self.ema_rate, self.ema_params):
245 | update_ema(params, self.master_params, rate=rate)
246 |
247 | def grad_clip(self):
248 | max_grad_norm = self.gradient_clipping
249 | if hasattr(self.opt, "clip_grad_norm"):
250 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
251 | self.opt.clip_grad_norm(max_grad_norm)
252 | else:
253 | torch.nn.utils.clip_grad_norm_(
254 | self.model.parameters(),
255 | max_grad_norm,
256 | )
257 |
258 | def _anneal_lr(self):
259 | if not self.learning_steps:
260 | return
261 | frac_done = (self.step + self.resume_step) / self.learning_steps
262 | lr = self.lr * (1 - frac_done)
263 | for param_group in self.opt.param_groups:
264 | param_group["lr"] = lr
265 |
266 | def _log_grad_norm(self):
267 | sqsum = 0.0
268 | # cnt = 0
269 | for p in self.master_params:
270 | if p.grad is not None:
271 | sqsum += (p.grad ** 2).sum().item()
272 | logger.logkv_mean("grad_norm", math.sqrt(sqsum))
273 |
274 | def log_step(self):
275 | logger.logkv("step", self.step + self.resume_step)
276 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
277 |
278 | def save(self):
279 | self._save_checkpoint(0, self.master_params)
280 | for r, p in zip(self.ema_rate, self.ema_params):
281 | self._save_checkpoint(r, p)
282 | self._save_opt()
283 | dist_util.barrier()
284 |
285 | def _save_checkpoint(self, rate, params):
286 | state_dict = self._master_params_to_state_dict(params)
287 | if dist_util.get_rank() == 0:
288 | logger.log(f"saving model {rate}...")
289 | if not rate:
290 | filename = f"model_{(self.step+self.resume_step):06d}.pt"
291 | else:
292 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
293 | print('writing to', bf.join(self.checkpoint_path, filename))
294 | with bf.BlobFile(bf.join(self.checkpoint_path, filename), "wb") as f:
295 | torch.save(state_dict, f)
296 |
297 | def _save_opt(self):
298 | if dist_util.get_rank() == 0:
299 | logger.log(f"saving optimizer...")
300 | filename = f"opt_{(self.step+self.resume_step):06d}.pt"
301 | print('writing to', bf.join(self.checkpoint_path, filename))
302 | with bf.BlobFile(bf.join(self.checkpoint_path, filename), "wb") as f:
303 | torch.save(self.opt.state_dict(), f)
304 |
305 | def _master_params_to_state_dict(self, master_params, key=None):
306 | state_dict = self.model.state_dict()
307 | for i, (name, _value) in enumerate(self.model.named_parameters()):
308 | assert name in state_dict
309 | if key is not None and key == name:
310 | return master_params[i]
311 | state_dict[name] = master_params[i]
312 | if key is not None:
313 | raise KeyError(key)
314 | return state_dict
315 |
316 | def _state_dict_to_master_params(self, state_dict):
317 | params = [state_dict[name] for name, _ in self.model.named_parameters()]
318 | return params
319 |
320 | @staticmethod
321 | def parse_resume_step_from_filename(filename):
322 | """
323 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
324 | checkpoint's number of steps.
325 | """
326 | filename: str = os.path.basename(filename)
327 | assert filename.startswith('model') and filename[-3:] == '.pt', "Invalid model name"
328 | return int(filename[-9:-3])
329 |
330 | @staticmethod
331 | def find_resume_checkpoint():
332 | log_dir = logger.get_current().dir
333 | model_weights = sorted(filter(lambda s: s.endswith(".pt") and s.startswith("model"), os.listdir(log_dir)))
334 | if model_weights:
335 | path = os.path.join(log_dir, model_weights[-1])
336 | return path
337 |
338 | @staticmethod
339 | def find_ema_checkpoint(main_checkpoint, step, rate):
340 | if not main_checkpoint:
341 | return None
342 | filename = f"ema_{rate}_{step:06d}.pt"
343 | path = bf.join(bf.dirname(main_checkpoint), filename)
344 | if bf.exists(path):
345 | return path
346 | return None
347 |
348 | @staticmethod
349 | def find_opt_checkpoint(main_checkpoint, step):
350 | if not main_checkpoint:
351 | return None
352 | filename = f"opt_{step:06d}.pt"
353 | path = bf.join(bf.dirname(main_checkpoint), filename)
354 | if bf.exists(path):
355 | return path
356 | return None
357 |
358 | __call__ = run_loop
359 |
360 |
361 | def update_ema(target_params, source_params, rate=0.99):
362 | """
363 | Update target parameters to be closer to those of source parameters using
364 | an exponential moving average.
365 |
366 | :param target_params: the target parameter sequence.
367 | :param source_params: the source parameter sequence.
368 | :param rate: the EMA rate (closer to 1 means slower).
369 | """
370 | for trg, src in zip(target_params, source_params):
371 | trg.detach().mul_(rate).add_(src, alpha=1 - rate)
372 |
--------------------------------------------------------------------------------