├── .gitignore ├── README.md ├── captum_vis.py ├── dataset ├── CUB200.py ├── ConText.py ├── choose_dataset.py ├── mnist.py └── transform_func.py ├── engine.py ├── fig ├── zfig_story.jpg └── zfig_structure.jpg ├── requirements.txt ├── sloter ├── slot_model.py └── utils │ ├── grad_cam.py │ ├── position_encode.py │ ├── slot_attention.py │ └── vis.py ├── test.py ├── timm ├── __init__.py ├── data │ ├── __init__.py │ ├── auto_augment.py │ ├── config.py │ ├── constants.py │ ├── dataset.py │ ├── distributed_sampler.py │ ├── loader.py │ ├── mixup.py │ ├── random_erasing.py │ ├── tf_preprocessing.py │ ├── transforms.py │ └── transforms_factory.py ├── loss │ ├── __init__.py │ ├── cross_entropy.py │ ├── jsd.py │ └── slot_loss.py ├── models │ ├── __init__.py │ ├── densenet.py │ ├── dla.py │ ├── dpn.py │ ├── efficientnet.py │ ├── efficientnet_blocks.py │ ├── efficientnet_builder.py │ ├── factory.py │ ├── feature_hooks.py │ ├── gluon_resnet.py │ ├── gluon_xception.py │ ├── helpers.py │ ├── hrnet.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── inception_v4.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_jit.py │ │ ├── activations_me.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── anti_aliasing.py │ │ ├── blur_pool.py │ │ ├── cbam.py │ │ ├── cond_conv2d.py │ │ ├── config.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── create_norm_act.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── evo_norm.py │ │ ├── helpers.py │ │ ├── inplace_abn.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── norm_act.py │ │ ├── padding.py │ │ ├── pool2d_same.py │ │ ├── se.py │ │ ├── selective_kernel.py │ │ ├── separable_conv.py │ │ ├── space_to_depth.py │ │ ├── split_attn.py │ │ ├── split_batchnorm.py │ │ ├── test_time_pool.py │ │ └── weight_init.py │ ├── mobilenetv3.py │ ├── nasnet.py │ ├── pnasnet.py │ ├── pruned │ │ ├── ecaresnet101d_pruned.txt │ │ ├── ecaresnet50d_pruned.txt │ │ ├── efficientnet_b1_pruned.txt │ │ ├── efficientnet_b2_pruned.txt │ │ └── efficientnet_b3_pruned.txt │ ├── registry.py │ ├── regnet.py │ ├── res2net.py │ ├── resnest.py │ ├── resnet.py │ ├── selecsls.py │ ├── senet.py │ ├── sknet.py │ ├── tresnet.py │ ├── vovnet.py │ └── xception.py ├── optim │ ├── __init__.py │ ├── adamw.py │ ├── lookahead.py │ ├── nadam.py │ ├── novograd.py │ ├── nvnovograd.py │ ├── optim_factory.py │ ├── radam.py │ └── rmsprop_tf.py ├── scheduler │ ├── __init__.py │ ├── cosine_lr.py │ ├── plateau_lr.py │ ├── scheduler.py │ ├── scheduler_factory.py │ ├── step_lr.py │ └── tanh_lr.py ├── utils.py └── version.py ├── tools ├── calculate_tool.py ├── image_aug.py └── prepare_things.py ├── torchcam ├── IBA │ ├── __init__.py │ ├── _keras_graph.py │ ├── pytorch.py │ ├── pytorch_readout.py │ ├── tensorflow_v1.py │ └── utils.py ├── IGOS.py ├── __init__.py ├── cams │ ├── __init__.py │ ├── cam.py │ └── gradcam.py ├── utils.py └── version.py ├── torchcam_vis.py ├── torchray ├── VERSION ├── __init__.py ├── attribution │ ├── __init__.py │ ├── common.py │ ├── deconvnet.py │ ├── excitation_backprop.py │ ├── extremal_perturbation.py │ ├── grad_cam.py │ ├── gradient.py │ ├── guided_backprop.py │ ├── linear_approx.py │ └── rise.py ├── benchmark │ ├── __init__.py │ ├── datasets.py │ ├── imagenet_classes.txt │ ├── logging.py │ ├── models.py │ ├── pointing_game.py │ ├── pointing_game_ebp_coco_difficult.txt │ ├── pointing_game_ebp_voc07_difficult.txt │ └── server.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project exclude paths 2 | /venv/ 3 | saved_model* 4 | .idea/ 5 | sloter/vis*/ 6 | data/ 7 | *.pyc 8 | .vscode/ 9 | scripts/ 10 | temp* -------------------------------------------------------------------------------- /dataset/CUB200.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | import os 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class CUB_200(Dataset): 9 | def __init__(self, args, train=True, transform=None): 10 | super(CUB_200, self).__init__() 11 | self.root = args.dataset_dir 12 | self.size = args.img_size 13 | self.num = args.num_classes 14 | self.train = train 15 | self.transform_ = transform 16 | self.classes_file = os.path.join(self.root, 'classes.txt') # 17 | self.image_class_labels_file = os.path.join(self.root, 'image_class_labels.txt') # 18 | self.images_file = os.path.join(self.root, 'images.txt') # 19 | self.train_test_split_file = os.path.join(self.root, 'train_test_split.txt') # 20 | self.bounding_boxes_file = os.path.join(self.root, 'bounding_boxes.txt') # 21 | 22 | self._train_ids = [] 23 | self._test_ids = [] 24 | self._image_id_label = {} 25 | self._train_path_label = [] 26 | self._test_path_label = [] 27 | 28 | self._train_test_split() 29 | self._get_id_to_label() 30 | self._get_path_label() 31 | 32 | def _train_test_split(self): 33 | 34 | for line in open(self.train_test_split_file): 35 | image_id, label = line.strip('\n').split() 36 | if label == '1': 37 | self._train_ids.append(image_id) 38 | elif label == '0': 39 | self._test_ids.append(image_id) 40 | else: 41 | raise Exception('label error') 42 | 43 | def _get_id_to_label(self): 44 | for line in open(self.image_class_labels_file): 45 | image_id, class_id = line.strip('\n').split() 46 | self._image_id_label[image_id] = class_id 47 | 48 | def _get_path_label(self): 49 | for line in open(self.images_file): 50 | image_id, image_name = line.strip('\n').split() 51 | if int(image_name[:3]) > self.num: 52 | if image_id in self._train_ids: 53 | self._train_ids.pop(self._train_ids.index(image_id)) 54 | else: 55 | self._test_ids.pop(self._test_ids.index(image_id)) 56 | continue 57 | label = self._image_id_label[image_id] 58 | if image_id in self._train_ids: 59 | self._train_path_label.append((image_name, label)) 60 | else: 61 | self._test_path_label.append((image_name, label)) 62 | 63 | def __getitem__(self, index): 64 | if self.train: 65 | image_name, label = self._train_path_label[index] 66 | else: 67 | image_name, label = self._test_path_label[index] 68 | image_path = os.path.join(self.root, 'images', image_name) 69 | img = Image.open(image_path) 70 | if img.mode == 'L': 71 | img = img.convert('RGB') 72 | label = int(label) - 1 73 | label = torch.from_numpy(np.array(label)) 74 | if self.transform_ is not None: 75 | img = self.transform_(img) 76 | return {"image": img, "label": label, "names": image_path} 77 | 78 | def __len__(self): 79 | if self.train: 80 | return len(self._train_ids) 81 | else: 82 | return len(self._test_ids) 83 | -------------------------------------------------------------------------------- /dataset/ConText.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch.utils.data import Dataset 3 | from sklearn.model_selection import train_test_split 4 | from tools.prepare_things import get_name 5 | import os 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class MakeList(object): 11 | """ 12 | this class used to make list of data for model train and test, return the root name of each image 13 | root: txt file records condition for every cxr image 14 | """ 15 | def __init__(self, args, ratio=0.8): 16 | self.image_root = args.dataset_dir 17 | self.all_image = get_name(self.image_root, mode_folder=False) 18 | self.category = sorted(set([i[:i.find('_')] for i in self.all_image])) 19 | 20 | for c_id, c in enumerate(self.category): 21 | print(c_id, '\t', c) 22 | 23 | self.ration = ratio 24 | 25 | def get_data(self): 26 | all_data = [] 27 | for img in self.all_image: 28 | label = self.deal_label(img) 29 | all_data.append([os.path.join(self.image_root, img), label]) 30 | train, val = train_test_split(all_data, random_state=1, train_size=self.ration) 31 | return train, val 32 | 33 | def deal_label(self, img_name): 34 | categoty_no = img_name[:img_name.find('_')] 35 | back = self.category.index(categoty_no) 36 | return back 37 | 38 | 39 | class MakeListImage(): 40 | """ 41 | this class used to make list of data for ImageNet 42 | """ 43 | def __init__(self, args): 44 | self.image_root = args.dataset_dir 45 | self.category = get_name(self.image_root + "train/") 46 | self.used_cat = self.category[:args.num_classes] 47 | # for c_id, c in enumerate(self.used_cat): 48 | # print(c_id, '\t', c) 49 | 50 | def get_data(self): 51 | train = self.get_img(self.used_cat, "train") 52 | val = self.get_img(self.used_cat, "val") 53 | return train, val 54 | 55 | def get_img(self, folders, phase): 56 | record = [] 57 | for folder in folders: 58 | current_root = os.path.join(self.image_root, phase, folder) 59 | images = get_name(current_root, mode_folder=False) 60 | for img in images: 61 | record.append([os.path.join(current_root, img), self.deal_label(folder)]) 62 | return record 63 | 64 | def deal_label(self, img_name): 65 | back = self.used_cat.index(img_name) 66 | return back 67 | 68 | 69 | class ConText(Dataset): 70 | """read all image name and label""" 71 | def __init__(self, data, transform=None): 72 | self.all_item = data 73 | self.transform = transform 74 | 75 | def __len__(self): 76 | return len(self.all_item) 77 | 78 | def __getitem__(self, item_id): # generate data when giving index 79 | while not os.path.exists(self.all_item[item_id][0]): 80 | raise ("not exist image:" + self.all_item[item_id][0]) 81 | image_path = self.all_item[item_id][0] 82 | image = Image.open(image_path).convert('RGB') 83 | if image.mode == 'L': 84 | image = image.convert('RGB') 85 | if self.transform: 86 | image = self.transform(image) 87 | label = self.all_item[item_id][1] 88 | label = torch.from_numpy(np.array(label)) 89 | return {"image": image, "label": label, "names": image_path} -------------------------------------------------------------------------------- /dataset/choose_dataset.py: -------------------------------------------------------------------------------- 1 | from dataset.mnist import MNIST 2 | from dataset.CUB200 import CUB_200 3 | from dataset.ConText import ConText, MakeList, MakeListImage 4 | from dataset.transform_func import make_transform 5 | 6 | 7 | def select_dataset(args): 8 | if args.dataset == "MNIST": 9 | dataset_train = MNIST('./data/mnist', train=True, download=True, transform=make_transform(args, "train")) 10 | dataset_val = MNIST('./data/mnist', train=False, transform=make_transform(args, "val")) 11 | return dataset_train, dataset_val 12 | if args.dataset == "CUB200": 13 | dataset_train = CUB_200(args, train=True, transform=make_transform(args, "train")) 14 | dataset_val = CUB_200(args, train=False, transform=make_transform(args, "val")) 15 | return dataset_train, dataset_val 16 | if args.dataset == "ConText": 17 | train, val = MakeList(args).get_data() 18 | dataset_train = ConText(train, transform=make_transform(args, "train")) 19 | dataset_val = ConText(val, transform=make_transform(args, "val")) 20 | return dataset_train, dataset_val 21 | if args.dataset == "ImageNet": 22 | train, val = MakeListImage(args).get_data() 23 | dataset_train = ConText(train, transform=make_transform(args, "train")) 24 | dataset_val = ConText(val, transform=make_transform(args, "val")) 25 | return dataset_train, dataset_val 26 | 27 | raise ValueError(f'unknown {args.dataset}') 28 | 29 | -------------------------------------------------------------------------------- /dataset/transform_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tools.image_aug import ImageAugment 3 | import torchvision.transforms.functional as F 4 | from collections.abc import Sequence, Iterable 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | _pil_interpolation_to_str = { 10 | Image.NEAREST: 'PIL.Image.NEAREST', 11 | Image.BILINEAR: 'PIL.Image.BILINEAR', 12 | Image.BICUBIC: 'PIL.Image.BICUBIC', 13 | Image.LANCZOS: 'PIL.Image.LANCZOS', 14 | Image.HAMMING: 'PIL.Image.HAMMING', 15 | Image.BOX: 'PIL.Image.BOX', 16 | } 17 | 18 | 19 | class Resize(object): 20 | """class for resize images. """ 21 | def __init__(self, size, interpolation=Image.BILINEAR): 22 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 23 | self.size = size 24 | self.interpolation = interpolation 25 | 26 | def __call__(self, image): 27 | return np.array(F.resize(image, self.size, self.interpolation)) 28 | 29 | def __repr__(self): 30 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 31 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 32 | 33 | 34 | class Aug(object): 35 | """class for preprocessing images. """ 36 | def __init__(self, aug): 37 | self.aug = aug 38 | 39 | def __call__(self, image): 40 | if self.aug: 41 | ImgAug = ImageAugment() # ImageAugment class will augment the img and label at same time 42 | seq = ImgAug.aug_sequence() 43 | image_aug = ImgAug.aug(image, seq) 44 | return image_aug 45 | else: 46 | return image 47 | 48 | def __repr__(self): 49 | return self.__class__.__name__ + 'Augmentation function' 50 | 51 | 52 | class ToTensor(object): 53 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 54 | 55 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 56 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 57 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 58 | or if the numpy.ndarray has dtype = np.uint8 59 | 60 | In the other cases, tensors are returned without scaling. 61 | """ 62 | 63 | def __call__(self, image, color=True): 64 | if image.ndim == 2: 65 | image = image[:, :, None] 66 | image = torch.from_numpy(((image/255).transpose([2, 0, 1])).copy()) # convert numpy data to tensor 67 | return image 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ + '()' 71 | 72 | 73 | class Compose(object): 74 | def __init__(self, transforms): 75 | self.transforms = transforms 76 | 77 | def __call__(self, img): 78 | for t in self.transforms: 79 | img = t(img) 80 | return img 81 | 82 | def __repr__(self): 83 | format_string = self.__class__.__name__ + '(' 84 | for t in self.transforms: 85 | format_string += '\n' 86 | format_string += ' {0}'.format(t) 87 | format_string += '\n)' 88 | return format_string 89 | 90 | 91 | class Normalize(object): 92 | def __init__(self, mean, std): 93 | self.mean = mean 94 | self.std = std 95 | 96 | def __call__(self, imgs): 97 | imgs = F.normalize(imgs, mean=self.mean, std=self.std) 98 | return imgs 99 | 100 | 101 | def make_transform(args, mode): 102 | normalize_value = {"MNIST": [[0.1307], [0.3081]], 103 | "CUB200": [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]], 104 | "ConText": [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]], 105 | "ImageNet": [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]} 106 | selected_norm = normalize_value[args.dataset] 107 | normalize = Compose([ 108 | ToTensor(), 109 | Normalize(selected_norm[0], selected_norm[1]) 110 | ]) 111 | 112 | if mode == "train": 113 | return Compose([ 114 | Resize((args.img_size, args.img_size)), 115 | Aug(args.aug), 116 | normalize, 117 | ] 118 | ) 119 | if mode == "val": 120 | return Compose([ 121 | Resize((args.img_size, args.img_size)), 122 | normalize, 123 | ] 124 | ) 125 | raise ValueError(f'unknown {mode}') -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tools.calculate_tool as cal 3 | from tqdm.auto import tqdm 4 | 5 | 6 | def train_one_epoch(model, data_loader, optimizer, device, record, epoch): 7 | model.train() 8 | calculation(model, "train", data_loader, device, record, epoch, optimizer) 9 | 10 | 11 | @torch.no_grad() 12 | def evaluate(model, data_loader, device, record, epoch): 13 | model.eval() 14 | calculation(model, "val", data_loader, device, record, epoch) 15 | 16 | 17 | def calculation(model, mode, data_loader, device, record, epoch, optimizer=None): 18 | L = len(data_loader) 19 | running_loss = 0.0 20 | running_corrects = 0.0 21 | running_att_loss = 0.0 22 | running_log_loss = 0.0 23 | print("start " + mode + " :" + str(epoch)) 24 | for i_batch, sample_batch in enumerate(tqdm(data_loader)): 25 | inputs = sample_batch["image"].to(device, dtype=torch.float32) 26 | labels = sample_batch["label"].to(device, dtype=torch.int64) 27 | 28 | if mode == "train": 29 | optimizer.zero_grad() 30 | logits, loss_list = model(inputs, labels) 31 | loss = loss_list[0] 32 | if mode == "train": 33 | loss.backward() 34 | # clip_gradient(optimizer, 1.1) 35 | optimizer.step() 36 | 37 | a = loss.item() 38 | running_loss += a 39 | if len(loss_list) > 2: # For slot training only 40 | running_att_loss += loss_list[2].item() 41 | running_log_loss += loss_list[1].item() 42 | running_corrects += cal.evaluateTop1(logits, labels) 43 | # if i_batch % 10 == 0: 44 | # print("epoch: {} {}/{} Loss: {:.4f}".format(epoch, i_batch, L-1, a)) 45 | epoch_loss = round(running_loss/L, 3) 46 | epoch_loss_log = round(running_log_loss/L, 3) 47 | epoch_loss_att = round(running_att_loss/L, 3) 48 | epoch_acc = round(running_corrects/L, 3) 49 | record[mode]["loss"].append(epoch_loss) 50 | record[mode]["acc"].append(epoch_acc) 51 | record[mode]["log_loss"].append(epoch_loss_log) 52 | record[mode]["att_loss"].append(epoch_loss_att) 53 | 54 | 55 | def clip_gradient(optimizer, grad_clip): 56 | """ 57 | Clips gradients computed during backpropagation to avoid explosion of gradients. 58 | 59 | :param optimizer: optimizer with the gradients to be clipped 60 | :param grad_clip: clip value 61 | """ 62 | for group in optimizer.param_groups: 63 | for param in group["params"]: 64 | if param.grad is not None: 65 | param.grad.data.clamp_(-grad_clip, grad_clip) 66 | 67 | -------------------------------------------------------------------------------- /fig/zfig_story.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbw520/scouter/5885b821681daf8c2263975490b4c6418687277b/fig/zfig_story.jpg -------------------------------------------------------------------------------- /fig/zfig_structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbw520/scouter/5885b821681daf8c2263975490b4c6418687277b/fig/zfig_structure.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.6.20 2 | cycler==0.10.0 3 | decorator==4.4.2 4 | future==0.18.2 5 | imageio==2.9.0 6 | imgaug==0.4.0 7 | joblib==0.16.0 8 | kiwisolver==1.2.0 9 | matplotlib==3.3.1 10 | networkx==2.5 11 | nose==1.3.7 12 | numpy==1.19.2 13 | opencv-python==4.4.0.42 14 | Pillow==7.2.0 15 | prefetch-generator==1.0.1 16 | pyparsing==2.4.7 17 | python-dateutil==2.8.1 18 | PyWavelets==1.1.1 19 | scikit-image==0.17.2 20 | scikit-learn==0.23.2 21 | scipy==1.5.2 22 | Shapely==1.7.1 23 | six==1.15.0 24 | tensorly==0.4.5 25 | thop==0.0.31.post2005241907 26 | threadpoolctl==2.1.0 27 | tifffile==2020.9.3 28 | torch==1.6.0 29 | torchvision==0.7.0 30 | tqdm==4.49.0 31 | -------------------------------------------------------------------------------- /sloter/slot_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from sloter.utils.slot_attention import SlotAttention 5 | from sloter.utils.position_encode import build_position_encoding 6 | from timm.models import create_model 7 | from collections import OrderedDict 8 | 9 | 10 | class Identical(nn.Module): 11 | def __init__(self): 12 | super(Identical, self).__init__() 13 | 14 | def forward(self, x): 15 | return x 16 | 17 | 18 | def load_backbone(args): 19 | bone = create_model( 20 | args.model, 21 | pretrained=args.pre_trained, 22 | num_classes=args.num_classes) 23 | if args.dataset == "MNIST": 24 | bone.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False) 25 | if args.use_slot: 26 | if args.use_pre: 27 | checkpoint = torch.load(f"saved_model/{args.dataset}_no_slot_checkpoint.pth") 28 | new_state_dict = OrderedDict() 29 | for k, v in checkpoint["model"].items(): 30 | name = k[9:] # remove `backbone.` 31 | new_state_dict[name] = v 32 | bone.load_state_dict(new_state_dict) 33 | print("load pre dataset parameter over") 34 | if not args.grad: 35 | if 'seresnet' in args.model: 36 | bone.avg_pool = Identical() 37 | bone.last_linear = Identical() 38 | elif 'res' in args.model: 39 | bone.global_pool = Identical() 40 | bone.fc = Identical() 41 | elif 'efficient' in args.model: 42 | bone.global_pool = Identical() 43 | bone.classifier = Identical() 44 | elif 'densenet' in args.model: 45 | bone.global_pool = Identical() 46 | bone.classifier = Identical() 47 | elif 'mobilenet' in args.model: 48 | bone.global_pool = Identical() 49 | bone.conv_head = Identical() 50 | bone.act2 = Identical() 51 | bone.classifier = Identical() 52 | return bone 53 | 54 | 55 | class SlotModel(nn.Module): 56 | def __init__(self, args): 57 | super(SlotModel, self).__init__() 58 | self.use_slot = args.use_slot 59 | self.backbone = load_backbone(args) 60 | if self.use_slot: 61 | if 'densenet' in args.model: 62 | self.feature_size = 8 63 | else: 64 | self.feature_size = 9 65 | 66 | self.channel = args.channel 67 | self.slots_per_class = args.slots_per_class 68 | self.conv1x1 = nn.Conv2d(self.channel, args.hidden_dim, kernel_size=(1, 1), stride=(1, 1)) 69 | if args.pre_trained: 70 | self.dfs_freeze(self.backbone, args.freeze_layers) 71 | self.slot = SlotAttention(args.num_classes, self.slots_per_class, args.hidden_dim, vis=args.vis, 72 | vis_id=args.vis_id, loss_status=args.loss_status, power=args.power, to_k_layer=args.to_k_layer) 73 | self.position_emb = build_position_encoding('sine', hidden_dim=args.hidden_dim) 74 | self.lambda_value = float(args.lambda_value) 75 | else: 76 | if args.pre_trained: 77 | self.dfs_freeze(self.backbone, args.freeze_layers) 78 | 79 | def dfs_freeze(self, model, freeze_layer_num): 80 | if freeze_layer_num == 0: 81 | return 82 | 83 | unfreeze_layers = ['layer4', 'layer3', 'layer2', 'layer1'][:4-freeze_layer_num] 84 | for name, child in model.named_children(): 85 | skip = False 86 | for freeze_layer in unfreeze_layers: 87 | if freeze_layer in name: 88 | skip = True 89 | break 90 | if skip: 91 | continue 92 | for param in child.parameters(): 93 | param.requires_grad = False 94 | self.dfs_freeze(child, freeze_layer_num) 95 | 96 | def dfs_freeze_bnorm(self, model): 97 | for name, child in model.named_children(): 98 | if 'bn' not in name: 99 | self.dfs_freeze_bnorm(child) 100 | continue 101 | for param in child.parameters(): 102 | param.requires_grad = False 103 | self.dfs_freeze_bnorm(child) 104 | 105 | def forward(self, x, target=None): 106 | x = self.backbone(x) 107 | if self.use_slot: 108 | x = self.conv1x1(x.view(x.size(0), self.channel, self.feature_size, self.feature_size)) 109 | x = torch.relu(x) 110 | pe = self.position_emb(x) 111 | x_pe = x + pe 112 | 113 | b, n, r, c = x.shape 114 | x = x.reshape((b, n, -1)).permute((0, 2, 1)) 115 | x_pe = x_pe.reshape((b, n, -1)).permute((0, 2, 1)) 116 | x, attn_loss = self.slot(x_pe, x) 117 | output = F.log_softmax(x, dim=1) 118 | 119 | if target is not None: 120 | if self.use_slot: 121 | loss = F.nll_loss(output, target) + self.lambda_value * attn_loss 122 | return [output, [loss, F.nll_loss(output, target), attn_loss]] 123 | else: 124 | loss = F.nll_loss(output, target) 125 | return [output, [loss]] 126 | 127 | return output 128 | -------------------------------------------------------------------------------- /sloter/utils/position_encode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class PositionEmbeddingSine(nn.Module): 11 | """ 12 | This is a more standard version of the position embedding, very similar to the one 13 | used by the Attention is all you need paper, generalized to work on images. 14 | """ 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, tensor_list): 27 | x = tensor_list 28 | b, c, h, w = x.shape 29 | mask = torch.zeros((b, h, w), dtype=torch.bool, device=x.device) 30 | not_mask = ~mask 31 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 32 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos.to(x.dtype) 47 | 48 | 49 | class PositionEmbeddingLearned(nn.Module): 50 | """ 51 | Absolute pos embedding, learned. 52 | """ 53 | def __init__(self, num_pos_feats=256): 54 | super().__init__() 55 | self.row_embed = nn.Embedding(50, num_pos_feats) 56 | self.col_embed = nn.Embedding(50, num_pos_feats) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | nn.init.uniform_(self.row_embed.weight) 61 | nn.init.uniform_(self.col_embed.weight) 62 | 63 | def forward(self, tensor_list): 64 | x = tensor_list 65 | h, w = x.shape[-2:] 66 | i = torch.arange(w, device=x.device) 67 | j = torch.arange(h, device=x.device) 68 | x_emb = self.col_embed(i) 69 | y_emb = self.row_embed(j) 70 | pos = torch.cat([ 71 | x_emb.unsqueeze(0).repeat(h, 1, 1), 72 | y_emb.unsqueeze(1).repeat(1, w, 1), 73 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 74 | return pos 75 | 76 | 77 | def build_position_encoding(position_embedding, hidden_dim): 78 | N_steps = hidden_dim // 2 79 | if position_embedding in ('v2', 'sine'): 80 | # TODO find a better way of exposing other arguments 81 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 82 | elif position_embedding in ('v3', 'learned'): 83 | position_embedding = PositionEmbeddingLearned(N_steps) 84 | else: 85 | raise ValueError(f"not supported {position_embedding}") 86 | 87 | return position_embedding 88 | -------------------------------------------------------------------------------- /sloter/utils/slot_attention.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | class SlotAttention(nn.Module): 10 | def __init__(self, num_classes, slots_per_class, dim, iters=3, eps=1e-8, vis=False, vis_id=0, loss_status=1, power=1, to_k_layer=1): 11 | super().__init__() 12 | self.num_classes = num_classes 13 | self.slots_per_class = slots_per_class 14 | self.num_slots = num_classes * slots_per_class 15 | self.iters = iters 16 | self.eps = eps 17 | self.scale = dim ** -0.5 18 | self.loss_status = loss_status 19 | 20 | slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 21 | slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 22 | 23 | mu = slots_mu.expand(1, self.num_slots, -1) 24 | sigma = slots_sigma.expand(1, self.num_slots, -1) 25 | self.initial_slots = nn.Parameter(torch.normal(mu, sigma)) 26 | 27 | self.to_q = nn.Sequential( 28 | nn.Linear(dim, dim), 29 | ) 30 | to_k_layer_list = [nn.Linear(dim, dim)] 31 | for to_k_layer_id in range(1, to_k_layer): 32 | to_k_layer_list.append(nn.ReLU(inplace=True)) 33 | to_k_layer_list.append(nn.Linear(dim, dim)) 34 | 35 | self.to_k = nn.Sequential( 36 | *to_k_layer_list 37 | ) 38 | self.gru = nn.GRU(dim, dim) 39 | 40 | self.vis = vis 41 | self.vis_id = vis_id 42 | self.power = power 43 | 44 | def forward(self, inputs, inputs_x): 45 | b, n, d = inputs.shape 46 | slots = self.initial_slots.expand(b, -1, -1) 47 | k, v = self.to_k(inputs), inputs 48 | 49 | for _ in range(self.iters): 50 | slots_prev = slots 51 | 52 | # q = self.to_q(slots) 53 | q = slots 54 | 55 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 56 | dots = torch.div(dots, dots.sum(2).expand_as(dots.permute([2,0,1])).permute([1,2,0])) * dots.sum(2).sum(1).expand_as(dots.permute([1,2,0])).permute([2,0,1])# * 10 57 | attn = torch.sigmoid(dots) 58 | updates = torch.einsum('bjd,bij->bid', inputs_x, attn) 59 | updates = updates / inputs_x.size(2) 60 | self.gru.flatten_parameters() 61 | slots, _ = self.gru( 62 | updates.reshape(1, -1, d), 63 | slots_prev.reshape(1, -1, d) 64 | ) 65 | 66 | slots = slots.reshape(b, -1, d) 67 | 68 | if self.vis: 69 | slots_vis = attn.clone() 70 | 71 | if self.vis: 72 | if self.slots_per_class > 1: 73 | new_slots_vis = torch.zeros((slots_vis.size(0), self.num_classes, slots_vis.size(-1))) 74 | for slot_class in range(self.num_classes): 75 | new_slots_vis[:, slot_class] = torch.sum(torch.cat([slots_vis[:, self.slots_per_class*slot_class: self.slots_per_class*(slot_class+1)]], dim=1), dim=1, keepdim=False) 76 | slots_vis = new_slots_vis.to(updates.device) 77 | 78 | slots_vis = slots_vis[self.vis_id] 79 | slots_vis = ((slots_vis - slots_vis.min()) / (slots_vis.max()-slots_vis.min()) * 255.).reshape(slots_vis.shape[:1]+(int(slots_vis.size(1)**0.5), int(slots_vis.size(1)**0.5))) 80 | slots_vis = (slots_vis.cpu().detach().numpy()).astype(np.uint8) 81 | for id, image in enumerate(slots_vis): 82 | image = Image.fromarray(image, mode='L') 83 | image.save(f'sloter/vis/slot_{id:d}.png') 84 | print(self.loss_status*torch.sum(attn.clone(), dim=2, keepdim=False)) 85 | print(self.loss_status*torch.sum(updates.clone(), dim=2, keepdim=False)) 86 | 87 | if self.slots_per_class > 1: 88 | new_updates = torch.zeros((updates.size(0), self.num_classes, updates.size(-1))) 89 | for slot_class in range(self.num_classes): 90 | new_updates[:, slot_class] = torch.sum(updates[:, self.slots_per_class*slot_class: self.slots_per_class*(slot_class+1)], dim=1, keepdim=False) 91 | updates = new_updates.to(updates.device) 92 | 93 | attn_relu = torch.relu(attn) 94 | slot_loss = torch.sum(attn_relu) / attn.size(0) / attn.size(1) / attn.size(2)# * self.slots_per_class 95 | 96 | return self.loss_status*torch.sum(updates, dim=2, keepdim=False), torch.pow(slot_loss, self.power) 97 | -------------------------------------------------------------------------------- /sloter/utils/vis.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.cm as mpl_color_map 3 | import copy 4 | import numpy as np 5 | 6 | 7 | def apply_colormap_on_image(org_im, activation, colormap_name): 8 | """ 9 | Apply heatmap on image 10 | Args: 11 | org_img (PIL img): Original image 12 | activation_map (numpy arr): Activation map (grayscale) 0-255 13 | colormap_name (str): Name of the colormap 14 | """ 15 | # Get colormap 16 | color_map = mpl_color_map.get_cmap(colormap_name) 17 | no_trans_heatmap = color_map(activation) 18 | # Change alpha channel in colormap to make sure original image is displayed 19 | heatmap = copy.copy(no_trans_heatmap) 20 | heatmap[:, :, 3] = 0.4 21 | heatmap = Image.fromarray((heatmap*255).astype(np.uint8)) 22 | no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8)) 23 | 24 | # Apply heatmap on iamge 25 | heatmap_on_image = Image.new("RGBA", org_im.size) 26 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA')) 27 | heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap) 28 | return no_trans_heatmap, heatmap_on_image -------------------------------------------------------------------------------- /timm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ 3 | is_scriptable, is_exportable, set_scriptable, set_exportable 4 | -------------------------------------------------------------------------------- /timm/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .config import resolve_data_config 3 | from .dataset import Dataset, DatasetTar, AugMixDataset 4 | from .transforms import * 5 | from .loader import create_loader 6 | from .transforms_factory import create_transform 7 | from .mixup import mixup_batch, FastCollateMixup 8 | from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ 9 | rand_augment_transform, auto_augment_transform 10 | -------------------------------------------------------------------------------- /timm/data/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .constants import * 3 | 4 | 5 | def resolve_data_config(args, default_cfg={}, model=None, verbose=True): 6 | new_config = {} 7 | default_cfg = default_cfg 8 | if not default_cfg and model is not None and hasattr(model, 'default_cfg'): 9 | default_cfg = model.default_cfg 10 | 11 | # Resolve input/image size 12 | in_chans = 3 13 | if 'chans' in args and args['chans'] is not None: 14 | in_chans = args['chans'] 15 | 16 | input_size = (in_chans, 224, 224) 17 | if 'input_size' in args and args['input_size'] is not None: 18 | assert isinstance(args['input_size'], (tuple, list)) 19 | assert len(args['input_size']) == 3 20 | input_size = tuple(args['input_size']) 21 | in_chans = input_size[0] # input_size overrides in_chans 22 | elif 'img_size' in args and args['img_size'] is not None: 23 | assert isinstance(args['img_size'], int) 24 | input_size = (in_chans, args['img_size'], args['img_size']) 25 | elif 'input_size' in default_cfg: 26 | input_size = default_cfg['input_size'] 27 | new_config['input_size'] = input_size 28 | 29 | # resolve interpolation method 30 | new_config['interpolation'] = 'bicubic' 31 | if 'interpolation' in args and args['interpolation']: 32 | new_config['interpolation'] = args['interpolation'] 33 | elif 'interpolation' in default_cfg: 34 | new_config['interpolation'] = default_cfg['interpolation'] 35 | 36 | # resolve dataset + model mean for normalization 37 | new_config['mean'] = IMAGENET_DEFAULT_MEAN 38 | if 'mean' in args and args['mean'] is not None: 39 | mean = tuple(args['mean']) 40 | if len(mean) == 1: 41 | mean = tuple(list(mean) * in_chans) 42 | else: 43 | assert len(mean) == in_chans 44 | new_config['mean'] = mean 45 | elif 'mean' in default_cfg: 46 | new_config['mean'] = default_cfg['mean'] 47 | 48 | # resolve dataset + model std deviation for normalization 49 | new_config['std'] = IMAGENET_DEFAULT_STD 50 | if 'std' in args and args['std'] is not None: 51 | std = tuple(args['std']) 52 | if len(std) == 1: 53 | std = tuple(list(std) * in_chans) 54 | else: 55 | assert len(std) == in_chans 56 | new_config['std'] = std 57 | elif 'std' in default_cfg: 58 | new_config['std'] = default_cfg['std'] 59 | 60 | # resolve default crop percentage 61 | new_config['crop_pct'] = DEFAULT_CROP_PCT 62 | if 'crop_pct' in args and args['crop_pct'] is not None: 63 | new_config['crop_pct'] = args['crop_pct'] 64 | elif 'crop_pct' in default_cfg: 65 | new_config['crop_pct'] = default_cfg['crop_pct'] 66 | 67 | if verbose: 68 | logging.info('Data processing configuration for current model + dataset:') 69 | for n, v in new_config.items(): 70 | logging.info('\t%s: %s' % (n, str(v))) 71 | 72 | return new_config 73 | -------------------------------------------------------------------------------- /timm/data/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 4 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 6 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 7 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 8 | -------------------------------------------------------------------------------- /timm/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class OrderedDistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | .. note:: 14 | Dataset is assumed to be of constant size. 15 | Arguments: 16 | dataset: Dataset used for sampling. 17 | num_replicas (optional): Number of processes participating in 18 | distributed training. 19 | rank (optional): Rank of the current process within num_replicas. 20 | """ 21 | 22 | def __init__(self, dataset, num_replicas=None, rank=None): 23 | if num_replicas is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | num_replicas = dist.get_world_size() 27 | if rank is None: 28 | if not dist.is_available(): 29 | raise RuntimeError("Requires distributed package to be available") 30 | rank = dist.get_rank() 31 | self.dataset = dataset 32 | self.num_replicas = num_replicas 33 | self.rank = rank 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | indices = list(range(len(self.dataset))) 39 | 40 | # add extra samples to make it evenly divisible 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | -------------------------------------------------------------------------------- /timm/data/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 6 | x = x.long().view(-1, 1) 7 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 8 | 9 | 10 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 11 | off_value = smoothing / num_classes 12 | on_value = 1. - smoothing + off_value 13 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 14 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 15 | return lam*y1 + (1. - lam)*y2 16 | 17 | 18 | def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): 19 | lam = 1. 20 | if not disable: 21 | lam = np.random.beta(alpha, alpha) 22 | input = input.mul(lam).add_(1 - lam, input.flip(0)) 23 | target = mixup_target(target, num_classes, lam, smoothing) 24 | return input, target 25 | 26 | 27 | class FastCollateMixup: 28 | 29 | def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): 30 | self.mixup_alpha = mixup_alpha 31 | self.label_smoothing = label_smoothing 32 | self.num_classes = num_classes 33 | self.mixup_enabled = True 34 | 35 | def __call__(self, batch): 36 | batch_size = len(batch) 37 | lam = 1. 38 | if self.mixup_enabled: 39 | lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) 40 | 41 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 42 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 43 | 44 | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 45 | for i in range(batch_size): 46 | mixed = batch[i][0].astype(np.float32) * lam + \ 47 | batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) 48 | np.round(mixed, out=mixed) 49 | tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) 50 | 51 | return tensor, target 52 | -------------------------------------------------------------------------------- /timm/data/random_erasing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | 5 | 6 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 7 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 8 | # paths, flip the order so normal is run on CPU if this becomes a problem 9 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 10 | if per_pixel: 11 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 12 | elif rand_color: 13 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 14 | else: 15 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 16 | 17 | 18 | class RandomErasing: 19 | """ Randomly selects a rectangle region in an image and erases its pixels. 20 | 'Random Erasing Data Augmentation' by Zhong et al. 21 | See https://arxiv.org/pdf/1708.04896.pdf 22 | 23 | This variant of RandomErasing is intended to be applied to either a batch 24 | or single image tensor after it has been normalized by dataset mean and std. 25 | Args: 26 | probability: Probability that the Random Erasing operation will be performed. 27 | min_area: Minimum percentage of erased area wrt input image area. 28 | max_area: Maximum percentage of erased area wrt input image area. 29 | min_aspect: Minimum aspect ratio of erased area. 30 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 31 | 'const' - erase block is constant color of 0 for all channels 32 | 'rand' - erase block is same per-channel random (normal) color 33 | 'pixel' - erase block is per-pixel random (normal) color 34 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 35 | per-image count is randomly chosen between 1 and this value. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, 41 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 42 | self.probability = probability 43 | self.min_area = min_area 44 | self.max_area = max_area 45 | max_aspect = max_aspect or 1 / min_aspect 46 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 47 | self.min_count = min_count 48 | self.max_count = max_count or min_count 49 | self.num_splits = num_splits 50 | mode = mode.lower() 51 | self.rand_color = False 52 | self.per_pixel = False 53 | if mode == 'rand': 54 | self.rand_color = True # per block random normal 55 | elif mode == 'pixel': 56 | self.per_pixel = True # per pixel random normal 57 | else: 58 | assert not mode or mode == 'const' 59 | self.device = device 60 | 61 | def _erase(self, img, chan, img_h, img_w, dtype): 62 | if random.random() > self.probability: 63 | return 64 | area = img_h * img_w 65 | count = self.min_count if self.min_count == self.max_count else \ 66 | random.randint(self.min_count, self.max_count) 67 | for _ in range(count): 68 | for attempt in range(10): 69 | target_area = random.uniform(self.min_area, self.max_area) * area / count 70 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 71 | h = int(round(math.sqrt(target_area * aspect_ratio))) 72 | w = int(round(math.sqrt(target_area / aspect_ratio))) 73 | if w < img_w and h < img_h: 74 | top = random.randint(0, img_h - h) 75 | left = random.randint(0, img_w - w) 76 | img[:, top:top + h, left:left + w] = _get_pixels( 77 | self.per_pixel, self.rand_color, (chan, h, w), 78 | dtype=dtype, device=self.device) 79 | break 80 | 81 | def __call__(self, input): 82 | if len(input.size()) == 3: 83 | self._erase(input, *input.size(), input.dtype) 84 | else: 85 | batch_size, chan, img_h, img_w = input.size() 86 | # skip first slice of batch if num_splits is set (for clean portion of samples) 87 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 88 | for i in range(batch_start, batch_size): 89 | self._erase(input[i], chan, img_h, img_w, input.dtype) 90 | return input 91 | -------------------------------------------------------------------------------- /timm/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 2 | from .jsd import JsdCrossEntropy -------------------------------------------------------------------------------- /timm/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCrossEntropy(nn.Module): 7 | """ 8 | NLL loss with label smoothing. 9 | """ 10 | def __init__(self, smoothing=0.1): 11 | """ 12 | Constructor for the LabelSmoothing module. 13 | :param smoothing: label smoothing factor 14 | """ 15 | super(LabelSmoothingCrossEntropy, self).__init__() 16 | assert smoothing < 1.0 17 | self.smoothing = smoothing 18 | self.confidence = 1. - smoothing 19 | 20 | def forward(self, x, target): 21 | logprobs = F.log_softmax(x, dim=-1) 22 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 26 | return loss.mean() 27 | 28 | 29 | class SoftTargetCrossEntropy(nn.Module): 30 | 31 | def __init__(self): 32 | super(SoftTargetCrossEntropy, self).__init__() 33 | 34 | def forward(self, x, target): 35 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 36 | return loss.mean() 37 | -------------------------------------------------------------------------------- /timm/loss/jsd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .cross_entropy import LabelSmoothingCrossEntropy 6 | 7 | 8 | class JsdCrossEntropy(nn.Module): 9 | """ Jensen-Shannon Divergence + Cross-Entropy Loss 10 | 11 | Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 12 | From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 13 | https://arxiv.org/abs/1912.02781 14 | 15 | Hacked together by Ross Wightman 16 | """ 17 | def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 18 | super().__init__() 19 | self.num_splits = num_splits 20 | self.alpha = alpha 21 | if smoothing is not None and smoothing > 0: 22 | self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 23 | else: 24 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 25 | 26 | def __call__(self, output, target): 27 | split_size = output.shape[0] // self.num_splits 28 | assert split_size * self.num_splits == output.shape[0] 29 | logits_split = torch.split(output, split_size) 30 | 31 | # Cross-entropy is only computed on clean images 32 | loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 33 | probs = [F.softmax(logits, dim=1) for logits in logits_split] 34 | 35 | # Clamp mixture distribution to avoid exploding KL divergence 36 | logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 37 | loss += self.alpha * sum([F.kl_div( 38 | logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 39 | return loss 40 | -------------------------------------------------------------------------------- /timm/loss/slot_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SoltLoss(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def __call__(self, x, target): 10 | loss = F.nll_loss(x[0], target) + x[1] 11 | return loss -------------------------------------------------------------------------------- /timm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .inception_v4 import * 2 | from .inception_resnet_v2 import * 3 | from .densenet import * 4 | from .resnet import * 5 | from .dpn import * 6 | from .senet import * 7 | from .xception import * 8 | from .nasnet import * 9 | from .pnasnet import * 10 | from .selecsls import * 11 | from .efficientnet import * 12 | from .mobilenetv3 import * 13 | from .inception_v3 import * 14 | from .gluon_resnet import * 15 | from .gluon_xception import * 16 | from .res2net import * 17 | from .dla import * 18 | from .hrnet import * 19 | from .sknet import * 20 | from .tresnet import * 21 | from .resnest import * 22 | from .regnet import * 23 | from .vovnet import * 24 | 25 | from .registry import * 26 | from .factory import create_model 27 | from .helpers import load_checkpoint, resume_checkpoint 28 | from .layers import TestTimePoolHead, apply_test_time_pool 29 | from .layers import convert_splitbn_model 30 | from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit 31 | -------------------------------------------------------------------------------- /timm/models/factory.py: -------------------------------------------------------------------------------- 1 | from .registry import is_model, is_model_in_modules, model_entrypoint 2 | from .helpers import load_checkpoint 3 | from .layers import set_layer_config 4 | 5 | 6 | def create_model( 7 | model_name, 8 | pretrained=False, 9 | num_classes=1000, 10 | in_chans=3, 11 | checkpoint_path='', 12 | scriptable=None, 13 | exportable=None, 14 | no_jit=None, 15 | **kwargs): 16 | """Create a model 17 | 18 | Args: 19 | model_name (str): name of model to instantiate 20 | pretrained (bool): load pretrained ImageNet-1k weights if true 21 | num_classes (int): number of classes for final fully connected layer (default: 1000) 22 | in_chans (int): number of input channels / colors (default: 3) 23 | checkpoint_path (str): path of checkpoint to load after model is initialized 24 | scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) 25 | exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) 26 | no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) 27 | 28 | Keyword Args: 29 | drop_rate (float): dropout rate for training (default: 0.0) 30 | global_pool (str): global pool type (default: 'avg') 31 | **: other kwargs are model specific 32 | """ 33 | model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) 34 | 35 | # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args 36 | is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) 37 | if not is_efficientnet: 38 | kwargs.pop('bn_tf', None) 39 | kwargs.pop('bn_momentum', None) 40 | kwargs.pop('bn_eps', None) 41 | 42 | # Parameters that aren't supported by all models should default to None in command line args, 43 | # remove them if they are present and not set so that non-supporting models don't break. 44 | if kwargs.get('drop_block_rate', None) is None: 45 | kwargs.pop('drop_block_rate', None) 46 | 47 | # handle backwards compat with drop_connect -> drop_path change 48 | drop_connect_rate = kwargs.pop('drop_connect_rate', None) 49 | if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: 50 | print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." 51 | " Setting drop_path to %f." % drop_connect_rate) 52 | kwargs['drop_path_rate'] = drop_connect_rate 53 | 54 | if kwargs.get('drop_path_rate', None) is None: 55 | kwargs.pop('drop_path_rate', None) 56 | 57 | with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): 58 | if is_model(model_name): 59 | create_fn = model_entrypoint(model_name) 60 | model = create_fn(**model_args, **kwargs) 61 | else: 62 | raise RuntimeError('Unknown model (%s)' % model_name) 63 | 64 | if checkpoint_path: 65 | load_checkpoint(model, checkpoint_path) 66 | 67 | return model 68 | -------------------------------------------------------------------------------- /timm/models/feature_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from collections import defaultdict, OrderedDict 4 | from functools import partial 5 | from typing import List 6 | 7 | 8 | class FeatureHooks: 9 | 10 | def __init__(self, hooks, named_modules): 11 | # setup feature hooks 12 | modules = {k: v for k, v in named_modules} 13 | for h in hooks: 14 | hook_name = h['name'] 15 | m = modules[hook_name] 16 | hook_fn = partial(self._collect_output_hook, hook_name) 17 | if h['type'] == 'forward_pre': 18 | m.register_forward_pre_hook(hook_fn) 19 | elif h['type'] == 'forward': 20 | m.register_forward_hook(hook_fn) 21 | else: 22 | assert False, "Unsupported hook type" 23 | self._feature_outputs = defaultdict(OrderedDict) 24 | 25 | def _collect_output_hook(self, name, *args): 26 | x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre 27 | if isinstance(x, tuple): 28 | x = x[0] # unwrap input tuple 29 | self._feature_outputs[x.device][name] = x 30 | 31 | def get_output(self, device) -> List[torch.tensor]: 32 | output = list(self._feature_outputs[device].values()) 33 | self._feature_outputs[device] = OrderedDict() # clear after reading 34 | return output 35 | -------------------------------------------------------------------------------- /timm/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .anti_aliasing import AntiAliasDownsampleLayer 5 | from .blur_pool import BlurPool2d 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame 10 | from .conv_bn_act import ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm_act import create_norm_act, get_norm_act_layer 15 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 16 | from .eca import EcaModule, CecaModule 17 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 18 | from .inplace_abn import InplaceAbn 19 | from .mixed_conv2d import MixedConv2d 20 | from .norm_act import BatchNormAct2d 21 | from .padding import get_padding 22 | from .pool2d_same import AvgPool2dSame, create_pool2d 23 | from .se import SEModule 24 | from .selective_kernel import SelectiveKernelConv 25 | from .separable_conv import SeparableConv2d, SeparableConvBnAct 26 | from .space_to_depth import SpaceToDepthModule 27 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 28 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 29 | from .weight_init import trunc_normal_ 30 | -------------------------------------------------------------------------------- /timm/models/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | -------------------------------------------------------------------------------- /timm/models/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /timm/models/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Author: Ross Wightman (rwightman) 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class AdaptiveAvgMaxPool2d(nn.Module): 53 | def __init__(self, output_size=1): 54 | super(AdaptiveAvgMaxPool2d, self).__init__() 55 | self.output_size = output_size 56 | 57 | def forward(self, x): 58 | return adaptive_avgmax_pool2d(x, self.output_size) 59 | 60 | 61 | class AdaptiveCatAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_catavgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class SelectAdaptivePool2d(nn.Module): 71 | """Selectable global pooling layer with dynamic input kernel size 72 | """ 73 | def __init__(self, output_size=1, pool_type='avg', flatten=False): 74 | super(SelectAdaptivePool2d, self).__init__() 75 | self.output_size = output_size 76 | self.pool_type = pool_type 77 | self.flatten = flatten 78 | if pool_type == 'avgmax': 79 | self.pool = AdaptiveAvgMaxPool2d(output_size) 80 | elif pool_type == 'catavgmax': 81 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 82 | elif pool_type == 'max': 83 | self.pool = nn.AdaptiveMaxPool2d(output_size) 84 | else: 85 | if pool_type != 'avg': 86 | assert False, 'Invalid pool type: %s' % pool_type 87 | self.pool = nn.AdaptiveAvgPool2d(output_size) 88 | 89 | def forward(self, x): 90 | x = self.pool(x) 91 | if self.flatten: 92 | x = x.flatten(1) 93 | return x 94 | 95 | def feat_mult(self): 96 | return adaptive_pool_feat_mult(self.pool_type) 97 | 98 | def __repr__(self): 99 | return self.__class__.__name__ + ' (' \ 100 | + 'output_size=' + str(self.output_size) \ 101 | + ', pool_type=' + self.pool_type + ')' 102 | -------------------------------------------------------------------------------- /timm/models/layers/anti_aliasing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AntiAliasDownsampleLayer(nn.Module): 8 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False): 9 | super(AntiAliasDownsampleLayer, self).__init__() 10 | if no_jit: 11 | self.op = Downsample(channels, filt_size, stride) 12 | else: 13 | self.op = DownsampleJIT(channels, filt_size, stride) 14 | 15 | # FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls 16 | 17 | def forward(self, x): 18 | return self.op(x) 19 | 20 | 21 | @torch.jit.script 22 | class DownsampleJIT(object): 23 | def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2): 24 | self.channels = channels 25 | self.stride = stride 26 | self.filt_size = filt_size 27 | assert self.filt_size == 3 28 | assert stride == 2 29 | self.filt = {} # lazy init by device for DataParallel compat 30 | 31 | def _create_filter(self, like: torch.Tensor): 32 | filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device) 33 | filt = filt[:, None] * filt[None, :] 34 | filt = filt / torch.sum(filt) 35 | return filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 36 | 37 | def __call__(self, input: torch.Tensor): 38 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 39 | filt = self.filt.get(str(input.device), self._create_filter(input)) 40 | return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1]) 41 | 42 | 43 | class Downsample(nn.Module): 44 | def __init__(self, channels=None, filt_size=3, stride=2): 45 | super(Downsample, self).__init__() 46 | self.channels = channels 47 | self.filt_size = filt_size 48 | self.stride = stride 49 | 50 | assert self.filt_size == 3 51 | filt = torch.tensor([1., 2., 1.]) 52 | filt = filt[:, None] * filt[None, :] 53 | filt = filt / torch.sum(filt) 54 | 55 | # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) 56 | self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) 57 | 58 | def forward(self, input): 59 | input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') 60 | return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) 61 | -------------------------------------------------------------------------------- /timm/models/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | FIXME merge this impl with those in `anti_aliasing.py` 7 | 8 | Hacked together by Chris Ha and Ross Wightman 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | from typing import Dict 16 | from .padding import get_padding 17 | 18 | 19 | class BlurPool2d(nn.Module): 20 | r"""Creates a module that computes blurs and downsample a given feature map. 21 | See :cite:`zhang2019shiftinvar` for more details. 22 | Corresponds to the Downsample class, which does blurring and subsampling 23 | 24 | Args: 25 | channels = Number of input channels 26 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 27 | stride (int): downsampling filter stride 28 | 29 | Returns: 30 | torch.Tensor: the transformed tensor. 31 | """ 32 | filt: Dict[str, torch.Tensor] 33 | 34 | def __init__(self, channels, filt_size=3, stride=2) -> None: 35 | super(BlurPool2d, self).__init__() 36 | assert filt_size > 1 37 | self.channels = channels 38 | self.filt_size = filt_size 39 | self.stride = stride 40 | pad_size = [get_padding(filt_size, stride, dilation=1)] * 4 41 | self.padding = nn.ReflectionPad2d(pad_size) 42 | self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat 43 | self.filt = {} # lazy init by device for DataParallel compat 44 | 45 | def _create_filter(self, like: torch.Tensor): 46 | blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device) 47 | return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1) 48 | 49 | def _apply(self, fn): 50 | # override nn.Module _apply, reset filter cache if used 51 | self.filt = {} 52 | super(BlurPool2d, self)._apply(fn) 53 | 54 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 55 | C = input_tensor.shape[1] 56 | blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor)) 57 | return F.conv2d( 58 | self.padding(input_tensor), blur_filt, stride=self.stride, groups=C) 59 | -------------------------------------------------------------------------------- /timm/models/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by Ross Wightman 9 | """ 10 | 11 | import torch 12 | from torch import nn as nn 13 | from .conv_bn_act import ConvBnAct 14 | 15 | 16 | class ChannelAttn(nn.Module): 17 | """ Original CBAM channel attention module, currently avg + max pool variant only. 18 | """ 19 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU): 20 | super(ChannelAttn, self).__init__() 21 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 22 | self.max_pool = nn.AdaptiveMaxPool2d(1) 23 | self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) 24 | self.act = act_layer(inplace=True) 25 | self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) 26 | 27 | def forward(self, x): 28 | x_avg = self.avg_pool(x) 29 | x_max = self.max_pool(x) 30 | x_avg = self.fc2(self.act(self.fc1(x_avg))) 31 | x_max = self.fc2(self.act(self.fc1(x_max))) 32 | x_attn = x_avg + x_max 33 | return x * x_attn.sigmoid() 34 | 35 | 36 | class LightChannelAttn(ChannelAttn): 37 | """An experimental 'lightweight' that sums avg + max pool first 38 | """ 39 | def __init__(self, channels, reduction=16): 40 | super(LightChannelAttn, self).__init__(channels, reduction) 41 | 42 | def forward(self, x): 43 | x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) 44 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 45 | return x * x_attn.sigmoid() 46 | 47 | 48 | class SpatialAttn(nn.Module): 49 | """ Original CBAM spatial attention module 50 | """ 51 | def __init__(self, kernel_size=7): 52 | super(SpatialAttn, self).__init__() 53 | self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) 54 | 55 | def forward(self, x): 56 | x_avg = torch.mean(x, dim=1, keepdim=True) 57 | x_max = torch.max(x, dim=1, keepdim=True)[0] 58 | x_attn = torch.cat([x_avg, x_max], dim=1) 59 | x_attn = self.conv(x_attn) 60 | return x * x_attn.sigmoid() 61 | 62 | 63 | class LightSpatialAttn(nn.Module): 64 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 65 | """ 66 | def __init__(self, kernel_size=7): 67 | super(LightSpatialAttn, self).__init__() 68 | self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) 69 | 70 | def forward(self, x): 71 | x_avg = torch.mean(x, dim=1, keepdim=True) 72 | x_max = torch.max(x, dim=1, keepdim=True)[0] 73 | x_attn = 0.5 * x_avg + 0.5 * x_max 74 | x_attn = self.conv(x_attn) 75 | return x * x_attn.sigmoid() 76 | 77 | 78 | class CbamModule(nn.Module): 79 | def __init__(self, channels, spatial_kernel_size=7): 80 | super(CbamModule, self).__init__() 81 | self.channel = ChannelAttn(channels) 82 | self.spatial = SpatialAttn(spatial_kernel_size) 83 | 84 | def forward(self, x): 85 | x = self.channel(x) 86 | x = self.spatial(x) 87 | return x 88 | 89 | 90 | class LightCbamModule(nn.Module): 91 | def __init__(self, channels, spatial_kernel_size=7): 92 | super(LightCbamModule, self).__init__() 93 | self.channel = LightChannelAttn(channels) 94 | self.spatial = LightSpatialAttn(spatial_kernel_size) 95 | 96 | def forward(self, x): 97 | x = self.channel(x) 98 | x = self.spatial(x) 99 | return x 100 | 101 | -------------------------------------------------------------------------------- /timm/models/layers/cond_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Conditionally Parameterized Convolution (CondConv) 2 | 3 | Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference 4 | (https://arxiv.org/abs/1904.04971) 5 | 6 | Hacked together by Ross Wightman 7 | """ 8 | 9 | import math 10 | from functools import partial 11 | import numpy as np 12 | import torch 13 | from torch import nn as nn 14 | from torch.nn import functional as F 15 | 16 | from .helpers import tup_pair 17 | from .conv2d_same import conv2d_same 18 | from .padding import get_padding_value 19 | 20 | 21 | def get_condconv_initializer(initializer, num_experts, expert_shape): 22 | def condconv_initializer(weight): 23 | """CondConv initializer function.""" 24 | num_params = np.prod(expert_shape) 25 | if (len(weight.shape) != 2 or weight.shape[0] != num_experts or 26 | weight.shape[1] != num_params): 27 | raise (ValueError( 28 | 'CondConv variables must have shape [num_experts, num_params]')) 29 | for i in range(num_experts): 30 | initializer(weight[i].view(expert_shape)) 31 | return condconv_initializer 32 | 33 | 34 | class CondConv2d(nn.Module): 35 | """ Conditionally Parameterized Convolution 36 | Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py 37 | 38 | Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: 39 | https://github.com/pytorch/pytorch/issues/17983 40 | """ 41 | __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] 42 | 43 | def __init__(self, in_channels, out_channels, kernel_size=3, 44 | stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): 45 | super(CondConv2d, self).__init__() 46 | 47 | self.in_channels = in_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = tup_pair(kernel_size) 50 | self.stride = tup_pair(stride) 51 | padding_val, is_padding_dynamic = get_padding_value( 52 | padding, kernel_size, stride=stride, dilation=dilation) 53 | self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript 54 | self.padding = tup_pair(padding_val) 55 | self.dilation = tup_pair(dilation) 56 | self.groups = groups 57 | self.num_experts = num_experts 58 | 59 | self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size 60 | weight_num_param = 1 61 | for wd in self.weight_shape: 62 | weight_num_param *= wd 63 | self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) 64 | 65 | if bias: 66 | self.bias_shape = (self.out_channels,) 67 | self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) 68 | else: 69 | self.register_parameter('bias', None) 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | init_weight = get_condconv_initializer( 75 | partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) 76 | init_weight(self.weight) 77 | if self.bias is not None: 78 | fan_in = np.prod(self.weight_shape[1:]) 79 | bound = 1 / math.sqrt(fan_in) 80 | init_bias = get_condconv_initializer( 81 | partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) 82 | init_bias(self.bias) 83 | 84 | def forward(self, x, routing_weights): 85 | B, C, H, W = x.shape 86 | weight = torch.matmul(routing_weights, self.weight) 87 | new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size 88 | weight = weight.view(new_weight_shape) 89 | bias = None 90 | if self.bias is not None: 91 | bias = torch.matmul(routing_weights, self.bias) 92 | bias = bias.view(B * self.out_channels) 93 | # move batch elements with channels so each batch element can be efficiently convolved with separate kernel 94 | x = x.view(1, B * C, H, W) 95 | if self.dynamic_padding: 96 | out = conv2d_same( 97 | x, weight, bias, stride=self.stride, padding=self.padding, 98 | dilation=self.dilation, groups=self.groups * B) 99 | else: 100 | out = F.conv2d( 101 | x, weight, bias, stride=self.stride, padding=self.padding, 102 | dilation=self.dilation, groups=self.groups * B) 103 | out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) 104 | 105 | # Literal port (from TF definition) 106 | # x = torch.split(x, 1, 0) 107 | # weight = torch.split(weight, 1, 0) 108 | # if self.bias is not None: 109 | # bias = torch.matmul(routing_weights, self.bias) 110 | # bias = torch.split(bias, 1, 0) 111 | # else: 112 | # bias = [None] * B 113 | # out = [] 114 | # for xi, wi, bi in zip(x, weight, bias): 115 | # wi = wi.view(*self.weight_shape) 116 | # if bi is not None: 117 | # bi = bi.view(*self.bias_shape) 118 | # out.append(self.conv_fn( 119 | # xi, wi, bi, stride=self.stride, padding=self.padding, 120 | # dilation=self.dilation, groups=self.groups)) 121 | # out = torch.cat(out, 0) 122 | return out 123 | -------------------------------------------------------------------------------- /timm/models/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /timm/models/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /timm/models/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .create_conv2d import create_conv2d 8 | from .create_norm_act import convert_norm_act_type 9 | 10 | 11 | class ConvBnAct(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 13 | norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True, 14 | drop_block=None, aa_layer=None): 15 | super(ConvBnAct, self).__init__() 16 | use_aa = aa_layer is not None 17 | self.conv = create_conv2d( 18 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=False) 20 | 21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 22 | norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) 23 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) 24 | self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | if self.aa is not None: 30 | x = self.aa(x) 31 | return x 32 | -------------------------------------------------------------------------------- /timm/models/layers/create_act.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .activations_jit import * 3 | from .activations_me import * 4 | from .config import is_exportable, is_scriptable, is_no_jit 5 | 6 | 7 | _ACT_FN_DEFAULT = dict( 8 | swish=swish, 9 | mish=mish, 10 | relu=F.relu, 11 | relu6=F.relu6, 12 | leaky_relu=F.leaky_relu, 13 | elu=F.elu, 14 | prelu=F.prelu, 15 | celu=F.celu, 16 | selu=F.selu, 17 | gelu=F.gelu, 18 | sigmoid=sigmoid, 19 | tanh=tanh, 20 | hard_sigmoid=hard_sigmoid, 21 | hard_swish=hard_swish, 22 | hard_mish=hard_mish, 23 | ) 24 | 25 | _ACT_FN_JIT = dict( 26 | swish=swish_jit, 27 | mish=mish_jit, 28 | hard_sigmoid=hard_sigmoid_jit, 29 | hard_swish=hard_swish_jit, 30 | hard_mish=hard_mish_jit 31 | ) 32 | 33 | _ACT_FN_ME = dict( 34 | swish=swish_me, 35 | mish=mish_me, 36 | hard_sigmoid=hard_sigmoid_me, 37 | hard_swish=hard_swish_me, 38 | hard_mish=hard_mish_me, 39 | ) 40 | 41 | _ACT_LAYER_DEFAULT = dict( 42 | swish=Swish, 43 | mish=Mish, 44 | relu=nn.ReLU, 45 | relu6=nn.ReLU6, 46 | elu=nn.ELU, 47 | prelu=nn.PReLU, 48 | celu=nn.CELU, 49 | selu=nn.SELU, 50 | gelu=nn.GELU, 51 | sigmoid=Sigmoid, 52 | tanh=Tanh, 53 | hard_sigmoid=HardSigmoid, 54 | hard_swish=HardSwish, 55 | hard_mish=HardMish, 56 | ) 57 | 58 | _ACT_LAYER_JIT = dict( 59 | swish=SwishJit, 60 | mish=MishJit, 61 | hard_sigmoid=HardSigmoidJit, 62 | hard_swish=HardSwishJit, 63 | hard_mish=HardMishJit 64 | ) 65 | 66 | _ACT_LAYER_ME = dict( 67 | swish=SwishMe, 68 | mish=MishMe, 69 | hard_sigmoid=HardSigmoidMe, 70 | hard_swish=HardSwishMe, 71 | hard_mish=HardMishMe, 72 | ) 73 | 74 | 75 | def get_act_fn(name='relu'): 76 | """ Activation Function Factory 77 | Fetching activation fns by name with this function allows export or torch script friendly 78 | functions to be returned dynamically based on current config. 79 | """ 80 | if not name: 81 | return None 82 | if not (is_no_jit() or is_exportable() or is_scriptable()): 83 | # If not exporting or scripting the model, first look for a memory-efficient version with 84 | # custom autograd, then fallback 85 | if name in _ACT_FN_ME: 86 | return _ACT_FN_ME[name] 87 | if not is_no_jit(): 88 | if name in _ACT_FN_JIT: 89 | return _ACT_FN_JIT[name] 90 | return _ACT_FN_DEFAULT[name] 91 | 92 | 93 | def get_act_layer(name='relu'): 94 | """ Activation Layer Factory 95 | Fetching activation layers by name with this function allows export or torch script friendly 96 | functions to be returned dynamically based on current config. 97 | """ 98 | if not name: 99 | return None 100 | if not (is_no_jit() or is_exportable() or is_scriptable()): 101 | if name in _ACT_LAYER_ME: 102 | return _ACT_LAYER_ME[name] 103 | if not is_no_jit(): 104 | if name in _ACT_LAYER_JIT: 105 | return _ACT_LAYER_JIT[name] 106 | return _ACT_LAYER_DEFAULT[name] 107 | 108 | 109 | def create_act_layer(name, inplace=False, **kwargs): 110 | act_layer = get_act_layer(name) 111 | if act_layer is not None: 112 | return act_layer(inplace=inplace, **kwargs) 113 | else: 114 | return None 115 | -------------------------------------------------------------------------------- /timm/models/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Select AttentionFactory Method 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch 6 | from .se import SEModule, EffectiveSEModule 7 | from .eca import EcaModule, CecaModule 8 | from .cbam import CbamModule, LightCbamModule 9 | 10 | 11 | def create_attn(attn_type, channels, **kwargs): 12 | module_cls = None 13 | if attn_type is not None: 14 | if isinstance(attn_type, str): 15 | attn_type = attn_type.lower() 16 | if attn_type == 'se': 17 | module_cls = SEModule 18 | elif attn_type == 'ese': 19 | module_cls = EffectiveSEModule 20 | elif attn_type == 'eca': 21 | module_cls = EcaModule 22 | elif attn_type == 'ceca': 23 | module_cls = CecaModule 24 | elif attn_type == 'cbam': 25 | module_cls = CbamModule 26 | elif attn_type == 'lcbam': 27 | module_cls = LightCbamModule 28 | else: 29 | assert False, "Invalid attn module (%s)" % attn_type 30 | elif isinstance(attn_type, bool): 31 | if attn_type: 32 | module_cls = SEModule 33 | else: 34 | module_cls = attn_type 35 | if module_cls is not None: 36 | return module_cls(channels, **kwargs) 37 | return None 38 | -------------------------------------------------------------------------------- /timm/models/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 20 | # We're going to use only lists for defining the MixedConv2d kernel groups, 21 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 22 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 23 | else: 24 | depthwise = kwargs.pop('depthwise', False) 25 | groups = out_channels if depthwise else kwargs.pop('groups', 1) 26 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 27 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 28 | else: 29 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 30 | return m 31 | -------------------------------------------------------------------------------- /timm/models/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | import types 2 | import functools 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .evo_norm import EvoNormBatch2d, EvoNormSample2d 8 | from .norm_act import BatchNormAct2d, GroupNormAct 9 | from .inplace_abn import InplaceAbn 10 | 11 | _NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} 12 | 13 | 14 | def get_norm_act_layer(layer_class): 15 | layer_class = layer_class.replace('_', '').lower() 16 | if layer_class.startswith("batchnorm"): 17 | layer = BatchNormAct2d 18 | elif layer_class.startswith("groupnorm"): 19 | layer = GroupNormAct 20 | elif layer_class == "evonormbatch": 21 | layer = EvoNormBatch2d 22 | elif layer_class == "evonormsample": 23 | layer = EvoNormSample2d 24 | elif layer_class == "iabn" or layer_class == "inplaceabn": 25 | layer = InplaceAbn 26 | else: 27 | assert False, "Invalid norm_act layer (%s)" % layer_class 28 | return layer 29 | 30 | 31 | def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): 32 | layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu 33 | assert len(layer_parts) in (1, 2) 34 | layer = get_norm_act_layer(layer_parts[0]) 35 | #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? 36 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 37 | if jit: 38 | layer_instance = torch.jit.script(layer_instance) 39 | return layer_instance 40 | 41 | 42 | def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None): 43 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 44 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 45 | norm_act_args = norm_kwargs.copy() if norm_kwargs else {} 46 | if isinstance(norm_layer, str): 47 | norm_act_layer = get_norm_act_layer(norm_layer) 48 | elif norm_layer in _NORM_ACT_TYPES: 49 | norm_act_layer = norm_layer 50 | elif isinstance(norm_layer, (types.FunctionType, functools.partial)): 51 | # assuming this is a lambda/fn/bound partial that creates norm_act layer 52 | norm_act_layer = norm_layer 53 | else: 54 | type_name = norm_layer.__name__.lower() 55 | if type_name.startswith('batchnorm'): 56 | norm_act_layer = BatchNormAct2d 57 | elif type_name.startswith('groupnorm'): 58 | norm_act_layer = GroupNormAct 59 | else: 60 | assert False, f"No equivalent norm_act layer for {type_name}" 61 | # Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 62 | # Newer models will use `apply_act` and likely have `act_layer` arg bound to relevant NormAct types. 63 | norm_act_args.update(dict(act_layer=act_layer)) 64 | return norm_act_layer, norm_act_args 65 | -------------------------------------------------------------------------------- /timm/models/layers/evo_norm.py: -------------------------------------------------------------------------------- 1 | """EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch 2 | 3 | An attempt at getting decent performing EvoNorms running in PyTorch. 4 | While currently faster than other impl, still quite a ways off the built-in BN 5 | in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). 6 | 7 | Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. 8 | 9 | Hacked together by Ross Wightman 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | class EvoNormBatch2d(nn.Module): 17 | def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): 18 | super(EvoNormBatch2d, self).__init__() 19 | self.apply_act = apply_act # apply activation (non-linearity) 20 | self.momentum = momentum 21 | self.eps = eps 22 | param_shape = (1, num_features, 1, 1) 23 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 24 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 25 | if apply_act: 26 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 27 | self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.apply_act: 34 | nn.init.ones_(self.v) 35 | 36 | def forward(self, x): 37 | assert x.dim() == 4, 'expected 4D input' 38 | x_type = x.dtype 39 | if self.training: 40 | var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) 41 | self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var) 42 | else: 43 | var = self.running_var 44 | 45 | if self.apply_act: 46 | v = self.v.to(dtype=x_type) 47 | d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) 48 | d = d.max((var + self.eps).sqrt().to(dtype=x_type)) 49 | x = x / d 50 | return x * self.weight + self.bias 51 | 52 | 53 | class EvoNormSample2d(nn.Module): 54 | def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): 55 | super(EvoNormSample2d, self).__init__() 56 | self.apply_act = apply_act # apply activation (non-linearity) 57 | self.groups = groups 58 | self.eps = eps 59 | param_shape = (1, num_features, 1, 1) 60 | self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) 61 | self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) 62 | if apply_act: 63 | self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | nn.init.ones_(self.weight) 68 | nn.init.zeros_(self.bias) 69 | if self.apply_act: 70 | nn.init.ones_(self.v) 71 | 72 | def forward(self, x): 73 | assert x.dim() == 4, 'expected 4D input' 74 | B, C, H, W = x.shape 75 | assert C % self.groups == 0 76 | if self.apply_act: 77 | n = (x * self.v).sigmoid().reshape(B, self.groups, -1) 78 | x = x.reshape(B, self.groups, -1) 79 | x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() 80 | x = x.reshape(B, C, H, W) 81 | return x * self.weight + self.bias 82 | -------------------------------------------------------------------------------- /timm/models/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | from itertools import repeat 6 | from torch._six import container_abcs 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, container_abcs.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | tup_single = _ntuple(1) 19 | tup_pair = _ntuple(2) 20 | tup_triple = _ntuple(3) 21 | tup_quadruple = _ntuple(4) 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /timm/models/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_block=None,): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity') 50 | self.act_name = act_layer 51 | else: 52 | # convert act layer passed as type to string 53 | if isinstance(act_layer, nn.ELU): 54 | self.act_name = 'elu' 55 | elif isinstance(act_layer, nn.LeakyReLU): 56 | self.act_name = 'leaky_relu' 57 | else: 58 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 59 | else: 60 | self.act_name = 'identity' 61 | self.act_param = act_param 62 | if self.affine: 63 | self.weight = nn.Parameter(torch.ones(num_features)) 64 | self.bias = nn.Parameter(torch.zeros(num_features)) 65 | else: 66 | self.register_parameter('weight', None) 67 | self.register_parameter('bias', None) 68 | self.register_buffer('running_mean', torch.zeros(num_features)) 69 | self.register_buffer('running_var', torch.ones(num_features)) 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | nn.init.constant_(self.running_mean, 0) 74 | nn.init.constant_(self.running_var, 1) 75 | if self.affine: 76 | nn.init.constant_(self.weight, 1) 77 | nn.init.constant_(self.bias, 0) 78 | 79 | def forward(self, x): 80 | output = inplace_abn( 81 | x, self.weight, self.bias, self.running_mean, self.running_var, 82 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 83 | if isinstance(output, tuple): 84 | output = output[0] 85 | return output 86 | -------------------------------------------------------------------------------- /timm/models/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.modules.utils import _pair, _quadruple 6 | 7 | 8 | class MedianPool2d(nn.Module): 9 | """ Median pool (usable as median filter when stride=1) module. 10 | 11 | Args: 12 | kernel_size: size of pooling kernel, int or 2-tuple 13 | stride: pool stride, int or 2-tuple 14 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 15 | same: override padding and enforce same padding, boolean 16 | """ 17 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 18 | super(MedianPool2d, self).__init__() 19 | self.k = _pair(kernel_size) 20 | self.stride = _pair(stride) 21 | self.padding = _quadruple(padding) # convert to l, r, t, b 22 | self.same = same 23 | 24 | def _padding(self, x): 25 | if self.same: 26 | ih, iw = x.size()[2:] 27 | if ih % self.stride[0] == 0: 28 | ph = max(self.k[0] - self.stride[0], 0) 29 | else: 30 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 31 | if iw % self.stride[1] == 0: 32 | pw = max(self.k[1] - self.stride[1], 0) 33 | else: 34 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 35 | pl = pw // 2 36 | pr = pw - pl 37 | pt = ph // 2 38 | pb = ph - pt 39 | padding = (pl, pr, pt, pb) 40 | else: 41 | padding = self.padding 42 | return padding 43 | 44 | def forward(self, x): 45 | x = F.pad(x, self._padding(x), mode='reflect') 46 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 47 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 48 | return x 49 | -------------------------------------------------------------------------------- /timm/models/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = out_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/models/layers/norm_act.py: -------------------------------------------------------------------------------- 1 | """ Normalization + Activation Layers 2 | """ 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .create_act import get_act_layer 8 | 9 | 10 | class BatchNormAct2d(nn.BatchNorm2d): 11 | """BatchNorm + Activation 12 | 13 | This module performs BatchNorm + Activation in a manner that will remain backwards 14 | compatible with weights trained with separate bn, act. This is why we inherit from BN 15 | instead of composing it as a .bn member. 16 | """ 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, 18 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 19 | super(BatchNormAct2d, self).__init__( 20 | num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 21 | if isinstance(act_layer, str): 22 | act_layer = get_act_layer(act_layer) 23 | if act_layer is not None and apply_act: 24 | self.act = act_layer(inplace=inplace) 25 | else: 26 | self.act = None 27 | 28 | def _forward_jit(self, x): 29 | """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function 30 | """ 31 | # exponential_average_factor is self.momentum set to 32 | # (when it is available) only so that if gets updated 33 | # in ONNX graph when this node is exported to ONNX. 34 | if self.momentum is None: 35 | exponential_average_factor = 0.0 36 | else: 37 | exponential_average_factor = self.momentum 38 | 39 | if self.training and self.track_running_stats: 40 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 41 | if self.num_batches_tracked is not None: 42 | self.num_batches_tracked += 1 43 | if self.momentum is None: # use cumulative moving average 44 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 45 | else: # use exponential moving average 46 | exponential_average_factor = self.momentum 47 | 48 | x = F.batch_norm( 49 | x, self.running_mean, self.running_var, self.weight, self.bias, 50 | self.training or not self.track_running_stats, 51 | exponential_average_factor, self.eps) 52 | return x 53 | 54 | @torch.jit.ignore 55 | def _forward_python(self, x): 56 | return super(BatchNormAct2d, self).forward(x) 57 | 58 | def forward(self, x): 59 | # FIXME cannot call parent forward() and maintain jit.script compatibility? 60 | if torch.jit.is_scripting(): 61 | x = self._forward_jit(x) 62 | else: 63 | x = self._forward_python(x) 64 | if self.act is not None: 65 | x = self.act(x) 66 | return x 67 | 68 | 69 | class GroupNormAct(nn.GroupNorm): 70 | 71 | def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, 72 | apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): 73 | super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) 74 | if isinstance(act_layer, str): 75 | act_layer = get_act_layer(act_layer) 76 | if act_layer is not None and apply_act: 77 | self.act = act_layer(inplace=inplace) 78 | else: 79 | self.act = None 80 | 81 | def forward(self, x): 82 | x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 83 | if self.act is not None: 84 | x = self.act(x) 85 | return x 86 | -------------------------------------------------------------------------------- /timm/models/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /timm/models/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import tup_pair 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = tup_pair(kernel_size) 26 | stride = tup_pair(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | return avg_pool2d_same( 31 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 32 | 33 | 34 | def max_pool2d_same( 35 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 36 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 37 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 38 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 39 | 40 | 41 | class MaxPool2dSame(nn.MaxPool2d): 42 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 43 | """ 44 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): 45 | kernel_size = tup_pair(kernel_size) 46 | stride = tup_pair(stride) 47 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) 48 | 49 | def forward(self, x): 50 | return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) 51 | 52 | 53 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 54 | stride = stride or kernel_size 55 | padding = kwargs.pop('padding', '') 56 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 57 | if is_dynamic: 58 | if pool_type == 'avg': 59 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 60 | elif pool_type == 'max': 61 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 62 | else: 63 | assert False, f'Unsupported pool type {pool_type}' 64 | else: 65 | if pool_type == 'avg': 66 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 67 | elif pool_type == 'max': 68 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 69 | else: 70 | assert False, f'Unsupported pool type {pool_type}' 71 | -------------------------------------------------------------------------------- /timm/models/layers/se.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from .create_act import get_act_fn 3 | 4 | 5 | class SEModule(nn.Module): 6 | 7 | def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, 8 | gate_fn='sigmoid'): 9 | super(SEModule, self).__init__() 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | reduction_channels = reduction_channels or max(channels // reduction, min_channels) 12 | self.fc1 = nn.Conv2d( 13 | channels, reduction_channels, kernel_size=1, padding=0, bias=True) 14 | self.act = act_layer(inplace=True) 15 | self.fc2 = nn.Conv2d( 16 | reduction_channels, channels, kernel_size=1, padding=0, bias=True) 17 | self.gate_fn = get_act_fn(gate_fn) 18 | 19 | def forward(self, x): 20 | x_se = self.avg_pool(x) 21 | x_se = self.fc1(x_se) 22 | x_se = self.act(x_se) 23 | x_se = self.fc2(x_se) 24 | return x * self.gate_fn(x_se) 25 | 26 | 27 | class EffectiveSEModule(nn.Module): 28 | """ 'Effective Squeeze-Excitation 29 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 30 | """ 31 | def __init__(self, channel, gate_fn='hard_sigmoid'): 32 | super(EffectiveSEModule, self).__init__() 33 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 34 | self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0) 35 | self.gate_fn = get_act_fn(gate_fn) 36 | 37 | def forward(self, x): 38 | x_se = self.avg_pool(x) 39 | x_se = self.fc(x_se) 40 | return x * self.gate_fn(x_se, inplace=True) 41 | -------------------------------------------------------------------------------- /timm/models/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | from .create_conv2d import create_conv2d 4 | from .create_norm_act import convert_norm_act_type 5 | 6 | 7 | class SeparableConvBnAct(nn.Module): 8 | """ Separable Conv w/ trailing Norm and Activation 9 | """ 10 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 11 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None, 12 | act_layer=nn.ReLU, apply_act=True, drop_block=None): 13 | super(SeparableConvBnAct, self).__init__() 14 | norm_kwargs = norm_kwargs or {} 15 | 16 | self.conv_dw = create_conv2d( 17 | in_channels, int(in_channels * channel_multiplier), kernel_size, 18 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 19 | 20 | self.conv_pw = create_conv2d( 21 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 22 | 23 | norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) 24 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) 25 | 26 | def forward(self, x): 27 | x = self.conv_dw(x) 28 | x = self.conv_pw(x) 29 | if self.bn is not None: 30 | x = self.bn(x) 31 | return x 32 | 33 | 34 | class SeparableConv2d(nn.Module): 35 | """ Separable Conv 36 | """ 37 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 38 | channel_multiplier=1.0, pw_kernel_size=1): 39 | super(SeparableConv2d, self).__init__() 40 | 41 | self.conv_dw = create_conv2d( 42 | in_channels, int(in_channels * channel_multiplier), kernel_size, 43 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 44 | 45 | self.conv_pw = create_conv2d( 46 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 47 | 48 | def forward(self, x): 49 | x = self.conv_dw(x) 50 | x = self.conv_pw(x) 51 | return x 52 | -------------------------------------------------------------------------------- /timm/models/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /timm/models/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | 14 | class RadixSoftmax(nn.Module): 15 | def __init__(self, radix, cardinality): 16 | super(RadixSoftmax, self).__init__() 17 | self.radix = radix 18 | self.cardinality = cardinality 19 | 20 | def forward(self, x): 21 | batch = x.size(0) 22 | if self.radix > 1: 23 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 24 | x = F.softmax(x, dim=1) 25 | x = x.reshape(batch, -1) 26 | else: 27 | x = torch.sigmoid(x) 28 | return x 29 | 30 | 31 | class SplitAttnConv2d(nn.Module): 32 | """Split-Attention Conv2d 33 | """ 34 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 35 | dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, 36 | act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): 37 | super(SplitAttnConv2d, self).__init__() 38 | self.radix = radix 39 | self.drop_block = drop_block 40 | mid_chs = out_channels * radix 41 | attn_chs = max(in_channels * radix // reduction_factor, 32) 42 | 43 | self.conv = nn.Conv2d( 44 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 45 | groups=groups * radix, bias=bias, **kwargs) 46 | self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None 47 | self.act0 = act_layer(inplace=True) 48 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 49 | self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None 50 | self.act1 = act_layer(inplace=True) 51 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 52 | self.rsoftmax = RadixSoftmax(radix, groups) 53 | 54 | def forward(self, x): 55 | x = self.conv(x) 56 | if self.bn0 is not None: 57 | x = self.bn0(x) 58 | if self.drop_block is not None: 59 | x = self.drop_block(x) 60 | x = self.act0(x) 61 | 62 | B, RC, H, W = x.shape 63 | if self.radix > 1: 64 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 65 | x_gap = x.sum(dim=1) 66 | else: 67 | x_gap = x 68 | x_gap = F.adaptive_avg_pool2d(x_gap, 1) 69 | x_gap = self.fc1(x_gap) 70 | if self.bn1 is not None: 71 | x_gap = self.bn1(x_gap) 72 | x_gap = self.act1(x_gap) 73 | x_attn = self.fc2(x_gap) 74 | 75 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 76 | if self.radix > 1: 77 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 78 | else: 79 | out = x * x_attn 80 | return out.contiguous() 81 | -------------------------------------------------------------------------------- /timm/models/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /timm/models/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | class TestTimePoolHead(nn.Module): 14 | def __init__(self, base, original_pool=7): 15 | super(TestTimePoolHead, self).__init__() 16 | self.base = base 17 | self.original_pool = original_pool 18 | base_fc = self.base.get_classifier() 19 | if isinstance(base_fc, nn.Conv2d): 20 | self.fc = base_fc 21 | else: 22 | self.fc = nn.Conv2d( 23 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 24 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 25 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 26 | self.base.reset_classifier(0) # delete original fc layer 27 | 28 | def forward(self, x): 29 | x = self.base.forward_features(x) 30 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 31 | x = self.fc(x) 32 | x = adaptive_avgmax_pool2d(x, 1) 33 | return x.view(x.size(0), -1) 34 | 35 | 36 | def apply_test_time_pool(model, config, args): 37 | test_time_pool = False 38 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 39 | return model, False 40 | if not args.no_test_pool and \ 41 | config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ 42 | config['input_size'][-2] > model.default_cfg['input_size'][-2]: 43 | logging.info('Target input size %s > pretrained default %s, using test time pooling' % 44 | (str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:]))) 45 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 46 | test_time_pool = True 47 | return model, test_time_pool 48 | -------------------------------------------------------------------------------- /timm/models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 61 | -------------------------------------------------------------------------------- /timm/models/pruned/ecaresnet50d_pruned.txt: -------------------------------------------------------------------------------- 1 | conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] -------------------------------------------------------------------------------- /timm/models/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | import fnmatch 4 | from collections import defaultdict 5 | 6 | __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] 7 | 8 | _module_to_models = defaultdict(set) # dict of sets to check membership of model in module 9 | _model_to_module = {} # mapping of model names to module names 10 | _model_entrypoints = {} # mapping of model names to entrypoint fns 11 | _model_has_pretrained = set() # set of model names that have pretrained weight url present 12 | 13 | 14 | def register_model(fn): 15 | # lookup containing module 16 | mod = sys.modules[fn.__module__] 17 | module_name_split = fn.__module__.split('.') 18 | module_name = module_name_split[-1] if len(module_name_split) else '' 19 | 20 | # add model to __all__ in module 21 | model_name = fn.__name__ 22 | if hasattr(mod, '__all__'): 23 | mod.__all__.append(model_name) 24 | else: 25 | mod.__all__ = [model_name] 26 | 27 | # add entries to registry dict/sets 28 | _model_entrypoints[model_name] = fn 29 | _model_to_module[model_name] = module_name 30 | _module_to_models[module_name].add(model_name) 31 | has_pretrained = False # check if model has a pretrained url to allow filtering on this 32 | if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: 33 | # this will catch all models that have entrypoint matching cfg key, but miss any aliasing 34 | # entrypoints or non-matching combos 35 | has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] 36 | if has_pretrained: 37 | _model_has_pretrained.add(model_name) 38 | return fn 39 | 40 | 41 | def _natural_key(string_): 42 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 43 | 44 | 45 | def list_models(filter='', module='', pretrained=False, exclude_filters=''): 46 | """ Return list of available model names, sorted alphabetically 47 | 48 | Args: 49 | filter (str) - Wildcard filter string that works with fnmatch 50 | module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') 51 | pretrained (bool) - Include only models with pretrained weights if True 52 | exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter 53 | 54 | Example: 55 | model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' 56 | model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module 57 | """ 58 | if module: 59 | models = list(_module_to_models[module]) 60 | else: 61 | models = _model_entrypoints.keys() 62 | if filter: 63 | models = fnmatch.filter(models, filter) # include these models 64 | if exclude_filters: 65 | if not isinstance(exclude_filters, list): 66 | exclude_filters = [exclude_filters] 67 | for xf in exclude_filters: 68 | exclude_models = fnmatch.filter(models, xf) # exclude these models 69 | if len(exclude_models): 70 | models = set(models).difference(exclude_models) 71 | if pretrained: 72 | models = _model_has_pretrained.intersection(models) 73 | return list(sorted(models, key=_natural_key)) 74 | 75 | 76 | def is_model(model_name): 77 | """ Check if a model name exists 78 | """ 79 | return model_name in _model_entrypoints 80 | 81 | 82 | def model_entrypoint(model_name): 83 | """Fetch a model entrypoint for specified model name 84 | """ 85 | return _model_entrypoints[model_name] 86 | 87 | 88 | def list_modules(): 89 | """ Return list of module names that contain models / model entrypoints 90 | """ 91 | modules = _module_to_models.keys() 92 | return list(sorted(modules)) 93 | 94 | 95 | def is_model_in_modules(model_name, module_names): 96 | """Check if a model exists within a subset of modules 97 | Args: 98 | model_name (str) - name of model to check 99 | module_names (tuple, list, set) - names of modules to search in 100 | """ 101 | assert isinstance(module_names, (tuple, list, set)) 102 | return any(model_name in _module_to_models[n] for n in module_names) 103 | 104 | -------------------------------------------------------------------------------- /timm/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .nadam import Nadam 2 | from .rmsprop_tf import RMSpropTF 3 | from .adamw import AdamW 4 | from .radam import RAdam 5 | from .novograd import NovoGrad 6 | from .nvnovograd import NvNovoGrad 7 | from .lookahead import Lookahead 8 | from .optim_factory import create_optimizer 9 | -------------------------------------------------------------------------------- /timm/optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /timm/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | """ Lookahead Optimizer Wrapper. 2 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 3 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 4 | """ 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | from collections import defaultdict 8 | 9 | 10 | class Lookahead(Optimizer): 11 | def __init__(self, base_optimizer, alpha=0.5, k=6): 12 | if not 0.0 <= alpha <= 1.0: 13 | raise ValueError(f'Invalid slow update rate: {alpha}') 14 | if not 1 <= k: 15 | raise ValueError(f'Invalid lookahead steps: {k}') 16 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 17 | self.base_optimizer = base_optimizer 18 | self.param_groups = self.base_optimizer.param_groups 19 | self.defaults = base_optimizer.defaults 20 | self.defaults.update(defaults) 21 | self.state = defaultdict(dict) 22 | # manually add our defaults to the param groups 23 | for name, default in defaults.items(): 24 | for group in self.param_groups: 25 | group.setdefault(name, default) 26 | 27 | def update_slow(self, group): 28 | for fast_p in group["params"]: 29 | if fast_p.grad is None: 30 | continue 31 | param_state = self.state[fast_p] 32 | if 'slow_buffer' not in param_state: 33 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 34 | param_state['slow_buffer'].copy_(fast_p.data) 35 | slow = param_state['slow_buffer'] 36 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 37 | fast_p.data.copy_(slow) 38 | 39 | def sync_lookahead(self): 40 | for group in self.param_groups: 41 | self.update_slow(group) 42 | 43 | def step(self, closure=None): 44 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 45 | loss = self.base_optimizer.step(closure) 46 | for group in self.param_groups: 47 | group['lookahead_step'] += 1 48 | if group['lookahead_step'] % group['lookahead_k'] == 0: 49 | self.update_slow(group) 50 | return loss 51 | 52 | def state_dict(self): 53 | fast_state_dict = self.base_optimizer.state_dict() 54 | slow_state = { 55 | (id(k) if isinstance(k, torch.Tensor) else k): v 56 | for k, v in self.state.items() 57 | } 58 | fast_state = fast_state_dict['state'] 59 | param_groups = fast_state_dict['param_groups'] 60 | return { 61 | 'state': fast_state, 62 | 'slow_state': slow_state, 63 | 'param_groups': param_groups, 64 | } 65 | 66 | def load_state_dict(self, state_dict): 67 | fast_state_dict = { 68 | 'state': state_dict['state'], 69 | 'param_groups': state_dict['param_groups'], 70 | } 71 | self.base_optimizer.load_state_dict(fast_state_dict) 72 | 73 | # We want to restore the slow state, but share param_groups reference 74 | # with base_optimizer. This is a bit redundant but least code 75 | slow_state_new = False 76 | if 'slow_state' not in state_dict: 77 | print('Loading state_dict from optimizer without Lookahead applied.') 78 | state_dict['slow_state'] = defaultdict(dict) 79 | slow_state_new = True 80 | slow_state_dict = { 81 | 'state': state_dict['slow_state'], 82 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 83 | } 84 | super(Lookahead, self).load_state_dict(slow_state_dict) 85 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 86 | if slow_state_new: 87 | # reapply defaults to catch missing lookahead specific ones 88 | for name, default in self.defaults.items(): 89 | for group in self.param_groups: 90 | group.setdefault(name, default) 91 | -------------------------------------------------------------------------------- /timm/optim/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | """ 27 | 28 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 29 | weight_decay=0, schedule_decay=4e-3): 30 | defaults = dict(lr=lr, betas=betas, eps=eps, 31 | weight_decay=weight_decay, schedule_decay=schedule_decay) 32 | super(Nadam, self).__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | state['m_schedule'] = 1. 56 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 57 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 58 | 59 | # Warming momentum schedule 60 | m_schedule = state['m_schedule'] 61 | schedule_decay = group['schedule_decay'] 62 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 63 | beta1, beta2 = group['betas'] 64 | eps = group['eps'] 65 | state['step'] += 1 66 | t = state['step'] 67 | 68 | if group['weight_decay'] != 0: 69 | grad = grad.add(group['weight_decay'], p.data) 70 | 71 | momentum_cache_t = beta1 * \ 72 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 73 | momentum_cache_t_1 = beta1 * \ 74 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 75 | m_schedule_new = m_schedule * momentum_cache_t 76 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 77 | state['m_schedule'] = m_schedule_new 78 | 79 | # Decay the first and second moment running average coefficient 80 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 81 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 82 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 83 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 84 | 85 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 86 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /timm/optim/novograd.py: -------------------------------------------------------------------------------- 1 | """NovoGrad Optimizer. 2 | Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd 3 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 4 | - https://arxiv.org/abs/1905.11286 5 | """ 6 | 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | import math 10 | 11 | 12 | class NovoGrad(Optimizer): 13 | def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): 14 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 15 | super(NovoGrad, self).__init__(params, defaults) 16 | self._lr = lr 17 | self._beta1 = betas[0] 18 | self._beta2 = betas[1] 19 | self._eps = eps 20 | self._wd = weight_decay 21 | self._grad_averaging = grad_averaging 22 | 23 | self._momentum_initialized = False 24 | 25 | def step(self, closure=None): 26 | loss = None 27 | if closure is not None: 28 | loss = closure() 29 | 30 | if not self._momentum_initialized: 31 | for group in self.param_groups: 32 | for p in group['params']: 33 | if p.grad is None: 34 | continue 35 | state = self.state[p] 36 | grad = p.grad.data 37 | if grad.is_sparse: 38 | raise RuntimeError('NovoGrad does not support sparse gradients') 39 | 40 | v = torch.norm(grad)**2 41 | m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data 42 | state['step'] = 0 43 | state['v'] = v 44 | state['m'] = m 45 | state['grad_ema'] = None 46 | self._momentum_initialized = True 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | state = self.state[p] 53 | state['step'] += 1 54 | 55 | step, v, m = state['step'], state['v'], state['m'] 56 | grad_ema = state['grad_ema'] 57 | 58 | grad = p.grad.data 59 | g2 = torch.norm(grad)**2 60 | grad_ema = g2 if grad_ema is None else grad_ema * \ 61 | self._beta2 + g2 * (1. - self._beta2) 62 | grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) 63 | 64 | if self._grad_averaging: 65 | grad *= (1. - self._beta1) 66 | 67 | g2 = torch.norm(grad)**2 68 | v = self._beta2*v + (1. - self._beta2)*g2 69 | m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) 70 | bias_correction1 = 1 - self._beta1 ** step 71 | bias_correction2 = 1 - self._beta2 ** step 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | 74 | state['v'], state['m'] = v, m 75 | state['grad_ema'] = grad_ema 76 | p.data.add_(-step_size, m) 77 | return loss 78 | -------------------------------------------------------------------------------- /timm/optim/nvnovograd.py: -------------------------------------------------------------------------------- 1 | """ Nvidia NovoGrad Optimizer. 2 | Original impl by Nvidia from Jasper example: 3 | - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper 4 | Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` 5 | - https://arxiv.org/abs/1905.11286 6 | """ 7 | 8 | import torch 9 | from torch.optim.optimizer import Optimizer 10 | import math 11 | 12 | 13 | class NvNovoGrad(Optimizer): 14 | """ 15 | Implements Novograd algorithm. 16 | 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float, optional): learning rate (default: 1e-3) 21 | betas (Tuple[float, float], optional): coefficients used for computing 22 | running averages of gradient and its square (default: (0.95, 0.98)) 23 | eps (float, optional): term added to the denominator to improve 24 | numerical stability (default: 1e-8) 25 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 26 | grad_averaging: gradient averaging 27 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 28 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 29 | (default: False) 30 | """ 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, 33 | weight_decay=0, grad_averaging=False, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, 44 | grad_averaging=grad_averaging, 45 | amsgrad=amsgrad) 46 | 47 | super(NvNovoGrad, self).__init__(params, defaults) 48 | 49 | def __setstate__(self, state): 50 | super(NvNovoGrad, self).__setstate__(state) 51 | for group in self.param_groups: 52 | group.setdefault('amsgrad', False) 53 | 54 | def step(self, closure=None): 55 | """Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (callable, optional): A closure that reevaluates the model 59 | and returns the loss. 60 | """ 61 | loss = None 62 | if closure is not None: 63 | loss = closure() 64 | 65 | for group in self.param_groups: 66 | for p in group['params']: 67 | if p.grad is None: 68 | continue 69 | grad = p.grad.data 70 | if grad.is_sparse: 71 | raise RuntimeError('Sparse gradients are not supported.') 72 | amsgrad = group['amsgrad'] 73 | 74 | state = self.state[p] 75 | 76 | # State initialization 77 | if len(state) == 0: 78 | state['step'] = 0 79 | # Exponential moving average of gradient values 80 | state['exp_avg'] = torch.zeros_like(p.data) 81 | # Exponential moving average of squared gradient values 82 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 83 | if amsgrad: 84 | # Maintains max of all exp. moving avg. of sq. grad. values 85 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | if amsgrad: 89 | max_exp_avg_sq = state['max_exp_avg_sq'] 90 | beta1, beta2 = group['betas'] 91 | 92 | state['step'] += 1 93 | 94 | norm = torch.sum(torch.pow(grad, 2)) 95 | 96 | if exp_avg_sq == 0: 97 | exp_avg_sq.copy_(norm) 98 | else: 99 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 100 | 101 | if amsgrad: 102 | # Maintains the maximum of all 2nd moment running avg. till now 103 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 104 | # Use the max. for normalizing running avg. of gradient 105 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 106 | else: 107 | denom = exp_avg_sq.sqrt().add_(group['eps']) 108 | 109 | grad.div_(denom) 110 | if group['weight_decay'] != 0: 111 | grad.add_(group['weight_decay'], p.data) 112 | if group['grad_averaging']: 113 | grad.mul_(1 - beta1) 114 | exp_avg.mul_(beta1).add_(grad) 115 | 116 | p.data.add_(-group['lr'], exp_avg) 117 | 118 | return loss 119 | -------------------------------------------------------------------------------- /timm/optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead 4 | try: 5 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 6 | has_apex = True 7 | except ImportError: 8 | has_apex = False 9 | 10 | 11 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 12 | decay = [] 13 | no_decay = [] 14 | for name, param in model.named_parameters(): 15 | if not param.requires_grad: 16 | continue # frozen weights 17 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 18 | no_decay.append(param) 19 | else: 20 | decay.append(param) 21 | return [ 22 | {'params': no_decay, 'weight_decay': 0.}, 23 | {'params': decay, 'weight_decay': weight_decay}] 24 | 25 | 26 | def create_optimizer(args, model, filter_bias_and_bn=True): 27 | opt_lower = args.opt.lower() 28 | weight_decay = args.weight_decay 29 | if 'adamw' in opt_lower or 'radam' in opt_lower: 30 | # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay 31 | # I don't believe they follow the paper or original Torch7 impl which schedules weight 32 | # decay based on the ratio of current_lr/initial_lr 33 | weight_decay /= args.lr 34 | if weight_decay and filter_bias_and_bn: 35 | parameters = add_weight_decay(model, weight_decay) 36 | weight_decay = 0. 37 | else: 38 | parameters = model.parameters() 39 | 40 | if 'fused' in opt_lower: 41 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 42 | 43 | opt_split = opt_lower.split('_') 44 | opt_lower = opt_split[-1] 45 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 46 | optimizer = optim.SGD( 47 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) 48 | elif opt_lower == 'momentum': 49 | optimizer = optim.SGD( 50 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) 51 | elif opt_lower == 'adam': 52 | optimizer = optim.Adam( 53 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 54 | elif opt_lower == 'adamw': 55 | optimizer = AdamW( 56 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 57 | elif opt_lower == 'nadam': 58 | optimizer = Nadam( 59 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 60 | elif opt_lower == 'radam': 61 | optimizer = RAdam( 62 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 63 | elif opt_lower == 'adadelta': 64 | optimizer = optim.Adadelta( 65 | parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 66 | elif opt_lower == 'rmsprop': 67 | optimizer = optim.RMSprop( 68 | parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, 69 | momentum=args.momentum, weight_decay=weight_decay) 70 | elif opt_lower == 'rmsproptf': 71 | optimizer = RMSpropTF( 72 | parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, 73 | momentum=args.momentum, weight_decay=weight_decay) 74 | elif opt_lower == 'novograd': 75 | optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 76 | elif opt_lower == 'nvnovograd': 77 | optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 78 | elif opt_lower == 'fusedsgd': 79 | optimizer = FusedSGD( 80 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) 81 | elif opt_lower == 'fusedmomentum': 82 | optimizer = FusedSGD( 83 | parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) 84 | elif opt_lower == 'fusedadam': 85 | optimizer = FusedAdam( 86 | parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) 87 | elif opt_lower == 'fusedadamw': 88 | optimizer = FusedAdam( 89 | parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) 90 | elif opt_lower == 'fusedlamb': 91 | optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) 92 | elif opt_lower == 'fusednovograd': 93 | optimizer = FusedNovoGrad( 94 | parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) 95 | else: 96 | assert False and "Invalid optimizer" 97 | raise ValueError 98 | 99 | if len(opt_split) > 1: 100 | if opt_split[0] == 'lookahead': 101 | optimizer = Lookahead(optimizer) 102 | 103 | return optimizer 104 | -------------------------------------------------------------------------------- /timm/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .plateau_lr import PlateauLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .tanh_lr import TanhLRScheduler 5 | from .scheduler_factory import create_scheduler 6 | -------------------------------------------------------------------------------- /timm/scheduler/cosine_lr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from .scheduler import Scheduler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class CosineLRScheduler(Scheduler): 13 | """ 14 | Cosine decay with restarts. 15 | This is described in the paper https://arxiv.org/abs/1608.03983. 16 | 17 | Inspiration from 18 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 19 | """ 20 | 21 | def __init__(self, 22 | optimizer: torch.optim.Optimizer, 23 | t_initial: int, 24 | t_mul: float = 1., 25 | lr_min: float = 0., 26 | decay_rate: float = 1., 27 | warmup_t=0, 28 | warmup_lr_init=0, 29 | warmup_prefix=False, 30 | cycle_limit=0, 31 | t_in_epochs=True, 32 | noise_range_t=None, 33 | noise_pct=0.67, 34 | noise_std=1.0, 35 | noise_seed=42, 36 | initialize=True) -> None: 37 | super().__init__( 38 | optimizer, param_group_field="lr", 39 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 40 | initialize=initialize) 41 | 42 | assert t_initial > 0 43 | assert lr_min >= 0 44 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 45 | logger.warning("Cosine annealing scheduler will have no effect on the learning " 46 | "rate since t_initial = t_mul = eta_mul = 1.") 47 | self.t_initial = t_initial 48 | self.t_mul = t_mul 49 | self.lr_min = lr_min 50 | self.decay_rate = decay_rate 51 | self.cycle_limit = cycle_limit 52 | self.warmup_t = warmup_t 53 | self.warmup_lr_init = warmup_lr_init 54 | self.warmup_prefix = warmup_prefix 55 | self.t_in_epochs = t_in_epochs 56 | if self.warmup_t: 57 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 58 | super().update_groups(self.warmup_lr_init) 59 | else: 60 | self.warmup_steps = [1 for _ in self.base_values] 61 | 62 | def _get_lr(self, t): 63 | if t < self.warmup_t: 64 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 65 | else: 66 | if self.warmup_prefix: 67 | t = t - self.warmup_t 68 | 69 | if self.t_mul != 1: 70 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 71 | t_i = self.t_mul ** i * self.t_initial 72 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 73 | else: 74 | i = t // self.t_initial 75 | t_i = self.t_initial 76 | t_curr = t - (self.t_initial * i) 77 | 78 | gamma = self.decay_rate ** i 79 | lr_min = self.lr_min * gamma 80 | lr_max_values = [v * gamma for v in self.base_values] 81 | 82 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 83 | lrs = [ 84 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 85 | ] 86 | else: 87 | lrs = [self.lr_min for _ in self.base_values] 88 | 89 | return lrs 90 | 91 | def get_epoch_values(self, epoch: int): 92 | if self.t_in_epochs: 93 | return self._get_lr(epoch) 94 | else: 95 | return None 96 | 97 | def get_update_values(self, num_updates: int): 98 | if not self.t_in_epochs: 99 | return self._get_lr(num_updates) 100 | else: 101 | return None 102 | 103 | def get_cycle_length(self, cycles=0): 104 | if not cycles: 105 | cycles = self.cycle_limit 106 | cycles = max(1, cycles) 107 | if self.t_mul == 1.0: 108 | return self.t_initial * cycles 109 | else: 110 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 111 | -------------------------------------------------------------------------------- /timm/scheduler/plateau_lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .scheduler import Scheduler 4 | 5 | 6 | class PlateauLRScheduler(Scheduler): 7 | """Decay the LR by a factor every time the validation loss plateaus.""" 8 | 9 | def __init__(self, 10 | optimizer, 11 | decay_rate=0.1, 12 | patience_t=10, 13 | verbose=True, 14 | threshold=1e-4, 15 | cooldown_t=0, 16 | warmup_t=0, 17 | warmup_lr_init=0, 18 | lr_min=0, 19 | mode='max', 20 | noise_range_t=None, 21 | noise_type='normal', 22 | noise_pct=0.67, 23 | noise_std=1.0, 24 | noise_seed=None, 25 | initialize=True, 26 | ): 27 | super().__init__(optimizer, 'lr', initialize=initialize) 28 | 29 | self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 30 | self.optimizer, 31 | patience=patience_t, 32 | factor=decay_rate, 33 | verbose=verbose, 34 | threshold=threshold, 35 | cooldown=cooldown_t, 36 | mode=mode, 37 | min_lr=lr_min 38 | ) 39 | 40 | self.noise_range = noise_range_t 41 | self.noise_pct = noise_pct 42 | self.noise_type = noise_type 43 | self.noise_std = noise_std 44 | self.noise_seed = noise_seed if noise_seed is not None else 42 45 | self.warmup_t = warmup_t 46 | self.warmup_lr_init = warmup_lr_init 47 | if self.warmup_t: 48 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 49 | super().update_groups(self.warmup_lr_init) 50 | else: 51 | self.warmup_steps = [1 for _ in self.base_values] 52 | self.restore_lr = None 53 | 54 | def state_dict(self): 55 | return { 56 | 'best': self.lr_scheduler.best, 57 | 'last_epoch': self.lr_scheduler.last_epoch, 58 | } 59 | 60 | def load_state_dict(self, state_dict): 61 | self.lr_scheduler.best = state_dict['best'] 62 | if 'last_epoch' in state_dict: 63 | self.lr_scheduler.last_epoch = state_dict['last_epoch'] 64 | 65 | # override the base class step fn completely 66 | def step(self, epoch, metric=None): 67 | if epoch <= self.warmup_t: 68 | lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] 69 | super().update_groups(lrs) 70 | else: 71 | if self.restore_lr is not None: 72 | # restore actual LR from before our last noise perturbation before stepping base 73 | for i, param_group in enumerate(self.optimizer.param_groups): 74 | param_group['lr'] = self.restore_lr[i] 75 | self.restore_lr = None 76 | 77 | self.lr_scheduler.step(metric, epoch) # step the base scheduler 78 | 79 | if self.noise_range is not None: 80 | if isinstance(self.noise_range, (list, tuple)): 81 | apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] 82 | else: 83 | apply_noise = epoch >= self.noise_range 84 | if apply_noise: 85 | self._apply_noise(epoch) 86 | 87 | def _apply_noise(self, epoch): 88 | g = torch.Generator() 89 | g.manual_seed(self.noise_seed + epoch) 90 | if self.noise_type == 'normal': 91 | while True: 92 | # resample if noise out of percent limit, brute force but shouldn't spin much 93 | noise = torch.randn(1, generator=g).item() 94 | if abs(noise) < self.noise_pct: 95 | break 96 | else: 97 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 98 | 99 | # apply the noise on top of previous LR, cache the old value so we can restore for normal 100 | # stepping of base scheduler 101 | restore_lr = [] 102 | for i, param_group in enumerate(self.optimizer.param_groups): 103 | old_lr = float(param_group['lr']) 104 | restore_lr.append(old_lr) 105 | new_lr = old_lr + old_lr * noise 106 | param_group['lr'] = new_lr 107 | self.restore_lr = restore_lr 108 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /timm/scheduler/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | from .cosine_lr import CosineLRScheduler 2 | from .tanh_lr import TanhLRScheduler 3 | from .step_lr import StepLRScheduler 4 | from .plateau_lr import PlateauLRScheduler 5 | 6 | 7 | def create_scheduler(args, optimizer): 8 | num_epochs = args.epochs 9 | 10 | if getattr(args, 'lr_noise', None) is not None: 11 | lr_noise = getattr(args, 'lr_noise') 12 | if isinstance(lr_noise, (list, tuple)): 13 | noise_range = [n * num_epochs for n in lr_noise] 14 | if len(noise_range) == 1: 15 | noise_range = noise_range[0] 16 | else: 17 | noise_range = lr_noise * num_epochs 18 | else: 19 | noise_range = None 20 | 21 | lr_scheduler = None 22 | if args.sched == 'cosine': 23 | lr_scheduler = CosineLRScheduler( 24 | optimizer, 25 | t_initial=num_epochs, 26 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 27 | lr_min=args.min_lr, 28 | decay_rate=args.decay_rate, 29 | warmup_lr_init=args.warmup_lr, 30 | warmup_t=args.warmup_epochs, 31 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 32 | t_in_epochs=True, 33 | noise_range_t=noise_range, 34 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 35 | noise_std=getattr(args, 'lr_noise_std', 1.), 36 | noise_seed=getattr(args, 'seed', 42), 37 | ) 38 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 39 | elif args.sched == 'tanh': 40 | lr_scheduler = TanhLRScheduler( 41 | optimizer, 42 | t_initial=num_epochs, 43 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 44 | lr_min=args.min_lr, 45 | warmup_lr_init=args.warmup_lr, 46 | warmup_t=args.warmup_epochs, 47 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 48 | t_in_epochs=True, 49 | noise_range_t=noise_range, 50 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 51 | noise_std=getattr(args, 'lr_noise_std', 1.), 52 | noise_seed=getattr(args, 'seed', 42), 53 | ) 54 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 55 | elif args.sched == 'step': 56 | lr_scheduler = StepLRScheduler( 57 | optimizer, 58 | decay_t=args.decay_epochs, 59 | decay_rate=args.decay_rate, 60 | warmup_lr_init=args.warmup_lr, 61 | warmup_t=args.warmup_epochs, 62 | noise_range_t=noise_range, 63 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 64 | noise_std=getattr(args, 'lr_noise_std', 1.), 65 | noise_seed=getattr(args, 'seed', 42), 66 | ) 67 | elif args.sched == 'plateau': 68 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 69 | lr_scheduler = PlateauLRScheduler( 70 | optimizer, 71 | decay_rate=args.decay_rate, 72 | patience_t=args.patience_epochs, 73 | lr_min=args.min_lr, 74 | mode=mode, 75 | warmup_lr_init=args.warmup_lr, 76 | warmup_t=args.warmup_epochs, 77 | cooldown_t=0, 78 | noise_range_t=noise_range, 79 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 80 | noise_std=getattr(args, 'lr_noise_std', 1.), 81 | noise_seed=getattr(args, 'seed', 42), 82 | ) 83 | 84 | return lr_scheduler, num_epochs 85 | -------------------------------------------------------------------------------- /timm/scheduler/step_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from .scheduler import Scheduler 5 | 6 | 7 | class StepLRScheduler(Scheduler): 8 | """ 9 | """ 10 | 11 | def __init__(self, 12 | optimizer: torch.optim.Optimizer, 13 | decay_t: float, 14 | decay_rate: float = 1., 15 | warmup_t=0, 16 | warmup_lr_init=0, 17 | t_in_epochs=True, 18 | noise_range_t=None, 19 | noise_pct=0.67, 20 | noise_std=1.0, 21 | noise_seed=42, 22 | initialize=True, 23 | ) -> None: 24 | super().__init__( 25 | optimizer, param_group_field="lr", 26 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 27 | initialize=initialize) 28 | 29 | self.decay_t = decay_t 30 | self.decay_rate = decay_rate 31 | self.warmup_t = warmup_t 32 | self.warmup_lr_init = warmup_lr_init 33 | self.t_in_epochs = t_in_epochs 34 | if self.warmup_t: 35 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 36 | super().update_groups(self.warmup_lr_init) 37 | else: 38 | self.warmup_steps = [1 for _ in self.base_values] 39 | 40 | def _get_lr(self, t): 41 | if t < self.warmup_t: 42 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 43 | else: 44 | lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] 45 | return lrs 46 | 47 | def get_epoch_values(self, epoch: int): 48 | if self.t_in_epochs: 49 | return self._get_lr(epoch) 50 | else: 51 | return None 52 | 53 | def get_update_values(self, num_updates: int): 54 | if not self.t_in_epochs: 55 | return self._get_lr(num_updates) 56 | else: 57 | return None 58 | -------------------------------------------------------------------------------- /timm/scheduler/tanh_lr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from .scheduler import Scheduler 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class TanhLRScheduler(Scheduler): 13 | """ 14 | Hyberbolic-Tangent decay with restarts. 15 | This is described in the paper https://arxiv.org/abs/1806.01593 16 | """ 17 | 18 | def __init__(self, 19 | optimizer: torch.optim.Optimizer, 20 | t_initial: int, 21 | lb: float = -6., 22 | ub: float = 4., 23 | t_mul: float = 1., 24 | lr_min: float = 0., 25 | decay_rate: float = 1., 26 | warmup_t=0, 27 | warmup_lr_init=0, 28 | warmup_prefix=False, 29 | cycle_limit=0, 30 | t_in_epochs=True, 31 | noise_range_t=None, 32 | noise_pct=0.67, 33 | noise_std=1.0, 34 | noise_seed=42, 35 | initialize=True) -> None: 36 | super().__init__( 37 | optimizer, param_group_field="lr", 38 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 39 | initialize=initialize) 40 | 41 | assert t_initial > 0 42 | assert lr_min >= 0 43 | assert lb < ub 44 | assert cycle_limit >= 0 45 | assert warmup_t >= 0 46 | assert warmup_lr_init >= 0 47 | self.lb = lb 48 | self.ub = ub 49 | self.t_initial = t_initial 50 | self.t_mul = t_mul 51 | self.lr_min = lr_min 52 | self.decay_rate = decay_rate 53 | self.cycle_limit = cycle_limit 54 | self.warmup_t = warmup_t 55 | self.warmup_lr_init = warmup_lr_init 56 | self.warmup_prefix = warmup_prefix 57 | self.t_in_epochs = t_in_epochs 58 | if self.warmup_t: 59 | t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) 60 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] 61 | super().update_groups(self.warmup_lr_init) 62 | else: 63 | self.warmup_steps = [1 for _ in self.base_values] 64 | 65 | def _get_lr(self, t): 66 | if t < self.warmup_t: 67 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 68 | else: 69 | if self.warmup_prefix: 70 | t = t - self.warmup_t 71 | 72 | if self.t_mul != 1: 73 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 74 | t_i = self.t_mul ** i * self.t_initial 75 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 76 | else: 77 | i = t // self.t_initial 78 | t_i = self.t_initial 79 | t_curr = t - (self.t_initial * i) 80 | 81 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 82 | gamma = self.decay_rate ** i 83 | lr_min = self.lr_min * gamma 84 | lr_max_values = [v * gamma for v in self.base_values] 85 | 86 | tr = t_curr / t_i 87 | lrs = [ 88 | lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) 89 | for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] 93 | return lrs 94 | 95 | def get_epoch_values(self, epoch: int): 96 | if self.t_in_epochs: 97 | return self._get_lr(epoch) 98 | else: 99 | return None 100 | 101 | def get_update_values(self, num_updates: int): 102 | if not self.t_in_epochs: 103 | return self._get_lr(num_updates) 104 | else: 105 | return None 106 | 107 | def get_cycle_length(self, cycles=0): 108 | if not cycles: 109 | cycles = self.cycle_limit 110 | cycles = max(1, cycles) 111 | if self.t_mul == 1.0: 112 | return self.t_initial * cycles 113 | else: 114 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 115 | -------------------------------------------------------------------------------- /timm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.30' 2 | -------------------------------------------------------------------------------- /tools/calculate_tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def evaluateTop1(logits, labels): 5 | with torch.no_grad(): 6 | pred = logits.argmax(dim=1) 7 | return torch.eq(pred, labels).sum().float().item()/labels.size(0) 8 | 9 | 10 | def evaluateTop5(logits, labels): 11 | with torch.no_grad(): 12 | maxk = max((1, 5)) 13 | labels_resize = labels.view(-1, 1) 14 | _, pred = logits.topk(maxk, 1, True, True) 15 | return torch.eq(pred, labels_resize).sum().float().item()/labels.size(0) 16 | 17 | 18 | class MetricLog(): 19 | def __init__(self): 20 | self.record = {"train": {"loss": [], "acc": [], "log_loss": [], "att_loss": []}, 21 | "val": {"loss": [], "acc": [], "log_loss": [], "att_loss": []}} 22 | 23 | def print_metric(self): 24 | print("train loss:", self.record["train"]["loss"]) 25 | print("val loss:", self.record["val"]["loss"]) 26 | print("train acc:", self.record["train"]["acc"]) 27 | print("val acc:", self.record["val"]["acc"]) 28 | print("train CE loss", self.record["train"]["log_loss"]) 29 | print("val CE loss", self.record["val"]["log_loss"]) 30 | print("train attention loss", self.record["train"]["att_loss"]) 31 | print("val attention loss", self.record["val"]["att_loss"]) 32 | -------------------------------------------------------------------------------- /tools/image_aug.py: -------------------------------------------------------------------------------- 1 | import imgaug.augmenters as iaa 2 | import matplotlib.pyplot as plt 3 | import random 4 | import numpy as np 5 | 6 | 7 | class ImageAugment(object): 8 | """ 9 | class for augment the training data using imgaug 10 | """ 11 | def __init__(self): 12 | self.key = 0 13 | self.choice = 1 14 | self.rotate = np.random.randint(-10, 10) 15 | self.scale_x = random.uniform(0.8, 1.0) 16 | self.scale_y = random.uniform(0.8, 1.0) 17 | self.translate_x = random.uniform(0, 0.1) 18 | self.translate_y = random.uniform(-0.1, 0.1) 19 | self.brightness = np.random.randint(-10, 10) 20 | self.linear_contrast = random.uniform(0.5, 2.0) 21 | self.alpha = random.uniform(0, 1.0) 22 | self.lightness = random.uniform(0.75, 1.5) 23 | self.Gaussian = random.uniform(0.0, 0.05*255) 24 | self.Gaussian_blur = random.uniform(0, 3.0) 25 | 26 | def aug(self, image, sequence): 27 | """ 28 | :param image: need size (H, W, C) one image once 29 | :param sequence: collection of augment function 30 | :return: 31 | """ 32 | image_aug = sequence(image=image) 33 | return image_aug 34 | 35 | def rd(self, rand_max): 36 | seed = np.random.randint(0, rand_max) 37 | return seed 38 | 39 | def aug_sequence(self): 40 | sequence = self.aug_function() 41 | seq = iaa.Sequential(sequence, random_order=True) 42 | return seq 43 | 44 | def aug_function(self): 45 | sequence = [] 46 | if self.rd(2) == self.key: 47 | sequence.append(iaa.Fliplr(1.0)) # 50% horizontally flip all batch images 48 | if self.rd(2) == self.key: 49 | sequence.append(iaa.Flipud(1.0)) # 50% vertically flip all batch images 50 | if self.rd(2) == self.key: 51 | sequence.append(iaa.Affine( 52 | scale={"x": self.scale_x, "y": self.scale_y}, # scale images to 80-100% of their size 53 | translate_percent={"x": self.translate_x, "y": self.translate_y}, # translate by -10 to +10 percent (per axis) 54 | rotate=(self.rotate), # rotate by -15 to +15 degrees 55 | )) 56 | if self.rd(2) == self.key: 57 | sequence.extend(iaa.SomeOf((1, self.choice), 58 | [ 59 | iaa.OneOf([ 60 | iaa.GaussianBlur(self.Gaussian_blur), # blur images with a sigma between 0 and 3.0 61 | # iaa.AverageBlur(k=(2, 7)), # blur images using local means with kernel size 2-7 62 | # iaa.MedianBlur(k=(3, 11)) # blur images using local medians with kernel size 3-11 63 | ]), 64 | # iaa.Sharpen(alpha=self.alpha, lightness=self.lightness), # sharpen images 65 | # iaa.LinearContrast(self.linear_contrast, per_channel=0.5), # improve or worse the contrast 66 | # iaa.Add(self.brightness, per_channel=0.5), # change brightness 67 | # iaa.AdditiveGaussianNoise(loc=0, scale=0.1, per_channel=0.5) # add gaussian n 68 | ])) 69 | return sequence 70 | 71 | 72 | def show_aug(image): 73 | plt.figure(figsize=(10, 10), facecolor="#FFFFFF") 74 | for i in range(1, len(image)+1): 75 | plt.subplot(len(image), 1, i) 76 | plt.imshow(image[i-1]) 77 | plt.show() 78 | 79 | -------------------------------------------------------------------------------- /tools/prepare_things.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.distributed as dist 4 | from collections import defaultdict, deque 5 | from torch.utils.data import DataLoader 6 | from prefetch_generator import BackgroundGenerator 7 | 8 | 9 | def init_distributed_mode(args): 10 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 11 | args.rank = int(os.environ["RANK"]) 12 | args.world_size = int(os.environ['WORLD_SIZE']) 13 | args.gpu = int(os.environ['LOCAL_RANK']) 14 | elif 'SLURM_PROCID' in os.environ: 15 | args.rank = int(os.environ['SLURM_PROCID']) 16 | args.gpu = args.rank % torch.cuda.device_count() 17 | else: 18 | print('Not using distributed mode') 19 | args.distributed = False 20 | return 21 | 22 | args.distributed = True 23 | 24 | torch.cuda.set_device(args.gpu) 25 | args.dist_backend = 'nccl' 26 | print('| distributed init (rank {}): {}'.format( 27 | args.rank, args.dist_url), flush=True) 28 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 29 | world_size=args.world_size, rank=args.rank) 30 | torch.distributed.barrier() 31 | setup_for_distributed(args.rank == 0) 32 | 33 | 34 | def setup_for_distributed(is_master): 35 | """ 36 | This function disables printing when not in master process 37 | """ 38 | import builtins as __builtin__ 39 | builtin_print = __builtin__.print 40 | 41 | def print(*args, **kwargs): 42 | force = kwargs.pop('force', False) 43 | if is_master or force: 44 | builtin_print(*args, **kwargs) 45 | 46 | __builtin__.print = print 47 | 48 | 49 | def is_dist_avail_and_initialized(): 50 | if not dist.is_available(): 51 | return False 52 | if not dist.is_initialized(): 53 | return False 54 | return True 55 | 56 | 57 | def get_world_size(): 58 | if not is_dist_avail_and_initialized(): 59 | return 1 60 | return dist.get_world_size() 61 | 62 | 63 | def get_rank(): 64 | if not is_dist_avail_and_initialized(): 65 | return 0 66 | return dist.get_rank() 67 | 68 | 69 | def is_main_process(): 70 | return get_rank() == 0 71 | 72 | 73 | def save_on_master(*args, **kwargs): 74 | if is_main_process(): 75 | torch.save(*args, **kwargs) 76 | 77 | 78 | class SmoothedValue(object): 79 | """Track a series of values and provide access to smoothed values over a 80 | window or the global series average. 81 | """ 82 | 83 | def __init__(self, window_size=20, fmt=None): 84 | if fmt is None: 85 | fmt = "{median:.4f} ({global_avg:.4f})" 86 | self.deque = deque(maxlen=window_size) 87 | self.total = 0.0 88 | self.count = 0 89 | self.fmt = fmt 90 | 91 | def update(self, value, n=1): 92 | self.deque.append(value) 93 | self.count += n 94 | self.total += value * n 95 | 96 | def synchronize_between_processes(self): 97 | """ 98 | Warning: does not synchronize the deque! 99 | """ 100 | if not is_dist_avail_and_initialized(): 101 | return 102 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 103 | dist.barrier() 104 | dist.all_reduce(t) 105 | t = t.tolist() 106 | self.count = int(t[0]) 107 | self.total = t[1] 108 | 109 | @property 110 | def median(self): 111 | d = torch.tensor(list(self.deque)) 112 | return d.median().item() 113 | 114 | @property 115 | def avg(self): 116 | d = torch.tensor(list(self.deque), dtype=torch.float32) 117 | return d.mean().item() 118 | 119 | @property 120 | def global_avg(self): 121 | return self.total / self.count 122 | 123 | @property 124 | def max(self): 125 | return max(self.deque) 126 | 127 | @property 128 | def value(self): 129 | return self.deque[-1] 130 | 131 | def __str__(self): 132 | return self.fmt.format( 133 | median=self.median, 134 | avg=self.avg, 135 | global_avg=self.global_avg, 136 | max=self.max, 137 | value=self.value) 138 | 139 | 140 | class DataLoaderX(DataLoader): 141 | def __iter__(self): 142 | return BackgroundGenerator(super().__iter__()) 143 | 144 | 145 | def get_name(root, mode_folder=True): 146 | for root, dirs, file in os.walk(root): 147 | if mode_folder: 148 | return sorted(dirs) 149 | else: 150 | return sorted(file) -------------------------------------------------------------------------------- /torchcam/IBA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbw520/scouter/5885b821681daf8c2263975490b4c6418687277b/torchcam/IBA/__init__.py -------------------------------------------------------------------------------- /torchcam/__init__.py: -------------------------------------------------------------------------------- 1 | from torchcam import cams, utils 2 | 3 | try: 4 | from .version import __version__ # noqa: F401 5 | except ImportError: 6 | pass 7 | -------------------------------------------------------------------------------- /torchcam/cams/__init__.py: -------------------------------------------------------------------------------- 1 | from .cam import * 2 | from .gradcam import * 3 | 4 | del cam 5 | del gradcam 6 | -------------------------------------------------------------------------------- /torchcam/utils.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Utils 6 | """ 7 | 8 | import numpy as np 9 | from matplotlib import cm 10 | from PIL import Image 11 | 12 | 13 | def overlay_mask(img, mask, colormap='jet', alpha=0.7): 14 | """Overlay a colormapped mask on a background image 15 | 16 | Args: 17 | img (PIL.Image.Image): background image 18 | mask (PIL.Image.Image): mask to be overlayed in grayscale 19 | colormap (str, optional): colormap to be applied on the mask 20 | alpha (float, optional): transparency of the background image 21 | 22 | Returns: 23 | PIL.Image.Image: overlayed image 24 | """ 25 | 26 | if not isinstance(img, Image.Image) or not isinstance(mask, Image.Image): 27 | raise TypeError('img and mask arguments need to be PIL.Image') 28 | 29 | if not isinstance(alpha, float) or alpha < 0 or alpha >= 1: 30 | raise ValueError('alpha argument is expected to be of type float between 0 and 1') 31 | 32 | cmap = cm.get_cmap(colormap) 33 | # Resize mask and apply colormap 34 | overlay = mask.resize(img.size, resample=Image.BICUBIC) 35 | overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, 1:]).astype(np.uint8) 36 | # Overlay the image with the mask 37 | overlayed_img = Image.fromarray((alpha * np.asarray(img) + (1 - alpha) * overlay).astype(np.uint8)) 38 | 39 | return overlayed_img 40 | -------------------------------------------------------------------------------- /torchcam/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.2a0+6dd7a75' 2 | -------------------------------------------------------------------------------- /torchray/VERSION: -------------------------------------------------------------------------------- 1 | 1.0.0 2 | -------------------------------------------------------------------------------- /torchray/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import importlib.resources as resources 3 | except ImportError: 4 | import importlib_resources as resources 5 | 6 | with resources.open_text('torchray', 'VERSION') as f: 7 | __version__ = f.readlines()[0].rstrip() 8 | -------------------------------------------------------------------------------- /torchray/attribution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wbw520/scouter/5885b821681daf8c2263975490b4c6418687277b/torchray/attribution/__init__.py -------------------------------------------------------------------------------- /torchray/attribution/deconvnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module implements the *deconvolution* method of [DECONV]_ for visualizing 5 | deep networks. The simplest interface is given by the :func:`deconvnet` 6 | function: 7 | 8 | .. literalinclude:: ../examples/deconvnet.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". DeConvNet is a 13 | backpropagation method, and thus works by changing the definition of the 14 | backward functions of some layers. The modified ReLU is implemented by class 15 | :class:`DeConvNetReLU`; however, this is rarely used directly; instead, one 16 | uses the :class:`DeConvNetContext` context instead, as follows: 17 | 18 | .. literalinclude:: ../examples/deconvnet_manual.py 19 | :language: python 20 | :linenos: 21 | 22 | See also :ref:`Backprogation methods ` for further examples 23 | and discussion. 24 | 25 | Theory 26 | ~~~~~~ 27 | 28 | The only change is a modified definition of the backward ReLU function: 29 | 30 | .. math:: 31 | \operatorname{ReLU}^*(x,p) = 32 | \begin{cases} 33 | p, & \mathrm{if}~ p > 0,\\ 34 | 0, & \mathrm{otherwise} \\ 35 | \end{cases} 36 | 37 | Warning: 38 | 39 | DeConvNets are defined for "standard" networks that use ReLU operations. 40 | Further modifications may be required for more complex or new networks 41 | that use other type of non-linearities. 42 | 43 | References: 44 | 45 | .. [DECONV] Zeiler and Fergus, 46 | *Visualizing and Understanding Convolutional Networks*, 47 | ECCV 2014, 48 | ``__. 49 | """ 50 | 51 | __all__ = ["DeConvNetContext", "deconvnet"] 52 | 53 | import torch 54 | 55 | from .common import ReLUContext, saliency 56 | 57 | 58 | class DeConvNetReLU(torch.autograd.Function): 59 | """DeConvNet ReLU autograd function. 60 | 61 | This is an autograd function that redefines the ``relu`` function 62 | to match the DeConvNet ReLU definition. 63 | """ 64 | 65 | @staticmethod 66 | def forward(ctx, input): 67 | """DeConvNet ReLU forward function.""" 68 | return input.clamp(min=0) 69 | 70 | @staticmethod 71 | def backward(ctx, grad_output): 72 | """DeConvNet ReLU backward function.""" 73 | return grad_output.clamp(min=0) 74 | 75 | 76 | class DeConvNetContext(ReLUContext): 77 | """DeConvNet context. 78 | 79 | This context modifies the computation of gradient to match the DeConvNet 80 | definition. 81 | 82 | See :mod:`torchray.attribution.deconvnet` for how to use it. 83 | """ 84 | 85 | def __init__(self): 86 | super(DeConvNetContext, self).__init__(DeConvNetReLU) 87 | 88 | 89 | def deconvnet(*args, context_builder=DeConvNetContext, **kwargs): 90 | """DeConvNet method. 91 | 92 | The function takes the same arguments as :func:`.common.saliency`, with 93 | the defaults required to apply the DeConvNet method, and supports the 94 | same arguments and return values. 95 | """ 96 | return saliency(*args, context_builder=context_builder, **kwargs) 97 | -------------------------------------------------------------------------------- /torchray/attribution/grad_cam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """ 4 | This module provides an implementation of the *Grad-CAM* method of [GRADCAM]_ 5 | for saliency visualization. The simplest interface is given by the 6 | :func:`grad_cam` function: 7 | 8 | .. literalinclude:: ../examples/grad_cam.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". Grad-CAM backprop 13 | is a variant of the gradient method, applied at an intermediate layer: 14 | 15 | .. literalinclude:: ../examples/grad_cam_manual.py 16 | :language: python 17 | :linenos: 18 | 19 | Note that the function :func:`gradient_to_grad_cam_saliency` is used to convert 20 | activations and gradients to a saliency map. 21 | 22 | See also :ref:`backprop` for further examples and discussion. 23 | 24 | Theory 25 | ~~~~~~ 26 | 27 | Grad-CAM can be seen as a variant of the *gradient* method 28 | (:mod:`torchray.attribution.gradient`) with two differences: 29 | 30 | 1. The saliency is measured at an intermediate layer of the network, usually at 31 | the output of the last convolutional layer. 32 | 33 | 2. Saliency is defined as the clamped product of forward activation and 34 | backward gradient at that layer. 35 | 36 | References: 37 | 38 | .. [GRADCAM] Ramprasaath R. Selvaraju, Abhishek Das, Ramakrishna Vedantam, 39 | Michael Cogswell, Devi Parikh and Dhruv Batra, 40 | *Visual Explanations from Deep Networks via Gradient-based 41 | Localization,* 42 | ICCV 2017, 43 | ``__. 44 | """ 45 | 46 | __all__ = ["grad_cam"] 47 | 48 | import torch 49 | from .common import saliency 50 | 51 | 52 | def gradient_to_grad_cam_saliency(x): 53 | r"""Convert activation and gradient to a Grad-CAM saliency map. 54 | 55 | The tensor :attr:`x` must have a valid gradient ``x.grad``. 56 | The function then computes the saliency map :math:`s`: given by: 57 | 58 | .. math:: 59 | 60 | s_{n1u} = \max\{0, \sum_{c}x_{ncu}\cdot dx_{ncu}\} 61 | 62 | Args: 63 | x (:class:`torch.Tensor`): activation tensor with a valid gradient. 64 | 65 | Returns: 66 | :class:`torch.Tensor`: saliency map. 67 | """ 68 | # Apply global average pooling (GAP) to gradient. 69 | grad_weight = torch.mean(x.grad, (2, 3), keepdim=True) 70 | 71 | # Linearly combine activations and GAP gradient weights. 72 | saliency_map = torch.sum(x * grad_weight, 1, keepdim=True) 73 | 74 | # Apply ReLU to visualization. 75 | saliency_map = torch.clamp(saliency_map, min=0) 76 | 77 | return saliency_map 78 | 79 | 80 | def grad_cam(*args, 81 | saliency_layer, 82 | gradient_to_saliency=gradient_to_grad_cam_saliency, 83 | **kwargs): 84 | r"""Grad-CAM method. 85 | 86 | The function takes the same arguments as :func:`.common.saliency`, with 87 | the defaults required to apply the Grad-CAM method, and supports the 88 | same arguments and return values. 89 | """ 90 | return saliency(*args, 91 | saliency_layer=saliency_layer, 92 | gradient_to_saliency=gradient_to_saliency, 93 | **kwargs,) 94 | -------------------------------------------------------------------------------- /torchray/attribution/gradient.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module implements the *gradient* method of [GRAD]_ for visualizing a deep 5 | network. It is a backpropagation method, and in fact the simplest of them all 6 | as it coincides with standard backpropagation. The simplest way to use this 7 | method is via the :func:`gradient` function: 8 | 9 | .. literalinclude:: ../examples/gradient.py 10 | :language: python 11 | :linenos: 12 | 13 | Alternatively, one can do so manually, as follows 14 | 15 | .. literalinclude:: ../examples/gradient_manual.py 16 | :language: python 17 | :linenos: 18 | 19 | Note that in this example, for visualization, the gradient is 20 | convernted into an image by postprocessing by using the function 21 | :func:`torchray.attribution.common.saliency`. 22 | 23 | See also :ref:`backprop` for further examples. 24 | 25 | References: 26 | 27 | .. [GRAD] Karen Simonyan, Andrea Vedaldi and Andrew Zisserman, 28 | *Deep Inside Convolutional Networks: 29 | Visualising Image Classification Models and Saliency Maps,* 30 | ICLR workshop, 2014, 31 | ``__. 32 | """ 33 | 34 | __all__ = ["gradient"] 35 | 36 | from .common import saliency 37 | 38 | 39 | def gradient(*args, context_builder=None, **kwargs): 40 | r"""Gradient method 41 | 42 | The function takes the same arguments as :func:`.common.saliency`, with 43 | the defaults required to apply the gradient method, and supports the 44 | same arguments and return values. 45 | """ 46 | assert context_builder is None 47 | return saliency(*args, **kwargs) 48 | -------------------------------------------------------------------------------- /torchray/attribution/guided_backprop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module implements *guided backpropagation* method of [GUIDED]_ or 5 | visualizing deep networks. The simplest interface is given by the 6 | :func:`guided_backprop` function: 7 | 8 | .. literalinclude:: ../examples/guided_backprop.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". Guided backprop is 13 | a backpropagation method, and thus works by changing the definition of the 14 | backward functions of some layers. This can be done using the 15 | :class:`GuidedBackpropContext` context: 16 | 17 | .. literalinclude:: ../examples/guided_backprop_manual.py 18 | :language: python 19 | :linenos: 20 | 21 | See also :ref:`backprop` for further examples. 22 | 23 | Theory 24 | ~~~~~~ 25 | 26 | Guided backprop is a backpropagation method, and thus it works by changing the 27 | definition of the backward functions of some layers. The only change is a 28 | modified definition of the backward ReLU function: 29 | 30 | .. math:: 31 | \operatorname{ReLU}^*(x,p) = 32 | \begin{cases} 33 | p, & \mathrm{if}~p > 0 ~\mathrm{and}~ x > 0,\\ 34 | 0, & \mathrm{otherwise} \\ 35 | \end{cases} 36 | 37 | The modified ReLU is implemented by class :class:`GuidedBackpropReLU`. 38 | 39 | References: 40 | 41 | .. [GUIDED] Springenberg et al., 42 | *Striving for simplicity: The all convolutional net*, 43 | ICLR Workshop 2015, 44 | ``__. 45 | """ 46 | 47 | __all__ = ['GuidedBackpropContext', 'guided_backprop'] 48 | 49 | import torch 50 | 51 | from .common import ReLUContext, saliency 52 | 53 | 54 | class GuidedBackpropReLU(torch.autograd.Function): 55 | """This class implements a ReLU function with the guided backprop rules.""" 56 | @staticmethod 57 | def forward(ctx, input): 58 | """Guided backprop ReLU forward function.""" 59 | ctx.save_for_backward(input) 60 | return input.clamp(min=0) 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | """Guided backprop ReLU backward function.""" 65 | input, = ctx.saved_tensors 66 | grad_input = grad_output.clone() 67 | grad_input[input < 0] = 0 68 | grad_input = grad_input.clamp(min=0) 69 | return grad_input 70 | 71 | 72 | class GuidedBackpropContext(ReLUContext): 73 | r"""GuidedBackprop context. 74 | 75 | This context modifies the computation of gradients 76 | to match the guided backpropagaton definition. 77 | 78 | See :mod:`torchray.attribution.guided_backprop` for how to use it. 79 | """ 80 | 81 | def __init__(self): 82 | super(GuidedBackpropContext, self).__init__(GuidedBackpropReLU) 83 | 84 | 85 | def guided_backprop(*args, context_builder=GuidedBackpropContext, **kwargs): 86 | r"""Guided backprop. 87 | 88 | The function takes the same arguments as :func:`.common.saliency`, with 89 | the defaults required to apply the guided backprop method, and supports the 90 | same arguments and return values. 91 | """ 92 | return saliency(*args, 93 | context_builder=context_builder, 94 | **kwargs) 95 | -------------------------------------------------------------------------------- /torchray/attribution/linear_approx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module provides an implementation of the *linear approximation* method 5 | for saliency visualization. The simplest interface is given by the 6 | :func:`linear_approx` function: 7 | 8 | .. literalinclude:: ../examples/linear_approx.py 9 | :language: python 10 | :linenos: 11 | 12 | Alternatively, it is possible to run the method "manually". Linear 13 | approximation is a variant of the gradient method, applied at an intermediate 14 | layer: 15 | 16 | .. literalinclude:: ../examples/linear_approx_manual.py 17 | :language: python 18 | :linenos: 19 | 20 | Note that the function :func:`gradient_to_linear_approx_saliency` is used to 21 | convert activations and gradients to a saliency map. 22 | """ 23 | 24 | __all__ = ['gradient_to_linear_approx_saliency', 'linear_approx'] 25 | 26 | 27 | import torch 28 | from .common import saliency 29 | 30 | 31 | def gradient_to_linear_approx_saliency(x): 32 | """Returns the linear approximation of a tensor. 33 | 34 | The tensor :attr:`x` must have a valid gradient ``x.grad``. 35 | The function then computes the saliency map :math:`s`: given by: 36 | 37 | .. math:: 38 | 39 | s_{n1u} = \sum_{c} x_{ncu} \cdot dx_{ncu} 40 | 41 | Args: 42 | x (:class:`torch.Tensor`): activation tensor with a valid gradient. 43 | 44 | Returns: 45 | :class:`torch.Tensor`: Saliency map. 46 | """ 47 | viz = torch.sum(x * x.grad, 1, keepdim=True) 48 | return viz 49 | 50 | 51 | def linear_approx(*args, 52 | gradient_to_saliency=gradient_to_linear_approx_saliency, 53 | **kwargs): 54 | """Linear approximation. 55 | 56 | The function takes the same arguments as :func:`.common.saliency`, with 57 | the defaults required to apply the linear approximation method, and 58 | supports the same arguments and return values. 59 | """ 60 | return saliency(*args, 61 | gradient_to_saliency=gradient_to_saliency, 62 | **kwargs) 63 | -------------------------------------------------------------------------------- /torchray/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | r"""This script provides a few functions for getting and plotting example data. 2 | """ 3 | import os 4 | import torchvision 5 | from matplotlib import pyplot as plt 6 | 7 | from .datasets import * # noqa 8 | from .models import * # noqa 9 | 10 | 11 | def get_example_data(arch='vgg16', shape=224): 12 | """Get example data to demonstrate visualization techniques. 13 | 14 | Args: 15 | arch (str, optional): name of torchvision.models architecture. 16 | Default: ``'vgg16'``. 17 | shape (int or tuple of int, optional): shape to resize input image to. 18 | Default: ``224``. 19 | 20 | Returns: 21 | (:class:`torch.nn.Module`, :class:`torch.Tensor`, int, int): a tuple 22 | containing 23 | 24 | - a convolutional neural network model in evaluation mode. 25 | - a sample input tensor image. 26 | - the ImageNet category id of an object in the image. 27 | - the ImageNet category id of another object in the image. 28 | 29 | """ 30 | 31 | # Get a network pre-trained on ImageNet. 32 | model = torchvision.models.__dict__[arch](pretrained=True) 33 | 34 | # Switch to eval mode to make the visualization deterministic. 35 | model.eval() 36 | 37 | # We do not need grads for the parameters. 38 | for param in model.parameters(): 39 | param.requires_grad_(False) 40 | 41 | # Download an example image from wikimedia. 42 | import requests 43 | from io import BytesIO 44 | from PIL import Image 45 | 46 | url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/7/7f/Arthur_Heyer_-_Dog_and_Cats.jpg/592px-Arthur_Heyer_-_Dog_and_Cats.jpg' 47 | response = requests.get(url) 48 | img = Image.open(BytesIO(response.content)) 49 | 50 | # Pre-process the image and convert into a tensor 51 | transform = torchvision.transforms.Compose([ 52 | torchvision.transforms.Resize(shape), 53 | torchvision.transforms.CenterCrop(shape), 54 | torchvision.transforms.ToTensor(), 55 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 56 | std=[0.229, 0.224, 0.225]), 57 | ]) 58 | 59 | x = transform(img).unsqueeze(0) 60 | 61 | # bulldog category id. 62 | category_id_1 = 245 63 | 64 | # persian cat category id. 65 | category_id_2 = 285 66 | 67 | # Move model and input to device. 68 | from torchray.utils import get_device 69 | dev = get_device() 70 | model = model.to(dev) 71 | x = x.to(dev) 72 | 73 | return model, x, category_id_1, category_id_2 74 | 75 | 76 | def plot_example(input, 77 | saliency, 78 | method, 79 | category_id, 80 | show_plot=False, 81 | save_path=None): 82 | """Plot an example. 83 | 84 | Args: 85 | input (:class:`torch.Tensor`): 4D tensor containing input images. 86 | saliency (:class:`torch.Tensor`): 4D tensor containing saliency maps. 87 | method (str): name of saliency method. 88 | category_id (int): ID of ImageNet category. 89 | show_plot (bool, optional): If True, show plot. Default: ``False``. 90 | save_path (str, optional): Path to save figure to. Default: ``None``. 91 | """ 92 | from torchray.utils import imsc 93 | from torchray.benchmark.datasets import IMAGENET_CLASSES 94 | 95 | if isinstance(category_id, int): 96 | category_id = [category_id] 97 | 98 | batch_size = len(input) 99 | 100 | plt.clf() 101 | for i in range(batch_size): 102 | class_i = category_id[i % len(category_id)] 103 | 104 | plt.subplot(batch_size, 2, 1 + 2 * i) 105 | imsc(input[i]) 106 | plt.title('input image', fontsize=8) 107 | 108 | plt.subplot(batch_size, 2, 2 + 2 * i) 109 | imsc(saliency[i], interpolation='none') 110 | plt.title('{} for category {} ({})'.format( 111 | method, IMAGENET_CLASSES[class_i], class_i), fontsize=8) 112 | 113 | # Save figure if path is specified. 114 | if save_path: 115 | save_dir = os.path.dirname(os.path.abspath(save_path)) 116 | # Create directory if necessary. 117 | if not os.path.exists(save_dir): 118 | os.makedirs(save_dir) 119 | ext = os.path.splitext(save_path)[1].strip('.') 120 | plt.savefig(save_path, format=ext, bbox_inches='tight') 121 | 122 | # Show plot if desired. 123 | if show_plot: 124 | plt.show() 125 | -------------------------------------------------------------------------------- /torchray/benchmark/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | r""" 4 | This module is used to start and run a MongoDB server. 5 | 6 | To start a MongoDB server, use 7 | 8 | .. code:: shell 9 | 10 | $ python -m torchray.benchmark.server 11 | 12 | """ 13 | import subprocess 14 | from torchray.utils import get_config 15 | 16 | 17 | def run_server(): 18 | """Runs an instance of MongoDB as a logging server.""" 19 | config = get_config() 20 | command = [ 21 | config['mongo']['server'], 22 | '--dbpath', config['mongo']['database'], 23 | '--bind_ip', config['mongo']['hostname'], 24 | '--port', str(config['mongo']['port']) 25 | ] 26 | print(f"Command: {' '.join(command)}.") 27 | code = subprocess.call(command, cwd=".") 28 | print(f"Return code {code}") 29 | 30 | 31 | if __name__ == '__main__': 32 | run_server() 33 | --------------------------------------------------------------------------------