├── Datasets.md
├── LICENSE
├── MODEL_ZOO.md
├── README.md
├── infer.py
├── pics
├── 000000000885.jpg
├── backbones.png
├── detection.png
├── example_inference.jpeg
├── example_inference_open_images.jpeg
├── herbarium.png
├── loss_graph.png
├── main_pic.png
├── ms_coco_scores.png
└── open_images.png
├── requirements.txt
├── src
├── helper_functions
│ └── helper_functions.py
├── loss_functions
│ └── losses.py
└── models
│ ├── __init__.py
│ ├── tresnet
│ ├── __init__.py
│ ├── layers
│ │ ├── anti_aliasing.py
│ │ ├── avg_pool.py
│ │ └── general_layers.py
│ └── tresnet.py
│ └── utils
│ ├── __init__.py
│ └── factory.py
├── tests
└── test_asl.py
├── train.py
└── validate.py
/Datasets.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | - MS-COCO dataset can be downloaded from the [officie site](https://cocodataset.org/#download). As commonly done, we use the 2014 Train/Val annotations.
4 |
5 | - Open-Images V6 dataset, that was used for the paper, is now directly availabe to download from [here](https://github.com/Alibaba-MIIL/PartialLabelingCSL/blob/main/OpenImages.md).
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Alibaba-MIIL
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # ASL pre-trained models
2 |
3 | We provide a collection models trained with ASL on various multi-label datasets
4 |
5 |
6 | | Backbone | Input Size | Dataset | mAP |
7 | | ------------ | :--------------: | :--------------: | :--------------: |
8 | | [TResNet_M (ImageNet21K pretrain)](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/MS_COCO_TRresNet_M_224_81.8.pth) | 224 | MS-COCO | 81.8 |
9 | | [TResNet_L](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/MS_COCO_TRresNet_L_448_86.6.pth) | 448 | MS-COCO | 86.6 |
10 | | [TResNet_XL](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/MS_COCO_TResNet_xl_640_88.4.pth) | 640 | MS-COCO | 88.4 |
11 | | [TResNet_L](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/Open_ImagesV6_TRresNet_L_448.pth) | 448 | OpenImagesV6 | 86.3 |
12 | | [TResNet_XL](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/PASCAL_VOC_TResNet_xl_448_96.0.pth) | 448 | PASCAL-VOC | 96.0 |
13 | | [TResNet_L](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/NUS_WIDE_TRresNet_L_448_65.2.pth) | 448 | NUS-WIDE [dataset](https://drive.google.com/file/d/0B7IzDz-4yH_HMFdiSE44R1lselE/view) [split](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/nus_wid_data.csv) | 65.2 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Asymmetric Loss For Multi-Label Classification
2 |
3 | [](https://paperswithcode.com/sota/multi-label-classification-on-ms-coco?p=imagenet-21k-pretraining-for-the-masses)
4 | [](https://paperswithcode.com/sota/multi-label-classification-on-nus-wide?p=asymmetric-loss-for-multi-label)
5 | [](https://paperswithcode.com/sota/multi-label-classification-on-pascal-voc-2007?p=imagenet-21k-pretraining-for-the-masses)
6 |
[Paper](https://arxiv.org/abs/2009.14119) |
7 | [Pretrained models](MODEL_ZOO.md) |
8 | [Datasets](Datasets.md)
9 |
10 | Official PyTorch Implementation
11 |
12 | > Emanuel Ben-Baruch, Tal Ridnik, Nadav Zamir, Asaf Noy, Itamar
13 | > Friedman, Matan Protter, Lihi Zelnik-Manor
DAMO Academy, Alibaba
14 | > Group
15 |
16 | **Abstract**
17 |
18 | In a typical multi-label setting, a picture contains on average few positive labels, and many negative ones. This positive-negative imbalance dominates the optimization process, and can lead to under-emphasizing gradients from positive labels during training, resulting in poor accuracy. In this paper, we introduce a novel asymmetric loss ("ASL"), which operates differently on positive and negative samples. The loss enables to dynamically down-weights and hard-thresholds easy negative samples, while also discarding possibly mislabeled samples. We demonstrate how ASL can balance the probabilities of different samples, and how this balancing is translated to better mAP scores. With ASL, we reach state-of-the-art results on multiple popular multi-label datasets: MS-COCO, Pascal-VOC, NUS-WIDE and Open Images. We also demonstrate ASL applicability for other tasks, such as single-label classification and object detection. ASL is effective, easy to implement, and does not increase the training time or complexity.
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | ## 9/1/2023 Update
30 | Added [tests](https://github.com/Alibaba-MIIL/ASL/blob/main/tests/test_asl.py) auto-generated by [CodiumAI](https://www.codium.ai/) tool
31 |
32 | ## 29/11/2021 Update - New article released, offering new classification head with state-of-the-art results
33 | Checkout our new project, [Ml-Decoder](https://github.com/Alibaba-MIIL/ML_Decoder), which presents a unified classification head for multi-label, single-label and
34 | zero-shot tasks. Backbones with ML-Decoder reach SOTA results, while also improving speed-accuracy tradeoff.
35 |
36 |
37 |
38 |
39 |  |
40 |  |
41 |
42 |
43 |
44 |
45 |
46 | ## 24/7/2021 Update - ASL article was accepeted to ICCV 2021
47 | A final version of the paper, with updated results for ImageNet-21K pretraining, is released to arxiv.
48 | Note that ASL is becoming the de-facto 'default' loss for high performance multi-label classification, and all the top results in papers-with-code are currently using it.
49 |
50 |
51 |
52 |
53 |
54 | ## Training Code Now Available !
55 |
56 | With great collaboration by [@GhostWnd](https://github.com/GhostWnd), we
57 | now provide a [script](train.py) for fully reproducing the article
58 | results, and finally a modern multi-label training code is
59 | available for the community.
60 | ## Frequently Asked Questions
61 | Some questions are repeatedly asked in the issues section. make sure to
62 | review them before starting a new issue:
63 | - Regarding combining ASL with other techniques, see
64 | [link](https://github.com/Alibaba-MIIL/ASL/issues/35)
65 | - Regarding implementation of asymmetric clipping, see [link](https://github.com/Alibaba-MIIL/ASL/issues/10)
66 | - Regarding disable_torch_grad_focal_loss option, see
67 | [link](https://github.com/Alibaba-MIIL/ASL/issues/31)
68 | - Regarding squish Vs crop resizing, see
69 | [link](https://github.com/Alibaba-MIIL/ASL/issues/30#issuecomment-754005570)
70 | - Regarding training tricks, see
71 | [link](https://github.com/Alibaba-MIIL/ASL/issues/30#issuecomment-750780576)
72 | - How to apply ASL to your own dataset, see
73 | [link](https://github.com/Alibaba-MIIL/ASL/issues/22#issuecomment-736721770)
74 |
75 |
76 |
77 | ## Asymmetric Loss (ASL) Implementation
78 | In this PyTorch [file](\src\loss_functions\losses.py), we provide
79 | implementations of our new loss function, ASL, that can serve as a
80 | drop-in replacement for standard loss functions (Cross-Entropy and
81 | Focal-Loss)
82 |
83 | For the multi-label case (sigmoids), the two implementations are:
84 | - ```class AsymmetricLoss(nn.Module)```
85 | - ```class AsymmetricLossOptimized(nn.Module)```
86 |
87 | The two losses are bit-accurate. However, AsymmetricLossOptimized()
88 | contains a more optimized (and complicated) way of implementing ASL,
89 | which minimizes memory allocations, gpu uploading, and favors inplace
90 | operations.
91 |
92 | For the single-label case (softmax), the implementations is called:
93 | - ```class ASLSingleLabel(nn.Module)```
94 |
95 |
96 |
97 | ## Pretrained Models
98 | In this [link](MODEL_ZOO.md), we provide pre-trained models on various
99 | dataset.
100 |
101 | ## Validation Code
102 | Thanks to external contribution of @hellbell, we now provide a
103 | validation code that repdroduces the article results on MS-COCO:
104 |
105 | ```
106 | python validate.py \
107 | --model_name=tresnet_l \
108 | --model_path=./models_local/MS_COCO_TRresNet_L_448_86.6.pth
109 | ```
110 |
111 | ## Inference Code
112 | We provide [inference code](infer.py), that demonstrate how to load our
113 | model, pre-process an image and do actuall inference. Example run of
114 | MS-COCO model (after downloading the relevant model):
115 | ```
116 | python infer.py \
117 | --dataset_type=MS-COCO \
118 | --model_name=tresnet_l \
119 | --model_path=./models_local/MS_COCO_TRresNet_L_448_86.6.pth \
120 | --pic_path=./pics/000000000885.jpg \
121 | --input_size=448
122 | ```
123 | which will result in:
124 |
125 |
126 |
127 |  |
128 |
129 |
130 |
131 |
132 | Example run of OpenImages model:
133 | ```
134 | python infer.py \
135 | --dataset_type=OpenImages \
136 | --model_name=tresnet_l \
137 | --model_path=./models_local/Open_ImagesV6_TRresNet_L_448.pth \
138 | --pic_path=./pics/000000000885.jpg \
139 | --input_size=448
140 | ```
141 |
142 |
143 |
144 |  |
145 |
146 |
147 |
148 |
149 |
150 | ## Citation
151 | ```
152 | @misc{benbaruch2020asymmetric,
153 | title={Asymmetric Loss For Multi-Label Classification},
154 | author={Emanuel Ben-Baruch and Tal Ridnik and Nadav Zamir and Asaf Noy and Itamar Friedman and Matan Protter and Lihi Zelnik-Manor},
155 | year={2020},
156 | eprint={2009.14119},
157 | archivePrefix={arXiv},
158 | primaryClass={cs.CV} }
159 | ```
160 |
161 | ## Contact
162 | Feel free to contact if there are any questions or issues - Emanuel
163 | Ben-Baruch (emanuel.benbaruch@alibaba-inc.com) or Tal Ridnik (tal.ridnik@alibaba-inc.com).
164 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.helper_functions.helper_functions import parse_args
3 | from src.loss_functions.losses import AsymmetricLoss, AsymmetricLossOptimized
4 | from src.models import create_model
5 | import argparse
6 | import matplotlib
7 |
8 | matplotlib.use('TkAgg')
9 | import matplotlib.pyplot as plt
10 | from PIL import Image
11 | import numpy as np
12 |
13 | parser = argparse.ArgumentParser(description='ASL MS-COCO Inference on a single image')
14 |
15 | parser.add_argument('--model_path', type=str, default='./models_local/TRresNet_L_448_86.6.pth')
16 | parser.add_argument('--pic_path', type=str, default='./pics/000000000885.jpg')
17 | parser.add_argument('--model_name', type=str, default='tresnet_l')
18 | parser.add_argument('--input_size', type=int, default=448)
19 | parser.add_argument('--dataset_type', type=str, default='MS-COCO')
20 | parser.add_argument('--th', type=float, default=None)
21 |
22 |
23 | def main():
24 | print('ASL Example Inference code on a single image')
25 |
26 | # parsing args
27 | args = parse_args(parser)
28 |
29 | # setup model
30 | print('creating and loading the model...')
31 | state = torch.load(args.model_path, map_location='cpu')
32 | args.num_classes = state['num_classes']
33 | model = create_model(args).cuda()
34 | model.load_state_dict(state['model'], strict=True)
35 | model.eval()
36 | classes_list = np.array(list(state['idx_to_class'].values()))
37 | print('done\n')
38 |
39 | # doing inference
40 | print('loading image and doing inference...')
41 | im = Image.open(args.pic_path)
42 | im_resize = im.resize((args.input_size, args.input_size))
43 | np_img = np.array(im_resize, dtype=np.uint8)
44 | tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() / 255.0 # HWC to CHW
45 | tensor_batch = torch.unsqueeze(tensor_img, 0).cuda()
46 | output = torch.squeeze(torch.sigmoid(model(tensor_batch)))
47 | np_output = output.cpu().detach().numpy()
48 | detected_classes = classes_list[np_output > args.th]
49 | print('done\n')
50 |
51 | # example loss calculation
52 | output = model(tensor_batch)
53 | loss_func1 = AsymmetricLoss()
54 | loss_func2 = AsymmetricLossOptimized()
55 | target = output.clone()
56 | target[output < 0] = 0 # mockup target
57 | target[output >= 0] = 1
58 | loss1 = loss_func1(output, target)
59 | loss2 = loss_func2(output, target)
60 | assert abs((loss1.item() - loss2.item())) < 1e-6
61 |
62 | # displaying image
63 | print('showing image on screen...')
64 | fig = plt.figure()
65 | plt.imshow(im)
66 | plt.axis('off')
67 | plt.axis('tight')
68 | # plt.rcParams["axes.titlesize"] = 10
69 | plt.title("detected classes: {}".format(detected_classes))
70 |
71 | plt.show()
72 | print('done\n')
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
--------------------------------------------------------------------------------
/pics/000000000885.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/000000000885.jpg
--------------------------------------------------------------------------------
/pics/backbones.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/backbones.png
--------------------------------------------------------------------------------
/pics/detection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/detection.png
--------------------------------------------------------------------------------
/pics/example_inference.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/example_inference.jpeg
--------------------------------------------------------------------------------
/pics/example_inference_open_images.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/example_inference_open_images.jpeg
--------------------------------------------------------------------------------
/pics/herbarium.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/herbarium.png
--------------------------------------------------------------------------------
/pics/loss_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/loss_graph.png
--------------------------------------------------------------------------------
/pics/main_pic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/main_pic.png
--------------------------------------------------------------------------------
/pics/ms_coco_scores.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/ms_coco_scores.png
--------------------------------------------------------------------------------
/pics/open_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/open_images.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7
2 | torchvision>=0.5.0
3 | git+https://github.com/mapillary/inplace_abn.git@v1.0.12
4 | randaugment
5 | pycocotools
--------------------------------------------------------------------------------
/src/helper_functions/helper_functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 | import random
4 | import time
5 | from copy import deepcopy
6 |
7 | import numpy as np
8 | from PIL import Image
9 | from torchvision import datasets as datasets
10 | import torch
11 | from PIL import ImageDraw
12 | from pycocotools.coco import COCO
13 |
14 |
15 | def parse_args(parser):
16 | # parsing args
17 | args = parser.parse_args()
18 | if args.dataset_type == 'OpenImages':
19 | args.do_bottleneck_head = True
20 | if args.th == None:
21 | args.th = 0.995
22 | else:
23 | args.do_bottleneck_head = False
24 | if args.th == None:
25 | args.th = 0.7
26 | return args
27 |
28 |
29 | def average_precision(output, target):
30 | epsilon = 1e-8
31 |
32 | # sort examples
33 | indices = output.argsort()[::-1]
34 | # Computes prec@i
35 | total_count_ = np.cumsum(np.ones((len(output), 1)))
36 |
37 | target_ = target[indices]
38 | ind = target_ == 1
39 | pos_count_ = np.cumsum(ind)
40 | total = pos_count_[-1]
41 | pos_count_[np.logical_not(ind)] = 0
42 | pp = pos_count_ / total_count_
43 | precision_at_i_ = np.sum(pp)
44 | precision_at_i = precision_at_i_ / (total + epsilon)
45 |
46 | return precision_at_i
47 |
48 |
49 | def mAP(targs, preds):
50 | """Returns the model's average precision for each class
51 | Return:
52 | ap (FloatTensor): 1xK tensor, with avg precision for each class k
53 | """
54 |
55 | if np.size(preds) == 0:
56 | return 0
57 | ap = np.zeros((preds.shape[1]))
58 | # compute average precision for each class
59 | for k in range(preds.shape[1]):
60 | # sort scores
61 | scores = preds[:, k]
62 | targets = targs[:, k]
63 | # compute average precision
64 | ap[k] = average_precision(scores, targets)
65 | return 100 * ap.mean()
66 |
67 |
68 | class AverageMeter(object):
69 | def __init__(self):
70 | self.val = None
71 | self.sum = None
72 | self.cnt = None
73 | self.avg = None
74 | self.ema = None
75 | self.initialized = False
76 |
77 | def update(self, val, n=1):
78 | if not self.initialized:
79 | self.initialize(val, n)
80 | else:
81 | self.add(val, n)
82 |
83 | def initialize(self, val, n):
84 | self.val = val
85 | self.sum = val * n
86 | self.cnt = n
87 | self.avg = val
88 | self.ema = val
89 | self.initialized = True
90 |
91 | def add(self, val, n):
92 | self.val = val
93 | self.sum += val * n
94 | self.cnt += n
95 | self.avg = self.sum / self.cnt
96 | self.ema = self.ema * 0.99 + self.val * 0.01
97 |
98 |
99 | class CocoDetection(datasets.coco.CocoDetection):
100 | def __init__(self, root, annFile, transform=None, target_transform=None):
101 | self.root = root
102 | self.coco = COCO(annFile)
103 |
104 | self.ids = list(self.coco.imgToAnns.keys())
105 | self.transform = transform
106 | self.target_transform = target_transform
107 | self.cat2cat = dict()
108 | for cat in self.coco.cats.keys():
109 | self.cat2cat[cat] = len(self.cat2cat)
110 | # print(self.cat2cat)
111 |
112 | def __getitem__(self, index):
113 | coco = self.coco
114 | img_id = self.ids[index]
115 | ann_ids = coco.getAnnIds(imgIds=img_id)
116 | target = coco.loadAnns(ann_ids)
117 |
118 | output = torch.zeros((3, 80), dtype=torch.long)
119 | for obj in target:
120 | if obj['area'] < 32 * 32:
121 | output[0][self.cat2cat[obj['category_id']]] = 1
122 | elif obj['area'] < 96 * 96:
123 | output[1][self.cat2cat[obj['category_id']]] = 1
124 | else:
125 | output[2][self.cat2cat[obj['category_id']]] = 1
126 | target = output
127 |
128 | path = coco.loadImgs(img_id)[0]['file_name']
129 | img = Image.open(os.path.join(self.root, path)).convert('RGB')
130 | if self.transform is not None:
131 | img = self.transform(img)
132 |
133 | if self.target_transform is not None:
134 | target = self.target_transform(target)
135 | return img, target
136 |
137 |
138 | class ModelEma(torch.nn.Module):
139 | def __init__(self, model, decay=0.9997, device=None):
140 | super(ModelEma, self).__init__()
141 | # make a copy of the model for accumulating moving average of weights
142 | self.module = deepcopy(model)
143 | self.module.eval()
144 | self.decay = decay
145 | self.device = device # perform ema on different device from model if set
146 | if self.device is not None:
147 | self.module.to(device=device)
148 |
149 | def _update(self, model, update_fn):
150 | with torch.no_grad():
151 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
152 | if self.device is not None:
153 | model_v = model_v.to(device=self.device)
154 | ema_v.copy_(update_fn(ema_v, model_v))
155 |
156 | def update(self, model):
157 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
158 |
159 | def set(self, model):
160 | self._update(model, update_fn=lambda e, m: m)
161 |
162 |
163 | class CutoutPIL(object):
164 | def __init__(self, cutout_factor=0.5):
165 | self.cutout_factor = cutout_factor
166 |
167 | def __call__(self, x):
168 | img_draw = ImageDraw.Draw(x)
169 | h, w = x.size[0], x.size[1] # HWC
170 | h_cutout = int(self.cutout_factor * h + 0.5)
171 | w_cutout = int(self.cutout_factor * w + 0.5)
172 | y_c = np.random.randint(h)
173 | x_c = np.random.randint(w)
174 |
175 | y1 = np.clip(y_c - h_cutout // 2, 0, h)
176 | y2 = np.clip(y_c + h_cutout // 2, 0, h)
177 | x1 = np.clip(x_c - w_cutout // 2, 0, w)
178 | x2 = np.clip(x_c + w_cutout // 2, 0, w)
179 | fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
180 | img_draw.rectangle([x1, y1, x2, y2], fill=fill_color)
181 |
182 | return x
183 |
184 |
185 | def add_weight_decay(model, weight_decay=1e-4, skip_list=()):
186 | decay = []
187 | no_decay = []
188 | for name, param in model.named_parameters():
189 | if not param.requires_grad:
190 | continue # frozen weights
191 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
192 | no_decay.append(param)
193 | else:
194 | decay.append(param)
195 | return [
196 | {'params': no_decay, 'weight_decay': 0.},
197 | {'params': decay, 'weight_decay': weight_decay}]
198 |
--------------------------------------------------------------------------------
/src/loss_functions/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AsymmetricLoss(nn.Module):
6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
7 | super(AsymmetricLoss, self).__init__()
8 |
9 | self.gamma_neg = gamma_neg
10 | self.gamma_pos = gamma_pos
11 | self.clip = clip
12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
13 | self.eps = eps
14 |
15 | def forward(self, x, y):
16 | """"
17 | Parameters
18 | ----------
19 | x: input logits
20 | y: targets (multi-label binarized vector)
21 | """
22 |
23 | # Calculating Probabilities
24 | x_sigmoid = torch.sigmoid(x)
25 | xs_pos = x_sigmoid
26 | xs_neg = 1 - x_sigmoid
27 |
28 | # Asymmetric Clipping
29 | if self.clip is not None and self.clip > 0:
30 | xs_neg = (xs_neg + self.clip).clamp(max=1)
31 |
32 | # Basic CE calculation
33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
35 | loss = los_pos + los_neg
36 |
37 | # Asymmetric Focusing
38 | if self.gamma_neg > 0 or self.gamma_pos > 0:
39 | if self.disable_torch_grad_focal_loss:
40 | torch.set_grad_enabled(False)
41 | pt0 = xs_pos * y
42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
43 | pt = pt0 + pt1
44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma)
46 | if self.disable_torch_grad_focal_loss:
47 | torch.set_grad_enabled(True)
48 | loss *= one_sided_w
49 |
50 | return -loss.sum()
51 |
52 |
53 | class AsymmetricLossOptimized(nn.Module):
54 | ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
55 | favors inplace operations'''
56 |
57 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
58 | super(AsymmetricLossOptimized, self).__init__()
59 |
60 | self.gamma_neg = gamma_neg
61 | self.gamma_pos = gamma_pos
62 | self.clip = clip
63 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
64 | self.eps = eps
65 |
66 | # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
67 | self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
68 |
69 | def forward(self, x, y):
70 | """"
71 | Parameters
72 | ----------
73 | x: input logits
74 | y: targets (multi-label binarized vector)
75 | """
76 |
77 | self.targets = y
78 | self.anti_targets = 1 - y
79 |
80 | # Calculating Probabilities
81 | self.xs_pos = torch.sigmoid(x)
82 | self.xs_neg = 1.0 - self.xs_pos
83 |
84 | # Asymmetric Clipping
85 | if self.clip is not None and self.clip > 0:
86 | self.xs_neg.add_(self.clip).clamp_(max=1)
87 |
88 | # Basic CE calculation
89 | self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
90 | self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
91 |
92 | # Asymmetric Focusing
93 | if self.gamma_neg > 0 or self.gamma_pos > 0:
94 | if self.disable_torch_grad_focal_loss:
95 | torch.set_grad_enabled(False)
96 | self.xs_pos = self.xs_pos * self.targets
97 | self.xs_neg = self.xs_neg * self.anti_targets
98 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
99 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
100 | if self.disable_torch_grad_focal_loss:
101 | torch.set_grad_enabled(True)
102 | self.loss *= self.asymmetric_w
103 |
104 | return -self.loss.sum()
105 |
106 |
107 | class ASLSingleLabel(nn.Module):
108 | '''
109 | This loss is intended for single-label classification problems
110 | '''
111 | def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'):
112 | super(ASLSingleLabel, self).__init__()
113 |
114 | self.eps = eps
115 | self.logsoftmax = nn.LogSoftmax(dim=-1)
116 | self.targets_classes = []
117 | self.gamma_pos = gamma_pos
118 | self.gamma_neg = gamma_neg
119 | self.reduction = reduction
120 |
121 | def forward(self, inputs, target):
122 | '''
123 | "input" dimensions: - (batch_size,number_classes)
124 | "target" dimensions: - (batch_size)
125 | '''
126 | num_classes = inputs.size()[-1]
127 | log_preds = self.logsoftmax(inputs)
128 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
129 |
130 | # ASL weights
131 | targets = self.targets_classes
132 | anti_targets = 1 - targets
133 | xs_pos = torch.exp(log_preds)
134 | xs_neg = 1 - xs_pos
135 | xs_pos = xs_pos * targets
136 | xs_neg = xs_neg * anti_targets
137 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
138 | self.gamma_pos * targets + self.gamma_neg * anti_targets)
139 | log_preds = log_preds * asymmetric_w
140 |
141 | if self.eps > 0: # label smoothing
142 | self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
143 |
144 | # loss calculation
145 | loss = - self.targets_classes.mul(log_preds)
146 |
147 | loss = loss.sum(dim=-1)
148 | if self.reduction == 'mean':
149 | loss = loss.mean()
150 |
151 | return loss
152 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import create_model
2 |
3 | __all__ = ['create_model']
4 |
--------------------------------------------------------------------------------
/src/models/tresnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .tresnet import TResnetM, TResnetL, TResnetXL
2 |
--------------------------------------------------------------------------------
/src/models/tresnet/layers/anti_aliasing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.parallel
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class AntiAliasDownsampleLayer(nn.Module):
9 | def __init__(self, remove_model_jit: bool = False, filt_size: int = 3, stride: int = 2,
10 | channels: int = 0):
11 | super(AntiAliasDownsampleLayer, self).__init__()
12 | if not remove_model_jit:
13 | self.op = DownsampleJIT(filt_size, stride, channels)
14 | else:
15 | self.op = Downsample(filt_size, stride, channels)
16 |
17 | def forward(self, x):
18 | return self.op(x)
19 |
20 |
21 | @torch.jit.script
22 | class DownsampleJIT(object):
23 | def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
24 | self.stride = stride
25 | self.filt_size = filt_size
26 | self.channels = channels
27 |
28 | assert self.filt_size == 3
29 | assert stride == 2
30 | a = torch.tensor([1., 2., 1.])
31 |
32 | filt = (a[:, None] * a[None, :]).clone().detach()
33 | filt = filt / torch.sum(filt)
34 | self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()
35 |
36 | def __call__(self, input: torch.Tensor):
37 | if input.dtype != self.filt.dtype:
38 | self.filt = self.filt.float()
39 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
40 | return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])
41 |
42 |
43 | class Downsample(nn.Module):
44 | def __init__(self, filt_size=3, stride=2, channels=None):
45 | super(Downsample, self).__init__()
46 | self.filt_size = filt_size
47 | self.stride = stride
48 | self.channels = channels
49 |
50 |
51 | assert self.filt_size == 3
52 | a = torch.tensor([1., 2., 1.])
53 |
54 | filt = (a[:, None] * a[None, :]).clone().detach()
55 | filt = filt / torch.sum(filt)
56 | self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
57 |
58 | def forward(self, input):
59 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
60 | return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
61 |
--------------------------------------------------------------------------------
/src/models/tresnet/layers/avg_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 | class FastAvgPool2d(nn.Module):
8 | def __init__(self, flatten=False):
9 | super(FastAvgPool2d, self).__init__()
10 | self.flatten = flatten
11 |
12 | def forward(self, x):
13 | if self.flatten:
14 | in_size = x.size()
15 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
16 | else:
17 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
18 |
19 |
20 |
--------------------------------------------------------------------------------
/src/models/tresnet/layers/general_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from src.models.tresnet.layers.avg_pool import FastAvgPool2d
6 |
7 |
8 | class Flatten(nn.Module):
9 | def forward(self, x):
10 | return x.view(x.size(0), -1)
11 |
12 |
13 | class DepthToSpace(nn.Module):
14 |
15 | def __init__(self, block_size):
16 | super().__init__()
17 | self.bs = block_size
18 |
19 | def forward(self, x):
20 | N, C, H, W = x.size()
21 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
22 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
23 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
24 | return x
25 |
26 |
27 | class SpaceToDepthModule(nn.Module):
28 | def __init__(self, remove_model_jit=False):
29 | super().__init__()
30 | if not remove_model_jit:
31 | self.op = SpaceToDepthJit()
32 | else:
33 | self.op = SpaceToDepth()
34 |
35 | def forward(self, x):
36 | return self.op(x)
37 |
38 |
39 | class SpaceToDepth(nn.Module):
40 | def __init__(self, block_size=4):
41 | super().__init__()
42 | assert block_size == 4
43 | self.bs = block_size
44 |
45 | def forward(self, x):
46 | N, C, H, W = x.size()
47 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
48 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
49 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
50 | return x
51 |
52 |
53 | @torch.jit.script
54 | class SpaceToDepthJit(object):
55 | def __call__(self, x: torch.Tensor):
56 | # assuming hard-coded that block_size==4 for acceleration
57 | N, C, H, W = x.size()
58 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
59 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
60 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
61 | return x
62 |
63 |
64 | class hard_sigmoid(nn.Module):
65 | def __init__(self, inplace=True):
66 | super(hard_sigmoid, self).__init__()
67 | self.inplace = inplace
68 |
69 | def forward(self, x):
70 | if self.inplace:
71 | return x.add_(3.).clamp_(0., 6.).div_(6.)
72 | else:
73 | return F.relu6(x + 3.) / 6.
74 |
75 |
76 | class SEModule(nn.Module):
77 |
78 | def __init__(self, channels, reduction_channels, inplace=True):
79 | super(SEModule, self).__init__()
80 | self.avg_pool = FastAvgPool2d()
81 | self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
82 | self.relu = nn.ReLU(inplace=inplace)
83 | self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
84 | # self.activation = hard_sigmoid(inplace=inplace)
85 | self.activation = nn.Sigmoid()
86 |
87 | def forward(self, x):
88 | x_se = self.avg_pool(x)
89 | x_se2 = self.fc1(x_se)
90 | x_se2 = self.relu(x_se2)
91 | x_se = self.fc2(x_se2)
92 | x_se = self.activation(x_se)
93 | return x * x_se
94 |
--------------------------------------------------------------------------------
/src/models/tresnet/tresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Module as Module
4 | from collections import OrderedDict
5 | from src.models.tresnet.layers.anti_aliasing import AntiAliasDownsampleLayer
6 | from .layers.avg_pool import FastAvgPool2d
7 | from .layers.general_layers import SEModule, SpaceToDepthModule
8 | from inplace_abn import InPlaceABN, ABN
9 |
10 |
11 | def InplacABN_to_ABN(module: nn.Module) -> nn.Module:
12 | # convert all InplaceABN layer to bit-accurate ABN layers.
13 | if isinstance(module, InPlaceABN):
14 | module_new = ABN(module.num_features, activation=module.activation,
15 | activation_param=module.activation_param)
16 | for key in module.state_dict():
17 | module_new.state_dict()[key].copy_(module.state_dict()[key])
18 | module_new.training = module.training
19 | module_new.weight.data = module_new.weight.abs() + module_new.eps
20 | return module_new
21 | for name, child in reversed(module._modules.items()):
22 | new_child = InplacABN_to_ABN(child)
23 | if new_child != child:
24 | module._modules[name] = new_child
25 | return module
26 |
27 | class bottleneck_head(nn.Module):
28 | def __init__(self, num_features, num_classes, bottleneck_features=200):
29 | super(bottleneck_head, self).__init__()
30 | self.embedding_generator = nn.ModuleList()
31 | self.embedding_generator.append(nn.Linear(num_features, bottleneck_features))
32 | self.embedding_generator = nn.Sequential(*self.embedding_generator)
33 | self.FC = nn.Linear(bottleneck_features, num_classes)
34 |
35 | def forward(self, x):
36 | self.embedding = self.embedding_generator(x)
37 | logits = self.FC(self.embedding)
38 | return logits
39 |
40 |
41 | def conv2d(ni, nf, stride):
42 | return nn.Sequential(
43 | nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),
44 | nn.BatchNorm2d(nf),
45 | nn.ReLU(inplace=True)
46 | )
47 |
48 |
49 | def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1):
50 | return nn.Sequential(
51 | nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups,
52 | bias=False),
53 | InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param)
54 | )
55 |
56 |
57 | class BasicBlock(Module):
58 | expansion = 1
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
61 | super(BasicBlock, self).__init__()
62 | if stride == 1:
63 | self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3)
64 | else:
65 | if anti_alias_layer is None:
66 | self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3)
67 | else:
68 | self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
69 | anti_alias_layer(channels=planes, filt_size=3, stride=2))
70 |
71 | self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity")
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 | reduce_layer_planes = max(planes * self.expansion // 4, 64)
76 | self.se = SEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
77 |
78 | def forward(self, x):
79 | if self.downsample is not None:
80 | residual = self.downsample(x)
81 | else:
82 | residual = x
83 |
84 | out = self.conv1(x)
85 | out = self.conv2(out)
86 |
87 | if self.se is not None: out = self.se(out)
88 |
89 | out += residual
90 |
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class Bottleneck(Module):
97 | expansion = 4
98 |
99 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
100 | super(Bottleneck, self).__init__()
101 | self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu",
102 | activation_param=1e-3)
103 | if stride == 1:
104 | self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu",
105 | activation_param=1e-3)
106 | else:
107 | if anti_alias_layer is None:
108 | self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu",
109 | activation_param=1e-3)
110 | else:
111 | self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1,
112 | activation="leaky_relu", activation_param=1e-3),
113 | anti_alias_layer(channels=planes, filt_size=3, stride=2))
114 |
115 | self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1,
116 | activation="identity")
117 |
118 | self.relu = nn.ReLU(inplace=True)
119 | self.downsample = downsample
120 | self.stride = stride
121 |
122 | reduce_layer_planes = max(planes * self.expansion // 8, 64)
123 | self.se = SEModule(planes, reduce_layer_planes) if use_se else None
124 |
125 | def forward(self, x):
126 | if self.downsample is not None:
127 | residual = self.downsample(x)
128 | else:
129 | residual = x
130 |
131 | out = self.conv1(x)
132 | out = self.conv2(out)
133 | if self.se is not None: out = self.se(out)
134 |
135 | out = self.conv3(out)
136 | out = out + residual # no inplace
137 | out = self.relu(out)
138 |
139 | return out
140 |
141 |
142 | class TResNet(Module):
143 |
144 | def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0,
145 | do_bottleneck_head=False,bottleneck_features=512):
146 | super(TResNet, self).__init__()
147 |
148 | # JIT layers
149 | space_to_depth = SpaceToDepthModule()
150 | anti_alias_layer = AntiAliasDownsampleLayer
151 | global_pool_layer = FastAvgPool2d(flatten=True)
152 |
153 | # TResnet stages
154 | self.inplanes = int(64 * width_factor)
155 | self.planes = int(64 * width_factor)
156 | conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3)
157 | layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True,
158 | anti_alias_layer=anti_alias_layer) # 56x56
159 | layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True,
160 | anti_alias_layer=anti_alias_layer) # 28x28
161 | layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True,
162 | anti_alias_layer=anti_alias_layer) # 14x14
163 | layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False,
164 | anti_alias_layer=anti_alias_layer) # 7x7
165 |
166 | # body
167 | self.body = nn.Sequential(OrderedDict([
168 | ('SpaceToDepth', space_to_depth),
169 | ('conv1', conv1),
170 | ('layer1', layer1),
171 | ('layer2', layer2),
172 | ('layer3', layer3),
173 | ('layer4', layer4)]))
174 |
175 | # head
176 | self.embeddings = []
177 | self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)]))
178 | self.num_features = (self.planes * 8) * Bottleneck.expansion
179 | if do_bottleneck_head:
180 | fc = bottleneck_head(self.num_features, num_classes,
181 | bottleneck_features=bottleneck_features)
182 | else:
183 | fc = nn.Linear(self.num_features , num_classes)
184 |
185 | self.head = nn.Sequential(OrderedDict([('fc', fc)]))
186 |
187 | # model initilization
188 | for m in self.modules():
189 | if isinstance(m, nn.Conv2d):
190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
191 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN):
192 | nn.init.constant_(m.weight, 1)
193 | nn.init.constant_(m.bias, 0)
194 |
195 | # residual connections special initialization
196 | for m in self.modules():
197 | if isinstance(m, BasicBlock):
198 | m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
199 | if isinstance(m, Bottleneck):
200 | m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
201 | if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
202 |
203 | def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None):
204 | downsample = None
205 | if stride != 1 or self.inplanes != planes * block.expansion:
206 | layers = []
207 | if stride == 2:
208 | # avg pooling before 1x1 conv
209 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
210 | layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1,
211 | activation="identity")]
212 | downsample = nn.Sequential(*layers)
213 |
214 | layers = []
215 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se,
216 | anti_alias_layer=anti_alias_layer))
217 | self.inplanes = planes * block.expansion
218 | for i in range(1, blocks): layers.append(
219 | block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
220 | return nn.Sequential(*layers)
221 |
222 | def forward(self, x):
223 | x = self.body(x)
224 | self.embeddings = self.global_pool(x)
225 | logits = self.head(self.embeddings)
226 | return logits
227 |
228 |
229 | def TResnetM(model_params):
230 | """Constructs a medium TResnet model.
231 | """
232 | in_chans = 3
233 | num_classes = model_params['num_classes']
234 | model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans)
235 | return model
236 |
237 |
238 | def TResnetL(model_params):
239 | """Constructs a large TResnet model.
240 | """
241 | in_chans = 3
242 | num_classes = model_params['num_classes']
243 | do_bottleneck_head = model_params['args'].do_bottleneck_head
244 | model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2,
245 | do_bottleneck_head=do_bottleneck_head)
246 | return model
247 |
248 |
249 | def TResnetXL(model_params):
250 | """Constructs a xlarge TResnet model.
251 | """
252 | in_chans = 3
253 | num_classes = model_params['num_classes']
254 | model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3)
255 |
256 | return model
257 |
--------------------------------------------------------------------------------
/src/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import create_model
2 | __all__ = ['create_model']
--------------------------------------------------------------------------------
/src/models/utils/factory.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger(__name__)
4 |
5 | from ..tresnet import TResnetM, TResnetL, TResnetXL
6 |
7 |
8 | def create_model(args):
9 | """Create a model
10 | """
11 | model_params = {'args': args, 'num_classes': args.num_classes}
12 | args = model_params['args']
13 | args.model_name = args.model_name.lower()
14 |
15 | if args.model_name=='tresnet_m':
16 | model = TResnetM(model_params)
17 | elif args.model_name=='tresnet_l':
18 | model = TResnetL(model_params)
19 | elif args.model_name=='tresnet_xl':
20 | model = TResnetXL(model_params)
21 | else:
22 | print("model: {} not found !!".format(args.model_name))
23 | exit(-1)
24 |
25 | return model
26 |
--------------------------------------------------------------------------------
/tests/test_asl.py:
--------------------------------------------------------------------------------
1 |
2 | # test_asl.py
3 | # tests auto-generated by https://www.codium.ai/
4 | # testing https://github.com/Alibaba-MIIL/ASL/blob/b9d01aff9f66ccddab6112e47f3ed0ceb59ad7f5/tests/test_asl.py#L6 class
5 |
6 | import unittest
7 | import torch
8 | from src.loss_functions.losses import AsymmetricLoss
9 |
10 | """
11 | Code Analysis for AsymmetricLoss() class:
12 | - This class is a custom loss function called AsymmetricLoss, which is a subclass of the nn.Module class.
13 | - It is used to calculate the loss between the input logits and the targets (multi-label binarized vector).
14 | - The __init__ method initializes the parameters of the class, such as gamma_neg, gamma_pos, clip, eps, and disable_torch_grad_focal_loss.
15 | - The forward method is used to calculate the loss between the input logits and the targets.
16 | - The forward method first calculates the sigmoid of the input logits and then calculates the positive and negative logits.
17 | - If the clip parameter is not None and greater than 0, the negative logits are clipped to a maximum of 1.
18 | - The loss is then calculated using the positive and negative logits.
19 | - If the gamma_neg and gamma_pos parameters are greater than 0, the loss is weighted using the one_sided_gamma and one_sided_w parameters.
20 | - Finally, the loss is returned as a negative sum.
21 | """
22 |
23 |
24 | """
25 | Test strategies:
26 | - test_init(): tests that the parameters of the class are initialized correctly
27 | - test_forward_positive_logits(): tests that the forward method correctly calculates the loss for positive logits
28 | - test_forward_negative_logits(): tests that the forward method correctly calculates the loss for negative logits
29 | - test_forward_clipped_logits(): tests that the forward method correctly calculates the loss for clipped logits
30 | - test_forward_gamma_pos(): tests that the forward method correctly calculates the loss when gamma_pos is greater than 0
31 | - test_forward_gamma_neg(): tests that the forward method correctly calculates the loss when gamma_neg is greater than 0
32 | - test_forward_eps(): tests that the forward method correctly calculates the loss when eps is greater than 0
33 | - test_forward_disable_torch_grad_focal_loss(): tests that the forward method correctly calculates the loss when disable_torch_grad_focal_loss is set to True
34 | """
35 | class TestAsymmetricLoss(unittest.TestCase):
36 | def setUp(self):
37 | self.loss = AsymmetricLoss()
38 |
39 | def test_init(self):
40 | self.assertEqual(self.loss.gamma_neg, 4)
41 | self.assertEqual(self.loss.gamma_pos, 1)
42 | self.assertEqual(self.loss.clip, 0.05)
43 | self.assertEqual(self.loss.eps, 1e-08)
44 | self.assertTrue(self.loss.disable_torch_grad_focal_loss)
45 |
46 | def test_forward_positive_logits(self):
47 | x = torch.tensor([1., 2., 3.])
48 | y = torch.tensor([1., 0., 1.])
49 | expected_loss = -torch.log(torch.sigmoid(x)).sum()
50 | self.assertEqual(self.loss(x, y), expected_loss)
51 |
52 | def test_forward_negative_logits(self):
53 | x = torch.tensor([-1., -2., -3.])
54 | y = torch.tensor([1., 0., 1.])
55 | expected_loss = -torch.log(1 - torch.sigmoid(x)).sum()
56 | self.assertEqual(self.loss(x, y), expected_loss)
57 |
58 | def test_forward_clipped_logits(self):
59 | x = torch.tensor([-1., -2., -3.])
60 | y = torch.tensor([1., 0., 1.])
61 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)).sum()
62 | self.assertEqual(self.loss(x, y), expected_loss)
63 |
64 | def test_forward_gamma_pos(self):
65 | x = torch.tensor([1., 2., 3.])
66 | y = torch.tensor([1., 0., 1.])
67 | self.loss.gamma_pos = 2
68 | expected_loss = -torch.log(torch.sigmoid(x)) * torch.pow(1 - torch.sigmoid(x), 2).sum()
69 | self.assertEqual(self.loss(x, y), expected_loss)
70 |
71 | def test_forward_gamma_neg(self):
72 | x = torch.tensor([-1., -2., -3.])
73 | y = torch.tensor([1., 0., 1.])
74 | self.loss.gamma_neg = 3
75 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)) * torch.pow(1 - (torch.sigmoid(x) + self.loss.clip).clamp(max=1), 3).sum()
76 | self.assertEqual(self.loss(x, y), expected_loss)
77 |
78 | def test_forward_eps(self):
79 | x = torch.tensor([-1., -2., -3.])
80 | y = torch.tensor([1., 0., 1.])
81 | self.loss.eps = 0.5
82 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(min=0.5)).sum()
83 | self.assertEqual(self.loss(x, y), expected_loss)
84 |
85 | def test_forward_disable_torch_grad_focal_loss(self):
86 | x = torch.tensor([-1., -2., -3.])
87 | y = torch.tensor([1., 0., 1.])
88 | self.loss.disable_torch_grad_focal_loss = False
89 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)).sum()
90 | self.assertEqual(self.loss(x, y), expected_loss)
91 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | import torch.nn.parallel
5 | import torch.optim
6 | import torch.utils.data.distributed
7 | import torchvision.transforms as transforms
8 | from torch.optim import lr_scheduler
9 | from src.helper_functions.helper_functions import mAP, CocoDetection, CutoutPIL, ModelEma, add_weight_decay
10 | from src.models import create_model
11 | from src.loss_functions.losses import AsymmetricLoss
12 | from randaugment import RandAugment
13 | from torch.cuda.amp import GradScaler, autocast
14 |
15 | parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')
16 | parser.add_argument('data', metavar='DIR', help='path to dataset', default='/home/MSCOCO_2014/')
17 | parser.add_argument('--lr', default=1e-4, type=float)
18 | parser.add_argument('--model-name', default='tresnet_m')
19 | parser.add_argument('--model-path', default='./tresnet_m.pth', type=str)
20 | parser.add_argument('--num-classes', default=80)
21 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
22 | help='number of data loading workers (default: 16)')
23 | parser.add_argument('--image-size', default=224, type=int,
24 | metavar='N', help='input image size (default: 448)')
25 | parser.add_argument('--thre', default=0.8, type=float,
26 | metavar='N', help='threshold value')
27 | parser.add_argument('-b', '--batch-size', default=128, type=int,
28 | metavar='N', help='mini-batch size (default: 16)')
29 | parser.add_argument('--print-freq', '-p', default=64, type=int,
30 | metavar='N', help='print frequency (default: 64)')
31 |
32 |
33 | def main():
34 | args = parser.parse_args()
35 | args.do_bottleneck_head = False
36 |
37 | # Setup model
38 | print('creating model...')
39 | model = create_model(args).cuda()
40 | if args.model_path: # make sure to load pretrained ImageNet model
41 | state = torch.load(args.model_path, map_location='cpu')
42 | filtered_dict = {k: v for k, v in state['model'].items() if
43 | (k in model.state_dict() and 'head.fc' not in k)}
44 | model.load_state_dict(filtered_dict, strict=False)
45 | print('done\n')
46 |
47 | # COCO Data loading
48 | instances_path_val = os.path.join(args.data, 'annotations/instances_val2014.json')
49 | instances_path_train = os.path.join(args.data, 'annotations/instances_train2014.json')
50 | # data_path_val = args.data
51 | # data_path_train = args.data
52 | data_path_val = f'{args.data}/val2014' # args.data
53 | data_path_train = f'{args.data}/train2014' # args.data
54 | val_dataset = CocoDetection(data_path_val,
55 | instances_path_val,
56 | transforms.Compose([
57 | transforms.Resize((args.image_size, args.image_size)),
58 | transforms.ToTensor(),
59 | # normalize, # no need, toTensor does normalization
60 | ]))
61 | train_dataset = CocoDetection(data_path_train,
62 | instances_path_train,
63 | transforms.Compose([
64 | transforms.Resize((args.image_size, args.image_size)),
65 | CutoutPIL(cutout_factor=0.5),
66 | RandAugment(),
67 | transforms.ToTensor(),
68 | # normalize,
69 | ]))
70 | print("len(val_dataset)): ", len(val_dataset))
71 | print("len(train_dataset)): ", len(train_dataset))
72 |
73 | # Pytorch Data loader
74 | train_loader = torch.utils.data.DataLoader(
75 | train_dataset, batch_size=args.batch_size, shuffle=True,
76 | num_workers=args.workers, pin_memory=True)
77 |
78 | val_loader = torch.utils.data.DataLoader(
79 | val_dataset, batch_size=args.batch_size, shuffle=False,
80 | num_workers=args.workers, pin_memory=False)
81 |
82 | # Actuall Training
83 | train_multi_label_coco(model, train_loader, val_loader, args.lr)
84 |
85 |
86 | def train_multi_label_coco(model, train_loader, val_loader, lr):
87 | ema = ModelEma(model, 0.9997) # 0.9997^641=0.82
88 |
89 | # set optimizer
90 | Epochs = 80
91 | Stop_epoch = 40
92 | weight_decay = 1e-4
93 | criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05, disable_torch_grad_focal_loss=True)
94 | parameters = add_weight_decay(model, weight_decay)
95 | optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0) # true wd, filter_bias_and_bn
96 | steps_per_epoch = len(train_loader)
97 | scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs,
98 | pct_start=0.2)
99 |
100 | highest_mAP = 0
101 | trainInfoList = []
102 | scaler = GradScaler()
103 | for epoch in range(Epochs):
104 | if epoch > Stop_epoch:
105 | break
106 | for i, (inputData, target) in enumerate(train_loader):
107 | inputData = inputData.cuda()
108 | target = target.cuda() # (batch,3,num_classes)
109 | target = target.max(dim=1)[0]
110 | with autocast(): # mixed precision
111 | output = model(inputData).float() # sigmoid will be done in loss !
112 | loss = criterion(output, target)
113 | model.zero_grad()
114 |
115 | scaler.scale(loss).backward()
116 | # loss.backward()
117 |
118 | scaler.step(optimizer)
119 | scaler.update()
120 | # optimizer.step()
121 |
122 | scheduler.step()
123 |
124 | ema.update(model)
125 | # store information
126 | if i % 100 == 0:
127 | trainInfoList.append([epoch, i, loss.item()])
128 | print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}'
129 | .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3),
130 | scheduler.get_last_lr()[0], \
131 | loss.item()))
132 |
133 | try:
134 | torch.save(model.state_dict(), os.path.join(
135 | 'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))
136 | except:
137 | pass
138 |
139 | model.eval()
140 | mAP_score = validate_multi(val_loader, model, ema)
141 | model.train()
142 | if mAP_score > highest_mAP:
143 | highest_mAP = mAP_score
144 | try:
145 | torch.save(model.state_dict(), os.path.join(
146 | 'models/', 'model-highest.ckpt'))
147 | except:
148 | pass
149 | print('current_mAP = {:.2f}, highest_mAP = {:.2f}\n'.format(mAP_score, highest_mAP))
150 |
151 |
152 | def validate_multi(val_loader, model, ema_model):
153 | print("starting validation")
154 | Sig = torch.nn.Sigmoid()
155 | preds_regular = []
156 | preds_ema = []
157 | targets = []
158 | for i, (input, target) in enumerate(val_loader):
159 | target = target
160 | target = target.max(dim=1)[0]
161 | # compute output
162 | with torch.no_grad():
163 | with autocast():
164 | output_regular = Sig(model(input.cuda())).cpu()
165 | output_ema = Sig(ema_model.module(input.cuda())).cpu()
166 |
167 | # for mAP calculation
168 | preds_regular.append(output_regular.cpu().detach())
169 | preds_ema.append(output_ema.cpu().detach())
170 | targets.append(target.cpu().detach())
171 |
172 | mAP_score_regular = mAP(torch.cat(targets).numpy(), torch.cat(preds_regular).numpy())
173 | mAP_score_ema = mAP(torch.cat(targets).numpy(), torch.cat(preds_ema).numpy())
174 | print("mAP score regular {:.2f}, mAP score EMA {:.2f}".format(mAP_score_regular, mAP_score_ema))
175 | return max(mAP_score_regular, mAP_score_ema)
176 |
177 |
178 | if __name__ == '__main__':
179 | main()
180 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | # Adopted from: https://github.com/allenai/elastic/blob/master/multilabel_classify.py
2 | # special thanks to @hellbell
3 |
4 | import argparse
5 | import time
6 | import torch
7 | import torch.nn.parallel
8 | import torch.optim
9 | import torch.utils.data.distributed
10 | import torchvision.transforms as transforms
11 | import os
12 |
13 | from src.helper_functions.helper_functions import mAP, AverageMeter, CocoDetection
14 | from src.models import create_model
15 | import numpy as np
16 |
17 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
18 | parser.add_argument('data', metavar='DIR', help='path to dataset')
19 | parser.add_argument('--model-name', default='tresnet_l')
20 | parser.add_argument('--model-path', default='./TRresNet_L_448_86.6.pth', type=str)
21 | parser.add_argument('--num-classes', default=80)
22 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
23 | help='number of data loading workers (default: 16)')
24 | parser.add_argument('--image-size', default=448, type=int,
25 | metavar='N', help='input image size (default: 448)')
26 | parser.add_argument('--thre', default=0.8, type=float,
27 | metavar='N', help='threshold value')
28 | parser.add_argument('-b', '--batch-size', default=32, type=int,
29 | metavar='N', help='mini-batch size (default: 16)')
30 | parser.add_argument('--print-freq', '-p', default=64, type=int,
31 | metavar='N', help='print frequency (default: 64)')
32 |
33 |
34 | def main():
35 | args = parser.parse_args()
36 | args.batch_size = args.batch_size
37 |
38 | # setup model
39 | print('creating and loading the model...')
40 | state = torch.load(args.model_path, map_location='cpu')
41 | args.num_classes = state['num_classes']
42 | args.do_bottleneck_head = False
43 | model = create_model(args).cuda()
44 | model.load_state_dict(state['model'], strict=True)
45 | model.eval()
46 | classes_list = np.array(list(state['idx_to_class'].values()))
47 | print('done\n')
48 |
49 | # Data loading code
50 | normalize = transforms.Normalize(mean=[0, 0, 0],
51 | std=[1, 1, 1])
52 |
53 | instances_path = os.path.join(args.data, 'annotations/instances_val2014.json')
54 | data_path = os.path.join(args.data, 'val2014')
55 | val_dataset = CocoDetection(data_path,
56 | instances_path,
57 | transforms.Compose([
58 | transforms.Resize((args.image_size, args.image_size)),
59 | transforms.ToTensor(),
60 | normalize,
61 | ]))
62 |
63 | print("len(val_dataset)): ", len(val_dataset))
64 | val_loader = torch.utils.data.DataLoader(
65 | val_dataset, batch_size=args.batch_size, shuffle=False,
66 | num_workers=args.workers, pin_memory=True)
67 |
68 | validate_multi(val_loader, model, args)
69 |
70 |
71 | def validate_multi(val_loader, model, args):
72 | print("starting actuall validation")
73 | batch_time = AverageMeter()
74 | prec = AverageMeter()
75 | rec = AverageMeter()
76 | mAP_meter = AverageMeter()
77 |
78 | Sig = torch.nn.Sigmoid()
79 |
80 | end = time.time()
81 | tp, fp, fn, tn, count = 0, 0, 0, 0, 0
82 | preds = []
83 | targets = []
84 | for i, (input, target) in enumerate(val_loader):
85 | target = target
86 | target = target.max(dim=1)[0]
87 | # compute output
88 | with torch.no_grad():
89 | output = Sig(model(input.cuda())).cpu()
90 |
91 | # for mAP calculation
92 | preds.append(output.cpu())
93 | targets.append(target.cpu())
94 |
95 | # measure accuracy and record loss
96 | pred = output.data.gt(args.thre).long()
97 |
98 | tp += (pred + target).eq(2).sum(dim=0)
99 | fp += (pred - target).eq(1).sum(dim=0)
100 | fn += (pred - target).eq(-1).sum(dim=0)
101 | tn += (pred + target).eq(0).sum(dim=0)
102 | count += input.size(0)
103 |
104 | this_tp = (pred + target).eq(2).sum()
105 | this_fp = (pred - target).eq(1).sum()
106 | this_fn = (pred - target).eq(-1).sum()
107 | this_tn = (pred + target).eq(0).sum()
108 |
109 | this_prec = this_tp.float() / (
110 | this_tp + this_fp).float() * 100.0 if this_tp + this_fp != 0 else 0.0
111 | this_rec = this_tp.float() / (
112 | this_tp + this_fn).float() * 100.0 if this_tp + this_fn != 0 else 0.0
113 |
114 | prec.update(float(this_prec), input.size(0))
115 | rec.update(float(this_rec), input.size(0))
116 |
117 | # measure elapsed time
118 | batch_time.update(time.time() - end)
119 | end = time.time()
120 |
121 | p_c = [float(tp[i].float() / (tp[i] + fp[i]).float()) * 100.0 if tp[
122 | i] > 0 else 0.0
123 | for i in range(len(tp))]
124 | r_c = [float(tp[i].float() / (tp[i] + fn[i]).float()) * 100.0 if tp[
125 | i] > 0 else 0.0
126 | for i in range(len(tp))]
127 | f_c = [2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0 for
128 | i in range(len(tp))]
129 |
130 | mean_p_c = sum(p_c) / len(p_c)
131 | mean_r_c = sum(r_c) / len(r_c)
132 | mean_f_c = sum(f_c) / len(f_c)
133 |
134 | p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0
135 | r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0
136 | f_o = 2 * p_o * r_o / (p_o + r_o)
137 |
138 | if i % args.print_freq == 0:
139 | print('Test: [{0}/{1}]\t'
140 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
141 | 'Precision {prec.val:.2f} ({prec.avg:.2f})\t'
142 | 'Recall {rec.val:.2f} ({rec.avg:.2f})'.format(
143 | i, len(val_loader), batch_time=batch_time,
144 | prec=prec, rec=rec))
145 | print(
146 | 'P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}'
147 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o))
148 |
149 | print(
150 | '--------------------------------------------------------------------')
151 | print(' * P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}'
152 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o))
153 |
154 | mAP_score = mAP(torch.cat(targets).numpy(), torch.cat(preds).numpy())
155 | print("mAP score:", mAP_score)
156 |
157 | return
158 |
159 |
160 | if __name__ == '__main__':
161 | main()
162 |
--------------------------------------------------------------------------------