├── Datasets.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── infer.py ├── pics ├── 000000000885.jpg ├── backbones.png ├── detection.png ├── example_inference.jpeg ├── example_inference_open_images.jpeg ├── herbarium.png ├── loss_graph.png ├── main_pic.png ├── ms_coco_scores.png └── open_images.png ├── requirements.txt ├── src ├── helper_functions │ └── helper_functions.py ├── loss_functions │ └── losses.py └── models │ ├── __init__.py │ ├── tresnet │ ├── __init__.py │ ├── layers │ │ ├── anti_aliasing.py │ │ ├── avg_pool.py │ │ └── general_layers.py │ └── tresnet.py │ └── utils │ ├── __init__.py │ └── factory.py ├── tests └── test_asl.py ├── train.py └── validate.py /Datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | - MS-COCO dataset can be downloaded from the [officie site](https://cocodataset.org/#download). As commonly done, we use the 2014 Train/Val annotations. 4 | 5 | - Open-Images V6 dataset, that was used for the paper, is now directly availabe to download from [here](https://github.com/Alibaba-MIIL/PartialLabelingCSL/blob/main/OpenImages.md). 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Alibaba-MIIL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # ASL pre-trained models 2 | 3 | We provide a collection models trained with ASL on various multi-label datasets 4 | 5 | 6 | | Backbone | Input Size | Dataset | mAP | 7 | | ------------ | :--------------: | :--------------: | :--------------: | 8 | | [TResNet_M (ImageNet21K pretrain)](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/MS_COCO_TRresNet_M_224_81.8.pth) | 224 | MS-COCO | 81.8 | 9 | | [TResNet_L](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/MS_COCO_TRresNet_L_448_86.6.pth) | 448 | MS-COCO | 86.6 | 10 | | [TResNet_XL](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/MS_COCO_TResNet_xl_640_88.4.pth) | 640 | MS-COCO | 88.4 | 11 | | [TResNet_L](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/Open_ImagesV6_TRresNet_L_448.pth) | 448 | OpenImagesV6 | 86.3 | 12 | | [TResNet_XL](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/PASCAL_VOC_TResNet_xl_448_96.0.pth) | 448 | PASCAL-VOC | 96.0 | 13 | | [TResNet_L](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/NUS_WIDE_TRresNet_L_448_65.2.pth) | 448 | NUS-WIDE [dataset](https://drive.google.com/file/d/0B7IzDz-4yH_HMFdiSE44R1lselE/view) [split](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ASL/nus_wid_data.csv) | 65.2 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Asymmetric Loss For Multi-Label Classification 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/imagenet-21k-pretraining-for-the-masses/multi-label-classification-on-ms-coco)](https://paperswithcode.com/sota/multi-label-classification-on-ms-coco?p=imagenet-21k-pretraining-for-the-masses)
4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/asymmetric-loss-for-multi-label/multi-label-classification-on-nus-wide)](https://paperswithcode.com/sota/multi-label-classification-on-nus-wide?p=asymmetric-loss-for-multi-label)
5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/imagenet-21k-pretraining-for-the-masses/multi-label-classification-on-pascal-voc-2007)](https://paperswithcode.com/sota/multi-label-classification-on-pascal-voc-2007?p=imagenet-21k-pretraining-for-the-masses)
6 |
[Paper](https://arxiv.org/abs/2009.14119) | 7 | [Pretrained models](MODEL_ZOO.md) | 8 | [Datasets](Datasets.md) 9 | 10 | Official PyTorch Implementation 11 | 12 | > Emanuel Ben-Baruch, Tal Ridnik, Nadav Zamir, Asaf Noy, Itamar 13 | > Friedman, Matan Protter, Lihi Zelnik-Manor
DAMO Academy, Alibaba 14 | > Group 15 | 16 | **Abstract** 17 | 18 | In a typical multi-label setting, a picture contains on average few positive labels, and many negative ones. This positive-negative imbalance dominates the optimization process, and can lead to under-emphasizing gradients from positive labels during training, resulting in poor accuracy. In this paper, we introduce a novel asymmetric loss ("ASL"), which operates differently on positive and negative samples. The loss enables to dynamically down-weights and hard-thresholds easy negative samples, while also discarding possibly mislabeled samples. We demonstrate how ASL can balance the probabilities of different samples, and how this balancing is translated to better mAP scores. With ASL, we reach state-of-the-art results on multiple popular multi-label datasets: MS-COCO, Pascal-VOC, NUS-WIDE and Open Images. We also demonstrate ASL applicability for other tasks, such as single-label classification and object detection. ASL is effective, easy to implement, and does not increase the training time or complexity. 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | ## 9/1/2023 Update 30 | Added [tests](https://github.com/Alibaba-MIIL/ASL/blob/main/tests/test_asl.py) auto-generated by [CodiumAI](https://www.codium.ai/) tool 31 | 32 | ## 29/11/2021 Update - New article released, offering new classification head with state-of-the-art results 33 | Checkout our new project, [Ml-Decoder](https://github.com/Alibaba-MIIL/ML_Decoder), which presents a unified classification head for multi-label, single-label and 34 | zero-shot tasks. Backbones with ML-Decoder reach SOTA results, while also improving speed-accuracy tradeoff. 35 | 36 |

37 | 38 | 39 | 40 | 41 | 42 | 43 |
44 |

45 | 46 | ## 24/7/2021 Update - ASL article was accepeted to ICCV 2021 47 | A final version of the paper, with updated results for ImageNet-21K pretraining, is released to arxiv. 48 | Note that ASL is becoming the de-facto 'default' loss for high performance multi-label classification, and all the top results in papers-with-code are currently using it. 49 | 50 | 51 | 52 | 53 | 54 | ## Training Code Now Available ! 55 | 56 | With great collaboration by [@GhostWnd](https://github.com/GhostWnd), we 57 | now provide a [script](train.py) for fully reproducing the article 58 | results, and finally a modern multi-label training code is 59 | available for the community. 60 | ## Frequently Asked Questions 61 | Some questions are repeatedly asked in the issues section. make sure to 62 | review them before starting a new issue: 63 | - Regarding combining ASL with other techniques, see 64 | [link](https://github.com/Alibaba-MIIL/ASL/issues/35) 65 | - Regarding implementation of asymmetric clipping, see [link](https://github.com/Alibaba-MIIL/ASL/issues/10) 66 | - Regarding disable_torch_grad_focal_loss option, see 67 | [link](https://github.com/Alibaba-MIIL/ASL/issues/31) 68 | - Regarding squish Vs crop resizing, see 69 | [link](https://github.com/Alibaba-MIIL/ASL/issues/30#issuecomment-754005570) 70 | - Regarding training tricks, see 71 | [link](https://github.com/Alibaba-MIIL/ASL/issues/30#issuecomment-750780576) 72 | - How to apply ASL to your own dataset, see 73 | [link](https://github.com/Alibaba-MIIL/ASL/issues/22#issuecomment-736721770) 74 | 75 | 76 | 77 | ## Asymmetric Loss (ASL) Implementation 78 | In this PyTorch [file](\src\loss_functions\losses.py), we provide 79 | implementations of our new loss function, ASL, that can serve as a 80 | drop-in replacement for standard loss functions (Cross-Entropy and 81 | Focal-Loss) 82 | 83 | For the multi-label case (sigmoids), the two implementations are: 84 | - ```class AsymmetricLoss(nn.Module)``` 85 | - ```class AsymmetricLossOptimized(nn.Module)```
86 | 87 | The two losses are bit-accurate. However, AsymmetricLossOptimized() 88 | contains a more optimized (and complicated) way of implementing ASL, 89 | which minimizes memory allocations, gpu uploading, and favors inplace 90 | operations. 91 | 92 | For the single-label case (softmax), the implementations is called: 93 | - ```class ASLSingleLabel(nn.Module)``` 94 | 95 | 96 | 97 | ## Pretrained Models 98 | In this [link](MODEL_ZOO.md), we provide pre-trained models on various 99 | dataset. 100 | 101 | ## Validation Code 102 | Thanks to external contribution of @hellbell, we now provide a 103 | validation code that repdroduces the article results on MS-COCO: 104 | 105 | ``` 106 | python validate.py \ 107 | --model_name=tresnet_l \ 108 | --model_path=./models_local/MS_COCO_TRresNet_L_448_86.6.pth 109 | ``` 110 | 111 | ## Inference Code 112 | We provide [inference code](infer.py), that demonstrate how to load our 113 | model, pre-process an image and do actuall inference. Example run of 114 | MS-COCO model (after downloading the relevant model): 115 | ``` 116 | python infer.py \ 117 | --dataset_type=MS-COCO \ 118 | --model_name=tresnet_l \ 119 | --model_path=./models_local/MS_COCO_TRresNet_L_448_86.6.pth \ 120 | --pic_path=./pics/000000000885.jpg \ 121 | --input_size=448 122 | ``` 123 | which will result in: 124 |

125 | 126 | 127 | 128 | 129 |
130 |

131 | 132 | Example run of OpenImages model: 133 | ``` 134 | python infer.py \ 135 | --dataset_type=OpenImages \ 136 | --model_name=tresnet_l \ 137 | --model_path=./models_local/Open_ImagesV6_TRresNet_L_448.pth \ 138 | --pic_path=./pics/000000000885.jpg \ 139 | --input_size=448 140 | ``` 141 |

142 | 143 | 144 | 145 | 146 |
147 |

148 | 149 | 150 | ## Citation 151 | ``` 152 | @misc{benbaruch2020asymmetric, 153 | title={Asymmetric Loss For Multi-Label Classification}, 154 | author={Emanuel Ben-Baruch and Tal Ridnik and Nadav Zamir and Asaf Noy and Itamar Friedman and Matan Protter and Lihi Zelnik-Manor}, 155 | year={2020}, 156 | eprint={2009.14119}, 157 | archivePrefix={arXiv}, 158 | primaryClass={cs.CV} } 159 | ``` 160 | 161 | ## Contact 162 | Feel free to contact if there are any questions or issues - Emanuel 163 | Ben-Baruch (emanuel.benbaruch@alibaba-inc.com) or Tal Ridnik (tal.ridnik@alibaba-inc.com). 164 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.helper_functions.helper_functions import parse_args 3 | from src.loss_functions.losses import AsymmetricLoss, AsymmetricLossOptimized 4 | from src.models import create_model 5 | import argparse 6 | import matplotlib 7 | 8 | matplotlib.use('TkAgg') 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | import numpy as np 12 | 13 | parser = argparse.ArgumentParser(description='ASL MS-COCO Inference on a single image') 14 | 15 | parser.add_argument('--model_path', type=str, default='./models_local/TRresNet_L_448_86.6.pth') 16 | parser.add_argument('--pic_path', type=str, default='./pics/000000000885.jpg') 17 | parser.add_argument('--model_name', type=str, default='tresnet_l') 18 | parser.add_argument('--input_size', type=int, default=448) 19 | parser.add_argument('--dataset_type', type=str, default='MS-COCO') 20 | parser.add_argument('--th', type=float, default=None) 21 | 22 | 23 | def main(): 24 | print('ASL Example Inference code on a single image') 25 | 26 | # parsing args 27 | args = parse_args(parser) 28 | 29 | # setup model 30 | print('creating and loading the model...') 31 | state = torch.load(args.model_path, map_location='cpu') 32 | args.num_classes = state['num_classes'] 33 | model = create_model(args).cuda() 34 | model.load_state_dict(state['model'], strict=True) 35 | model.eval() 36 | classes_list = np.array(list(state['idx_to_class'].values())) 37 | print('done\n') 38 | 39 | # doing inference 40 | print('loading image and doing inference...') 41 | im = Image.open(args.pic_path) 42 | im_resize = im.resize((args.input_size, args.input_size)) 43 | np_img = np.array(im_resize, dtype=np.uint8) 44 | tensor_img = torch.from_numpy(np_img).permute(2, 0, 1).float() / 255.0 # HWC to CHW 45 | tensor_batch = torch.unsqueeze(tensor_img, 0).cuda() 46 | output = torch.squeeze(torch.sigmoid(model(tensor_batch))) 47 | np_output = output.cpu().detach().numpy() 48 | detected_classes = classes_list[np_output > args.th] 49 | print('done\n') 50 | 51 | # example loss calculation 52 | output = model(tensor_batch) 53 | loss_func1 = AsymmetricLoss() 54 | loss_func2 = AsymmetricLossOptimized() 55 | target = output.clone() 56 | target[output < 0] = 0 # mockup target 57 | target[output >= 0] = 1 58 | loss1 = loss_func1(output, target) 59 | loss2 = loss_func2(output, target) 60 | assert abs((loss1.item() - loss2.item())) < 1e-6 61 | 62 | # displaying image 63 | print('showing image on screen...') 64 | fig = plt.figure() 65 | plt.imshow(im) 66 | plt.axis('off') 67 | plt.axis('tight') 68 | # plt.rcParams["axes.titlesize"] = 10 69 | plt.title("detected classes: {}".format(detected_classes)) 70 | 71 | plt.show() 72 | print('done\n') 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /pics/000000000885.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/000000000885.jpg -------------------------------------------------------------------------------- /pics/backbones.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/backbones.png -------------------------------------------------------------------------------- /pics/detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/detection.png -------------------------------------------------------------------------------- /pics/example_inference.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/example_inference.jpeg -------------------------------------------------------------------------------- /pics/example_inference_open_images.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/example_inference_open_images.jpeg -------------------------------------------------------------------------------- /pics/herbarium.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/herbarium.png -------------------------------------------------------------------------------- /pics/loss_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/loss_graph.png -------------------------------------------------------------------------------- /pics/main_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/main_pic.png -------------------------------------------------------------------------------- /pics/ms_coco_scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/ms_coco_scores.png -------------------------------------------------------------------------------- /pics/open_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/ASL/8c9e0bd8d5d450cf19093363fc08aa7244ad4408/pics/open_images.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | torchvision>=0.5.0 3 | git+https://github.com/mapillary/inplace_abn.git@v1.0.12 4 | randaugment 5 | pycocotools -------------------------------------------------------------------------------- /src/helper_functions/helper_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | import random 4 | import time 5 | from copy import deepcopy 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from torchvision import datasets as datasets 10 | import torch 11 | from PIL import ImageDraw 12 | from pycocotools.coco import COCO 13 | 14 | 15 | def parse_args(parser): 16 | # parsing args 17 | args = parser.parse_args() 18 | if args.dataset_type == 'OpenImages': 19 | args.do_bottleneck_head = True 20 | if args.th == None: 21 | args.th = 0.995 22 | else: 23 | args.do_bottleneck_head = False 24 | if args.th == None: 25 | args.th = 0.7 26 | return args 27 | 28 | 29 | def average_precision(output, target): 30 | epsilon = 1e-8 31 | 32 | # sort examples 33 | indices = output.argsort()[::-1] 34 | # Computes prec@i 35 | total_count_ = np.cumsum(np.ones((len(output), 1))) 36 | 37 | target_ = target[indices] 38 | ind = target_ == 1 39 | pos_count_ = np.cumsum(ind) 40 | total = pos_count_[-1] 41 | pos_count_[np.logical_not(ind)] = 0 42 | pp = pos_count_ / total_count_ 43 | precision_at_i_ = np.sum(pp) 44 | precision_at_i = precision_at_i_ / (total + epsilon) 45 | 46 | return precision_at_i 47 | 48 | 49 | def mAP(targs, preds): 50 | """Returns the model's average precision for each class 51 | Return: 52 | ap (FloatTensor): 1xK tensor, with avg precision for each class k 53 | """ 54 | 55 | if np.size(preds) == 0: 56 | return 0 57 | ap = np.zeros((preds.shape[1])) 58 | # compute average precision for each class 59 | for k in range(preds.shape[1]): 60 | # sort scores 61 | scores = preds[:, k] 62 | targets = targs[:, k] 63 | # compute average precision 64 | ap[k] = average_precision(scores, targets) 65 | return 100 * ap.mean() 66 | 67 | 68 | class AverageMeter(object): 69 | def __init__(self): 70 | self.val = None 71 | self.sum = None 72 | self.cnt = None 73 | self.avg = None 74 | self.ema = None 75 | self.initialized = False 76 | 77 | def update(self, val, n=1): 78 | if not self.initialized: 79 | self.initialize(val, n) 80 | else: 81 | self.add(val, n) 82 | 83 | def initialize(self, val, n): 84 | self.val = val 85 | self.sum = val * n 86 | self.cnt = n 87 | self.avg = val 88 | self.ema = val 89 | self.initialized = True 90 | 91 | def add(self, val, n): 92 | self.val = val 93 | self.sum += val * n 94 | self.cnt += n 95 | self.avg = self.sum / self.cnt 96 | self.ema = self.ema * 0.99 + self.val * 0.01 97 | 98 | 99 | class CocoDetection(datasets.coco.CocoDetection): 100 | def __init__(self, root, annFile, transform=None, target_transform=None): 101 | self.root = root 102 | self.coco = COCO(annFile) 103 | 104 | self.ids = list(self.coco.imgToAnns.keys()) 105 | self.transform = transform 106 | self.target_transform = target_transform 107 | self.cat2cat = dict() 108 | for cat in self.coco.cats.keys(): 109 | self.cat2cat[cat] = len(self.cat2cat) 110 | # print(self.cat2cat) 111 | 112 | def __getitem__(self, index): 113 | coco = self.coco 114 | img_id = self.ids[index] 115 | ann_ids = coco.getAnnIds(imgIds=img_id) 116 | target = coco.loadAnns(ann_ids) 117 | 118 | output = torch.zeros((3, 80), dtype=torch.long) 119 | for obj in target: 120 | if obj['area'] < 32 * 32: 121 | output[0][self.cat2cat[obj['category_id']]] = 1 122 | elif obj['area'] < 96 * 96: 123 | output[1][self.cat2cat[obj['category_id']]] = 1 124 | else: 125 | output[2][self.cat2cat[obj['category_id']]] = 1 126 | target = output 127 | 128 | path = coco.loadImgs(img_id)[0]['file_name'] 129 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 130 | if self.transform is not None: 131 | img = self.transform(img) 132 | 133 | if self.target_transform is not None: 134 | target = self.target_transform(target) 135 | return img, target 136 | 137 | 138 | class ModelEma(torch.nn.Module): 139 | def __init__(self, model, decay=0.9997, device=None): 140 | super(ModelEma, self).__init__() 141 | # make a copy of the model for accumulating moving average of weights 142 | self.module = deepcopy(model) 143 | self.module.eval() 144 | self.decay = decay 145 | self.device = device # perform ema on different device from model if set 146 | if self.device is not None: 147 | self.module.to(device=device) 148 | 149 | def _update(self, model, update_fn): 150 | with torch.no_grad(): 151 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 152 | if self.device is not None: 153 | model_v = model_v.to(device=self.device) 154 | ema_v.copy_(update_fn(ema_v, model_v)) 155 | 156 | def update(self, model): 157 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 158 | 159 | def set(self, model): 160 | self._update(model, update_fn=lambda e, m: m) 161 | 162 | 163 | class CutoutPIL(object): 164 | def __init__(self, cutout_factor=0.5): 165 | self.cutout_factor = cutout_factor 166 | 167 | def __call__(self, x): 168 | img_draw = ImageDraw.Draw(x) 169 | h, w = x.size[0], x.size[1] # HWC 170 | h_cutout = int(self.cutout_factor * h + 0.5) 171 | w_cutout = int(self.cutout_factor * w + 0.5) 172 | y_c = np.random.randint(h) 173 | x_c = np.random.randint(w) 174 | 175 | y1 = np.clip(y_c - h_cutout // 2, 0, h) 176 | y2 = np.clip(y_c + h_cutout // 2, 0, h) 177 | x1 = np.clip(x_c - w_cutout // 2, 0, w) 178 | x2 = np.clip(x_c + w_cutout // 2, 0, w) 179 | fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 180 | img_draw.rectangle([x1, y1, x2, y2], fill=fill_color) 181 | 182 | return x 183 | 184 | 185 | def add_weight_decay(model, weight_decay=1e-4, skip_list=()): 186 | decay = [] 187 | no_decay = [] 188 | for name, param in model.named_parameters(): 189 | if not param.requires_grad: 190 | continue # frozen weights 191 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 192 | no_decay.append(param) 193 | else: 194 | decay.append(param) 195 | return [ 196 | {'params': no_decay, 'weight_decay': 0.}, 197 | {'params': decay, 'weight_decay': weight_decay}] 198 | -------------------------------------------------------------------------------- /src/loss_functions/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AsymmetricLoss(nn.Module): 6 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): 7 | super(AsymmetricLoss, self).__init__() 8 | 9 | self.gamma_neg = gamma_neg 10 | self.gamma_pos = gamma_pos 11 | self.clip = clip 12 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 13 | self.eps = eps 14 | 15 | def forward(self, x, y): 16 | """" 17 | Parameters 18 | ---------- 19 | x: input logits 20 | y: targets (multi-label binarized vector) 21 | """ 22 | 23 | # Calculating Probabilities 24 | x_sigmoid = torch.sigmoid(x) 25 | xs_pos = x_sigmoid 26 | xs_neg = 1 - x_sigmoid 27 | 28 | # Asymmetric Clipping 29 | if self.clip is not None and self.clip > 0: 30 | xs_neg = (xs_neg + self.clip).clamp(max=1) 31 | 32 | # Basic CE calculation 33 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 34 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 35 | loss = los_pos + los_neg 36 | 37 | # Asymmetric Focusing 38 | if self.gamma_neg > 0 or self.gamma_pos > 0: 39 | if self.disable_torch_grad_focal_loss: 40 | torch.set_grad_enabled(False) 41 | pt0 = xs_pos * y 42 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 43 | pt = pt0 + pt1 44 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 45 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 46 | if self.disable_torch_grad_focal_loss: 47 | torch.set_grad_enabled(True) 48 | loss *= one_sided_w 49 | 50 | return -loss.sum() 51 | 52 | 53 | class AsymmetricLossOptimized(nn.Module): 54 | ''' Notice - optimized version, minimizes memory allocation and gpu uploading, 55 | favors inplace operations''' 56 | 57 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): 58 | super(AsymmetricLossOptimized, self).__init__() 59 | 60 | self.gamma_neg = gamma_neg 61 | self.gamma_pos = gamma_pos 62 | self.clip = clip 63 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 64 | self.eps = eps 65 | 66 | # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations 67 | self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None 68 | 69 | def forward(self, x, y): 70 | """" 71 | Parameters 72 | ---------- 73 | x: input logits 74 | y: targets (multi-label binarized vector) 75 | """ 76 | 77 | self.targets = y 78 | self.anti_targets = 1 - y 79 | 80 | # Calculating Probabilities 81 | self.xs_pos = torch.sigmoid(x) 82 | self.xs_neg = 1.0 - self.xs_pos 83 | 84 | # Asymmetric Clipping 85 | if self.clip is not None and self.clip > 0: 86 | self.xs_neg.add_(self.clip).clamp_(max=1) 87 | 88 | # Basic CE calculation 89 | self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) 90 | self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) 91 | 92 | # Asymmetric Focusing 93 | if self.gamma_neg > 0 or self.gamma_pos > 0: 94 | if self.disable_torch_grad_focal_loss: 95 | torch.set_grad_enabled(False) 96 | self.xs_pos = self.xs_pos * self.targets 97 | self.xs_neg = self.xs_neg * self.anti_targets 98 | self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, 99 | self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) 100 | if self.disable_torch_grad_focal_loss: 101 | torch.set_grad_enabled(True) 102 | self.loss *= self.asymmetric_w 103 | 104 | return -self.loss.sum() 105 | 106 | 107 | class ASLSingleLabel(nn.Module): 108 | ''' 109 | This loss is intended for single-label classification problems 110 | ''' 111 | def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'): 112 | super(ASLSingleLabel, self).__init__() 113 | 114 | self.eps = eps 115 | self.logsoftmax = nn.LogSoftmax(dim=-1) 116 | self.targets_classes = [] 117 | self.gamma_pos = gamma_pos 118 | self.gamma_neg = gamma_neg 119 | self.reduction = reduction 120 | 121 | def forward(self, inputs, target): 122 | ''' 123 | "input" dimensions: - (batch_size,number_classes) 124 | "target" dimensions: - (batch_size) 125 | ''' 126 | num_classes = inputs.size()[-1] 127 | log_preds = self.logsoftmax(inputs) 128 | self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) 129 | 130 | # ASL weights 131 | targets = self.targets_classes 132 | anti_targets = 1 - targets 133 | xs_pos = torch.exp(log_preds) 134 | xs_neg = 1 - xs_pos 135 | xs_pos = xs_pos * targets 136 | xs_neg = xs_neg * anti_targets 137 | asymmetric_w = torch.pow(1 - xs_pos - xs_neg, 138 | self.gamma_pos * targets + self.gamma_neg * anti_targets) 139 | log_preds = log_preds * asymmetric_w 140 | 141 | if self.eps > 0: # label smoothing 142 | self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes) 143 | 144 | # loss calculation 145 | loss = - self.targets_classes.mul(log_preds) 146 | 147 | loss = loss.sum(dim=-1) 148 | if self.reduction == 'mean': 149 | loss = loss.mean() 150 | 151 | return loss 152 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import create_model 2 | 3 | __all__ = ['create_model'] 4 | -------------------------------------------------------------------------------- /src/models/tresnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .tresnet import TResnetM, TResnetL, TResnetXL 2 | -------------------------------------------------------------------------------- /src/models/tresnet/layers/anti_aliasing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class AntiAliasDownsampleLayer(nn.Module): 9 | def __init__(self, remove_model_jit: bool = False, filt_size: int = 3, stride: int = 2, 10 | channels: int = 0): 11 | super(AntiAliasDownsampleLayer, self).__init__() 12 | if not remove_model_jit: 13 | self.op = DownsampleJIT(filt_size, stride, channels) 14 | else: 15 | self.op = Downsample(filt_size, stride, channels) 16 | 17 | def forward(self, x): 18 | return self.op(x) 19 | 20 | 21 | @torch.jit.script 22 | class DownsampleJIT(object): 23 | def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): 24 | self.stride = stride 25 | self.filt_size = filt_size 26 | self.channels = channels 27 | 28 | assert self.filt_size == 3 29 | assert stride == 2 30 | a = torch.tensor([1., 2., 1.]) 31 | 32 | filt = (a[:, None] * a[None, :]).clone().detach() 33 | filt = filt / torch.sum(filt) 34 | self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half() 35 | 36 | def __call__(self, input: torch.Tensor): 37 | if input.dtype != self.filt.dtype: 38 | self.filt = self.filt.float() 39 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 40 | return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1]) 41 | 42 | 43 | class Downsample(nn.Module): 44 | def __init__(self, filt_size=3, stride=2, channels=None): 45 | super(Downsample, self).__init__() 46 | self.filt_size = filt_size 47 | self.stride = stride 48 | self.channels = channels 49 | 50 | 51 | assert self.filt_size == 3 52 | a = torch.tensor([1., 2., 1.]) 53 | 54 | filt = (a[:, None] * a[None, :]).clone().detach() 55 | filt = filt / torch.sum(filt) 56 | self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 57 | 58 | def forward(self, input): 59 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 60 | return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) 61 | -------------------------------------------------------------------------------- /src/models/tresnet/layers/avg_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class FastAvgPool2d(nn.Module): 8 | def __init__(self, flatten=False): 9 | super(FastAvgPool2d, self).__init__() 10 | self.flatten = flatten 11 | 12 | def forward(self, x): 13 | if self.flatten: 14 | in_size = x.size() 15 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 16 | else: 17 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/models/tresnet/layers/general_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from src.models.tresnet.layers.avg_pool import FastAvgPool2d 6 | 7 | 8 | class Flatten(nn.Module): 9 | def forward(self, x): 10 | return x.view(x.size(0), -1) 11 | 12 | 13 | class DepthToSpace(nn.Module): 14 | 15 | def __init__(self, block_size): 16 | super().__init__() 17 | self.bs = block_size 18 | 19 | def forward(self, x): 20 | N, C, H, W = x.size() 21 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 22 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 23 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 24 | return x 25 | 26 | 27 | class SpaceToDepthModule(nn.Module): 28 | def __init__(self, remove_model_jit=False): 29 | super().__init__() 30 | if not remove_model_jit: 31 | self.op = SpaceToDepthJit() 32 | else: 33 | self.op = SpaceToDepth() 34 | 35 | def forward(self, x): 36 | return self.op(x) 37 | 38 | 39 | class SpaceToDepth(nn.Module): 40 | def __init__(self, block_size=4): 41 | super().__init__() 42 | assert block_size == 4 43 | self.bs = block_size 44 | 45 | def forward(self, x): 46 | N, C, H, W = x.size() 47 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 48 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 49 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 50 | return x 51 | 52 | 53 | @torch.jit.script 54 | class SpaceToDepthJit(object): 55 | def __call__(self, x: torch.Tensor): 56 | # assuming hard-coded that block_size==4 for acceleration 57 | N, C, H, W = x.size() 58 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 59 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 60 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 61 | return x 62 | 63 | 64 | class hard_sigmoid(nn.Module): 65 | def __init__(self, inplace=True): 66 | super(hard_sigmoid, self).__init__() 67 | self.inplace = inplace 68 | 69 | def forward(self, x): 70 | if self.inplace: 71 | return x.add_(3.).clamp_(0., 6.).div_(6.) 72 | else: 73 | return F.relu6(x + 3.) / 6. 74 | 75 | 76 | class SEModule(nn.Module): 77 | 78 | def __init__(self, channels, reduction_channels, inplace=True): 79 | super(SEModule, self).__init__() 80 | self.avg_pool = FastAvgPool2d() 81 | self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True) 82 | self.relu = nn.ReLU(inplace=inplace) 83 | self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True) 84 | # self.activation = hard_sigmoid(inplace=inplace) 85 | self.activation = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | x_se = self.avg_pool(x) 89 | x_se2 = self.fc1(x_se) 90 | x_se2 = self.relu(x_se2) 91 | x_se = self.fc2(x_se2) 92 | x_se = self.activation(x_se) 93 | return x * x_se 94 | -------------------------------------------------------------------------------- /src/models/tresnet/tresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Module as Module 4 | from collections import OrderedDict 5 | from src.models.tresnet.layers.anti_aliasing import AntiAliasDownsampleLayer 6 | from .layers.avg_pool import FastAvgPool2d 7 | from .layers.general_layers import SEModule, SpaceToDepthModule 8 | from inplace_abn import InPlaceABN, ABN 9 | 10 | 11 | def InplacABN_to_ABN(module: nn.Module) -> nn.Module: 12 | # convert all InplaceABN layer to bit-accurate ABN layers. 13 | if isinstance(module, InPlaceABN): 14 | module_new = ABN(module.num_features, activation=module.activation, 15 | activation_param=module.activation_param) 16 | for key in module.state_dict(): 17 | module_new.state_dict()[key].copy_(module.state_dict()[key]) 18 | module_new.training = module.training 19 | module_new.weight.data = module_new.weight.abs() + module_new.eps 20 | return module_new 21 | for name, child in reversed(module._modules.items()): 22 | new_child = InplacABN_to_ABN(child) 23 | if new_child != child: 24 | module._modules[name] = new_child 25 | return module 26 | 27 | class bottleneck_head(nn.Module): 28 | def __init__(self, num_features, num_classes, bottleneck_features=200): 29 | super(bottleneck_head, self).__init__() 30 | self.embedding_generator = nn.ModuleList() 31 | self.embedding_generator.append(nn.Linear(num_features, bottleneck_features)) 32 | self.embedding_generator = nn.Sequential(*self.embedding_generator) 33 | self.FC = nn.Linear(bottleneck_features, num_classes) 34 | 35 | def forward(self, x): 36 | self.embedding = self.embedding_generator(x) 37 | logits = self.FC(self.embedding) 38 | return logits 39 | 40 | 41 | def conv2d(ni, nf, stride): 42 | return nn.Sequential( 43 | nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False), 44 | nn.BatchNorm2d(nf), 45 | nn.ReLU(inplace=True) 46 | ) 47 | 48 | 49 | def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1): 50 | return nn.Sequential( 51 | nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, 52 | bias=False), 53 | InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param) 54 | ) 55 | 56 | 57 | class BasicBlock(Module): 58 | expansion = 1 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): 61 | super(BasicBlock, self).__init__() 62 | if stride == 1: 63 | self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3) 64 | else: 65 | if anti_alias_layer is None: 66 | self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3) 67 | else: 68 | self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3), 69 | anti_alias_layer(channels=planes, filt_size=3, stride=2)) 70 | 71 | self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity") 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | reduce_layer_planes = max(planes * self.expansion // 4, 64) 76 | self.se = SEModule(planes * self.expansion, reduce_layer_planes) if use_se else None 77 | 78 | def forward(self, x): 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | else: 82 | residual = x 83 | 84 | out = self.conv1(x) 85 | out = self.conv2(out) 86 | 87 | if self.se is not None: out = self.se(out) 88 | 89 | out += residual 90 | 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class Bottleneck(Module): 97 | expansion = 4 98 | 99 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): 100 | super(Bottleneck, self).__init__() 101 | self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", 102 | activation_param=1e-3) 103 | if stride == 1: 104 | self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", 105 | activation_param=1e-3) 106 | else: 107 | if anti_alias_layer is None: 108 | self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu", 109 | activation_param=1e-3) 110 | else: 111 | self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1, 112 | activation="leaky_relu", activation_param=1e-3), 113 | anti_alias_layer(channels=planes, filt_size=3, stride=2)) 114 | 115 | self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1, 116 | activation="identity") 117 | 118 | self.relu = nn.ReLU(inplace=True) 119 | self.downsample = downsample 120 | self.stride = stride 121 | 122 | reduce_layer_planes = max(planes * self.expansion // 8, 64) 123 | self.se = SEModule(planes, reduce_layer_planes) if use_se else None 124 | 125 | def forward(self, x): 126 | if self.downsample is not None: 127 | residual = self.downsample(x) 128 | else: 129 | residual = x 130 | 131 | out = self.conv1(x) 132 | out = self.conv2(out) 133 | if self.se is not None: out = self.se(out) 134 | 135 | out = self.conv3(out) 136 | out = out + residual # no inplace 137 | out = self.relu(out) 138 | 139 | return out 140 | 141 | 142 | class TResNet(Module): 143 | 144 | def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, 145 | do_bottleneck_head=False,bottleneck_features=512): 146 | super(TResNet, self).__init__() 147 | 148 | # JIT layers 149 | space_to_depth = SpaceToDepthModule() 150 | anti_alias_layer = AntiAliasDownsampleLayer 151 | global_pool_layer = FastAvgPool2d(flatten=True) 152 | 153 | # TResnet stages 154 | self.inplanes = int(64 * width_factor) 155 | self.planes = int(64 * width_factor) 156 | conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3) 157 | layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True, 158 | anti_alias_layer=anti_alias_layer) # 56x56 159 | layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, 160 | anti_alias_layer=anti_alias_layer) # 28x28 161 | layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, 162 | anti_alias_layer=anti_alias_layer) # 14x14 163 | layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, 164 | anti_alias_layer=anti_alias_layer) # 7x7 165 | 166 | # body 167 | self.body = nn.Sequential(OrderedDict([ 168 | ('SpaceToDepth', space_to_depth), 169 | ('conv1', conv1), 170 | ('layer1', layer1), 171 | ('layer2', layer2), 172 | ('layer3', layer3), 173 | ('layer4', layer4)])) 174 | 175 | # head 176 | self.embeddings = [] 177 | self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)])) 178 | self.num_features = (self.planes * 8) * Bottleneck.expansion 179 | if do_bottleneck_head: 180 | fc = bottleneck_head(self.num_features, num_classes, 181 | bottleneck_features=bottleneck_features) 182 | else: 183 | fc = nn.Linear(self.num_features , num_classes) 184 | 185 | self.head = nn.Sequential(OrderedDict([('fc', fc)])) 186 | 187 | # model initilization 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 191 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # residual connections special initialization 196 | for m in self.modules(): 197 | if isinstance(m, BasicBlock): 198 | m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero 199 | if isinstance(m, Bottleneck): 200 | m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero 201 | if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) 202 | 203 | def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None): 204 | downsample = None 205 | if stride != 1 or self.inplanes != planes * block.expansion: 206 | layers = [] 207 | if stride == 2: 208 | # avg pooling before 1x1 conv 209 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) 210 | layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, 211 | activation="identity")] 212 | downsample = nn.Sequential(*layers) 213 | 214 | layers = [] 215 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se, 216 | anti_alias_layer=anti_alias_layer)) 217 | self.inplanes = planes * block.expansion 218 | for i in range(1, blocks): layers.append( 219 | block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer)) 220 | return nn.Sequential(*layers) 221 | 222 | def forward(self, x): 223 | x = self.body(x) 224 | self.embeddings = self.global_pool(x) 225 | logits = self.head(self.embeddings) 226 | return logits 227 | 228 | 229 | def TResnetM(model_params): 230 | """Constructs a medium TResnet model. 231 | """ 232 | in_chans = 3 233 | num_classes = model_params['num_classes'] 234 | model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans) 235 | return model 236 | 237 | 238 | def TResnetL(model_params): 239 | """Constructs a large TResnet model. 240 | """ 241 | in_chans = 3 242 | num_classes = model_params['num_classes'] 243 | do_bottleneck_head = model_params['args'].do_bottleneck_head 244 | model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, 245 | do_bottleneck_head=do_bottleneck_head) 246 | return model 247 | 248 | 249 | def TResnetXL(model_params): 250 | """Constructs a xlarge TResnet model. 251 | """ 252 | in_chans = 3 253 | num_classes = model_params['num_classes'] 254 | model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3) 255 | 256 | return model 257 | -------------------------------------------------------------------------------- /src/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import create_model 2 | __all__ = ['create_model'] -------------------------------------------------------------------------------- /src/models/utils/factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | from ..tresnet import TResnetM, TResnetL, TResnetXL 6 | 7 | 8 | def create_model(args): 9 | """Create a model 10 | """ 11 | model_params = {'args': args, 'num_classes': args.num_classes} 12 | args = model_params['args'] 13 | args.model_name = args.model_name.lower() 14 | 15 | if args.model_name=='tresnet_m': 16 | model = TResnetM(model_params) 17 | elif args.model_name=='tresnet_l': 18 | model = TResnetL(model_params) 19 | elif args.model_name=='tresnet_xl': 20 | model = TResnetXL(model_params) 21 | else: 22 | print("model: {} not found !!".format(args.model_name)) 23 | exit(-1) 24 | 25 | return model 26 | -------------------------------------------------------------------------------- /tests/test_asl.py: -------------------------------------------------------------------------------- 1 | 2 | # test_asl.py 3 | # tests auto-generated by https://www.codium.ai/ 4 | # testing https://github.com/Alibaba-MIIL/ASL/blob/b9d01aff9f66ccddab6112e47f3ed0ceb59ad7f5/tests/test_asl.py#L6 class 5 | 6 | import unittest 7 | import torch 8 | from src.loss_functions.losses import AsymmetricLoss 9 | 10 | """ 11 | Code Analysis for AsymmetricLoss() class: 12 | - This class is a custom loss function called AsymmetricLoss, which is a subclass of the nn.Module class. 13 | - It is used to calculate the loss between the input logits and the targets (multi-label binarized vector). 14 | - The __init__ method initializes the parameters of the class, such as gamma_neg, gamma_pos, clip, eps, and disable_torch_grad_focal_loss. 15 | - The forward method is used to calculate the loss between the input logits and the targets. 16 | - The forward method first calculates the sigmoid of the input logits and then calculates the positive and negative logits. 17 | - If the clip parameter is not None and greater than 0, the negative logits are clipped to a maximum of 1. 18 | - The loss is then calculated using the positive and negative logits. 19 | - If the gamma_neg and gamma_pos parameters are greater than 0, the loss is weighted using the one_sided_gamma and one_sided_w parameters. 20 | - Finally, the loss is returned as a negative sum. 21 | """ 22 | 23 | 24 | """ 25 | Test strategies: 26 | - test_init(): tests that the parameters of the class are initialized correctly 27 | - test_forward_positive_logits(): tests that the forward method correctly calculates the loss for positive logits 28 | - test_forward_negative_logits(): tests that the forward method correctly calculates the loss for negative logits 29 | - test_forward_clipped_logits(): tests that the forward method correctly calculates the loss for clipped logits 30 | - test_forward_gamma_pos(): tests that the forward method correctly calculates the loss when gamma_pos is greater than 0 31 | - test_forward_gamma_neg(): tests that the forward method correctly calculates the loss when gamma_neg is greater than 0 32 | - test_forward_eps(): tests that the forward method correctly calculates the loss when eps is greater than 0 33 | - test_forward_disable_torch_grad_focal_loss(): tests that the forward method correctly calculates the loss when disable_torch_grad_focal_loss is set to True 34 | """ 35 | class TestAsymmetricLoss(unittest.TestCase): 36 | def setUp(self): 37 | self.loss = AsymmetricLoss() 38 | 39 | def test_init(self): 40 | self.assertEqual(self.loss.gamma_neg, 4) 41 | self.assertEqual(self.loss.gamma_pos, 1) 42 | self.assertEqual(self.loss.clip, 0.05) 43 | self.assertEqual(self.loss.eps, 1e-08) 44 | self.assertTrue(self.loss.disable_torch_grad_focal_loss) 45 | 46 | def test_forward_positive_logits(self): 47 | x = torch.tensor([1., 2., 3.]) 48 | y = torch.tensor([1., 0., 1.]) 49 | expected_loss = -torch.log(torch.sigmoid(x)).sum() 50 | self.assertEqual(self.loss(x, y), expected_loss) 51 | 52 | def test_forward_negative_logits(self): 53 | x = torch.tensor([-1., -2., -3.]) 54 | y = torch.tensor([1., 0., 1.]) 55 | expected_loss = -torch.log(1 - torch.sigmoid(x)).sum() 56 | self.assertEqual(self.loss(x, y), expected_loss) 57 | 58 | def test_forward_clipped_logits(self): 59 | x = torch.tensor([-1., -2., -3.]) 60 | y = torch.tensor([1., 0., 1.]) 61 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)).sum() 62 | self.assertEqual(self.loss(x, y), expected_loss) 63 | 64 | def test_forward_gamma_pos(self): 65 | x = torch.tensor([1., 2., 3.]) 66 | y = torch.tensor([1., 0., 1.]) 67 | self.loss.gamma_pos = 2 68 | expected_loss = -torch.log(torch.sigmoid(x)) * torch.pow(1 - torch.sigmoid(x), 2).sum() 69 | self.assertEqual(self.loss(x, y), expected_loss) 70 | 71 | def test_forward_gamma_neg(self): 72 | x = torch.tensor([-1., -2., -3.]) 73 | y = torch.tensor([1., 0., 1.]) 74 | self.loss.gamma_neg = 3 75 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)) * torch.pow(1 - (torch.sigmoid(x) + self.loss.clip).clamp(max=1), 3).sum() 76 | self.assertEqual(self.loss(x, y), expected_loss) 77 | 78 | def test_forward_eps(self): 79 | x = torch.tensor([-1., -2., -3.]) 80 | y = torch.tensor([1., 0., 1.]) 81 | self.loss.eps = 0.5 82 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(min=0.5)).sum() 83 | self.assertEqual(self.loss(x, y), expected_loss) 84 | 85 | def test_forward_disable_torch_grad_focal_loss(self): 86 | x = torch.tensor([-1., -2., -3.]) 87 | y = torch.tensor([1., 0., 1.]) 88 | self.loss.disable_torch_grad_focal_loss = False 89 | expected_loss = -torch.log((1 - torch.sigmoid(x) + self.loss.clip).clamp(max=1)).sum() 90 | self.assertEqual(self.loss(x, y), expected_loss) 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn.parallel 5 | import torch.optim 6 | import torch.utils.data.distributed 7 | import torchvision.transforms as transforms 8 | from torch.optim import lr_scheduler 9 | from src.helper_functions.helper_functions import mAP, CocoDetection, CutoutPIL, ModelEma, add_weight_decay 10 | from src.models import create_model 11 | from src.loss_functions.losses import AsymmetricLoss 12 | from randaugment import RandAugment 13 | from torch.cuda.amp import GradScaler, autocast 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training') 16 | parser.add_argument('data', metavar='DIR', help='path to dataset', default='/home/MSCOCO_2014/') 17 | parser.add_argument('--lr', default=1e-4, type=float) 18 | parser.add_argument('--model-name', default='tresnet_m') 19 | parser.add_argument('--model-path', default='./tresnet_m.pth', type=str) 20 | parser.add_argument('--num-classes', default=80) 21 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 22 | help='number of data loading workers (default: 16)') 23 | parser.add_argument('--image-size', default=224, type=int, 24 | metavar='N', help='input image size (default: 448)') 25 | parser.add_argument('--thre', default=0.8, type=float, 26 | metavar='N', help='threshold value') 27 | parser.add_argument('-b', '--batch-size', default=128, type=int, 28 | metavar='N', help='mini-batch size (default: 16)') 29 | parser.add_argument('--print-freq', '-p', default=64, type=int, 30 | metavar='N', help='print frequency (default: 64)') 31 | 32 | 33 | def main(): 34 | args = parser.parse_args() 35 | args.do_bottleneck_head = False 36 | 37 | # Setup model 38 | print('creating model...') 39 | model = create_model(args).cuda() 40 | if args.model_path: # make sure to load pretrained ImageNet model 41 | state = torch.load(args.model_path, map_location='cpu') 42 | filtered_dict = {k: v for k, v in state['model'].items() if 43 | (k in model.state_dict() and 'head.fc' not in k)} 44 | model.load_state_dict(filtered_dict, strict=False) 45 | print('done\n') 46 | 47 | # COCO Data loading 48 | instances_path_val = os.path.join(args.data, 'annotations/instances_val2014.json') 49 | instances_path_train = os.path.join(args.data, 'annotations/instances_train2014.json') 50 | # data_path_val = args.data 51 | # data_path_train = args.data 52 | data_path_val = f'{args.data}/val2014' # args.data 53 | data_path_train = f'{args.data}/train2014' # args.data 54 | val_dataset = CocoDetection(data_path_val, 55 | instances_path_val, 56 | transforms.Compose([ 57 | transforms.Resize((args.image_size, args.image_size)), 58 | transforms.ToTensor(), 59 | # normalize, # no need, toTensor does normalization 60 | ])) 61 | train_dataset = CocoDetection(data_path_train, 62 | instances_path_train, 63 | transforms.Compose([ 64 | transforms.Resize((args.image_size, args.image_size)), 65 | CutoutPIL(cutout_factor=0.5), 66 | RandAugment(), 67 | transforms.ToTensor(), 68 | # normalize, 69 | ])) 70 | print("len(val_dataset)): ", len(val_dataset)) 71 | print("len(train_dataset)): ", len(train_dataset)) 72 | 73 | # Pytorch Data loader 74 | train_loader = torch.utils.data.DataLoader( 75 | train_dataset, batch_size=args.batch_size, shuffle=True, 76 | num_workers=args.workers, pin_memory=True) 77 | 78 | val_loader = torch.utils.data.DataLoader( 79 | val_dataset, batch_size=args.batch_size, shuffle=False, 80 | num_workers=args.workers, pin_memory=False) 81 | 82 | # Actuall Training 83 | train_multi_label_coco(model, train_loader, val_loader, args.lr) 84 | 85 | 86 | def train_multi_label_coco(model, train_loader, val_loader, lr): 87 | ema = ModelEma(model, 0.9997) # 0.9997^641=0.82 88 | 89 | # set optimizer 90 | Epochs = 80 91 | Stop_epoch = 40 92 | weight_decay = 1e-4 93 | criterion = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05, disable_torch_grad_focal_loss=True) 94 | parameters = add_weight_decay(model, weight_decay) 95 | optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=0) # true wd, filter_bias_and_bn 96 | steps_per_epoch = len(train_loader) 97 | scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=steps_per_epoch, epochs=Epochs, 98 | pct_start=0.2) 99 | 100 | highest_mAP = 0 101 | trainInfoList = [] 102 | scaler = GradScaler() 103 | for epoch in range(Epochs): 104 | if epoch > Stop_epoch: 105 | break 106 | for i, (inputData, target) in enumerate(train_loader): 107 | inputData = inputData.cuda() 108 | target = target.cuda() # (batch,3,num_classes) 109 | target = target.max(dim=1)[0] 110 | with autocast(): # mixed precision 111 | output = model(inputData).float() # sigmoid will be done in loss ! 112 | loss = criterion(output, target) 113 | model.zero_grad() 114 | 115 | scaler.scale(loss).backward() 116 | # loss.backward() 117 | 118 | scaler.step(optimizer) 119 | scaler.update() 120 | # optimizer.step() 121 | 122 | scheduler.step() 123 | 124 | ema.update(model) 125 | # store information 126 | if i % 100 == 0: 127 | trainInfoList.append([epoch, i, loss.item()]) 128 | print('Epoch [{}/{}], Step [{}/{}], LR {:.1e}, Loss: {:.1f}' 129 | .format(epoch, Epochs, str(i).zfill(3), str(steps_per_epoch).zfill(3), 130 | scheduler.get_last_lr()[0], \ 131 | loss.item())) 132 | 133 | try: 134 | torch.save(model.state_dict(), os.path.join( 135 | 'models/', 'model-{}-{}.ckpt'.format(epoch + 1, i + 1))) 136 | except: 137 | pass 138 | 139 | model.eval() 140 | mAP_score = validate_multi(val_loader, model, ema) 141 | model.train() 142 | if mAP_score > highest_mAP: 143 | highest_mAP = mAP_score 144 | try: 145 | torch.save(model.state_dict(), os.path.join( 146 | 'models/', 'model-highest.ckpt')) 147 | except: 148 | pass 149 | print('current_mAP = {:.2f}, highest_mAP = {:.2f}\n'.format(mAP_score, highest_mAP)) 150 | 151 | 152 | def validate_multi(val_loader, model, ema_model): 153 | print("starting validation") 154 | Sig = torch.nn.Sigmoid() 155 | preds_regular = [] 156 | preds_ema = [] 157 | targets = [] 158 | for i, (input, target) in enumerate(val_loader): 159 | target = target 160 | target = target.max(dim=1)[0] 161 | # compute output 162 | with torch.no_grad(): 163 | with autocast(): 164 | output_regular = Sig(model(input.cuda())).cpu() 165 | output_ema = Sig(ema_model.module(input.cuda())).cpu() 166 | 167 | # for mAP calculation 168 | preds_regular.append(output_regular.cpu().detach()) 169 | preds_ema.append(output_ema.cpu().detach()) 170 | targets.append(target.cpu().detach()) 171 | 172 | mAP_score_regular = mAP(torch.cat(targets).numpy(), torch.cat(preds_regular).numpy()) 173 | mAP_score_ema = mAP(torch.cat(targets).numpy(), torch.cat(preds_ema).numpy()) 174 | print("mAP score regular {:.2f}, mAP score EMA {:.2f}".format(mAP_score_regular, mAP_score_ema)) 175 | return max(mAP_score_regular, mAP_score_ema) 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | # Adopted from: https://github.com/allenai/elastic/blob/master/multilabel_classify.py 2 | # special thanks to @hellbell 3 | 4 | import argparse 5 | import time 6 | import torch 7 | import torch.nn.parallel 8 | import torch.optim 9 | import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import os 12 | 13 | from src.helper_functions.helper_functions import mAP, AverageMeter, CocoDetection 14 | from src.models import create_model 15 | import numpy as np 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 18 | parser.add_argument('data', metavar='DIR', help='path to dataset') 19 | parser.add_argument('--model-name', default='tresnet_l') 20 | parser.add_argument('--model-path', default='./TRresNet_L_448_86.6.pth', type=str) 21 | parser.add_argument('--num-classes', default=80) 22 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 23 | help='number of data loading workers (default: 16)') 24 | parser.add_argument('--image-size', default=448, type=int, 25 | metavar='N', help='input image size (default: 448)') 26 | parser.add_argument('--thre', default=0.8, type=float, 27 | metavar='N', help='threshold value') 28 | parser.add_argument('-b', '--batch-size', default=32, type=int, 29 | metavar='N', help='mini-batch size (default: 16)') 30 | parser.add_argument('--print-freq', '-p', default=64, type=int, 31 | metavar='N', help='print frequency (default: 64)') 32 | 33 | 34 | def main(): 35 | args = parser.parse_args() 36 | args.batch_size = args.batch_size 37 | 38 | # setup model 39 | print('creating and loading the model...') 40 | state = torch.load(args.model_path, map_location='cpu') 41 | args.num_classes = state['num_classes'] 42 | args.do_bottleneck_head = False 43 | model = create_model(args).cuda() 44 | model.load_state_dict(state['model'], strict=True) 45 | model.eval() 46 | classes_list = np.array(list(state['idx_to_class'].values())) 47 | print('done\n') 48 | 49 | # Data loading code 50 | normalize = transforms.Normalize(mean=[0, 0, 0], 51 | std=[1, 1, 1]) 52 | 53 | instances_path = os.path.join(args.data, 'annotations/instances_val2014.json') 54 | data_path = os.path.join(args.data, 'val2014') 55 | val_dataset = CocoDetection(data_path, 56 | instances_path, 57 | transforms.Compose([ 58 | transforms.Resize((args.image_size, args.image_size)), 59 | transforms.ToTensor(), 60 | normalize, 61 | ])) 62 | 63 | print("len(val_dataset)): ", len(val_dataset)) 64 | val_loader = torch.utils.data.DataLoader( 65 | val_dataset, batch_size=args.batch_size, shuffle=False, 66 | num_workers=args.workers, pin_memory=True) 67 | 68 | validate_multi(val_loader, model, args) 69 | 70 | 71 | def validate_multi(val_loader, model, args): 72 | print("starting actuall validation") 73 | batch_time = AverageMeter() 74 | prec = AverageMeter() 75 | rec = AverageMeter() 76 | mAP_meter = AverageMeter() 77 | 78 | Sig = torch.nn.Sigmoid() 79 | 80 | end = time.time() 81 | tp, fp, fn, tn, count = 0, 0, 0, 0, 0 82 | preds = [] 83 | targets = [] 84 | for i, (input, target) in enumerate(val_loader): 85 | target = target 86 | target = target.max(dim=1)[0] 87 | # compute output 88 | with torch.no_grad(): 89 | output = Sig(model(input.cuda())).cpu() 90 | 91 | # for mAP calculation 92 | preds.append(output.cpu()) 93 | targets.append(target.cpu()) 94 | 95 | # measure accuracy and record loss 96 | pred = output.data.gt(args.thre).long() 97 | 98 | tp += (pred + target).eq(2).sum(dim=0) 99 | fp += (pred - target).eq(1).sum(dim=0) 100 | fn += (pred - target).eq(-1).sum(dim=0) 101 | tn += (pred + target).eq(0).sum(dim=0) 102 | count += input.size(0) 103 | 104 | this_tp = (pred + target).eq(2).sum() 105 | this_fp = (pred - target).eq(1).sum() 106 | this_fn = (pred - target).eq(-1).sum() 107 | this_tn = (pred + target).eq(0).sum() 108 | 109 | this_prec = this_tp.float() / ( 110 | this_tp + this_fp).float() * 100.0 if this_tp + this_fp != 0 else 0.0 111 | this_rec = this_tp.float() / ( 112 | this_tp + this_fn).float() * 100.0 if this_tp + this_fn != 0 else 0.0 113 | 114 | prec.update(float(this_prec), input.size(0)) 115 | rec.update(float(this_rec), input.size(0)) 116 | 117 | # measure elapsed time 118 | batch_time.update(time.time() - end) 119 | end = time.time() 120 | 121 | p_c = [float(tp[i].float() / (tp[i] + fp[i]).float()) * 100.0 if tp[ 122 | i] > 0 else 0.0 123 | for i in range(len(tp))] 124 | r_c = [float(tp[i].float() / (tp[i] + fn[i]).float()) * 100.0 if tp[ 125 | i] > 0 else 0.0 126 | for i in range(len(tp))] 127 | f_c = [2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0 for 128 | i in range(len(tp))] 129 | 130 | mean_p_c = sum(p_c) / len(p_c) 131 | mean_r_c = sum(r_c) / len(r_c) 132 | mean_f_c = sum(f_c) / len(f_c) 133 | 134 | p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0 135 | r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0 136 | f_o = 2 * p_o * r_o / (p_o + r_o) 137 | 138 | if i % args.print_freq == 0: 139 | print('Test: [{0}/{1}]\t' 140 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 141 | 'Precision {prec.val:.2f} ({prec.avg:.2f})\t' 142 | 'Recall {rec.val:.2f} ({rec.avg:.2f})'.format( 143 | i, len(val_loader), batch_time=batch_time, 144 | prec=prec, rec=rec)) 145 | print( 146 | 'P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' 147 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) 148 | 149 | print( 150 | '--------------------------------------------------------------------') 151 | print(' * P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' 152 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) 153 | 154 | mAP_score = mAP(torch.cat(targets).numpy(), torch.cat(preds).numpy()) 155 | print("mAP score:", mAP_score) 156 | 157 | return 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | --------------------------------------------------------------------------------