├── .gitignore ├── README.md ├── clr.py ├── data └── README.md ├── dataset.py ├── flops_benchmark.py ├── logger.py ├── mobile_net.py ├── requirements.txt ├── resnet.py ├── results └── README.md ├── run.py ├── start.py └── util ├── __init__.py └── transfor.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | # own ignore 106 | *.pyc 107 | .idea 108 | *.tar 109 | tmp 110 | /data/*.txt 111 | /results/*_* 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 | read 4 | stars 5 | forks 6 | issues 7 |

8 |
9 | 10 | # pytorch_train 11 | 12 | ## 模型训练 13 | 14 | Github 地址:[pytorch_train](https://github.com/tf2jaguar/pytorch_train) 15 | 16 | ![UtyvKf.png](https://s1.ax1x.com/2020/07/14/UtyvKf.png) 17 | 18 | 训练模型主要分为五个模块:启动器、自定义数据加载器、网络模型、学习率/损失率调整以及训练可视化。 19 | 20 | 启动器是项目的入口,通过对启动器参数的设置,可以进行很多灵活的启动方式,下图为部分启动器参数设置。 21 | 22 | ![UtciwD.png](https://s1.ax1x.com/2020/07/14/UtciwD.png) 23 | 24 | 任何一个深度学习的模型训练都是离不开数据集的,根据多种多样的数据集,我们应该使用一个方式将数据集用一种通用的结构返回,方便网络模型的加载处理。 25 | 26 | ![Utc9OK.png](https://s1.ax1x.com/2020/07/14/Utc9OK.png) 27 | 28 | 这里使用了残差网络Resnet-34,代码中还提供了Resnet-18、Resnet-50、Resnet-101以及Resnet-152。残差结构是通过一个快捷连接,极大的减少了参数数量,降低了内存使用。 29 | 30 | 以下为残差网络的基本结构和Resnet-34 部分网络结构图。 31 | 32 | ![UtcPeO.png](https://s1.ax1x.com/2020/07/14/UtcPeO.png) 33 | 34 | ![Utcn6P.png](https://s1.ax1x.com/2020/07/14/Utcn6P.png) 35 | 36 | 37 | 除了最开始看到的train-val图表、Top-、Top-5的error记录表以外,在训练过程中,使用进度条打印当前训练的进度、训练精度等信息。打印时机可以通过上边提到的 启动器 优雅地配置。 38 | 39 | ![Utc3kQ.png](https://s1.ax1x.com/2020/07/14/Utc3kQ.png) 40 | 41 | 以下为最终的项目包架构。 42 | 43 | ``` 44 | pytorch_train 45 | |-- data -- 存放读取训练、校验、测试数据路径的txt 46 | | |-- train.txt 47 | | |-- val.txt 48 | | |-- test.txt 49 | |-- result -- 存放最终生成训练结果的目录 50 | |-- util -- 模型移植工具 51 | |-- clr.py -- 学习率 52 | |-- dataset.py -- 自定义数据集 53 | |-- flops_benchmark.py -- 统计每秒浮点运算次数 54 | |-- logger.py -- 日志可视化 55 | |-- mobile_net.py -- 网络模型之一 mobile_net2 56 | |-- resnet.py -- 网络模型之一 Resnet系列 57 | |-- run.py -- 具体执行训练、测试方法 58 | |-- start.py -- 启动器 59 | ``` 60 | 61 | ![UtgkuV.png](https://s1.ax1x.com/2020/07/14/UtgkuV.png) 62 | 63 | 64 | ## 模型移植 65 | 66 | Github 地址:[pytorch_train/transfor](https://github.com/tf2jaguar/pytorch_train/blob/master/util/transfor.py) 67 | 68 | 69 | ```python 70 | import os 71 | 72 | import torch 73 | import torchvision 74 | 75 | model_pth = os.path.join("results", "2020-04-27_10-27-17", 'checkpoint.pth.tar') 76 | # 将resnet34模型保存为Android可以调用的文件 77 | mobile_pt = os.path.join("results", "2020-04-27_10-27-17", 'resnet34.pt') 78 | num_class = 13 79 | device = 'cpu' # 'cuda:0' # cpu 80 | 81 | model = torchvision.models.resnet34(num_classes=num_class) 82 | model = torch.nn.DataParallel(model, [0]) 83 | model.to(device=device) 84 | 85 | checkpoint = torch.load(model_pth, map_location=device) 86 | model.load_state_dict(checkpoint['state_dict']) 87 | 88 | model.eval() # 模型设为评估模式 89 | # 1张3通道224*224的图片 90 | input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式 91 | traced_script_module = torch.jit.trace(model.module, input_tensor) # 模型转化 92 | traced_script_module.save(mobile_pt) # 保存文件 93 | ``` 94 | 95 | 96 | ## 启动模型训练 97 | 98 | 启动前需要确保你已经有了本项目使用的数据集 CompCars 99 | 100 | ### 重新开始新的训练 101 | 102 | ```shell script 103 | python start.py --data_root "./data" --gpus 0,1,2 -w 2 -b 120 --num_class 13 104 | ``` 105 | 106 | - --data_root 数据集路径位置 107 | - --gups 使用gpu训练的块数 108 | - -w 为gpu加载自定义数据集的工作线程 109 | - -b 用来gpu训练的 batch size是多少 110 | - --num_class 分类类别数量 111 | 112 | ### 使用上次训练结果继续训练 113 | 114 | ```shell script 115 | python start.py --data_root "./data" --gpus 0,1,2 -w 2 -b 120 --num_class 13 --resume "results/2020-04-14_12-36-16" 116 | ``` 117 | 118 | - --data_root 数据集路径位置 119 | - --gups 使用gpu训练的块数 120 | - -w 为gpu加载自定义数据集的工作线程 121 | - -b 用来gpu训练的 batch size是多少 122 | - --num_class 分类类别数量 123 | - --resume 上次训练结果文件夹,可继续上次的训练 124 | 125 | ### 模型移植 126 | 127 | 将训练好的模型转换为Android可以执行的模型 128 | 129 | ```shell script 130 | python transfor.py 131 | ``` 132 | 133 | ### 项目定制化 134 | 135 | - 找寻自己的数据集 136 | - 需要修改启动脚本中 **--num_class**,模型类别 137 | 138 | 目前项目中具备很多备注记录,稍加review代码就可以理解,如有不清楚,可以私信询问。 139 | 140 | ### 鼓励一下 141 | 142 |
143 | image-qxUDIOimage-qxUBdK 144 |
145 | 146 | 有偿提供全套环境搭建+数据集下载+模型迁移+论文范本+技术指导 147 | -------------------------------------------------------------------------------- /clr.py: -------------------------------------------------------------------------------- 1 | # temporary file until https://github.com/pytorch/pytorch/pull/2016 is merged (hopefully 0.5) 2 | 3 | 4 | import numpy as np 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | class CyclicLR(object): 9 | """Sets the learning rate of each parameter group according to 10 | cyclical learning rate policy (CLR). The policy cycles the learning 11 | rate between two boundaries with a constant frequency, as detailed in 12 | the paper `Cyclical Learning Rates for Training Neural Networks`_. 13 | The distance between the two boundaries can be scaled on a per-iteration 14 | or per-cycle basis. 15 | Cyclical learning rate policy changes the learning rate after every batch. 16 | `batch_step` should be called after a batch has been used for training. 17 | To resume training, save `last_batch_iteration` and use it to instantiate `CycleLR`. 18 | This class has three built-in policies, as put forth in the paper: 19 | "triangular": 20 | A basic triangular cycle w/ no amplitude scaling. 21 | "triangular2": 22 | A basic triangular cycle that scales initial amplitude by half each cycle. 23 | "exp_range": 24 | A cycle that scales initial amplitude by gamma**(cycle iterations) at each 25 | cycle iteration. 26 | This implementation was adapted from the github repo: `bckenstler/CLR`_ 27 | Args: 28 | optimizer (Optimizer): Wrapped optimizer. 29 | base_lr (float or list): Initial learning rate which is the 30 | lower boundary in the cycle for eachparam groups. 31 | Default: 0.001 32 | max_lr (float or list): Upper boundaries in the cycle for 33 | each parameter group. Functionally, 34 | it defines the cycle amplitude (max_lr - base_lr). 35 | The lr at any cycle is the sum of base_lr 36 | and some scaling of the amplitude; therefore 37 | max_lr may not actually be reached depending on 38 | scaling function. Default: 0.006 39 | step_size (int): Number of training iterations per 40 | half cycle. Authors suggest setting step_size 41 | 2-8 x training iterations in epoch. Default: 2000 42 | mode (str): One of {triangular, triangular2, exp_range}. 43 | Values correspond to policies detailed above. 44 | If scale_fn is not None, this argument is ignored. 45 | Default: 'triangular' 46 | gamma (float): Constant in 'exp_range' scaling function: 47 | gamma**(cycle iterations) 48 | Default: 1.0 49 | scale_fn (function): Custom scaling policy defined by a single 50 | argument lambda function, where 51 | 0 <= scale_fn(x) <= 1 for all x >= 0. 52 | mode paramater is ignored 53 | Default: None 54 | scale_mode (str): {'cycle', 'iterations'}. 55 | Defines whether scale_fn is evaluated on 56 | cycle number or cycle iterations (training 57 | iterations since start of cycle). 58 | Default: 'cycle' 59 | last_batch_iteration (int): The index of the last batch. Default: -1 60 | Example: 61 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 62 | >>> scheduler = torch.optim.CyclicLR(optimizer) 63 | >>> data_loader = torch.utils.data.DataLoader(...) 64 | >>> for epoch in range(10): 65 | >>> for batch in data_loader: 66 | >>> scheduler.batch_step() 67 | >>> train_batch(...) 68 | .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 69 | .. _bckenstler/CLR: https://github.com/bckenstler/CLR 70 | """ 71 | 72 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 73 | step_size=2000, mode='triangular', gamma=1., 74 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 75 | 76 | if not isinstance(optimizer, Optimizer): 77 | raise TypeError('{} is not an Optimizer'.format( 78 | type(optimizer).__name__)) 79 | self.optimizer = optimizer 80 | 81 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 82 | if len(base_lr) != len(optimizer.param_groups): 83 | raise ValueError("expected {} base_lr, got {}".format( 84 | len(optimizer.param_groups), len(base_lr))) 85 | self.base_lrs = list(base_lr) 86 | else: 87 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 88 | 89 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 90 | if len(max_lr) != len(optimizer.param_groups): 91 | raise ValueError("expected {} max_lr, got {}".format( 92 | len(optimizer.param_groups), len(max_lr))) 93 | self.max_lrs = list(max_lr) 94 | else: 95 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 96 | 97 | self.step_size = step_size 98 | 99 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 100 | and scale_fn is None: 101 | raise ValueError('mode is invalid and scale_fn is None') 102 | 103 | self.mode = mode 104 | self.gamma = gamma 105 | 106 | if scale_fn is None: 107 | if self.mode == 'triangular': 108 | self.scale_fn = self._triangular_scale_fn 109 | self.scale_mode = 'cycle' 110 | elif self.mode == 'triangular2': 111 | self.scale_fn = self._triangular2_scale_fn 112 | self.scale_mode = 'cycle' 113 | elif self.mode == 'exp_range': 114 | self.scale_fn = self._exp_range_scale_fn 115 | self.scale_mode = 'iterations' 116 | else: 117 | self.scale_fn = scale_fn 118 | self.scale_mode = scale_mode 119 | 120 | self.batch_step(last_batch_iteration + 1) 121 | self.last_batch_iteration = last_batch_iteration 122 | 123 | def batch_step(self, batch_iteration=None): 124 | if batch_iteration is None: 125 | batch_iteration = self.last_batch_iteration + 1 126 | self.last_batch_iteration = batch_iteration 127 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 128 | param_group['lr'] = lr 129 | 130 | def _triangular_scale_fn(self, x): 131 | return 1. 132 | 133 | def _triangular2_scale_fn(self, x): 134 | return 1 / (2. ** (x - 1)) 135 | 136 | def _exp_range_scale_fn(self, x): 137 | return self.gamma ** (x) 138 | 139 | def get_lr(self): 140 | step_size = float(self.step_size) 141 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 142 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 143 | 144 | lrs = [] 145 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 146 | for param_group, base_lr, max_lr in param_lrs: 147 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 148 | if self.scale_mode == 'cycle': 149 | lr = base_lr + base_height * self.scale_fn(cycle) 150 | else: 151 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 152 | lrs.append(lr) 153 | return lrs 154 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # 存放用于读取train、val、test的目录文件 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | 12 | __image_net_stats = {'mean': [0.485, 0.456, 0.406], 13 | 'std': [0.229, 0.224, 0.225]} 14 | 15 | 16 | class CarDataset(Dataset): 17 | 18 | def __init__(self, csv_file, root_dir, transform=None): 19 | """ 20 | CarDataset: Custom dataset 21 | :param csv_file(字符串):带有注释的csv文件的路径。 22 | :param root_dir(字符串):包含所有图像的根目录。 23 | :param transform(可调用,可选):应用于样本的可选transform。 24 | """ 25 | self.landmarks_frame = pd.read_csv(csv_file) 26 | self.root_dir = root_dir 27 | self.transform = transform 28 | 29 | def __getitem__(self, index): 30 | img_path = os.path.join(self.root_dir, self.landmarks_frame.iloc[index, 0]) 31 | label = self.landmarks_frame.iloc[index, 1:] 32 | image = Image.open(img_path) 33 | if self.transform: 34 | image = self.transform(image) 35 | return image, int(label) 36 | 37 | def __len__(self): 38 | return len(self.landmarks_frame) 39 | 40 | 41 | def inception_preproccess(input_size, normalize=None): 42 | if normalize is None: 43 | normalize = __image_net_stats 44 | return transforms.Compose([ 45 | transforms.RandomResizedCrop(input_size), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(**normalize) 49 | ]) 50 | 51 | 52 | def scale_crop(input_size, scale_size=None, normalize=None): 53 | if normalize is None: 54 | normalize = __image_net_stats 55 | t_list = [ 56 | transforms.CenterCrop(input_size), 57 | transforms.ToTensor(), 58 | transforms.Normalize(**normalize), 59 | ] 60 | if scale_size != input_size: 61 | t_list = [transforms.Resize(scale_size)] + t_list 62 | 63 | return transforms.Compose(t_list) 64 | 65 | 66 | def get_transform(augment=True, input_size=224): 67 | normalize = __image_net_stats 68 | scale_size = int(input_size / 0.875) 69 | if augment: 70 | return inception_preproccess(input_size=input_size, normalize=normalize) 71 | else: 72 | return scale_crop(input_size=input_size, scale_size=scale_size, normalize=normalize) 73 | 74 | 75 | def get_loaders(dataroot, val_batch_size, train_batch_size, input_size, workers): 76 | val_data = CarDataset(dataroot + '/val.txt', './data', transform=get_transform(True, input_size)) 77 | # val_data = datasets.ImageFolder(root=os.path.join(dataroot, 'val'), transform=get_transform(False, input_size)) 78 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, shuffle=False, num_workers=workers, 79 | pin_memory=True) 80 | 81 | # train_data = datasets.ImageFolder(root=os.path.join(dataroot, 'train'), 82 | # transform=get_transform(input_size=input_size)) 83 | 84 | train_data = CarDataset(dataroot + '/train.txt', './data', transform=get_transform(True, input_size)) 85 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_batch_size, shuffle=True, 86 | num_workers=workers, pin_memory=True) 87 | return train_loader, val_loader 88 | 89 | 90 | def get_test_loaders(dataroot, batch_size, input_size, workers): 91 | test_data = CarDataset(dataroot + '/test.txt', './data', transform=get_transform(True, input_size)) 92 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, 93 | num_workers=workers, pin_memory=True) 94 | return test_loader 95 | -------------------------------------------------------------------------------- /flops_benchmark.py: -------------------------------------------------------------------------------- 1 | #### https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py 2 | import torch 3 | 4 | 5 | # ---- Public functions 6 | 7 | def add_flops_counting_methods(net_main_module): 8 | """Adds flops counting functions to an existing model. After that 9 | the flops count should be activated and the model should be run on an input 10 | image. 11 | Example: 12 | fcn = add_flops_counting_methods(fcn) 13 | fcn = fcn.cuda().train() 14 | fcn.start_flops_count() 15 | _ = fcn(batch) 16 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch 17 | Important: dividing by 2 only works for resnet models -- see below for the details 18 | of flops computation. 19 | Attention: we are counting multiply-add as two flops in this work, because in 20 | most resnet models convolutions are bias-free (BN layers act as bias there) 21 | and it makes sense to count muliply and add as separate flops therefore. 22 | This is why in the above example we divide by 2 in order to be consistent with 23 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual 24 | Networks" by Figurnov et al multiply-add was counted as two flops. 25 | This module computes the average flops which is necessary for dynamic networks which 26 | have different number of executed layers. For static networks it is enough to run the network 27 | once and get statistics (above example). 28 | Implementation: 29 | The module works by adding batch_count to the main module which tracks the sum 30 | of all batch sizes that were run through the network. 31 | Also each convolutional layer of the network tracks the overall number of flops 32 | performed. 33 | The parameters are updated with the help of registered hook-functions which 34 | are being called each time the respective layer is executed. 35 | Parameters 36 | ---------- 37 | net_main_module : torch.nn.Module 38 | Main module containing network 39 | Returns 40 | ------- 41 | net_main_module : torch.nn.Module 42 | Updated main module with new methods/attributes that are used 43 | to compute flops. 44 | """ 45 | 46 | # adding additional methods to the existing module object, 47 | # this is done this way so that each function has access to self object 48 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 49 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 50 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 51 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 52 | 53 | net_main_module.reset_flops_count() 54 | 55 | # Adding varialbles necessary for masked flops computation 56 | net_main_module.apply(add_flops_mask_variable_or_reset) 57 | 58 | return net_main_module 59 | 60 | 61 | def compute_average_flops_cost(self): 62 | """ 63 | A method that will be available after add_flops_counting_methods() is called 64 | on a desired net object. 65 | Returns current mean flops consumption per image. 66 | """ 67 | 68 | batches_count = self.__batch_counter__ 69 | 70 | flops_sum = 0 71 | 72 | for module in self.modules(): 73 | 74 | if isinstance(module, torch.nn.Conv2d): 75 | flops_sum += module.__flops__ 76 | 77 | return flops_sum / batches_count 78 | 79 | 80 | def start_flops_count(self): 81 | """ 82 | A method that will be available after add_flops_counting_methods() is called 83 | on a desired net object. 84 | Activates the computation of mean flops consumption per image. 85 | Call it before you run the network. 86 | """ 87 | 88 | add_batch_counter_hook_function(self) 89 | 90 | self.apply(add_flops_counter_hook_function) 91 | 92 | 93 | def stop_flops_count(self): 94 | """ 95 | A method that will be available after add_flops_counting_methods() is called 96 | on a desired net object. 97 | Stops computing the mean flops consumption per image. 98 | Call whenever you want to pause the computation. 99 | """ 100 | 101 | remove_batch_counter_hook_function(self) 102 | 103 | self.apply(remove_flops_counter_hook_function) 104 | 105 | 106 | def reset_flops_count(self): 107 | """ 108 | A method that will be available after add_flops_counting_methods() is called 109 | on a desired net object. 110 | Resets statistics computed so far. 111 | """ 112 | 113 | add_batch_counter_variables_or_reset(self) 114 | 115 | self.apply(add_flops_counter_variable_or_reset) 116 | 117 | 118 | def add_flops_mask(module, mask): 119 | def add_flops_mask_func(module): 120 | if isinstance(module, torch.nn.Conv2d): 121 | module.__mask__ = mask 122 | 123 | module.apply(add_flops_mask_func) 124 | 125 | 126 | def remove_flops_mask(module): 127 | module.apply(add_flops_mask_variable_or_reset) 128 | 129 | 130 | # ---- Internal functions 131 | 132 | 133 | def conv_flops_counter_hook(conv_module, input, output): 134 | # Can have multiple inputs, getting the first one 135 | input = input[0] 136 | 137 | batch_size = input.shape[0] 138 | output_height, output_width = output.shape[2:] 139 | 140 | kernel_height, kernel_width = conv_module.kernel_size 141 | in_channels = conv_module.in_channels 142 | out_channels = conv_module.out_channels 143 | groups = conv_module.groups 144 | 145 | # We count multiply-add as 2 flops 146 | conv_per_position_flops = 2 * kernel_height * kernel_width * in_channels * out_channels / groups 147 | 148 | active_elements_count = batch_size * output_height * output_width 149 | 150 | if conv_module.__mask__ is not None: 151 | # (b, 1, h, w) 152 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 153 | active_elements_count = flops_mask.sum() 154 | 155 | overall_conv_flops = conv_per_position_flops * active_elements_count 156 | 157 | bias_flops = 0 158 | 159 | if conv_module.bias is not None: 160 | bias_flops = out_channels * active_elements_count 161 | 162 | overall_flops = overall_conv_flops + bias_flops 163 | 164 | conv_module.__flops__ += overall_flops 165 | 166 | 167 | def batch_counter_hook(module, input, output): 168 | # Can have multiple inputs, getting the first one 169 | input = input[0] 170 | 171 | batch_size = input.shape[0] 172 | 173 | module.__batch_counter__ += batch_size 174 | 175 | 176 | def add_batch_counter_variables_or_reset(module): 177 | module.__batch_counter__ = 0 178 | 179 | 180 | def add_batch_counter_hook_function(module): 181 | if hasattr(module, '__batch_counter_handle__'): 182 | return 183 | 184 | handle = module.register_forward_hook(batch_counter_hook) 185 | module.__batch_counter_handle__ = handle 186 | 187 | 188 | def remove_batch_counter_hook_function(module): 189 | if hasattr(module, '__batch_counter_handle__'): 190 | module.__batch_counter_handle__.remove() 191 | 192 | del module.__batch_counter_handle__ 193 | 194 | 195 | def add_flops_counter_variable_or_reset(module): 196 | if isinstance(module, torch.nn.Conv2d): 197 | module.__flops__ = 0 198 | 199 | 200 | def add_flops_counter_hook_function(module): 201 | if isinstance(module, torch.nn.Conv2d): 202 | 203 | if hasattr(module, '__flops_handle__'): 204 | return 205 | 206 | handle = module.register_forward_hook(conv_flops_counter_hook) 207 | module.__flops_handle__ = handle 208 | 209 | 210 | def remove_flops_counter_hook_function(module): 211 | if isinstance(module, torch.nn.Conv2d): 212 | 213 | if hasattr(module, '__flops_handle__'): 214 | module.__flops_handle__.remove() 215 | 216 | del module.__flops_handle__ 217 | 218 | 219 | # --- Masked flops counting 220 | 221 | 222 | # Also being run in the initialization 223 | def add_flops_mask_variable_or_reset(module): 224 | if isinstance(module, torch.nn.Conv2d): 225 | module.__mask__ = None 226 | 227 | 228 | def count_flops(model, batch_size, device, dtype, input_size, in_channels, *params): 229 | net = model(*params, input_size=input_size) 230 | # print(net) 231 | net = add_flops_counting_methods(net) 232 | 233 | net.to(device=device, dtype=dtype) 234 | net = net.train() 235 | 236 | batch = torch.randn(batch_size, in_channels, input_size, input_size).to(device=device, dtype=dtype) 237 | net.start_flops_count() 238 | 239 | _ = net(batch) 240 | return net.compute_average_flops_cost() / 2 # Result in FLOPs -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os.path 3 | 4 | import matplotlib 5 | import numpy as np 6 | from matplotlib import pyplot as plt 7 | 8 | matplotlib.use('Agg') 9 | plt.switch_backend('agg') 10 | 11 | 12 | class CsvLogger: 13 | def __init__(self, filepath='./', filename='results.csv', data=None): 14 | self.log_path = filepath 15 | self.log_name = filename 16 | self.csv_path = os.path.join(self.log_path, self.log_name) 17 | self.fieldsnames = ['epoch', 'val_error1', 'val_error5', 'val_loss', 'train_error1', 'train_error5', 18 | 'train_loss'] 19 | 20 | with open(self.csv_path, 'w') as f: 21 | writer = csv.DictWriter(f, fieldnames=self.fieldsnames) 22 | writer.writeheader() 23 | 24 | self.data = {} 25 | for field in self.fieldsnames: 26 | self.data[field] = [] 27 | if data is not None: 28 | for d in data: 29 | d_num = {} 30 | for key in d: 31 | d_num[key] = float(d[key]) if key != 'epoch' else int(d[key]) 32 | self.write(d_num) 33 | 34 | def write(self, data): 35 | for k in self.data: 36 | self.data[k].append(data[k]) 37 | with open(self.csv_path, 'a') as f: 38 | writer = csv.DictWriter(f, fieldnames=self.fieldsnames) 39 | writer.writerow(data) 40 | 41 | def save_params(self, args, params): 42 | with open(os.path.join(self.log_path, 'params.txt'), 'w') as f: 43 | f.write('{}\n'.format(' '.join(args))) 44 | f.write('{}\n'.format(params)) 45 | 46 | def write_text(self, text, print_t=True): 47 | with open(os.path.join(self.log_path, 'params.txt'), 'a') as f: 48 | f.write('{}\n'.format(text)) 49 | if print_t: 50 | print(text) 51 | 52 | def plot_progress_errk(self, claimed_acc=None, title='Net', k=1): 53 | tr_str = 'train_error{}'.format(k) 54 | val_str = 'val_error{}'.format(k) 55 | plt.figure(figsize=(9, 8), dpi=300) 56 | plt.plot(self.data[tr_str], label='Training error') 57 | plt.plot(self.data[val_str], label='Validation error') 58 | if claimed_acc is not None: 59 | plt.plot((0, len(self.data[tr_str])), (1 - claimed_acc, 1 - claimed_acc), 'k--', 60 | label='Claimed validation error ({:.2f}%)'.format(100. * (1 - claimed_acc))) 61 | plt.plot((0, len(self.data[tr_str])), 62 | (np.min(self.data[val_str]), np.min(self.data[val_str])), 'r--', 63 | label='Best validation error ({:.2f}%)'.format(100. * np.min(self.data[val_str]))) 64 | plt.title('Top-{} error for {}'.format(k, title)) 65 | plt.xlabel('Epoch') 66 | plt.ylabel('Error') 67 | plt.legend() 68 | plt.xlim(0, len(self.data[tr_str]) + 1) 69 | plt.savefig(os.path.join(self.log_path, 'top{}.png'.format(k))) 70 | 71 | def plot_progress_loss(self, title='Net'): 72 | plt.figure(figsize=(9, 8), dpi=300) 73 | plt.plot(self.data['train_loss'], label='Training') 74 | plt.plot(self.data['val_loss'], label='Validation') 75 | plt.title(title) 76 | plt.xlabel('Epoch') 77 | plt.ylabel('Loss') 78 | plt.legend() 79 | plt.xlim(0, len(self.data['train_loss']) + 1) 80 | plt.savefig(os.path.join(self.log_path, 'loss.png')) 81 | 82 | def plot_progress(self, claimed_acc1=None, claimed_acc5=None, title='Net'): 83 | self.plot_progress_errk(claimed_acc1, title, 1) 84 | self.plot_progress_errk(claimed_acc5, title, 5) 85 | self.plot_progress_loss(title) 86 | plt.close('all') 87 | -------------------------------------------------------------------------------- /mobile_net.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | def _make_divisible(v, divisor, min_value=None): 10 | """ 11 | This function is taken from the original tf repo. 12 | It ensures that all layers have a channel number that is divisible by 8 13 | It can be seen here: 14 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 15 | :param v: 16 | :param divisor: 17 | :param min_value: 18 | :return: 19 | """ 20 | if min_value is None: 21 | min_value = divisor 22 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 23 | # Make sure that round down does not go down by more than 10%. 24 | if new_v < 0.9 * v: 25 | new_v += divisor 26 | return new_v 27 | 28 | 29 | class LinearBottleneck(nn.Module): 30 | def __init__(self, inplanes, outplanes, stride=1, t=6, activation=nn.ReLU6): 31 | super(LinearBottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, inplanes * t, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(inplanes * t) 34 | self.conv2 = nn.Conv2d(inplanes * t, inplanes * t, kernel_size=3, stride=stride, padding=1, bias=False, 35 | groups=inplanes * t) 36 | self.bn2 = nn.BatchNorm2d(inplanes * t) 37 | self.conv3 = nn.Conv2d(inplanes * t, outplanes, kernel_size=1, bias=False) 38 | self.bn3 = nn.BatchNorm2d(outplanes) 39 | self.activation = activation(inplace=True) 40 | self.stride = stride 41 | self.t = t 42 | self.inplanes = inplanes 43 | self.outplanes = outplanes 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.activation(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.activation(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.stride == 1 and self.inplanes == self.outplanes: 60 | out += residual 61 | 62 | return out 63 | 64 | 65 | class MobileNet2(nn.Module): 66 | """MobileNet2 implementation. 67 | """ 68 | 69 | def __init__(self, scale=1.0, input_size=224, t=6, in_channels=3, num_classes=1000, activation=nn.ReLU6): 70 | """ 71 | MobileNet2 constructor. 72 | :param in_channels: (int, optional): number of channels in the input tensor. 73 | Default is 3 for RGB image inputs. 74 | :param input_size: 75 | :param num_classes: number of classes to predict. Default 76 | is 1000 for ImageNet. 77 | :param scale: 78 | :param t: 79 | :param activation: 80 | """ 81 | 82 | super(MobileNet2, self).__init__() 83 | 84 | self.scale = scale 85 | self.t = t 86 | self.activation_type = activation 87 | self.activation = activation(inplace=True) 88 | self.num_classes = num_classes 89 | 90 | self.num_of_channels = [32, 16, 24, 32, 64, 96, 160, 320] 91 | # assert (input_size % 32 == 0) 92 | 93 | self.c = [_make_divisible(ch * self.scale, 8) for ch in self.num_of_channels] 94 | self.n = [1, 1, 2, 3, 4, 3, 3, 1] 95 | self.s = [2, 1, 2, 2, 2, 1, 2, 1] 96 | self.conv1 = nn.Conv2d(in_channels, self.c[0], kernel_size=3, bias=False, stride=self.s[0], padding=1) 97 | self.bn1 = nn.BatchNorm2d(self.c[0]) 98 | self.bottlenecks = self._make_bottlenecks() 99 | 100 | # Last convolution has 1280 output channels for scale <= 1 101 | self.last_conv_out_ch = 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8) 102 | self.conv_last = nn.Conv2d(self.c[-1], self.last_conv_out_ch, kernel_size=1, bias=False) 103 | self.bn_last = nn.BatchNorm2d(self.last_conv_out_ch) 104 | self.avgpool = nn.AdaptiveAvgPool2d(1) 105 | self.dropout = nn.Dropout(p=0.2, inplace=True) # confirmed by paper authors 106 | self.fc = nn.Linear(self.last_conv_out_ch, self.num_classes) 107 | self.init_params() 108 | 109 | def init_params(self): 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | init.kaiming_normal_(m.weight, mode='fan_out') 113 | if m.bias is not None: 114 | init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | init.constant_(m.weight, 1) 117 | init.constant_(m.bias, 0) 118 | elif isinstance(m, nn.Linear): 119 | init.normal_(m.weight, std=0.001) 120 | if m.bias is not None: 121 | init.constant_(m.bias, 0) 122 | 123 | def _make_stage(self, inplanes, outplanes, n, stride, t, stage): 124 | modules = OrderedDict() 125 | stage_name = "LinearBottleneck{}".format(stage) 126 | 127 | # First module is the only one utilizing stride 128 | first_module = LinearBottleneck(inplanes=inplanes, outplanes=outplanes, stride=stride, t=t, 129 | activation=self.activation_type) 130 | modules[stage_name + "_0"] = first_module 131 | 132 | # add more LinearBottleneck depending on number of repeats 133 | for i in range(n - 1): 134 | name = stage_name + "_{}".format(i + 1) 135 | module = LinearBottleneck(inplanes=outplanes, outplanes=outplanes, stride=1, t=6, 136 | activation=self.activation_type) 137 | modules[name] = module 138 | 139 | return nn.Sequential(modules) 140 | 141 | def _make_bottlenecks(self): 142 | modules = OrderedDict() 143 | stage_name = "Bottlenecks" 144 | 145 | # First module is the only one with t=1 146 | bottleneck1 = self._make_stage(inplanes=self.c[0], outplanes=self.c[1], n=self.n[1], stride=self.s[1], t=1, 147 | stage=0) 148 | modules[stage_name + "_0"] = bottleneck1 149 | 150 | # add more LinearBottleneck depending on number of repeats 151 | for i in range(1, len(self.c) - 1): 152 | name = stage_name + "_{}".format(i) 153 | module = self._make_stage(inplanes=self.c[i], outplanes=self.c[i + 1], n=self.n[i + 1], 154 | stride=self.s[i + 1], 155 | t=self.t, stage=i) 156 | modules[name] = module 157 | 158 | return nn.Sequential(modules) 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.bn1(x) 163 | x = self.activation(x) 164 | 165 | x = self.bottlenecks(x) 166 | x = self.conv_last(x) 167 | x = self.bn_last(x) 168 | x = self.activation(x) 169 | 170 | # average pooling layer 171 | x = self.avgpool(x) 172 | x = self.dropout(x) 173 | 174 | # flatten for input to fully-connected layer 175 | x = x.view(x.size(0), -1) 176 | x = self.fc(x) 177 | return F.log_softmax(x, dim=1) #TODO not needed(?) 178 | 179 | 180 | if __name__ == "__main__": 181 | """Testing 182 | """ 183 | model1 = MobileNet2() 184 | print(model1) 185 | model2 = MobileNet2(scale=0.35) 186 | print(model2) 187 | model3 = MobileNet2(in_channels=2, num_classes=10) 188 | print(model3) 189 | x = torch.randn(1, 2, 224, 224) 190 | print(model3(x)) 191 | model4_size = 32 * 10 192 | model4 = MobileNet2(input_size=model4_size, num_classes=10) 193 | print(model4) 194 | x2 = torch.randn(1, 3, model4_size, model4_size) 195 | print(model4(x2)) 196 | model5 = MobileNet2(input_size=196, num_classes=10) 197 | x3 = torch.randn(1, 3, 196, 196) 198 | print(model5(x3)) # fail -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | torch>=0.4.0 3 | torchvision>=0.1.9 4 | tqdm>=4.19.4 5 | matplotlib 6 | numpy -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch.nn as nn 4 | from torch.utils.model_zoo import load_url 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'http://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, block, layers, num_classes=1000, norm_layer=None): 97 | super(ResNet, self).__init__() 98 | if norm_layer is None: 99 | norm_layer = nn.BatchNorm2d 100 | self._norm_layer = norm_layer 101 | 102 | self.inplanes = 64 103 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 104 | self.bn1 = nn.BatchNorm2d(self.inplanes) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | 108 | self.layer1 = self._make_layer(block, 64, layers[0]) 109 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 112 | 113 | self.avgpool = nn.AvgPool2d(7, stride=1) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 119 | m.weight.data.normal_(0, math.sqrt(2. / n)) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | norm_layer = self._norm_layer 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | norm_layer(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | x = self.conv1(x) 144 | x = self.bn1(x) 145 | x = self.relu(x) 146 | x = self.maxpool(x) 147 | 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | 153 | x = self.avgpool(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.fc(x) 156 | 157 | return x 158 | 159 | 160 | def resnet18(pretrained=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | """ 165 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 166 | if pretrained: 167 | model.load_state_dict(load_url(model_urls['resnet18'])) 168 | return model 169 | 170 | 171 | def resnet34(pretrained=False, modelpath='./models', **kwargs): 172 | """Constructs a ResNet-34 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(load_url(model_urls['resnet34'], model_dir=modelpath)) 179 | return model 180 | 181 | 182 | def resnet50(pretrained=False, modelpath='./models', **kwargs): 183 | """Constructs a ResNet-50 model from 184 | `"Deep Residual Learning for Image Recognition" `_ 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 189 | if pretrained: 190 | model.load_state_dict(load_url(model_urls['resnet50'], model_dir=modelpath)) 191 | return model 192 | 193 | 194 | def resnet101(pretrained=False, modelpath='./models', **kwargs): 195 | """Constructs a ResNet-101 model. 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(load_url(model_urls['resnet101'], model_dir=modelpath)) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, modelpath='./models', **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(load_url(model_urls['resnet152'], model_dir=modelpath)) 213 | return model 214 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # 存放训练结果 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import matplotlib 5 | import numpy as np 6 | import torch 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data 10 | from matplotlib import pyplot as plt 11 | from tqdm import tqdm, trange 12 | 13 | from clr import CyclicLR 14 | 15 | matplotlib.use('Agg') 16 | 17 | 18 | def train(model, loader, epoch, optimizer, criterion, device, dtype, batch_size, log_interval, scheduler): 19 | model.train() 20 | correct1, correct5 = 0, 0 21 | 22 | for batch_idx, (data, target) in enumerate(tqdm(loader)): 23 | if isinstance(scheduler, CyclicLR): 24 | scheduler.batch_step() 25 | data, target = data.to(device=device, dtype=dtype), target.to(device=device) 26 | 27 | optimizer.zero_grad() 28 | output = model(data) 29 | 30 | loss = criterion(output, target) 31 | loss.backward() 32 | optimizer.step() 33 | 34 | corr = correct(output, target, topk=(1, 5)) 35 | correct1 += corr[0] 36 | correct5 += corr[1] 37 | 38 | if batch_idx % log_interval == 0: 39 | tqdm.write( 40 | '\nTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}. ' 41 | 'Top-1 accuracy: {:.2f}%({:.2f}%). ' 42 | 'Top-5 accuracy: {:.2f}%({:.2f}%).'.format(epoch, batch_idx, len(loader), 43 | 100. * batch_idx / len(loader), loss.item(), 44 | 100. * corr[0] / batch_size, 45 | 100. * correct1 / (batch_size * (batch_idx + 1)), 46 | 100. * corr[1] / batch_size, 47 | 100. * correct5 / (batch_size * (batch_idx + 1)))) 48 | return loss.item(), correct1 / len(loader.dataset), correct5 / len(loader.dataset) 49 | 50 | 51 | def val(model, loader, criterion, device, dtype): 52 | model.eval() 53 | test_loss = 0 54 | correct1, correct5 = 0, 0 55 | 56 | for batch_idx, (data, target) in enumerate(tqdm(loader)): 57 | data, target = data.to(device=device, dtype=dtype), target.to(device=device) 58 | with torch.no_grad(): 59 | output = model(data) 60 | test_loss += criterion(output, target).item() # sum up batch loss 61 | corr = correct(output, target, topk=(1, 5)) 62 | correct1 += corr[0] 63 | correct5 += corr[1] 64 | 65 | test_loss /= len(loader) 66 | 67 | tqdm.write( 68 | '\nTest set: Average loss: {:.4f}, Top1: {}/{} ({:.2f}%), ' 69 | 'Top5: {}/{} ({:.2f}%)'.format(test_loss, int(correct1), len(loader.dataset), 70 | 100. * correct1 / len(loader.dataset), int(correct5), 71 | len(loader.dataset), 100. * correct5 / len(loader.dataset))) 72 | return test_loss, correct1 / len(loader.dataset), correct5 / len(loader.dataset) 73 | 74 | 75 | def test(model, loader, criterion, device, dtype, classes): 76 | model.eval() 77 | test_loss = 0 78 | correct1, correct5 = 0, 0 79 | 80 | for batch_idx, (data, target) in enumerate(loader): 81 | data, target = data.to(device=device, dtype=dtype), target.to(device=device) 82 | with torch.no_grad(): 83 | output = model(data) 84 | 85 | _, prediction = torch.max(output.data, dim=1) 86 | print('Predicted: ', '; '.join('%5s' % classes[prediction[j]] for j in range(len(prediction)))) 87 | 88 | test_loss += criterion(output, target).item() # sum up batch loss 89 | corr = correct(output, target, topk=(1, 5)) 90 | correct1 += corr[0] 91 | correct5 += corr[1] 92 | 93 | test_loss /= len(loader) 94 | return test_loss, correct1 / len(loader.dataset), correct5 / len(loader.dataset) 95 | 96 | 97 | def correct(output, target, topk=(1,)): 98 | """Computes the correct@k for the specified values of k""" 99 | maxk = max(topk) 100 | 101 | _, pred = output.topk(maxk, 1, True, True) 102 | pred = pred.t().type_as(target) 103 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 104 | 105 | res = [] 106 | for k in topk: 107 | correct_k = correct[:k].view(-1).float().sum(0).item() 108 | res.append(correct_k) 109 | return res 110 | 111 | 112 | def save_checkpoint(state, is_best, filepath='./', filename='checkpoint.pth.tar'): 113 | save_path = os.path.join(filepath, filename) 114 | best_path = os.path.join(filepath, 'model_best.pth.tar') 115 | torch.save(state, save_path) 116 | if is_best: 117 | shutil.copyfile(save_path, best_path) 118 | 119 | 120 | def find_bounds_clr(model, loader, optimizer, criterion, device, dtype, min_lr=8e-6, max_lr=8e-5, step_size=2000, 121 | mode='triangular', save_path='.'): 122 | model.train() 123 | correct1, correct5 = 0, 0 124 | scheduler = CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, step_size=step_size, mode=mode) 125 | epoch_count = step_size // len(loader) # Assuming step_size is multiple of batch per epoch 126 | accuracy = [] 127 | for _ in trange(epoch_count): 128 | for batch_idx, (data, target) in enumerate(tqdm(loader)): 129 | if scheduler is not None: 130 | scheduler.batch_step() 131 | data, target = data.to(device=device, dtype=dtype), target.to(device=device) 132 | 133 | optimizer.zero_grad() 134 | output = model(data) 135 | 136 | loss = criterion(output, target) 137 | loss.backward() 138 | optimizer.step() 139 | 140 | corr = correct(output, target) 141 | accuracy.append(corr[0] / data.shape[0]) 142 | 143 | lrs = np.linspace(min_lr, max_lr, step_size) 144 | plt.plot(lrs, accuracy) 145 | plt.show() 146 | plt.savefig(os.path.join(save_path, 'find_bounds_clr.png')) 147 | np.save(os.path.join(save_path, 'acc.npy'), accuracy) 148 | return 149 | -------------------------------------------------------------------------------- /start.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import random 5 | import sys 6 | from datetime import datetime 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.parallel 11 | import torch.optim 12 | import torch.utils.data 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | from tqdm import trange 15 | 16 | import flops_benchmark 17 | from clr import CyclicLR 18 | from dataset import get_loaders, get_test_loaders 19 | from logger import CsvLogger 20 | from mobile_net import MobileNet2 21 | from resnet import resnet34 22 | from run import train, val, test, save_checkpoint, find_bounds_clr 23 | 24 | parser = argparse.ArgumentParser(description='Pytorch training') 25 | parser.add_argument('--data_root', required=True, metavar='PATH', help='Path to ImageNet train and val folders') 26 | parser.add_argument('--gpus', default=None, help='List of GPUs used for training e.g 0,1,3') 27 | parser.add_argument('-w', '--workers', default=4, type=int, metavar='N', 28 | help='Number of data loading workers (default: 4)') 29 | parser.add_argument('--type', default='float32', help='Type of tensor: float32, float16, float64. Default: float32') 30 | parser.add_argument('--num_class', default=1000, type=int, help='Number of data categories (default: 1000)') 31 | 32 | # Optimization options 33 | parser.add_argument('--epochs', type=int, default=400, help='Number of epochs to train.') 34 | parser.add_argument('-b', '--batch_size', default=64, type=int, metavar='N', help='mini-batch size (default: 64)') 35 | parser.add_argument('--learning_rate', '-lr', type=float, default=0.001, help='The learning rate (default: 0.001)') 36 | parser.add_argument('--momentum', '-m', type=float, default=0.9, help='Momentum (default: 0.9)') 37 | parser.add_argument('--decay', '-d', type=float, default=4e-5, help='Weight decay (L2 penalty).') 38 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma at scheduled epochs.') 39 | parser.add_argument('--schedule', type=int, nargs='+', default=[200, 300], 40 | help='Decrease learning rate at these epochs.') 41 | 42 | # CLR 43 | parser.add_argument('--clr', dest='clr', action='store_true', help='Use CLR') 44 | parser.add_argument('--min_lr', type=float, default=1e-5, help='Minimal LR for CLR.') 45 | parser.add_argument('--max_lr', type=float, default=1, help='Maximal LR for CLR.') 46 | parser.add_argument('--epochs_per_step', type=int, default=20, 47 | help='Number of epochs per step in CLR, recommended to be between 2 and 10.') 48 | parser.add_argument('--mode', default='triangular2', help='CLR mode. One of {triangular, triangular2, exp_range}') 49 | parser.add_argument('--find_clr', dest='find_clr', action='store_true', 50 | help='Run search for optimal LR in range (min_lr, max_lr)') 51 | 52 | # Checkpoints 53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='Just evaluate model') 54 | parser.add_argument('-s', '--save', type=str, default='', help='Folder to save checkpoints.') 55 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', help='Directory to store results') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 57 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 58 | parser.add_argument('--log_interval', type=int, default=100, metavar='N', help='Number of batches between log messages') 59 | parser.add_argument('--seed', type=int, default=None, metavar='S', help='random seed (default: 1)') 60 | 61 | # Architecture 62 | parser.add_argument('--scaling', type=float, default=1, metavar='SC', help='Scaling of Net (default x1).') 63 | parser.add_argument('--input_size', type=int, default=224, metavar='I', 64 | help='Input size of Net, multiple of 32 (default 224).') 65 | 66 | # mobile_net 官方准确率 67 | # https://github.com/keras-team/keras/blob/fe066966b5afa96f2f6b9f71ec0c71158b44068d/keras/applications/mobilenetv2.py#L30 68 | claimed_acc_top1 = {224: {1.4: 0.75, 1.3: 0.744, 1.0: 0.718, 0.75: 0.698, 0.5: 0.654, 0.35: 0.603}, 69 | 192: {1.0: 0.707, 0.75: 0.687, 0.5: 0.639, 0.35: 0.582}, 70 | 160: {1.0: 0.688, 0.75: 0.664, 0.5: 0.610, 0.35: 0.557}, 71 | 128: {1.0: 0.653, 0.75: 0.632, 0.5: 0.577, 0.35: 0.508}, 72 | 96: {1.0: 0.603, 0.75: 0.588, 0.5: 0.512, 0.35: 0.455}, 73 | } 74 | claimed_acc_top5 = {224: {1.4: 0.925, 1.3: 0.921, 1.0: 0.910, 0.75: 0.896, 0.5: 0.864, 0.35: 0.829}, 75 | 192: {1.0: 0.901, 0.75: 0.889, 0.5: 0.854, 0.35: 0.812}, 76 | 160: {1.0: 0.890, 0.75: 0.873, 0.5: 0.832, 0.35: 0.791}, 77 | 128: {1.0: 0.869, 0.75: 0.855, 0.5: 0.808, 0.35: 0.750}, 78 | 96: {1.0: 0.832, 0.75: 0.816, 0.5: 0.758, 0.35: 0.704}, 79 | } 80 | # CompCars carType 81 | classes = ('None', 'MPV', 'SUV', 'sedan', 'hatchback', 'minibus', 'fastback', 'estate', 'pickup', 'hardtop convertible', 82 | 'sports', 'crossover', 'convertible') 83 | 84 | 85 | def main(): 86 | args = parser.parse_args() 87 | 88 | if args.seed is None: 89 | args.seed = random.randint(1, 10000) 90 | print("Random Seed: ", args.seed) 91 | random.seed(args.seed) 92 | torch.manual_seed(args.seed) 93 | if args.gpus: 94 | torch.cuda.manual_seed_all(args.seed) 95 | 96 | time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 97 | if args.evaluate: 98 | args.results_dir = './tmp' 99 | if args.save is '': 100 | args.save = time_stamp 101 | save_path = os.path.join(args.results_dir, args.save) 102 | if not os.path.exists(save_path): 103 | os.makedirs(save_path) 104 | 105 | if args.gpus is not None: 106 | args.gpus = [int(i) for i in args.gpus.split(',')] 107 | device = 'cuda:' + str(args.gpus[0]) 108 | cudnn.benchmark = True 109 | else: 110 | device = 'cpu' 111 | 112 | if args.type == 'float64': 113 | dtype = torch.float64 114 | elif args.type == 'float32': 115 | dtype = torch.float32 116 | elif args.type == 'float16': 117 | dtype = torch.float16 118 | else: 119 | raise ValueError('Wrong type!') # TODO int8 120 | 121 | # define net model 122 | # model = MobileNet2(input_size=args.input_size, scale=args.scaling, num_classes=args.num_class) 123 | model = resnet34(pretrained=False, modelpath=args.data_root, num_classes=args.num_class) 124 | num_parameters = sum([l.nelement() for l in model.parameters()]) 125 | print(model) 126 | print('number of parameters: {}'.format(num_parameters)) 127 | # print('FLOPs: {}'.format( 128 | # flops_benchmark.count_flops(MobileNet2, 129 | # args.batch_size // len(args.gpus) if args.gpus is not None else args.batch_size, 130 | # device, dtype, args.input_size, 3, args.scaling))) 131 | 132 | # define loss function (criterion) and optimizer 133 | criterion = torch.nn.CrossEntropyLoss() 134 | if args.gpus is not None: 135 | model = torch.nn.DataParallel(model, args.gpus) 136 | model.to(device=device, dtype=dtype) 137 | criterion.to(device=device, dtype=dtype) 138 | 139 | optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.decay, 140 | nesterov=True) 141 | 142 | best_test = 0 143 | 144 | # optionally resume from a checkpoint 145 | data = None 146 | if args.resume: 147 | if os.path.isfile(args.resume): 148 | print("=> loading checkpoint '{}'".format(args.resume)) 149 | checkpoint = torch.load(args.resume, map_location=device) 150 | args.start_epoch = checkpoint['epoch'] - 1 151 | best_test = checkpoint['best_prec1'] 152 | model.load_state_dict(checkpoint['state_dict']) 153 | optimizer.load_state_dict(checkpoint['optimizer']) 154 | print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) 155 | elif os.path.isdir(args.resume): 156 | checkpoint_path = os.path.join(args.resume, 'checkpoint.pth.tar') 157 | csv_path = os.path.join(args.resume, 'results.csv') 158 | print("=> loading checkpoint csv '{}'".format(checkpoint_path)) 159 | checkpoint = torch.load(checkpoint_path, map_location=device) 160 | args.start_epoch = checkpoint['epoch'] - 1 161 | best_test = checkpoint['best_prec1'] 162 | model.load_state_dict(checkpoint['state_dict']) 163 | optimizer.load_state_dict(checkpoint['optimizer']) 164 | print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) 165 | data = [] 166 | with open(csv_path) as csvfile: 167 | reader = csv.DictReader(csvfile) 168 | for row in reader: 169 | data.append(row) 170 | else: 171 | print("=> no checkpoint found at '{}'".format(args.resume)) 172 | 173 | if args.evaluate: 174 | test_loader = get_test_loaders(args.data_root, args.batch_size, args.input_size, args.workers) 175 | loss, top1, top5 = test(model, test_loader, criterion, device, dtype,classes) 176 | print("loss:{}, top1:{}, top5:{}".format(loss, top1, top5)) 177 | # TODO 178 | return 179 | 180 | train_loader, val_loader = get_loaders(args.data_root, args.batch_size, args.batch_size, args.input_size, 181 | args.workers) 182 | if args.find_clr: 183 | find_bounds_clr(model, train_loader, optimizer, criterion, device, dtype, min_lr=args.min_lr, 184 | max_lr=args.max_lr, step_size=args.epochs_per_step * len(train_loader), mode=args.mode, 185 | save_path=save_path) 186 | return 187 | 188 | if args.clr: 189 | scheduler = CyclicLR(optimizer, base_lr=args.min_lr, max_lr=args.max_lr, 190 | step_size=args.epochs_per_step * len(train_loader), mode=args.mode) 191 | else: 192 | scheduler = MultiStepLR(optimizer, milestones=args.schedule, gamma=args.gamma) 193 | 194 | csv_logger = CsvLogger(filepath=save_path, data=data) 195 | csv_logger.save_params(sys.argv, args) 196 | 197 | claimed_acc1 = None 198 | claimed_acc5 = None 199 | if args.input_size in claimed_acc_top1: 200 | if args.scaling in claimed_acc_top1[args.input_size]: 201 | claimed_acc1 = claimed_acc_top1[args.input_size][args.scaling] 202 | claimed_acc5 = claimed_acc_top5[args.input_size][args.scaling] 203 | csv_logger.write_text( 204 | 'Claimed accuracies are: {:.2f}% top-1, {:.2f}% top-5'.format(claimed_acc1 * 100., claimed_acc5 * 100.)) 205 | train_network(args.start_epoch, args.epochs, scheduler, model, train_loader, val_loader, optimizer, criterion, 206 | device, dtype, args.batch_size, args.log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, 207 | best_test) 208 | 209 | 210 | def train_network(start_epoch, epochs, scheduler, model, train_loader, val_loader, optimizer, criterion, device, dtype, 211 | batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test): 212 | for epoch in trange(start_epoch, epochs + 1): 213 | # 不知道该在哪个位置调整 214 | # if not isinstance(scheduler, CyclicLR): 215 | # scheduler.step() 216 | train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, epoch, optimizer, criterion, device, 217 | dtype, batch_size, log_interval, scheduler) 218 | test_loss, test_accuracy1, test_accuracy5 = val(model, val_loader, criterion, device, dtype) 219 | 220 | csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5, 221 | 'val_loss': test_loss, 'train_error1': 1 - train_accuracy1, 222 | 'train_error5': 1 - train_accuracy5, 'train_loss': train_loss}) 223 | save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test, 224 | 'optimizer': optimizer.state_dict()}, test_accuracy1 > best_test, filepath=save_path) 225 | 226 | csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5) 227 | 228 | if test_accuracy1 > best_test: 229 | best_test = test_accuracy1 230 | 231 | scheduler.step() 232 | 233 | csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.)) 234 | 235 | 236 | if __name__ == '__main__': 237 | main() 238 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2020/5/17 18:03 4 | # @Author : Jelly 5 | # @File : __init__.py.py 6 | -------------------------------------------------------------------------------- /util/transfor.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision 5 | 6 | model_pth = os.path.join("results", "2020-04-27_10-27-17", 'checkpoint.pth.tar') 7 | # 将resnet34模型保存为Android可以调用的文件 8 | mobile_pt = os.path.join("results", "2020-04-27_10-27-17", 'resnet34.pt') 9 | num_class = 13 10 | device = 'cpu' # 'cuda:0' # cpu 11 | 12 | model = torchvision.models.resnet34(num_classes=num_class) 13 | model = torch.nn.DataParallel(model, [0]) 14 | model.to(device=device) 15 | 16 | checkpoint = torch.load(model_pth, map_location=device) 17 | model.load_state_dict(checkpoint['state_dict']) 18 | 19 | model.eval() # 模型设为评估模式 20 | # 1张3通道224*224的图片 21 | input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式 22 | traced_script_module = torch.jit.trace(model.module, input_tensor) # 模型转化 23 | traced_script_module.save(mobile_pt) # 保存文件 24 | --------------------------------------------------------------------------------