├── utils ├── progress │ ├── MANIFEST.in │ ├── demo.gif │ ├── LICENSE │ ├── setup.py │ ├── progress │ │ ├── spinner.py │ │ ├── counter.py │ │ ├── bar.py │ │ ├── helpers.py │ │ └── __init__.py │ ├── test_progress.py │ └── README.rst ├── __init__.py ├── eval.py ├── visualize.py ├── utils.py ├── logger.py ├── misc.py └── radam.py ├── requirements.txt ├── models ├── __init__.py ├── resnetxt_wsl.py ├── Res.py └── resnet_cbam.py ├── predict ├── config.json ├── resnetxt_wsl.py ├── customize_service.py ├── predict.py ├── Res.py └── resnet_cbam.py ├── README.md ├── transform.py ├── dataset.py ├── preprocess.py ├── args.py └── train.py /utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QLMX/huawei-garbage/HEAD/utils/progress/demo.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch==1.0.1 2 | torchvision==0.2.2 3 | matplotlib==3.1.0 4 | numpy==1.16.4 5 | scikit-image 6 | pandas 7 | sklearn 8 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: ikkyu-wen 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-17 01:47 8 | """ 9 | from __future__ import absolute_import 10 | 11 | from .resnetxt_wsl import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | from .utils import * 8 | from .radam import * 9 | 10 | # progress bar 11 | import os, sys 12 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 13 | print(os.path.dirname(__file__)) 14 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy', 'precision'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | 20 | def precision(output, target): 21 | pass -------------------------------------------------------------------------------- /utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /predict/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "PyTorch", 3 | "runtime":"python3.6", 4 | "model_algorithm": "image_classification", 5 | "metrics": { 6 | "f1": 0.345294, 7 | "accuracy": 0.462963, 8 | "precision": 0.338977, 9 | "recall": 0.351852 10 | }, 11 | "apis": [{ 12 | "protocol": "http", 13 | "url": "/", 14 | "method": "post", 15 | "request": { 16 | "Content-type": "multipart/form-data", 17 | "data": { 18 | "type": "object", 19 | "properties": { 20 | "input_img": {"type": "file"} 21 | }, 22 | "required": ["input_img"] 23 | } 24 | }, 25 | "response": { 26 | "Content-type": "multipart/form-data", 27 | "data": { 28 | "type": "object", 29 | "properties": { 30 | "result": {"type": "string"} 31 | }, 32 | "required": ["result"] 33 | } 34 | } 35 | }], 36 | "dependencies": [{ 37 | "installer": "pip", 38 | "packages": [ 39 | { 40 | "restraint": "ATLEAST", 41 | "package_version": "5.2.0", 42 | "package_name": "Pillow" 43 | } 44 | ] 45 | }] 46 | } -------------------------------------------------------------------------------- /utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | -------------------------------------------------------------------------------- /utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 华为云垃圾分类挑战杯亚军方案分享 2 | 3 | ### 1.代码结构 4 | 5 | ``` 6 | {repo_root} 7 | ├── models //模型文件夹 8 | ├── utils //一些函数包 9 | | ├── eval.py // 求精度 10 | │ ├── misc.py // 模型保存,参数初始化,优化函数选择 11 | │ ├── radam.py 12 | │ └── ... 13 | ├── args.py //参数配置文件 14 | ├── build_net.py //搭建模型 15 | ├── dataset.py //数据批量加载文件 16 | ├── preprocess.py //数据预处理文件,生成坐标标签 17 | ├── train.py //训练运行文件 18 | ├── transform.py //数据增强文件 19 | ``` 20 | 21 | ### 2. 环境设置 22 | 23 | 可以直接通过`pip install -r requirements.txt`安装指定的函数包,python版本为3.6,具体的函数包如下: 24 | 25 | * pytorch>=1.0.1 26 | * torchvision==0.2.2 27 | * matplotlib>=3.1.0 28 | * numpy>=1.16.4 29 | * scikit-image 30 | * pandas 31 | * sklearn 32 | 33 | 注:py3.7训练的话,要修改下面的代码 34 | `if use_cuda: inputs, targets = inputs.cuda(), targets.cuda(async=True) inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)` 35 | \#python3.7已经移除了async关键字,而用non_blocking代替。(导致apache-airflow也出了问题) 36 | \#cuda() 本身也没有async. 37 | 38 | 就是把 async=True去掉 39 | 40 | if use_cuda: 41 | inputs, targets = inputs.cuda(), targets.cuda() 42 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)` 43 | 44 | ## 3.运行步骤 45 | 46 | 1. 建立文件夹data,把garbage_classify全部解压缩到data下 47 | 2. 运行preprocess.py,生成训练集和测试集运行文 48 | 3. 单张显卡的话,修改arg.py 85行 parser.add_argument('--gpu-id', default='0, 1, 2, 3' 为'--gpu-id', default='0',同时修改 '--train-batch','--test-batch'为适当的数字 49 | 4. 运行train.py 50 | 51 | ### 4.方案思路 52 | 53 | [方案讲解](https://mp.weixin.qq.com/s/7GhXMXQkBgH_JVcKMjCejQ) 54 | 55 | 知乎专栏:[ML与DL成长之路](https://zhuanlan.zhihu.com/ai-growth) 56 | 57 | 如果复现过程中有bug,麻烦反馈一下,会优化更新。如果对您有帮助记得给个**star** 58 | 59 | --- 60 | 61 | **小尾巴** 62 | 63 | QQ群:AI成长社①:545702197 64 | 65 | 微信群:添加微信号:Derek_wen8,备注:加群 -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-09-07 18:54 8 | 公众号:AI成长社 9 | 知乎:https://www.zhihu.com/people/qlmx-61/columns 10 | """ 11 | import random 12 | import math 13 | import torch 14 | 15 | from PIL import Image, ImageOps, ImageFilter 16 | from torchvision import transforms 17 | 18 | class Resize(object): 19 | def __init__(self, size, interpolation=Image.BILINEAR): 20 | self.size = size 21 | self.interpolation = interpolation 22 | 23 | def __call__(self, img): 24 | # padding 25 | ratio = self.size[0] / self.size[1] 26 | w, h = img.size 27 | if w / h < ratio: 28 | t = int(h * ratio) 29 | w_padding = (t - w) // 2 30 | img = img.crop((-w_padding, 0, w+w_padding, h)) 31 | else: 32 | t = int(w / ratio) 33 | h_padding = (t - h) // 2 34 | img = img.crop((0, -h_padding, w, h+h_padding)) 35 | 36 | img = img.resize(self.size, self.interpolation) 37 | 38 | return img 39 | 40 | class RandomRotate(object): 41 | def __init__(self, degree, p=0.5): 42 | self.degree = degree 43 | self.p = p 44 | 45 | def __call__(self, img): 46 | if random.random() < self.p: 47 | rotate_degree = random.uniform(-1*self.degree, self.degree) 48 | img = img.rotate(rotate_degree, Image.BILINEAR) 49 | return img 50 | 51 | class RandomGaussianBlur(object): 52 | def __init__(self, p=0.5): 53 | self.p = p 54 | def __call__(self, img): 55 | if random.random() < self.p: 56 | img = img.filter(ImageFilter.GaussianBlur( 57 | radius=random.random())) 58 | return img 59 | 60 | def get_train_transform(mean, std, size): 61 | train_transform = transforms.Compose([ 62 | Resize((int(size * (256 / 224)), int(size * (256 / 224)))), 63 | transforms.RandomCrop(size), 64 | 65 | transforms.RandomHorizontalFlip(), 66 | # RandomRotate(15, 0.3), 67 | # RandomGaussianBlur(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(mean=mean, std=std), 70 | ]) 71 | return train_transform 72 | 73 | def get_test_transform(mean, std, size): 74 | return transforms.Compose([ 75 | Resize((int(size * (256 / 224)), int(size * (256 / 224)))), 76 | transforms.CenterCrop(size), 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=mean, std=std), 79 | ]) 80 | 81 | def get_transforms(input_size=224, test_size=224, backbone=None): 82 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 83 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 84 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 85 | transformations = {} 86 | transformations['val_train'] = get_train_transform(mean, std, input_size) 87 | transformations['val_test'] = get_test_transform(mean, std, test_size) 88 | return transformations 89 | 90 | -------------------------------------------------------------------------------- /utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-09-07 20:27 8 | 公众号:AI成长社 9 | 知乎:https://www.zhihu.com/people/qlmx-61/columns 10 | """ 11 | import random 12 | import torch 13 | from torch.utils.data import Dataset 14 | from torch.utils.data import sampler 15 | import torchvision.transforms as transforms 16 | import pandas as pd 17 | import six 18 | import sys 19 | from PIL import Image 20 | import numpy as np 21 | 22 | class Dataset(Dataset): 23 | def __init__(self, root=None, transform=None, target_transform=None, to=None): 24 | if '.txt' in root: 25 | self.env = list(open(root)) 26 | else: 27 | self.env = root 28 | 29 | if not self.env: 30 | print('cannot creat lmdb from %s' % (root)) 31 | sys.exit(0) 32 | 33 | self.len = len(self.env) - 1 34 | 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | 38 | def __len__(self): 39 | return self.len 40 | 41 | def __getitem__(self, index): 42 | assert index <= len(self), 'index range error' 43 | index += 1 44 | img_path, label = self.env[index].strip().split(',') 45 | 46 | try: 47 | img = Image.open(img_path) 48 | except: 49 | print(img_path) 50 | print('Corrupted image for %d' % index) 51 | return self[index + 1] 52 | 53 | if self.transform is not None: 54 | if img.layers == 1: 55 | print(img_path) 56 | img = self.transform(img) 57 | 58 | if self.target_transform is not None: 59 | label = self.target_transform(label) 60 | return (img, int(label)) 61 | 62 | class TestDataset(Dataset): 63 | def __init__(self, root=None, transform=None, target_transform=None, to=None): 64 | if '.txt' in root: 65 | self.env = list(open(root)) 66 | else: 67 | self.env = root 68 | 69 | if not self.env: 70 | print('cannot creat lmdb from %s' % (root)) 71 | sys.exit(0) 72 | 73 | self.len = len(self.env) - 1 74 | 75 | self.transform = transform 76 | self.target_transform = target_transform 77 | self.to = to 78 | 79 | def __len__(self): 80 | return self.len 81 | 82 | def __getitem__(self, index): 83 | assert index <= len(self), 'index range error' 84 | index += 1 85 | img_path, label = self.env[index].strip().split(',') 86 | 87 | try: 88 | img = Image.open(img_path) 89 | except: 90 | print(img_path) 91 | print('Corrupted image for %d' % index) 92 | return self[index + 1] 93 | 94 | if self.transform is not None: 95 | img = self.transform(img) 96 | 97 | 98 | if self.target_transform is not None: 99 | label = self.target_transform(label) 100 | 101 | return (img, int(label)) 102 | 103 | 104 | class resizeNormalize(object): 105 | 106 | def __init__(self, size, interpolation=Image.BILINEAR): 107 | self.size = size 108 | self.interpolation = interpolation 109 | self.toTensor = transforms.ToTensor() 110 | 111 | def __call__(self, img): 112 | # padding 113 | ratio = self.size[0] / self.size[1] 114 | w, h = img.size 115 | if w / h < ratio: 116 | t = int(h * ratio) 117 | w_padding = (t - w) // 2 118 | img = img.crop((-w_padding, 0, w+w_padding, h)) 119 | else: 120 | t = int(w / ratio) 121 | h_padding = (t - h) // 2 122 | img = img.crop((0, -h_padding, w, h+h_padding)) 123 | 124 | # img.show() 125 | # resize 126 | img = img.resize(self.size, self.interpolation) 127 | img = self.toTensor(img) 128 | img.sub_(0.5).div_(0.5) 129 | return img -------------------------------------------------------------------------------- /models/resnetxt_wsl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Optional list of dependencies required by the package 8 | 9 | ''' 10 | Code From : https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py 11 | ''' 12 | __all__ = ['resnext101_32x8d_wsl', 'resnext101_32x16d_wsl', 'resnext101_32x32d_wsl', 'resnext101_32x48d_wsl'] 13 | 14 | dependencies = ['torch', 'torchvision'] 15 | 16 | try: 17 | from torch.hub import load_state_dict_from_url 18 | except ImportError: 19 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 20 | 21 | # from .Res import ResNet, Bottleneck 22 | from .resnet_cbam import ResNet, Bottleneck 23 | 24 | model_urls = { 25 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 26 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 27 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 28 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 29 | } 30 | 31 | 32 | # def _resnext(arch, block, layers, pretrained, progress, **kwargs): 33 | # model = ResNet(block, layers, **kwargs) 34 | # state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 35 | # model.load_state_dict(state_dict) 36 | # return model 37 | 38 | #使用部分加载 39 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 40 | model = ResNet(block, layers, **kwargs) 41 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 42 | new_state_dict = model.state_dict() 43 | new_state_dict.update(state_dict) 44 | model.load_state_dict(new_state_dict) 45 | return model 46 | 47 | 48 | def resnext101_32x8d_wsl(progress=True, **kwargs): 49 | """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data 50 | and finetuned on ImageNet from Figure 5 in 51 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 52 | Args: 53 | progress (bool): If True, displays a progress bar of the download to stderr. 54 | """ 55 | kwargs['groups'] = 32 56 | kwargs['width_per_group'] = 8 57 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 58 | 59 | 60 | def resnext101_32x16d_wsl(progress=True, **kwargs): 61 | """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data 62 | and finetuned on ImageNet from Figure 5 in 63 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 64 | Args:zz 65 | progress (bool): If True, displays a progress bar of the download to stderr. 66 | """ 67 | kwargs['groups'] = 32 68 | kwargs['width_per_group'] = 16 69 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 70 | 71 | 72 | def resnext101_32x32d_wsl(progress=True, **kwargs): 73 | """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data 74 | and finetuned on ImageNet from Figure 5 in 75 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 76 | Args: 77 | progress (bool): If True, displays a progress bar of the download to stderr. 78 | """ 79 | kwargs['groups'] = 32 80 | kwargs['width_per_group'] = 32 81 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 82 | 83 | 84 | def resnext101_32x48d_wsl(progress=True, **kwargs): 85 | """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data 86 | and finetuned on ImageNet from Figure 5 in 87 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 88 | Args: 89 | progress (bool): If True, displays a progress bar of the download to stderr. 90 | """ 91 | kwargs['groups'] = 32 92 | kwargs['width_per_group'] = 48 93 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 94 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /predict/resnetxt_wsl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Optional list of dependencies required by the package 8 | 9 | ''' 10 | Code From : https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py 11 | ''' 12 | __all__ = ['resnext101_32x8d_wsl', 'resnext101_32x16d_wsl', 'resnext101_32x32d_wsl', 'resnext101_32x48d_wsl'] 13 | 14 | dependencies = ['torch', 'torchvision'] 15 | 16 | try: 17 | from torch.hub import load_state_dict_from_url 18 | except ImportError: 19 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 20 | 21 | # from Res import ResNet, Bottleneck 22 | from .resnet_cbam import ResNet, Bottleneck 23 | 24 | model_urls = { 25 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 26 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 27 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 28 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 29 | } 30 | 31 | 32 | # def _resnext(arch, block, layers, pretrained, progress, **kwargs): 33 | # model = ResNet(block, layers, **kwargs) 34 | # if pretrained: 35 | # state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 36 | # model.load_state_dict(state_dict) 37 | # return model 38 | 39 | #使用部分加载 40 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 41 | model = ResNet(block, layers, **kwargs) 42 | if pretrained: 43 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 44 | new_state_dict = model.state_dict() 45 | new_state_dict.update(state_dict) 46 | model.load_state_dict(new_state_dict) 47 | return model 48 | 49 | 50 | def resnext101_32x8d_wsl(pretrained=True, progress=True, **kwargs): 51 | """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data 52 | and finetuned on ImageNet from Figure 5 in 53 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 54 | Args: 55 | progress (bool): If True, displays a progress bar of the download to stderr. 56 | """ 57 | kwargs['groups'] = 32 58 | kwargs['width_per_group'] = 8 59 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 60 | 61 | 62 | def resnext101_32x16d_wsl(pretrained=True, progress=True, **kwargs): 63 | """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data 64 | and finetuned on ImageNet from Figure 5 in 65 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 66 | Args: 67 | progress (bool): If True, displays a progress bar of the download to stderr. 68 | """ 69 | kwargs['groups'] = 32 70 | kwargs['width_per_group'] = 16 71 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 72 | 73 | 74 | def resnext101_32x32d_wsl(pretrained=True, progress=True, **kwargs): 75 | """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data 76 | and finetuned on ImageNet from Figure 5 in 77 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 78 | Args: 79 | progress (bool): If True, displays a progress bar of the download to stderr. 80 | """ 81 | kwargs['groups'] = 32 82 | kwargs['width_per_group'] = 32 83 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 84 | 85 | 86 | def resnext101_32x48d_wsl(pretrained=True, progress=True, **kwargs): 87 | """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data 88 | and finetuned on ImageNet from Figure 5 in 89 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 90 | Args: 91 | progress (bool): If True, displays a progress bar of the download to stderr. 92 | """ 93 | kwargs['groups'] = 32 94 | kwargs['width_per_group'] = 48 95 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 96 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-14 11:16 8 | 公众号:AI成长社 9 | 知乎:https://www.zhihu.com/people/qlmx-61/columns 10 | """ 11 | from glob import glob 12 | import os 13 | import codecs 14 | import random 15 | import numpy as np 16 | from sklearn.model_selection import KFold, StratifiedKFold 17 | 18 | 19 | base_path = 'data/' 20 | data_path = base_path + 'garbage_classify/train_data' 21 | 22 | label_files = glob(os.path.join(data_path, '*.txt')) 23 | img_paths = [] 24 | labels = [] 25 | result = [] 26 | label_dict = {} 27 | data_dict = {} 28 | 29 | for index, file_path in enumerate(label_files): 30 | with codecs.open(file_path, 'r', 'utf-8') as f: 31 | line = f.readline() 32 | line_split = line.strip().split(', ') 33 | if len(line_split) != 2: 34 | print('%s contain error lable' % os.path.basename(file_path)) 35 | continue 36 | img_name = line_split[0] 37 | label = int(line_split[1]) 38 | img_paths.append(os.path.join(data_path, img_name)) 39 | labels.append(label) 40 | result.append(os.path.join(data_path, img_name) + ',' + str(label)) 41 | label_dict[label] = label_dict.get(label, 0) + 1 42 | if label not in data_dict: 43 | data_dict[label] = [] 44 | data_dict[label].append(os.path.join(data_path, img_name) + ',' + str(label)) 45 | 46 | data_path_add = base_path + 'garbage_classify_v3' 47 | label_files_add = glob(os.path.join(data_path_add, '*.txt')) 48 | 49 | for index, file_path in enumerate(label_files_add): 50 | with codecs.open(file_path, 'r', 'utf-8') as f: 51 | line = f.readline() 52 | line_split = line.strip().split(', ') 53 | if len(line_split) != 2: 54 | print('%s contain error lable' % os.path.basename(file_path)) 55 | continue 56 | img_name = line_split[0] 57 | label = int(line_split[1]) 58 | img_paths.append(os.path.join(data_path_add, img_name)) 59 | labels.append(label) 60 | result.append(os.path.join(data_path_add, img_name) + ',' + str(label)) 61 | label_dict[label] = label_dict.get(label, 0) + 1 62 | if label not in data_dict: 63 | data_dict[label] = [] 64 | data_dict[label].append(os.path.join(data_path_add, img_name) + ',' + str(label)) 65 | 66 | 67 | folds = StratifiedKFold(n_splits=10, shuffle=True, random_state=2019) 68 | for fold_, (trn_idx, val_idx) in enumerate(folds.split(result, labels)): 69 | train_data = list(np.array(result)[trn_idx]) 70 | val_data = list(np.array(result)[val_idx]) 71 | 72 | print(len(train_data), len(val_data)) 73 | 74 | with open(base_path + 'train1.txt', 'w') as f1: 75 | for item in train_data: 76 | f1.write(item + '\n') 77 | 78 | with open(base_path + 'val1.txt', 'w') as f2: 79 | for item in val_data: 80 | f2.write(item + '\n') 81 | 82 | 83 | from PIL import Image 84 | 85 | ###predata 2 86 | all_data = [] 87 | train = [] 88 | val = [] 89 | rate = 0.9 90 | import cv2 91 | from tqdm import tqdm 92 | 93 | error_list = ['data/additional_train_data/38/242.jpg', 94 | 'data/additional_train_data/34/79.jpg', 95 | 'data/additional_train_data/27/55.jpg' 96 | 'data/new/8/0.jpg' 97 | ] 98 | 99 | data_path = base_path + 'new/' 100 | for i in range(40): 101 | na_item = [] 102 | img_files = glob(os.path.join(data_path, str(i), '*.jpg')) 103 | for item in tqdm(img_files): 104 | ii = cv2.imread(item) 105 | if item not in error_list: 106 | 107 | jj = Image.open(item).layers 108 | if jj == 1: 109 | print(item) 110 | all_data.append(item + ',' + str(i)) 111 | na_item.append(item + ',' + str(i)) 112 | random.shuffle(na_item) 113 | train.extend(na_item[ : int(len(na_item)*rate)]) 114 | val.extend(na_item[int(len(na_item)*rate):]) 115 | print(len(train), len(val)) 116 | 117 | random.shuffle(all_data) 118 | random.shuffle(train) 119 | random.shuffle(val) 120 | 121 | print(len(all_data)) 122 | 123 | old = [] 124 | with open(base_path + 'train1.txt', 'r') as f: 125 | for i in f.readlines(): 126 | old.append(i.strip()) 127 | for i in all_data: 128 | img_path, label = i.strip().split(',') 129 | 130 | all_data.extend(old) 131 | print(len(all_data)) 132 | random.shuffle(all_data) 133 | 134 | with open(base_path + 'new_shu_label.txt', 'w') as f1: 135 | for item in all_data: 136 | f1.write(item + '\n') -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-20 00:55 8 | """ 9 | import torch.nn as nn 10 | import cv2 11 | import torch 12 | device=torch.device("cuda") 13 | import os 14 | import urllib.parse as urlparse 15 | import requests 16 | import torch 17 | 18 | __all__ = ['GetEncoder', 'GetPreTrainedModel', 'load_pretrained', 'l2_norm'] 19 | 20 | 21 | #修改模型以进行特征提取 22 | def GetEncoder(model): 23 | layerName,layer=list(model.named_children())[-1] 24 | exec("model."+layerName+"=nn.Linear(layer.in_features,layer.in_features)") 25 | exec("torch.nn.init.eye_(model."+layerName+".weight)") 26 | for param in model.parameters(): 27 | param.requires_grad=False 28 | return model,layer.in_features 29 | 30 | #修改模型以进行微调,n_ZeroChild和n_ZeroLayer用来设置参数固定层,当children为Sequential时使用n_ZeroLayer,可对其内部进行设置 31 | def GetPreTrainedModel(model,n_Output,n_ZeroChild,n_ZeroLayer=None): 32 | for i,children in enumerate(model.children()): 33 | if i==n_ZeroChild: 34 | if n_ZeroLayer is not None: 35 | for j,layer in enumerate(children): 36 | if j==n_ZeroLayer: 37 | break 38 | for param in layer.parameters(): 39 | param.requires_grad=False 40 | break 41 | for param in children.parameters(): 42 | param.requires_grad=False 43 | layerName,layer=list(model.named_children())[-1] 44 | exec("model."+layerName+"=nn.Linear(layer.in_features,"+str(n_Output)+")") 45 | return model 46 | 47 | 48 | class StackNet2(nn.Module): 49 | def __init__(self,models,n_Target): 50 | super(StackNet,self).__init__() 51 | self.models=models 52 | n_Out=0 53 | for i,(model,scale_In,n_ZeroChild,n_ZeroLayer) in enumerate(self.models): 54 | for j,children in enumerate(model.children()): 55 | if j==n_ZeroChild: 56 | if n_ZeroLayer is not None: 57 | for k,layer in enumerate(children): 58 | if k==n_ZeroLayer: 59 | break 60 | for param in layer.parameters(): 61 | param.requires_grad=False 62 | break 63 | 64 | layerName,layer=list(model.named_children())[-1] 65 | n_Out+=layer.in_features 66 | exec("model."+layerName+"=nn.Linear(layer.in_features,layer.in_features)") 67 | exec("torch.nn.init.eye_(model."+layerName+".weight)") 68 | exec("layer=model."+layerName) 69 | for param in layer.parameters(): 70 | param.requires_grad=False 71 | exec("self.model"+str(i)+"=model") 72 | self.fc=nn.Linear(n_Out,n_Target) 73 | def forward(self,x): 74 | feature=[] 75 | for model,scale_In,_,_ in self.models: 76 | feature.append(model(x)) 77 | feature=torch.cat(feature,dim=1) 78 | return self.fc(feature) 79 | 80 | 81 | def _download_file_from_google_drive(fid, dest): 82 | def _get_confirm_token(res): 83 | for k, v in res.cookies.items(): 84 | if k.startswith('download_warning'): return v 85 | return None 86 | 87 | def _save_response_content(res, dest): 88 | CHUNK_SIZE = 32768 89 | with open(dest, "wb") as f: 90 | for chunk in res.iter_content(CHUNK_SIZE): 91 | if chunk: f.write(chunk) 92 | 93 | URL = "https://docs.google.com/uc?export=download" 94 | sess = requests.Session() 95 | res = sess.get(URL, params={'id': fid}, stream=True) 96 | token = _get_confirm_token(res) 97 | 98 | if token: 99 | params = {'id': fid, 'confirm': token} 100 | res = sess.get(URL, params=params, stream=True) 101 | _save_response_content(res, dest) 102 | 103 | 104 | def _load_url(url, dest): 105 | if os.path.isfile(dest) and os.path.exists(dest): return dest 106 | print('[INFO]: Downloading weights...') 107 | fid = urlparse.parse_qs(urlparse.urlparse(url).query)['id'][0] 108 | _download_file_from_google_drive(fid, dest) 109 | return dest 110 | 111 | 112 | def load_pretrained(m, meta, dest, pretrained=False): 113 | if pretrained: 114 | if len(meta) == 0: 115 | print('[INFO]: Pretrained model not available') 116 | return m 117 | if dest is None: dest = meta[0] 118 | else: 119 | dest = dest + '/' + meta[0] 120 | print(dest) 121 | m.load_state_dict(torch.load(_load_url(meta[1], dest))) 122 | return m 123 | 124 | def l2_norm(input,axis=1): 125 | norm = torch.norm(input,2,axis,True) 126 | output = torch.div(input, norm) 127 | return output -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-15 14:25 8 | 公众号:AI成长社 9 | 知乎:https://www.zhihu.com/people/qlmx-61/columns 10 | """ 11 | import argparse 12 | from build_net import model_names 13 | 14 | # Parse arguments 15 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 16 | 17 | # Datasets 18 | parser.add_argument('-train', '--trainroot', default='data/new_shu_label.txt', type=str) #new_shu_label 19 | parser.add_argument('-val', '--valroot', default='data/val1.txt', type=str) 20 | 21 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 22 | help='number of data loading workers (default: 4)') 23 | # Optimization options 24 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 25 | help='number of total epochs to run') 26 | parser.add_argument('--num-classes', default=43, type=int, metavar='N', 27 | help='number of classfication of image') 28 | parser.add_argument('--image-size', default=288, type=int, metavar='N', 29 | help='the train image size') 30 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 31 | help='manual epoch number (useful on restarts)') 32 | parser.add_argument('--train-batch', default=64, type=int, metavar='N', 33 | help='train batchsize (default: 256)') 34 | parser.add_argument('--test-batch', default=32, type=int, metavar='N', 35 | help='test batchsize (default: 200)') 36 | parser.add_argument('--optimizer', default='sgd', 37 | choices=['sgd', 'rmsprop', 'adam', 'AdaBound', 'radam'], metavar='N', 38 | help='optimizer (default=sgd)') 39 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 40 | metavar='LR', help='initial learning rate,1e-2, 1e-4, 0.001') 41 | parser.add_argument('--lr-fc-times', '--lft', default=5, type=int, 42 | metavar='LR', help='initial model last layer rate') 43 | parser.add_argument('--drop', '--dropout', default=0, type=float, 44 | metavar='Dropout', help='Dropout ratio') 45 | parser.add_argument('--schedule', type=int, nargs='+', default=[30, 50, 60], 46 | help='Decrease learning rate at these epochs.') 47 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 48 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 49 | help='momentum') 50 | parser.add_argument('--no_nesterov', dest='nesterov', 51 | action='store_false', 52 | help='do not use Nesterov momentum') 53 | parser.add_argument('--alpha', default=0.99, type=float, metavar='M', 54 | help='alpha for ') 55 | parser.add_argument('--beta1', default=0.9, type=float, metavar='M', 56 | help='beta1 for Adam (default: 0.9)') 57 | parser.add_argument('--beta2', default=0.999, type=float, metavar='M', 58 | help='beta2 for Adam (default: 0.999)') 59 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 60 | metavar='W', help='weight decay (default: 1e-4)') 61 | parser.add_argument('--final-lr', '--fl', default=1e-3,type=float, 62 | metavar='W', help='weight decay (default: 1e-3)') 63 | # Checkpoints 64 | parser.add_argument('-c', '--checkpoint', default='/data0/search/qlmx/clover/garbage/res_16_288_last1', type=str, metavar='PATH', 65 | help='path to save checkpoint (default: checkpoint)') 66 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 67 | help='path to latest checkpoint (default: none)') 68 | # Architecture 69 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext101_32x16d_wsl', 70 | choices=model_names, 71 | help='model architecture: ' + 72 | ' | '.join(model_names) + 73 | ' (default: resnext101_32x8d, pnasnet5large)') 74 | parser.add_argument('--depth', type=int, default=29, help='Model depth.') 75 | parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') 76 | parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') 77 | parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') 78 | # Miscs 79 | parser.add_argument('--manualSeed', type=int, help='manual seed') 80 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 81 | help='evaluate model on validation set') 82 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 83 | help='use pre-trained model') 84 | #Device options 85 | parser.add_argument('--gpu-id', default='0, 1, 2, 3', type=str, 86 | help='id(s) for CUDA_VISIBLE_DEVICES') 87 | 88 | args = parser.parse_args() 89 | -------------------------------------------------------------------------------- /predict/customize_service.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-14 16:07 8 | """ 9 | import torchvision.transforms as transforms 10 | import torch 11 | from torch import nn 12 | from PIL import Image 13 | import torch.nn.functional as F 14 | from resnetxt_wsl import resnext101_32x8d_wsl, resnext101_32x16d_wsl, resnext101_32x32d_wsl 15 | 16 | 17 | from model_service.pytorch_model_service import PTServingBaseService 18 | 19 | import time 20 | from metric.metrics_manager import MetricsManager 21 | import log 22 | logger = log.getLogger(__name__) 23 | 24 | 25 | args = {} 26 | args['arch'] = 'resnext101_32x16d_wsl' 27 | args['pretrained'] = False 28 | args['num_classes'] = 43 29 | args['big_size'] = 327 30 | args['image_size'] = 288 31 | torch.backends.cudnn.benchmark = True 32 | 33 | class classfication_service(PTServingBaseService): 34 | def __init__(self, model_name, model_path): 35 | super(classfication_service, self).__init__(model_name, model_path) 36 | self.pre_img = self.preprocess_img1() 37 | self.model = self.build_model(model_path) 38 | self.model = self.model.cuda() 39 | self.model.eval() 40 | 41 | self.label_id_name_dict = \ 42 | { 43 | "0": "其他垃圾/一次性快餐盒", 44 | "1": "其他垃圾/污损塑料", 45 | "2": "其他垃圾/烟蒂", 46 | "3": "其他垃圾/牙签", 47 | "4": "其他垃圾/破碎花盆及碟碗", 48 | "5": "其他垃圾/竹筷", 49 | "6": "厨余垃圾/剩饭剩菜", 50 | "7": "厨余垃圾/大骨头", 51 | "8": "厨余垃圾/水果果皮", 52 | "9": "厨余垃圾/水果果肉", 53 | "10": "厨余垃圾/茶叶渣", 54 | "11": "厨余垃圾/菜叶菜根", 55 | "12": "厨余垃圾/蛋壳", 56 | "13": "厨余垃圾/鱼骨", 57 | "14": "可回收物/充电宝", 58 | "15": "可回收物/包", 59 | "16": "可回收物/化妆品瓶", 60 | "17": "可回收物/塑料玩具", 61 | "18": "可回收物/塑料碗盆", 62 | "19": "可回收物/塑料衣架", 63 | "20": "可回收物/快递纸袋", 64 | "21": "可回收物/插头电线", 65 | "22": "可回收物/旧衣服", 66 | "23": "可回收物/易拉罐", 67 | "24": "可回收物/枕头", 68 | "25": "可回收物/毛绒玩具", 69 | "26": "可回收物/洗发水瓶", 70 | "27": "可回收物/玻璃杯", 71 | "28": "可回收物/皮鞋", 72 | "29": "可回收物/砧板", 73 | "30": "可回收物/纸板箱", 74 | "31": "可回收物/调料瓶", 75 | "32": "可回收物/酒瓶", 76 | "33": "可回收物/金属食品罐", 77 | "34": "可回收物/锅", 78 | "35": "可回收物/食用油桶", 79 | "36": "可回收物/饮料瓶", 80 | "37": "有害垃圾/干电池", 81 | "38": "有害垃圾/软膏", 82 | "39": "有害垃圾/过期药物", 83 | "40": "可回收物/毛巾", 84 | "41": "可回收物/饮料盒", 85 | "42": "可回收物/纸袋" 86 | } 87 | 88 | def build_model(self, model_path): 89 | model = resnext101_32x16d_wsl(pretrained=False, progress=False) 90 | model.fc = nn.Sequential( 91 | nn.Dropout(0.2), 92 | nn.Linear(2048, 43) 93 | ) 94 | model.load_state_dict(torch.load(model_path)) 95 | return model 96 | 97 | def preprocess_img(self): 98 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 99 | infer_transformation = transforms.Compose([ 100 | Resize((args['image_size'], args['image_size'])), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=mean, std=std), 103 | ]) 104 | return infer_transformation 105 | 106 | def preprocess_img1(self): 107 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 108 | return transforms.Compose([ 109 | Resize((329, 329)), 110 | transforms.CenterCrop(288), 111 | transforms.ToTensor(), 112 | transforms.Normalize(mean=mean, std=std), 113 | ]) 114 | 115 | def _preprocess(self, data): 116 | preprocessed_data = {} 117 | for k, v in data.items(): 118 | for file_name, file_content in v.items(): 119 | img = Image.open(file_content) 120 | img = self.pre_img(img) 121 | preprocessed_data[k] = img 122 | return preprocessed_data 123 | 124 | def _inference(self, data): 125 | img = data['input_img'] 126 | img = img.unsqueeze(0).cuda() 127 | with torch.no_grad(): 128 | pred_score = self.model(img) 129 | 130 | if pred_score is not None: 131 | _, pred_label = torch.max(pred_score.data, 1) 132 | return {'result': self.label_id_name_dict[str(pred_label[0].item())]} 133 | else: 134 | return {'result': 'predict score is None'} 135 | 136 | # return result 137 | 138 | def _postprocess(self, data): 139 | return data 140 | 141 | 142 | class Resize(object): 143 | def __init__(self, size, interpolation=Image.BILINEAR): 144 | self.size = size 145 | self.interpolation = interpolation 146 | 147 | def __call__(self, img): 148 | ratio = 1 149 | w, h = img.size 150 | if w / h < ratio: 151 | t = int(h * ratio) 152 | w_padding = (t - w) // 2 153 | img = img.crop((-w_padding, 0, w+w_padding, h)) 154 | else: 155 | t = int(w / ratio) 156 | h_padding = (t - h) // 2 157 | img = img.crop((0, -h_padding, w, h+h_padding)) 158 | 159 | img = img.resize(self.size, self.interpolation) 160 | 161 | return img -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | import torch 16 | import shutil 17 | import adabound 18 | from utils.radam import RAdam, AdamW 19 | import torchvision.transforms as transforms 20 | 21 | 22 | 23 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', 'get_optimizer', 'save_checkpoint'] 24 | 25 | 26 | def get_mean_and_std(dataset): 27 | '''Compute the mean and std value of dataset.''' 28 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 29 | 30 | mean = torch.zeros(3) 31 | std = torch.zeros(3) 32 | print('==> Computing mean and std..') 33 | for inputs, targets in dataloader: 34 | for i in range(3): 35 | mean[i] += inputs[:,i,:,:].mean() 36 | std[i] += inputs[:,i,:,:].std() 37 | mean.div_(len(dataset)) 38 | std.div_(len(dataset)) 39 | return mean, std 40 | 41 | def init_params(net): 42 | '''Init layer parameters.''' 43 | for m in net.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | init.kaiming_normal(m.weight, mode='fan_out') 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | init.constant(m.weight, 1) 50 | init.constant(m.bias, 0) 51 | elif isinstance(m, nn.Linear): 52 | init.normal(m.weight, std=1e-3) 53 | if m.bias: 54 | init.constant(m.bias, 0) 55 | 56 | def mkdir_p(path): 57 | '''make dir if not exist''' 58 | try: 59 | os.makedirs(path) 60 | except OSError as exc: # Python >2.5 61 | if exc.errno == errno.EEXIST and os.path.isdir(path): 62 | pass 63 | else: 64 | raise 65 | 66 | class AverageMeter(object): 67 | """Computes and stores the average and current value 68 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 69 | """ 70 | def __init__(self): 71 | self.reset() 72 | 73 | def reset(self): 74 | self.val = 0 75 | self.avg = 0 76 | self.sum = 0 77 | self.count = 0 78 | 79 | def update(self, val, n=1): 80 | self.val = val 81 | self.sum += val * n 82 | self.count += n 83 | self.avg = self.sum / self.count 84 | 85 | def get_optimizer(model, args): 86 | parameters = [] 87 | for name, param in model.named_parameters(): 88 | if 'fc' in name or 'class' in name or 'last_linear' in name or 'ca' in name or 'sa' in name: 89 | parameters.append({'params': param, 'lr': args.lr * args.lr_fc_times}) 90 | else: 91 | parameters.append({'params': param, 'lr': args.lr}) 92 | 93 | if args.optimizer == 'sgd': 94 | return torch.optim.SGD(parameters, 95 | # model.parameters(), 96 | args.lr, 97 | momentum=args.momentum, nesterov=args.nesterov, 98 | weight_decay=args.weight_decay) 99 | elif args.optimizer == 'rmsprop': 100 | return torch.optim.RMSprop(parameters, 101 | # model.parameters(), 102 | args.lr, 103 | alpha=args.alpha, 104 | weight_decay=args.weight_decay) 105 | elif args.optimizer == 'adam': 106 | return torch.optim.Adam(parameters, 107 | # model.parameters(), 108 | args.lr, 109 | betas=(args.beta1, args.beta2), 110 | weight_decay=args.weight_decay) 111 | elif args.optimizer == 'AdaBound': 112 | return adabound.AdaBound(parameters, 113 | # model.parameters(), 114 | lr=args.lr, final_lr=args.final_lr) 115 | elif args.optimizer == 'radam': 116 | return RAdam(parameters, lr=args.lr, betas=(args.beta1, args.beta2), 117 | weight_decay=args.weight_decay) 118 | 119 | else: 120 | raise NotImplementedError 121 | 122 | 123 | def save_checkpoint(state, is_best, single=True, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 124 | if single: 125 | fold = '' 126 | else: 127 | fold = str(state['fold']) + '_' 128 | cur_name = 'checkpoint.pth.tar' 129 | filepath = os.path.join(checkpoint, fold + cur_name) 130 | curpath = os.path.join(checkpoint, fold + 'model_cur.pth') 131 | 132 | torch.save(state, filepath) 133 | torch.save(state['state_dict'], curpath) 134 | 135 | if is_best and state['epoch'] >= 5: 136 | model_name = 'model_' + str(state['epoch']) + '_' + str(int(round(state['train_acc']*100, 0))) + '_' + str(int(round(state['acc']*100, 0))) + '.pth' 137 | model_path = os.path.join(checkpoint, fold + model_name) 138 | torch.save(state['state_dict'], model_path) 139 | 140 | 141 | def save_checkpoint2(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): 142 | # best_model = '/application/search/qlmx/clover/garbage/code/image_classfication/predict/' 143 | fold = str(state['fold']) + '_' 144 | filepath = os.path.join(checkpoint, fold + filename) 145 | model_path = os.path.join(checkpoint, fold + 'model_cur.pth') 146 | 147 | torch.save(state, filepath) 148 | torch.save(state['state_dict'], model_path) 149 | if is_best: 150 | shutil.copyfile(filepath, os.path.join(checkpoint, fold + 'model_best.pth.tar')) 151 | shutil.copyfile(model_path, os.path.join(checkpoint, fold + 'model_best.pth')) 152 | -------------------------------------------------------------------------------- /predict/predict.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-14 16:07 8 | """ 9 | import torchvision.transforms as transforms 10 | import torch 11 | from PIL import Image 12 | from collections import OrderedDict 13 | import torch.nn.functional as F 14 | from efficientnet_pytorch import EfficientNet 15 | from torch import nn 16 | import os, time 17 | import torchvision.models as models 18 | from resnetxt_wsl import resnext101_32x8d_wsl, resnext101_32x16d_wsl, resnext101_32x32d_wsl 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 20 | 21 | args = {} 22 | args['arch'] = 'resnext101_32x16d_wsl' 23 | args['pretrained'] = False 24 | args['num_classes'] = 43 25 | args['image_size'] = 320 26 | 27 | 28 | class classfication_service(): 29 | def __init__(self, model_path): 30 | self.model = self.build_model(model_path) 31 | self.pre_img = self.preprocess_img() 32 | self.model.eval() 33 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 34 | self.label_id_name_dict = \ 35 | { 36 | "0": "其他垃圾/一次性快餐盒", 37 | "1": "其他垃圾/污损塑料", 38 | "2": "其他垃圾/烟蒂", 39 | "3": "其他垃圾/牙签", 40 | "4": "其他垃圾/破碎花盆及碟碗", 41 | "5": "其他垃圾/竹筷", 42 | "6": "厨余垃圾/剩饭剩菜", 43 | "7": "厨余垃圾/大骨头", 44 | "8": "厨余垃圾/水果果皮", 45 | "9": "厨余垃圾/水果果肉", 46 | "10": "厨余垃圾/茶叶渣", 47 | "11": "厨余垃圾/菜叶菜根", 48 | "12": "厨余垃圾/蛋壳", 49 | "13": "厨余垃圾/鱼骨", 50 | "14": "可回收物/充电宝", 51 | "15": "可回收物/包", 52 | "16": "可回收物/化妆品瓶", 53 | "17": "可回收物/塑料玩具", 54 | "18": "可回收物/塑料碗盆", 55 | "19": "可回收物/塑料衣架", 56 | "20": "可回收物/快递纸袋", 57 | "21": "可回收物/插头电线", 58 | "22": "可回收物/旧衣服", 59 | "23": "可回收物/易拉罐", 60 | "24": "可回收物/枕头", 61 | "25": "可回收物/毛绒玩具", 62 | "26": "可回收物/洗发水瓶", 63 | "27": "可回收物/玻璃杯", 64 | "28": "可回收物/皮鞋", 65 | "29": "可回收物/砧板", 66 | "30": "可回收物/纸板箱", 67 | "31": "可回收物/调料瓶", 68 | "32": "可回收物/酒瓶", 69 | "33": "可回收物/金属食品罐", 70 | "34": "可回收物/锅", 71 | "35": "可回收物/食用油桶", 72 | "36": "可回收物/饮料瓶", 73 | "37": "有害垃圾/干电池", 74 | "38": "有害垃圾/软膏", 75 | "39": "有害垃圾/过期药物", 76 | "40": "可回收物/毛巾", 77 | "41": "可回收物/饮料盒", 78 | "42": "可回收物/纸袋" 79 | } 80 | 81 | def build_model(self, model_path): 82 | if args['arch'] == 'resnext101_32x16d_wsl': 83 | model = resnext101_32x16d_wsl(pretrained=False, progress=False) 84 | if args['arch'] == 'resnext101_32x8d': 85 | model = models.__dict__[args['arch']]() 86 | elif args['arch'] == 'efficientnet-b7': 87 | model = EfficientNet.from_name(args['arch']) 88 | 89 | layerName, layer = list(model.named_children())[-1] 90 | exec("model." + layerName + "=nn.Linear(layer.in_features," + str(args['num_classes']) + ")") 91 | 92 | if torch.cuda.is_available(): 93 | modelState = torch.load(model_path) 94 | model.load_state_dict(modelState) 95 | model = model.cuda() 96 | else: 97 | modelState = torch.load(model_path, map_location='cpu') 98 | model.load_state_dict(modelState) 99 | return model 100 | 101 | def preprocess_img(self): 102 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 103 | infer_transformation = transforms.Compose([ 104 | Resize((args['image_size'], args['image_size'])), 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean=mean, std=std), 107 | ]) 108 | return infer_transformation 109 | 110 | def _preprocess(self, data): 111 | preprocessed_data = {} 112 | for k, v in data.items(): 113 | for file_name, file_content in v.items(): 114 | img = Image.open(file_content) 115 | img = self.pre_img(img) 116 | preprocessed_data[k] = img 117 | return preprocessed_data 118 | 119 | def _inference(self, data): 120 | """ 121 | model inference function 122 | Here are a inference example of resnet, if you use another model, please modify this function 123 | """ 124 | img = data['input_img'] 125 | img = img.unsqueeze(0) 126 | img = img.to(self.device) 127 | with torch.no_grad(): 128 | pred_score = self.model(img) 129 | 130 | if pred_score is not None: 131 | _, pred_label = torch.max(pred_score.data, 1) 132 | result = {'result': self.label_id_name_dict[str(pred_label[0].item())]} 133 | else: 134 | result = {'result': 'predict score is None'} 135 | 136 | return result 137 | 138 | def _postprocess(self, data): 139 | return data 140 | 141 | 142 | class Resize(object): 143 | def __init__(self, size, interpolation=Image.BILINEAR): 144 | self.size = size 145 | self.interpolation = interpolation 146 | 147 | def __call__(self, img): 148 | ratio = self.size[0] / self.size[1] 149 | w, h = img.size 150 | if w / h < ratio: 151 | t = int(h * ratio) 152 | w_padding = (t - w) // 2 153 | img = img.crop((-w_padding, 0, w+w_padding, h)) 154 | else: 155 | t = int(w / ratio) 156 | h_padding = (t - h) // 2 157 | img = img.crop((0, -h_padding, w, h+h_padding)) 158 | 159 | img = img.resize(self.size, self.interpolation) 160 | 161 | return img 162 | 163 | if __name__ == '__main__': 164 | model_path = '/Users/QLMX/Documents/' + 'model_19_9992_9590.pth' 165 | infer = classfication_service(model_path) 166 | input_dir = '/Users/QLMX/Downloads/garbage_classify_v3_select_100' 167 | files = os.listdir(input_dir) 168 | t1 = int(time.time()*1000) 169 | for file_name in files: 170 | file_path = os.path.join(input_dir, file_name) 171 | img = Image.open(file_path) 172 | 173 | img = infer.pre_img(img) 174 | tt1 = int(time.time() * 1000) 175 | result = infer._inference({'input_img': img}) 176 | tt2 = int(time.time() * 1000) 177 | print((tt2 - tt1) / 100) 178 | t2 = int(time.time()*1000) 179 | print((t2 - t1)/100) 180 | 181 | -------------------------------------------------------------------------------- /utils/radam.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-09-13 10:49 8 | """ 9 | import math 10 | import torch 11 | from torch.optim.optimizer import Optimizer, required 12 | 13 | 14 | class RAdam(Optimizer): 15 | 16 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 17 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 18 | self.buffer = [[None, None, None] for ind in range(10)] 19 | super(RAdam, self).__init__(params, defaults) 20 | 21 | def __setstate__(self, state): 22 | super(RAdam, self).__setstate__(state) 23 | 24 | def step(self, closure=None): 25 | 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | for group in self.param_groups: 31 | 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | grad = p.grad.data.float() 36 | if grad.is_sparse: 37 | raise RuntimeError('RAdam does not support sparse gradients') 38 | 39 | p_data_fp32 = p.data.float() 40 | 41 | state = self.state[p] 42 | 43 | if len(state) == 0: 44 | state['step'] = 0 45 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 46 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 47 | else: 48 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 49 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 50 | 51 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 52 | beta1, beta2 = group['betas'] 53 | 54 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 55 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 56 | 57 | state['step'] += 1 58 | buffered = self.buffer[int(state['step'] % 10)] 59 | if state['step'] == buffered[0]: 60 | N_sma, step_size = buffered[1], buffered[2] 61 | else: 62 | buffered[0] = state['step'] 63 | beta2_t = beta2 ** state['step'] 64 | N_sma_max = 2 / (1 - beta2) - 1 65 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 66 | buffered[1] = N_sma 67 | 68 | # more conservative since it's an approximated value 69 | if N_sma >= 5: 70 | step_size = math.sqrt( 71 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 72 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 73 | else: 74 | step_size = 1.0 / (1 - beta1 ** state['step']) 75 | buffered[2] = step_size 76 | 77 | if group['weight_decay'] != 0: 78 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 79 | 80 | # more conservative since it's an approximated value 81 | if N_sma >= 5: 82 | denom = exp_avg_sq.sqrt().add_(group['eps']) 83 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 84 | else: 85 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 86 | 87 | p.data.copy_(p_data_fp32) 88 | 89 | return loss 90 | 91 | 92 | class PlainRAdam(Optimizer): 93 | 94 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 95 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 96 | 97 | super(PlainRAdam, self).__init__(params, defaults) 98 | 99 | def __setstate__(self, state): 100 | super(PlainRAdam, self).__setstate__(state) 101 | 102 | def step(self, closure=None): 103 | 104 | loss = None 105 | if closure is not None: 106 | loss = closure() 107 | 108 | for group in self.param_groups: 109 | 110 | for p in group['params']: 111 | if p.grad is None: 112 | continue 113 | grad = p.grad.data.float() 114 | if grad.is_sparse: 115 | raise RuntimeError('RAdam does not support sparse gradients') 116 | 117 | p_data_fp32 = p.data.float() 118 | 119 | state = self.state[p] 120 | 121 | if len(state) == 0: 122 | state['step'] = 0 123 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 124 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 125 | else: 126 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 127 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 128 | 129 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 130 | beta1, beta2 = group['betas'] 131 | 132 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 133 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 134 | 135 | state['step'] += 1 136 | beta2_t = beta2 ** state['step'] 137 | N_sma_max = 2 / (1 - beta2) - 1 138 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 139 | 140 | if group['weight_decay'] != 0: 141 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 142 | 143 | # more conservative since it's an approximated value 144 | if N_sma >= 5: 145 | step_size = group['lr'] * math.sqrt( 146 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 147 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 148 | denom = exp_avg_sq.sqrt().add_(group['eps']) 149 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 150 | else: 151 | step_size = group['lr'] / (1 - beta1 ** state['step']) 152 | p_data_fp32.add_(-step_size, exp_avg) 153 | 154 | p.data.copy_(p_data_fp32) 155 | 156 | return loss 157 | 158 | 159 | class AdamW(Optimizer): 160 | 161 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): 162 | defaults = dict(lr=lr, betas=betas, eps=eps, 163 | weight_decay=weight_decay, warmup=warmup) 164 | super(AdamW, self).__init__(params, defaults) 165 | 166 | def __setstate__(self, state): 167 | super(AdamW, self).__setstate__(state) 168 | 169 | def step(self, closure=None): 170 | loss = None 171 | if closure is not None: 172 | loss = closure() 173 | 174 | for group in self.param_groups: 175 | 176 | for p in group['params']: 177 | if p.grad is None: 178 | continue 179 | grad = p.grad.data.float() 180 | if grad.is_sparse: 181 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 182 | 183 | p_data_fp32 = p.data.float() 184 | 185 | state = self.state[p] 186 | 187 | if len(state) == 0: 188 | state['step'] = 0 189 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 190 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 191 | else: 192 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 193 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 194 | 195 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 196 | beta1, beta2 = group['betas'] 197 | 198 | state['step'] += 1 199 | 200 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 201 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 202 | 203 | denom = exp_avg_sq.sqrt().add_(group['eps']) 204 | bias_correction1 = 1 - beta1 ** state['step'] 205 | bias_correction2 = 1 - beta2 ** state['step'] 206 | 207 | if group['warmup'] > state['step']: 208 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 209 | else: 210 | scheduled_lr = group['lr'] 211 | 212 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 213 | 214 | if group['weight_decay'] != 0: 215 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 216 | 217 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 218 | 219 | p.data.copy_(p_data_fp32) 220 | 221 | return loss -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: QLMX 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-08-14 11:08 8 | 公众号:AI成长社 9 | 知乎:https://www.zhihu.com/people/qlmx-61/columns 10 | """ 11 | from __future__ import print_function 12 | 13 | import os 14 | import time 15 | import random 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.parallel 20 | import torch.backends.cudnn as cudnn 21 | import torch.utils.data as data 22 | import torchvision.transforms as transforms 23 | import dataset 24 | import numpy as np 25 | from args import args 26 | from build_net import make_model 27 | from transform import get_transforms 28 | 29 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig, get_optimizer, save_checkpoint 30 | 31 | 32 | state = {k: v for k, v in args._get_kwargs()} 33 | 34 | # Use CUDA 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 36 | use_cuda = torch.cuda.is_available() 37 | 38 | # Random seed 39 | if args.manualSeed is None: 40 | args.manualSeed = random.randint(1, 10000) 41 | random.seed(args.manualSeed) 42 | torch.manual_seed(args.manualSeed) 43 | if use_cuda: 44 | torch.cuda.manual_seed_all(args.manualSeed) 45 | best_acc = 0 # best test accuracy 46 | 47 | def main(): 48 | global best_acc 49 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 50 | 51 | if not os.path.isdir(args.checkpoint): 52 | mkdir_p(args.checkpoint) 53 | 54 | # Data 55 | transform = get_transforms(input_size=args.image_size, test_size=args.image_size, backbone=None) 56 | 57 | 58 | print('==> Preparing dataset %s' % args.trainroot) 59 | trainset = dataset.Dataset(root=args.trainroot, transform=transform['val_train']) 60 | train_loader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) 61 | 62 | valset = dataset.TestDataset(root=args.valroot, transform=transform['val_test']) 63 | val_loader = data.DataLoader(valset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) 64 | 65 | model = make_model(args) 66 | 67 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 68 | model.features = torch.nn.DataParallel(model.features) 69 | model.cuda() 70 | else: 71 | model = torch.nn.DataParallel(model).cuda() 72 | 73 | cudnn.benchmark = True 74 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 75 | 76 | # define loss function (criterion) and optimizer 77 | criterion = nn.CrossEntropyLoss().cuda() 78 | optimizer = get_optimizer(model, args) 79 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=5, verbose=False) 80 | 81 | # Resume 82 | title = 'ImageNet-' + args.arch 83 | if args.resume: 84 | # Load checkpoint. 85 | print('==> Resuming from checkpoint..') 86 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 87 | args.checkpoint = os.path.dirname(args.resume) 88 | checkpoint = torch.load(args.resume) 89 | best_acc = checkpoint['best_acc'] 90 | start_epoch = checkpoint['epoch'] 91 | model.module.load_state_dict(checkpoint['state_dict']) 92 | optimizer.load_state_dict(checkpoint['optimizer']) 93 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) 94 | else: 95 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) 96 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) 97 | 98 | 99 | if args.evaluate: 100 | print('\nEvaluation only') 101 | test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) 102 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) 103 | return 104 | 105 | # Train and val 106 | for epoch in range(start_epoch, args.epochs): 107 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[0]['lr'])) 108 | 109 | train_loss, train_acc, train_5 = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 110 | test_loss, test_acc, test_5 = test(val_loader, model, criterion, epoch, use_cuda) 111 | 112 | scheduler.step(test_loss) 113 | 114 | # append logger file 115 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) 116 | print('train_loss:%f, val_loss:%f, train_acc:%f, train_5:%f, val_acc:%f, val_5:%f' % (train_loss, test_loss, train_acc, train_5, test_acc, test_5)) 117 | 118 | # save model 119 | is_best = test_acc > best_acc 120 | best_acc = max(test_acc, best_acc) 121 | 122 | if len(args.gpu_id) > 1: 123 | save_checkpoint({ 124 | 'fold': 0, 125 | 'epoch': epoch + 1, 126 | 'state_dict': model.module.state_dict(), 127 | 'train_acc': train_acc, 128 | 'acc': test_acc, 129 | 'best_acc': best_acc, 130 | 'optimizer': optimizer.state_dict(), 131 | }, is_best, single=True, checkpoint=args.checkpoint) 132 | else: 133 | save_checkpoint({ 134 | 'fold': 0, 135 | 'epoch': epoch + 1, 136 | 'state_dict': model.state_dict(), 137 | 'train_acc':train_acc, 138 | 'acc': test_acc, 139 | 'best_acc': best_acc, 140 | 'optimizer' : optimizer.state_dict(), 141 | }, is_best, single=True, checkpoint=args.checkpoint) 142 | 143 | logger.close() 144 | logger.plot() 145 | savefig(os.path.join(args.checkpoint, 'log.eps')) 146 | 147 | print('Best acc:') 148 | print(best_acc) 149 | 150 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 151 | # switch to train mode 152 | model.train() 153 | 154 | batch_time = AverageMeter() 155 | data_time = AverageMeter() 156 | losses = AverageMeter() 157 | top1 = AverageMeter() 158 | top5 = AverageMeter() 159 | end = time.time() 160 | 161 | bar = Bar('Processing', max=len(train_loader)) 162 | for batch_idx, (inputs, targets) in enumerate(train_loader): 163 | # measure data loading time 164 | data_time.update(time.time() - end) 165 | 166 | if use_cuda: 167 | inputs, targets = inputs.cuda(), targets.cuda(async=True) 168 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 169 | 170 | # compute output 171 | outputs = model(inputs) 172 | loss = criterion(outputs, targets) 173 | 174 | # measure accuracy and record loss 175 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 176 | losses.update(loss.item(), inputs.size(0)) 177 | top1.update(prec1.item(), inputs.size(0)) 178 | top5.update(prec5.item(), inputs.size(0)) 179 | 180 | # compute gradient and do SGD step 181 | optimizer.zero_grad() 182 | loss.backward() 183 | optimizer.step() 184 | 185 | # measure elapsed time 186 | batch_time.update(time.time() - end) 187 | end = time.time() 188 | 189 | # plot progress 190 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 191 | batch=batch_idx + 1, 192 | size=len(train_loader), 193 | data=data_time.val, 194 | bt=batch_time.val, 195 | total=bar.elapsed_td, 196 | eta=bar.eta_td, 197 | loss=losses.avg, 198 | top1=top1.avg, 199 | top5=top5.avg, 200 | ) 201 | bar.next() 202 | bar.finish() 203 | return (losses.avg, top1.avg, top5.avg) 204 | 205 | def test(val_loader, model, criterion, epoch, use_cuda): 206 | global best_acc 207 | 208 | batch_time = AverageMeter() 209 | data_time = AverageMeter() 210 | losses = AverageMeter() 211 | top1 = AverageMeter() 212 | top5 = AverageMeter() 213 | 214 | # switch to evaluate mode 215 | model.eval() 216 | 217 | end = time.time() 218 | bar = Bar('Processing', max=len(val_loader)) 219 | for batch_idx, (inputs, targets) in enumerate(val_loader): 220 | # measure data loading time 221 | data_time.update(time.time() - end) 222 | 223 | if use_cuda: 224 | inputs, targets = inputs.cuda(), targets.cuda() 225 | inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) 226 | 227 | # compute output 228 | outputs = model(inputs) 229 | loss = criterion(outputs, targets) 230 | 231 | # measure accuracy and record loss 232 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 233 | losses.update(loss.item(), inputs.size(0)) 234 | top1.update(prec1.item(), inputs.size(0)) 235 | top5.update(prec5.item(), inputs.size(0)) 236 | 237 | # measure elapsed time 238 | batch_time.update(time.time() - end) 239 | end = time.time() 240 | 241 | # plot progress 242 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 243 | batch=batch_idx + 1, 244 | size=len(val_loader), 245 | data=data_time.avg, 246 | bt=batch_time.avg, 247 | total=bar.elapsed_td, 248 | eta=bar.eta_td, 249 | loss=losses.avg, 250 | top1=top1.avg, 251 | top5=top5.avg, 252 | ) 253 | bar.next() 254 | bar.finish() 255 | return (losses.avg, top1.avg, top5.avg) 256 | 257 | 258 | if __name__ == '__main__': 259 | main() 260 | 261 | -------------------------------------------------------------------------------- /models/Res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 16 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 26 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = x.reshape(x.size(0), -1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | 219 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 220 | model = ResNet(block, layers, **kwargs) 221 | if pretrained: 222 | state_dict = load_state_dict_from_url(model_urls[arch], 223 | progress=progress) 224 | model.load_state_dict(state_dict) 225 | return model 226 | 227 | 228 | def resnet18(pretrained=False, progress=True, **kwargs): 229 | """Constructs a ResNet-18 model. 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | progress (bool): If True, displays a progress bar of the download to stderr 233 | """ 234 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 235 | **kwargs) 236 | 237 | 238 | def resnet34(pretrained=False, progress=True, **kwargs): 239 | """Constructs a ResNet-34 model. 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet50(pretrained=False, progress=True, **kwargs): 249 | """Constructs a ResNet-50 model. 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | """Constructs a ResNet-101 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet152(pretrained=False, progress=True, **kwargs): 269 | """Constructs a ResNet-152 model. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNeXt-50 32x4d model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 4 286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 287 | pretrained, progress, **kwargs) 288 | 289 | 290 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 291 | """Constructs a ResNeXt-101 32x8d model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | kwargs['groups'] = 32 297 | kwargs['width_per_group'] = 8 298 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 299 | pretrained, progress, **kwargs) 300 | -------------------------------------------------------------------------------- /predict/Res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 16 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 26 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = x.reshape(x.size(0), -1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | 219 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 220 | model = ResNet(block, layers, **kwargs) 221 | if pretrained: 222 | state_dict = load_state_dict_from_url(model_urls[arch], 223 | progress=progress) 224 | model.load_state_dict(state_dict) 225 | return model 226 | 227 | 228 | def resnet18(pretrained=False, progress=True, **kwargs): 229 | """Constructs a ResNet-18 model. 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | progress (bool): If True, displays a progress bar of the download to stderr 233 | """ 234 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 235 | **kwargs) 236 | 237 | 238 | def resnet34(pretrained=False, progress=True, **kwargs): 239 | """Constructs a ResNet-34 model. 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet50(pretrained=False, progress=True, **kwargs): 249 | """Constructs a ResNet-50 model. 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | """Constructs a ResNet-101 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet152(pretrained=False, progress=True, **kwargs): 269 | """Constructs a ResNet-152 model. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNeXt-50 32x4d model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 4 286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 287 | pretrained, progress, **kwargs) 288 | 289 | 290 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 291 | """Constructs a ResNeXt-101 32x8d model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | kwargs['groups'] = 32 297 | kwargs['width_per_group'] = 8 298 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 299 | pretrained, progress, **kwargs) 300 | -------------------------------------------------------------------------------- /models/resnet_cbam.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: ikkyu-wen 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-09-17 12:39 8 | """ 9 | import torch.nn as nn 10 | import math 11 | try: 12 | from torch.hub import load_state_dict_from_url 13 | except ImportError: 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | import torch 16 | 17 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 18 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 19 | 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 29 | } 30 | 31 | 32 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 35 | padding=dilation, groups=groups, bias=False, dilation=dilation) 36 | 37 | 38 | def conv1x1(in_planes, out_planes, stride=1): 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | 44 | class ChannelAttention(nn.Module): 45 | def __init__(self, in_planes, ratio=16): 46 | super(ChannelAttention, self).__init__() 47 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 48 | self.max_pool = nn.AdaptiveMaxPool2d(1) 49 | 50 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 51 | self.relu1 = nn.ReLU() 52 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 53 | 54 | self.sigmoid = nn.Sigmoid() 55 | 56 | def forward(self, x): 57 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 58 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 59 | out = avg_out + max_out 60 | return self.sigmoid(out) 61 | 62 | class SpatialAttention(nn.Module): 63 | def __init__(self, kernel_size=7): 64 | super(SpatialAttention, self).__init__() 65 | 66 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 67 | padding = 3 if kernel_size == 7 else 1 68 | 69 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | avg_out = torch.mean(x, dim=1, keepdim=True) 74 | max_out, _ = torch.max(x, dim=1, keepdim=True) 75 | x = torch.cat([avg_out, max_out], dim=1) 76 | x = self.conv1(x) 77 | return self.sigmoid(x) 78 | 79 | 80 | class BasicBlock(nn.Module): 81 | expansion = 1 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(BasicBlock, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | if groups != 1 or base_width != 64: 89 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 90 | if dilation > 1: 91 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 92 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 93 | self.conv1 = conv3x3(inplanes, planes, stride) 94 | self.bn1 = norm_layer(planes) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.conv2 = conv3x3(planes, planes) 97 | self.bn2 = norm_layer(planes) 98 | 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | identity = x 104 | 105 | out = self.conv1(x) 106 | out = self.bn1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | 112 | if self.downsample is not None: 113 | identity = self.downsample(x) 114 | 115 | out += identity 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class Bottleneck(nn.Module): 122 | expansion = 4 123 | 124 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 125 | base_width=64, dilation=1, norm_layer=None): 126 | super(Bottleneck, self).__init__() 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm2d 129 | width = int(planes * (base_width / 64.)) * groups 130 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 131 | self.conv1 = conv1x1(inplanes, width) 132 | self.bn1 = norm_layer(width) 133 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 134 | self.bn2 = norm_layer(width) 135 | self.conv3 = conv1x1(width, planes * self.expansion) 136 | self.bn3 = norm_layer(planes * self.expansion) 137 | self.relu = nn.ReLU(inplace=True) 138 | 139 | self.downsample = downsample 140 | self.stride = stride 141 | 142 | def forward(self, x): 143 | identity = x 144 | 145 | out = self.conv1(x) 146 | out = self.bn1(out) 147 | out = self.relu(out) 148 | 149 | out = self.conv2(out) 150 | out = self.bn2(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv3(out) 154 | out = self.bn3(out) 155 | 156 | if self.downsample is not None: 157 | identity = self.downsample(x) 158 | 159 | out += identity 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class ResNet(nn.Module): 166 | 167 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 168 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 169 | norm_layer=None): 170 | super(ResNet, self).__init__() 171 | if norm_layer is None: 172 | norm_layer = nn.BatchNorm2d 173 | self._norm_layer = norm_layer 174 | 175 | self.inplanes = 64 176 | self.dilation = 1 177 | if replace_stride_with_dilation is None: 178 | # each element in the tuple indicates if we should replace 179 | # the 2x2 stride with a dilated convolution instead 180 | replace_stride_with_dilation = [False, False, False] 181 | if len(replace_stride_with_dilation) != 3: 182 | raise ValueError("replace_stride_with_dilation should be None " 183 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 184 | self.groups = groups 185 | self.base_width = width_per_group 186 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 187 | bias=False) 188 | self.bn1 = norm_layer(self.inplanes) 189 | self.relu = nn.ReLU(inplace=True) 190 | 191 | self.ca = ChannelAttention(self.inplanes) 192 | self.sa = SpatialAttention() 193 | 194 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 195 | self.layer1 = self._make_layer(block, 64, layers[0]) 196 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 197 | dilate=replace_stride_with_dilation[0]) 198 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 199 | dilate=replace_stride_with_dilation[1]) 200 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 201 | dilate=replace_stride_with_dilation[2]) 202 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 203 | self.fc = nn.Linear(512 * block.expansion, num_classes) 204 | 205 | for m in self.modules(): 206 | if isinstance(m, nn.Conv2d): 207 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 208 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 209 | nn.init.constant_(m.weight, 1) 210 | nn.init.constant_(m.bias, 0) 211 | 212 | # Zero-initialize the last BN in each residual branch, 213 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 214 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 215 | if zero_init_residual: 216 | for m in self.modules(): 217 | if isinstance(m, Bottleneck): 218 | nn.init.constant_(m.bn3.weight, 0) 219 | elif isinstance(m, BasicBlock): 220 | nn.init.constant_(m.bn2.weight, 0) 221 | 222 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 223 | norm_layer = self._norm_layer 224 | downsample = None 225 | previous_dilation = self.dilation 226 | if dilate: 227 | self.dilation *= stride 228 | stride = 1 229 | if stride != 1 or self.inplanes != planes * block.expansion: 230 | downsample = nn.Sequential( 231 | conv1x1(self.inplanes, planes * block.expansion, stride), 232 | norm_layer(planes * block.expansion), 233 | ) 234 | 235 | layers = [] 236 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 237 | self.base_width, previous_dilation, norm_layer)) 238 | self.inplanes = planes * block.expansion 239 | for _ in range(1, blocks): 240 | layers.append(block(self.inplanes, planes, groups=self.groups, 241 | base_width=self.base_width, dilation=self.dilation, 242 | norm_layer=norm_layer)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x): 247 | x = self.conv1(x) 248 | x = self.bn1(x) 249 | x = self.relu(x) 250 | x = self.ca(x) * x 251 | x = self.sa(x) * x 252 | x = self.maxpool(x) 253 | 254 | x = self.layer1(x) 255 | x = self.layer2(x) 256 | x = self.layer3(x) 257 | x = self.layer4(x) 258 | 259 | x = self.avgpool(x) 260 | x = x.reshape(x.size(0), -1) 261 | x = self.fc(x) 262 | 263 | return x 264 | 265 | 266 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 267 | model = ResNet(block, layers, **kwargs) 268 | if pretrained: 269 | state_dict = load_state_dict_from_url(model_urls[arch], 270 | progress=progress) 271 | new_state_dict = model.state_dict() 272 | new_state_dict.update(state_dict) 273 | model.load_state_dict(new_state_dict) 274 | return model 275 | 276 | 277 | 278 | def resnet18(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNet-18 model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 285 | **kwargs) 286 | 287 | 288 | def resnet34(pretrained=False, progress=True, **kwargs): 289 | """Constructs a ResNet-34 model. 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | progress (bool): If True, displays a progress bar of the download to stderr 293 | """ 294 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 295 | **kwargs) 296 | 297 | 298 | def resnet50(pretrained=False, progress=True, **kwargs): 299 | """Constructs a ResNet-50 model. 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 305 | **kwargs) 306 | 307 | 308 | def resnet101(pretrained=False, progress=True, **kwargs): 309 | """Constructs a ResNet-101 model. 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 315 | **kwargs) 316 | 317 | 318 | def resnet152(pretrained=False, progress=True, **kwargs): 319 | """Constructs a ResNet-152 model. 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 325 | **kwargs) 326 | 327 | 328 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 329 | """Constructs a ResNeXt-50 32x4d model. 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | kwargs['groups'] = 32 335 | kwargs['width_per_group'] = 4 336 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 337 | pretrained, progress, **kwargs) 338 | 339 | 340 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 341 | """Constructs a ResNeXt-101 32x8d model. 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | """ 346 | kwargs['groups'] = 32 347 | kwargs['width_per_group'] = 8 348 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 349 | pretrained, progress, **kwargs) 350 | -------------------------------------------------------------------------------- /predict/resnet_cbam.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @version: python3.6 5 | @author: ikkyu-wen 6 | @contact: wenruichn@gmail.com 7 | @time: 2019-09-17 12:39 8 | """ 9 | import torch.nn as nn 10 | import math 11 | try: 12 | from torch.hub import load_state_dict_from_url 13 | except ImportError: 14 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 15 | import torch 16 | 17 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 18 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 19 | 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 29 | } 30 | 31 | 32 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 35 | padding=dilation, groups=groups, bias=False, dilation=dilation) 36 | 37 | 38 | def conv1x1(in_planes, out_planes, stride=1): 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | 44 | class ChannelAttention(nn.Module): 45 | def __init__(self, in_planes, ratio=16): 46 | super(ChannelAttention, self).__init__() 47 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 48 | self.max_pool = nn.AdaptiveMaxPool2d(1) 49 | 50 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 51 | self.relu1 = nn.ReLU() 52 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 53 | 54 | self.sigmoid = nn.Sigmoid() 55 | 56 | def forward(self, x): 57 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 58 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 59 | out = avg_out + max_out 60 | return self.sigmoid(out) 61 | 62 | class SpatialAttention(nn.Module): 63 | def __init__(self, kernel_size=7): 64 | super(SpatialAttention, self).__init__() 65 | 66 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 67 | padding = 3 if kernel_size == 7 else 1 68 | 69 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | avg_out = torch.mean(x, dim=1, keepdim=True) 74 | max_out, _ = torch.max(x, dim=1, keepdim=True) 75 | x = torch.cat([avg_out, max_out], dim=1) 76 | x = self.conv1(x) 77 | return self.sigmoid(x) 78 | 79 | 80 | class BasicBlock(nn.Module): 81 | expansion = 1 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 84 | base_width=64, dilation=1, norm_layer=None): 85 | super(BasicBlock, self).__init__() 86 | if norm_layer is None: 87 | norm_layer = nn.BatchNorm2d 88 | if groups != 1 or base_width != 64: 89 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 90 | if dilation > 1: 91 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 92 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 93 | self.conv1 = conv3x3(inplanes, planes, stride) 94 | self.bn1 = norm_layer(planes) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.conv2 = conv3x3(planes, planes) 97 | self.bn2 = norm_layer(planes) 98 | 99 | self.downsample = downsample 100 | self.stride = stride 101 | 102 | def forward(self, x): 103 | identity = x 104 | 105 | out = self.conv1(x) 106 | out = self.bn1(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv2(out) 110 | out = self.bn2(out) 111 | 112 | if self.downsample is not None: 113 | identity = self.downsample(x) 114 | 115 | out += identity 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class Bottleneck(nn.Module): 122 | expansion = 4 123 | 124 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 125 | base_width=64, dilation=1, norm_layer=None): 126 | super(Bottleneck, self).__init__() 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm2d 129 | width = int(planes * (base_width / 64.)) * groups 130 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 131 | self.conv1 = conv1x1(inplanes, width) 132 | self.bn1 = norm_layer(width) 133 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 134 | self.bn2 = norm_layer(width) 135 | self.conv3 = conv1x1(width, planes * self.expansion) 136 | self.bn3 = norm_layer(planes * self.expansion) 137 | self.relu = nn.ReLU(inplace=True) 138 | 139 | self.downsample = downsample 140 | self.stride = stride 141 | 142 | def forward(self, x): 143 | identity = x 144 | 145 | out = self.conv1(x) 146 | out = self.bn1(out) 147 | out = self.relu(out) 148 | 149 | out = self.conv2(out) 150 | out = self.bn2(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv3(out) 154 | out = self.bn3(out) 155 | 156 | if self.downsample is not None: 157 | identity = self.downsample(x) 158 | 159 | out += identity 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class ResNet(nn.Module): 166 | 167 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 168 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 169 | norm_layer=None): 170 | super(ResNet, self).__init__() 171 | if norm_layer is None: 172 | norm_layer = nn.BatchNorm2d 173 | self._norm_layer = norm_layer 174 | 175 | self.inplanes = 64 176 | self.dilation = 1 177 | if replace_stride_with_dilation is None: 178 | # each element in the tuple indicates if we should replace 179 | # the 2x2 stride with a dilated convolution instead 180 | replace_stride_with_dilation = [False, False, False] 181 | if len(replace_stride_with_dilation) != 3: 182 | raise ValueError("replace_stride_with_dilation should be None " 183 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 184 | self.groups = groups 185 | self.base_width = width_per_group 186 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 187 | bias=False) 188 | self.bn1 = norm_layer(self.inplanes) 189 | self.relu = nn.ReLU(inplace=True) 190 | 191 | self.ca = ChannelAttention(self.inplanes) 192 | self.sa = SpatialAttention() 193 | 194 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 195 | self.layer1 = self._make_layer(block, 64, layers[0]) 196 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 197 | dilate=replace_stride_with_dilation[0]) 198 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 199 | dilate=replace_stride_with_dilation[1]) 200 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 201 | dilate=replace_stride_with_dilation[2]) 202 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 203 | self.fc = nn.Linear(512 * block.expansion, num_classes) 204 | 205 | for m in self.modules(): 206 | if isinstance(m, nn.Conv2d): 207 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 208 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 209 | nn.init.constant_(m.weight, 1) 210 | nn.init.constant_(m.bias, 0) 211 | 212 | # Zero-initialize the last BN in each residual branch, 213 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 214 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 215 | if zero_init_residual: 216 | for m in self.modules(): 217 | if isinstance(m, Bottleneck): 218 | nn.init.constant_(m.bn3.weight, 0) 219 | elif isinstance(m, BasicBlock): 220 | nn.init.constant_(m.bn2.weight, 0) 221 | 222 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 223 | norm_layer = self._norm_layer 224 | downsample = None 225 | previous_dilation = self.dilation 226 | if dilate: 227 | self.dilation *= stride 228 | stride = 1 229 | if stride != 1 or self.inplanes != planes * block.expansion: 230 | downsample = nn.Sequential( 231 | conv1x1(self.inplanes, planes * block.expansion, stride), 232 | norm_layer(planes * block.expansion), 233 | ) 234 | 235 | layers = [] 236 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 237 | self.base_width, previous_dilation, norm_layer)) 238 | self.inplanes = planes * block.expansion 239 | for _ in range(1, blocks): 240 | layers.append(block(self.inplanes, planes, groups=self.groups, 241 | base_width=self.base_width, dilation=self.dilation, 242 | norm_layer=norm_layer)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x): 247 | x = self.conv1(x) 248 | x = self.bn1(x) 249 | x = self.relu(x) 250 | x = self.ca(x) * x 251 | x = self.sa(x) * x 252 | x = self.maxpool(x) 253 | 254 | x = self.layer1(x) 255 | x = self.layer2(x) 256 | x = self.layer3(x) 257 | x = self.layer4(x) 258 | 259 | x = self.avgpool(x) 260 | x = x.reshape(x.size(0), -1) 261 | x = self.fc(x) 262 | 263 | return x 264 | 265 | 266 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 267 | model = ResNet(block, layers, **kwargs) 268 | if pretrained: 269 | state_dict = load_state_dict_from_url(model_urls[arch], 270 | progress=progress) 271 | new_state_dict = model.state_dict() 272 | new_state_dict.update(state_dict) 273 | model.load_state_dict(new_state_dict) 274 | return model 275 | 276 | 277 | 278 | def resnet18(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNet-18 model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 285 | **kwargs) 286 | 287 | 288 | def resnet34(pretrained=False, progress=True, **kwargs): 289 | """Constructs a ResNet-34 model. 290 | Args: 291 | pretrained (bool): If True, returns a model pre-trained on ImageNet 292 | progress (bool): If True, displays a progress bar of the download to stderr 293 | """ 294 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 295 | **kwargs) 296 | 297 | 298 | def resnet50(pretrained=False, progress=True, **kwargs): 299 | """Constructs a ResNet-50 model. 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 305 | **kwargs) 306 | 307 | 308 | def resnet101(pretrained=False, progress=True, **kwargs): 309 | """Constructs a ResNet-101 model. 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | progress (bool): If True, displays a progress bar of the download to stderr 313 | """ 314 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 315 | **kwargs) 316 | 317 | 318 | def resnet152(pretrained=False, progress=True, **kwargs): 319 | """Constructs a ResNet-152 model. 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 325 | **kwargs) 326 | 327 | 328 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 329 | """Constructs a ResNeXt-50 32x4d model. 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | progress (bool): If True, displays a progress bar of the download to stderr 333 | """ 334 | kwargs['groups'] = 32 335 | kwargs['width_per_group'] = 4 336 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 337 | pretrained, progress, **kwargs) 338 | 339 | 340 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 341 | """Constructs a ResNeXt-101 32x8d model. 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | """ 346 | kwargs['groups'] = 32 347 | kwargs['width_per_group'] = 8 348 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 349 | pretrained, progress, **kwargs) 350 | --------------------------------------------------------------------------------