├── .gitignore ├── LICENSE ├── README.md ├── config └── base.json ├── core ├── base_model.py ├── base_network.py ├── logger.py ├── praser.py └── util.py ├── data ├── __init__.py ├── auto_augment.py └── dataset.py ├── experiments └── clean.sh ├── misc └── template.png ├── models ├── __init__.py ├── loss.py ├── metric.py ├── model.py └── network.py ├── new_project.py ├── requirements.txt ├── run.py └── slurm └── run.slurm /.gitignore: -------------------------------------------------------------------------------- 1 | # myself 2 | experiments/* 3 | !experiments/clean.sh 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Liangwei Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Template Using DistributedDataParallel 2 | 3 | This is a seed project for distributed PyTorch training, which was built to customize your network quickly. 4 | 5 | ### Overview 6 | 7 | Here is an overview of what this template can do, and most of them can be customized by the configure file. 8 | 9 | ![distributed pytorch template](misc/template.png) 10 | 11 | ### Basic Functions 12 | 13 | - checkpoint/resume training 14 | - progress bar (using tqdm) 15 | - progress logs (using logging) 16 | - progress visualization (using tensorboard) 17 | - finetune (partial network parameters training) 18 | - learning rate scheduler 19 | - random seed (reproducibility) 20 | 21 | ------ 22 | ### Features 23 | 24 | - distributed training using DistributedDataParallel 25 | - base class for extensibility 26 | - `.json` configure file for most parameter tuning 27 | - support multiple networks/losses/metrics definition 28 | - debug mode for fast test 🌟 29 | 30 | ------ 31 | ### Usage 32 | 33 | #### You Need to Know 34 | 35 | 1. cuDNN default settings are as follows for training, which may reduce your code reproducibility! Notice it to avoid unexpected behaviors. 36 | 37 | ```python 38 | torch.backends.cudnn.enabled = True 39 | # speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 40 | if seed >=0 and gl_seed>=0: # slower, more reproducible 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | else: # faster, less reproducible, default setting 44 | torch.backends.cudnn.deterministic = False 45 | torch.backends.cudnn.benchmark = True 46 | ``` 47 | 48 | 2. The project allows custom classes/functions and parameters by configure file. You can define dataset, losses, networks, etc. by the specific format. Take the `network` as an example: 49 | 50 | ```yaml 51 | // import Network() class from models.network.py file with args 52 | "which_networks": [ 53 | { 54 | "name": ["models.network", "Network"], 55 | "args": { "init_type": "kaiming"} 56 | } 57 | ], 58 | 59 | // import mutilple Networks from defualt file with args 60 | "which_networks": [ 61 | {"name": "Network1", args: {"init_type": "kaiming"}}, 62 | {"name": "Network2", args: {"init_type": "kaiming"}}, 63 | ], 64 | 65 | // import mutilple Networks from defualt file without args 66 | "which_networks" : [ 67 | "Network1", // equivalent to {"name": "Network1", args: {}}, 68 | "Network2" 69 | ] 70 | 71 | // more details can be found on More Details part and init_objs function in praser.py 72 | ``` 73 | 74 | 75 | 76 | #### Start 77 | 78 | Run the `run.py` with your setting. 79 | 80 | ```python 81 | python run.py 82 | ``` 83 | 84 | More choices can be found on `run.py` and `config/base.json`. 85 | 86 | 87 | #### Customize Dataset 88 | 89 | Dataset part decides the data need to be fed into the network, you can define the dataset by following steps: 90 | 91 | 1. Put your dataset under `data` folder. See `dataset.py` in this folder as an example. 92 | 2. Edit the **\[dataset\]\[train|test\]** part in `config/base.json` to import and initialize dataset. 93 | 94 | ```yaml 95 | "datasets": { // train or test 96 | "train": { 97 | "which_dataset": { // import designated dataset using args 98 | "name": ["data.dataset", "Dataset"], 99 | "args":{ // args to init dataset 100 | "data_root": "/data/jlw/datasets/comofod" 101 | } 102 | }, 103 | "dataloader":{ 104 | "validation_split": 0.1, // percent or number 105 | "args":{ // args to init dataloader 106 | "batch_size": 2, // batch size in every gpu 107 | "num_workers": 4, 108 | "shuffle": true, 109 | "pin_memory": true, 110 | "drop_last": true 111 | } 112 | } 113 | }, 114 | } 115 | ``` 116 | 117 | ##### More details 118 | 119 | - You can import dataset from a new file. Key `name` can be a list to show your file name and class/function name, or a single string to explain class name in default file(`data.dataset.py`). An example is as follows: 120 | 121 | ```yaml 122 | "name": ["data.dataset", "Dataset"], // import Dataset() class from data.dataset.py 123 | "name": "Dataset", // import Dataset() class from default file 124 | ``` 125 | 126 | - You can control and record more parameters through configure file. Take `data_root` as the example, you just need to add it in `args` dict and edit the corresponding class to parse this value: 127 | 128 | ```yaml 129 | "args":{ // args to init dataset 130 | "data_root": "your data path" 131 | } 132 | ``` 133 | 134 | ```python 135 | class Dataset(data.Dataset): 136 | def __init__(self, data_root, phase='train', image_size=[256, 256], loader=pil_loader): 137 | imgs = make_dataset(data_root) # data_root value is from configure file 138 | ``` 139 | 140 | 141 | 142 | #### Customize Network 143 | 144 | Network part shows your learning network structure, you can define your network by following steps: 145 | 146 | 1. Put your network under `models` folder. See `network.py` in this folder as an example. 147 | 2. Edit the **\[model\][which_networks]** part in `config/base.json` to import and initialize your networks, and it is a list. 148 | 149 | ```yaml 150 | "which_networks": [ // import designated list of networks using args 151 | { 152 | "name": "Network", 153 | "args": { // args to init network 154 | "init_type": "kaiming" 155 | } 156 | } 157 | ], 158 | ``` 159 | ##### More details 160 | 161 | - You can import networks from a new file. Key `name` can be a list to show your file name and class/function name, or a single string to explain class name in default file(`models.network.py` ). An example is as follows: 162 | 163 | ```yaml 164 | "name": ["models.network", "Network"], // import Network() class from models.network.py 165 | "name": "Network", // import Network() class from default file 166 | ``` 167 | 168 | - You can control and record more parameters through configure file. Take `init_type` as the example, you just need to add it in `args` dict and edit corresponding class to parse this value: 169 | 170 | ```yaml 171 | "args": { // args to init network 172 | "init_type": "kaiming" 173 | } 174 | ``` 175 | 176 | ```python 177 | class BaseNetwork(nn.Module): 178 | def __init__(self, init_type='kaiming', gain=0.02): 179 | super(BaseNetwork, self).__init__() # init_type value is from configure file 180 | class Network(BaseNetwork): 181 | def __init__(self, in_channels=3, **kwargs): 182 | super(Network, self).__init__(**kwargs) # get init_type value and pass it to base network 183 | ``` 184 | 185 | - You can import multiple networks. You should import the networks in configure file and use it in model. 186 | 187 | ```yaml 188 | "which_networks": [ 189 | {"name": "Network1", args: {}}, 190 | {"name": "Network2", args: {}}, 191 | ], 192 | ``` 193 | 194 | 195 | 196 | 197 | #### Customize Model(Trainer) 198 | 199 | Model part shows your training process including optimizers/losses/process control, etc. You can define your model by following steps: 200 | 201 | 1. Put your Model under `models` folder. See `model.py` in its folder as an example. 202 | 2. Edit the **\[model\][which_model]** part in `config/base.json` to import and initialize your model. 203 | 204 | ```yaml 205 | "which_model": { // import designated model(trainer) using args 206 | "name": ["models.model", "Model"], 207 | "args": { // args to init model 208 | } 209 | }, 210 | ``` 211 | 212 | ##### More details 213 | 214 | - You can import model from a new file. Key `name` can be a list to show your file name and class/function name, or a single string to explain class name in default file(`models.model.py` ). An example is as follows: 215 | 216 | ```yaml 217 | "name": ["models.model", "Model"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) 218 | "name": "Model", // import Model() class from default file 219 | ``` 220 | 221 | - You can control and record more parameters through configure file. Please infer to above `More details` part. 222 | 223 | 224 | ##### Losses and Metrics 225 | 226 | Losses and Metrics are defined on configure file. You also can control and record more parameters through configure file, please refer to the above `More details` part. 227 | 228 | ```yaml 229 | "which_metrics": ["mae"], 230 | "which_losses": ["mse_loss"] 231 | ``` 232 | 233 | After the above steps, you need to rewrite several functions like `base_model.py/model.py` for your network and dataset. 234 | 235 | ##### Init step 236 | 237 | See `__init__()` functions as the example. 238 | 239 | ##### Training/validation step 240 | 241 | See `train_step()/val_step()` functions as the example. 242 | 243 | ##### Checkpoint/Resume training 244 | 245 | See `save_everything()/load_everything()` functions as the example. 246 | 247 | 248 | 249 | #### Debug mode 250 | 251 | Sometimes we hope to debug the process quickly to ensure the whole project works, so debug mode is necessary. 252 | 253 | This mode will reduce the dataset size and speed up the training process. You just need to run the file with -d option and edit the debug dict in configure file. 254 | 255 | ```python 256 | python run.py -d 257 | ``` 258 | 259 | ```yaml 260 | "debug": { // args in debug mode, which will replace args in train 261 | "val_epoch": 1, 262 | "save_checkpoint_epoch": 1, 263 | "log_iter": 30, 264 | "data_len": 50 // percent or number, change the size of dataloder to debug_split. 265 | } 266 | ``` 267 | 268 | 269 | 270 | #### Customize More 271 | 272 | You can choose the random seed, experiment path in configure file. We will add more useful basic functions with related instructions. **Welcome to more contributions for more extensive customization and code enhancements.** 273 | 274 | ------ 275 | ### Todo 276 | 277 | Here are some basic functions or examples that this repository is ready to implement: 278 | 279 | - [x] basic dataset/data_loader with validation split 280 | - [x] basic networks with weight initialization 281 | - [x] basic model (trainer) 282 | - [x] checkpoint/resume training 283 | - [x] progress bar (using tqdm) 284 | - [x] progress logs (using logging) 285 | - [x] progress visualization (using tensorboard) 286 | - [x] multi-gpu support (using DistributedDataParallel and torch.multiprocessing) 287 | - [x] finetune (partial network parameters training) 288 | - [x] learning rate scheduler 289 | - [x] random seed (reproducibility) 290 | - [x] multiple optimizer and scheduler by configure file 291 | - [ ] praser arguments customization 292 | - [ ] more network examples 293 | 294 | 295 | ------ 296 | ### Acknowledge 297 | 298 | We are benefit a lot from following projects: 299 | 300 | > 1. https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement 301 | > 2. https://github.com/researchmm/PEN-Net-for-Inpainting 302 | > 3. https://github.com/tczhangzhi/pytorch-distributed 303 | > 4. https://github.com/victoresque/pytorch-template -------------------------------------------------------------------------------- /config/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "base", // experiments name 3 | "gpu_ids": [0], // gpu ids list, default is single 0 4 | "seed" : 2021, // random seed, seed <0 represents randomization not used 5 | "finetune_norm": false, // find the parameters to optimize 6 | 7 | "path": { //set every part file path 8 | "base_dir": "experiments", // base path for all log except resume_state 9 | "tb_logger": "tb_logger", // path of tensorboard logger 10 | "results": "results", 11 | "code": "code", // code backup 12 | "checkpoint": "checkpoint", 13 | // "resume_state": "experiments/debug_base_220226_214326/checkpoint/100" 14 | "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 15 | }, 16 | 17 | "datasets": { // train or test 18 | "train": { 19 | "which_dataset": { // import designated dataset using arguments 20 | "name": ["data.dataset", "Dataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 21 | "args":{ // arguments to initialize dataset 22 | // "data_root": "/home/huangyecheng/dataset/cmfd/comofod" 23 | "data_root": "/data/jlw/datasets/comofod" 24 | } 25 | }, 26 | "dataloader":{ 27 | "validation_split": 0.1, // percent or number 28 | "args":{ // arguments to initialize dataloader 29 | "batch_size": 2, // batch size in each gpu 30 | "num_workers": 4, 31 | "shuffle": true, 32 | "pin_memory": true, 33 | "drop_last": true 34 | }, 35 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 36 | "batch_size": 1, // batch size in each gpu 37 | "num_workers": 4, 38 | "shuffle": false, 39 | "pin_memory": true, 40 | "drop_last": false 41 | } 42 | } 43 | }, 44 | "test": { 45 | "which_dataset": { 46 | "name": "Dataset", // import Dataset() class / function(not recommend) from default file 47 | "args":{ 48 | "data_root": "/data/jlw/datasets/comofod", 49 | "phase": "test" 50 | } 51 | }, 52 | "dataloader":{ 53 | "args":{ 54 | "batch_size": 1, 55 | "num_workers": 4, 56 | "pin_memory": true 57 | } 58 | } 59 | } 60 | }, 61 | 62 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 63 | "which_model": { // import designated model(trainer) using arguments 64 | "name": ["models.model", "Model"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) 65 | "args": { 66 | "ema_scheduler": { 67 | "ema_start": 1e3, 68 | "ema_iter": 1, 69 | "ema_decay": 0.9999 70 | }, 71 | "optimizers": [ 72 | { "lr": 1e-4, "weight_decay": 0.0001} 73 | ] 74 | } 75 | }, 76 | "which_networks": [ // import designated list of networks using arguments 77 | { 78 | "name": "Network", // import Network() class / function(not recommend) from default file (default is [models/network.py]) 79 | "args": { // arguments to initialize network 80 | "init_type": "kaiming" // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 81 | } 82 | } 83 | ], 84 | "which_losses": [ // import designated list of losses without arguments 85 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 86 | ], 87 | "which_metrics": [ // import designated list of metrics without arguments 88 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 89 | ] 90 | }, 91 | 92 | "train": { // arguments for basic training 93 | "n_epoch": 500, // max epochs 94 | "n_iter": 1e6, // max interations, not limited now 95 | "val_epoch": 30, // valdation every specified number of epochs 96 | "save_checkpoint_epoch": 30, 97 | "log_iter": 1e4, // log every specified number of iterations 98 | "tensorboard" : true // tensorboardX enable 99 | }, 100 | 101 | "debug": { // arguments in debug mode, which will replace arguments in train 102 | "val_epoch": 1, 103 | "save_checkpoint_epoch": 1, 104 | "log_iter": 30, 105 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /core/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | from functools import partial 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | import core.util as Util 11 | CustomResult = collections.namedtuple('CustomResult', 'name result') 12 | 13 | class BaseModel(): 14 | def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer): 15 | """ init model with basic input, which are from __init__(**kwargs) function in inherited class """ 16 | self.opt = opt 17 | self.phase = opt['phase'] 18 | self.set_device = partial(Util.set_device, rank=opt['global_rank']) 19 | 20 | ''' optimizers and schedulers ''' 21 | self.schedulers = [] 22 | self.optimizers = [] 23 | 24 | ''' process record ''' 25 | self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size'] 26 | self.epoch = 0 27 | self.iter = 0 28 | 29 | self.phase_loader = phase_loader 30 | self.val_loader = val_loader 31 | self.metrics = metrics 32 | 33 | ''' logger to log file, which only work on GPU 0. writer to tensorboard and result file ''' 34 | self.logger = logger 35 | self.writer = writer 36 | self.results_dict = CustomResult([],[]) # {"name":[], "result":[]} 37 | 38 | def train(self): 39 | while self.epoch <= self.opt['train']['n_epoch'] and self.iter <= self.opt['train']['n_iter']: 40 | self.epoch += 1 41 | if self.opt['distributed']: 42 | ''' sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch ''' 43 | self.phase_loader.sampler.set_epoch(self.epoch) 44 | 45 | train_log = self.train_step() 46 | 47 | ''' save logged informations into log dict ''' 48 | train_log.update({'epoch': self.epoch, 'iters': self.iter}) 49 | 50 | ''' print logged informations to the screen and tensorboard ''' 51 | for key, value in train_log.items(): 52 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 53 | 54 | if self.epoch % self.opt['train']['save_checkpoint_epoch'] == 0: 55 | self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch)) 56 | self.save_everything() 57 | 58 | if self.epoch % self.opt['train']['val_epoch'] == 0: 59 | self.logger.info("\n\n\n------------------------------Validation Start------------------------------") 60 | if self.val_loader is None: 61 | self.logger.warning('Validation stop where dataloader is None, Skip it.') 62 | else: 63 | val_log = self.val_step() 64 | for key, value in val_log.items(): 65 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 66 | self.logger.info("\n------------------------------Validation End------------------------------\n\n") 67 | self.logger.info('Number of Epochs has reached the limit, End.') 68 | 69 | def test(self): 70 | pass 71 | 72 | @abstractmethod 73 | def train_step(self): 74 | raise NotImplementedError('You must specify how to train your networks.') 75 | 76 | @abstractmethod 77 | def val_step(self): 78 | raise NotImplementedError('You must specify how to do validation on your networks.') 79 | 80 | def test_step(self): 81 | pass 82 | 83 | def print_network(self, network): 84 | """ print network structure, only work on GPU 0 """ 85 | if self.opt['global_rank'] !=0: 86 | return 87 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 88 | network = network.module 89 | 90 | s, n = str(network), sum(map(lambda x: x.numel(), network.parameters())) 91 | net_struc_str = '{}'.format(network.__class__.__name__) 92 | self.logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 93 | self.logger.info(s) 94 | 95 | def save_network(self, network, network_label): 96 | """ save network structure, only work on GPU 0 """ 97 | if self.opt['global_rank'] !=0: 98 | return 99 | save_filename = '{}_{}.pth'.format(self.epoch, network_label) 100 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) 101 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 102 | network = network.module 103 | state_dict = network.state_dict() 104 | for key, param in state_dict.items(): 105 | state_dict[key] = param.cpu() 106 | torch.save(state_dict, save_path) 107 | 108 | def load_network(self, network, network_label, strict=True): 109 | if self.opt['path']['resume_state'] is None: 110 | return 111 | self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label)) 112 | 113 | model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label) 114 | 115 | if not os.path.exists(model_path): 116 | self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path)) 117 | return 118 | 119 | self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path)) 120 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 121 | network = network.module 122 | network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict) 123 | 124 | def save_training_state(self): 125 | """ saves training state during training, only work on GPU 0 """ 126 | if self.opt['global_rank'] !=0: 127 | return 128 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' 129 | state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []} 130 | for s in self.schedulers: 131 | state['schedulers'].append(s.state_dict()) 132 | for o in self.optimizers: 133 | state['optimizers'].append(o.state_dict()) 134 | save_filename = '{}.state'.format(self.epoch) 135 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) 136 | torch.save(state, save_path) 137 | 138 | def resume_training(self): 139 | """ resume the optimizers and schedulers for training, only work when phase is test or resume training enable """ 140 | if self.phase!='train' or self. opt['path']['resume_state'] is None: 141 | return 142 | self.logger.info('Beign loading training states'.format()) 143 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' 144 | 145 | state_path = "{}.state".format(self. opt['path']['resume_state']) 146 | 147 | if not os.path.exists(state_path): 148 | self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path)) 149 | return 150 | 151 | self.logger.info('Loading training state for [{:s}] ...'.format(state_path)) 152 | resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage)) 153 | 154 | resume_optimizers = resume_state['optimizers'] 155 | resume_schedulers = resume_state['schedulers'] 156 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers)) 157 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers)) 158 | for i, o in enumerate(resume_optimizers): 159 | self.optimizers[i].load_state_dict(o) 160 | for i, s in enumerate(resume_schedulers): 161 | self.schedulers[i].load_state_dict(s) 162 | 163 | self.epoch = resume_state['epoch'] 164 | self.iter = resume_state['iter'] 165 | 166 | @abstractmethod 167 | def save_everything(self): 168 | raise NotImplementedError('You must specify how to save your networks, optimizers and schedulers.') 169 | -------------------------------------------------------------------------------- /core/base_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | class BaseNetwork(nn.Module): 3 | def __init__(self, init_type='kaiming', gain=0.02): 4 | super(BaseNetwork, self).__init__() 5 | self.init_type = init_type 6 | self.gain = gain 7 | 8 | def init_weights(self): 9 | """ 10 | initialize network's weights 11 | init_type: normal | xavier | kaiming | orthogonal 12 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 13 | """ 14 | 15 | def init_func(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('InstanceNorm2d') != -1: 18 | if hasattr(m, 'weight') and m.weight is not None: 19 | nn.init.constant_(m.weight.data, 1.0) 20 | if hasattr(m, 'bias') and m.bias is not None: 21 | nn.init.constant_(m.bias.data, 0.0) 22 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 23 | if self.init_type == 'normal': 24 | nn.init.normal_(m.weight.data, 0.0, self.gain) 25 | elif self.init_type == 'xavier': 26 | nn.init.xavier_normal_(m.weight.data, gain=self.gain) 27 | elif self.init_type == 'xavier_uniform': 28 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 29 | elif self.init_type == 'kaiming': 30 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 31 | elif self.init_type == 'orthogonal': 32 | nn.init.orthogonal_(m.weight.data, gain=self.gain) 33 | elif self.init_type == 'none': # uses pytorch's default init method 34 | m.reset_parameters() 35 | else: 36 | raise NotImplementedError('initialization method [%s] is not implemented' % self.init_type) 37 | if hasattr(m, 'bias') and m.bias is not None: 38 | nn.init.constant_(m.bias.data, 0.0) 39 | 40 | self.apply(init_func) 41 | # propagate to children 42 | for m in self.children(): 43 | if hasattr(m, 'init_weights'): 44 | m.init_weights(self.init_type, self.gain) 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import importlib 4 | from datetime import datetime 5 | import logging 6 | import pandas as pd 7 | 8 | import core.util as Util 9 | 10 | class InfoLogger(): 11 | """ 12 | use logging to record log, only work on GPU 0 by judging global_rank 13 | """ 14 | def __init__(self, opt): 15 | self.opt = opt 16 | self.rank = opt['global_rank'] 17 | self.phase = opt['phase'] 18 | 19 | self.setup_logger(None, opt['path']['experiments_root'], opt['phase'], level=logging.INFO, screen=False) 20 | self.logger = logging.getLogger(opt['phase']) 21 | self.infologger_ftns = {'info', 'warning', 'debug'} 22 | 23 | def __getattr__(self, name): 24 | if self.rank != 0: # info only print on GPU 0. 25 | def wrapper(info, *args, **kwargs): 26 | pass 27 | return wrapper 28 | if name in self.infologger_ftns: 29 | print_info = getattr(self.logger, name, None) 30 | def wrapper(info, *args, **kwargs): 31 | print_info(info, *args, **kwargs) 32 | return wrapper 33 | 34 | @staticmethod 35 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 36 | """ set up logger """ 37 | l = logging.getLogger(logger_name) 38 | formatter = logging.Formatter( 39 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 40 | log_file = os.path.join(root, '{}.log'.format(phase)) 41 | fh = logging.FileHandler(log_file, mode='a+') 42 | fh.setFormatter(formatter) 43 | l.setLevel(level) 44 | l.addHandler(fh) 45 | if screen: 46 | sh = logging.StreamHandler() 47 | sh.setFormatter(formatter) 48 | l.addHandler(sh) 49 | 50 | class VisualWriter(): 51 | """ 52 | use tensorboard to record visuals, support 'add_scalar', 'add_scalars', 'add_image', 'add_images', etc. funtion. 53 | Also integrated with save results function. 54 | """ 55 | def __init__(self, opt, logger): 56 | log_dir = opt['path']['tb_logger'] 57 | self.result_dir = opt['path']['results'] 58 | enabled = opt['train']['tensorboard'] 59 | self.rank = opt['global_rank'] 60 | 61 | self.writer = None 62 | self.selected_module = "" 63 | 64 | if enabled and self.rank==0: 65 | log_dir = str(log_dir) 66 | 67 | # Retrieve vizualization writer. 68 | succeeded = False 69 | for module in ["tensorboardX", "torch.utils.tensorboard"]: 70 | try: 71 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 72 | succeeded = True 73 | break 74 | except ImportError: 75 | succeeded = False 76 | self.selected_module = module 77 | 78 | if not succeeded: 79 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 80 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 81 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 82 | logger.warning(message) 83 | 84 | self.epoch = 0 85 | self.iter = 0 86 | self.phase = '' 87 | 88 | self.tb_writer_ftns = { 89 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 90 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 91 | } 92 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 93 | self.custom_ftns = {'close'} 94 | self.timer = datetime.now() 95 | 96 | def set_iter(self, epoch, iter, phase='train'): 97 | self.phase = phase 98 | self.epoch = epoch 99 | self.iter = iter 100 | 101 | def save_images(self, results): 102 | result_path = os.path.join(self.result_dir, self.phase) 103 | os.makedirs(result_path, exist_ok=True) 104 | result_path = os.path.join(result_path, str(self.epoch)) 105 | os.makedirs(result_path, exist_ok=True) 106 | 107 | ''' get names and corresponding images from results[OrderedDict] ''' 108 | try: 109 | names = results['name'] 110 | outputs = Util.postprocess(results['result']) 111 | for i in range(len(names)): 112 | Image.fromarray(outputs[i]).save(os.path.join(result_path, names[i])) 113 | except: 114 | raise NotImplementedError('You must specify the context of name and result in save_current_results functions of model.') 115 | 116 | def close(self): 117 | self.writer.close() 118 | print('Close the Tensorboard SummaryWriter.') 119 | 120 | 121 | def __getattr__(self, name): 122 | """ 123 | If visualization is configured to use: 124 | return add_data() methods of tensorboard with additional information (step, tag) added. 125 | Otherwise: 126 | return a blank function handle that does nothing 127 | """ 128 | if name in self.tb_writer_ftns: 129 | add_data = getattr(self.writer, name, None) 130 | def wrapper(tag, data, *args, **kwargs): 131 | if add_data is not None: 132 | # add phase(train/valid) tag 133 | if name not in self.tag_mode_exceptions: 134 | tag = '{}/{}'.format(self.phase, tag) 135 | add_data(tag, data, self.iter, *args, **kwargs) 136 | return wrapper 137 | elif name in self.custom_ftns: 138 | customfunc = getattr(self.writer, name, None) 139 | def wrapper(*args, **kwargs): 140 | if customfunc is not None: 141 | customfunc(*args, **kwargs) 142 | return wrapper 143 | else: 144 | # default action for returning methods defined in this class, set_step() for instance. 145 | try: 146 | attr = object.__getattr__(name) 147 | except AttributeError: 148 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 149 | return attr 150 | 151 | 152 | class LogTracker: 153 | """ 154 | record training numerical indicators. 155 | """ 156 | def __init__(self, *keys, phase='train'): 157 | self.phase = phase 158 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 159 | self.reset() 160 | 161 | def reset(self): 162 | for col in self._data.columns: 163 | self._data[col].values[:] = 0 164 | 165 | def update(self, key, value, n=1): 166 | self._data.total[key] += value * n 167 | self._data.counts[key] += n 168 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 169 | 170 | def avg(self, key): 171 | return self._data.average[key] 172 | 173 | def result(self): 174 | return {'{}/{}'.format(self.phase, k):v for k, v in dict(self._data.average).items()} 175 | -------------------------------------------------------------------------------- /core/praser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from collections import OrderedDict 4 | import json 5 | from pathlib import Path 6 | from datetime import datetime 7 | from functools import partial 8 | import importlib 9 | from types import FunctionType 10 | import shutil 11 | 12 | def init_obj(opt, logger, *args, default_file_name='default file', given_module=None, init_type='Network', **modify_kwargs): 13 | """ 14 | finds a function handle with the name given as 'name' in config, 15 | and returns the instance initialized with corresponding args. 16 | """ 17 | if opt is None or len(opt)<1: 18 | logger.info('Option is None when initialize {}'.format(init_type)) 19 | return None 20 | 21 | ''' default format is dict with name key ''' 22 | if isinstance(opt, str): 23 | opt = {'name': opt} 24 | logger.info('Config is a str, converts to a dict {}'.format(opt)) 25 | 26 | name = opt['name'] 27 | ''' name can be list, indicates the file and class name of function ''' 28 | if isinstance(name, list): 29 | file_name, class_name = name[0], name[1] 30 | else: 31 | file_name, class_name = default_file_name, name 32 | try: 33 | if given_module is not None: 34 | module = given_module 35 | else: 36 | module = importlib.import_module(file_name) 37 | 38 | attr = getattr(module, class_name) 39 | kwargs = opt.get('args', {}) 40 | kwargs.update(modify_kwargs) 41 | ''' import class or function with args ''' 42 | if isinstance(attr, type): 43 | ret = attr(*args, **kwargs) 44 | ret.__name__ = ret.__class__.__name__ 45 | elif isinstance(attr, FunctionType): 46 | ret = partial(attr, *args, **kwargs) 47 | ret.__name__ = attr.__name__ 48 | # ret = attr 49 | logger.info('{} [{:s}() form {:s}] is created.'.format(init_type, class_name, file_name)) 50 | except: 51 | raise NotImplementedError('{} [{:s}() form {:s}] not recognized.'.format(init_type, class_name, file_name)) 52 | return ret 53 | 54 | 55 | def mkdirs(paths): 56 | if isinstance(paths, str): 57 | os.makedirs(paths, exist_ok=True) 58 | else: 59 | for path in paths: 60 | os.makedirs(path, exist_ok=True) 61 | 62 | def get_timestamp(): 63 | return datetime.now().strftime('%y%m%d_%H%M%S') 64 | 65 | 66 | def write_json(content, fname): 67 | fname = Path(fname) 68 | with fname.open('wt') as handle: 69 | json.dump(content, handle, indent=4, sort_keys=False) 70 | 71 | class NoneDict(dict): 72 | def __missing__(self, key): 73 | return None 74 | 75 | def dict_to_nonedict(opt): 76 | """ convert to NoneDict, which return None for missing key. """ 77 | if isinstance(opt, dict): 78 | new_opt = dict() 79 | for key, sub_opt in opt.items(): 80 | new_opt[key] = dict_to_nonedict(sub_opt) 81 | return NoneDict(**new_opt) 82 | elif isinstance(opt, list): 83 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 84 | else: 85 | return opt 86 | 87 | def dict2str(opt, indent_l=1): 88 | """ dict to string for logger """ 89 | msg = '' 90 | for k, v in opt.items(): 91 | if isinstance(v, dict): 92 | msg += ' ' * (indent_l * 2) + k + ':[\n' 93 | msg += dict2str(v, indent_l + 1) 94 | msg += ' ' * (indent_l * 2) + ']\n' 95 | else: 96 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 97 | return msg 98 | 99 | def parse(args): 100 | json_str = '' 101 | with open(args.config, 'r') as f: 102 | for line in f: 103 | line = line.split('//')[0] + '\n' 104 | json_str += line 105 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 106 | 107 | ''' replace the config context using args ''' 108 | opt['phase'] = args.phase 109 | if args.gpu_ids is not None: 110 | opt['gpu_ids'] = [int(id) for id in args.gpu_ids.split(',')] 111 | if args.batch is not None: 112 | opt['datasets'][opt['phase']]['dataloader']['args']['batch_size'] = args.batch 113 | 114 | ''' set cuda environment ''' 115 | if len(opt['gpu_ids']) > 1: 116 | opt['distributed'] = True 117 | else: 118 | opt['distributed'] = False 119 | 120 | ''' update name ''' 121 | if args.debug: 122 | opt['name'] = 'debug_{}'.format(opt['name']) 123 | elif opt['finetune_norm']: 124 | opt['name'] = 'finetune_{}'.format(opt['name']) 125 | else: 126 | opt['name'] = '{}_{}'.format(opt['phase'], opt['name']) 127 | 128 | ''' set log directory ''' 129 | experiments_root = os.path.join(opt['path']['base_dir'], '{}_{}'.format(opt['name'], get_timestamp())) 130 | mkdirs(experiments_root) 131 | 132 | ''' save json ''' 133 | write_json(opt, '{}/config.json'.format(experiments_root)) 134 | 135 | ''' change folder relative hierarchy ''' 136 | opt['path']['experiments_root'] = experiments_root 137 | for key, path in opt['path'].items(): 138 | if 'resume' not in key and 'base_dir' not in key and 'experiments_root' not in key: 139 | opt['path'][key] = os.path.join(experiments_root, path) 140 | mkdirs(opt['path'][key]) 141 | 142 | ''' debug mode ''' 143 | if 'debug' in opt['name']: 144 | opt['train'].update(opt['debug']) 145 | 146 | ''' code backup ''' 147 | # if os.path.exists(opt['path']['code']): 148 | # shutil.rmtree(opt['path']['code']) 149 | for name in os.listdir('.'): 150 | if name in ['config', 'models', 'core', 'slurm', 'data']: 151 | shutil.copytree(name, os.path.join(opt['path']['code'], name), ignore=shutil.ignore_patterns("*.pyc", "__pycache__")) 152 | if '.py' in name or '.sh' in name: 153 | shutil.copy(name, opt['path']['code']) 154 | return dict_to_nonedict(opt) 155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | from torchvision.utils import make_grid 6 | import math 7 | 8 | def tensor2img(tensor, in_type='pt', out_type=np.uint8): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | in_type: zc(zero center)[-1, 1], pt(pytorch)[0,1] 14 | ''' 15 | if in_type == 'pt': 16 | tensor = tensor.squeeze().clamp_(*[0, 1]) # clamp 17 | elif in_type == 'zc': 18 | tensor = tensor.squeeze().clamp_(*[-1, 1]) # clamp 19 | n_dim = tensor.dim() 20 | if n_dim == 4: 21 | n_img = len(tensor) 22 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 23 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 24 | elif n_dim == 3: 25 | img_np = tensor.numpy() 26 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 27 | elif n_dim == 2: 28 | img_np = tensor.numpy() 29 | else: 30 | raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 31 | if in_type == 'pt': 32 | img_np = img_np * 255 33 | elif in_type == 'zc': 34 | img_np = (img_np+1) * 127.5 35 | return img_np.round().astype(out_type) 36 | 37 | def postprocess(images, in_type='pt', out_type=np.uint8): 38 | return [tensor2img(image, in_type=in_type, out_type=out_type) for image in images] 39 | 40 | 41 | def set_seed(seed, gl_seed=0): 42 | """ set random seed, gl_seed used in worker_init_fn function """ 43 | if seed >=0 and gl_seed>=0: 44 | seed += gl_seed 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed_all(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | 50 | ''' change the deterministic and benchmark maybe cause uncertain convolution behavior. 51 | speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html ''' 52 | if seed >=0 and gl_seed>=0: # slower, more reproducible 53 | torch.backends.cudnn.deterministic = True 54 | torch.backends.cudnn.benchmark = False 55 | else: # faster, less reproducible 56 | torch.backends.cudnn.deterministic = False 57 | torch.backends.cudnn.benchmark = True 58 | 59 | def set_gpu(args, distributed=False, rank=0): 60 | """ set parameter to gpu or ddp """ 61 | if distributed and isinstance(args, torch.nn.Module): 62 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True) 63 | else: 64 | return args.cuda() 65 | 66 | def set_device(args, distributed=False, rank=0): 67 | """ set parameter to gpu or cpu """ 68 | if torch.cuda.is_available(): 69 | if isinstance(args, list): 70 | return (set_gpu(item, distributed, rank) for item in args) 71 | elif isinstance(args, dict): 72 | return {key:set_gpu(args[key], distributed, rank) for key in args} 73 | else: 74 | args = set_gpu(args, distributed, rank) 75 | return args 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | from torch.utils.data.distributed import DistributedSampler 5 | from torch import Generator, randperm 6 | from torch.utils.data import DataLoader, Subset 7 | 8 | import core.util as Util 9 | from core.praser import init_obj 10 | 11 | 12 | def define_dataloader(logger, opt): 13 | """ create train/test dataloader and validation dataloader, validation dataloader is None when phase is test or not GPU 0 """ 14 | '''create dataset and set random seed''' 15 | dataloader_args = opt['datasets'][opt['phase']]['dataloader']['args'] 16 | worker_init_fn = partial(Util.set_seed, gl_seed=opt['seed']) 17 | 18 | phase_dataset, val_dataset = define_dataset(logger, opt) 19 | 20 | '''create datasampler''' 21 | data_sampler = None 22 | if opt['distributed']: 23 | data_sampler = DistributedSampler(phase_dataset, shuffle=dataloader_args.get('shuffle', False), num_replicas=opt['world_size'], rank=opt['global_rank']) 24 | dataloader_args.update({'shuffle':False}) # sampler option is mutually exclusive with shuffle 25 | 26 | ''' create dataloader and validation dataloader ''' 27 | dataloader = DataLoader(phase_dataset, sampler=data_sampler, worker_init_fn=worker_init_fn, **dataloader_args) 28 | ''' val_dataloader don't use DistributedSampler to run only GPU 0! ''' 29 | if opt['global_rank']==0 and val_dataset is not None: 30 | dataloader_args.update(opt['datasets'][opt['phase']]['dataloader'].get('val_args',{})) 31 | val_dataloader = DataLoader(val_dataset, worker_init_fn=worker_init_fn, **dataloader_args) 32 | else: 33 | val_dataloader = None 34 | return dataloader, val_dataloader 35 | 36 | 37 | def define_dataset(logger, opt): 38 | ''' loading Dataset() class from given file's name ''' 39 | dataset_opt = opt['datasets'][opt['phase']]['which_dataset'] 40 | phase_dataset = init_obj(dataset_opt, logger, default_file_name='data.dataset', init_type='Dataset') 41 | val_dataset = None 42 | 43 | valid_len = 0 44 | data_len = len(phase_dataset) 45 | if 'debug' in opt['name']: 46 | debug_split = opt['debug'].get('debug_split', 1.0) 47 | if isinstance(debug_split, int): 48 | data_len = debug_split 49 | else: 50 | data_len *= debug_split 51 | 52 | dataloder_opt = opt['datasets'][opt['phase']]['dataloader'] 53 | valid_split = dataloder_opt.get('validation_split', 0) 54 | 55 | ''' divide validation dataset, valid_split==0 when phase is test or validation_split is 0. ''' 56 | if valid_split > 0.0 or 'debug' in opt['name']: 57 | if isinstance(valid_split, int): 58 | assert valid_split < data_len, "Validation set size is configured to be larger than entire dataset." 59 | valid_len = valid_split 60 | else: 61 | valid_len = int(data_len * valid_split) 62 | data_len -= valid_len 63 | phase_dataset, val_dataset = subset_split(dataset=phase_dataset, lengths=[data_len, valid_len], generator=Generator().manual_seed(opt['seed'])) 64 | 65 | logger.info('Dataset for {} have {} samples.'.format(opt['phase'], data_len)) 66 | if opt['phase'] == 'train': 67 | logger.info('Dataset for {} have {} samples.'.format('val', valid_len)) 68 | return phase_dataset, val_dataset 69 | 70 | def subset_split(dataset, lengths, generator): 71 | """ 72 | split a dataset into non-overlapping new datasets of given lengths. main code is from random_split function in pytorch 73 | """ 74 | indices = randperm(sum(lengths), generator=generator).tolist() 75 | Subsets = [] 76 | for offset, length in zip(np.add.accumulate(lengths), lengths): 77 | if length == 0: 78 | Subsets.append(None) 79 | else: 80 | Subsets.append(Subset(dataset, indices[offset - length : offset])) 81 | return Subsets 82 | -------------------------------------------------------------------------------- /data/auto_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from scipy import ndimage 4 | from PIL import Image, ImageEnhance, ImageOps 5 | 6 | 7 | class AutoAugment(object): 8 | def __init__(self): 9 | self.policies = [ 10 | ['Invert', 0.1, 7, 'Contrast', 0.2, 6], 11 | ['Rotate', 0.7, 2, 'TranslateX', 0.3, 9], 12 | ['Sharpness', 0.8, 1, 'Sharpness', 0.9, 3], 13 | ['ShearY', 0.5, 8, 'TranslateY', 0.7, 9], 14 | ['AutoContrast', 0.5, 8, 'Equalize', 0.9, 2], 15 | ['ShearY', 0.2, 7, 'Posterize', 0.3, 7], 16 | ['Color', 0.4, 3, 'Brightness', 0.6, 7], 17 | ['Sharpness', 0.3, 9, 'Brightness', 0.7, 9], 18 | ['Equalize', 0.6, 5, 'Equalize', 0.5, 1], 19 | ['Contrast', 0.6, 7, 'Sharpness', 0.6, 5], 20 | ['Color', 0.7, 7, 'TranslateX', 0.5, 8], 21 | ['Equalize', 0.3, 7, 'AutoContrast', 0.4, 8], 22 | ['TranslateY', 0.4, 3, 'Sharpness', 0.2, 6], 23 | ['Brightness', 0.9, 6, 'Color', 0.2, 8], 24 | ['Solarize', 0.5, 2, 'Invert', 0, 0.3], 25 | ['Equalize', 0.2, 0, 'AutoContrast', 0.6, 0], 26 | ['Equalize', 0.2, 8, 'Equalize', 0.6, 4], 27 | ['Color', 0.9, 9, 'Equalize', 0.6, 6], 28 | ['AutoContrast', 0.8, 4, 'Solarize', 0.2, 8], 29 | ['Brightness', 0.1, 3, 'Color', 0.7, 0], 30 | ['Solarize', 0.4, 5, 'AutoContrast', 0.9, 3], 31 | ['TranslateY', 0.9, 9, 'TranslateY', 0.7, 9], 32 | ['AutoContrast', 0.9, 2, 'Solarize', 0.8, 3], 33 | ['Equalize', 0.8, 8, 'Invert', 0.1, 3], 34 | ['TranslateY', 0.7, 9, 'AutoContrast', 0.9, 1], 35 | ] 36 | 37 | def __call__(self, img): 38 | img = apply_policy(img, self.policies[random.randrange(len(self.policies))]) 39 | return img 40 | 41 | 42 | class ImageNetAutoAugment(object): 43 | def __init__(self): 44 | self.policies = [ 45 | ['Posterize', 0.4, 8, 'Rotate', 0.6, 9], 46 | ['Solarize', 0.6, 5, 'AutoContrast', 0.6, 5], 47 | ['Equalize', 0.8, 8, 'Equalize', 0.6, 3], 48 | ['Posterize', 0.6, 7, 'Posterize', 0.6, 6], 49 | ['Equalize', 0.4, 7, 'Solarize', 0.2, 4], 50 | ['Equalize', 0.4, 4, 'Rotate', 0.8, 8], 51 | ['Solarize', 0.6, 3, 'Equalize', 0.6, 7], 52 | ['Posterize', 0.8, 5, 'Equalize', 1.0, 2], 53 | ['Rotate', 0.2, 3, 'Solarize', 0.6, 8], 54 | ['Equalize', 0.6, 8, 'Posterize', 0.4, 6], 55 | ['Rotate', 0.8, 8, 'Color', 0.4, 0], 56 | ['Rotate', 0.4, 9, 'Equalize', 0.6, 2], 57 | ['Equalize', 0.0, 0.7, 'Equalize', 0.8, 8], 58 | ['Invert', 0.6, 4, 'Equalize', 1.0, 8], 59 | ['Color', 0.6, 4, 'Contrast', 1.0, 8], 60 | ['Rotate', 0.8, 8, 'Color', 1.0, 2], 61 | ['Color', 0.8, 8, 'Solarize', 0.8, 7], 62 | ['Sharpness', 0.4, 7, 'Invert', 0.6, 8], 63 | ['ShearX', 0.6, 5, 'Equalize', 1.0, 9], 64 | ['Color', 0.4, 0, 'Equalize', 0.6, 3], 65 | ['Equalize', 0.4, 7, 'Solarize', 0.2, 4], 66 | ['Solarize', 0.6, 5, 'AutoContrast', 0.6, 5], 67 | ['Invert', 0.6, 4, 'Equalize', 1.0, 8], 68 | ['Color', 0.6, 4, 'Contrast', 1.0, 8], 69 | ['Equalize', 0.8, 8, 'Equalize', 0.6, 3] 70 | ] 71 | 72 | def __call__(self, img): 73 | img = apply_policy(img, self.policies[random.randrange(len(self.policies))]) 74 | return img 75 | 76 | 77 | operations = { 78 | 'ShearX': lambda img, magnitude: shear_x(img, magnitude), 79 | 'ShearY': lambda img, magnitude: shear_y(img, magnitude), 80 | 'TranslateX': lambda img, magnitude: translate_x(img, magnitude), 81 | 'TranslateY': lambda img, magnitude: translate_y(img, magnitude), 82 | 'Rotate': lambda img, magnitude: rotate(img, magnitude), 83 | 'AutoContrast': lambda img, magnitude: auto_contrast(img, magnitude), 84 | 'Invert': lambda img, magnitude: invert(img, magnitude), 85 | 'Equalize': lambda img, magnitude: equalize(img, magnitude), 86 | 'Solarize': lambda img, magnitude: solarize(img, magnitude), 87 | 'Posterize': lambda img, magnitude: posterize(img, magnitude), 88 | 'Contrast': lambda img, magnitude: contrast(img, magnitude), 89 | 'Color': lambda img, magnitude: color(img, magnitude), 90 | 'Brightness': lambda img, magnitude: brightness(img, magnitude), 91 | 'Sharpness': lambda img, magnitude: sharpness(img, magnitude), 92 | 'Cutout': lambda img, magnitude: cutout(img, magnitude), 93 | } 94 | 95 | 96 | def apply_policy(img, policy): 97 | if random.random() < policy[1]: 98 | img = operations[policy[0]](img, policy[2]) 99 | if random.random() < policy[4]: 100 | img = operations[policy[3]](img, policy[5]) 101 | 102 | return img 103 | 104 | 105 | def transform_matrix_offset_center(matrix, x, y): 106 | o_x = float(x) / 2 + 0.5 107 | o_y = float(y) / 2 + 0.5 108 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 109 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 110 | transform_matrix = offset_matrix @ matrix @ reset_matrix 111 | return transform_matrix 112 | 113 | 114 | def shear_x(img, magnitude): 115 | img = np.array(img) 116 | magnitudes = np.linspace(-0.3, 0.3, 11) 117 | 118 | transform_matrix = np.array([[1, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 0], 119 | [0, 1, 0], 120 | [0, 0, 1]]) 121 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 122 | affine_matrix = transform_matrix[:2, :2] 123 | offset = transform_matrix[:2, 2] 124 | img = np.stack([ndimage.interpolation.affine_transform( 125 | img[:, :, c], 126 | affine_matrix, 127 | offset) for c in range(img.shape[2])], axis=2) 128 | img = Image.fromarray(img) 129 | return img 130 | 131 | 132 | def shear_y(img, magnitude): 133 | img = np.array(img) 134 | magnitudes = np.linspace(-0.3, 0.3, 11) 135 | 136 | transform_matrix = np.array([[1, 0, 0], 137 | [random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 1, 0], 138 | [0, 0, 1]]) 139 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 140 | affine_matrix = transform_matrix[:2, :2] 141 | offset = transform_matrix[:2, 2] 142 | img = np.stack([ndimage.interpolation.affine_transform( 143 | img[:, :, c], 144 | affine_matrix, 145 | offset) for c in range(img.shape[2])], axis=2) 146 | img = Image.fromarray(img) 147 | return img 148 | 149 | 150 | def translate_x(img, magnitude): 151 | img = np.array(img) 152 | magnitudes = np.linspace(-150/331, 150/331, 11) 153 | 154 | transform_matrix = np.array([[1, 0, 0], 155 | [0, 1, img.shape[1]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], 156 | [0, 0, 1]]) 157 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 158 | affine_matrix = transform_matrix[:2, :2] 159 | offset = transform_matrix[:2, 2] 160 | img = np.stack([ndimage.interpolation.affine_transform( 161 | img[:, :, c], 162 | affine_matrix, 163 | offset) for c in range(img.shape[2])], axis=2) 164 | img = Image.fromarray(img) 165 | return img 166 | 167 | 168 | def translate_y(img, magnitude): 169 | img = np.array(img) 170 | magnitudes = np.linspace(-150/331, 150/331, 11) 171 | 172 | transform_matrix = np.array([[1, 0, img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], 173 | [0, 1, 0], 174 | [0, 0, 1]]) 175 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 176 | affine_matrix = transform_matrix[:2, :2] 177 | offset = transform_matrix[:2, 2] 178 | img = np.stack([ndimage.interpolation.affine_transform( 179 | img[:, :, c], 180 | affine_matrix, 181 | offset) for c in range(img.shape[2])], axis=2) 182 | img = Image.fromarray(img) 183 | return img 184 | 185 | 186 | def rotate(img, magnitude): 187 | img = np.array(img) 188 | magnitudes = np.linspace(-30, 30, 11) 189 | theta = np.deg2rad(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 190 | transform_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 191 | [np.sin(theta), np.cos(theta), 0], 192 | [0, 0, 1]]) 193 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 194 | affine_matrix = transform_matrix[:2, :2] 195 | offset = transform_matrix[:2, 2] 196 | img = np.stack([ndimage.interpolation.affine_transform( 197 | img[:, :, c], 198 | affine_matrix, 199 | offset) for c in range(img.shape[2])], axis=2) 200 | img = Image.fromarray(img) 201 | return img 202 | 203 | 204 | def auto_contrast(img, magnitude): 205 | img = ImageOps.autocontrast(img) 206 | return img 207 | 208 | 209 | def invert(img, magnitude): 210 | img = ImageOps.invert(img) 211 | return img 212 | 213 | 214 | def equalize(img, magnitude): 215 | img = ImageOps.equalize(img) 216 | return img 217 | 218 | 219 | def solarize(img, magnitude): 220 | magnitudes = np.linspace(0, 256, 11) 221 | img = ImageOps.solarize(img, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 222 | return img 223 | 224 | 225 | def posterize(img, magnitude): 226 | magnitudes = np.linspace(4, 8, 11) 227 | img = ImageOps.posterize(img, int(round(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])))) 228 | return img 229 | 230 | 231 | def contrast(img, magnitude): 232 | magnitudes = np.linspace(0.1, 1.9, 11) 233 | img = ImageEnhance.Contrast(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 234 | return img 235 | 236 | 237 | def color(img, magnitude): 238 | magnitudes = np.linspace(0.1, 1.9, 11) 239 | img = ImageEnhance.Color(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 240 | return img 241 | 242 | 243 | def brightness(img, magnitude): 244 | magnitudes = np.linspace(0.1, 1.9, 11) 245 | img = ImageEnhance.Brightness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 246 | return img 247 | 248 | 249 | def sharpness(img, magnitude): 250 | magnitudes = np.linspace(0.1, 1.9, 11) 251 | img = ImageEnhance.Sharpness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 252 | return img 253 | 254 | 255 | def cutout(org_img, magnitude=None): 256 | 257 | magnitudes = np.linspace(0, 60/331, 11) 258 | 259 | img = np.copy(org_img) 260 | mask_val = img.mean() 261 | 262 | if magnitude is None: 263 | mask_size = 16 264 | else: 265 | mask_size = int(round(img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))) 266 | top = np.random.randint(0 - mask_size//2, img.shape[0] - mask_size) 267 | left = np.random.randint(0 - mask_size//2, img.shape[1] - mask_size) 268 | bottom = top + mask_size 269 | right = left + mask_size 270 | 271 | if top < 0: 272 | top = 0 273 | if left < 0: 274 | left = 0 275 | 276 | img[top:bottom, left:right, :].fill(mask_val) 277 | 278 | img = Image.fromarray(img) 279 | 280 | return img 281 | 282 | 283 | class Cutout(object): 284 | 285 | def __init__(self, length=16): 286 | self.length = length 287 | 288 | def __call__(self, img): 289 | img = np.array(img) 290 | 291 | mask_val = img.mean() 292 | 293 | top = np.random.randint(0 - self.length//2, img.shape[0] - self.length) 294 | left = np.random.randint(0 - self.length//2, img.shape[1] - self.length) 295 | bottom = top + self.length 296 | right = left + self.length 297 | 298 | top = 0 if top < 0 else top 299 | left = 0 if left < 0 else top 300 | 301 | img[top:bottom, left:right, :] = mask_val 302 | 303 | img = Image.fromarray(img) 304 | 305 | return img -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | from .auto_augment import AutoAugment, ImageNetAutoAugment 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 20 | 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in sorted(fnames): 23 | if is_image_file(fname) and ('O' in fname or 'F' in fname): 24 | path = os.path.join(root, fname) 25 | images.append(path) 26 | 27 | return images 28 | 29 | def pil_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class Dataset(data.Dataset): 33 | def __init__(self, data_root, phase='train', image_size=[256, 256], loader=pil_loader): 34 | imgs = make_dataset(data_root) 35 | self.imgs = imgs 36 | if phase == 'train': 37 | self.tfs = transforms.Compose([ 38 | transforms.Resize((image_size[0], image_size[1])), 39 | ImageNetAutoAugment(), 40 | transforms.ToTensor() 41 | ]) 42 | else: 43 | self.tfs = transforms.Compose([ 44 | transforms.Resize((image_size[0], image_size[1])), 45 | transforms.ToTensor() 46 | ]) 47 | 48 | self.loader = loader 49 | 50 | def __getitem__(self, index): 51 | ret = {} 52 | path = self.imgs[index] 53 | img = self.loader(path) 54 | img = self.tfs(img) 55 | ret['input'] = img 56 | ret['path'] = path.rsplit("/")[-1] 57 | return ret 58 | 59 | def __len__(self): 60 | return len(self.imgs) 61 | -------------------------------------------------------------------------------- /experiments/clean.sh: -------------------------------------------------------------------------------- 1 | rm -rf debug* -------------------------------------------------------------------------------- /misc/template.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/distributed-pytorch-template/beeb4895e6fda5e6de474c87d636f716b439b7a2/misc/template.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from core.praser import init_obj 3 | 4 | def create_model(**cfg_model): 5 | """ create_model """ 6 | opt = cfg_model['opt'] 7 | logger = cfg_model['logger'] 8 | 9 | model_opt = opt['model']['which_model'] 10 | model_opt['args'].update(cfg_model) 11 | model = init_obj(model_opt, logger, default_file_name='models.model', init_type='Model') 12 | 13 | return model 14 | 15 | def define_network(logger, opt, network_opt): 16 | """ define network with weights initialization """ 17 | net = init_obj(network_opt, logger, default_file_name='models.network', init_type='Network') 18 | 19 | if opt['phase'] == 'train': 20 | logger.info('Network [{}] weights initialize using [{:s}] method.'.format(net.__class__.__name__, network_opt['args'].get('init_type', 'default'))) 21 | net.init_weights() 22 | return net 23 | 24 | 25 | def define_loss(logger, loss_opt): 26 | return init_obj(loss_opt, logger, default_file_name='models.loss', init_type='Loss') 27 | 28 | def define_metric(logger, metric_opt): 29 | return init_obj(metric_opt, logger, default_file_name='models.metric', init_type='Metric') -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch.nn as nn 3 | import torch 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | # class mse_loss(nn.Module): 11 | # def __init__(self) -> None: 12 | # super().__init__() 13 | # self.loss_fn = nn.MSELoss() 14 | # def forward(self, output, target): 15 | # return self.loss_fn(output, target) 16 | 17 | 18 | def mse_loss(output, target): 19 | return F.mse_loss(output, target) 20 | 21 | 22 | class FocalLoss(nn.Module): 23 | def __init__(self, gamma=2, alpha=None, size_average=True): 24 | super(FocalLoss, self).__init__() 25 | self.gamma = gamma 26 | self.alpha = alpha 27 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 28 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 29 | self.size_average = size_average 30 | 31 | def forward(self, input, target): 32 | if input.dim()>2: 33 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 34 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 35 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 36 | target = target.view(-1,1) 37 | 38 | logpt = F.log_softmax(input) 39 | logpt = logpt.gather(1,target) 40 | logpt = logpt.view(-1) 41 | pt = Variable(logpt.data.exp()) 42 | 43 | if self.alpha is not None: 44 | if self.alpha.type()!=input.data.type(): 45 | self.alpha = self.alpha.type_as(input.data) 46 | at = self.alpha.gather(0,target.data.view(-1)) 47 | logpt = logpt * Variable(at) 48 | 49 | loss = -1 * (1-pt)**self.gamma * logpt 50 | if self.size_average: return loss.mean() 51 | else: return loss.sum() 52 | 53 | -------------------------------------------------------------------------------- /models/metric.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | def mae(input, target): 4 | with torch.no_grad(): 5 | loss = nn.L1Loss() 6 | output = loss(input, target) 7 | return output -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from core.base_model import BaseModel 4 | from core.logger import LogTracker 5 | import copy 6 | class EMA(): 7 | def __init__(self, beta=0.9999): 8 | super().__init__() 9 | self.beta = beta 10 | def update_model_average(self, ma_model, current_model): 11 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 12 | old_weight, up_weight = ma_params.data, current_params.data 13 | ma_params.data = self.update_average(old_weight, up_weight) 14 | def update_average(self, old, new): 15 | if old is None: 16 | return new 17 | return old * self.beta + (1 - self.beta) * new 18 | 19 | class Model(BaseModel): 20 | def __init__(self, networks, losses, optimizers=None, ema_scheduler=None, **kwargs): 21 | ''' must to init BaseModel with kwargs ''' 22 | super(Model, self).__init__(**kwargs) 23 | 24 | ''' networks, dataloder, optimizers, losses, etc. ''' 25 | self.netG = networks[0] 26 | if ema_scheduler is not None: 27 | self.ema_scheduler = ema_scheduler 28 | self.netG_EMA = copy.deepcopy(self.netG) 29 | self.EMA = EMA(beta=self.ema_scheduler['ema_decay']) 30 | else: 31 | self.ema_scheduler = None 32 | 33 | ''' networks can be a list, and must convers by self.set_device function if using multiple GPU. ''' 34 | self.netG = self.set_device(self.netG, distributed=self.opt['distributed']) 35 | if self.ema_scheduler is not None: 36 | self.netG_EMA = self.set_device(self.netG_EMA, distributed=self.opt['distributed']) 37 | self.load_networks() 38 | 39 | self.optG = torch.optim.Adam(list(filter(lambda p: p.requires_grad, self.netG.parameters())), **optimizers[0]) 40 | self.optimizers.append(self.optG) 41 | self.resume_training() 42 | 43 | self.loss_fn = losses[0] 44 | 45 | ''' can rewrite in inherited class for more informations logging ''' 46 | self.train_metrics = LogTracker(*[m.__name__ for m in losses], phase='train') 47 | self.val_metrics = LogTracker(*[m.__name__ for m in losses], *[m.__name__ for m in self.metrics], phase='val') 48 | self.test_metrics = LogTracker(*[m.__name__ for m in losses], *[m.__name__ for m in self.metrics], phase='test') 49 | 50 | def set_input(self, data): 51 | ''' must use set_device in tensor ''' 52 | self.input = self.set_device(data['input']) 53 | self.path = data['path'] 54 | 55 | def get_current_visuals(self): 56 | dict = { 57 | 'input': self.input.detach()[0].float().cpu() 58 | ,'output': self.output.detach()[0].float().cpu() 59 | } 60 | return dict 61 | 62 | def save_current_results(self): 63 | self.results_dict = self.results_dict._replace(name=self.path, result=self.output.detach().float().cpu()) 64 | return self.results_dict._asdict() 65 | 66 | def train_step(self): 67 | self.netG.train() 68 | self.train_metrics.reset() 69 | for train_data in tqdm.tqdm(self.phase_loader): 70 | self.set_input(train_data) 71 | self.optG.zero_grad() 72 | self.output = self.netG(self.input) 73 | loss = self.loss_fn(self.output, self.input) 74 | loss.backward() 75 | self.optG.step() 76 | 77 | self.iter += self.batch_size 78 | self.writer.set_iter(self.epoch, self.iter, phase='train') 79 | self.train_metrics.update(self.loss_fn.__name__, loss.item()) 80 | self.writer.add_scalar(self.loss_fn.__name__, loss.item()) 81 | if self.iter % self.opt['train']['log_iter'] == 0: 82 | for key, value in self.train_metrics.result().items(): 83 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 84 | for key, value in self.get_current_visuals().items(): 85 | self.writer.add_image(key, value) 86 | if self.ema_scheduler is not None: 87 | if self.iter % self.ema_scheduler['ema_iter'] == 0 and self.iter > self.ema_scheduler['ema_start']: 88 | self.logger.info('Update the EMA model at the iter {:.0f}'.format(self.iter)) 89 | self.EMA.update_model_average(self.netG_EMA, self.netG) 90 | 91 | for scheduler in self.schedulers: 92 | scheduler.step() 93 | return self.train_metrics.result() 94 | 95 | def val_step(self): 96 | self.netG.eval() 97 | self.val_metrics.reset() 98 | with torch.no_grad(): 99 | for val_data in tqdm.tqdm(self.val_loader): 100 | self.set_input(val_data) 101 | self.output = self.netG(self.input) 102 | loss = self.loss_fn(self.output, self.input) 103 | 104 | self.iter += self.batch_size 105 | self.writer.set_iter(self.epoch, self.iter, phase='val') 106 | self.val_metrics.update(self.loss_fn.__name__, loss.item()) 107 | self.writer.add_scalar(self.loss_fn.__name__, loss.item()) 108 | for met in self.metrics: 109 | key, value = met.__name__, met(self.input, self.output) 110 | self.writer.add_scalar(key, value) 111 | self.val_metrics.update(key, value) 112 | for key, value in self.get_current_visuals().items(): 113 | self.writer.add_image(key, value) 114 | self.writer.save_images(self.save_current_results()) 115 | 116 | return self.val_metrics.result() 117 | 118 | def load_networks(self): 119 | """ save pretrained model and training state, which only do on GPU 0. """ 120 | if self.opt['distributed']: 121 | netG_label = self.netG.module.__class__.__name__ 122 | else: 123 | netG_label = self.netG.__class__.__name__ 124 | self.load_network(network=self.netG, network_label=netG_label, strict=False) 125 | if self.ema_scheduler is not None: 126 | self.load_network(network=self.netG_EMA, network_label=netG_label+'_ema', strict=False) 127 | 128 | def save_everything(self): 129 | """ load pretrained model and training state, optimizers and schedulers must be a list. """ 130 | if self.opt['distributed']: 131 | netG_label = self.netG.module.__class__.__name__ 132 | else: 133 | netG_label = self.netG.__class__.__name__ 134 | self.save_network(network=self.netG, network_label=netG_label) 135 | if self.ema_scheduler is not None: 136 | self.save_network(network=self.netG_EMA, network_label=netG_label+'_ema') 137 | self.save_training_state([self.optG], self.schedulers) 138 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from core.base_network import BaseNetwork 3 | class Network(BaseNetwork): 4 | def __init__(self, in_channels=3, **kwargs): 5 | super(Network, self).__init__(**kwargs) 6 | 7 | self.in_channels = in_channels 8 | cnums = 64 9 | self.down_net = nn.Sequential( 10 | Down2(in_channels, cnums), 11 | Down2(cnums, cnums*2), 12 | Down3(cnums*2, cnums*4), 13 | Down2(cnums*4, cnums*8), 14 | ) 15 | self.up_net = nn.Sequential( 16 | Up2(cnums*8, cnums*4), 17 | Up2(cnums*4, cnums*2), 18 | Up3(cnums*2, cnums*1), 19 | # Up2(cnums*1, cnums*1), 20 | nn.Upsample(scale_factor=2, mode='nearest', align_corners=None), 21 | nn.Conv2d(cnums*1, in_channels, kernel_size=3, stride=1, padding=1), 22 | nn.Tanh() 23 | ) 24 | 25 | def forward(self, _x): 26 | _x = self.down_net(_x) 27 | _x = self.up_net(_x) 28 | return _x 29 | 30 | 31 | class conv2DBatchNormRelu(nn.Module): 32 | def __init__( 33 | self, 34 | in_channels, 35 | n_filters, 36 | k_size, 37 | stride, 38 | padding, 39 | bias=True, 40 | dilation=1, 41 | with_bn=True, 42 | ): 43 | super(conv2DBatchNormRelu, self).__init__() 44 | 45 | conv_mod = nn.Conv2d(int(in_channels), 46 | int(n_filters), 47 | kernel_size=k_size, 48 | padding=padding, 49 | stride=stride, 50 | bias=bias, 51 | dilation=dilation, ) 52 | 53 | if with_bn: 54 | self.cbr_unit = nn.Sequential(conv_mod, 55 | nn.BatchNorm2d(int(n_filters)), 56 | nn.ReLU(inplace=True)) 57 | else: 58 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) 59 | 60 | def forward(self, inputs): 61 | outputs = self.cbr_unit(inputs) 62 | return outputs 63 | 64 | 65 | class Down2(nn.Module): 66 | def __init__(self, in_size, out_size): 67 | super(Down2, self).__init__() 68 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 2, 1) 69 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 70 | 71 | def forward(self, inputs): 72 | outputs = self.conv1(inputs) 73 | outputs = self.conv2(outputs) 74 | return outputs 75 | 76 | 77 | class Down3(nn.Module): 78 | def __init__(self, in_size, out_size): 79 | super(Down3, self).__init__() 80 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 2, 1) 81 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 82 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 83 | 84 | def forward(self, inputs): 85 | outputs = self.conv1(inputs) 86 | outputs = self.conv2(outputs) 87 | outputs = self.conv3(outputs) 88 | return outputs 89 | 90 | 91 | class Up2(nn.Module): 92 | def __init__(self, in_size, out_size): 93 | super(Up2, self).__init__() 94 | self.up = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) 95 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 96 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 97 | 98 | def forward(self, inputs): 99 | outputs = self.up(inputs) 100 | outputs = self.conv1(outputs) 101 | outputs = self.conv2(outputs) 102 | return outputs 103 | 104 | 105 | class Up3(nn.Module): 106 | def __init__(self, in_size, out_size): 107 | super(Up3, self).__init__() 108 | self.up = nn.Upsample(scale_factor=2, mode='nearest', align_corners=None) 109 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 110 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 111 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 112 | 113 | def forward(self, inputs): 114 | outputs = self.up(inputs) 115 | outputs = self.conv1(outputs) 116 | outputs = self.conv2(outputs) 117 | outputs = self.conv3(outputs) 118 | return outputs 119 | 120 | -------------------------------------------------------------------------------- /new_project.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from shutil import copytree, ignore_patterns 4 | 5 | # Code from https://github.com/victoresque/pytorch-template/blob/master/new_project.py 6 | # This script initializes new pytorch project with the template files. 7 | # Run `python new_project.py ../MyNewProject` then new project named 8 | # MyNewProject will be made 9 | current_dir = Path() 10 | assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory' 11 | assert len(sys.argv) == 2, 'Specify a name for the new project. Example: python3 new_project.py MyNewProject' 12 | 13 | project_name = Path(sys.argv[1]) 14 | target_dir = current_dir / project_name 15 | 16 | ignore = [".git", "experiments", "new_project.py", "__pycache__"] 17 | copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore)) 18 | print('New project initialized at', target_dir.absolute().resolve()) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6 2 | torchvision 3 | numpy 4 | pandas 5 | tqdm 6 | tensorboardX>=1.14 7 | 8 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | import torch 5 | import torch.multiprocessing as mp 6 | 7 | from core.logger import VisualWriter, InfoLogger 8 | import core.praser as Praser 9 | import core.util as Util 10 | from data import define_dataloader 11 | from models import create_model, define_network, define_loss, define_metric, define_optimizer, define_scheduler 12 | 13 | def main_worker(gpu, ngpus_per_node, opt): 14 | """ threads running on each GPU """ 15 | if 'local_rank' not in opt: 16 | opt['local_rank'] = opt['global_rank'] = gpu 17 | if opt['distributed']: 18 | torch.cuda.set_device(int(opt['local_rank'])) 19 | print('using GPU {} for training'.format(int(opt['local_rank']))) 20 | torch.distributed.init_process_group(backend = 'nccl', 21 | init_method = opt['init_method'], 22 | world_size = opt['world_size'], 23 | rank = opt['global_rank'], 24 | group_name='mtorch' 25 | ) 26 | '''set seed and and cuDNN environment ''' 27 | torch.backends.cudnn.enabled = True 28 | warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True') 29 | Util.set_seed(opt['seed']) 30 | 31 | ''' set logger ''' 32 | phase_logger = InfoLogger(opt) 33 | phase_writer = VisualWriter(opt, phase_logger) 34 | 35 | phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root'])) 36 | 37 | '''set networks and dataset''' 38 | phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test. 39 | networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']] 40 | 41 | ''' set metrics, loss, optimizer and schedulers ''' 42 | metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']] 43 | losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']] 44 | 45 | model = create_model( 46 | opt = opt, 47 | networks = networks, 48 | phase_loader = phase_loader, 49 | val_loader = val_loader, 50 | losses = losses, 51 | metrics = metrics, 52 | logger = phase_logger, 53 | writer = phase_writer 54 | ) 55 | 56 | phase_logger.info('Begin model {}.'.format(opt['phase'])) 57 | try: 58 | if opt['phase'] == 'train': 59 | model.train() 60 | else: 61 | model.test() 62 | finally: 63 | phase_writer.close() 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('-c', '--config', type=str, default='config/base.json', help='JSON file for configuration') 69 | parser.add_argument('-p', '--phase', type=str, choices=['train','test'], help='Run train or test', default='train') 70 | parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu') 71 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 72 | parser.add_argument('-d', '--debug', action='store_true') 73 | parser.add_argument('-P', '--port', default='21012', type=str) 74 | 75 | ''' parser configs ''' 76 | args = parser.parse_args() 77 | opt = Praser.parse(args) 78 | 79 | ''' cuda devices ''' 80 | gpu_str = ','.join(str(x) for x in opt['gpu_ids']) 81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str 82 | print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str)) 83 | 84 | ''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training''' 85 | # [Todo]: multi GPU on multi machine 86 | if opt['distributed']: 87 | ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count() 88 | opt['world_size'] = ngpus_per_node 89 | opt['init_method'] = 'tcp://127.0.0.1:'+ args.port 90 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) 91 | else: 92 | opt['world_size'] = 1 93 | main_worker(0, 1, opt) -------------------------------------------------------------------------------- /slurm/run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o experiments/slurm.log 3 | #SBATCH -J base 4 | #SBATCH -p dell 5 | #SBATCH --gres=gpu:4 6 | #SBATCH -c 16 7 | python run.py -gpu 0,1,2,3 8 | --------------------------------------------------------------------------------