├── .gitignore ├── LICENSE ├── README.md ├── distribute.sh ├── docs ├── config.md └── tutorial.md ├── projects ├── cifar10_demo │ ├── ReadMe.md │ ├── cifar10.yaml │ └── main.py └── fake_demo │ ├── ReadMe.md │ ├── fake_cfg.yaml │ └── main.py ├── release.md ├── requirements.txt ├── setup.py ├── tests ├── test_model.py └── test_transforms.py └── torchline ├── __init__.py ├── config ├── __init__.py ├── config.py └── default.py ├── data ├── __init__.py ├── albumentaion_transforms.py ├── autoaugment.py ├── build.py ├── common_datasets.py ├── sampler.py └── transforms.py ├── engine ├── __init__.py ├── build.py ├── default_module.py └── utils.py ├── losses ├── __init__.py ├── build.py ├── focal_loss.py └── loss.py ├── models ├── __init__.py ├── build.py ├── dpn.py ├── efficientnet.py ├── mnasnet.py ├── mobilenet.py ├── nas_models.py ├── pnasnet.py └── resnet.py ├── trainer ├── __init__.py ├── build.py └── default_trainer.py └── utils ├── __init__.py ├── average_meter.py ├── logger.py ├── registry.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | dist 3 | build 4 | torchline.egg-info 5 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 marsggbo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchline v0.3.0.4 2 | 3 | > Easy to use Pytorch 4 | > 5 | > Only one configure file is enough! 6 | > 7 | > You can change anything you want just in only one configure file. 8 | 9 | # Dependences 10 | 11 | - Python>=3.6 12 | - Pytorch>=1.3.1 13 | - torchvision>=0.4.0,<0.5.0 14 | - yacs==0.1.6 15 | - pytorch-lightning<=0.7.6 16 | 17 | 18 | # Install 19 | 20 | - Before you install `torchline`, please make sure you have installed the above libraries. 21 | - You can use `torchline` both in Linux and Windows. 22 | 23 | ```bash 24 | pip install torchline 25 | ``` 26 | 27 | # Run demo 28 | 29 | ## train model with GPU0 and GPU 1 30 | ```python 31 | cd projects/cifar10_demo 32 | python main.py --config_file cifar10.yaml trainer.gpus [0,1] 33 | ``` 34 | 35 | ## debug,add command line `trainer.fast_dev_run True` 36 | ```python 37 | cd projects/cifar10_demo 38 | python main.py --config_file cifar10.yaml trainer.gpus [0] trainer.fast_dev_run True 39 | ``` 40 | 41 | CIFAR demo uses ResNet50,which is trained for 72 epochs and achieved the best result (94.39% validation accuracy) at the epoch 54. 42 | 43 | # Thanks 44 | 45 | - [AutoML](https://zhuanlan.zhihu.com/automl) 46 | - [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning) 47 | 48 | -------------------------------------------------------------------------------- /distribute.sh: -------------------------------------------------------------------------------- 1 | rm -rf build dist torchline.egg-info 2 | python setup.py sdist bdist_wheel 3 | if (($1==1));then 4 | python -m twine upload dist/* 5 | elif (($1==2));then 6 | python -m twine upload --repository-url https://test.pypi.org/legacy/ dist/* 7 | else 8 | echo "Wrong command, only support 1 or 2" 9 | fi -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | 2 | # 1. 输入 3 | 4 | ## 1.1 `input` 5 | 6 | 7 | - `size`: [112,112] # 输入图像大小 8 | 9 | # 2. 数据集 10 | 11 | ## 2.1 `dataset` 12 | 13 | ```yaml 14 | batch_size: 1 15 | dir: './datasets/mydataset' 16 | is_train: False 17 | name: 'fakedata' 18 | test_list: './datasets/test_list.txt' 19 | train_list: './datasets/train_list.txt' 20 | valid_list: './datasets/valid_list.txt' 21 | ``` 22 | 23 | ## 2.2 `dataloader` 24 | 25 | ```yaml 26 | num_workers: 0 27 | sample_test: 'default' 28 | sample_train: 'default' 29 | ``` 30 | 31 | ## 2.3 数据增强 32 | 33 | ### 2.3.1 数据增强 34 | 35 | #### 2.3.1.1 基于torchvision 36 | ```yaml 37 | transforms: 38 | img: 39 | aug_cifar: False 40 | aug_imagenet: False 41 | center_crop: 42 | enable: 0 43 | color_jitter: 44 | brightness: 0.0 45 | contrast: 0.0 46 | enable: 0 47 | hue: 0.0 48 | saturation: 0.0 49 | random_crop: 50 | enable: 1 51 | padding: 4 52 | random_horizontal_flip: 53 | enable: 1 54 | p: 0.5 55 | random_resized_crop: 56 | enable: 0 57 | ratio: (0.75, 1.3333333333333333) 58 | scale: (0.5, 1.0) 59 | random_rotation: 60 | degrees: 10 61 | enable: 0 62 | random_vertical_flip: 63 | enable: 0 64 | p: 0.5 65 | resize: 66 | enable: 0 67 | name: 'DefaultTransforms' 68 | tensor: 69 | normalization: 70 | mean: [0.4914, 0.4822, 0.4465] 71 | std: [0.2023, 0.1994, 0.201] 72 | random_erasing: 73 | enable: 0 74 | p: 0.5 75 | ratio: ((0.3, 3.3),) 76 | scale: (0.02, 0.3) 77 | ``` 78 | 79 | #### 2.3.1.2 基于albumentations 80 | 81 | 82 | `abtfs`: albumentations transforms 83 | ```yaml 84 | abtfs: 85 | random_grid_shuffle: 86 | enable: 0 87 | grid: 2 88 | channel_shuffle: 89 | enable: 0 90 | channel_dropout: 91 | enable: 0 92 | drop_range: (1, 1) 93 | fill_value: 127 94 | noise: 95 | enable: 1 96 | blur: 97 | enable: 0 98 | rotate: 99 | enable: 1 100 | bright: 101 | enable: 1 102 | distortion: 103 | enable: 0 104 | hue: 105 | enable: 0 106 | cutout: 107 | enable: 1 108 | num_holes: 10 109 | size: 20 110 | fill_value: 127 111 | ``` 112 | 113 | ### 2.3.2 标签数据增强 114 | 115 | ```yaml 116 | label_transforms: 117 | name: 'default' 118 | ``` 119 | 120 | # 3. 模型 121 | 122 | 123 | ```yaml 124 | model: 125 | classes: 10 126 | name: 'FakeNet' 127 | pretrained: True 128 | # features: ['f4'] 129 | # features_fusion: 'sum' 130 | # finetune: False 131 | ``` 132 | 133 | # 4. 损失函数 134 | 135 | ```yaml 136 | loss: 137 | class_weight: [] 138 | focal_loss: 139 | alpha: [] 140 | gamma: 2 141 | size_average: True 142 | label_smoothing: 0.1 143 | name: 'CrossEntropy' 144 | ``` 145 | 146 | # 5. 优化器和步长调整器 147 | 148 | ```yaml 149 | optim: 150 | base_lr: 0.1 151 | momentum: 0.9 152 | name: 'sgd' 153 | scheduler: 154 | gamma: 0.1 155 | milestones: [150, 250] 156 | mode: 'min' 157 | name: 'MultiStepLR' 158 | patience: 10 159 | step_size: 10 160 | t_max: 10 161 | verbose: True 162 | weight_decay: 0.0005 163 | ``` 164 | 165 | # 6. 引擎Module 166 | 167 | ```yaml 168 | module: 169 | name: 'DefaultModule' 170 | ``` 171 | 172 | # 7. 训练控制器 173 | 174 | ```yaml 175 | trainer: 176 | accumulate_grad_batches: 1 177 | amp_level: 'O1' 178 | check_val_every_n_epoch: 1 179 | default_root_dir: './output_fakedata' 180 | distributed_backend: 'dp' 181 | fast_dev_run: False 182 | gpus: [] 183 | gradient_clip_val: 0 184 | log_gpu_memory: '' 185 | log_save_interval: 100 186 | logger: 187 | mlflow: 188 | experiment_name: 'torchline_logs' 189 | tracking_uri: './output' 190 | setting: 0 191 | test_tube: 192 | name: 'torchline_logs' 193 | save_dir: './output_fakedata' 194 | version: -1 195 | type: 'test_tube' 196 | max_epochs: 100 197 | min_epochs: 1 198 | name: 'DefaultTrainer' 199 | num_nodes: 1 200 | num_sanity_val_steps: 5 201 | overfit_pct: 0.0 202 | print_nan_grads: True 203 | process_position: 0 204 | resume_from_checkpoint: '' 205 | row_log_interval: 10 206 | show_progress_bar: False 207 | test_percent_check: 1.0 208 | track_grad_norm: -1 209 | train_percent_check: 1.0 210 | truncated_bptt_steps: '' 211 | use_amp: False 212 | val_check_interval: 1.0 213 | val_percent_check: 1.0 214 | weights_save_path: '' 215 | weights_summary: '' 216 | ``` 217 | 218 | # 其他 219 | 220 | ```yaml 221 | VERSION: 1 222 | DEFAULT_CUDNN_BENCHMARK: True 223 | SEED: 666 224 | topk: [1, 3] 225 | ``` 226 | -------------------------------------------------------------------------------- /docs/tutorial.md: -------------------------------------------------------------------------------- 1 | # 代码结构 2 | 3 | 4 | 文件结构参照[detectron2](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=1&cad=rja&uact=8&ved=2ahUKEwiux_PXpLDmAhVOPnAKHVTjDVEQFjAAegQIBxAC&url=https%3A%2F%2Fgithub.com%2Ffacebookresearch%2Fdetectron2&usg=AOvVaw25FixXG7GH7dRKY6sOc2Oc)设计,采用注册机制来实现不同模块的灵活切换。 5 | 6 | - `torchline` 7 | - `config`: Configuration Module 8 | - `config.py`: return base `CfgNode` 9 | - `default.py`: 默认参数设置文件,后续可以通过导入`.yaml`文件来对指定参数做修改 10 | - `data`: 数据集模块,返回`torch.utils.data.Dataset` 11 | - `build.py`: 注册数据集,提供`build_data`函数来获取不同的数据集和数据集注册器 12 | - 数据集相关文件: 13 | - `common_datasets.py`: 返回MNIST和CIFAR10数据集 14 | - transform相关文件: 15 | - `transforms.py`: 提供`build_transforms`函数来构建`transforms`,并且该文件中包含了默认的transform类,即`DefaultTransforms` 16 | - `autoaugment.py`: Google提出的自动数据增强操作 17 | - `albumentation_transforms`: 使用albumentations库做数据增广 18 | - `engine`: 19 | - `default_module.py`: 提供了`LightningModule`的一个继承类模板 20 | - `losses`: 21 | - `build.py`: 提供`build_loss_fn`函数和loss注册器 22 | - `loss.py`: 提供一些常用的loss函数(`CrossEntropy()`) 23 | - `models`: 24 | - `build.py`: 提供`build_model`函数和模型注册器 25 | - `trainer`: 26 | - `utils.py`: 27 | - `registry.py`: 注册器模板 28 | - `logger.py`: 输出日志模块 29 | - `main.py`: 代码运行入口,后面的项目构建都可以参照这个文件写代码 30 | - `projects`: You can create your own project here. 31 | - `cifar10_demo`: A CIFAR10 demo project 32 | - `fake_demo`: 使用随机生成的数据,方便调试和体验torchline的使用 33 | 34 | 35 | # 如何自定义? 36 | 37 | 待完善。。。 38 | 39 | # 自定义参数配置 40 | 41 | # 自定义数据集 42 | 43 | # 自定义`engine`模板 44 | 45 | > 可自定义 如何**读取数据**,**优化器设置**,**forward步骤** 46 | 47 | 48 | # 自定义新的模型`model` 49 | 50 | 51 | # 自定义损失函数 52 | 53 | # 自定义`trainer` -------------------------------------------------------------------------------- /projects/cifar10_demo/ReadMe.md: -------------------------------------------------------------------------------- 1 | 2 | # Run 3 | 4 | ```python 5 | cd projects/cifar10_demo 6 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 7 | ``` 8 | 9 | # Restore traininig 10 | 11 | 指定`trainer.logger`信息即可,例如你在前面的实验中的日志信息(包括metrics和checkpoint等)保存如下: 12 | 13 | ```bash 14 | |___output 15 | |___lightning_log # specified by trainer.logger.test_tube.name 16 | |___version_0 # specified by trainer.logger.test_tube.version 17 | |___metrics.csv 18 | |___...(other log files) 19 | |___checkpoint 20 | |___ _checkpoint_epoch_60.ckpt 21 | ``` 22 | 23 | - `trainer.logger.setting `: 0 表示默认设置,1表示不用logger,2表示自定义logger 24 | - `trainer.logger.test_tube.name`: logger名字,如lightning_log 25 | - `trainer.logger.test_tube.version`: logger版本,如0 26 | 27 | 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0 30 | ``` 31 | 32 | # test_only 33 | 34 | 只运行验证集,参数设置同上 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --test_only trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0 37 | ``` 38 | 39 | # predict_only 40 | 41 | 预测指定路径下的图片,需要设置如下两个参数: 42 | 43 | - `predict_only.load_ckpt.checkpoint_path`: checkpoint路径 44 | - `predict_only.to_pred_file_path`: 需要预测的图片路径,可以是单张图片的路径,也可以是包含多张图片的文件夹路径 45 | 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --predict_only predict_only.load_ckpt.checkpoint_path './output_cifar10/lightning_logs/version_0/checkpoints/_ckpt_epoch_69.ckpt' predict_only.to_pred_file_path '.' 48 | ``` -------------------------------------------------------------------------------- /projects/cifar10_demo/cifar10.yaml: -------------------------------------------------------------------------------- 1 | input: 2 | size: (32, 32) 3 | model: 4 | name: 'Resnet50' 5 | classes: 10 6 | dataset: 7 | name: 'CIFAR10' 8 | dir: './datasets/cifar10' 9 | batch_size: 128 10 | optim: 11 | name: 'sgd' 12 | base_lr: 0.1 13 | scheduler: 14 | name: 'MultiStepLR' 15 | milestones: [150, 250] 16 | gamma: 0.1 17 | transforms: 18 | tensor: 19 | normalization: 20 | mean: [0.4914, 0.4822, 0.4465] 21 | std: [0.2023, 0.1994, 0.2010] 22 | img: 23 | aug_cifar: False 24 | resize: 25 | enable: 0 26 | random_crop: 27 | enable: 1 28 | padding: 4 29 | random_horizontal_flip: 30 | enable: 1 31 | p: 0.5 32 | random_vertical_flip: 33 | enable: 0 34 | p: 0.5 35 | random_rotation: 36 | enable: 0 37 | trainer: 38 | default_root_dir: './output_cifar10' 39 | logger: 40 | setting: 0 41 | type: 'test_tube' 42 | mlflow: 43 | tracking_uri: './output_cifar10' 44 | test_tube: 45 | save_dir: './output_cifar10' 46 | name: 'torchline_logs' 47 | version: -1 # if <0, then use default version 48 | hooks: 49 | early_stopping: 50 | setting: 2 51 | patience: 20 52 | monitor: 'valid_acc_1' 53 | mode: 'max' 54 | model_checkpoint: 55 | setting: 2 56 | monitor: 'valid_acc_1' 57 | mode: 'max' 58 | filepath: '' 59 | topk: [1, 3] 60 | predict_only: 61 | type: 'ckpt' 62 | to_pred_file_path: '' # specify the path of images 63 | load_ckpt: 64 | checkpoint_path: '' # load_from_checkpoint 65 | load_metric: 66 | weights_path: '' # load_from_metrics 67 | tags_csv: '' 68 | on_gpu: True 69 | map_location: 'cuda:0' 70 | -------------------------------------------------------------------------------- /projects/cifar10_demo/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs a model on a single node across N-gpus. 3 | """ 4 | import argparse 5 | import glob 6 | import os 7 | import shutil 8 | from argparse import ArgumentParser 9 | 10 | import numpy as np 11 | import torch 12 | from torchline.config import get_cfg 13 | from torchline.engine import build_module 14 | from torchline.trainer import build_trainer 15 | from torchline.utils import get_imgs_to_predict 16 | 17 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 18 | 19 | def main(hparams): 20 | """ 21 | Main training routine specific for this project 22 | :param hparams: 23 | """ 24 | 25 | cfg = get_cfg() 26 | cfg.setup_cfg_with_hparams(hparams) 27 | # only predict on some samples 28 | if hasattr(hparams, "predict_only") and hparams.predict_only: 29 | predict_only = cfg.predict_only 30 | if predict_only.type == 'ckpt': 31 | load_params = {key: predict_only.load_ckpt[key] for key in predict_only.load_ckpt} 32 | model = build_module(cfg) 33 | ckpt_path = load_params['checkpoint_path'] 34 | model.load_state_dict(torch.load(ckpt_path)['state_dict']) 35 | elif predict_only.type == 'metrics': 36 | load_params = {key: predict_only.load_metric[key] for key in predict_only.load_metric} 37 | model = build_module(cfg).load_from_metrics(**load_params) 38 | else: 39 | print(f'{cfg.predict_only.type} not supported') 40 | raise NotImplementedError 41 | 42 | model.eval() 43 | model.freeze() 44 | images = get_imgs_to_predict(cfg.predict_only.to_pred_file_path, cfg) 45 | if torch.cuda.is_available(): 46 | images['img_data'] = images['img_data'].cuda() 47 | model = model.cuda() 48 | predictions = model(images['img_data']) 49 | class_indices = torch.argmax(predictions, dim=1) 50 | for i, file in enumerate(images['img_file']): 51 | index = class_indices[i] 52 | print(f"{file} is {classes[index]}") 53 | return predictions.cpu() 54 | elif hasattr(hparams, "test_only") and hparams.test_only: 55 | model = build_module(cfg) 56 | trainer = build_trainer(cfg, hparams) 57 | trainer.test(model) 58 | else: 59 | model = build_module(cfg) 60 | trainer = build_trainer(cfg, hparams) 61 | trainer.fit(model) 62 | 63 | 64 | if __name__ == '__main__': 65 | # ------------------------ 66 | # TRAINING ARGUMENTS 67 | # ------------------------ 68 | # these are project-wide arguments 69 | 70 | root_dir = os.path.dirname(os.path.realpath(__file__)) 71 | parent_parser = ArgumentParser(add_help=False) 72 | 73 | # gpu args 74 | parent_parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file") 75 | parent_parser.add_argument('--test_only', action='store_true', help='if true, return trainer.test(model). Validates only the test set') 76 | parent_parser.add_argument('--predict_only', action='store_true', help='if true run model(samples). Predict on the given samples.') 77 | parent_parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) 78 | 79 | # each LightningModule defines arguments relevant to it 80 | hparams = parent_parser.parse_args() 81 | assert not (hparams.test_only and hparams.predict_only), "You can't set both 'test_only' and 'predict_only' True" 82 | 83 | # --------------------- 84 | # RUN TRAINING 85 | # --------------------- 86 | main(hparams) 87 | -------------------------------------------------------------------------------- /projects/fake_demo/ReadMe.md: -------------------------------------------------------------------------------- 1 | 2 | # Run 3 | 4 | ```python 5 | cd projects/cifar10_demo 6 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 7 | ``` 8 | 9 | # Restore traininig 10 | 11 | 指定`trainer.logger`信息即可,例如你在前面的实验中的日志信息(包括metrics和checkpoint等)保存如下: 12 | 13 | ```bash 14 | |___output 15 | |___lightning_log # specified by trainer.logger.test_tube.name 16 | |___version_0 # specified by trainer.logger.test_tube.version 17 | |___metrics.csv 18 | |___...(other log files) 19 | |___checkpoint 20 | |___ _checkpoint_epoch_60.ckpt 21 | ``` 22 | 23 | - `trainer.logger.setting `: 0 表示默认设置,1表示不用logger,2表示自定义logger 24 | - `trainer.logger.test_tube.name`: logger名字,如lightning_log 25 | - `trainer.logger.test_tube.version`: logger版本,如0 26 | 27 | 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0 30 | ``` 31 | 32 | # test_only 33 | 34 | 只运行验证集,参数设置同上 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --test_only trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0 37 | ``` 38 | 39 | # predict_only 40 | 41 | 预测指定路径下的图片,需要设置如下两个参数: 42 | 43 | - `predict_only.load_ckpt.checkpoint_path`: checkpoint路径 44 | - `predict_only.to_pred_file_path`: 需要预测的图片路径,可以是单张图片的路径,也可以是包含多张图片的文件夹路径 45 | 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --predict_only predict_only.load_ckpt.checkpoint_path './output_cifar10/lightning_logs/version_0/checkpoints/_ckpt_epoch_69.ckpt' predict_only.to_pred_file_path '.' 48 | ``` -------------------------------------------------------------------------------- /projects/fake_demo/fake_cfg.yaml: -------------------------------------------------------------------------------- 1 | DEFAULT_CUDNN_BENCHMARK: True 2 | SEED: 666 3 | VERSION: 1 4 | abtfs: 5 | blur: 6 | blur_limit: 3 7 | enable: 0 8 | bright: 9 | clip_limit: 1 10 | enable: 1 11 | channel_dropout: 12 | drop_range: (1, 1) 13 | enable: 0 14 | fill_value: 127 15 | channel_shuffle: 16 | enable: 0 17 | cutout: 18 | enable: 1 19 | fill_value: 127 20 | num_holes: 10 21 | size: 20 22 | distortion: 23 | enable: 0 24 | hue: 25 | enable: 0 26 | noise: 27 | enable: 1 28 | random_grid_shuffle: 29 | enable: 0 30 | grid: 2 31 | rotate: 32 | enable: 1 33 | p: 1 34 | rotate_limit: 45 35 | scale_limit: 0.2 36 | shift_limit: 0.0625 37 | dataloader: 38 | num_workers: 0 39 | sample_test: 'default' 40 | sample_train: 'default' 41 | dataset: 42 | batch_size: 1 43 | dir: '' 44 | is_train: False 45 | name: 'fakedata' 46 | test_list: '' 47 | train_list: '' 48 | valid_list: '' 49 | hooks: 50 | early_stopping: 51 | min_delta: 0.0 52 | mode: 'max' 53 | monitor: 'valid_acc_1' 54 | patience: 20 55 | setting: 2 56 | verbose: 1 57 | model_checkpoint: 58 | filepath: '' 59 | mode: 'max' 60 | monitor: 'valid_acc_1' 61 | period: 1 62 | prefix: '' 63 | save_top_k: 1 64 | save_weights_only: False 65 | setting: 2 66 | verbose: 1 67 | input: 68 | size: (16, 16) 69 | label_transforms: 70 | name: 'default' 71 | log: 72 | name: '' 73 | path: '' 74 | loss: 75 | class_weight: [] 76 | focal_loss: 77 | alpha: [] 78 | gamma: 2 79 | size_average: True 80 | label_smoothing: 0.1 81 | name: 'CrossEntropy' 82 | model: 83 | classes: 10 84 | finetune: False 85 | name: 'FakeNet' 86 | pretrained: True 87 | module: 88 | name: 'DefaultModule' 89 | optim: 90 | base_lr: 0.1 91 | momentum: 0.9 92 | name: 'sgd' 93 | scheduler: 94 | gamma: 0.1 95 | milestones: [150, 250] 96 | mode: 'min' 97 | name: 'MultiStepLR' 98 | patience: 10 99 | step_size: 10 100 | t_0: 5 101 | t_max: 10 102 | t_mul: 20 103 | verbose: True 104 | weight_decay: 0.0005 105 | predict_only: 106 | load_ckpt: 107 | checkpoint_path: '' 108 | load_metric: 109 | map_location: '' 110 | on_gpu: True 111 | tags_csv: '' 112 | weights_path: '' 113 | to_pred_file_path: '' 114 | type: 'ckpt' 115 | topk: [1, 3] 116 | trainer: 117 | accumulate_grad_batches: 1 118 | auto_lr_find: False 119 | auto_scale_batch_size: False 120 | auto_select_gpus: False 121 | benchmark: False 122 | check_val_every_n_epoch: 1 123 | default_root_dir: './output' 124 | distributed_backend: 'dp' 125 | fast_dev_run: False 126 | gpus: [] 127 | gradient_clip_val: 0 128 | log_gpu_memory: '' 129 | log_save_interval: 100 130 | logger: 131 | mlflow: 132 | experiment_name: 'torchline_logs' 133 | tracking_uri: './output' 134 | setting: 2 135 | test_tube: 136 | name: 'torchline_logs' 137 | save_dir: './output' 138 | version: -1 139 | type: 'test_tube' 140 | max_epochs: 100 141 | max_steps: 1000 142 | min_epochs: 10 143 | min_steps: 400 144 | name: 'DefaultTrainer' 145 | num_nodes: 1 146 | num_sanity_val_steps: 5 147 | num_tpu_cores: '' 148 | overfit_pct: 0.0 149 | precision: 32 150 | print_nan_grads: True 151 | process_position: 0 152 | progress_bar_refresh_rate: 0 153 | reload_dataloaders_every_epoch: False 154 | replace_sampler_ddp: True 155 | resume_from_checkpoint: '' 156 | row_log_interval: 10 157 | show_progress_bar: False 158 | terminate_on_nan: False 159 | test_percent_check: 1.0 160 | track_grad_norm: -1 161 | train_percent_check: 1.0 162 | truncated_bptt_steps: '' 163 | val_check_interval: 1.0 164 | val_percent_check: 1.0 165 | weights_save_path: '' 166 | weights_summary: '' 167 | transforms: 168 | img: 169 | aug_cifar: False 170 | aug_imagenet: False 171 | center_crop: 172 | enable: 0 173 | color_jitter: 174 | brightness: 0.0 175 | contrast: 0.0 176 | enable: 0 177 | hue: 0.0 178 | saturation: 0.0 179 | random_crop: 180 | enable: 1 181 | padding: 4 182 | random_horizontal_flip: 183 | enable: 1 184 | p: 0.5 185 | random_resized_crop: 186 | enable: 0 187 | ratio: (0.75, 1.3333333333333333) 188 | scale: (0.5, 1.0) 189 | random_rotation: 190 | degrees: 10 191 | enable: 0 192 | random_vertical_flip: 193 | enable: 0 194 | p: 0.5 195 | resize: 196 | enable: 0 197 | name: 'DefaultTransforms' 198 | tensor: 199 | normalization: 200 | mean: [0.4914, 0.4822, 0.4465] 201 | std: [0.2023, 0.1994, 0.201] 202 | random_erasing: 203 | enable: 0 204 | p: 0.5 205 | ratio: ((0.3, 3.3),) 206 | scale: (0.02, 0.3) -------------------------------------------------------------------------------- /projects/fake_demo/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs a model on a single node across N-gpus. 3 | """ 4 | import argparse 5 | import glob 6 | import os 7 | import shutil 8 | import sys 9 | sys.path.append('../..') 10 | from argparse import ArgumentParser 11 | 12 | import numpy as np 13 | import torch 14 | 15 | from torchline.models import META_ARCH_REGISTRY 16 | from torchline.trainer import build_trainer 17 | from torchline.config import get_cfg 18 | from torchline.engine import build_module 19 | from torchline.utils import Logger, get_imgs_to_predict 20 | 21 | @META_ARCH_REGISTRY.register() 22 | class FakeNet(torch.nn.Module): 23 | def __init__(self, cfg): 24 | super(FakeNet, self).__init__() 25 | self.cfg = cfg 26 | h, w = cfg.input.size 27 | self.feat = torch.nn.Conv2d(3,16,3,1,1) 28 | self.clf = torch.nn.Linear(16, cfg.model.classes) 29 | print(self) 30 | 31 | def forward(self, x): 32 | b = x.shape[0] 33 | out = self.feat(x) 34 | out = torch.nn.AdaptiveAvgPool2d(1)(out).view(b, -1) 35 | out = self.clf(out) 36 | out = torch.softmax(out, dim=1) 37 | return out 38 | 39 | def main(hparams): 40 | """ 41 | Main training routine specific for this project 42 | :param hparams: 43 | """ 44 | 45 | cfg = get_cfg() 46 | cfg.setup_cfg_with_hparams(hparams) 47 | if hasattr(hparams, "test_only") and hparams.test_only: 48 | model = build_module(cfg) 49 | trainer = build_trainer(cfg, hparams) 50 | trainer.test(model) 51 | else: 52 | model = build_module(cfg) 53 | trainer = build_trainer(cfg, hparams) 54 | trainer.fit(model) 55 | 56 | 57 | if __name__ == '__main__': 58 | # ------------------------ 59 | # TRAINING ARGUMENTS 60 | # ------------------------ 61 | # these are project-wide arguments 62 | 63 | root_dir = os.path.dirname(os.path.realpath(__file__)) 64 | parent_parser = ArgumentParser(add_help=False) 65 | 66 | # gpu args 67 | parent_parser.add_argument("--config_file", default="./fake_cfg.yaml", metavar="FILE", help="path to config file") 68 | parent_parser.add_argument('--test_only', action='store_true', help='if true, return trainer.test(model). Validates only the test set') 69 | parent_parser.add_argument('--predict_only', action='store_true', help='if true run model(samples). Predict on the given samples.') 70 | parent_parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER) 71 | 72 | # each LightningModule defines arguments relevant to it 73 | hparams = parent_parser.parse_args() 74 | assert not (hparams.test_only and hparams.predict_only), "You can't set both 'test_only' and 'predict_only' True" 75 | 76 | # --------------------- 77 | # RUN TRAINING 78 | # --------------------- 79 | main(hparams) 80 | -------------------------------------------------------------------------------- /release.md: -------------------------------------------------------------------------------- 1 | 2 | # Release 3 | 4 | ## v0.1 5 | 6 | ### 2019.12.12 更新信息 7 | - 基本框架搭建完成 8 | - 可正常运行cifar10_demo 9 | 10 | 11 | ### 2019.12.13 更新信息 12 | - 实现setup安装 13 | - 完善包之间的引用关系 14 | - 完善各模块之间的关系 15 | - data 16 | - build_data 17 | - build_sampler 18 | - build_transforms 19 | - build_label_transforms 20 | - engine 21 | - build_module 22 | - losses 23 | - build_loss_fn 24 | - models 25 | - build_model 26 | - 27 | ### 2019.12.16更新信息 28 | 29 | - 更新package版本至`torchline-0.1.3` 30 | - 更新transforms类结构,可以更加方便的自定义修改 31 | 32 | ### 2019.12.17更新信息 33 | - 可单独预测指定路径下的图片,`predict_only`模式, 34 | - 完成`test_only`模式 35 | - 新增topk结果显示 36 | - 支持restore training 37 | > 详细细节可查看[project/skin/ReadMe.md](projects/skin/ReadMe.md) 38 | 39 | ### 2019.12.19更新信息 40 | - 修改各种路径参数设置逻辑(详见[skin/ReadMe.md](projects/skin/ReadMe.md)) 41 | 42 | 43 | ## v0.2 44 | 45 | ### 2020.02.18 46 | 47 | - 优化代码逻辑 48 | - 抽象化出`trainer`类 49 | - `trainer`负责整个计算逻辑, `engine`下定义的`DefaultModule`用来指定具体的步骤包括:模型,优化器,训练、验证、测试过程,以及数据的读取等,`models`中则是定义了具体的模型,如resnet等。 50 | 51 | ## v0.2.2 52 | 53 | - 代码逻辑更加清晰 54 | - 修复日志bug,使用更加灵活 55 | 56 | ## v0.2.2.2 57 | 58 | - config输出格式化 59 | 60 | ## v0.2.3.0 61 | 62 | - 增加新的输出日志的方式(即logging),日志可保存到文件方便查看。之前默认使用tqdm,一个比较大的缺点时无法清晰的看到模型是否开始收敛。 63 | - 引入`AverageMeterGroup`等来记录各项指标,能更清晰地看出整体收敛趋势 64 | - 更新`fake_data` demo 65 | - 修复`CosineAnnealingLR` 存在的bug 66 | 67 | ## v0.2.3.1 68 | - 优化`AverageMeterGroup`输出格式,方便阅读 69 | 70 | ## v0.2.3.2 71 | - 优化优化器代码,增加可扩展性 72 | - 增加学习率热启动(`CosineAnnealingWarmRestarts`) 73 | 74 | ## v0.2.3.3 75 | - 优化resume 76 | - 增加albumentations数据增广操作 77 | - 修改之前的resize和crop之间的逻辑关系 78 | 79 | ## v0.2.3.4 80 | - 抽象化optimizer和scheduler定义,方便从外部直接调用 81 | - 添加计算模型大小的函数 82 | 83 | ## v0.2.4.0 84 | - 增加大量SOTA模型结构,如Mnasnet, mobilenet等 85 | - 统一模型结构(features, logits, forward, last_linear) 86 | 87 | ## v0.2.4.1 88 | - 修改单机多卡训练bug 89 | - 此模式夏batch size必须是gpu的整数倍,否则汇报如下错误: 90 | ```Python 91 | ValueError: only one element tensors can be converted to Python scalars 92 | ``` 93 | - 规范化两种日志模式: tqdm和logging 94 | 95 | ## v0.2.4.2 96 | - 修复单机多卡训练时的bug 97 | - 修改和统一model forward函数: features+logits 98 | 99 | ## v0.2.4.3 100 | - 更新module forward函数 101 | - 增加loss函数,最小化entropy 102 | 103 | ## v0.3.0.0 2020.05.11 104 | - 适配pytorchlightning 0.7.5 105 | 106 | ## v0.3.0.1 2020.05.14 107 | - 完善logger path自动匹配功能;例如,当resume_from_checkpoint时会自动还原之前的logger路径。反之则会自动更新logger version。 108 | 109 | ## v0.3.0.2 2020.05.19 110 | - 允许`DefaultModule`存储计算中的所有`gt_labels`和`predictions`,然后可以通过`analyze_result`函数计算指定的指标,例如precision或者auc等。 111 | 112 | ## v0.3.0.3 2020.05.19 113 | - 修复上一版本在多GPU状态下报错的问题 114 | 115 | ## v0.3.0.4 2020.05.19 116 | - 修复多GPU状态下重复打印日志问题 117 | 118 | # TODO list 119 | 120 | 121 | - [x] 弄清楚logging机制 122 | - [x] save和load模型,优化器参数 123 | - [x] skin数据集读取测试 124 | - [x] 构建skin project 125 | - [x] 能否预测单张图片? 126 | - [x] 构建一个简单的API接口 127 | - [x] 进一步完善包导入 128 | - [x] 设置训练epoch数量 129 | - [X] 内建更多SOTA模型 130 | - [x] 每个epoch输出学习率大小 131 | - [x] resume时输出checkpoint的结果 132 | - [x] 如果resume,则自动匹配checkpoints等路径 133 | - [x] 优化输出日志信息 134 | - [x] 使用albumentations做数据增强 135 | - [x] transforms resize和randomcrop逻辑关系 136 | - [x] 从engine中抽离出optimizer和scheduler 137 | - [x] ~~resnet结构可能有问题,resnet50应该有98MB,但是实现只有89.7~~。(没有错,只是因为计算时将classes设置成了10,所以导致了误差) 138 | - [x] 单机多卡多GPU测试 139 | - [x] ~~考虑是否将finetune设置内嵌到模型中~~ (取消设置,避免模型代码臃肿) 140 | - [x] 多GPU运行时日志会因为多线程而导致先后输出不同batch的结果,需要在结果整合后再输出结果,可以考虑将`print_log`放到`on_batch_end`里去(#v0.3.0.4) 141 | - [ ] 设置更多默认的数据集 142 | - [ ] 完善使用文档 143 | - [x] ~~评估使用hydra代替yacs的必要性~~(工作量太大) 144 | - [ ] 增加config参数鲁棒性和兼容性 145 | - [ ] 多机多卡测试 146 | - [x] template project. 可快速基于torchline创建一个新项目 147 | - [ ] 将`default_module`中的`parse_cfg_for_scheduler`解耦,放到`utils.py`文件中去 148 | - [ ] checkpoint将scheduler参数也保存,同时添加设置可以跳过optimizer或scheduler的restore 149 | - [x] multi-gpus情况下日志会生成多份,打印信息也有这种情况(#v0.3.0.4) 150 | - [ ] 文件结构重构 151 | - [x] 适配pytorchlightning 0.7.5版本 152 | - [x] ~~规范参数名称,尽量使用全程,如使用optimizer而不是optim # 在大版本v0.3.0.0中更新~~ 153 | - [x] ~~albumentations和torchvision读取图片使用的分别是cv2和PIL,数据格式分别是numpy和PIL.Image,后面需要考虑如何统一格式。~~ (在实现transformation的时候特别处理一下) 154 | - [ ] 增加`Module`中`print_log`通用性 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | yacs==0.1.6 2 | pytorch-lightning==0.7.6 3 | pretrainedmodels>=0.7.4 4 | pillow>=6.2.0 5 | test-tube==0.7.5 6 | albumentations>=0.4.3 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt', 'r') as f: 4 | requirements = f.read().splitlines() 5 | 6 | setup( 7 | name="torchline", # Replace with your own username 8 | version="0.3.0.4", 9 | author="marsggbo", 10 | author_email="csxinhe@comp.hkbu.edu.hk", 11 | description="A framework for easy to use Pytorch", 12 | long_description=''' 13 | The ML developer can easily use this framework to implement your ideas. Our framework is built based on pytorch_lightning, 14 | and the structures is inspired by detectron2''', 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/marsggbo/torchline", 17 | packages=find_packages(exclude=("tests", "projects")), 18 | install_requires=requirements, 19 | classifiers=[ 20 | "Programming Language :: Python :: 3", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | ], 24 | python_requires='>=3.6', 25 | ) 26 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchline as tl 3 | cfg = tl.config.get_cfg() 4 | cfg.model.pretrained = False 5 | 6 | def cpu_test(): 7 | print('=====cpu=====') 8 | x = torch.rand(1,3,64,64) 9 | for m in tl.models.model_list: 10 | cfg.model.name = m 11 | net = tl.models.build_model(cfg) 12 | try: 13 | y = net(x) 14 | print(f"{m} pass") 15 | except Exception as e: 16 | print(f"{m} fail") 17 | print(str(e)) 18 | pass 19 | 20 | def single_gpu_test(): 21 | print('=====single_gpu=====') 22 | x = torch.rand(4,3,64,64) 23 | for m in tl.models.model_list: 24 | cfg.model.name = m 25 | net = tl.models.build_model(cfg).cuda() 26 | x = x.cuda() 27 | try: 28 | y = net(x) 29 | print(f"{m} pass") 30 | except Exception as e: 31 | print(f"{m} fail") 32 | print(str(e)) 33 | pass 34 | 35 | def multi_gpus_test(): 36 | print('=====multi_gpu=====') 37 | x = torch.rand(4,3,64,64) 38 | for m in tl.models.model_list: 39 | cfg.model.name = m 40 | net = tl.models.build_model(cfg).cuda() 41 | net = torch.nn.DataParallel(net, device_ids=[0,1]) 42 | x = x.cuda() 43 | try: 44 | y = net(x) 45 | print(f"{m} pass") 46 | except Exception as e: 47 | print(f"{m} fail") 48 | print(str(e)) 49 | pass 50 | 51 | if __name__ == '__main__': 52 | cpu_test() 53 | single_gpu_test() 54 | multi_gpus_test() -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import torchline as tl 2 | cfg = tl.config.get_cfg() 3 | cfg.dataset.is_train=True 4 | cfg.transforms.name = 'AlbumentationsTransforms' 5 | cfg.abtfs.hue.enable=1 6 | cfg.abtfs.rotate.enable=1 7 | cfg.abtfs.bright.enable=1 8 | cfg.abtfs.noise.enable=1 9 | tf2 = tl.data.build_transforms(cfg) -------------------------------------------------------------------------------- /torchline/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config, data, models, engine, losses, models, utils, trainer 2 | version = "0.3.0.4" 3 | -------------------------------------------------------------------------------- /torchline/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import CfgNode, get_cfg -------------------------------------------------------------------------------- /torchline/config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | from yacs.config import CfgNode as _CfgNode 6 | 7 | from torchline.utils import Logger 8 | 9 | 10 | class CfgNode(_CfgNode): 11 | 12 | def _version(self): 13 | ''' 14 | calculate the version of the configuration and logger 15 | ''' 16 | def calc_version(path): 17 | if (not os.path.exists(path)) or len(os.listdir(path))==0: 18 | version = 0 19 | else: 20 | versions = [int(v.split('_')[-1]) for v in os.listdir(path)] 21 | version = max(versions)+1 22 | return version 23 | 24 | save_dir=self.trainer.default_root_dir 25 | logger_name=self.trainer.logger.test_tube.name 26 | path = os.path.join(save_dir, logger_name) 27 | if self.trainer.logger.setting==0: 28 | version = calc_version(path) 29 | elif self.trainer.logger.setting==2: 30 | if self.trainer.logger.test_tube.version<0: 31 | version = calc_version(path) 32 | else: 33 | version = int(self.trainer.logger.test_tube.version) 34 | return version 35 | 36 | def setup_cfg_with_hparams(self, hparams): 37 | """ 38 | Create configs and perform basic setups. 39 | Args: 40 | hparams: arguments args from command line 41 | """ 42 | self.merge_from_file(hparams.config_file) 43 | self.merge_from_list(hparams.opts) 44 | self.update({'hparams': hparams}) 45 | 46 | # './outputs/torchline_logs/version_0/checkpoints/_ckpt_epoch_1.ckpt' 47 | ckpt_file = self.trainer.resume_from_checkpoint 48 | if ckpt_file: # resume cfg 49 | assert os.path.exists(ckpt_file), f"{ckpt_file} not exits" 50 | ckpt_path = os.path.dirname(ckpt_file).split('/')[:-1] 51 | # set log cfg 52 | self.log.path = ''.join([p+'/' for p in ckpt_path]) 53 | self.log.name = os.path.join(self.log.path, 'log.txt') 54 | # set trainer logger cfg 55 | self.trainer.logger.setting = 2 56 | root = self.trainer.default_root_dir 57 | self.trainer.logger.test_tube.name = ckpt_file.replace(root,'').split('/')[1] 58 | # self.trainer.logger.test_tube.version 59 | else: 60 | version = self._version() 61 | save_dir = self.trainer.default_root_dir 62 | logger_name = self.trainer.logger.test_tube.name 63 | self.log.path = os.path.join(save_dir, logger_name, f"version_{version}") 64 | self.log.name = os.path.join(self.log.path, 'log.txt') 65 | self.freeze() 66 | 67 | os.makedirs(self.log.path, exist_ok=True) 68 | 69 | # copy config file 70 | src_cfg_file = hparams.config_file # source config file 71 | cfg_file_name = os.path.basename(src_cfg_file) # config file name 72 | dst_cfg_file = os.path.join(self.log.path, cfg_file_name) 73 | with open(dst_cfg_file, 'w') as f: 74 | hparams = self.hparams 75 | self.pop('hparams') 76 | f.write(str(self)) 77 | self.update({'hparams': hparams}) 78 | 79 | if not (hparams.test_only or hparams.predict_only): 80 | torch.backends.cudnn.benchmark = False 81 | else: 82 | torch.backends.cudnn.benchmark = self.DEFAULT_CUDNN_BENCHMARK 83 | 84 | Logger(__name__, self.log.name).getlogger().info("Running with full config:\n{}".format(self)) 85 | 86 | def __str__(self): 87 | def _indent(s_, num_spaces): 88 | s = s_.split("\n") 89 | if len(s) == 1: 90 | return s_ 91 | first = s.pop(0) 92 | s = [(num_spaces * " ") + line for line in s] 93 | s = "\n".join(s) 94 | s = first + "\n" + s 95 | return s 96 | 97 | r = "" 98 | s = [] 99 | for k, v in sorted(self.items()): 100 | seperator = "\n" if isinstance(v, CfgNode) else " " 101 | v = f"'{v}'" if isinstance(v, str) else v 102 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 103 | attr_str = _indent(attr_str, 4) 104 | s.append(attr_str) 105 | r += "\n".join(s) 106 | return r 107 | 108 | global_cfg = CfgNode() 109 | 110 | def get_cfg(): 111 | ''' 112 | Get a copy of the default config. 113 | 114 | Returns: 115 | a CfgNode instance. 116 | ''' 117 | from .default import _C 118 | return _C.clone() -------------------------------------------------------------------------------- /torchline/config/default.py: -------------------------------------------------------------------------------- 1 | import random 2 | from .config import CfgNode as CN 3 | 4 | _C = CN() 5 | _C.VERSION = 1 6 | 7 | 8 | # ---------------------------------------------------------------------------- # 9 | # input 10 | # ---------------------------------------------------------------------------- # 11 | _C.input = CN() 12 | _C.input.size = (224, 224) 13 | 14 | # ---------------------------------------------------------------------------- # 15 | # dataset 16 | # ---------------------------------------------------------------------------- # 17 | _C.dataset = CN() 18 | _C.dataset.name = 'cifar10' 19 | _C.dataset.batch_size = 16 20 | _C.dataset.dir = './datasets/skin100_dataset/train' 21 | _C.dataset.train_list = './datasets/train_skin10.txt' 22 | _C.dataset.valid_list = './datasets/valid_skin10.txt' 23 | _C.dataset.test_list = './datasets/test_skin10.txt' 24 | _C.dataset.is_train = False # specify to load training or testing set 25 | 26 | 27 | # ---------------------------------------------------------------------------- # 28 | # transforms 29 | # ---------------------------------------------------------------------------- # 30 | 31 | _C.transforms = CN() # image transforms 32 | _C.transforms.name = 'DefaultTransforms' 33 | 34 | 35 | ## transforms for tensor 36 | _C.transforms.tensor = CN() 37 | # for skin100 38 | _C.transforms.tensor.normalization = CN() 39 | _C.transforms.tensor.normalization.mean = [0.6075, 0.4564, 0.4182] 40 | _C.transforms.tensor.normalization.std = [0.2158, 0.1871, 0.1826] 41 | # _C.transforms.tensor.normalization = { 42 | # 'mean':[0.6054, 0.4433, 0.4084], 43 | # 'std': [0.2125, 0.1816, 0.1786] # for skin10 44 | _C.transforms.tensor.random_erasing = CN() 45 | _C.transforms.tensor.random_erasing.enable = 0 46 | _C.transforms.tensor.random_erasing.p = 0.5 47 | _C.transforms.tensor.random_erasing.scale = (0.02, 0.3) # range of proportion of erased area against input image. 48 | _C.transforms.tensor.random_erasing.ratio = (0.3, 3.3), # range of aspect ratio of erased area. 49 | 50 | 51 | # ---------------------------------------------------------------------------- # 52 | # albumentations transforms (abtfs) 53 | # ---------------------------------------------------------------------------- # 54 | 55 | _C.abtfs = CN() 56 | _C.abtfs.random_grid_shuffle = CN() 57 | _C.abtfs.random_grid_shuffle.enable = 0 58 | _C.abtfs.random_grid_shuffle.grid = 2 59 | 60 | _C.abtfs.channel_shuffle = CN() 61 | _C.abtfs.channel_shuffle.enable = 0 62 | 63 | _C.abtfs.channel_dropout = CN() 64 | _C.abtfs.channel_dropout.enable = 0 65 | _C.abtfs.channel_dropout.drop_range = (1, 1) 66 | _C.abtfs.channel_dropout.fill_value = 127 67 | 68 | _C.abtfs.noise = CN() 69 | _C.abtfs.noise.enable = 1 70 | 71 | _C.abtfs.blur = CN() 72 | _C.abtfs.blur.enable = 0 73 | _C.abtfs.blur.blur_limit = 3 74 | 75 | _C.abtfs.rotate = CN() 76 | _C.abtfs.rotate.enable = 1 77 | _C.abtfs.rotate.p = 1 78 | _C.abtfs.rotate.shift_limit = 0.0625 79 | _C.abtfs.rotate.scale_limit = 0.2 80 | _C.abtfs.rotate.rotate_limit = 45 81 | 82 | _C.abtfs.bright = CN() 83 | _C.abtfs.bright.enable = 1 84 | _C.abtfs.bright.clip_limit = 1 85 | 86 | _C.abtfs.distortion = CN() 87 | _C.abtfs.distortion.enable = 0 88 | 89 | _C.abtfs.hue = CN() 90 | _C.abtfs.hue.enable = 0 91 | 92 | _C.abtfs.cutout = CN() 93 | _C.abtfs.cutout.enable = 1 94 | _C.abtfs.cutout.num_holes = 10 95 | _C.abtfs.cutout.size = 20 96 | _C.abtfs.cutout.fill_value = 127 97 | 98 | # ---------------------------------------------------------------------------- # 99 | # torchvision transforms 100 | # ---------------------------------------------------------------------------- # 101 | 102 | ## transforms for PIL image 103 | _C.transforms.img = CN() 104 | 105 | ### modify the image size, only use one operation 106 | # random_resized_crop 107 | _C.transforms.img.random_resized_crop = CN() 108 | _C.transforms.img.random_resized_crop.enable = 0 109 | _C.transforms.img.random_resized_crop.scale = (0.5, 1.0) 110 | _C.transforms.img.random_resized_crop.ratio = (3/4, 4/3) 111 | 112 | # resize 113 | _C.transforms.img.resize = CN() 114 | _C.transforms.img.resize.enable = 1 115 | 116 | # random_crop 117 | _C.transforms.img.random_crop = CN() 118 | _C.transforms.img.random_crop.enable = 1 119 | _C.transforms.img.random_crop.padding = 0 120 | 121 | # center_crop 122 | _C.transforms.img.center_crop = CN() 123 | _C.transforms.img.center_crop.enable = 0 124 | 125 | ### without modifying the image size 126 | _C.transforms.img.aug_imagenet = False 127 | _C.transforms.img.aug_cifar = False 128 | 129 | # color_jitter 130 | _C.transforms.img.color_jitter = CN() 131 | _C.transforms.img.color_jitter.enable = 0 132 | _C.transforms.img.color_jitter.brightness = 0. 133 | _C.transforms.img.color_jitter.contrast = 0. 134 | _C.transforms.img.color_jitter.saturation = 0. 135 | _C.transforms.img.color_jitter.hue = 0. 136 | 137 | # horizontal_flip 138 | _C.transforms.img.random_horizontal_flip = CN() 139 | _C.transforms.img.random_horizontal_flip.enable = 1 140 | _C.transforms.img.random_horizontal_flip.p = 0.5 141 | 142 | # vertical_flip 143 | _C.transforms.img.random_vertical_flip = CN() 144 | _C.transforms.img.random_vertical_flip.enable = 1 145 | _C.transforms.img.random_vertical_flip.p = 0.5 146 | 147 | # random_rotation 148 | _C.transforms.img.random_rotation = CN() 149 | _C.transforms.img.random_rotation.enable = 1 150 | _C.transforms.img.random_rotation.degrees = 10 151 | 152 | 153 | 154 | _C.label_transforms = CN() # label transforms 155 | _C.label_transforms.name = 'default' 156 | 157 | 158 | # ---------------------------------------------------------------------------- # 159 | # dataloader 160 | # ---------------------------------------------------------------------------- # 161 | _C.dataloader = CN() 162 | _C.dataloader.num_workers = 4 163 | _C.dataloader.sample_train = "default" 164 | _C.dataloader.sample_test = "default" 165 | 166 | 167 | # ---------------------------------------------------------------------------- # 168 | # model 169 | # ---------------------------------------------------------------------------- # 170 | _C.model = CN() 171 | _C.model.name = 'Resnet50' 172 | _C.model.classes = 10 173 | _C.model.pretrained = True 174 | _C.model.finetune = False 175 | 176 | 177 | # ---------------------------------------------------------------------------- # 178 | # optimizer 179 | # ---------------------------------------------------------------------------- # 180 | _C.optim = CN() 181 | _C.optim.name = 'adam' 182 | _C.optim.momentum = 0.9 183 | _C.optim.base_lr = 0.001 184 | # _C.optim.lr = _C.optim.base_lr # will changed in v0.3.0.0 185 | _C.optim.weight_decay = 0.0005 186 | 187 | # scheduler 188 | _C.optim.scheduler = CN() 189 | _C.optim.scheduler.name = 'MultiStepLR' 190 | _C.optim.scheduler.gamma = 0.1 # decay factor 191 | 192 | # for CosineAnnealingLR 193 | _C.optim.scheduler.t_max = 10 194 | 195 | # for CosineAnnealingLR 196 | _C.optim.scheduler.t_0 = 5 197 | _C.optim.scheduler.t_mul = 20 198 | 199 | # for ReduceLROnPlateau 200 | _C.optim.scheduler.mode = 'min' # min for loss, max for acc 201 | _C.optim.scheduler.patience = 10 202 | _C.optim.scheduler.verbose = True # print log once update lr 203 | 204 | # for StepLR 205 | _C.optim.scheduler.step_size = 10 206 | 207 | # for MultiStepLR 208 | _C.optim.scheduler.milestones = [10, 25, 35, 50] 209 | 210 | # _C.optimizer = _C.optim # enhance compatibility. will changed in v0.3.0.0 211 | # ---------------------------------------------------------------------------- # 212 | # loss 213 | # ---------------------------------------------------------------------------- # 214 | _C.loss = CN() 215 | _C.loss.name = 'CrossEntropy' 216 | _C.loss.class_weight = [] 217 | _C.loss.label_smoothing = 0.1 # CrossEntropyLabelSmooth 218 | 219 | _C.loss.focal_loss = CN() 220 | _C.loss.focal_loss.alpha = [] # FocalLoss 221 | _C.loss.focal_loss.gamma = 2 222 | _C.loss.focal_loss.size_average = True 223 | # ---------------------------------------------------------------------------- # 224 | # hooks 225 | # ---------------------------------------------------------------------------- # 226 | _C.hooks = CN() 227 | 228 | ## EarlyStopping 229 | _C.hooks.early_stopping = CN() 230 | _C.hooks.early_stopping.setting = 2 # 0: True 1: False 2: custom 231 | _C.hooks.early_stopping.monitor = 'valid_loss' # or 'valid_acc_1 232 | _C.hooks.early_stopping.min_delta = 0. 233 | _C.hooks.early_stopping.patience = 10 234 | _C.hooks.early_stopping.mode = 'min' # or 'max 235 | _C.hooks.early_stopping.verbose = 1 236 | 237 | # ModelCheckpoint 238 | _C.hooks.model_checkpoint = CN() 239 | _C.hooks.model_checkpoint.setting = 2 # 0: True 1: False 2: custom 240 | _C.hooks.model_checkpoint.filepath = '' # the empty file path is recommended 241 | _C.hooks.model_checkpoint.monitor = 'valid_loss' 242 | _C.hooks.model_checkpoint.mode = 'min' 243 | _C.hooks.model_checkpoint.save_top_k = 1 244 | _C.hooks.model_checkpoint.save_weights_only = False 245 | _C.hooks.model_checkpoint.verbose = 1 246 | _C.hooks.model_checkpoint.period = 1 247 | _C.hooks.model_checkpoint.prefix = '' 248 | 249 | 250 | # ---------------------------------------------------------------------------- # 251 | # Module template 252 | # ---------------------------------------------------------------------------- # 253 | 254 | _C.module = CN() 255 | _C.module.name = 'DefaultModule' 256 | _C.module.analyze_result = False 257 | # analyze the predictions and gt_laels, e.g., compute the f1_score, precision 258 | # you can modify the `analyze_result(self, gt_labels, predictions)` in `default_module.py` 259 | 260 | 261 | # ---------------------------------------------------------------------------- # 262 | # Trainer 263 | # ---------------------------------------------------------------------------- # 264 | 265 | _C.trainer = CN() 266 | _C.trainer.name = 'DefaultTrainer' 267 | _C.trainer.default_save_path = './output' # will be removed 268 | _C.trainer.default_root_dir = './output' 269 | _C.trainer.gradient_clip_val = 0 # 0 means don't clip. 270 | _C.trainer.process_position = 0 271 | _C.trainer.num_nodes = 1 272 | _C.trainer.gpus = [] # list 273 | _C.trainer.auto_select_gpus = False # If `auto_select_gpus` is enabled and `gpus` is an integer, pick available gpus automatically. 274 | # This is especially useful when GPUs are configured to be in "exclusive mode", such that 275 | # only one process at a time can access them. 276 | _C.trainer.num_tpu_cores = '' # How many TPU cores to train on (1 or 8). 277 | _C.trainer.log_gpu_memory = "" # None, 'min_max', 'all'. Might slow performance 278 | _C.trainer.show_progress_bar = False # will be removed 279 | _C.trainer.progress_bar_refresh_rate = 0 # How often to refresh progress bar (in steps). Value ``0`` disables progress bar. 280 | _C.trainer.overfit_pct = 0.0 # if 0>> policy = ImageNetPolicy() 16 | >>> transformed = policy(image) 17 | Example as a PyTorch Transform: 18 | >>> transform=transforms.Compose([ 19 | >>> transforms.Resize(256), 20 | >>> ImageNetPolicy(), 21 | >>> transforms.ToTensor()]) 22 | """ 23 | def __init__(self, fillcolor=(128, 128, 128)): 24 | self.policies = [ 25 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 26 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 27 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 28 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 29 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 30 | 31 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 32 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 33 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 34 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 35 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 36 | 37 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 38 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 39 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 40 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 41 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 42 | 43 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 44 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 45 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 46 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 47 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 48 | 49 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 50 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 51 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 52 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 53 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 54 | ] 55 | 56 | 57 | def __call__(self, img): 58 | policy_idx = random.randint(0, len(self.policies) - 1) 59 | return self.policies[policy_idx](img) 60 | 61 | def __repr__(self): 62 | return "AutoAugment ImageNet Policy" 63 | 64 | 65 | class CIFAR10Policy(object): 66 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 67 | Example: 68 | >>> policy = CIFAR10Policy() 69 | >>> transformed = policy(image) 70 | Example as a PyTorch Transform: 71 | >>> transform=transforms.Compose([ 72 | >>> transforms.Resize(256), 73 | >>> CIFAR10Policy(), 74 | >>> transforms.ToTensor()]) 75 | """ 76 | def __init__(self, fillcolor=(128, 128, 128)): 77 | self.policies = [ 78 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 79 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 80 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 81 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 82 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 83 | 84 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 85 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 86 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 87 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 88 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 89 | 90 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 91 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 92 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 93 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 94 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 95 | 96 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 97 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 98 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 99 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 100 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 101 | 102 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 103 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 104 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 105 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 106 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 107 | ] 108 | 109 | 110 | def __call__(self, img): 111 | policy_idx = random.randint(0, len(self.policies) - 1) 112 | return self.policies[policy_idx](img) 113 | 114 | def __repr__(self): 115 | return "AutoAugment CIFAR10 Policy" 116 | 117 | 118 | class SVHNPolicy(object): 119 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 120 | Example: 121 | >>> policy = SVHNPolicy() 122 | >>> transformed = policy(image) 123 | Example as a PyTorch Transform: 124 | >>> transform=transforms.Compose([ 125 | >>> transforms.Resize(256), 126 | >>> SVHNPolicy(), 127 | >>> transforms.ToTensor()]) 128 | """ 129 | def __init__(self, fillcolor=(128, 128, 128)): 130 | self.policies = [ 131 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 133 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 135 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 138 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 139 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 141 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 142 | 143 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 144 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 145 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 146 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 147 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 148 | 149 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 150 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 151 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 152 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 153 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 154 | 155 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 156 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 157 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 158 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 159 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 160 | ] 161 | 162 | 163 | def __call__(self, img): 164 | policy_idx = random.randint(0, len(self.policies) - 1) 165 | return self.policies[policy_idx](img) 166 | 167 | def __repr__(self): 168 | return "AutoAugment SVHN Policy" 169 | 170 | 171 | class SubPolicy(object): 172 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 173 | ranges = { 174 | "shearX": np.linspace(0, 0.3, 10), 175 | "shearY": np.linspace(0, 0.3, 10), 176 | "translateX": np.linspace(0, 150 / 331, 10), 177 | "translateY": np.linspace(0, 150 / 331, 10), 178 | "rotate": np.linspace(0, 30, 10), 179 | "color": np.linspace(0.0, 0.9, 10), 180 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 181 | "solarize": np.linspace(256, 0, 10), 182 | "contrast": np.linspace(0.0, 0.9, 10), 183 | "sharpness": np.linspace(0.0, 0.9, 10), 184 | "brightness": np.linspace(0.0, 0.9, 10), 185 | "autocontrast": [0] * 10, 186 | "equalize": [0] * 10, 187 | "invert": [0] * 10 188 | } 189 | 190 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 191 | def rotate_with_fill(img, magnitude): 192 | rot = img.convert("RGBA").rotate(magnitude) 193 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 194 | 195 | func = { 196 | "shearX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 198 | Image.BICUBIC, fillcolor=fillcolor), 199 | "shearY": lambda img, magnitude: img.transform( 200 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 201 | Image.BICUBIC, fillcolor=fillcolor), 202 | "translateX": lambda img, magnitude: img.transform( 203 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 204 | fillcolor=fillcolor), 205 | "translateY": lambda img, magnitude: img.transform( 206 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 207 | fillcolor=fillcolor), 208 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 209 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 210 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 211 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 212 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 213 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 214 | 1 + magnitude * random.choice([-1, 1])), 215 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 216 | 1 + magnitude * random.choice([-1, 1])), 217 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 218 | 1 + magnitude * random.choice([-1, 1])), 219 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 220 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 221 | "invert": lambda img, magnitude: ImageOps.invert(img) 222 | } 223 | 224 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 225 | # operation1, ranges[operation1][magnitude_idx1], 226 | # operation2, ranges[operation2][magnitude_idx2]) 227 | self.p1 = p1 228 | self.operation1 = func[operation1] 229 | self.magnitude1 = ranges[operation1][magnitude_idx1] 230 | self.p2 = p2 231 | self.operation2 = func[operation2] 232 | self.magnitude2 = ranges[operation2][magnitude_idx2] 233 | 234 | 235 | def __call__(self, img): 236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 238 | return img -------------------------------------------------------------------------------- /torchline/data/build.py: -------------------------------------------------------------------------------- 1 | from torchline.utils import Registry 2 | 3 | DATASET_REGISTRY = Registry("DATASET") 4 | DATASET_REGISTRY.__doc__ = """ 5 | Registry for dataset, i.e. torch.utils.data.Dataset. 6 | 7 | The registered object will be called with `obj(cfg)` 8 | """ 9 | 10 | __all__ = [ 11 | 'DATASET_REGISTRY', 12 | 'build_data' 13 | ] 14 | 15 | def build_data(cfg): 16 | """ 17 | Built the dataset, defined by `cfg.dataset.name`. 18 | """ 19 | name = cfg.dataset.name 20 | return DATASET_REGISTRY.get(name)(cfg) 21 | -------------------------------------------------------------------------------- /torchline/data/common_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import CIFAR10 as _CIFAR10 3 | from torchvision.datasets import MNIST as _MNIST 4 | 5 | from .build import DATASET_REGISTRY 6 | from .transforms import build_transforms 7 | 8 | __all__ = [ 9 | 'MNIST', 10 | 'CIFAR10', 11 | 'FakeData', 12 | 'fakedata' 13 | ] 14 | 15 | @DATASET_REGISTRY.register() 16 | def MNIST(cfg): 17 | root = cfg.dataset.dir 18 | is_train = cfg.dataset.is_train 19 | transform = build_transforms(cfg) 20 | return _MNIST(root=root, train=is_train, transform=transform.transform, download=True) 21 | 22 | @DATASET_REGISTRY.register() 23 | def CIFAR10(cfg): 24 | root = cfg.dataset.dir 25 | is_train = cfg.dataset.is_train 26 | transform = build_transforms(cfg) 27 | return _CIFAR10(root=root, train=is_train, transform=transform.transform, download=True) 28 | 29 | class FakeData(torch.utils.data.Dataset): 30 | def __init__(self, size=64, num=100): 31 | if isinstance(size, int): 32 | self.size = [size, size] 33 | elif isinstance(size, list): 34 | self.size = size 35 | self.num = num 36 | self.data = torch.rand(num, 3, *size) 37 | self.labels = torch.randint(0, 10, (num,)) 38 | 39 | 40 | def __getitem__(self, index): 41 | return self.data[index], self.labels[index] 42 | 43 | def __len__(self): 44 | return self.num 45 | 46 | @DATASET_REGISTRY.register() 47 | def fakedata(cfg): 48 | size = cfg.input.size 49 | return FakeData(size) -------------------------------------------------------------------------------- /torchline/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchline.utils import Registry, Logger 4 | 5 | SAMPLER_REGISTRY = Registry('SAMPLER') 6 | SAMPLER_REGISTRY.__doc__ = """ 7 | Registry for dataset sampler, i.e. torch.utils.data.Sampler. 8 | 9 | The registered object will be called with `obj(cfg)` 10 | """ 11 | 12 | __all__ = [ 13 | 'build_sampler', 14 | 'SAMPLER_REGISTRY' 15 | ] 16 | 17 | def build_sampler(cfg): 18 | """ 19 | Built the dataset sampler, defined by `cfg.dataset.name`. 20 | """ 21 | is_train = cfg.dataset.is_train 22 | if is_train: 23 | name = cfg.dataloader.sample_train 24 | else: 25 | name = cfg.dataloader.sample_test 26 | if name == 'default': 27 | return None 28 | return SAMPLER_REGISTRY.get(name)(cfg) 29 | 30 | @SAMPLER_REGISTRY.register() 31 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): 32 | """Samples elements randomly from a given list of indices for imbalanced dataset 33 | Arguments: 34 | indices (list, optional): a list of indices 35 | num_samples (int, optional): number of samples to draw 36 | """ 37 | 38 | def __init__(self, dataset, indices=None, num_samples=None): 39 | 40 | # if indices is not provided, 41 | # all elements in the dataset will be considered 42 | self.indices = list(range(len(dataset))) \ 43 | if indices is None else indices 44 | 45 | # if num_samples is not provided, 46 | # draw `len(indices)` samples in each iteration 47 | self.num_samples = len(self.indices) \ 48 | if num_samples is None else num_samples 49 | 50 | # distribution of classes in the dataset 51 | label_to_count = {} 52 | for idx in self.indices: 53 | label = self._get_label(dataset, idx) 54 | if label in label_to_count: 55 | label_to_count[label] += 1 56 | else: 57 | label_to_count[label] = 1 58 | 59 | # weight for each sample 60 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)] 61 | for idx in self.indices] 62 | self.weights = torch.DoubleTensor(weights) 63 | 64 | def _get_label(self, dataset, idx): 65 | dataset_type = type(dataset) 66 | if dataset_type is torchvision.datasets.MNIST: 67 | return dataset.train_labels[idx].item() 68 | elif dataset_type is torchvision.datasets.ImageFolder: 69 | return dataset.imgs[idx][1] 70 | else: 71 | try: 72 | return dataset.imgs[idx][-1] 73 | except Exception as e: 74 | print(str(e)) 75 | raise NotImplementedError 76 | 77 | def __iter__(self): 78 | return (self.indices[i] for i in torch.multinomial( 79 | self.weights, self.num_samples, replacement=True)) 80 | 81 | def __len__(self): 82 | return self.num_samples 83 | 84 | 85 | @SAMPLER_REGISTRY.register() 86 | class DiffWeightedRandomSampler(torch.utils.data.sampler.Sampler): 87 | r""" 88 | Samples elements from a given list of indices with given probabilities (weights), with replacement. 89 | 90 | Arguments: 91 | weights (sequence) : a sequence of weights, not necessary summing up to one 92 | num_samples (int): number of samples to draw 93 | 94 | """ 95 | 96 | def __init__(self, indices, weights, num_samples=0): 97 | if not isinstance(num_samples, int) or isinstance(num_samples, bool): 98 | raise ValueError("num_samples should be a non-negative integeral " 99 | "value, but got num_samples={}".format(num_samples)) 100 | self.indices = indices 101 | weights = [ weights[i] for i in self.indices ] 102 | self.weights = torch.DoubleTensor(weights) 103 | if num_samples == 0: 104 | self.num_samples = len(self.weights) 105 | else: 106 | self.num_samples = num_samples 107 | self.replacement = True 108 | 109 | def __iter__(self): 110 | return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, self.replacement)) 111 | 112 | def __len__(self): 113 | return self.num_samples -------------------------------------------------------------------------------- /torchline/data/transforms.py: -------------------------------------------------------------------------------- 1 | 2 | #coding=utf-8 3 | import torch 4 | import numpy as np 5 | import math 6 | import random 7 | import torchvision 8 | from torchvision import transforms 9 | from . import autoaugment 10 | from torchline.utils import Registry, Logger 11 | 12 | TRANSFORMS_REGISTRY = Registry('transforms') 13 | TRANSFORMS_REGISTRY.__doc__ = """ 14 | Registry for data transform functions, i.e. torchvision.transforms 15 | 16 | The registered object will be called with `obj(cfg)` 17 | """ 18 | 19 | LABEL_TRANSFORMS_REGISTRY = Registry('label_transforms') 20 | LABEL_TRANSFORMS_REGISTRY.__doc__ = """ 21 | Registry for label transform functions, i.e. torchvision.transforms 22 | 23 | The registered object will be called with `obj(cfg)` 24 | """ 25 | 26 | __all__ = [ 27 | 'build_transforms', 28 | 'build_label_transforms', 29 | 'TRANSFORMS_REGISTRY', 30 | 'LABEL_TRANSFORMS_REGISTRY', 31 | 'DefaultTransforms', 32 | '_DefaultTransforms', 33 | 'BaseTransforms' 34 | ] 35 | 36 | 37 | def build_transforms(cfg): 38 | """ 39 | Built the transforms, defined by `cfg.transforms.name`. 40 | """ 41 | name = cfg.transforms.name 42 | return TRANSFORMS_REGISTRY.get(name)(cfg) 43 | 44 | def build_label_transforms(cfg): 45 | """ 46 | Built the label transforms, defined by `cfg.label_transforms.name`. 47 | """ 48 | name = cfg.label_transforms.name 49 | if name == 'default': 50 | return None 51 | return LABEL_TRANSFORMS_REGISTRY.get(name)(cfg) 52 | 53 | class BaseTransforms(object): 54 | def __init__(self, is_train, log_name): 55 | self.logger_print = Logger(__name__, log_name).getlogger() 56 | self.is_train = is_train 57 | 58 | def get_transform(self): 59 | if not self.is_train: 60 | self.logger_print.info('Generating validation transform ...') 61 | transform = self.valid_transform 62 | self.logger_print.info(f'Valid transform={transform}') 63 | else: 64 | self.logger_print.info('Generating training transform ...') 65 | transform = self.train_transform 66 | self.logger_print.info(f'Train transform={transform}') 67 | return transform 68 | 69 | @property 70 | def valid_transform(self): 71 | raise NotImplementedError 72 | 73 | @property 74 | def train_transform(self): 75 | raise NotImplementedError 76 | 77 | 78 | @TRANSFORMS_REGISTRY.register() 79 | def DefaultTransforms(cfg): 80 | is_train = cfg.dataset.is_train 81 | log_name = cfg.log.name 82 | mean = cfg.transforms.tensor.normalization.mean 83 | std = cfg.transforms.tensor.normalization.std 84 | img_size = cfg.input.size 85 | 86 | aug_imagenet = cfg.transforms.img.aug_imagenet 87 | aug_cifar = cfg.transforms.img.aug_cifar 88 | random_resized_crop = cfg.transforms.img.random_resized_crop 89 | resize = cfg.transforms.img.resize 90 | random_crop = cfg.transforms.img.random_crop 91 | center_crop = cfg.transforms.img.center_crop 92 | random_horizontal_flip = cfg.transforms.img.random_horizontal_flip 93 | random_vertical_flip = cfg.transforms.img.random_vertical_flip 94 | random_rotation = cfg.transforms.img.random_rotation 95 | color_jitter = cfg.transforms.img.color_jitter 96 | return _DefaultTransforms(is_train, log_name, img_size, 97 | aug_imagenet, aug_cifar, 98 | random_resized_crop, 99 | resize, 100 | random_crop, 101 | center_crop, 102 | random_horizontal_flip, 103 | random_vertical_flip, 104 | random_rotation, 105 | color_jitter, 106 | mean, std) 107 | 108 | class _DefaultTransforms(BaseTransforms): 109 | def __init__(self, is_train, log_name, img_size, 110 | aug_imagenet=False, aug_cifar=False, 111 | random_resized_crop={'enable':0}, 112 | resize={'enable':1}, 113 | random_crop={'enable':0, 'padding':0}, 114 | center_crop={'enable':0}, 115 | random_horizontal_flip={'enbale':0, 'p':0.5}, 116 | random_vertical_flip={'enbale':0, 'p':0.5}, 117 | random_rotation={'enbale':0, 'degrees':15}, 118 | color_jitter={'enable':0,'brightness':0.1, 'contrast':0.1, 'saturation':0.1, 'hue':0.1}, 119 | mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], *args, **kwargs): 120 | super(_DefaultTransforms, self).__init__(is_train, log_name) 121 | self.is_train = is_train 122 | self.mean = mean 123 | self.std = std 124 | self.img_size = img_size 125 | self.min_edge_size = min(self.img_size) 126 | self.normalize = transforms.Normalize(self.mean, self.std) 127 | self.aug_imagenet = aug_imagenet 128 | self.aug_cifar = aug_cifar 129 | self.random_resized_crop = random_resized_crop 130 | self.resize = resize 131 | self.random_crop = random_crop 132 | self.center_crop = center_crop 133 | self.random_horizontal_flip = random_horizontal_flip 134 | self.random_vertical_flip = random_vertical_flip 135 | self.random_rotation = random_rotation 136 | self.color_jitter = color_jitter 137 | self.transform = self.get_transform() 138 | 139 | @property 140 | def valid_transform(self): 141 | transform = transforms.Compose([ 142 | transforms.Resize(self.img_size), 143 | transforms.ToTensor(), 144 | self.normalize 145 | ]) 146 | return transform 147 | 148 | @property 149 | def train_transform(self): 150 | # aug_imagenet 151 | if self.aug_imagenet: 152 | self.logger_print.info('Using imagenet augmentation') 153 | transform = transforms.Compose([ 154 | transforms.Resize(self.img_size), 155 | autoaugment.ImageNetPolicy(), 156 | transforms.ToTensor(), 157 | self.normalize 158 | ]) 159 | # aug cifar 160 | elif self.aug_cifar: 161 | self.logger_print.info('Using cifar augmentation') 162 | transform = transforms.Compose([ 163 | transforms.Resize(self.img_size), 164 | autoaugment.CIFAR10Policy(), 165 | transforms.ToTensor(), 166 | self.normalize 167 | ]) 168 | # customized transformations 169 | else: 170 | transform = self.read_transform_from_cfg() 171 | return transform 172 | 173 | def read_transform_from_cfg(self): 174 | transform_list = [] 175 | self.check_conflict_options() 176 | 177 | # resize and crop opertaion 178 | if self.random_resized_crop['enable']: 179 | transform_list.append(transforms.RandomResizedCrop(self.img_size)) 180 | elif self.resize['enable']: 181 | transform_list.append(transforms.Resize(self.img_size)) 182 | if self.random_crop['enable']: 183 | transform_list.append(transforms.RandomCrop(self.min_edge_size, padding=self.random_crop['padding'])) 184 | elif self.center_crop['enable']: 185 | transform_list.append(transforms.CenterCrop(self.min_edge_size)) 186 | 187 | # ColorJitter 188 | if self.color_jitter['enable']: 189 | params = {key: self.color_jitter[key] for key in self.color_jitter 190 | if key != 'enable'} 191 | transform_list.append(transforms.ColorJitter(**params)) 192 | 193 | # horizontal flip 194 | if self.random_horizontal_flip['enable']: 195 | p = self.random_horizontal_flip['p'] 196 | transform_list.append(transforms.RandomHorizontalFlip(p)) 197 | 198 | # vertical flip 199 | if self.random_vertical_flip['enable']: 200 | p = self.random_vertical_flip['p'] 201 | transform_list.append(transforms.RandomVerticalFlip(p)) 202 | 203 | # rotation 204 | if self.random_rotation['enable']: 205 | degrees = self.random_rotation['degrees'] 206 | transform_list.append(transforms.RandomRotation(degrees)) 207 | transform_list.append(transforms.ToTensor()) 208 | transform_list.append(self.normalize) 209 | transform_list = transforms.Compose(transform_list) 210 | assert len(transform_list.transforms) > 0, "You must apply transformations" 211 | return transform_list 212 | 213 | def check_conflict_options(self): 214 | count = self.random_resized_crop['enable'] + \ 215 | self.resize['enable'] 216 | assert count <= 1, 'You can only use one resize transform operation' 217 | 218 | count = self.random_crop['enable'] + \ 219 | self.center_crop['enable'] 220 | assert count <= 1, 'You can only use one crop transform operation' -------------------------------------------------------------------------------- /torchline/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .default_module import * 2 | from .build import * 3 | from .utils import * -------------------------------------------------------------------------------- /torchline/engine/build.py: -------------------------------------------------------------------------------- 1 | 2 | from torchline.utils import Registry, Logger 3 | 4 | MODULE_REGISTRY = Registry('MODULE') 5 | MODULE_REGISTRY.__doc__ = """ 6 | Registry for module template, e.g. DefaultModule. 7 | 8 | The registered object will be called with `obj(cfg)` 9 | """ 10 | 11 | __all__ = [ 12 | 'MODULE_REGISTRY', 13 | 'build_module' 14 | ] 15 | 16 | def build_module(cfg): 17 | """ 18 | Built the module template, defined by `cfg.module.name`. 19 | """ 20 | name = cfg.module.name 21 | return MODULE_REGISTRY.get(name)(cfg) 22 | 23 | -------------------------------------------------------------------------------- /torchline/engine/default_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example template for defining a system 3 | """ 4 | import logging 5 | import os 6 | from argparse import ArgumentParser 7 | from collections import OrderedDict 8 | 9 | import pytorch_lightning as pl 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | from pytorch_lightning import LightningModule 16 | from sklearn import metrics 17 | from torch import optim 18 | from torch.utils.data import DataLoader 19 | from torch.utils.data.distributed import DistributedSampler 20 | 21 | from torchline.data import build_data, build_sampler 22 | from torchline.losses import build_loss_fn 23 | from torchline.models import build_model 24 | from torchline.utils import AverageMeterGroup, topk_acc 25 | 26 | from .build import MODULE_REGISTRY 27 | from .utils import generate_optimizer, generate_scheduler 28 | 29 | __all__ = [ 30 | 'DefaultModule' 31 | ] 32 | 33 | @MODULE_REGISTRY.register() 34 | class DefaultModule(LightningModule): 35 | """ 36 | Sample model to show how to define a template 37 | """ 38 | 39 | def __init__(self, cfg): 40 | """ 41 | Pass in parsed HyperOptArgumentParser to the model 42 | :param cfg: 43 | """ 44 | # init superclass 45 | super(DefaultModule, self).__init__() 46 | self.cfg = cfg 47 | self.hparams = self.cfg.hparams 48 | self.batch_size = self.cfg.dataset.batch_size 49 | 50 | # if you specify an example input, the summary will show input/output for each layer 51 | h, w = self.cfg.input.size 52 | self.example_input_array = torch.rand(1, 3, h, w) 53 | self.crt_batch_idx = 0 54 | self.inputs = self.example_input_array 55 | 56 | # build model 57 | self.model = self.build_model(cfg) 58 | self.train_meters = AverageMeterGroup() 59 | self.valid_meters = AverageMeterGroup() 60 | 61 | # --------------------- 62 | # model SETUP 63 | # --------------------- 64 | def build_model(self, cfg): 65 | """ 66 | Layout model 67 | :return: 68 | """ 69 | return build_model(cfg) 70 | 71 | def build_loss_fn(self, cfg): 72 | """ 73 | Layout loss_fn 74 | :return: 75 | """ 76 | return build_loss_fn(cfg) 77 | 78 | def build_data(self, cfg, is_train): 79 | """ 80 | Layout training dataset 81 | :return: 82 | """ 83 | cfg.defrost() 84 | cfg.dataset.is_train = is_train 85 | cfg.freeze() 86 | return build_data(cfg) 87 | 88 | def build_sampler(self, cfg, is_train): 89 | """ 90 | Layout training dataset 91 | :return: 92 | """ 93 | cfg.defrost() 94 | cfg.dataset.is_train = is_train 95 | cfg.freeze() 96 | return build_sampler(cfg) 97 | 98 | # --------------------- 99 | # Hooks 100 | # --------------------- 101 | 102 | def on_train_start(self): 103 | ckpt_path = self.trainer.resume_from_checkpoint 104 | print(ckpt_path) 105 | if ckpt_path: 106 | if os.path.exists(ckpt_path): 107 | ckpt = torch.load(ckpt_path) 108 | best = ckpt['checkpoint_callback_best'] 109 | self.log_info(f"The best result of the ckpt is {best}") 110 | else: 111 | print(f'{ckpt_path} not exists') 112 | raise NotImplementedError 113 | 114 | def on_epoch_start(self): 115 | if not self.cfg.trainer.progress_bar_refresh_rate: 116 | # print current lr 117 | if isinstance(self.trainer.optimizers, list): 118 | if len(self.trainer.optimizers) == 1: 119 | optimizer = self.trainer.optimizers[0] 120 | lr = optimizer.param_groups[0]["lr"] 121 | self.log_info(f"lr={lr:.4e}") 122 | else: 123 | for index, optimizer in enumerate(self.trainer.optimizers): 124 | lr = optimizer.param_groups[0]["lr"] 125 | name = str(optimizer).split('(')[0].strip() 126 | self.log_info(f"lr of {name}_{index} is {lr:.4e} ") 127 | else: 128 | lr = self.trainer.optimizers.param_groups[0]["lr"] 129 | self.log_info(f"lr={lr:.4e}") 130 | 131 | def on_epoch_end(self): 132 | if not self.cfg.trainer.progress_bar_refresh_rate: 133 | self.log_info(f'Final Train: {self.train_meters}') 134 | self.log_info(f'FInal Valid: {self.valid_meters}') 135 | self.log_info("===========================\n") 136 | self.train_meters = AverageMeterGroup() 137 | self.valid_meters = AverageMeterGroup() 138 | 139 | def log_info(self, *args, **kwargs): 140 | if self.trainer.proc_rank == 0: 141 | self.trainer.logger_print.info(*args, **kwargs) 142 | 143 | def log_warning(self, *args, **kwargs): 144 | if self.trainer.proc_rank == 0: 145 | self.trainer.logger_print.warning(*args, **kwargs) 146 | # --------------------- 147 | # TRAINING 148 | # --------------------- 149 | 150 | def forward(self, x): 151 | """ 152 | No special modification required for lightning, define as you normally would 153 | :param x: 154 | :return: 155 | 156 | return middle features 157 | features = self.model.features(x) 158 | logits = self.model.logits(features) 159 | return logits 160 | """ 161 | return self.model(x) 162 | 163 | def loss(self, predictions, gt_labels): 164 | loss_fn = self.build_loss_fn(self.cfg) 165 | return loss_fn(predictions, gt_labels) 166 | 167 | def print_log(self, batch_idx, is_train, inputs, meters, save_examples=False): 168 | flag = batch_idx % self.cfg.trainer.log_save_interval == 0 169 | if not self.trainer.progress_bar_refresh_rate and flag: 170 | if is_train: 171 | _type = 'Train' 172 | all_step = self.trainer.num_training_batches 173 | else: 174 | _type = 'Valid' 175 | all_step = self.trainer.num_val_batches 176 | crt_epoch, crt_step = self.trainer.current_epoch, batch_idx 177 | all_epoch = self.trainer.max_epochs 178 | log_info = f"{_type} Epoch {crt_epoch}/{all_epoch} step {crt_step}/{all_step} {meters}" 179 | self.log_info(log_info) 180 | 181 | if self.current_epoch==0 and batch_idx==0 and save_examples: 182 | if not os.path.exists('train_valid_samples'): 183 | os.makedirs('train_valid_samples') 184 | for i, img in enumerate(inputs[:5]): 185 | torchvision.transforms.ToPILImage()(img.cpu()).save(f'./train_valid_samples/{_type}_img{i}.jpg') 186 | 187 | def training_step(self, batch, batch_idx): 188 | """ 189 | Lightning calls this inside the training loop 190 | :param batch: 191 | :return: 192 | """ 193 | 194 | # forward pass 195 | inputs, gt_labels = batch 196 | predictions = self.forward(inputs) 197 | 198 | # calculate loss 199 | loss_val = self.loss(predictions, gt_labels) 200 | 201 | # acc 202 | acc_results = topk_acc(predictions, gt_labels, self.cfg.topk) 203 | tqdm_dict = {} 204 | 205 | if self.on_gpu: 206 | acc_results = [torch.tensor(x).to(loss_val.device.index) for x in acc_results] 207 | 208 | # in DP mode (default) make sure if result is scalar, there's another dim in the beginning 209 | if self.trainer.use_dp or self.trainer.use_ddp2: 210 | loss_val = loss_val.unsqueeze(0) 211 | acc_results = [x.unsqueeze(0) for x in acc_results] 212 | 213 | tqdm_dict['train_loss'] = loss_val 214 | for i, k in enumerate(self.cfg.topk): 215 | tqdm_dict[f'train_acc_{k}'] = acc_results[i] 216 | 217 | output = OrderedDict({ 218 | 'loss': loss_val, 219 | 'progress_bar': tqdm_dict, 220 | 'log': tqdm_dict 221 | }) 222 | 223 | self.train_meters.update({key: val.item() for key, val in tqdm_dict.items()}) 224 | # self.print_log(batch_idx, True, inputs, self.train_meters) 225 | 226 | # can also return just a scalar instead of a dict (return loss_val) 227 | return output 228 | 229 | def validation_step(self, batch, batch_idx): 230 | """ 231 | Lightning calls this inside the validation loop 232 | :param batch: 233 | :return: 234 | """ 235 | inputs, gt_labels = batch 236 | predictions = self.forward(inputs) 237 | 238 | loss_val = self.loss(predictions, gt_labels) 239 | 240 | # acc 241 | val_acc_1, val_acc_k = topk_acc(predictions, gt_labels, self.cfg.topk) 242 | 243 | if self.on_gpu: 244 | val_acc_1 = val_acc_1.cuda(loss_val.device.index) 245 | val_acc_k = val_acc_k.cuda(loss_val.device.index) 246 | 247 | # in DP mode (default) make sure if result is scalar, there's another dim in the beginning 248 | if self.trainer.use_dp or self.trainer.use_ddp2: 249 | loss_val = loss_val.unsqueeze(0) 250 | val_acc_1 = val_acc_1.unsqueeze(0) 251 | val_acc_k = val_acc_k.unsqueeze(0) 252 | 253 | output = OrderedDict({ 254 | 'valid_loss': loss_val, 255 | 'valid_acc_1': val_acc_1, 256 | f'valid_acc_{self.cfg.topk[-1]}': val_acc_k, 257 | }) 258 | tqdm_dict = {k: v for k, v in dict(output).items()} 259 | self.valid_meters.update({key: val.item() for key, val in tqdm_dict.items()}) 260 | # self.print_log(batch_idx, False, inputs, self.valid_meters) 261 | 262 | if self.cfg.module.analyze_result: 263 | output.update({ 264 | 'predictions': predictions.detach(), 265 | 'gt_labels': gt_labels.detach(), 266 | }) 267 | # can also return just a scalar instead of a dict (return loss_val) 268 | return output 269 | 270 | def training_step_end(self, output): 271 | self.print_log(self.trainer.batch_idx, True, self.inputs, self.train_meters) 272 | return output 273 | 274 | def validation_step_end(self, output): 275 | self.crt_batch_idx += 1 276 | self.print_log(self.crt_batch_idx, False, self.inputs, self.valid_meters) 277 | return output 278 | 279 | def validation_epoch_end(self, outputs): 280 | """ 281 | Called at the end of validation to aggregate outputs 282 | :param outputs: list of individual outputs of each validation step 283 | :return: 284 | """ 285 | # if returned a scalar from validation_step, outputs is a list of tensor scalars 286 | # we return just the average in this case (if we want) 287 | # return torch.stack(outputs).mean() 288 | 289 | self.crt_batch_idx = 0 290 | tqdm_dict = {key: val.avg for key, val in self.valid_meters.meters.items()} 291 | valid_loss = torch.tensor(self.valid_meters.meters['valid_loss'].avg) 292 | valid_acc_1 = torch.tensor(self.valid_meters.meters['valid_acc_1'].avg) 293 | result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 294 | 'valid_loss': valid_loss, 295 | 'valid_acc_1': valid_acc_1} 296 | 297 | if self.cfg.module.analyze_result: 298 | predictions = [] 299 | gt_labels = [] 300 | for output in outputs: 301 | predictions.append(output['predictions']) 302 | gt_labels.append(output['gt_labels']) 303 | predictions = torch.cat(predictions) 304 | gt_labels = torch.cat(gt_labels) 305 | analyze_result = self.analyze_result(gt_labels, predictions) 306 | self.log_info(analyze_result) 307 | # result.update({'analyze_result': analyze_result}) 308 | return result 309 | 310 | def test_step(self, batch, batch_idx): 311 | return self.validation_step(batch, batch_idx) 312 | 313 | def test_epoch_end(self, outputs): 314 | return self.validation_epoch_end(outputs) 315 | 316 | def analyze_result(self, gt_labels, predictions): 317 | ''' 318 | Args: 319 | gt_lables: tensor (N) 320 | predictions: tensor (N*C) 321 | ''' 322 | return str(metrics.classification_report(gt_labels.cpu(), predictions.cpu().argmax(1))) 323 | 324 | # --------------------- 325 | # TRAINING SETUP 326 | # --------------------- 327 | @classmethod 328 | def parse_cfg_for_scheduler(cls, cfg, scheduler_name): 329 | if scheduler_name.lower() == 'CosineAnnealingLR'.lower(): 330 | params = {'T_max': cfg.optim.scheduler.t_max} 331 | elif scheduler_name.lower() == 'CosineAnnealingWarmRestarts'.lower(): 332 | params = {'T_0': cfg.optim.scheduler.t_0, 'T_mult': cfg.optim.scheduler.t_mul} 333 | elif scheduler_name.lower() == 'StepLR'.lower(): 334 | params = {'step_size': cfg.optim.scheduler.step_size, 'gamma': cfg.optim.scheduler.gamma} 335 | elif scheduler_name.lower() == 'MultiStepLR'.lower(): 336 | params = {'milestones': cfg.optim.scheduler.milestones, 'gamma': cfg.optim.scheduler.gamma} 337 | elif scheduler_name.lower() == 'ReduceLROnPlateau'.lower(): 338 | params = {'mode': cfg.optim.scheduler.mode, 'patience': cfg.optim.scheduler.patience, 339 | 'verbose': cfg.optim.scheduler.verbose, 'factor': cfg.optim.scheduler.gamma} 340 | else: 341 | print(f"{scheduler_name} not implemented") 342 | raise NotImplementedError 343 | return params 344 | 345 | def configure_optimizers(self): 346 | """ 347 | return whatever optimizers we want here 348 | :return: list of optimizers 349 | """ 350 | optim_name = self.cfg.optim.name 351 | momentum = self.cfg.optim.momentum 352 | weight_decay = self.cfg.optim.weight_decay 353 | lr = self.cfg.optim.base_lr 354 | optimizer = generate_optimizer(self.model, optim_name, lr, momentum, weight_decay) 355 | scheduler_params = self.parse_cfg_for_scheduler(self.cfg, self.cfg.optim.scheduler.name) 356 | scheduler = generate_scheduler(optimizer, self.cfg.optim.scheduler.name, **scheduler_params) 357 | return [optimizer], [scheduler] 358 | 359 | def __dataloader(self, is_train): 360 | # init data generators 361 | dataset = self.build_data(self.cfg, is_train) 362 | 363 | # when using multi-node (ddp) we need to add the datasampler 364 | train_sampler = self.build_sampler(self.cfg, is_train) 365 | batch_size = self.batch_size 366 | 367 | if self.use_ddp: 368 | train_sampler = DistributedSampler(dataset) 369 | 370 | should_shuffle = train_sampler is None 371 | loader = DataLoader( 372 | dataset=dataset, 373 | batch_size=batch_size, 374 | shuffle=should_shuffle if is_train else False, 375 | sampler=train_sampler, 376 | num_workers=self.cfg.dataloader.num_workers 377 | ) 378 | 379 | return loader 380 | 381 | @pl.data_loader 382 | def train_dataloader(self): 383 | logging.info('training data loader called') 384 | return self.__dataloader(is_train=True) 385 | 386 | @pl.data_loader 387 | def val_dataloader(self): 388 | logging.info('val data loader called') 389 | return self.__dataloader(is_train=False) 390 | 391 | @pl.data_loader 392 | def test_dataloader(self): 393 | logging.info('test data loader called') 394 | return self.__dataloader(is_train=False) 395 | -------------------------------------------------------------------------------- /torchline/engine/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = [ 4 | 'generate_optimizer', 5 | 'generate_scheduler' 6 | ] 7 | 8 | def generate_optimizer(model, optim_name, lr, momentum=0.9, weight_decay=1e-5): 9 | ''' 10 | return torch.optim.Optimizer 11 | ''' 12 | if optim_name.lower() == 'sgd': 13 | return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) 14 | elif optim_name.lower() == 'adadelta': 15 | return torch.optim.Adagrad(model.parameters(), lr=lr, weight_decay=weight_decay) 16 | elif optim_name.lower() == 'adam': 17 | return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 18 | elif optim_name.lower() == 'rmsprop': 19 | return torch.optim.RMSprop(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay) 20 | else: 21 | print(f"{optim_name} not implemented") 22 | raise NotImplementedError 23 | 24 | def generate_scheduler(optimizer, scheduler_name, **params): 25 | ''' 26 | return torch.optim.lr_scheduler 27 | ''' 28 | if scheduler_name.lower() == 'CosineAnnealingLR'.lower(): 29 | return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **params) 30 | elif scheduler_name.lower() == 'CosineAnnealingWarmRestarts'.lower(): 31 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, **params) 32 | elif scheduler_name.lower() == 'StepLR'.lower(): 33 | return torch.optim.lr_scheduler.StepLR(optimizer, **params) 34 | elif scheduler_name.lower() == 'MultiStepLR'.lower(): 35 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, **params) 36 | elif scheduler_name.lower() == 'ReduceLROnPlateau'.lower(): 37 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **params) 38 | else: 39 | print(f"{scheduler_name} not implemented") 40 | raise NotImplementedError -------------------------------------------------------------------------------- /torchline/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import * 2 | from .loss import * 3 | from .focal_loss import * -------------------------------------------------------------------------------- /torchline/losses/build.py: -------------------------------------------------------------------------------- 1 | from torchline.utils import Registry 2 | 3 | LOSS_FN_REGISTRY = Registry("LOSS_FN") 4 | LOSS_FN_REGISTRY.__doc__ = """ 5 | Registry for loss function, e.g. cross entropy loss. 6 | 7 | The registered object will be called with `obj(cfg)` 8 | """ 9 | 10 | __all__ = [ 11 | 'build_loss_fn', 12 | 'LOSS_FN_REGISTRY' 13 | ] 14 | 15 | def build_loss_fn(cfg): 16 | """ 17 | Built the loss function, defined by `cfg.loss.name`. 18 | """ 19 | name = cfg.loss.name 20 | return LOSS_FN_REGISTRY.get(name)(cfg) 21 | -------------------------------------------------------------------------------- /torchline/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from .build import LOSS_FN_REGISTRY 6 | 7 | __all__ = [ 8 | 'FocalLoss', 9 | '_FocalLoss' 10 | ] 11 | 12 | @LOSS_FN_REGISTRY.register() 13 | def FocalLoss(cfg): 14 | alpha = cfg.loss.focal_loss.alpha 15 | gamma = cfg.loss.focal_loss.gamma 16 | size_average = cfg.loss.focal_loss.size_average 17 | num_classes = cfg.model.classes 18 | return _FocalLoss(alpha, gamma, num_classes, size_average) 19 | 20 | class _FocalLoss(torch.nn.Module): 21 | def __init__(self, alpha, gamma, num_classes, size_average=True): 22 | """focal_loss function: -α(1-yi)**γ *ce_loss(xi,yi) 23 | Args: 24 | alpha: class weight (default 0.25). 25 | When α is a 'list', it indicates the class-wise weights; 26 | When α is a constant, 27 | if in detection task, it indicates that the class-wise weights are[α, 1-α, 1-α, ...], 28 | the first class indicates the background 29 | if in classification task, it indicates that the class-wise weights are the same 30 | gamma: γ (default 2), focusing paramter smoothly adjusts the rate at which easy examples are down-weighted. 31 | num_classes: the number of classes 32 | size_average: (default 'mean'/'sum') specify the way to compute the loss value 33 | """ 34 | 35 | super(_FocalLoss,self).__init__() 36 | self.size_average = size_average 37 | if isinstance(alpha,list): 38 | assert len(alpha)==num_classes 39 | alpha /= np.sum(alpha) # setting the value in range of [0, 1] 40 | # print("Focal loss alpha = {}, assign specific weights for each class".format(alpha)) 41 | self.alpha = torch.Tensor(alpha) 42 | else: 43 | assert alpha<=1 44 | self.alpha = torch.zeros(num_classes)+0.00001 45 | 46 | # classification task 47 | # print("Focal loss alpha={}, the weight for each class is the same".format(alpha)) 48 | self.alpha += alpha 49 | 50 | # detection task # 如果α为一个常数,则降低第一类的影响,在目标检测中背景为第一类 51 | # print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha)) 52 | # self.alpha[0] += alpha 53 | # self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes] 54 | self.gamma = gamma 55 | 56 | def forward(self, predictions, labels): 57 | """ 58 | focal_loss损失计算 59 | Args: 60 | preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数 61 | labels: 实际类别. size:[B,N] or [B] 62 | return: 63 | loss 64 | """ 65 | assert predictions.dim()==2 and labels.dim()==1 66 | preds = predictions.view(-1,predictions.size(-1)) # num*classes 67 | alpha = self.alpha.to(labels.device) 68 | 69 | # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作) 70 | preds_softmax = F.softmax(preds, dim=1) 71 | preds_logsoft = torch.log(preds_softmax) 72 | 73 | # implement nll_loss ( crossempty = log_softmax + nll ) 74 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # num*1 75 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) # num*1 76 | alpha = alpha.gather(0,labels.view(-1)) # num 77 | 78 | # calc loss 79 | # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ 80 | # shape: num*1 81 | loss = -torch.mul( torch.pow((1-preds_softmax), self.gamma), preds_logsoft ) 82 | # α * (1-pt)**γ * ce_loss 83 | # shape: 84 | loss = torch.mul(alpha, loss.t()) 85 | del preds 86 | del alpha 87 | if self.size_average: 88 | loss = loss.mean() 89 | else: 90 | loss = loss.sum() 91 | return loss -------------------------------------------------------------------------------- /torchline/losses/loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | from .build import LOSS_FN_REGISTRY 5 | 6 | __all__ = [ 7 | 'CrossEntropy', 8 | 'CrossEntropyLabelSmooth', 9 | '_CrossEntropyLabelSmooth' 10 | ] 11 | 12 | @LOSS_FN_REGISTRY.register() 13 | def CrossEntropy(cfg): 14 | weight = cfg.loss.class_weight 15 | if weight in ['', None, []]: 16 | weight = None 17 | else: 18 | weight = torch.tensor(weight) 19 | if torch.cuda.is_available(): weight=weight.cuda() 20 | return torch.nn.CrossEntropyLoss(weight=weight) 21 | 22 | @LOSS_FN_REGISTRY.register() 23 | def CrossEntropyLabelSmooth(cfg): 24 | try: 25 | label_smoothing = cfg.loss.label_smoothing 26 | except: 27 | label_smoothing = 0.1 28 | return _CrossEntropyLabelSmooth(label_smoothing) 29 | 30 | class _CrossEntropyLabelSmooth(torch.nn.Module): 31 | def __init__(self, label_smoothing): 32 | super(_CrossEntropyLabelSmooth, self).__init__() 33 | self.label_smoothing = label_smoothing 34 | 35 | def forward(self, pred, target): 36 | logsoftmax = torch.nn.LogSoftmax(dim=1) 37 | n_classes = pred.size(1) 38 | # convert to one-hot 39 | target = torch.unsqueeze(target, 1) 40 | soft_target = torch.zeros_like(pred) 41 | soft_target.scatter_(1, target, 1) 42 | # label smoothing 43 | soft_target = soft_target * (1 - self.label_smoothing) + self.label_smoothing / n_classes 44 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) -------------------------------------------------------------------------------- /torchline/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import * 2 | from .resnet import * 3 | from .pnasnet import * 4 | from .efficientnet import * 5 | from .dpn import * 6 | from .nas_models import * 7 | from .mnasnet import * 8 | from .mobilenet import * 9 | 10 | model_list = [ 11 | 'DPN26', 12 | 'DPN92', 13 | 'EfficientNetB0', 14 | 'MNASNet0_5', 15 | 'MNASNet0_75', 16 | 'MNASNet1_0', 17 | 'MNASNet1_3', 18 | 'MobileNet_V2', 19 | 'Nasnetamobile', 20 | 'Pnasnet5large', 21 | 'PNASNetA', 22 | 'PNASNetB', 23 | 'Resnet18', 24 | 'Resnet34', 25 | 'Resnet50', 26 | 'Resnet101', 27 | 'Resnet152', 28 | 'Resnext50_32x4d', 29 | 'Resnext101_32x8d', 30 | 'Wide_resnet50_2', 31 | 'Wide_resnet101_2', 32 | ] -------------------------------------------------------------------------------- /torchline/models/build.py: -------------------------------------------------------------------------------- 1 | from torchline.utils import Registry 2 | 3 | META_ARCH_REGISTRY = Registry("META_ARCH") 4 | META_ARCH_REGISTRY.__doc__ = """ 5 | Registry for meta-architectures, i.e. the whole model. 6 | 7 | The registered object will be called with `obj(cfg)` 8 | and expected to return a `nn.Module` object. 9 | """ 10 | 11 | __all__ = [ 12 | 'build_model', 13 | 'META_ARCH_REGISTRY' 14 | ] 15 | 16 | def build_model(cfg): 17 | """ 18 | Built the whole model, defined by `cfg.model.name`. 19 | """ 20 | name = cfg.model.name 21 | return META_ARCH_REGISTRY.get(name)(cfg) 22 | -------------------------------------------------------------------------------- /torchline/models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .build import META_ARCH_REGISTRY 7 | 8 | __all__ = [ 9 | 'DPN', 10 | 'DPN26', 11 | 'DPN92' 12 | ] 13 | 14 | class Bottleneck(nn.Module): 15 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 16 | super(Bottleneck, self).__init__() 17 | self.out_planes = out_planes 18 | self.dense_depth = dense_depth 19 | 20 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(in_planes) 22 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 23 | self.bn2 = nn.BatchNorm2d(in_planes) 24 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 26 | 27 | self.shortcut = nn.Sequential() 28 | if first_layer: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(out_planes+dense_depth) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = F.relu(self.bn2(self.conv2(out))) 37 | out = self.bn3(self.conv3(out)) 38 | x = self.shortcut(x) 39 | d = self.out_planes 40 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 41 | out = F.relu(out) 42 | return out 43 | 44 | @META_ARCH_REGISTRY.register() 45 | class DPN(nn.Module): 46 | def __init__(self, cfg, num_classes=10): 47 | super(DPN, self).__init__() 48 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 49 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 50 | 51 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(64) 53 | self.last_planes = 64 54 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 55 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 56 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 57 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 58 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1) 59 | self.last_linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], num_classes) 60 | 61 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 62 | strides = [stride] + [1]*(num_blocks-1) 63 | layers = [] 64 | for i,stride in enumerate(strides): 65 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 66 | self.last_planes = out_planes + (i+2) * dense_depth 67 | return nn.Sequential(*layers) 68 | 69 | def features(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layer1(out) 72 | out = self.layer2(out) 73 | out = self.layer3(out) 74 | out = self.layer4(out) 75 | return out 76 | 77 | def logits(self, x): 78 | out = self.g_avg_pool(x) 79 | out = self.last_linear(out.view(out.size(0), -1)) 80 | return out 81 | 82 | def forward(self, x): 83 | out = self.features(x) 84 | out = self.logits(out) 85 | return out 86 | 87 | @META_ARCH_REGISTRY.register() 88 | class DPN26(DPN): 89 | def __init__(self, cfg): 90 | num_classes = cfg.model.classes 91 | net_cfg = { 92 | 'in_planes': (96,192,384,768), 93 | 'out_planes': (256,512,1024,2048), 94 | 'num_blocks': (2,2,2,2), 95 | 'dense_depth': (16,32,24,128) 96 | } 97 | super(DPN26, self).__init__(net_cfg, num_classes) 98 | 99 | @META_ARCH_REGISTRY.register() 100 | class DPN92(DPN): 101 | def __init__(self, cfg): 102 | num_classes = cfg.model.classes 103 | net_cfg = { 104 | 'in_planes': (96,192,384,768), 105 | 'out_planes': (256,512,1024,2048), 106 | 'num_blocks': (3,4,20,3), 107 | 'dense_depth': (16,32,24,128) 108 | } 109 | super(DPN92, self).__init__(net_cfg, num_classes) 110 | -------------------------------------------------------------------------------- /torchline/models/efficientnet.py: -------------------------------------------------------------------------------- 1 | '''EfficientNet in PyTorch. 2 | 3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | from .build import META_ARCH_REGISTRY 11 | 12 | __all__ = [ 13 | 'EfficientNet', 14 | 'EfficientNetB0', 15 | ] 16 | 17 | class Block(nn.Module): 18 | '''expand + depthwise + pointwise + squeeze-excitation''' 19 | 20 | def __init__(self, in_planes, out_planes, expansion, stride): 21 | super(Block, self).__init__() 22 | self.stride = stride 23 | 24 | planes = expansion * in_planes 25 | self.conv1 = nn.Conv2d( 26 | in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 29 | stride=stride, padding=1, groups=planes, bias=False) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.conv3 = nn.Conv2d( 32 | planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 33 | self.bn3 = nn.BatchNorm2d(out_planes) 34 | 35 | self.shortcut = nn.Sequential() 36 | if stride == 1 and in_planes != out_planes: 37 | self.shortcut = nn.Sequential( 38 | nn.Conv2d(in_planes, out_planes, kernel_size=1, 39 | stride=1, padding=0, bias=False), 40 | nn.BatchNorm2d(out_planes), 41 | ) 42 | 43 | # SE layers 44 | self.fc1 = nn.Conv2d(out_planes, out_planes//16, kernel_size=1) 45 | self.fc2 = nn.Conv2d(out_planes//16, out_planes, kernel_size=1) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = F.relu(self.bn2(self.conv2(out))) 50 | out = self.bn3(self.conv3(out)) 51 | shortcut = self.shortcut(x) if self.stride == 1 else out 52 | # Squeeze-Excitation 53 | w = F.avg_pool2d(out, out.size(2)) 54 | w = F.relu(self.fc1(w)) 55 | w = self.fc2(w).sigmoid() 56 | out = out * w + shortcut 57 | return out 58 | 59 | @META_ARCH_REGISTRY.register() 60 | class EfficientNet(nn.Module): 61 | def __init__(self, layers, num_classes=10): 62 | super(EfficientNet, self).__init__() 63 | self.layers = layers 64 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, 65 | stride=1, padding=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(32) 67 | self.layers = self._make_layers(in_planes=32) 68 | self.last_linear = nn.Linear(layers[-1][1], num_classes) 69 | 70 | def _make_layers(self, in_planes): 71 | layers = [] 72 | for expansion, out_planes, num_blocks, stride in self.layers: 73 | strides = [stride] + [1]*(num_blocks-1) 74 | for stride in strides: 75 | layers.append(Block(in_planes, out_planes, expansion, stride)) 76 | in_planes = out_planes 77 | return nn.Sequential(*layers) 78 | 79 | def features(self, x): # 2*3*32*32 80 | out = F.relu(self.bn1(self.conv1(x))) # 2*32*32*32 81 | out = self.layers(out) # 2*320*1*1 82 | return out 83 | 84 | def logits(self, x): 85 | out = nn.AdaptiveAvgPool2d((1,1))(x) # 2*320*1*1 86 | out = out.view(out.size(0), -1) 87 | out = self.last_linear(out) 88 | return out 89 | 90 | def forward(self, x): # 2*3*32*32 91 | out = self.features(x) # 2*320*1*1 92 | out = self.logits(out) # 2*num_classes 93 | return out 94 | 95 | 96 | @META_ARCH_REGISTRY.register() 97 | class EfficientNetB0(EfficientNet): 98 | def __init__(self, cfg): 99 | num_classes = cfg.model.classes 100 | # (expansion, out_planes, num_blocks, stride) 101 | layers = [(1, 16, 1, 2), 102 | (6, 24, 2, 1), 103 | (6, 40, 2, 2), 104 | (6, 80, 3, 2), 105 | (6, 112, 3, 1), 106 | (6, 192, 4, 2), 107 | (6, 320, 1, 2)] 108 | super(EfficientNetB0, self).__init__(layers, num_classes) -------------------------------------------------------------------------------- /torchline/models/mnasnet.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from torchvision.models.mnasnet import MNASNet 7 | 8 | from .build import META_ARCH_REGISTRY 9 | 10 | __all__ = [ 11 | 'MNASNet', 12 | 'MNASNet0_5', 13 | 'MNASNet0_75', 14 | 'MNASNet1_0', 15 | 'MNASNet1_3', 16 | ] 17 | 18 | class MNASNet(nn.Module): 19 | # Modify attributs 20 | def __init__(self, model): 21 | super(MNASNet, self).__init__() 22 | for key, val in model.__dict__.items(): 23 | self.__dict__[key] = val 24 | self.stem = model.layers[:8] 25 | self.layer1 = model.layers[8] 26 | self.layer2 = model.layers[9] 27 | self.layer3 = model.layers[10] 28 | self.layer4 = model.layers[11] 29 | self.layer5 = model.layers[12] 30 | self.layer6 = model.layers[13] 31 | self.layer7 = model.layers[14:] 32 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1) 33 | self.last_linear = model.classifier 34 | 35 | def features(self, x): # b*3*64*64 36 | out = self.stem(x) # b*16*32*32 37 | out = self.layer1(out) # b*16*16*16 38 | out = self.layer2(out) # b*24*8*8 39 | out = self.layer3(out) # b*40*4*4 40 | out = self.layer4(out) # b*48*4*4 41 | out = self.layer5(out) # b*96*2*2 42 | out = self.layer6(out) # b*160*2*2 43 | out = self.layer7(out) # b*1280*2*2 44 | return out 45 | 46 | def logits(self, x): 47 | out = x.mean([2, 3]) 48 | out = self.last_linear(out) 49 | return out 50 | 51 | def forward(self, x): 52 | out = self.features(x) 53 | out = self.logits(out) 54 | return out 55 | 56 | def generate_model(cfg, name): 57 | pretrained=cfg.model.pretrained 58 | classes = cfg.model.classes 59 | if 'dropout' in cfg.model: 60 | dropout = cfg.model.dropout 61 | else: 62 | dropout = 0.2 63 | model = eval(f"models.{name}(pretrained={pretrained})") 64 | if classes != 1000: 65 | in_features = model.classifier[1].in_features 66 | model.classifier = nn.Sequential( 67 | nn.Dropout(p=dropout, inplace=True), 68 | nn.Linear(in_features, classes, bias=False)) 69 | return MNASNet(model) 70 | 71 | @META_ARCH_REGISTRY.register() 72 | def MNASNet0_5(cfg): 73 | return generate_model(cfg, 'mnasnet0_5') 74 | 75 | @META_ARCH_REGISTRY.register() 76 | def MNASNet0_75(cfg): 77 | return generate_model(cfg, 'mnasnet0_75') 78 | 79 | @META_ARCH_REGISTRY.register() 80 | def MNASNet1_0(cfg): 81 | return generate_model(cfg, 'mnasnet1_0') 82 | 83 | @META_ARCH_REGISTRY.register() 84 | def MNASNet1_3(cfg): 85 | return generate_model(cfg, 'mnasnet1_3') -------------------------------------------------------------------------------- /torchline/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from torchvision.models.mobilenet import MobileNetV2 7 | 8 | from .build import META_ARCH_REGISTRY 9 | 10 | __all__ = [ 11 | 'MobileNetV2', 12 | 'MobileNet_V2' 13 | ] 14 | 15 | 16 | 17 | class MobileNetV2(nn.Module): 18 | # Modify attributs 19 | def __init__(self, model): 20 | super(MobileNetV2, self).__init__() 21 | for key, val in model.__dict__.items(): 22 | self.__dict__[key] = val 23 | 24 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1) 25 | self.last_linear = self.classifier 26 | 27 | def logits(self, x): 28 | out = self.g_avg_pool(x) 29 | out = out.view(out.size(0), -1) 30 | out = self.last_linear(out) 31 | return out 32 | 33 | def forward(self, x): 34 | out = self.features(x) 35 | out = self.logits(out) 36 | return out 37 | 38 | def generate_model(cfg, name): 39 | pretrained=cfg.model.pretrained 40 | classes = cfg.model.classes 41 | if 'dropout' in cfg.model: 42 | dropout = cfg.model.dropout 43 | else: 44 | dropout = 0.2 45 | model = eval(f"models.{name}(pretrained={pretrained})") 46 | if classes != 1000: 47 | in_features = model.classifier[1].in_features 48 | model.classifier = nn.Sequential( 49 | nn.Dropout(p=dropout, inplace=True), 50 | nn.Linear(in_features, classes, bias=False)) 51 | return MobileNetV2(model) 52 | 53 | @META_ARCH_REGISTRY.register() 54 | def MobileNet_V2(cfg): 55 | return generate_model(cfg, 'mobilenet_v2') -------------------------------------------------------------------------------- /torchline/models/nas_models.py: -------------------------------------------------------------------------------- 1 | import pretrainedmodels 2 | import torch.nn as nn 3 | 4 | from torchline.models import META_ARCH_REGISTRY 5 | 6 | __all__ = [ 7 | 'Nasnetamobile', 8 | 'Pnasnet5large', 9 | ] 10 | 11 | def generate_model(cfg, name): 12 | pretrained='imagenet' if cfg.model.pretrained else None 13 | classes = cfg.model.classes 14 | img_size = cfg.input.size[0] 15 | model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained=pretrained) 16 | model.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | if classes != 1000: 18 | in_features = model.last_linear.in_features 19 | model.last_linear = nn.Sequential( 20 | nn.Linear(in_features, classes, bias=False)) 21 | return model 22 | 23 | @META_ARCH_REGISTRY.register() 24 | def Nasnetamobile(cfg): 25 | return generate_model(cfg, 'nasnetamobile') 26 | 27 | @META_ARCH_REGISTRY.register() 28 | def Pnasnet5large(cfg): 29 | return generate_model(cfg, 'pnasnet5large') 30 | -------------------------------------------------------------------------------- /torchline/models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .build import META_ARCH_REGISTRY 10 | 11 | __all__ = [ 12 | 'PNASNet', 13 | 'PNASNetA', 14 | 'PNASNetB', 15 | ] 16 | 17 | class SepConv(nn.Module): 18 | '''Separable Convolution.''' 19 | def __init__(self, in_planes, out_planes, kernel_size, stride): 20 | super(SepConv, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, out_planes, 22 | kernel_size, stride, 23 | padding=(kernel_size-1)//2, 24 | bias=False, groups=in_planes) 25 | self.bn1 = nn.BatchNorm2d(out_planes) 26 | 27 | def forward(self, x): 28 | return self.bn1(self.conv1(x)) 29 | 30 | 31 | class CellA(nn.Module): 32 | def __init__(self, in_planes, out_planes, stride=1): 33 | super(CellA, self).__init__() 34 | self.stride = stride 35 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 36 | if stride==2: 37 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 38 | self.bn1 = nn.BatchNorm2d(out_planes) 39 | 40 | def forward(self, x): 41 | y1 = self.sep_conv1(x) 42 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 43 | if self.stride==2: 44 | y2 = self.bn1(self.conv1(y2)) 45 | return F.relu(y1+y2) 46 | 47 | class CellB(nn.Module): 48 | def __init__(self, in_planes, out_planes, stride=1): 49 | super(CellB, self).__init__() 50 | self.stride = stride 51 | # Left branch 52 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 53 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 54 | # Right branch 55 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 56 | if stride==2: 57 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 58 | self.bn1 = nn.BatchNorm2d(out_planes) 59 | # Reduce channels 60 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 61 | self.bn2 = nn.BatchNorm2d(out_planes) 62 | 63 | def forward(self, x): 64 | # Left branch 65 | y1 = self.sep_conv1(x) 66 | y2 = self.sep_conv2(x) 67 | # Right branch 68 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 69 | if self.stride==2: 70 | y3 = self.bn1(self.conv1(y3)) 71 | y4 = self.sep_conv3(x) 72 | # Concat & reduce channels 73 | b1 = F.relu(y1+y2) 74 | b2 = F.relu(y3+y4) 75 | y = torch.cat([b1,b2], 1) 76 | return F.relu(self.bn2(self.conv2(y))) 77 | 78 | class PNASNet(nn.Module): 79 | def __init__(self, cell_type, num_cells, num_planes, num_classes=10): 80 | super(PNASNet, self).__init__() 81 | self.in_planes = num_planes 82 | self.cell_type = cell_type 83 | 84 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(num_planes) 86 | 87 | self.layer1 = self._make_layer(num_planes, num_cells=6) 88 | self.layer2 = self._downsample(num_planes*2) 89 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 90 | self.layer4 = self._downsample(num_planes*4) 91 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 92 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1) 93 | self.last_linear = nn.Linear(num_planes*4, num_classes) 94 | 95 | def _make_layer(self, planes, num_cells): 96 | layers = [] 97 | for _ in range(num_cells): 98 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 99 | self.in_planes = planes 100 | return nn.Sequential(*layers) 101 | 102 | def _downsample(self, planes): 103 | layer = self.cell_type(self.in_planes, planes, stride=2) 104 | self.in_planes = planes 105 | return layer 106 | 107 | def features(self, x): 108 | out = F.relu(self.bn1(self.conv1(x))) 109 | out = self.layer1(out) 110 | out = self.layer2(out) 111 | out = self.layer3(out) 112 | out = self.layer4(out) 113 | out = self.layer5(out) 114 | return out 115 | 116 | def logits(self, x): 117 | out = self.g_avg_pool(x) 118 | out = self.last_linear(out.view(out.size(0), -1)) 119 | return out 120 | 121 | def forward(self, x): 122 | out = self.features(x) 123 | out = self.logits(out) 124 | return out 125 | 126 | @META_ARCH_REGISTRY.register() 127 | class PNASNetA(PNASNet): 128 | def __init__(self, cfg): 129 | num_classes = cfg.model.classes 130 | super(PNASNetA, self).__init__(CellA, num_cells=6, num_planes=44, num_classes=num_classes) 131 | 132 | @META_ARCH_REGISTRY.register() 133 | class PNASNetB(PNASNet): 134 | def __init__(self, cfg): 135 | num_classes = cfg.model.classes 136 | super(PNASNetB, self).__init__(CellB, num_cells=6, num_planes=32, num_classes=num_classes) 137 | 138 | -------------------------------------------------------------------------------- /torchline/models/resnet.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from torchvision.models.resnet import ResNet 7 | 8 | from .build import META_ARCH_REGISTRY 9 | 10 | __all__ = [ 11 | 'ResNet', 12 | 'Resnet18', 13 | 'Resnet34', 14 | 'Resnet50', 15 | 'Resnet101', 16 | 'Resnet152', 17 | 'Resnext50_32x4d', 18 | 'Resnext101_32x8d', 19 | 'Wide_resnet50_2', 20 | 'Wide_resnet101_2' 21 | ] 22 | 23 | 24 | class ResNet(nn.Module): 25 | # Modify attributs 26 | def __init__(self, model): 27 | super(ResNet, self).__init__() 28 | for key, val in model.__dict__.items(): 29 | self.__dict__[key] = val 30 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1) 31 | self.last_linear = self.fc 32 | 33 | def features(self, x): 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | out = self.maxpool(out) 38 | out = self.layer1(out) 39 | out = self.layer2(out) 40 | out = self.layer3(out) 41 | out = self.layer4(out) 42 | return out 43 | 44 | def logits(self, x): 45 | out = self.g_avg_pool(x) 46 | out = out.view(out.size(0), -1) 47 | out = self.last_linear(out) 48 | return out 49 | 50 | def forward(self, x): 51 | out = self.features(x) 52 | out = self.logits(out) 53 | return out 54 | 55 | def generate_model(cfg, name): 56 | pretrained=cfg.model.pretrained 57 | classes = cfg.model.classes 58 | model = eval(f"models.{name}(pretrained={pretrained})") 59 | if classes != 1000: 60 | in_features = model.fc.in_features 61 | model.fc = nn.Linear(in_features, classes, bias=False) 62 | return ResNet(model) 63 | 64 | @META_ARCH_REGISTRY.register() 65 | def Resnet18(cfg): 66 | return generate_model(cfg, 'resnet18') 67 | 68 | @META_ARCH_REGISTRY.register() 69 | def Resnet34(cfg): 70 | return generate_model(cfg, 'resnet34') 71 | 72 | @META_ARCH_REGISTRY.register() 73 | def Resnet50(cfg): 74 | return generate_model(cfg, 'resnet50') 75 | 76 | @META_ARCH_REGISTRY.register() 77 | def Resnet101(cfg): 78 | return generate_model(cfg, 'resnet101') 79 | 80 | @META_ARCH_REGISTRY.register() 81 | def Resnet152(cfg): 82 | return generate_model(cfg, 'resnet152') 83 | 84 | @META_ARCH_REGISTRY.register() 85 | def Resnext50_32x4d(cfg): 86 | return generate_model(cfg, 'resnext50_32x4d') 87 | 88 | @META_ARCH_REGISTRY.register() 89 | def Resnext101_32x8d(cfg): 90 | return generate_model(cfg, 'resnext101_32x8d') 91 | 92 | @META_ARCH_REGISTRY.register() 93 | def Wide_resnet50_2(cfg): 94 | return generate_model(cfg, 'wide_resnet50_2') 95 | 96 | @META_ARCH_REGISTRY.register() 97 | def Wide_resnet101_2(cfg): 98 | return generate_model(cfg, 'wide_resnet101_2') -------------------------------------------------------------------------------- /torchline/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import * 2 | from .default_trainer import * -------------------------------------------------------------------------------- /torchline/trainer/build.py: -------------------------------------------------------------------------------- 1 | from torchline.utils import Registry 2 | 3 | TRAINER_REGISTRY = Registry("TRAINER") 4 | TRAINER_REGISTRY.__doc__ = """ 5 | Registry for meta-architectures, i.e. the whole model. 6 | 7 | The registered object will be called with `obj(cfg)` 8 | and expected to return a `nn.Module` object. 9 | """ 10 | 11 | __all__ = [ 12 | 'TRAINER_REGISTRY', 13 | 'build_trainer', 14 | ] 15 | 16 | def build_trainer(cfg, hparams): 17 | """ 18 | Built the whole trainer, defined by `cfg.trainer.name`. 19 | """ 20 | name = cfg.trainer.name 21 | return TRAINER_REGISTRY.get(name)(cfg, hparams) 22 | -------------------------------------------------------------------------------- /torchline/trainer/default_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 7 | from pytorch_lightning.logging import TestTubeLogger 8 | from torchline.utils import Logger 9 | 10 | from .build import TRAINER_REGISTRY 11 | 12 | __all__ = [ 13 | 'DefaultTrainer' 14 | ] 15 | 16 | 17 | @TRAINER_REGISTRY.register() 18 | class DefaultTrainer(Trainer): 19 | def __init__(self, cfg, hparams): 20 | self.cfg = cfg 21 | self.hparams = hparams 22 | 23 | self.logger = self._logger() 24 | self.logger_print = Logger(__name__, cfg.log.name).getlogger() 25 | self.early_stop_callback = self._early_stop_callback() 26 | self.checkpoint_callback = self._checkpoint_callback() 27 | 28 | # you can update trainer_params to change different parameters 29 | self.trainer_params = dict(self.cfg.trainer) 30 | self.trainer_params.update({ 31 | 'logger': self.logger, 32 | 'early_stop_callback': self.early_stop_callback, 33 | 'checkpoint_callback': self.checkpoint_callback, 34 | }) 35 | self.trainer_params.pop('name') 36 | for key in self.trainer_params: 37 | self.trainer_params[key] = self.parse_cfg_param(self.trainer_params[key]) 38 | 39 | super(DefaultTrainer, self).__init__(**self.trainer_params) 40 | 41 | def parse_cfg_param(self, cfg_item): 42 | return cfg_item if cfg_item not in ['', []] else None 43 | 44 | def _logger(self): 45 | 46 | # logger 47 | logger_cfg = self.cfg.trainer.logger 48 | assert logger_cfg.setting in [0,1,2], "You can only set three logger levels [0,1,2], but you set {}".format(logger_cfg.setting) 49 | if logger_cfg.type == 'mlflow': 50 | raise NotImplementedError 51 | # params = {key: logger_cfg.mlflow[key] for key in logger_cfg.mlflow} 52 | # custom_logger = MLFlowLogger(**params) 53 | elif logger_cfg.type == 'test_tube': 54 | params = {key: logger_cfg.test_tube[key] for key in logger_cfg.test_tube} # key: save_dir, name, version 55 | 56 | # save_dir: logger root path: 57 | if self.cfg.trainer.default_root_dir: 58 | save_dir = self.cfg.trainer.default_root_dir 59 | else: 60 | save_dir = logger_cfg.test_tube.save_dir 61 | 62 | # version 63 | version = self.cfg.trainer.logger.test_tube.version 64 | if logger_cfg.setting==1: # disable logger 65 | return False 66 | elif logger_cfg.setting==2: # use custom logger 67 | # if version < o, then use the default value. else use custom value 68 | if version<0: 69 | version = int(self.cfg.log.path.replace('/','')[-1]) 70 | else: 71 | version = int(self.cfg.log.path.replace('/','')[-1]) 72 | 73 | default_logger = TestTubeLogger(save_dir, name='torchline_logs', version=version) 74 | params.update({'version': version, 'name':logger_cfg.test_tube.name, 'save_dir': save_dir}) 75 | custom_logger = TestTubeLogger(**params) 76 | else: 77 | print(f"{logger_cfg.type} not supported") 78 | raise NotImplementedError 79 | 80 | loggers = { 81 | 0: default_logger, 82 | 1: False, 83 | 2: custom_logger 84 | } # 0: True (default) 1: False 2: custom 85 | logger = loggers[logger_cfg.setting] 86 | 87 | return logger 88 | 89 | def _early_stop_callback(self): 90 | # early_stop_callback hooks 91 | hooks = self.cfg.hooks 92 | params = {key: hooks.early_stopping[key] for key in hooks.early_stopping if key != 'setting'} 93 | early_stop_callbacks = { 94 | 0: True, # default setting 95 | 1: False, # do not use early stopping 96 | 2: EarlyStopping(**params) # use custom setting 97 | } 98 | assert hooks.early_stopping.setting in early_stop_callbacks, 'The setting of early stopping can only be in [0,1,2]' 99 | early_stop_callback = early_stop_callbacks[hooks.early_stopping.setting] 100 | return early_stop_callback 101 | 102 | def _checkpoint_callback(self): 103 | # checkpoint_callback hooks 104 | hooks = self.cfg.hooks 105 | assert hooks.model_checkpoint.setting in [0,1,2], "You can only set three ckpt levels [0,1,2], but you set {}".format(hooks.model_checkpoint.setting) 106 | 107 | params = {key: hooks.model_checkpoint[key] for key in hooks.model_checkpoint if key != 'setting'} 108 | if hooks.model_checkpoint.setting==2: 109 | if hooks.model_checkpoint.filepath.strip()=='': 110 | filepath = os.path.join(self.cfg.log.path,'checkpoints') 111 | monitor = hooks.model_checkpoint.monitor 112 | filepath = os.path.join(filepath, '{epoch}-{%s:.2f}'%monitor) 113 | params.update({'filepath': filepath}) 114 | else: 115 | self.logger_print.warn("The specified checkpoint path is not recommended!") 116 | checkpoint_callbacks = { 117 | 0: True, 118 | 1: False, 119 | 2: ModelCheckpoint(**params) 120 | } 121 | checkpoint_callback = checkpoint_callbacks[hooks.model_checkpoint.setting] 122 | return checkpoint_callback 123 | -------------------------------------------------------------------------------- /torchline/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .registry import * 3 | from .utils import * 4 | from .average_meter import * -------------------------------------------------------------------------------- /torchline/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | __all__ = [ 4 | 'AverageMeterGroup', 5 | 'AverageMeter' 6 | ] 7 | 8 | class AverageMeterGroup: 9 | """ 10 | Average meter group for multiple average meters. 11 | """ 12 | 13 | def __init__(self, verbose_type='avg'): 14 | self.meters = OrderedDict() 15 | self.verbose_type = verbose_type 16 | 17 | def update(self, data): 18 | """ 19 | Update the meter group with a dict of metrics. 20 | Non-exist average meters will be automatically created. 21 | """ 22 | for k, v in data.items(): 23 | if k not in self.meters: 24 | self.meters[k] = AverageMeter(k, ":.4f", self.verbose_type) 25 | self.meters[k].update(v) 26 | 27 | def __getattr__(self, item): 28 | return self.meters[item] 29 | 30 | def __getitem__(self, item): 31 | return self.meters[item] 32 | 33 | def __str__(self): 34 | return " ".join(f"{v}" for v in self.meters.values()) 35 | 36 | def summary(self): 37 | """ 38 | Return a summary string of group data. 39 | """ 40 | return " ".join(v.summary() for v in self.meters.values()) 41 | 42 | 43 | class AverageMeter: 44 | """ 45 | Computes and stores the average and current value. 46 | Parameters 47 | ---------- 48 | name : str 49 | Name to display. 50 | fmt : str 51 | Format string to print the values. 52 | verbose_type : str 53 | 'all': value(avg) 54 | 'avg': avg 55 | """ 56 | 57 | def __init__(self, name, fmt=':f', verbose_type='avg'): 58 | self.name = name 59 | self.fmt = fmt 60 | if verbose_type not in ['all', 'avg']: 61 | print('Not supported verbose type, using default verbose, "avg"') 62 | verbose_type = 'avg' 63 | self.verbose_type = verbose_type 64 | self.reset() 65 | 66 | def reset(self): 67 | """ 68 | Reset the meter. 69 | """ 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0 73 | self.count = 0 74 | 75 | def update(self, val, n=1): 76 | """ 77 | Update with value and weight. 78 | Parameters 79 | ---------- 80 | val : float or int 81 | The new value to be accounted in. 82 | n : int 83 | The weight of the new value. 84 | """ 85 | self.val = val 86 | self.sum += val * n 87 | self.count += n 88 | self.avg = self.sum / self.count 89 | 90 | def __str__(self): 91 | if self.verbose_type=='all': 92 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 93 | elif self.verbose_type=='avg': 94 | fmtstr = '{name} {avg' + self.fmt + '}' 95 | else: 96 | fmtstr = '{name} {avg' + self.fmt + '}' 97 | return fmtstr.format(**self.__dict__) 98 | 99 | def summary(self): 100 | fmtstr = '{name}: {avg' + self.fmt + '}' 101 | return fmtstr.format(**self.__dict__) -------------------------------------------------------------------------------- /torchline/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding:utf-8 -*- 3 | 4 | import logging 5 | import time 6 | import os 7 | 8 | __all__ = [ 9 | 'Logger' 10 | ] 11 | 12 | class Logger(object): 13 | def __init__(self, logger_name=None, filename=None, *args, **kwargs): 14 | ''' 15 | 指定保存日志的文件路径,日志级别,以及调用文件 16 | 将日志存入到指定的文件中 17 | ''' 18 | 19 | # 创建一个logger 20 | if filename is None: 21 | file = 'log.txt' 22 | else: 23 | file = filename 24 | self.logger = logging.getLogger(logger_name) 25 | self.logger.setLevel(logging.INFO) 26 | self.logger.propagate = False 27 | if (self.logger.hasHandlers()): 28 | self.logger.handlers.clear() 29 | formatter = logging.Formatter('[%(asctime)s] %(filename)s->%(funcName)s line:%(lineno)d [%(levelname)s]%(message)s') 30 | 31 | if file: 32 | hdlr = logging.FileHandler(file, 'a', encoding='utf-8') 33 | hdlr.setLevel(logging.INFO) 34 | hdlr.setFormatter(formatter) 35 | self.logger.addHandler(hdlr) 36 | 37 | strhdlr = logging.StreamHandler() 38 | strhdlr.setLevel(logging.INFO) 39 | strhdlr.setFormatter(formatter) 40 | self.logger.addHandler(strhdlr) 41 | 42 | if file: hdlr.close() 43 | strhdlr.close() 44 | 45 | def getlogger(self): 46 | return self.logger 47 | -------------------------------------------------------------------------------- /torchline/utils/registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # refernce: https://github.com/facebookresearch/fvcore/blob/51695ec2984ae7b4afae27a14d02986ef3f8afd8/fvcore/common/registry.py 4 | 5 | from typing import Dict, Optional 6 | 7 | __all__ = [ 8 | 'Registry' 9 | ] 10 | 11 | class Registry(object): 12 | ''' 13 | The registry that provides name -> object mapping, to support third-party 14 | users' custom modules. 15 | To create a registry (e.g. a backbone registry): 16 | .. code-block:: python 17 | BACKBONE_REGISTRY = Registry('BACKBONE') 18 | To register an object: 19 | .. code-block:: python 20 | @BACKBONE_REGISTRY.register() 21 | class MyBackbone(): 22 | ... 23 | Or: 24 | .. code-block:: python 25 | BACKBONE_REGISTRY.register(MyBackbone) 26 | ''' 27 | 28 | def __init__(self, name: str) -> None: 29 | """ 30 | Args: 31 | name (str): the name of this registry 32 | """ 33 | self._name: str = name 34 | self._obj_map: Dict[str, object] = {} 35 | 36 | def _do_register(self, name: str, obj: object) -> None: 37 | assert ( 38 | name not in self._obj_map 39 | ), "An object named '{}' was already registered in '{}' registry!".format( 40 | name, self._name 41 | ) 42 | self._obj_map[name] = obj 43 | 44 | def register(self, obj: object = None) -> Optional[object]: 45 | ''' 46 | Register the given object under the the name `obj.__name__`. 47 | Can be used as either a decorator or not. See docstring of this class for usage. 48 | ''' 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class: object) -> object: 52 | name = func_or_class.__name__ # pyre-ignore 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ # pyre-ignore 60 | self._do_register(name, obj) 61 | 62 | def get(self, name: str) -> object: 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError( 66 | "No object named '{}' found in '{}' registry!".format( 67 | name, self._name 68 | ) 69 | ) 70 | return ret -------------------------------------------------------------------------------- /torchline/utils/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | 8 | __all__ = [ 9 | 'image_loader', 10 | 'get_imgs_to_predict', 11 | 'topk_acc', 12 | 'model_size' 13 | ] 14 | 15 | def image_loader(filename, cfg): 16 | '''load an image and convert it to tensor 17 | Args: 18 | filename: image filename 19 | cfg: CfgNode 20 | return: 21 | torch.tensor 22 | ''' 23 | image = Image.open(filename).convert('RGB') 24 | mean = cfg.transforms.tensor.normalization.mean 25 | std = cfg.transforms.tensor.normalization.std 26 | img_size = cfg.input.size 27 | 28 | transform = transforms.Compose([ 29 | transforms.Resize(img_size), 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean, std), 32 | ]) 33 | image = transform(image) 34 | return image 35 | 36 | def get_imgs_to_predict(path, cfg): 37 | ''''load images which are only used for prediction or testing 38 | Args: 39 | path: str 40 | return: 41 | torch.tensor (N*C*H*W) 42 | ''' 43 | if os.path.isfile(path): 44 | images = image_loader(path, cfg).unsqueeze(0) 45 | elif os.path.isdir(path): 46 | image_types = [os.path.join(path,'*.jpg'), os.path.join(path,'*.png')] 47 | image_files = [] 48 | images = { 49 | 'img_file': [], 50 | 'img_data': [] 51 | } 52 | for img_type in image_types: 53 | image_files.extend(glob.glob(img_type)) 54 | for img_file in image_files: 55 | images['img_file'].append(img_file) 56 | images['img_data'].append(image_loader(img_file, cfg)) 57 | images['img_data'] = torch.stack(images['img_data']) 58 | return images 59 | 60 | def model_size(model): 61 | return sum([p.numel() for p in model.parameters()])*4/1024**2 62 | 63 | def topk_acc(output, target, topk=(1, 3)): 64 | """Computes the precision@k for the specified values of k""" 65 | maxk = max(topk) 66 | 67 | _, pred = output.topk(maxk, 1, True, True) 68 | pred = pred.t() 69 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 70 | 71 | res = [] 72 | for k in topk: 73 | correct_k = correct[:k].view(-1).float().sum(0) 74 | res.append(correct_k*100.0/len(target.view(-1))) 75 | # res.append(correct_k.mul_(100.0 / batch_size)) 76 | return torch.tensor(res) 77 | --------------------------------------------------------------------------------