├── LICENSE
├── MODEL_ZOO.md
├── README.md
├── Speed_Accuracy_Comparisons.md
├── changes.txt
├── distributed_train.sh
├── inference.py
├── kd
├── __init__.py
├── helpers.py
└── kd_utils.py
├── models
├── __init__.py
├── beit.py
├── byoanet.py
├── byobnet.py
├── cait.py
├── coat.py
├── convit.py
├── convmixer.py
├── convnext.py
├── crossvit.py
├── cspnet.py
├── deit.py
├── densenet.py
├── dla.py
├── dpn.py
├── efficientnet.py
├── efficientnet_blocks.py
├── efficientnet_builder.py
├── factory.py
├── features.py
├── fx_features.py
├── ghostnet.py
├── gluon_resnet.py
├── gluon_xception.py
├── hardcorenas.py
├── helpers.py
├── hrnet.py
├── hub.py
├── inception_resnet_v2.py
├── inception_v3.py
├── inception_v4.py
├── layers
│ ├── __init__.py
│ ├── activations.py
│ ├── activations_jit.py
│ ├── activations_me.py
│ ├── adaptive_avgmax_pool.py
│ ├── attention_pool2d.py
│ ├── blur_pool.py
│ ├── bottleneck_attn.py
│ ├── cbam.py
│ ├── classifier.py
│ ├── cond_conv2d.py
│ ├── config.py
│ ├── conv2d_same.py
│ ├── conv_bn_act.py
│ ├── create_act.py
│ ├── create_attn.py
│ ├── create_conv2d.py
│ ├── create_norm_act.py
│ ├── drop.py
│ ├── eca.py
│ ├── evo_norm.py
│ ├── filter_response_norm.py
│ ├── gather_excite.py
│ ├── global_context.py
│ ├── halo_attn.py
│ ├── helpers.py
│ ├── inplace_abn.py
│ ├── lambda_layer.py
│ ├── linear.py
│ ├── median_pool.py
│ ├── mixed_conv2d.py
│ ├── ml_decoder.py
│ ├── mlp.py
│ ├── non_local_attn.py
│ ├── norm.py
│ ├── norm_act.py
│ ├── padding.py
│ ├── patch_embed.py
│ ├── pool2d_same.py
│ ├── pos_embed.py
│ ├── selective_kernel.py
│ ├── separable_conv.py
│ ├── space_to_depth.py
│ ├── split_attn.py
│ ├── split_batchnorm.py
│ ├── squeeze_excite.py
│ ├── std_conv.py
│ ├── test_time_pool.py
│ ├── trace_utils.py
│ └── weight_init.py
├── levit.py
├── mlp_mixer.py
├── mobilenetv3.py
├── mobilevit.py
├── nasnet.py
├── nest.py
├── nfnet.py
├── pit.py
├── pnasnet.py
├── poolformer.py
├── pruned
│ ├── ecaresnet101d_pruned.txt
│ ├── ecaresnet50d_pruned.txt
│ ├── efficientnet_b1_pruned.txt
│ ├── efficientnet_b2_pruned.txt
│ └── efficientnet_b3_pruned.txt
├── registry.py
├── regnet.py
├── res2net.py
├── resnest.py
├── resnet.py
├── resnetv2.py
├── rexnet.py
├── selecsls.py
├── senet.py
├── sknet.py
├── swin_transformer.py
├── swin_transformer_v2_cr.py
├── tnt.py
├── tresnet.py
├── tresnet_v2.py
├── twins.py
├── vgg.py
├── visformer.py
├── vision_transformer.py
├── vision_transformer_hybrid.py
├── volo.py
├── vovnet.py
├── xception.py
├── xception_aligned.py
└── xcit.py
├── pics
├── CPU.png
├── GPU.png
├── pic1.png
├── pic2.png
└── pic3.png
├── requirements.txt
├── test
└── test_build_kd_model.py
├── train.py
├── utils
└── checkpoint_saver.py
└── validate.py
/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # Model Zoo
2 |
3 | With USI, we can reliably identify models that provide [good speed-accuracy trade-off](Speed_Accuracy_Comparisons.md).
4 |
5 | For those top models, we provide here weights from large-scale pretraining on ImageNet-21K. We recommended using the large-scale weights for transfer learning - they almost always provide superior results on transfer, compared to 1K weights.
6 |
7 | | Backbone | 21K Single-label Pretraining weights | 21K Multi-label Pretraining weights | ImageNet-1K Accurcy [\%] |
8 | | :------------: | :--------------: | :--------------: | :--------------: |
9 | **TResNet-L** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_l_v2/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_l_v2/multi_label_ls.pth) | [83.9](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/USI/tresnet_l_v2_83_9.pth) |
10 | **TResNet-M** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_m/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/tresnet_m/multi_label_ls.pth) | 82.5 |
11 | **ResNet50** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/resnet50/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/resnet50/multi_label_ls.pth) | 81.0 |
12 | **MobileNetV3_Large_100** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/mobilenetv3_large_100/single_label_ls.pth) | N/A | 77.3 |
13 | **LeViT-384** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_384/single_label_ls.pth) | [Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_384/multi_label_ls.pth) | 82.7 |
14 | **LeViT-768** |[Link](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_768/single_label_ls.pth) | N/A | [84.2](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/unified/levit_768/levit_768_84_2.pth) |
15 | **[EdgeNeXt-S](https://arxiv.org/abs/2206.10589)** |N/A | N/A | [81.1](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth) |
16 |
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results
2 |
3 | Official PyTorch Implementation
4 |
[Paper](http://arxiv.org/abs/2204.03475) | [Model Zoo](MODEL_ZOO.md) | [Speed-Accuracy Comparisons](Speed_Accuracy_Comparisons.md) |
5 | > Tal Ridnik, Hussam Lawen, Emanuel Ben-Baruch, Asaf Noy
DAMO Academy, Alibaba
6 | > Group
7 |
8 | **Abstract**
9 |
10 | ImageNet serves as the primary dataset for evaluating the quality of computer-vision models. The common practice today is training each architecture with a tailor-made scheme, designed and tuned by an expert.
11 | In this paper, we present a unified scheme for training any backbone on ImageNet. The scheme, named USI (Unified Scheme for ImageNet), is based on knowledge distillation and modern tricks. It requires no adjustments or hyper-parameters tuning between different models, and is efficient in terms of training times.
12 | We test USI on a wide variety of architectures, including CNNs, Transformers, Mobile-oriented and MLP-only. On all models tested, USI outperforms previous state-of-the-art results. Hence, we are able to transform training on ImageNet from an expert-oriented task to an automatic seamless routine.
13 | Since USI accepts any backbone and trains it to top results, it also enables to perform methodical comparisons, and identify the most efficient backbones along the speed-accuracy Pareto curve.
14 |
15 |
16 |
17 |
18 |  |
19 |
20 |
21 |
22 |
23 | ## 11/1/2023 Update
24 | Added [tests](https://github.com/Alibaba-MIIL/Solving_ImageNet/blob/main/test/test_build_kd_model.py) auto-generated by [CodiumAI](https://www.codium.ai/) tool
25 |
26 | ## How to Train on ImageNet with USI scheme
27 | The proposed USI scheme does not require hyper-parameter tuning. The base training configuration works well for any backbone.
28 | All the results presented in the paper are fully reproducible.
29 |
30 | First download teacher model weights from [here](https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/USI/tresnet_l_v2_83_9.pth)
31 |
32 | An example code - training ResNet50 model with USI:
33 | ```
34 | python3 -u -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 train.py \
35 | /mnt/Imagenet/
36 | --model=resnet50
37 | --kd_model_name=tresnet_l_v2
38 | --kd_model_path=./tresnet_l_v2_83_9.pth
39 | ```
40 |
41 | Some additional degrees of freedom that might be usefull:
42 |
43 | - Adjusting the batch size (defualt - 128): ```--batch-size=...```
44 | - Training for more epochs (default - 300): ```--epochs=...```
45 |
46 |
47 | ## Acknowledgements
48 |
49 | The training code is based on the excellent [timm repository](https://github.com/rwightman/pytorch-image-models). Also, thanks [EdgeNeXt](https://arxiv.org/pdf/2206.10589.pdf) authors for sharing their model.
50 |
51 | ## Citation
52 | ```
53 | @misc{https://doi.org/10.48550/arxiv.2204.03475,
54 | doi = {10.48550/ARXIV.2204.03475},
55 | url = {https://arxiv.org/abs/2204.03475},
56 | author = {Ridnik, Tal and Lawen, Hussam and Ben-Baruch, Emanuel and Noy, Asaf},
57 | keywords = {Computer Vision and Pattern Recognition (cs.CV), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
58 | title = {Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results},
59 | publisher = {arXiv},
60 | year = {2022},
61 | }
62 | ```
63 |
--------------------------------------------------------------------------------
/Speed_Accuracy_Comparisons.md:
--------------------------------------------------------------------------------
1 | # GPU TensorRT Speed-Accuracy Comparison
2 |
3 |
4 |
5 |  |
6 |
7 |
8 |
9 |
10 | ### Analysis and insights for GPU inference results:
11 |
12 | - For low-to-medium accuracies, the most efficient models are LeViT-256 and LeViT-384
13 | - For higher accuracies (> 83.5%), TResNet-L and LeViT-768* models provides the best trade-off among the models tested
14 | - Other transformer models (Swin, TnT and ViT), provide inferior speed-accuracy trade-off compared to modern CNNs. Mobile-oriented models also do not excel on GPU inference
15 | - Note that several modern architectures, titled ”Small” (ConvNext-S, Swin-S, TnT-S), are in fact quite resource-intensive - their inference speed is more than four times slower compared to a plain ResNet50 model
16 |
17 |
18 | # CPU OpenVino Speed-Accuracy Comparison
19 |
20 |
21 |
22 |  |
23 |
24 |
25 |
26 |
27 | ### Analysis and insights for CPU inference results:
28 | - On CPU, Mobile oriented models (OFA, MobileNet, EfficientNet) provide the best speed-accuracy trade-off
29 | - LeViT models, who excelled on GPU inference, are not as efficient for CPU inference.
30 | - Other transformer models (Swin, TnT and ViT) again provide inferior speed-accuracy trade-off.
31 |
32 |
33 |
34 | # Implementation Details
35 |
36 | On GPU, The throughput was tested with TensorRT inference engine, FP16, and batch size of 256. On CPU, The throughput was tested using Intel’s
37 | OpenVINO Inference Engine, FP16, a batch size of 1 and 16 streams (equivalent to the number of CPU cores). All measurements were done after models were optimized to inference by batch-norm fusion. This significantly accelerates
38 | models that utilize batch-norm, like ResNet50, TResNet and LeViT. Note that LeViT-768* model is not a part of the original paper, but a model defined by us, to test LeViT design on higher accuracies regime.
--------------------------------------------------------------------------------
/changes.txt:
--------------------------------------------------------------------------------
1 | - loading KD model, line 377
2 |
3 | model_KD = build_kd_model(args)
4 |
5 | - KD logic, line 714
6 |
7 | # KD logic
8 | if model_KD is not None:
9 |
10 | - line 652. choose the top metric, instead of EMA only
11 |
12 | if ema_eval_metrics[eval_metric] > eval_metrics[eval_metric]: # choose the best model
13 | eval_metrics_unite = ema_eval_metrics
14 |
15 | - created new saver class, CheckpointSaverUSI.
16 | CheckpointSaverUSI solves two issues:
17 | 1) remove saving optimizer status, which causes problems if directory structure is changed
18 | save_state['state_dict'] = get_state_dict(self.model_ema, self.unwrap_fn)
19 | 2) save the correct model (EMA vs regular)
20 |
21 | - add tresnet_l_v2 and volo models
22 |
23 | - in validate.py, perform bn fusion:
24 |
25 | model.cpu().eval()
26 | from kd.helpers import InplacABN_to_ABN,fuse_bn2d_bn1d_abn
27 | model = InplacABN_to_ABN(model)
28 | model = fuse_bn2d_bn1d_abn(model)
--------------------------------------------------------------------------------
/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@"
5 |
6 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """PyTorch Inference Script
3 |
4 | An example inference script that outputs top-k class ids for images in a folder into a csv.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
7 | """
8 | import os
9 | import time
10 | import argparse
11 | import logging
12 | import numpy as np
13 | import torch
14 |
15 | from models import create_model, apply_test_time_pool
16 | from timm.data import ImageDataset, create_loader, resolve_data_config
17 | from timm.utils import AverageMeter, setup_default_logging
18 |
19 | torch.backends.cudnn.benchmark = True
20 | _logger = logging.getLogger('inference')
21 |
22 |
23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
24 | parser.add_argument('data', metavar='DIR',
25 | help='path to dataset')
26 | parser.add_argument('--output_dir', metavar='DIR', default='./',
27 | help='path to output files')
28 | parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
29 | help='model architecture (default: dpn92)')
30 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
31 | help='number of data loading workers (default: 2)')
32 | parser.add_argument('-b', '--batch-size', default=256, type=int,
33 | metavar='N', help='mini-batch size (default: 256)')
34 | parser.add_argument('--img-size', default=None, type=int,
35 | metavar='N', help='Input image dimension')
36 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
37 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
38 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
39 | help='Override mean pixel value of dataset')
40 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
41 | help='Override std deviation of of dataset')
42 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
43 | help='Image resize interpolation type (overrides model)')
44 | parser.add_argument('--num-classes', type=int, default=1000,
45 | help='Number classes in dataset')
46 | parser.add_argument('--log-freq', default=10, type=int,
47 | metavar='N', help='batch logging frequency (default: 10)')
48 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
49 | help='path to latest checkpoint (default: none)')
50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
51 | help='use pre-trained model')
52 | parser.add_argument('--num-gpu', type=int, default=1,
53 | help='Number of GPUS to use')
54 | parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
55 | help='disable test time pool')
56 | parser.add_argument('--topk', default=5, type=int,
57 | metavar='N', help='Top-k to output to CSV')
58 |
59 |
60 | def main():
61 | setup_default_logging()
62 | args = parser.parse_args()
63 | # might as well try to do something useful...
64 | args.pretrained = args.pretrained or not args.checkpoint
65 |
66 | # create model
67 | model = create_model(
68 | args.model,
69 | num_classes=args.num_classes,
70 | in_chans=3,
71 | pretrained=args.pretrained,
72 | checkpoint_path=args.checkpoint)
73 |
74 | _logger.info('Model %s created, param count: %d' %
75 | (args.model, sum([m.numel() for m in model.parameters()])))
76 |
77 | config = resolve_data_config(vars(args), model=model)
78 | model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config)
79 |
80 | if args.num_gpu > 1:
81 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
82 | else:
83 | model = model.cuda()
84 |
85 | loader = create_loader(
86 | ImageDataset(args.data),
87 | input_size=config['input_size'],
88 | batch_size=args.batch_size,
89 | use_prefetcher=True,
90 | interpolation=config['interpolation'],
91 | mean=config['mean'],
92 | std=config['std'],
93 | num_workers=args.workers,
94 | crop_pct=1.0 if test_time_pool else config['crop_pct'])
95 |
96 | model.eval()
97 |
98 | k = min(args.topk, args.num_classes)
99 | batch_time = AverageMeter()
100 | end = time.time()
101 | topk_ids = []
102 | with torch.no_grad():
103 | for batch_idx, (input, _) in enumerate(loader):
104 | input = input.cuda()
105 | labels = model(input)
106 | topk = labels.topk(k)[1]
107 | topk_ids.append(topk.cpu().numpy())
108 |
109 | # measure elapsed time
110 | batch_time.update(time.time() - end)
111 | end = time.time()
112 |
113 | if batch_idx % args.log_freq == 0:
114 | _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
115 | batch_idx, len(loader), batch_time=batch_time))
116 |
117 | topk_ids = np.concatenate(topk_ids, axis=0)
118 |
119 | with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
120 | filenames = loader.dataset.filenames(basename=True)
121 | for filename, label in zip(filenames, topk_ids):
122 | out_file.write('{0},{1}\n'.format(
123 | filename, ','.join([ str(v) for v in label])))
124 |
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
--------------------------------------------------------------------------------
/kd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/kd/__init__.py
--------------------------------------------------------------------------------
/kd/kd_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from models.factory import create_model
3 | from kd.helpers import fuse_bn2d_bn1d_abn, InplacABN_to_ABN
4 | import torchvision.transforms as T
5 |
6 |
7 | class build_kd_model(nn.Module):
8 | def __init__(self, args=None):
9 | super(build_kd_model, self).__init__()
10 |
11 | model_kd = create_model(
12 | model_name=args.kd_model_name,
13 | checkpoint_path=args.kd_model_path,
14 | # pretrained=False,
15 | pretrained=args.kd_model_path is None,
16 | num_classes=args.num_classes,
17 | in_chans=3)
18 |
19 | model_kd.cpu().eval()
20 | model_kd = InplacABN_to_ABN(model_kd)
21 | model_kd = fuse_bn2d_bn1d_abn(model_kd)
22 | self.model = model_kd.cuda().eval()
23 | self.mean_model_kd = model_kd.default_cfg['mean']
24 | self.std_model_kd = model_kd.default_cfg['std']
25 |
26 | # handling different normalization of teacher and student
27 | def normalize_input(self, input, student_model):
28 | if hasattr(student_model, 'module'):
29 | model_s = student_model.module
30 | else:
31 | model_s = student_model
32 |
33 | mean_student = model_s.default_cfg['mean']
34 | std_student = model_s.default_cfg['std']
35 |
36 | input_kd = input
37 | if mean_student != self.mean_model_kd or std_student != self.std_model_kd:
38 | std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1],
39 | self.std_model_kd[2] / std_student[2])
40 | transform_std = T.Normalize(mean=(0, 0, 0), std=std)
41 |
42 | mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1],
43 | self.mean_model_kd[2] - mean_student[2])
44 | transform_mean = T.Normalize(mean=mean, std=(1, 1, 1))
45 |
46 | input_kd = transform_mean(transform_std(input))
47 |
48 | return input_kd
49 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .beit import *
2 | from .byoanet import *
3 | from .byobnet import *
4 | from .cait import *
5 | from .coat import *
6 | from .convit import *
7 | from .convmixer import *
8 | from .convnext import *
9 | from .crossvit import *
10 | from .cspnet import *
11 | from .deit import *
12 | from .densenet import *
13 | from .dla import *
14 | from .dpn import *
15 | from .efficientnet import *
16 | from .ghostnet import *
17 | from .gluon_resnet import *
18 | from .gluon_xception import *
19 | from .hardcorenas import *
20 | from .hrnet import *
21 | from .inception_resnet_v2 import *
22 | from .inception_v3 import *
23 | from .inception_v4 import *
24 | from .levit import *
25 | from .mlp_mixer import *
26 | from .mobilenetv3 import *
27 | from .mobilevit import *
28 | from .nasnet import *
29 | from .nest import *
30 | from .nfnet import *
31 | from .pit import *
32 | from .pnasnet import *
33 | from .poolformer import *
34 | from .regnet import *
35 | from .res2net import *
36 | from .resnest import *
37 | from .resnet import *
38 | from .resnetv2 import *
39 | from .rexnet import *
40 | from .selecsls import *
41 | from .senet import *
42 | from .sknet import *
43 | from .swin_transformer import *
44 | from .swin_transformer_v2_cr import *
45 | from .tnt import *
46 | from .tresnet import *
47 | from .tresnet_v2 import *
48 | from .twins import *
49 | from .vgg import *
50 | from .volo import *
51 | from .visformer import *
52 | from .vision_transformer import *
53 | from .vision_transformer_hybrid import *
54 | from .volo import *
55 | from .vovnet import *
56 | from .xception import *
57 | from .xception_aligned import *
58 | from .xcit import *
59 |
60 | from .factory import create_model, parse_model_name, safe_model_name
61 | from .helpers import load_checkpoint, resume_checkpoint, model_parameters
62 | from .layers import TestTimePoolHead, apply_test_time_pool
63 | from .layers import convert_splitbn_model
64 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
65 | from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
66 | is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value
67 |
--------------------------------------------------------------------------------
/models/convmixer.py:
--------------------------------------------------------------------------------
1 | """ ConvMixer
2 |
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8 | from models.registry import register_model
9 | from .helpers import build_model_with_cfg, checkpoint_seq
10 | from .layers import SelectAdaptivePool2d
11 |
12 |
13 | def _cfg(url='', **kwargs):
14 | return {
15 | 'url': url,
16 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
17 | 'crop_pct': .96, 'interpolation': 'bicubic',
18 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
19 | 'first_conv': 'stem.0',
20 | **kwargs
21 | }
22 |
23 |
24 | default_cfgs = {
25 | 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'),
26 | 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'),
27 | 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar')
28 | }
29 |
30 |
31 | class Residual(nn.Module):
32 | def __init__(self, fn):
33 | super().__init__()
34 | self.fn = fn
35 |
36 | def forward(self, x):
37 | return self.fn(x) + x
38 |
39 |
40 | class ConvMixer(nn.Module):
41 | def __init__(
42 | self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg',
43 | act_layer=nn.GELU, **kwargs):
44 | super().__init__()
45 | self.num_classes = num_classes
46 | self.num_features = dim
47 | self.grad_checkpointing = False
48 |
49 | self.stem = nn.Sequential(
50 | nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
51 | act_layer(),
52 | nn.BatchNorm2d(dim)
53 | )
54 | self.blocks = nn.Sequential(
55 | *[nn.Sequential(
56 | Residual(nn.Sequential(
57 | nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
58 | act_layer(),
59 | nn.BatchNorm2d(dim)
60 | )),
61 | nn.Conv2d(dim, dim, kernel_size=1),
62 | act_layer(),
63 | nn.BatchNorm2d(dim)
64 | ) for i in range(depth)]
65 | )
66 | self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
67 | self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
68 |
69 | @torch.jit.ignore
70 | def group_matcher(self, coarse=False):
71 | matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)')
72 | return matcher
73 |
74 | @torch.jit.ignore
75 | def set_grad_checkpointing(self, enable=True):
76 | self.grad_checkpointing = enable
77 |
78 | @torch.jit.ignore
79 | def get_classifier(self):
80 | return self.head
81 |
82 | def reset_classifier(self, num_classes, global_pool=None):
83 | self.num_classes = num_classes
84 | if global_pool is not None:
85 | self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
86 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
87 |
88 | def forward_features(self, x):
89 | x = self.stem(x)
90 | if self.grad_checkpointing and not torch.jit.is_scripting():
91 | x = checkpoint_seq(self.blocks, x)
92 | else:
93 | x = self.blocks(x)
94 | return x
95 |
96 | def forward_head(self, x, pre_logits: bool = False):
97 | x = self.pooling(x)
98 | return x if pre_logits else self.head(x)
99 |
100 | def forward(self, x):
101 | x = self.forward_features(x)
102 | x = self.forward_head(x)
103 | return x
104 |
105 |
106 | def _create_convmixer(variant, pretrained=False, **kwargs):
107 | return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
108 |
109 |
110 | @register_model
111 | def convmixer_1536_20(pretrained=False, **kwargs):
112 | model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
113 | return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
114 |
115 |
116 | @register_model
117 | def convmixer_768_32(pretrained=False, **kwargs):
118 | model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)
119 | return _create_convmixer('convmixer_768_32', pretrained, **model_args)
120 |
121 |
122 | @register_model
123 | def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs):
124 | model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
125 | return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)
--------------------------------------------------------------------------------
/models/factory.py:
--------------------------------------------------------------------------------
1 | from urllib.parse import urlsplit, urlunsplit
2 | import os
3 |
4 | from .registry import is_model, is_model_in_modules, model_entrypoint
5 | from .helpers import load_checkpoint
6 | from .layers import set_layer_config
7 | from .hub import load_model_config_from_hf
8 |
9 |
10 | def parse_model_name(model_name):
11 | model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use
12 | parsed = urlsplit(model_name)
13 | assert parsed.scheme in ('', 'timm', 'hf-hub')
14 | if parsed.scheme == 'hf-hub':
15 | # FIXME may use fragment as revision, currently `@` in URI path
16 | return parsed.scheme, parsed.path
17 | else:
18 | model_name = os.path.split(parsed.path)[-1]
19 | return 'timm', model_name
20 |
21 |
22 | def safe_model_name(model_name, remove_source=True):
23 | def make_safe(name):
24 | return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
25 | if remove_source:
26 | model_name = parse_model_name(model_name)[-1]
27 | return make_safe(model_name)
28 |
29 |
30 | def create_model(
31 | model_name,
32 | pretrained=False,
33 | pretrained_cfg=None,
34 | checkpoint_path='',
35 | scriptable=None,
36 | exportable=None,
37 | no_jit=None,
38 | **kwargs):
39 | """Create a model
40 |
41 | Args:
42 | model_name (str): name of model to instantiate
43 | pretrained (bool): load pretrained ImageNet-1k weights if true
44 | checkpoint_path (str): path of checkpoint to load after model is initialized
45 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
46 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
47 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
48 |
49 | Keyword Args:
50 | drop_rate (float): dropout rate for training (default: 0.0)
51 | global_pool (str): global pool type (default: 'avg')
52 | **: other kwargs are model specific
53 | """
54 | # Parameters that aren't supported by all models or are intended to only override model defaults if set
55 | # should default to None in command line args/cfg. Remove them if they are present and not set so that
56 | # non-supporting models don't break and default args remain in effect.
57 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
58 |
59 | model_source, model_name = parse_model_name(model_name)
60 | if model_source == 'hf-hub':
61 | # FIXME hf-hub source overrides any passed in pretrained_cfg, warn?
62 | # For model names specified in the form `hf-hub:path/architecture_name@revision`,
63 | # load model weights + pretrained_cfg from Hugging Face hub.
64 | pretrained_cfg, model_name = load_model_config_from_hf(model_name)
65 |
66 | if not is_model(model_name):
67 | raise RuntimeError('Unknown model (%s)' % model_name)
68 |
69 | create_fn = model_entrypoint(model_name)
70 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
71 | model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs)
72 |
73 | if checkpoint_path:
74 | load_checkpoint(model, checkpoint_path)
75 |
76 | return model
77 |
--------------------------------------------------------------------------------
/models/fx_features.py:
--------------------------------------------------------------------------------
1 | """ PyTorch FX Based Feature Extraction Helpers
2 | Using https://pytorch.org/vision/stable/feature_extraction.html
3 | """
4 | from typing import Callable, List, Dict, Union
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from .features import _get_feature_info
10 |
11 | try:
12 | from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
13 | has_fx_feature_extraction = True
14 | except ImportError:
15 | has_fx_feature_extraction = False
16 |
17 | # Layers we went to treat as leaf modules
18 | from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
19 | from .layers.non_local_attn import BilinearAttnTransform
20 | from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
21 |
22 | # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
23 | # BUT modules from timm.models should use the registration mechanism below
24 | _leaf_modules = {
25 | BilinearAttnTransform, # reason: flow control t <= 1
26 | # Reason: get_same_padding has a max which raises a control flow error
27 | Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
28 | CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
29 | }
30 |
31 | try:
32 | from .layers import InplaceAbn
33 | _leaf_modules.add(InplaceAbn)
34 | except ImportError:
35 | pass
36 |
37 |
38 | def register_notrace_module(module: nn.Module):
39 | """
40 | Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
41 | """
42 | _leaf_modules.add(module)
43 | return module
44 |
45 |
46 | # Functions we want to autowrap (treat them as leaves)
47 | _autowrap_functions = set()
48 |
49 |
50 | def register_notrace_function(func: Callable):
51 | """
52 | Decorator for functions which ought not to be traced through
53 | """
54 | _autowrap_functions.add(func)
55 | return func
56 |
57 |
58 | def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
59 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
60 | return _create_feature_extractor(
61 | model, return_nodes,
62 | tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
63 | )
64 |
65 |
66 | class FeatureGraphNet(nn.Module):
67 | """ A FX Graph based feature extractor that works with the model feature_info metadata
68 | """
69 | def __init__(self, model, out_indices, out_map=None):
70 | super().__init__()
71 | assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
72 | self.feature_info = _get_feature_info(model, out_indices)
73 | if out_map is not None:
74 | assert len(out_map) == len(out_indices)
75 | return_nodes = {
76 | info['module']: out_map[i] if out_map is not None else info['module']
77 | for i, info in enumerate(self.feature_info) if i in out_indices}
78 | self.graph_module = create_feature_extractor(model, return_nodes)
79 |
80 | def forward(self, x):
81 | return list(self.graph_module(x).values())
82 |
83 |
84 | class GraphExtractNet(nn.Module):
85 | """ A standalone feature extraction wrapper that maps dict -> list or single tensor
86 | NOTE:
87 | * one can use feature_extractor directly if dictionary output is desired
88 | * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
89 | metadata for builtin feature extraction mode
90 | * create_feature_extractor can be used directly if dictionary output is acceptable
91 |
92 | Args:
93 | model: model to extract features from
94 | return_nodes: node names to return features from (dict or list)
95 | squeeze_out: if only one output, and output in list format, flatten to single tensor
96 | """
97 | def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
98 | super().__init__()
99 | self.squeeze_out = squeeze_out
100 | self.graph_module = create_feature_extractor(model, return_nodes)
101 |
102 | def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
103 | out = list(self.graph_module(x).values())
104 | if self.squeeze_out and len(out) == 1:
105 | return out[0]
106 | return out
107 |
--------------------------------------------------------------------------------
/models/hub.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from functools import partial
5 | from pathlib import Path
6 | from typing import Union
7 |
8 | import torch
9 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse
10 | try:
11 | from torch.hub import get_dir
12 | except ImportError:
13 | from torch.hub import _get_torch_home as get_dir
14 |
15 | from timm import __version__
16 | try:
17 | from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
18 | cached_download = partial(cached_download, library_name="timm", library_version=__version__)
19 | _has_hf_hub = True
20 | except ImportError:
21 | cached_download = None
22 | _has_hf_hub = False
23 |
24 | _logger = logging.getLogger(__name__)
25 |
26 |
27 | def get_cache_dir(child_dir=''):
28 | """
29 | Returns the location of the directory where models are cached (and creates it if necessary).
30 | """
31 | # Issue warning to move data if old env is set
32 | if os.getenv('TORCH_MODEL_ZOO'):
33 | _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
34 |
35 | hub_dir = get_dir()
36 | child_dir = () if not child_dir else (child_dir,)
37 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
38 | os.makedirs(model_dir, exist_ok=True)
39 | return model_dir
40 |
41 |
42 | def download_cached_file(url, check_hash=True, progress=False):
43 | parts = urlparse(url)
44 | filename = os.path.basename(parts.path)
45 | cached_file = os.path.join(get_cache_dir(), filename)
46 | if not os.path.exists(cached_file):
47 | _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
48 | hash_prefix = None
49 | if check_hash:
50 | r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
51 | hash_prefix = r.group(1) if r else None
52 | download_url_to_file(url, cached_file, hash_prefix, progress=progress)
53 | return cached_file
54 |
55 |
56 | def has_hf_hub(necessary=False):
57 | if not _has_hf_hub and necessary:
58 | # if no HF Hub module installed and it is necessary to continue, raise error
59 | raise RuntimeError(
60 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
61 | return _has_hf_hub
62 |
63 |
64 | def hf_split(hf_id):
65 | # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
66 | rev_split = hf_id.split('@')
67 | assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
68 | hf_model_id = rev_split[0]
69 | hf_revision = rev_split[-1] if len(rev_split) > 1 else None
70 | return hf_model_id, hf_revision
71 |
72 |
73 | def load_cfg_from_json(json_file: Union[str, os.PathLike]):
74 | with open(json_file, "r", encoding="utf-8") as reader:
75 | text = reader.read()
76 | return json.loads(text)
77 |
78 |
79 | def _download_from_hf(model_id: str, filename: str):
80 | hf_model_id, hf_revision = hf_split(model_id)
81 | url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
82 | return cached_download(url, cache_dir=get_cache_dir('hf'))
83 |
84 |
85 | def load_model_config_from_hf(model_id: str):
86 | assert has_hf_hub(True)
87 | cached_file = _download_from_hf(model_id, 'config.json')
88 | pretrained_cfg = load_cfg_from_json(cached_file)
89 | pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
90 | pretrained_cfg['source'] = 'hf-hub'
91 | model_name = pretrained_cfg.get('architecture')
92 | return pretrained_cfg, model_name
93 |
94 |
95 | def load_state_dict_from_hf(model_id: str):
96 | assert has_hf_hub(True)
97 | cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
98 | state_dict = torch.load(cached_file, map_location='cpu')
99 | return state_dict
100 |
101 |
102 | def save_for_hf(model, save_directory, model_config=None):
103 | assert has_hf_hub(True)
104 | model_config = model_config or {}
105 | save_directory = Path(save_directory)
106 | save_directory.mkdir(exist_ok=True, parents=True)
107 |
108 | weights_path = save_directory / 'pytorch_model.bin'
109 | torch.save(model.state_dict(), weights_path)
110 |
111 | config_path = save_directory / 'config.json'
112 | hf_config = model.pretrained_cfg
113 | hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
114 | hf_config['num_features'] = model_config.pop('num_features', model.num_features)
115 | hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
116 | hf_config.update(model_config)
117 |
118 | with config_path.open('w') as f:
119 | json.dump(hf_config, f, indent=2)
120 |
121 |
122 | def push_to_hf_hub(
123 | model,
124 | local_dir,
125 | repo_namespace_or_url=None,
126 | commit_message='Add model',
127 | use_auth_token=True,
128 | git_email=None,
129 | git_user=None,
130 | revision=None,
131 | model_config=None,
132 | ):
133 | if repo_namespace_or_url:
134 | repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
135 | else:
136 | if isinstance(use_auth_token, str):
137 | token = use_auth_token
138 | else:
139 | token = HfFolder.get_token()
140 |
141 | if token is None:
142 | raise ValueError(
143 | "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
144 | "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
145 | "token as the `use_auth_token` argument."
146 | )
147 |
148 | repo_owner = HfApi().whoami(token)['name']
149 | repo_name = Path(local_dir).name
150 |
151 | repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
152 |
153 | repo = Repository(
154 | local_dir,
155 | clone_from=repo_url,
156 | use_auth_token=use_auth_token,
157 | git_user=git_user,
158 | git_email=git_email,
159 | revision=revision,
160 | )
161 |
162 | # Prepare a default model card that includes the necessary tags to enable inference.
163 | readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
164 | with repo.commit(commit_message):
165 | # Save model weights and config.
166 | save_for_hf(model, repo.local_dir, model_config=model_config)
167 |
168 | # Save a model card if it doesn't exist.
169 | readme_path = Path(repo.local_dir) / 'README.md'
170 | if not readme_path.exists():
171 | readme_path.write_text(readme_text)
172 |
173 | return repo.git_remote_url()
174 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .activations import *
2 | from .adaptive_avgmax_pool import \
3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4 | from .blur_pool import BlurPool2d
5 | from .classifier import ClassifierHead, create_classifier
6 | from .cond_conv2d import CondConv2d, get_condconv_initializer
7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
8 | set_layer_config
9 | from .conv2d_same import Conv2dSame, conv2d_same
10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
11 | from .create_act import create_act_layer, get_act_layer, get_act_fn
12 | from .create_attn import get_attn, create_attn
13 | from .create_conv2d import create_conv2d
14 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
15 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
16 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
17 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
18 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
19 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
20 | from .gather_excite import GatherExcite
21 | from .global_context import GlobalContext
22 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
23 | from .inplace_abn import InplaceAbn
24 | from .linear import Linear
25 | from .mixed_conv2d import MixedConv2d
26 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
27 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
28 | from .norm import GroupNorm, LayerNorm2d
29 | from .norm_act import BatchNormAct2d, GroupNormAct
30 | from .padding import get_padding, get_same_padding, pad_same
31 | from .patch_embed import PatchEmbed
32 | from .pool2d_same import AvgPool2dSame, create_pool2d
33 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
34 | from .selective_kernel import SelectiveKernel
35 | from .separable_conv import SeparableConv2d, SeparableConvNormAct
36 | from .space_to_depth import SpaceToDepthModule
37 | from .split_attn import SplitAttn
38 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
39 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
40 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool
41 | from .trace_utils import _assert, _float_to_int
42 | from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
43 |
--------------------------------------------------------------------------------
/models/layers/activations.py:
--------------------------------------------------------------------------------
1 | """ Activations
2 |
3 | A collection of activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 |
9 | import torch
10 | from torch import nn as nn
11 | from torch.nn import functional as F
12 |
13 |
14 | def swish(x, inplace: bool = False):
15 | """Swish - Described in: https://arxiv.org/abs/1710.05941
16 | """
17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
18 |
19 |
20 | class Swish(nn.Module):
21 | def __init__(self, inplace: bool = False):
22 | super(Swish, self).__init__()
23 | self.inplace = inplace
24 |
25 | def forward(self, x):
26 | return swish(x, self.inplace)
27 |
28 |
29 | def mish(x, inplace: bool = False):
30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
31 | NOTE: I don't have a working inplace variant
32 | """
33 | return x.mul(F.softplus(x).tanh())
34 |
35 |
36 | class Mish(nn.Module):
37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
38 | """
39 | def __init__(self, inplace: bool = False):
40 | super(Mish, self).__init__()
41 |
42 | def forward(self, x):
43 | return mish(x)
44 |
45 |
46 | def sigmoid(x, inplace: bool = False):
47 | return x.sigmoid_() if inplace else x.sigmoid()
48 |
49 |
50 | # PyTorch has this, but not with a consistent inplace argmument interface
51 | class Sigmoid(nn.Module):
52 | def __init__(self, inplace: bool = False):
53 | super(Sigmoid, self).__init__()
54 | self.inplace = inplace
55 |
56 | def forward(self, x):
57 | return x.sigmoid_() if self.inplace else x.sigmoid()
58 |
59 |
60 | def tanh(x, inplace: bool = False):
61 | return x.tanh_() if inplace else x.tanh()
62 |
63 |
64 | # PyTorch has this, but not with a consistent inplace argmument interface
65 | class Tanh(nn.Module):
66 | def __init__(self, inplace: bool = False):
67 | super(Tanh, self).__init__()
68 | self.inplace = inplace
69 |
70 | def forward(self, x):
71 | return x.tanh_() if self.inplace else x.tanh()
72 |
73 |
74 | def hard_swish(x, inplace: bool = False):
75 | inner = F.relu6(x + 3.).div_(6.)
76 | return x.mul_(inner) if inplace else x.mul(inner)
77 |
78 |
79 | class HardSwish(nn.Module):
80 | def __init__(self, inplace: bool = False):
81 | super(HardSwish, self).__init__()
82 | self.inplace = inplace
83 |
84 | def forward(self, x):
85 | return hard_swish(x, self.inplace)
86 |
87 |
88 | def hard_sigmoid(x, inplace: bool = False):
89 | if inplace:
90 | return x.add_(3.).clamp_(0., 6.).div_(6.)
91 | else:
92 | return F.relu6(x + 3.) / 6.
93 |
94 |
95 | class HardSigmoid(nn.Module):
96 | def __init__(self, inplace: bool = False):
97 | super(HardSigmoid, self).__init__()
98 | self.inplace = inplace
99 |
100 | def forward(self, x):
101 | return hard_sigmoid(x, self.inplace)
102 |
103 |
104 | def hard_mish(x, inplace: bool = False):
105 | """ Hard Mish
106 | Experimental, based on notes by Mish author Diganta Misra at
107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
108 | """
109 | if inplace:
110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
111 | else:
112 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
113 |
114 |
115 | class HardMish(nn.Module):
116 | def __init__(self, inplace: bool = False):
117 | super(HardMish, self).__init__()
118 | self.inplace = inplace
119 |
120 | def forward(self, x):
121 | return hard_mish(x, self.inplace)
122 |
123 |
124 | class PReLU(nn.PReLU):
125 | """Applies PReLU (w/ dummy inplace arg)
126 | """
127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
129 |
130 | def forward(self, input: torch.Tensor) -> torch.Tensor:
131 | return F.prelu(input, self.weight)
132 |
133 |
134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
135 | return F.gelu(x)
136 |
137 |
138 | class GELU(nn.Module):
139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
140 | """
141 | def __init__(self, inplace: bool = False):
142 | super(GELU, self).__init__()
143 |
144 | def forward(self, input: torch.Tensor) -> torch.Tensor:
145 | return F.gelu(input)
146 |
--------------------------------------------------------------------------------
/models/layers/activations_jit.py:
--------------------------------------------------------------------------------
1 | """ Activations
2 |
3 | A collection of jit-scripted activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
8 | versions if they contain in-place ops.
9 |
10 | Hacked together by / Copyright 2020 Ross Wightman
11 | """
12 |
13 | import torch
14 | from torch import nn as nn
15 | from torch.nn import functional as F
16 |
17 |
18 | @torch.jit.script
19 | def swish_jit(x, inplace: bool = False):
20 | """Swish - Described in: https://arxiv.org/abs/1710.05941
21 | """
22 | return x.mul(x.sigmoid())
23 |
24 |
25 | @torch.jit.script
26 | def mish_jit(x, _inplace: bool = False):
27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
28 | """
29 | return x.mul(F.softplus(x).tanh())
30 |
31 |
32 | class SwishJit(nn.Module):
33 | def __init__(self, inplace: bool = False):
34 | super(SwishJit, self).__init__()
35 |
36 | def forward(self, x):
37 | return swish_jit(x)
38 |
39 |
40 | class MishJit(nn.Module):
41 | def __init__(self, inplace: bool = False):
42 | super(MishJit, self).__init__()
43 |
44 | def forward(self, x):
45 | return mish_jit(x)
46 |
47 |
48 | @torch.jit.script
49 | def hard_sigmoid_jit(x, inplace: bool = False):
50 | # return F.relu6(x + 3.) / 6.
51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
52 |
53 |
54 | class HardSigmoidJit(nn.Module):
55 | def __init__(self, inplace: bool = False):
56 | super(HardSigmoidJit, self).__init__()
57 |
58 | def forward(self, x):
59 | return hard_sigmoid_jit(x)
60 |
61 |
62 | @torch.jit.script
63 | def hard_swish_jit(x, inplace: bool = False):
64 | # return x * (F.relu6(x + 3.) / 6)
65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
66 |
67 |
68 | class HardSwishJit(nn.Module):
69 | def __init__(self, inplace: bool = False):
70 | super(HardSwishJit, self).__init__()
71 |
72 | def forward(self, x):
73 | return hard_swish_jit(x)
74 |
75 |
76 | @torch.jit.script
77 | def hard_mish_jit(x, inplace: bool = False):
78 | """ Hard Mish
79 | Experimental, based on notes by Mish author Diganta Misra at
80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
81 | """
82 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
83 |
84 |
85 | class HardMishJit(nn.Module):
86 | def __init__(self, inplace: bool = False):
87 | super(HardMishJit, self).__init__()
88 |
89 | def forward(self, x):
90 | return hard_mish_jit(x)
91 |
--------------------------------------------------------------------------------
/models/layers/activations_me.py:
--------------------------------------------------------------------------------
1 | """ Activations (memory-efficient w/ custom autograd)
2 |
3 | A collection of activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | These activations are not compatible with jit scripting or ONNX export of the model, please use either
7 | the JIT or basic versions of the activations.
8 |
9 | Hacked together by / Copyright 2020 Ross Wightman
10 | """
11 |
12 | import torch
13 | from torch import nn as nn
14 | from torch.nn import functional as F
15 |
16 |
17 | @torch.jit.script
18 | def swish_jit_fwd(x):
19 | return x.mul(torch.sigmoid(x))
20 |
21 |
22 | @torch.jit.script
23 | def swish_jit_bwd(x, grad_output):
24 | x_sigmoid = torch.sigmoid(x)
25 | return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
26 |
27 |
28 | class SwishJitAutoFn(torch.autograd.Function):
29 | """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
30 | Inspired by conversation btw Jeremy Howard & Adam Pazske
31 | https://twitter.com/jeremyphoward/status/1188251041835315200
32 | """
33 | @staticmethod
34 | def symbolic(g, x):
35 | return g.op("Mul", x, g.op("Sigmoid", x))
36 |
37 | @staticmethod
38 | def forward(ctx, x):
39 | ctx.save_for_backward(x)
40 | return swish_jit_fwd(x)
41 |
42 | @staticmethod
43 | def backward(ctx, grad_output):
44 | x = ctx.saved_tensors[0]
45 | return swish_jit_bwd(x, grad_output)
46 |
47 |
48 | def swish_me(x, inplace=False):
49 | return SwishJitAutoFn.apply(x)
50 |
51 |
52 | class SwishMe(nn.Module):
53 | def __init__(self, inplace: bool = False):
54 | super(SwishMe, self).__init__()
55 |
56 | def forward(self, x):
57 | return SwishJitAutoFn.apply(x)
58 |
59 |
60 | @torch.jit.script
61 | def mish_jit_fwd(x):
62 | return x.mul(torch.tanh(F.softplus(x)))
63 |
64 |
65 | @torch.jit.script
66 | def mish_jit_bwd(x, grad_output):
67 | x_sigmoid = torch.sigmoid(x)
68 | x_tanh_sp = F.softplus(x).tanh()
69 | return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
70 |
71 |
72 | class MishJitAutoFn(torch.autograd.Function):
73 | """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
74 | A memory efficient, jit scripted variant of Mish
75 | """
76 | @staticmethod
77 | def forward(ctx, x):
78 | ctx.save_for_backward(x)
79 | return mish_jit_fwd(x)
80 |
81 | @staticmethod
82 | def backward(ctx, grad_output):
83 | x = ctx.saved_tensors[0]
84 | return mish_jit_bwd(x, grad_output)
85 |
86 |
87 | def mish_me(x, inplace=False):
88 | return MishJitAutoFn.apply(x)
89 |
90 |
91 | class MishMe(nn.Module):
92 | def __init__(self, inplace: bool = False):
93 | super(MishMe, self).__init__()
94 |
95 | def forward(self, x):
96 | return MishJitAutoFn.apply(x)
97 |
98 |
99 | @torch.jit.script
100 | def hard_sigmoid_jit_fwd(x, inplace: bool = False):
101 | return (x + 3).clamp(min=0, max=6).div(6.)
102 |
103 |
104 | @torch.jit.script
105 | def hard_sigmoid_jit_bwd(x, grad_output):
106 | m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
107 | return grad_output * m
108 |
109 |
110 | class HardSigmoidJitAutoFn(torch.autograd.Function):
111 | @staticmethod
112 | def forward(ctx, x):
113 | ctx.save_for_backward(x)
114 | return hard_sigmoid_jit_fwd(x)
115 |
116 | @staticmethod
117 | def backward(ctx, grad_output):
118 | x = ctx.saved_tensors[0]
119 | return hard_sigmoid_jit_bwd(x, grad_output)
120 |
121 |
122 | def hard_sigmoid_me(x, inplace: bool = False):
123 | return HardSigmoidJitAutoFn.apply(x)
124 |
125 |
126 | class HardSigmoidMe(nn.Module):
127 | def __init__(self, inplace: bool = False):
128 | super(HardSigmoidMe, self).__init__()
129 |
130 | def forward(self, x):
131 | return HardSigmoidJitAutoFn.apply(x)
132 |
133 |
134 | @torch.jit.script
135 | def hard_swish_jit_fwd(x):
136 | return x * (x + 3).clamp(min=0, max=6).div(6.)
137 |
138 |
139 | @torch.jit.script
140 | def hard_swish_jit_bwd(x, grad_output):
141 | m = torch.ones_like(x) * (x >= 3.)
142 | m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
143 | return grad_output * m
144 |
145 |
146 | class HardSwishJitAutoFn(torch.autograd.Function):
147 | """A memory efficient, jit-scripted HardSwish activation"""
148 | @staticmethod
149 | def forward(ctx, x):
150 | ctx.save_for_backward(x)
151 | return hard_swish_jit_fwd(x)
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | x = ctx.saved_tensors[0]
156 | return hard_swish_jit_bwd(x, grad_output)
157 |
158 | @staticmethod
159 | def symbolic(g, self):
160 | input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
161 | hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
162 | hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
163 | return g.op("Mul", self, hardtanh_)
164 |
165 |
166 | def hard_swish_me(x, inplace=False):
167 | return HardSwishJitAutoFn.apply(x)
168 |
169 |
170 | class HardSwishMe(nn.Module):
171 | def __init__(self, inplace: bool = False):
172 | super(HardSwishMe, self).__init__()
173 |
174 | def forward(self, x):
175 | return HardSwishJitAutoFn.apply(x)
176 |
177 |
178 | @torch.jit.script
179 | def hard_mish_jit_fwd(x):
180 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
181 |
182 |
183 | @torch.jit.script
184 | def hard_mish_jit_bwd(x, grad_output):
185 | m = torch.ones_like(x) * (x >= -2.)
186 | m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
187 | return grad_output * m
188 |
189 |
190 | class HardMishJitAutoFn(torch.autograd.Function):
191 | """ A memory efficient, jit scripted variant of Hard Mish
192 | Experimental, based on notes by Mish author Diganta Misra at
193 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
194 | """
195 | @staticmethod
196 | def forward(ctx, x):
197 | ctx.save_for_backward(x)
198 | return hard_mish_jit_fwd(x)
199 |
200 | @staticmethod
201 | def backward(ctx, grad_output):
202 | x = ctx.saved_tensors[0]
203 | return hard_mish_jit_bwd(x, grad_output)
204 |
205 |
206 | def hard_mish_me(x, inplace: bool = False):
207 | return HardMishJitAutoFn.apply(x)
208 |
209 |
210 | class HardMishMe(nn.Module):
211 | def __init__(self, inplace: bool = False):
212 | super(HardMishMe, self).__init__()
213 |
214 | def forward(self, x):
215 | return HardMishJitAutoFn.apply(x)
216 |
217 |
218 |
219 |
--------------------------------------------------------------------------------
/models/layers/adaptive_avgmax_pool.py:
--------------------------------------------------------------------------------
1 | """ PyTorch selectable adaptive pooling
2 | Adaptive pooling with the ability to select the type of pooling from:
3 | * 'avg' - Average pooling
4 | * 'max' - Max pooling
5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
7 |
8 | Both a functional and a nn.Module version of the pooling is provided.
9 |
10 | Hacked together by / Copyright 2020 Ross Wightman
11 | """
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 |
17 | def adaptive_pool_feat_mult(pool_type='avg'):
18 | if pool_type == 'catavgmax':
19 | return 2
20 | else:
21 | return 1
22 |
23 |
24 | def adaptive_avgmax_pool2d(x, output_size=1):
25 | x_avg = F.adaptive_avg_pool2d(x, output_size)
26 | x_max = F.adaptive_max_pool2d(x, output_size)
27 | return 0.5 * (x_avg + x_max)
28 |
29 |
30 | def adaptive_catavgmax_pool2d(x, output_size=1):
31 | x_avg = F.adaptive_avg_pool2d(x, output_size)
32 | x_max = F.adaptive_max_pool2d(x, output_size)
33 | return torch.cat((x_avg, x_max), 1)
34 |
35 |
36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
37 | """Selectable global pooling function with dynamic input kernel size
38 | """
39 | if pool_type == 'avg':
40 | x = F.adaptive_avg_pool2d(x, output_size)
41 | elif pool_type == 'avgmax':
42 | x = adaptive_avgmax_pool2d(x, output_size)
43 | elif pool_type == 'catavgmax':
44 | x = adaptive_catavgmax_pool2d(x, output_size)
45 | elif pool_type == 'max':
46 | x = F.adaptive_max_pool2d(x, output_size)
47 | else:
48 | assert False, 'Invalid pool type: %s' % pool_type
49 | return x
50 |
51 |
52 | class FastAdaptiveAvgPool2d(nn.Module):
53 | def __init__(self, flatten=False):
54 | super(FastAdaptiveAvgPool2d, self).__init__()
55 | self.flatten = flatten
56 |
57 | def forward(self, x):
58 | return x.mean((2, 3), keepdim=not self.flatten)
59 |
60 |
61 | class AdaptiveAvgMaxPool2d(nn.Module):
62 | def __init__(self, output_size=1):
63 | super(AdaptiveAvgMaxPool2d, self).__init__()
64 | self.output_size = output_size
65 |
66 | def forward(self, x):
67 | return adaptive_avgmax_pool2d(x, self.output_size)
68 |
69 |
70 | class AdaptiveCatAvgMaxPool2d(nn.Module):
71 | def __init__(self, output_size=1):
72 | super(AdaptiveCatAvgMaxPool2d, self).__init__()
73 | self.output_size = output_size
74 |
75 | def forward(self, x):
76 | return adaptive_catavgmax_pool2d(x, self.output_size)
77 |
78 |
79 | class SelectAdaptivePool2d(nn.Module):
80 | """Selectable global pooling layer with dynamic input kernel size
81 | """
82 | def __init__(self, output_size=1, pool_type='fast', flatten=False):
83 | super(SelectAdaptivePool2d, self).__init__()
84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
85 | self.flatten = nn.Flatten(1) if flatten else nn.Identity()
86 | if pool_type == '':
87 | self.pool = nn.Identity() # pass through
88 | elif pool_type == 'fast':
89 | assert output_size == 1
90 | self.pool = FastAdaptiveAvgPool2d(flatten)
91 | self.flatten = nn.Identity()
92 | elif pool_type == 'avg':
93 | self.pool = nn.AdaptiveAvgPool2d(output_size)
94 | elif pool_type == 'avgmax':
95 | self.pool = AdaptiveAvgMaxPool2d(output_size)
96 | elif pool_type == 'catavgmax':
97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size)
98 | elif pool_type == 'max':
99 | self.pool = nn.AdaptiveMaxPool2d(output_size)
100 | else:
101 | assert False, 'Invalid pool type: %s' % pool_type
102 |
103 | def is_identity(self):
104 | return not self.pool_type
105 |
106 | def forward(self, x):
107 | x = self.pool(x)
108 | x = self.flatten(x)
109 | return x
110 |
111 | def feat_mult(self):
112 | return adaptive_pool_feat_mult(self.pool_type)
113 |
114 | def __repr__(self):
115 | return self.__class__.__name__ + ' (' \
116 | + 'pool_type=' + self.pool_type \
117 | + ', flatten=' + str(self.flatten) + ')'
118 |
119 |
--------------------------------------------------------------------------------
/models/layers/attention_pool2d.py:
--------------------------------------------------------------------------------
1 | """ Attention Pool 2D
2 |
3 | Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
4 |
5 | Based on idea in CLIP by OpenAI, licensed Apache 2.0
6 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
7 |
8 | Hacked together by / Copyright 2021 Ross Wightman
9 | """
10 | from typing import Union, Tuple
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 | from .helpers import to_2tuple
16 | from .pos_embed import apply_rot_embed, RotaryEmbedding
17 | from .weight_init import trunc_normal_
18 |
19 |
20 | class RotAttentionPool2d(nn.Module):
21 | """ Attention based 2D feature pooling w/ rotary (relative) pos embedding.
22 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
23 |
24 | Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
25 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
26 |
27 | NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
28 | train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
29 | """
30 | def __init__(
31 | self,
32 | in_features: int,
33 | out_features: int = None,
34 | embed_dim: int = None,
35 | num_heads: int = 4,
36 | qkv_bias: bool = True,
37 | ):
38 | super().__init__()
39 | embed_dim = embed_dim or in_features
40 | out_features = out_features or in_features
41 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
42 | self.proj = nn.Linear(embed_dim, out_features)
43 | self.num_heads = num_heads
44 | assert embed_dim % num_heads == 0
45 | self.head_dim = embed_dim // num_heads
46 | self.scale = self.head_dim ** -0.5
47 | self.pos_embed = RotaryEmbedding(self.head_dim)
48 |
49 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
50 | nn.init.zeros_(self.qkv.bias)
51 |
52 | def forward(self, x):
53 | B, _, H, W = x.shape
54 | N = H * W
55 | x = x.reshape(B, -1, N).permute(0, 2, 1)
56 |
57 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
58 |
59 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
60 | q, k, v = x[0], x[1], x[2]
61 |
62 | qc, q = q[:, :, :1], q[:, :, 1:]
63 | sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
64 | q = apply_rot_embed(q, sin_emb, cos_emb)
65 | q = torch.cat([qc, q], dim=2)
66 |
67 | kc, k = k[:, :, :1], k[:, :, 1:]
68 | k = apply_rot_embed(k, sin_emb, cos_emb)
69 | k = torch.cat([kc, k], dim=2)
70 |
71 | attn = (q @ k.transpose(-2, -1)) * self.scale
72 | attn = attn.softmax(dim=-1)
73 |
74 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
75 | x = self.proj(x)
76 | return x[:, 0]
77 |
78 |
79 | class AttentionPool2d(nn.Module):
80 | """ Attention based 2D feature pooling w/ learned (absolute) pos embedding.
81 | This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
82 |
83 | It was based on impl in CLIP by OpenAI
84 | https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
85 |
86 | NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
87 | """
88 | def __init__(
89 | self,
90 | in_features: int,
91 | feat_size: Union[int, Tuple[int, int]],
92 | out_features: int = None,
93 | embed_dim: int = None,
94 | num_heads: int = 4,
95 | qkv_bias: bool = True,
96 | ):
97 | super().__init__()
98 |
99 | embed_dim = embed_dim or in_features
100 | out_features = out_features or in_features
101 | assert embed_dim % num_heads == 0
102 | self.feat_size = to_2tuple(feat_size)
103 | self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
104 | self.proj = nn.Linear(embed_dim, out_features)
105 | self.num_heads = num_heads
106 | self.head_dim = embed_dim // num_heads
107 | self.scale = self.head_dim ** -0.5
108 |
109 | spatial_dim = self.feat_size[0] * self.feat_size[1]
110 | self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
111 | trunc_normal_(self.pos_embed, std=in_features ** -0.5)
112 | trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
113 | nn.init.zeros_(self.qkv.bias)
114 |
115 | def forward(self, x):
116 | B, _, H, W = x.shape
117 | N = H * W
118 | assert self.feat_size[0] == H
119 | assert self.feat_size[1] == W
120 | x = x.reshape(B, -1, N).permute(0, 2, 1)
121 | x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
122 | x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
123 |
124 | x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
125 | q, k, v = x[0], x[1], x[2]
126 | attn = (q @ k.transpose(-2, -1)) * self.scale
127 | attn = attn.softmax(dim=-1)
128 |
129 | x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
130 | x = self.proj(x)
131 | return x[:, 0]
132 |
--------------------------------------------------------------------------------
/models/layers/blur_pool.py:
--------------------------------------------------------------------------------
1 | """
2 | BlurPool layer inspired by
3 | - Kornia's Max_BlurPool2d
4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5 |
6 | Hacked together by Chris Ha and Ross Wightman
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import numpy as np
13 | from .padding import get_padding
14 |
15 |
16 | class BlurPool2d(nn.Module):
17 | r"""Creates a module that computes blurs and downsample a given feature map.
18 | See :cite:`zhang2019shiftinvar` for more details.
19 | Corresponds to the Downsample class, which does blurring and subsampling
20 |
21 | Args:
22 | channels = Number of input channels
23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
24 | stride (int): downsampling filter stride
25 |
26 | Returns:
27 | torch.Tensor: the transformed tensor.
28 | """
29 | def __init__(self, channels, filt_size=3, stride=2) -> None:
30 | super(BlurPool2d, self).__init__()
31 | assert filt_size > 1
32 | self.channels = channels
33 | self.filt_size = filt_size
34 | self.stride = stride
35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
38 | self.register_buffer('filt', blur_filter, persistent=False)
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | x = F.pad(x, self.padding, 'reflect')
42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
43 |
--------------------------------------------------------------------------------
/models/layers/cbam.py:
--------------------------------------------------------------------------------
1 | """ CBAM (sort-of) Attention
2 |
3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
4 |
5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
6 | some tasks, especially fine-grained it seems. I may end up removing this impl.
7 |
8 | Hacked together by / Copyright 2020 Ross Wightman
9 | """
10 | import torch
11 | from torch import nn as nn
12 | import torch.nn.functional as F
13 |
14 | from .conv_bn_act import ConvNormAct
15 | from .create_act import create_act_layer, get_act_layer
16 | from .helpers import make_divisible
17 |
18 |
19 | class ChannelAttn(nn.Module):
20 | """ Original CBAM channel attention module, currently avg + max pool variant only.
21 | """
22 | def __init__(
23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
25 | super(ChannelAttn, self).__init__()
26 | if not rd_channels:
27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
29 | self.act = act_layer(inplace=True)
30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
31 | self.gate = create_act_layer(gate_layer)
32 |
33 | def forward(self, x):
34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
36 | return x * self.gate(x_avg + x_max)
37 |
38 |
39 | class LightChannelAttn(ChannelAttn):
40 | """An experimental 'lightweight' that sums avg + max pool first
41 | """
42 | def __init__(
43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
45 | super(LightChannelAttn, self).__init__(
46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
47 |
48 | def forward(self, x):
49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
50 | x_attn = self.fc2(self.act(self.fc1(x_pool)))
51 | return x * F.sigmoid(x_attn)
52 |
53 |
54 | class SpatialAttn(nn.Module):
55 | """ Original CBAM spatial attention module
56 | """
57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'):
58 | super(SpatialAttn, self).__init__()
59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
60 | self.gate = create_act_layer(gate_layer)
61 |
62 | def forward(self, x):
63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
64 | x_attn = self.conv(x_attn)
65 | return x * self.gate(x_attn)
66 |
67 |
68 | class LightSpatialAttn(nn.Module):
69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results.
70 | """
71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'):
72 | super(LightSpatialAttn, self).__init__()
73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
74 | self.gate = create_act_layer(gate_layer)
75 |
76 | def forward(self, x):
77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
78 | x_attn = self.conv(x_attn)
79 | return x * self.gate(x_attn)
80 |
81 |
82 | class CbamModule(nn.Module):
83 | def __init__(
84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
86 | super(CbamModule, self).__init__()
87 | self.channel = ChannelAttn(
88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
91 |
92 | def forward(self, x):
93 | x = self.channel(x)
94 | x = self.spatial(x)
95 | return x
96 |
97 |
98 | class LightCbamModule(nn.Module):
99 | def __init__(
100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
102 | super(LightCbamModule, self).__init__()
103 | self.channel = LightChannelAttn(
104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
106 | self.spatial = LightSpatialAttn(spatial_kernel_size)
107 |
108 | def forward(self, x):
109 | x = self.channel(x)
110 | x = self.spatial(x)
111 | return x
112 |
113 |
--------------------------------------------------------------------------------
/models/layers/classifier.py:
--------------------------------------------------------------------------------
1 | """ Classifier head and layer factory
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from torch import nn as nn
6 | from torch.nn import functional as F
7 |
8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d
9 |
10 |
11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
13 | if not pool_type:
14 | assert num_classes == 0 or use_conv,\
15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used'
16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
18 | num_pooled_features = num_features * global_pool.feat_mult()
19 | return global_pool, num_pooled_features
20 |
21 |
22 | def _create_fc(num_features, num_classes, use_conv=False):
23 | if num_classes <= 0:
24 | fc = nn.Identity() # pass-through (no classifier)
25 | elif use_conv:
26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
27 | else:
28 | fc = nn.Linear(num_features, num_classes, bias=True)
29 | return fc
30 |
31 |
32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
35 | return global_pool, fc
36 |
37 |
38 | class ClassifierHead(nn.Module):
39 | """Classifier head w/ configurable global pooling and dropout."""
40 |
41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
42 | super(ClassifierHead, self).__init__()
43 | self.drop_rate = drop_rate
44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
47 |
48 | def forward(self, x, pre_logits: bool = False):
49 | x = self.global_pool(x)
50 | if self.drop_rate:
51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training)
52 | if pre_logits:
53 | return x.flatten(1)
54 | else:
55 | x = self.fc(x)
56 | return self.flatten(x)
57 |
--------------------------------------------------------------------------------
/models/layers/cond_conv2d.py:
--------------------------------------------------------------------------------
1 | """ PyTorch Conditionally Parameterized Convolution (CondConv)
2 |
3 | Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
4 | (https://arxiv.org/abs/1904.04971)
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 |
9 | import math
10 | from functools import partial
11 | import numpy as np
12 | import torch
13 | from torch import nn as nn
14 | from torch.nn import functional as F
15 |
16 | from .helpers import to_2tuple
17 | from .conv2d_same import conv2d_same
18 | from .padding import get_padding_value
19 |
20 |
21 | def get_condconv_initializer(initializer, num_experts, expert_shape):
22 | def condconv_initializer(weight):
23 | """CondConv initializer function."""
24 | num_params = np.prod(expert_shape)
25 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
26 | weight.shape[1] != num_params):
27 | raise (ValueError(
28 | 'CondConv variables must have shape [num_experts, num_params]'))
29 | for i in range(num_experts):
30 | initializer(weight[i].view(expert_shape))
31 | return condconv_initializer
32 |
33 |
34 | class CondConv2d(nn.Module):
35 | """ Conditionally Parameterized Convolution
36 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
37 |
38 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
39 | https://github.com/pytorch/pytorch/issues/17983
40 | """
41 | __constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
42 |
43 | def __init__(self, in_channels, out_channels, kernel_size=3,
44 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
45 | super(CondConv2d, self).__init__()
46 |
47 | self.in_channels = in_channels
48 | self.out_channels = out_channels
49 | self.kernel_size = to_2tuple(kernel_size)
50 | self.stride = to_2tuple(stride)
51 | padding_val, is_padding_dynamic = get_padding_value(
52 | padding, kernel_size, stride=stride, dilation=dilation)
53 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
54 | self.padding = to_2tuple(padding_val)
55 | self.dilation = to_2tuple(dilation)
56 | self.groups = groups
57 | self.num_experts = num_experts
58 |
59 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
60 | weight_num_param = 1
61 | for wd in self.weight_shape:
62 | weight_num_param *= wd
63 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
64 |
65 | if bias:
66 | self.bias_shape = (self.out_channels,)
67 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
68 | else:
69 | self.register_parameter('bias', None)
70 |
71 | self.reset_parameters()
72 |
73 | def reset_parameters(self):
74 | init_weight = get_condconv_initializer(
75 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
76 | init_weight(self.weight)
77 | if self.bias is not None:
78 | fan_in = np.prod(self.weight_shape[1:])
79 | bound = 1 / math.sqrt(fan_in)
80 | init_bias = get_condconv_initializer(
81 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
82 | init_bias(self.bias)
83 |
84 | def forward(self, x, routing_weights):
85 | B, C, H, W = x.shape
86 | weight = torch.matmul(routing_weights, self.weight)
87 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
88 | weight = weight.view(new_weight_shape)
89 | bias = None
90 | if self.bias is not None:
91 | bias = torch.matmul(routing_weights, self.bias)
92 | bias = bias.view(B * self.out_channels)
93 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
94 | x = x.view(1, B * C, H, W)
95 | if self.dynamic_padding:
96 | out = conv2d_same(
97 | x, weight, bias, stride=self.stride, padding=self.padding,
98 | dilation=self.dilation, groups=self.groups * B)
99 | else:
100 | out = F.conv2d(
101 | x, weight, bias, stride=self.stride, padding=self.padding,
102 | dilation=self.dilation, groups=self.groups * B)
103 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
104 |
105 | # Literal port (from TF definition)
106 | # x = torch.split(x, 1, 0)
107 | # weight = torch.split(weight, 1, 0)
108 | # if self.bias is not None:
109 | # bias = torch.matmul(routing_weights, self.bias)
110 | # bias = torch.split(bias, 1, 0)
111 | # else:
112 | # bias = [None] * B
113 | # out = []
114 | # for xi, wi, bi in zip(x, weight, bias):
115 | # wi = wi.view(*self.weight_shape)
116 | # if bi is not None:
117 | # bi = bi.view(*self.bias_shape)
118 | # out.append(self.conv_fn(
119 | # xi, wi, bi, stride=self.stride, padding=self.padding,
120 | # dilation=self.dilation, groups=self.groups))
121 | # out = torch.cat(out, 0)
122 | return out
123 |
--------------------------------------------------------------------------------
/models/layers/config.py:
--------------------------------------------------------------------------------
1 | """ Model / Layer Config singleton state
2 | """
3 | from typing import Any, Optional
4 |
5 | __all__ = [
6 | 'is_exportable', 'is_scriptable', 'is_no_jit',
7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
8 | ]
9 |
10 | # Set to True if prefer to have layers with no jit optimization (includes activations)
11 | _NO_JIT = False
12 |
13 | # Set to True if prefer to have activation layers with no jit optimization
14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
15 | # the jit flags so far are activations. This will change as more layers are updated and/or added.
16 | _NO_ACTIVATION_JIT = False
17 |
18 | # Set to True if exporting a model with Same padding via ONNX
19 | _EXPORTABLE = False
20 |
21 | # Set to True if wanting to use torch.jit.script on a model
22 | _SCRIPTABLE = False
23 |
24 |
25 | def is_no_jit():
26 | return _NO_JIT
27 |
28 |
29 | class set_no_jit:
30 | def __init__(self, mode: bool) -> None:
31 | global _NO_JIT
32 | self.prev = _NO_JIT
33 | _NO_JIT = mode
34 |
35 | def __enter__(self) -> None:
36 | pass
37 |
38 | def __exit__(self, *args: Any) -> bool:
39 | global _NO_JIT
40 | _NO_JIT = self.prev
41 | return False
42 |
43 |
44 | def is_exportable():
45 | return _EXPORTABLE
46 |
47 |
48 | class set_exportable:
49 | def __init__(self, mode: bool) -> None:
50 | global _EXPORTABLE
51 | self.prev = _EXPORTABLE
52 | _EXPORTABLE = mode
53 |
54 | def __enter__(self) -> None:
55 | pass
56 |
57 | def __exit__(self, *args: Any) -> bool:
58 | global _EXPORTABLE
59 | _EXPORTABLE = self.prev
60 | return False
61 |
62 |
63 | def is_scriptable():
64 | return _SCRIPTABLE
65 |
66 |
67 | class set_scriptable:
68 | def __init__(self, mode: bool) -> None:
69 | global _SCRIPTABLE
70 | self.prev = _SCRIPTABLE
71 | _SCRIPTABLE = mode
72 |
73 | def __enter__(self) -> None:
74 | pass
75 |
76 | def __exit__(self, *args: Any) -> bool:
77 | global _SCRIPTABLE
78 | _SCRIPTABLE = self.prev
79 | return False
80 |
81 |
82 | class set_layer_config:
83 | """ Layer config context manager that allows setting all layer config flags at once.
84 | If a flag arg is None, it will not change the current value.
85 | """
86 | def __init__(
87 | self,
88 | scriptable: Optional[bool] = None,
89 | exportable: Optional[bool] = None,
90 | no_jit: Optional[bool] = None,
91 | no_activation_jit: Optional[bool] = None):
92 | global _SCRIPTABLE
93 | global _EXPORTABLE
94 | global _NO_JIT
95 | global _NO_ACTIVATION_JIT
96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
97 | if scriptable is not None:
98 | _SCRIPTABLE = scriptable
99 | if exportable is not None:
100 | _EXPORTABLE = exportable
101 | if no_jit is not None:
102 | _NO_JIT = no_jit
103 | if no_activation_jit is not None:
104 | _NO_ACTIVATION_JIT = no_activation_jit
105 |
106 | def __enter__(self) -> None:
107 | pass
108 |
109 | def __exit__(self, *args: Any) -> bool:
110 | global _SCRIPTABLE
111 | global _EXPORTABLE
112 | global _NO_JIT
113 | global _NO_ACTIVATION_JIT
114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
115 | return False
116 |
--------------------------------------------------------------------------------
/models/layers/conv2d_same.py:
--------------------------------------------------------------------------------
1 | """ Conv2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import Tuple, Optional
9 |
10 | from .padding import pad_same, get_padding_value
11 |
12 |
13 | def conv2d_same(
14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
16 | x = pad_same(x, weight.shape[-2:], stride, dilation)
17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
18 |
19 |
20 | class Conv2dSame(nn.Conv2d):
21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
22 | """
23 |
24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
25 | padding=0, dilation=1, groups=1, bias=True):
26 | super(Conv2dSame, self).__init__(
27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
28 |
29 | def forward(self, x):
30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
31 |
32 |
33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
34 | padding = kwargs.pop('padding', '')
35 | kwargs.setdefault('bias', False)
36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
37 | if is_dynamic:
38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
39 | else:
40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
41 |
42 |
43 |
--------------------------------------------------------------------------------
/models/layers/conv_bn_act.py:
--------------------------------------------------------------------------------
1 | """ Conv2d + BN + Act
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from torch import nn as nn
6 |
7 | from .create_conv2d import create_conv2d
8 | from .create_norm_act import get_norm_act_layer
9 |
10 |
11 | class ConvNormAct(nn.Module):
12 | def __init__(
13 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
14 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
15 | super(ConvNormAct, self).__init__()
16 | self.conv = create_conv2d(
17 | in_channels, out_channels, kernel_size, stride=stride,
18 | padding=padding, dilation=dilation, groups=groups, bias=bias)
19 |
20 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
21 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
22 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
23 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
25 |
26 | @property
27 | def in_channels(self):
28 | return self.conv.in_channels
29 |
30 | @property
31 | def out_channels(self):
32 | return self.conv.out_channels
33 |
34 | def forward(self, x):
35 | x = self.conv(x)
36 | x = self.bn(x)
37 | return x
38 |
39 |
40 | ConvBnAct = ConvNormAct
41 |
42 |
43 | class ConvNormActAa(nn.Module):
44 | def __init__(
45 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
46 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
47 | super(ConvNormActAa, self).__init__()
48 | use_aa = aa_layer is not None
49 |
50 | self.conv = create_conv2d(
51 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
52 | padding=padding, dilation=dilation, groups=groups, bias=bias)
53 |
54 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
55 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
56 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
57 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
58 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
59 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
60 |
61 | @property
62 | def in_channels(self):
63 | return self.conv.in_channels
64 |
65 | @property
66 | def out_channels(self):
67 | return self.conv.out_channels
68 |
69 | def forward(self, x):
70 | x = self.conv(x)
71 | x = self.bn(x)
72 | x = self.aa(x)
73 | return x
74 |
--------------------------------------------------------------------------------
/models/layers/create_act.py:
--------------------------------------------------------------------------------
1 | """ Activation Factory
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | from typing import Union, Callable, Type
5 |
6 | from .activations import *
7 | from .activations_jit import *
8 | from .activations_me import *
9 | from .config import is_exportable, is_scriptable, is_no_jit
10 |
11 | # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
12 | # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
13 | # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
14 | _has_silu = 'silu' in dir(torch.nn.functional)
15 | _has_hardswish = 'hardswish' in dir(torch.nn.functional)
16 | _has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
17 | _has_mish = 'mish' in dir(torch.nn.functional)
18 |
19 |
20 | _ACT_FN_DEFAULT = dict(
21 | silu=F.silu if _has_silu else swish,
22 | swish=F.silu if _has_silu else swish,
23 | mish=F.mish if _has_mish else mish,
24 | relu=F.relu,
25 | relu6=F.relu6,
26 | leaky_relu=F.leaky_relu,
27 | elu=F.elu,
28 | celu=F.celu,
29 | selu=F.selu,
30 | gelu=gelu,
31 | sigmoid=sigmoid,
32 | tanh=tanh,
33 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
34 | hard_swish=F.hardswish if _has_hardswish else hard_swish,
35 | hard_mish=hard_mish,
36 | )
37 |
38 | _ACT_FN_JIT = dict(
39 | silu=F.silu if _has_silu else swish_jit,
40 | swish=F.silu if _has_silu else swish_jit,
41 | mish=F.mish if _has_mish else mish_jit,
42 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
43 | hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
44 | hard_mish=hard_mish_jit
45 | )
46 |
47 | _ACT_FN_ME = dict(
48 | silu=F.silu if _has_silu else swish_me,
49 | swish=F.silu if _has_silu else swish_me,
50 | mish=F.mish if _has_mish else mish_me,
51 | hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
52 | hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
53 | hard_mish=hard_mish_me,
54 | )
55 |
56 | _ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
57 | for a in _ACT_FNS:
58 | a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
59 | a.setdefault('hardswish', a.get('hard_swish'))
60 |
61 |
62 | _ACT_LAYER_DEFAULT = dict(
63 | silu=nn.SiLU if _has_silu else Swish,
64 | swish=nn.SiLU if _has_silu else Swish,
65 | mish=nn.Mish if _has_mish else Mish,
66 | relu=nn.ReLU,
67 | relu6=nn.ReLU6,
68 | leaky_relu=nn.LeakyReLU,
69 | elu=nn.ELU,
70 | prelu=PReLU,
71 | celu=nn.CELU,
72 | selu=nn.SELU,
73 | gelu=GELU,
74 | sigmoid=Sigmoid,
75 | tanh=Tanh,
76 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
77 | hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
78 | hard_mish=HardMish,
79 | )
80 |
81 | _ACT_LAYER_JIT = dict(
82 | silu=nn.SiLU if _has_silu else SwishJit,
83 | swish=nn.SiLU if _has_silu else SwishJit,
84 | mish=nn.Mish if _has_mish else MishJit,
85 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
86 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
87 | hard_mish=HardMishJit
88 | )
89 |
90 | _ACT_LAYER_ME = dict(
91 | silu=nn.SiLU if _has_silu else SwishMe,
92 | swish=nn.SiLU if _has_silu else SwishMe,
93 | mish=nn.Mish if _has_mish else MishMe,
94 | hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
95 | hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
96 | hard_mish=HardMishMe,
97 | )
98 |
99 | _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
100 | for a in _ACT_LAYERS:
101 | a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
102 | a.setdefault('hardswish', a.get('hard_swish'))
103 |
104 |
105 | def get_act_fn(name: Union[Callable, str] = 'relu'):
106 | """ Activation Function Factory
107 | Fetching activation fns by name with this function allows export or torch script friendly
108 | functions to be returned dynamically based on current config.
109 | """
110 | if not name:
111 | return None
112 | if isinstance(name, Callable):
113 | return name
114 | if not (is_no_jit() or is_exportable() or is_scriptable()):
115 | # If not exporting or scripting the model, first look for a memory-efficient version with
116 | # custom autograd, then fallback
117 | if name in _ACT_FN_ME:
118 | return _ACT_FN_ME[name]
119 | if not (is_no_jit() or is_exportable()):
120 | if name in _ACT_FN_JIT:
121 | return _ACT_FN_JIT[name]
122 | return _ACT_FN_DEFAULT[name]
123 |
124 |
125 | def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
126 | """ Activation Layer Factory
127 | Fetching activation layers by name with this function allows export or torch script friendly
128 | functions to be returned dynamically based on current config.
129 | """
130 | if not name:
131 | return None
132 | if not isinstance(name, str):
133 | # callable, module, etc
134 | return name
135 | if not (is_no_jit() or is_exportable() or is_scriptable()):
136 | if name in _ACT_LAYER_ME:
137 | return _ACT_LAYER_ME[name]
138 | if not (is_no_jit() or is_exportable()):
139 | if name in _ACT_LAYER_JIT:
140 | return _ACT_LAYER_JIT[name]
141 | return _ACT_LAYER_DEFAULT[name]
142 |
143 |
144 | def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
145 | act_layer = get_act_layer(name)
146 | if act_layer is None:
147 | return None
148 | return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)
149 |
--------------------------------------------------------------------------------
/models/layers/create_attn.py:
--------------------------------------------------------------------------------
1 | """ Attention Factory
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 | import torch
6 | from functools import partial
7 |
8 | from .bottleneck_attn import BottleneckAttn
9 | from .cbam import CbamModule, LightCbamModule
10 | from .eca import EcaModule, CecaModule
11 | from .gather_excite import GatherExcite
12 | from .global_context import GlobalContext
13 | from .halo_attn import HaloAttn
14 | from .lambda_layer import LambdaLayer
15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
16 | from .selective_kernel import SelectiveKernel
17 | from .split_attn import SplitAttn
18 | from .squeeze_excite import SEModule, EffectiveSEModule
19 |
20 |
21 | def get_attn(attn_type):
22 | if isinstance(attn_type, torch.nn.Module):
23 | return attn_type
24 | module_cls = None
25 | if attn_type is not None:
26 | if isinstance(attn_type, str):
27 | attn_type = attn_type.lower()
28 | # Lightweight attention modules (channel and/or coarse spatial).
29 | # Typically added to existing network architecture blocks in addition to existing convolutions.
30 | if attn_type == 'se':
31 | module_cls = SEModule
32 | elif attn_type == 'ese':
33 | module_cls = EffectiveSEModule
34 | elif attn_type == 'eca':
35 | module_cls = EcaModule
36 | elif attn_type == 'ecam':
37 | module_cls = partial(EcaModule, use_mlp=True)
38 | elif attn_type == 'ceca':
39 | module_cls = CecaModule
40 | elif attn_type == 'ge':
41 | module_cls = GatherExcite
42 | elif attn_type == 'gc':
43 | module_cls = GlobalContext
44 | elif attn_type == 'gca':
45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
46 | elif attn_type == 'cbam':
47 | module_cls = CbamModule
48 | elif attn_type == 'lcbam':
49 | module_cls = LightCbamModule
50 |
51 | # Attention / attention-like modules w/ significant params
52 | # Typically replace some of the existing workhorse convs in a network architecture.
53 | # All of these accept a stride argument and can spatially downsample the input.
54 | elif attn_type == 'sk':
55 | module_cls = SelectiveKernel
56 | elif attn_type == 'splat':
57 | module_cls = SplitAttn
58 |
59 | # Self-attention / attention-like modules w/ significant compute and/or params
60 | # Typically replace some of the existing workhorse convs in a network architecture.
61 | # All of these accept a stride argument and can spatially downsample the input.
62 | elif attn_type == 'lambda':
63 | return LambdaLayer
64 | elif attn_type == 'bottleneck':
65 | return BottleneckAttn
66 | elif attn_type == 'halo':
67 | return HaloAttn
68 | elif attn_type == 'nl':
69 | module_cls = NonLocalAttn
70 | elif attn_type == 'bat':
71 | module_cls = BatNonLocalAttn
72 |
73 | # Woops!
74 | else:
75 | assert False, "Invalid attn module (%s)" % attn_type
76 | elif isinstance(attn_type, bool):
77 | if attn_type:
78 | module_cls = SEModule
79 | else:
80 | module_cls = attn_type
81 | return module_cls
82 |
83 |
84 | def create_attn(attn_type, channels, **kwargs):
85 | module_cls = get_attn(attn_type)
86 | if module_cls is not None:
87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
88 | return module_cls(channels, **kwargs)
89 | return None
90 |
--------------------------------------------------------------------------------
/models/layers/create_conv2d.py:
--------------------------------------------------------------------------------
1 | """ Create Conv2d Factory Method
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | from .mixed_conv2d import MixedConv2d
7 | from .cond_conv2d import CondConv2d
8 | from .conv2d_same import create_conv2d_pad
9 |
10 |
11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
12 | """ Select a 2d convolution implementation based on arguments
13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
14 |
15 | Used extensively by EfficientNet, MobileNetv3 and related networks.
16 | """
17 | if isinstance(kernel_size, list):
18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
19 | if 'groups' in kwargs:
20 | groups = kwargs.pop('groups')
21 | if groups == in_channels:
22 | kwargs['depthwise'] = True
23 | else:
24 | assert groups == 1
25 | # We're going to use only lists for defining the MixedConv2d kernel groups,
26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
28 | else:
29 | depthwise = kwargs.pop('depthwise', False)
30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
31 | groups = in_channels if depthwise else kwargs.pop('groups', 1)
32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
34 | else:
35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
36 | return m
37 |
--------------------------------------------------------------------------------
/models/layers/create_norm_act.py:
--------------------------------------------------------------------------------
1 | """ NormAct (Normalizaiton + Activation Layer) Factory
2 |
3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with
5 | combined modules like IABN or EvoNorms.
6 |
7 | Hacked together by / Copyright 2020 Ross Wightman
8 | """
9 | import types
10 | import functools
11 |
12 | from .evo_norm import *
13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
15 | from .inplace_abn import InplaceAbn
16 |
17 | _NORM_ACT_MAP = dict(
18 | batchnorm=BatchNormAct2d,
19 | batchnorm2d=BatchNormAct2d,
20 | groupnorm=GroupNormAct,
21 | layernorm=LayerNormAct,
22 | layernorm2d=LayerNormAct2d,
23 | evonormb0=EvoNorm2dB0,
24 | evonormb1=EvoNorm2dB1,
25 | evonormb2=EvoNorm2dB2,
26 | evonorms0=EvoNorm2dS0,
27 | evonorms0a=EvoNorm2dS0a,
28 | evonorms1=EvoNorm2dS1,
29 | evonorms1a=EvoNorm2dS1a,
30 | evonorms2=EvoNorm2dS2,
31 | evonorms2a=EvoNorm2dS2a,
32 | frn=FilterResponseNormAct2d,
33 | frntlu=FilterResponseNormTlu2d,
34 | inplaceabn=InplaceAbn,
35 | iabn=InplaceAbn,
36 | )
37 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
38 | # has act_layer arg to define act type
39 | _NORM_ACT_REQUIRES_ARG = {
40 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
41 |
42 |
43 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
44 | layer = get_norm_act_layer(layer_name, act_layer=act_layer)
45 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
46 | if jit:
47 | layer_instance = torch.jit.script(layer_instance)
48 | return layer_instance
49 |
50 |
51 | def get_norm_act_layer(norm_layer, act_layer=None):
52 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
53 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
54 | norm_act_kwargs = {}
55 |
56 | # unbind partial fn, so args can be rebound later
57 | if isinstance(norm_layer, functools.partial):
58 | norm_act_kwargs.update(norm_layer.keywords)
59 | norm_layer = norm_layer.func
60 |
61 | if isinstance(norm_layer, str):
62 | layer_name = norm_layer.replace('_', '').lower().split('-')[0]
63 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
64 | elif norm_layer in _NORM_ACT_TYPES:
65 | norm_act_layer = norm_layer
66 | elif isinstance(norm_layer, types.FunctionType):
67 | # if function type, must be a lambda/fn that creates a norm_act layer
68 | norm_act_layer = norm_layer
69 | else:
70 | type_name = norm_layer.__name__.lower()
71 | if type_name.startswith('batchnorm'):
72 | norm_act_layer = BatchNormAct2d
73 | elif type_name.startswith('groupnorm'):
74 | norm_act_layer = GroupNormAct
75 | elif type_name.startswith('layernorm2d'):
76 | norm_act_layer = LayerNormAct2d
77 | elif type_name.startswith('layernorm'):
78 | norm_act_layer = LayerNormAct
79 | else:
80 | assert False, f"No equivalent norm_act layer for {type_name}"
81 |
82 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
83 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
84 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
85 | norm_act_kwargs.setdefault('act_layer', act_layer)
86 | if norm_act_kwargs:
87 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
88 | return norm_act_layer
89 |
--------------------------------------------------------------------------------
/models/layers/drop.py:
--------------------------------------------------------------------------------
1 | """ DropBlock, DropPath
2 |
3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4 |
5 | Papers:
6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7 |
8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9 |
10 | Code:
11 | DropBlock impl inspired by two Tensorflow impl that I liked:
12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | def drop_block_2d(
23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
26 |
27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
28 | runs with success, but needs further validation and possibly optimization for lower runtime impact.
29 | """
30 | B, C, H, W = x.shape
31 | total_size = W * H
32 | clipped_block_size = min(block_size, min(W, H))
33 | # seed_drop_rate, the gamma parameter
34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
35 | (W - block_size + 1) * (H - block_size + 1))
36 |
37 | # Forces the block to be inside the feature map.
38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
42 |
43 | if batchwise:
44 | # one mask for whole batch, quite a bit faster
45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
46 | else:
47 | uniform_noise = torch.rand_like(x)
48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
49 | block_mask = -F.max_pool2d(
50 | -block_mask,
51 | kernel_size=clipped_block_size, # block_size,
52 | stride=1,
53 | padding=clipped_block_size // 2)
54 |
55 | if with_noise:
56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
57 | if inplace:
58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
59 | else:
60 | x = x * block_mask + normal_noise * (1 - block_mask)
61 | else:
62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
63 | if inplace:
64 | x.mul_(block_mask * normalize_scale)
65 | else:
66 | x = x * block_mask * normalize_scale
67 | return x
68 |
69 |
70 | def drop_block_fast_2d(
71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74 |
75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
76 | block mask at edges.
77 | """
78 | B, C, H, W = x.shape
79 | total_size = W * H
80 | clipped_block_size = min(block_size, min(W, H))
81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
82 | (W - block_size + 1) * (H - block_size + 1))
83 |
84 | block_mask = torch.empty_like(x).bernoulli_(gamma)
85 | block_mask = F.max_pool2d(
86 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
87 |
88 | if with_noise:
89 | normal_noise = torch.empty_like(x).normal_()
90 | if inplace:
91 | x.mul_(1. - block_mask).add_(normal_noise * block_mask)
92 | else:
93 | x = x * (1. - block_mask) + normal_noise * block_mask
94 | else:
95 | block_mask = 1 - block_mask
96 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
97 | if inplace:
98 | x.mul_(block_mask * normalize_scale)
99 | else:
100 | x = x * block_mask * normalize_scale
101 | return x
102 |
103 |
104 | class DropBlock2d(nn.Module):
105 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
106 | """
107 |
108 | def __init__(
109 | self,
110 | drop_prob: float = 0.1,
111 | block_size: int = 7,
112 | gamma_scale: float = 1.0,
113 | with_noise: bool = False,
114 | inplace: bool = False,
115 | batchwise: bool = False,
116 | fast: bool = True):
117 | super(DropBlock2d, self).__init__()
118 | self.drop_prob = drop_prob
119 | self.gamma_scale = gamma_scale
120 | self.block_size = block_size
121 | self.with_noise = with_noise
122 | self.inplace = inplace
123 | self.batchwise = batchwise
124 | self.fast = fast # FIXME finish comparisons of fast vs not
125 |
126 | def forward(self, x):
127 | if not self.training or not self.drop_prob:
128 | return x
129 | if self.fast:
130 | return drop_block_fast_2d(
131 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
132 | else:
133 | return drop_block_2d(
134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
135 |
136 |
137 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
138 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
139 |
140 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
141 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
142 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
143 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
144 | 'survival rate' as the argument.
145 |
146 | """
147 | if drop_prob == 0. or not training:
148 | return x
149 | keep_prob = 1 - drop_prob
150 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
151 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
152 | if keep_prob > 0.0 and scale_by_keep:
153 | random_tensor.div_(keep_prob)
154 | return x * random_tensor
155 |
156 |
157 | class DropPath(nn.Module):
158 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
159 | """
160 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
161 | super(DropPath, self).__init__()
162 | self.drop_prob = drop_prob
163 | self.scale_by_keep = scale_by_keep
164 |
165 | def forward(self, x):
166 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
167 |
--------------------------------------------------------------------------------
/models/layers/eca.py:
--------------------------------------------------------------------------------
1 | """
2 | ECA module from ECAnet
3 |
4 | paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
5 | https://arxiv.org/abs/1910.03151
6 |
7 | Original ECA model borrowed from https://github.com/BangguWu/ECANet
8 |
9 | Modified circular ECA implementation and adaption for use in timm package
10 | by Chris Ha https://github.com/VRandme
11 |
12 | Original License:
13 |
14 | MIT License
15 |
16 | Copyright (c) 2019 BangguWu, Qilong Wang
17 |
18 | Permission is hereby granted, free of charge, to any person obtaining a copy
19 | of this software and associated documentation files (the "Software"), to deal
20 | in the Software without restriction, including without limitation the rights
21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
22 | copies of the Software, and to permit persons to whom the Software is
23 | furnished to do so, subject to the following conditions:
24 |
25 | The above copyright notice and this permission notice shall be included in all
26 | copies or substantial portions of the Software.
27 |
28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34 | SOFTWARE.
35 | """
36 | import math
37 | from torch import nn
38 | import torch.nn.functional as F
39 |
40 |
41 | from .create_act import create_act_layer
42 | from .helpers import make_divisible
43 |
44 |
45 | class EcaModule(nn.Module):
46 | """Constructs an ECA module.
47 |
48 | Args:
49 | channels: Number of channels of the input feature map for use in adaptive kernel sizes
50 | for actual calculations according to channel.
51 | gamma, beta: when channel is given parameters of mapping function
52 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf
53 | (default=None. if channel size not given, use k_size given for kernel size.)
54 | kernel_size: Adaptive selection of kernel size (default=3)
55 | gamm: used in kernel_size calc, see above
56 | beta: used in kernel_size calc, see above
57 | act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
58 | gate_layer: gating non-linearity to use
59 | """
60 | def __init__(
61 | self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
62 | rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
63 | super(EcaModule, self).__init__()
64 | if channels is not None:
65 | t = int(abs(math.log(channels, 2) + beta) / gamma)
66 | kernel_size = max(t if t % 2 else t + 1, 3)
67 | assert kernel_size % 2 == 1
68 | padding = (kernel_size - 1) // 2
69 | if use_mlp:
70 | # NOTE 'mlp' mode is a timm experiment, not in paper
71 | assert channels is not None
72 | if rd_channels is None:
73 | rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
74 | act_layer = act_layer or nn.ReLU
75 | self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
76 | self.act = create_act_layer(act_layer)
77 | self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
78 | else:
79 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
80 | self.act = None
81 | self.conv2 = None
82 | self.gate = create_act_layer(gate_layer)
83 |
84 | def forward(self, x):
85 | y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
86 | y = self.conv(y)
87 | if self.conv2 is not None:
88 | y = self.act(y)
89 | y = self.conv2(y)
90 | y = self.gate(y).view(x.shape[0], -1, 1, 1)
91 | return x * y.expand_as(x)
92 |
93 |
94 | EfficientChannelAttn = EcaModule # alias
95 |
96 |
97 | class CecaModule(nn.Module):
98 | """Constructs a circular ECA module.
99 |
100 | ECA module where the conv uses circular padding rather than zero padding.
101 | Unlike the spatial dimension, the channels do not have inherent ordering nor
102 | locality. Although this module in essence, applies such an assumption, it is unnecessary
103 | to limit the channels on either "edge" from being circularly adapted to each other.
104 | This will fundamentally increase connectivity and possibly increase performance metrics
105 | (accuracy, robustness), without significantly impacting resource metrics
106 | (parameter size, throughput,latency, etc)
107 |
108 | Args:
109 | channels: Number of channels of the input feature map for use in adaptive kernel sizes
110 | for actual calculations according to channel.
111 | gamma, beta: when channel is given parameters of mapping function
112 | refer to original paper https://arxiv.org/pdf/1910.03151.pdf
113 | (default=None. if channel size not given, use k_size given for kernel size.)
114 | kernel_size: Adaptive selection of kernel size (default=3)
115 | gamm: used in kernel_size calc, see above
116 | beta: used in kernel_size calc, see above
117 | act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
118 | gate_layer: gating non-linearity to use
119 | """
120 |
121 | def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
122 | super(CecaModule, self).__init__()
123 | if channels is not None:
124 | t = int(abs(math.log(channels, 2) + beta) / gamma)
125 | kernel_size = max(t if t % 2 else t + 1, 3)
126 | has_act = act_layer is not None
127 | assert kernel_size % 2 == 1
128 |
129 | # PyTorch circular padding mode is buggy as of pytorch 1.4
130 | # see https://github.com/pytorch/pytorch/pull/17240
131 | # implement manual circular padding
132 | self.padding = (kernel_size - 1) // 2
133 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
134 | self.gate = create_act_layer(gate_layer)
135 |
136 | def forward(self, x):
137 | y = x.mean((2, 3)).view(x.shape[0], 1, -1)
138 | # Manually implement circular padding, F.pad does not seemed to be bugged
139 | y = F.pad(y, (self.padding, self.padding), mode='circular')
140 | y = self.conv(y)
141 | y = self.gate(y).view(x.shape[0], -1, 1, 1)
142 | return x * y.expand_as(x)
143 |
144 |
145 | CircularEfficientChannelAttn = CecaModule
146 |
--------------------------------------------------------------------------------
/models/layers/filter_response_norm.py:
--------------------------------------------------------------------------------
1 | """ Filter Response Norm in PyTorch
2 |
3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
4 |
5 | Hacked together by / Copyright 2021 Ross Wightman
6 | """
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .create_act import create_act_layer
11 | from .trace_utils import _assert
12 |
13 |
14 | def inv_instance_rms(x, eps: float = 1e-5):
15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
16 | return rms.expand(x.shape)
17 |
18 |
19 | class FilterResponseNormTlu2d(nn.Module):
20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
21 | super(FilterResponseNormTlu2d, self).__init__()
22 | self.apply_act = apply_act # apply activation (non-linearity)
23 | self.rms = rms
24 | self.eps = eps
25 | self.weight = nn.Parameter(torch.ones(num_features))
26 | self.bias = nn.Parameter(torch.zeros(num_features))
27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | nn.init.ones_(self.weight)
32 | nn.init.zeros_(self.bias)
33 | if self.tau is not None:
34 | nn.init.zeros_(self.tau)
35 |
36 | def forward(self, x):
37 | _assert(x.dim() == 4, 'expected 4D input')
38 | x_dtype = x.dtype
39 | v_shape = (1, -1, 1, 1)
40 | x = x * inv_instance_rms(x, self.eps)
41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
43 |
44 |
45 | class FilterResponseNormAct2d(nn.Module):
46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
47 | super(FilterResponseNormAct2d, self).__init__()
48 | if act_layer is not None and apply_act:
49 | self.act = create_act_layer(act_layer, inplace=inplace)
50 | else:
51 | self.act = nn.Identity()
52 | self.rms = rms
53 | self.eps = eps
54 | self.weight = nn.Parameter(torch.ones(num_features))
55 | self.bias = nn.Parameter(torch.zeros(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.ones_(self.weight)
60 | nn.init.zeros_(self.bias)
61 |
62 | def forward(self, x):
63 | _assert(x.dim() == 4, 'expected 4D input')
64 | x_dtype = x.dtype
65 | v_shape = (1, -1, 1, 1)
66 | x = x * inv_instance_rms(x, self.eps)
67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
68 | return self.act(x)
69 |
--------------------------------------------------------------------------------
/models/layers/gather_excite.py:
--------------------------------------------------------------------------------
1 | """ Gather-Excite Attention Block
2 |
3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
4 |
5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
6 |
7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
8 | impl that covers all of the cases.
9 |
10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
11 |
12 | Hacked together by / Copyright 2021 Ross Wightman
13 | """
14 | import math
15 |
16 | from torch import nn as nn
17 | import torch.nn.functional as F
18 |
19 | from .create_act import create_act_layer, get_act_layer
20 | from .create_conv2d import create_conv2d
21 | from .helpers import make_divisible
22 | from .mlp import ConvMlp
23 |
24 |
25 | class GatherExcite(nn.Module):
26 | """ Gather-Excite Attention Module
27 | """
28 | def __init__(
29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
32 | super(GatherExcite, self).__init__()
33 | self.add_maxpool = add_maxpool
34 | act_layer = get_act_layer(act_layer)
35 | self.extent = extent
36 | if extra_params:
37 | self.gather = nn.Sequential()
38 | if extent == 0:
39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
40 | self.gather.add_module(
41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
42 | if norm_layer:
43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
44 | else:
45 | assert extent % 2 == 0
46 | num_conv = int(math.log2(extent))
47 | for i in range(num_conv):
48 | self.gather.add_module(
49 | f'conv{i + 1}',
50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
51 | if norm_layer:
52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
53 | if i != num_conv - 1:
54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
55 | else:
56 | self.gather = None
57 | if self.extent == 0:
58 | self.gk = 0
59 | self.gs = 0
60 | else:
61 | assert extent % 2 == 0
62 | self.gk = self.extent * 2 - 1
63 | self.gs = self.extent
64 |
65 | if not rd_channels:
66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
68 | self.gate = create_act_layer(gate_layer)
69 |
70 | def forward(self, x):
71 | size = x.shape[-2:]
72 | if self.gather is not None:
73 | x_ge = self.gather(x)
74 | else:
75 | if self.extent == 0:
76 | # global extent
77 | x_ge = x.mean(dim=(2, 3), keepdims=True)
78 | if self.add_maxpool:
79 | # experimental codepath, may remove or change
80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
81 | else:
82 | x_ge = F.avg_pool2d(
83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
84 | if self.add_maxpool:
85 | # experimental codepath, may remove or change
86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
87 | x_ge = self.mlp(x_ge)
88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
89 | x_ge = F.interpolate(x_ge, size=size)
90 | return x * self.gate(x_ge)
91 |
--------------------------------------------------------------------------------
/models/layers/global_context.py:
--------------------------------------------------------------------------------
1 | """ Global Context Attention Block
2 |
3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
4 | - https://arxiv.org/abs/1904.11492
5 |
6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet
7 |
8 | Hacked together by / Copyright 2021 Ross Wightman
9 | """
10 | from torch import nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .create_act import create_act_layer, get_act_layer
14 | from .helpers import make_divisible
15 | from .mlp import ConvMlp
16 | from .norm import LayerNorm2d
17 |
18 |
19 | class GlobalContext(nn.Module):
20 |
21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
23 | super(GlobalContext, self).__init__()
24 | act_layer = get_act_layer(act_layer)
25 |
26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
27 |
28 | if rd_channels is None:
29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
30 | if fuse_add:
31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
32 | else:
33 | self.mlp_add = None
34 | if fuse_scale:
35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
36 | else:
37 | self.mlp_scale = None
38 |
39 | self.gate = create_act_layer(gate_layer)
40 | self.init_last_zero = init_last_zero
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | if self.conv_attn is not None:
45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
46 | if self.mlp_add is not None:
47 | nn.init.zeros_(self.mlp_add.fc2.weight)
48 |
49 | def forward(self, x):
50 | B, C, H, W = x.shape
51 |
52 | if self.conv_attn is not None:
53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
56 | context = context.view(B, C, 1, 1)
57 | else:
58 | context = x.mean(dim=(2, 3), keepdim=True)
59 |
60 | if self.mlp_scale is not None:
61 | mlp_x = self.mlp_scale(context)
62 | x = x * self.gate(mlp_x)
63 | if self.mlp_add is not None:
64 | mlp_x = self.mlp_add(context)
65 | x = x + mlp_x
66 |
67 | return x
68 |
--------------------------------------------------------------------------------
/models/layers/helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, collections.abc.Iterable):
13 | return x
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
26 | min_value = min_value or divisor
27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28 | # Make sure that round down does not go down by more than 10%.
29 | if new_v < round_limit * v:
30 | new_v += divisor
31 | return new_v
32 |
--------------------------------------------------------------------------------
/models/layers/inplace_abn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | try:
5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync
6 | has_iabn = True
7 | except ImportError:
8 | has_iabn = False
9 |
10 | def inplace_abn(x, weight, bias, running_mean, running_var,
11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
12 | raise ImportError(
13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
14 |
15 | def inplace_abn_sync(**kwargs):
16 | inplace_abn(**kwargs)
17 |
18 |
19 | class InplaceAbn(nn.Module):
20 | """Activated Batch Normalization
21 |
22 | This gathers a BatchNorm and an activation function in a single module
23 |
24 | Parameters
25 | ----------
26 | num_features : int
27 | Number of feature channels in the input and output.
28 | eps : float
29 | Small constant to prevent numerical issues.
30 | momentum : float
31 | Momentum factor applied to compute running statistics.
32 | affine : bool
33 | If `True` apply learned scale and shift transformation after normalization.
34 | act_layer : str or nn.Module type
35 | Name or type of the activation functions, one of: `leaky_relu`, `elu`
36 | act_param : float
37 | Negative slope for the `leaky_relu` activation.
38 | """
39 |
40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None):
42 | super(InplaceAbn, self).__init__()
43 | self.num_features = num_features
44 | self.affine = affine
45 | self.eps = eps
46 | self.momentum = momentum
47 | if apply_act:
48 | if isinstance(act_layer, str):
49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '')
50 | self.act_name = act_layer if act_layer else 'identity'
51 | else:
52 | # convert act layer passed as type to string
53 | if act_layer == nn.ELU:
54 | self.act_name = 'elu'
55 | elif act_layer == nn.LeakyReLU:
56 | self.act_name = 'leaky_relu'
57 | elif act_layer is None or act_layer == nn.Identity:
58 | self.act_name = 'identity'
59 | else:
60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN'
61 | else:
62 | self.act_name = 'identity'
63 | self.act_param = act_param
64 | if self.affine:
65 | self.weight = nn.Parameter(torch.ones(num_features))
66 | self.bias = nn.Parameter(torch.zeros(num_features))
67 | else:
68 | self.register_parameter('weight', None)
69 | self.register_parameter('bias', None)
70 | self.register_buffer('running_mean', torch.zeros(num_features))
71 | self.register_buffer('running_var', torch.ones(num_features))
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | nn.init.constant_(self.running_mean, 0)
76 | nn.init.constant_(self.running_var, 1)
77 | if self.affine:
78 | nn.init.constant_(self.weight, 1)
79 | nn.init.constant_(self.bias, 0)
80 |
81 | def forward(self, x):
82 | output = inplace_abn(
83 | x, self.weight, self.bias, self.running_mean, self.running_var,
84 | self.training, self.momentum, self.eps, self.act_name, self.act_param)
85 | if isinstance(output, tuple):
86 | output = output[0]
87 | return output
88 |
--------------------------------------------------------------------------------
/models/layers/lambda_layer.py:
--------------------------------------------------------------------------------
1 | """ Lambda Layer
2 |
3 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
4 | - https://arxiv.org/abs/2102.08602
5 |
6 | @misc{2102.08602,
7 | Author = {Irwan Bello},
8 | Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
9 | Year = {2021},
10 | }
11 |
12 | Status:
13 | This impl is a WIP. Code snippets in the paper were used as reference but
14 | good chance some details are missing/wrong.
15 |
16 | I've only implemented local lambda conv based pos embeddings.
17 |
18 | For a PyTorch impl that includes other embedding options checkout
19 | https://github.com/lucidrains/lambda-networks
20 |
21 | Hacked together by / Copyright 2021 Ross Wightman
22 | """
23 | import torch
24 | from torch import nn
25 | import torch.nn.functional as F
26 |
27 | from .helpers import to_2tuple, make_divisible
28 | from .weight_init import trunc_normal_
29 |
30 |
31 | def rel_pos_indices(size):
32 | size = to_2tuple(size)
33 | pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
34 | rel_pos = pos[:, None, :] - pos[:, :, None]
35 | rel_pos[0] += size[0] - 1
36 | rel_pos[1] += size[1] - 1
37 | return rel_pos # 2, H * W, H * W
38 |
39 |
40 | class LambdaLayer(nn.Module):
41 | """Lambda Layer
42 |
43 | Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
44 | - https://arxiv.org/abs/2102.08602
45 |
46 | NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
47 |
48 | The internal dimensions of the lambda module are controlled via the interaction of several arguments.
49 | * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
50 | * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
51 | * the query (q) and key (k) dimension are determined by
52 | * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
53 | * q = num_heads * dim_head, k = dim_head
54 | * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
55 |
56 | Args:
57 | dim (int): input dimension to the module
58 | dim_out (int): output dimension of the module, same as dim if not set
59 | feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
60 | stride (int): output stride of the module, avg pool used if stride == 2
61 | num_heads (int): parallel attention heads.
62 | dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
63 | r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
64 | qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
65 | qkv_bias (bool): add bias to q, k, and v projections
66 | """
67 | def __init__(
68 | self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
69 | qk_ratio=1.0, qkv_bias=False):
70 | super().__init__()
71 | dim_out = dim_out or dim
72 | assert dim_out % num_heads == 0, ' should be divided by num_heads'
73 | self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
74 | self.num_heads = num_heads
75 | self.dim_v = dim_out // num_heads
76 |
77 | self.qkv = nn.Conv2d(
78 | dim,
79 | num_heads * self.dim_qk + self.dim_qk + self.dim_v,
80 | kernel_size=1, bias=qkv_bias)
81 | self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
82 | self.norm_v = nn.BatchNorm2d(self.dim_v)
83 |
84 | if r is not None:
85 | # local lambda convolution for pos
86 | self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
87 | self.pos_emb = None
88 | self.rel_pos_indices = None
89 | else:
90 | # relative pos embedding
91 | assert feat_size is not None
92 | feat_size = to_2tuple(feat_size)
93 | rel_size = [2 * s - 1 for s in feat_size]
94 | self.conv_lambda = None
95 | self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
96 | self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
97 |
98 | self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
99 |
100 | self.reset_parameters()
101 |
102 | def reset_parameters(self):
103 | trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
104 | if self.conv_lambda is not None:
105 | trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
106 | if self.pos_emb is not None:
107 | trunc_normal_(self.pos_emb, std=.02)
108 |
109 | def forward(self, x):
110 | B, C, H, W = x.shape
111 | M = H * W
112 | qkv = self.qkv(x)
113 | q, k, v = torch.split(qkv, [
114 | self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
115 | q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
116 | v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
117 | k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
118 |
119 | content_lam = k @ v # B, K, V
120 | content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
121 |
122 | if self.pos_emb is None:
123 | position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
124 | position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
125 | else:
126 | # FIXME relative pos embedding path not fully verified
127 | pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
128 | position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
129 | position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
130 |
131 | out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
132 | out = self.pool(out)
133 | return out
134 |
--------------------------------------------------------------------------------
/models/layers/linear.py:
--------------------------------------------------------------------------------
1 | """ Linear layer (alternate definition)
2 | """
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn as nn
6 |
7 |
8 | class Linear(nn.Linear):
9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10 |
11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13 | """
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | if torch.jit.is_scripting():
16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18 | else:
19 | return F.linear(input, self.weight, self.bias)
20 |
--------------------------------------------------------------------------------
/models/layers/median_pool.py:
--------------------------------------------------------------------------------
1 | """ Median Pool
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .helpers import to_2tuple, to_4tuple
7 |
8 |
9 | class MedianPool2d(nn.Module):
10 | """ Median pool (usable as median filter when stride=1) module.
11 |
12 | Args:
13 | kernel_size: size of pooling kernel, int or 2-tuple
14 | stride: pool stride, int or 2-tuple
15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
16 | same: override padding and enforce same padding, boolean
17 | """
18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
19 | super(MedianPool2d, self).__init__()
20 | self.k = to_2tuple(kernel_size)
21 | self.stride = to_2tuple(stride)
22 | self.padding = to_4tuple(padding) # convert to l, r, t, b
23 | self.same = same
24 |
25 | def _padding(self, x):
26 | if self.same:
27 | ih, iw = x.size()[2:]
28 | if ih % self.stride[0] == 0:
29 | ph = max(self.k[0] - self.stride[0], 0)
30 | else:
31 | ph = max(self.k[0] - (ih % self.stride[0]), 0)
32 | if iw % self.stride[1] == 0:
33 | pw = max(self.k[1] - self.stride[1], 0)
34 | else:
35 | pw = max(self.k[1] - (iw % self.stride[1]), 0)
36 | pl = pw // 2
37 | pr = pw - pl
38 | pt = ph // 2
39 | pb = ph - pt
40 | padding = (pl, pr, pt, pb)
41 | else:
42 | padding = self.padding
43 | return padding
44 |
45 | def forward(self, x):
46 | x = F.pad(x, self._padding(x), mode='reflect')
47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
49 | return x
50 |
--------------------------------------------------------------------------------
/models/layers/mixed_conv2d.py:
--------------------------------------------------------------------------------
1 | """ PyTorch Mixed Convolution
2 |
3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import torch
9 | from torch import nn as nn
10 |
11 | from .conv2d_same import create_conv2d_pad
12 |
13 |
14 | def _split_channels(num_chan, num_groups):
15 | split = [num_chan // num_groups for _ in range(num_groups)]
16 | split[0] += num_chan - sum(split)
17 | return split
18 |
19 |
20 | class MixedConv2d(nn.ModuleDict):
21 | """ Mixed Grouped Convolution
22 |
23 | Based on MDConv and GroupedConv in MixNet impl:
24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
25 | """
26 | def __init__(self, in_channels, out_channels, kernel_size=3,
27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs):
28 | super(MixedConv2d, self).__init__()
29 |
30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
31 | num_groups = len(kernel_size)
32 | in_splits = _split_channels(in_channels, num_groups)
33 | out_splits = _split_channels(out_channels, num_groups)
34 | self.in_channels = sum(in_splits)
35 | self.out_channels = sum(out_splits)
36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
37 | conv_groups = in_ch if depthwise else 1
38 | # use add_module to keep key space clean
39 | self.add_module(
40 | str(idx),
41 | create_conv2d_pad(
42 | in_ch, out_ch, k, stride=stride,
43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
44 | )
45 | self.splits = in_splits
46 |
47 | def forward(self, x):
48 | x_split = torch.split(x, self.splits, 1)
49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
50 | x = torch.cat(x_out, 1)
51 | return x
52 |
--------------------------------------------------------------------------------
/models/layers/mlp.py:
--------------------------------------------------------------------------------
1 | """ MLP module w/ dropout and configurable activation layer
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from torch import nn as nn
6 |
7 | from .helpers import to_2tuple
8 |
9 |
10 | class Mlp(nn.Module):
11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
12 | """
13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
14 | super().__init__()
15 | out_features = out_features or in_features
16 | hidden_features = hidden_features or in_features
17 | drop_probs = to_2tuple(drop)
18 |
19 | self.fc1 = nn.Linear(in_features, hidden_features)
20 | self.act = act_layer()
21 | self.drop1 = nn.Dropout(drop_probs[0])
22 | self.fc2 = nn.Linear(hidden_features, out_features)
23 | self.drop2 = nn.Dropout(drop_probs[1])
24 |
25 | def forward(self, x):
26 | x = self.fc1(x)
27 | x = self.act(x)
28 | x = self.drop1(x)
29 | x = self.fc2(x)
30 | x = self.drop2(x)
31 | return x
32 |
33 |
34 | class GluMlp(nn.Module):
35 | """ MLP w/ GLU style gating
36 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
37 | """
38 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
39 | super().__init__()
40 | out_features = out_features or in_features
41 | hidden_features = hidden_features or in_features
42 | assert hidden_features % 2 == 0
43 | drop_probs = to_2tuple(drop)
44 |
45 | self.fc1 = nn.Linear(in_features, hidden_features)
46 | self.act = act_layer()
47 | self.drop1 = nn.Dropout(drop_probs[0])
48 | self.fc2 = nn.Linear(hidden_features // 2, out_features)
49 | self.drop2 = nn.Dropout(drop_probs[1])
50 |
51 | def init_weights(self):
52 | # override init of fc1 w/ gate portion set to weight near zero, bias=1
53 | fc1_mid = self.fc1.bias.shape[0] // 2
54 | nn.init.ones_(self.fc1.bias[fc1_mid:])
55 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
56 |
57 | def forward(self, x):
58 | x = self.fc1(x)
59 | x, gates = x.chunk(2, dim=-1)
60 | x = x * self.act(gates)
61 | x = self.drop1(x)
62 | x = self.fc2(x)
63 | x = self.drop2(x)
64 | return x
65 |
66 |
67 | class GatedMlp(nn.Module):
68 | """ MLP as used in gMLP
69 | """
70 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
71 | gate_layer=None, drop=0.):
72 | super().__init__()
73 | out_features = out_features or in_features
74 | hidden_features = hidden_features or in_features
75 | drop_probs = to_2tuple(drop)
76 |
77 | self.fc1 = nn.Linear(in_features, hidden_features)
78 | self.act = act_layer()
79 | self.drop1 = nn.Dropout(drop_probs[0])
80 | if gate_layer is not None:
81 | assert hidden_features % 2 == 0
82 | self.gate = gate_layer(hidden_features)
83 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
84 | else:
85 | self.gate = nn.Identity()
86 | self.fc2 = nn.Linear(hidden_features, out_features)
87 | self.drop2 = nn.Dropout(drop_probs[1])
88 |
89 | def forward(self, x):
90 | x = self.fc1(x)
91 | x = self.act(x)
92 | x = self.drop1(x)
93 | x = self.gate(x)
94 | x = self.fc2(x)
95 | x = self.drop2(x)
96 | return x
97 |
98 |
99 | class ConvMlp(nn.Module):
100 | """ MLP using 1x1 convs that keeps spatial dims
101 | """
102 | def __init__(
103 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
104 | super().__init__()
105 | out_features = out_features or in_features
106 | hidden_features = hidden_features or in_features
107 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
108 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
109 | self.act = act_layer()
110 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
111 | self.drop = nn.Dropout(drop)
112 |
113 | def forward(self, x):
114 | x = self.fc1(x)
115 | x = self.norm(x)
116 | x = self.act(x)
117 | x = self.drop(x)
118 | x = self.fc2(x)
119 | return x
120 |
--------------------------------------------------------------------------------
/models/layers/non_local_attn.py:
--------------------------------------------------------------------------------
1 | """ Bilinear-Attention-Transform and Non-Local Attention
2 |
3 | Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
4 | - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
5 | Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
6 | """
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from .conv_bn_act import ConvNormAct
12 | from .helpers import make_divisible
13 | from .trace_utils import _assert
14 |
15 |
16 | class NonLocalAttn(nn.Module):
17 | """Spatial NL block for image classification.
18 |
19 | This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
20 | Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
21 | """
22 |
23 | def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
24 | super(NonLocalAttn, self).__init__()
25 | if rd_channels is None:
26 | rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
27 | self.scale = in_channels ** -0.5 if use_scale else 1.0
28 | self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
29 | self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
30 | self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
31 | self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
32 | self.norm = nn.BatchNorm2d(in_channels)
33 | self.reset_parameters()
34 |
35 | def forward(self, x):
36 | shortcut = x
37 |
38 | t = self.t(x)
39 | p = self.p(x)
40 | g = self.g(x)
41 |
42 | B, C, H, W = t.size()
43 | t = t.view(B, C, -1).permute(0, 2, 1)
44 | p = p.view(B, C, -1)
45 | g = g.view(B, C, -1).permute(0, 2, 1)
46 |
47 | att = torch.bmm(t, p) * self.scale
48 | att = F.softmax(att, dim=2)
49 | x = torch.bmm(att, g)
50 |
51 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
52 | x = self.z(x)
53 | x = self.norm(x) + shortcut
54 |
55 | return x
56 |
57 | def reset_parameters(self):
58 | for name, m in self.named_modules():
59 | if isinstance(m, nn.Conv2d):
60 | nn.init.kaiming_normal_(
61 | m.weight, mode='fan_out', nonlinearity='relu')
62 | if len(list(m.parameters())) > 1:
63 | nn.init.constant_(m.bias, 0.0)
64 | elif isinstance(m, nn.BatchNorm2d):
65 | nn.init.constant_(m.weight, 0)
66 | nn.init.constant_(m.bias, 0)
67 | elif isinstance(m, nn.GroupNorm):
68 | nn.init.constant_(m.weight, 0)
69 | nn.init.constant_(m.bias, 0)
70 |
71 |
72 | class BilinearAttnTransform(nn.Module):
73 |
74 | def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
75 | super(BilinearAttnTransform, self).__init__()
76 |
77 | self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
78 | self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
79 | self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
80 | self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
81 | self.block_size = block_size
82 | self.groups = groups
83 | self.in_channels = in_channels
84 |
85 | def resize_mat(self, x, t: int):
86 | B, C, block_size, block_size1 = x.shape
87 | _assert(block_size == block_size1, '')
88 | if t <= 1:
89 | return x
90 | x = x.view(B * C, -1, 1, 1)
91 | x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
92 | x = x.view(B * C, block_size, block_size, t, t)
93 | x = torch.cat(torch.split(x, 1, dim=1), dim=3)
94 | x = torch.cat(torch.split(x, 1, dim=2), dim=4)
95 | x = x.view(B, C, block_size * t, block_size * t)
96 | return x
97 |
98 | def forward(self, x):
99 | _assert(x.shape[-1] % self.block_size == 0, '')
100 | _assert(x.shape[-2] % self.block_size == 0, '')
101 | B, C, H, W = x.shape
102 | out = self.conv1(x)
103 | rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
104 | cp = F.adaptive_max_pool2d(out, (1, self.block_size))
105 | p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
106 | q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
107 | p = p / p.sum(dim=3, keepdim=True)
108 | q = q / q.sum(dim=2, keepdim=True)
109 | p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
110 | 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
111 | p = p.view(B, C, self.block_size, self.block_size)
112 | q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
113 | 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
114 | q = q.view(B, C, self.block_size, self.block_size)
115 | p = self.resize_mat(p, H // self.block_size)
116 | q = self.resize_mat(q, W // self.block_size)
117 | y = p.matmul(x)
118 | y = y.matmul(q)
119 |
120 | y = self.conv2(y)
121 | return y
122 |
123 |
124 | class BatNonLocalAttn(nn.Module):
125 | """ BAT
126 | Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
127 | """
128 |
129 | def __init__(
130 | self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
131 | drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
132 | super().__init__()
133 | if rd_channels is None:
134 | rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
135 | self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
136 | self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
137 | self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
138 | self.dropout = nn.Dropout2d(p=drop_rate)
139 |
140 | def forward(self, x):
141 | xl = self.conv1(x)
142 | y = self.ba(xl)
143 | y = self.conv2(y)
144 | y = self.dropout(y)
145 | return y + x
146 |
--------------------------------------------------------------------------------
/models/layers/norm.py:
--------------------------------------------------------------------------------
1 | """ Normalization layers and wrappers
2 | """
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class GroupNorm(nn.GroupNorm):
9 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
10 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
11 | super().__init__(num_groups, num_channels, eps=eps, affine=affine)
12 |
13 | def forward(self, x):
14 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
15 |
16 |
17 | class LayerNorm2d(nn.LayerNorm):
18 | """ LayerNorm for channels of '2D' spatial BCHW tensors """
19 | def __init__(self, num_channels):
20 | super().__init__(num_channels)
21 |
22 | def forward(self, x: torch.Tensor) -> torch.Tensor:
23 | return F.layer_norm(
24 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
25 |
--------------------------------------------------------------------------------
/models/layers/norm_act.py:
--------------------------------------------------------------------------------
1 | """ Normalization + Activation Layers
2 | """
3 | from typing import Union, List
4 |
5 | import torch
6 | from torch import nn as nn
7 | from torch.nn import functional as F
8 |
9 | from .trace_utils import _assert
10 | from .create_act import get_act_layer
11 |
12 |
13 | class BatchNormAct2d(nn.BatchNorm2d):
14 | """BatchNorm + Activation
15 |
16 | This module performs BatchNorm + Activation in a manner that will remain backwards
17 | compatible with weights trained with separate bn, act. This is why we inherit from BN
18 | instead of composing it as a .bn member.
19 | """
20 | def __init__(
21 | self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
22 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
23 | super(BatchNormAct2d, self).__init__(
24 | num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
25 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
26 | act_layer = get_act_layer(act_layer) # string -> nn.Module
27 | if act_layer is not None and apply_act:
28 | act_args = dict(inplace=True) if inplace else {}
29 | self.act = act_layer(**act_args)
30 | else:
31 | self.act = nn.Identity()
32 |
33 | def forward(self, x):
34 | # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
35 | _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
36 |
37 | # exponential_average_factor is set to self.momentum
38 | # (when it is available) only so that it gets updated
39 | # in ONNX graph when this node is exported to ONNX.
40 | if self.momentum is None:
41 | exponential_average_factor = 0.0
42 | else:
43 | exponential_average_factor = self.momentum
44 |
45 | if self.training and self.track_running_stats:
46 | # TODO: if statement only here to tell the jit to skip emitting this when it is None
47 | if self.num_batches_tracked is not None: # type: ignore[has-type]
48 | self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
49 | if self.momentum is None: # use cumulative moving average
50 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
51 | else: # use exponential moving average
52 | exponential_average_factor = self.momentum
53 |
54 | r"""
55 | Decide whether the mini-batch stats should be used for normalization rather than the buffers.
56 | Mini-batch stats are used in training mode, and in eval mode when buffers are None.
57 | """
58 | if self.training:
59 | bn_training = True
60 | else:
61 | bn_training = (self.running_mean is None) and (self.running_var is None)
62 |
63 | r"""
64 | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
65 | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
66 | used for normalization (i.e. in eval mode when buffers are not None).
67 | """
68 | x = F.batch_norm(
69 | x,
70 | # If buffers are not to be tracked, ensure that they won't be updated
71 | self.running_mean if not self.training or self.track_running_stats else None,
72 | self.running_var if not self.training or self.track_running_stats else None,
73 | self.weight,
74 | self.bias,
75 | bn_training,
76 | exponential_average_factor,
77 | self.eps,
78 | )
79 | x = self.drop(x)
80 | x = self.act(x)
81 | return x
82 |
83 |
84 | def _num_groups(num_channels, num_groups, group_size):
85 | if group_size:
86 | assert num_channels % group_size == 0
87 | return num_channels // group_size
88 | return num_groups
89 |
90 |
91 | class GroupNormAct(nn.GroupNorm):
92 | # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
93 | def __init__(
94 | self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
95 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
96 | super(GroupNormAct, self).__init__(
97 | _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
98 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
99 | act_layer = get_act_layer(act_layer) # string -> nn.Module
100 | if act_layer is not None and apply_act:
101 | act_args = dict(inplace=True) if inplace else {}
102 | self.act = act_layer(**act_args)
103 | else:
104 | self.act = nn.Identity()
105 |
106 | def forward(self, x):
107 | x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
108 | x = self.drop(x)
109 | x = self.act(x)
110 | return x
111 |
112 |
113 | class LayerNormAct(nn.LayerNorm):
114 | def __init__(
115 | self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
116 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
117 | super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
118 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
119 | act_layer = get_act_layer(act_layer) # string -> nn.Module
120 | if act_layer is not None and apply_act:
121 | act_args = dict(inplace=True) if inplace else {}
122 | self.act = act_layer(**act_args)
123 | else:
124 | self.act = nn.Identity()
125 |
126 | def forward(self, x):
127 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
128 | x = self.drop(x)
129 | x = self.act(x)
130 | return x
131 |
132 |
133 | class LayerNormAct2d(nn.LayerNorm):
134 | def __init__(
135 | self, num_channels, eps=1e-5, affine=True,
136 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
137 | super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
138 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
139 | act_layer = get_act_layer(act_layer) # string -> nn.Module
140 | if act_layer is not None and apply_act:
141 | act_args = dict(inplace=True) if inplace else {}
142 | self.act = act_layer(**act_args)
143 | else:
144 | self.act = nn.Identity()
145 |
146 | def forward(self, x):
147 | x = F.layer_norm(
148 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
149 | x = self.drop(x)
150 | x = self.act(x)
151 | return x
152 |
--------------------------------------------------------------------------------
/models/layers/padding.py:
--------------------------------------------------------------------------------
1 | """ Padding Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import math
6 | from typing import List, Tuple
7 |
8 | import torch.nn.functional as F
9 |
10 |
11 | # Calculate symmetric padding for a convolution
12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
14 | return padding
15 |
16 |
17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
18 | def get_same_padding(x: int, k: int, s: int, d: int):
19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
20 |
21 |
22 | # Can SAME padding for given args be done statically?
23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
25 |
26 |
27 | # Dynamically pad input x with 'SAME' padding for conv with specified args
28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
29 | ih, iw = x.size()[-2:]
30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
31 | if pad_h > 0 or pad_w > 0:
32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
33 | return x
34 |
35 |
36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
37 | dynamic = False
38 | if isinstance(padding, str):
39 | # for any string padding, the padding will be calculated for you, one of three ways
40 | padding = padding.lower()
41 | if padding == 'same':
42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
43 | if is_static_pad(kernel_size, **kwargs):
44 | # static case, no extra overhead
45 | padding = get_padding(kernel_size, **kwargs)
46 | else:
47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead
48 | padding = 0
49 | dynamic = True
50 | elif padding == 'valid':
51 | # 'VALID' padding, same as padding=0
52 | padding = 0
53 | else:
54 | # Default to PyTorch style 'same'-ish symmetric padding
55 | padding = get_padding(kernel_size, **kwargs)
56 | return padding, dynamic
57 |
--------------------------------------------------------------------------------
/models/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | """ Image to Patch Embedding using Conv2d
2 |
3 | A convolution based approach to patchifying a 2D image w/ embedding projection.
4 |
5 | Based on the impl in https://github.com/google-research/vision_transformer
6 |
7 | Hacked together by / Copyright 2020 Ross Wightman
8 | """
9 | from torch import nn as nn
10 |
11 | from .helpers import to_2tuple
12 | from .trace_utils import _assert
13 |
14 |
15 | class PatchEmbed(nn.Module):
16 | """ 2D Image to Patch Embedding
17 | """
18 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
19 | super().__init__()
20 | img_size = to_2tuple(img_size)
21 | patch_size = to_2tuple(patch_size)
22 | self.img_size = img_size
23 | self.patch_size = patch_size
24 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
25 | self.num_patches = self.grid_size[0] * self.grid_size[1]
26 | self.flatten = flatten
27 |
28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30 |
31 | def forward(self, x):
32 | B, C, H, W = x.shape
33 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
34 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
35 | x = self.proj(x)
36 | if self.flatten:
37 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
38 | x = self.norm(x)
39 | return x
40 |
--------------------------------------------------------------------------------
/models/layers/pool2d_same.py:
--------------------------------------------------------------------------------
1 | """ AvgPool2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import List, Tuple, Optional
9 |
10 | from .helpers import to_2tuple
11 | from .padding import pad_same, get_padding_value
12 |
13 |
14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
15 | ceil_mode: bool = False, count_include_pad: bool = True):
16 | # FIXME how to deal with count_include_pad vs not for external padding?
17 | x = pad_same(x, kernel_size, stride)
18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
19 |
20 |
21 | class AvgPool2dSame(nn.AvgPool2d):
22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling
23 | """
24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25 | kernel_size = to_2tuple(kernel_size)
26 | stride = to_2tuple(stride)
27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
28 |
29 | def forward(self, x):
30 | x = pad_same(x, self.kernel_size, self.stride)
31 | return F.avg_pool2d(
32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
33 |
34 |
35 | def max_pool2d_same(
36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
37 | dilation: List[int] = (1, 1), ceil_mode: bool = False):
38 | x = pad_same(x, kernel_size, stride, value=-float('inf'))
39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
40 |
41 |
42 | class MaxPool2dSame(nn.MaxPool2d):
43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling
44 | """
45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
46 | kernel_size = to_2tuple(kernel_size)
47 | stride = to_2tuple(stride)
48 | dilation = to_2tuple(dilation)
49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
50 |
51 | def forward(self, x):
52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
54 |
55 |
56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
57 | stride = stride or kernel_size
58 | padding = kwargs.pop('padding', '')
59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
60 | if is_dynamic:
61 | if pool_type == 'avg':
62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
63 | elif pool_type == 'max':
64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
65 | else:
66 | assert False, f'Unsupported pool type {pool_type}'
67 | else:
68 | if pool_type == 'avg':
69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
70 | elif pool_type == 'max':
71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
72 | else:
73 | assert False, f'Unsupported pool type {pool_type}'
74 |
--------------------------------------------------------------------------------
/models/layers/selective_kernel.py:
--------------------------------------------------------------------------------
1 | """ Selective Kernel Convolution/Attention
2 |
3 | Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 | import torch
8 | from torch import nn as nn
9 |
10 | from .conv_bn_act import ConvNormActAa
11 | from .helpers import make_divisible
12 | from .trace_utils import _assert
13 |
14 |
15 | def _kernel_valid(k):
16 | if isinstance(k, (list, tuple)):
17 | for ki in k:
18 | return _kernel_valid(ki)
19 | assert k >= 3 and k % 2
20 |
21 |
22 | class SelectiveKernelAttn(nn.Module):
23 | def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
24 | """ Selective Kernel Attention Module
25 |
26 | Selective Kernel attention mechanism factored out into its own module.
27 |
28 | """
29 | super(SelectiveKernelAttn, self).__init__()
30 | self.num_paths = num_paths
31 | self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
32 | self.bn = norm_layer(attn_channels)
33 | self.act = act_layer(inplace=True)
34 | self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
35 |
36 | def forward(self, x):
37 | _assert(x.shape[1] == self.num_paths, '')
38 | x = x.sum(1).mean((2, 3), keepdim=True)
39 | x = self.fc_reduce(x)
40 | x = self.bn(x)
41 | x = self.act(x)
42 | x = self.fc_select(x)
43 | B, C, H, W = x.shape
44 | x = x.view(B, self.num_paths, C // self.num_paths, H, W)
45 | x = torch.softmax(x, dim=1)
46 | return x
47 |
48 |
49 | class SelectiveKernel(nn.Module):
50 |
51 | def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
52 | rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
53 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None):
54 | """ Selective Kernel Convolution Module
55 |
56 | As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
57 |
58 | Largest change is the input split, which divides the input channels across each convolution path, this can
59 | be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
60 | the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
61 | a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
62 |
63 | Args:
64 | in_channels (int): module input (feature) channel count
65 | out_channels (int): module output (feature) channel count
66 | kernel_size (int, list): kernel size for each convolution branch
67 | stride (int): stride for convolutions
68 | dilation (int): dilation for module as a whole, impacts dilation of each branch
69 | groups (int): number of groups for each branch
70 | rd_ratio (int, float): reduction factor for attention features
71 | keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
72 | split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
73 | can be viewed as grouping by path, output expands to module out_channels count
74 | act_layer (nn.Module): activation layer to use
75 | norm_layer (nn.Module): batchnorm/norm layer to use
76 | aa_layer (nn.Module): anti-aliasing module
77 | drop_layer (nn.Module): spatial drop module in convs (drop block, etc)
78 | """
79 | super(SelectiveKernel, self).__init__()
80 | out_channels = out_channels or in_channels
81 | kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
82 | _kernel_valid(kernel_size)
83 | if not isinstance(kernel_size, list):
84 | kernel_size = [kernel_size] * 2
85 | if keep_3x3:
86 | dilation = [dilation * (k - 1) // 2 for k in kernel_size]
87 | kernel_size = [3] * len(kernel_size)
88 | else:
89 | dilation = [dilation] * len(kernel_size)
90 | self.num_paths = len(kernel_size)
91 | self.in_channels = in_channels
92 | self.out_channels = out_channels
93 | self.split_input = split_input
94 | if self.split_input:
95 | assert in_channels % self.num_paths == 0
96 | in_channels = in_channels // self.num_paths
97 | groups = min(out_channels, groups)
98 |
99 | conv_kwargs = dict(
100 | stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
101 | aa_layer=aa_layer, drop_layer=drop_layer)
102 | self.paths = nn.ModuleList([
103 | ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
104 | for k, d in zip(kernel_size, dilation)])
105 |
106 | attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
107 | self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
108 |
109 | def forward(self, x):
110 | if self.split_input:
111 | x_split = torch.split(x, self.in_channels // self.num_paths, 1)
112 | x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
113 | else:
114 | x_paths = [op(x) for op in self.paths]
115 | x = torch.stack(x_paths, dim=1)
116 | x_attn = self.attn(x)
117 | x = x * x_attn
118 | x = torch.sum(x, dim=1)
119 | return x
120 |
--------------------------------------------------------------------------------
/models/layers/separable_conv.py:
--------------------------------------------------------------------------------
1 | """ Depthwise Separable Conv Modules
2 |
3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | from torch import nn as nn
9 |
10 | from .create_conv2d import create_conv2d
11 | from .create_norm_act import get_norm_act_layer
12 |
13 |
14 | class SeparableConvNormAct(nn.Module):
15 | """ Separable Conv w/ trailing Norm and Activation
16 | """
17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
19 | apply_act=True, drop_layer=None):
20 | super(SeparableConvNormAct, self).__init__()
21 |
22 | self.conv_dw = create_conv2d(
23 | in_channels, int(in_channels * channel_multiplier), kernel_size,
24 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
25 |
26 | self.conv_pw = create_conv2d(
27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
28 |
29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
32 |
33 | @property
34 | def in_channels(self):
35 | return self.conv_dw.in_channels
36 |
37 | @property
38 | def out_channels(self):
39 | return self.conv_pw.out_channels
40 |
41 | def forward(self, x):
42 | x = self.conv_dw(x)
43 | x = self.conv_pw(x)
44 | x = self.bn(x)
45 | return x
46 |
47 |
48 | SeparableConvBnAct = SeparableConvNormAct
49 |
50 |
51 | class SeparableConv2d(nn.Module):
52 | """ Separable Conv
53 | """
54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
55 | channel_multiplier=1.0, pw_kernel_size=1):
56 | super(SeparableConv2d, self).__init__()
57 |
58 | self.conv_dw = create_conv2d(
59 | in_channels, int(in_channels * channel_multiplier), kernel_size,
60 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
61 |
62 | self.conv_pw = create_conv2d(
63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
64 |
65 | @property
66 | def in_channels(self):
67 | return self.conv_dw.in_channels
68 |
69 | @property
70 | def out_channels(self):
71 | return self.conv_pw.out_channels
72 |
73 | def forward(self, x):
74 | x = self.conv_dw(x)
75 | x = self.conv_pw(x)
76 | return x
77 |
--------------------------------------------------------------------------------
/models/layers/space_to_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SpaceToDepth(nn.Module):
6 | def __init__(self, block_size=4):
7 | super().__init__()
8 | assert block_size == 4
9 | self.bs = block_size
10 |
11 | def forward(self, x):
12 | N, C, H, W = x.size()
13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
16 | return x
17 |
18 |
19 | @torch.jit.script
20 | class SpaceToDepthJit(object):
21 | def __call__(self, x: torch.Tensor):
22 | # assuming hard-coded that block_size==4 for acceleration
23 | N, C, H, W = x.size()
24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
27 | return x
28 |
29 |
30 | class SpaceToDepthModule(nn.Module):
31 | def __init__(self, no_jit=False):
32 | super().__init__()
33 | if not no_jit:
34 | self.op = SpaceToDepthJit()
35 | else:
36 | self.op = SpaceToDepth()
37 |
38 | def forward(self, x):
39 | return self.op(x)
40 |
41 |
42 | class DepthToSpace(nn.Module):
43 |
44 | def __init__(self, block_size):
45 | super().__init__()
46 | self.bs = block_size
47 |
48 | def forward(self, x):
49 | N, C, H, W = x.size()
50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
53 | return x
54 |
--------------------------------------------------------------------------------
/models/layers/split_attn.py:
--------------------------------------------------------------------------------
1 | """ Split Attention Conv2d (for ResNeSt Models)
2 |
3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4 |
5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6 |
7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8 | """
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | from .helpers import make_divisible
14 |
15 |
16 | class RadixSoftmax(nn.Module):
17 | def __init__(self, radix, cardinality):
18 | super(RadixSoftmax, self).__init__()
19 | self.radix = radix
20 | self.cardinality = cardinality
21 |
22 | def forward(self, x):
23 | batch = x.size(0)
24 | if self.radix > 1:
25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
26 | x = F.softmax(x, dim=1)
27 | x = x.reshape(batch, -1)
28 | else:
29 | x = torch.sigmoid(x)
30 | return x
31 |
32 |
33 | class SplitAttn(nn.Module):
34 | """Split-Attention (aka Splat)
35 | """
36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
39 | super(SplitAttn, self).__init__()
40 | out_channels = out_channels or in_channels
41 | self.radix = radix
42 | mid_chs = out_channels * radix
43 | if rd_channels is None:
44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
45 | else:
46 | attn_chs = rd_channels * radix
47 |
48 | padding = kernel_size // 2 if padding is None else padding
49 | self.conv = nn.Conv2d(
50 | in_channels, mid_chs, kernel_size, stride, padding, dilation,
51 | groups=groups * radix, bias=bias, **kwargs)
52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
54 | self.act0 = act_layer(inplace=True)
55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
57 | self.act1 = act_layer(inplace=True)
58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
59 | self.rsoftmax = RadixSoftmax(radix, groups)
60 |
61 | def forward(self, x):
62 | x = self.conv(x)
63 | x = self.bn0(x)
64 | x = self.drop(x)
65 | x = self.act0(x)
66 |
67 | B, RC, H, W = x.shape
68 | if self.radix > 1:
69 | x = x.reshape((B, self.radix, RC // self.radix, H, W))
70 | x_gap = x.sum(dim=1)
71 | else:
72 | x_gap = x
73 | x_gap = x_gap.mean((2, 3), keepdim=True)
74 | x_gap = self.fc1(x_gap)
75 | x_gap = self.bn1(x_gap)
76 | x_gap = self.act1(x_gap)
77 | x_attn = self.fc2(x_gap)
78 |
79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
80 | if self.radix > 1:
81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
82 | else:
83 | out = x * x_attn
84 | return out.contiguous()
85 |
--------------------------------------------------------------------------------
/models/layers/split_batchnorm.py:
--------------------------------------------------------------------------------
1 | """ Split BatchNorm
2 |
3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias
5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
6 | namespace.
7 |
8 | This allows easily removing the auxiliary BN layers after training to efficiently
9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
10 | 'Disentangled Learning via An Auxiliary BN'
11 |
12 | Hacked together by / Copyright 2020 Ross Wightman
13 | """
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d):
19 |
20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
21 | track_running_stats=True, num_splits=2):
22 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
24 | self.num_splits = num_splits
25 | self.aux_bn = nn.ModuleList([
26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
27 |
28 | def forward(self, input: torch.Tensor):
29 | if self.training: # aux BN only relevant while training
30 | split_size = input.shape[0] // self.num_splits
31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
32 | split_input = input.split(split_size)
33 | x = [super().forward(split_input[0])]
34 | for i, a in enumerate(self.aux_bn):
35 | x.append(a(split_input[i + 1]))
36 | return torch.cat(x, dim=0)
37 | else:
38 | return super().forward(input)
39 |
40 |
41 | def convert_splitbn_model(module, num_splits=2):
42 | """
43 | Recursively traverse module and its children to replace all instances of
44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
45 | Args:
46 | module (torch.nn.Module): input module
47 | num_splits: number of separate batchnorm layers to split input across
48 | Example::
49 | >>> # model is an instance of torch.nn.Module
50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
51 | """
52 | mod = module
53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
54 | return module
55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
56 | mod = SplitBatchNorm2d(
57 | module.num_features, module.eps, module.momentum, module.affine,
58 | module.track_running_stats, num_splits=num_splits)
59 | mod.running_mean = module.running_mean
60 | mod.running_var = module.running_var
61 | mod.num_batches_tracked = module.num_batches_tracked
62 | if module.affine:
63 | mod.weight.data = module.weight.data.clone().detach()
64 | mod.bias.data = module.bias.data.clone().detach()
65 | for aux in mod.aux_bn:
66 | aux.running_mean = module.running_mean.clone()
67 | aux.running_var = module.running_var.clone()
68 | aux.num_batches_tracked = module.num_batches_tracked.clone()
69 | if module.affine:
70 | aux.weight.data = module.weight.data.clone().detach()
71 | aux.bias.data = module.bias.data.clone().detach()
72 | for name, child in module.named_children():
73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
74 | del module
75 | return mod
76 |
--------------------------------------------------------------------------------
/models/layers/squeeze_excite.py:
--------------------------------------------------------------------------------
1 | """ Squeeze-and-Excitation Channel Attention
2 |
3 | An SE implementation originally based on PyTorch SE-Net impl.
4 | Has since evolved with additional functionality / configuration.
5 |
6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
7 |
8 | Also included is Effective Squeeze-Excitation (ESE).
9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
10 |
11 | Hacked together by / Copyright 2021 Ross Wightman
12 | """
13 | from torch import nn as nn
14 |
15 | from .create_act import create_act_layer
16 | from .helpers import make_divisible
17 |
18 |
19 | class SEModule(nn.Module):
20 | """ SE Module as defined in original SE-Nets with a few additions
21 | Additions include:
22 | * divisor can be specified to keep channels % div == 0 (default: 8)
23 | * reduction channels can be specified directly by arg (if rd_channels is set)
24 | * reduction channels can be specified by float rd_ratio (default: 1/16)
25 | * global max pooling can be added to the squeeze aggregation
26 | * customizable activation, normalization, and gate layer
27 | """
28 | def __init__(
29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
30 | act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
31 | super(SEModule, self).__init__()
32 | self.add_maxpool = add_maxpool
33 | if not rd_channels:
34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
37 | self.act = create_act_layer(act_layer, inplace=True)
38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
39 | self.gate = create_act_layer(gate_layer)
40 |
41 | def forward(self, x):
42 | x_se = x.mean((2, 3), keepdim=True)
43 | if self.add_maxpool:
44 | # experimental codepath, may remove or change
45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
46 | x_se = self.fc1(x_se)
47 | x_se = self.act(self.bn(x_se))
48 | x_se = self.fc2(x_se)
49 | return x * self.gate(x_se)
50 |
51 |
52 | SqueezeExcite = SEModule # alias
53 |
54 |
55 | class EffectiveSEModule(nn.Module):
56 | """ 'Effective Squeeze-Excitation
57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
58 | """
59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
60 | super(EffectiveSEModule, self).__init__()
61 | self.add_maxpool = add_maxpool
62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
63 | self.gate = create_act_layer(gate_layer)
64 |
65 | def forward(self, x):
66 | x_se = x.mean((2, 3), keepdim=True)
67 | if self.add_maxpool:
68 | # experimental codepath, may remove or change
69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
70 | x_se = self.fc(x_se)
71 | return x * self.gate(x_se)
72 |
73 |
74 | EffectiveSqueezeExcite = EffectiveSEModule # alias
75 |
--------------------------------------------------------------------------------
/models/layers/std_conv.py:
--------------------------------------------------------------------------------
1 | """ Convolution with Weight Standardization (StdConv and ScaledStdConv)
2 |
3 | StdConv:
4 | @article{weightstandardization,
5 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
6 | title = {Weight Standardization},
7 | journal = {arXiv preprint arXiv:1903.10520},
8 | year = {2019},
9 | }
10 | Code: https://github.com/joe-siyuan-qiao/WeightStandardization
11 |
12 | ScaledStdConv:
13 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
14 | - https://arxiv.org/abs/2101.08692
15 | Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
16 |
17 | Hacked together by / copyright Ross Wightman, 2021.
18 | """
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | from .padding import get_padding, get_padding_value, pad_same
24 |
25 |
26 | class StdConv2d(nn.Conv2d):
27 | """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
28 |
29 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
30 | https://arxiv.org/abs/1903.10520v2
31 | """
32 | def __init__(
33 | self, in_channel, out_channels, kernel_size, stride=1, padding=None,
34 | dilation=1, groups=1, bias=False, eps=1e-6):
35 | if padding is None:
36 | padding = get_padding(kernel_size, stride, dilation)
37 | super().__init__(
38 | in_channel, out_channels, kernel_size, stride=stride,
39 | padding=padding, dilation=dilation, groups=groups, bias=bias)
40 | self.eps = eps
41 |
42 | def forward(self, x):
43 | weight = F.batch_norm(
44 | self.weight.reshape(1, self.out_channels, -1), None, None,
45 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
46 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
47 | return x
48 |
49 |
50 | class StdConv2dSame(nn.Conv2d):
51 | """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
52 |
53 | Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
54 | https://arxiv.org/abs/1903.10520v2
55 | """
56 | def __init__(
57 | self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
58 | dilation=1, groups=1, bias=False, eps=1e-6):
59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
60 | super().__init__(
61 | in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
62 | groups=groups, bias=bias)
63 | self.same_pad = is_dynamic
64 | self.eps = eps
65 |
66 | def forward(self, x):
67 | if self.same_pad:
68 | x = pad_same(x, self.kernel_size, self.stride, self.dilation)
69 | weight = F.batch_norm(
70 | self.weight.reshape(1, self.out_channels, -1), None, None,
71 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
72 | x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
73 | return x
74 |
75 |
76 | class ScaledStdConv2d(nn.Conv2d):
77 | """Conv2d layer with Scaled Weight Standardization.
78 |
79 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
80 | https://arxiv.org/abs/2101.08692
81 |
82 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
83 | """
84 |
85 | def __init__(
86 | self, in_channels, out_channels, kernel_size, stride=1, padding=None,
87 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
88 | if padding is None:
89 | padding = get_padding(kernel_size, stride, dilation)
90 | super().__init__(
91 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
92 | groups=groups, bias=bias)
93 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
94 | self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
95 | self.eps = eps
96 |
97 | def forward(self, x):
98 | weight = F.batch_norm(
99 | self.weight.reshape(1, self.out_channels, -1), None, None,
100 | weight=(self.gain * self.scale).view(-1),
101 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
102 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
103 |
104 |
105 | class ScaledStdConv2dSame(nn.Conv2d):
106 | """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
107 |
108 | Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
109 | https://arxiv.org/abs/2101.08692
110 |
111 | NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
112 | """
113 |
114 | def __init__(
115 | self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
116 | dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
117 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
118 | super().__init__(
119 | in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
120 | groups=groups, bias=bias)
121 | self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
122 | self.scale = gamma * self.weight[0].numel() ** -0.5
123 | self.same_pad = is_dynamic
124 | self.eps = eps
125 |
126 | def forward(self, x):
127 | if self.same_pad:
128 | x = pad_same(x, self.kernel_size, self.stride, self.dilation)
129 | weight = F.batch_norm(
130 | self.weight.reshape(1, self.out_channels, -1), None, None,
131 | weight=(self.gain * self.scale).view(-1),
132 | training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
133 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
134 |
--------------------------------------------------------------------------------
/models/layers/test_time_pool.py:
--------------------------------------------------------------------------------
1 | """ Test Time Pooling (Average-Max Pool)
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | import logging
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
11 |
12 |
13 | _logger = logging.getLogger(__name__)
14 |
15 |
16 | class TestTimePoolHead(nn.Module):
17 | def __init__(self, base, original_pool=7):
18 | super(TestTimePoolHead, self).__init__()
19 | self.base = base
20 | self.original_pool = original_pool
21 | base_fc = self.base.get_classifier()
22 | if isinstance(base_fc, nn.Conv2d):
23 | self.fc = base_fc
24 | else:
25 | self.fc = nn.Conv2d(
26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
29 | self.base.reset_classifier(0) # delete original fc layer
30 |
31 | def forward(self, x):
32 | x = self.base.forward_features(x)
33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
34 | x = self.fc(x)
35 | x = adaptive_avgmax_pool2d(x, 1)
36 | return x.view(x.size(0), -1)
37 |
38 |
39 | def apply_test_time_pool(model, config, use_test_size=True):
40 | test_time_pool = False
41 | if not hasattr(model, 'default_cfg') or not model.default_cfg:
42 | return model, False
43 | if use_test_size and 'test_input_size' in model.default_cfg:
44 | df_input_size = model.default_cfg['test_input_size']
45 | else:
46 | df_input_size = model.default_cfg['input_size']
47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
49 | (str(config['input_size'][-2:]), str(df_input_size[-2:])))
50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
51 | test_time_pool = True
52 | return model, test_time_pool
53 |
--------------------------------------------------------------------------------
/models/layers/trace_utils.py:
--------------------------------------------------------------------------------
1 | try:
2 | from torch import _assert
3 | except ImportError:
4 | def _assert(condition: bool, message: str):
5 | assert condition, message
6 |
7 |
8 | def _float_to_int(x: float) -> int:
9 | """
10 | Symbolic tracing helper to substitute for inbuilt `int`.
11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy`
12 | """
13 | return int(x)
14 |
--------------------------------------------------------------------------------
/models/layers/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 |
5 | from torch.nn.init import _calculate_fan_in_and_fan_out
6 |
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 |
65 | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
66 | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
67 | if mode == 'fan_in':
68 | denom = fan_in
69 | elif mode == 'fan_out':
70 | denom = fan_out
71 | elif mode == 'fan_avg':
72 | denom = (fan_in + fan_out) / 2
73 |
74 | variance = scale / denom
75 |
76 | if distribution == "truncated_normal":
77 | # constant is stddev of standard normal truncated to (-2, 2)
78 | trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
79 | elif distribution == "normal":
80 | tensor.normal_(std=math.sqrt(variance))
81 | elif distribution == "uniform":
82 | bound = math.sqrt(3 * variance)
83 | tensor.uniform_(-bound, bound)
84 | else:
85 | raise ValueError(f"invalid distribution {distribution}")
86 |
87 |
88 | def lecun_normal_(tensor):
89 | variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
90 |
--------------------------------------------------------------------------------
/models/pruned/ecaresnet50d_pruned.txt:
--------------------------------------------------------------------------------
1 | conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022]
--------------------------------------------------------------------------------
/models/registry.py:
--------------------------------------------------------------------------------
1 | """ Model Registry
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 |
5 | import sys
6 | import re
7 | import fnmatch
8 | from collections import defaultdict
9 | from copy import deepcopy
10 |
11 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
12 | 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained']
13 |
14 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
15 | _model_to_module = {} # mapping of model names to module names
16 | _model_entrypoints = {} # mapping of model names to entrypoint fns
17 | _model_has_pretrained = set() # set of model names that have pretrained weight url present
18 | _model_pretrained_cfgs = dict() # central repo for model default_cfgs
19 |
20 |
21 | def register_model(fn):
22 | # lookup containing module
23 | mod = sys.modules[fn.__module__]
24 | module_name_split = fn.__module__.split('.')
25 | module_name = module_name_split[-1] if len(module_name_split) else ''
26 |
27 | # add model to __all__ in module
28 | model_name = fn.__name__
29 | if hasattr(mod, '__all__'):
30 | mod.__all__.append(model_name)
31 | else:
32 | mod.__all__ = [model_name]
33 |
34 | # add entries to registry dict/sets
35 | _model_entrypoints[model_name] = fn
36 | _model_to_module[model_name] = module_name
37 | _module_to_models[module_name].add(model_name)
38 | has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this
39 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
40 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
41 | # entrypoints or non-matching combos
42 | cfg = mod.default_cfgs[model_name]
43 | has_valid_pretrained = (
44 | ('url' in cfg and 'http' in cfg['url']) or
45 | ('file' in cfg and cfg['file']) or
46 | ('hf_hub_id' in cfg and cfg['hf_hub_id'])
47 | )
48 | _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name]
49 | if has_valid_pretrained:
50 | _model_has_pretrained.add(model_name)
51 | return fn
52 |
53 |
54 | def _natural_key(string_):
55 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
56 |
57 |
58 | def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
59 | """ Return list of available model names, sorted alphabetically
60 |
61 | Args:
62 | filter (str) - Wildcard filter string that works with fnmatch
63 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
64 | pretrained (bool) - Include only models with pretrained weights if True
65 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
66 | name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
67 |
68 | Example:
69 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
70 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
71 | """
72 | if module:
73 | all_models = list(_module_to_models[module])
74 | else:
75 | all_models = _model_entrypoints.keys()
76 | if filter:
77 | models = []
78 | include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
79 | for f in include_filters:
80 | include_models = fnmatch.filter(all_models, f) # include these models
81 | if len(include_models):
82 | models = set(models).union(include_models)
83 | else:
84 | models = all_models
85 | if exclude_filters:
86 | if not isinstance(exclude_filters, (tuple, list)):
87 | exclude_filters = [exclude_filters]
88 | for xf in exclude_filters:
89 | exclude_models = fnmatch.filter(models, xf) # exclude these models
90 | if len(exclude_models):
91 | models = set(models).difference(exclude_models)
92 | if pretrained:
93 | models = _model_has_pretrained.intersection(models)
94 | if name_matches_cfg:
95 | models = set(_model_pretrained_cfgs).intersection(models)
96 | return list(sorted(models, key=_natural_key))
97 |
98 |
99 | def is_model(model_name):
100 | """ Check if a model name exists
101 | """
102 | return model_name in _model_entrypoints
103 |
104 |
105 | def model_entrypoint(model_name):
106 | """Fetch a model entrypoint for specified model name
107 | """
108 | return _model_entrypoints[model_name]
109 |
110 |
111 | def list_modules():
112 | """ Return list of module names that contain models / model entrypoints
113 | """
114 | modules = _module_to_models.keys()
115 | return list(sorted(modules))
116 |
117 |
118 | def is_model_in_modules(model_name, module_names):
119 | """Check if a model exists within a subset of modules
120 | Args:
121 | model_name (str) - name of model to check
122 | module_names (tuple, list, set) - names of modules to search in
123 | """
124 | assert isinstance(module_names, (tuple, list, set))
125 | return any(model_name in _module_to_models[n] for n in module_names)
126 |
127 |
128 | def is_model_pretrained(model_name):
129 | return model_name in _model_has_pretrained
130 |
131 |
132 | def get_pretrained_cfg(model_name):
133 | if model_name in _model_pretrained_cfgs:
134 | return deepcopy(_model_pretrained_cfgs[model_name])
135 | return {}
136 |
137 |
138 | def has_pretrained_cfg_key(model_name, cfg_key):
139 | """ Query model default_cfgs for existence of a specific key.
140 | """
141 | if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]:
142 | return True
143 | return False
144 |
145 |
146 | def is_pretrained_cfg_key(model_name, cfg_key):
147 | """ Return truthy value for specified model default_cfg key, False if does not exist.
148 | """
149 | if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False):
150 | return True
151 | return False
152 |
153 |
154 | def get_pretrained_cfg_value(model_name, cfg_key):
155 | """ Get a specific model default_cfg value by key. None if it doesn't exist.
156 | """
157 | if model_name in _model_pretrained_cfgs:
158 | return _model_pretrained_cfgs[model_name].get(cfg_key, None)
159 | return None
--------------------------------------------------------------------------------
/pics/CPU.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/CPU.png
--------------------------------------------------------------------------------
/pics/GPU.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/GPU.png
--------------------------------------------------------------------------------
/pics/pic1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/pic1.png
--------------------------------------------------------------------------------
/pics/pic2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/pic2.png
--------------------------------------------------------------------------------
/pics/pic3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alibaba-MIIL/Solving_ImageNet/31af67a0365249e186283b74af1c9c2b16c63a4c/pics/pic3.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.5.4
2 | torch>=1.9.0
3 | torchvision>=0.5.0
4 | pyyaml
5 | inplace-abn
--------------------------------------------------------------------------------
/test/test_build_kd_model.py:
--------------------------------------------------------------------------------
1 |
2 | # test_build_kd_model.py - Generated by https://www.codium.ai/
3 |
4 | import unittest
5 | from kd.kd_utils import build_kd_model
6 |
7 | """
8 | Code Analysis:
9 | - This class is used to build a knowledge distillation (KD) model.
10 | - It uses the create_model() function from the models.factory module to create a KD model.
11 | - It then uses the InplacABN_to_ABN() and fuse_bn2d_bn1d_abn() functions from the kd.helpers module to convert the model to an ABN model and fuse the batch normalization layers.
12 | - The model is then loaded onto the GPU and set to evaluation mode.
13 | - The mean and standard deviation of the model are also stored.
14 | - The normalize_input() function is used to normalize the input data to match the mean and standard deviation of the KD model.
15 | - It uses the torchvision.transforms module to apply the necessary transformations.
16 | """
17 |
18 |
19 | """
20 | Test strategies:
21 | - test_create_model(): tests that the create_model() function from the models.factory module is correctly called and the model is created.
22 | - test_InplacABN_to_ABN(): tests that the InplacABN_to_ABN() function from the kd.helpers module is correctly called and the model is converted to an ABN model.
23 | - test_fuse_bn2d_bn1d_abn(): tests that the fuse_bn2d_bn1d_abn() function from the kd.helpers module is correctly called and the batch normalization layers are fused.
24 | - test_load_gpu(): tests that the model is correctly loaded onto the GPU and set to evaluation mode.
25 | - test_mean_std(): tests that the mean and standard deviation of the model are correctly stored.
26 | - test_normalize_input(): tests that the normalize_input() function is correctly called and the input data is normalized to match the mean and standard deviation of the KD model.
27 | - test_transforms(): tests that the torchvision.transforms module is correctly called and the necessary transformations are applied.
28 | - test_edge_cases(): tests that the class handles edge cases correctly.
29 | """
30 |
31 |
32 | class TestBuildKdModel(unittest.TestCase):
33 |
34 | def setUp(self):
35 | self.args = {
36 | 'kd_model_name': 'resnet18',
37 | 'kd_model_path': './models/resnet18.pth',
38 | 'num_classes': 10,
39 | 'in_chans': 3
40 | }
41 |
42 | def test_create_model(self):
43 | model = build_kd_model(self.args)
44 | self.assertIsNotNone(model.model)
45 |
46 | def test_InplacABN_to_ABN(self):
47 | model = build_kd_model(self.args)
48 | self.assertIsNotNone(model.model.bn1)
49 |
50 | def test_fuse_bn2d_bn1d_abn(self):
51 | model = build_kd_model(self.args)
52 | self.assertIsNotNone(model.model.bn2d)
53 |
54 | def test_load_gpu(self):
55 | model = build_kd_model(self.args)
56 | self.assertEqual(model.model.device.type, 'cuda')
57 | self.assertEqual(model.model.training, False)
58 |
59 | def test_mean_std(self):
60 | model = build_kd_model(self.args)
61 | self.assertEqual(model.mean_model_kd, model.model.default_cfg['mean'])
62 | self.assertEqual(model.std_model_kd, model.model.default_cfg['std'])
63 |
64 | def test_normalize_input(self):
65 | model = build_kd_model(self.args)
66 | input = torch.randn(3, 224, 224)
67 | student_model = create_model('resnet18', None, False, 10, 3)
68 | model.normalize_input(input, student_model)
69 | self.assertEqual(input.mean(), model.mean_model_kd[0])
70 | self.assertEqual(input.std(), model.std_model_kd[0])
71 |
72 | def test_transforms(self):
73 | model = build_kd_model(self.args)
74 | student_model = create_model('resnet18', None, False, 10, 3)
75 | self.assertIsInstance(model.transform_std, T.Normalize)
76 | self.assertIsInstance(model.transform_mean, T.Normalize)
77 |
78 | def test_edge_cases(self):
79 | args = {
80 | 'kd_model_name': 'resnet18',
81 | 'kd_model_path': None,
82 | 'num_classes': 10,
83 | 'in_chans': 3
84 | }
85 | model = build_kd_model(args)
86 | self.assertIsNotNone(model.model)
87 |
--------------------------------------------------------------------------------
/utils/checkpoint_saver.py:
--------------------------------------------------------------------------------
1 | """ Checkpoint Saver
2 |
3 | Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import glob
9 | import operator
10 | import os
11 | import logging
12 |
13 | import torch
14 |
15 | from timm.utils.model import unwrap_model, get_state_dict
16 |
17 | _logger = logging.getLogger(__name__)
18 |
19 |
20 | class CheckpointSaverUSI: # don't save optimizer state dict, since it forces specific repo sturcture
21 | def __init__(
22 | self,
23 | model,
24 | optimizer,
25 | args=None,
26 | model_ema=None,
27 | amp_scaler=None,
28 | checkpoint_prefix='checkpoint',
29 | recovery_prefix='recovery',
30 | checkpoint_dir='',
31 | recovery_dir='',
32 | decreasing=False,
33 | max_history=10,
34 | unwrap_fn=unwrap_model):
35 |
36 | # objects to save state_dicts of
37 | self.model = model
38 | self.optimizer = optimizer
39 | self.args = args
40 | self.model_ema = model_ema
41 | self.amp_scaler = amp_scaler
42 |
43 | # state
44 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
45 | self.best_epoch = None
46 | self.best_metric = None
47 | self.curr_recovery_file = ''
48 | self.last_recovery_file = ''
49 |
50 | # config
51 | self.checkpoint_dir = checkpoint_dir
52 | self.recovery_dir = recovery_dir
53 | self.save_prefix = checkpoint_prefix
54 | self.recovery_prefix = recovery_prefix
55 | self.extension = '.pth.tar'
56 | self.decreasing = decreasing # a lower metric is better if True
57 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
58 | self.max_history = max_history
59 | self.unwrap_fn = unwrap_fn
60 | assert self.max_history >= 1
61 |
62 | def save_checkpoint(self, epoch, metric=None, metric_ema=None):
63 | assert epoch >= 0
64 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
65 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
66 | self._save(tmp_save_path, epoch, metric, metric_ema)
67 | if os.path.exists(last_save_path):
68 | os.unlink(last_save_path) # required for Windows support.
69 | os.rename(tmp_save_path, last_save_path)
70 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
71 | if (len(self.checkpoint_files) < self.max_history
72 | or metric is None or self.cmp(metric, worst_file[1])):
73 | if len(self.checkpoint_files) >= self.max_history:
74 | self._cleanup_checkpoints(1)
75 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
76 | save_path = os.path.join(self.checkpoint_dir, filename)
77 | os.link(last_save_path, save_path)
78 | self.checkpoint_files.append((save_path, metric))
79 | self.checkpoint_files = sorted(
80 | self.checkpoint_files, key=lambda x: x[1],
81 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better
82 |
83 | checkpoints_str = "Current checkpoints:\n"
84 | for c in self.checkpoint_files:
85 | checkpoints_str += ' {}\n'.format(c)
86 | _logger.info(checkpoints_str)
87 |
88 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
89 | self.best_epoch = epoch
90 | self.best_metric = metric
91 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
92 | if os.path.exists(best_save_path):
93 | os.unlink(best_save_path)
94 | os.link(last_save_path, best_save_path)
95 |
96 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
97 |
98 | def _save(self, save_path, epoch, metric=None, metric_ema=None):
99 | save_state = {
100 | 'epoch': epoch,
101 | 'arch': type(self.model).__name__.lower(),
102 | 'state_dict': get_state_dict(self.model, self.unwrap_fn),
103 | # 'optimizer': self.optimizer.state_dict(),
104 | 'version': 2, # version < 2 increments epoch before save
105 | }
106 | if self.args is not None:
107 | save_state['arch'] = self.args.model
108 | save_state['args'] = self.args
109 | # if self.amp_scaler is not None:
110 | # save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
111 | if self.model_ema is not None:
112 | if metric_ema > metric: # save EMA weights instead of regular weights
113 | save_state['state_dict'] = get_state_dict(self.model_ema, self.unwrap_fn)
114 | # save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
115 | if metric is not None:
116 | save_state['metric'] = metric
117 | torch.save(save_state, save_path)
118 |
119 | def _cleanup_checkpoints(self, trim=0):
120 | trim = min(len(self.checkpoint_files), trim)
121 | delete_index = self.max_history - trim
122 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
123 | return
124 | to_delete = self.checkpoint_files[delete_index:]
125 | for d in to_delete:
126 | try:
127 | _logger.debug("Cleaning checkpoint: {}".format(d))
128 | os.remove(d[0])
129 | except Exception as e:
130 | _logger.error("Exception '{}' while deleting checkpoint".format(e))
131 | self.checkpoint_files = self.checkpoint_files[:delete_index]
132 |
133 | def save_recovery(self, epoch, batch_idx=0):
134 | assert epoch >= 0
135 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
136 | save_path = os.path.join(self.recovery_dir, filename)
137 | self._save(save_path, epoch)
138 | if os.path.exists(self.last_recovery_file):
139 | try:
140 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
141 | os.remove(self.last_recovery_file)
142 | except Exception as e:
143 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
144 | self.last_recovery_file = self.curr_recovery_file
145 | self.curr_recovery_file = save_path
146 |
147 | def find_recovery(self):
148 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
149 | files = glob.glob(recovery_path + '*' + self.extension)
150 | files = sorted(files)
151 | return files[0] if len(files) else ''
152 |
--------------------------------------------------------------------------------