├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── Speed_Accuracy_Comparisons.md ├── changes.txt ├── distributed_train.sh ├── inference.py ├── kd ├── __init__.py ├── helpers.py └── kd_utils.py ├── models ├── __init__.py ├── beit.py ├── byoanet.py ├── byobnet.py ├── cait.py ├── coat.py ├── convit.py ├── convmixer.py ├── convnext.py ├── crossvit.py ├── cspnet.py ├── deit.py ├── densenet.py ├── dla.py ├── dpn.py ├── efficientnet.py ├── efficientnet_blocks.py ├── efficientnet_builder.py ├── factory.py ├── features.py ├── fx_features.py ├── ghostnet.py ├── gluon_resnet.py ├── gluon_xception.py ├── hardcorenas.py ├── helpers.py ├── hrnet.py ├── hub.py ├── inception_resnet_v2.py ├── inception_v3.py ├── inception_v4.py ├── layers │ ├── __init__.py │ ├── activations.py │ ├── activations_jit.py │ ├── activations_me.py │ ├── adaptive_avgmax_pool.py │ ├── attention_pool2d.py │ ├── blur_pool.py │ ├── bottleneck_attn.py │ ├── cbam.py │ ├── classifier.py │ ├── cond_conv2d.py │ ├── config.py │ ├── conv2d_same.py │ ├── conv_bn_act.py │ ├── create_act.py │ ├── create_attn.py │ ├── create_conv2d.py │ ├── create_norm_act.py │ ├── drop.py │ ├── eca.py │ ├── evo_norm.py │ ├── filter_response_norm.py │ ├── gather_excite.py │ ├── global_context.py │ ├── halo_attn.py │ ├── helpers.py │ ├── inplace_abn.py │ ├── lambda_layer.py │ ├── linear.py │ ├── median_pool.py │ ├── mixed_conv2d.py │ ├── ml_decoder.py │ ├── mlp.py │ ├── non_local_attn.py │ ├── norm.py │ ├── norm_act.py │ ├── padding.py │ ├── patch_embed.py │ ├── pool2d_same.py │ ├── pos_embed.py │ ├── selective_kernel.py │ ├── separable_conv.py │ ├── space_to_depth.py │ ├── split_attn.py │ ├── split_batchnorm.py │ ├── squeeze_excite.py │ ├── std_conv.py │ ├── test_time_pool.py │ ├── trace_utils.py │ └── weight_init.py ├── levit.py ├── mlp_mixer.py ├── mobilenetv3.py ├── mobilevit.py ├── nasnet.py ├── nest.py ├── nfnet.py ├── pit.py ├── pnasnet.py ├── poolformer.py ├── pruned │ ├── ecaresnet101d_pruned.txt │ ├── ecaresnet50d_pruned.txt │ ├── efficientnet_b1_pruned.txt │ ├── efficientnet_b2_pruned.txt │ └── efficientnet_b3_pruned.txt ├── registry.py ├── regnet.py ├── res2net.py ├── resnest.py ├── resnet.py ├── resnetv2.py ├── rexnet.py ├── selecsls.py ├── senet.py ├── sknet.py ├── swin_transformer.py ├── swin_transformer_v2_cr.py ├── tnt.py ├── tresnet.py ├── tresnet_v2.py ├── twins.py ├── vgg.py ├── visformer.py ├── vision_transformer.py ├── vision_transformer_hybrid.py ├── volo.py ├── vovnet.py ├── xception.py ├── xception_aligned.py └── xcit.py ├── pics ├── CPU.png ├── GPU.png ├── pic1.png ├── pic2.png └── pic3.png ├── requirements.txt ├── test └── test_build_kd_model.py ├── train.py ├── utils └── checkpoint_saver.py └── validate.py /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | With USI, we can reliably identify models that provide [good speed-accuracy trade-off](Speed_Accuracy_Comparisons.md). 4 | 5 | For those top models, we provide here weights from large-scale pretraining on ImageNet-21K. We recommended using the large-scale weights for transfer learning - they almost always provide superior results on transfer, compared to 1K weights. 6 | 7 | | Backbone | 21K Single-label Pretraining weights | 21K Multi-label Pretraining weights | ImageNet-1K Accurcy [\%] | 8 | | :------------: | :--------------: | :--------------: | :--------------: | 9 | **TResNet-L** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_l_v2/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_l_v2/multi_label_ls.pth) | [83.9](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/USI/tresnet_l_v2_83_9.pth) | 10 | **TResNet-M** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_m/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_m/multi_label_ls.pth) | 82.5 | 11 | **ResNet50** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/resnet50/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/resnet50/multi_label_ls.pth) | 81.0 | 12 | **MobileNetV3_Large_100** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/mobilenetv3_large_100/single_label_ls.pth) | N/A | 77.3 | 13 | **LeViT-384** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_384/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_384/multi_label_ls.pth) | 82.7 | 14 | **LeViT-768** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_768/single_label_ls.pth) | N/A | [84.2](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_768/levit_768_84_2.pth) | 15 | **[EdgeNeXt-S](https://arxiv.org/abs/2206.10589)** |N/A | N/A | [81.1](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth) | 16 | 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results 2 | 3 | Official PyTorch Implementation 4 |
[Paper](http://arxiv.org/abs/2204.03475) | [Model Zoo](MODEL_ZOO.md) | [Speed-Accuracy Comparisons](Speed_Accuracy_Comparisons.md) | 5 | > Tal Ridnik, Hussam Lawen, Emanuel Ben-Baruch, Asaf Noy
DAMO Academy, Alibaba 6 | > Group 7 | 8 | **Abstract** 9 | 10 | ImageNet serves as the primary dataset for evaluating the quality of computer-vision models. The common practice today is training each architecture with a tailor-made scheme, designed and tuned by an expert. 11 | In this paper, we present a unified scheme for training any backbone on ImageNet. The scheme, named USI (Unified Scheme for ImageNet), is based on knowledge distillation and modern tricks. It requires no adjustments or hyper-parameters tuning between different models, and is efficient in terms of training times. 12 | We test USI on a wide variety of architectures, including CNNs, Transformers, Mobile-oriented and MLP-only. On all models tested, USI outperforms previous state-of-the-art results. Hence, we are able to transform training on ImageNet from an expert-oriented task to an automatic seamless routine. 13 | Since USI accepts any backbone and trains it to top results, it also enables to perform methodical comparisons, and identify the most efficient backbones along the speed-accuracy Pareto curve. 14 | 15 |

16 | 17 | 18 | 19 | 20 |
21 |

22 | 23 | ## 11/1/2023 Update 24 | Added [tests](https://github.com/Alibaba-MIIL/Solving_ImageNet/blob/main/test/test_build_kd_model.py) auto-generated by [CodiumAI](https://www.codium.ai/) tool 25 | 26 | ## How to Train on ImageNet with USI scheme 27 | The proposed USI scheme does not require hyper-parameter tuning. The base training configuration works well for any backbone. 28 | All the results presented in the paper are fully reproducible. 29 | 30 | First download teacher model weights from [here](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/USI/tresnet_l_v2_83_9.pth) 31 | 32 | An example code - training ResNet50 model with USI: 33 | ``` 34 | python3 -u -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train.py \ 35 | /mnt/Imagenet/ 36 | --model=resnet50 37 | --kd_model_name=tresnet_l_v2 38 | --kd_model_path=./tresnet_l_v2_83_9.pth 39 | ``` 40 | 41 | Some additional degrees of freedom that might be usefull: 42 | 43 | - Adjusting the batch size (defualt - 128): ```--batch-size=...``` 44 | - Training for more epochs (default - 300): ```--epochs=...``` 45 | 46 | 47 | ## Acknowledgements 48 | 49 | The training code is based on the excellent [timm repository](https://github.com/rwightman/pytorch-image-models). Also, thanks [EdgeNeXt](https://arxiv.org/pdf/2206.10589.pdf) authors for sharing their model. 50 | 51 | ## Citation 52 | ``` 53 | @misc{https://doi.org/10.48550/arxiv.2204.03475, 54 | doi = {10.48550/ARXIV.2204.03475}, 55 | url = {https://arxiv.org/abs/2204.03475}, 56 | author = {Ridnik, Tal and Lawen, Hussam and Ben-Baruch, Emanuel and Noy, Asaf}, 57 | keywords = {Computer Vision and Pattern Recognition (cs.CV), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, 58 | title = {Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results}, 59 | publisher = {arXiv}, 60 | year = {2022}, 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /Speed_Accuracy_Comparisons.md: -------------------------------------------------------------------------------- 1 | # GPU TensorRT Speed-Accuracy Comparison 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ### Analysis and insights for GPU inference results: 11 | 12 | - For low-to-medium accuracies, the most efficient models are LeViT-256 and LeViT-384 13 | - For higher accuracies (> 83.5%), TResNet-L and LeViT-768* models provides the best trade-off among the models tested 14 | - Other transformer models (Swin, TnT and ViT), provide inferior speed-accuracy trade-off compared to modern CNNs. Mobile-oriented models also do not excel on GPU inference 15 | - Note that several modern architectures, titled ”Small” (ConvNext-S, Swin-S, TnT-S), are in fact quite resource-intensive - their inference speed is more than four times slower compared to a plain ResNet50 model 16 | 17 | 18 | # CPU OpenVino Speed-Accuracy Comparison 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | ### Analysis and insights for CPU inference results: 28 | - On CPU, Mobile oriented models (OFA, MobileNet, EfficientNet) provide the best speed-accuracy trade-off 29 | - LeViT models, who excelled on GPU inference, are not as efficient for CPU inference. 30 | - Other transformer models (Swin, TnT and ViT) again provide inferior speed-accuracy trade-off. 31 | 32 | 33 | 34 | # Implementation Details 35 | 36 | On GPU, The throughput was tested with TensorRT inference engine, FP16, and batch size of 256. On CPU, The throughput was tested using Intel’s 37 | OpenVINO Inference Engine, FP16, a batch size of 1 and 16 streams (equivalent to the number of CPU cores). All measurements were done after models were optimized to inference by batch-norm fusion. This significantly accelerates 38 | models that utilize batch-norm, like ResNet50, TResNet and LeViT. Note that LeViT-768* model is not a part of the original paper, but a model defined by us, to test LeViT design on higher accuracies regime. -------------------------------------------------------------------------------- /changes.txt: -------------------------------------------------------------------------------- 1 | - loading KD model, line 377 2 | 3 | model_KD = build_kd_model(args) 4 | 5 | - KD logic, line 714 6 | 7 | # KD logic 8 | if model_KD is not None: 9 | 10 | - line 652. choose the top metric, instead of EMA only 11 | 12 | if ema_eval_metrics[eval_metric] > eval_metrics[eval_metric]: # choose the best model 13 | eval_metrics_unite = ema_eval_metrics 14 | 15 | - created new saver class, CheckpointSaverUSI. 16 | CheckpointSaverUSI solves two issues: 17 | 1) remove saving optimizer status, which causes problems if directory structure is changed 18 | save_state['state_dict'] = get_state_dict(self.model_ema, self.unwrap_fn) 19 | 2) save the correct model (EMA vs regular) 20 | 21 | - add tresnet_l_v2 and volo models 22 | 23 | - in validate.py, perform bn fusion: 24 | 25 | model.cpu().eval() 26 | from kd.helpers import InplacABN_to_ABN,fuse_bn2d_bn1d_abn 27 | model = InplacABN_to_ABN(model) 28 | model = fuse_bn2d_bn1d_abn(model) -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """PyTorch Inference Script 3 | 4 | An example inference script that outputs top-k class ids for images in a folder into a csv. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 7 | """ 8 | import os 9 | import time 10 | import argparse 11 | import logging 12 | import numpy as np 13 | import torch 14 | 15 | from models import create_model, apply_test_time_pool 16 | from timm.data import ImageDataset, create_loader, resolve_data_config 17 | from timm.utils import AverageMeter, setup_default_logging 18 | 19 | torch.backends.cudnn.benchmark = True 20 | _logger = logging.getLogger('inference') 21 | 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') 24 | parser.add_argument('data', metavar='DIR', 25 | help='path to dataset') 26 | parser.add_argument('--output_dir', metavar='DIR', default='./', 27 | help='path to output files') 28 | parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', 29 | help='model architecture (default: dpn92)') 30 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 31 | help='number of data loading workers (default: 2)') 32 | parser.add_argument('-b', '--batch-size', default=256, type=int, 33 | metavar='N', help='mini-batch size (default: 256)') 34 | parser.add_argument('--img-size', default=None, type=int, 35 | metavar='N', help='Input image dimension') 36 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 37 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 38 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 39 | help='Override mean pixel value of dataset') 40 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 41 | help='Override std deviation of of dataset') 42 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 43 | help='Image resize interpolation type (overrides model)') 44 | parser.add_argument('--num-classes', type=int, default=1000, 45 | help='Number classes in dataset') 46 | parser.add_argument('--log-freq', default=10, type=int, 47 | metavar='N', help='batch logging frequency (default: 10)') 48 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 49 | help='path to latest checkpoint (default: none)') 50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 51 | help='use pre-trained model') 52 | parser.add_argument('--num-gpu', type=int, default=1, 53 | help='Number of GPUS to use') 54 | parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', 55 | help='disable test time pool') 56 | parser.add_argument('--topk', default=5, type=int, 57 | metavar='N', help='Top-k to output to CSV') 58 | 59 | 60 | def main(): 61 | setup_default_logging() 62 | args = parser.parse_args() 63 | # might as well try to do something useful... 64 | args.pretrained = args.pretrained or not args.checkpoint 65 | 66 | # create model 67 | model = create_model( 68 | args.model, 69 | num_classes=args.num_classes, 70 | in_chans=3, 71 | pretrained=args.pretrained, 72 | checkpoint_path=args.checkpoint) 73 | 74 | _logger.info('Model %s created, param count: %d' % 75 | (args.model, sum([m.numel() for m in model.parameters()]))) 76 | 77 | config = resolve_data_config(vars(args), model=model) 78 | model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) 79 | 80 | if args.num_gpu > 1: 81 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() 82 | else: 83 | model = model.cuda() 84 | 85 | loader = create_loader( 86 | ImageDataset(args.data), 87 | input_size=config['input_size'], 88 | batch_size=args.batch_size, 89 | use_prefetcher=True, 90 | interpolation=config['interpolation'], 91 | mean=config['mean'], 92 | std=config['std'], 93 | num_workers=args.workers, 94 | crop_pct=1.0 if test_time_pool else config['crop_pct']) 95 | 96 | model.eval() 97 | 98 | k = min(args.topk, args.num_classes) 99 | batch_time = AverageMeter() 100 | end = time.time() 101 | topk_ids = [] 102 | with torch.no_grad(): 103 | for batch_idx, (input, _) in enumerate(loader): 104 | input = input.cuda() 105 | labels = model(input) 106 | topk = labels.topk(k)[1] 107 | topk_ids.append(topk.cpu().numpy()) 108 | 109 | # measure elapsed time 110 | batch_time.update(time.time() - end) 111 | end = time.time() 112 | 113 | if batch_idx % args.log_freq == 0: 114 | _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 115 | batch_idx, len(loader), batch_time=batch_time)) 116 | 117 | topk_ids = np.concatenate(topk_ids, axis=0) 118 | 119 | with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: 120 | filenames = loader.dataset.filenames(basename=True) 121 | for filename, label in zip(filenames, topk_ids): 122 | out_file.write('{0},{1}\n'.format( 123 | filename, ','.join([ str(v) for v in label]))) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /kd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/kd/__init__.py -------------------------------------------------------------------------------- /kd/kd_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.factory import create_model 3 | from kd.helpers import fuse_bn2d_bn1d_abn, InplacABN_to_ABN 4 | import torchvision.transforms as T 5 | 6 | 7 | class build_kd_model(nn.Module): 8 | def __init__(self, args=None): 9 | super(build_kd_model, self).__init__() 10 | 11 | model_kd = create_model( 12 | model_name=args.kd_model_name, 13 | checkpoint_path=args.kd_model_path, 14 | # pretrained=False, 15 | pretrained=args.kd_model_path is None, 16 | num_classes=args.num_classes, 17 | in_chans=3) 18 | 19 | model_kd.cpu().eval() 20 | model_kd = InplacABN_to_ABN(model_kd) 21 | model_kd = fuse_bn2d_bn1d_abn(model_kd) 22 | self.model = model_kd.cuda().eval() 23 | self.mean_model_kd = model_kd.default_cfg['mean'] 24 | self.std_model_kd = model_kd.default_cfg['std'] 25 | 26 | # handling different normalization of teacher and student 27 | def normalize_input(self, input, student_model): 28 | if hasattr(student_model, 'module'): 29 | model_s = student_model.module 30 | else: 31 | model_s = student_model 32 | 33 | mean_student = model_s.default_cfg['mean'] 34 | std_student = model_s.default_cfg['std'] 35 | 36 | input_kd = input 37 | if mean_student != self.mean_model_kd or std_student != self.std_model_kd: 38 | std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1], 39 | self.std_model_kd[2] / std_student[2]) 40 | transform_std = T.Normalize(mean=(0, 0, 0), std=std) 41 | 42 | mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1], 43 | self.mean_model_kd[2] - mean_student[2]) 44 | transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) 45 | 46 | input_kd = transform_mean(transform_std(input)) 47 | 48 | return input_kd 49 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beit import * 2 | from .byoanet import * 3 | from .byobnet import * 4 | from .cait import * 5 | from .coat import * 6 | from .convit import * 7 | from .convmixer import * 8 | from .convnext import * 9 | from .crossvit import * 10 | from .cspnet import * 11 | from .deit import * 12 | from .densenet import * 13 | from .dla import * 14 | from .dpn import * 15 | from .efficientnet import * 16 | from .ghostnet import * 17 | from .gluon_resnet import * 18 | from .gluon_xception import * 19 | from .hardcorenas import * 20 | from .hrnet import * 21 | from .inception_resnet_v2 import * 22 | from .inception_v3 import * 23 | from .inception_v4 import * 24 | from .levit import * 25 | from .mlp_mixer import * 26 | from .mobilenetv3 import * 27 | from .mobilevit import * 28 | from .nasnet import * 29 | from .nest import * 30 | from .nfnet import * 31 | from .pit import * 32 | from .pnasnet import * 33 | from .poolformer import * 34 | from .regnet import * 35 | from .res2net import * 36 | from .resnest import * 37 | from .resnet import * 38 | from .resnetv2 import * 39 | from .rexnet import * 40 | from .selecsls import * 41 | from .senet import * 42 | from .sknet import * 43 | from .swin_transformer import * 44 | from .swin_transformer_v2_cr import * 45 | from .tnt import * 46 | from .tresnet import * 47 | from .tresnet_v2 import * 48 | from .twins import * 49 | from .vgg import * 50 | from .volo import * 51 | from .visformer import * 52 | from .vision_transformer import * 53 | from .vision_transformer_hybrid import * 54 | from .volo import * 55 | from .vovnet import * 56 | from .xception import * 57 | from .xception_aligned import * 58 | from .xcit import * 59 | 60 | from .factory import create_model, parse_model_name, safe_model_name 61 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters 62 | from .layers import TestTimePoolHead, apply_test_time_pool 63 | from .layers import convert_splitbn_model 64 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 65 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ 66 | is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value 67 | -------------------------------------------------------------------------------- /models/convmixer.py: -------------------------------------------------------------------------------- 1 | """ ConvMixer 2 | 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from models.registry import register_model 9 | from .helpers import build_model_with_cfg, checkpoint_seq 10 | from .layers import SelectAdaptivePool2d 11 | 12 | 13 | def _cfg(url='', **kwargs): 14 | return { 15 | 'url': url, 16 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 17 | 'crop_pct': .96, 'interpolation': 'bicubic', 18 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', 19 | 'first_conv': 'stem.0', 20 | **kwargs 21 | } 22 | 23 | 24 | default_cfgs = { 25 | 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'), 26 | 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'), 27 | 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar') 28 | } 29 | 30 | 31 | class Residual(nn.Module): 32 | def __init__(self, fn): 33 | super().__init__() 34 | self.fn = fn 35 | 36 | def forward(self, x): 37 | return self.fn(x) + x 38 | 39 | 40 | class ConvMixer(nn.Module): 41 | def __init__( 42 | self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg', 43 | act_layer=nn.GELU, **kwargs): 44 | super().__init__() 45 | self.num_classes = num_classes 46 | self.num_features = dim 47 | self.grad_checkpointing = False 48 | 49 | self.stem = nn.Sequential( 50 | nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), 51 | act_layer(), 52 | nn.BatchNorm2d(dim) 53 | ) 54 | self.blocks = nn.Sequential( 55 | *[nn.Sequential( 56 | Residual(nn.Sequential( 57 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), 58 | act_layer(), 59 | nn.BatchNorm2d(dim) 60 | )), 61 | nn.Conv2d(dim, dim, kernel_size=1), 62 | act_layer(), 63 | nn.BatchNorm2d(dim) 64 | ) for i in range(depth)] 65 | ) 66 | self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) 67 | self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() 68 | 69 | @torch.jit.ignore 70 | def group_matcher(self, coarse=False): 71 | matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)') 72 | return matcher 73 | 74 | @torch.jit.ignore 75 | def set_grad_checkpointing(self, enable=True): 76 | self.grad_checkpointing = enable 77 | 78 | @torch.jit.ignore 79 | def get_classifier(self): 80 | return self.head 81 | 82 | def reset_classifier(self, num_classes, global_pool=None): 83 | self.num_classes = num_classes 84 | if global_pool is not None: 85 | self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) 86 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 87 | 88 | def forward_features(self, x): 89 | x = self.stem(x) 90 | if self.grad_checkpointing and not torch.jit.is_scripting(): 91 | x = checkpoint_seq(self.blocks, x) 92 | else: 93 | x = self.blocks(x) 94 | return x 95 | 96 | def forward_head(self, x, pre_logits: bool = False): 97 | x = self.pooling(x) 98 | return x if pre_logits else self.head(x) 99 | 100 | def forward(self, x): 101 | x = self.forward_features(x) 102 | x = self.forward_head(x) 103 | return x 104 | 105 | 106 | def _create_convmixer(variant, pretrained=False, **kwargs): 107 | return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs) 108 | 109 | 110 | @register_model 111 | def convmixer_1536_20(pretrained=False, **kwargs): 112 | model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) 113 | return _create_convmixer('convmixer_1536_20', pretrained, **model_args) 114 | 115 | 116 | @register_model 117 | def convmixer_768_32(pretrained=False, **kwargs): 118 | model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs) 119 | return _create_convmixer('convmixer_768_32', pretrained, **model_args) 120 | 121 | 122 | @register_model 123 | def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs): 124 | model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) 125 | return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) -------------------------------------------------------------------------------- /models/factory.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlsplit, urlunsplit 2 | import os 3 | 4 | from .registry import is_model, is_model_in_modules, model_entrypoint 5 | from .helpers import load_checkpoint 6 | from .layers import set_layer_config 7 | from .hub import load_model_config_from_hf 8 | 9 | 10 | def parse_model_name(model_name): 11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use 12 | parsed = urlsplit(model_name) 13 | assert parsed.scheme in ('', 'timm', 'hf-hub') 14 | if parsed.scheme == 'hf-hub': 15 | # FIXME may use fragment as revision, currently `@` in URI path 16 | return parsed.scheme, parsed.path 17 | else: 18 | model_name = os.path.split(parsed.path)[-1] 19 | return 'timm', model_name 20 | 21 | 22 | def safe_model_name(model_name, remove_source=True): 23 | def make_safe(name): 24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') 25 | if remove_source: 26 | model_name = parse_model_name(model_name)[-1] 27 | return make_safe(model_name) 28 | 29 | 30 | def create_model( 31 | model_name, 32 | pretrained=False, 33 | pretrained_cfg=None, 34 | checkpoint_path='', 35 | scriptable=None, 36 | exportable=None, 37 | no_jit=None, 38 | **kwargs): 39 | """Create a model 40 | 41 | Args: 42 | model_name (str): name of model to instantiate 43 | pretrained (bool): load pretrained ImageNet-1k weights if true 44 | checkpoint_path (str): path of checkpoint to load after model is initialized 45 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 46 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 47 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 48 | 49 | Keyword Args: 50 | drop_rate (float): dropout rate for training (default: 0.0) 51 | global_pool (str): global pool type (default: 'avg') 52 | **: other kwargs are model specific 53 | """ 54 | # Parameters that aren't supported by all models or are intended to only override model defaults if set 55 | # should default to None in command line args/cfg. Remove them if they are present and not set so that 56 | # non-supporting models don't break and default args remain in effect. 57 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 58 | 59 | model_source, model_name = parse_model_name(model_name) 60 | if model_source == 'hf-hub': 61 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? 62 | # For model names specified in the form `hf-hub:path/architecture_name@revision`, 63 | # load model weights + pretrained_cfg from Hugging Face hub. 64 | pretrained_cfg, model_name = load_model_config_from_hf(model_name) 65 | 66 | if not is_model(model_name): 67 | raise RuntimeError('Unknown model (%s)' % model_name) 68 | 69 | create_fn = model_entrypoint(model_name) 70 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 71 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) 72 | 73 | if checkpoint_path: 74 | load_checkpoint(model, checkpoint_path) 75 | 76 | return model 77 | -------------------------------------------------------------------------------- /models/fx_features.py: -------------------------------------------------------------------------------- 1 | """ PyTorch FX Based Feature Extraction Helpers 2 | Using https://pytorch.org/vision/stable/feature_extraction.html 3 | """ 4 | from typing import Callable, List, Dict, Union 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from .features import _get_feature_info 10 | 11 | try: 12 | from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor 13 | has_fx_feature_extraction = True 14 | except ImportError: 15 | has_fx_feature_extraction = False 16 | 17 | # Layers we went to treat as leaf modules 18 | from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame 19 | from .layers.non_local_attn import BilinearAttnTransform 20 | from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame 21 | 22 | # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here 23 | # BUT modules from timm.models should use the registration mechanism below 24 | _leaf_modules = { 25 | BilinearAttnTransform, # reason: flow control t <= 1 26 | # Reason: get_same_padding has a max which raises a control flow error 27 | Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, 28 | CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) 29 | } 30 | 31 | try: 32 | from .layers import InplaceAbn 33 | _leaf_modules.add(InplaceAbn) 34 | except ImportError: 35 | pass 36 | 37 | 38 | def register_notrace_module(module: nn.Module): 39 | """ 40 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it. 41 | """ 42 | _leaf_modules.add(module) 43 | return module 44 | 45 | 46 | # Functions we want to autowrap (treat them as leaves) 47 | _autowrap_functions = set() 48 | 49 | 50 | def register_notrace_function(func: Callable): 51 | """ 52 | Decorator for functions which ought not to be traced through 53 | """ 54 | _autowrap_functions.add(func) 55 | return func 56 | 57 | 58 | def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): 59 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 60 | return _create_feature_extractor( 61 | model, return_nodes, 62 | tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} 63 | ) 64 | 65 | 66 | class FeatureGraphNet(nn.Module): 67 | """ A FX Graph based feature extractor that works with the model feature_info metadata 68 | """ 69 | def __init__(self, model, out_indices, out_map=None): 70 | super().__init__() 71 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' 72 | self.feature_info = _get_feature_info(model, out_indices) 73 | if out_map is not None: 74 | assert len(out_map) == len(out_indices) 75 | return_nodes = { 76 | info['module']: out_map[i] if out_map is not None else info['module'] 77 | for i, info in enumerate(self.feature_info) if i in out_indices} 78 | self.graph_module = create_feature_extractor(model, return_nodes) 79 | 80 | def forward(self, x): 81 | return list(self.graph_module(x).values()) 82 | 83 | 84 | class GraphExtractNet(nn.Module): 85 | """ A standalone feature extraction wrapper that maps dict -> list or single tensor 86 | NOTE: 87 | * one can use feature_extractor directly if dictionary output is desired 88 | * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info 89 | metadata for builtin feature extraction mode 90 | * create_feature_extractor can be used directly if dictionary output is acceptable 91 | 92 | Args: 93 | model: model to extract features from 94 | return_nodes: node names to return features from (dict or list) 95 | squeeze_out: if only one output, and output in list format, flatten to single tensor 96 | """ 97 | def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): 98 | super().__init__() 99 | self.squeeze_out = squeeze_out 100 | self.graph_module = create_feature_extractor(model, return_nodes) 101 | 102 | def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: 103 | out = list(self.graph_module(x).values()) 104 | if self.squeeze_out and len(out) == 1: 105 | return out[0] 106 | return out 107 | -------------------------------------------------------------------------------- /models/hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import torch 9 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse 10 | try: 11 | from torch.hub import get_dir 12 | except ImportError: 13 | from torch.hub import _get_torch_home as get_dir 14 | 15 | from timm import __version__ 16 | try: 17 | from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url 18 | cached_download = partial(cached_download, library_name="timm", library_version=__version__) 19 | _has_hf_hub = True 20 | except ImportError: 21 | cached_download = None 22 | _has_hf_hub = False 23 | 24 | _logger = logging.getLogger(__name__) 25 | 26 | 27 | def get_cache_dir(child_dir=''): 28 | """ 29 | Returns the location of the directory where models are cached (and creates it if necessary). 30 | """ 31 | # Issue warning to move data if old env is set 32 | if os.getenv('TORCH_MODEL_ZOO'): 33 | _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 34 | 35 | hub_dir = get_dir() 36 | child_dir = () if not child_dir else (child_dir,) 37 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 38 | os.makedirs(model_dir, exist_ok=True) 39 | return model_dir 40 | 41 | 42 | def download_cached_file(url, check_hash=True, progress=False): 43 | parts = urlparse(url) 44 | filename = os.path.basename(parts.path) 45 | cached_file = os.path.join(get_cache_dir(), filename) 46 | if not os.path.exists(cached_file): 47 | _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) 48 | hash_prefix = None 49 | if check_hash: 50 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]] 51 | hash_prefix = r.group(1) if r else None 52 | download_url_to_file(url, cached_file, hash_prefix, progress=progress) 53 | return cached_file 54 | 55 | 56 | def has_hf_hub(necessary=False): 57 | if not _has_hf_hub and necessary: 58 | # if no HF Hub module installed and it is necessary to continue, raise error 59 | raise RuntimeError( 60 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 61 | return _has_hf_hub 62 | 63 | 64 | def hf_split(hf_id): 65 | # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme 66 | rev_split = hf_id.split('@') 67 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' 68 | hf_model_id = rev_split[0] 69 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None 70 | return hf_model_id, hf_revision 71 | 72 | 73 | def load_cfg_from_json(json_file: Union[str, os.PathLike]): 74 | with open(json_file, "r", encoding="utf-8") as reader: 75 | text = reader.read() 76 | return json.loads(text) 77 | 78 | 79 | def _download_from_hf(model_id: str, filename: str): 80 | hf_model_id, hf_revision = hf_split(model_id) 81 | url = hf_hub_url(hf_model_id, filename, revision=hf_revision) 82 | return cached_download(url, cache_dir=get_cache_dir('hf')) 83 | 84 | 85 | def load_model_config_from_hf(model_id: str): 86 | assert has_hf_hub(True) 87 | cached_file = _download_from_hf(model_id, 'config.json') 88 | pretrained_cfg = load_cfg_from_json(cached_file) 89 | pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation 90 | pretrained_cfg['source'] = 'hf-hub' 91 | model_name = pretrained_cfg.get('architecture') 92 | return pretrained_cfg, model_name 93 | 94 | 95 | def load_state_dict_from_hf(model_id: str): 96 | assert has_hf_hub(True) 97 | cached_file = _download_from_hf(model_id, 'pytorch_model.bin') 98 | state_dict = torch.load(cached_file, map_location='cpu') 99 | return state_dict 100 | 101 | 102 | def save_for_hf(model, save_directory, model_config=None): 103 | assert has_hf_hub(True) 104 | model_config = model_config or {} 105 | save_directory = Path(save_directory) 106 | save_directory.mkdir(exist_ok=True, parents=True) 107 | 108 | weights_path = save_directory / 'pytorch_model.bin' 109 | torch.save(model.state_dict(), weights_path) 110 | 111 | config_path = save_directory / 'config.json' 112 | hf_config = model.pretrained_cfg 113 | hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) 114 | hf_config['num_features'] = model_config.pop('num_features', model.num_features) 115 | hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) 116 | hf_config.update(model_config) 117 | 118 | with config_path.open('w') as f: 119 | json.dump(hf_config, f, indent=2) 120 | 121 | 122 | def push_to_hf_hub( 123 | model, 124 | local_dir, 125 | repo_namespace_or_url=None, 126 | commit_message='Add model', 127 | use_auth_token=True, 128 | git_email=None, 129 | git_user=None, 130 | revision=None, 131 | model_config=None, 132 | ): 133 | if repo_namespace_or_url: 134 | repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:] 135 | else: 136 | if isinstance(use_auth_token, str): 137 | token = use_auth_token 138 | else: 139 | token = HfFolder.get_token() 140 | 141 | if token is None: 142 | raise ValueError( 143 | "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " 144 | "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " 145 | "token as the `use_auth_token` argument." 146 | ) 147 | 148 | repo_owner = HfApi().whoami(token)['name'] 149 | repo_name = Path(local_dir).name 150 | 151 | repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}' 152 | 153 | repo = Repository( 154 | local_dir, 155 | clone_from=repo_url, 156 | use_auth_token=use_auth_token, 157 | git_user=git_user, 158 | git_email=git_email, 159 | revision=revision, 160 | ) 161 | 162 | # Prepare a default model card that includes the necessary tags to enable inference. 163 | readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}' 164 | with repo.commit(commit_message): 165 | # Save model weights and config. 166 | save_for_hf(model, repo.local_dir, model_config=model_config) 167 | 168 | # Save a model card if it doesn't exist. 169 | readme_path = Path(repo.local_dir) / 'README.md' 170 | if not readme_path.exists(): 171 | readme_path.write_text(readme_text) 172 | 173 | return repo.git_remote_url() 174 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 15 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 16 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 17 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 18 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 19 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 20 | from .gather_excite import GatherExcite 21 | from .global_context import GlobalContext 22 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 23 | from .inplace_abn import InplaceAbn 24 | from .linear import Linear 25 | from .mixed_conv2d import MixedConv2d 26 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp 27 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 28 | from .norm import GroupNorm, LayerNorm2d 29 | from .norm_act import BatchNormAct2d, GroupNormAct 30 | from .padding import get_padding, get_same_padding, pad_same 31 | from .patch_embed import PatchEmbed 32 | from .pool2d_same import AvgPool2dSame, create_pool2d 33 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 34 | from .selective_kernel import SelectiveKernel 35 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 36 | from .space_to_depth import SpaceToDepthModule 37 | from .split_attn import SplitAttn 38 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 39 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 40 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 41 | from .trace_utils import _assert, _float_to_int 42 | from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ 43 | -------------------------------------------------------------------------------- /models/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | 123 | 124 | class PReLU(nn.PReLU): 125 | """Applies PReLU (w/ dummy inplace arg) 126 | """ 127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: 128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init) 129 | 130 | def forward(self, input: torch.Tensor) -> torch.Tensor: 131 | return F.prelu(input, self.weight) 132 | 133 | 134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: 135 | return F.gelu(x) 136 | 137 | 138 | class GELU(nn.Module): 139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) 140 | """ 141 | def __init__(self, inplace: bool = False): 142 | super(GELU, self).__init__() 143 | 144 | def forward(self, input: torch.Tensor) -> torch.Tensor: 145 | return F.gelu(input) 146 | -------------------------------------------------------------------------------- /models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /models/layers/activations_me.py: -------------------------------------------------------------------------------- 1 | """ Activations (memory-efficient w/ custom autograd) 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | These activations are not compatible with jit scripting or ONNX export of the model, please use either 7 | the JIT or basic versions of the activations. 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | """ 11 | 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | 17 | @torch.jit.script 18 | def swish_jit_fwd(x): 19 | return x.mul(torch.sigmoid(x)) 20 | 21 | 22 | @torch.jit.script 23 | def swish_jit_bwd(x, grad_output): 24 | x_sigmoid = torch.sigmoid(x) 25 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) 26 | 27 | 28 | class SwishJitAutoFn(torch.autograd.Function): 29 | """ torch.jit.script optimised Swish w/ memory-efficient checkpoint 30 | Inspired by conversation btw Jeremy Howard & Adam Pazske 31 | https://twitter.com/jeremyphoward/status/1188251041835315200 32 | """ 33 | @staticmethod 34 | def symbolic(g, x): 35 | return g.op("Mul", x, g.op("Sigmoid", x)) 36 | 37 | @staticmethod 38 | def forward(ctx, x): 39 | ctx.save_for_backward(x) 40 | return swish_jit_fwd(x) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | x = ctx.saved_tensors[0] 45 | return swish_jit_bwd(x, grad_output) 46 | 47 | 48 | def swish_me(x, inplace=False): 49 | return SwishJitAutoFn.apply(x) 50 | 51 | 52 | class SwishMe(nn.Module): 53 | def __init__(self, inplace: bool = False): 54 | super(SwishMe, self).__init__() 55 | 56 | def forward(self, x): 57 | return SwishJitAutoFn.apply(x) 58 | 59 | 60 | @torch.jit.script 61 | def mish_jit_fwd(x): 62 | return x.mul(torch.tanh(F.softplus(x))) 63 | 64 | 65 | @torch.jit.script 66 | def mish_jit_bwd(x, grad_output): 67 | x_sigmoid = torch.sigmoid(x) 68 | x_tanh_sp = F.softplus(x).tanh() 69 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) 70 | 71 | 72 | class MishJitAutoFn(torch.autograd.Function): 73 | """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 74 | A memory efficient, jit scripted variant of Mish 75 | """ 76 | @staticmethod 77 | def forward(ctx, x): 78 | ctx.save_for_backward(x) 79 | return mish_jit_fwd(x) 80 | 81 | @staticmethod 82 | def backward(ctx, grad_output): 83 | x = ctx.saved_tensors[0] 84 | return mish_jit_bwd(x, grad_output) 85 | 86 | 87 | def mish_me(x, inplace=False): 88 | return MishJitAutoFn.apply(x) 89 | 90 | 91 | class MishMe(nn.Module): 92 | def __init__(self, inplace: bool = False): 93 | super(MishMe, self).__init__() 94 | 95 | def forward(self, x): 96 | return MishJitAutoFn.apply(x) 97 | 98 | 99 | @torch.jit.script 100 | def hard_sigmoid_jit_fwd(x, inplace: bool = False): 101 | return (x + 3).clamp(min=0, max=6).div(6.) 102 | 103 | 104 | @torch.jit.script 105 | def hard_sigmoid_jit_bwd(x, grad_output): 106 | m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. 107 | return grad_output * m 108 | 109 | 110 | class HardSigmoidJitAutoFn(torch.autograd.Function): 111 | @staticmethod 112 | def forward(ctx, x): 113 | ctx.save_for_backward(x) 114 | return hard_sigmoid_jit_fwd(x) 115 | 116 | @staticmethod 117 | def backward(ctx, grad_output): 118 | x = ctx.saved_tensors[0] 119 | return hard_sigmoid_jit_bwd(x, grad_output) 120 | 121 | 122 | def hard_sigmoid_me(x, inplace: bool = False): 123 | return HardSigmoidJitAutoFn.apply(x) 124 | 125 | 126 | class HardSigmoidMe(nn.Module): 127 | def __init__(self, inplace: bool = False): 128 | super(HardSigmoidMe, self).__init__() 129 | 130 | def forward(self, x): 131 | return HardSigmoidJitAutoFn.apply(x) 132 | 133 | 134 | @torch.jit.script 135 | def hard_swish_jit_fwd(x): 136 | return x * (x + 3).clamp(min=0, max=6).div(6.) 137 | 138 | 139 | @torch.jit.script 140 | def hard_swish_jit_bwd(x, grad_output): 141 | m = torch.ones_like(x) * (x >= 3.) 142 | m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) 143 | return grad_output * m 144 | 145 | 146 | class HardSwishJitAutoFn(torch.autograd.Function): 147 | """A memory efficient, jit-scripted HardSwish activation""" 148 | @staticmethod 149 | def forward(ctx, x): 150 | ctx.save_for_backward(x) 151 | return hard_swish_jit_fwd(x) 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | x = ctx.saved_tensors[0] 156 | return hard_swish_jit_bwd(x, grad_output) 157 | 158 | @staticmethod 159 | def symbolic(g, self): 160 | input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) 161 | hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) 162 | hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) 163 | return g.op("Mul", self, hardtanh_) 164 | 165 | 166 | def hard_swish_me(x, inplace=False): 167 | return HardSwishJitAutoFn.apply(x) 168 | 169 | 170 | class HardSwishMe(nn.Module): 171 | def __init__(self, inplace: bool = False): 172 | super(HardSwishMe, self).__init__() 173 | 174 | def forward(self, x): 175 | return HardSwishJitAutoFn.apply(x) 176 | 177 | 178 | @torch.jit.script 179 | def hard_mish_jit_fwd(x): 180 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 181 | 182 | 183 | @torch.jit.script 184 | def hard_mish_jit_bwd(x, grad_output): 185 | m = torch.ones_like(x) * (x >= -2.) 186 | m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) 187 | return grad_output * m 188 | 189 | 190 | class HardMishJitAutoFn(torch.autograd.Function): 191 | """ A memory efficient, jit scripted variant of Hard Mish 192 | Experimental, based on notes by Mish author Diganta Misra at 193 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 194 | """ 195 | @staticmethod 196 | def forward(ctx, x): 197 | ctx.save_for_backward(x) 198 | return hard_mish_jit_fwd(x) 199 | 200 | @staticmethod 201 | def backward(ctx, grad_output): 202 | x = ctx.saved_tensors[0] 203 | return hard_mish_jit_bwd(x, grad_output) 204 | 205 | 206 | def hard_mish_me(x, inplace: bool = False): 207 | return HardMishJitAutoFn.apply(x) 208 | 209 | 210 | class HardMishMe(nn.Module): 211 | def __init__(self, inplace: bool = False): 212 | super(HardMishMe, self).__init__() 213 | 214 | def forward(self, x): 215 | return HardMishJitAutoFn.apply(x) 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /models/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3), keepdim=not self.flatten) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = nn.Flatten(1) if flatten else nn.Identity() 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(flatten) 91 | self.flatten = nn.Identity() 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return not self.pool_type 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | x = self.flatten(x) 109 | return x 110 | 111 | def feat_mult(self): 112 | return adaptive_pool_feat_mult(self.pool_type) 113 | 114 | def __repr__(self): 115 | return self.__class__.__name__ + ' (' \ 116 | + 'pool_type=' + self.pool_type \ 117 | + ', flatten=' + str(self.flatten) + ')' 118 | 119 | -------------------------------------------------------------------------------- /models/layers/attention_pool2d.py: -------------------------------------------------------------------------------- 1 | """ Attention Pool 2D 2 | 3 | Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. 4 | 5 | Based on idea in CLIP by OpenAI, licensed Apache 2.0 6 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from typing import Union, Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from .helpers import to_2tuple 16 | from .pos_embed import apply_rot_embed, RotaryEmbedding 17 | from .weight_init import trunc_normal_ 18 | 19 | 20 | class RotAttentionPool2d(nn.Module): 21 | """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. 22 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. 23 | 24 | Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. 25 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 26 | 27 | NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from 28 | train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW 29 | """ 30 | def __init__( 31 | self, 32 | in_features: int, 33 | out_features: int = None, 34 | embed_dim: int = None, 35 | num_heads: int = 4, 36 | qkv_bias: bool = True, 37 | ): 38 | super().__init__() 39 | embed_dim = embed_dim or in_features 40 | out_features = out_features or in_features 41 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) 42 | self.proj = nn.Linear(embed_dim, out_features) 43 | self.num_heads = num_heads 44 | assert embed_dim % num_heads == 0 45 | self.head_dim = embed_dim // num_heads 46 | self.scale = self.head_dim ** -0.5 47 | self.pos_embed = RotaryEmbedding(self.head_dim) 48 | 49 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) 50 | nn.init.zeros_(self.qkv.bias) 51 | 52 | def forward(self, x): 53 | B, _, H, W = x.shape 54 | N = H * W 55 | x = x.reshape(B, -1, N).permute(0, 2, 1) 56 | 57 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) 58 | 59 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 60 | q, k, v = x[0], x[1], x[2] 61 | 62 | qc, q = q[:, :, :1], q[:, :, 1:] 63 | sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) 64 | q = apply_rot_embed(q, sin_emb, cos_emb) 65 | q = torch.cat([qc, q], dim=2) 66 | 67 | kc, k = k[:, :, :1], k[:, :, 1:] 68 | k = apply_rot_embed(k, sin_emb, cos_emb) 69 | k = torch.cat([kc, k], dim=2) 70 | 71 | attn = (q @ k.transpose(-2, -1)) * self.scale 72 | attn = attn.softmax(dim=-1) 73 | 74 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) 75 | x = self.proj(x) 76 | return x[:, 0] 77 | 78 | 79 | class AttentionPool2d(nn.Module): 80 | """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. 81 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. 82 | 83 | It was based on impl in CLIP by OpenAI 84 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py 85 | 86 | NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. 87 | """ 88 | def __init__( 89 | self, 90 | in_features: int, 91 | feat_size: Union[int, Tuple[int, int]], 92 | out_features: int = None, 93 | embed_dim: int = None, 94 | num_heads: int = 4, 95 | qkv_bias: bool = True, 96 | ): 97 | super().__init__() 98 | 99 | embed_dim = embed_dim or in_features 100 | out_features = out_features or in_features 101 | assert embed_dim % num_heads == 0 102 | self.feat_size = to_2tuple(feat_size) 103 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) 104 | self.proj = nn.Linear(embed_dim, out_features) 105 | self.num_heads = num_heads 106 | self.head_dim = embed_dim // num_heads 107 | self.scale = self.head_dim ** -0.5 108 | 109 | spatial_dim = self.feat_size[0] * self.feat_size[1] 110 | self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) 111 | trunc_normal_(self.pos_embed, std=in_features ** -0.5) 112 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5) 113 | nn.init.zeros_(self.qkv.bias) 114 | 115 | def forward(self, x): 116 | B, _, H, W = x.shape 117 | N = H * W 118 | assert self.feat_size[0] == H 119 | assert self.feat_size[1] == W 120 | x = x.reshape(B, -1, N).permute(0, 2, 1) 121 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1) 122 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype) 123 | 124 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 125 | q, k, v = x[0], x[1], x[2] 126 | attn = (q @ k.transpose(-2, -1)) * self.scale 127 | attn = attn.softmax(dim=-1) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) 130 | x = self.proj(x) 131 | return x[:, 0] 132 | -------------------------------------------------------------------------------- /models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /models/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | import torch 11 | from torch import nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .conv_bn_act import ConvNormAct 15 | from .create_act import create_act_layer, get_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class ChannelAttn(nn.Module): 20 | """ Original CBAM channel attention module, currently avg + max pool variant only. 21 | """ 22 | def __init__( 23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 25 | super(ChannelAttn, self).__init__() 26 | if not rd_channels: 27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) 29 | self.act = act_layer(inplace=True) 30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) 31 | self.gate = create_act_layer(gate_layer) 32 | 33 | def forward(self, x): 34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) 35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) 36 | return x * self.gate(x_avg + x_max) 37 | 38 | 39 | class LightChannelAttn(ChannelAttn): 40 | """An experimental 'lightweight' that sums avg + max pool first 41 | """ 42 | def __init__( 43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 45 | super(LightChannelAttn, self).__init__( 46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) 47 | 48 | def forward(self, x): 49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) 50 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 51 | return x * F.sigmoid(x_attn) 52 | 53 | 54 | class SpatialAttn(nn.Module): 55 | """ Original CBAM spatial attention module 56 | """ 57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 58 | super(SpatialAttn, self).__init__() 59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) 60 | self.gate = create_act_layer(gate_layer) 61 | 62 | def forward(self, x): 63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) 64 | x_attn = self.conv(x_attn) 65 | return x * self.gate(x_attn) 66 | 67 | 68 | class LightSpatialAttn(nn.Module): 69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 70 | """ 71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 72 | super(LightSpatialAttn, self).__init__() 73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) 74 | self.gate = create_act_layer(gate_layer) 75 | 76 | def forward(self, x): 77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) 78 | x_attn = self.conv(x_attn) 79 | return x * self.gate(x_attn) 80 | 81 | 82 | class CbamModule(nn.Module): 83 | def __init__( 84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 86 | super(CbamModule, self).__init__() 87 | self.channel = ChannelAttn( 88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) 91 | 92 | def forward(self, x): 93 | x = self.channel(x) 94 | x = self.spatial(x) 95 | return x 96 | 97 | 98 | class LightCbamModule(nn.Module): 99 | def __init__( 100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 102 | super(LightCbamModule, self).__init__() 103 | self.channel = LightChannelAttn( 104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 106 | self.spatial = LightSpatialAttn(spatial_kernel_size) 107 | 108 | def forward(self, x): 109 | x = self.channel(x) 110 | x = self.spatial(x) 111 | return x 112 | 113 | -------------------------------------------------------------------------------- /models/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | 10 | 11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 13 | if not pool_type: 14 | assert num_classes == 0 or use_conv,\ 15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 18 | num_pooled_features = num_features * global_pool.feat_mult() 19 | return global_pool, num_pooled_features 20 | 21 | 22 | def _create_fc(num_features, num_classes, use_conv=False): 23 | if num_classes <= 0: 24 | fc = nn.Identity() # pass-through (no classifier) 25 | elif use_conv: 26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 27 | else: 28 | fc = nn.Linear(num_features, num_classes, bias=True) 29 | return fc 30 | 31 | 32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 35 | return global_pool, fc 36 | 37 | 38 | class ClassifierHead(nn.Module): 39 | """Classifier head w/ configurable global pooling and dropout.""" 40 | 41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 42 | super(ClassifierHead, self).__init__() 43 | self.drop_rate = drop_rate 44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 47 | 48 | def forward(self, x, pre_logits: bool = False): 49 | x = self.global_pool(x) 50 | if self.drop_rate: 51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 52 | if pre_logits: 53 | return x.flatten(1) 54 | else: 55 | x = self.fc(x) 56 | return self.flatten(x) 57 | -------------------------------------------------------------------------------- /models/layers/cond_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Conditionally Parameterized Convolution (CondConv) 2 | 3 | Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference 4 | (https://arxiv.org/abs/1904.04971) 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import math 10 | from functools import partial 11 | import numpy as np 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | from .helpers import to_2tuple 17 | from .conv2d_same import conv2d_same 18 | from .padding import get_padding_value 19 | 20 | 21 | def get_condconv_initializer(initializer, num_experts, expert_shape): 22 | def condconv_initializer(weight): 23 | """CondConv initializer function.""" 24 | num_params = np.prod(expert_shape) 25 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or 26 | weight.shape[1] != num_params): 27 | raise (ValueError( 28 | 'CondConv variables must have shape [num_experts, num_params]')) 29 | for i in range(num_experts): 30 | initializer(weight[i].view(expert_shape)) 31 | return condconv_initializer 32 | 33 | 34 | class CondConv2d(nn.Module): 35 | """ Conditionally Parameterized Convolution 36 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 37 | 38 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 39 | https://github.com/pytorch/pytorch/issues/17983 40 | """ 41 | __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] 42 | 43 | def __init__(self, in_channels, out_channels, kernel_size=3, 44 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 45 | super(CondConv2d, self).__init__() 46 | 47 | self.in_channels = in_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = to_2tuple(kernel_size) 50 | self.stride = to_2tuple(stride) 51 | padding_val, is_padding_dynamic = get_padding_value( 52 | padding, kernel_size, stride=stride, dilation=dilation) 53 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 54 | self.padding = to_2tuple(padding_val) 55 | self.dilation = to_2tuple(dilation) 56 | self.groups = groups 57 | self.num_experts = num_experts 58 | 59 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 60 | weight_num_param = 1 61 | for wd in self.weight_shape: 62 | weight_num_param *= wd 63 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 64 | 65 | if bias: 66 | self.bias_shape = (self.out_channels,) 67 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 68 | else: 69 | self.register_parameter('bias', None) 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | init_weight = get_condconv_initializer( 75 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) 76 | init_weight(self.weight) 77 | if self.bias is not None: 78 | fan_in = np.prod(self.weight_shape[1:]) 79 | bound = 1 / math.sqrt(fan_in) 80 | init_bias = get_condconv_initializer( 81 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) 82 | init_bias(self.bias) 83 | 84 | def forward(self, x, routing_weights): 85 | B, C, H, W = x.shape 86 | weight = torch.matmul(routing_weights, self.weight) 87 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 88 | weight = weight.view(new_weight_shape) 89 | bias = None 90 | if self.bias is not None: 91 | bias = torch.matmul(routing_weights, self.bias) 92 | bias = bias.view(B * self.out_channels) 93 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 94 | x = x.view(1, B * C, H, W) 95 | if self.dynamic_padding: 96 | out = conv2d_same( 97 | x, weight, bias, stride=self.stride, padding=self.padding, 98 | dilation=self.dilation, groups=self.groups * B) 99 | else: 100 | out = F.conv2d( 101 | x, weight, bias, stride=self.stride, padding=self.padding, 102 | dilation=self.dilation, groups=self.groups * B) 103 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 104 | 105 | # Literal port (from TF definition) 106 | # x = torch.split(x, 1, 0) 107 | # weight = torch.split(weight, 1, 0) 108 | # if self.bias is not None: 109 | # bias = torch.matmul(routing_weights, self.bias) 110 | # bias = torch.split(bias, 1, 0) 111 | # else: 112 | # bias = [None] * B 113 | # out = [] 114 | # for xi, wi, bi in zip(x, weight, bias): 115 | # wi = wi.view(*self.weight_shape) 116 | # if bi is not None: 117 | # bi = bi.view(*self.bias_shape) 118 | # out.append(self.conv_fn( 119 | # xi, wi, bi, stride=self.stride, padding=self.padding, 120 | # dilation=self.dilation, groups=self.groups)) 121 | # out = torch.cat(out, 0) 122 | return out 123 | -------------------------------------------------------------------------------- /models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .create_conv2d import create_conv2d 8 | from .create_norm_act import get_norm_act_layer 9 | 10 | 11 | class ConvNormAct(nn.Module): 12 | def __init__( 13 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 14 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): 15 | super(ConvNormAct, self).__init__() 16 | self.conv = create_conv2d( 17 | in_channels, out_channels, kernel_size, stride=stride, 18 | padding=padding, dilation=dilation, groups=groups, bias=bias) 19 | 20 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 21 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 22 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 23 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 25 | 26 | @property 27 | def in_channels(self): 28 | return self.conv.in_channels 29 | 30 | @property 31 | def out_channels(self): 32 | return self.conv.out_channels 33 | 34 | def forward(self, x): 35 | x = self.conv(x) 36 | x = self.bn(x) 37 | return x 38 | 39 | 40 | ConvBnAct = ConvNormAct 41 | 42 | 43 | class ConvNormActAa(nn.Module): 44 | def __init__( 45 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 46 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): 47 | super(ConvNormActAa, self).__init__() 48 | use_aa = aa_layer is not None 49 | 50 | self.conv = create_conv2d( 51 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 52 | padding=padding, dilation=dilation, groups=groups, bias=bias) 53 | 54 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 55 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 56 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 57 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 58 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 59 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() 60 | 61 | @property 62 | def in_channels(self): 63 | return self.conv.in_channels 64 | 65 | @property 66 | def out_channels(self): 67 | return self.conv.out_channels 68 | 69 | def forward(self, x): 70 | x = self.conv(x) 71 | x = self.bn(x) 72 | x = self.aa(x) 73 | return x 74 | -------------------------------------------------------------------------------- /models/layers/create_act.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from typing import Union, Callable, Type 5 | 6 | from .activations import * 7 | from .activations_jit import * 8 | from .activations_me import * 9 | from .config import is_exportable, is_scriptable, is_no_jit 10 | 11 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. 12 | # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. 13 | # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. 14 | _has_silu = 'silu' in dir(torch.nn.functional) 15 | _has_hardswish = 'hardswish' in dir(torch.nn.functional) 16 | _has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) 17 | _has_mish = 'mish' in dir(torch.nn.functional) 18 | 19 | 20 | _ACT_FN_DEFAULT = dict( 21 | silu=F.silu if _has_silu else swish, 22 | swish=F.silu if _has_silu else swish, 23 | mish=F.mish if _has_mish else mish, 24 | relu=F.relu, 25 | relu6=F.relu6, 26 | leaky_relu=F.leaky_relu, 27 | elu=F.elu, 28 | celu=F.celu, 29 | selu=F.selu, 30 | gelu=gelu, 31 | sigmoid=sigmoid, 32 | tanh=tanh, 33 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, 34 | hard_swish=F.hardswish if _has_hardswish else hard_swish, 35 | hard_mish=hard_mish, 36 | ) 37 | 38 | _ACT_FN_JIT = dict( 39 | silu=F.silu if _has_silu else swish_jit, 40 | swish=F.silu if _has_silu else swish_jit, 41 | mish=F.mish if _has_mish else mish_jit, 42 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, 43 | hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, 44 | hard_mish=hard_mish_jit 45 | ) 46 | 47 | _ACT_FN_ME = dict( 48 | silu=F.silu if _has_silu else swish_me, 49 | swish=F.silu if _has_silu else swish_me, 50 | mish=F.mish if _has_mish else mish_me, 51 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, 52 | hard_swish=F.hardswish if _has_hardswish else hard_swish_me, 53 | hard_mish=hard_mish_me, 54 | ) 55 | 56 | _ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) 57 | for a in _ACT_FNS: 58 | a.setdefault('hardsigmoid', a.get('hard_sigmoid')) 59 | a.setdefault('hardswish', a.get('hard_swish')) 60 | 61 | 62 | _ACT_LAYER_DEFAULT = dict( 63 | silu=nn.SiLU if _has_silu else Swish, 64 | swish=nn.SiLU if _has_silu else Swish, 65 | mish=nn.Mish if _has_mish else Mish, 66 | relu=nn.ReLU, 67 | relu6=nn.ReLU6, 68 | leaky_relu=nn.LeakyReLU, 69 | elu=nn.ELU, 70 | prelu=PReLU, 71 | celu=nn.CELU, 72 | selu=nn.SELU, 73 | gelu=GELU, 74 | sigmoid=Sigmoid, 75 | tanh=Tanh, 76 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, 77 | hard_swish=nn.Hardswish if _has_hardswish else HardSwish, 78 | hard_mish=HardMish, 79 | ) 80 | 81 | _ACT_LAYER_JIT = dict( 82 | silu=nn.SiLU if _has_silu else SwishJit, 83 | swish=nn.SiLU if _has_silu else SwishJit, 84 | mish=nn.Mish if _has_mish else MishJit, 85 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, 86 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, 87 | hard_mish=HardMishJit 88 | ) 89 | 90 | _ACT_LAYER_ME = dict( 91 | silu=nn.SiLU if _has_silu else SwishMe, 92 | swish=nn.SiLU if _has_silu else SwishMe, 93 | mish=nn.Mish if _has_mish else MishMe, 94 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, 95 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, 96 | hard_mish=HardMishMe, 97 | ) 98 | 99 | _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) 100 | for a in _ACT_LAYERS: 101 | a.setdefault('hardsigmoid', a.get('hard_sigmoid')) 102 | a.setdefault('hardswish', a.get('hard_swish')) 103 | 104 | 105 | def get_act_fn(name: Union[Callable, str] = 'relu'): 106 | """ Activation Function Factory 107 | Fetching activation fns by name with this function allows export or torch script friendly 108 | functions to be returned dynamically based on current config. 109 | """ 110 | if not name: 111 | return None 112 | if isinstance(name, Callable): 113 | return name 114 | if not (is_no_jit() or is_exportable() or is_scriptable()): 115 | # If not exporting or scripting the model, first look for a memory-efficient version with 116 | # custom autograd, then fallback 117 | if name in _ACT_FN_ME: 118 | return _ACT_FN_ME[name] 119 | if not (is_no_jit() or is_exportable()): 120 | if name in _ACT_FN_JIT: 121 | return _ACT_FN_JIT[name] 122 | return _ACT_FN_DEFAULT[name] 123 | 124 | 125 | def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): 126 | """ Activation Layer Factory 127 | Fetching activation layers by name with this function allows export or torch script friendly 128 | functions to be returned dynamically based on current config. 129 | """ 130 | if not name: 131 | return None 132 | if not isinstance(name, str): 133 | # callable, module, etc 134 | return name 135 | if not (is_no_jit() or is_exportable() or is_scriptable()): 136 | if name in _ACT_LAYER_ME: 137 | return _ACT_LAYER_ME[name] 138 | if not (is_no_jit() or is_exportable()): 139 | if name in _ACT_LAYER_JIT: 140 | return _ACT_LAYER_JIT[name] 141 | return _ACT_LAYER_DEFAULT[name] 142 | 143 | 144 | def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): 145 | act_layer = get_act_layer(name) 146 | if act_layer is None: 147 | return None 148 | return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) 149 | -------------------------------------------------------------------------------- /models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type is not None: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | layernorm=LayerNormAct, 22 | layernorm2d=LayerNormAct2d, 23 | evonormb0=EvoNorm2dB0, 24 | evonormb1=EvoNorm2dB1, 25 | evonormb2=EvoNorm2dB2, 26 | evonorms0=EvoNorm2dS0, 27 | evonorms0a=EvoNorm2dS0a, 28 | evonorms1=EvoNorm2dS1, 29 | evonorms1a=EvoNorm2dS1a, 30 | evonorms2=EvoNorm2dS2, 31 | evonorms2a=EvoNorm2dS2a, 32 | frn=FilterResponseNormAct2d, 33 | frntlu=FilterResponseNormTlu2d, 34 | inplaceabn=InplaceAbn, 35 | iabn=InplaceAbn, 36 | ) 37 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 38 | # has act_layer arg to define act type 39 | _NORM_ACT_REQUIRES_ARG = { 40 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 41 | 42 | 43 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 44 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 45 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 46 | if jit: 47 | layer_instance = torch.jit.script(layer_instance) 48 | return layer_instance 49 | 50 | 51 | def get_norm_act_layer(norm_layer, act_layer=None): 52 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 53 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 54 | norm_act_kwargs = {} 55 | 56 | # unbind partial fn, so args can be rebound later 57 | if isinstance(norm_layer, functools.partial): 58 | norm_act_kwargs.update(norm_layer.keywords) 59 | norm_layer = norm_layer.func 60 | 61 | if isinstance(norm_layer, str): 62 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 63 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) 64 | elif norm_layer in _NORM_ACT_TYPES: 65 | norm_act_layer = norm_layer 66 | elif isinstance(norm_layer, types.FunctionType): 67 | # if function type, must be a lambda/fn that creates a norm_act layer 68 | norm_act_layer = norm_layer 69 | else: 70 | type_name = norm_layer.__name__.lower() 71 | if type_name.startswith('batchnorm'): 72 | norm_act_layer = BatchNormAct2d 73 | elif type_name.startswith('groupnorm'): 74 | norm_act_layer = GroupNormAct 75 | elif type_name.startswith('layernorm2d'): 76 | norm_act_layer = LayerNormAct2d 77 | elif type_name.startswith('layernorm'): 78 | norm_act_layer = LayerNormAct 79 | else: 80 | assert False, f"No equivalent norm_act layer for {type_name}" 81 | 82 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 83 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 84 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 85 | norm_act_kwargs.setdefault('act_layer', act_layer) 86 | if norm_act_kwargs: 87 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 88 | return norm_act_layer 89 | -------------------------------------------------------------------------------- /models/layers/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | 3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 4 | 5 | Papers: 6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 7 | 8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 9 | 10 | Code: 11 | DropBlock impl inspired by two Tensorflow impl that I liked: 12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | def drop_block_2d( 23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, 24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 26 | 27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 28 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 29 | """ 30 | B, C, H, W = x.shape 31 | total_size = W * H 32 | clipped_block_size = min(block_size, min(W, H)) 33 | # seed_drop_rate, the gamma parameter 34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 35 | (W - block_size + 1) * (H - block_size + 1)) 36 | 37 | # Forces the block to be inside the feature map. 38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) 39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ 40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 42 | 43 | if batchwise: 44 | # one mask for whole batch, quite a bit faster 45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 46 | else: 47 | uniform_noise = torch.rand_like(x) 48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 49 | block_mask = -F.max_pool2d( 50 | -block_mask, 51 | kernel_size=clipped_block_size, # block_size, 52 | stride=1, 53 | padding=clipped_block_size // 2) 54 | 55 | if with_noise: 56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 57 | if inplace: 58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 59 | else: 60 | x = x * block_mask + normal_noise * (1 - block_mask) 61 | else: 62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) 63 | if inplace: 64 | x.mul_(block_mask * normalize_scale) 65 | else: 66 | x = x * block_mask * normalize_scale 67 | return x 68 | 69 | 70 | def drop_block_fast_2d( 71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, 72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False): 73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 74 | 75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 76 | block mask at edges. 77 | """ 78 | B, C, H, W = x.shape 79 | total_size = W * H 80 | clipped_block_size = min(block_size, min(W, H)) 81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 82 | (W - block_size + 1) * (H - block_size + 1)) 83 | 84 | block_mask = torch.empty_like(x).bernoulli_(gamma) 85 | block_mask = F.max_pool2d( 86 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) 87 | 88 | if with_noise: 89 | normal_noise = torch.empty_like(x).normal_() 90 | if inplace: 91 | x.mul_(1. - block_mask).add_(normal_noise * block_mask) 92 | else: 93 | x = x * (1. - block_mask) + normal_noise * block_mask 94 | else: 95 | block_mask = 1 - block_mask 96 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype) 97 | if inplace: 98 | x.mul_(block_mask * normalize_scale) 99 | else: 100 | x = x * block_mask * normalize_scale 101 | return x 102 | 103 | 104 | class DropBlock2d(nn.Module): 105 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 106 | """ 107 | 108 | def __init__( 109 | self, 110 | drop_prob: float = 0.1, 111 | block_size: int = 7, 112 | gamma_scale: float = 1.0, 113 | with_noise: bool = False, 114 | inplace: bool = False, 115 | batchwise: bool = False, 116 | fast: bool = True): 117 | super(DropBlock2d, self).__init__() 118 | self.drop_prob = drop_prob 119 | self.gamma_scale = gamma_scale 120 | self.block_size = block_size 121 | self.with_noise = with_noise 122 | self.inplace = inplace 123 | self.batchwise = batchwise 124 | self.fast = fast # FIXME finish comparisons of fast vs not 125 | 126 | def forward(self, x): 127 | if not self.training or not self.drop_prob: 128 | return x 129 | if self.fast: 130 | return drop_block_fast_2d( 131 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace) 132 | else: 133 | return drop_block_2d( 134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 135 | 136 | 137 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 138 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 139 | 140 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 141 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 142 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 143 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 144 | 'survival rate' as the argument. 145 | 146 | """ 147 | if drop_prob == 0. or not training: 148 | return x 149 | keep_prob = 1 - drop_prob 150 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 151 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 152 | if keep_prob > 0.0 and scale_by_keep: 153 | random_tensor.div_(keep_prob) 154 | return x * random_tensor 155 | 156 | 157 | class DropPath(nn.Module): 158 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 159 | """ 160 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 161 | super(DropPath, self).__init__() 162 | self.drop_prob = drop_prob 163 | self.scale_by_keep = scale_by_keep 164 | 165 | def forward(self, x): 166 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 167 | -------------------------------------------------------------------------------- /models/layers/eca.py: -------------------------------------------------------------------------------- 1 | """ 2 | ECA module from ECAnet 3 | 4 | paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks 5 | https://arxiv.org/abs/1910.03151 6 | 7 | Original ECA model borrowed from https://github.com/BangguWu/ECANet 8 | 9 | Modified circular ECA implementation and adaption for use in timm package 10 | by Chris Ha https://github.com/VRandme 11 | 12 | Original License: 13 | 14 | MIT License 15 | 16 | Copyright (c) 2019 BangguWu, Qilong Wang 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in all 26 | copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 | SOFTWARE. 35 | """ 36 | import math 37 | from torch import nn 38 | import torch.nn.functional as F 39 | 40 | 41 | from .create_act import create_act_layer 42 | from .helpers import make_divisible 43 | 44 | 45 | class EcaModule(nn.Module): 46 | """Constructs an ECA module. 47 | 48 | Args: 49 | channels: Number of channels of the input feature map for use in adaptive kernel sizes 50 | for actual calculations according to channel. 51 | gamma, beta: when channel is given parameters of mapping function 52 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf 53 | (default=None. if channel size not given, use k_size given for kernel size.) 54 | kernel_size: Adaptive selection of kernel size (default=3) 55 | gamm: used in kernel_size calc, see above 56 | beta: used in kernel_size calc, see above 57 | act_layer: optional non-linearity after conv, enables conv bias, this is an experiment 58 | gate_layer: gating non-linearity to use 59 | """ 60 | def __init__( 61 | self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', 62 | rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): 63 | super(EcaModule, self).__init__() 64 | if channels is not None: 65 | t = int(abs(math.log(channels, 2) + beta) / gamma) 66 | kernel_size = max(t if t % 2 else t + 1, 3) 67 | assert kernel_size % 2 == 1 68 | padding = (kernel_size - 1) // 2 69 | if use_mlp: 70 | # NOTE 'mlp' mode is a timm experiment, not in paper 71 | assert channels is not None 72 | if rd_channels is None: 73 | rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) 74 | act_layer = act_layer or nn.ReLU 75 | self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) 76 | self.act = create_act_layer(act_layer) 77 | self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) 78 | else: 79 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) 80 | self.act = None 81 | self.conv2 = None 82 | self.gate = create_act_layer(gate_layer) 83 | 84 | def forward(self, x): 85 | y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv 86 | y = self.conv(y) 87 | if self.conv2 is not None: 88 | y = self.act(y) 89 | y = self.conv2(y) 90 | y = self.gate(y).view(x.shape[0], -1, 1, 1) 91 | return x * y.expand_as(x) 92 | 93 | 94 | EfficientChannelAttn = EcaModule # alias 95 | 96 | 97 | class CecaModule(nn.Module): 98 | """Constructs a circular ECA module. 99 | 100 | ECA module where the conv uses circular padding rather than zero padding. 101 | Unlike the spatial dimension, the channels do not have inherent ordering nor 102 | locality. Although this module in essence, applies such an assumption, it is unnecessary 103 | to limit the channels on either "edge" from being circularly adapted to each other. 104 | This will fundamentally increase connectivity and possibly increase performance metrics 105 | (accuracy, robustness), without significantly impacting resource metrics 106 | (parameter size, throughput,latency, etc) 107 | 108 | Args: 109 | channels: Number of channels of the input feature map for use in adaptive kernel sizes 110 | for actual calculations according to channel. 111 | gamma, beta: when channel is given parameters of mapping function 112 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf 113 | (default=None. if channel size not given, use k_size given for kernel size.) 114 | kernel_size: Adaptive selection of kernel size (default=3) 115 | gamm: used in kernel_size calc, see above 116 | beta: used in kernel_size calc, see above 117 | act_layer: optional non-linearity after conv, enables conv bias, this is an experiment 118 | gate_layer: gating non-linearity to use 119 | """ 120 | 121 | def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): 122 | super(CecaModule, self).__init__() 123 | if channels is not None: 124 | t = int(abs(math.log(channels, 2) + beta) / gamma) 125 | kernel_size = max(t if t % 2 else t + 1, 3) 126 | has_act = act_layer is not None 127 | assert kernel_size % 2 == 1 128 | 129 | # PyTorch circular padding mode is buggy as of pytorch 1.4 130 | # see https://github.com/pytorch/pytorch/pull/17240 131 | # implement manual circular padding 132 | self.padding = (kernel_size - 1) // 2 133 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) 134 | self.gate = create_act_layer(gate_layer) 135 | 136 | def forward(self, x): 137 | y = x.mean((2, 3)).view(x.shape[0], 1, -1) 138 | # Manually implement circular padding, F.pad does not seemed to be bugged 139 | y = F.pad(y, (self.padding, self.padding), mode='circular') 140 | y = self.conv(y) 141 | y = self.gate(y).view(x.shape[0], -1, 1, 1) 142 | return x * y.expand_as(x) 143 | 144 | 145 | CircularEfficientChannelAttn = CecaModule 146 | -------------------------------------------------------------------------------- /models/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /models/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /models/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | -------------------------------------------------------------------------------- /models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /models/layers/lambda_layer.py: -------------------------------------------------------------------------------- 1 | """ Lambda Layer 2 | 3 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 4 | - https://arxiv.org/abs/2102.08602 5 | 6 | @misc{2102.08602, 7 | Author = {Irwan Bello}, 8 | Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, 9 | Year = {2021}, 10 | } 11 | 12 | Status: 13 | This impl is a WIP. Code snippets in the paper were used as reference but 14 | good chance some details are missing/wrong. 15 | 16 | I've only implemented local lambda conv based pos embeddings. 17 | 18 | For a PyTorch impl that includes other embedding options checkout 19 | https://github.com/lucidrains/lambda-networks 20 | 21 | Hacked together by / Copyright 2021 Ross Wightman 22 | """ 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | 27 | from .helpers import to_2tuple, make_divisible 28 | from .weight_init import trunc_normal_ 29 | 30 | 31 | def rel_pos_indices(size): 32 | size = to_2tuple(size) 33 | pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) 34 | rel_pos = pos[:, None, :] - pos[:, :, None] 35 | rel_pos[0] += size[0] - 1 36 | rel_pos[1] += size[1] - 1 37 | return rel_pos # 2, H * W, H * W 38 | 39 | 40 | class LambdaLayer(nn.Module): 41 | """Lambda Layer 42 | 43 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` 44 | - https://arxiv.org/abs/2102.08602 45 | 46 | NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. 47 | 48 | The internal dimensions of the lambda module are controlled via the interaction of several arguments. 49 | * the output dimension of the module is specified by dim_out, which falls back to input dim if not set 50 | * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim 51 | * the query (q) and key (k) dimension are determined by 52 | * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None 53 | * q = num_heads * dim_head, k = dim_head 54 | * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set 55 | 56 | Args: 57 | dim (int): input dimension to the module 58 | dim_out (int): output dimension of the module, same as dim if not set 59 | feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W 60 | stride (int): output stride of the module, avg pool used if stride == 2 61 | num_heads (int): parallel attention heads. 62 | dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set 63 | r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9) 64 | qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) 65 | qkv_bias (bool): add bias to q, k, and v projections 66 | """ 67 | def __init__( 68 | self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9, 69 | qk_ratio=1.0, qkv_bias=False): 70 | super().__init__() 71 | dim_out = dim_out or dim 72 | assert dim_out % num_heads == 0, ' should be divided by num_heads' 73 | self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads 74 | self.num_heads = num_heads 75 | self.dim_v = dim_out // num_heads 76 | 77 | self.qkv = nn.Conv2d( 78 | dim, 79 | num_heads * self.dim_qk + self.dim_qk + self.dim_v, 80 | kernel_size=1, bias=qkv_bias) 81 | self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk) 82 | self.norm_v = nn.BatchNorm2d(self.dim_v) 83 | 84 | if r is not None: 85 | # local lambda convolution for pos 86 | self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0)) 87 | self.pos_emb = None 88 | self.rel_pos_indices = None 89 | else: 90 | # relative pos embedding 91 | assert feat_size is not None 92 | feat_size = to_2tuple(feat_size) 93 | rel_size = [2 * s - 1 for s in feat_size] 94 | self.conv_lambda = None 95 | self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk)) 96 | self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) 97 | 98 | self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() 99 | 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in 104 | if self.conv_lambda is not None: 105 | trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5) 106 | if self.pos_emb is not None: 107 | trunc_normal_(self.pos_emb, std=.02) 108 | 109 | def forward(self, x): 110 | B, C, H, W = x.shape 111 | M = H * W 112 | qkv = self.qkv(x) 113 | q, k, v = torch.split(qkv, [ 114 | self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1) 115 | q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K 116 | v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V 117 | k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M 118 | 119 | content_lam = k @ v # B, K, V 120 | content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V 121 | 122 | if self.pos_emb is None: 123 | position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K 124 | position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V 125 | else: 126 | # FIXME relative pos embedding path not fully verified 127 | pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) 128 | position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V 129 | position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V 130 | 131 | out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W 132 | out = self.pool(out) 133 | return out 134 | -------------------------------------------------------------------------------- /models/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /models/layers/mlp.py: -------------------------------------------------------------------------------- 1 | """ MLP module w/ dropout and configurable activation layer 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .helpers import to_2tuple 8 | 9 | 10 | class Mlp(nn.Module): 11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 12 | """ 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | drop_probs = to_2tuple(drop) 18 | 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.drop1 = nn.Dropout(drop_probs[0]) 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop2 = nn.Dropout(drop_probs[1]) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop1(x) 29 | x = self.fc2(x) 30 | x = self.drop2(x) 31 | return x 32 | 33 | 34 | class GluMlp(nn.Module): 35 | """ MLP w/ GLU style gating 36 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 37 | """ 38 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): 39 | super().__init__() 40 | out_features = out_features or in_features 41 | hidden_features = hidden_features or in_features 42 | assert hidden_features % 2 == 0 43 | drop_probs = to_2tuple(drop) 44 | 45 | self.fc1 = nn.Linear(in_features, hidden_features) 46 | self.act = act_layer() 47 | self.drop1 = nn.Dropout(drop_probs[0]) 48 | self.fc2 = nn.Linear(hidden_features // 2, out_features) 49 | self.drop2 = nn.Dropout(drop_probs[1]) 50 | 51 | def init_weights(self): 52 | # override init of fc1 w/ gate portion set to weight near zero, bias=1 53 | fc1_mid = self.fc1.bias.shape[0] // 2 54 | nn.init.ones_(self.fc1.bias[fc1_mid:]) 55 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) 56 | 57 | def forward(self, x): 58 | x = self.fc1(x) 59 | x, gates = x.chunk(2, dim=-1) 60 | x = x * self.act(gates) 61 | x = self.drop1(x) 62 | x = self.fc2(x) 63 | x = self.drop2(x) 64 | return x 65 | 66 | 67 | class GatedMlp(nn.Module): 68 | """ MLP as used in gMLP 69 | """ 70 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 71 | gate_layer=None, drop=0.): 72 | super().__init__() 73 | out_features = out_features or in_features 74 | hidden_features = hidden_features or in_features 75 | drop_probs = to_2tuple(drop) 76 | 77 | self.fc1 = nn.Linear(in_features, hidden_features) 78 | self.act = act_layer() 79 | self.drop1 = nn.Dropout(drop_probs[0]) 80 | if gate_layer is not None: 81 | assert hidden_features % 2 == 0 82 | self.gate = gate_layer(hidden_features) 83 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property? 84 | else: 85 | self.gate = nn.Identity() 86 | self.fc2 = nn.Linear(hidden_features, out_features) 87 | self.drop2 = nn.Dropout(drop_probs[1]) 88 | 89 | def forward(self, x): 90 | x = self.fc1(x) 91 | x = self.act(x) 92 | x = self.drop1(x) 93 | x = self.gate(x) 94 | x = self.fc2(x) 95 | x = self.drop2(x) 96 | return x 97 | 98 | 99 | class ConvMlp(nn.Module): 100 | """ MLP using 1x1 convs that keeps spatial dims 101 | """ 102 | def __init__( 103 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): 104 | super().__init__() 105 | out_features = out_features or in_features 106 | hidden_features = hidden_features or in_features 107 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) 108 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() 109 | self.act = act_layer() 110 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) 111 | self.drop = nn.Dropout(drop) 112 | 113 | def forward(self, x): 114 | x = self.fc1(x) 115 | x = self.norm(x) 116 | x = self.act(x) 117 | x = self.drop(x) 118 | x = self.fc2(x) 119 | return x 120 | -------------------------------------------------------------------------------- /models/layers/non_local_attn.py: -------------------------------------------------------------------------------- 1 | """ Bilinear-Attention-Transform and Non-Local Attention 2 | 3 | Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` 4 | - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html 5 | Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification 6 | """ 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from .conv_bn_act import ConvNormAct 12 | from .helpers import make_divisible 13 | from .trace_utils import _assert 14 | 15 | 16 | class NonLocalAttn(nn.Module): 17 | """Spatial NL block for image classification. 18 | 19 | This was adapted from https://github.com/BA-Transform/BAT-Image-Classification 20 | Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. 21 | """ 22 | 23 | def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): 24 | super(NonLocalAttn, self).__init__() 25 | if rd_channels is None: 26 | rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) 27 | self.scale = in_channels ** -0.5 if use_scale else 1.0 28 | self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) 29 | self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) 30 | self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) 31 | self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) 32 | self.norm = nn.BatchNorm2d(in_channels) 33 | self.reset_parameters() 34 | 35 | def forward(self, x): 36 | shortcut = x 37 | 38 | t = self.t(x) 39 | p = self.p(x) 40 | g = self.g(x) 41 | 42 | B, C, H, W = t.size() 43 | t = t.view(B, C, -1).permute(0, 2, 1) 44 | p = p.view(B, C, -1) 45 | g = g.view(B, C, -1).permute(0, 2, 1) 46 | 47 | att = torch.bmm(t, p) * self.scale 48 | att = F.softmax(att, dim=2) 49 | x = torch.bmm(att, g) 50 | 51 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 52 | x = self.z(x) 53 | x = self.norm(x) + shortcut 54 | 55 | return x 56 | 57 | def reset_parameters(self): 58 | for name, m in self.named_modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_( 61 | m.weight, mode='fan_out', nonlinearity='relu') 62 | if len(list(m.parameters())) > 1: 63 | nn.init.constant_(m.bias, 0.0) 64 | elif isinstance(m, nn.BatchNorm2d): 65 | nn.init.constant_(m.weight, 0) 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.GroupNorm): 68 | nn.init.constant_(m.weight, 0) 69 | nn.init.constant_(m.bias, 0) 70 | 71 | 72 | class BilinearAttnTransform(nn.Module): 73 | 74 | def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 75 | super(BilinearAttnTransform, self).__init__() 76 | 77 | self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) 78 | self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) 79 | self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) 80 | self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) 81 | self.block_size = block_size 82 | self.groups = groups 83 | self.in_channels = in_channels 84 | 85 | def resize_mat(self, x, t: int): 86 | B, C, block_size, block_size1 = x.shape 87 | _assert(block_size == block_size1, '') 88 | if t <= 1: 89 | return x 90 | x = x.view(B * C, -1, 1, 1) 91 | x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) 92 | x = x.view(B * C, block_size, block_size, t, t) 93 | x = torch.cat(torch.split(x, 1, dim=1), dim=3) 94 | x = torch.cat(torch.split(x, 1, dim=2), dim=4) 95 | x = x.view(B, C, block_size * t, block_size * t) 96 | return x 97 | 98 | def forward(self, x): 99 | _assert(x.shape[-1] % self.block_size == 0, '') 100 | _assert(x.shape[-2] % self.block_size == 0, '') 101 | B, C, H, W = x.shape 102 | out = self.conv1(x) 103 | rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) 104 | cp = F.adaptive_max_pool2d(out, (1, self.block_size)) 105 | p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() 106 | q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() 107 | p = p / p.sum(dim=3, keepdim=True) 108 | q = q / q.sum(dim=2, keepdim=True) 109 | p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( 110 | 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() 111 | p = p.view(B, C, self.block_size, self.block_size) 112 | q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( 113 | 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() 114 | q = q.view(B, C, self.block_size, self.block_size) 115 | p = self.resize_mat(p, H // self.block_size) 116 | q = self.resize_mat(q, W // self.block_size) 117 | y = p.matmul(x) 118 | y = y.matmul(q) 119 | 120 | y = self.conv2(y) 121 | return y 122 | 123 | 124 | class BatNonLocalAttn(nn.Module): 125 | """ BAT 126 | Adapted from: https://github.com/BA-Transform/BAT-Image-Classification 127 | """ 128 | 129 | def __init__( 130 | self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 131 | drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): 132 | super().__init__() 133 | if rd_channels is None: 134 | rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) 135 | self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) 136 | self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) 137 | self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) 138 | self.dropout = nn.Dropout2d(p=drop_rate) 139 | 140 | def forward(self, x): 141 | xl = self.conv1(x) 142 | y = self.ba(xl) 143 | y = self.conv2(y) 144 | y = self.dropout(y) 145 | return y + x 146 | -------------------------------------------------------------------------------- /models/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GroupNorm(nn.GroupNorm): 9 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): 10 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 11 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 12 | 13 | def forward(self, x): 14 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 15 | 16 | 17 | class LayerNorm2d(nn.LayerNorm): 18 | """ LayerNorm for channels of '2D' spatial BCHW tensors """ 19 | def __init__(self, num_channels): 20 | super().__init__(num_channels) 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | return F.layer_norm( 24 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 25 | -------------------------------------------------------------------------------- /models/layers/norm_act.py: -------------------------------------------------------------------------------- 1 | """ Normalization + Activation Layers 2 | """ 3 | from typing import Union, List 4 | 5 | import torch 6 | from torch import nn as nn 7 | from torch.nn import functional as F 8 | 9 | from .trace_utils import _assert 10 | from .create_act import get_act_layer 11 | 12 | 13 | class BatchNormAct2d(nn.BatchNorm2d): 14 | """BatchNorm + Activation 15 | 16 | This module performs BatchNorm + Activation in a manner that will remain backwards 17 | compatible with weights trained with separate bn, act. This is why we inherit from BN 18 | instead of composing it as a .bn member. 19 | """ 20 | def __init__( 21 | self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, 22 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): 23 | super(BatchNormAct2d, self).__init__( 24 | num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 25 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 26 | act_layer = get_act_layer(act_layer) # string -> nn.Module 27 | if act_layer is not None and apply_act: 28 | act_args = dict(inplace=True) if inplace else {} 29 | self.act = act_layer(**act_args) 30 | else: 31 | self.act = nn.Identity() 32 | 33 | def forward(self, x): 34 | # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing 35 | _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)') 36 | 37 | # exponential_average_factor is set to self.momentum 38 | # (when it is available) only so that it gets updated 39 | # in ONNX graph when this node is exported to ONNX. 40 | if self.momentum is None: 41 | exponential_average_factor = 0.0 42 | else: 43 | exponential_average_factor = self.momentum 44 | 45 | if self.training and self.track_running_stats: 46 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 47 | if self.num_batches_tracked is not None: # type: ignore[has-type] 48 | self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] 49 | if self.momentum is None: # use cumulative moving average 50 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 51 | else: # use exponential moving average 52 | exponential_average_factor = self.momentum 53 | 54 | r""" 55 | Decide whether the mini-batch stats should be used for normalization rather than the buffers. 56 | Mini-batch stats are used in training mode, and in eval mode when buffers are None. 57 | """ 58 | if self.training: 59 | bn_training = True 60 | else: 61 | bn_training = (self.running_mean is None) and (self.running_var is None) 62 | 63 | r""" 64 | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 65 | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 66 | used for normalization (i.e. in eval mode when buffers are not None). 67 | """ 68 | x = F.batch_norm( 69 | x, 70 | # If buffers are not to be tracked, ensure that they won't be updated 71 | self.running_mean if not self.training or self.track_running_stats else None, 72 | self.running_var if not self.training or self.track_running_stats else None, 73 | self.weight, 74 | self.bias, 75 | bn_training, 76 | exponential_average_factor, 77 | self.eps, 78 | ) 79 | x = self.drop(x) 80 | x = self.act(x) 81 | return x 82 | 83 | 84 | def _num_groups(num_channels, num_groups, group_size): 85 | if group_size: 86 | assert num_channels % group_size == 0 87 | return num_channels // group_size 88 | return num_groups 89 | 90 | 91 | class GroupNormAct(nn.GroupNorm): 92 | # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args 93 | def __init__( 94 | self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, 95 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): 96 | super(GroupNormAct, self).__init__( 97 | _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) 98 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 99 | act_layer = get_act_layer(act_layer) # string -> nn.Module 100 | if act_layer is not None and apply_act: 101 | act_args = dict(inplace=True) if inplace else {} 102 | self.act = act_layer(**act_args) 103 | else: 104 | self.act = nn.Identity() 105 | 106 | def forward(self, x): 107 | x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 108 | x = self.drop(x) 109 | x = self.act(x) 110 | return x 111 | 112 | 113 | class LayerNormAct(nn.LayerNorm): 114 | def __init__( 115 | self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, 116 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): 117 | super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) 118 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 119 | act_layer = get_act_layer(act_layer) # string -> nn.Module 120 | if act_layer is not None and apply_act: 121 | act_args = dict(inplace=True) if inplace else {} 122 | self.act = act_layer(**act_args) 123 | else: 124 | self.act = nn.Identity() 125 | 126 | def forward(self, x): 127 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 128 | x = self.drop(x) 129 | x = self.act(x) 130 | return x 131 | 132 | 133 | class LayerNormAct2d(nn.LayerNorm): 134 | def __init__( 135 | self, num_channels, eps=1e-5, affine=True, 136 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): 137 | super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) 138 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 139 | act_layer = get_act_layer(act_layer) # string -> nn.Module 140 | if act_layer is not None and apply_act: 141 | act_args = dict(inplace=True) if inplace else {} 142 | self.act = act_layer(**act_args) 143 | else: 144 | self.act = nn.Identity() 145 | 146 | def forward(self, x): 147 | x = F.layer_norm( 148 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 149 | x = self.drop(x) 150 | x = self.act(x) 151 | return x 152 | -------------------------------------------------------------------------------- /models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 19 | super().__init__() 20 | img_size = to_2tuple(img_size) 21 | patch_size = to_2tuple(patch_size) 22 | self.img_size = img_size 23 | self.patch_size = patch_size 24 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 25 | self.num_patches = self.grid_size[0] * self.grid_size[1] 26 | self.flatten = flatten 27 | 28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 30 | 31 | def forward(self, x): 32 | B, C, H, W = x.shape 33 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 34 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 35 | x = self.proj(x) 36 | if self.flatten: 37 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 38 | x = self.norm(x) 39 | return x 40 | -------------------------------------------------------------------------------- /models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /models/layers/selective_kernel.py: -------------------------------------------------------------------------------- 1 | """ Selective Kernel Convolution/Attention 2 | 3 | Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | from torch import nn as nn 9 | 10 | from .conv_bn_act import ConvNormActAa 11 | from .helpers import make_divisible 12 | from .trace_utils import _assert 13 | 14 | 15 | def _kernel_valid(k): 16 | if isinstance(k, (list, tuple)): 17 | for ki in k: 18 | return _kernel_valid(ki) 19 | assert k >= 3 and k % 2 20 | 21 | 22 | class SelectiveKernelAttn(nn.Module): 23 | def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): 24 | """ Selective Kernel Attention Module 25 | 26 | Selective Kernel attention mechanism factored out into its own module. 27 | 28 | """ 29 | super(SelectiveKernelAttn, self).__init__() 30 | self.num_paths = num_paths 31 | self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) 32 | self.bn = norm_layer(attn_channels) 33 | self.act = act_layer(inplace=True) 34 | self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) 35 | 36 | def forward(self, x): 37 | _assert(x.shape[1] == self.num_paths, '') 38 | x = x.sum(1).mean((2, 3), keepdim=True) 39 | x = self.fc_reduce(x) 40 | x = self.bn(x) 41 | x = self.act(x) 42 | x = self.fc_select(x) 43 | B, C, H, W = x.shape 44 | x = x.view(B, self.num_paths, C // self.num_paths, H, W) 45 | x = torch.softmax(x, dim=1) 46 | return x 47 | 48 | 49 | class SelectiveKernel(nn.Module): 50 | 51 | def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, 52 | rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, 53 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): 54 | """ Selective Kernel Convolution Module 55 | 56 | As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. 57 | 58 | Largest change is the input split, which divides the input channels across each convolution path, this can 59 | be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps 60 | the parameter count from ballooning when the convolutions themselves don't have groups, but still provides 61 | a noteworthy increase in performance over similar param count models without this attention layer. -Ross W 62 | 63 | Args: 64 | in_channels (int): module input (feature) channel count 65 | out_channels (int): module output (feature) channel count 66 | kernel_size (int, list): kernel size for each convolution branch 67 | stride (int): stride for convolutions 68 | dilation (int): dilation for module as a whole, impacts dilation of each branch 69 | groups (int): number of groups for each branch 70 | rd_ratio (int, float): reduction factor for attention features 71 | keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations 72 | split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, 73 | can be viewed as grouping by path, output expands to module out_channels count 74 | act_layer (nn.Module): activation layer to use 75 | norm_layer (nn.Module): batchnorm/norm layer to use 76 | aa_layer (nn.Module): anti-aliasing module 77 | drop_layer (nn.Module): spatial drop module in convs (drop block, etc) 78 | """ 79 | super(SelectiveKernel, self).__init__() 80 | out_channels = out_channels or in_channels 81 | kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation 82 | _kernel_valid(kernel_size) 83 | if not isinstance(kernel_size, list): 84 | kernel_size = [kernel_size] * 2 85 | if keep_3x3: 86 | dilation = [dilation * (k - 1) // 2 for k in kernel_size] 87 | kernel_size = [3] * len(kernel_size) 88 | else: 89 | dilation = [dilation] * len(kernel_size) 90 | self.num_paths = len(kernel_size) 91 | self.in_channels = in_channels 92 | self.out_channels = out_channels 93 | self.split_input = split_input 94 | if self.split_input: 95 | assert in_channels % self.num_paths == 0 96 | in_channels = in_channels // self.num_paths 97 | groups = min(out_channels, groups) 98 | 99 | conv_kwargs = dict( 100 | stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, 101 | aa_layer=aa_layer, drop_layer=drop_layer) 102 | self.paths = nn.ModuleList([ 103 | ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) 104 | for k, d in zip(kernel_size, dilation)]) 105 | 106 | attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) 107 | self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) 108 | 109 | def forward(self, x): 110 | if self.split_input: 111 | x_split = torch.split(x, self.in_channels // self.num_paths, 1) 112 | x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] 113 | else: 114 | x_paths = [op(x) for op in self.paths] 115 | x = torch.stack(x_paths, dim=1) 116 | x_attn = self.attn(x) 117 | x = x * x_attn 118 | x = torch.sum(x, dim=1) 119 | return x 120 | -------------------------------------------------------------------------------- /models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /models/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | -------------------------------------------------------------------------------- /models/layers/std_conv.py: -------------------------------------------------------------------------------- 1 | """ Convolution with Weight Standardization (StdConv and ScaledStdConv) 2 | 3 | StdConv: 4 | @article{weightstandardization, 5 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, 6 | title = {Weight Standardization}, 7 | journal = {arXiv preprint arXiv:1903.10520}, 8 | year = {2019}, 9 | } 10 | Code: https://github.com/joe-siyuan-qiao/WeightStandardization 11 | 12 | ScaledStdConv: 13 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` 14 | - https://arxiv.org/abs/2101.08692 15 | Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets 16 | 17 | Hacked together by / copyright Ross Wightman, 2021. 18 | """ 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | from .padding import get_padding, get_padding_value, pad_same 24 | 25 | 26 | class StdConv2d(nn.Conv2d): 27 | """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. 28 | 29 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 30 | https://arxiv.org/abs/1903.10520v2 31 | """ 32 | def __init__( 33 | self, in_channel, out_channels, kernel_size, stride=1, padding=None, 34 | dilation=1, groups=1, bias=False, eps=1e-6): 35 | if padding is None: 36 | padding = get_padding(kernel_size, stride, dilation) 37 | super().__init__( 38 | in_channel, out_channels, kernel_size, stride=stride, 39 | padding=padding, dilation=dilation, groups=groups, bias=bias) 40 | self.eps = eps 41 | 42 | def forward(self, x): 43 | weight = F.batch_norm( 44 | self.weight.reshape(1, self.out_channels, -1), None, None, 45 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 46 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 47 | return x 48 | 49 | 50 | class StdConv2dSame(nn.Conv2d): 51 | """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. 52 | 53 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - 54 | https://arxiv.org/abs/1903.10520v2 55 | """ 56 | def __init__( 57 | self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', 58 | dilation=1, groups=1, bias=False, eps=1e-6): 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 60 | super().__init__( 61 | in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 62 | groups=groups, bias=bias) 63 | self.same_pad = is_dynamic 64 | self.eps = eps 65 | 66 | def forward(self, x): 67 | if self.same_pad: 68 | x = pad_same(x, self.kernel_size, self.stride, self.dilation) 69 | weight = F.batch_norm( 70 | self.weight.reshape(1, self.out_channels, -1), None, None, 71 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 72 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 73 | return x 74 | 75 | 76 | class ScaledStdConv2d(nn.Conv2d): 77 | """Conv2d layer with Scaled Weight Standardization. 78 | 79 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - 80 | https://arxiv.org/abs/2101.08692 81 | 82 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. 83 | """ 84 | 85 | def __init__( 86 | self, in_channels, out_channels, kernel_size, stride=1, padding=None, 87 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): 88 | if padding is None: 89 | padding = get_padding(kernel_size, stride, dilation) 90 | super().__init__( 91 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 92 | groups=groups, bias=bias) 93 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) 94 | self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) 95 | self.eps = eps 96 | 97 | def forward(self, x): 98 | weight = F.batch_norm( 99 | self.weight.reshape(1, self.out_channels, -1), None, None, 100 | weight=(self.gain * self.scale).view(-1), 101 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 102 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 103 | 104 | 105 | class ScaledStdConv2dSame(nn.Conv2d): 106 | """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support 107 | 108 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - 109 | https://arxiv.org/abs/2101.08692 110 | 111 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. 112 | """ 113 | 114 | def __init__( 115 | self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', 116 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): 117 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) 118 | super().__init__( 119 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, 120 | groups=groups, bias=bias) 121 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) 122 | self.scale = gamma * self.weight[0].numel() ** -0.5 123 | self.same_pad = is_dynamic 124 | self.eps = eps 125 | 126 | def forward(self, x): 127 | if self.same_pad: 128 | x = pad_same(x, self.kernel_size, self.stride, self.dilation) 129 | weight = F.batch_norm( 130 | self.weight.reshape(1, self.out_channels, -1), None, None, 131 | weight=(self.gain * self.scale).view(-1), 132 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight) 133 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 134 | -------------------------------------------------------------------------------- /models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=True): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /models/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | from torch.nn.init import _calculate_fan_in_and_fan_out 6 | 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | 65 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): 66 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 67 | if mode == 'fan_in': 68 | denom = fan_in 69 | elif mode == 'fan_out': 70 | denom = fan_out 71 | elif mode == 'fan_avg': 72 | denom = (fan_in + fan_out) / 2 73 | 74 | variance = scale / denom 75 | 76 | if distribution == "truncated_normal": 77 | # constant is stddev of standard normal truncated to (-2, 2) 78 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) 79 | elif distribution == "normal": 80 | tensor.normal_(std=math.sqrt(variance)) 81 | elif distribution == "uniform": 82 | bound = math.sqrt(3 * variance) 83 | tensor.uniform_(-bound, bound) 84 | else: 85 | raise ValueError(f"invalid distribution {distribution}") 86 | 87 | 88 | def lecun_normal_(tensor): 89 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') 90 | -------------------------------------------------------------------------------- /models/pruned/ecaresnet50d_pruned.txt: -------------------------------------------------------------------------------- 1 | conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] -------------------------------------------------------------------------------- /models/registry.py: -------------------------------------------------------------------------------- 1 | """ Model Registry 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | 5 | import sys 6 | import re 7 | import fnmatch 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | 11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 12 | 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] 13 | 14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 15 | _model_to_module = {} # mapping of model names to module names 16 | _model_entrypoints = {} # mapping of model names to entrypoint fns 17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 18 | _model_pretrained_cfgs = dict() # central repo for model default_cfgs 19 | 20 | 21 | def register_model(fn): 22 | # lookup containing module 23 | mod = sys.modules[fn.__module__] 24 | module_name_split = fn.__module__.split('.') 25 | module_name = module_name_split[-1] if len(module_name_split) else '' 26 | 27 | # add model to __all__ in module 28 | model_name = fn.__name__ 29 | if hasattr(mod, '__all__'): 30 | mod.__all__.append(model_name) 31 | else: 32 | mod.__all__ = [model_name] 33 | 34 | # add entries to registry dict/sets 35 | _model_entrypoints[model_name] = fn 36 | _model_to_module[model_name] = module_name 37 | _module_to_models[module_name].add(model_name) 38 | has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this 39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 41 | # entrypoints or non-matching combos 42 | cfg = mod.default_cfgs[model_name] 43 | has_valid_pretrained = ( 44 | ('url' in cfg and 'http' in cfg['url']) or 45 | ('file' in cfg and cfg['file']) or 46 | ('hf_hub_id' in cfg and cfg['hf_hub_id']) 47 | ) 48 | _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] 49 | if has_valid_pretrained: 50 | _model_has_pretrained.add(model_name) 51 | return fn 52 | 53 | 54 | def _natural_key(string_): 55 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 56 | 57 | 58 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): 59 | """ Return list of available model names, sorted alphabetically 60 | 61 | Args: 62 | filter (str) - Wildcard filter string that works with fnmatch 63 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 64 | pretrained (bool) - Include only models with pretrained weights if True 65 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 66 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) 67 | 68 | Example: 69 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 70 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 71 | """ 72 | if module: 73 | all_models = list(_module_to_models[module]) 74 | else: 75 | all_models = _model_entrypoints.keys() 76 | if filter: 77 | models = [] 78 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter] 79 | for f in include_filters: 80 | include_models = fnmatch.filter(all_models, f) # include these models 81 | if len(include_models): 82 | models = set(models).union(include_models) 83 | else: 84 | models = all_models 85 | if exclude_filters: 86 | if not isinstance(exclude_filters, (tuple, list)): 87 | exclude_filters = [exclude_filters] 88 | for xf in exclude_filters: 89 | exclude_models = fnmatch.filter(models, xf) # exclude these models 90 | if len(exclude_models): 91 | models = set(models).difference(exclude_models) 92 | if pretrained: 93 | models = _model_has_pretrained.intersection(models) 94 | if name_matches_cfg: 95 | models = set(_model_pretrained_cfgs).intersection(models) 96 | return list(sorted(models, key=_natural_key)) 97 | 98 | 99 | def is_model(model_name): 100 | """ Check if a model name exists 101 | """ 102 | return model_name in _model_entrypoints 103 | 104 | 105 | def model_entrypoint(model_name): 106 | """Fetch a model entrypoint for specified model name 107 | """ 108 | return _model_entrypoints[model_name] 109 | 110 | 111 | def list_modules(): 112 | """ Return list of module names that contain models / model entrypoints 113 | """ 114 | modules = _module_to_models.keys() 115 | return list(sorted(modules)) 116 | 117 | 118 | def is_model_in_modules(model_name, module_names): 119 | """Check if a model exists within a subset of modules 120 | Args: 121 | model_name (str) - name of model to check 122 | module_names (tuple, list, set) - names of modules to search in 123 | """ 124 | assert isinstance(module_names, (tuple, list, set)) 125 | return any(model_name in _module_to_models[n] for n in module_names) 126 | 127 | 128 | def is_model_pretrained(model_name): 129 | return model_name in _model_has_pretrained 130 | 131 | 132 | def get_pretrained_cfg(model_name): 133 | if model_name in _model_pretrained_cfgs: 134 | return deepcopy(_model_pretrained_cfgs[model_name]) 135 | return {} 136 | 137 | 138 | def has_pretrained_cfg_key(model_name, cfg_key): 139 | """ Query model default_cfgs for existence of a specific key. 140 | """ 141 | if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]: 142 | return True 143 | return False 144 | 145 | 146 | def is_pretrained_cfg_key(model_name, cfg_key): 147 | """ Return truthy value for specified model default_cfg key, False if does not exist. 148 | """ 149 | if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False): 150 | return True 151 | return False 152 | 153 | 154 | def get_pretrained_cfg_value(model_name, cfg_key): 155 | """ Get a specific model default_cfg value by key. None if it doesn't exist. 156 | """ 157 | if model_name in _model_pretrained_cfgs: 158 | return _model_pretrained_cfgs[model_name].get(cfg_key, None) 159 | return None -------------------------------------------------------------------------------- /pics/CPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/CPU.png -------------------------------------------------------------------------------- /pics/GPU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/GPU.png -------------------------------------------------------------------------------- /pics/pic1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/pic1.png -------------------------------------------------------------------------------- /pics/pic2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/pic2.png -------------------------------------------------------------------------------- /pics/pic3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/pic3.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.5.4 2 | torch>=1.9.0 3 | torchvision>=0.5.0 4 | pyyaml 5 | inplace-abn -------------------------------------------------------------------------------- /test/test_build_kd_model.py: -------------------------------------------------------------------------------- 1 | 2 | # test_build_kd_model.py - Generated by https://www.codium.ai/ 3 | 4 | import unittest 5 | from kd.kd_utils import build_kd_model 6 | 7 | """ 8 | Code Analysis: 9 | - This class is used to build a knowledge distillation (KD) model. 10 | - It uses the create_model() function from the models.factory module to create a KD model. 11 | - It then uses the InplacABN_to_ABN() and fuse_bn2d_bn1d_abn() functions from the kd.helpers module to convert the model to an ABN model and fuse the batch normalization layers. 12 | - The model is then loaded onto the GPU and set to evaluation mode. 13 | - The mean and standard deviation of the model are also stored. 14 | - The normalize_input() function is used to normalize the input data to match the mean and standard deviation of the KD model. 15 | - It uses the torchvision.transforms module to apply the necessary transformations. 16 | """ 17 | 18 | 19 | """ 20 | Test strategies: 21 | - test_create_model(): tests that the create_model() function from the models.factory module is correctly called and the model is created. 22 | - test_InplacABN_to_ABN(): tests that the InplacABN_to_ABN() function from the kd.helpers module is correctly called and the model is converted to an ABN model. 23 | - test_fuse_bn2d_bn1d_abn(): tests that the fuse_bn2d_bn1d_abn() function from the kd.helpers module is correctly called and the batch normalization layers are fused. 24 | - test_load_gpu(): tests that the model is correctly loaded onto the GPU and set to evaluation mode. 25 | - test_mean_std(): tests that the mean and standard deviation of the model are correctly stored. 26 | - test_normalize_input(): tests that the normalize_input() function is correctly called and the input data is normalized to match the mean and standard deviation of the KD model. 27 | - test_transforms(): tests that the torchvision.transforms module is correctly called and the necessary transformations are applied. 28 | - test_edge_cases(): tests that the class handles edge cases correctly. 29 | """ 30 | 31 | 32 | class TestBuildKdModel(unittest.TestCase): 33 | 34 | def setUp(self): 35 | self.args = { 36 | 'kd_model_name': 'resnet18', 37 | 'kd_model_path': './models/resnet18.pth', 38 | 'num_classes': 10, 39 | 'in_chans': 3 40 | } 41 | 42 | def test_create_model(self): 43 | model = build_kd_model(self.args) 44 | self.assertIsNotNone(model.model) 45 | 46 | def test_InplacABN_to_ABN(self): 47 | model = build_kd_model(self.args) 48 | self.assertIsNotNone(model.model.bn1) 49 | 50 | def test_fuse_bn2d_bn1d_abn(self): 51 | model = build_kd_model(self.args) 52 | self.assertIsNotNone(model.model.bn2d) 53 | 54 | def test_load_gpu(self): 55 | model = build_kd_model(self.args) 56 | self.assertEqual(model.model.device.type, 'cuda') 57 | self.assertEqual(model.model.training, False) 58 | 59 | def test_mean_std(self): 60 | model = build_kd_model(self.args) 61 | self.assertEqual(model.mean_model_kd, model.model.default_cfg['mean']) 62 | self.assertEqual(model.std_model_kd, model.model.default_cfg['std']) 63 | 64 | def test_normalize_input(self): 65 | model = build_kd_model(self.args) 66 | input = torch.randn(3, 224, 224) 67 | student_model = create_model('resnet18', None, False, 10, 3) 68 | model.normalize_input(input, student_model) 69 | self.assertEqual(input.mean(), model.mean_model_kd[0]) 70 | self.assertEqual(input.std(), model.std_model_kd[0]) 71 | 72 | def test_transforms(self): 73 | model = build_kd_model(self.args) 74 | student_model = create_model('resnet18', None, False, 10, 3) 75 | self.assertIsInstance(model.transform_std, T.Normalize) 76 | self.assertIsInstance(model.transform_mean, T.Normalize) 77 | 78 | def test_edge_cases(self): 79 | args = { 80 | 'kd_model_name': 'resnet18', 81 | 'kd_model_path': None, 82 | 'num_classes': 10, 83 | 'in_chans': 3 84 | } 85 | model = build_kd_model(args) 86 | self.assertIsNotNone(model.model) 87 | -------------------------------------------------------------------------------- /utils/checkpoint_saver.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint Saver 2 | 3 | Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import glob 9 | import operator 10 | import os 11 | import logging 12 | 13 | import torch 14 | 15 | from timm.utils.model import unwrap_model, get_state_dict 16 | 17 | _logger = logging.getLogger(__name__) 18 | 19 | 20 | class CheckpointSaverUSI: # don't save optimizer state dict, since it forces specific repo sturcture 21 | def __init__( 22 | self, 23 | model, 24 | optimizer, 25 | args=None, 26 | model_ema=None, 27 | amp_scaler=None, 28 | checkpoint_prefix='checkpoint', 29 | recovery_prefix='recovery', 30 | checkpoint_dir='', 31 | recovery_dir='', 32 | decreasing=False, 33 | max_history=10, 34 | unwrap_fn=unwrap_model): 35 | 36 | # objects to save state_dicts of 37 | self.model = model 38 | self.optimizer = optimizer 39 | self.args = args 40 | self.model_ema = model_ema 41 | self.amp_scaler = amp_scaler 42 | 43 | # state 44 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness 45 | self.best_epoch = None 46 | self.best_metric = None 47 | self.curr_recovery_file = '' 48 | self.last_recovery_file = '' 49 | 50 | # config 51 | self.checkpoint_dir = checkpoint_dir 52 | self.recovery_dir = recovery_dir 53 | self.save_prefix = checkpoint_prefix 54 | self.recovery_prefix = recovery_prefix 55 | self.extension = '.pth.tar' 56 | self.decreasing = decreasing # a lower metric is better if True 57 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs 58 | self.max_history = max_history 59 | self.unwrap_fn = unwrap_fn 60 | assert self.max_history >= 1 61 | 62 | def save_checkpoint(self, epoch, metric=None, metric_ema=None): 63 | assert epoch >= 0 64 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) 65 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) 66 | self._save(tmp_save_path, epoch, metric, metric_ema) 67 | if os.path.exists(last_save_path): 68 | os.unlink(last_save_path) # required for Windows support. 69 | os.rename(tmp_save_path, last_save_path) 70 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None 71 | if (len(self.checkpoint_files) < self.max_history 72 | or metric is None or self.cmp(metric, worst_file[1])): 73 | if len(self.checkpoint_files) >= self.max_history: 74 | self._cleanup_checkpoints(1) 75 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension 76 | save_path = os.path.join(self.checkpoint_dir, filename) 77 | os.link(last_save_path, save_path) 78 | self.checkpoint_files.append((save_path, metric)) 79 | self.checkpoint_files = sorted( 80 | self.checkpoint_files, key=lambda x: x[1], 81 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better 82 | 83 | checkpoints_str = "Current checkpoints:\n" 84 | for c in self.checkpoint_files: 85 | checkpoints_str += ' {}\n'.format(c) 86 | _logger.info(checkpoints_str) 87 | 88 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): 89 | self.best_epoch = epoch 90 | self.best_metric = metric 91 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) 92 | if os.path.exists(best_save_path): 93 | os.unlink(best_save_path) 94 | os.link(last_save_path, best_save_path) 95 | 96 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) 97 | 98 | def _save(self, save_path, epoch, metric=None, metric_ema=None): 99 | save_state = { 100 | 'epoch': epoch, 101 | 'arch': type(self.model).__name__.lower(), 102 | 'state_dict': get_state_dict(self.model, self.unwrap_fn), 103 | # 'optimizer': self.optimizer.state_dict(), 104 | 'version': 2, # version < 2 increments epoch before save 105 | } 106 | if self.args is not None: 107 | save_state['arch'] = self.args.model 108 | save_state['args'] = self.args 109 | # if self.amp_scaler is not None: 110 | # save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() 111 | if self.model_ema is not None: 112 | if metric_ema > metric: # save EMA weights instead of regular weights 113 | save_state['state_dict'] = get_state_dict(self.model_ema, self.unwrap_fn) 114 | # save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) 115 | if metric is not None: 116 | save_state['metric'] = metric 117 | torch.save(save_state, save_path) 118 | 119 | def _cleanup_checkpoints(self, trim=0): 120 | trim = min(len(self.checkpoint_files), trim) 121 | delete_index = self.max_history - trim 122 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index: 123 | return 124 | to_delete = self.checkpoint_files[delete_index:] 125 | for d in to_delete: 126 | try: 127 | _logger.debug("Cleaning checkpoint: {}".format(d)) 128 | os.remove(d[0]) 129 | except Exception as e: 130 | _logger.error("Exception '{}' while deleting checkpoint".format(e)) 131 | self.checkpoint_files = self.checkpoint_files[:delete_index] 132 | 133 | def save_recovery(self, epoch, batch_idx=0): 134 | assert epoch >= 0 135 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension 136 | save_path = os.path.join(self.recovery_dir, filename) 137 | self._save(save_path, epoch) 138 | if os.path.exists(self.last_recovery_file): 139 | try: 140 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) 141 | os.remove(self.last_recovery_file) 142 | except Exception as e: 143 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) 144 | self.last_recovery_file = self.curr_recovery_file 145 | self.curr_recovery_file = save_path 146 | 147 | def find_recovery(self): 148 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) 149 | files = glob.glob(recovery_path + '*' + self.extension) 150 | files = sorted(files) 151 | return files[0] if len(files) else '' 152 | --------------------------------------------------------------------------------