├── utils ├── progress │ ├── MANIFEST.in │ ├── demo.gif │ ├── LICENSE │ ├── setup.py │ ├── progress │ │ ├── spinner.py │ │ ├── counter.py │ │ ├── bar.py │ │ ├── helpers.py │ │ └── __init__.py │ ├── test_progress.py │ └── README.rst ├── IBNNet.png ├── images │ ├── cifar.png │ └── imagenet.png ├── __init__.py ├── eval.py ├── misc.py ├── visualize.py └── logger.py ├── hubconf.py ├── test.sh ├── train.sh ├── ibnnet ├── __init__.py ├── modules.py ├── se_resnet_ibn.py ├── resnext_ibn.py ├── densenet_ibn.py └── resnet_ibn.py ├── LICENSE ├── README.md └── imagenet.py /utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /utils/IBNNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/IBN-Net/HEAD/utils/IBNNet.png -------------------------------------------------------------------------------- /utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/IBN-Net/HEAD/utils/images/cifar.png -------------------------------------------------------------------------------- /utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/IBN-Net/HEAD/utils/progress/demo.gif -------------------------------------------------------------------------------- /utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/IBN-Net/HEAD/utils/images/imagenet.png -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch'] 2 | from ibnnet import resnet18_ibn_a, resnet34_ibn_a, resnet50_ibn_a, resnet101_ibn_a, \ 3 | resnet18_ibn_b, resnet34_ibn_b, resnet50_ibn_b, resnet101_ibn_b, \ 4 | resnext101_ibn_a, se_resnet101_ibn_a 5 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | MODEL=resnet50_ibn_a 4 | DATA_PATH=/pathToYourImageNetDataset/ 5 | 6 | python -u imagenet.py \ 7 | -a $MODEL \ 8 | --data $DATA_PATH \ 9 | --pretrained \ 10 | --test-batch 100 \ 11 | -e \ 12 | -j 16 \ 13 | --gpu_id 0,1 14 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | #from .logger import * 5 | #from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar 12 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | MODEL=resnet50_ibn_a 4 | DATA_PATH=/pathToYourImageNetDataset/ 5 | EXP_DIR=exp/$MODEL 6 | mkdir -p $EXP_DIR 7 | 8 | python -u imagenet.py \ 9 | -a $MODEL \ 10 | -j 32 \ 11 | --data $DATA_PATH \ 12 | --train-batch 256 \ 13 | --test-batch 100 \ 14 | --lr 0.1 \ 15 | --epochs 100 \ 16 | -c exp/${MODEL} \ 17 | --gpu_id 0,1,2,3,4,5,6,7 \ 18 | 2>&1 | tee exp/${MODEL}/log.txt 19 | -------------------------------------------------------------------------------- /ibnnet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet_ibn import resnet18_ibn_a, resnet34_ibn_a, resnet50_ibn_a, resnet101_ibn_a, resnet152_ibn_a, \ 4 | resnet18_ibn_b, resnet34_ibn_b, resnet50_ibn_b, resnet101_ibn_b, resnet152_ibn_b 5 | from .densenet_ibn import densenet121_ibn_a, densenet169_ibn_a, densenet201_ibn_a, densenet161_ibn_a 6 | from .resnext_ibn import resnext50_ibn_a, resnext101_ibn_a, resnext152_ibn_a 7 | from .se_resnet_ibn import se_resnet50_ibn_a, se_resnet101_ibn_a, se_resnet152_ibn_a 8 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Wei Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /ibnnet/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class IBN(nn.Module): 6 | r"""Instance-Batch Normalization layer from 7 | `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net" 8 | ` 9 | 10 | Args: 11 | planes (int): Number of channels for the input tensor 12 | ratio (float): Ratio of instance normalization in the IBN layer 13 | """ 14 | def __init__(self, planes, ratio=0.5): 15 | super(IBN, self).__init__() 16 | self.half = int(planes * ratio) 17 | self.IN = nn.InstanceNorm2d(self.half, affine=True) 18 | self.BN = nn.BatchNorm2d(planes - self.half) 19 | 20 | def forward(self, x): 21 | split = torch.split(x, self.half, 1) 22 | out1 = self.IN(split[0].contiguous()) 23 | out2 = self.BN(split[1].contiguous()) 24 | out = torch.cat((out1, out2), 1) 25 | return out 26 | 27 | 28 | class SELayer(nn.Module): 29 | def __init__(self, channel, reduction=16): 30 | super(SELayer, self).__init__() 31 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 32 | self.fc = nn.Sequential( 33 | nn.Linear(channel, int(channel/reduction), bias=False), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(int(channel/reduction), channel, bias=False), 36 | nn.Sigmoid() 37 | ) 38 | 39 | def forward(self, x): 40 | b, c, _, _ = x.size() 41 | y = self.avg_pool(x).view(b, c) 42 | y = self.fc(y).view(b, c, 1, 1) 43 | return x * y.expand_as(x) 44 | -------------------------------------------------------------------------------- /utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | -------------------------------------------------------------------------------- /utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | from torch.autograd import Variable 16 | 17 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 18 | 19 | 20 | def get_mean_and_std(dataset): 21 | '''Compute the mean and std value of dataset.''' 22 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 23 | 24 | mean = torch.zeros(3) 25 | std = torch.zeros(3) 26 | print('==> Computing mean and std..') 27 | for inputs, targets in dataloader: 28 | for i in range(3): 29 | mean[i] += inputs[:,i,:,:].mean() 30 | std[i] += inputs[:,i,:,:].std() 31 | mean.div_(len(dataset)) 32 | std.div_(len(dataset)) 33 | return mean, std 34 | 35 | def init_params(net): 36 | '''Init layer parameters.''' 37 | for m in net.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | init.kaiming_normal(m.weight, mode='fan_out') 40 | if m.bias: 41 | init.constant(m.bias, 0) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.constant(m.weight, 1) 44 | init.constant(m.bias, 0) 45 | elif isinstance(m, nn.Linear): 46 | init.normal(m.weight, std=1e-3) 47 | if m.bias: 48 | init.constant(m.bias, 0) 49 | 50 | def mkdir_p(path): 51 | '''make dir if not exist''' 52 | try: 53 | os.makedirs(path) 54 | except OSError as exc: # Python >2.5 55 | if exc.errno == errno.EEXIST and os.path.isdir(path): 56 | pass 57 | else: 58 | raise 59 | 60 | class AverageMeter(object): 61 | """Computes and stores the average and current value 62 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 63 | """ 64 | def __init__(self): 65 | self.reset() 66 | 67 | def reset(self): 68 | self.val = 0 69 | self.avg = 0 70 | self.sum = 0 71 | self.count = 0 72 | 73 | def update(self, val, n=1): 74 | self.val = val 75 | self.sum += val * n 76 | self.count += n 77 | self.avg = self.sum / self.count 78 | -------------------------------------------------------------------------------- /utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 1 #32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Instance-Batch Normalization Network 2 | 3 | ### Paper 4 | 5 | Xingang Pan, Ping Luo, Jianping Shi, Xiaoou Tang. ["Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"](https://arxiv.org/abs/1807.09441), ECCV2018. 6 | 7 | ### Introduction 8 | 9 | 10 | - IBN-Net is a CNN model with domain/appearance invariance. It carefully unifies instance normalization and batch normalization in a single deep network. 11 | - It provides a simple way to increase both modeling and generalization capacity without adding model complexity. 12 | - IBN-Net is especially suitable for cross domain or person/vehicle re-identification tasks, see [michuanhaohao/reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline) and [strong baseline for ReID](https://arxiv.org/pdf/1906.08332.pdf) for more details. 13 | 14 | ### Requirements 15 | - Pytorch 0.4.1 or higher 16 | 17 | ### Results 18 | 19 | Top1/Top5 error on the ImageNet validation set are reported. You may get different results when training your models with different random seed. 20 | 21 | | Model | origin | re-implementation | IBN-Net | 22 | | ------------------- | ------------------ | ------------------ | ------------------ | 23 | | DenseNet-121 | 25.0/- | 24.96/7.85 | 24.47/7.25 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/densenet121_ibn_a-e4af5cc1.pth) | 24 | | DenseNet-169 | 23.6/- | 24.02/7.06 | 23.25/6.51 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/densenet169_ibn_a-9f32c161.pth) | 25 | | ResNet-18 | - | 30.24/10.92 | 29.17/10.24 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth) | 26 | | ResNet-34 | - | 26.70/8.58 | 25.78/8.19 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth) | 27 | | ResNet-50 | 24.7/7.8 | 24.27/7.08 | 22.54/6.32 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth) | 28 | | ResNet-101 | 23.6/7.1 | 22.48/6.23 | 21.39/5.59 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth) | 29 | | ResNeXt-101 | 21.2/5.6 | 21.31/5.74 | 20.88/5.42 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth) | 30 | | SE-ResNet-101 | 22.38/6.07 | 21.68/5.88 | 21.25/5.51 [[pre-trained model]](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth) | 31 | 32 | The rank1/mAP on two Re-ID benchmarks Market1501 and DukeMTMC-reID (from [michuanhaohao/reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline)): 33 | 34 | | Backbone | Market1501 | DukeMTMC-reID | 35 | | --- | -- | -- | 36 | | ResNet50 | 94.5 (85.9) | 86.4 (76.4) | 37 | | ResNet101 | 94.5 (87.1) | 87.6 (77.6) | 38 | | SeResNet50 | 94.4 (86.3) | 86.4 (76.5) | 39 | | SeResNet101 | 94.6 (87.3) | 87.5 (78.0) | 40 | | SeResNeXt50 | 94.9 (87.6) | 88.0 (78.3) | 41 | | SeResNeXt101 | 95.0 (88.0) | 88.4 (79.0) | 42 | | IBN-Net-a | 95.0 (88.2) | 90.1 (79.1) | 43 | 44 | ### Load IBN-Net from torch.hub 45 | ```python 46 | import torch 47 | model = torch.hub.load('XingangPan/IBN-Net', 'resnet50_ibn_a', pretrained=True) 48 | ``` 49 | 50 | ### Testing/Training on ImageNet 51 | 1. Clone the repository 52 | ```Shell 53 | git clone https://github.com/XingangPan/IBN-Net.git 54 | ``` 55 | 56 | 2. Download [ImageNet](http://image-net.org/download-images) dataset (if you need to test or train on ImageNet). You may follow the instruction at [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch) to process the validation set. 57 | 58 | ### Testing 59 | 1. Edit `test.sh`. Modify `model` and `data_path` to yours. 60 | Options for `model`: resnet50_ibn_a, resnet50_ibn_b, resnet101_ibn_a, resnext101_ibn_a, se_resnet101_ibn_a, densenet121_ibn_a, densenet169_ibn_a. 61 | 62 | 2. Run test script 63 | ```Shell 64 | sh test.sh 65 | ``` 66 | 67 | ### Training 68 | 1. Edit `train.sh`. Modify `model` and `data_path` to yours. 69 | 2. Run train script 70 | ```Shell 71 | sh train.sh 72 | ``` 73 | 74 | ### Acknowledgement 75 | This code is developed based on [bearpaw/pytorch-classification](https://github.com/bearpaw/pytorch-classification). 76 | 77 | ### MXNet Implementation 78 | https://github.com/bruinxiong/IBN-Net.mxnet 79 | 80 | ### Citing IBN-Net 81 | ``` 82 | @inproceedings{pan2018IBN-Net, 83 | author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang}, 84 | title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net}, 85 | booktitle = {ECCV}, 86 | year = {2018} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /ibnnet/se_resnet_ibn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .modules import IBN, SELayer 8 | 9 | 10 | __all__ = ['se_resnet50_ibn_a', 'se_resnet101_ibn_a', 'se_resnet152_ibn_a'] 11 | 12 | 13 | model_urls = { 14 | 'se_resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | class SEBottleneck_IBN(nn.Module): 23 | expansion = 4 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None, ibn=None, reduction=16): 26 | super(SEBottleneck_IBN, self).__init__() 27 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 28 | if ibn == 'a': 29 | self.bn1 = IBN(planes) 30 | else: 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 33 | padding=1, bias=False) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 36 | self.bn3 = nn.BatchNorm2d(planes * 4) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.se = SELayer(planes * 4, reduction) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv3(out) 54 | out = self.bn3(out) 55 | out = self.se(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | 60 | out += residual 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class ResNet_IBN(nn.Module): 67 | 68 | def __init__(self, 69 | block, 70 | layers, 71 | ibn_cfg=('a', 'a', 'a', None), 72 | num_classes=1000): 73 | self.inplanes = 64 74 | super(ResNet_IBN, self).__init__() 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 76 | bias=False) 77 | self.bn1 = nn.BatchNorm2d(64) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 80 | self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0]) 81 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1]) 82 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2]) 83 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3]) 84 | self.avgpool = nn.AvgPool2d(7) 85 | self.fc = nn.Linear(512 * block.expansion, num_classes) 86 | 87 | self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64))) 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 91 | m.weight.data.normal_(0, math.sqrt(2. / n)) 92 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | 96 | def _make_layer(self, block, planes, blocks, stride=1, ibn=None): 97 | downsample = None 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | downsample = nn.Sequential( 100 | nn.Conv2d(self.inplanes, planes * block.expansion, 101 | kernel_size=1, stride=stride, bias=False), 102 | nn.BatchNorm2d(planes * block.expansion), 103 | ) 104 | 105 | layers = [] 106 | layers.append(block(self.inplanes, planes, stride, downsample, ibn=ibn)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, blocks): 109 | layers.append(block(self.inplanes, planes, 1, None, ibn=ibn)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | x = self.conv1(x) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | x = self.layer2(x) 121 | x = self.layer3(x) 122 | x = self.layer4(x) 123 | 124 | x = self.avgpool(x) 125 | x = x.view(x.size(0), -1) 126 | x = self.fc(x) 127 | 128 | return x 129 | 130 | 131 | def se_resnet50_ibn_a(pretrained=False): 132 | """Constructs a SE-ResNet-50-IBN-a model. 133 | 134 | Args: 135 | pretrained (bool): If True, returns a model pre-trained on ImageNet 136 | """ 137 | model = ResNet_IBN(SEBottleneck_IBN, [3, 4, 6, 3], ibn_cfg=('a', 'a', 'a', None)) 138 | if pretrained: 139 | warnings.warn("Pretrained model not available for SE-ResNet-50-IBN-a!") 140 | return model 141 | 142 | 143 | def se_resnet101_ibn_a(pretrained=False): 144 | """Constructs a SE-ResNet-101-IBN-a model. 145 | 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | """ 149 | model = ResNet_IBN(SEBottleneck_IBN, [3, 4, 23, 3], ibn_cfg=('a', 'a', 'a', None)) 150 | if pretrained: 151 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['se_resnet101_ibn_a'])) 152 | return model 153 | 154 | 155 | def se_resnet152_ibn_a(pretrained=False): 156 | """Constructs a SE-ResNet-152-IBN-a model. 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet_IBN(SEBottleneck_IBN, [3, 8, 36, 3], ibn_cfg=('a', 'a', 'a', None)) 162 | if pretrained: 163 | warnings.warn("Pretrained model not available for SE-ResNet-152-IBN-a!") 164 | return model 165 | -------------------------------------------------------------------------------- /ibnnet/resnext_ibn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .modules import IBN 8 | 9 | 10 | __all__ = ['resnext50_ibn_a', 'resnext101_ibn_a', 'resnext152_ibn_a'] 11 | 12 | 13 | model_urls = { 14 | 'resnext101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth', 15 | } 16 | 17 | 18 | class Bottleneck_IBN(nn.Module): 19 | """ 20 | RexNeXt bottleneck type C 21 | """ 22 | expansion = 4 23 | 24 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None, ibn=None): 25 | """ Constructor 26 | Args: 27 | inplanes: input channel dimensionality 28 | planes: output channel dimensionality 29 | baseWidth: base width. 30 | cardinality: num of convolution groups. 31 | stride: conv stride. Replaces pooling layer. 32 | """ 33 | super(Bottleneck_IBN, self).__init__() 34 | 35 | D = int(math.floor(planes * (baseWidth / 64))) 36 | C = cardinality 37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 38 | if ibn == 'a': 39 | self.bn1 = IBN(D*C) 40 | else: 41 | self.bn1 = nn.BatchNorm2d(D*C) 42 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 43 | self.bn2 = nn.BatchNorm2d(D*C) 44 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 45 | self.bn3 = nn.BatchNorm2d(planes * 4) 46 | self.relu = nn.ReLU(inplace=True) 47 | 48 | self.downsample = downsample 49 | 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv3(out) 62 | out = self.bn3(out) 63 | 64 | if self.downsample is not None: 65 | residual = self.downsample(x) 66 | 67 | out += residual 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class ResNeXt_IBN(nn.Module): 74 | 75 | def __init__(self, 76 | baseWidth, 77 | cardinality, 78 | layers, 79 | ibn_cfg=('a', 'a', 'a', None), 80 | num_classes=1000): 81 | super(ResNeXt_IBN, self).__init__() 82 | block = Bottleneck_IBN 83 | 84 | self.cardinality = cardinality 85 | self.baseWidth = baseWidth 86 | self.num_classes = num_classes 87 | self.inplanes = 64 88 | self.output_size = 64 89 | 90 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0]) 95 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1]) 96 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2]) 97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3]) 98 | self.avgpool = nn.AvgPool2d(7) 99 | self.fc = nn.Linear(512 * block.expansion, num_classes) 100 | 101 | self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64))) 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 105 | m.weight.data.normal_(0, math.sqrt(2. / n)) 106 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 107 | m.weight.data.fill_(1) 108 | m.bias.data.zero_() 109 | 110 | def _make_layer(self, block, planes, blocks, stride=1, ibn=None): 111 | downsample = None 112 | if stride != 1 or self.inplanes != planes * block.expansion: 113 | downsample = nn.Sequential( 114 | nn.Conv2d(self.inplanes, planes * block.expansion, 115 | kernel_size=1, stride=stride, bias=False), 116 | nn.BatchNorm2d(planes * block.expansion), 117 | ) 118 | 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, self.baseWidth, 121 | self.cardinality, stride, downsample, ibn)) 122 | self.inplanes = planes * block.expansion 123 | for i in range(1, blocks): 124 | layers.append(block(self.inplanes, planes, self.baseWidth, 125 | self.cardinality, 1, None, ibn)) 126 | 127 | return nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool1(x) 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x = self.layer3(x) 137 | x = self.layer4(x) 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def resnext50_ibn_a(pretrained=False, baseWidth=4, cardinality=32): 146 | """ 147 | Construct ResNeXt-50-IBN-a. 148 | """ 149 | model = ResNeXt_IBN(baseWidth, cardinality, [3, 4, 6, 3], ('a', 'a', 'a', None)) 150 | if pretrained: 151 | warnings.warn("Pretrained model not available for ResNeXt-50-IBN-a!") 152 | return model 153 | 154 | 155 | def resnext101_ibn_a(pretrained=False, baseWidth=4, cardinality=32): 156 | """ 157 | Construct ResNeXt-101-IBN-a. 158 | """ 159 | model = ResNeXt_IBN(baseWidth, cardinality, [3, 4, 23, 3], ('a', 'a', 'a', None)) 160 | if pretrained: 161 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnext101_ibn_a'])) 162 | return model 163 | 164 | 165 | def resnext152_ibn_a(pretrained=False, baseWidth=4, cardinality=32): 166 | """ 167 | Construct ResNeXt-152-IBN-a. 168 | """ 169 | model = ResNeXt_IBN(baseWidth, cardinality, [3, 8, 36, 3], ('a', 'a', 'a', None)) 170 | if pretrained: 171 | warnings.warn("Pretrained model not available for ResNeXt-152-IBN-a!") 172 | return model 173 | -------------------------------------------------------------------------------- /ibnnet/densenet_ibn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | __all__ = ['DenseNet_IBN', 'densenet121_ibn_a', 'densenet169_ibn_a', 10 | 'densenet201_ibn_a', 'densenet161_ibn_a'] 11 | 12 | 13 | model_urls = { 14 | 'densenet121_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/densenet121_ibn_a-e4af5cc1.pth', 15 | 'densenet169_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/densenet169_ibn_a-9f32c161.pth', 16 | } 17 | 18 | 19 | class IBN(nn.Module): 20 | r"""Instance-Batch Normalization layer from 21 | `"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net" 22 | ` 23 | 24 | Args: 25 | planes (int): Number of channels for the input tensor 26 | ratio (float): Ratio of instance normalization in the IBN layer 27 | """ 28 | def __init__(self, planes, ratio=0.5): 29 | super(IBN, self).__init__() 30 | self.half = int(planes * (1-ratio)) 31 | self.BN = nn.BatchNorm2d(self.half) 32 | self.IN = nn.InstanceNorm2d(planes - self.half, affine=True) 33 | 34 | def forward(self, x): 35 | split = torch.split(x, self.half, 1) 36 | out1 = self.BN(split[0].contiguous()) 37 | out2 = self.IN(split[1].contiguous()) 38 | out = torch.cat((out1, out2), 1) 39 | return out 40 | 41 | 42 | def densenet121_ibn_a(pretrained=False, **kwargs): 43 | r"""Densenet-121-IBN-a model from 44 | `"Densely Connected Convolutional Networks" `_ 45 | 46 | Args: 47 | pretrained (bool): If True, returns a model pre-trained on ImageNet 48 | """ 49 | model = DenseNet_IBN(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 50 | **kwargs) 51 | if pretrained: 52 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['densenet121_ibn_a'])) 53 | return model 54 | 55 | 56 | def densenet169_ibn_a(pretrained=False, **kwargs): 57 | r"""Densenet-169-IBN-a model from 58 | `"Densely Connected Convolutional Networks" `_ 59 | 60 | Args: 61 | pretrained (bool): If True, returns a model pre-trained on ImageNet 62 | """ 63 | model = DenseNet_IBN(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 64 | **kwargs) 65 | if pretrained: 66 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['densenet169_ibn_a'])) 67 | return model 68 | 69 | 70 | def densenet201_ibn_a(pretrained=False, **kwargs): 71 | r"""Densenet-201-IBN-a model from 72 | `"Densely Connected Convolutional Networks" `_ 73 | 74 | Args: 75 | pretrained (bool): If True, returns a model pre-trained on ImageNet 76 | """ 77 | model = DenseNet_IBN(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 78 | **kwargs) 79 | if pretrained: 80 | warnings.warn("Pretrained model not available for Densenet-201-IBN-a!") 81 | return model 82 | 83 | 84 | def densenet161_ibn_a(pretrained=False, **kwargs): 85 | r"""Densenet-161-IBN-a model from 86 | `"Densely Connected Convolutional Networks" `_ 87 | 88 | Args: 89 | pretrained (bool): If True, returns a model pre-trained on ImageNet 90 | """ 91 | model = DenseNet_IBN(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 92 | **kwargs) 93 | if pretrained: 94 | warnings.warn("Pretrained model not available for Densenet-161-IBN-a!") 95 | return model 96 | 97 | 98 | class _DenseLayer(nn.Sequential): 99 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, ibn): 100 | super(_DenseLayer, self).__init__() 101 | if ibn: 102 | self.add_module('norm1', IBN(num_input_features, 0.4)), 103 | else: 104 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 105 | self.add_module('relu1', nn.ReLU(inplace=True)), 106 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 107 | growth_rate, kernel_size=1, stride=1, bias=False)), 108 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 109 | self.add_module('relu2', nn.ReLU(inplace=True)), 110 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 111 | kernel_size=3, stride=1, padding=1, bias=False)), 112 | self.drop_rate = drop_rate 113 | 114 | def forward(self, x): 115 | new_features = super(_DenseLayer, self).forward(x) 116 | if self.drop_rate > 0: 117 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 118 | return torch.cat([x, new_features], 1) 119 | 120 | 121 | class _DenseBlock(nn.Sequential): 122 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, ibn): 123 | super(_DenseBlock, self).__init__() 124 | for i in range(num_layers): 125 | if ibn and i % 3 == 0: 126 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, True) 127 | else: 128 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, False) 129 | self.add_module('denselayer%d' % (i + 1), layer) 130 | 131 | 132 | class _Transition(nn.Sequential): 133 | def __init__(self, num_input_features, num_output_features): 134 | super(_Transition, self).__init__() 135 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 136 | self.add_module('relu', nn.ReLU(inplace=True)) 137 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 138 | kernel_size=1, stride=1, bias=False)) 139 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 140 | 141 | 142 | class DenseNet_IBN(nn.Module): 143 | r"""Densenet-BC model class, based on 144 | `"Densely Connected Convolutional Networks" `_ 145 | 146 | Args: 147 | growth_rate (int) - how many filters to add each layer (`k` in paper) 148 | block_config (list of 4 ints) - how many layers in each pooling block 149 | num_init_features (int) - the number of filters to learn in the first convolution layer 150 | bn_size (int) - multiplicative factor for number of bottle neck layers 151 | (i.e. bn_size * k features in the bottleneck layer) 152 | drop_rate (float) - dropout rate after each dense layer 153 | num_classes (int) - number of classification classes 154 | """ 155 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 156 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 157 | 158 | super(DenseNet_IBN, self).__init__() 159 | 160 | # First convolution 161 | self.features = nn.Sequential(OrderedDict([ 162 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 163 | ('norm0', nn.BatchNorm2d(num_init_features)), 164 | ('relu0', nn.ReLU(inplace=True)), 165 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 166 | ])) 167 | 168 | # Each denseblock 169 | num_features = num_init_features 170 | for i, num_layers in enumerate(block_config): 171 | ibn = True 172 | if i >= 3: 173 | ibn = False 174 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 175 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, ibn=ibn) 176 | self.features.add_module('denseblock%d' % (i + 1), block) 177 | num_features = num_features + num_layers * growth_rate 178 | if i != len(block_config) - 1: 179 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 180 | self.features.add_module('transition%d' % (i + 1), trans) 181 | num_features = num_features // 2 182 | 183 | # Final batch norm 184 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 185 | 186 | # Linear layer 187 | self.classifier = nn.Linear(num_features, num_classes) 188 | 189 | def forward(self, x): 190 | features = self.features(x) 191 | out = F.relu(features, inplace=True) 192 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 193 | out = self.classifier(out) 194 | return out 195 | -------------------------------------------------------------------------------- /ibnnet/resnet_ibn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .modules import IBN 8 | 9 | 10 | __all__ = ['ResNet_IBN', 'resnet18_ibn_a', 'resnet34_ibn_a', 'resnet50_ibn_a', 'resnet101_ibn_a', 'resnet152_ibn_a', 11 | 'resnet18_ibn_b', 'resnet34_ibn_b', 'resnet50_ibn_b', 'resnet101_ibn_b', 'resnet152_ibn_b'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth', 16 | 'resnet34_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth', 17 | 'resnet50_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth', 18 | 'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth', 19 | 'resnet18_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_b-bc2f3c11.pth', 20 | 'resnet34_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_b-04134c37.pth', 21 | 'resnet50_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_b-9ca61e85.pth', 22 | 'resnet101_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_b-c55f6dba.pth', 23 | } 24 | 25 | 26 | class BasicBlock_IBN(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): 30 | super(BasicBlock_IBN, self).__init__() 31 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | if ibn == 'a': 34 | self.bn1 = IBN(planes) 35 | else: 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == 'b' else None 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | if self.IN is not None: 59 | out = self.IN(out) 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck_IBN(nn.Module): 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): 69 | super(Bottleneck_IBN, self).__init__() 70 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 71 | if ibn == 'a': 72 | self.bn1 = IBN(planes) 73 | else: 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 76 | padding=1, bias=False) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 79 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 80 | self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == 'b' else None 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | residual = self.downsample(x) 101 | 102 | out += residual 103 | if self.IN is not None: 104 | out = self.IN(out) 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class ResNet_IBN(nn.Module): 111 | 112 | def __init__(self, 113 | block, 114 | layers, 115 | ibn_cfg=('a', 'a', 'a', None), 116 | num_classes=1000): 117 | self.inplanes = 64 118 | super(ResNet_IBN, self).__init__() 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 120 | bias=False) 121 | if ibn_cfg[0] == 'b': 122 | self.bn1 = nn.InstanceNorm2d(64, affine=True) 123 | else: 124 | self.bn1 = nn.BatchNorm2d(64) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 127 | self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0]) 128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1]) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2]) 130 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3]) 131 | self.avgpool = nn.AvgPool2d(7) 132 | self.fc = nn.Linear(512 * block.expansion, num_classes) 133 | 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | 142 | def _make_layer(self, block, planes, blocks, stride=1, ibn=None): 143 | downsample = None 144 | if stride != 1 or self.inplanes != planes * block.expansion: 145 | downsample = nn.Sequential( 146 | nn.Conv2d(self.inplanes, planes * block.expansion, 147 | kernel_size=1, stride=stride, bias=False), 148 | nn.BatchNorm2d(planes * block.expansion), 149 | ) 150 | 151 | layers = [] 152 | layers.append(block(self.inplanes, planes, 153 | None if ibn == 'b' else ibn, 154 | stride, downsample)) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append(block(self.inplanes, planes, 158 | None if (ibn == 'b' and i < blocks-1) else ibn)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def forward(self, x): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | 173 | x = self.avgpool(x) 174 | x = x.view(x.size(0), -1) 175 | x = self.fc(x) 176 | 177 | return x 178 | 179 | 180 | def resnet18_ibn_a(pretrained=False, **kwargs): 181 | """Constructs a ResNet-18-IBN-a model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet_IBN(block=BasicBlock_IBN, 187 | layers=[2, 2, 2, 2], 188 | ibn_cfg=('a', 'a', 'a', None), 189 | **kwargs) 190 | if pretrained: 191 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_a'])) 192 | return model 193 | 194 | 195 | def resnet34_ibn_a(pretrained=False, **kwargs): 196 | """Constructs a ResNet-34-IBN-a model. 197 | 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | model = ResNet_IBN(block=BasicBlock_IBN, 202 | layers=[3, 4, 6, 3], 203 | ibn_cfg=('a', 'a', 'a', None), 204 | **kwargs) 205 | if pretrained: 206 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_a'])) 207 | return model 208 | 209 | 210 | def resnet50_ibn_a(pretrained=False, **kwargs): 211 | """Constructs a ResNet-50-IBN-a model. 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet_IBN(block=Bottleneck_IBN, 217 | layers=[3, 4, 6, 3], 218 | ibn_cfg=('a', 'a', 'a', None), 219 | **kwargs) 220 | if pretrained: 221 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_a'])) 222 | return model 223 | 224 | 225 | def resnet101_ibn_a(pretrained=False, **kwargs): 226 | """Constructs a ResNet-101-IBN-a model. 227 | 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | """ 231 | model = ResNet_IBN(block=Bottleneck_IBN, 232 | layers=[3, 4, 23, 3], 233 | ibn_cfg=('a', 'a', 'a', None), 234 | **kwargs) 235 | if pretrained: 236 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_a'])) 237 | return model 238 | 239 | 240 | def resnet152_ibn_a(pretrained=False, **kwargs): 241 | """Constructs a ResNet-152-IBN-a model. 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet_IBN(block=Bottleneck_IBN, 247 | layers=[3, 8, 36, 3], 248 | ibn_cfg=('a', 'a', 'a', None), 249 | **kwargs) 250 | if pretrained: 251 | warnings.warn("Pretrained model not available for ResNet-152-IBN-a!") 252 | return model 253 | 254 | 255 | def resnet18_ibn_b(pretrained=False, **kwargs): 256 | """Constructs a ResNet-18-IBN-b model. 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | """ 261 | model = ResNet_IBN(block=BasicBlock_IBN, 262 | layers=[2, 2, 2, 2], 263 | ibn_cfg=('b', 'b', None, None), 264 | **kwargs) 265 | if pretrained: 266 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_b'])) 267 | return model 268 | 269 | 270 | def resnet34_ibn_b(pretrained=False, **kwargs): 271 | """Constructs a ResNet-34-IBN-b model. 272 | 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | """ 276 | model = ResNet_IBN(block=BasicBlock_IBN, 277 | layers=[3, 4, 6, 3], 278 | ibn_cfg=('b', 'b', None, None), 279 | **kwargs) 280 | if pretrained: 281 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_b'])) 282 | return model 283 | 284 | 285 | def resnet50_ibn_b(pretrained=False, **kwargs): 286 | """Constructs a ResNet-50-IBN-b model. 287 | 288 | Args: 289 | pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | """ 291 | model = ResNet_IBN(block=Bottleneck_IBN, 292 | layers=[3, 4, 6, 3], 293 | ibn_cfg=('b', 'b', None, None), 294 | **kwargs) 295 | if pretrained: 296 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_b'])) 297 | return model 298 | 299 | 300 | def resnet101_ibn_b(pretrained=False, **kwargs): 301 | """Constructs a ResNet-101-IBN-b model. 302 | 303 | Args: 304 | pretrained (bool): If True, returns a model pre-trained on ImageNet 305 | """ 306 | model = ResNet_IBN(block=Bottleneck_IBN, 307 | layers=[3, 4, 23, 3], 308 | ibn_cfg=('b', 'b', None, None), 309 | **kwargs) 310 | if pretrained: 311 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_b'])) 312 | return model 313 | 314 | 315 | def resnet152_ibn_b(pretrained=False, **kwargs): 316 | """Constructs a ResNet-152-IBN-b model. 317 | 318 | Args: 319 | pretrained (bool): If True, returns a model pre-trained on ImageNet 320 | """ 321 | model = ResNet_IBN(block=Bottleneck_IBN, 322 | layers=[3, 8, 36, 3], 323 | ibn_cfg=('b', 'b', None, None), 324 | **kwargs) 325 | if pretrained: 326 | warnings.warn("Pretrained model not available for ResNet-152-IBN-b!") 327 | return model 328 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import random 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.optim as optim 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import torchvision.models as models 17 | 18 | import ibnnet as customized_models 19 | 20 | from utils import Bar, AverageMeter, accuracy, mkdir_p 21 | 22 | # Models 23 | default_model_names = sorted(name for name in models.__dict__ 24 | if name.islower() and not name.startswith("__") 25 | and callable(models.__dict__[name])) 26 | 27 | customized_models_names = sorted(name for name in customized_models.__dict__ 28 | if name.islower() and not name.startswith("__") 29 | and callable(customized_models.__dict__[name])) 30 | 31 | for name in customized_models.__dict__: 32 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): 33 | models.__dict__[name] = customized_models.__dict__[name] 34 | 35 | model_names = default_model_names + customized_models_names 36 | 37 | # Parse arguments 38 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 39 | 40 | # Datasets 41 | parser.add_argument('-d', '--data', default='path to dataset', type=str) 42 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 43 | help='number of data loading workers (default: 4)') 44 | # Optimization options 45 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 46 | help='number of total epochs to run') 47 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 48 | help='manual epoch number (useful on restarts)') 49 | parser.add_argument('--train-batch', default=256, type=int, metavar='N', 50 | help='train batchsize (default: 256)') 51 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 52 | help='test batchsize (default: 100)') 53 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 54 | metavar='LR', help='initial learning rate') 55 | parser.add_argument('--drop', '--dropout', default=0, type=float, 56 | metavar='Dropout', help='Dropout ratio') 57 | parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90], 58 | help='Decrease learning rate at these epochs.') 59 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 60 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 61 | help='momentum') 62 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 63 | metavar='W', help='weight decay (default: 1e-4)') 64 | # Checkpoints 65 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 66 | help='path to save checkpoint (default: checkpoint)') 67 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 68 | help='path to latest checkpoint (default: none)') 69 | # Architecture 70 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', 71 | choices=model_names, 72 | help='model architecture: ' + 73 | ' | '.join(model_names) + 74 | ' (default: resnet50)') 75 | parser.add_argument('--depth', type=int, default=29, help='Model depth.') 76 | parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') 77 | parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') 78 | parser.add_argument('--widen-factor', type=int, default=4, 79 | help='Widen factor. 4 -> 64, 8 -> 128, ...') 80 | # Miscs 81 | parser.add_argument('--manualSeed', type=int, help='manual seed') 82 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 83 | help='evaluate model on validation set') 84 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 85 | help='use pre-trained model') 86 | parser.add_argument('--model_weight', dest='model_weight', default=None, type=str, 87 | help='custom pretrained model weight') 88 | # Device options 89 | parser.add_argument('--cpu', dest='cpu', action='store_true', 90 | help='use cpu mode') 91 | parser.add_argument('--gpu_id', default='1', type=str, 92 | help='id(s) for CUDA_VISIBLE_DEVICES') 93 | 94 | args = parser.parse_args() 95 | state = {k: v for k, v in args._get_kwargs()} 96 | 97 | # Use CUDA 98 | if args.cpu: 99 | print('Use CPU mode') 100 | use_cuda = False 101 | pin_memory = False 102 | else: 103 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 104 | use_cuda = torch.cuda.is_available() 105 | pin_memory = True 106 | 107 | # Random seed 108 | if args.manualSeed is None: 109 | args.manualSeed = random.randint(1, 10000) 110 | random.seed(args.manualSeed) 111 | torch.manual_seed(args.manualSeed) 112 | if use_cuda: 113 | torch.cuda.manual_seed_all(args.manualSeed) 114 | 115 | best_acc = 0 # best test accuracy 116 | 117 | 118 | def main(): 119 | global best_acc 120 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 121 | 122 | if not os.path.isdir(args.checkpoint): 123 | mkdir_p(args.checkpoint) 124 | 125 | # Data loading code 126 | traindir = os.path.join(args.data, 'train') 127 | valdir = os.path.join(args.data, 'val') 128 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 129 | std=[0.229, 0.224, 0.225]) 130 | 131 | train_loader = torch.utils.data.DataLoader( 132 | datasets.ImageFolder(traindir, transforms.Compose([ 133 | transforms.RandomSizedCrop(224), 134 | transforms.RandomHorizontalFlip(), 135 | transforms.ToTensor(), 136 | normalize, 137 | ])), 138 | batch_size=args.train_batch, shuffle=True, 139 | num_workers=args.workers, pin_memory=pin_memory) 140 | 141 | val_loader = torch.utils.data.DataLoader( 142 | datasets.ImageFolder(valdir, transforms.Compose([ 143 | transforms.Scale(256), 144 | transforms.CenterCrop(224), 145 | transforms.ToTensor(), 146 | normalize, 147 | ])), 148 | batch_size=args.test_batch, shuffle=False, 149 | num_workers=args.workers, pin_memory=pin_memory) 150 | 151 | # create model 152 | print("=> creating model '{}'".format(args.arch)) 153 | if args.arch.startswith('resnext'): 154 | model = models.__dict__[args.arch]( 155 | pretrained=args.pretrained, 156 | baseWidth=args.base_width, 157 | cardinality=args.cardinality, 158 | ) 159 | else: 160 | model = models.__dict__[args.arch](pretrained=args.pretrained) 161 | 162 | if use_cuda: 163 | model = torch.nn.DataParallel(model).cuda() 164 | cudnn.benchmark = True 165 | print(model) 166 | if args.model_weight: 167 | model_weight = torch.load(args.model_weight) 168 | model.load_state_dict(model_weight['state_dict'], strict=False) 169 | 170 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 171 | 172 | # define loss function (criterion) and optimizer 173 | criterion = nn.CrossEntropyLoss() 174 | if use_cuda: 175 | criterion = criterion.cuda() 176 | optimizer = optim.SGD(model.parameters(), 177 | lr=args.lr, 178 | momentum=args.momentum, 179 | weight_decay=args.weight_decay) 180 | 181 | # Resume 182 | if args.resume: 183 | # Load checkpoint. 184 | print('==> Resuming from checkpoint..') 185 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 186 | args.checkpoint = os.path.dirname(args.resume) 187 | checkpoint = torch.load(args.resume) 188 | best_acc = checkpoint['best_acc'] 189 | start_epoch = checkpoint['epoch'] 190 | model.load_state_dict(checkpoint['state_dict']) 191 | optimizer.load_state_dict(checkpoint['optimizer']) 192 | 193 | if args.evaluate: 194 | print('\nEvaluation only') 195 | test_loss, test_acc1, test_acc5 = test(val_loader, model, criterion, start_epoch, use_cuda) 196 | print(' Test Loss: %.8f, Top1 Acc: %.2f, Top5 Acc: %.2f' 197 | % (test_loss, test_acc1, test_acc5)) 198 | print(' Top1 Err: %.2f, Top5 Err: %.2f' % (100.0 - test_acc1, 100.0 - test_acc5)) 199 | return 200 | 201 | # Train and val 202 | for epoch in range(start_epoch, args.epochs): 203 | adjust_learning_rate(optimizer, epoch) 204 | 205 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) 206 | 207 | train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) 208 | test_loss, test_acc1, test_acc = test(val_loader, model, criterion, epoch, use_cuda) 209 | 210 | # save model 211 | is_best = test_acc > best_acc 212 | best_acc = max(test_acc, best_acc) 213 | save_checkpoint({ 214 | 'epoch': epoch + 1, 215 | 'state_dict': model.state_dict(), 216 | 'acc': test_acc, 217 | 'best_acc': best_acc, 218 | 'optimizer': optimizer.state_dict(), 219 | }, is_best, checkpoint=args.checkpoint) 220 | 221 | print('Best acc:') 222 | print(best_acc) 223 | 224 | 225 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda): 226 | # switch to train mode 227 | model.train() 228 | 229 | batch_time = AverageMeter() 230 | data_time = AverageMeter() 231 | losses = AverageMeter() 232 | top1 = AverageMeter() 233 | top5 = AverageMeter() 234 | end = time.time() 235 | 236 | bar = Bar('P', max=len(train_loader)) 237 | for batch_idx, (inputs, targets) in enumerate(train_loader): 238 | # measure data loading time 239 | data_time.update(time.time() - end) 240 | if use_cuda: 241 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True) 242 | 243 | # compute output 244 | outputs = model(inputs) 245 | loss = criterion(outputs, targets) 246 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 247 | 248 | # measure accuracy and record loss 249 | losses.update(loss.item(), inputs.size(0)) 250 | top1.update(prec1.item(), inputs.size(0)) 251 | top5.update(prec5.item(), inputs.size(0)) 252 | 253 | # compute gradient and do SGD step 254 | optimizer.zero_grad() 255 | loss.backward() 256 | optimizer.step() 257 | 258 | # measure elapsed time 259 | batch_time.update(time.time() - end) 260 | end = time.time() 261 | 262 | # plot progress 263 | if (batch_idx+1) % 10 == 0: 264 | print('({batch}/{size}) D: {data:.2f}s | B: {bt:.2f}s | T: {total:} | ' 265 | 'E: {eta:} | L: {loss:.3f} | t1: {top1: .3f} | t5: {top5: .3f}'.format( 266 | batch=batch_idx + 1, 267 | size=len(train_loader), 268 | data=data_time.val, 269 | bt=batch_time.val, 270 | total=bar.elapsed_td, 271 | eta=bar.eta_td, 272 | loss=losses.avg, 273 | top1=top1.avg, 274 | top5=top5.avg, 275 | )) 276 | bar.next() 277 | bar.finish() 278 | return (losses.avg, top5.avg) 279 | 280 | 281 | def test(val_loader, model, criterion, epoch, use_cuda): 282 | global best_acc 283 | 284 | batch_time = AverageMeter() 285 | data_time = AverageMeter() 286 | losses = AverageMeter() 287 | top1 = AverageMeter() 288 | top5 = AverageMeter() 289 | 290 | # switch to evaluate mode 291 | model.eval() 292 | end = time.time() 293 | 294 | bar = Bar('P', max=len(val_loader)) 295 | for batch_idx, (inputs, targets) in enumerate(val_loader): 296 | # measure data loading time 297 | data_time.update(time.time() - end) 298 | if use_cuda: 299 | inputs, targets = inputs.cuda(), targets.cuda() 300 | 301 | # compute output 302 | end = time.time() 303 | outputs = model(inputs) 304 | batch_time.update(time.time() - end) 305 | loss = criterion(outputs, targets) 306 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 307 | 308 | # measure accuracy and record loss 309 | losses.update(loss.item(), inputs.size(0)) 310 | top1.update(prec1.item(), inputs.size(0)) 311 | top5.update(prec5.item(), inputs.size(0)) 312 | 313 | # plot progress 314 | if (batch_idx+1) % 10 == 0: 315 | print('({batch}/{size}) D: {data:.2f}s | B: {bt:.2f}s | T: {total:} | ' 316 | 'E: {eta:} | L: {loss:.3f} | t1: {top1: .3f} | t5: {top5: .3f}'.format( 317 | batch=batch_idx + 1, 318 | size=len(val_loader), 319 | data=data_time.avg, 320 | bt=batch_time.avg, 321 | total=bar.elapsed_td, 322 | eta=bar.eta_td, 323 | loss=losses.avg, 324 | top1=top1.avg, 325 | top5=top5.avg, 326 | )) 327 | bar.next() 328 | bar.finish() 329 | return (losses.avg, top1.avg, top5.avg) 330 | 331 | 332 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth'): 333 | filepath = os.path.join(checkpoint, filename) 334 | torch.save(state, filepath) 335 | if is_best: 336 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth')) 337 | 338 | 339 | def adjust_learning_rate(optimizer, epoch): 340 | global state 341 | if epoch in args.schedule: 342 | state['lr'] *= args.gamma 343 | for param_group in optimizer.param_groups: 344 | param_group['lr'] = state['lr'] 345 | 346 | 347 | if __name__ == '__main__': 348 | main() 349 | --------------------------------------------------------------------------------