├── .gitignore ├── LICENSE ├── README.md ├── configs ├── 20220223_cifar100.yml ├── 20220309_cifar10.yml ├── 20220531_cifar100_ddp.yaml ├── __init__.py └── config_loader.py ├── figs └── Torch-atom.png ├── main.py ├── main_ddp.py ├── run_ddp.sh └── src ├── __init__.py ├── controller.py ├── datasets ├── __init__.py ├── cifar.py ├── dataset_builder.py ├── dataset_config.yml ├── transform_builder.py └── transforms.py ├── losses ├── __init__.py ├── classification.py ├── loss_builder.py └── loss_wrapper.py ├── metrics ├── __init__.py ├── accuracy.py └── metric_builder.py ├── models ├── __init__.py ├── blocks │ └── __init__.py ├── mobilenetv2.py ├── model_builder.py ├── resnet.py ├── shufflenet.py ├── shufflenetv2.py └── vgg.py ├── optimizer ├── __init__.py ├── optimizer_builder.py ├── optimizer_config.yml └── optimizers.py ├── schemes ├── __init__.py ├── lr_schemes.py ├── scheme_builder.py └── scheme_config.yml ├── trainer.py └── utils ├── __init__.py ├── dist.py ├── logger.py ├── meter.py └── netio.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Eric Shaw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch-atom 2 | 3 | [TOC] 4 | 5 | ## Introduction 6 | 7 | A basic and simple training framework for pytorch, easy for extension. 8 | 9 | ![architecture figure](./figs/Torch-atom.png) 10 | 11 | ## Dependence 12 | 13 | - torch==1.7.0+cu110 (>= 1.6 for data distributed parallel) 14 | - torchvision==0.8.0 15 | - easydict==1.9 16 | - tensorboard==2.7 17 | - tensorboardX==2.4 18 | - PyYAML==5.3.1 19 | 20 | ## Features 21 | 22 | - Any module can be easily customized 23 | - Not abstract, easy to learn, develop and debug 24 | - With a lot of repetitive work reduction, it can be more easily control the training process in development and research 25 | - Friendly to multi-network interactive training, such as GAN, transfer learning, knowledge distillation, etc. 26 | - DDP training support. 27 | 28 | ## Train 29 | 30 | ```shell 31 | python main.py --model resnet18 --save_dir cifar100_resnet18 --config_path ./configs/cifar100.yml 32 | ``` 33 | 34 | for ddp training, you can just run the below command in the terminal 35 | 36 | ```shell 37 | sh run_ddp.sh 38 | ``` 39 | 40 | ## Results 41 | 42 | ### CIFAR100 43 | 44 | **All CIFAR100 log files** can be downloaded **in package** here: [pan.baidu](https://pan.baidu.com/s/15tY1a1gkOhhpaTdkCbwtXQ?pwd=3lp2) 45 | code:3lp2 46 | 47 | 48 | | Network | Accuracy | log | 49 | | :----------: | :------: | :----------------------------------------------------------: | 50 | | resnet18 | 76.46 | [pan.baidu](https://pan.baidu.com/s/1m73kIlUcg2k7cMeVvBCpag?pwd=ewnd) code: ewnd | 51 | | resnet34 | 77.23 | [pan.baidu](https://pan.baidu.com/s/1xqwaG2u-x5RFhYS4YA_EYA?pwd=dq4r) code: dq4r | 52 | | resnet50 | 76.82 | [pan.baidu](https://pan.baidu.com/s/1qCrTAo8Dj07qpfW3gcF6tA?pwd=1e62) code: 1e62 | 53 | | resnet101 | 77.32 | [pan.baidu](https://pan.baidu.com/s/1-dxxym_FBAOweNQCMnYVSQ?pwd=myfv) code: myfv | 54 | | vgg11_bn | 70.52 | [pan.baidu](https://pan.baidu.com/s/1J3axxkoAxenM3iHWOAw39A?pwd=2pun) code: 2pun | 55 | | vgg13_bn | 73.71 | [pan.baidu](https://pan.baidu.com/s/1idoXxVrUJOuGwOCcfjpAew?pwd=4vmm) code: 4vmm | 56 | | mobilenetV2 | 68.99 | [pan.baidu](https://pan.baidu.com/s/1MpXEHlpKD-lRrfISAnJr8Q?pwd=e93w) code: e93w | 57 | | shufflenet | 71.17 | [pan.baidu](https://pan.baidu.com/s/1ZjgyYoHU5_lIldFxvMcV6Q?pwd=lnvy) code: lnvy | 58 | | shufflenetV2 | 71.16 | [pan.baidu](https://pan.baidu.com/s/1PRUqSljNNNGsj-76FJh0Bg?pwd=vmi6) code: vmi6 | 59 | 60 | ### CIFAR10 61 | 62 | **All CIFAR10 log files** can be downloaded **in package** here: [pan.baidu](https://pan.baidu.com/s/1wKXd54FKl0irE8zi0BofGg?pwd=3iqz) 63 | code:3iqz 64 | 65 | | Network | Accuracy | log | 66 | | :----------: | :------: | :--: | 67 | | resnet18 | 94.92 | [pan.baidu](https://pan.baidu.com/s/1-x6WUNGectas1Mzc9U0g7Q?pwd=a20j) code: a20j | 68 | | resnet34 | 94.80 | [pan.baidu](https://pan.baidu.com/s/1sHMr2uumiwx13XjgO58O7g?pwd=q8h1) code: q8h1 | 69 | | resnet50 | 94.81 | [pan.baidu](https://pan.baidu.com/s/1R_DSUDOg39WDiwb7teW7bw?pwd=f3wr) code: f3wr | 70 | | resnet101 | 95.45 | [pan.baidu](https://pan.baidu.com/s/1YTQaNkIDMtEfGf1q3XINPA?pwd=d3i8) code: d3i8 | 71 | | vgg11_bn | 92.21 | [pan.baidu](https://pan.baidu.com/s/1ne6HaB8_tbIk_NTfRtvbyg?pwd=di45) code: di45 | 72 | | vgg13_bn | 93.74 | [pan.baidu](https://pan.baidu.com/s/1KoHkv7LMK1x-kJUrt7imbg?pwd=su1z) code: su1z | 73 | | mobilenetV2 | 90.92 | [pan.baidu](https://pan.baidu.com/s/1wXwYh6IWyoQKVZ9V1TxOow?pwd=todf) code: todf | 74 | | shufflenet | 92.06 | [pan.baidu](https://pan.baidu.com/s/198vPh8UydLoM-JPAsmqzWg?pwd=1xr2) code: 1xr2 | 75 | | shufflenetV2 | 91.61 | [pan.baidu](https://pan.baidu.com/s/1AFUa17uJWviZil05EYqB2Q?pwd=8swu) code: 8swu | 76 | 77 | ## Demo Train 78 | 79 | - import modules 80 | 81 | ```python 82 | from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, OptimizerBuilder, SchedulerBuilder, MetricBuilder 83 | from torch.utils.data import DataLoader 84 | from src import Controller 85 | from src.utils import AverageMeter 86 | ``` 87 | 88 | - Load your dataloader and transform 89 | 90 | ```python 91 | transform_name = 'cifar100_transform' # your transform function name 92 | dataset_name = 'CIFAR100' # your dataset class name 93 | train_transform, val_transform = TransformBuilder.load(transform_name) 94 | trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True) 95 | valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False) 96 | train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True) 97 | val_loader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True) 98 | ``` 99 | 100 | - Load model, loss wrapper, metrics, optimizer, schemes and controller 101 | 102 | ```python 103 | epochs = 20 104 | 105 | model = ModelBuilder.load("resnet18", num_classes=100) 106 | 107 | # if categorical weights: LossBuilder.load("CrossEntropyLoss", weight=torch.tensor([1.0, 2.0]).float()) 108 | loss_func1 = LossBuilder.load("CrossEntropyLoss") 109 | 110 | # Only one loss function, cross entropy, set its weight to 1.0 111 | loss_wrapper = LossWrapper([loss_func1], [1.0]) 112 | 113 | model = model.cuda() 114 | loss_wrapper = loss_wrapper.cuda() 115 | metric_functions = [MetricBuilder.load('Accuracy')] 116 | 117 | optimizer, optimizer_param = OptimizerBuilder.load('Adam', model.parameters(), lr=0.1) 118 | scheduler, scheduler_param = SchedulerBuilder.load("cosine_annealing_lr", optimizer, max_epoch=epochs) 119 | controller = Controller(loss_wrapper=loss_wrapper, model=model, optimizer=optimizer) 120 | ``` 121 | 122 | - Train ! 123 | 124 | ```python 125 | for epoch in range(epochs): 126 | # train 127 | model.train() 128 | loss_recorder = AverageMeter(type='scalar', name='total loss') 129 | loss_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["CrossEntropyLoss"]) 130 | metric_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["Accuracy"]) 131 | for (img, label) in train_loader: 132 | img = img.cuda() 133 | label = label.cuda() 134 | loss, loss_tuple, output_no_grad = controller.train_step(img, label) 135 | 136 | loss_recorder.update(loss.item(), img.size(0)) 137 | loss_list_recorder.update(loss_tuple, img.size(0)) 138 | 139 | metrics = tuple([func(output_no_grad, label) for func in metric_functions]) 140 | metric_list_recorder.update(metrics, img.size(0)) 141 | 142 | print(f"total loss:{loss_recorder.get_value()} loss_tuple:{loss_list_recorder} metrics:{metric_list_recorder}") 143 | 144 | # eval 145 | model.eval() 146 | # ... 147 | ``` 148 | 149 | 150 | 151 | ## Customize 152 | 153 | ### Customize Dataset 154 | 155 | - In `src/datasets` directory, define your customized mydataset.py like `cifar.py`,. 156 | - CIFAR class needs some parameters for initialization, such as `root`, `train`, `download`, which can be specified in `src/datasets/dataset_config.yml`. Something should be noticed that `transform` needs to be set in `transorms.py`, details can be found at *Customize Transform*. 157 | - In `src/datasets/dataset_builder.py`, please import your dataset class. For example, `MyDataset` class is defined in `mydataset.py`, thus `from .mydataset import *` in `dataset_builder.py` 158 | - In `configs/xxx.yml`, set `dataset.name` to `MyDataset` 159 | 160 | ### Customize Transform 161 | 162 | - In `src/datasets/transforms.py`, define your transform function, named `my_transform_func`, which returns `train_transform` and `val_transform` 163 | - In `configs/xxx.yml`, please set `dataset.transform_name` to `my_transform_func` 164 | 165 | ### Customize Model 166 | 167 | - In `src/models` directory, define your customized model, such as `my_model.py` , and define the module class `MyModel`. Please refer to `resnet.py` 168 | - In `src/models/model_builder.py`, import your model. `from .my_model import *` under the `try` process, and `from my_model import *` under the `except` process. It is just convenient for debugging. 169 | - In `configs/xxx.yml`, set the `model['name']` to `MyModel` 170 | 171 | ### Customize Loss Function 172 | 173 | - In `src/losses` directory, you can define customized loss Module, such as `CrossEntropyLoss` in `classification.py`. 174 | - Then import your loss Module in `loss_builder.py` 175 | - Maybe your model is supervised by multiple loss functions, which have different weights, so `LossWrapper` Module in `src/losses/loss_wrapper.py` may meet the requirement. 176 | - In `configs/xxx.yml`, please add your loss names and weights into `train.criterion.names` and `train/criterion/loss_weights` respectively. 177 | 178 | ### Customize Optimizer 179 | 180 | - In `src/optimizer` directory, `optimizers.py` can be found, please define your customized optimizer here. For example, `SGD` and `Adam` have already defined, `parameters` and `lr` should be specified, and other params need to be specifed by `*args, **kwargs`. Please refer to 181 | 182 | ```python 183 | # src/optimizer/optimizers.py 184 | def SGD(parameters, lr, *args, **kwargs) -> optim.Optimizer: 185 | optimizer = optim.SGD(parameters, lr=lr, *args, **kwargs) 186 | return optimizer 187 | ``` 188 | 189 | - For other parameters, such as `weight_decay`, can be set in `src/optimizer/optimizer_config.yml`. Please refer to the below yaml, and it is ok for `5e-4` format, we transform it in `src/optimizer/optimizer_builder.py`. 190 | 191 | ```yaml 192 | # src/optimizer/optimizer_config.yml 193 | SGD: 194 | momentum: 0.9 195 | weight_decay: 5e-4 196 | dampening: 0 197 | nesterov: False 198 | ``` 199 | 200 | - In `configs/xxx.yml`, set the `train['lr']`, and set the `train['optimizer']` to `SGD` 201 | 202 | ### Customize Schemes 203 | 204 | - In `src/schemes/lr_schemes.py`, define your learning rate scheme function, named `my_scheduler`, which requires some params, such as `optimizer`, `epochs` and so on. 205 | - Other params can be specified easily in `src/schemes/scheme_config.yml` 206 | - In `configs/xxx.yml`, set the `train.schedule` to `my_scheduler` 207 | 208 | ### Customize Metrics 209 | 210 | - In `src/metrics/` directory, define your metric, such as `Accuracy` in `accuracy.py`, which computes the metric of predictions and target and returns an metric scalar 211 | - Import your metric in `metric_builder.py`, for example, `from .accuracy import *` 212 | - Multiple metrics are supported, in `configs/xxx.yml`, add your metrics into `train.metric.names`. While training model, the strategy of saving checkpoint refers to the `train.metrics.key_metric_name` in `configs.xxx.yml`, more details can be found at *Customize Checkpoint Saving Strategy* 213 | 214 | ### Customize Training and Validation Procedure for One Batch 215 | 216 | - In `src/controller.py`, please feel free to build your training end validation step 217 | - Training step returns `loss`, `loss_tuple` and `output_no_grad`, where `loss_tuple` and `output_no_grad` only involve in logging, whether `loss` has a gradient or not depends on you. 218 | 219 | ### Customize Checkpoint Saving Strategy 220 | 221 | - After training epoch, validation epoch will be performed in general. Torch-atom's NetIO in `src/utils/netio.py` will save the best state dict according to `key_metric_name` and `strategy` in `configs/xxx.yml` 222 | - Of course, checkpoint can be saved each `save_freq` epoch, which can be set in `configs/xxx.yml` as well 223 | 224 | ## Change Log 225 | 226 | - 2202.6.2 DDP support for training 227 | 228 | ## Todo 229 | 230 | - [x] DDP training 231 | - [ ] More experiment results 232 | - [ ] More widely-used datasets and models 233 | - [ ] Some visualization code for analysis 234 | - bad case analysis 235 | - data augmentation visualization 236 | - ... 237 | 238 | 239 | 240 | ## Acknowledgement 241 | 242 | Torch-atom got ideas and developed based on the following projects: 243 | 244 | [open-mmlab/mmclassification](https://github.com/open-mmlab/mmclassification) 245 | 246 | [weiaicunzai/pytorch-cifar100](https://github.com/weiaicunzai/pytorch-cifar100) 247 | 248 | 249 | 250 | ## Citation 251 | 252 | If you find this project useful in your research, please consider cite: 253 | 254 | ``` 255 | @misc{2022torchatom, 256 | title={Torch-atom: A basic and simple training framework for pytorch}, 257 | author={Baitan Shao}, 258 | howpublished = {\url{https://github.com/shaoeric/torch-atom}}, 259 | year={2022} 260 | } 261 | ``` 262 | 263 | ## License 264 | 265 | [The MIT License | Open Source Initiative](https://opensource.org/licenses/MIT) 266 | 267 | 268 | 269 | ## Finally 270 | 271 | Please feel free to submit issues, :) -------------------------------------------------------------------------------- /configs/20220223_cifar100.yml: -------------------------------------------------------------------------------- 1 | name: "cifar100_classification" 2 | 3 | environment: 4 | cuda: 5 | flag: True 6 | seed: 42 7 | 8 | dataset: 9 | name: "CIFAR100" 10 | transform_name: "cifar100_transform" 11 | 12 | model: 13 | name: "resnet34" 14 | ckpt: "" 15 | resume: False 16 | strict: True 17 | num_classes: 100 18 | 19 | train: 20 | epochs: 200 21 | start_epoch: 1 22 | batch_size: 128 23 | lr: 0.1 24 | optimizer: "SGD" # optimizer params can be configured at src/optimizer/optimizer_config.yml 25 | schedule: "multi_step_lr" # scheduler params can be configured at src/schemes/scheme_comfig.yml 26 | criterion: 27 | names: ["CrossEntropyLoss"] 28 | loss_weights: [1.0] 29 | metric: 30 | names: ["Accuracy"] 31 | key_metric_name: "Accuracy" # model should be saved if key_metric_name meets best score and strategy is not "none" 32 | strategy: "max" # saving strategy of key_metric_name, choises: ["max", "min", "none"] 33 | val_freq: 1 34 | 35 | output: 36 | ckpt_root: "./ckpt/" 37 | save_dir: "cifar100_resnet32" 38 | save_freq: 20 39 | log_dir: "./logs/" 40 | tensorboard: True 41 | log_step_freq: -1 # greater or equal than 0, print log each log_step_freq step 42 | log_epoch_freq: 1 43 | -------------------------------------------------------------------------------- /configs/20220309_cifar10.yml: -------------------------------------------------------------------------------- 1 | name: "cifar10_classification" 2 | 3 | environment: 4 | cuda: 5 | flag: True 6 | 7 | seed: 42 8 | 9 | dataset: 10 | name: "CIFAR10" 11 | transform_name: "cifar10_transform" 12 | 13 | model: 14 | name: "resnet34" 15 | ckpt: "" 16 | resume: False 17 | strict: True 18 | num_classes: 10 19 | 20 | train: 21 | epochs: 200 22 | start_epoch: 1 23 | batch_size: 128 24 | lr: 0.1 25 | optimizer: "SGD" # optimizer params can be configured at src/optimizer/optimizer_config.yml 26 | schedule: "multi_step_lr" # scheduler params can be configured at src/schemes/scheme_comfig.yml 27 | criterion: 28 | names: ["CrossEntropyLoss"] 29 | loss_weights: [1.0] 30 | metric: 31 | names: ["Accuracy"] 32 | key_metric_name: "Accuracy" # model should be saved if key_metric_name meets best score and strategy is not "none" 33 | strategy: "max" # saving strategy of key_metric_name, choises: ["max", "min", "none"] 34 | val_freq: 1 35 | 36 | output: 37 | ckpt_root: "./ckpt/" 38 | save_dir: "cifar10_resnet32" 39 | save_freq: 20 40 | log_dir: "./logs/" 41 | tensorboard: True 42 | log_step_freq: -1 # greater or equal than 0, print log each log_step_freq step 43 | log_epoch_freq: 1 44 | -------------------------------------------------------------------------------- /configs/20220531_cifar100_ddp.yaml: -------------------------------------------------------------------------------- 1 | name: "cifar100_classification_ddp" 2 | 3 | environment: 4 | cuda: 5 | flag: True 6 | seed: 42 7 | ddp: False 8 | local_rank: -1 9 | num_gpu: 0 10 | 11 | dataset: 12 | name: "CIFAR100" 13 | transform_name: "cifar100_transform" 14 | 15 | model: 16 | name: "resnet34" 17 | ckpt: "" 18 | resume: False 19 | strict: True 20 | num_classes: 100 21 | 22 | train: 23 | epochs: 200 24 | start_epoch: 1 25 | batch_size: 128 # batch_size for each gpu 26 | lr: 0.2 # global lr, if lr=0.1 trained on one node and one gpu, then the lr should be scaled, the scaling factor canbe the number of gpus 27 | workers: 20 28 | optimizer: "SGD" # optimizer params can be configured at src/optimizer/optimizer_config.yml 29 | schedule: "multi_step_lr" # scheduler params can be configured at src/schemes/scheme_comfig.yml 30 | criterion: 31 | names: ["CrossEntropyLoss"] 32 | loss_weights: [1.0] 33 | metric: 34 | names: ["Accuracy"] 35 | key_metric_name: "Accuracy" # model should be saved if key_metric_name meets best score and strategy is not "none" 36 | strategy: "max" # saving strategy of key_metric_name, choises: ["max", "min", "none"] 37 | val_freq: 1 38 | 39 | output: 40 | ckpt_root: "./ckpt/" 41 | save_dir: "cifar100_resnet32_ddp" 42 | save_freq: 20 43 | log_dir: "./logs/" 44 | tensorboard: True 45 | log_step_freq: -1 # greater or equal than 0, print log each log_step_freq step 46 | log_epoch_freq: 1 47 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_loader import ConfigLoader -------------------------------------------------------------------------------- /configs/config_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from easydict import EasyDict 4 | 5 | 6 | __all__ = [ 7 | "ConfigLoader" 8 | ] 9 | 10 | 11 | class ConfigLoader: 12 | def __init__(self) -> None: 13 | self.PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | @staticmethod 16 | def load(config_filepath: str) -> dict: 17 | with open(config_filepath, 'r') as f: 18 | data = yaml.load(f, Loader=yaml.FullLoader) 19 | return EasyDict(data) 20 | 21 | 22 | if __name__ == '__main__': 23 | config = ConfigLoader.load("configs/20220223_cifar100.yml") 24 | print(config) 25 | print(config['train']['criterion']['weights']) -------------------------------------------------------------------------------- /figs/Torch-atom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoeric/torch-atom/7688fc38c0d19fe4d13a9773115df911ffe6eaaa/figs/Torch-atom.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from configs import ConfigLoader 3 | from datetime import datetime 4 | from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, NetIO, Trainer 5 | import argparse 6 | import numpy as np 7 | import os 8 | from torch.utils.data import DataLoader 9 | 10 | def set_seed(seed): 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | 16 | 17 | def prepare_environment(args): 18 | config = ConfigLoader.load(args.config_path.replace('\n', '').replace('\r', '')) 19 | date = datetime.now().strftime("%Y%m%d") 20 | if args.save_dir is not None: 21 | config.output["save_dir"] = args.save_dir 22 | config.output["save_dir"] = "{}_{}".format(date, config.output["save_dir"]) 23 | 24 | config.model['name'] = args.model 25 | 26 | seed = config.environment['seed'] 27 | set_seed(seed) 28 | 29 | if config.environment.cuda.flag: 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.deterministic = True 32 | return config 33 | 34 | def build_dataloader(config): 35 | batch_size = config.train['batch_size'] 36 | transform_name = config.dataset['transform_name'] 37 | dataset_name = config.dataset['name'] 38 | train_transform, val_transform = TransformBuilder.load(transform_name) 39 | trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True) 40 | valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False) 41 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) 42 | val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) 43 | return (train_loader, trainset_config), (val_loader, valset_config) 44 | 45 | 46 | def build_trainer(config): 47 | netio = NetIO(config) 48 | 49 | model = ModelBuilder.load(config.model['name'], num_classes=config.model['num_classes']) 50 | if config.model['resume']: 51 | model = netio.load_file(model, config.model['ckpt']) 52 | 53 | loss_func1 = LossBuilder.load("CrossEntropyLoss") 54 | loss_wrapper = LossWrapper([loss_func1], [config.train.criterion['loss_weights']]) 55 | 56 | if config.environment.cuda.flag: 57 | model = model.cuda() 58 | loss_wrapper = loss_wrapper.cuda() 59 | 60 | trainer = Trainer(config=config, model=model, wrapper=loss_wrapper, ioer=netio) 61 | return trainer 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--config_path', type=str, default="./configs/20220223_cifar100.yml") 66 | parser.add_argument('--model', type=str, default='resnet34') 67 | parser.add_argument('--save_dir', type=str) 68 | args = parser.parse_args() 69 | 70 | config = prepare_environment(args) 71 | 72 | start_epoch = config.train['start_epoch'] 73 | max_epoch = config.train['epochs'] + 1 74 | 75 | trainer = build_trainer(config) 76 | (train_loader, trainset_config), (val_loader, valset_config) = build_dataloader(config) 77 | 78 | if trainer.logger is not None: 79 | trainer.logger.info(trainset_config) 80 | trainer.logger.info(valset_config) 81 | trainer.logger.info(config.model) 82 | trainer.logger.info(config.train) 83 | trainer.logger.info(config.output) 84 | 85 | for epoch in range(start_epoch, max_epoch): 86 | trainer.train(epoch, train_loader) 87 | trainer.validate(epoch, val_loader) 88 | 89 | trainer.logger.info("best metric: {}".format(trainer.ioer.get_best_score())) 90 | 91 | if __name__ == '__main__': 92 | main() 93 | 94 | -------------------------------------------------------------------------------- /main_ddp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.distributed import DistributedSampler 5 | import argparse 6 | import numpy as np 7 | import os 8 | from configs import ConfigLoader 9 | from datetime import datetime 10 | from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, NetIO, Trainer 11 | 12 | 13 | def set_seed(seed): 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def prepare_environment(args): 21 | config = ConfigLoader.load(args.config_path.replace('\n', '').replace('\r', '')) 22 | date = datetime.now().strftime("%Y%m%d") 23 | if args.save_dir is not None: 24 | config.output["save_dir"] = args.save_dir 25 | config.output["save_dir"] = "{}_{}".format(date, config.output["save_dir"]) 26 | 27 | config.model['name'] = args.model 28 | if args.seed is not None: 29 | config.environment['seed'] = args.seed 30 | seed = config.environment['seed'] 31 | set_seed(seed) 32 | 33 | config.environment.local_rank = args.local_rank 34 | config.environment.num_gpu = torch.cuda.device_count() 35 | torch.backends.cudnn.benchmark = True 36 | torch.backends.cudnn.deterministic = True 37 | return config 38 | 39 | def build_dataloader(config): 40 | batch_size = config.train['batch_size'] # 128 41 | num_gpu = config.environment['num_gpu'] 42 | num_workers = config.train['workers'] 43 | num_workers = int((num_workers + num_gpu - 1) / num_gpu) 44 | 45 | transform_name = config.dataset['transform_name'] 46 | dataset_name = config.dataset['name'] 47 | train_transform, val_transform = TransformBuilder.load(transform_name) 48 | 49 | trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True) 50 | valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False) 51 | 52 | train_sampler = DistributedSampler(trainset) 53 | val_sampler = DistributedSampler(valset) 54 | print("build dataloader", config.environment.local_rank, batch_size) 55 | 56 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=train_sampler) 57 | val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, sampler=val_sampler) 58 | return (train_loader, trainset_config), (val_loader, valset_config) 59 | 60 | 61 | def build_trainer(config): 62 | local_rank = config.environment.local_rank 63 | netio = NetIO(config) if local_rank == 0 else None 64 | model = ModelBuilder.load(config.model['name'], num_classes=config.model['num_classes']) 65 | if config.model['resume']: 66 | model = netio.load_file(model, config.model['ckpt']) 67 | 68 | loss_func1 = LossBuilder.load("CrossEntropyLoss") 69 | loss_wrapper = LossWrapper([loss_func1], [config.train.criterion['loss_weights']]) 70 | 71 | # 72 | model = model.cuda(local_rank) 73 | loss_wrapper = loss_wrapper.cuda(local_rank) 74 | 75 | trainer = Trainer(config=config, model=model, wrapper=loss_wrapper, ioer=netio, ddp=True) 76 | return trainer 77 | 78 | def main(): 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--config_path', type=str, default="./configs/20220223_cifar100.yml") 81 | parser.add_argument('--model', type=str, default='resnet18') 82 | parser.add_argument('--save_dir', type=str) 83 | parser.add_argument('--seed', type=int, default=None) 84 | parser.add_argument('--local_rank', default=-1, type=int, 85 | help='node rank for distributed training') 86 | args = parser.parse_args() 87 | 88 | config = prepare_environment(args) 89 | num_gpu = config.environment.num_gpu 90 | main_worker(args.local_rank, num_gpu, args, config) 91 | 92 | 93 | def main_worker(local_rank, num_gpus, args, config): 94 | start_epoch = config.train['start_epoch'] 95 | max_epoch = config.train['epochs'] + 1 96 | # config.train['batch_size'] = config.train['batch_size'] // num_gpus 97 | # config.train['lr'] = config.train['lr'] / num_gpus 98 | # config.train['lr'] = config.train['lr'] * num_gpus 99 | 100 | dist.init_process_group(backend='nccl', init_method='env://') 101 | torch.cuda.set_device(args.local_rank) 102 | 103 | trainer = build_trainer(config) 104 | (train_loader, trainset_config), (val_loader, valset_config) = build_dataloader(config) 105 | 106 | if trainer.logger is not None and local_rank == 0: 107 | trainer.logger.info(trainset_config) 108 | trainer.logger.info(valset_config) 109 | trainer.logger.info(config.model) 110 | trainer.logger.info(config.train) 111 | trainer.logger.info(config.output) 112 | trainer.logger.info(config.environment) 113 | 114 | for epoch in range(start_epoch, max_epoch): 115 | train_loader.sampler.set_epoch(epoch) 116 | 117 | trainer.train(epoch, train_loader) 118 | trainer.validate(epoch, val_loader) 119 | 120 | if trainer.logger is not None and local_rank == 0: 121 | trainer.logger.info("best metric: {}".format(trainer.ioer.get_best_score())) 122 | 123 | if __name__ == '__main__': 124 | main() 125 | 126 | -------------------------------------------------------------------------------- /run_ddp.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 main_ddp.py --config_path configs/20220531_cifar100_ddp.yaml --model resnet34 --save_dir resnet34_cifar100_ddp --seed 42 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DatasetBuilder, TransformBuilder 2 | from .losses import LossBuilder, LossWrapper 3 | from .metrics import MetricBuilder 4 | from .optimizer import OptimizerBuilder 5 | from .schemes import SchedulerBuilder 6 | from .models import ModelBuilder 7 | from .utils import NetIO, AverageMeter, LoggerBuilder 8 | from .trainer import Trainer 9 | 10 | 11 | __all__ = [ 12 | 'DatasetBuilder', 13 | 'TransformBuilder', 14 | 'LossBuilder', 15 | 'LossWrapper', 16 | 'MetricBuilder', 17 | 'OptimizerBuilder', 18 | 'SchedulerBuilder', 19 | 'ModelBuilder', 20 | 'NetIO', 21 | 'AverageMeter', 22 | 'LoggerBuilder', 23 | 'Trainer', 24 | ] -------------------------------------------------------------------------------- /src/controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import optim 4 | from src.losses import LossWrapper 5 | from typing import List 6 | 7 | 8 | __all__ = ["Controller"] 9 | 10 | 11 | class Controller(object): 12 | def __init__(self, 13 | loss_wrapper: LossWrapper, 14 | model: nn.Module, 15 | optimizer: optim.Optimizer 16 | ) -> None: 17 | self.loss_wrapper = loss_wrapper 18 | self.model = model 19 | self.optimizer = optimizer 20 | 21 | 22 | def train_step(self, input: torch.Tensor, label: torch.Tensor, *args, **kwargs): 23 | """ 24 | Define the training process for the model, easy for extension for multiple models 25 | 26 | Args: 27 | input (torch.Tensor): input tensor of the model 28 | label (torch.Tensor): ground truth of the input tensor 29 | 30 | Returns: 31 | loss (torch.FloatTensor): train loss 32 | loss_tuple (tuple[torch.FloatTensor]): a tuple of loss item 33 | output_no_grad (torch.FloatTensor): model output without grad 34 | """ 35 | self.optimizer.zero_grad() 36 | output = self.model(input) 37 | loss, loss_tuple, output_no_grad = self.loss_wrapper(output, [label]) 38 | loss.backward() 39 | self.optimizer.step() 40 | return loss, loss_tuple, output_no_grad 41 | 42 | 43 | def validate_step(self, input: torch.Tensor, label: torch.Tensor, *args, **kwargs): 44 | """ 45 | Define the validation process for the model 46 | 47 | Args: 48 | input (torch.Tensor): input tensor for the model 49 | label (torch.Tensor): ground truth for the input tensor 50 | 51 | Returns: 52 | loss (torch.FloatTensor): validation loss item, without grad 53 | loss_tuple (tuple[torch.FloatTensor]): a tuple of loss item 54 | output_no_grad (torch.FloatTensor): model output without grad 55 | """ 56 | output = self.model(input) 57 | loss, loss_tuple, output_no_grad = self.loss_wrapper(output, [label]) 58 | return loss.detach(), loss_tuple, output_no_grad -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_builder import DatasetBuilder 2 | from .transform_builder import TransformBuilder 3 | 4 | __all__ = ['DatasetBuilder', 'TransformBuilder'] -------------------------------------------------------------------------------- /src/datasets/cifar.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import os.path 4 | import numpy as np 5 | import pickle 6 | from torchvision.datasets.vision import VisionDataset 7 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 8 | 9 | 10 | 11 | class CIFAR10(VisionDataset): 12 | """`CIFAR10 `_ Dataset. 13 | 14 | Args: 15 | root (string): Root directory of dataset where directory 16 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 17 | train (bool, optional): If True, creates dataset from training set, otherwise 18 | creates from test set. 19 | transform (callable, optional): A function/transform that takes in an PIL image 20 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 21 | target_transform (callable, optional): A function/transform that takes in the 22 | target and transforms it. 23 | download (bool, optional): If true, downloads the dataset from the internet and 24 | puts it in root directory. If dataset is already downloaded, it is not 25 | downloaded again. 26 | 27 | """ 28 | base_folder = 'cifar-10-batches-py' 29 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 30 | filename = "cifar-10-python.tar.gz" 31 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 32 | train_list = [ 33 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 34 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 35 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 36 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 37 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 38 | ] 39 | 40 | test_list = [ 41 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 42 | ] 43 | meta = { 44 | 'filename': 'batches.meta', 45 | 'key': 'label_names', 46 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 47 | } 48 | 49 | def __init__(self, root=None, train=True, transform=None, target_transform=None, download=False, trainset_sample_ratio=1.0, seed=1, *args, **kwargs): 50 | super(CIFAR10, self).__init__(root, transform=transform, 51 | target_transform=target_transform) 52 | 53 | self.train = train # training set or test set 54 | 55 | 56 | if download: 57 | self.download() 58 | 59 | if not self._check_integrity(): 60 | raise RuntimeError('Dataset not found or corrupted.' + 61 | ' You can use download=True to download it') 62 | 63 | if self.train: 64 | downloaded_list = self.train_list 65 | self.sample_ratio = trainset_sample_ratio 66 | else: 67 | downloaded_list = self.test_list 68 | self.sample_ratio = 1.0 69 | 70 | self.data = [] 71 | self.targets = [] 72 | 73 | # now load the picked numpy arrays 74 | for file_name, checksum in downloaded_list: 75 | file_path = os.path.join(self.root, self.base_folder, file_name) 76 | with open(file_path, 'rb') as f: 77 | entry = pickle.load(f, encoding='latin1') 78 | self.data.append(entry['data']) 79 | if 'labels' in entry: 80 | self.targets.extend(entry['labels']) 81 | else: 82 | self.targets.extend(entry['fine_labels']) 83 | 84 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 85 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 86 | 87 | print("dataset: {} target: {}".format(len(self.data), len(self.targets))) 88 | 89 | if self.sample_ratio < 1.0: 90 | np.random.seed(seed) 91 | sample_idx = np.random.randint(0, len(self.data), int(len(self.data) * self.sample_ratio)) 92 | self.data = self.data[sample_idx] 93 | self.targets = np.array(self.targets)[sample_idx].tolist() 94 | 95 | print("sample dataset: {} target:{}".format(len(self.data), len(self.targets))) 96 | self._load_meta() 97 | 98 | def _load_meta(self): 99 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 100 | if not check_integrity(path, self.meta['md5']): 101 | raise RuntimeError('Dataset metadata file not found or corrupted.' + 102 | ' You can use download=True to download it') 103 | with open(path, 'rb') as infile: 104 | data = pickle.load(infile, encoding='latin1') 105 | self.classes = data[self.meta['key']] 106 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | 113 | Returns: 114 | tuple: (image, target) where target is index of the target class. 115 | """ 116 | img, target = self.data[index], self.targets[index] 117 | 118 | # doing this so that it is consistent with all other datasets 119 | # to return a PIL Image 120 | img = Image.fromarray(img) 121 | 122 | if self.transform is not None: 123 | img = self.transform(img) 124 | 125 | if self.target_transform is not None: 126 | target = self.target_transform(target) 127 | 128 | data = { 129 | 'image': img, 130 | 'label': target 131 | } 132 | return data 133 | 134 | def __len__(self): 135 | return len(self.data) 136 | 137 | def _check_integrity(self): 138 | root = self.root 139 | for fentry in (self.train_list + self.test_list): 140 | filename, md5 = fentry[0], fentry[1] 141 | fpath = os.path.join(root, self.base_folder, filename) 142 | if not check_integrity(fpath, md5): 143 | return False 144 | return True 145 | 146 | def download(self): 147 | if self._check_integrity(): 148 | print('Files already downloaded and verified') 149 | return 150 | download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) 151 | 152 | def extra_repr(self): 153 | return "Split: {}".format("Train" if self.train is True else "Test") 154 | 155 | 156 | class CIFAR100(CIFAR10): 157 | """`CIFAR100 `_ Dataset. 158 | 159 | This is a subclass of the `CIFAR10` Dataset. 160 | """ 161 | base_folder = 'cifar-100-python' 162 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 163 | filename = "cifar-100-python.tar.gz" 164 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 165 | train_list = [ 166 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 167 | ] 168 | 169 | test_list = [ 170 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 171 | ] 172 | meta = { 173 | 'filename': 'meta', 174 | 'key': 'fine_label_names', 175 | 'md5': '7973b15100ade9c7d40fb424638fde48', 176 | } -------------------------------------------------------------------------------- /src/datasets/dataset_builder.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .cifar import * 3 | except: 4 | from sys import path 5 | path.append('../datasets') 6 | from cifar import * 7 | 8 | import os 9 | 10 | import yaml 11 | from easydict import EasyDict 12 | import re 13 | 14 | __all__ = [ 15 | "DatasetBuilder" 16 | ] 17 | 18 | 19 | DATASET_CONFIG = os.path.join(os.path.dirname(__file__), "dataset_config.yml") 20 | 21 | 22 | def parse_dataset_config(): 23 | loader = yaml.SafeLoader 24 | loader.add_implicit_resolver( 25 | u'tag:yaml.org,2002:float', 26 | re.compile(u'''^(?: 27 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 28 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 29 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 30 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 31 | |[-+]?\\.(?:inf|Inf|INF) 32 | |\\.(?:nan|NaN|NAN))$''', re.X), 33 | list(u'-+0123456789.')) 34 | with open(DATASET_CONFIG, 'r') as f: 35 | data = yaml.load(f, Loader=loader) 36 | return EasyDict(data) 37 | 38 | 39 | class DatasetBuilder: 40 | def __init__(self) -> None: 41 | pass 42 | 43 | @staticmethod 44 | def load(dataset_name: str = 'CIFAR10', 45 | *args, **kwargs): 46 | config = parse_dataset_config()[dataset_name] 47 | config.update(kwargs) 48 | dataset = globals()[dataset_name](*args, **config) 49 | return dataset, {dataset_name: config} 50 | 51 | 52 | if __name__ == '__main__': 53 | from torchvision.transforms import transforms 54 | from torch.utils.data import DataLoader 55 | 56 | train_transform = transforms.Compose([ 57 | transforms.RandomCrop(32, 4), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 61 | ]) 62 | val_transform = transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 65 | ]) 66 | trainset, trainset_config = DatasetBuilder.load(dataset_name="CIFAR10", transform=train_transform, train=True) 67 | valset, valset_config = DatasetBuilder.load(dataset_name="CIFAR10", transform=val_transform, train=False) 68 | print(trainset_config) 69 | print(valset_config) 70 | val_loader = DataLoader(valset, batch_size=16, shuffle=False, num_workers=4) 71 | for img, label in val_loader: 72 | print(img.shape, label.shape) # torch.Size([16, 3, 32, 32]) torch.Size([16]) 73 | break -------------------------------------------------------------------------------- /src/datasets/dataset_config.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | CIFAR10: # name needs to be matched with dataset class name 4 | root: /4T/lc/BNN_NoBN-main/data/cifar10 5 | trainset_sample_ratio: 1.0 # for few shot learning 6 | seed: 1 # random seed for trainset sample 7 | download: True 8 | 9 | 10 | CIFAR100: # name needs to be matched with dataset class name 11 | root: /4T/lc/BNN_NoBN-main/data/cifar100 12 | trainset_sample_ratio: 1.0 # for few shot learning 13 | seed: 1 # random seed for trainset sample 14 | download: True -------------------------------------------------------------------------------- /src/datasets/transform_builder.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .transforms import * 3 | except: 4 | from transforms import * 5 | 6 | 7 | class TransformBuilder: 8 | def __init__(self) -> None: 9 | pass 10 | 11 | @staticmethod 12 | def load(name: str = 'cifar100_transform'): 13 | transform_func_name = globals()[name] 14 | train_transform, val_transform = transform_func_name() 15 | return train_transform, val_transform 16 | 17 | 18 | if __name__ == '__main__': 19 | train_transform, val_transform = TransformBuilder.load('cifar100_transform') 20 | print(train_transform) 21 | print(val_transform) -------------------------------------------------------------------------------- /src/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms 2 | 3 | __all__ = ['cifar100_transform', 'cifar10_transform'] 4 | 5 | 6 | def cifar100_transform(): 7 | mean = [0.5071, 0.4866, 0.4409] 8 | std = [0.2675, 0.2565, 0.2761] 9 | train_transform = transforms.Compose([ 10 | transforms.RandomCrop(32, 4), 11 | transforms.RandomHorizontalFlip(), 12 | transforms.ToTensor(), 13 | transforms.Normalize(mean, std) 14 | ]) 15 | val_transform = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean, std) 18 | ]) 19 | return train_transform, val_transform 20 | 21 | 22 | def cifar10_transform(): 23 | mean = [0.4914, 0.4822, 0.4465] 24 | std = [0.2470, 0.2435, 0.2616] 25 | train_transform = transforms.Compose([ 26 | transforms.RandomCrop(32, 4), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean, std) 30 | ]) 31 | val_transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize(mean, std) 34 | ]) 35 | return train_transform, val_transform -------------------------------------------------------------------------------- /src/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # @author shaoeric 2 | # @create: 2022.2.23 3 | 4 | from .loss_builder import LossBuilder 5 | from .loss_wrapper import LossWrapper 6 | 7 | __all__ = ['LossBuilder', 'LossWrapper'] 8 | 9 | 10 | # In src/losses, define the customized loss function with .py script 11 | # In src/losses/loss_wrapper.py, initialize all loss functions with `LossLoader` in loss_loader.py and their corresponding loss weights, and implement how to compute the total loss in `forward` 12 | -------------------------------------------------------------------------------- /src/losses/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = [ 6 | "CrossEntropyLoss" 7 | ] 8 | 9 | class CrossEntropyLoss(nn.Module): 10 | def __init__(self, *args, **kwargs) -> None: 11 | super(CrossEntropyLoss, self).__init__() 12 | self.loss_func = nn.CrossEntropyLoss(*args, **kwargs) 13 | 14 | def forward(self, input, label): 15 | loss = self.loss_func(input, label) 16 | return loss -------------------------------------------------------------------------------- /src/losses/loss_builder.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .classification import * 3 | except: 4 | from sys import path 5 | path.append("../losses") 6 | from classification import * 7 | 8 | import torch.nn as nn 9 | 10 | __all__ = [ 11 | "LossBuilder" 12 | ] 13 | 14 | class LossBuilder: 15 | def __init__(self) -> None: 16 | pass 17 | 18 | @staticmethod 19 | def load(loss_name, *args, **kwargs) -> nn.Module: 20 | loss_func = globals()[loss_name] 21 | return loss_func(*args, **kwargs) 22 | 23 | 24 | if __name__ == '__main__': 25 | import torch 26 | loss_func = LossBuilder.load("CrossEntropyLoss", weight=torch.tensor(list(range(10))).float()) 27 | x = torch.randn(size=(2, 10)) # input 28 | y = torch.tensor([0, 1]).long() # label 29 | loss = loss_func.forward(x, y) 30 | print(loss) -------------------------------------------------------------------------------- /src/losses/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | 5 | 6 | class LossWrapper(nn.Module): 7 | def __init__(self, loss_func_list: List[nn.Module], loss_weight_list: List[float], *args, **kwargs) -> None: 8 | super(LossWrapper, self).__init__() 9 | self.loss_func_list = loss_func_list 10 | self.loss_weight_list = loss_weight_list 11 | assert len(self.loss_func_list) == len(self.loss_weight_list), "length of loss function list should match the length of loss weight list" 12 | self.num_meter = len(self.loss_func_list) 13 | if len(self.loss_func_list) == 1: 14 | self.loss_weight_list = [1.0] 15 | 16 | def forward(self, pred: torch.Tensor, targets: List[torch.Tensor], *args, **kwargs): 17 | """ 18 | Calculate the total loss between model prediction and target list 19 | 20 | Args: 21 | pred (torch.Tensor): model prediction 22 | targets (List[torch.Tensor]): a list of targets for multi-task / multi loss training 23 | 24 | Returns: 25 | loss (torch.FloatTensor): a weighted loss tensor 26 | loss_list (tuple[torch.FloatTensor]): a tuple of loss item 27 | pred (torch.FloatTensor): model output without grad 28 | """ 29 | assert len(self.loss_func_list) == len(targets), "length of loss function list should match the length of targets" 30 | 31 | loss = 0.0 32 | loss_list = [] 33 | for loss_func, loss_weight, target in zip(self.loss_func_list, self.loss_weight_list, targets): 34 | loss_item = loss_func(pred, target) * loss_weight 35 | loss += loss_item 36 | loss_list.append(loss_item.detach()) 37 | 38 | return loss, tuple(loss_list), pred.detach() 39 | 40 | 41 | if __name__ == '__main__': 42 | from loss_builder import LossBuilder 43 | model = nn.Linear(3, 5) 44 | x = torch.randn(2, 3) 45 | y = torch.randint(0, 5, size=(2, )) 46 | 47 | # loss = CrossEntropyLoss * 1.0 48 | ce_loss = LossBuilder.load("CrossEntropyLoss") 49 | wrapper = LossWrapper([ce_loss], [1.0]) 50 | out = model(x) 51 | loss, loss_list, output = wrapper.forward(out, [y]) 52 | print("loss: {} loss_list: {}, pred: {}".format(loss, loss_list, output.max(dim=1)[1])) -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric_builder import MetricBuilder 2 | 3 | __all__ = ['MetricBuilder'] -------------------------------------------------------------------------------- /src/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = [ 5 | "Accuracy" 6 | ] 7 | 8 | def accuracy(output, target, topk=(1,)): 9 | """Computes the precision@k for the specified values of k""" 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | return res 22 | 23 | 24 | class Accuracy: 25 | def __init__(self) -> None: 26 | pass 27 | 28 | def __call__(self, pred: torch.Tensor, label: torch.Tensor) -> float: 29 | """ 30 | Calculate the accuracy of predictions 31 | 32 | Args: 33 | pred (torch.Tensor): shape: [N, C] 34 | label (torch.Tensor): shape: [N, ] 35 | Return: 36 | accuracy (float): range from 0 to 100 37 | """ 38 | with torch.no_grad(): 39 | assert len(pred.shape) == 2 40 | assert pred.size(0) == label.size(0) 41 | res = accuracy(pred, label)[0] 42 | return res 43 | 44 | 45 | if __name__ == '__main__': 46 | pred = torch.randn(size=(3, 5)) 47 | label = torch.randint(low=0, high=5, size=(3, )) 48 | print(pred.argmax(1)) 49 | print(label) 50 | metric = Accuracy() 51 | acc = metric(pred, label) 52 | print(acc) -------------------------------------------------------------------------------- /src/metrics/metric_builder.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .accuracy import Accuracy 3 | except: 4 | from sys import path 5 | path.append("../metric") 6 | from accuracy import Accuracy 7 | 8 | 9 | __all__ = [ 10 | "MetricBuilder" 11 | ] 12 | 13 | 14 | class MetricBuilder: 15 | def __init__(self) -> None: 16 | pass 17 | 18 | @staticmethod 19 | def load(metric_name: str, *args, **kwargs) -> object: 20 | metric_func = globals()[metric_name] 21 | return metric_func(*args, **kwargs) 22 | 23 | 24 | 25 | if __name__ == '__main__': 26 | import torch 27 | pred = torch.randn(size=(30, 2)) 28 | label = torch.randint(low=0, high=2, size=(30, )) 29 | metric = MetricBuilder.load("Accuracy") 30 | acc = metric(pred, label) 31 | print(acc) -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # @author: shaoeric 2 | # @create: 2022.2.23 3 | 4 | from .model_builder import ModelBuilder 5 | 6 | __all__ = ['ModelBuilder'] -------------------------------------------------------------------------------- /src/models/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoeric/torch-atom/7688fc38c0d19fe4d13a9773115df911ffe6eaaa/src/models/blocks/__init__.py -------------------------------------------------------------------------------- /src/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """mobilenetv2 in pytorch 2 | [1] Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen 3 | MobileNetV2: Inverted Residuals and Linear Bottlenecks 4 | https://arxiv.org/abs/1801.04381 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | __all__ = ['mobilenetv2'] 12 | 13 | 14 | class LinearBottleNeck(nn.Module): 15 | 16 | def __init__(self, in_channels, out_channels, stride, t=6, class_num=100): 17 | super().__init__() 18 | 19 | self.residual = nn.Sequential( 20 | nn.Conv2d(in_channels, in_channels * t, 1), 21 | nn.BatchNorm2d(in_channels * t), 22 | nn.ReLU6(inplace=True), 23 | 24 | nn.Conv2d(in_channels * t, in_channels * t, 3, stride=stride, padding=1, groups=in_channels * t), 25 | nn.BatchNorm2d(in_channels * t), 26 | nn.ReLU6(inplace=True), 27 | 28 | nn.Conv2d(in_channels * t, out_channels, 1), 29 | nn.BatchNorm2d(out_channels) 30 | ) 31 | 32 | self.stride = stride 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | 36 | def forward(self, x): 37 | 38 | residual = self.residual(x) 39 | 40 | if self.stride == 1 and self.in_channels == self.out_channels: 41 | residual += x 42 | 43 | return residual 44 | 45 | class MobileNetV2(nn.Module): 46 | 47 | def __init__(self, num_classes=100): 48 | super().__init__() 49 | 50 | self.pre = nn.Sequential( 51 | nn.Conv2d(3, 32, 1, padding=1), 52 | nn.BatchNorm2d(32), 53 | nn.ReLU6(inplace=True) 54 | ) 55 | 56 | self.stage1 = LinearBottleNeck(32, 16, 1, 1) 57 | self.stage2 = self._make_stage(2, 16, 24, 2, 6) 58 | self.stage3 = self._make_stage(3, 24, 32, 2, 6) 59 | self.stage4 = self._make_stage(4, 32, 64, 2, 6) 60 | self.stage5 = self._make_stage(3, 64, 96, 1, 6) 61 | self.stage6 = self._make_stage(3, 96, 160, 1, 6) 62 | self.stage7 = LinearBottleNeck(160, 320, 1, 6) 63 | 64 | self.conv1 = nn.Sequential( 65 | nn.Conv2d(320, 1280, 1), 66 | nn.BatchNorm2d(1280), 67 | nn.ReLU6(inplace=True) 68 | ) 69 | 70 | self.conv2 = nn.Conv2d(1280, num_classes, 1) 71 | 72 | def forward(self, x): 73 | x = self.pre(x) 74 | x = self.stage1(x) 75 | x = self.stage2(x) 76 | x = self.stage3(x) 77 | x = self.stage4(x) 78 | x = self.stage5(x) 79 | x = self.stage6(x) 80 | x = self.stage7(x) 81 | x = self.conv1(x) 82 | x = F.adaptive_avg_pool2d(x, 1) 83 | x = self.conv2(x) 84 | x = x.view(x.size(0), -1) 85 | 86 | return x 87 | 88 | def _make_stage(self, repeat, in_channels, out_channels, stride, t): 89 | 90 | layers = [] 91 | layers.append(LinearBottleNeck(in_channels, out_channels, stride, t)) 92 | 93 | while repeat - 1: 94 | layers.append(LinearBottleNeck(out_channels, out_channels, 1, t)) 95 | repeat -= 1 96 | 97 | return nn.Sequential(*layers) 98 | 99 | def mobilenetv2(num_classes): 100 | return MobileNetV2(num_classes=num_classes) -------------------------------------------------------------------------------- /src/models/model_builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from .resnet import * 5 | from .vgg import * 6 | from .shufflenet import * 7 | from .shufflenetv2 import * 8 | from .mobilenetv2 import * 9 | except: 10 | from sys import path 11 | path.append('../models') 12 | from resnet import * 13 | from vgg import * 14 | from shufflenet import * 15 | from shufflenetv2 import * 16 | from mobilenetv2 import * 17 | 18 | 19 | 20 | 21 | class ModelBuilder: 22 | def __init__(self) -> None: 23 | pass 24 | 25 | @staticmethod 26 | def load(model_name, *args, **kwargs) -> nn.Module: 27 | model = globals()[model_name] 28 | return model(*args, **kwargs) 29 | 30 | 31 | if __name__ == '__main__': 32 | model = ModelBuilder.load("resnet32", num_classes=10) 33 | print(model) -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | """resnet in pytorch 2 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 3 | Deep Residual Learning for Image Recognition 4 | https://arxiv.org/abs/1512.03385v1 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | __all__ = [ 12 | 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152' 13 | ] 14 | 15 | class BasicBlock(nn.Module): 16 | """Basic Block for resnet 18 and resnet 34 17 | """ 18 | 19 | #BasicBlock and BottleNeck block 20 | #have different output size 21 | #we use class attribute expansion 22 | #to distinct 23 | expansion = 1 24 | 25 | def __init__(self, in_channels, out_channels, stride=1): 26 | super().__init__() 27 | 28 | #residual function 29 | self.residual_function = nn.Sequential( 30 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | #shortcut 38 | self.shortcut = nn.Sequential() 39 | 40 | #the shortcut output dimension is not the same with residual function 41 | #use 1*1 convolution to match the dimension 42 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 45 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 46 | ) 47 | 48 | def forward(self, x): 49 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 50 | 51 | class BottleNeck(nn.Module): 52 | """Residual block for resnet over 50 layers 53 | """ 54 | expansion = 4 55 | def __init__(self, in_channels, out_channels, stride=1): 56 | super().__init__() 57 | self.residual_function = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 59 | nn.BatchNorm2d(out_channels), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 62 | nn.BatchNorm2d(out_channels), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 65 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 66 | ) 67 | 68 | self.shortcut = nn.Sequential() 69 | 70 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 71 | self.shortcut = nn.Sequential( 72 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 73 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 74 | ) 75 | 76 | def forward(self, x): 77 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 78 | 79 | class ResNet(nn.Module): 80 | 81 | def __init__(self, block, num_block, num_classes=100): 82 | super().__init__() 83 | 84 | self.in_channels = 64 85 | 86 | self.conv1 = nn.Sequential( 87 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 88 | nn.BatchNorm2d(64), 89 | nn.ReLU(inplace=True)) 90 | #we use a different inputsize than the original paper 91 | #so conv2_x's stride is 1 92 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 93 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 94 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 95 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 96 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 97 | self.fc = nn.Linear(512 * block.expansion, num_classes) 98 | 99 | def _make_layer(self, block, out_channels, num_blocks, stride): 100 | """make resnet layers(by layer i didnt mean this 'layer' was the 101 | same as a neuron netowork layer, ex. conv layer), one layer may 102 | contain more than one residual block 103 | Args: 104 | block: block type, basic block or bottle neck block 105 | out_channels: output depth channel number of this layer 106 | num_blocks: how many blocks per layer 107 | stride: the stride of the first block of this layer 108 | Return: 109 | return a resnet layer 110 | """ 111 | 112 | # we have num_block blocks per layer, the first block 113 | # could be 1 or 2, other blocks would always be 1 114 | strides = [stride] + [1] * (num_blocks - 1) 115 | layers = [] 116 | for stride in strides: 117 | layers.append(block(self.in_channels, out_channels, stride)) 118 | self.in_channels = out_channels * block.expansion 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | output = self.conv1(x) 124 | output = self.conv2_x(output) 125 | output = self.conv3_x(output) 126 | output = self.conv4_x(output) 127 | output = self.conv5_x(output) 128 | output = self.avg_pool(output) 129 | output = output.view(output.size(0), -1) 130 | output = self.fc(output) 131 | 132 | return output 133 | 134 | def resnet18(num_classes): 135 | """ return a ResNet 18 object 136 | """ 137 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 138 | 139 | def resnet34(num_classes): 140 | """ return a ResNet 34 object 141 | """ 142 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 143 | 144 | def resnet50(num_classes): 145 | """ return a ResNet 50 object 146 | """ 147 | return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=num_classes) 148 | 149 | def resnet101(num_classes): 150 | """ return a ResNet 101 object 151 | """ 152 | return ResNet(BottleNeck, [3, 4, 23, 3], num_classes=num_classes) 153 | 154 | def resnet152(num_classes): 155 | """ return a ResNet 152 object 156 | """ 157 | return ResNet(BottleNeck, [3, 8, 36, 3], num_classes=num_classes) -------------------------------------------------------------------------------- /src/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | """shufflenet in pytorch 2 | [1] Xiangyu Zhang, Xinyu Zhou, Mengxiao Lin, Jian Sun. 3 | ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices 4 | https://arxiv.org/abs/1707.01083v2 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | __all__ = ['shufflenet'] 11 | 12 | class BasicConv2d(nn.Module): 13 | 14 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 15 | super().__init__() 16 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, **kwargs) 17 | self.bn = nn.BatchNorm2d(output_channels) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.bn(x) 23 | x = self.relu(x) 24 | return x 25 | 26 | class ChannelShuffle(nn.Module): 27 | 28 | def __init__(self, groups): 29 | super().__init__() 30 | self.groups = groups 31 | 32 | def forward(self, x): 33 | batchsize, channels, height, width = x.data.size() 34 | channels_per_group = int(channels / self.groups) 35 | 36 | #"""suppose a convolutional layer with g groups whose output has 37 | #g x n channels; we first reshape the output channel dimension 38 | #into (g, n)""" 39 | x = x.view(batchsize, self.groups, channels_per_group, height, width) 40 | 41 | #"""transposing and then flattening it back as the input of next layer.""" 42 | x = x.transpose(1, 2).contiguous() 43 | x = x.view(batchsize, -1, height, width) 44 | 45 | return x 46 | 47 | class DepthwiseConv2d(nn.Module): 48 | 49 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 50 | super().__init__() 51 | self.depthwise = nn.Sequential( 52 | nn.Conv2d(input_channels, output_channels, kernel_size, **kwargs), 53 | nn.BatchNorm2d(output_channels) 54 | ) 55 | 56 | def forward(self, x): 57 | return self.depthwise(x) 58 | 59 | class PointwiseConv2d(nn.Module): 60 | def __init__(self, input_channels, output_channels, **kwargs): 61 | super().__init__() 62 | self.pointwise = nn.Sequential( 63 | nn.Conv2d(input_channels, output_channels, 1, **kwargs), 64 | nn.BatchNorm2d(output_channels) 65 | ) 66 | 67 | def forward(self, x): 68 | return self.pointwise(x) 69 | 70 | class ShuffleNetUnit(nn.Module): 71 | 72 | def __init__(self, input_channels, output_channels, stage, stride, groups): 73 | super().__init__() 74 | 75 | #"""Similar to [9], we set the number of bottleneck channels to 1/4 76 | #of the output channels for each ShuffleNet unit.""" 77 | self.bottlneck = nn.Sequential( 78 | PointwiseConv2d( 79 | input_channels, 80 | int(output_channels / 4), 81 | groups=groups 82 | ), 83 | nn.ReLU(inplace=True) 84 | ) 85 | 86 | #"""Note that for Stage 2, we do not apply group convolution on the first pointwise 87 | #layer because the number of input channels is relatively small.""" 88 | if stage == 2: 89 | self.bottlneck = nn.Sequential( 90 | PointwiseConv2d( 91 | input_channels, 92 | int(output_channels / 4), 93 | groups=groups 94 | ), 95 | nn.ReLU(inplace=True) 96 | ) 97 | 98 | self.channel_shuffle = ChannelShuffle(groups) 99 | 100 | self.depthwise = DepthwiseConv2d( 101 | int(output_channels / 4), 102 | int(output_channels / 4), 103 | 3, 104 | groups=int(output_channels / 4), 105 | stride=stride, 106 | padding=1 107 | ) 108 | 109 | self.expand = PointwiseConv2d( 110 | int(output_channels / 4), 111 | output_channels, 112 | groups=groups 113 | ) 114 | 115 | self.relu = nn.ReLU(inplace=True) 116 | self.fusion = self._add 117 | self.shortcut = nn.Sequential() 118 | 119 | #"""As for the case where ShuffleNet is applied with stride, 120 | #we simply make two modifications (see Fig 2 (c)): 121 | #(i) add a 3 × 3 average pooling on the shortcut path; 122 | #(ii) replace the element-wise addition with channel concatenation, 123 | #which makes it easy to enlarge channel dimension with little extra 124 | #computation cost. 125 | if stride != 1 or input_channels != output_channels: 126 | self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) 127 | 128 | self.expand = PointwiseConv2d( 129 | int(output_channels / 4), 130 | output_channels - input_channels, 131 | groups=groups 132 | ) 133 | 134 | self.fusion = self._cat 135 | 136 | def _add(self, x, y): 137 | return torch.add(x, y) 138 | 139 | def _cat(self, x, y): 140 | return torch.cat([x, y], dim=1) 141 | 142 | def forward(self, x): 143 | shortcut = self.shortcut(x) 144 | 145 | shuffled = self.bottlneck(x) 146 | shuffled = self.channel_shuffle(shuffled) 147 | shuffled = self.depthwise(shuffled) 148 | shuffled = self.expand(shuffled) 149 | 150 | output = self.fusion(shortcut, shuffled) 151 | output = self.relu(output) 152 | 153 | return output 154 | 155 | class ShuffleNet(nn.Module): 156 | 157 | def __init__(self, num_blocks, num_classes=100, groups=3): 158 | super().__init__() 159 | 160 | if groups == 1: 161 | out_channels = [24, 144, 288, 567] 162 | elif groups == 2: 163 | out_channels = [24, 200, 400, 800] 164 | elif groups == 3: 165 | out_channels = [24, 240, 480, 960] 166 | elif groups == 4: 167 | out_channels = [24, 272, 544, 1088] 168 | elif groups == 8: 169 | out_channels = [24, 384, 768, 1536] 170 | 171 | self.conv1 = BasicConv2d(3, out_channels[0], 3, padding=1, stride=1) 172 | self.input_channels = out_channels[0] 173 | 174 | self.stage2 = self._make_stage( 175 | ShuffleNetUnit, 176 | num_blocks[0], 177 | out_channels[1], 178 | stride=2, 179 | stage=2, 180 | groups=groups 181 | ) 182 | 183 | self.stage3 = self._make_stage( 184 | ShuffleNetUnit, 185 | num_blocks[1], 186 | out_channels[2], 187 | stride=2, 188 | stage=3, 189 | groups=groups 190 | ) 191 | 192 | self.stage4 = self._make_stage( 193 | ShuffleNetUnit, 194 | num_blocks[2], 195 | out_channels[3], 196 | stride=2, 197 | stage=4, 198 | groups=groups 199 | ) 200 | 201 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 202 | self.fc = nn.Linear(out_channels[3], num_classes) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.stage2(x) 207 | x = self.stage3(x) 208 | x = self.stage4(x) 209 | x = self.avg(x) 210 | x = x.view(x.size(0), -1) 211 | x = self.fc(x) 212 | 213 | return x 214 | 215 | def _make_stage(self, block, num_blocks, output_channels, stride, stage, groups): 216 | """make shufflenet stage 217 | Args: 218 | block: block type, shuffle unit 219 | out_channels: output depth channel number of this stage 220 | num_blocks: how many blocks per stage 221 | stride: the stride of the first block of this stage 222 | stage: stage index 223 | groups: group number of group convolution 224 | Return: 225 | return a shuffle net stage 226 | """ 227 | strides = [stride] + [1] * (num_blocks - 1) 228 | 229 | stage = [] 230 | 231 | for stride in strides: 232 | stage.append( 233 | block( 234 | self.input_channels, 235 | output_channels, 236 | stride=stride, 237 | stage=stage, 238 | groups=groups 239 | ) 240 | ) 241 | self.input_channels = output_channels 242 | 243 | return nn.Sequential(*stage) 244 | 245 | def shufflenet(num_classes): 246 | return ShuffleNet([4, 8, 4], num_classes=num_classes) -------------------------------------------------------------------------------- /src/models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | """shufflenetv2 in pytorch 2 | [1] Ningning Ma, Xiangyu Zhang, Hai-Tao Zheng, Jian Sun 3 | ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design 4 | https://arxiv.org/abs/1807.11164 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ['shufflenetv2'] 13 | 14 | def channel_split(x, split): 15 | """split a tensor into two pieces along channel dimension 16 | Args: 17 | x: input tensor 18 | split:(int) channel size for each pieces 19 | """ 20 | assert x.size(1) == split * 2 21 | return torch.split(x, split, dim=1) 22 | 23 | def channel_shuffle(x, groups): 24 | """channel shuffle operation 25 | Args: 26 | x: input tensor 27 | groups: input branch number 28 | """ 29 | 30 | batch_size, channels, height, width = x.size() 31 | channels_per_group = int(channels // groups) 32 | 33 | x = x.view(batch_size, groups, channels_per_group, height, width) 34 | x = x.transpose(1, 2).contiguous() 35 | x = x.view(batch_size, -1, height, width) 36 | 37 | return x 38 | 39 | class ShuffleUnit(nn.Module): 40 | 41 | def __init__(self, in_channels, out_channels, stride): 42 | super().__init__() 43 | 44 | self.stride = stride 45 | self.in_channels = in_channels 46 | self.out_channels = out_channels 47 | 48 | if stride != 1 or in_channels != out_channels: 49 | self.residual = nn.Sequential( 50 | nn.Conv2d(in_channels, in_channels, 1), 51 | nn.BatchNorm2d(in_channels), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 54 | nn.BatchNorm2d(in_channels), 55 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 56 | nn.BatchNorm2d(int(out_channels / 2)), 57 | nn.ReLU(inplace=True) 58 | ) 59 | 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 62 | nn.BatchNorm2d(in_channels), 63 | nn.Conv2d(in_channels, int(out_channels / 2), 1), 64 | nn.BatchNorm2d(int(out_channels / 2)), 65 | nn.ReLU(inplace=True) 66 | ) 67 | else: 68 | self.shortcut = nn.Sequential() 69 | 70 | in_channels = int(in_channels / 2) 71 | self.residual = nn.Sequential( 72 | nn.Conv2d(in_channels, in_channels, 1), 73 | nn.BatchNorm2d(in_channels), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), 76 | nn.BatchNorm2d(in_channels), 77 | nn.Conv2d(in_channels, in_channels, 1), 78 | nn.BatchNorm2d(in_channels), 79 | nn.ReLU(inplace=True) 80 | ) 81 | 82 | 83 | def forward(self, x): 84 | 85 | if self.stride == 1 and self.out_channels == self.in_channels: 86 | shortcut, residual = channel_split(x, int(self.in_channels / 2)) 87 | else: 88 | shortcut = x 89 | residual = x 90 | 91 | shortcut = self.shortcut(shortcut) 92 | residual = self.residual(residual) 93 | x = torch.cat([shortcut, residual], dim=1) 94 | x = channel_shuffle(x, 2) 95 | 96 | return x 97 | 98 | class ShuffleNetV2(nn.Module): 99 | 100 | def __init__(self, ratio=1, num_classes=100): 101 | super().__init__() 102 | if ratio == 0.5: 103 | out_channels = [48, 96, 192, 1024] 104 | elif ratio == 1: 105 | out_channels = [116, 232, 464, 1024] 106 | elif ratio == 1.5: 107 | out_channels = [176, 352, 704, 1024] 108 | elif ratio == 2: 109 | out_channels = [244, 488, 976, 2048] 110 | else: 111 | ValueError('unsupported ratio number') 112 | 113 | self.pre = nn.Sequential( 114 | nn.Conv2d(3, 24, 3, padding=1), 115 | nn.BatchNorm2d(24) 116 | ) 117 | 118 | self.stage2 = self._make_stage(24, out_channels[0], 3) 119 | self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7) 120 | self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3) 121 | self.conv5 = nn.Sequential( 122 | nn.Conv2d(out_channels[2], out_channels[3], 1), 123 | nn.BatchNorm2d(out_channels[3]), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | self.fc = nn.Linear(out_channels[3], num_classes) 128 | 129 | def forward(self, x): 130 | x = self.pre(x) 131 | x = self.stage2(x) 132 | x = self.stage3(x) 133 | x = self.stage4(x) 134 | x = self.conv5(x) 135 | x = F.adaptive_avg_pool2d(x, 1) 136 | x = x.view(x.size(0), -1) 137 | x = self.fc(x) 138 | 139 | return x 140 | 141 | def _make_stage(self, in_channels, out_channels, repeat): 142 | layers = [] 143 | layers.append(ShuffleUnit(in_channels, out_channels, 2)) 144 | 145 | while repeat: 146 | layers.append(ShuffleUnit(out_channels, out_channels, 1)) 147 | repeat -= 1 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def shufflenetv2(num_classes): 152 | return ShuffleNetV2(num_classes=num_classes) -------------------------------------------------------------------------------- /src/models/vgg.py: -------------------------------------------------------------------------------- 1 | """vgg in pytorch 2 | [1] Karen Simonyan, Andrew Zisserman 3 | Very Deep Convolutional Networks for Large-Scale Image Recognition. 4 | https://arxiv.org/abs/1409.1556v6 5 | """ 6 | '''VGG11/13/16/19 in Pytorch.''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | __all__ = [ 12 | 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn' 13 | ] 14 | 15 | cfg = { 16 | 'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 17 | 'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 18 | 'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 19 | 'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 20 | } 21 | 22 | class VGG(nn.Module): 23 | 24 | def __init__(self, features, num_classes=100): 25 | super().__init__() 26 | self.features = features 27 | 28 | self.classifier = nn.Sequential( 29 | nn.Linear(512, 4096), 30 | nn.ReLU(inplace=True), 31 | nn.Dropout(), 32 | nn.Linear(4096, 4096), 33 | nn.ReLU(inplace=True), 34 | nn.Dropout(), 35 | nn.Linear(4096, num_classes) 36 | ) 37 | 38 | def forward(self, x): 39 | output = self.features(x) 40 | output = output.view(output.size()[0], -1) 41 | output = self.classifier(output) 42 | 43 | return output 44 | 45 | def make_layers(cfg, batch_norm=False): 46 | layers = [] 47 | 48 | input_channel = 3 49 | for l in cfg: 50 | if l == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | continue 53 | 54 | layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)] 55 | 56 | if batch_norm: 57 | layers += [nn.BatchNorm2d(l)] 58 | 59 | layers += [nn.ReLU(inplace=True)] 60 | input_channel = l 61 | 62 | return nn.Sequential(*layers) 63 | 64 | def vgg11_bn(num_classes): 65 | return VGG(make_layers(cfg['A'], batch_norm=True), num_classes=num_classes) 66 | 67 | def vgg13_bn(num_classes): 68 | return VGG(make_layers(cfg['B'], batch_norm=True), num_classes=num_classes) 69 | 70 | def vgg16_bn(num_classes): 71 | return VGG(make_layers(cfg['D'], batch_norm=True), num_classes=num_classes) 72 | 73 | def vgg19_bn(num_classes): 74 | return VGG(make_layers(cfg['E'], batch_norm=True), num_classes=num_classes) 75 | -------------------------------------------------------------------------------- /src/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer_builder import OptimizerBuilder 2 | 3 | __all__ = ['OptimizerBuilder'] -------------------------------------------------------------------------------- /src/optimizer/optimizer_builder.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .optimizers import * 3 | except: 4 | from sys import path 5 | path.append("../") 6 | from optimizers import * 7 | import os 8 | import yaml 9 | from easydict import EasyDict 10 | from typing import Tuple, Iterator 11 | from torch import optim 12 | import torch.nn as nn 13 | import re 14 | 15 | __all__ = ["OptimizerBuilder"] 16 | 17 | 18 | OPTIMIZER_CONFIG = os.path.join(os.path.dirname(__file__), "optimizer_config.yml") 19 | 20 | 21 | def parse_optimizer_config(): 22 | loader = yaml.SafeLoader 23 | loader.add_implicit_resolver( 24 | u'tag:yaml.org,2002:float', 25 | re.compile(u'''^(?: 26 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 27 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 28 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 29 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 30 | |[-+]?\\.(?:inf|Inf|INF) 31 | |\\.(?:nan|NaN|NAN))$''', re.X), 32 | list(u'-+0123456789.')) 33 | with open(OPTIMIZER_CONFIG, 'r') as f: 34 | data = yaml.load(f, Loader=loader) 35 | return EasyDict(data) 36 | 37 | 38 | class OptimizerBuilder: 39 | def __init__(self) -> None: 40 | pass 41 | 42 | @staticmethod 43 | def load(optimizer_name: str, parameters: Iterator[nn.Parameter], lr: float) -> Tuple[optim.Optimizer, dict]: 44 | config = parse_optimizer_config() 45 | optimizer_param = config[optimizer_name] 46 | optimizer = globals()[optimizer_name](parameters=parameters, lr=lr, **optimizer_param) 47 | return optimizer, {optimizer_name: optimizer_param} 48 | 49 | 50 | if __name__ == '__main__': 51 | import torch.nn as nn 52 | model = nn.Linear(3, 2) 53 | optimizer, param = OptimizerBuilder.load('Adam', model.parameters(), 0.1) 54 | print(param) -------------------------------------------------------------------------------- /src/optimizer/optimizer_config.yml: -------------------------------------------------------------------------------- 1 | SGD: 2 | momentum: 0.9 3 | weight_decay: 5e-4 4 | dampening: 0 5 | nesterov: False 6 | 7 | Adam: 8 | betas: [0.9, 0.999] 9 | eps: 1e-8 10 | weight_decay: 0 11 | amsgrad: False -------------------------------------------------------------------------------- /src/optimizer/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | __all__ = [ 4 | "SGD", "Adam" 5 | ] 6 | 7 | 8 | def SGD(parameters, lr, *args, **kwargs) -> optim.Optimizer: 9 | optimizer = optim.SGD(parameters, lr=lr, *args, **kwargs) 10 | return optimizer 11 | 12 | 13 | def Adam(parameters, lr, *args, **kwargs) -> optim.Optimizer: 14 | optimizer = optim.Adam(parameters, lr=lr, *args, **kwargs) 15 | return optimizer -------------------------------------------------------------------------------- /src/schemes/__init__.py: -------------------------------------------------------------------------------- 1 | # @author: shaoeric 2 | # @create: 2022.2.23 3 | 4 | from .scheme_builder import SchedulerBuilder 5 | 6 | __all__ = ['SchedulerBuilder'] -------------------------------------------------------------------------------- /src/schemes/lr_schemes.py: -------------------------------------------------------------------------------- 1 | from torch.optim import lr_scheduler 2 | 3 | 4 | __all__ = [ 5 | 'constant_lr', 6 | 'multi_step_lr', 7 | 'cosine_annealing_lr', 8 | 'warmup_cosine_annealing_lr' 9 | ] 10 | 11 | 12 | def constant_lr(optimizer, epochs, *args, **kwargs): 13 | gamma = 1.0 14 | last_epoch = kwargs["last_epoch"] if "last_epoch" in kwargs else -1 15 | scheduler = lr_scheduler.StepLR(optimizer, step_size=epochs, gamma=gamma, last_epoch=last_epoch) 16 | return scheduler 17 | 18 | 19 | def multi_step_lr(optimizer, epochs, *args, **kwargs): 20 | milestones = kwargs["milestones"] if "milestones" in kwargs else [60, 120, 160] 21 | gamma = kwargs["gamma"] if "gamma" in kwargs else 0.1 22 | last_epoch = kwargs["last_epoch"] if "last_epoch" in kwargs else -1 23 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=gamma, last_epoch=last_epoch) 24 | return scheduler 25 | 26 | 27 | def cosine_annealing_lr(optimizer, epochs, *args, **kwargs): 28 | T_max = kwargs["T_max"] if "T_max" in kwargs else epochs // 5 29 | eta_min = kwargs["eta_min"] if "eta_min" in kwargs else 1e-6 30 | last_epoch = kwargs["last_epoch"] if "last_epoch" in kwargs else -1 31 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min, last_epoch=last_epoch) 32 | return scheduler 33 | 34 | 35 | def warmup_cosine_annealing_lr(optimizer, epochs, *args, **kwargs): 36 | T_0 = kwargs["T_0"] if "T_0" in kwargs else 0 37 | T_mult = kwargs["T_mult"] if "T_mult" in kwargs else 1 38 | eta_min = kwargs["eta_min"] if "eta_min" in kwargs else 1e-6 39 | last_epoch = kwargs["last_epoch"] if "last_epoch" in kwargs else -1 40 | scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=last_epoch) 41 | return scheduler -------------------------------------------------------------------------------- /src/schemes/scheme_builder.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .lr_schemes import * 3 | except: 4 | from sys import path 5 | path.append('../schemes') 6 | from lr_schemes import * 7 | 8 | from torch import optim 9 | import os 10 | import yaml 11 | from easydict import EasyDict 12 | import re 13 | 14 | __all__ = [ 15 | "SchedulerBuilder" 16 | ] 17 | 18 | 19 | SCHEME_CONFIG = os.path.join(os.path.dirname(__file__), "scheme_config.yml") 20 | 21 | 22 | def parse_scheduler_config(): 23 | loader = yaml.SafeLoader 24 | loader.add_implicit_resolver( 25 | u'tag:yaml.org,2002:float', 26 | re.compile(u'''^(?: 27 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 28 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 29 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 30 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 31 | |[-+]?\\.(?:inf|Inf|INF) 32 | |\\.(?:nan|NaN|NAN))$''', re.X), 33 | list(u'-+0123456789.')) 34 | with open(SCHEME_CONFIG, 'r') as f: 35 | data = yaml.load(f, Loader=loader) 36 | return EasyDict(data) 37 | 38 | 39 | class SchedulerBuilder: 40 | def __init__(self): 41 | pass 42 | 43 | @staticmethod 44 | def load(scheduler_func_name: str, optimizer: optim, max_epoch: int): 45 | scheduler_func = globals()[scheduler_func_name] 46 | config = parse_scheduler_config() 47 | scheduler_params = config[scheduler_func_name] 48 | scheduler = scheduler_func(optimizer, max_epoch, **scheduler_params) 49 | return scheduler, {scheduler_func_name: scheduler_params} 50 | 51 | 52 | if __name__ == "__main__": 53 | from torch import optim 54 | from torch import nn 55 | import matplotlib.pyplot as plt 56 | epochs = 100 57 | m = nn.Linear(3, 5) 58 | optimizer = optim.SGD(m.parameters(), lr=0.001) 59 | scheduler, scheduler_param = SchedulerBuilder.load("cosine_annealing_lr", optimizer, epochs) 60 | lrs = [] 61 | for i in range(epochs): 62 | scheduler.step() 63 | lr = scheduler.get_lr() 64 | lrs.append(lr) 65 | plt.plot(lrs) 66 | plt.show() -------------------------------------------------------------------------------- /src/schemes/scheme_config.yml: -------------------------------------------------------------------------------- 1 | constant_lr: 2 | step_size: -1 3 | gamma: 1.0 4 | last_epoch: -1 5 | 6 | multi_step_lr: 7 | milestones: [60, 120, 160] 8 | gamma: 0.1 9 | last_epoch: -1 10 | 11 | cosine_annealing_lr: 12 | T_max: 10 13 | eta_min: 1e-6 14 | last_epoch: -1 15 | 16 | warmup_cosine_annealing_lr: 17 | T_0: 0 18 | T_mult: 1 19 | eta_min: 1e-6 20 | last_epoch: -1 -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | import torch.distributed as dist 5 | from typing import List 6 | from easydict import EasyDict 7 | from .utils import NetIO 8 | from .losses import LossWrapper 9 | from .optimizer import OptimizerBuilder 10 | from .schemes import SchedulerBuilder 11 | from .utils import LoggerBuilder, AverageMeter 12 | from .metrics import MetricBuilder 13 | from .controller import Controller 14 | 15 | 16 | __all__ = ['Trainer'] 17 | 18 | 19 | class Trainer: 20 | def __init__(self, 21 | config: EasyDict, 22 | model: nn.Module, 23 | wrapper: LossWrapper, 24 | ioer: NetIO, 25 | ddp: bool = False, 26 | *args, **kwargs 27 | ) -> None: 28 | 29 | self.config = config 30 | self.loss_wrapper = wrapper 31 | self.ioer = ioer 32 | self.logger, self.summary = None, None 33 | self.ddp = ddp 34 | 35 | self.__parse_config() 36 | 37 | self.metric_func_list = self.__get_metrics() 38 | 39 | self.__num_loss = self.loss_wrapper.num_meter 40 | self.__num_metric = len(self.metric_func_list) 41 | 42 | self.model = model 43 | (self.optimizer, self.optimizer_params), (self.scheduler, self.scheduler_params) = self.__build_optimizer(self.model, self.lr) 44 | 45 | self.controller = Controller(loss_wrapper=self.loss_wrapper, model=self.model, optimizer=self.optimizer) 46 | if self.is_main_process(): 47 | self.logger, self.summary = LoggerBuilder(config).load() 48 | if self.logger is not None: 49 | self.logger.info(self.optimizer_params) 50 | self.logger.info(self.scheduler_params) 51 | 52 | self.__global_step = 0 53 | print("optimizer: ", self.local_rank, self.optimizer) 54 | 55 | def __parse_config(self): 56 | self.max_epoch = self.config.train["epochs"] 57 | self.lr = self.config.train["lr"] 58 | self.loss_names = self.config.train['criterion']['names'] 59 | self.metric_names = self.config.train["metric"]["names"] 60 | self.key_metric_name = self.config.train["metric"]["key_metric_name"] 61 | 62 | self.log_step_freq = self.config.output["log_step_freq"] 63 | self.log_epoch_freq = self.config.output["log_epoch_freq"] 64 | if self.ddp: 65 | self.local_rank = self.config.environment["local_rank"] 66 | self.num_gpu = self.config.environment["num_gpu"] 67 | 68 | def __build_optimizer(self, model: nn.Module, lr: float, *args, **kwargs): 69 | optimizer_name = self.config.train.optimizer 70 | scheduler_name = self.config.train.schedule 71 | 72 | optimizer, optimizer_config = OptimizerBuilder.load(optimizer_name, model.parameters(), lr) 73 | scheduler, scheduler_config = SchedulerBuilder.load(scheduler_name, optimizer, self.max_epoch) 74 | 75 | if self.ddp: 76 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[self.local_rank]) 77 | #model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 78 | 79 | return (optimizer, optimizer_config), (scheduler, scheduler_config) 80 | 81 | 82 | def __get_metrics(self): 83 | metric_func_list = [] 84 | 85 | for metric_name in self.metric_names: 86 | metric_func = MetricBuilder.load(metric_name) 87 | metric_func_list.append(metric_func) 88 | return metric_func_list 89 | 90 | 91 | def train(self, epoch: int, dataloader: DataLoader): 92 | if self.logger is not None and self.is_main_process(): 93 | self.logger.info("Training epoch [{} / {}]".format(epoch, self.max_epoch)) 94 | 95 | use_cuda = torch.cuda.is_available() 96 | self.model.train() 97 | 98 | current_lr = self.scheduler.get_lr()[0] 99 | print("lr: ", self.local_rank, current_lr) 100 | 101 | if self.summary is not None and self.is_main_process(): 102 | self.summary.add_scalar("train/lr", current_lr, epoch) 103 | 104 | loss_recorder = AverageMeter(type='scalar', name='total loss') 105 | loss_list_recorder = AverageMeter(type='tuple', num_scalar=self.__num_loss, names=self.loss_names) 106 | metric_list_recorder = AverageMeter(type='tuple', num_scalar=self.__num_metric, names=self.metric_names) 107 | 108 | # === current epoch begins training === 109 | for batch_idx, batch in enumerate(dataloader): 110 | data = batch["image"].float() 111 | target = batch["label"].long() 112 | batch_size = data.size(0) 113 | if use_cuda: 114 | data = data.cuda(self.local_rank if self.ddp else None, non_blocking=True) 115 | target = target.cuda(self.local_rank if self.ddp else None, non_blocking=True) 116 | 117 | loss, loss_tuple, output_no_grad = self.controller.train_step(data, target) 118 | metrics = tuple([func(output_no_grad, target) for func in self.metric_func_list]) 119 | 120 | if self.ddp: 121 | torch.distributed.barrier() 122 | 123 | loss = self.reduce_mean(loss, self.num_gpu) 124 | loss_tuple = tuple([self.reduce_mean(loss_each, self.num_gpu) for loss_each in loss_tuple]) 125 | metrics = tuple([self.reduce_mean(metric_each, self.num_gpu) for metric_each in metrics]) 126 | 127 | loss_recorder.update(loss.item(), batch_size) 128 | loss_list_recorder.update(loss_tuple, batch_size) 129 | metric_list_recorder.update(metrics, batch_size) 130 | 131 | if self.log_step_freq > 0 and self.__global_step % self.log_step_freq == 0: 132 | if self.logger and self.is_main_process(): 133 | msg = "[Train] Epoch:[{}/{}] batch:[{}/{}] loss: {:.4f} loss list: {} metric list: {}".format(epoch, self.max_epoch, batch_idx + 1, len(dataloader), 134 | loss_recorder.get_value(), loss_list_recorder, metric_list_recorder) 135 | self.logger.info(msg) 136 | 137 | self.__global_step += 1 138 | # === current epoch finishes training === 139 | 140 | if epoch % self.log_epoch_freq == 0: 141 | if self.logger and self.is_main_process(): 142 | msg = "[Train] Epoch:[{}/{}] loss: {:.4f} loss list: {} metric list: {}".format(epoch, self.max_epoch, loss_recorder.get_value(), loss_list_recorder, metric_list_recorder) 143 | self.logger.info(msg) 144 | if self.summary and self.is_main_process(): 145 | self.summary.add_scalar("train/epoch_loss", loss_recorder.get_value(), epoch) 146 | names = metric_list_recorder.get_name() 147 | values = metric_list_recorder.get_value() 148 | for name, value in zip(names, values): 149 | self.summary.add_scalar("train/epoch_{}".format(name), value, epoch) 150 | 151 | self.scheduler.step() 152 | 153 | def validate(self, epoch: int, dataloader: DataLoader): 154 | self.model.eval() 155 | loss_recorder = AverageMeter(type="scalar", name='total loss') 156 | loss_list_recorder = AverageMeter(type="tuple", num_scalar=self.__num_metric, names=self.loss_names) 157 | metric_list_recorder = AverageMeter(type='tuple', num_scalar=self.__num_metric, names=self.metric_names) 158 | use_cuda = torch.cuda.is_available() 159 | val_step = 0 160 | with torch.no_grad(): 161 | # === current epoch begins validation === 162 | for batch_idx, batch in enumerate(dataloader): 163 | data = batch["image"].float() 164 | target = batch["label"].long() 165 | batch_size = data.size(0) 166 | if use_cuda: 167 | data = data.cuda(self.local_rank if self.ddp else None, non_blocking=True) 168 | target = target.cuda(self.local_rank if self.ddp else None, non_blocking=True) 169 | loss, loss_tuple, output_no_grad = self.controller.validate_step(data, target) 170 | 171 | metrics = tuple([func(output_no_grad, target) for func in self.metric_func_list]) 172 | 173 | if self.ddp: 174 | loss = self.reduce_mean(loss, self.num_gpu) 175 | loss_tuple = tuple([self.reduce_mean(loss_each, self.num_gpu) for loss_each in loss_tuple]) 176 | metrics = tuple([self.reduce_mean(metric_each, self.num_gpu) for metric_each in metrics]) 177 | loss_recorder.update(loss.item(), batch_size) 178 | loss_list_recorder.update(loss_tuple, batch_size) 179 | metric_list_recorder.update(metrics, batch_size) 180 | 181 | if self.log_step_freq > 0 and val_step % self.log_step_freq == 0: 182 | if self.logger and self.is_main_process(): 183 | msg = "[Validation] Epoch:[{}/{}] batch:[{}/{}] loss: {:.4f} loss list: {} metric list: {}".format(epoch, self.max_epoch, batch_idx + 1, len(dataloader), 184 | loss_recorder.get_value(), loss_list_recorder, metric_list_recorder) 185 | self.logger.info(msg) 186 | val_step += 1 187 | # === current epoch finishes validation === 188 | 189 | if epoch % self.log_epoch_freq == 0: 190 | if self.logger and self.is_main_process(): 191 | msg = "[Validation] Epoch:[{}/{}] loss: {:.4f} loss list: {} metric list: {}".format(epoch, self.max_epoch, loss_recorder.get_value(), loss_list_recorder, metric_list_recorder) 192 | self.logger.info(msg) 193 | if self.summary and self.is_main_process(): 194 | self.summary.add_scalar("val/epoch_loss", loss_recorder.get_value(), epoch) 195 | names = metric_list_recorder.get_name() 196 | values = metric_list_recorder.get_value() 197 | for name, value in zip(names, values): 198 | self.summary.add_scalar("val/epoch_{}".format(name), value, epoch) 199 | 200 | # save checkpoint referring to the save_freq and the saving strategy, besides record the key metric value 201 | if self.ioer and self.is_main_process(): 202 | self.ioer.save_file(self.model, epoch, metric_list_recorder.get_value_by_name(self.key_metric_name)) 203 | 204 | def is_main_process(self): 205 | return (not self.ddp) or (self.local_rank == 0) 206 | 207 | def reduce_mean(self, tensor, nprocs): 208 | rt = tensor.clone() 209 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 210 | rt /= nprocs 211 | return rt -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .netio import NetIO 2 | from .logger import LoggerBuilder 3 | from .meter import AverageMeter 4 | 5 | __all__ = [ 6 | 'NetIO', 7 | 'LoggerBuilder', 'AverageMeter', 8 | ] -------------------------------------------------------------------------------- /src/utils/dist.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import torch 3 | 4 | def get_world_size(): 5 | if not dist.is_available(): 6 | return 1 7 | if not dist.is_initialized(): 8 | return 1 9 | return dist.get_world_size() 10 | 11 | def get_rank(): 12 | if not dist.is_available(): 13 | return 0 14 | if not dist.is_initialized(): 15 | return 0 16 | return dist.get_rank() 17 | 18 | def is_main_process(): 19 | return get_rank() == 0 20 | 21 | def synchronize(): 22 | """ 23 | Helper function to synchronize (barrier) among all processes when 24 | using distributed training 25 | """ 26 | if not dist.is_available(): 27 | return 28 | if not dist.is_initialized(): 29 | return 30 | world_size = dist.get_world_size() 31 | if world_size == 1: 32 | return 33 | dist.barrier() 34 | 35 | def reduce_value(value, average=True): 36 | world_size = get_world_size() 37 | if world_size < 2: 38 | return value 39 | with torch.no_grad(): 40 | dist.all_reduce(value) 41 | if average: 42 | value /= world_size 43 | return value -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from tensorboardX import SummaryWriter 5 | 6 | 7 | class LoggerBuilder: 8 | def __init__(self, config) -> None: 9 | self.config = config 10 | self.log_dir = os.path.join(config.output["log_dir"], config.output["save_dir"]) 11 | self.log_file = os.path.join(self.log_dir, "train.log") 12 | 13 | def load(self): 14 | logger = self.__load_logger() 15 | writer = self.__load_tensorboard() 16 | return logger, writer 17 | 18 | def __load_logger(self) -> logging.Logger: 19 | logger = logging.getLogger(self.config.name) 20 | logger.setLevel(logging.INFO) 21 | formatter = logging.Formatter( 22 | '%(asctime)s - %(name)s - %(funcName)s - %(lineno)d - %(message)s',datefmt='%Y-%m-%d %H:%M:%S' 23 | ) 24 | os.makedirs(self.log_dir, exist_ok=True) 25 | file_handler = logging.FileHandler(self.log_file) 26 | file_handler.setLevel(logging.INFO) 27 | file_handler.setFormatter(formatter) 28 | 29 | console_handler = logging.StreamHandler() 30 | console_handler.setLevel(logging.INFO) 31 | console_handler.setFormatter(formatter) 32 | 33 | logger.addHandler(file_handler) 34 | logger.addHandler(console_handler) 35 | return logger 36 | 37 | def __load_tensorboard(self): 38 | summary_writer = None 39 | if self.config.output["tensorboard"]: 40 | summary_writer = SummaryWriter(self.log_dir) 41 | return summary_writer -------------------------------------------------------------------------------- /src/utils/meter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | __all__ = ['AverageMeter'] 4 | 5 | class AverageScalarMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self, *args, **kwargs): 8 | if 'name' not in kwargs: 9 | raise KeyError("name should be specified") 10 | name = kwargs['name'] 11 | self.__sum = 0 12 | self.__count = 0 13 | self.name = name 14 | 15 | def update(self, val, batchsize=1): 16 | self.__sum += val * batchsize 17 | self.__count += batchsize 18 | 19 | def get_value(self): 20 | if self.__count == 0: 21 | return 0 22 | return self.__sum / self.__count 23 | 24 | def get_value_by_name(self, name): 25 | return self.get_value() 26 | 27 | def get_name(self): 28 | return self.name 29 | 30 | def __repr__(self) -> str: 31 | return "{}: {}".format(self.name, round(self.get_value(), 4)) 32 | 33 | 34 | class AverageTupleMeter(object): 35 | def __init__(self, num_scalar: int, *args, **kwargs) -> None: 36 | assert isinstance(num_scalar, int) 37 | if 'names' not in kwargs: 38 | raise KeyError('names should be specified') 39 | names = kwargs['names'] 40 | self.names = names 41 | self.__num_scalar = num_scalar 42 | self.__meter_list = [AverageScalarMeter(name=self.names[i]) for i in range(self.__num_scalar)] 43 | 44 | def update(self, vals, batchsize=1): 45 | assert isinstance(vals, tuple), "vals is not tuple" 46 | assert len(vals) == self.__num_scalar 47 | for i in range(len(vals)): 48 | self.__meter_list[i].update(vals[i], batchsize) 49 | 50 | def get_value(self): 51 | return tuple([meter.get_value() for meter in self.__meter_list]) 52 | 53 | def get_value_by_name(self, name): 54 | idx = self.names.index(name) 55 | return self.__meter_list[idx].get_value() 56 | 57 | def get_name(self): 58 | return self.names 59 | 60 | def __repr__(self) -> str: 61 | string = "" 62 | values = self.get_value() 63 | for name, value in zip(self.names, values): 64 | string += "{}: {} ".format(name, round(value.item(), 4)) 65 | return string 66 | 67 | 68 | class AverageMeter(object): 69 | def __init__(self, type: str="scalar", num_scalar: int=1, *args, **kwargs) -> None: 70 | assert type in ('scalar', 'tuple') 71 | assert isinstance(num_scalar, int) 72 | self.type = type 73 | self.num_scalar = num_scalar 74 | if self.type == 'scalar': 75 | self.meter = AverageScalarMeter(*args, **kwargs) 76 | else: 77 | self.meter = AverageTupleMeter(self.num_scalar, *args, **kwargs) 78 | 79 | def update(self, val, batchsize): 80 | self.meter.update(val, batchsize) 81 | 82 | def get_value(self): 83 | return self.meter.get_value() 84 | 85 | def get_value_by_name(self, name): 86 | return self.meter.get_value_by_name(name) 87 | 88 | def get_name(self): 89 | return self.meter.get_name() 90 | 91 | def __repr__(self) -> str: 92 | return self.meter.__repr__() 93 | 94 | if __name__ == '__main__': 95 | meter_scalar = AverageMeter(type='scalar') 96 | meter_scalar.update(3.0, 2) 97 | meter_scalar.update(1.0, 3) 98 | meter_scalar.update(2.0, 2) 99 | meter_scalar.update(4.0, 1) 100 | print(meter_scalar.get_value()) 101 | 102 | meter_tuple = AverageMeter(type='tuple', num_scalar=2) 103 | meter_tuple.update((3.0,2), 2) 104 | meter_tuple.update((1.0,1), 3) 105 | meter_tuple.update((2.0,2), 2) 106 | meter_tuple.update((4.0,3), 1) 107 | print(meter_tuple.get_value()) -------------------------------------------------------------------------------- /src/utils/netio.py: -------------------------------------------------------------------------------- 1 | # @author: shaoeric 2 | # @create: 2022.2.24 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | 7 | class NetIO: 8 | def __init__(self, config) -> None: 9 | self.config = config 10 | 11 | self.__save_dir = os.path.join(config.output["ckpt_root"], config.output["save_dir"]) 12 | self.__save_freq = config.output["save_freq"] 13 | self.__strict = config.model["strict"] 14 | self.__key_metric = config.train["metric"]["key_metric_name"] 15 | self.__strategy = config.train["metric"]["strategy"] 16 | 17 | if self.__strategy == "max": 18 | self.__best_score = -float("inf") 19 | elif self.__strategy == "min": 20 | self.__best_score = float("inf") 21 | elif self.__strategy == "none": 22 | self.__best_score = None 23 | else: 24 | raise NotImplementedError("Strategy not defined") 25 | 26 | os.makedirs(self.__save_dir, exist_ok=True) 27 | 28 | def load_file(self, net: nn.Module, weight_path: str): 29 | serial = torch.load(weight_path, map_location="cpu") 30 | state_dict = serial["state_dict"] 31 | net.load_state_dict(state_dict, strict=self.__strict) 32 | return net 33 | 34 | def save_file(self, net: nn.Module, epoch: int, metric: float, *args, **kwargs): 35 | if isinstance(net, nn.DataParallel): 36 | state_dict = net.module.state_dict() 37 | else: 38 | state_dict = net.state_dict() 39 | dic = { 40 | "state_dict": state_dict, 41 | "epoch": epoch, 42 | self.__key_metric: metric 43 | } 44 | # save checkpoint 45 | if self.__save_freq <= 0: 46 | torch.save(dic, os.path.join(self.__save_dir, "last.pth")) 47 | elif epoch % self.__save_freq == 0: 48 | torch.save(dic, os.path.join(self.__save_dir, "{}.pth".format(epoch))) 49 | 50 | # strategy 51 | if (self.__strategy == "max" and metric > self.__best_score) or \ 52 | (self.__strategy == "min" and metric < self.__best_score): 53 | torch.save(dic, os.path.join(self.__save_dir, "best.pth")) 54 | self.__best_score = metric 55 | 56 | def get_best_score(self): 57 | return self.__best_score --------------------------------------------------------------------------------