├── .gitignore
├── LICENSE
├── README.md
├── distribute.sh
├── docs
├── config.md
└── tutorial.md
├── projects
├── cifar10_demo
│ ├── ReadMe.md
│ ├── cifar10.yaml
│ └── main.py
└── fake_demo
│ ├── ReadMe.md
│ ├── fake_cfg.yaml
│ └── main.py
├── release.md
├── requirements.txt
├── setup.py
├── tests
├── test_model.py
└── test_transforms.py
└── torchline
├── __init__.py
├── config
├── __init__.py
├── config.py
└── default.py
├── data
├── __init__.py
├── albumentaion_transforms.py
├── autoaugment.py
├── build.py
├── common_datasets.py
├── sampler.py
└── transforms.py
├── engine
├── __init__.py
├── build.py
├── default_module.py
└── utils.py
├── losses
├── __init__.py
├── build.py
├── focal_loss.py
└── loss.py
├── models
├── __init__.py
├── build.py
├── dpn.py
├── efficientnet.py
├── mnasnet.py
├── mobilenet.py
├── nas_models.py
├── pnasnet.py
└── resnet.py
├── trainer
├── __init__.py
├── build.py
└── default_trainer.py
└── utils
├── __init__.py
├── average_meter.py
├── logger.py
├── registry.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | dist
3 | build
4 | torchline.egg-info
5 | .vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 marsggbo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torchline v0.3.0.4
2 |
3 | > Easy to use Pytorch
4 | >
5 | > Only one configure file is enough!
6 | >
7 | > You can change anything you want just in only one configure file.
8 |
9 | # Dependences
10 |
11 | - Python>=3.6
12 | - Pytorch>=1.3.1
13 | - torchvision>=0.4.0,<0.5.0
14 | - yacs==0.1.6
15 | - pytorch-lightning<=0.7.6
16 |
17 |
18 | # Install
19 |
20 | - Before you install `torchline`, please make sure you have installed the above libraries.
21 | - You can use `torchline` both in Linux and Windows.
22 |
23 | ```bash
24 | pip install torchline
25 | ```
26 |
27 | # Run demo
28 |
29 | ## train model with GPU0 and GPU 1
30 | ```python
31 | cd projects/cifar10_demo
32 | python main.py --config_file cifar10.yaml trainer.gpus [0,1]
33 | ```
34 |
35 | ## debug,add command line `trainer.fast_dev_run True`
36 | ```python
37 | cd projects/cifar10_demo
38 | python main.py --config_file cifar10.yaml trainer.gpus [0] trainer.fast_dev_run True
39 | ```
40 |
41 | CIFAR demo uses ResNet50,which is trained for 72 epochs and achieved the best result (94.39% validation accuracy) at the epoch 54.
42 |
43 | # Thanks
44 |
45 | - [AutoML](https://zhuanlan.zhihu.com/automl)
46 | - [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning)
47 |
48 |
--------------------------------------------------------------------------------
/distribute.sh:
--------------------------------------------------------------------------------
1 | rm -rf build dist torchline.egg-info
2 | python setup.py sdist bdist_wheel
3 | if (($1==1));then
4 | python -m twine upload dist/*
5 | elif (($1==2));then
6 | python -m twine upload --repository-url https://test.pypi.org/legacy/ dist/*
7 | else
8 | echo "Wrong command, only support 1 or 2"
9 | fi
--------------------------------------------------------------------------------
/docs/config.md:
--------------------------------------------------------------------------------
1 |
2 | # 1. 输入
3 |
4 | ## 1.1 `input`
5 |
6 |
7 | - `size`: [112,112] # 输入图像大小
8 |
9 | # 2. 数据集
10 |
11 | ## 2.1 `dataset`
12 |
13 | ```yaml
14 | batch_size: 1
15 | dir: './datasets/mydataset'
16 | is_train: False
17 | name: 'fakedata'
18 | test_list: './datasets/test_list.txt'
19 | train_list: './datasets/train_list.txt'
20 | valid_list: './datasets/valid_list.txt'
21 | ```
22 |
23 | ## 2.2 `dataloader`
24 |
25 | ```yaml
26 | num_workers: 0
27 | sample_test: 'default'
28 | sample_train: 'default'
29 | ```
30 |
31 | ## 2.3 数据增强
32 |
33 | ### 2.3.1 数据增强
34 |
35 | #### 2.3.1.1 基于torchvision
36 | ```yaml
37 | transforms:
38 | img:
39 | aug_cifar: False
40 | aug_imagenet: False
41 | center_crop:
42 | enable: 0
43 | color_jitter:
44 | brightness: 0.0
45 | contrast: 0.0
46 | enable: 0
47 | hue: 0.0
48 | saturation: 0.0
49 | random_crop:
50 | enable: 1
51 | padding: 4
52 | random_horizontal_flip:
53 | enable: 1
54 | p: 0.5
55 | random_resized_crop:
56 | enable: 0
57 | ratio: (0.75, 1.3333333333333333)
58 | scale: (0.5, 1.0)
59 | random_rotation:
60 | degrees: 10
61 | enable: 0
62 | random_vertical_flip:
63 | enable: 0
64 | p: 0.5
65 | resize:
66 | enable: 0
67 | name: 'DefaultTransforms'
68 | tensor:
69 | normalization:
70 | mean: [0.4914, 0.4822, 0.4465]
71 | std: [0.2023, 0.1994, 0.201]
72 | random_erasing:
73 | enable: 0
74 | p: 0.5
75 | ratio: ((0.3, 3.3),)
76 | scale: (0.02, 0.3)
77 | ```
78 |
79 | #### 2.3.1.2 基于albumentations
80 |
81 |
82 | `abtfs`: albumentations transforms
83 | ```yaml
84 | abtfs:
85 | random_grid_shuffle:
86 | enable: 0
87 | grid: 2
88 | channel_shuffle:
89 | enable: 0
90 | channel_dropout:
91 | enable: 0
92 | drop_range: (1, 1)
93 | fill_value: 127
94 | noise:
95 | enable: 1
96 | blur:
97 | enable: 0
98 | rotate:
99 | enable: 1
100 | bright:
101 | enable: 1
102 | distortion:
103 | enable: 0
104 | hue:
105 | enable: 0
106 | cutout:
107 | enable: 1
108 | num_holes: 10
109 | size: 20
110 | fill_value: 127
111 | ```
112 |
113 | ### 2.3.2 标签数据增强
114 |
115 | ```yaml
116 | label_transforms:
117 | name: 'default'
118 | ```
119 |
120 | # 3. 模型
121 |
122 |
123 | ```yaml
124 | model:
125 | classes: 10
126 | name: 'FakeNet'
127 | pretrained: True
128 | # features: ['f4']
129 | # features_fusion: 'sum'
130 | # finetune: False
131 | ```
132 |
133 | # 4. 损失函数
134 |
135 | ```yaml
136 | loss:
137 | class_weight: []
138 | focal_loss:
139 | alpha: []
140 | gamma: 2
141 | size_average: True
142 | label_smoothing: 0.1
143 | name: 'CrossEntropy'
144 | ```
145 |
146 | # 5. 优化器和步长调整器
147 |
148 | ```yaml
149 | optim:
150 | base_lr: 0.1
151 | momentum: 0.9
152 | name: 'sgd'
153 | scheduler:
154 | gamma: 0.1
155 | milestones: [150, 250]
156 | mode: 'min'
157 | name: 'MultiStepLR'
158 | patience: 10
159 | step_size: 10
160 | t_max: 10
161 | verbose: True
162 | weight_decay: 0.0005
163 | ```
164 |
165 | # 6. 引擎Module
166 |
167 | ```yaml
168 | module:
169 | name: 'DefaultModule'
170 | ```
171 |
172 | # 7. 训练控制器
173 |
174 | ```yaml
175 | trainer:
176 | accumulate_grad_batches: 1
177 | amp_level: 'O1'
178 | check_val_every_n_epoch: 1
179 | default_root_dir: './output_fakedata'
180 | distributed_backend: 'dp'
181 | fast_dev_run: False
182 | gpus: []
183 | gradient_clip_val: 0
184 | log_gpu_memory: ''
185 | log_save_interval: 100
186 | logger:
187 | mlflow:
188 | experiment_name: 'torchline_logs'
189 | tracking_uri: './output'
190 | setting: 0
191 | test_tube:
192 | name: 'torchline_logs'
193 | save_dir: './output_fakedata'
194 | version: -1
195 | type: 'test_tube'
196 | max_epochs: 100
197 | min_epochs: 1
198 | name: 'DefaultTrainer'
199 | num_nodes: 1
200 | num_sanity_val_steps: 5
201 | overfit_pct: 0.0
202 | print_nan_grads: True
203 | process_position: 0
204 | resume_from_checkpoint: ''
205 | row_log_interval: 10
206 | show_progress_bar: False
207 | test_percent_check: 1.0
208 | track_grad_norm: -1
209 | train_percent_check: 1.0
210 | truncated_bptt_steps: ''
211 | use_amp: False
212 | val_check_interval: 1.0
213 | val_percent_check: 1.0
214 | weights_save_path: ''
215 | weights_summary: ''
216 | ```
217 |
218 | # 其他
219 |
220 | ```yaml
221 | VERSION: 1
222 | DEFAULT_CUDNN_BENCHMARK: True
223 | SEED: 666
224 | topk: [1, 3]
225 | ```
226 |
--------------------------------------------------------------------------------
/docs/tutorial.md:
--------------------------------------------------------------------------------
1 | # 代码结构
2 |
3 |
4 | 文件结构参照[detectron2](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=1&cad=rja&uact=8&ved=2ahUKEwiux_PXpLDmAhVOPnAKHVTjDVEQFjAAegQIBxAC&url=https%3A%2F%2Fgithub.com%2Ffacebookresearch%2Fdetectron2&usg=AOvVaw25FixXG7GH7dRKY6sOc2Oc)设计,采用注册机制来实现不同模块的灵活切换。
5 |
6 | - `torchline`
7 | - `config`: Configuration Module
8 | - `config.py`: return base `CfgNode`
9 | - `default.py`: 默认参数设置文件,后续可以通过导入`.yaml`文件来对指定参数做修改
10 | - `data`: 数据集模块,返回`torch.utils.data.Dataset`
11 | - `build.py`: 注册数据集,提供`build_data`函数来获取不同的数据集和数据集注册器
12 | - 数据集相关文件:
13 | - `common_datasets.py`: 返回MNIST和CIFAR10数据集
14 | - transform相关文件:
15 | - `transforms.py`: 提供`build_transforms`函数来构建`transforms`,并且该文件中包含了默认的transform类,即`DefaultTransforms`
16 | - `autoaugment.py`: Google提出的自动数据增强操作
17 | - `albumentation_transforms`: 使用albumentations库做数据增广
18 | - `engine`:
19 | - `default_module.py`: 提供了`LightningModule`的一个继承类模板
20 | - `losses`:
21 | - `build.py`: 提供`build_loss_fn`函数和loss注册器
22 | - `loss.py`: 提供一些常用的loss函数(`CrossEntropy()`)
23 | - `models`:
24 | - `build.py`: 提供`build_model`函数和模型注册器
25 | - `trainer`:
26 | - `utils.py`:
27 | - `registry.py`: 注册器模板
28 | - `logger.py`: 输出日志模块
29 | - `main.py`: 代码运行入口,后面的项目构建都可以参照这个文件写代码
30 | - `projects`: You can create your own project here.
31 | - `cifar10_demo`: A CIFAR10 demo project
32 | - `fake_demo`: 使用随机生成的数据,方便调试和体验torchline的使用
33 |
34 |
35 | # 如何自定义?
36 |
37 | 待完善。。。
38 |
39 | # 自定义参数配置
40 |
41 | # 自定义数据集
42 |
43 | # 自定义`engine`模板
44 |
45 | > 可自定义 如何**读取数据**,**优化器设置**,**forward步骤**
46 |
47 |
48 | # 自定义新的模型`model`
49 |
50 |
51 | # 自定义损失函数
52 |
53 | # 自定义`trainer`
--------------------------------------------------------------------------------
/projects/cifar10_demo/ReadMe.md:
--------------------------------------------------------------------------------
1 |
2 | # Run
3 |
4 | ```python
5 | cd projects/cifar10_demo
6 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1
7 | ```
8 |
9 | # Restore traininig
10 |
11 | 指定`trainer.logger`信息即可,例如你在前面的实验中的日志信息(包括metrics和checkpoint等)保存如下:
12 |
13 | ```bash
14 | |___output
15 | |___lightning_log # specified by trainer.logger.test_tube.name
16 | |___version_0 # specified by trainer.logger.test_tube.version
17 | |___metrics.csv
18 | |___...(other log files)
19 | |___checkpoint
20 | |___ _checkpoint_epoch_60.ckpt
21 | ```
22 |
23 | - `trainer.logger.setting `: 0 表示默认设置,1表示不用logger,2表示自定义logger
24 | - `trainer.logger.test_tube.name`: logger名字,如lightning_log
25 | - `trainer.logger.test_tube.version`: logger版本,如0
26 |
27 |
28 | ```bash
29 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0
30 | ```
31 |
32 | # test_only
33 |
34 | 只运行验证集,参数设置同上
35 | ```
36 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --test_only trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0
37 | ```
38 |
39 | # predict_only
40 |
41 | 预测指定路径下的图片,需要设置如下两个参数:
42 |
43 | - `predict_only.load_ckpt.checkpoint_path`: checkpoint路径
44 | - `predict_only.to_pred_file_path`: 需要预测的图片路径,可以是单张图片的路径,也可以是包含多张图片的文件夹路径
45 |
46 | ```
47 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --predict_only predict_only.load_ckpt.checkpoint_path './output_cifar10/lightning_logs/version_0/checkpoints/_ckpt_epoch_69.ckpt' predict_only.to_pred_file_path '.'
48 | ```
--------------------------------------------------------------------------------
/projects/cifar10_demo/cifar10.yaml:
--------------------------------------------------------------------------------
1 | input:
2 | size: (32, 32)
3 | model:
4 | name: 'Resnet50'
5 | classes: 10
6 | dataset:
7 | name: 'CIFAR10'
8 | dir: './datasets/cifar10'
9 | batch_size: 128
10 | optim:
11 | name: 'sgd'
12 | base_lr: 0.1
13 | scheduler:
14 | name: 'MultiStepLR'
15 | milestones: [150, 250]
16 | gamma: 0.1
17 | transforms:
18 | tensor:
19 | normalization:
20 | mean: [0.4914, 0.4822, 0.4465]
21 | std: [0.2023, 0.1994, 0.2010]
22 | img:
23 | aug_cifar: False
24 | resize:
25 | enable: 0
26 | random_crop:
27 | enable: 1
28 | padding: 4
29 | random_horizontal_flip:
30 | enable: 1
31 | p: 0.5
32 | random_vertical_flip:
33 | enable: 0
34 | p: 0.5
35 | random_rotation:
36 | enable: 0
37 | trainer:
38 | default_root_dir: './output_cifar10'
39 | logger:
40 | setting: 0
41 | type: 'test_tube'
42 | mlflow:
43 | tracking_uri: './output_cifar10'
44 | test_tube:
45 | save_dir: './output_cifar10'
46 | name: 'torchline_logs'
47 | version: -1 # if <0, then use default version
48 | hooks:
49 | early_stopping:
50 | setting: 2
51 | patience: 20
52 | monitor: 'valid_acc_1'
53 | mode: 'max'
54 | model_checkpoint:
55 | setting: 2
56 | monitor: 'valid_acc_1'
57 | mode: 'max'
58 | filepath: ''
59 | topk: [1, 3]
60 | predict_only:
61 | type: 'ckpt'
62 | to_pred_file_path: '' # specify the path of images
63 | load_ckpt:
64 | checkpoint_path: '' # load_from_checkpoint
65 | load_metric:
66 | weights_path: '' # load_from_metrics
67 | tags_csv: ''
68 | on_gpu: True
69 | map_location: 'cuda:0'
70 |
--------------------------------------------------------------------------------
/projects/cifar10_demo/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Runs a model on a single node across N-gpus.
3 | """
4 | import argparse
5 | import glob
6 | import os
7 | import shutil
8 | from argparse import ArgumentParser
9 |
10 | import numpy as np
11 | import torch
12 | from torchline.config import get_cfg
13 | from torchline.engine import build_module
14 | from torchline.trainer import build_trainer
15 | from torchline.utils import get_imgs_to_predict
16 |
17 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
18 |
19 | def main(hparams):
20 | """
21 | Main training routine specific for this project
22 | :param hparams:
23 | """
24 |
25 | cfg = get_cfg()
26 | cfg.setup_cfg_with_hparams(hparams)
27 | # only predict on some samples
28 | if hasattr(hparams, "predict_only") and hparams.predict_only:
29 | predict_only = cfg.predict_only
30 | if predict_only.type == 'ckpt':
31 | load_params = {key: predict_only.load_ckpt[key] for key in predict_only.load_ckpt}
32 | model = build_module(cfg)
33 | ckpt_path = load_params['checkpoint_path']
34 | model.load_state_dict(torch.load(ckpt_path)['state_dict'])
35 | elif predict_only.type == 'metrics':
36 | load_params = {key: predict_only.load_metric[key] for key in predict_only.load_metric}
37 | model = build_module(cfg).load_from_metrics(**load_params)
38 | else:
39 | print(f'{cfg.predict_only.type} not supported')
40 | raise NotImplementedError
41 |
42 | model.eval()
43 | model.freeze()
44 | images = get_imgs_to_predict(cfg.predict_only.to_pred_file_path, cfg)
45 | if torch.cuda.is_available():
46 | images['img_data'] = images['img_data'].cuda()
47 | model = model.cuda()
48 | predictions = model(images['img_data'])
49 | class_indices = torch.argmax(predictions, dim=1)
50 | for i, file in enumerate(images['img_file']):
51 | index = class_indices[i]
52 | print(f"{file} is {classes[index]}")
53 | return predictions.cpu()
54 | elif hasattr(hparams, "test_only") and hparams.test_only:
55 | model = build_module(cfg)
56 | trainer = build_trainer(cfg, hparams)
57 | trainer.test(model)
58 | else:
59 | model = build_module(cfg)
60 | trainer = build_trainer(cfg, hparams)
61 | trainer.fit(model)
62 |
63 |
64 | if __name__ == '__main__':
65 | # ------------------------
66 | # TRAINING ARGUMENTS
67 | # ------------------------
68 | # these are project-wide arguments
69 |
70 | root_dir = os.path.dirname(os.path.realpath(__file__))
71 | parent_parser = ArgumentParser(add_help=False)
72 |
73 | # gpu args
74 | parent_parser.add_argument("--config_file", default="", metavar="FILE", help="path to config file")
75 | parent_parser.add_argument('--test_only', action='store_true', help='if true, return trainer.test(model). Validates only the test set')
76 | parent_parser.add_argument('--predict_only', action='store_true', help='if true run model(samples). Predict on the given samples.')
77 | parent_parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER)
78 |
79 | # each LightningModule defines arguments relevant to it
80 | hparams = parent_parser.parse_args()
81 | assert not (hparams.test_only and hparams.predict_only), "You can't set both 'test_only' and 'predict_only' True"
82 |
83 | # ---------------------
84 | # RUN TRAINING
85 | # ---------------------
86 | main(hparams)
87 |
--------------------------------------------------------------------------------
/projects/fake_demo/ReadMe.md:
--------------------------------------------------------------------------------
1 |
2 | # Run
3 |
4 | ```python
5 | cd projects/cifar10_demo
6 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1
7 | ```
8 |
9 | # Restore traininig
10 |
11 | 指定`trainer.logger`信息即可,例如你在前面的实验中的日志信息(包括metrics和checkpoint等)保存如下:
12 |
13 | ```bash
14 | |___output
15 | |___lightning_log # specified by trainer.logger.test_tube.name
16 | |___version_0 # specified by trainer.logger.test_tube.version
17 | |___metrics.csv
18 | |___...(other log files)
19 | |___checkpoint
20 | |___ _checkpoint_epoch_60.ckpt
21 | ```
22 |
23 | - `trainer.logger.setting `: 0 表示默认设置,1表示不用logger,2表示自定义logger
24 | - `trainer.logger.test_tube.name`: logger名字,如lightning_log
25 | - `trainer.logger.test_tube.version`: logger版本,如0
26 |
27 |
28 | ```bash
29 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0
30 | ```
31 |
32 | # test_only
33 |
34 | 只运行验证集,参数设置同上
35 | ```
36 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --test_only trainer.logger.setting 2 trainer.logger.test_tube.name lightning_log trainer.logger.test_tube.version 0
37 | ```
38 |
39 | # predict_only
40 |
41 | 预测指定路径下的图片,需要设置如下两个参数:
42 |
43 | - `predict_only.load_ckpt.checkpoint_path`: checkpoint路径
44 | - `predict_only.to_pred_file_path`: 需要预测的图片路径,可以是单张图片的路径,也可以是包含多张图片的文件夹路径
45 |
46 | ```
47 | CUDA_VISIBLE_DEVICES=0 python main.py --config_file cifar10.yaml --gpus 1 --predict_only predict_only.load_ckpt.checkpoint_path './output_cifar10/lightning_logs/version_0/checkpoints/_ckpt_epoch_69.ckpt' predict_only.to_pred_file_path '.'
48 | ```
--------------------------------------------------------------------------------
/projects/fake_demo/fake_cfg.yaml:
--------------------------------------------------------------------------------
1 | DEFAULT_CUDNN_BENCHMARK: True
2 | SEED: 666
3 | VERSION: 1
4 | abtfs:
5 | blur:
6 | blur_limit: 3
7 | enable: 0
8 | bright:
9 | clip_limit: 1
10 | enable: 1
11 | channel_dropout:
12 | drop_range: (1, 1)
13 | enable: 0
14 | fill_value: 127
15 | channel_shuffle:
16 | enable: 0
17 | cutout:
18 | enable: 1
19 | fill_value: 127
20 | num_holes: 10
21 | size: 20
22 | distortion:
23 | enable: 0
24 | hue:
25 | enable: 0
26 | noise:
27 | enable: 1
28 | random_grid_shuffle:
29 | enable: 0
30 | grid: 2
31 | rotate:
32 | enable: 1
33 | p: 1
34 | rotate_limit: 45
35 | scale_limit: 0.2
36 | shift_limit: 0.0625
37 | dataloader:
38 | num_workers: 0
39 | sample_test: 'default'
40 | sample_train: 'default'
41 | dataset:
42 | batch_size: 1
43 | dir: ''
44 | is_train: False
45 | name: 'fakedata'
46 | test_list: ''
47 | train_list: ''
48 | valid_list: ''
49 | hooks:
50 | early_stopping:
51 | min_delta: 0.0
52 | mode: 'max'
53 | monitor: 'valid_acc_1'
54 | patience: 20
55 | setting: 2
56 | verbose: 1
57 | model_checkpoint:
58 | filepath: ''
59 | mode: 'max'
60 | monitor: 'valid_acc_1'
61 | period: 1
62 | prefix: ''
63 | save_top_k: 1
64 | save_weights_only: False
65 | setting: 2
66 | verbose: 1
67 | input:
68 | size: (16, 16)
69 | label_transforms:
70 | name: 'default'
71 | log:
72 | name: ''
73 | path: ''
74 | loss:
75 | class_weight: []
76 | focal_loss:
77 | alpha: []
78 | gamma: 2
79 | size_average: True
80 | label_smoothing: 0.1
81 | name: 'CrossEntropy'
82 | model:
83 | classes: 10
84 | finetune: False
85 | name: 'FakeNet'
86 | pretrained: True
87 | module:
88 | name: 'DefaultModule'
89 | optim:
90 | base_lr: 0.1
91 | momentum: 0.9
92 | name: 'sgd'
93 | scheduler:
94 | gamma: 0.1
95 | milestones: [150, 250]
96 | mode: 'min'
97 | name: 'MultiStepLR'
98 | patience: 10
99 | step_size: 10
100 | t_0: 5
101 | t_max: 10
102 | t_mul: 20
103 | verbose: True
104 | weight_decay: 0.0005
105 | predict_only:
106 | load_ckpt:
107 | checkpoint_path: ''
108 | load_metric:
109 | map_location: ''
110 | on_gpu: True
111 | tags_csv: ''
112 | weights_path: ''
113 | to_pred_file_path: ''
114 | type: 'ckpt'
115 | topk: [1, 3]
116 | trainer:
117 | accumulate_grad_batches: 1
118 | auto_lr_find: False
119 | auto_scale_batch_size: False
120 | auto_select_gpus: False
121 | benchmark: False
122 | check_val_every_n_epoch: 1
123 | default_root_dir: './output'
124 | distributed_backend: 'dp'
125 | fast_dev_run: False
126 | gpus: []
127 | gradient_clip_val: 0
128 | log_gpu_memory: ''
129 | log_save_interval: 100
130 | logger:
131 | mlflow:
132 | experiment_name: 'torchline_logs'
133 | tracking_uri: './output'
134 | setting: 2
135 | test_tube:
136 | name: 'torchline_logs'
137 | save_dir: './output'
138 | version: -1
139 | type: 'test_tube'
140 | max_epochs: 100
141 | max_steps: 1000
142 | min_epochs: 10
143 | min_steps: 400
144 | name: 'DefaultTrainer'
145 | num_nodes: 1
146 | num_sanity_val_steps: 5
147 | num_tpu_cores: ''
148 | overfit_pct: 0.0
149 | precision: 32
150 | print_nan_grads: True
151 | process_position: 0
152 | progress_bar_refresh_rate: 0
153 | reload_dataloaders_every_epoch: False
154 | replace_sampler_ddp: True
155 | resume_from_checkpoint: ''
156 | row_log_interval: 10
157 | show_progress_bar: False
158 | terminate_on_nan: False
159 | test_percent_check: 1.0
160 | track_grad_norm: -1
161 | train_percent_check: 1.0
162 | truncated_bptt_steps: ''
163 | val_check_interval: 1.0
164 | val_percent_check: 1.0
165 | weights_save_path: ''
166 | weights_summary: ''
167 | transforms:
168 | img:
169 | aug_cifar: False
170 | aug_imagenet: False
171 | center_crop:
172 | enable: 0
173 | color_jitter:
174 | brightness: 0.0
175 | contrast: 0.0
176 | enable: 0
177 | hue: 0.0
178 | saturation: 0.0
179 | random_crop:
180 | enable: 1
181 | padding: 4
182 | random_horizontal_flip:
183 | enable: 1
184 | p: 0.5
185 | random_resized_crop:
186 | enable: 0
187 | ratio: (0.75, 1.3333333333333333)
188 | scale: (0.5, 1.0)
189 | random_rotation:
190 | degrees: 10
191 | enable: 0
192 | random_vertical_flip:
193 | enable: 0
194 | p: 0.5
195 | resize:
196 | enable: 0
197 | name: 'DefaultTransforms'
198 | tensor:
199 | normalization:
200 | mean: [0.4914, 0.4822, 0.4465]
201 | std: [0.2023, 0.1994, 0.201]
202 | random_erasing:
203 | enable: 0
204 | p: 0.5
205 | ratio: ((0.3, 3.3),)
206 | scale: (0.02, 0.3)
--------------------------------------------------------------------------------
/projects/fake_demo/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Runs a model on a single node across N-gpus.
3 | """
4 | import argparse
5 | import glob
6 | import os
7 | import shutil
8 | import sys
9 | sys.path.append('../..')
10 | from argparse import ArgumentParser
11 |
12 | import numpy as np
13 | import torch
14 |
15 | from torchline.models import META_ARCH_REGISTRY
16 | from torchline.trainer import build_trainer
17 | from torchline.config import get_cfg
18 | from torchline.engine import build_module
19 | from torchline.utils import Logger, get_imgs_to_predict
20 |
21 | @META_ARCH_REGISTRY.register()
22 | class FakeNet(torch.nn.Module):
23 | def __init__(self, cfg):
24 | super(FakeNet, self).__init__()
25 | self.cfg = cfg
26 | h, w = cfg.input.size
27 | self.feat = torch.nn.Conv2d(3,16,3,1,1)
28 | self.clf = torch.nn.Linear(16, cfg.model.classes)
29 | print(self)
30 |
31 | def forward(self, x):
32 | b = x.shape[0]
33 | out = self.feat(x)
34 | out = torch.nn.AdaptiveAvgPool2d(1)(out).view(b, -1)
35 | out = self.clf(out)
36 | out = torch.softmax(out, dim=1)
37 | return out
38 |
39 | def main(hparams):
40 | """
41 | Main training routine specific for this project
42 | :param hparams:
43 | """
44 |
45 | cfg = get_cfg()
46 | cfg.setup_cfg_with_hparams(hparams)
47 | if hasattr(hparams, "test_only") and hparams.test_only:
48 | model = build_module(cfg)
49 | trainer = build_trainer(cfg, hparams)
50 | trainer.test(model)
51 | else:
52 | model = build_module(cfg)
53 | trainer = build_trainer(cfg, hparams)
54 | trainer.fit(model)
55 |
56 |
57 | if __name__ == '__main__':
58 | # ------------------------
59 | # TRAINING ARGUMENTS
60 | # ------------------------
61 | # these are project-wide arguments
62 |
63 | root_dir = os.path.dirname(os.path.realpath(__file__))
64 | parent_parser = ArgumentParser(add_help=False)
65 |
66 | # gpu args
67 | parent_parser.add_argument("--config_file", default="./fake_cfg.yaml", metavar="FILE", help="path to config file")
68 | parent_parser.add_argument('--test_only', action='store_true', help='if true, return trainer.test(model). Validates only the test set')
69 | parent_parser.add_argument('--predict_only', action='store_true', help='if true run model(samples). Predict on the given samples.')
70 | parent_parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER)
71 |
72 | # each LightningModule defines arguments relevant to it
73 | hparams = parent_parser.parse_args()
74 | assert not (hparams.test_only and hparams.predict_only), "You can't set both 'test_only' and 'predict_only' True"
75 |
76 | # ---------------------
77 | # RUN TRAINING
78 | # ---------------------
79 | main(hparams)
80 |
--------------------------------------------------------------------------------
/release.md:
--------------------------------------------------------------------------------
1 |
2 | # Release
3 |
4 | ## v0.1
5 |
6 | ### 2019.12.12 更新信息
7 | - 基本框架搭建完成
8 | - 可正常运行cifar10_demo
9 |
10 |
11 | ### 2019.12.13 更新信息
12 | - 实现setup安装
13 | - 完善包之间的引用关系
14 | - 完善各模块之间的关系
15 | - data
16 | - build_data
17 | - build_sampler
18 | - build_transforms
19 | - build_label_transforms
20 | - engine
21 | - build_module
22 | - losses
23 | - build_loss_fn
24 | - models
25 | - build_model
26 | -
27 | ### 2019.12.16更新信息
28 |
29 | - 更新package版本至`torchline-0.1.3`
30 | - 更新transforms类结构,可以更加方便的自定义修改
31 |
32 | ### 2019.12.17更新信息
33 | - 可单独预测指定路径下的图片,`predict_only`模式,
34 | - 完成`test_only`模式
35 | - 新增topk结果显示
36 | - 支持restore training
37 | > 详细细节可查看[project/skin/ReadMe.md](projects/skin/ReadMe.md)
38 |
39 | ### 2019.12.19更新信息
40 | - 修改各种路径参数设置逻辑(详见[skin/ReadMe.md](projects/skin/ReadMe.md))
41 |
42 |
43 | ## v0.2
44 |
45 | ### 2020.02.18
46 |
47 | - 优化代码逻辑
48 | - 抽象化出`trainer`类
49 | - `trainer`负责整个计算逻辑, `engine`下定义的`DefaultModule`用来指定具体的步骤包括:模型,优化器,训练、验证、测试过程,以及数据的读取等,`models`中则是定义了具体的模型,如resnet等。
50 |
51 | ## v0.2.2
52 |
53 | - 代码逻辑更加清晰
54 | - 修复日志bug,使用更加灵活
55 |
56 | ## v0.2.2.2
57 |
58 | - config输出格式化
59 |
60 | ## v0.2.3.0
61 |
62 | - 增加新的输出日志的方式(即logging),日志可保存到文件方便查看。之前默认使用tqdm,一个比较大的缺点时无法清晰的看到模型是否开始收敛。
63 | - 引入`AverageMeterGroup`等来记录各项指标,能更清晰地看出整体收敛趋势
64 | - 更新`fake_data` demo
65 | - 修复`CosineAnnealingLR` 存在的bug
66 |
67 | ## v0.2.3.1
68 | - 优化`AverageMeterGroup`输出格式,方便阅读
69 |
70 | ## v0.2.3.2
71 | - 优化优化器代码,增加可扩展性
72 | - 增加学习率热启动(`CosineAnnealingWarmRestarts`)
73 |
74 | ## v0.2.3.3
75 | - 优化resume
76 | - 增加albumentations数据增广操作
77 | - 修改之前的resize和crop之间的逻辑关系
78 |
79 | ## v0.2.3.4
80 | - 抽象化optimizer和scheduler定义,方便从外部直接调用
81 | - 添加计算模型大小的函数
82 |
83 | ## v0.2.4.0
84 | - 增加大量SOTA模型结构,如Mnasnet, mobilenet等
85 | - 统一模型结构(features, logits, forward, last_linear)
86 |
87 | ## v0.2.4.1
88 | - 修改单机多卡训练bug
89 | - 此模式夏batch size必须是gpu的整数倍,否则汇报如下错误:
90 | ```Python
91 | ValueError: only one element tensors can be converted to Python scalars
92 | ```
93 | - 规范化两种日志模式: tqdm和logging
94 |
95 | ## v0.2.4.2
96 | - 修复单机多卡训练时的bug
97 | - 修改和统一model forward函数: features+logits
98 |
99 | ## v0.2.4.3
100 | - 更新module forward函数
101 | - 增加loss函数,最小化entropy
102 |
103 | ## v0.3.0.0 2020.05.11
104 | - 适配pytorchlightning 0.7.5
105 |
106 | ## v0.3.0.1 2020.05.14
107 | - 完善logger path自动匹配功能;例如,当resume_from_checkpoint时会自动还原之前的logger路径。反之则会自动更新logger version。
108 |
109 | ## v0.3.0.2 2020.05.19
110 | - 允许`DefaultModule`存储计算中的所有`gt_labels`和`predictions`,然后可以通过`analyze_result`函数计算指定的指标,例如precision或者auc等。
111 |
112 | ## v0.3.0.3 2020.05.19
113 | - 修复上一版本在多GPU状态下报错的问题
114 |
115 | ## v0.3.0.4 2020.05.19
116 | - 修复多GPU状态下重复打印日志问题
117 |
118 | # TODO list
119 |
120 |
121 | - [x] 弄清楚logging机制
122 | - [x] save和load模型,优化器参数
123 | - [x] skin数据集读取测试
124 | - [x] 构建skin project
125 | - [x] 能否预测单张图片?
126 | - [x] 构建一个简单的API接口
127 | - [x] 进一步完善包导入
128 | - [x] 设置训练epoch数量
129 | - [X] 内建更多SOTA模型
130 | - [x] 每个epoch输出学习率大小
131 | - [x] resume时输出checkpoint的结果
132 | - [x] 如果resume,则自动匹配checkpoints等路径
133 | - [x] 优化输出日志信息
134 | - [x] 使用albumentations做数据增强
135 | - [x] transforms resize和randomcrop逻辑关系
136 | - [x] 从engine中抽离出optimizer和scheduler
137 | - [x] ~~resnet结构可能有问题,resnet50应该有98MB,但是实现只有89.7~~。(没有错,只是因为计算时将classes设置成了10,所以导致了误差)
138 | - [x] 单机多卡多GPU测试
139 | - [x] ~~考虑是否将finetune设置内嵌到模型中~~ (取消设置,避免模型代码臃肿)
140 | - [x] 多GPU运行时日志会因为多线程而导致先后输出不同batch的结果,需要在结果整合后再输出结果,可以考虑将`print_log`放到`on_batch_end`里去(#v0.3.0.4)
141 | - [ ] 设置更多默认的数据集
142 | - [ ] 完善使用文档
143 | - [x] ~~评估使用hydra代替yacs的必要性~~(工作量太大)
144 | - [ ] 增加config参数鲁棒性和兼容性
145 | - [ ] 多机多卡测试
146 | - [x] template project. 可快速基于torchline创建一个新项目
147 | - [ ] 将`default_module`中的`parse_cfg_for_scheduler`解耦,放到`utils.py`文件中去
148 | - [ ] checkpoint将scheduler参数也保存,同时添加设置可以跳过optimizer或scheduler的restore
149 | - [x] multi-gpus情况下日志会生成多份,打印信息也有这种情况(#v0.3.0.4)
150 | - [ ] 文件结构重构
151 | - [x] 适配pytorchlightning 0.7.5版本
152 | - [x] ~~规范参数名称,尽量使用全程,如使用optimizer而不是optim # 在大版本v0.3.0.0中更新~~
153 | - [x] ~~albumentations和torchvision读取图片使用的分别是cv2和PIL,数据格式分别是numpy和PIL.Image,后面需要考虑如何统一格式。~~ (在实现transformation的时候特别处理一下)
154 | - [ ] 增加`Module`中`print_log`通用性
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | yacs==0.1.6
2 | pytorch-lightning==0.7.6
3 | pretrainedmodels>=0.7.4
4 | pillow>=6.2.0
5 | test-tube==0.7.5
6 | albumentations>=0.4.3
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open('requirements.txt', 'r') as f:
4 | requirements = f.read().splitlines()
5 |
6 | setup(
7 | name="torchline", # Replace with your own username
8 | version="0.3.0.4",
9 | author="marsggbo",
10 | author_email="csxinhe@comp.hkbu.edu.hk",
11 | description="A framework for easy to use Pytorch",
12 | long_description='''
13 | The ML developer can easily use this framework to implement your ideas. Our framework is built based on pytorch_lightning,
14 | and the structures is inspired by detectron2''',
15 | long_description_content_type="text/markdown",
16 | url="https://github.com/marsggbo/torchline",
17 | packages=find_packages(exclude=("tests", "projects")),
18 | install_requires=requirements,
19 | classifiers=[
20 | "Programming Language :: Python :: 3",
21 | "License :: OSI Approved :: MIT License",
22 | "Operating System :: OS Independent",
23 | ],
24 | python_requires='>=3.6',
25 | )
26 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchline as tl
3 | cfg = tl.config.get_cfg()
4 | cfg.model.pretrained = False
5 |
6 | def cpu_test():
7 | print('=====cpu=====')
8 | x = torch.rand(1,3,64,64)
9 | for m in tl.models.model_list:
10 | cfg.model.name = m
11 | net = tl.models.build_model(cfg)
12 | try:
13 | y = net(x)
14 | print(f"{m} pass")
15 | except Exception as e:
16 | print(f"{m} fail")
17 | print(str(e))
18 | pass
19 |
20 | def single_gpu_test():
21 | print('=====single_gpu=====')
22 | x = torch.rand(4,3,64,64)
23 | for m in tl.models.model_list:
24 | cfg.model.name = m
25 | net = tl.models.build_model(cfg).cuda()
26 | x = x.cuda()
27 | try:
28 | y = net(x)
29 | print(f"{m} pass")
30 | except Exception as e:
31 | print(f"{m} fail")
32 | print(str(e))
33 | pass
34 |
35 | def multi_gpus_test():
36 | print('=====multi_gpu=====')
37 | x = torch.rand(4,3,64,64)
38 | for m in tl.models.model_list:
39 | cfg.model.name = m
40 | net = tl.models.build_model(cfg).cuda()
41 | net = torch.nn.DataParallel(net, device_ids=[0,1])
42 | x = x.cuda()
43 | try:
44 | y = net(x)
45 | print(f"{m} pass")
46 | except Exception as e:
47 | print(f"{m} fail")
48 | print(str(e))
49 | pass
50 |
51 | if __name__ == '__main__':
52 | cpu_test()
53 | single_gpu_test()
54 | multi_gpus_test()
--------------------------------------------------------------------------------
/tests/test_transforms.py:
--------------------------------------------------------------------------------
1 | import torchline as tl
2 | cfg = tl.config.get_cfg()
3 | cfg.dataset.is_train=True
4 | cfg.transforms.name = 'AlbumentationsTransforms'
5 | cfg.abtfs.hue.enable=1
6 | cfg.abtfs.rotate.enable=1
7 | cfg.abtfs.bright.enable=1
8 | cfg.abtfs.noise.enable=1
9 | tf2 = tl.data.build_transforms(cfg)
--------------------------------------------------------------------------------
/torchline/__init__.py:
--------------------------------------------------------------------------------
1 | from . import config, data, models, engine, losses, models, utils, trainer
2 | version = "0.3.0.4"
3 |
--------------------------------------------------------------------------------
/torchline/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import CfgNode, get_cfg
--------------------------------------------------------------------------------
/torchline/config/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import torch
5 | from yacs.config import CfgNode as _CfgNode
6 |
7 | from torchline.utils import Logger
8 |
9 |
10 | class CfgNode(_CfgNode):
11 |
12 | def _version(self):
13 | '''
14 | calculate the version of the configuration and logger
15 | '''
16 | def calc_version(path):
17 | if (not os.path.exists(path)) or len(os.listdir(path))==0:
18 | version = 0
19 | else:
20 | versions = [int(v.split('_')[-1]) for v in os.listdir(path)]
21 | version = max(versions)+1
22 | return version
23 |
24 | save_dir=self.trainer.default_root_dir
25 | logger_name=self.trainer.logger.test_tube.name
26 | path = os.path.join(save_dir, logger_name)
27 | if self.trainer.logger.setting==0:
28 | version = calc_version(path)
29 | elif self.trainer.logger.setting==2:
30 | if self.trainer.logger.test_tube.version<0:
31 | version = calc_version(path)
32 | else:
33 | version = int(self.trainer.logger.test_tube.version)
34 | return version
35 |
36 | def setup_cfg_with_hparams(self, hparams):
37 | """
38 | Create configs and perform basic setups.
39 | Args:
40 | hparams: arguments args from command line
41 | """
42 | self.merge_from_file(hparams.config_file)
43 | self.merge_from_list(hparams.opts)
44 | self.update({'hparams': hparams})
45 |
46 | # './outputs/torchline_logs/version_0/checkpoints/_ckpt_epoch_1.ckpt'
47 | ckpt_file = self.trainer.resume_from_checkpoint
48 | if ckpt_file: # resume cfg
49 | assert os.path.exists(ckpt_file), f"{ckpt_file} not exits"
50 | ckpt_path = os.path.dirname(ckpt_file).split('/')[:-1]
51 | # set log cfg
52 | self.log.path = ''.join([p+'/' for p in ckpt_path])
53 | self.log.name = os.path.join(self.log.path, 'log.txt')
54 | # set trainer logger cfg
55 | self.trainer.logger.setting = 2
56 | root = self.trainer.default_root_dir
57 | self.trainer.logger.test_tube.name = ckpt_file.replace(root,'').split('/')[1]
58 | # self.trainer.logger.test_tube.version
59 | else:
60 | version = self._version()
61 | save_dir = self.trainer.default_root_dir
62 | logger_name = self.trainer.logger.test_tube.name
63 | self.log.path = os.path.join(save_dir, logger_name, f"version_{version}")
64 | self.log.name = os.path.join(self.log.path, 'log.txt')
65 | self.freeze()
66 |
67 | os.makedirs(self.log.path, exist_ok=True)
68 |
69 | # copy config file
70 | src_cfg_file = hparams.config_file # source config file
71 | cfg_file_name = os.path.basename(src_cfg_file) # config file name
72 | dst_cfg_file = os.path.join(self.log.path, cfg_file_name)
73 | with open(dst_cfg_file, 'w') as f:
74 | hparams = self.hparams
75 | self.pop('hparams')
76 | f.write(str(self))
77 | self.update({'hparams': hparams})
78 |
79 | if not (hparams.test_only or hparams.predict_only):
80 | torch.backends.cudnn.benchmark = False
81 | else:
82 | torch.backends.cudnn.benchmark = self.DEFAULT_CUDNN_BENCHMARK
83 |
84 | Logger(__name__, self.log.name).getlogger().info("Running with full config:\n{}".format(self))
85 |
86 | def __str__(self):
87 | def _indent(s_, num_spaces):
88 | s = s_.split("\n")
89 | if len(s) == 1:
90 | return s_
91 | first = s.pop(0)
92 | s = [(num_spaces * " ") + line for line in s]
93 | s = "\n".join(s)
94 | s = first + "\n" + s
95 | return s
96 |
97 | r = ""
98 | s = []
99 | for k, v in sorted(self.items()):
100 | seperator = "\n" if isinstance(v, CfgNode) else " "
101 | v = f"'{v}'" if isinstance(v, str) else v
102 | attr_str = "{}:{}{}".format(str(k), seperator, str(v))
103 | attr_str = _indent(attr_str, 4)
104 | s.append(attr_str)
105 | r += "\n".join(s)
106 | return r
107 |
108 | global_cfg = CfgNode()
109 |
110 | def get_cfg():
111 | '''
112 | Get a copy of the default config.
113 |
114 | Returns:
115 | a CfgNode instance.
116 | '''
117 | from .default import _C
118 | return _C.clone()
--------------------------------------------------------------------------------
/torchline/config/default.py:
--------------------------------------------------------------------------------
1 | import random
2 | from .config import CfgNode as CN
3 |
4 | _C = CN()
5 | _C.VERSION = 1
6 |
7 |
8 | # ---------------------------------------------------------------------------- #
9 | # input
10 | # ---------------------------------------------------------------------------- #
11 | _C.input = CN()
12 | _C.input.size = (224, 224)
13 |
14 | # ---------------------------------------------------------------------------- #
15 | # dataset
16 | # ---------------------------------------------------------------------------- #
17 | _C.dataset = CN()
18 | _C.dataset.name = 'cifar10'
19 | _C.dataset.batch_size = 16
20 | _C.dataset.dir = './datasets/skin100_dataset/train'
21 | _C.dataset.train_list = './datasets/train_skin10.txt'
22 | _C.dataset.valid_list = './datasets/valid_skin10.txt'
23 | _C.dataset.test_list = './datasets/test_skin10.txt'
24 | _C.dataset.is_train = False # specify to load training or testing set
25 |
26 |
27 | # ---------------------------------------------------------------------------- #
28 | # transforms
29 | # ---------------------------------------------------------------------------- #
30 |
31 | _C.transforms = CN() # image transforms
32 | _C.transforms.name = 'DefaultTransforms'
33 |
34 |
35 | ## transforms for tensor
36 | _C.transforms.tensor = CN()
37 | # for skin100
38 | _C.transforms.tensor.normalization = CN()
39 | _C.transforms.tensor.normalization.mean = [0.6075, 0.4564, 0.4182]
40 | _C.transforms.tensor.normalization.std = [0.2158, 0.1871, 0.1826]
41 | # _C.transforms.tensor.normalization = {
42 | # 'mean':[0.6054, 0.4433, 0.4084],
43 | # 'std': [0.2125, 0.1816, 0.1786] # for skin10
44 | _C.transforms.tensor.random_erasing = CN()
45 | _C.transforms.tensor.random_erasing.enable = 0
46 | _C.transforms.tensor.random_erasing.p = 0.5
47 | _C.transforms.tensor.random_erasing.scale = (0.02, 0.3) # range of proportion of erased area against input image.
48 | _C.transforms.tensor.random_erasing.ratio = (0.3, 3.3), # range of aspect ratio of erased area.
49 |
50 |
51 | # ---------------------------------------------------------------------------- #
52 | # albumentations transforms (abtfs)
53 | # ---------------------------------------------------------------------------- #
54 |
55 | _C.abtfs = CN()
56 | _C.abtfs.random_grid_shuffle = CN()
57 | _C.abtfs.random_grid_shuffle.enable = 0
58 | _C.abtfs.random_grid_shuffle.grid = 2
59 |
60 | _C.abtfs.channel_shuffle = CN()
61 | _C.abtfs.channel_shuffle.enable = 0
62 |
63 | _C.abtfs.channel_dropout = CN()
64 | _C.abtfs.channel_dropout.enable = 0
65 | _C.abtfs.channel_dropout.drop_range = (1, 1)
66 | _C.abtfs.channel_dropout.fill_value = 127
67 |
68 | _C.abtfs.noise = CN()
69 | _C.abtfs.noise.enable = 1
70 |
71 | _C.abtfs.blur = CN()
72 | _C.abtfs.blur.enable = 0
73 | _C.abtfs.blur.blur_limit = 3
74 |
75 | _C.abtfs.rotate = CN()
76 | _C.abtfs.rotate.enable = 1
77 | _C.abtfs.rotate.p = 1
78 | _C.abtfs.rotate.shift_limit = 0.0625
79 | _C.abtfs.rotate.scale_limit = 0.2
80 | _C.abtfs.rotate.rotate_limit = 45
81 |
82 | _C.abtfs.bright = CN()
83 | _C.abtfs.bright.enable = 1
84 | _C.abtfs.bright.clip_limit = 1
85 |
86 | _C.abtfs.distortion = CN()
87 | _C.abtfs.distortion.enable = 0
88 |
89 | _C.abtfs.hue = CN()
90 | _C.abtfs.hue.enable = 0
91 |
92 | _C.abtfs.cutout = CN()
93 | _C.abtfs.cutout.enable = 1
94 | _C.abtfs.cutout.num_holes = 10
95 | _C.abtfs.cutout.size = 20
96 | _C.abtfs.cutout.fill_value = 127
97 |
98 | # ---------------------------------------------------------------------------- #
99 | # torchvision transforms
100 | # ---------------------------------------------------------------------------- #
101 |
102 | ## transforms for PIL image
103 | _C.transforms.img = CN()
104 |
105 | ### modify the image size, only use one operation
106 | # random_resized_crop
107 | _C.transforms.img.random_resized_crop = CN()
108 | _C.transforms.img.random_resized_crop.enable = 0
109 | _C.transforms.img.random_resized_crop.scale = (0.5, 1.0)
110 | _C.transforms.img.random_resized_crop.ratio = (3/4, 4/3)
111 |
112 | # resize
113 | _C.transforms.img.resize = CN()
114 | _C.transforms.img.resize.enable = 1
115 |
116 | # random_crop
117 | _C.transforms.img.random_crop = CN()
118 | _C.transforms.img.random_crop.enable = 1
119 | _C.transforms.img.random_crop.padding = 0
120 |
121 | # center_crop
122 | _C.transforms.img.center_crop = CN()
123 | _C.transforms.img.center_crop.enable = 0
124 |
125 | ### without modifying the image size
126 | _C.transforms.img.aug_imagenet = False
127 | _C.transforms.img.aug_cifar = False
128 |
129 | # color_jitter
130 | _C.transforms.img.color_jitter = CN()
131 | _C.transforms.img.color_jitter.enable = 0
132 | _C.transforms.img.color_jitter.brightness = 0.
133 | _C.transforms.img.color_jitter.contrast = 0.
134 | _C.transforms.img.color_jitter.saturation = 0.
135 | _C.transforms.img.color_jitter.hue = 0.
136 |
137 | # horizontal_flip
138 | _C.transforms.img.random_horizontal_flip = CN()
139 | _C.transforms.img.random_horizontal_flip.enable = 1
140 | _C.transforms.img.random_horizontal_flip.p = 0.5
141 |
142 | # vertical_flip
143 | _C.transforms.img.random_vertical_flip = CN()
144 | _C.transforms.img.random_vertical_flip.enable = 1
145 | _C.transforms.img.random_vertical_flip.p = 0.5
146 |
147 | # random_rotation
148 | _C.transforms.img.random_rotation = CN()
149 | _C.transforms.img.random_rotation.enable = 1
150 | _C.transforms.img.random_rotation.degrees = 10
151 |
152 |
153 |
154 | _C.label_transforms = CN() # label transforms
155 | _C.label_transforms.name = 'default'
156 |
157 |
158 | # ---------------------------------------------------------------------------- #
159 | # dataloader
160 | # ---------------------------------------------------------------------------- #
161 | _C.dataloader = CN()
162 | _C.dataloader.num_workers = 4
163 | _C.dataloader.sample_train = "default"
164 | _C.dataloader.sample_test = "default"
165 |
166 |
167 | # ---------------------------------------------------------------------------- #
168 | # model
169 | # ---------------------------------------------------------------------------- #
170 | _C.model = CN()
171 | _C.model.name = 'Resnet50'
172 | _C.model.classes = 10
173 | _C.model.pretrained = True
174 | _C.model.finetune = False
175 |
176 |
177 | # ---------------------------------------------------------------------------- #
178 | # optimizer
179 | # ---------------------------------------------------------------------------- #
180 | _C.optim = CN()
181 | _C.optim.name = 'adam'
182 | _C.optim.momentum = 0.9
183 | _C.optim.base_lr = 0.001
184 | # _C.optim.lr = _C.optim.base_lr # will changed in v0.3.0.0
185 | _C.optim.weight_decay = 0.0005
186 |
187 | # scheduler
188 | _C.optim.scheduler = CN()
189 | _C.optim.scheduler.name = 'MultiStepLR'
190 | _C.optim.scheduler.gamma = 0.1 # decay factor
191 |
192 | # for CosineAnnealingLR
193 | _C.optim.scheduler.t_max = 10
194 |
195 | # for CosineAnnealingLR
196 | _C.optim.scheduler.t_0 = 5
197 | _C.optim.scheduler.t_mul = 20
198 |
199 | # for ReduceLROnPlateau
200 | _C.optim.scheduler.mode = 'min' # min for loss, max for acc
201 | _C.optim.scheduler.patience = 10
202 | _C.optim.scheduler.verbose = True # print log once update lr
203 |
204 | # for StepLR
205 | _C.optim.scheduler.step_size = 10
206 |
207 | # for MultiStepLR
208 | _C.optim.scheduler.milestones = [10, 25, 35, 50]
209 |
210 | # _C.optimizer = _C.optim # enhance compatibility. will changed in v0.3.0.0
211 | # ---------------------------------------------------------------------------- #
212 | # loss
213 | # ---------------------------------------------------------------------------- #
214 | _C.loss = CN()
215 | _C.loss.name = 'CrossEntropy'
216 | _C.loss.class_weight = []
217 | _C.loss.label_smoothing = 0.1 # CrossEntropyLabelSmooth
218 |
219 | _C.loss.focal_loss = CN()
220 | _C.loss.focal_loss.alpha = [] # FocalLoss
221 | _C.loss.focal_loss.gamma = 2
222 | _C.loss.focal_loss.size_average = True
223 | # ---------------------------------------------------------------------------- #
224 | # hooks
225 | # ---------------------------------------------------------------------------- #
226 | _C.hooks = CN()
227 |
228 | ## EarlyStopping
229 | _C.hooks.early_stopping = CN()
230 | _C.hooks.early_stopping.setting = 2 # 0: True 1: False 2: custom
231 | _C.hooks.early_stopping.monitor = 'valid_loss' # or 'valid_acc_1
232 | _C.hooks.early_stopping.min_delta = 0.
233 | _C.hooks.early_stopping.patience = 10
234 | _C.hooks.early_stopping.mode = 'min' # or 'max
235 | _C.hooks.early_stopping.verbose = 1
236 |
237 | # ModelCheckpoint
238 | _C.hooks.model_checkpoint = CN()
239 | _C.hooks.model_checkpoint.setting = 2 # 0: True 1: False 2: custom
240 | _C.hooks.model_checkpoint.filepath = '' # the empty file path is recommended
241 | _C.hooks.model_checkpoint.monitor = 'valid_loss'
242 | _C.hooks.model_checkpoint.mode = 'min'
243 | _C.hooks.model_checkpoint.save_top_k = 1
244 | _C.hooks.model_checkpoint.save_weights_only = False
245 | _C.hooks.model_checkpoint.verbose = 1
246 | _C.hooks.model_checkpoint.period = 1
247 | _C.hooks.model_checkpoint.prefix = ''
248 |
249 |
250 | # ---------------------------------------------------------------------------- #
251 | # Module template
252 | # ---------------------------------------------------------------------------- #
253 |
254 | _C.module = CN()
255 | _C.module.name = 'DefaultModule'
256 | _C.module.analyze_result = False
257 | # analyze the predictions and gt_laels, e.g., compute the f1_score, precision
258 | # you can modify the `analyze_result(self, gt_labels, predictions)` in `default_module.py`
259 |
260 |
261 | # ---------------------------------------------------------------------------- #
262 | # Trainer
263 | # ---------------------------------------------------------------------------- #
264 |
265 | _C.trainer = CN()
266 | _C.trainer.name = 'DefaultTrainer'
267 | _C.trainer.default_save_path = './output' # will be removed
268 | _C.trainer.default_root_dir = './output'
269 | _C.trainer.gradient_clip_val = 0 # 0 means don't clip.
270 | _C.trainer.process_position = 0
271 | _C.trainer.num_nodes = 1
272 | _C.trainer.gpus = [] # list
273 | _C.trainer.auto_select_gpus = False # If `auto_select_gpus` is enabled and `gpus` is an integer, pick available gpus automatically.
274 | # This is especially useful when GPUs are configured to be in "exclusive mode", such that
275 | # only one process at a time can access them.
276 | _C.trainer.num_tpu_cores = '' # How many TPU cores to train on (1 or 8).
277 | _C.trainer.log_gpu_memory = "" # None, 'min_max', 'all'. Might slow performance
278 | _C.trainer.show_progress_bar = False # will be removed
279 | _C.trainer.progress_bar_refresh_rate = 0 # How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
280 | _C.trainer.overfit_pct = 0.0 # if 0>> policy = ImageNetPolicy()
16 | >>> transformed = policy(image)
17 | Example as a PyTorch Transform:
18 | >>> transform=transforms.Compose([
19 | >>> transforms.Resize(256),
20 | >>> ImageNetPolicy(),
21 | >>> transforms.ToTensor()])
22 | """
23 | def __init__(self, fillcolor=(128, 128, 128)):
24 | self.policies = [
25 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
26 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
27 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
28 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
29 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
30 |
31 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
32 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
33 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
34 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
35 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
36 |
37 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
38 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
39 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
40 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
41 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
42 |
43 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
44 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
45 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
46 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
47 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
48 |
49 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
50 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
51 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
52 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
53 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
54 | ]
55 |
56 |
57 | def __call__(self, img):
58 | policy_idx = random.randint(0, len(self.policies) - 1)
59 | return self.policies[policy_idx](img)
60 |
61 | def __repr__(self):
62 | return "AutoAugment ImageNet Policy"
63 |
64 |
65 | class CIFAR10Policy(object):
66 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
67 | Example:
68 | >>> policy = CIFAR10Policy()
69 | >>> transformed = policy(image)
70 | Example as a PyTorch Transform:
71 | >>> transform=transforms.Compose([
72 | >>> transforms.Resize(256),
73 | >>> CIFAR10Policy(),
74 | >>> transforms.ToTensor()])
75 | """
76 | def __init__(self, fillcolor=(128, 128, 128)):
77 | self.policies = [
78 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
79 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
80 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
81 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
82 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
83 |
84 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
85 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
86 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
87 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
88 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
89 |
90 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
91 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
92 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
93 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
94 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
95 |
96 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
97 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
98 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
99 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
100 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
101 |
102 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
103 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
104 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
105 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
106 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
107 | ]
108 |
109 |
110 | def __call__(self, img):
111 | policy_idx = random.randint(0, len(self.policies) - 1)
112 | return self.policies[policy_idx](img)
113 |
114 | def __repr__(self):
115 | return "AutoAugment CIFAR10 Policy"
116 |
117 |
118 | class SVHNPolicy(object):
119 | """ Randomly choose one of the best 25 Sub-policies on SVHN.
120 | Example:
121 | >>> policy = SVHNPolicy()
122 | >>> transformed = policy(image)
123 | Example as a PyTorch Transform:
124 | >>> transform=transforms.Compose([
125 | >>> transforms.Resize(256),
126 | >>> SVHNPolicy(),
127 | >>> transforms.ToTensor()])
128 | """
129 | def __init__(self, fillcolor=(128, 128, 128)):
130 | self.policies = [
131 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
132 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
133 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
134 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
135 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
136 |
137 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
138 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
139 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
140 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
141 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
142 |
143 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
144 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
145 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
146 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
147 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
148 |
149 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
150 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
151 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
152 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
153 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
154 |
155 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
156 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
157 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
158 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
159 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
160 | ]
161 |
162 |
163 | def __call__(self, img):
164 | policy_idx = random.randint(0, len(self.policies) - 1)
165 | return self.policies[policy_idx](img)
166 |
167 | def __repr__(self):
168 | return "AutoAugment SVHN Policy"
169 |
170 |
171 | class SubPolicy(object):
172 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
173 | ranges = {
174 | "shearX": np.linspace(0, 0.3, 10),
175 | "shearY": np.linspace(0, 0.3, 10),
176 | "translateX": np.linspace(0, 150 / 331, 10),
177 | "translateY": np.linspace(0, 150 / 331, 10),
178 | "rotate": np.linspace(0, 30, 10),
179 | "color": np.linspace(0.0, 0.9, 10),
180 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
181 | "solarize": np.linspace(256, 0, 10),
182 | "contrast": np.linspace(0.0, 0.9, 10),
183 | "sharpness": np.linspace(0.0, 0.9, 10),
184 | "brightness": np.linspace(0.0, 0.9, 10),
185 | "autocontrast": [0] * 10,
186 | "equalize": [0] * 10,
187 | "invert": [0] * 10
188 | }
189 |
190 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
191 | def rotate_with_fill(img, magnitude):
192 | rot = img.convert("RGBA").rotate(magnitude)
193 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
194 |
195 | func = {
196 | "shearX": lambda img, magnitude: img.transform(
197 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
198 | Image.BICUBIC, fillcolor=fillcolor),
199 | "shearY": lambda img, magnitude: img.transform(
200 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
201 | Image.BICUBIC, fillcolor=fillcolor),
202 | "translateX": lambda img, magnitude: img.transform(
203 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
204 | fillcolor=fillcolor),
205 | "translateY": lambda img, magnitude: img.transform(
206 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
207 | fillcolor=fillcolor),
208 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
209 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
210 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
211 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
212 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
213 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
214 | 1 + magnitude * random.choice([-1, 1])),
215 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
216 | 1 + magnitude * random.choice([-1, 1])),
217 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
218 | 1 + magnitude * random.choice([-1, 1])),
219 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
220 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
221 | "invert": lambda img, magnitude: ImageOps.invert(img)
222 | }
223 |
224 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
225 | # operation1, ranges[operation1][magnitude_idx1],
226 | # operation2, ranges[operation2][magnitude_idx2])
227 | self.p1 = p1
228 | self.operation1 = func[operation1]
229 | self.magnitude1 = ranges[operation1][magnitude_idx1]
230 | self.p2 = p2
231 | self.operation2 = func[operation2]
232 | self.magnitude2 = ranges[operation2][magnitude_idx2]
233 |
234 |
235 | def __call__(self, img):
236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
238 | return img
--------------------------------------------------------------------------------
/torchline/data/build.py:
--------------------------------------------------------------------------------
1 | from torchline.utils import Registry
2 |
3 | DATASET_REGISTRY = Registry("DATASET")
4 | DATASET_REGISTRY.__doc__ = """
5 | Registry for dataset, i.e. torch.utils.data.Dataset.
6 |
7 | The registered object will be called with `obj(cfg)`
8 | """
9 |
10 | __all__ = [
11 | 'DATASET_REGISTRY',
12 | 'build_data'
13 | ]
14 |
15 | def build_data(cfg):
16 | """
17 | Built the dataset, defined by `cfg.dataset.name`.
18 | """
19 | name = cfg.dataset.name
20 | return DATASET_REGISTRY.get(name)(cfg)
21 |
--------------------------------------------------------------------------------
/torchline/data/common_datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.datasets import CIFAR10 as _CIFAR10
3 | from torchvision.datasets import MNIST as _MNIST
4 |
5 | from .build import DATASET_REGISTRY
6 | from .transforms import build_transforms
7 |
8 | __all__ = [
9 | 'MNIST',
10 | 'CIFAR10',
11 | 'FakeData',
12 | 'fakedata'
13 | ]
14 |
15 | @DATASET_REGISTRY.register()
16 | def MNIST(cfg):
17 | root = cfg.dataset.dir
18 | is_train = cfg.dataset.is_train
19 | transform = build_transforms(cfg)
20 | return _MNIST(root=root, train=is_train, transform=transform.transform, download=True)
21 |
22 | @DATASET_REGISTRY.register()
23 | def CIFAR10(cfg):
24 | root = cfg.dataset.dir
25 | is_train = cfg.dataset.is_train
26 | transform = build_transforms(cfg)
27 | return _CIFAR10(root=root, train=is_train, transform=transform.transform, download=True)
28 |
29 | class FakeData(torch.utils.data.Dataset):
30 | def __init__(self, size=64, num=100):
31 | if isinstance(size, int):
32 | self.size = [size, size]
33 | elif isinstance(size, list):
34 | self.size = size
35 | self.num = num
36 | self.data = torch.rand(num, 3, *size)
37 | self.labels = torch.randint(0, 10, (num,))
38 |
39 |
40 | def __getitem__(self, index):
41 | return self.data[index], self.labels[index]
42 |
43 | def __len__(self):
44 | return self.num
45 |
46 | @DATASET_REGISTRY.register()
47 | def fakedata(cfg):
48 | size = cfg.input.size
49 | return FakeData(size)
--------------------------------------------------------------------------------
/torchline/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from torchline.utils import Registry, Logger
4 |
5 | SAMPLER_REGISTRY = Registry('SAMPLER')
6 | SAMPLER_REGISTRY.__doc__ = """
7 | Registry for dataset sampler, i.e. torch.utils.data.Sampler.
8 |
9 | The registered object will be called with `obj(cfg)`
10 | """
11 |
12 | __all__ = [
13 | 'build_sampler',
14 | 'SAMPLER_REGISTRY'
15 | ]
16 |
17 | def build_sampler(cfg):
18 | """
19 | Built the dataset sampler, defined by `cfg.dataset.name`.
20 | """
21 | is_train = cfg.dataset.is_train
22 | if is_train:
23 | name = cfg.dataloader.sample_train
24 | else:
25 | name = cfg.dataloader.sample_test
26 | if name == 'default':
27 | return None
28 | return SAMPLER_REGISTRY.get(name)(cfg)
29 |
30 | @SAMPLER_REGISTRY.register()
31 | class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
32 | """Samples elements randomly from a given list of indices for imbalanced dataset
33 | Arguments:
34 | indices (list, optional): a list of indices
35 | num_samples (int, optional): number of samples to draw
36 | """
37 |
38 | def __init__(self, dataset, indices=None, num_samples=None):
39 |
40 | # if indices is not provided,
41 | # all elements in the dataset will be considered
42 | self.indices = list(range(len(dataset))) \
43 | if indices is None else indices
44 |
45 | # if num_samples is not provided,
46 | # draw `len(indices)` samples in each iteration
47 | self.num_samples = len(self.indices) \
48 | if num_samples is None else num_samples
49 |
50 | # distribution of classes in the dataset
51 | label_to_count = {}
52 | for idx in self.indices:
53 | label = self._get_label(dataset, idx)
54 | if label in label_to_count:
55 | label_to_count[label] += 1
56 | else:
57 | label_to_count[label] = 1
58 |
59 | # weight for each sample
60 | weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
61 | for idx in self.indices]
62 | self.weights = torch.DoubleTensor(weights)
63 |
64 | def _get_label(self, dataset, idx):
65 | dataset_type = type(dataset)
66 | if dataset_type is torchvision.datasets.MNIST:
67 | return dataset.train_labels[idx].item()
68 | elif dataset_type is torchvision.datasets.ImageFolder:
69 | return dataset.imgs[idx][1]
70 | else:
71 | try:
72 | return dataset.imgs[idx][-1]
73 | except Exception as e:
74 | print(str(e))
75 | raise NotImplementedError
76 |
77 | def __iter__(self):
78 | return (self.indices[i] for i in torch.multinomial(
79 | self.weights, self.num_samples, replacement=True))
80 |
81 | def __len__(self):
82 | return self.num_samples
83 |
84 |
85 | @SAMPLER_REGISTRY.register()
86 | class DiffWeightedRandomSampler(torch.utils.data.sampler.Sampler):
87 | r"""
88 | Samples elements from a given list of indices with given probabilities (weights), with replacement.
89 |
90 | Arguments:
91 | weights (sequence) : a sequence of weights, not necessary summing up to one
92 | num_samples (int): number of samples to draw
93 |
94 | """
95 |
96 | def __init__(self, indices, weights, num_samples=0):
97 | if not isinstance(num_samples, int) or isinstance(num_samples, bool):
98 | raise ValueError("num_samples should be a non-negative integeral "
99 | "value, but got num_samples={}".format(num_samples))
100 | self.indices = indices
101 | weights = [ weights[i] for i in self.indices ]
102 | self.weights = torch.DoubleTensor(weights)
103 | if num_samples == 0:
104 | self.num_samples = len(self.weights)
105 | else:
106 | self.num_samples = num_samples
107 | self.replacement = True
108 |
109 | def __iter__(self):
110 | return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, self.replacement))
111 |
112 | def __len__(self):
113 | return self.num_samples
--------------------------------------------------------------------------------
/torchline/data/transforms.py:
--------------------------------------------------------------------------------
1 |
2 | #coding=utf-8
3 | import torch
4 | import numpy as np
5 | import math
6 | import random
7 | import torchvision
8 | from torchvision import transforms
9 | from . import autoaugment
10 | from torchline.utils import Registry, Logger
11 |
12 | TRANSFORMS_REGISTRY = Registry('transforms')
13 | TRANSFORMS_REGISTRY.__doc__ = """
14 | Registry for data transform functions, i.e. torchvision.transforms
15 |
16 | The registered object will be called with `obj(cfg)`
17 | """
18 |
19 | LABEL_TRANSFORMS_REGISTRY = Registry('label_transforms')
20 | LABEL_TRANSFORMS_REGISTRY.__doc__ = """
21 | Registry for label transform functions, i.e. torchvision.transforms
22 |
23 | The registered object will be called with `obj(cfg)`
24 | """
25 |
26 | __all__ = [
27 | 'build_transforms',
28 | 'build_label_transforms',
29 | 'TRANSFORMS_REGISTRY',
30 | 'LABEL_TRANSFORMS_REGISTRY',
31 | 'DefaultTransforms',
32 | '_DefaultTransforms',
33 | 'BaseTransforms'
34 | ]
35 |
36 |
37 | def build_transforms(cfg):
38 | """
39 | Built the transforms, defined by `cfg.transforms.name`.
40 | """
41 | name = cfg.transforms.name
42 | return TRANSFORMS_REGISTRY.get(name)(cfg)
43 |
44 | def build_label_transforms(cfg):
45 | """
46 | Built the label transforms, defined by `cfg.label_transforms.name`.
47 | """
48 | name = cfg.label_transforms.name
49 | if name == 'default':
50 | return None
51 | return LABEL_TRANSFORMS_REGISTRY.get(name)(cfg)
52 |
53 | class BaseTransforms(object):
54 | def __init__(self, is_train, log_name):
55 | self.logger_print = Logger(__name__, log_name).getlogger()
56 | self.is_train = is_train
57 |
58 | def get_transform(self):
59 | if not self.is_train:
60 | self.logger_print.info('Generating validation transform ...')
61 | transform = self.valid_transform
62 | self.logger_print.info(f'Valid transform={transform}')
63 | else:
64 | self.logger_print.info('Generating training transform ...')
65 | transform = self.train_transform
66 | self.logger_print.info(f'Train transform={transform}')
67 | return transform
68 |
69 | @property
70 | def valid_transform(self):
71 | raise NotImplementedError
72 |
73 | @property
74 | def train_transform(self):
75 | raise NotImplementedError
76 |
77 |
78 | @TRANSFORMS_REGISTRY.register()
79 | def DefaultTransforms(cfg):
80 | is_train = cfg.dataset.is_train
81 | log_name = cfg.log.name
82 | mean = cfg.transforms.tensor.normalization.mean
83 | std = cfg.transforms.tensor.normalization.std
84 | img_size = cfg.input.size
85 |
86 | aug_imagenet = cfg.transforms.img.aug_imagenet
87 | aug_cifar = cfg.transforms.img.aug_cifar
88 | random_resized_crop = cfg.transforms.img.random_resized_crop
89 | resize = cfg.transforms.img.resize
90 | random_crop = cfg.transforms.img.random_crop
91 | center_crop = cfg.transforms.img.center_crop
92 | random_horizontal_flip = cfg.transforms.img.random_horizontal_flip
93 | random_vertical_flip = cfg.transforms.img.random_vertical_flip
94 | random_rotation = cfg.transforms.img.random_rotation
95 | color_jitter = cfg.transforms.img.color_jitter
96 | return _DefaultTransforms(is_train, log_name, img_size,
97 | aug_imagenet, aug_cifar,
98 | random_resized_crop,
99 | resize,
100 | random_crop,
101 | center_crop,
102 | random_horizontal_flip,
103 | random_vertical_flip,
104 | random_rotation,
105 | color_jitter,
106 | mean, std)
107 |
108 | class _DefaultTransforms(BaseTransforms):
109 | def __init__(self, is_train, log_name, img_size,
110 | aug_imagenet=False, aug_cifar=False,
111 | random_resized_crop={'enable':0},
112 | resize={'enable':1},
113 | random_crop={'enable':0, 'padding':0},
114 | center_crop={'enable':0},
115 | random_horizontal_flip={'enbale':0, 'p':0.5},
116 | random_vertical_flip={'enbale':0, 'p':0.5},
117 | random_rotation={'enbale':0, 'degrees':15},
118 | color_jitter={'enable':0,'brightness':0.1, 'contrast':0.1, 'saturation':0.1, 'hue':0.1},
119 | mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], *args, **kwargs):
120 | super(_DefaultTransforms, self).__init__(is_train, log_name)
121 | self.is_train = is_train
122 | self.mean = mean
123 | self.std = std
124 | self.img_size = img_size
125 | self.min_edge_size = min(self.img_size)
126 | self.normalize = transforms.Normalize(self.mean, self.std)
127 | self.aug_imagenet = aug_imagenet
128 | self.aug_cifar = aug_cifar
129 | self.random_resized_crop = random_resized_crop
130 | self.resize = resize
131 | self.random_crop = random_crop
132 | self.center_crop = center_crop
133 | self.random_horizontal_flip = random_horizontal_flip
134 | self.random_vertical_flip = random_vertical_flip
135 | self.random_rotation = random_rotation
136 | self.color_jitter = color_jitter
137 | self.transform = self.get_transform()
138 |
139 | @property
140 | def valid_transform(self):
141 | transform = transforms.Compose([
142 | transforms.Resize(self.img_size),
143 | transforms.ToTensor(),
144 | self.normalize
145 | ])
146 | return transform
147 |
148 | @property
149 | def train_transform(self):
150 | # aug_imagenet
151 | if self.aug_imagenet:
152 | self.logger_print.info('Using imagenet augmentation')
153 | transform = transforms.Compose([
154 | transforms.Resize(self.img_size),
155 | autoaugment.ImageNetPolicy(),
156 | transforms.ToTensor(),
157 | self.normalize
158 | ])
159 | # aug cifar
160 | elif self.aug_cifar:
161 | self.logger_print.info('Using cifar augmentation')
162 | transform = transforms.Compose([
163 | transforms.Resize(self.img_size),
164 | autoaugment.CIFAR10Policy(),
165 | transforms.ToTensor(),
166 | self.normalize
167 | ])
168 | # customized transformations
169 | else:
170 | transform = self.read_transform_from_cfg()
171 | return transform
172 |
173 | def read_transform_from_cfg(self):
174 | transform_list = []
175 | self.check_conflict_options()
176 |
177 | # resize and crop opertaion
178 | if self.random_resized_crop['enable']:
179 | transform_list.append(transforms.RandomResizedCrop(self.img_size))
180 | elif self.resize['enable']:
181 | transform_list.append(transforms.Resize(self.img_size))
182 | if self.random_crop['enable']:
183 | transform_list.append(transforms.RandomCrop(self.min_edge_size, padding=self.random_crop['padding']))
184 | elif self.center_crop['enable']:
185 | transform_list.append(transforms.CenterCrop(self.min_edge_size))
186 |
187 | # ColorJitter
188 | if self.color_jitter['enable']:
189 | params = {key: self.color_jitter[key] for key in self.color_jitter
190 | if key != 'enable'}
191 | transform_list.append(transforms.ColorJitter(**params))
192 |
193 | # horizontal flip
194 | if self.random_horizontal_flip['enable']:
195 | p = self.random_horizontal_flip['p']
196 | transform_list.append(transforms.RandomHorizontalFlip(p))
197 |
198 | # vertical flip
199 | if self.random_vertical_flip['enable']:
200 | p = self.random_vertical_flip['p']
201 | transform_list.append(transforms.RandomVerticalFlip(p))
202 |
203 | # rotation
204 | if self.random_rotation['enable']:
205 | degrees = self.random_rotation['degrees']
206 | transform_list.append(transforms.RandomRotation(degrees))
207 | transform_list.append(transforms.ToTensor())
208 | transform_list.append(self.normalize)
209 | transform_list = transforms.Compose(transform_list)
210 | assert len(transform_list.transforms) > 0, "You must apply transformations"
211 | return transform_list
212 |
213 | def check_conflict_options(self):
214 | count = self.random_resized_crop['enable'] + \
215 | self.resize['enable']
216 | assert count <= 1, 'You can only use one resize transform operation'
217 |
218 | count = self.random_crop['enable'] + \
219 | self.center_crop['enable']
220 | assert count <= 1, 'You can only use one crop transform operation'
--------------------------------------------------------------------------------
/torchline/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .default_module import *
2 | from .build import *
3 | from .utils import *
--------------------------------------------------------------------------------
/torchline/engine/build.py:
--------------------------------------------------------------------------------
1 |
2 | from torchline.utils import Registry, Logger
3 |
4 | MODULE_REGISTRY = Registry('MODULE')
5 | MODULE_REGISTRY.__doc__ = """
6 | Registry for module template, e.g. DefaultModule.
7 |
8 | The registered object will be called with `obj(cfg)`
9 | """
10 |
11 | __all__ = [
12 | 'MODULE_REGISTRY',
13 | 'build_module'
14 | ]
15 |
16 | def build_module(cfg):
17 | """
18 | Built the module template, defined by `cfg.module.name`.
19 | """
20 | name = cfg.module.name
21 | return MODULE_REGISTRY.get(name)(cfg)
22 |
23 |
--------------------------------------------------------------------------------
/torchline/engine/default_module.py:
--------------------------------------------------------------------------------
1 | """
2 | Example template for defining a system
3 | """
4 | import logging
5 | import os
6 | from argparse import ArgumentParser
7 | from collections import OrderedDict
8 |
9 | import pytorch_lightning as pl
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torchvision
14 | import torchvision.transforms as transforms
15 | from pytorch_lightning import LightningModule
16 | from sklearn import metrics
17 | from torch import optim
18 | from torch.utils.data import DataLoader
19 | from torch.utils.data.distributed import DistributedSampler
20 |
21 | from torchline.data import build_data, build_sampler
22 | from torchline.losses import build_loss_fn
23 | from torchline.models import build_model
24 | from torchline.utils import AverageMeterGroup, topk_acc
25 |
26 | from .build import MODULE_REGISTRY
27 | from .utils import generate_optimizer, generate_scheduler
28 |
29 | __all__ = [
30 | 'DefaultModule'
31 | ]
32 |
33 | @MODULE_REGISTRY.register()
34 | class DefaultModule(LightningModule):
35 | """
36 | Sample model to show how to define a template
37 | """
38 |
39 | def __init__(self, cfg):
40 | """
41 | Pass in parsed HyperOptArgumentParser to the model
42 | :param cfg:
43 | """
44 | # init superclass
45 | super(DefaultModule, self).__init__()
46 | self.cfg = cfg
47 | self.hparams = self.cfg.hparams
48 | self.batch_size = self.cfg.dataset.batch_size
49 |
50 | # if you specify an example input, the summary will show input/output for each layer
51 | h, w = self.cfg.input.size
52 | self.example_input_array = torch.rand(1, 3, h, w)
53 | self.crt_batch_idx = 0
54 | self.inputs = self.example_input_array
55 |
56 | # build model
57 | self.model = self.build_model(cfg)
58 | self.train_meters = AverageMeterGroup()
59 | self.valid_meters = AverageMeterGroup()
60 |
61 | # ---------------------
62 | # model SETUP
63 | # ---------------------
64 | def build_model(self, cfg):
65 | """
66 | Layout model
67 | :return:
68 | """
69 | return build_model(cfg)
70 |
71 | def build_loss_fn(self, cfg):
72 | """
73 | Layout loss_fn
74 | :return:
75 | """
76 | return build_loss_fn(cfg)
77 |
78 | def build_data(self, cfg, is_train):
79 | """
80 | Layout training dataset
81 | :return:
82 | """
83 | cfg.defrost()
84 | cfg.dataset.is_train = is_train
85 | cfg.freeze()
86 | return build_data(cfg)
87 |
88 | def build_sampler(self, cfg, is_train):
89 | """
90 | Layout training dataset
91 | :return:
92 | """
93 | cfg.defrost()
94 | cfg.dataset.is_train = is_train
95 | cfg.freeze()
96 | return build_sampler(cfg)
97 |
98 | # ---------------------
99 | # Hooks
100 | # ---------------------
101 |
102 | def on_train_start(self):
103 | ckpt_path = self.trainer.resume_from_checkpoint
104 | print(ckpt_path)
105 | if ckpt_path:
106 | if os.path.exists(ckpt_path):
107 | ckpt = torch.load(ckpt_path)
108 | best = ckpt['checkpoint_callback_best']
109 | self.log_info(f"The best result of the ckpt is {best}")
110 | else:
111 | print(f'{ckpt_path} not exists')
112 | raise NotImplementedError
113 |
114 | def on_epoch_start(self):
115 | if not self.cfg.trainer.progress_bar_refresh_rate:
116 | # print current lr
117 | if isinstance(self.trainer.optimizers, list):
118 | if len(self.trainer.optimizers) == 1:
119 | optimizer = self.trainer.optimizers[0]
120 | lr = optimizer.param_groups[0]["lr"]
121 | self.log_info(f"lr={lr:.4e}")
122 | else:
123 | for index, optimizer in enumerate(self.trainer.optimizers):
124 | lr = optimizer.param_groups[0]["lr"]
125 | name = str(optimizer).split('(')[0].strip()
126 | self.log_info(f"lr of {name}_{index} is {lr:.4e} ")
127 | else:
128 | lr = self.trainer.optimizers.param_groups[0]["lr"]
129 | self.log_info(f"lr={lr:.4e}")
130 |
131 | def on_epoch_end(self):
132 | if not self.cfg.trainer.progress_bar_refresh_rate:
133 | self.log_info(f'Final Train: {self.train_meters}')
134 | self.log_info(f'FInal Valid: {self.valid_meters}')
135 | self.log_info("===========================\n")
136 | self.train_meters = AverageMeterGroup()
137 | self.valid_meters = AverageMeterGroup()
138 |
139 | def log_info(self, *args, **kwargs):
140 | if self.trainer.proc_rank == 0:
141 | self.trainer.logger_print.info(*args, **kwargs)
142 |
143 | def log_warning(self, *args, **kwargs):
144 | if self.trainer.proc_rank == 0:
145 | self.trainer.logger_print.warning(*args, **kwargs)
146 | # ---------------------
147 | # TRAINING
148 | # ---------------------
149 |
150 | def forward(self, x):
151 | """
152 | No special modification required for lightning, define as you normally would
153 | :param x:
154 | :return:
155 |
156 | return middle features
157 | features = self.model.features(x)
158 | logits = self.model.logits(features)
159 | return logits
160 | """
161 | return self.model(x)
162 |
163 | def loss(self, predictions, gt_labels):
164 | loss_fn = self.build_loss_fn(self.cfg)
165 | return loss_fn(predictions, gt_labels)
166 |
167 | def print_log(self, batch_idx, is_train, inputs, meters, save_examples=False):
168 | flag = batch_idx % self.cfg.trainer.log_save_interval == 0
169 | if not self.trainer.progress_bar_refresh_rate and flag:
170 | if is_train:
171 | _type = 'Train'
172 | all_step = self.trainer.num_training_batches
173 | else:
174 | _type = 'Valid'
175 | all_step = self.trainer.num_val_batches
176 | crt_epoch, crt_step = self.trainer.current_epoch, batch_idx
177 | all_epoch = self.trainer.max_epochs
178 | log_info = f"{_type} Epoch {crt_epoch}/{all_epoch} step {crt_step}/{all_step} {meters}"
179 | self.log_info(log_info)
180 |
181 | if self.current_epoch==0 and batch_idx==0 and save_examples:
182 | if not os.path.exists('train_valid_samples'):
183 | os.makedirs('train_valid_samples')
184 | for i, img in enumerate(inputs[:5]):
185 | torchvision.transforms.ToPILImage()(img.cpu()).save(f'./train_valid_samples/{_type}_img{i}.jpg')
186 |
187 | def training_step(self, batch, batch_idx):
188 | """
189 | Lightning calls this inside the training loop
190 | :param batch:
191 | :return:
192 | """
193 |
194 | # forward pass
195 | inputs, gt_labels = batch
196 | predictions = self.forward(inputs)
197 |
198 | # calculate loss
199 | loss_val = self.loss(predictions, gt_labels)
200 |
201 | # acc
202 | acc_results = topk_acc(predictions, gt_labels, self.cfg.topk)
203 | tqdm_dict = {}
204 |
205 | if self.on_gpu:
206 | acc_results = [torch.tensor(x).to(loss_val.device.index) for x in acc_results]
207 |
208 | # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
209 | if self.trainer.use_dp or self.trainer.use_ddp2:
210 | loss_val = loss_val.unsqueeze(0)
211 | acc_results = [x.unsqueeze(0) for x in acc_results]
212 |
213 | tqdm_dict['train_loss'] = loss_val
214 | for i, k in enumerate(self.cfg.topk):
215 | tqdm_dict[f'train_acc_{k}'] = acc_results[i]
216 |
217 | output = OrderedDict({
218 | 'loss': loss_val,
219 | 'progress_bar': tqdm_dict,
220 | 'log': tqdm_dict
221 | })
222 |
223 | self.train_meters.update({key: val.item() for key, val in tqdm_dict.items()})
224 | # self.print_log(batch_idx, True, inputs, self.train_meters)
225 |
226 | # can also return just a scalar instead of a dict (return loss_val)
227 | return output
228 |
229 | def validation_step(self, batch, batch_idx):
230 | """
231 | Lightning calls this inside the validation loop
232 | :param batch:
233 | :return:
234 | """
235 | inputs, gt_labels = batch
236 | predictions = self.forward(inputs)
237 |
238 | loss_val = self.loss(predictions, gt_labels)
239 |
240 | # acc
241 | val_acc_1, val_acc_k = topk_acc(predictions, gt_labels, self.cfg.topk)
242 |
243 | if self.on_gpu:
244 | val_acc_1 = val_acc_1.cuda(loss_val.device.index)
245 | val_acc_k = val_acc_k.cuda(loss_val.device.index)
246 |
247 | # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
248 | if self.trainer.use_dp or self.trainer.use_ddp2:
249 | loss_val = loss_val.unsqueeze(0)
250 | val_acc_1 = val_acc_1.unsqueeze(0)
251 | val_acc_k = val_acc_k.unsqueeze(0)
252 |
253 | output = OrderedDict({
254 | 'valid_loss': loss_val,
255 | 'valid_acc_1': val_acc_1,
256 | f'valid_acc_{self.cfg.topk[-1]}': val_acc_k,
257 | })
258 | tqdm_dict = {k: v for k, v in dict(output).items()}
259 | self.valid_meters.update({key: val.item() for key, val in tqdm_dict.items()})
260 | # self.print_log(batch_idx, False, inputs, self.valid_meters)
261 |
262 | if self.cfg.module.analyze_result:
263 | output.update({
264 | 'predictions': predictions.detach(),
265 | 'gt_labels': gt_labels.detach(),
266 | })
267 | # can also return just a scalar instead of a dict (return loss_val)
268 | return output
269 |
270 | def training_step_end(self, output):
271 | self.print_log(self.trainer.batch_idx, True, self.inputs, self.train_meters)
272 | return output
273 |
274 | def validation_step_end(self, output):
275 | self.crt_batch_idx += 1
276 | self.print_log(self.crt_batch_idx, False, self.inputs, self.valid_meters)
277 | return output
278 |
279 | def validation_epoch_end(self, outputs):
280 | """
281 | Called at the end of validation to aggregate outputs
282 | :param outputs: list of individual outputs of each validation step
283 | :return:
284 | """
285 | # if returned a scalar from validation_step, outputs is a list of tensor scalars
286 | # we return just the average in this case (if we want)
287 | # return torch.stack(outputs).mean()
288 |
289 | self.crt_batch_idx = 0
290 | tqdm_dict = {key: val.avg for key, val in self.valid_meters.meters.items()}
291 | valid_loss = torch.tensor(self.valid_meters.meters['valid_loss'].avg)
292 | valid_acc_1 = torch.tensor(self.valid_meters.meters['valid_acc_1'].avg)
293 | result = {'progress_bar': tqdm_dict, 'log': tqdm_dict,
294 | 'valid_loss': valid_loss,
295 | 'valid_acc_1': valid_acc_1}
296 |
297 | if self.cfg.module.analyze_result:
298 | predictions = []
299 | gt_labels = []
300 | for output in outputs:
301 | predictions.append(output['predictions'])
302 | gt_labels.append(output['gt_labels'])
303 | predictions = torch.cat(predictions)
304 | gt_labels = torch.cat(gt_labels)
305 | analyze_result = self.analyze_result(gt_labels, predictions)
306 | self.log_info(analyze_result)
307 | # result.update({'analyze_result': analyze_result})
308 | return result
309 |
310 | def test_step(self, batch, batch_idx):
311 | return self.validation_step(batch, batch_idx)
312 |
313 | def test_epoch_end(self, outputs):
314 | return self.validation_epoch_end(outputs)
315 |
316 | def analyze_result(self, gt_labels, predictions):
317 | '''
318 | Args:
319 | gt_lables: tensor (N)
320 | predictions: tensor (N*C)
321 | '''
322 | return str(metrics.classification_report(gt_labels.cpu(), predictions.cpu().argmax(1)))
323 |
324 | # ---------------------
325 | # TRAINING SETUP
326 | # ---------------------
327 | @classmethod
328 | def parse_cfg_for_scheduler(cls, cfg, scheduler_name):
329 | if scheduler_name.lower() == 'CosineAnnealingLR'.lower():
330 | params = {'T_max': cfg.optim.scheduler.t_max}
331 | elif scheduler_name.lower() == 'CosineAnnealingWarmRestarts'.lower():
332 | params = {'T_0': cfg.optim.scheduler.t_0, 'T_mult': cfg.optim.scheduler.t_mul}
333 | elif scheduler_name.lower() == 'StepLR'.lower():
334 | params = {'step_size': cfg.optim.scheduler.step_size, 'gamma': cfg.optim.scheduler.gamma}
335 | elif scheduler_name.lower() == 'MultiStepLR'.lower():
336 | params = {'milestones': cfg.optim.scheduler.milestones, 'gamma': cfg.optim.scheduler.gamma}
337 | elif scheduler_name.lower() == 'ReduceLROnPlateau'.lower():
338 | params = {'mode': cfg.optim.scheduler.mode, 'patience': cfg.optim.scheduler.patience,
339 | 'verbose': cfg.optim.scheduler.verbose, 'factor': cfg.optim.scheduler.gamma}
340 | else:
341 | print(f"{scheduler_name} not implemented")
342 | raise NotImplementedError
343 | return params
344 |
345 | def configure_optimizers(self):
346 | """
347 | return whatever optimizers we want here
348 | :return: list of optimizers
349 | """
350 | optim_name = self.cfg.optim.name
351 | momentum = self.cfg.optim.momentum
352 | weight_decay = self.cfg.optim.weight_decay
353 | lr = self.cfg.optim.base_lr
354 | optimizer = generate_optimizer(self.model, optim_name, lr, momentum, weight_decay)
355 | scheduler_params = self.parse_cfg_for_scheduler(self.cfg, self.cfg.optim.scheduler.name)
356 | scheduler = generate_scheduler(optimizer, self.cfg.optim.scheduler.name, **scheduler_params)
357 | return [optimizer], [scheduler]
358 |
359 | def __dataloader(self, is_train):
360 | # init data generators
361 | dataset = self.build_data(self.cfg, is_train)
362 |
363 | # when using multi-node (ddp) we need to add the datasampler
364 | train_sampler = self.build_sampler(self.cfg, is_train)
365 | batch_size = self.batch_size
366 |
367 | if self.use_ddp:
368 | train_sampler = DistributedSampler(dataset)
369 |
370 | should_shuffle = train_sampler is None
371 | loader = DataLoader(
372 | dataset=dataset,
373 | batch_size=batch_size,
374 | shuffle=should_shuffle if is_train else False,
375 | sampler=train_sampler,
376 | num_workers=self.cfg.dataloader.num_workers
377 | )
378 |
379 | return loader
380 |
381 | @pl.data_loader
382 | def train_dataloader(self):
383 | logging.info('training data loader called')
384 | return self.__dataloader(is_train=True)
385 |
386 | @pl.data_loader
387 | def val_dataloader(self):
388 | logging.info('val data loader called')
389 | return self.__dataloader(is_train=False)
390 |
391 | @pl.data_loader
392 | def test_dataloader(self):
393 | logging.info('test data loader called')
394 | return self.__dataloader(is_train=False)
395 |
--------------------------------------------------------------------------------
/torchline/engine/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | __all__ = [
4 | 'generate_optimizer',
5 | 'generate_scheduler'
6 | ]
7 |
8 | def generate_optimizer(model, optim_name, lr, momentum=0.9, weight_decay=1e-5):
9 | '''
10 | return torch.optim.Optimizer
11 | '''
12 | if optim_name.lower() == 'sgd':
13 | return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
14 | elif optim_name.lower() == 'adadelta':
15 | return torch.optim.Adagrad(model.parameters(), lr=lr, weight_decay=weight_decay)
16 | elif optim_name.lower() == 'adam':
17 | return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
18 | elif optim_name.lower() == 'rmsprop':
19 | return torch.optim.RMSprop(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
20 | else:
21 | print(f"{optim_name} not implemented")
22 | raise NotImplementedError
23 |
24 | def generate_scheduler(optimizer, scheduler_name, **params):
25 | '''
26 | return torch.optim.lr_scheduler
27 | '''
28 | if scheduler_name.lower() == 'CosineAnnealingLR'.lower():
29 | return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **params)
30 | elif scheduler_name.lower() == 'CosineAnnealingWarmRestarts'.lower():
31 | return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, **params)
32 | elif scheduler_name.lower() == 'StepLR'.lower():
33 | return torch.optim.lr_scheduler.StepLR(optimizer, **params)
34 | elif scheduler_name.lower() == 'MultiStepLR'.lower():
35 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, **params)
36 | elif scheduler_name.lower() == 'ReduceLROnPlateau'.lower():
37 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **params)
38 | else:
39 | print(f"{scheduler_name} not implemented")
40 | raise NotImplementedError
--------------------------------------------------------------------------------
/torchline/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import *
2 | from .loss import *
3 | from .focal_loss import *
--------------------------------------------------------------------------------
/torchline/losses/build.py:
--------------------------------------------------------------------------------
1 | from torchline.utils import Registry
2 |
3 | LOSS_FN_REGISTRY = Registry("LOSS_FN")
4 | LOSS_FN_REGISTRY.__doc__ = """
5 | Registry for loss function, e.g. cross entropy loss.
6 |
7 | The registered object will be called with `obj(cfg)`
8 | """
9 |
10 | __all__ = [
11 | 'build_loss_fn',
12 | 'LOSS_FN_REGISTRY'
13 | ]
14 |
15 | def build_loss_fn(cfg):
16 | """
17 | Built the loss function, defined by `cfg.loss.name`.
18 | """
19 | name = cfg.loss.name
20 | return LOSS_FN_REGISTRY.get(name)(cfg)
21 |
--------------------------------------------------------------------------------
/torchline/losses/focal_loss.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from .build import LOSS_FN_REGISTRY
6 |
7 | __all__ = [
8 | 'FocalLoss',
9 | '_FocalLoss'
10 | ]
11 |
12 | @LOSS_FN_REGISTRY.register()
13 | def FocalLoss(cfg):
14 | alpha = cfg.loss.focal_loss.alpha
15 | gamma = cfg.loss.focal_loss.gamma
16 | size_average = cfg.loss.focal_loss.size_average
17 | num_classes = cfg.model.classes
18 | return _FocalLoss(alpha, gamma, num_classes, size_average)
19 |
20 | class _FocalLoss(torch.nn.Module):
21 | def __init__(self, alpha, gamma, num_classes, size_average=True):
22 | """focal_loss function: -α(1-yi)**γ *ce_loss(xi,yi)
23 | Args:
24 | alpha: class weight (default 0.25).
25 | When α is a 'list', it indicates the class-wise weights;
26 | When α is a constant,
27 | if in detection task, it indicates that the class-wise weights are[α, 1-α, 1-α, ...],
28 | the first class indicates the background
29 | if in classification task, it indicates that the class-wise weights are the same
30 | gamma: γ (default 2), focusing paramter smoothly adjusts the rate at which easy examples are down-weighted.
31 | num_classes: the number of classes
32 | size_average: (default 'mean'/'sum') specify the way to compute the loss value
33 | """
34 |
35 | super(_FocalLoss,self).__init__()
36 | self.size_average = size_average
37 | if isinstance(alpha,list):
38 | assert len(alpha)==num_classes
39 | alpha /= np.sum(alpha) # setting the value in range of [0, 1]
40 | # print("Focal loss alpha = {}, assign specific weights for each class".format(alpha))
41 | self.alpha = torch.Tensor(alpha)
42 | else:
43 | assert alpha<=1
44 | self.alpha = torch.zeros(num_classes)+0.00001
45 |
46 | # classification task
47 | # print("Focal loss alpha={}, the weight for each class is the same".format(alpha))
48 | self.alpha += alpha
49 |
50 | # detection task # 如果α为一个常数,则降低第一类的影响,在目标检测中背景为第一类
51 | # print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
52 | # self.alpha[0] += alpha
53 | # self.alpha[1:] += (1-alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
54 | self.gamma = gamma
55 |
56 | def forward(self, predictions, labels):
57 | """
58 | focal_loss损失计算
59 | Args:
60 | preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
61 | labels: 实际类别. size:[B,N] or [B]
62 | return:
63 | loss
64 | """
65 | assert predictions.dim()==2 and labels.dim()==1
66 | preds = predictions.view(-1,predictions.size(-1)) # num*classes
67 | alpha = self.alpha.to(labels.device)
68 |
69 | # 这里并没有直接使用log_softmax, 因为后面会用到softmax的结果(当然你也可以使用log_softmax,然后进行exp操作)
70 | preds_softmax = F.softmax(preds, dim=1)
71 | preds_logsoft = torch.log(preds_softmax)
72 |
73 | # implement nll_loss ( crossempty = log_softmax + nll )
74 | preds_softmax = preds_softmax.gather(1,labels.view(-1,1)) # num*1
75 | preds_logsoft = preds_logsoft.gather(1,labels.view(-1,1)) # num*1
76 | alpha = alpha.gather(0,labels.view(-1)) # num
77 |
78 | # calc loss
79 | # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
80 | # shape: num*1
81 | loss = -torch.mul( torch.pow((1-preds_softmax), self.gamma), preds_logsoft )
82 | # α * (1-pt)**γ * ce_loss
83 | # shape:
84 | loss = torch.mul(alpha, loss.t())
85 | del preds
86 | del alpha
87 | if self.size_average:
88 | loss = loss.mean()
89 | else:
90 | loss = loss.sum()
91 | return loss
--------------------------------------------------------------------------------
/torchline/losses/loss.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import torch
3 | import torch.nn.functional as F
4 | from .build import LOSS_FN_REGISTRY
5 |
6 | __all__ = [
7 | 'CrossEntropy',
8 | 'CrossEntropyLabelSmooth',
9 | '_CrossEntropyLabelSmooth'
10 | ]
11 |
12 | @LOSS_FN_REGISTRY.register()
13 | def CrossEntropy(cfg):
14 | weight = cfg.loss.class_weight
15 | if weight in ['', None, []]:
16 | weight = None
17 | else:
18 | weight = torch.tensor(weight)
19 | if torch.cuda.is_available(): weight=weight.cuda()
20 | return torch.nn.CrossEntropyLoss(weight=weight)
21 |
22 | @LOSS_FN_REGISTRY.register()
23 | def CrossEntropyLabelSmooth(cfg):
24 | try:
25 | label_smoothing = cfg.loss.label_smoothing
26 | except:
27 | label_smoothing = 0.1
28 | return _CrossEntropyLabelSmooth(label_smoothing)
29 |
30 | class _CrossEntropyLabelSmooth(torch.nn.Module):
31 | def __init__(self, label_smoothing):
32 | super(_CrossEntropyLabelSmooth, self).__init__()
33 | self.label_smoothing = label_smoothing
34 |
35 | def forward(self, pred, target):
36 | logsoftmax = torch.nn.LogSoftmax(dim=1)
37 | n_classes = pred.size(1)
38 | # convert to one-hot
39 | target = torch.unsqueeze(target, 1)
40 | soft_target = torch.zeros_like(pred)
41 | soft_target.scatter_(1, target, 1)
42 | # label smoothing
43 | soft_target = soft_target * (1 - self.label_smoothing) + self.label_smoothing / n_classes
44 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
--------------------------------------------------------------------------------
/torchline/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import *
2 | from .resnet import *
3 | from .pnasnet import *
4 | from .efficientnet import *
5 | from .dpn import *
6 | from .nas_models import *
7 | from .mnasnet import *
8 | from .mobilenet import *
9 |
10 | model_list = [
11 | 'DPN26',
12 | 'DPN92',
13 | 'EfficientNetB0',
14 | 'MNASNet0_5',
15 | 'MNASNet0_75',
16 | 'MNASNet1_0',
17 | 'MNASNet1_3',
18 | 'MobileNet_V2',
19 | 'Nasnetamobile',
20 | 'Pnasnet5large',
21 | 'PNASNetA',
22 | 'PNASNetB',
23 | 'Resnet18',
24 | 'Resnet34',
25 | 'Resnet50',
26 | 'Resnet101',
27 | 'Resnet152',
28 | 'Resnext50_32x4d',
29 | 'Resnext101_32x8d',
30 | 'Wide_resnet50_2',
31 | 'Wide_resnet101_2',
32 | ]
--------------------------------------------------------------------------------
/torchline/models/build.py:
--------------------------------------------------------------------------------
1 | from torchline.utils import Registry
2 |
3 | META_ARCH_REGISTRY = Registry("META_ARCH")
4 | META_ARCH_REGISTRY.__doc__ = """
5 | Registry for meta-architectures, i.e. the whole model.
6 |
7 | The registered object will be called with `obj(cfg)`
8 | and expected to return a `nn.Module` object.
9 | """
10 |
11 | __all__ = [
12 | 'build_model',
13 | 'META_ARCH_REGISTRY'
14 | ]
15 |
16 | def build_model(cfg):
17 | """
18 | Built the whole model, defined by `cfg.model.name`.
19 | """
20 | name = cfg.model.name
21 | return META_ARCH_REGISTRY.get(name)(cfg)
22 |
--------------------------------------------------------------------------------
/torchline/models/dpn.py:
--------------------------------------------------------------------------------
1 | '''Dual Path Networks in PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .build import META_ARCH_REGISTRY
7 |
8 | __all__ = [
9 | 'DPN',
10 | 'DPN26',
11 | 'DPN92'
12 | ]
13 |
14 | class Bottleneck(nn.Module):
15 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer):
16 | super(Bottleneck, self).__init__()
17 | self.out_planes = out_planes
18 | self.dense_depth = dense_depth
19 |
20 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False)
21 | self.bn1 = nn.BatchNorm2d(in_planes)
22 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False)
23 | self.bn2 = nn.BatchNorm2d(in_planes)
24 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False)
25 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth)
26 |
27 | self.shortcut = nn.Sequential()
28 | if first_layer:
29 | self.shortcut = nn.Sequential(
30 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False),
31 | nn.BatchNorm2d(out_planes+dense_depth)
32 | )
33 |
34 | def forward(self, x):
35 | out = F.relu(self.bn1(self.conv1(x)))
36 | out = F.relu(self.bn2(self.conv2(out)))
37 | out = self.bn3(self.conv3(out))
38 | x = self.shortcut(x)
39 | d = self.out_planes
40 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1)
41 | out = F.relu(out)
42 | return out
43 |
44 | @META_ARCH_REGISTRY.register()
45 | class DPN(nn.Module):
46 | def __init__(self, cfg, num_classes=10):
47 | super(DPN, self).__init__()
48 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes']
49 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth']
50 |
51 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
52 | self.bn1 = nn.BatchNorm2d(64)
53 | self.last_planes = 64
54 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1)
55 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2)
56 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2)
57 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2)
58 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1)
59 | self.last_linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], num_classes)
60 |
61 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride):
62 | strides = [stride] + [1]*(num_blocks-1)
63 | layers = []
64 | for i,stride in enumerate(strides):
65 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0))
66 | self.last_planes = out_planes + (i+2) * dense_depth
67 | return nn.Sequential(*layers)
68 |
69 | def features(self, x):
70 | out = F.relu(self.bn1(self.conv1(x)))
71 | out = self.layer1(out)
72 | out = self.layer2(out)
73 | out = self.layer3(out)
74 | out = self.layer4(out)
75 | return out
76 |
77 | def logits(self, x):
78 | out = self.g_avg_pool(x)
79 | out = self.last_linear(out.view(out.size(0), -1))
80 | return out
81 |
82 | def forward(self, x):
83 | out = self.features(x)
84 | out = self.logits(out)
85 | return out
86 |
87 | @META_ARCH_REGISTRY.register()
88 | class DPN26(DPN):
89 | def __init__(self, cfg):
90 | num_classes = cfg.model.classes
91 | net_cfg = {
92 | 'in_planes': (96,192,384,768),
93 | 'out_planes': (256,512,1024,2048),
94 | 'num_blocks': (2,2,2,2),
95 | 'dense_depth': (16,32,24,128)
96 | }
97 | super(DPN26, self).__init__(net_cfg, num_classes)
98 |
99 | @META_ARCH_REGISTRY.register()
100 | class DPN92(DPN):
101 | def __init__(self, cfg):
102 | num_classes = cfg.model.classes
103 | net_cfg = {
104 | 'in_planes': (96,192,384,768),
105 | 'out_planes': (256,512,1024,2048),
106 | 'num_blocks': (3,4,20,3),
107 | 'dense_depth': (16,32,24,128)
108 | }
109 | super(DPN92, self).__init__(net_cfg, num_classes)
110 |
--------------------------------------------------------------------------------
/torchline/models/efficientnet.py:
--------------------------------------------------------------------------------
1 | '''EfficientNet in PyTorch.
2 |
3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks".
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | from .build import META_ARCH_REGISTRY
11 |
12 | __all__ = [
13 | 'EfficientNet',
14 | 'EfficientNetB0',
15 | ]
16 |
17 | class Block(nn.Module):
18 | '''expand + depthwise + pointwise + squeeze-excitation'''
19 |
20 | def __init__(self, in_planes, out_planes, expansion, stride):
21 | super(Block, self).__init__()
22 | self.stride = stride
23 |
24 | planes = expansion * in_planes
25 | self.conv1 = nn.Conv2d(
26 | in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
27 | self.bn1 = nn.BatchNorm2d(planes)
28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
29 | stride=stride, padding=1, groups=planes, bias=False)
30 | self.bn2 = nn.BatchNorm2d(planes)
31 | self.conv3 = nn.Conv2d(
32 | planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
33 | self.bn3 = nn.BatchNorm2d(out_planes)
34 |
35 | self.shortcut = nn.Sequential()
36 | if stride == 1 and in_planes != out_planes:
37 | self.shortcut = nn.Sequential(
38 | nn.Conv2d(in_planes, out_planes, kernel_size=1,
39 | stride=1, padding=0, bias=False),
40 | nn.BatchNorm2d(out_planes),
41 | )
42 |
43 | # SE layers
44 | self.fc1 = nn.Conv2d(out_planes, out_planes//16, kernel_size=1)
45 | self.fc2 = nn.Conv2d(out_planes//16, out_planes, kernel_size=1)
46 |
47 | def forward(self, x):
48 | out = F.relu(self.bn1(self.conv1(x)))
49 | out = F.relu(self.bn2(self.conv2(out)))
50 | out = self.bn3(self.conv3(out))
51 | shortcut = self.shortcut(x) if self.stride == 1 else out
52 | # Squeeze-Excitation
53 | w = F.avg_pool2d(out, out.size(2))
54 | w = F.relu(self.fc1(w))
55 | w = self.fc2(w).sigmoid()
56 | out = out * w + shortcut
57 | return out
58 |
59 | @META_ARCH_REGISTRY.register()
60 | class EfficientNet(nn.Module):
61 | def __init__(self, layers, num_classes=10):
62 | super(EfficientNet, self).__init__()
63 | self.layers = layers
64 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3,
65 | stride=1, padding=1, bias=False)
66 | self.bn1 = nn.BatchNorm2d(32)
67 | self.layers = self._make_layers(in_planes=32)
68 | self.last_linear = nn.Linear(layers[-1][1], num_classes)
69 |
70 | def _make_layers(self, in_planes):
71 | layers = []
72 | for expansion, out_planes, num_blocks, stride in self.layers:
73 | strides = [stride] + [1]*(num_blocks-1)
74 | for stride in strides:
75 | layers.append(Block(in_planes, out_planes, expansion, stride))
76 | in_planes = out_planes
77 | return nn.Sequential(*layers)
78 |
79 | def features(self, x): # 2*3*32*32
80 | out = F.relu(self.bn1(self.conv1(x))) # 2*32*32*32
81 | out = self.layers(out) # 2*320*1*1
82 | return out
83 |
84 | def logits(self, x):
85 | out = nn.AdaptiveAvgPool2d((1,1))(x) # 2*320*1*1
86 | out = out.view(out.size(0), -1)
87 | out = self.last_linear(out)
88 | return out
89 |
90 | def forward(self, x): # 2*3*32*32
91 | out = self.features(x) # 2*320*1*1
92 | out = self.logits(out) # 2*num_classes
93 | return out
94 |
95 |
96 | @META_ARCH_REGISTRY.register()
97 | class EfficientNetB0(EfficientNet):
98 | def __init__(self, cfg):
99 | num_classes = cfg.model.classes
100 | # (expansion, out_planes, num_blocks, stride)
101 | layers = [(1, 16, 1, 2),
102 | (6, 24, 2, 1),
103 | (6, 40, 2, 2),
104 | (6, 80, 3, 2),
105 | (6, 112, 3, 1),
106 | (6, 192, 4, 2),
107 | (6, 320, 1, 2)]
108 | super(EfficientNetB0, self).__init__(layers, num_classes)
--------------------------------------------------------------------------------
/torchline/models/mnasnet.py:
--------------------------------------------------------------------------------
1 | import types
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from torchvision.models.mnasnet import MNASNet
7 |
8 | from .build import META_ARCH_REGISTRY
9 |
10 | __all__ = [
11 | 'MNASNet',
12 | 'MNASNet0_5',
13 | 'MNASNet0_75',
14 | 'MNASNet1_0',
15 | 'MNASNet1_3',
16 | ]
17 |
18 | class MNASNet(nn.Module):
19 | # Modify attributs
20 | def __init__(self, model):
21 | super(MNASNet, self).__init__()
22 | for key, val in model.__dict__.items():
23 | self.__dict__[key] = val
24 | self.stem = model.layers[:8]
25 | self.layer1 = model.layers[8]
26 | self.layer2 = model.layers[9]
27 | self.layer3 = model.layers[10]
28 | self.layer4 = model.layers[11]
29 | self.layer5 = model.layers[12]
30 | self.layer6 = model.layers[13]
31 | self.layer7 = model.layers[14:]
32 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1)
33 | self.last_linear = model.classifier
34 |
35 | def features(self, x): # b*3*64*64
36 | out = self.stem(x) # b*16*32*32
37 | out = self.layer1(out) # b*16*16*16
38 | out = self.layer2(out) # b*24*8*8
39 | out = self.layer3(out) # b*40*4*4
40 | out = self.layer4(out) # b*48*4*4
41 | out = self.layer5(out) # b*96*2*2
42 | out = self.layer6(out) # b*160*2*2
43 | out = self.layer7(out) # b*1280*2*2
44 | return out
45 |
46 | def logits(self, x):
47 | out = x.mean([2, 3])
48 | out = self.last_linear(out)
49 | return out
50 |
51 | def forward(self, x):
52 | out = self.features(x)
53 | out = self.logits(out)
54 | return out
55 |
56 | def generate_model(cfg, name):
57 | pretrained=cfg.model.pretrained
58 | classes = cfg.model.classes
59 | if 'dropout' in cfg.model:
60 | dropout = cfg.model.dropout
61 | else:
62 | dropout = 0.2
63 | model = eval(f"models.{name}(pretrained={pretrained})")
64 | if classes != 1000:
65 | in_features = model.classifier[1].in_features
66 | model.classifier = nn.Sequential(
67 | nn.Dropout(p=dropout, inplace=True),
68 | nn.Linear(in_features, classes, bias=False))
69 | return MNASNet(model)
70 |
71 | @META_ARCH_REGISTRY.register()
72 | def MNASNet0_5(cfg):
73 | return generate_model(cfg, 'mnasnet0_5')
74 |
75 | @META_ARCH_REGISTRY.register()
76 | def MNASNet0_75(cfg):
77 | return generate_model(cfg, 'mnasnet0_75')
78 |
79 | @META_ARCH_REGISTRY.register()
80 | def MNASNet1_0(cfg):
81 | return generate_model(cfg, 'mnasnet1_0')
82 |
83 | @META_ARCH_REGISTRY.register()
84 | def MNASNet1_3(cfg):
85 | return generate_model(cfg, 'mnasnet1_3')
--------------------------------------------------------------------------------
/torchline/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | import types
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from torchvision.models.mobilenet import MobileNetV2
7 |
8 | from .build import META_ARCH_REGISTRY
9 |
10 | __all__ = [
11 | 'MobileNetV2',
12 | 'MobileNet_V2'
13 | ]
14 |
15 |
16 |
17 | class MobileNetV2(nn.Module):
18 | # Modify attributs
19 | def __init__(self, model):
20 | super(MobileNetV2, self).__init__()
21 | for key, val in model.__dict__.items():
22 | self.__dict__[key] = val
23 |
24 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1)
25 | self.last_linear = self.classifier
26 |
27 | def logits(self, x):
28 | out = self.g_avg_pool(x)
29 | out = out.view(out.size(0), -1)
30 | out = self.last_linear(out)
31 | return out
32 |
33 | def forward(self, x):
34 | out = self.features(x)
35 | out = self.logits(out)
36 | return out
37 |
38 | def generate_model(cfg, name):
39 | pretrained=cfg.model.pretrained
40 | classes = cfg.model.classes
41 | if 'dropout' in cfg.model:
42 | dropout = cfg.model.dropout
43 | else:
44 | dropout = 0.2
45 | model = eval(f"models.{name}(pretrained={pretrained})")
46 | if classes != 1000:
47 | in_features = model.classifier[1].in_features
48 | model.classifier = nn.Sequential(
49 | nn.Dropout(p=dropout, inplace=True),
50 | nn.Linear(in_features, classes, bias=False))
51 | return MobileNetV2(model)
52 |
53 | @META_ARCH_REGISTRY.register()
54 | def MobileNet_V2(cfg):
55 | return generate_model(cfg, 'mobilenet_v2')
--------------------------------------------------------------------------------
/torchline/models/nas_models.py:
--------------------------------------------------------------------------------
1 | import pretrainedmodels
2 | import torch.nn as nn
3 |
4 | from torchline.models import META_ARCH_REGISTRY
5 |
6 | __all__ = [
7 | 'Nasnetamobile',
8 | 'Pnasnet5large',
9 | ]
10 |
11 | def generate_model(cfg, name):
12 | pretrained='imagenet' if cfg.model.pretrained else None
13 | classes = cfg.model.classes
14 | img_size = cfg.input.size[0]
15 | model = pretrainedmodels.__dict__[name](num_classes=1000, pretrained=pretrained)
16 | model.avg_pool = nn.AdaptiveAvgPool2d(1)
17 | if classes != 1000:
18 | in_features = model.last_linear.in_features
19 | model.last_linear = nn.Sequential(
20 | nn.Linear(in_features, classes, bias=False))
21 | return model
22 |
23 | @META_ARCH_REGISTRY.register()
24 | def Nasnetamobile(cfg):
25 | return generate_model(cfg, 'nasnetamobile')
26 |
27 | @META_ARCH_REGISTRY.register()
28 | def Pnasnet5large(cfg):
29 | return generate_model(cfg, 'pnasnet5large')
30 |
--------------------------------------------------------------------------------
/torchline/models/pnasnet.py:
--------------------------------------------------------------------------------
1 | '''PNASNet in PyTorch.
2 |
3 | Paper: Progressive Neural Architecture Search
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from .build import META_ARCH_REGISTRY
10 |
11 | __all__ = [
12 | 'PNASNet',
13 | 'PNASNetA',
14 | 'PNASNetB',
15 | ]
16 |
17 | class SepConv(nn.Module):
18 | '''Separable Convolution.'''
19 | def __init__(self, in_planes, out_planes, kernel_size, stride):
20 | super(SepConv, self).__init__()
21 | self.conv1 = nn.Conv2d(in_planes, out_planes,
22 | kernel_size, stride,
23 | padding=(kernel_size-1)//2,
24 | bias=False, groups=in_planes)
25 | self.bn1 = nn.BatchNorm2d(out_planes)
26 |
27 | def forward(self, x):
28 | return self.bn1(self.conv1(x))
29 |
30 |
31 | class CellA(nn.Module):
32 | def __init__(self, in_planes, out_planes, stride=1):
33 | super(CellA, self).__init__()
34 | self.stride = stride
35 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
36 | if stride==2:
37 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
38 | self.bn1 = nn.BatchNorm2d(out_planes)
39 |
40 | def forward(self, x):
41 | y1 = self.sep_conv1(x)
42 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
43 | if self.stride==2:
44 | y2 = self.bn1(self.conv1(y2))
45 | return F.relu(y1+y2)
46 |
47 | class CellB(nn.Module):
48 | def __init__(self, in_planes, out_planes, stride=1):
49 | super(CellB, self).__init__()
50 | self.stride = stride
51 | # Left branch
52 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride)
53 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride)
54 | # Right branch
55 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride)
56 | if stride==2:
57 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
58 | self.bn1 = nn.BatchNorm2d(out_planes)
59 | # Reduce channels
60 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
61 | self.bn2 = nn.BatchNorm2d(out_planes)
62 |
63 | def forward(self, x):
64 | # Left branch
65 | y1 = self.sep_conv1(x)
66 | y2 = self.sep_conv2(x)
67 | # Right branch
68 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1)
69 | if self.stride==2:
70 | y3 = self.bn1(self.conv1(y3))
71 | y4 = self.sep_conv3(x)
72 | # Concat & reduce channels
73 | b1 = F.relu(y1+y2)
74 | b2 = F.relu(y3+y4)
75 | y = torch.cat([b1,b2], 1)
76 | return F.relu(self.bn2(self.conv2(y)))
77 |
78 | class PNASNet(nn.Module):
79 | def __init__(self, cell_type, num_cells, num_planes, num_classes=10):
80 | super(PNASNet, self).__init__()
81 | self.in_planes = num_planes
82 | self.cell_type = cell_type
83 |
84 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False)
85 | self.bn1 = nn.BatchNorm2d(num_planes)
86 |
87 | self.layer1 = self._make_layer(num_planes, num_cells=6)
88 | self.layer2 = self._downsample(num_planes*2)
89 | self.layer3 = self._make_layer(num_planes*2, num_cells=6)
90 | self.layer4 = self._downsample(num_planes*4)
91 | self.layer5 = self._make_layer(num_planes*4, num_cells=6)
92 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1)
93 | self.last_linear = nn.Linear(num_planes*4, num_classes)
94 |
95 | def _make_layer(self, planes, num_cells):
96 | layers = []
97 | for _ in range(num_cells):
98 | layers.append(self.cell_type(self.in_planes, planes, stride=1))
99 | self.in_planes = planes
100 | return nn.Sequential(*layers)
101 |
102 | def _downsample(self, planes):
103 | layer = self.cell_type(self.in_planes, planes, stride=2)
104 | self.in_planes = planes
105 | return layer
106 |
107 | def features(self, x):
108 | out = F.relu(self.bn1(self.conv1(x)))
109 | out = self.layer1(out)
110 | out = self.layer2(out)
111 | out = self.layer3(out)
112 | out = self.layer4(out)
113 | out = self.layer5(out)
114 | return out
115 |
116 | def logits(self, x):
117 | out = self.g_avg_pool(x)
118 | out = self.last_linear(out.view(out.size(0), -1))
119 | return out
120 |
121 | def forward(self, x):
122 | out = self.features(x)
123 | out = self.logits(out)
124 | return out
125 |
126 | @META_ARCH_REGISTRY.register()
127 | class PNASNetA(PNASNet):
128 | def __init__(self, cfg):
129 | num_classes = cfg.model.classes
130 | super(PNASNetA, self).__init__(CellA, num_cells=6, num_planes=44, num_classes=num_classes)
131 |
132 | @META_ARCH_REGISTRY.register()
133 | class PNASNetB(PNASNet):
134 | def __init__(self, cfg):
135 | num_classes = cfg.model.classes
136 | super(PNASNetB, self).__init__(CellB, num_cells=6, num_planes=32, num_classes=num_classes)
137 |
138 |
--------------------------------------------------------------------------------
/torchline/models/resnet.py:
--------------------------------------------------------------------------------
1 | import types
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from torchvision.models.resnet import ResNet
7 |
8 | from .build import META_ARCH_REGISTRY
9 |
10 | __all__ = [
11 | 'ResNet',
12 | 'Resnet18',
13 | 'Resnet34',
14 | 'Resnet50',
15 | 'Resnet101',
16 | 'Resnet152',
17 | 'Resnext50_32x4d',
18 | 'Resnext101_32x8d',
19 | 'Wide_resnet50_2',
20 | 'Wide_resnet101_2'
21 | ]
22 |
23 |
24 | class ResNet(nn.Module):
25 | # Modify attributs
26 | def __init__(self, model):
27 | super(ResNet, self).__init__()
28 | for key, val in model.__dict__.items():
29 | self.__dict__[key] = val
30 | self.g_avg_pool = nn.AdaptiveAvgPool2d(1)
31 | self.last_linear = self.fc
32 |
33 | def features(self, x):
34 | out = self.conv1(x)
35 | out = self.bn1(out)
36 | out = self.relu(out)
37 | out = self.maxpool(out)
38 | out = self.layer1(out)
39 | out = self.layer2(out)
40 | out = self.layer3(out)
41 | out = self.layer4(out)
42 | return out
43 |
44 | def logits(self, x):
45 | out = self.g_avg_pool(x)
46 | out = out.view(out.size(0), -1)
47 | out = self.last_linear(out)
48 | return out
49 |
50 | def forward(self, x):
51 | out = self.features(x)
52 | out = self.logits(out)
53 | return out
54 |
55 | def generate_model(cfg, name):
56 | pretrained=cfg.model.pretrained
57 | classes = cfg.model.classes
58 | model = eval(f"models.{name}(pretrained={pretrained})")
59 | if classes != 1000:
60 | in_features = model.fc.in_features
61 | model.fc = nn.Linear(in_features, classes, bias=False)
62 | return ResNet(model)
63 |
64 | @META_ARCH_REGISTRY.register()
65 | def Resnet18(cfg):
66 | return generate_model(cfg, 'resnet18')
67 |
68 | @META_ARCH_REGISTRY.register()
69 | def Resnet34(cfg):
70 | return generate_model(cfg, 'resnet34')
71 |
72 | @META_ARCH_REGISTRY.register()
73 | def Resnet50(cfg):
74 | return generate_model(cfg, 'resnet50')
75 |
76 | @META_ARCH_REGISTRY.register()
77 | def Resnet101(cfg):
78 | return generate_model(cfg, 'resnet101')
79 |
80 | @META_ARCH_REGISTRY.register()
81 | def Resnet152(cfg):
82 | return generate_model(cfg, 'resnet152')
83 |
84 | @META_ARCH_REGISTRY.register()
85 | def Resnext50_32x4d(cfg):
86 | return generate_model(cfg, 'resnext50_32x4d')
87 |
88 | @META_ARCH_REGISTRY.register()
89 | def Resnext101_32x8d(cfg):
90 | return generate_model(cfg, 'resnext101_32x8d')
91 |
92 | @META_ARCH_REGISTRY.register()
93 | def Wide_resnet50_2(cfg):
94 | return generate_model(cfg, 'wide_resnet50_2')
95 |
96 | @META_ARCH_REGISTRY.register()
97 | def Wide_resnet101_2(cfg):
98 | return generate_model(cfg, 'wide_resnet101_2')
--------------------------------------------------------------------------------
/torchline/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import *
2 | from .default_trainer import *
--------------------------------------------------------------------------------
/torchline/trainer/build.py:
--------------------------------------------------------------------------------
1 | from torchline.utils import Registry
2 |
3 | TRAINER_REGISTRY = Registry("TRAINER")
4 | TRAINER_REGISTRY.__doc__ = """
5 | Registry for meta-architectures, i.e. the whole model.
6 |
7 | The registered object will be called with `obj(cfg)`
8 | and expected to return a `nn.Module` object.
9 | """
10 |
11 | __all__ = [
12 | 'TRAINER_REGISTRY',
13 | 'build_trainer',
14 | ]
15 |
16 | def build_trainer(cfg, hparams):
17 | """
18 | Built the whole trainer, defined by `cfg.trainer.name`.
19 | """
20 | name = cfg.trainer.name
21 | return TRAINER_REGISTRY.get(name)(cfg, hparams)
22 |
--------------------------------------------------------------------------------
/torchline/trainer/default_trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import shutil
4 |
5 | from pytorch_lightning import Trainer
6 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
7 | from pytorch_lightning.logging import TestTubeLogger
8 | from torchline.utils import Logger
9 |
10 | from .build import TRAINER_REGISTRY
11 |
12 | __all__ = [
13 | 'DefaultTrainer'
14 | ]
15 |
16 |
17 | @TRAINER_REGISTRY.register()
18 | class DefaultTrainer(Trainer):
19 | def __init__(self, cfg, hparams):
20 | self.cfg = cfg
21 | self.hparams = hparams
22 |
23 | self.logger = self._logger()
24 | self.logger_print = Logger(__name__, cfg.log.name).getlogger()
25 | self.early_stop_callback = self._early_stop_callback()
26 | self.checkpoint_callback = self._checkpoint_callback()
27 |
28 | # you can update trainer_params to change different parameters
29 | self.trainer_params = dict(self.cfg.trainer)
30 | self.trainer_params.update({
31 | 'logger': self.logger,
32 | 'early_stop_callback': self.early_stop_callback,
33 | 'checkpoint_callback': self.checkpoint_callback,
34 | })
35 | self.trainer_params.pop('name')
36 | for key in self.trainer_params:
37 | self.trainer_params[key] = self.parse_cfg_param(self.trainer_params[key])
38 |
39 | super(DefaultTrainer, self).__init__(**self.trainer_params)
40 |
41 | def parse_cfg_param(self, cfg_item):
42 | return cfg_item if cfg_item not in ['', []] else None
43 |
44 | def _logger(self):
45 |
46 | # logger
47 | logger_cfg = self.cfg.trainer.logger
48 | assert logger_cfg.setting in [0,1,2], "You can only set three logger levels [0,1,2], but you set {}".format(logger_cfg.setting)
49 | if logger_cfg.type == 'mlflow':
50 | raise NotImplementedError
51 | # params = {key: logger_cfg.mlflow[key] for key in logger_cfg.mlflow}
52 | # custom_logger = MLFlowLogger(**params)
53 | elif logger_cfg.type == 'test_tube':
54 | params = {key: logger_cfg.test_tube[key] for key in logger_cfg.test_tube} # key: save_dir, name, version
55 |
56 | # save_dir: logger root path:
57 | if self.cfg.trainer.default_root_dir:
58 | save_dir = self.cfg.trainer.default_root_dir
59 | else:
60 | save_dir = logger_cfg.test_tube.save_dir
61 |
62 | # version
63 | version = self.cfg.trainer.logger.test_tube.version
64 | if logger_cfg.setting==1: # disable logger
65 | return False
66 | elif logger_cfg.setting==2: # use custom logger
67 | # if version < o, then use the default value. else use custom value
68 | if version<0:
69 | version = int(self.cfg.log.path.replace('/','')[-1])
70 | else:
71 | version = int(self.cfg.log.path.replace('/','')[-1])
72 |
73 | default_logger = TestTubeLogger(save_dir, name='torchline_logs', version=version)
74 | params.update({'version': version, 'name':logger_cfg.test_tube.name, 'save_dir': save_dir})
75 | custom_logger = TestTubeLogger(**params)
76 | else:
77 | print(f"{logger_cfg.type} not supported")
78 | raise NotImplementedError
79 |
80 | loggers = {
81 | 0: default_logger,
82 | 1: False,
83 | 2: custom_logger
84 | } # 0: True (default) 1: False 2: custom
85 | logger = loggers[logger_cfg.setting]
86 |
87 | return logger
88 |
89 | def _early_stop_callback(self):
90 | # early_stop_callback hooks
91 | hooks = self.cfg.hooks
92 | params = {key: hooks.early_stopping[key] for key in hooks.early_stopping if key != 'setting'}
93 | early_stop_callbacks = {
94 | 0: True, # default setting
95 | 1: False, # do not use early stopping
96 | 2: EarlyStopping(**params) # use custom setting
97 | }
98 | assert hooks.early_stopping.setting in early_stop_callbacks, 'The setting of early stopping can only be in [0,1,2]'
99 | early_stop_callback = early_stop_callbacks[hooks.early_stopping.setting]
100 | return early_stop_callback
101 |
102 | def _checkpoint_callback(self):
103 | # checkpoint_callback hooks
104 | hooks = self.cfg.hooks
105 | assert hooks.model_checkpoint.setting in [0,1,2], "You can only set three ckpt levels [0,1,2], but you set {}".format(hooks.model_checkpoint.setting)
106 |
107 | params = {key: hooks.model_checkpoint[key] for key in hooks.model_checkpoint if key != 'setting'}
108 | if hooks.model_checkpoint.setting==2:
109 | if hooks.model_checkpoint.filepath.strip()=='':
110 | filepath = os.path.join(self.cfg.log.path,'checkpoints')
111 | monitor = hooks.model_checkpoint.monitor
112 | filepath = os.path.join(filepath, '{epoch}-{%s:.2f}'%monitor)
113 | params.update({'filepath': filepath})
114 | else:
115 | self.logger_print.warn("The specified checkpoint path is not recommended!")
116 | checkpoint_callbacks = {
117 | 0: True,
118 | 1: False,
119 | 2: ModelCheckpoint(**params)
120 | }
121 | checkpoint_callback = checkpoint_callbacks[hooks.model_checkpoint.setting]
122 | return checkpoint_callback
123 |
--------------------------------------------------------------------------------
/torchline/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import *
2 | from .registry import *
3 | from .utils import *
4 | from .average_meter import *
--------------------------------------------------------------------------------
/torchline/utils/average_meter.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | __all__ = [
4 | 'AverageMeterGroup',
5 | 'AverageMeter'
6 | ]
7 |
8 | class AverageMeterGroup:
9 | """
10 | Average meter group for multiple average meters.
11 | """
12 |
13 | def __init__(self, verbose_type='avg'):
14 | self.meters = OrderedDict()
15 | self.verbose_type = verbose_type
16 |
17 | def update(self, data):
18 | """
19 | Update the meter group with a dict of metrics.
20 | Non-exist average meters will be automatically created.
21 | """
22 | for k, v in data.items():
23 | if k not in self.meters:
24 | self.meters[k] = AverageMeter(k, ":.4f", self.verbose_type)
25 | self.meters[k].update(v)
26 |
27 | def __getattr__(self, item):
28 | return self.meters[item]
29 |
30 | def __getitem__(self, item):
31 | return self.meters[item]
32 |
33 | def __str__(self):
34 | return " ".join(f"{v}" for v in self.meters.values())
35 |
36 | def summary(self):
37 | """
38 | Return a summary string of group data.
39 | """
40 | return " ".join(v.summary() for v in self.meters.values())
41 |
42 |
43 | class AverageMeter:
44 | """
45 | Computes and stores the average and current value.
46 | Parameters
47 | ----------
48 | name : str
49 | Name to display.
50 | fmt : str
51 | Format string to print the values.
52 | verbose_type : str
53 | 'all': value(avg)
54 | 'avg': avg
55 | """
56 |
57 | def __init__(self, name, fmt=':f', verbose_type='avg'):
58 | self.name = name
59 | self.fmt = fmt
60 | if verbose_type not in ['all', 'avg']:
61 | print('Not supported verbose type, using default verbose, "avg"')
62 | verbose_type = 'avg'
63 | self.verbose_type = verbose_type
64 | self.reset()
65 |
66 | def reset(self):
67 | """
68 | Reset the meter.
69 | """
70 | self.val = 0
71 | self.avg = 0
72 | self.sum = 0
73 | self.count = 0
74 |
75 | def update(self, val, n=1):
76 | """
77 | Update with value and weight.
78 | Parameters
79 | ----------
80 | val : float or int
81 | The new value to be accounted in.
82 | n : int
83 | The weight of the new value.
84 | """
85 | self.val = val
86 | self.sum += val * n
87 | self.count += n
88 | self.avg = self.sum / self.count
89 |
90 | def __str__(self):
91 | if self.verbose_type=='all':
92 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
93 | elif self.verbose_type=='avg':
94 | fmtstr = '{name} {avg' + self.fmt + '}'
95 | else:
96 | fmtstr = '{name} {avg' + self.fmt + '}'
97 | return fmtstr.format(**self.__dict__)
98 |
99 | def summary(self):
100 | fmtstr = '{name}: {avg' + self.fmt + '}'
101 | return fmtstr.format(**self.__dict__)
--------------------------------------------------------------------------------
/torchline/utils/logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding:utf-8 -*-
3 |
4 | import logging
5 | import time
6 | import os
7 |
8 | __all__ = [
9 | 'Logger'
10 | ]
11 |
12 | class Logger(object):
13 | def __init__(self, logger_name=None, filename=None, *args, **kwargs):
14 | '''
15 | 指定保存日志的文件路径,日志级别,以及调用文件
16 | 将日志存入到指定的文件中
17 | '''
18 |
19 | # 创建一个logger
20 | if filename is None:
21 | file = 'log.txt'
22 | else:
23 | file = filename
24 | self.logger = logging.getLogger(logger_name)
25 | self.logger.setLevel(logging.INFO)
26 | self.logger.propagate = False
27 | if (self.logger.hasHandlers()):
28 | self.logger.handlers.clear()
29 | formatter = logging.Formatter('[%(asctime)s] %(filename)s->%(funcName)s line:%(lineno)d [%(levelname)s]%(message)s')
30 |
31 | if file:
32 | hdlr = logging.FileHandler(file, 'a', encoding='utf-8')
33 | hdlr.setLevel(logging.INFO)
34 | hdlr.setFormatter(formatter)
35 | self.logger.addHandler(hdlr)
36 |
37 | strhdlr = logging.StreamHandler()
38 | strhdlr.setLevel(logging.INFO)
39 | strhdlr.setFormatter(formatter)
40 | self.logger.addHandler(strhdlr)
41 |
42 | if file: hdlr.close()
43 | strhdlr.close()
44 |
45 | def getlogger(self):
46 | return self.logger
47 |
--------------------------------------------------------------------------------
/torchline/utils/registry.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | # refernce: https://github.com/facebookresearch/fvcore/blob/51695ec2984ae7b4afae27a14d02986ef3f8afd8/fvcore/common/registry.py
4 |
5 | from typing import Dict, Optional
6 |
7 | __all__ = [
8 | 'Registry'
9 | ]
10 |
11 | class Registry(object):
12 | '''
13 | The registry that provides name -> object mapping, to support third-party
14 | users' custom modules.
15 | To create a registry (e.g. a backbone registry):
16 | .. code-block:: python
17 | BACKBONE_REGISTRY = Registry('BACKBONE')
18 | To register an object:
19 | .. code-block:: python
20 | @BACKBONE_REGISTRY.register()
21 | class MyBackbone():
22 | ...
23 | Or:
24 | .. code-block:: python
25 | BACKBONE_REGISTRY.register(MyBackbone)
26 | '''
27 |
28 | def __init__(self, name: str) -> None:
29 | """
30 | Args:
31 | name (str): the name of this registry
32 | """
33 | self._name: str = name
34 | self._obj_map: Dict[str, object] = {}
35 |
36 | def _do_register(self, name: str, obj: object) -> None:
37 | assert (
38 | name not in self._obj_map
39 | ), "An object named '{}' was already registered in '{}' registry!".format(
40 | name, self._name
41 | )
42 | self._obj_map[name] = obj
43 |
44 | def register(self, obj: object = None) -> Optional[object]:
45 | '''
46 | Register the given object under the the name `obj.__name__`.
47 | Can be used as either a decorator or not. See docstring of this class for usage.
48 | '''
49 | if obj is None:
50 | # used as a decorator
51 | def deco(func_or_class: object) -> object:
52 | name = func_or_class.__name__ # pyre-ignore
53 | self._do_register(name, func_or_class)
54 | return func_or_class
55 |
56 | return deco
57 |
58 | # used as a function call
59 | name = obj.__name__ # pyre-ignore
60 | self._do_register(name, obj)
61 |
62 | def get(self, name: str) -> object:
63 | ret = self._obj_map.get(name)
64 | if ret is None:
65 | raise KeyError(
66 | "No object named '{}' found in '{}' registry!".format(
67 | name, self._name
68 | )
69 | )
70 | return ret
--------------------------------------------------------------------------------
/torchline/utils/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import torch
5 | from PIL import Image
6 | from torchvision import transforms
7 |
8 | __all__ = [
9 | 'image_loader',
10 | 'get_imgs_to_predict',
11 | 'topk_acc',
12 | 'model_size'
13 | ]
14 |
15 | def image_loader(filename, cfg):
16 | '''load an image and convert it to tensor
17 | Args:
18 | filename: image filename
19 | cfg: CfgNode
20 | return:
21 | torch.tensor
22 | '''
23 | image = Image.open(filename).convert('RGB')
24 | mean = cfg.transforms.tensor.normalization.mean
25 | std = cfg.transforms.tensor.normalization.std
26 | img_size = cfg.input.size
27 |
28 | transform = transforms.Compose([
29 | transforms.Resize(img_size),
30 | transforms.ToTensor(),
31 | transforms.Normalize(mean, std),
32 | ])
33 | image = transform(image)
34 | return image
35 |
36 | def get_imgs_to_predict(path, cfg):
37 | ''''load images which are only used for prediction or testing
38 | Args:
39 | path: str
40 | return:
41 | torch.tensor (N*C*H*W)
42 | '''
43 | if os.path.isfile(path):
44 | images = image_loader(path, cfg).unsqueeze(0)
45 | elif os.path.isdir(path):
46 | image_types = [os.path.join(path,'*.jpg'), os.path.join(path,'*.png')]
47 | image_files = []
48 | images = {
49 | 'img_file': [],
50 | 'img_data': []
51 | }
52 | for img_type in image_types:
53 | image_files.extend(glob.glob(img_type))
54 | for img_file in image_files:
55 | images['img_file'].append(img_file)
56 | images['img_data'].append(image_loader(img_file, cfg))
57 | images['img_data'] = torch.stack(images['img_data'])
58 | return images
59 |
60 | def model_size(model):
61 | return sum([p.numel() for p in model.parameters()])*4/1024**2
62 |
63 | def topk_acc(output, target, topk=(1, 3)):
64 | """Computes the precision@k for the specified values of k"""
65 | maxk = max(topk)
66 |
67 | _, pred = output.topk(maxk, 1, True, True)
68 | pred = pred.t()
69 | correct = pred.eq(target.view(1, -1).expand_as(pred))
70 |
71 | res = []
72 | for k in topk:
73 | correct_k = correct[:k].view(-1).float().sum(0)
74 | res.append(correct_k*100.0/len(target.view(-1)))
75 | # res.append(correct_k.mul_(100.0 / batch_size))
76 | return torch.tensor(res)
77 |
--------------------------------------------------------------------------------