├── .gitignore
├── README.md
├── ckpts
└── .gitignore
├── eval.py
├── images
├── air_bottleneck.png
└── air_module.png
├── models
├── __init__.py
├── __init__.pyc
├── air.py
├── air.pyc
├── airx.py
└── airx.pyc
└── utils
├── __init__.py
├── __init__.pyc
├── measure.py
├── measure.pyc
├── transforms.py
└── transforms.pyc
/.gitignore:
--------------------------------------------------------------------------------
1 | *.tar.gz
2 | *.pth
3 | *.tar
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AirNet-PyTorch
2 | Implementation of the paper ''Attention Inspiring Receptive-fields Network'' (under review), which contains the evaluation code and trained models. By:
3 |
4 | [Lu Yang](https://github.com/soeaver), Qing Song, Yingqi Wu and Mengjie Hu
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## Install
13 | * Install [PyTorch>=0.3.0](http://pytorch.org/)
14 | * Install [torchvision>=0.2.0](http://pytorch.org/)
15 | * Clone
16 | ```
17 | git clone https://github.com/soeaver/AirNet-PyTorch
18 | ```
19 |
20 | ## Evaluation
21 | * Download the trained models, and move them to the `ckpts` folder.
22 | * Run the `eval.py`:
23 | ```
24 | python eval.py --gpu_id 0 --arch airnet50_1x64d --model_weights ./ckpts/air50_1x64d.pth
25 | ```
26 | * The results will be consistent with the paper.
27 |
28 |
29 | ## Results
30 |
31 | ### ImageNet1k
32 | Single-crop (224x224) validation error rate is reported.
33 |
34 | | Network | Flops (G) | Params (M) | Top-1 Error (%) | Top-5 Error (%) | Download |
35 | | :---------------------: | --------- |----------- | --------------- | --------------- | -------- |
36 | | AirNet50-1x64d (r=16) | 4.36 | 25.7 | 22.11 | 6.18 | [GoogleDrive](https://drive.google.com/open?id=1oUHnx8pw9YRJshN2biLoh_H1I4efoTWE) |
37 | | AirNet50-1x64d (r=2) | 4.72 | 27.4 | 21.83 | 5.89 | [GoogleDrive](https://drive.google.com/open?id=1rOA9ciKbEKMkiDO3g3qY06goXZR9hO-Y) |
38 | | AirNeXt50-32x4d | 5.29 | 25.5 | 20.87 | 5.52 | [GoogleDrive](https://drive.google.com/open?id=1xLcPHN1NCONtpDKNXDEIKhAn475mYD-L) |
39 |
40 |
41 | ## Other Resources (from [DPNs](https://github.com/cypw/DPNs))
42 |
43 | ImageNet-1k Trainig/Validation List:
44 | - Download link: [GoogleDrive](https://goo.gl/Ne42bM)
45 |
46 | ImageNet-1k category name mapping table:
47 | - Download link: [GoogleDrive](https://goo.gl/YTAED5)
48 |
49 |
--------------------------------------------------------------------------------
/ckpts/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import datetime
5 | import argparse
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torchvision.models as models
10 | import torch.backends.cudnn as cudnn
11 |
12 | import models as customized_models
13 | from PIL import Image
14 | from utils import measure_model, weight_filler
15 | from utils import transforms as T
16 |
17 | # Models
18 | default_model_names = sorted(name for name in models.__dict__
19 | if name.islower() and not name.startswith("__")
20 | and callable(models.__dict__[name]))
21 | customized_models_names = sorted(name for name in customized_models.__dict__
22 | if name.islower() and not name.startswith("__")
23 | and callable(customized_models.__dict__[name]))
24 | for name in customized_models.__dict__:
25 | if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]):
26 | models.__dict__[name] = customized_models.__dict__[name]
27 | model_names = default_model_names + customized_models_names
28 | print(model_names)
29 |
30 | # Parse arguments
31 | parser = argparse.ArgumentParser(description='Evaluat the imagenet validation',
32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
33 |
34 | parser.add_argument('--gpu_id', type=str, default='1', help='gpu id for evaluation')
35 | parser.add_argument('--data_root', type=str, default='/home/user/Database/ILSVRC2012/Data/CLS-LOC/val/',
36 | help='Path to imagenet validation path')
37 | parser.add_argument('--val_file', type=str, default='ILSVRC2012_val.txt',
38 | help='val_file')
39 | parser.add_argument('--arch', type=str,
40 | default='air50_1x64d',
41 | help='model arch')
42 | parser.add_argument('--model_weights', type=str,
43 | default='./ckpts/air50_1x64d.pth',
44 | help='model weights')
45 |
46 | parser.add_argument('--ground_truth', type=bool, default=True, help='whether provide gt labels')
47 | parser.add_argument('--class_num', type=int, default=1000, help='predict classes number')
48 | parser.add_argument('--skip_num', type=int, default=0, help='skip_num for evaluation')
49 | parser.add_argument('--base_size', type=int, default=256, help='short size of images')
50 | parser.add_argument('--crop_size', type=int, default=224, help='crop size of images')
51 | parser.add_argument('--crop_type', type=str, default='center', choices=['center', 'multi'],
52 | help='crop type of evaluation')
53 | parser.add_argument('--batch_size', type=int, default=1, help='batch size of multi-crop test')
54 | parser.add_argument('--top_k', type=int, nargs='+', default=[1, 5], help='top_k')
55 | parser.add_argument('--save_score_vec', type=bool, default=False, help='whether save the score map')
56 |
57 | args = parser.parse_args()
58 |
59 | # ------------------ MEAN & STD ---------------------
60 | PIXEL_MEANS = np.array([0.485, 0.456, 0.406])
61 | PIXEL_STDS = np.array([0.229, 0.224, 0.225])
62 | # ---------------------------------------------------
63 |
64 | # Set GPU id, CUDA and cudnn
65 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
66 | USE_CUDA = torch.cuda.is_available()
67 | cudnn.benchmark = True
68 |
69 | # Create & Load model
70 | MODEL = models.__dict__[args.arch]()
71 | # Calculate FLOPs & Param
72 | n_flops, n_convops, n_params = measure_model(MODEL, args.crop_size, args.crop_size)
73 | print('==> FLOPs: {:.4f}M, Conv_FLOPs: {:.4f}M, Params: {:.4f}M'.
74 | format(n_flops / 1e6, n_convops / 1e6, n_params / 1e6))
75 | del MODEL
76 |
77 | # Load Weights
78 | MODEL = models.__dict__[args.arch]()
79 | checkpoint = torch.load(args.model_weights)
80 | weight_dict = checkpoint
81 | model_dict = MODEL.state_dict()
82 | updated_dict, match_layers, mismatch_layers = weight_filler(weight_dict, model_dict)
83 | model_dict.update(updated_dict)
84 | MODEL.load_state_dict(model_dict)
85 |
86 | # Switch to evaluate mode
87 | MODEL.cuda().eval()
88 | print(MODEL)
89 |
90 | # Create log & dict
91 | LOG_PTH = './log{}.txt'.format(datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
92 | SET_DICT = dict()
93 | f = open(args.val_file, 'r')
94 | img_order = 0
95 | for _ in f:
96 | img_dict = dict()
97 | img_dict['path'] = os.path.join(args.data_root + _.strip().split(' ')[0])
98 | img_dict['evaluated'] = False
99 | img_dict['score_vec'] = []
100 | if args.ground_truth:
101 | img_dict['gt'] = int(_.strip().split(' ')[1])
102 | else:
103 | img_dict['gt'] = -1
104 | SET_DICT[img_order] = img_dict
105 | img_order += 1
106 | f.close()
107 |
108 |
109 | def eval_batch():
110 | eval_len = len(SET_DICT)
111 | accuracy = np.zeros(len(args.top_k))
112 | start_time = datetime.datetime.now()
113 |
114 | for i in xrange(eval_len - args.skip_num):
115 | im = cv2.imread(SET_DICT[i + args.skip_num]['path'])
116 | im = T.bgr2rgb(im)
117 | scale_im = T.pil_resize(Image.fromarray(im), args.base_size)
118 | normalized_im = T.normalize(np.asarray(scale_im) / 255.0, mean=PIXEL_MEANS, std=PIXEL_STDS)
119 | crop_ims = []
120 | if args.crop_type == 'center': # for single crop
121 | crop_ims.append(T.center_crop(normalized_im, crop_size=args.crop_size))
122 | elif args.crop_type == 'multi': # for 10 crops
123 | crop_ims.extend(T.mirror_crop(normalized_im, crop_size=args.crop_size))
124 | else:
125 | crop_ims.append(normalized_im)
126 |
127 | score_vec = np.zeros(args.class_num, dtype=np.float32)
128 | iter_num = int(len(crop_ims) / args.batch_size)
129 | timer_pt1 = datetime.datetime.now()
130 | for j in xrange(iter_num):
131 | input_data = np.asarray(crop_ims, dtype=np.float32)[j * args.batch_size:(j + 1) * args.batch_size]
132 | input_data = input_data.transpose(0, 3, 1, 2)
133 | input_data = torch.autograd.Variable(torch.from_numpy(input_data).cuda(), volatile=True)
134 | outputs = MODEL(input_data)
135 | scores = outputs.data.cpu().numpy()
136 | score_vec += np.sum(scores, axis=0)
137 | score_index = (-score_vec / len(crop_ims)).argsort()
138 | timer_pt2 = datetime.datetime.now()
139 |
140 | SET_DICT[i + args.skip_num]['evaluated'] = True
141 | SET_DICT[i + args.skip_num]['score_vec'] = score_vec / len(crop_ims)
142 |
143 | print 'Testing image: {}/{} {} {}/{} {}s' \
144 | .format(str(i + 1), str(eval_len - args.skip_num), str(SET_DICT[i + args.skip_num]['path'].split('/')[-1]),
145 | str(score_index[0]), str(SET_DICT[i + args.skip_num]['gt']),
146 | str((timer_pt2 - timer_pt1).microseconds / 1e6 + (timer_pt2 - timer_pt1).seconds)),
147 |
148 | for j in xrange(len(args.top_k)):
149 | if SET_DICT[i + args.skip_num]['gt'] in score_index[:args.top_k[j]]:
150 | accuracy[j] += 1
151 | tmp_acc = float(accuracy[j]) / float(i + 1)
152 | if args.top_k[j] == 1:
153 | print '\ttop_' + str(args.top_k[j]) + ':' + str(tmp_acc),
154 | else:
155 | print 'top_' + str(args.top_k[j]) + ':' + str(tmp_acc)
156 | end_time = datetime.datetime.now()
157 |
158 | w = open(LOG_PTH, 'w')
159 | s1 = 'Evaluation process ends at: {}. \nTime cost is: {}. '.format(str(end_time), str(end_time - start_time))
160 | s2 = '\nThe model is: {}. \nThe val file is: {}. \n{} images has been tested, crop_type is: {}, base_size is: {}, ' \
161 | 'crop_size is: {}.'.format(args.model_weights, args.val_file, str(eval_len - args.skip_num),
162 | args.crop_type, str(args.base_size), str(args.crop_size))
163 | s3 = '\nThe PIXEL_MEANS is: ({}, {}, {}), PIXEL_STDS is : ({}, {}, {}).' \
164 | .format(str(PIXEL_MEANS[0]), str(PIXEL_MEANS[1]), str(PIXEL_MEANS[2]), str(PIXEL_STDS[0]), str(PIXEL_STDS[1]),
165 | str(PIXEL_STDS[2]))
166 | s4 = ''
167 | for i in xrange(len(args.top_k)):
168 | _acc = float(accuracy[i]) / float(eval_len - args.skip_num)
169 | s4 += '\nAccuracy of top_{} is: {}; correct num is {}.'.format(str(args.top_k[i]), str(_acc),
170 | str(int(accuracy[i])))
171 | print s1, s2, s3, s4
172 | w.write(s1 + s2 + s3 + s4)
173 | w.close()
174 |
175 | if args.save_score_vec:
176 | w = open(LOG_PTH.replace('.txt', 'scorevec.txt'), 'w')
177 | for i in xrange(eval_len - args.skip_num):
178 | w.write(SET_DICT[i + args.skip_num]['score_vec'])
179 | w.close()
180 | print('DONE!')
181 |
182 |
183 | if __name__ == '__main__':
184 | eval_batch()
185 |
--------------------------------------------------------------------------------
/images/air_bottleneck.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/images/air_bottleneck.png
--------------------------------------------------------------------------------
/images/air_module.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/images/air_module.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .air import *
4 | from .airx import *
5 |
--------------------------------------------------------------------------------
/models/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/models/__init__.pyc
--------------------------------------------------------------------------------
/models/air.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | """
4 | Attention Inspiring Receptive-fields Network
5 | Copyright (c) Yang Lu, 2018
6 | """
7 | import math
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn import init
11 | import torch
12 |
13 | __all__ = ['air50_1x64d', 'air101_1x64d']
14 |
15 |
16 | class AIRBottleneck(nn.Module):
17 | """
18 | AIRBottleneck
19 | """
20 | expansion = 4
21 |
22 | def __init__(self, inplanes, planes, stride=1, ratio=2, downsample=None):
23 | """ Constructor
24 | Args:
25 | inplanes: input channel dimensionality
26 | planes: output channel dimensionality
27 | stride: conv stride. Replaces pooling layer
28 | ratio: dimensionality-compression ratio.
29 | """
30 | super(AIRBottleneck, self).__init__()
31 |
32 | self.stride = stride
33 | self.planes = planes
34 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
39 | self.bn3 = nn.BatchNorm2d(planes * 4)
40 |
41 | if self.stride == 1 and self.planes < 512: # for C2, C3, C4 stages
42 | self.conv_att1 = nn.Conv2d(inplanes, planes // ratio, kernel_size=1, stride=1, padding=0, bias=False)
43 | self.bn_att1 = nn.BatchNorm2d(planes // ratio)
44 | self.subsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
45 | # self.conv_att2 = nn.Conv2d(planes // ratio, planes // ratio, kernel_size=3, stride=2, padding=1, bias=False)
46 | # self.bn_att2 = nn.BatchNorm2d(planes // ratio)
47 | self.conv_att3 = nn.Conv2d(planes // ratio, planes // ratio, kernel_size=3, stride=1, padding=1, bias=False)
48 | self.bn_att3 = nn.BatchNorm2d(planes // ratio)
49 | self.conv_att4 = nn.Conv2d(planes // ratio, planes, kernel_size=1, stride=1, padding=0, bias=False)
50 | self.bn_att4 = nn.BatchNorm2d(planes)
51 | self.sigmoid = nn.Sigmoid()
52 |
53 | self.relu = nn.ReLU(inplace=True)
54 | self.downsample = downsample
55 |
56 | def forward(self, x):
57 | residual = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 | out = self.conv2(out)
63 | out = self.bn2(out)
64 | out = self.relu(out)
65 |
66 | if self.stride == 1 and self.planes < 512:
67 | att = self.conv_att1(x)
68 | att = self.bn_att1(att)
69 | att = self.relu(att)
70 | # att = self.conv_att2(att)
71 | # att = self.bn_att2(att)
72 | # att = self.relu(att)
73 | att = self.subsample(att)
74 | att = self.conv_att3(att)
75 | att = self.bn_att3(att)
76 | att = self.relu(att)
77 | att = F.upsample(att, size=out.size()[2:], mode='bilinear')
78 | att = self.conv_att4(att)
79 | att = self.bn_att4(att)
80 | att = self.sigmoid(att)
81 | out = out * att
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class AIR(nn.Module):
96 | def __init__(self, baseWidth=64, head7x7=True, ratio=2, layers=(3, 4, 23, 3), num_classes=1000):
97 | """ Constructor
98 | Args:
99 | layers: config of layers, e.g., [3, 4, 23, 3]
100 | num_classes: number of classes
101 | """
102 | super(AIR, self).__init__()
103 | block = AIRBottleneck
104 |
105 | self.inplanes = baseWidth
106 |
107 | self.head7x7 = head7x7
108 | if self.head7x7:
109 | self.conv1 = nn.Conv2d(3, baseWidth, 7, 2, 3, bias=False)
110 | self.bn1 = nn.BatchNorm2d(baseWidth)
111 | else:
112 | self.conv1 = nn.Conv2d(3, baseWidth // 2, 3, 2, 1, bias=False)
113 | self.bn1 = nn.BatchNorm2d(baseWidth // 2)
114 | self.conv2 = nn.Conv2d(baseWidth // 2, baseWidth // 2, 3, 1, 1, bias=False)
115 | self.bn2 = nn.BatchNorm2d(baseWidth // 2)
116 | self.conv3 = nn.Conv2d(baseWidth // 2, baseWidth, 3, 1, 1, bias=False)
117 | self.bn3 = nn.BatchNorm2d(baseWidth)
118 | self.relu = nn.ReLU(inplace=True)
119 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
120 |
121 | self.layer1 = self._make_layer(block, baseWidth, layers[0], 1, ratio)
122 | self.layer2 = self._make_layer(block, baseWidth * 2, layers[1], 2, ratio)
123 | self.layer3 = self._make_layer(block, baseWidth * 4, layers[2], 2, ratio)
124 | self.layer4 = self._make_layer(block, baseWidth * 8, layers[3], 2, ratio)
125 | self.avgpool = nn.AdaptiveAvgPool2d(1)
126 | self.fc = nn.Linear(baseWidth * 8 * block.expansion, num_classes)
127 |
128 | for m in self.modules():
129 | if isinstance(m, nn.Conv2d):
130 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
131 | m.weight.data.normal_(0, math.sqrt(2. / n))
132 | elif isinstance(m, nn.BatchNorm2d):
133 | m.weight.data.fill_(1)
134 | m.bias.data.zero_()
135 |
136 | def _make_layer(self, block, planes, blocks, stride=1, ratio=2):
137 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
138 | Args:
139 | block: block type used to construct AIR
140 | planes: number of output channels (need to multiply by block.expansion)
141 | blocks: number of blocks to be built
142 | Returns: a Module consisting of n sequential bottlenecks.
143 | """
144 | downsample = None
145 | if stride != 1 or self.inplanes != planes * block.expansion:
146 | downsample = nn.Sequential(
147 | nn.Conv2d(self.inplanes, planes * block.expansion,
148 | kernel_size=1, stride=stride, bias=False),
149 | nn.BatchNorm2d(planes * block.expansion),
150 | )
151 |
152 | layers = []
153 | layers.append(block(self.inplanes, planes, stride, ratio, downsample))
154 | self.inplanes = planes * block.expansion
155 | for i in range(1, blocks):
156 | layers.append(block(self.inplanes, planes, 1, ratio))
157 |
158 | return nn.Sequential(*layers)
159 |
160 | def forward(self, x):
161 | if self.head7x7:
162 | x = self.conv1(x)
163 | x = self.bn1(x)
164 | x = self.relu(x)
165 | else:
166 | x = self.conv1(x)
167 | x = self.bn1(x)
168 | x = self.relu(x)
169 | x = self.conv2(x)
170 | x = self.bn2(x)
171 | x = self.relu(x)
172 | x = self.conv3(x)
173 | x = self.bn3(x)
174 | x = self.relu(x)
175 | x = self.maxpool1(x)
176 | x = self.layer1(x)
177 | x = self.layer2(x)
178 | x = self.layer3(x)
179 | x = self.layer4(x)
180 | x = self.avgpool(x)
181 | x = x.view(x.size(0), -1)
182 | x = self.fc(x)
183 |
184 | return x
185 |
186 |
187 | def air50_1x64d():
188 | model = AIR(baseWidth=64, head7x7=False, layers=(3, 4, 6, 3), num_classes=1000)
189 | return model
190 |
191 |
192 | def air101_1x64d():
193 | model = AIR(baseWidth=64, head7x7=False, layers=(3, 4, 23, 3), num_classes=1000)
194 | return model
195 |
--------------------------------------------------------------------------------
/models/air.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/models/air.pyc
--------------------------------------------------------------------------------
/models/airx.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | """
4 | Attention Inspiring Receptive-fields Network
5 | Copyright (c) Yang Lu, 2018
6 | """
7 | import math
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn import init
11 | import torch
12 |
13 | __all__ = ['airx50_32x4d', 'airx101_32x4d_r16', 'airx101_32x4d_r2']
14 |
15 |
16 | class AIRXBottleneck(nn.Module):
17 | """
18 | AIRXBottleneck
19 | """
20 | expansion = 4
21 |
22 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, ratio=2, downsample=None):
23 | """ Constructor
24 | Args:
25 | inplanes: input channel dimensionality
26 | planes: output channel dimensionality
27 | baseWidth: base width
28 | cardinality: num of convolution groups
29 | stride: conv stride. Replaces pooling layer
30 | ratio: dimensionality-compression ratio.
31 | """
32 | super(AIRXBottleneck, self).__init__()
33 |
34 | D = int(math.floor(planes * (baseWidth / 64.0)))
35 | C = cardinality
36 | self.stride = stride
37 | self.planes = planes
38 |
39 | self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
40 | self.bn1 = nn.BatchNorm2d(D * C)
41 | self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
42 | self.bn2 = nn.BatchNorm2d(D * C)
43 | self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
44 | self.bn3 = nn.BatchNorm2d(planes * 4)
45 |
46 | if self.stride == 1 and self.planes < 512: # for C2, C3, C4 stages
47 | self.conv_att1 = nn.Conv2d(inplanes, D * C // ratio, kernel_size=1, stride=1, padding=0, bias=False)
48 | self.bn_att1 = nn.BatchNorm2d(D * C // ratio)
49 | self.subsample = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
50 | # self.conv_att2 = nn.Conv2d(D*C // ratio, D*C // ratio, kernel_size=3, stride=2, padding=1, groups=C//2, bias=False)
51 | # self.bn_att2 = nn.BatchNorm2d(D*C // ratio)
52 | self.conv_att3 = nn.Conv2d(D * C // ratio, D * C // ratio, kernel_size=3, stride=1,
53 | padding=1, groups=C // ratio, bias=False)
54 | self.bn_att3 = nn.BatchNorm2d(D * C // ratio)
55 | self.conv_att4 = nn.Conv2d(D * C // ratio, D * C, kernel_size=1, stride=1, padding=0, bias=False)
56 | self.bn_att4 = nn.BatchNorm2d(D * C)
57 | self.sigmoid = nn.Sigmoid()
58 |
59 | self.relu = nn.ReLU(inplace=True)
60 | self.downsample = downsample
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 | out = self.conv2(out)
69 | out = self.bn2(out)
70 | out = self.relu(out)
71 |
72 | if self.stride == 1 and self.planes < 512:
73 | att = self.conv_att1(x)
74 | att = self.bn_att1(att)
75 | att = self.relu(att)
76 | # att = self.conv_att2(att)
77 | # att = self.bn_att2(att)
78 | # att = self.relu(att)
79 | att = self.subsample(att)
80 | att = self.conv_att3(att)
81 | att = self.bn_att3(att)
82 | att = self.relu(att)
83 | att = F.upsample(att, size=out.size()[2:], mode='bilinear')
84 | att = self.conv_att4(att)
85 | att = self.bn_att4(att)
86 | att = self.sigmoid(att)
87 | out = out * att
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class AIRX(nn.Module):
102 | def __init__(self, baseWidth=4, cardinality=32, head7x7=True, ratio=2, layers=(3, 4, 23, 3), num_classes=1000):
103 | """ Constructor
104 | Args:
105 | baseWidth: baseWidth for AIRX.
106 | cardinality: number of convolution groups.
107 | layers: config of layers, e.g., [3, 4, 6, 3]
108 | num_classes: number of classes
109 | """
110 | super(AIRX, self).__init__()
111 | block = AIRXBottleneck
112 |
113 | self.cardinality = cardinality
114 | self.baseWidth = baseWidth
115 | self.inplanes = 64
116 |
117 | self.head7x7 = head7x7
118 | if self.head7x7:
119 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
120 | self.bn1 = nn.BatchNorm2d(64)
121 | else:
122 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
123 | self.bn1 = nn.BatchNorm2d(32)
124 | self.conv2 = nn.Conv2d(32, 32, 3, 1, 1, bias=False)
125 | self.bn2 = nn.BatchNorm2d(32)
126 | self.conv3 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
127 | self.bn3 = nn.BatchNorm2d(64)
128 | self.relu = nn.ReLU(inplace=True)
129 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
130 |
131 | self.layer1 = self._make_layer(block, 64, layers[0], 1, ratio)
132 | self.layer2 = self._make_layer(block, 128, layers[1], 2, ratio)
133 | self.layer3 = self._make_layer(block, 256, layers[2], 2, ratio)
134 | self.layer4 = self._make_layer(block, 512, layers[3], 2, ratio)
135 | self.avgpool = nn.AdaptiveAvgPool2d(1)
136 | self.fc = nn.Linear(512 * block.expansion, num_classes)
137 |
138 | for m in self.modules():
139 | if isinstance(m, nn.Conv2d):
140 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
141 | m.weight.data.normal_(0, math.sqrt(2. / n))
142 | elif isinstance(m, nn.BatchNorm2d):
143 | m.weight.data.fill_(1)
144 | m.bias.data.zero_()
145 |
146 | def _make_layer(self, block, planes, blocks, stride=1, ratio=2):
147 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
148 | Args:
149 | block: block type used to construct ResNext
150 | planes: number of output channels (need to multiply by block.expansion)
151 | blocks: number of blocks to be built
152 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
153 | Returns: a Module consisting of n sequential bottlenecks.
154 | """
155 | downsample = None
156 | if stride != 1 or self.inplanes != planes * block.expansion:
157 | downsample = nn.Sequential(
158 | nn.Conv2d(self.inplanes, planes * block.expansion,
159 | kernel_size=1, stride=stride, bias=False),
160 | nn.BatchNorm2d(planes * block.expansion),
161 | )
162 |
163 | layers = []
164 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, ratio, downsample))
165 | self.inplanes = planes * block.expansion
166 | for i in range(1, blocks):
167 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, 1, ratio))
168 |
169 | return nn.Sequential(*layers)
170 |
171 | def forward(self, x):
172 | if self.head7x7:
173 | x = self.conv1(x)
174 | x = self.bn1(x)
175 | x = self.relu(x)
176 | else:
177 | x = self.conv1(x)
178 | x = self.bn1(x)
179 | x = self.relu(x)
180 | x = self.conv2(x)
181 | x = self.bn2(x)
182 | x = self.relu(x)
183 | x = self.conv3(x)
184 | x = self.bn3(x)
185 | x = self.relu(x)
186 | x = self.maxpool1(x)
187 | x = self.layer1(x)
188 | x = self.layer2(x)
189 | x = self.layer3(x)
190 | x = self.layer4(x)
191 | x = self.avgpool(x)
192 | x = x.view(x.size(0), -1)
193 | x = self.fc(x)
194 |
195 | return x
196 |
197 |
198 | def airx50_32x4d():
199 | model = AIRX(baseWidth=4, cardinality=32, head7x7=False, layers=(3, 4, 6, 3), num_classes=1000)
200 | return model
201 |
202 |
203 | def airx101_32x4d_r16():
204 | model = AIRX(baseWidth=4, cardinality=32, head7x7=False, ratio=16, layers=(3, 4, 23, 3), num_classes=1000)
205 | return model
206 |
207 |
208 | def airx101_32x4d_r2():
209 | model = AIRX(baseWidth=4, cardinality=32, head7x7=False, layers=(3, 4, 23, 3), num_classes=1000)
210 | return model
211 |
--------------------------------------------------------------------------------
/models/airx.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/models/airx.pyc
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .measure import *
2 |
--------------------------------------------------------------------------------
/utils/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/utils/__init__.pyc
--------------------------------------------------------------------------------
/utils/measure.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | from functools import reduce
5 | import operator
6 |
7 | """
8 | Import from https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py
9 | """
10 |
11 | __all__ = ['measure_model', 'weight_filler']
12 |
13 | count_ops = 0
14 | conv_ops = 0
15 | count_params = 0
16 |
17 |
18 | def get_num_gen(gen):
19 | return sum(1 for x in gen)
20 |
21 |
22 | def is_leaf(model):
23 | return get_num_gen(model.children()) == 0
24 |
25 |
26 | def get_layer_info(layer):
27 | layer_str = str(layer)
28 | type_name = layer_str[:layer_str.find('(')].strip()
29 | return type_name
30 |
31 |
32 | def get_layer_param(model):
33 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
34 |
35 |
36 | ### The input batch size should be 1 to call this function
37 | def measure_layer(layer, x):
38 | global count_ops, conv_ops, count_params
39 | delta_ops = 0
40 | delta_params = 0
41 | multi_add = 1
42 | type_name = get_layer_info(layer)
43 |
44 | ### ops_conv
45 | if type_name in ['Conv2d']:
46 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
47 | layer.stride[0] + 1)
48 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
49 | layer.stride[1] + 1)
50 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
51 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
52 | # print (str(layer), delta_ops)
53 | conv_ops += delta_ops
54 | delta_params = get_layer_param(layer)
55 |
56 | ### ops_nonlinearity
57 | elif type_name in ['ReLU', 'Sigmoid', 'PReLU', 'ReLU6']:
58 | delta_ops = x.numel()
59 | delta_params = get_layer_param(layer)
60 |
61 | ### ops_pooling
62 | elif type_name in ['AvgPool2d', 'MaxPool2d']:
63 | in_w = x.size()[2]
64 | kernel_ops = layer.kernel_size * layer.kernel_size
65 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
66 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
67 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
68 | delta_params = get_layer_param(layer)
69 |
70 | ### ops_pooling3d
71 | elif type_name in ['AvgPool3d', 'MaxPool3d']:
72 | in_c = x.size()[2]
73 | kernel_ops = layer.kernel_size[0] * layer.kernel_size[0]
74 | out_c = int((in_c + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] + 1)
75 | delta_ops = x.size()[0] * x.size()[1] * out_c * x.size()[3] * x.size()[4] * kernel_ops
76 | delta_params = get_layer_param(layer)
77 |
78 | elif type_name in ['AdaptiveAvgPool2d']:
79 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]
80 | delta_params = get_layer_param(layer)
81 |
82 | ### ops_linear
83 | elif type_name in ['Linear']:
84 | weight_ops = layer.weight.numel() * multi_add
85 | bias_ops = layer.bias.numel()
86 | delta_ops = x.size()[0] * (weight_ops + bias_ops)
87 | delta_params = get_layer_param(layer)
88 |
89 | ### ops_nothing
90 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout', 'Sequential', 'upsample_bilinear']:
91 | delta_params = get_layer_param(layer)
92 |
93 | ### unknown layer type
94 | else:
95 | raise TypeError('unknown layer type: %s' % type_name)
96 |
97 | count_ops += delta_ops
98 | count_params += delta_params
99 | # print type_name, delta_ops, delta_params
100 | return
101 |
102 |
103 | def measure_model(model, H, W):
104 | global count_ops, conv_ops, count_params
105 | count_ops = 0
106 | conv_ops = 0
107 | count_params = 0
108 | data = Variable(torch.zeros(1, 3, H, W))
109 |
110 | def should_measure(x):
111 | return is_leaf(x)
112 |
113 | def modify_forward(model):
114 | for child in model.children():
115 | if should_measure(child):
116 | def new_forward(m):
117 | def lambda_forward(x):
118 | measure_layer(m, x)
119 | return m.old_forward(x)
120 |
121 | return lambda_forward
122 |
123 | child.old_forward = child.forward
124 | child.forward = new_forward(child)
125 | else:
126 | modify_forward(child)
127 |
128 | def restore_forward(model):
129 | for child in model.children():
130 | # leaf node
131 | if is_leaf(child) and hasattr(child, 'old_forward'):
132 | child.forward = child.old_forward
133 | child.old_forward = None
134 | else:
135 | restore_forward(child)
136 |
137 | modify_forward(model)
138 | model.forward(data)
139 | restore_forward(model)
140 |
141 | return count_ops, conv_ops, count_params
142 |
143 |
144 | def weight_filler(src, dst):
145 | updated_dict = dst.copy()
146 | match_layers = []
147 | mismatch_layers = []
148 | for dst_k in dst:
149 | if dst_k in src:
150 | src_k = dst_k
151 | if src[src_k].shape == dst[dst_k].shape:
152 | match_layers.append(dst_k)
153 | updated_dict[dst_k] = src[src_k]
154 | else:
155 | mismatch_layers.append(dst_k)
156 | elif dst_k.replace('module.', '') in src:
157 | src_k = dst_k.replace('module.', '')
158 | if src[src_k].shape == dst[dst_k].shape:
159 | match_layers.append(dst_k)
160 | updated_dict[dst_k] = src[src_k]
161 | else:
162 | mismatch_layers.append(dst_k)
163 | elif 'module.' + dst_k in src:
164 | src_k = 'module.' + dst_k
165 | if src[src_k].shape == dst[dst_k].shape:
166 | match_layers.append(dst_k)
167 | updated_dict[dst_k] = src[src_k]
168 | else:
169 | mismatch_layers.append(dst_k)
170 | else:
171 | mismatch_layers.append(dst_k)
172 |
173 | return updated_dict, match_layers, mismatch_layers
--------------------------------------------------------------------------------
/utils/measure.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/utils/measure.pyc
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numbers
4 | import random
5 | import collections
6 | import numpy as np
7 | from PIL import Image
8 |
9 |
10 | def bgr2rgb(im):
11 | rgb_im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
12 | return rgb_im
13 |
14 |
15 | def rgb2bgr(im):
16 | bgr_im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
17 | return bgr_im
18 |
19 |
20 | def normalize(im, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), rgb=False):
21 | if rgb:
22 | r, g, b = cv2.split(im)
23 | else:
24 | b, g, r = cv2.split(im)
25 | norm_im = cv2.merge([(b - mean[0]) / std[0], (g - mean[1]) / std[1], (r - mean[2]) / std[2]])
26 | return norm_im
27 |
28 |
29 | def scale(im, short_size=256, max_size=1e5, interp=cv2.INTER_LINEAR):
30 | """ support gray im; interp: cv2.INTER_LINEAR (default) or cv2.INTER_NEAREST; """
31 | im_size_min = np.min(im.shape[0:2])
32 | im_size_max = np.max(im.shape[0:2])
33 | scale_ratio = float(short_size) / float(im_size_min)
34 | if np.round(scale_ratio * im_size_max) > float(max_size):
35 | scale_ratio = float(max_size) / float(im_size_max)
36 |
37 | scale_im = cv2.resize(im, None, None, fx=scale_ratio, fy=scale_ratio, interpolation=interp)
38 |
39 | return scale_im, scale_ratio
40 |
41 |
42 | def scale_by_max(im, long_size=512, interp=cv2.INTER_LINEAR):
43 | """ support gray im; interp: cv2.INTER_LINEAR (default) or cv2.INTER_NEAREST; """
44 | im_size_max = np.max(im.shape[0:2])
45 | scale_ratio = float(long_size) / float(im_size_max)
46 |
47 | scale_im = cv2.resize(im, None, None, fx=scale_ratio, fy=scale_ratio, interpolation=interp)
48 |
49 | return scale_im, scale_ratio
50 |
51 |
52 | def scale_by_target(im, target_size=(512, 256), interp=cv2.INTER_LINEAR):
53 | """ target_size=(h, w), support gray im; interp: cv2.INTER_LINEAR (default) or cv2.INTER_NEAREST; """
54 | min_factor = min(float(target_size[0]) / float(im.shape[0]),
55 | float(target_size[1]) / float(im.shape[1]))
56 |
57 | scale_im = cv2.resize(im, None, None, fx=min_factor, fy=min_factor, interpolation=interp)
58 |
59 | return scale_im, min_factor
60 |
61 |
62 | def rotate(im, degree=0, borderValue=(0, 0, 0), interp=cv2.INTER_LINEAR):
63 | """ support gray im; interp: cv2.INTER_LINEAR (default) or cv2.INTER_NEAREST; """
64 | h, w = im.shape[:2]
65 | rotate_mat = cv2.getRotationMatrix2D((w / 2, h / 2), degree, 1)
66 | rotation = cv2.warpAffine(im, rotate_mat, (w, h), flags=interp,
67 | borderValue=cv2.cv.Scalar(borderValue[0], borderValue[1], borderValue[2]))
68 |
69 | return rotation
70 |
71 |
72 | def HSV_adjust(im, color=1.0, contrast=1.0, brightness=1.0):
73 | _HSV = np.dot(cv2.cvtColor(im, cv2.COLOR_BGR2HSV).reshape((-1, 3)),
74 | np.array([[color, 0, 0], [0, contrast, 0], [0, 0, brightness]]))
75 |
76 | _HSV_H = np.where(_HSV < 255, _HSV, 255)
77 | hsv = cv2.cvtColor(np.uint8(_HSV_H.reshape((-1, im.shape[1], 3))), cv2.COLOR_HSV2BGR)
78 |
79 | return hsv
80 |
81 |
82 | def salt_pepper(im, SNR=1.0):
83 | """ SNR: better >= 0.9; """
84 | noise_num = int((1 - SNR) * im.shape[0] * im.shape[1])
85 | noise_im = im.copy()
86 | for i in xrange(noise_num):
87 | rand_x = np.random.random_integers(0, im.shape[0] - 1)
88 | rand_y = np.random.random_integers(0, im.shape[1] - 1)
89 |
90 | if np.random.random_integers(0, 1) == 0:
91 | noise_im[rand_x, rand_y] = 0
92 | else:
93 | noise_im[rand_x, rand_y] = 255
94 |
95 | return noise_im
96 |
97 |
98 | def padding_im(im, target_size=(512, 512), borderType=cv2.BORDER_CONSTANT, mode=0):
99 | """ support gray im; target_size=(h, w); mode=0 left-top, mode=1 center; """
100 | if mode not in (0, 1):
101 | raise Exception("mode need to be one of 0 or 1, 0 for left-top mode, 1 for center mode.")
102 |
103 | pad_h_top = max(int((target_size[0] - im.shape[0]) * 0.5), 0) * mode
104 | pad_h_bottom = max(target_size[0] - im.shape[0], 0) - pad_h_top
105 | pad_w_left = max(int((target_size[1] - im.shape[1]) * 0.5), 0) * mode
106 | pad_w_right = max(target_size[1] - im.shape[1], 0) - pad_w_left
107 |
108 | if borderType == cv2.BORDER_CONSTANT:
109 | pad_im = cv2.copyMakeBorder(im, pad_h_top, pad_h_bottom, pad_w_left, pad_w_right, cv2.BORDER_CONSTANT)
110 | else:
111 | pad_im = cv2.copyMakeBorder(im, pad_h_top, pad_h_bottom, pad_w_left, pad_w_right, borderType)
112 |
113 | return pad_im
114 |
115 |
116 | def extend_bbox(im, bbox, margin=(0.5, 0.5, 0.5, 0.5)):
117 | box_w = int(bbox[2] - bbox[0])
118 | box_h = int(bbox[3] - bbox[1])
119 |
120 | new_x1 = max(1, bbox[0] - margin[0] * box_w)
121 | new_y1 = max(1, bbox[1] - margin[1] * box_h)
122 | new_x2 = min(im.shape[1] - 1, bbox[2] + margin[2] * box_w)
123 | new_y2 = min(im.shape[0] - 1, bbox[3] + margin[3] * box_h)
124 |
125 | return np.asarray([new_x1, new_y1, new_x2, new_y2])
126 |
127 |
128 | def bbox_crop(im, bbox):
129 | return im[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
130 |
131 |
132 | def center_crop(im, crop_size=224): # single crop
133 | im_size_min = min(im.shape[:2])
134 | if im_size_min < crop_size:
135 | return
136 | yy = int((im.shape[0] - crop_size) / 2)
137 | xx = int((im.shape[1] - crop_size) / 2)
138 | crop_im = im[yy: yy + crop_size, xx: xx + crop_size]
139 |
140 | return crop_im
141 |
142 |
143 | def over_sample(im, crop_size=224): # 5 crops of image
144 | im_size_min = min(im.shape[:2])
145 | if im_size_min < crop_size:
146 | return
147 | yy = int((im.shape[0] - crop_size) / 2)
148 | xx = int((im.shape[1] - crop_size) / 2)
149 | sample_list = [im[:crop_size, :crop_size], im[-crop_size:, -crop_size:], im[:crop_size, -crop_size:],
150 | im[-crop_size:, :crop_size], im[yy: yy + crop_size, xx: xx + crop_size]]
151 |
152 | return sample_list
153 |
154 |
155 | def mirror_crop(im, crop_size=224): # 10 crops
156 | crop_list = []
157 | mirror = im[:, ::-1]
158 | crop_list.extend(over_sample(im, crop_size=crop_size))
159 | crop_list.extend(over_sample(mirror, crop_size=crop_size))
160 |
161 | return crop_list
162 |
163 |
164 | def multiscale_mirrorcrop(im, scales=(256, 288, 320, 352)): # 120(4*3*10) crops
165 | crop_list = []
166 | im_size_min = np.min(im.shape[0:2])
167 | for i in scales:
168 | resize_im = cv2.resize(im, (im.shape[1] * i / im_size_min, im.shape[0] * i / im_size_min))
169 | yy = int((resize_im.shape[0] - i) / 2)
170 | xx = int((resize_im.shape[1] - i) / 2)
171 | for j in xrange(3):
172 | left_center_right = resize_im[yy * j: yy * j + i, xx * j: xx * j + i]
173 | mirror = left_center_right[:, ::-1]
174 | crop_list.extend(over_sample(left_center_right))
175 | crop_list.extend(over_sample(mirror))
176 |
177 | return crop_list
178 |
179 |
180 | def multi_scale(im, scales=(480, 576, 688, 864, 1200), max_sizes=(800, 1000, 1200, 1500, 1800), image_flip=False):
181 | im_size_min = np.min(im.shape[0:2])
182 | im_size_max = np.max(im.shape[0:2])
183 |
184 | scale_ims = []
185 | scale_ratios = []
186 | for i in xrange(len(scales)):
187 | scale_ratio = float(scales[i]) / float(im_size_min)
188 | if np.round(scale_ratio * im_size_max) > float(max_sizes[i]):
189 | scale_ratio = float(max_sizes[i]) / float(im_size_max)
190 | resize_im = cv2.resize(im, None, None, fx=scale_ratio, fy=scale_ratio,
191 | interpolation=cv2.INTER_LINEAR)
192 | scale_ims.append(resize_im)
193 | scale_ratios.append(scale_ratio)
194 | if image_flip:
195 | scale_ims.append(cv2.resize(im[:, ::-1], None, None, fx=scale_ratio, fy=scale_ratio,
196 | interpolation=cv2.INTER_LINEAR))
197 | scale_ratios.append(-scale_ratio)
198 |
199 | return scale_ims, scale_ratios
200 |
201 |
202 | def multi_scale_by_max(im, scales=(480, 576, 688, 864, 1200), image_flip=False):
203 | im_size_max = np.max(im.shape[0:2])
204 |
205 | scale_ims = []
206 | scale_ratios = []
207 | for i in xrange(len(scales)):
208 | scale_ratio = float(scales[i]) / float(im_size_max)
209 |
210 | resize_im = cv2.resize(im, None, None, fx=scale_ratio, fy=scale_ratio, interpolation=cv2.INTER_LINEAR)
211 | scale_ims.append(resize_im)
212 | scale_ratios.append(scale_ratio)
213 | if image_flip:
214 | scale_ims.append(cv2.resize(im[:, ::-1], None, None, fx=scale_ratio, fy=scale_ratio,
215 | interpolation=cv2.INTER_LINEAR))
216 | scale_ratios.append(-scale_ratio)
217 |
218 | return scale_ims, scale_ratios
219 |
220 |
221 | def pil_resize(im, size, interpolation=Image.BILINEAR):
222 | if isinstance(size, int):
223 | w, h = im.size
224 | if (w <= h and w == size) or (h <= w and h == size):
225 | return im
226 | if w < h:
227 | ow = size
228 | oh = int(size * h / w)
229 | return im.resize((ow, oh), interpolation)
230 | else:
231 | oh = size
232 | ow = int(size * w / h)
233 | return im.resize((ow, oh), interpolation)
234 | else:
235 | return im.resize(size[::-1], interpolation)
236 |
237 |
--------------------------------------------------------------------------------
/utils/transforms.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soeaver/AirNet-PyTorch/e9dc06fabbde828109c4f75d8f2907ed1a3d0014/utils/transforms.pyc
--------------------------------------------------------------------------------