├── docs └── 深度学习tricks-我的一点个人见解.md ├── torch_utils ├── models │ ├── backbone │ │ ├── __init__.py │ │ └── timm_models.py │ ├── layers │ │ ├── __init__.py │ │ ├── anti_alias.py │ │ └── layers.py │ ├── utils │ │ ├── __init__.py │ │ └── freeze_norm.py │ ├── seg_models │ │ └── __init__.py │ ├── cls_models │ │ ├── __init__.py │ │ ├── hybrid_cls_model.py │ │ └── simple_cls_model.py │ └── __init__.py ├── optimizer │ ├── timm_optim │ │ ├── __init__.py │ │ ├── lars.py │ │ └── lamb.py │ ├── ranger21 │ │ ├── __init__.py │ │ ├── chebyshev_lr_functions.py │ │ └── rangerabel.py │ ├── __init__.py │ ├── group_optim.py │ ├── lookahead.py │ ├── over9000.py │ ├── ranger.py │ ├── radam.py │ └── gc.py ├── dataset │ ├── data_all_in_gpu.py │ ├── __init__.py │ ├── visualize.py │ ├── dataloader.py │ ├── random.py │ ├── del_duplicate_image.py │ ├── common_aug.py │ ├── randaugment.py │ ├── customized_aug.py │ └── mixup.py ├── advanced │ ├── __init__.py │ ├── dolg.py │ ├── arcface.py │ └── NextVLAD.py ├── lr_scheduler │ ├── __init__.py │ ├── concat.py │ ├── customized.py │ ├── onecycle.py │ └── cosine_annealing_with_warmup.py ├── __init__.py ├── criterion │ ├── __init__.py │ ├── metric_loss.py │ ├── rmi.py │ ├── focal.py │ ├── lovasz.py │ ├── dice.py │ ├── cross_entropy.py │ └── bitempered_loss.py └── tools.py ├── .flake8 ├── requirements.txt ├── tests └── unit_test │ ├── test_layers.py │ ├── test_models.py │ └── test_losses.py ├── .gitignore ├── .pre-commit-config.yaml ├── setup.py └── README.md /docs/深度学习tricks-我的一点个人见解.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torch_utils/models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .timm_models import create_timm_model 2 | -------------------------------------------------------------------------------- /torch_utils/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .anti_alias import * 3 | -------------------------------------------------------------------------------- /torch_utils/optimizer/timm_optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lamb import Lamb 2 | from .lars import Lars 3 | -------------------------------------------------------------------------------- /torch_utils/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .freeze_norm import set_bn_eval, set_bn_train, freeze_bn, unfreeze_bn 2 | -------------------------------------------------------------------------------- /torch_utils/optimizer/ranger21/__init__.py: -------------------------------------------------------------------------------- 1 | from .ranger21 import Ranger21 2 | from .rangerabel import Ranger21abel 3 | -------------------------------------------------------------------------------- /torch_utils/models/seg_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNet, get_encoder_info, get_hrnet, get_unet, get_unet_ps 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 150 3 | per-file-ignores = __init__.py: F401, F403 4 | ignore = W293, W504, E126, W503, F401, F403, E722, F541, W291, E266 -------------------------------------------------------------------------------- /torch_utils/dataset/data_all_in_gpu.py: -------------------------------------------------------------------------------- 1 | # TODO 2 | # CUDA dataset + kornia data augmentation 3 | # CUDA Prefetcher Loader + fast_collate 4 | # kornia mixup and cutmix 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm>=0.4.5 2 | albumentations>=1.0.0 3 | pytorch-metric-learning>=0.9.99 4 | torchinfo>=0.1.1 5 | torch-lr-finder>=0.2.1 6 | imagehash>=4.1.0 7 | thop>=0.0.31 8 | -------------------------------------------------------------------------------- /torch_utils/models/cls_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_cls_model import ImageModel, get_encoder_last_channel, get_conv_model 2 | from .hybrid_cls_model import HybridModel, get_hybrid_swin 3 | -------------------------------------------------------------------------------- /torch_utils/advanced/__init__.py: -------------------------------------------------------------------------------- 1 | from .NextVLAD import NeXtVLAD 2 | from .arcface import ArcMarginProduct, ArcFaceLoss, ArcMarginProduct_subcenter, ArcFaceLossAdaptiveMargin 3 | from .dolg import DolgNet 4 | -------------------------------------------------------------------------------- /torch_utils/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine_annealing_with_warmup import CosineAnnealingWarmupRestarts 2 | 3 | from torch.optim.lr_scheduler import OneCycleLR 4 | 5 | from .customized import get_scheduler, get_poly_scheduler, get_flat_anneal_scheduler 6 | -------------------------------------------------------------------------------- /torch_utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixup import Mixup, MixupDataset 2 | 3 | from .dataloader import PrefetchLoader 4 | 5 | from .randaugment import randAugment 6 | 7 | from .del_duplicate_image import delete_duplicate_imghash 8 | 9 | from .visualize import write_aug 10 | 11 | from .random import random 12 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .advanced import * 2 | from .criterion import * 3 | from .dataset import * 4 | from .lr_scheduler import * 5 | from .models import * 6 | from .optimizer import * 7 | 8 | from . import tools 9 | 10 | __version__ = '0.1.0' 11 | 12 | 13 | def get_version(): 14 | return __version__ 15 | -------------------------------------------------------------------------------- /torch_utils/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .ranger import Ranger 2 | from .radam import RAdam 3 | from .lookahead import Lookahead 4 | from .over9000 import RangerLars 5 | from .gc import SGD_GCC, SGD_GC, AdamW_GCC2 6 | from .ranger21 import Ranger21, Ranger21abel 7 | from .timm_optim import Lamb, Lars 8 | from .group_optim import get_params 9 | -------------------------------------------------------------------------------- /tests/unit_test/test_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_utils import layers 3 | from torch_utils import set_bn_eval 4 | 5 | 6 | class TestModel: 7 | feature = torch.rand(2, 32, 56, 56) 8 | 9 | def test_anti_alias(self): 10 | anti_alias = layers.Anti_Alias_Filter(32).eval() 11 | assert anti_alias(TestModel.feature).shape == (2, 32, 56, 56) 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.idea 4 | *.vscode 5 | *.__pycache__ 6 | 7 | 8 | # compilation and distribution 9 | __pycache__ 10 | *.egg 11 | *.egg-info 12 | _ext 13 | *.pyc 14 | *.so 15 | build/ 16 | dist/ 17 | wheels/ 18 | cocoapi* 19 | panopticapi* 20 | 21 | 22 | # ipython/jupyter notebooks 23 | **/.ipynb_checkpoints/ 24 | 25 | # Editor temporaries 26 | *.swn 27 | *.swo 28 | *.swp 29 | *~ 30 | 31 | exp/* 32 | -------------------------------------------------------------------------------- /torch_utils/optimizer/ranger21/chebyshev_lr_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from https://arxiv.org/abs/2103.01338v1 4 | 5 | 6 | def cheb_steps(m, M, T): 7 | C, R = (M + m) / 2.0, (M - m) / 2.0 8 | thetas = (np.arange(T) + 0.5) / T * np.pi 9 | return 1.0 / (C - R * np.cos(thetas)) 10 | 11 | 12 | def cheb_perm(T): 13 | perm = np.array([0]) 14 | while len(perm) < T: 15 | perm = np.vstack([perm, 2 * len(perm) - 1 - perm]).T.flatten() 16 | return perm 17 | 18 | 19 | # steps = cheb_steps(0.1,1,8) 20 | # perm = cheb_perm(8) 21 | # schedule = steps[perm] 22 | -------------------------------------------------------------------------------- /torch_utils/models/utils/freeze_norm.py: -------------------------------------------------------------------------------- 1 | def set_bn_eval(m): 2 | classname = m.__class__.__name__ 3 | if classname.find('BatchNorm') != -1: 4 | m.eval() 5 | 6 | 7 | def set_bn_train(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('BatchNorm') != -1: 10 | m.train() 11 | 12 | 13 | def freeze_bn(model): 14 | for m in model.named_modules(): 15 | set_bn_eval(m[1]) 16 | 17 | 18 | def unfreeze_bn(model): 19 | for m in model.named_modules(): 20 | set_bn_train(m[1]) 21 | 22 | # usage: model.apply(freeze_bn) # this will freeze the bn in training process 23 | -------------------------------------------------------------------------------- /torch_utils/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import LabelSmoothingCrossEntropy, SmoothBCEwLogits 2 | from .cross_entropy import SoftTargetCrossEntropy, KLDivLosswSoftmax 3 | from .cross_entropy import topkLoss 4 | 5 | from .metric_loss import CircleLoss, ArcFaceLoss, SupConLoss 6 | from .metric_loss import InfoNCE 7 | from .metric_loss import CrossBatchMemory 8 | from .metric_loss import MoCo, SupConLoss_MoCo 9 | 10 | # seg losses 11 | from .cross_entropy import SoftBCEWithLogitsLoss, SoftCrossEntropyLoss 12 | from .lovasz import BinaryLovaszLoss, LovaszLoss 13 | from .focal import BinaryFocalLoss, FocalLoss 14 | from .bitempered_loss import BiTemperedLogisticLoss, BinaryBiTemperedLogisticLoss 15 | from .dice import DiceLoss, TverskyLoss 16 | from .rmi import RMILoss 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitlab.com/pycqa/flake8.git 3 | rev: 3.8.3 4 | hooks: 5 | - id: flake8 6 | args: [--max-line-length=150] 7 | - repo: https://github.com/pre-commit/mirrors-yapf 8 | rev: v0.30.0 9 | hooks: 10 | - id: yapf 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v3.1.0 13 | hooks: 14 | - id: trailing-whitespace 15 | - id: check-yaml 16 | - id: end-of-file-fixer 17 | - id: requirements-txt-fixer 18 | - id: double-quote-string-fixer 19 | - id: check-merge-conflict 20 | - id: fix-encoding-pragma 21 | args: ["--remove"] 22 | - id: mixed-line-ending 23 | args: ["--fix=lf"] 24 | -------------------------------------------------------------------------------- /torch_utils/models/backbone/timm_models.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | popular_models = { 4 | 'resnest50d': 2048, 5 | 'resnetv2_50x1_bitm': 2048, 6 | 'swsl_resnext50_32x4d': 2048, 7 | 'densenet121': 1024, 8 | 'seresnext50_32x4d': 2048, 9 | } 10 | 11 | 12 | def create_timm_model(name, pretrained=True, num_classes=0, in_channel=3): 13 | # when in_channel==1, we suggest to manually modify the weight of the first layer by sum func 14 | # timm implementation uses strategy of circular copying RGB channel weight and rescale (not good for all cases) 15 | if num_classes: 16 | model = timm.create_model(name, pretrained=pretrained, num_classes=num_classes, in_chans=in_channel) 17 | else: 18 | model = timm.create_model(name, features_only=True, pretrained=pretrained, in_chans=in_channel) 19 | return model 20 | -------------------------------------------------------------------------------- /torch_utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/TylerYep/torchinfo 2 | try: 3 | from torchinfo import summary 4 | except: 5 | print('[Warning] torchinfo not installed') 6 | 7 | # https://github.com/davidtvs/pytorch-lr-finder 8 | try: 9 | from torch_lr_finder import LRFinder 10 | except: 11 | print('[Warning] torch_lr_finder not installed') 12 | 13 | # TODO: https://github.com/Stonesjtu/pytorch_memlab 14 | 15 | # https://github.com/Lyken17/pytorch-OpCounter 16 | try: 17 | import thop 18 | import torch 19 | 20 | def profile(model, input_shape=(1, 3, 224, 224)): 21 | macs, params = thop.profile(model, inputs=(torch.randn(*input_shape), )) 22 | return {'macs': macs, 'params': params} 23 | except: 24 | print('[Warning] thop not installed') 25 | 26 | from . import layers 27 | from .utils import * 28 | from .backbone import * 29 | from .cls_models import * 30 | from .seg_models import * 31 | -------------------------------------------------------------------------------- /torch_utils/dataset/visualize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | 4 | # visualize tools 5 | 6 | 7 | def visualize(**images): 8 | """PLot images in one row.""" 9 | n = len(images) 10 | plt.figure(figsize=(16, 5)) 11 | for i, (name, image) in enumerate(images.items()): 12 | plt.subplot(1, n, i + 1) 13 | plt.xticks([]) 14 | plt.yticks([]) 15 | plt.title(' '.join(name.split('_')).title()) 16 | plt.imshow(image) 17 | plt.show() 18 | 19 | 20 | def test_transform(img_path, transform): 21 | img = cv2.imread(img_path) 22 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 23 | img = transform(image=img)['image'] 24 | visualize(image=img) 25 | 26 | 27 | def write_aug(img_path, transform, num=30): 28 | img = cv2.imread(img_path) 29 | for i in range(num): 30 | t = transform(image=img)['image'] 31 | cv2.imwrite('./aug/' + str(i) + '.jpg', t) 32 | -------------------------------------------------------------------------------- /torch_utils/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | CUDA Prefetcher Loader 3 | changed from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/loader.py 4 | """ 5 | import torch 6 | 7 | 8 | class PrefetchLoader: 9 | 10 | def __init__(self, loader): 11 | self.loader = loader 12 | 13 | def __iter__(self): 14 | stream = torch.cuda.Stream() 15 | first = True 16 | 17 | data, target = None, None 18 | for next_data, next_target in self.loader: 19 | with torch.cuda.stream(stream): 20 | next_data = next_data.cuda(non_blocking=True) 21 | next_target = next_target.cuda(non_blocking=True) 22 | 23 | if not first: 24 | yield data, target 25 | else: 26 | first = False 27 | 28 | torch.cuda.current_stream().wait_stream(stream) 29 | data = next_data 30 | target = next_target 31 | 32 | yield data, target 33 | -------------------------------------------------------------------------------- /torch_utils/dataset/random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class random(): 5 | """ 6 | random functions in pytorch 7 | """ 8 | @staticmethod 9 | def randint(low, high): 10 | return int(torch.randint(low, high + 1, (1,)).numpy()) 11 | 12 | @staticmethod 13 | def random(): 14 | return float(torch.rand(1)[0].numpy()) 15 | 16 | @staticmethod 17 | def uniform(low, high): 18 | low = torch.FloatTensor([low]) 19 | high = torch.FloatTensor([high]) 20 | m = torch.distributions.uniform.Uniform(low, high) 21 | sample = float(m.sample()[0].numpy()) 22 | return sample 23 | 24 | @staticmethod 25 | def choice(samples): 26 | idx = torch.randint(len(samples), (1,)) 27 | return samples[idx] 28 | 29 | @staticmethod 30 | def choices(samples, k=1, replacement=True): 31 | if replacement: 32 | idxs = torch.randint(len(samples), (k,)) 33 | else: 34 | idxs = torch.randperm(len(samples))[:k] 35 | result = [] 36 | for idx in idxs: 37 | result.append(samples[idx]) 38 | return result 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | 5 | from setuptools import find_packages, setup 6 | from torch_utils import get_version 7 | 8 | # Get the long description from the README file 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | 14 | def install_package(package): 15 | output = subprocess.check_output( 16 | [sys.executable, '-m', 'pip', 'install', package]) 17 | print(output.decode()) 18 | 19 | 20 | def load_package(requirements_path='requirements.txt'): 21 | requirements = [] 22 | with open(requirements_path, 'r') as f: 23 | for each in f.readlines(): 24 | requirements.append(each.strip()) 25 | return requirements 26 | 27 | 28 | setup(name='torch_utils', 29 | version=get_version(), 30 | description='(WIP)(Unofficial) PyTorch Utils', 31 | long_description=long_description, 32 | long_description_content_type='text/markdown', 33 | url='https://github.com/seefun/TorchUtils', 34 | author='See Fun', 35 | author_email='seefun@outlook.com', 36 | packages=find_packages(), 37 | install_requires=load_package('./requirements.txt'), 38 | include_package_data=True) 39 | -------------------------------------------------------------------------------- /torch_utils/models/cls_models/hybrid_cls_model.py: -------------------------------------------------------------------------------- 1 | # Hybrid Vision Transformer 2 | import torch.nn as nn 3 | import timm 4 | from timm.models.vision_transformer_hybrid import HybridEmbed 5 | 6 | 7 | class HybridModel(nn.Module): 8 | def __init__(self, vit='swin_base_patch4_window7_224', embedder='tf_efficientnet_b4_ns', 9 | classes=1, input_size=448, pretrained=True): 10 | super(HybridModel, self).__init__() 11 | self.vit = timm.create_model(vit, pretrained=pretrained) 12 | self.embedder = timm.create_model(embedder, features_only=True, out_indices=[2], pretrained=pretrained) 13 | self.vit.patch_embed = HybridEmbed(self.embedder, img_size=input_size, embed_dim=128) 14 | self.n_features = self.vit.head.in_features 15 | self.vit.head = nn.Linear(self.n_features, classes) 16 | 17 | def forward(self, images): 18 | features = self.vit.forward_features(images) 19 | x = self.vit.head(features) 20 | return x 21 | 22 | 23 | def get_hybrid_swin(swin_type='base', embedder='tf_efficientnet_b4_ns', classes=1, pretrained=True): 24 | # input size 448x448 25 | assert swin_type in ['large', 'base', 'small', 'tiny'] 26 | swin_name = 'swin_' + swin_type + '_patch4_window7_224' 27 | model = HybridModel(vit=swin_name, embedder=embedder, classes=classes, input_size=448, pretrained=pretrained) 28 | return model 29 | -------------------------------------------------------------------------------- /torch_utils/criterion/metric_loss.py: -------------------------------------------------------------------------------- 1 | # from https://kevinmusgrave.github.io/pytorch-metric-learning/losses/ 2 | 3 | from pytorch_metric_learning.losses import CircleLoss, ArcFaceLoss, SupConLoss 4 | from pytorch_metric_learning.losses import NTXentLoss, CrossBatchMemory 5 | 6 | # logist, embeddings = model_conv(input) 7 | # loss_func = ArcFaceLoss(num_classes, embedding_size, margin=28.6, scale=64).to(torch.device('cuda')) 8 | # loss_func = CircleLoss(m=0.4, gamma=80).to(torch.device('cuda')) 9 | # loss_func = SupConLoss(temperature=0.1).to(torch.device('cuda')) 10 | # metric_loss = loss_func(embeddings, labels) # in your training for-loop 11 | 12 | CircleLoss = CircleLoss 13 | ArcFaceLoss = ArcFaceLoss 14 | SupConLoss = SupConLoss 15 | 16 | InfoNCE = NTXentLoss 17 | CrossBatchMemory = CrossBatchMemory 18 | # CrossBatchMemory(loss, embedding_size, memory_size=1024, miner=None) 19 | 20 | 21 | class MoCo(CrossBatchMemory): 22 | def __init__(self, embedding_size, memory_size): 23 | super(MoCo, self).__init__(NTXentLoss(temperature=0.1), 24 | embedding_size, 25 | memory_size) 26 | 27 | 28 | class SupConLoss_MoCo(CrossBatchMemory): 29 | def __init__(self, embedding_size, memory_size): 30 | super(SupConLoss_MoCo, self).__init__(SupConLoss(temperature=0.1), 31 | embedding_size, 32 | memory_size) 33 | -------------------------------------------------------------------------------- /torch_utils/lr_scheduler/concat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, OneCycleLR 3 | 4 | 5 | class ConcatLR(torch.optim.lr_scheduler._LRScheduler): 6 | def __init__(self, optimizer, scheduler1, scheduler2, total_steps, pct_start=0.5, last_epoch=-1): 7 | self.scheduler1 = scheduler1 8 | self.scheduler2 = scheduler2 9 | self.step_start = float(pct_start * total_steps) - 1 10 | super(ConcatLR, self).__init__(optimizer, last_epoch) 11 | 12 | def step(self): 13 | if self.last_epoch <= self.step_start: 14 | self.scheduler1.step() 15 | else: 16 | self.scheduler2.step() 17 | super().step() 18 | 19 | def get_lr(self): 20 | if self.last_epoch <= self.step_start: 21 | return self.scheduler1.get_lr() 22 | else: 23 | return self.scheduler2.get_lr() 24 | 25 | # Example: (from https://github.com/mgrankin/over9000/blob/master/train.py) 26 | # 27 | # Attention !!!: torch.optim.lr_scheduler.SequentialLR is able to replace ConcatLR 28 | # 29 | # from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, OneCycleLR 30 | # 31 | # def d(x): 32 | # return 1 33 | # 34 | # if sched_type == 'flat_and_anneal': 35 | # dummy = LambdaLR(optimizer, d) 36 | # cosine = CosineAnnealingLR(optimizer, total_steps*(1-ann_start)) 37 | # scheduler = ConcatLR(optimizer, dummy, cosine, total_steps, ann_start) 38 | # else: 39 | # scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps, pct_start=0.3, 40 | # div_factor=10, cycle_momentum=True) 41 | -------------------------------------------------------------------------------- /torch_utils/tools.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | import shutil 7 | 8 | 9 | def set_gpus(gpus: str): 10 | ''' 11 | setting cuda devices. 12 | 13 | Example: 14 | >>> set_gpus('0,1,2,3') 15 | ''' 16 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 17 | 18 | 19 | def seed_everything(seed=42, deterministic=True): 20 | ''' 21 | seed everything (os, np, torch and torch.cuda). 22 | deterministic = True : using deterministic algo to make exp reproducible 23 | deterministic = False: using cudnn.benchmark to speed up training 24 | 25 | Example: 26 | >>> seed_everything(42) 27 | ''' 28 | random.seed(seed) 29 | os.environ['PYHTONHASHSEED'] = str(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | if deterministic: 35 | torch.backends.cudnn.deterministic = True 36 | else: 37 | torch.backends.cudnn.benchmark = True 38 | 39 | 40 | def worker_init_fn(worker_id): 41 | """ 42 | used in dataloader to avoid numpy random bug in multi workers pytorch dataloader 43 | https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ 44 | 45 | Example: 46 | >>> DataLoader(dataset, batch_size=2, num_workers=4, worker_init_fn=worker_init_fn) 47 | >>> # And also, in each epoch start, you should do: 48 | >>> np.random.seed(initial_seed + epoch*999) 49 | """ 50 | np.random.seed(np.random.get_state()[1][0] + worker_id) 51 | 52 | 53 | def backup_folder(source='.', destination='../exp/exp1/src'): 54 | shutil.copytree(source, destination) 55 | 56 | 57 | def backup_file(source='param.py', destination='../exp/exp1/parma.py'): 58 | shutil.copyfile(source, destination) 59 | -------------------------------------------------------------------------------- /torch_utils/optimizer/group_optim.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | 4 | 5 | def get_params(model, key): 6 | # conv weight 7 | if key == 'conv_weight': 8 | for m in model.named_modules(): 9 | if isinstance(m[1], nn.Conv2d): 10 | yield m[1].weight 11 | # bn weight 12 | if key == 'bn_weight': 13 | for m in model.named_modules(): 14 | if isinstance(m[1], _BatchNorm): 15 | if m[1].weight is not None: 16 | yield m[1].weight 17 | if isinstance(m[1], nn.GroupNorm): 18 | if m[1].weight is not None: 19 | yield m[1].weight 20 | # all bias 21 | if key == 'bias': 22 | for m in model.named_modules(): 23 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], _BatchNorm): 24 | if m[1].bias is not None: 25 | yield m[1].bias 26 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.GroupNorm): 27 | if m[1].bias is not None: 28 | yield m[1].bias 29 | 30 | 31 | """ 32 | optimizer = torch.optim.SGD( 33 | params=[ 34 | { 35 | 'params': get_params(model_conv, key='conv_weight'), 36 | 'lr': 1 * LR, 37 | 'weight_decay': 1 * WD, 38 | }, 39 | { 40 | 'params': get_params(model_conv, key='bn_weight'), 41 | 'lr': 1 * LR, 42 | 'weight_decay': 0.1 * WD, 43 | }, 44 | { 45 | 'params': get_params(model_conv, key='bias'), 46 | 'lr': 2 * LR, 47 | 'weight_decay': 0.0, 48 | }], 49 | momentum=0.9) 50 | """ 51 | -------------------------------------------------------------------------------- /torch_utils/lr_scheduler/customized.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.lr_scheduler import LambdaLR 4 | 5 | 6 | def _poly_lr_scheduler(iter, max_iter, gamma=0.05, power=0.9): 7 | if iter >= max_iter: 8 | return gamma 9 | new_lr = ((1 - float(iter) / max_iter) ** power) * (1 - gamma) + gamma 10 | return new_lr 11 | 12 | 13 | def _flat_anneal_schedule(iter, max_iter, warmup_iter=0, decay_start=0.5, anneal='cos', gamma=0.05): 14 | if iter >= max_iter: 15 | return gamma 16 | if warmup_iter and (iter < warmup_iter): 17 | # warmup from 0 to 1 18 | return iter / warmup_iter 19 | if iter <= max_iter * decay_start: 20 | return 1 21 | if anneal == 'cos': 22 | decay_status = (iter / max_iter - decay_start) / (1 - decay_start) * math.pi 23 | new_lr = (math.cos(decay_status) + 1.0) / 2.0 24 | new_lr = new_lr * (1 - gamma) + gamma 25 | else: 26 | decay_status = (iter / max_iter - decay_start) / (1 - decay_start) 27 | new_lr = (1 - decay_status) * (1 - gamma) + gamma 28 | return new_lr 29 | 30 | 31 | def get_scheduler(optimizer, lmbda): 32 | # support multiple parameter groups 33 | num_param_groups = len(optimizer.param_groups) 34 | return LambdaLR(optimizer, lr_lambda=[lmbda] * num_param_groups) 35 | 36 | 37 | def get_poly_scheduler(optimizer, max_iter, gamma=0.05, power=0.9): 38 | def decay_lambda(iter): 39 | return _poly_lr_scheduler(iter, max_iter, gamma, power) 40 | return get_scheduler(optimizer, decay_lambda) 41 | 42 | 43 | def get_flat_anneal_scheduler(optimizer, max_iter, warmup_iter=0, decay_start=0.5, anneal='cos', gamma=0.05): 44 | def decay_lambda(iter): 45 | return _flat_anneal_schedule(iter, max_iter, warmup_iter, decay_start, anneal, gamma) 46 | return get_scheduler(optimizer, decay_lambda) 47 | -------------------------------------------------------------------------------- /tests/unit_test/test_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_utils import models 3 | from torch_utils import advanced 4 | 5 | 6 | class TestModel: 7 | inputs = torch.rand(1, 3, 224, 224) 8 | 9 | def test_cls_model(self): 10 | pooling = ['gem', 'concat', 'avg'] 11 | fc = ['multi-dropout', 'attention', '2layers', 'simple', None] 12 | 13 | for pool_i in pooling: 14 | model_conv = models.ImageModel(pooling=pool_i, pretrained=False).eval() 15 | model_conv(TestModel.inputs) 16 | for fc_i in fc: 17 | model_conv = models.ImageModel(fc=fc_i, pretrained=False).eval() 18 | model_conv(TestModel.inputs) 19 | 20 | def test_unet(self): 21 | model_conv = models.UNet(pretrained=False, aspp=True, hypercolumns=False, deepsupervision=True, clshead=True).eval() 22 | model_conv(TestModel.inputs) 23 | 24 | model_conv = models.UNet(pretrained=False, aspp=True, deepsupervision=True, clshead=True).eval() 25 | model_conv(TestModel.inputs) 26 | 27 | def test_unet_ps(self): 28 | model_conv = models.UNet(pretrained=False, neck='unet_ps', aspp=True, deepsupervision=True, clshead=True).eval() 29 | model_conv(TestModel.inputs) 30 | 31 | def test_hrnet(self): 32 | model_conv = models.UNet(pretrained=False, neck=None, hypercolumns=False, deepsupervision=True, clshead=True).eval() 33 | model_conv(TestModel.inputs) 34 | 35 | model_conv = models.UNet(pretrained=False, neck=None, deepsupervision=True, clshead=True).eval() 36 | model_conv(TestModel.inputs) 37 | 38 | def test_DolgNet(self): 39 | model_conv = advanced.DolgNet('resnet101', False, 224, 3, 512, 512).eval() 40 | model_conv(TestModel.inputs) 41 | 42 | def test_hybrid(self): 43 | model_conv = models.HybridModel(vit='swin_base_patch4_window7_224', 44 | embedder='tf_efficientnet_b4_ns', 45 | classes=1, input_size=448, pretrained=False) 46 | model_conv(torch.rand(1, 3, 448, 448)) 47 | -------------------------------------------------------------------------------- /torch_utils/models/layers/anti_alias.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Anti_Alias_Filter(nn.Module): 7 | """ adaptive low pass filter (anti-alias) used before downsampling (pooling or stride>1 conv) 8 | idea from Delving-Deeper-Into-Anti-Aliasing-in-ConvNets (BMVC2020 best paper) https://arxiv.org/pdf/2008.09604.pdf 9 | code modified from: https://github.com/MaureenZOU/Adaptive-anti-Aliasing/blob/master/models_lpf/layers/pasa.py 10 | """ 11 | 12 | def __init__(self, in_channels, kernel_size=3, groups=8): 13 | super(Anti_Alias_Filter, self).__init__() 14 | self.pad = nn.ReflectionPad2d(kernel_size // 2) 15 | self.kernel_size = kernel_size 16 | self.groups = groups 17 | assert in_channels % groups == 0 18 | 19 | self.conv = nn.Conv2d(in_channels, groups * kernel_size * kernel_size, kernel_size=kernel_size, stride=1, groups=groups, bias=False) 20 | self.bn = nn.BatchNorm2d(groups * kernel_size * kernel_size) 21 | self.softmax = nn.Softmax(dim=1) 22 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 23 | 24 | def forward(self, x): 25 | sigma = self.conv(self.pad(x)) 26 | sigma = self.bn(sigma) 27 | sigma = self.softmax(sigma) 28 | 29 | n, c, h, w = sigma.shape 30 | 31 | sigma = sigma.reshape(n, 1, c, h * w) 32 | 33 | n, c, h, w = x.shape 34 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size).reshape((n, c, self.kernel_size * self.kernel_size, h * w)) 35 | 36 | n, c1, p, q = x.shape 37 | x = x.permute(1, 0, 2, 3).reshape(self.groups, c1 // self.groups, n, p, q).permute(2, 0, 1, 3, 4) 38 | 39 | n, c2, p, q = sigma.shape 40 | sigma = sigma.permute(2, 0, 1, 3).reshape((p // (self.kernel_size * self.kernel_size), 41 | self.kernel_size * self.kernel_size, n, c2, q)).permute(2, 0, 3, 1, 4) 42 | 43 | x = torch.sum(x * sigma, dim=3).reshape(n, c1, h, w) 44 | 45 | return x 46 | -------------------------------------------------------------------------------- /torch_utils/dataset/del_duplicate_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | from PIL import Image 6 | import imagehash 7 | 8 | funcs = [ 9 | imagehash.average_hash, 10 | imagehash.phash, 11 | imagehash.dhash, 12 | imagehash.whash, 13 | ] 14 | 15 | 16 | def delete_duplicate_imghash(imgpath_list, threshold=0.9, verbose=True, cuda=False): 17 | image_ids = [] 18 | hashes = [] 19 | 20 | for path in tqdm(imgpath_list): 21 | image = Image.open(path) 22 | # image_id = os.path.basename(path) 23 | image_id = path 24 | image_ids.append(image_id) 25 | hashes.append(np.array([f(image).hash for f in funcs]).reshape(256)) 26 | 27 | hashes_all = np.array(hashes) 28 | 29 | if cuda: 30 | hashes_all = torch.Tensor(hashes_all.astype(int)).cuda() 31 | sims = np.array([(hashes_all[i] == hashes_all).sum(axis=1).cpu().numpy() / 256 for i in range(hashes_all.shape[0])]) 32 | else: 33 | sims = np.array([np.sum((hashes_all[i] == hashes_all), axis=1) / 256 for i in range(hashes_all.shape[0])]) 34 | 35 | indices1 = np.where(sims > threshold) 36 | indices2 = np.where(indices1[0] != indices1[1]) 37 | image_ids1 = [image_ids[i] for i in indices1[0][indices2]] 38 | image_ids2 = [image_ids[i] for i in indices1[1][indices2]] 39 | dups = {tuple(sorted([image_id1, image_id2])): True for image_id1, image_id2 in zip(image_ids1, image_ids2)} 40 | print('found %d pairs of duplicates' % len(dups)) 41 | 42 | duplicate_image_ids = sorted(list(dups)) 43 | if verbose: 44 | for pair in duplicate_image_ids: 45 | print('found duplicate image pair:', pair) 46 | 47 | del_list = [] 48 | for pair in duplicate_image_ids: 49 | del1 = pair[0] not in del_list 50 | del2 = pair[1] not in del_list 51 | if del1 and del2: 52 | del_list.append(pair[1]) 53 | if verbose: 54 | print(pair[1], 'deleted') 55 | print('%d of duplicated images deleted' % len(del_list)) 56 | 57 | for del_img in del_list: 58 | imgpath_list.remove(del_img) 59 | 60 | return imgpath_list 61 | -------------------------------------------------------------------------------- /torch_utils/dataset/common_aug.py: -------------------------------------------------------------------------------- 1 | import albumentations 2 | from .randaugment import randAugment 3 | from albumentations import pytorch as AT 4 | 5 | IMAGE_SIZE = 512 6 | 7 | train_transform_randaug = albumentations.Compose([ 8 | albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE), 9 | albumentations.RandomRotate90(p=0.5), # albumentations.SafeRotate(border_mode=1, p=0.5), 10 | albumentations.Transpose(p=0.5), 11 | albumentations.Flip(p=0.5), 12 | randAugment(), 13 | albumentations.Normalize(), 14 | albumentations.CoarseDropout(max_holes=8, max_height=IMAGE_SIZE // 8, max_width=IMAGE_SIZE // 8, fill_value=0, p=0.25), 15 | AT.ToTensorV2(), 16 | ]) 17 | 18 | train_transform = albumentations.Compose([ 19 | albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE), 20 | albumentations.RandomRotate90(p=0.5), 21 | albumentations.Transpose(p=0.5), 22 | albumentations.Flip(p=0.5), 23 | albumentations.OneOf([ 24 | albumentations.RandomBrightnessContrast(0.2, 0.2, p=1), 25 | albumentations.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=20, val_shift_limit=15, p=1), 26 | ], p=0.5), 27 | albumentations.OneOf([ 28 | albumentations.ElasticTransform(alpha=1, sigma=20, alpha_affine=10), 29 | albumentations.GridDistortion(num_steps=6, distort_limit=0.1), 30 | albumentations.OpticalDistortion(distort_limit=0.05, shift_limit=0.05), 31 | ], p=0.25), 32 | albumentations.OneOf([ 33 | albumentations.CLAHE(), 34 | albumentations.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.3)), 35 | albumentations.Sharpen(alpha=(0.1, 0.3), lightness=(0.5, 1.0)), 36 | albumentations.GaussNoise((5, 30)), 37 | albumentations.ImageCompression(30, 90), 38 | ], p=0.5), 39 | albumentations.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.0625, rotate_limit=45, border_mode=1, p=0.5), 40 | albumentations.Normalize(), 41 | albumentations.CoarseDropout(max_holes=8, max_height=IMAGE_SIZE // 8, max_width=IMAGE_SIZE // 8, fill_value=0, p=0.25), 42 | AT.ToTensorV2(), 43 | ]) 44 | 45 | test_transform = albumentations.Compose([ 46 | albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE), 47 | albumentations.Normalize(), 48 | AT.ToTensorV2(), 49 | ]) 50 | -------------------------------------------------------------------------------- /torch_utils/optimizer/lookahead.py: -------------------------------------------------------------------------------- 1 | # https://github.com/alphadl/lookahead.pytorch 2 | 3 | from collections import defaultdict 4 | from itertools import chain 5 | from torch.optim import Optimizer 6 | import torch 7 | import warnings 8 | 9 | 10 | class Lookahead(Optimizer): 11 | def __init__(self, optimizer, k=5, alpha=0.5): 12 | self.optimizer = optimizer 13 | self.k = k 14 | self.alpha = alpha 15 | self.param_groups = self.optimizer.param_groups 16 | self.state = defaultdict(dict) 17 | self.fast_state = self.optimizer.state 18 | for group in self.param_groups: 19 | group["counter"] = 0 20 | 21 | def update(self, group): 22 | for fast in group["params"]: 23 | param_state = self.state[fast] 24 | if "slow_param" not in param_state: 25 | param_state["slow_param"] = torch.zeros_like(fast.data) 26 | param_state["slow_param"].copy_(fast.data) 27 | slow = param_state["slow_param"] 28 | slow += (fast.data - slow) * self.alpha 29 | fast.data.copy_(slow) 30 | 31 | def update_lookahead(self): 32 | for group in self.param_groups: 33 | self.update(group) 34 | 35 | def step(self, closure=None): 36 | loss = self.optimizer.step(closure) 37 | for group in self.param_groups: 38 | if group["counter"] == 0: 39 | self.update(group) 40 | group["counter"] += 1 41 | if group["counter"] >= self.k: 42 | group["counter"] = 0 43 | return loss 44 | 45 | def state_dict(self): 46 | fast_state_dict = self.optimizer.state_dict() 47 | slow_state = { 48 | (id(k) if isinstance(k, torch.Tensor) else k): v 49 | for k, v in self.state.items() 50 | } 51 | fast_state = fast_state_dict["state"] 52 | param_groups = fast_state_dict["param_groups"] 53 | return { 54 | "fast_state": fast_state, 55 | "slow_state": slow_state, 56 | "param_groups": param_groups, 57 | } 58 | 59 | def load_state_dict(self, state_dict): 60 | slow_state_dict = { 61 | "state": state_dict["slow_state"], 62 | "param_groups": state_dict["param_groups"], 63 | } 64 | fast_state_dict = { 65 | "state": state_dict["fast_state"], 66 | "param_groups": state_dict["param_groups"], 67 | } 68 | super(Lookahead, self).load_state_dict(slow_state_dict) 69 | self.optimizer.load_state_dict(fast_state_dict) 70 | self.fast_state = self.optimizer.state 71 | 72 | def add_param_group(self, param_group): 73 | param_group["counter"] = 0 74 | self.optimizer.add_param_group(param_group) 75 | -------------------------------------------------------------------------------- /torch_utils/dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | RandAugment 3 | Paper: https://arxiv.org/abs/1909.13719 4 | Re-implement (changed) using albumentations by seefun 5 | """ 6 | 7 | import numpy as np 8 | import albumentations # albumentations >= 1.0.0 9 | 10 | 11 | def randAugment(N=2, M=4, p=1.0, mode="all", cut_out=False): 12 | """ 13 | Examples: 14 | >>> # M from 0 to 20 15 | >>> transforms = randAugment(N=3, M=8, p=0.8, mode='all', cut_out=False) 16 | """ 17 | # Magnitude(M) search space 18 | scale = np.linspace(0, 0.4, 20) 19 | translate = np.linspace(0, 0.4, 20) 20 | rot = np.linspace(0, 30, 20) 21 | shear_x = np.linspace(0, 20, 20) 22 | shear_y = np.linspace(0, 20, 20) 23 | contrast = np.linspace(0.0, 0.4, 20) 24 | bright = np.linspace(0.0, 0.4, 20) 25 | sat = np.linspace(0.0, 0.2, 20) 26 | hue = np.linspace(0.0, 0.2, 20) 27 | shar = np.linspace(0.0, 0.9, 20) 28 | blur = np.linspace(0, 0.2, 20) 29 | noise = np.linspace(0, 1, 20) 30 | cut = np.linspace(0, 0.6, 20) 31 | # Transformation search space 32 | Aug = [ # geometrical 33 | albumentations.Affine(scale=(1.0 - scale[M], 1.0 + scale[M]), p=p), 34 | albumentations.Affine(translate_percent=(-translate[M], translate[M]), p=p), 35 | albumentations.Affine(rotate=(-rot[M], rot[M]), p=p), 36 | albumentations.Affine(shear={'x': (-shear_x[M], shear_x[M])}, p=p), 37 | albumentations.Affine(shear={'y': (-shear_y[M], shear_y[M])}, p=p), 38 | # Color Based 39 | albumentations.RandomContrast(limit=contrast[M], p=p), 40 | albumentations.RandomBrightness(limit=bright[M], p=p), 41 | albumentations.ColorJitter(brightness=0.0, contrast=0.0, saturation=sat[M], hue=0.0, p=p), 42 | albumentations.ColorJitter(brightness=0.0, contrast=0.0, saturation=0.0, hue=hue[M], p=p), 43 | albumentations.Sharpen(alpha=(0.1, shar[M]), lightness=(0.5, 1.0), p=p), 44 | albumentations.core.composition.PerChannel( 45 | albumentations.OneOf([ 46 | albumentations.MotionBlur(p=0.5), 47 | albumentations.MedianBlur(blur_limit=3, p=1), 48 | albumentations.Blur(blur_limit=3, p=1), ]), p=blur[M] * p), 49 | albumentations.GaussNoise(var_limit=(8.0 * noise[M], 64.0 * noise[M]), per_channel=True, p=p) 50 | ] 51 | # Sampling from the Transformation search space 52 | if mode == "geo": 53 | transforms = albumentations.SomeOf(Aug[0:5], N) 54 | elif mode == "color": 55 | transforms = albumentations.SomeOf(Aug[5:], N) 56 | else: 57 | transforms = albumentations.SomeOf(Aug, N) 58 | 59 | if cut_out: 60 | cut_trans = albumentations.OneOf([ 61 | albumentations.CoarseDropout(max_holes=8, max_height=16, max_width=16, fill_value=0, p=1), 62 | albumentations.GridDropout(ratio=cut[M], p=1), 63 | albumentations.Cutout(num_holes=8, max_h_size=16, max_w_size=16, p=1), 64 | ], p=cut[M]) 65 | transforms = albumentations.Compose([transforms, cut_trans]) 66 | 67 | return transforms 68 | -------------------------------------------------------------------------------- /torch_utils/lr_scheduler/onecycle.py: -------------------------------------------------------------------------------- 1 | # https://github.com/yurayli/fisheries-monitoring/blob/e82688464f0f0a5a038589ccb6957df9a764c697/fish_ssd/one_cycle_lr.py 2 | 3 | import numpy as np 4 | 5 | 6 | class OneCycleScheduler(object): 7 | 8 | def __init__(self, optimizer, epochs, train_loader, accumulation_steps=1, max_lr=3e-3, 9 | moms=(.95, .85), div_factor=25, sep_ratio=0.3, final_div=None): 10 | 11 | self.optimizer = optimizer 12 | 13 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 14 | if len(max_lr) != len(optimizer.param_groups): 15 | raise ValueError("expected {} max_lr, got {}".format( 16 | len(optimizer.param_groups), len(max_lr))) 17 | self.max_lrs = list(max_lr) 18 | self.init_lrs = [lr / div_factor for lr in self.max_lrs] 19 | else: 20 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 21 | self.init_lrs = [max_lr / div_factor] * len(optimizer.param_groups) 22 | 23 | if final_div is None: 24 | final_div = div_factor * 1e4 25 | self.final_lrs = [lr / final_div for lr in self.max_lrs] 26 | self.moms = moms 27 | 28 | total_iteration = epochs * len(train_loader) // accumulation_steps 29 | self.up_iteration = int(total_iteration * sep_ratio) 30 | self.down_iteration = total_iteration - self.up_iteration 31 | 32 | self.curr_iter = 0 33 | self._assign_lr_mom(self.init_lrs, [moms[0]] * len(optimizer.param_groups)) 34 | 35 | def _assign_lr_mom(self, lrs, moms): 36 | for param_group, lr, mom in zip(self.optimizer.param_groups, lrs, moms): 37 | param_group['lr'] = lr 38 | param_group['betas'] = (mom, 0.999) 39 | 40 | def _annealing_cos(self, start, end, pct): 41 | cos_out = np.cos(np.pi * pct) + 1 42 | return end + (start - end) / 2 * cos_out 43 | 44 | def step(self): 45 | self.curr_iter += 1 46 | 47 | if self.curr_iter <= self.up_iteration: 48 | pct = self.curr_iter / self.up_iteration 49 | curr_lrs = [self._annealing_cos(min_lr, max_lr, pct) 50 | for min_lr, max_lr in zip(self.init_lrs, self.max_lrs)] 51 | curr_moms = [self._annealing_cos(self.moms[0], self.moms[1], pct) 52 | for _ in range(len(self.optimizer.param_groups))] 53 | else: 54 | pct = (self.curr_iter - self.up_iteration) / self.down_iteration 55 | curr_lrs = [self._annealing_cos(max_lr, final_lr, pct) 56 | for max_lr, final_lr in zip(self.max_lrs, self.final_lrs)] 57 | curr_moms = [self._annealing_cos(self.moms[1], self.moms[0], pct) 58 | for _ in range(len(self.optimizer.param_groups))] 59 | 60 | self._assign_lr_mom(curr_lrs, curr_moms) 61 | 62 | # the official implementation: 63 | # torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, total_steps=None, epochs=None, 64 | # steps_per_epoch=None, pct_start=0.3, anneal_strategy='cos', 65 | # cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, 66 | # div_factor=25.0, final_div_factor=10000.0, last_epoch=-1) 67 | -------------------------------------------------------------------------------- /torch_utils/models/cls_models/simple_cls_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.cuda.amp import autocast 4 | 5 | from torch_utils.models.layers import FastGlobalConcatPool2d, FastGlobalAvgPool2d, GeM_cw, MultiSampleDropoutFC, SEBlock 6 | from torch_utils.models import create_timm_model 7 | 8 | 9 | class ImageModel(nn.Module): 10 | 11 | def __init__(self, 12 | name='resnest50d', 13 | pretrained=True, 14 | pooling='concat', 15 | fc='multi-dropout', 16 | num_feature=2048, 17 | classes=1, 18 | in_channel=3): 19 | super(ImageModel, self).__init__() 20 | self.model = create_timm_model(name, pretrained, in_channel=in_channel) 21 | 22 | if pooling == 'concat': 23 | self.pooling = FastGlobalConcatPool2d() 24 | num_feature *= 2 25 | elif pooling == 'gem': 26 | self.pooling = GeM_cw(num_feature) 27 | else: 28 | self.pooling = FastGlobalAvgPool2d() 29 | 30 | if fc == 'multi-dropout': 31 | self.fc = nn.Sequential( 32 | MultiSampleDropoutFC(in_ch=num_feature, out_ch=classes)) 33 | 34 | if fc == 'attention': 35 | self.fc = nn.Sequential( 36 | SEBlock(num_feature), 37 | MultiSampleDropoutFC(in_ch=num_feature, out_ch=classes)) 38 | 39 | elif fc == 'dropout': 40 | self.fc = nn.Sequential( 41 | nn.Dropout(0.25), 42 | nn.Linear(num_feature, classes, bias=True)) 43 | 44 | elif fc == '2layers': 45 | self.fc = nn.Sequential( 46 | nn.Linear(num_feature, 512, bias=False), 47 | nn.BatchNorm1d(512), 48 | nn.SiLU(inplace=True), 49 | nn.Dropout(), 50 | nn.Linear(512, classes, bias=True)) 51 | 52 | else: 53 | self.fc = nn.Linear(in_features=num_feature, out_features=classes, bias=True) 54 | 55 | @autocast() 56 | def forward(self, x): 57 | feature_map = self.model(x)[-1] 58 | embedding = self.pooling(feature_map) 59 | logits = self.fc(embedding) 60 | return logits, embedding 61 | 62 | 63 | def get_encoder_last_channel(name='resnest50d', verbose=True): 64 | model = create_timm_model(name, pretrained=False).eval() 65 | features = model(torch.rand(1, 3, 224, 224)) 66 | if verbose: 67 | for i, feat in enumerate(features): 68 | print('Feature [%d], channel num: %d' % (i, feat.shape[1])) 69 | return features[-1].shape[1] 70 | 71 | 72 | def get_conv_model(name='resnest50d', 73 | pretrained=True, 74 | pooling='avg', 75 | fc='multi-dropout', 76 | classes=1, 77 | in_channel=3): 78 | encoder_last_channel = get_encoder_last_channel(name, verbose=False) 79 | conv_model = ImageModel(name=name, 80 | pretrained=pretrained, 81 | pooling=pooling, 82 | fc=fc, 83 | num_feature=encoder_last_channel, 84 | classes=classes, 85 | in_channel=in_channel) 86 | return conv_model 87 | -------------------------------------------------------------------------------- /tests/unit_test/test_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_utils import criterion 3 | 4 | 5 | class TestBinaryLoss: 6 | y_pred = torch.tensor([0.88, -0.10, 0.55]).view(1, 1, 1, -1) 7 | y_true = torch.tensor(([1.0, 0.0, 1.0])).view(1, 1, 1, -1) 8 | 9 | def test_bitempered_loss(self): 10 | loss_fn = criterion.BiTemperedLogisticLoss() 11 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 12 | print(loss) 13 | 14 | def test_focal_loss(self): 15 | loss_fn = criterion.BinaryFocalLoss() 16 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 17 | print(loss) 18 | 19 | def test_lovasz_loss(self): 20 | loss_fn = criterion.BinaryLovaszLoss() 21 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 22 | print(loss) 23 | 24 | def test_smooth_bce(self): 25 | loss_fn = criterion.SmoothBCEwLogits() 26 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 27 | print(loss) 28 | 29 | def test_kld(self): 30 | loss_fn = criterion.KLDivLosswSoftmax() 31 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 32 | print(loss) 33 | 34 | def test_soft_target_ce(self): 35 | loss_fn = criterion.SoftTargetCrossEntropy() 36 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 37 | print(loss) 38 | 39 | def test_toolbelt_ce_binary(self): 40 | loss_fn = criterion.SoftBCEWithLogitsLoss() 41 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 42 | print(loss) 43 | 44 | def test_rmi(self): 45 | loss_fn = criterion.RMILoss() 46 | loss = loss_fn(torch.rand(1, 2, 32, 32), torch.rand(1, 2, 32, 32)) 47 | print(loss) 48 | 49 | def test_dice_binary(self): 50 | loss_fn = criterion.DiceLoss('binary') 51 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 52 | print(loss) 53 | 54 | def test_tversky_binary(self): 55 | loss_fn = criterion.TverskyLoss('binary') 56 | loss = loss_fn(TestBinaryLoss.y_pred, TestBinaryLoss.y_true) 57 | print(loss) 58 | 59 | 60 | class TestMultiLoss: 61 | y_pred = torch.tensor([[+1, -1, -1, -1], 62 | [-1, +1, -1, -1], 63 | [-1, -1, +1, -1], 64 | [-1, -1, -1, +1]]).float() 65 | y_true = torch.tensor([0, 1, 2, 3]).long() 66 | 67 | def test_focal(self): 68 | loss_fn = criterion.FocalLoss() 69 | loss = loss_fn(TestMultiLoss.y_pred, TestMultiLoss.y_true) 70 | print(loss) 71 | 72 | def test_lovasz(self): 73 | loss_fn = criterion.LovaszLoss() 74 | loss = loss_fn(TestMultiLoss.y_pred, TestMultiLoss.y_true) 75 | print(loss) 76 | 77 | def test_smoothing(self): 78 | loss_fn = criterion.LabelSmoothingCrossEntropy() 79 | loss = loss_fn(TestMultiLoss.y_pred, TestMultiLoss.y_true) 80 | print(loss) 81 | 82 | def test_toolbelt_ce_multiply(self): 83 | loss_fn = criterion.SoftCrossEntropyLoss() 84 | loss = loss_fn(TestMultiLoss.y_pred, TestMultiLoss.y_true) 85 | print(loss) 86 | 87 | def test_dice_multiply(self): 88 | loss_fn = criterion.DiceLoss('multiclass') 89 | loss = loss_fn(TestMultiLoss.y_pred, TestMultiLoss.y_true) 90 | print(loss) 91 | 92 | def test_tversky_multiply(self): 93 | loss_fn = criterion.TverskyLoss('multiclass') 94 | loss = loss_fn(TestMultiLoss.y_pred, TestMultiLoss.y_true) 95 | print(loss) 96 | -------------------------------------------------------------------------------- /torch_utils/dataset/customized_aug.py: -------------------------------------------------------------------------------- 1 | # compatible with albumentations 2 | # warning: this code used np.random other than torch.random (the same as albumentations); 3 | # should use tu.tools.worker_init_fn to avoid bug in seed inherit when using pytorch dataloader. 4 | # seefun 2021.10 5 | import math 6 | import numpy as np 7 | from albumentations import ImageOnlyTransform 8 | 9 | 10 | def generate_perlin_noise_2d(shape, res): 11 | assert len(shape) == 2 12 | assert len(res) == 2 13 | if shape[0] % res[0] != 0: 14 | res = [1, res[1]] 15 | if shape[1] % res[1] != 0: 16 | res = [res[0], 1] 17 | 18 | def f(t): 19 | return 6 * t**5 - 15 * t**4 + 10 * t**3 20 | 21 | delta = (res[0] / shape[0], res[1] / shape[1]) 22 | d = (shape[0] // res[0], shape[1] // res[1]) 23 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 24 | # Gradients 25 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 26 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 27 | g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 28 | g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1) 29 | g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1) 30 | g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1) 31 | # Ramps 32 | n00 = np.sum(grid * g00, 2) 33 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2) 34 | n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2) 35 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2) 36 | # Interpolation 37 | t = f(grid) 38 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 39 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 40 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 41 | 42 | 43 | def random_perlin_brightness(image, noise_strength=0, max_delta=0.4): 44 | # noise_strength in [0,1,2,....] 45 | noise_strength = 2**noise_strength 46 | h = image.shape[0] 47 | w = image.shape[1] 48 | h_gen = math.ceil(h / noise_strength) * noise_strength 49 | w_gen = math.ceil(w / noise_strength) * noise_strength 50 | noise = generate_perlin_noise_2d((h_gen, w_gen), (noise_strength, noise_strength)) 51 | if image.dtype == np.uint8: 52 | noise = noise[:h, :w] * 255 * max_delta 53 | noise = noise[:, :, np.newaxis] 54 | image = image.astype(np.float32) + noise 55 | image = np.clip(image, 0, 255) 56 | return image.astype(np.uint8) 57 | else: 58 | noise = noise[:h, :w] * max_delta 59 | noise = noise[:, :, np.newaxis] 60 | image = image.astype(np.float32) + noise 61 | image = np.clip(image, 0, 1) 62 | return image.astype(np.float32) 63 | 64 | 65 | class RandomBrightnessNoise(ImageOnlyTransform): 66 | """ Add random 2d perlin noise to the image 67 | Args: 68 | noise_strength (int): resolution of perlin noise from [0,1,2,3,...], 0,1,2 suggested, the bigger the stronger; 69 | max_delta (float): max brightness deleta in [0,1], the bigger the stronger; 70 | Targets: 71 | image 72 | Image types: 73 | uint8, float32 74 | """ 75 | 76 | def __init__( 77 | self, 78 | noise_strength=0, 79 | max_delta=0.4, 80 | always_apply=False, 81 | p=1.0, 82 | ): 83 | super(RandomBrightnessNoise, self).__init__(always_apply, p) 84 | self.noise_strength = noise_strength 85 | self.max_delta = max_delta 86 | 87 | def apply(self, image, **params): 88 | return random_perlin_brightness(image, self.noise_strength, self.max_delta) 89 | 90 | def get_transform_init_args_names(self): 91 | return ("noise_strength", "max_delta") 92 | -------------------------------------------------------------------------------- /torch_utils/lr_scheduler/cosine_annealing_with_warmup.py: -------------------------------------------------------------------------------- 1 | # https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup 2 | 3 | import math 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class CosineAnnealingWarmupRestarts(_LRScheduler): 9 | """ 10 | optimizer (Optimizer): Wrapped optimizer. 11 | first_cycle_steps (int): First cycle step size. 12 | cycle_mult(float): Cycle steps magnification. Default: -1. 13 | max_lr(float): First cycle's max learning rate. Default: 0.1. 14 | min_lr(float): Min learning rate. Default: 0.001. 15 | warmup_steps(int): Linear warmup step size. Default: 0. 16 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 17 | last_epoch (int): The index of last epoch. Default: -1. 18 | """ 19 | 20 | def __init__(self, 21 | optimizer: torch.optim.Optimizer, 22 | first_cycle_steps: int, 23 | cycle_mult: float = 1., 24 | max_lr: float = 0.1, 25 | min_lr: float = 0.001, 26 | warmup_steps: int = 0, 27 | gamma: float = 1., 28 | last_epoch: int = -1 29 | ): 30 | assert warmup_steps < first_cycle_steps 31 | 32 | self.first_cycle_steps = first_cycle_steps # first cycle step size 33 | self.cycle_mult = cycle_mult # cycle steps magnification 34 | self.base_max_lr = max_lr # first max learning rate 35 | self.max_lr = max_lr # max learning rate in the current cycle 36 | self.min_lr = min_lr # min learning rate 37 | self.warmup_steps = warmup_steps # warmup step size 38 | self.gamma = gamma # decrease rate of max learning rate by cycle 39 | 40 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 41 | self.cycle = 0 # cycle count 42 | self.step_in_cycle = last_epoch # step size of the current cycle 43 | 44 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 45 | 46 | # set learning rate min_lr 47 | self.init_lr() 48 | 49 | def init_lr(self): 50 | self.base_lrs = [] 51 | for param_group in self.optimizer.param_groups: 52 | param_group['lr'] = self.min_lr 53 | self.base_lrs.append(self.min_lr) 54 | 55 | def get_lr(self): 56 | if self.step_in_cycle == -1: 57 | return self.base_lrs 58 | elif self.step_in_cycle < self.warmup_steps: 59 | return [(self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 60 | else: 61 | return [base_lr + (self.max_lr - base_lr) 62 | * (1 + math.cos(math.pi * (self.step_in_cycle - self.warmup_steps) 63 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 64 | for base_lr in self.base_lrs] 65 | 66 | def step(self, epoch=None): 67 | if epoch is None: 68 | epoch = self.last_epoch + 1 69 | self.step_in_cycle = self.step_in_cycle + 1 70 | if self.step_in_cycle >= self.cur_cycle_steps: 71 | self.cycle += 1 72 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 73 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 74 | else: 75 | if epoch >= self.first_cycle_steps: 76 | if self.cycle_mult == 1.: 77 | self.step_in_cycle = epoch % self.first_cycle_steps 78 | self.cycle = epoch // self.first_cycle_steps 79 | else: 80 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 81 | self.cycle = n 82 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 83 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 84 | else: 85 | self.cur_cycle_steps = self.first_cycle_steps 86 | self.step_in_cycle = epoch 87 | 88 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 89 | self.last_epoch = math.floor(epoch) 90 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 91 | param_group['lr'] = lr 92 | 93 | # the official implementation: 94 | # torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1) 95 | -------------------------------------------------------------------------------- /torch_utils/advanced/dolg.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/dongkyuk/DOLG-pytorch/blob/main/model/dolg.py 2 | # reference: https://arxiv.org/pdf/2108.02927.pdf 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import timm 9 | from torch_utils.models.layers import GeM 10 | 11 | 12 | class MultiAtrous(nn.Module): 13 | def __init__(self, in_channel, out_channel, size, dilation_rates=[3, 6, 9]): 14 | super().__init__() 15 | self.dilated_convs = [ 16 | nn.Conv2d(in_channel, int(out_channel / 4), 17 | kernel_size=3, dilation=rate, padding=rate) 18 | for rate in dilation_rates 19 | ] 20 | self.gap_branch = nn.Sequential( 21 | nn.AdaptiveAvgPool2d(1), 22 | nn.Conv2d(in_channel, int(out_channel / 4), kernel_size=1), 23 | nn.ReLU(), 24 | nn.Upsample(size=(size, size), mode='bilinear', align_corners=False) 25 | ) 26 | self.dilated_convs.append(self.gap_branch) 27 | self.dilated_convs = nn.ModuleList(self.dilated_convs) 28 | 29 | def forward(self, x): 30 | local_feat = [] 31 | for dilated_conv in self.dilated_convs: 32 | local_feat.append(dilated_conv(x)) 33 | local_feat = torch.cat(local_feat, dim=1) 34 | return local_feat 35 | 36 | 37 | class DolgLocalBranch(nn.Module): 38 | def __init__(self, image_size, in_channel, out_channel, hidden_channel=2048): 39 | super().__init__() 40 | self.multi_atrous = MultiAtrous(in_channel, hidden_channel, size=int(image_size / 8)) 41 | self.conv1x1_1 = nn.Conv2d(hidden_channel, out_channel, kernel_size=1) 42 | self.conv1x1_2 = nn.Conv2d( 43 | out_channel, out_channel, kernel_size=1, bias=False) 44 | self.conv1x1_3 = nn.Conv2d(out_channel, out_channel, kernel_size=1) 45 | 46 | self.relu = nn.ReLU() 47 | self.bn = nn.BatchNorm2d(out_channel) 48 | self.softplus = nn.Softplus() 49 | 50 | def forward(self, x): 51 | local_feat = self.multi_atrous(x) 52 | 53 | local_feat = self.conv1x1_1(local_feat) 54 | local_feat = self.relu(local_feat) 55 | local_feat = self.conv1x1_2(local_feat) 56 | local_feat = self.bn(local_feat) 57 | 58 | attention_map = self.relu(local_feat) 59 | attention_map = self.conv1x1_3(attention_map) 60 | attention_map = self.softplus(attention_map) 61 | 62 | local_feat = F.normalize(local_feat, p=2, dim=1) 63 | local_feat = local_feat * attention_map 64 | 65 | return local_feat 66 | 67 | 68 | class OrthogonalFusion(nn.Module): 69 | def __init__(self): 70 | super().__init__() 71 | 72 | def forward(self, local_feat, global_feat): 73 | global_feat_norm = torch.norm(global_feat, p=2, dim=1) 74 | projection = torch.bmm(global_feat.unsqueeze(1), torch.flatten( 75 | local_feat, start_dim=2)) 76 | projection = torch.bmm(global_feat.unsqueeze( 77 | 2), projection).view(local_feat.size()) 78 | projection = projection / \ 79 | (global_feat_norm * global_feat_norm).view(-1, 1, 1, 1) 80 | orthogonal_comp = local_feat - projection 81 | global_feat = global_feat.unsqueeze(-1).unsqueeze(-1) 82 | return torch.cat([global_feat.expand(orthogonal_comp.size()), orthogonal_comp], dim=1) 83 | 84 | 85 | class DolgNet(nn.Module): 86 | def __init__(self, backbone, pretrained, image_size, input_dim, hidden_dim, output_dim): 87 | super().__init__() 88 | self.cnn = timm.create_model( 89 | backbone, 90 | pretrained=pretrained, 91 | features_only=True, 92 | in_chans=input_dim, 93 | out_indices=(2, 3) 94 | ) 95 | self.orthogonal_fusion = OrthogonalFusion() 96 | self.local_branch = DolgLocalBranch(image_size, 512, hidden_dim) # 512 is decided by model 97 | self.gap = nn.AdaptiveAvgPool2d(1) 98 | self.gem_pool = GeM() 99 | self.fc_1 = nn.Linear(1024, hidden_dim) # 1024 is decided by model 100 | self.fc_2 = nn.Linear(int(2 * hidden_dim), output_dim) 101 | 102 | def forward(self, x): 103 | output = self.cnn(x) 104 | 105 | local_feat = self.local_branch(output[0]) # ,hidden_channel,16,16 106 | global_feat = self.fc_1(self.gem_pool(output[1])) # ,1024 107 | 108 | feat = self.orthogonal_fusion(local_feat, global_feat) 109 | feat = self.gap(feat).squeeze() 110 | feat = self.fc_2(feat) 111 | 112 | return feat 113 | -------------------------------------------------------------------------------- /torch_utils/advanced/arcface.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from torch.autograd import Variable 8 | 9 | 10 | __all__ = ["ArcMarginProduct", "ArcFaceLoss", "ArcMarginProduct_subcenter", "ArcFaceLossAdaptiveMargin"] 11 | 12 | 13 | class ArcMarginProduct(nn.Module): 14 | r"""Implement of large margin arc distance: : 15 | Args: 16 | in_features: size of each input sample 17 | out_features: size of each output sample 18 | s: norm of input feature 19 | m: margin 20 | cos(theta + m) 21 | """ 22 | 23 | def __init__(self, in_features, out_features): 24 | super(ArcMarginProduct, self).__init__() 25 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 26 | # nn.init.xavier_uniform_(self.weight) 27 | self.reset_parameters() 28 | 29 | def reset_parameters(self): 30 | stdv = 1. / math.sqrt(self.weight.size(1)) 31 | self.weight.data.uniform_(-stdv, stdv) 32 | 33 | def forward(self, features): 34 | cosine = F.linear(F.normalize(features), F.normalize(self.weight)) 35 | return cosine 36 | 37 | 38 | class ArcFaceLoss(nn.modules.Module): 39 | def __init__(self, s=30.0, m=0.5): 40 | super(ArcFaceLoss, self).__init__() 41 | self.classify_loss = nn.CrossEntropyLoss() 42 | self.s = s 43 | self.easy_margin = False 44 | self.cos_m = math.cos(m) 45 | self.sin_m = math.sin(m) 46 | self.th = math.cos(math.pi - m) 47 | self.mm = math.sin(math.pi - m) * m 48 | 49 | def forward(self, logits, labels, epoch=0): 50 | cosine = logits 51 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 52 | phi = cosine * self.cos_m - sine * self.sin_m 53 | if self.easy_margin: 54 | phi = torch.where(cosine > 0, phi, cosine) 55 | else: 56 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 57 | 58 | one_hot = torch.zeros(cosine.size(), device='cuda') 59 | one_hot.scatter_(1, labels.view(-1, 1).long(), 1) 60 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 61 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 62 | output *= self.s 63 | loss1 = self.classify_loss(output, labels) 64 | loss2 = self.classify_loss(cosine, labels) 65 | gamma = 1 66 | loss = (loss1 + gamma * loss2) / (1 + gamma) 67 | return loss 68 | 69 | 70 | class DenseCrossEntropy(nn.Module): 71 | def forward(self, x, target): 72 | x = x.float() 73 | target = target.float() 74 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 75 | 76 | loss = -logprobs * target 77 | loss = loss.sum(-1) 78 | return loss.mean() 79 | 80 | 81 | class ArcMarginProduct_subcenter(nn.Module): 82 | # https://github.com/haqishen/Google-Landmark-Recognition-2020-3rd-Place-Solution/blob/main/models.py 83 | # https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf 84 | def __init__(self, in_features, out_features, k=3): 85 | super().__init__() 86 | self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features)) 87 | self.reset_parameters() 88 | self.k = k 89 | self.out_features = out_features 90 | 91 | def reset_parameters(self): 92 | stdv = 1. / math.sqrt(self.weight.size(1)) 93 | self.weight.data.uniform_(-stdv, stdv) 94 | 95 | def forward(self, features): 96 | cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) 97 | cosine_all = cosine_all.view(-1, self.out_features, self.k) 98 | cosine, _ = torch.max(cosine_all, dim=2) 99 | return cosine 100 | 101 | 102 | class ArcFaceLossAdaptiveMargin(nn.Module): 103 | def __init__(self, margins, s=30.0): 104 | super().__init__() 105 | self.crit = DenseCrossEntropy() 106 | self.s = s 107 | self.margins = margins 108 | 109 | def forward(self, logits, labels, out_dim): 110 | ms = [] 111 | ms = self.margins[labels.cpu().numpy()] 112 | cos_m = torch.from_numpy(np.cos(ms)).float().cuda() 113 | sin_m = torch.from_numpy(np.sin(ms)).float().cuda() 114 | th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda() 115 | mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda() 116 | labels = F.one_hot(labels, out_dim).float() 117 | logits = logits.float() 118 | cosine = logits 119 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 120 | phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1) 121 | phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1)) 122 | output = (labels * phi) + ((1.0 - labels) * cosine) 123 | output *= self.s 124 | loss = self.crit(output, labels) 125 | return loss 126 | -------------------------------------------------------------------------------- /torch_utils/optimizer/timm_optim/lars.py: -------------------------------------------------------------------------------- 1 | """ PyTorch LARS / LARC Optimizer 2 | 3 | An implementation of LARS (SGD) + LARC in PyTorch 4 | 5 | Based on: 6 | * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 7 | * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py 8 | 9 | Additional cleanup and modifications to properly support PyTorch XLA. 10 | 11 | Copyright 2021 Ross Wightman 12 | """ 13 | import torch 14 | from torch.optim.optimizer import Optimizer 15 | 16 | 17 | class Lars(Optimizer): 18 | """ LARS for PyTorch 19 | 20 | Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf 21 | 22 | Args: 23 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups. 24 | lr (float, optional): learning rate (default: 1.0). 25 | momentum (float, optional): momentum factor (default: 0) 26 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 27 | dampening (float, optional): dampening for momentum (default: 0) 28 | nesterov (bool, optional): enables Nesterov momentum (default: False) 29 | trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) 30 | eps (float): eps for division denominator (default: 1e-8) 31 | trust_clip (bool): enable LARC trust ratio clipping (default: False) 32 | always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False) 33 | """ 34 | 35 | def __init__( 36 | self, 37 | params, 38 | lr=1.0, 39 | momentum=0, 40 | dampening=0, 41 | weight_decay=0, 42 | nesterov=False, 43 | trust_coeff=0.001, 44 | eps=1e-8, 45 | trust_clip=False, 46 | always_adapt=False, 47 | ): 48 | if lr < 0.0: 49 | raise ValueError(f"Invalid learning rate: {lr}") 50 | if momentum < 0.0: 51 | raise ValueError(f"Invalid momentum value: {momentum}") 52 | if weight_decay < 0.0: 53 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 54 | if nesterov and (momentum <= 0 or dampening != 0): 55 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 56 | 57 | defaults = dict( 58 | lr=lr, 59 | momentum=momentum, 60 | dampening=dampening, 61 | weight_decay=weight_decay, 62 | nesterov=nesterov, 63 | trust_coeff=trust_coeff, 64 | eps=eps, 65 | trust_clip=trust_clip, 66 | always_adapt=always_adapt, 67 | ) 68 | super().__init__(params, defaults) 69 | 70 | def __setstate__(self, state): 71 | super().__setstate__(state) 72 | for group in self.param_groups: 73 | group.setdefault("nesterov", False) 74 | 75 | @torch.no_grad() 76 | def step(self, closure=None): 77 | """Performs a single optimization step. 78 | 79 | Args: 80 | closure (callable, optional): A closure that reevaluates the model and returns the loss. 81 | """ 82 | loss = None 83 | if closure is not None: 84 | with torch.enable_grad(): 85 | loss = closure() 86 | 87 | device = self.param_groups[0]['params'][0].device 88 | one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly 89 | 90 | for group in self.param_groups: 91 | weight_decay = group['weight_decay'] 92 | momentum = group['momentum'] 93 | dampening = group['dampening'] 94 | nesterov = group['nesterov'] 95 | trust_coeff = group['trust_coeff'] 96 | eps = group['eps'] 97 | 98 | for p in group['params']: 99 | if p.grad is None: 100 | continue 101 | grad = p.grad 102 | 103 | # apply LARS LR adaptation, LARC clipping, weight decay 104 | # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py 105 | if weight_decay != 0 or group['always_adapt']: 106 | w_norm = p.norm(2.0) 107 | g_norm = grad.norm(2.0) 108 | trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) 109 | # FIXME nested where required since logical and/or not working in PT XLA 110 | trust_ratio = torch.where( 111 | w_norm > 0, 112 | torch.where(g_norm > 0, trust_ratio, one_tensor), 113 | one_tensor, 114 | ) 115 | if group['trust_clip']: 116 | trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) 117 | grad.add(p, alpha=weight_decay) 118 | grad.mul_(trust_ratio) 119 | 120 | # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 121 | if momentum != 0: 122 | param_state = self.state[p] 123 | if 'momentum_buffer' not in param_state: 124 | buf = param_state['momentum_buffer'] = torch.clone(grad).detach() 125 | else: 126 | buf = param_state['momentum_buffer'] 127 | buf.mul_(momentum).add_(grad, alpha=1. - dampening) 128 | if nesterov: 129 | grad = grad.add(buf, alpha=momentum) 130 | else: 131 | grad = buf 132 | 133 | p.add_(grad, alpha=-group['lr']) 134 | 135 | return loss 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TorchUtils 3 | 4 | TorchUtils is a pytorch lib with several useful tools and training tricks. (Work In Progress) 5 | 6 | ## Install 7 | ```bash 8 | git clone https://github.com/seefun/TorchUtils.git 9 | cd TorchUtils 10 | ``` 11 | ```bash 12 | pip install -r requirements.txt 13 | pip install . 14 | ``` 15 | 16 | ## Import 17 | 18 | ``` 19 | import torch_utils as tu 20 | ``` 21 | 22 | 23 | ## Seed All 24 | 25 | ``` 26 | SEED = 42 27 | tu.tools.seed_everything(SEED) 28 | ``` 29 | 30 | ## Data Augmentation 31 | 32 | ``` 33 | import albumentations 34 | from albumentations import pytorch as AT 35 | train_transform = albumentations.Compose([ 36 | albumentations.Resize(IMAGE_SIZE, IMAGE_SIZE), 37 | albumentations.HorizontalFlip(p=0.5), 38 | tu.dataset.randAugment(image_size=IMAGE_SIZE, N=2, M=12, p=0.9, mode='all', cut_out=False), 39 | albumentations.Normalize(), 40 | albumentations.CoarseDropout(max_holes=8, max_height=IMAGE_SIZE // 8, max_width=IMAGE_SIZE // 8, fill_value=0, p=0.25), 41 | AT.ToTensorV2(), 42 | ]) 43 | 44 | mixup_dataset = tu.dataset.MixupDataset(dataset, alpha=0.2, prob=0.2, mixup_to_cutmix=0.25) 45 | # 0.15 mixup and 0.05 cutmix 46 | ``` 47 | 48 | ## Model 49 | 50 | fast build models with torch_utils: 51 | ``` 52 | model = tu.ImageModel(name='resnest50d', pretrained=True, 53 | pooling='concat', fc='multi-dropout', 54 | num_feature=2048, classes=1) 55 | model.cuda() 56 | ``` 57 | 58 | using other libs along with torch_utils: 59 | ``` 60 | import timm 61 | 62 | model = timm.create_model('tresnet_m', pretrained=True) 63 | model.global_pool = tu.layers.FastGlobalConcatPool2d(flatten=True) 64 | model.head = tu.layers.get_attention_fc(2048*2, 1) 65 | model.cuda() 66 | ``` 67 | 68 | ``` 69 | from pytorchcv.model_provider import get_model as ptcv_get_model 70 | 71 | model = ptcv_get_model('seresnext50_32x4d', pretrained=True) 72 | model.features.final_pool = tu.layers.GeM() 73 | model.output = tu.layers.get_simple_fc(2048, 1) 74 | model.cuda() 75 | ``` 76 | 77 | segmentation models: 78 | ``` 79 | hrnet = tu.get_hrnet(name='hrnet_w18', out_channel=1, pretrained=True).cuda() 80 | unet = tu.get_unet(name='resnest50d', out_channel=1, aspp=False, pretrained=True).cuda() 81 | ``` 82 | 83 | recommanded pretrained models: 84 | 85 | - [ResNeSt](https://github.com/zhanghang1989/ResNeSt) 86 | - SEResNext-50 87 | - GPU-Efficient 88 | - swsl_ResNeXt 89 | - [BiT/ResNetV2](https://github.com/google-research/big_transfer) 90 | - [TResNet](https://github.com/mrT23/TResNet) 91 | - EfficientNet_ns 92 | - [ResNext_WSL](https://github.com/facebookresearch/WSL-Images) 93 | - MixNet 94 | - SKNet 95 | - [SGENet](https://github.com/implus/PytorchInsight) 96 | - [HRNet](https://github.com/HRNet) 97 | - Res2Net 98 | 99 | 100 | recommanded github repos: 101 | 102 | - [pytorch-image-models(timm)](https://github.com/rwightman/pytorch-image-models) 103 | - [imgclsmob(pytorchcv)](https://github.com/osmr/imgclsmob/tree/master/pytorch) 104 | - [gen-efficientnet-pytorch](https://github.com/rwightman/gen-efficientnet-pytorch) 105 | - [efficientnet-pytorch](https://github.com/lukemelas/EfficientNet-PyTorch) 106 | - [pytorch-encoding](https://github.com/zhanghang1989/PyTorch-Encoding) 107 | - [pretrained-models-pytorch](https://github.com/Cadene/pretrained-models.pytorch) 108 | 109 | model utils: 110 | ``` 111 | # model summary 112 | tu.summary(model, input_size=(batch_size, 3, 224, 224)) 113 | # macs and flops 114 | tu.profile(model, input_shape=(batch_size, 3, 224, 224)) 115 | 116 | # 3 channels pretrained weights to 1 channel 117 | weight_rgb = model.conv1.weight.data 118 | weight_grey = weight_rgb.sum(dim=1, keepdim=True) 119 | model.conv1 = nn.Conv2d(1, 64, kernel_size=xxx, stride=xxx, padding=xxx, bias=False) 120 | model.conv1.weight.data = weight_grey 121 | 122 | # 3 channels pretrained weights to 4 channel 123 | weight_rgb = model.conv1.weight.data 124 | weight_y = weight_rgb.mean(dim=1, keepdim=True) 125 | weight_rgby = torch.cat([weight_rgb,weight_y], axis=1) * 3 / 4 126 | model.conv1 = nn.Conv2d(4, 64, kernel_size=xxx, stride=xxx, padding=xxx, bias=False) 127 | model.conv1.weight.data = weight_rgby 128 | 129 | # 2D models to 3d models using ACSConv (advanced) 130 | ## using code in this repo: https://github.com/M3DV/ACSConv 131 | ``` 132 | 133 | 134 | ## Optimizer 135 | ``` 136 | optimizer_ranger = tu.Ranger(model_conv.parameters(), lr=LR) 137 | # optimizer = torch.optim.AdamW(model_conv.parameters(), lr=LR, weight_decay=2e-4) 138 | ``` 139 | 140 | 141 | ## Criterion 142 | ``` 143 | # for example: 144 | criterion = tu.SmoothBCEwLogits(smoothing=0.02) 145 | 146 | # criterion = tu.LabelSmoothingCrossEntropy() 147 | ``` 148 | 149 | 150 | ## Find LR 151 | ``` 152 | lr_finder = tu.LRFinder(model, optimizer, criterion, device="cuda") 153 | lr_finder.range_test(train_loader, end_lr=10, num_iter=500, accumulation_steps=1) 154 | lr_finder.plot() # to inspect the loss-learning rate graph 155 | lr_finder.reset() # to reset the model and optimizer to their initial state 156 | ``` 157 | 158 | 159 | ## LR Scheduler 160 | ``` 161 | scheduler = tu.get_flat_anneal_scheduler(optimizer, max_iter, warmup_iter=0, decay_start=0.5, anneal='cos', gamma=0.05) 162 | 163 | # scheduler = tu.CosineAnnealingWarmUpRestarts(optimizer, T_0=T, T_mult=1, eta_max=LR, T_up=0, gamma=0.05) 164 | # torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1) 165 | # torch.optim.lr_scheduler.OneCycleLR or tu.OneCycleScheduler 166 | ``` 167 | 168 | 169 | ## AMP 170 | 171 | Ref: https://pytorch.org/docs/master/notes/amp_examples.html 172 | 173 | 174 | ## TODO 175 | - [ ] add unit test for models 176 | - [x] add Hybrid Vision Transformer 177 | - [ ] channels_last 178 | - [ ] inplace_abn 179 | - [ ] grad-CAM 180 | - [ ] convert [paddle ssld model](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/docs/zh_CN/algorithm_introduction/ImageNet_models.md#ssld%E7%9F%A5%E8%AF%86%E8%92%B8%E9%A6%8F%E9%A2%84%E8%AE%AD%E7%BB%83%E6%A8%A1%E5%9E%8B) to pytorch 181 | -------------------------------------------------------------------------------- /torch_utils/criterion/rmi.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/RElbers/region-mutual-information-pytorch 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | EPSILON = 0.0005 7 | 8 | 9 | class RMILoss(nn.Module): 10 | """ 11 | PyTorch Module which calculates the Region Mutual Information loss (https://arxiv.org/abs/1910.12037). 12 | """ 13 | 14 | def __init__(self, 15 | with_logits=True, 16 | radius=3, 17 | bce_weight=0.5, 18 | downsampling_method='max', 19 | stride=3, 20 | use_log_trace=True, 21 | use_double_precision=True, 22 | epsilon=EPSILON): 23 | """ 24 | :param with_logits: 25 | If True, apply the sigmoid function to the prediction before calculating loss. 26 | :param radius: 27 | RMI radius. 28 | :param bce_weight: 29 | Weight of the binary cross entropy. Must be between 0 and 1. 30 | :param downsampling_method: 31 | Downsampling method used before calculating RMI. Must be one of ['avg', 'max', 'region-extraction']. 32 | If 'region-extraction', then downscaling is done during the region extraction phase. 33 | Meaning that the stride is the spacing between consecutive regions. 34 | :param stride: 35 | Stride used for downsampling. 36 | :param use_log_trace: 37 | Whether to calculate the log of the trace, instead of the log of the determinant. See equation (15). 38 | :param use_double_precision: 39 | Calculate the RMI using doubles in order to fix potential numerical issues. 40 | :param epsilon: 41 | Magnitude of the entries added to the diagonal of M in order to fix potential numerical issues. 42 | """ 43 | super().__init__() 44 | 45 | self.use_double_precision = use_double_precision 46 | self.with_logits = with_logits 47 | self.bce_weight = bce_weight 48 | self.stride = stride 49 | self.downsampling_method = downsampling_method 50 | self.radius = radius 51 | self.use_log_trace = use_log_trace 52 | self.epsilon = epsilon 53 | 54 | def forward(self, input, target): 55 | # Calculate BCE if needed 56 | if self.bce_weight != 0: 57 | if self.with_logits: 58 | bce = F.binary_cross_entropy_with_logits(input, target=target) 59 | else: 60 | bce = F.binary_cross_entropy(input, target=target) 61 | bce = bce.mean() * self.bce_weight 62 | else: 63 | bce = 0.0 64 | 65 | # Apply sigmoid to get probabilities. See final paragraph of section 4. 66 | if self.with_logits: 67 | input = torch.sigmoid(input) 68 | 69 | # Calculate RMI loss 70 | rmi = self.rmi_loss(input=input, target=target) 71 | rmi = rmi.mean() * (1.0 - self.bce_weight) 72 | return rmi + bce 73 | 74 | def rmi_loss(self, input, target): 75 | """ 76 | Calculates the RMI loss between the prediction and target. 77 | 78 | :return: 79 | RMI loss 80 | """ 81 | 82 | assert input.shape == target.shape 83 | vector_size = self.radius * self.radius 84 | 85 | # Get region vectors 86 | y = self.extract_region_vector(target) 87 | p = self.extract_region_vector(input) 88 | 89 | # Convert to doubles for better precision 90 | if self.use_double_precision: 91 | y = y.double() 92 | p = p.double() 93 | 94 | # Small diagonal matrix to fix numerical issues 95 | eps = torch.eye(vector_size, dtype=y.dtype, device=y.device) * self.epsilon 96 | eps = eps.unsqueeze(dim=0).unsqueeze(dim=0) 97 | 98 | # Subtract mean 99 | y = y - y.mean(dim=3, keepdim=True) 100 | p = p - p.mean(dim=3, keepdim=True) 101 | 102 | # Covariances 103 | y_cov = y @ transpose(y) 104 | p_cov = p @ transpose(p) 105 | y_p_cov = y @ transpose(p) 106 | 107 | # Approximated posterior covariance matrix of Y given P 108 | m = y_cov - y_p_cov @ transpose(inverse(p_cov + eps)) @ transpose(y_p_cov) 109 | 110 | # Lower bound of RMI 111 | if self.use_log_trace: 112 | rmi = 0.5 * log_trace(m + eps) 113 | else: 114 | rmi = 0.5 * log_det(m + eps) 115 | 116 | # Normalize 117 | rmi = rmi / float(vector_size) 118 | 119 | # Sum over classes, mean over samples. 120 | return rmi.sum(dim=1).mean(dim=0) 121 | 122 | def extract_region_vector(self, x): 123 | """ 124 | Downsamples and extracts square regions from x. 125 | Returns the flattened vectors of length radius*radius. 126 | """ 127 | 128 | x = self.downsample(x) 129 | stride = self.stride if self.downsampling_method == 'region-extraction' else 1 130 | 131 | x_regions = F.unfold(x, kernel_size=self.radius, stride=stride) 132 | x_regions = x_regions.view((*x.shape[:2], self.radius ** 2, -1)) 133 | return x_regions 134 | 135 | def downsample(self, x): 136 | # Skip if stride is 1 137 | if self.stride == 1: 138 | return x 139 | 140 | # Skip if we pool during region extraction. 141 | if self.downsampling_method == 'region-extraction': 142 | return x 143 | 144 | padding = self.stride // 2 145 | if self.downsampling_method == 'max': 146 | return F.max_pool2d(x, kernel_size=self.stride, stride=self.stride, padding=padding) 147 | if self.downsampling_method == 'avg': 148 | return F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride, padding=padding) 149 | raise ValueError(self.downsampling_method) 150 | 151 | 152 | def transpose(x): 153 | return x.transpose(-2, -1) 154 | 155 | 156 | def inverse(x): 157 | return torch.inverse(x) 158 | 159 | 160 | def log_trace(x): 161 | x = torch.linalg.cholesky(x) 162 | diag = torch.diagonal(x, dim1=-2, dim2=-1) 163 | return 2 * torch.sum(torch.log(diag + 1e-8), dim=-1) 164 | 165 | 166 | def log_det(x): 167 | return torch.logdet(x) 168 | -------------------------------------------------------------------------------- /torch_utils/optimizer/over9000.py: -------------------------------------------------------------------------------- 1 | # https://github.com/mgrankin/over9000 2 | 3 | import torch 4 | import math 5 | import itertools as it 6 | from torch.optim import Optimizer, Adam 7 | 8 | 9 | class Lookahead(Optimizer): 10 | def __init__(self, base_optimizer, alpha=0.5, k=6): 11 | if not 0.0 <= alpha <= 1.0: 12 | raise ValueError(f'Invalid slow update rate: {alpha}') 13 | if not 1 <= k: 14 | raise ValueError(f'Invalid lookahead steps: {k}') 15 | self.optimizer = base_optimizer 16 | self.param_groups = self.optimizer.param_groups 17 | self.alpha = alpha 18 | self.k = k 19 | for group in self.param_groups: 20 | group["step_counter"] = 0 21 | self.slow_weights = [[p.clone().detach() for p in group['params']] 22 | for group in self.param_groups] 23 | 24 | for w in it.chain(*self.slow_weights): 25 | w.requires_grad = False 26 | 27 | def step(self, closure=None): 28 | loss = None 29 | if closure is not None: 30 | loss = closure() 31 | loss = self.optimizer.step() 32 | for group, slow_weights in zip(self.param_groups, self.slow_weights): 33 | group['step_counter'] += 1 34 | if group['step_counter'] % self.k != 0: 35 | continue 36 | for p, q in zip(group['params'], slow_weights): 37 | if p.grad is None: 38 | continue 39 | q.data.add_(self.alpha, p.data - q.data) 40 | p.data.copy_(q.data) 41 | return loss 42 | 43 | 44 | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): 45 | adam = Adam(params, *args, **kwargs) 46 | return Lookahead(adam, alpha, k) 47 | 48 | # RAdam + LARS 49 | 50 | 51 | class Ralamb(Optimizer): 52 | 53 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 54 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 55 | self.buffer = [[None, None, None] for ind in range(10)] 56 | super(Ralamb, self).__init__(params, defaults) 57 | 58 | def __setstate__(self, state): 59 | super(Ralamb, self).__setstate__(state) 60 | 61 | def step(self, closure=None): 62 | 63 | loss = None 64 | if closure is not None: 65 | loss = closure() 66 | 67 | for group in self.param_groups: 68 | 69 | for p in group['params']: 70 | if p.grad is None: 71 | continue 72 | grad = p.grad.data.float() 73 | if grad.is_sparse: 74 | raise RuntimeError('Ralamb does not support sparse gradients') 75 | 76 | p_data_fp32 = p.data.float() 77 | 78 | state = self.state[p] 79 | 80 | if len(state) == 0: 81 | state['step'] = 0 82 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 83 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 84 | else: 85 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 86 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 87 | 88 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 89 | beta1, beta2 = group['betas'] 90 | 91 | # Decay the first and second moment running average coefficient 92 | # m_t 93 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 94 | # v_t 95 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 96 | 97 | state['step'] += 1 98 | buffered = self.buffer[int(state['step'] % 10)] 99 | 100 | if state['step'] == buffered[0]: 101 | N_sma, radam_step = buffered[1], buffered[2] 102 | else: 103 | buffered[0] = state['step'] 104 | beta2_t = beta2 ** state['step'] 105 | N_sma_max = 2 / (1 - beta2) - 1 106 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 107 | buffered[1] = N_sma 108 | 109 | # more conservative since it's an approximated value 110 | if N_sma >= 5: 111 | radam_step = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * 112 | (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 113 | else: 114 | radam_step = group['lr'] / (1 - beta1 ** state['step']) 115 | buffered[2] = radam_step 116 | 117 | if group['weight_decay'] != 0: 118 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 119 | 120 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 121 | radam_norm = p_data_fp32.pow(2).sum().sqrt() 122 | if weight_norm == 0 or radam_norm == 0: 123 | trust_ratio = 1 124 | else: 125 | trust_ratio = weight_norm / radam_norm 126 | 127 | state['weight_norm'] = weight_norm 128 | state['adam_norm'] = radam_norm 129 | state['trust_ratio'] = trust_ratio 130 | 131 | # more conservative since it's an approximated value 132 | if N_sma >= 5: 133 | denom = exp_avg_sq.sqrt().add_(group['eps']) 134 | p_data_fp32.addcdiv_(-radam_step * trust_ratio, exp_avg, denom) 135 | else: 136 | p_data_fp32.add_(-radam_step * trust_ratio, exp_avg) 137 | 138 | p.data.copy_(p_data_fp32) 139 | 140 | return loss 141 | 142 | 143 | # RAdam + LARS + LookAHead 144 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 145 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 146 | 147 | def Over9000(params, alpha=0.5, k=6, *args, **kwargs): 148 | ralamb = Ralamb(params, *args, **kwargs) 149 | return Lookahead(ralamb, alpha, k) 150 | 151 | 152 | RangerLars = Over9000 153 | -------------------------------------------------------------------------------- /torch_utils/criterion/focal.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/BloodAxe/pytorch-toolbelt 2 | 3 | from functools import partial 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn.modules.loss import _Loss 9 | 10 | __all__ = ["BinaryFocalLoss", "FocalLoss"] 11 | 12 | 13 | def focal_loss_with_logits( 14 | output: torch.Tensor, 15 | target: torch.Tensor, 16 | gamma: float = 2.0, 17 | alpha: Optional[float] = 0.25, 18 | reduction: str = "mean", 19 | normalized: bool = False, 20 | reduced_threshold: Optional[float] = None, 21 | eps: float = 1e-6, 22 | ) -> torch.Tensor: 23 | """Compute binary focal loss between target and output logits. 24 | See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. 25 | Args: 26 | output: Tensor of arbitrary shape (predictions of the model) 27 | target: Tensor of the same shape as input 28 | gamma: Focal loss power factor 29 | alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, 30 | high values will give more weight to positive class. 31 | reduction (string, optional): Specifies the reduction to apply to the output: 32 | 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, 33 | 'mean': the sum of the output will be divided by the number of 34 | elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` 35 | and :attr:`reduce` are in the process of being deprecated, and in the meantime, 36 | specifying either of those two args will override :attr:`reduction`. 37 | 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' 38 | normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). 39 | reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). 40 | References: 41 | https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py 42 | """ 43 | target = target.type_as(output) 44 | 45 | logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none") 46 | pt = torch.exp(-logpt) 47 | 48 | # compute the loss 49 | if reduced_threshold is None: 50 | focal_term = (1.0 - pt).pow(gamma) 51 | else: 52 | focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) 53 | focal_term[pt < reduced_threshold] = 1 54 | 55 | loss = focal_term * logpt 56 | 57 | if alpha is not None: 58 | loss *= alpha * target + (1 - alpha) * (1 - target) 59 | 60 | if normalized: 61 | norm_factor = focal_term.sum(dtype=torch.float32).clamp_min(eps) 62 | loss /= norm_factor 63 | 64 | if reduction == "mean": 65 | loss = loss.mean() 66 | if reduction == "sum": 67 | loss = loss.sum(dtype=torch.float32) 68 | if reduction == "batchwise_mean": 69 | loss = loss.sum(dim=0, dtype=torch.float32) 70 | 71 | return loss 72 | 73 | 74 | class BinaryFocalLoss(_Loss): 75 | def __init__( 76 | self, 77 | alpha=None, 78 | gamma: float = 2.0, 79 | ignore_index=None, 80 | reduction="mean", 81 | normalized=False, 82 | reduced_threshold=None, 83 | ): 84 | """ 85 | :param alpha: Prior probability of having positive value in target. 86 | :param gamma: Power factor for dampening weight (focal strenght). 87 | :param ignore_index: If not None, targets may contain values to be ignored. 88 | Target values equal to ignore_index will be ignored from loss computation. 89 | :param reduced: Switch to reduced focal loss. Note, when using this mode you should use `reduction="sum"`. 90 | :param threshold: 91 | """ 92 | super().__init__() 93 | self.ignore_index = ignore_index 94 | self.focal_loss_fn = partial( 95 | focal_loss_with_logits, 96 | alpha=alpha, 97 | gamma=gamma, 98 | reduced_threshold=reduced_threshold, 99 | reduction=reduction, 100 | normalized=normalized, 101 | ) 102 | 103 | def forward(self, label_input, label_target): 104 | """Compute focal loss for binary classification problem.""" 105 | 106 | if self.ignore_index is not None: 107 | # Filter predictions with ignore label from loss computation 108 | ignored = label_target.eq(self.ignore_index) 109 | mask = ~ignored 110 | label_input = label_input * mask 111 | label_target = label_target * mask 112 | 113 | loss = self.focal_loss_fn(label_input, label_target) 114 | return loss 115 | 116 | 117 | class FocalLoss(_Loss): 118 | def __init__( 119 | self, alpha=None, gamma=2, ignore_index=None, reduction="mean", normalized=False, reduced_threshold=None 120 | ): 121 | """ 122 | Focal loss for multi-class problem. 123 | :param alpha: 124 | :param gamma: 125 | :param ignore_index: If not None, targets with given index are ignored 126 | :param reduced_threshold: A threshold factor for computing reduced focal loss 127 | """ 128 | super().__init__() 129 | self.ignore_index = ignore_index 130 | self.focal_loss_fn = partial( 131 | focal_loss_with_logits, 132 | alpha=alpha, 133 | gamma=gamma, 134 | reduced_threshold=reduced_threshold, 135 | reduction=reduction, 136 | normalized=normalized, 137 | ) 138 | 139 | def forward(self, label_input, label_target): 140 | num_classes = label_input.size(1) 141 | loss = 0 142 | 143 | # Filter anchors with -1 label from loss computation 144 | if self.ignore_index is not None: 145 | not_ignored = label_target != self.ignore_index 146 | 147 | for cls in range(num_classes): 148 | cls_label_target = (label_target == cls).long() 149 | cls_label_input = label_input[:, cls, ...] 150 | 151 | if self.ignore_index is not None: 152 | cls_label_target = cls_label_target[not_ignored] 153 | cls_label_input = cls_label_input[not_ignored] 154 | 155 | loss += self.focal_loss_fn(cls_label_input, cls_label_target) 156 | return loss 157 | -------------------------------------------------------------------------------- /torch_utils/criterion/lovasz.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | https://github.com/BloodAxe/pytorch-toolbelt 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | from typing import Optional, Union 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | from torch.nn.modules.loss import _Loss 15 | 16 | try: 17 | from itertools import ifilterfalse 18 | except ImportError: # py3k 19 | from itertools import filterfalse as ifilterfalse 20 | 21 | __all__ = ["BinaryLovaszLoss", "LovaszLoss"] 22 | 23 | 24 | def _lovasz_grad(gt_sorted): 25 | """Compute gradient of the Lovasz extension w.r.t sorted errors 26 | See Alg. 1 in paper 27 | """ 28 | p = len(gt_sorted) 29 | gts = gt_sorted.sum() 30 | intersection = gts - gt_sorted.float().cumsum(0) 31 | union = gts + (1 - gt_sorted).float().cumsum(0) 32 | jaccard = 1.0 - intersection / union 33 | if p > 1: # cover 1-pixel case 34 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 35 | return jaccard 36 | 37 | 38 | def _lovasz_hinge(logits, labels, per_image=True, ignore_index=None): 39 | """ 40 | Binary Lovasz hinge loss 41 | logits: [B, H, W] Variable, logits at each pixel (between -infinity and +infinity) 42 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 43 | per_image: compute the loss per image instead of per batch 44 | ignore: void class id 45 | """ 46 | if per_image: 47 | loss = mean( 48 | _lovasz_hinge_flat(*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore_index)) 49 | for log, lab in zip(logits, labels) 50 | ) 51 | else: 52 | loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore_index)) 53 | return loss 54 | 55 | 56 | def _lovasz_hinge_flat(logits, labels): 57 | """Binary Lovasz hinge loss 58 | Args: 59 | logits: [P] Variable, logits at each prediction (between -iinfinity and +iinfinity) 60 | labels: [P] Tensor, binary ground truth labels (0 or 1) 61 | ignore: label to ignore 62 | """ 63 | if len(labels) == 0: 64 | # only void pixels, the gradients should be 0 65 | return logits.sum() * 0.0 66 | signs = 2.0 * labels.float() - 1.0 67 | errors = 1.0 - logits * Variable(signs) 68 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 69 | perm = perm.data 70 | gt_sorted = labels[perm] 71 | grad = _lovasz_grad(gt_sorted) 72 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 73 | return loss 74 | 75 | 76 | def _flatten_binary_scores(scores, labels, ignore_index=None): 77 | """Flattens predictions in the batch (binary case) 78 | Remove labels equal to 'ignore' 79 | """ 80 | scores = scores.view(-1) 81 | labels = labels.view(-1) 82 | if ignore_index is None: 83 | return scores, labels 84 | valid = labels != ignore_index 85 | vscores = scores[valid] 86 | vlabels = labels[valid] 87 | return vscores, vlabels 88 | 89 | 90 | # --------------------------- MULTICLASS LOSSES --------------------------- 91 | 92 | 93 | def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore_index=None): 94 | """Multi-class Lovasz-Softmax loss 95 | Args: 96 | @param probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 97 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 98 | @param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 99 | @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 100 | @param per_image: compute the loss per image instead of per batch 101 | @param ignore_index: void class labels 102 | """ 103 | if per_image: 104 | loss = mean( 105 | _lovasz_softmax_flat(*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore_index), classes=classes) 106 | for prob, lab in zip(probas, labels) 107 | ) 108 | else: 109 | loss = _lovasz_softmax_flat(*_flatten_probas(probas, labels, ignore_index), classes=classes) 110 | return loss 111 | 112 | 113 | def _lovasz_softmax_flat(probas, labels, classes="present"): 114 | """Multi-class Lovasz-Softmax loss 115 | Args: 116 | @param probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 117 | @param labels: [P] Tensor, ground truth labels (between 0 and C - 1) 118 | @param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 119 | """ 120 | if probas.numel() == 0: 121 | # only void pixels, the gradients should be 0 122 | return probas * 0.0 123 | C = probas.size(1) 124 | losses = [] 125 | class_to_sum = list(range(C)) if classes in ["all", "present"] else classes 126 | for c in class_to_sum: 127 | fg = (labels == c).type_as(probas) # foreground for class c 128 | if classes == "present" and fg.sum() == 0: 129 | continue 130 | if C == 1: 131 | if len(classes) > 1: 132 | raise ValueError("Sigmoid output possible only with 1 class") 133 | class_pred = probas[:, 0] 134 | else: 135 | class_pred = probas[:, c] 136 | errors = (fg - class_pred).abs() 137 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 138 | perm = perm.data 139 | fg_sorted = fg[perm] 140 | losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) 141 | return mean(losses) 142 | 143 | 144 | def _flatten_probas(probas, labels, ignore=None): 145 | """Flattens predictions in the batch""" 146 | if probas.dim() == 3: 147 | # assumes output of a sigmoid layer 148 | B, H, W = probas.size() 149 | probas = probas.view(B, 1, H, W) 150 | 151 | C = probas.size(1) 152 | probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C] 153 | probas = probas.contiguous().view(-1, C) # [P, C] 154 | 155 | labels = labels.view(-1) 156 | if ignore is None: 157 | return probas, labels 158 | valid = labels != ignore 159 | vprobas = probas[valid] 160 | vlabels = labels[valid] 161 | return vprobas, vlabels 162 | 163 | 164 | # --------------------------- HELPER FUNCTIONS --------------------------- 165 | def isnan(x): 166 | return x != x 167 | 168 | 169 | def mean(values, ignore_nan=False, empty=0): 170 | """Nanmean compatible with generators.""" 171 | values = iter(values) 172 | if ignore_nan: 173 | values = ifilterfalse(isnan, values) 174 | try: 175 | n = 1 176 | acc = next(values) 177 | except StopIteration: 178 | if empty == "raise": 179 | raise ValueError("Empty mean") 180 | return empty 181 | for n, v in enumerate(values, 2): 182 | acc += v 183 | if n == 1: 184 | return acc 185 | return acc / n 186 | 187 | 188 | class BinaryLovaszLoss(_Loss): 189 | def __init__(self, per_image: bool = False, ignore_index: Optional[Union[int, float]] = None): 190 | super().__init__() 191 | self.ignore_index = ignore_index 192 | self.per_image = per_image 193 | 194 | def forward(self, logits, target): 195 | return _lovasz_hinge(logits, target, per_image=self.per_image, ignore_index=self.ignore_index) 196 | 197 | 198 | class LovaszLoss(_Loss): 199 | def __init__(self, per_image=False, ignore=None): 200 | super().__init__() 201 | self.ignore = ignore 202 | self.per_image = per_image 203 | 204 | def forward(self, logits, target): 205 | return _lovasz_softmax(logits, target, per_image=self.per_image, ignore_index=self.ignore) 206 | -------------------------------------------------------------------------------- /torch_utils/advanced/NextVLAD.py: -------------------------------------------------------------------------------- 1 | # https://github.com/ceshine/yt8m-2019/blob/95679eb3cf2ebc03c6c496319975cbe2dcb45af4/yt8m/encoders.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def general_weight_initialization(module: nn.Module): 9 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): 10 | if module.weight is not None: 11 | nn.init.uniform_(module.weight) 12 | if module.bias is not None: 13 | nn.init.constant_(module.bias, 0) 14 | elif isinstance(module, nn.Linear): 15 | nn.init.kaiming_normal_(module.weight) 16 | # print("Initing linear") 17 | if module.bias is not None: 18 | nn.init.constant_(module.bias, 0) 19 | 20 | 21 | class TimeFirstBatchNorm1d(nn.Module): 22 | def __init__(self, dim, groups=None): 23 | super().__init__() 24 | self.groups = groups 25 | self.bn = nn.BatchNorm1d(dim) 26 | 27 | def forward(self, tensor): 28 | _, length, dim = tensor.size() 29 | if self.groups: 30 | dim = dim // self.groups 31 | tensor = tensor.view(-1, dim) 32 | tensor = self.bn(tensor) 33 | if self.groups: 34 | return tensor.view(-1, length, self.groups, dim) 35 | else: 36 | return tensor.view(-1, length, dim) 37 | 38 | 39 | class NeXtVLAD(nn.Module): 40 | """NeXtVLAD layer implementation 41 | Adapted from https://github.com/linrongc/youtube-8m/blob/master/nextvlad.py 42 | """ 43 | 44 | def __init__(self, num_clusters=64, dim=128, alpha=100.0, 45 | groups: int = 8, expansion: int = 2, 46 | normalize_input=True, p_drop=0.25, add_batchnorm=False): 47 | """ 48 | Args: 49 | num_clusters : int 50 | The number of clusters 51 | dim : int 52 | Dimension of descriptors 53 | alpha : float 54 | Parameter of initialization. Larger value is harder assignment. 55 | normalize_input : bool 56 | If true, descriptor-wise L2 normalization is applied to input. 57 | """ 58 | super().__init__() 59 | assert dim % groups == 0, "`dim` must be divisible by `groups`" 60 | assert expansion > 1 61 | self.p_drop = p_drop 62 | self.cluster_dropout = nn.Dropout2d(p_drop) 63 | self.num_clusters = num_clusters 64 | self.dim = dim 65 | self.expansion = expansion 66 | self.grouped_dim = dim * expansion // groups 67 | self.groups = groups 68 | self.alpha = alpha 69 | self.normalize_input = normalize_input 70 | self.add_batchnorm = add_batchnorm 71 | self.expansion_mapper = nn.Linear(dim, dim * expansion) 72 | if add_batchnorm: 73 | self.soft_assignment_mapper = nn.Sequential( 74 | nn.Linear(dim * expansion, num_clusters * groups, bias=False), 75 | TimeFirstBatchNorm1d(num_clusters, groups=groups) 76 | ) 77 | else: 78 | self.soft_assignment_mapper = nn.Linear( 79 | dim * expansion, num_clusters * groups, bias=True) 80 | self.attention_mapper = nn.Linear( 81 | dim * expansion, groups 82 | ) 83 | # (n_clusters, dim / group) 84 | self.centroids = nn.Parameter( 85 | torch.rand(num_clusters, self.grouped_dim)) 86 | self.final_bn = nn.BatchNorm1d(num_clusters * self.grouped_dim) 87 | self._init_params() 88 | 89 | def _init_params(self): 90 | for component in (self.soft_assignment_mapper, self.attention_mapper, 91 | self.expansion_mapper): 92 | for module in component.modules(): 93 | general_weight_initialization(module) 94 | if self.add_batchnorm: 95 | self.soft_assignment_mapper[0].weight = nn.Parameter( 96 | (2.0 * self.alpha * self.centroids).repeat((self.groups, self.groups)) 97 | ) 98 | nn.init.constant_(self.soft_assignment_mapper[1].bn.weight, 1) 99 | nn.init.constant_(self.soft_assignment_mapper[1].bn.bias, 0) 100 | else: 101 | self.soft_assignment_mapper.weight = nn.Parameter( 102 | (2.0 * self.alpha * self.centroids).repeat((self.groups, self.groups)) 103 | ) 104 | self.soft_assignment_mapper.bias = nn.Parameter( 105 | (- self.alpha * self.centroids.norm(dim=1) 106 | ).repeat((self.groups,)) 107 | ) 108 | 109 | def forward(self, x, masks=None): 110 | """NeXtVlad Adaptive Pooling 111 | Arguments: 112 | x {torch.Tensor} -- shape: (n_batch, len, dim) 113 | Returns: 114 | torch.Tensor -- shape (n_batch, n_cluster * dim / groups) 115 | """ 116 | if self.normalize_input: 117 | x = F.normalize(x, p=2, dim=2) # across descriptor dim 118 | 119 | # expansion 120 | # shape: (n_batch, len, dim * expansion) 121 | x = self.expansion_mapper(x) 122 | 123 | # soft-assignment 124 | # shape: (n_batch, len, n_cluster, groups) 125 | soft_assign = self.soft_assignment_mapper(x).view( 126 | x.size(0), x.size(1), self.num_clusters, self.groups 127 | ) 128 | soft_assign = F.softmax(soft_assign, dim=2) 129 | 130 | # attention 131 | # shape: (n_batch, len, groups) 132 | attention = torch.sigmoid(self.attention_mapper(x)) 133 | if masks is not None: 134 | # shape: (n_batch, len, groups) 135 | attention = attention * masks[:, :, None] 136 | 137 | # (n_batch, len, n_cluster, groups, dim / groups) 138 | activation = ( 139 | attention[:, :, None, :, None] * 140 | soft_assign[:, :, :, :, None] 141 | ) 142 | 143 | # calculate residuals to each clusters 144 | # (n_batch, n_cluster, dim / groups) 145 | second_term = ( 146 | activation.sum(dim=3).sum(dim=1) * 147 | self.centroids[None, :, :] 148 | ) 149 | # (n_batch, n_cluster, dim / groups) 150 | first_term = ( 151 | # (n_batch, len, n_cluster, groups, dim / groups) 152 | activation * 153 | x.view(x.size(0), x.size(1), 1, self.groups, self.grouped_dim) 154 | ).sum(dim=3).sum(dim=1) 155 | 156 | # vlad shape (n_batch, n_cluster, dim / groups) 157 | vlad = first_term - second_term 158 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 159 | # flatten shape (n_batch, n_cluster * dim / groups) 160 | vlad = vlad.view(x.size(0), -1) # flatten 161 | # vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 162 | vlad = self.final_bn(vlad) 163 | if self.p_drop: 164 | vlad = self.cluster_dropout( 165 | vlad.view(x.size(0), self.num_clusters, self.grouped_dim, 1) 166 | ).view(x.size(0), -1) 167 | return vlad 168 | 169 | 170 | def test_nextvlad(): 171 | model = NeXtVLAD( 172 | num_clusters=64, dim=128, alpha=100, 173 | groups=8, expansion=2, normalize_input=True, 174 | p_drop=0.25, add_batchnorm=True 175 | ) 176 | # shape (n_batch, len, dim) 177 | input_tensor = torch.rand(16, 300, 128) 178 | # shape (n_batch, n_clusters * dim / groups) 179 | output_tensor = model(input_tensor) 180 | assert output_tensor.size() == (16, 64 * 2 * 128 // 8) 181 | model = NeXtVLAD( 182 | num_clusters=64, dim=128, alpha=100, 183 | groups=8, expansion=2, normalize_input=True, 184 | p_drop=0.25, add_batchnorm=False 185 | ) 186 | # shape (n_batch, len, dim) 187 | input_tensor = torch.rand(16, 300, 128) 188 | # shape (n_batch, n_clusters * dim / groups) 189 | output_tensor = model(input_tensor) 190 | assert output_tensor.size() == (16, 64 * 2 * 128 // 8) 191 | -------------------------------------------------------------------------------- /torch_utils/optimizer/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 3 | 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer, required 7 | 8 | 9 | def centralized_gradient(x, use_gc=True, gc_conv_only=False): 10 | '''credit - https://github.com/Yonghongwei/Gradient-Centralization ''' 11 | if use_gc: 12 | if gc_conv_only: 13 | if len(list(x.size())) > 3: 14 | x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) 15 | else: 16 | if len(list(x.size())) > 1: 17 | x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) 18 | return x 19 | 20 | 21 | class Ranger(Optimizer): 22 | 23 | def __init__(self, params, lr=1e-3, # lr 24 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 25 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 26 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 27 | use_gc=True, gc_conv_only=False, gc_loc=True 28 | ): 29 | 30 | # parameter checks 31 | if not 0.0 <= alpha <= 1.0: 32 | raise ValueError(f'Invalid slow update rate: {alpha}') 33 | if not 1 <= k: 34 | raise ValueError(f'Invalid lookahead steps: {k}') 35 | if not lr > 0: 36 | raise ValueError(f'Invalid Learning Rate: {lr}') 37 | if not eps > 0: 38 | raise ValueError(f'Invalid eps: {eps}') 39 | 40 | # parameter comments: 41 | # beta1 (momentum) of .95 seems to work better than .90... 42 | # N_sma_threshold of 5 seems better in testing than 4. 43 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 44 | 45 | # prep defaults and init torch.optim base 46 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, 47 | N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 48 | super().__init__(params, defaults) 49 | 50 | # adjustable threshold 51 | self.N_sma_threshhold = N_sma_threshhold 52 | 53 | # look ahead params 54 | 55 | self.alpha = alpha 56 | self.k = k 57 | 58 | # radam buffer for state 59 | self.radam_buffer = [[None, None, None] for ind in range(10)] 60 | 61 | # gc on or off 62 | self.gc_loc = gc_loc 63 | self.use_gc = use_gc 64 | self.gc_conv_only = gc_conv_only 65 | # level of gradient centralization 66 | # self.gc_gradient_threshold = 3 if gc_conv_only else 1 67 | 68 | print( 69 | f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") 70 | if (self.use_gc and (self.gc_conv_only is False)): 71 | print(f"GC applied to both conv and fc layers") 72 | elif (self.use_gc and (self.gc_conv_only is True)): 73 | print(f"GC applied to conv layers only") 74 | 75 | def __setstate__(self, state): 76 | print("set state called") 77 | super(Ranger, self).__setstate__(state) 78 | 79 | def step(self, closure=None): 80 | loss = None 81 | # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 82 | # Uncomment if you need to use the actual closure... 83 | 84 | # if closure is not None: 85 | # loss = closure() 86 | 87 | # Evaluate averages and grad, update param tensors 88 | for group in self.param_groups: 89 | 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | grad = p.grad.data.float() 94 | 95 | if grad.is_sparse: 96 | raise RuntimeError( 97 | 'Ranger optimizer does not support sparse gradients') 98 | 99 | p_data_fp32 = p.data.float() 100 | 101 | state = self.state[p] # get state dict for this param 102 | 103 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 104 | # if self.first_run_check==0: 105 | # self.first_run_check=1 106 | # print("Initializing slow buffer...should not see this at load from saved model!") 107 | state['step'] = 0 108 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 109 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 110 | 111 | # look ahead weight storage now in state dict 112 | state['slow_buffer'] = torch.empty_like(p.data) 113 | state['slow_buffer'].copy_(p.data) 114 | 115 | else: 116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 117 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as( 118 | p_data_fp32) 119 | 120 | # begin computations 121 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 122 | beta1, beta2 = group['betas'] 123 | 124 | # GC operation for Conv layers and FC layers 125 | # if grad.dim() > self.gc_gradient_threshold: 126 | # grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 127 | if self.gc_loc: 128 | grad = centralized_gradient(grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only) 129 | 130 | state['step'] += 1 131 | 132 | # compute variance mov avg 133 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 134 | 135 | # compute mean moving avg 136 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 137 | 138 | buffered = self.radam_buffer[int(state['step'] % 10)] 139 | 140 | if state['step'] == buffered[0]: 141 | N_sma, step_size = buffered[1], buffered[2] 142 | else: 143 | buffered[0] = state['step'] 144 | beta2_t = beta2 ** state['step'] 145 | N_sma_max = 2 / (1 - beta2) - 1 146 | N_sma = N_sma_max - 2 * \ 147 | state['step'] * beta2_t / (1 - beta2_t) 148 | buffered[1] = N_sma 149 | if N_sma > self.N_sma_threshhold: 150 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( 151 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 152 | else: 153 | step_size = 1.0 / (1 - beta1 ** state['step']) 154 | buffered[2] = step_size 155 | 156 | # if group['weight_decay'] != 0: 157 | # p_data_fp32.add_(-group['weight_decay'] 158 | # * group['lr'], p_data_fp32) 159 | 160 | # apply lr 161 | if N_sma > self.N_sma_threshhold: 162 | denom = exp_avg_sq.sqrt().add_(group['eps']) 163 | G_grad = exp_avg / denom 164 | else: 165 | G_grad = exp_avg 166 | 167 | if group['weight_decay'] != 0: 168 | G_grad.add_(p_data_fp32, alpha=group['weight_decay']) 169 | # GC operation 170 | if self.gc_loc is False: 171 | G_grad = centralized_gradient(G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only) 172 | 173 | p_data_fp32.add_(G_grad, alpha=-step_size * group['lr']) 174 | p.data.copy_(p_data_fp32) 175 | 176 | # integrated look ahead... 177 | # we do it at the param level instead of group level 178 | if state['step'] % group['k'] == 0: 179 | # get access to slow param tensor 180 | slow_p = state['slow_buffer'] 181 | # (fast weights - slow weights) * alpha 182 | slow_p.add_(p.data - slow_p, alpha=self.alpha) 183 | # copy interpolated weights to RAdam param tensor 184 | p.data.copy_(slow_p) 185 | 186 | return loss 187 | -------------------------------------------------------------------------------- /torch_utils/optimizer/radam.py: -------------------------------------------------------------------------------- 1 | # https://github.com/LiyuanLucasLiu/RAdam 2 | 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | 8 | class RAdam(Optimizer): 9 | 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 11 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 12 | self.buffer = [[None, None, None] for ind in range(10)] 13 | super(RAdam, self).__init__(params, defaults) 14 | 15 | def __setstate__(self, state): 16 | super(RAdam, self).__setstate__(state) 17 | 18 | def step(self, closure=None): 19 | 20 | loss = None 21 | if closure is not None: 22 | loss = closure() 23 | 24 | for group in self.param_groups: 25 | 26 | for p in group['params']: 27 | if p.grad is None: 28 | continue 29 | grad = p.grad.data.float() 30 | if grad.is_sparse: 31 | raise RuntimeError('RAdam does not support sparse gradients') 32 | 33 | p_data_fp32 = p.data.float() 34 | 35 | state = self.state[p] 36 | 37 | if len(state) == 0: 38 | state['step'] = 0 39 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 40 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 41 | else: 42 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 43 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 44 | 45 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 46 | beta1, beta2 = group['betas'] 47 | 48 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 49 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 50 | 51 | state['step'] += 1 52 | buffered = self.buffer[int(state['step'] % 10)] 53 | if state['step'] == buffered[0]: 54 | N_sma, step_size = buffered[1], buffered[2] 55 | else: 56 | buffered[0] = state['step'] 57 | beta2_t = beta2 ** state['step'] 58 | N_sma_max = 2 / (1 - beta2) - 1 59 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 60 | buffered[1] = N_sma 61 | 62 | # more conservative since it's an approximated value 63 | if N_sma >= 5: 64 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * 65 | (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 66 | else: 67 | step_size = group['lr'] / (1 - beta1 ** state['step']) 68 | buffered[2] = step_size 69 | 70 | if group['weight_decay'] != 0: 71 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 72 | 73 | # more conservative since it's an approximated value 74 | if N_sma >= 5: 75 | denom = exp_avg_sq.sqrt().add_(group['eps']) 76 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 77 | else: 78 | p_data_fp32.add_(-step_size, exp_avg) 79 | 80 | p.data.copy_(p_data_fp32) 81 | 82 | return loss 83 | 84 | 85 | class PlainRAdam(Optimizer): 86 | 87 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 88 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 89 | 90 | super(PlainRAdam, self).__init__(params, defaults) 91 | 92 | def __setstate__(self, state): 93 | super(PlainRAdam, self).__setstate__(state) 94 | 95 | def step(self, closure=None): 96 | 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | for group in self.param_groups: 102 | 103 | for p in group['params']: 104 | if p.grad is None: 105 | continue 106 | grad = p.grad.data.float() 107 | if grad.is_sparse: 108 | raise RuntimeError('RAdam does not support sparse gradients') 109 | 110 | p_data_fp32 = p.data.float() 111 | 112 | state = self.state[p] 113 | 114 | if len(state) == 0: 115 | state['step'] = 0 116 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 117 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 118 | else: 119 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 120 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 121 | 122 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 123 | beta1, beta2 = group['betas'] 124 | 125 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 126 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 127 | 128 | state['step'] += 1 129 | beta2_t = beta2 ** state['step'] 130 | N_sma_max = 2 / (1 - beta2) - 1 131 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 132 | 133 | if group['weight_decay'] != 0: 134 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 135 | 136 | # more conservative since it's an approximated value 137 | if N_sma >= 5: 138 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / 139 | N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 140 | denom = exp_avg_sq.sqrt().add_(group['eps']) 141 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 142 | else: 143 | step_size = group['lr'] / (1 - beta1 ** state['step']) 144 | p_data_fp32.add_(-step_size, exp_avg) 145 | 146 | p.data.copy_(p_data_fp32) 147 | 148 | return loss 149 | 150 | 151 | class AdamW(Optimizer): 152 | 153 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): 154 | defaults = dict(lr=lr, betas=betas, eps=eps, 155 | weight_decay=weight_decay, warmup=warmup) 156 | super(AdamW, self).__init__(params, defaults) 157 | 158 | def __setstate__(self, state): 159 | super(AdamW, self).__setstate__(state) 160 | 161 | def step(self, closure=None): 162 | loss = None 163 | if closure is not None: 164 | loss = closure() 165 | 166 | for group in self.param_groups: 167 | 168 | for p in group['params']: 169 | if p.grad is None: 170 | continue 171 | grad = p.grad.data.float() 172 | if grad.is_sparse: 173 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 174 | 175 | p_data_fp32 = p.data.float() 176 | 177 | state = self.state[p] 178 | 179 | if len(state) == 0: 180 | state['step'] = 0 181 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 182 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 183 | else: 184 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 185 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 186 | 187 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 188 | beta1, beta2 = group['betas'] 189 | 190 | state['step'] += 1 191 | 192 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 193 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 194 | 195 | denom = exp_avg_sq.sqrt().add_(group['eps']) 196 | bias_correction1 = 1 - beta1 ** state['step'] 197 | bias_correction2 = 1 - beta2 ** state['step'] 198 | 199 | if group['warmup'] > state['step']: 200 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 201 | else: 202 | scheduled_lr = group['lr'] 203 | 204 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 205 | 206 | if group['weight_decay'] != 0: 207 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 208 | 209 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 210 | 211 | p.data.copy_(p_data_fp32) 212 | 213 | return loss 214 | -------------------------------------------------------------------------------- /torch_utils/criterion/dice.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/qubvel/segmentation_models.pytorch 2 | # reference: https://github.com/BloodAxe/pytorch-toolbelt 3 | 4 | from typing import List, Optional 5 | 6 | import torch 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from torch.nn.modules.loss import _Loss 10 | 11 | __all__ = ["DiceLoss", "TverskyLoss"] 12 | 13 | BINARY_MODE = "binary" 14 | MULTICLASS_MODE = "multiclass" 15 | MULTILABEL_MODE = "multilabel" 16 | 17 | 18 | def to_tensor(x, dtype=None) -> torch.Tensor: 19 | if isinstance(x, torch.Tensor): 20 | if dtype is not None: 21 | x = x.type(dtype) 22 | return x 23 | if isinstance(x, np.ndarray): 24 | x = torch.from_numpy(x) 25 | if dtype is not None: 26 | x = x.type(dtype) 27 | return x 28 | if isinstance(x, (list, tuple)): 29 | x = np.array(x) 30 | x = torch.from_numpy(x) 31 | if dtype is not None: 32 | x = x.type(dtype) 33 | return x 34 | 35 | 36 | def soft_dice_score( 37 | output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None 38 | ) -> torch.Tensor: 39 | """ 40 | :param output: 41 | :param target: 42 | :param smooth: 43 | :param eps: 44 | :return: 45 | Shape: 46 | - Input: :math:`(N, NC, *)` where :math:`*` means any number 47 | of additional dimensions 48 | - Target: :math:`(N, NC, *)`, same shape as the input 49 | - Output: scalar. 50 | """ 51 | assert output.size() == target.size() 52 | if dims is not None: 53 | intersection = torch.sum(output * target, dim=dims) 54 | cardinality = torch.sum(output + target, dim=dims) 55 | else: 56 | intersection = torch.sum(output * target) 57 | cardinality = torch.sum(output + target) 58 | dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) 59 | return dice_score 60 | 61 | 62 | def soft_tversky_score(output: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, 63 | smooth: float = 0.0, eps: float = 1e-7, dims=None) -> torch.Tensor: 64 | assert output.size() == target.size() 65 | if dims is not None: 66 | intersection = torch.sum(output * target, dim=dims) # TP 67 | fp = torch.sum(output * (1. - target), dim=dims) 68 | fn = torch.sum((1 - output) * target, dim=dims) 69 | else: 70 | intersection = torch.sum(output * target) # TP 71 | fp = torch.sum(output * (1. - target)) 72 | fn = torch.sum((1 - output) * target) 73 | 74 | tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps) 75 | return tversky_score 76 | 77 | 78 | class DiceLoss(_Loss): 79 | """ 80 | Implementation of Dice loss for image segmentation task. 81 | It supports binary, multiclass and multilabel cases 82 | """ 83 | 84 | def __init__( 85 | self, 86 | mode: str, 87 | classes: List[int] = None, 88 | log_loss=False, 89 | from_logits=True, 90 | smooth: float = 0.0, 91 | ignore_index=None, 92 | eps=1e-7, 93 | ): 94 | """ 95 | :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} 96 | :param classes: Optional list of classes that contribute in loss computation; 97 | By default, all channels are included. 98 | :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` 99 | :param from_logits: If True assumes input is raw logits 100 | :param smooth: 101 | :param ignore_index: Label that indicates ignored pixels (does not contribute to loss) 102 | :param eps: Small epsilon for numerical stability 103 | """ 104 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 105 | super(DiceLoss, self).__init__() 106 | self.mode = mode 107 | if classes is not None: 108 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 109 | classes = to_tensor(classes, dtype=torch.long) 110 | 111 | self.classes = classes 112 | self.from_logits = from_logits 113 | self.smooth = smooth 114 | self.eps = eps 115 | self.ignore_index = ignore_index 116 | self.log_loss = log_loss 117 | 118 | def forward(self, y_pred, y_true): 119 | """ 120 | :param y_pred: NxCxHxW 121 | :param y_true: NxHxW 122 | :return: scalar 123 | """ 124 | assert y_true.size(0) == y_pred.size(0) 125 | 126 | if self.from_logits: 127 | # Apply activations to get [0..1] class probabilities 128 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 129 | # extreme values 0 and 1 130 | if self.mode == MULTICLASS_MODE: 131 | y_pred = y_pred.log_softmax(dim=1).exp() 132 | else: 133 | y_pred = F.logsigmoid(y_pred).exp() 134 | 135 | bs = y_true.size(0) 136 | num_classes = y_pred.size(1) 137 | dims = (0, 2) 138 | 139 | if self.mode == BINARY_MODE: 140 | y_true = y_true.view(bs, 1, -1) 141 | y_pred = y_pred.view(bs, 1, -1) 142 | 143 | if self.ignore_index is not None: 144 | mask = y_true != self.ignore_index 145 | y_pred = y_pred * mask 146 | y_true = y_true * mask 147 | 148 | if self.mode == MULTICLASS_MODE: 149 | y_true = y_true.view(bs, -1) 150 | y_pred = y_pred.view(bs, num_classes, -1) 151 | 152 | if self.ignore_index is not None: 153 | mask = y_true != self.ignore_index 154 | y_pred = y_pred * mask.unsqueeze(1) 155 | 156 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C 157 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W 158 | else: 159 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 160 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 161 | 162 | if self.mode == MULTILABEL_MODE: 163 | y_true = y_true.view(bs, num_classes, -1) 164 | y_pred = y_pred.view(bs, num_classes, -1) 165 | 166 | if self.ignore_index is not None: 167 | mask = y_true != self.ignore_index 168 | y_pred = y_pred * mask 169 | y_true = y_true * mask 170 | 171 | scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims) 172 | 173 | if self.log_loss: 174 | loss = -torch.log(scores.clamp_min(self.eps)) 175 | else: 176 | loss = 1.0 - scores 177 | 178 | # Dice loss is undefined for non-empty classes 179 | # So we zero contribution of channel that does not have true pixels 180 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 181 | # for this case, however it will be a modified jaccard loss 182 | 183 | mask = y_true.sum(dims) > 0 184 | loss *= mask.to(loss.dtype) 185 | 186 | if self.classes is not None: 187 | loss = loss[self.classes] 188 | 189 | return loss.mean() 190 | 191 | 192 | class TverskyLoss(DiceLoss): 193 | """Implementation of Tversky loss for image segmentation task. 194 | Where TP and FP is weighted by alpha and beta params. 195 | With alpha == beta == 0.5, this loss becomes equal DiceLoss. 196 | It supports binary, multiclass and multilabel cases 197 | Args: 198 | mode: Metric mode {'binary', 'multiclass', 'multilabel'} 199 | classes: Optional list of classes that contribute in loss computation; 200 | By default, all channels are included. 201 | log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` 202 | from_logits: If True assumes input is raw logits 203 | smooth: 204 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 205 | eps: Small epsilon for numerical stability 206 | alpha: Weight constant that penalize model for FPs (False Positives) 207 | beta: Weight constant that penalize model for FNs (False Positives) 208 | gamma: Constant that squares the error function. Defaults to ``1.0`` 209 | Return: 210 | loss: torch.Tensor 211 | """ 212 | 213 | def __init__( 214 | self, 215 | mode: str, 216 | classes: List[int] = None, 217 | log_loss: bool = False, 218 | from_logits: bool = True, 219 | smooth: float = 0.0, 220 | ignore_index: Optional[int] = None, 221 | eps: float = 1e-7, 222 | alpha: float = 0.5, 223 | beta: float = 0.5, 224 | gamma: float = 1.0, 225 | ): 226 | 227 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 228 | super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps) 229 | self.alpha = alpha 230 | self.beta = beta 231 | self.gamma = gamma 232 | 233 | def aggregate_loss(self, loss): 234 | return loss.mean() ** self.gamma 235 | 236 | def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: 237 | return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims) 238 | -------------------------------------------------------------------------------- /torch_utils/optimizer/timm_optim/lamb.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb 2 | 3 | This optimizer code was adapted from the following (starting with latest) 4 | * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py 5 | * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py 6 | * https://github.com/cybertronai/pytorch-lamb 7 | 8 | Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is 9 | similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. 10 | 11 | In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. 12 | 13 | Original copyrights for above sources are below. 14 | 15 | Modifications Copyright 2021 Ross Wightman 16 | """ 17 | # Copyright (c) 2021, Habana Labs Ltd. All rights reserved. 18 | 19 | # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. 20 | # 21 | # Licensed under the Apache License, Version 2.0 (the "License"); 22 | # you may not use this file except in compliance with the License. 23 | # You may obtain a copy of the License at 24 | # 25 | # http://www.apache.org/licenses/LICENSE-2.0 26 | # 27 | # Unless required by applicable law or agreed to in writing, software 28 | # distributed under the License is distributed on an "AS IS" BASIS, 29 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | # See the License for the specific language governing permissions and 31 | # limitations under the License. 32 | 33 | # MIT License 34 | # 35 | # Copyright (c) 2019 cybertronai 36 | # 37 | # Permission is hereby granted, free of charge, to any person obtaining a copy 38 | # of this software and associated documentation files (the "Software"), to deal 39 | # in the Software without restriction, including without limitation the rights 40 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 41 | # copies of the Software, and to permit persons to whom the Software is 42 | # furnished to do so, subject to the following conditions: 43 | # 44 | # The above copyright notice and this permission notice shall be included in all 45 | # copies or substantial portions of the Software. 46 | # 47 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 48 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 49 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 50 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 51 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 52 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 53 | # SOFTWARE. 54 | import math 55 | 56 | import torch 57 | from torch.optim import Optimizer 58 | 59 | 60 | class Lamb(Optimizer): 61 | """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB 62 | reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py 63 | 64 | LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 65 | 66 | Arguments: 67 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups. 68 | lr (float, optional): learning rate. (default: 1e-3) 69 | betas (Tuple[float, float], optional): coefficients used for computing 70 | running averages of gradient and its norm. (default: (0.9, 0.999)) 71 | eps (float, optional): term added to the denominator to improve 72 | numerical stability. (default: 1e-8) 73 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 74 | grad_averaging (bool, optional): whether apply (1-beta2) to grad when 75 | calculating running averages of gradient. (default: True) 76 | max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) 77 | trust_clip (bool): enable LAMBC trust ratio clipping (default: False) 78 | always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 79 | weight decay parameter (default: False) 80 | 81 | .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: 82 | https://arxiv.org/abs/1904.00962 83 | .. _On the Convergence of Adam and Beyond: 84 | https://openreview.net/forum?id=ryQu7f-RZ 85 | """ 86 | 87 | def __init__( 88 | self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, 89 | weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): 90 | defaults = dict( 91 | lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, 92 | grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, 93 | trust_clip=trust_clip, always_adapt=always_adapt) 94 | super().__init__(params, defaults) 95 | 96 | @torch.no_grad() 97 | def step(self, closure=None): 98 | """Performs a single optimization step. 99 | Arguments: 100 | closure (callable, optional): A closure that reevaluates the model 101 | and returns the loss. 102 | """ 103 | loss = None 104 | if closure is not None: 105 | with torch.enable_grad(): 106 | loss = closure() 107 | 108 | device = self.param_groups[0]['params'][0].device 109 | one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly 110 | global_grad_norm = torch.zeros(1, device=device) 111 | for group in self.param_groups: 112 | for p in group['params']: 113 | if p.grad is None: 114 | continue 115 | grad = p.grad 116 | if grad.is_sparse: 117 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 118 | global_grad_norm.add_(grad.pow(2).sum()) 119 | 120 | global_grad_norm = torch.sqrt(global_grad_norm) 121 | # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes 122 | # scalar types properly https://github.com/pytorch/pytorch/issues/9190 123 | max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) 124 | clip_global_grad_norm = torch.where( 125 | global_grad_norm > max_grad_norm, 126 | global_grad_norm / max_grad_norm, 127 | one_tensor) 128 | 129 | for group in self.param_groups: 130 | bias_correction = 1 if group['bias_correction'] else 0 131 | beta1, beta2 = group['betas'] 132 | grad_averaging = 1 if group['grad_averaging'] else 0 133 | beta3 = 1 - beta1 if grad_averaging else 1.0 134 | 135 | # assume same step across group now to simplify things 136 | # per parameter step can be easily support by making it tensor, or pass list into kernel 137 | if 'step' in group: 138 | group['step'] += 1 139 | else: 140 | group['step'] = 1 141 | 142 | if bias_correction: 143 | bias_correction1 = 1 - beta1 ** group['step'] 144 | bias_correction2 = 1 - beta2 ** group['step'] 145 | else: 146 | bias_correction1, bias_correction2 = 1.0, 1.0 147 | 148 | for p in group['params']: 149 | if p.grad is None: 150 | continue 151 | grad = p.grad.div_(clip_global_grad_norm) 152 | state = self.state[p] 153 | 154 | # State initialization 155 | if len(state) == 0: 156 | # Exponential moving average of gradient valuesa 157 | state['exp_avg'] = torch.zeros_like(p) 158 | # Exponential moving average of squared gradient values 159 | state['exp_avg_sq'] = torch.zeros_like(p) 160 | 161 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 162 | 163 | # Decay the first and second moment running average coefficient 164 | exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t 165 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t 166 | 167 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 168 | update = (exp_avg / bias_correction1).div_(denom) 169 | 170 | weight_decay = group['weight_decay'] 171 | if weight_decay != 0: 172 | update.add_(p, alpha=weight_decay) 173 | 174 | if weight_decay != 0 or group['always_adapt']: 175 | # Layer-wise LR adaptation. By default, skip adaptation on parameters that are 176 | # excluded from weight decay, unless always_adapt == True, then always enabled. 177 | w_norm = p.norm(2.0) 178 | g_norm = update.norm(2.0) 179 | # FIXME nested where required since logical and/or not working in PT XLA 180 | trust_ratio = torch.where( 181 | w_norm > 0, 182 | torch.where(g_norm > 0, w_norm / g_norm, one_tensor), 183 | one_tensor, 184 | ) 185 | if group['trust_clip']: 186 | # LAMBC trust clipping, upper bound fixed at one 187 | trust_ratio = torch.minimum(trust_ratio, one_tensor) 188 | update.mul_(trust_ratio) 189 | 190 | p.add_(update, alpha=-group['lr']) 191 | 192 | return loss 193 | -------------------------------------------------------------------------------- /torch_utils/criterion/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.loss import _WeightedLoss 5 | 6 | 7 | class LabelSmoothingCrossEntropy(nn.Module): 8 | """ 9 | NLL loss with label smoothing. 10 | """ 11 | 12 | def __init__(self, smoothing=0.1): 13 | """ 14 | Constructor for the LabelSmoothing module. 15 | :param smoothing: label smoothing factor 16 | """ 17 | super(LabelSmoothingCrossEntropy, self).__init__() 18 | assert smoothing < 1.0 19 | self.smoothing = smoothing 20 | self.confidence = 1. - smoothing 21 | 22 | def forward(self, x, target): 23 | logprobs = F.log_softmax(x, dim=-1) 24 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 25 | nll_loss = nll_loss.squeeze(1) 26 | smooth_loss = -logprobs.mean(dim=-1) 27 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 28 | return loss.mean() 29 | 30 | 31 | class SmoothBCEwLogits(_WeightedLoss): 32 | def __init__(self, weight=None, reduction='mean', smoothing=0.0): 33 | super().__init__(weight=weight, reduction=reduction) 34 | self.smoothing = smoothing 35 | self.weight = weight 36 | self.reduction = reduction 37 | 38 | @staticmethod 39 | def _smooth(targets: torch.Tensor, n_labels: int, smoothing=0.0): 40 | assert 0 <= smoothing < 1 41 | with torch.no_grad(): 42 | targets = targets * (1.0 - smoothing) + 0.5 * smoothing 43 | return targets 44 | 45 | def forward(self, inputs, targets): 46 | targets = SmoothBCEwLogits._smooth(targets, inputs.size(-1), self.smoothing) 47 | loss = F.binary_cross_entropy_with_logits(inputs, targets, self.weight) 48 | 49 | if self.reduction == 'sum': 50 | loss = loss.sum() 51 | elif self.reduction == 'mean': 52 | loss = loss.mean() 53 | else: 54 | loss = loss.mean() 55 | 56 | return loss 57 | 58 | 59 | class SoftTargetCrossEntropy(nn.Module): 60 | 61 | def __init__(self): 62 | super(SoftTargetCrossEntropy, self).__init__() 63 | 64 | def forward(self, x, target): 65 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 66 | return loss.mean() 67 | 68 | 69 | class KLDivLosswSoftmax(nn.Module): 70 | """KL-divergence with softmax""" 71 | 72 | def __init__(self): 73 | super(KLDivLosswSoftmax, self).__init__() 74 | self.loss = nn.KLDivLoss(reduction='batchmean') 75 | 76 | def forward(self, model_output, target): 77 | log = F.log_softmax(model_output, dim=-1) 78 | loss = self.loss(log, target) 79 | return loss 80 | 81 | 82 | class topkLoss(nn.Module): 83 | """topkLoss: Online Hard Example Mining""" 84 | 85 | def __init__(self, loss, top_k=0.75): 86 | super(topkLoss, self).__init__() 87 | self.top_k = top_k 88 | self.loss = loss 89 | 90 | def forward(self, input, target): 91 | loss = self.loss(input, target) 92 | if self.top_k == 1: 93 | return torch.mean(loss) 94 | else: 95 | valid_loss, idxs = torch.topk(loss, round(self.top_k * loss.size()[0]), dim=0) 96 | return torch.mean(valid_loss) 97 | 98 | # class JsdCrossEntropy(nn.Module): 99 | # """ Jensen-Shannon Divergence + Cross-Entropy Loss for AugMix 100 | # Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py 101 | # From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - 102 | # https://arxiv.org/abs/1912.02781 103 | # Hacked together by / Copyright 2020 Ross Wightman 104 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/loss/jsd.py 105 | # require: AugMixDataset(split data) + split_data_collate + JsdCrossEntropy 106 | # optional: split bn for different strength of augmentations 107 | 108 | # Example: 109 | # >>> num_aug_splits = 3 110 | # >>> # from timm.models import convert_splitbn_model 111 | # >>> # model = convert_splitbn_model(model, max(num_aug_splits, 2)) 112 | # >>> dataset = AugMixDataset(dataset_train, num_splits=num_aug_splits) # TODO: re-implement augmix dataset from timm 113 | # >>> dataloader = DataLoader(dataset, bs, collate_fn=split_data_collate) 114 | # >>> # split_data_collate: (s,s_a1,s_a2) to [s1,s2,s3, s1_a1,s2_a1,s3_a1, s1_a2,s2_a2,s3_a2] 115 | # >>> train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=0.1).cuda() 116 | # """ 117 | # def __init__(self, num_splits=3, alpha=12, smoothing=0.1): 118 | # super().__init__() 119 | # self.num_splits = num_splits 120 | # self.alpha = alpha 121 | # if smoothing is not None and smoothing > 0: 122 | # self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) 123 | # else: 124 | # self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 125 | 126 | # def __call__(self, output, target): 127 | # split_size = output.shape[0] // self.num_splits 128 | # assert split_size * self.num_splits == output.shape[0] 129 | # logits_split = torch.split(output, split_size) 130 | 131 | # # Cross-entropy is only computed on clean images 132 | # loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) 133 | # probs = [F.softmax(logits, dim=1) for logits in logits_split] 134 | 135 | # # Clamp mixture distribution to avoid exploding KL divergence 136 | # logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() 137 | # loss += self.alpha * sum([F.kl_div( 138 | # logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) 139 | # return loss 140 | 141 | 142 | class SoftBCEWithLogitsLoss(nn.Module): 143 | """ 144 | Drop-in replacement for nn.BCEWithLogitsLoss with few additions: 145 | - Support of ignore_index value 146 | - Support of label smoothing 147 | """ 148 | 149 | __constants__ = ["weight", "pos_weight", "reduction", "ignore_index", "smooth_factor"] 150 | 151 | def __init__( 152 | self, weight=None, ignore_index=None, reduction="mean", smooth_factor=None, pos_weight=None 153 | ): 154 | super().__init__() 155 | self.ignore_index = ignore_index 156 | self.reduction = reduction 157 | self.smooth_factor = smooth_factor 158 | self.register_buffer("weight", weight) 159 | self.register_buffer("pos_weight", pos_weight) 160 | 161 | def forward(self, input, target): 162 | if self.smooth_factor is not None: 163 | soft_targets = ((1 - target) * self.smooth_factor + target * (1 - self.smooth_factor)).type_as(input) 164 | else: 165 | soft_targets = target.type_as(input) 166 | 167 | loss = F.binary_cross_entropy_with_logits( 168 | input, soft_targets, self.weight, pos_weight=self.pos_weight, reduction="none" 169 | ) 170 | 171 | if self.ignore_index is not None: 172 | not_ignored_mask = target != self.ignore_index 173 | loss *= not_ignored_mask.type_as(loss) 174 | 175 | if self.reduction == "mean": 176 | loss = loss.mean() 177 | 178 | if self.reduction == "sum": 179 | loss = loss.sum() 180 | 181 | return loss 182 | 183 | 184 | def label_smoothed_nll_loss( 185 | lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1 186 | ) -> torch.Tensor: 187 | """ 188 | Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py 189 | :param lprobs: Log-probabilities of predictions (e.g after log_softmax) 190 | :param target: 191 | :param epsilon: 192 | :param ignore_index: 193 | :param reduction: 194 | :return: 195 | """ 196 | if target.dim() == lprobs.dim() - 1: 197 | target = target.unsqueeze(dim) 198 | 199 | if ignore_index is not None: 200 | pad_mask = target.eq(ignore_index) 201 | target = target.masked_fill(pad_mask, 0) 202 | nll_loss = -lprobs.gather(dim=dim, index=target) 203 | smooth_loss = -lprobs.sum(dim=dim, keepdim=True) 204 | 205 | # nll_loss.masked_fill_(pad_mask, 0.0) 206 | # smooth_loss.masked_fill_(pad_mask, 0.0) 207 | nll_loss = nll_loss.masked_fill(pad_mask, 0.0) 208 | smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0) 209 | else: 210 | nll_loss = -lprobs.gather(dim=dim, index=target) 211 | smooth_loss = -lprobs.sum(dim=dim, keepdim=True) 212 | 213 | nll_loss = nll_loss.squeeze(dim) 214 | smooth_loss = smooth_loss.squeeze(dim) 215 | 216 | if reduction == "sum": 217 | nll_loss = nll_loss.sum() 218 | smooth_loss = smooth_loss.sum() 219 | if reduction == "mean": 220 | nll_loss = nll_loss.mean() 221 | smooth_loss = smooth_loss.mean() 222 | 223 | eps_i = epsilon / lprobs.size(dim) 224 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 225 | return loss 226 | 227 | 228 | class SoftCrossEntropyLoss(nn.Module): 229 | """ 230 | Drop-in replacement for nn.CrossEntropyLoss with few additions: 231 | - Support of label smoothing 232 | """ 233 | 234 | __constants__ = ["reduction", "ignore_index", "smooth_factor"] 235 | 236 | def __init__(self, reduction: str = "mean", smooth_factor: float = 0.0, ignore_index=255, dim=1): 237 | super().__init__() 238 | self.smooth_factor = smooth_factor 239 | self.ignore_index = ignore_index 240 | self.reduction = reduction 241 | self.dim = dim 242 | 243 | def forward(self, input, target): 244 | log_prob = F.log_softmax(input, dim=self.dim) 245 | return label_smoothed_nll_loss( 246 | log_prob, 247 | target, 248 | epsilon=self.smooth_factor, 249 | ignore_index=self.ignore_index, 250 | reduction=self.reduction, 251 | dim=self.dim, 252 | ) 253 | -------------------------------------------------------------------------------- /torch_utils/criterion/bitempered_loss.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/BloodAxe/pytorch-toolbelt 2 | 3 | import torch 4 | from typing import Optional 5 | from torch import nn, Tensor 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["BiTemperedLogisticLoss", "BinaryBiTemperedLogisticLoss"] 9 | 10 | 11 | def log_t(u, t): 12 | """Compute log_t for `u'.""" 13 | if t == 1.0: 14 | return u.log() 15 | else: 16 | return (u.pow(1.0 - t) - 1.0) / (1.0 - t) 17 | 18 | 19 | def exp_t(u, t): 20 | """Compute exp_t for `u'.""" 21 | if t == 1: 22 | return u.exp() 23 | else: 24 | return (1.0 + (1.0 - t) * u).relu().pow(1.0 / (1.0 - t)) 25 | 26 | 27 | def compute_normalization_fixed_point(activations: Tensor, t: float, num_iters: int) -> Tensor: 28 | """Return the normalization value for each example (t > 1.0). 29 | Args: 30 | activations: A multi-dimensional tensor with last dimension `num_classes`. 31 | t: Temperature 2 (> 1.0 for tail heaviness). 32 | num_iters: Number of iterations to run the method. 33 | Return: A tensor of same shape as activation with the last dimension being 1. 34 | """ 35 | mu, _ = torch.max(activations, -1, keepdim=True) 36 | normalized_activations_step_0 = activations - mu 37 | 38 | normalized_activations = normalized_activations_step_0 39 | 40 | for _ in range(num_iters): 41 | logt_partition = torch.sum(exp_t(normalized_activations, t), -1, keepdim=True) 42 | normalized_activations = normalized_activations_step_0 * logt_partition.pow(1.0 - t) 43 | 44 | logt_partition = torch.sum(exp_t(normalized_activations, t), -1, keepdim=True) 45 | normalization_constants = -log_t(1.0 / logt_partition, t) + mu 46 | 47 | return normalization_constants 48 | 49 | 50 | def compute_normalization_binary_search(activations: Tensor, t: float, num_iters: int) -> Tensor: 51 | """Compute normalization value for each example (t < 1.0). 52 | Args: 53 | activations: A multi-dimensional tensor with last dimension `num_classes`. 54 | t: Temperature 2 (< 1.0 for finite support). 55 | num_iters: Number of iterations to run the method. 56 | Return: A tensor of same rank as activation with the last dimension being 1. 57 | """ 58 | mu, _ = torch.max(activations, -1, keepdim=True) 59 | normalized_activations = activations - mu 60 | 61 | effective_dim = torch.sum((normalized_activations > -1.0 / (1.0 - t)).to(torch.int32), dim=-1, keepdim=True).to( 62 | activations.dtype 63 | ) 64 | 65 | shape_partition = activations.shape[:-1] + (1,) 66 | lower = torch.zeros(shape_partition, dtype=activations.dtype, device=activations.device) 67 | upper = -log_t(1.0 / effective_dim, t) * torch.ones_like(lower) 68 | 69 | for _ in range(num_iters): 70 | logt_partition = (upper + lower) / 2.0 71 | sum_probs = torch.sum(exp_t(normalized_activations - logt_partition, t), dim=-1, keepdim=True) 72 | update = (sum_probs < 1.0).to(activations.dtype) 73 | lower = torch.reshape(lower * update + (1.0 - update) * logt_partition, shape_partition) 74 | upper = torch.reshape(upper * (1.0 - update) + update * logt_partition, shape_partition) 75 | 76 | logt_partition = (upper + lower) / 2.0 77 | return logt_partition + mu 78 | 79 | 80 | class ComputeNormalization(torch.autograd.Function): 81 | """ 82 | Class implementing custom backward pass for compute_normalization. See compute_normalization. 83 | """ 84 | 85 | @staticmethod 86 | def forward(ctx, activations, t, num_iters): 87 | if t < 1.0: 88 | normalization_constants = compute_normalization_binary_search(activations, t, num_iters) 89 | else: 90 | normalization_constants = compute_normalization_fixed_point(activations, t, num_iters) 91 | 92 | ctx.save_for_backward(activations, normalization_constants) 93 | ctx.t = t 94 | return normalization_constants 95 | 96 | @staticmethod 97 | def backward(ctx, grad_output): 98 | activations, normalization_constants = ctx.saved_tensors 99 | t = ctx.t 100 | normalized_activations = activations - normalization_constants 101 | probabilities = exp_t(normalized_activations, t) 102 | escorts = probabilities.pow(t) 103 | escorts = escorts / escorts.sum(dim=-1, keepdim=True) 104 | grad_input = escorts * grad_output 105 | 106 | return grad_input, None, None 107 | 108 | 109 | def compute_normalization(activations, t, num_iters=5): 110 | """Compute normalization value for each example. 111 | Backward pass is implemented. 112 | Args: 113 | activations: A multi-dimensional tensor with last dimension `num_classes`. 114 | t: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 115 | num_iters: Number of iterations to run the method. 116 | Return: A tensor of same rank as activation with the last dimension being 1. 117 | """ 118 | return ComputeNormalization.apply(activations, t, num_iters) 119 | 120 | 121 | def tempered_softmax(activations, t, num_iters=5): 122 | """Tempered softmax function. 123 | Args: 124 | activations: A multi-dimensional tensor with last dimension `num_classes`. 125 | t: Temperature > 1.0. 126 | num_iters: Number of iterations to run the method. 127 | Returns: 128 | A probabilities tensor. 129 | """ 130 | if t == 1.0: 131 | return activations.softmax(dim=-1) 132 | 133 | normalization_constants = compute_normalization(activations, t, num_iters) 134 | return exp_t(activations - normalization_constants, t) 135 | 136 | 137 | def bi_tempered_logistic_loss(activations, labels, t1, t2, label_smoothing=0.0, num_iters=5, reduction="mean"): 138 | """Bi-Tempered Logistic Loss. 139 | Args: 140 | activations: A multi-dimensional tensor with last dimension `num_classes`. 141 | labels: A tensor with shape and dtype as activations (onehot), 142 | or a long tensor of one dimension less than activations (pytorch standard) 143 | t1: Temperature 1 (< 1.0 for boundedness). 144 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 145 | label_smoothing: Label smoothing parameter between [0, 1). Default 0.0. 146 | num_iters: Number of iterations to run the method. Default 5. 147 | reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``. 148 | ``'none'``: No reduction is applied, return shape is shape of 149 | activations without the last dimension. 150 | ``'mean'``: Loss is averaged over minibatch. Return shape (1,) 151 | ``'sum'``: Loss is summed over minibatch. Return shape (1,) 152 | Returns: 153 | A loss tensor. 154 | """ 155 | if len(labels.shape) < len(activations.shape): # not one-hot 156 | labels_onehot = torch.zeros_like(activations) 157 | labels_onehot.scatter_(1, labels[..., None], 1) 158 | else: 159 | labels_onehot = labels 160 | 161 | if label_smoothing > 0: 162 | num_classes = labels_onehot.shape[-1] 163 | labels_onehot = (1 - label_smoothing * num_classes / (num_classes - 1)) * labels_onehot + label_smoothing / ( 164 | num_classes - 1 165 | ) 166 | 167 | probabilities = tempered_softmax(activations, t2, num_iters) 168 | 169 | loss_values = ( 170 | labels_onehot * log_t(labels_onehot + 1e-10, t1) 171 | - labels_onehot * log_t(probabilities, t1) 172 | - labels_onehot.pow(2.0 - t1) / (2.0 - t1) 173 | + probabilities.pow(2.0 - t1) / (2.0 - t1) 174 | ) 175 | loss_values = loss_values.sum(dim=-1) # sum over classes 176 | 177 | if reduction == "none": 178 | return loss_values 179 | if reduction == "sum": 180 | return loss_values.sum() 181 | if reduction == "mean": 182 | return loss_values.mean() 183 | 184 | 185 | class BiTemperedLogisticLoss(nn.Module): 186 | """ 187 | 188 | https://ai.googleblog.com/2019/08/bi-tempered-logistic-loss-for-training.html 189 | https://arxiv.org/abs/1906.03361 190 | """ 191 | 192 | def __init__(self, t1=0.8, t2=1.4, smoothing=0.0, ignore_index=None, reduction: str = "mean"): 193 | """ 194 | 195 | Args: 196 | t1: Temperature 1 (< 1.0 for boundedness). 197 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 198 | smoothing: Label smoothing parameter between [0, 1). Default 0.0. 199 | ignore_index: ignore_index 200 | reduction: reduction 201 | """ 202 | super(BiTemperedLogisticLoss, self).__init__() 203 | self.t1 = t1 204 | self.t2 = t2 205 | self.smoothing = smoothing 206 | self.reduction = reduction 207 | self.ignore_index = ignore_index 208 | 209 | def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: 210 | loss = bi_tempered_logistic_loss( 211 | predictions, targets, t1=self.t1, t2=self.t2, label_smoothing=self.smoothing, reduction="none" 212 | ) 213 | 214 | if self.ignore_index is not None: 215 | mask = ~targets.eq(self.ignore_index) 216 | loss *= mask 217 | 218 | if self.reduction == "mean": 219 | loss = loss.mean() 220 | elif self.reduction == "sum": 221 | loss = loss.sum() 222 | return loss 223 | 224 | 225 | class BinaryBiTemperedLogisticLoss(nn.Module): 226 | """ 227 | Modification of BiTemperedLogisticLoss for binary classification case. 228 | It's signature matches nn.BCEWithLogitsLoss: Predictions and target tensors must have shape [B,1,...] 229 | 230 | References: 231 | https://ai.googleblog.com/2019/08/bi-tempered-logistic-loss-for-training.html 232 | https://arxiv.org/abs/1906.03361 233 | """ 234 | 235 | def __init__( 236 | self, t1=0.8, t2=1.4, smoothing: float = 0.0, ignore_index: Optional[int] = None, reduction: str = "mean" 237 | ): 238 | """ 239 | 240 | Args: 241 | t1: Temperature 1 (< 1.0 for boundedness). 242 | t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). 243 | smoothing: Label smoothing parameter between [0, 1). Default 0.0. 244 | ignore_index: ignore_index 245 | reduction: reduction 246 | """ 247 | super().__init__() 248 | self.t1 = t1 249 | self.t2 = t2 250 | self.smoothing = smoothing 251 | self.reduction = reduction 252 | self.ignore_index = ignore_index 253 | 254 | def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: 255 | """ 256 | 257 | Args: 258 | predictions: [B,1,...] 259 | targets: [B,1,...] 260 | 261 | Returns: 262 | 263 | """ 264 | if predictions.size(1) != 1 or targets.size(1) != 1: 265 | raise ValueError("Channel dimension for predictions and targets must be equal to 1") 266 | 267 | loss = bi_tempered_logistic_loss( 268 | torch.cat([-predictions, predictions], dim=1).moveaxis(1, -1), 269 | torch.cat([1 - targets, targets], dim=1).moveaxis(1, -1), 270 | t1=self.t1, 271 | t2=self.t2, 272 | label_smoothing=self.smoothing, 273 | reduction="none", 274 | ).unsqueeze(dim=1) 275 | 276 | if self.ignore_index is not None: 277 | mask = targets.eq(self.ignore_index) 278 | loss = torch.masked_fill(loss, mask, 0) 279 | 280 | if self.reduction == "mean": 281 | loss = loss.mean() 282 | elif self.reduction == "sum": 283 | loss = loss.sum() 284 | return loss 285 | -------------------------------------------------------------------------------- /torch_utils/optimizer/gc.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/Yonghongwei/Gradient-Centralization 2 | 3 | # GCC: GC for conv only 4 | # GC: GC for both conv and fc 5 | # weight decay: 5e-5 or 1e-4; momentum: 0.9 6 | 7 | import math 8 | import torch 9 | from torch.optim.optimizer import Optimizer, required 10 | 11 | 12 | class SGD_GCC(Optimizer): 13 | 14 | def __init__(self, params, lr=required, momentum=0.9, dampening=0, 15 | weight_decay=5e-5, nesterov=False): 16 | if lr is not required and lr < 0.0: 17 | raise ValueError("Invalid learning rate: {}".format(lr)) 18 | if momentum < 0.0: 19 | raise ValueError("Invalid momentum value: {}".format(momentum)) 20 | if weight_decay < 0.0: 21 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 22 | 23 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 24 | weight_decay=weight_decay, nesterov=nesterov) 25 | if nesterov and (momentum <= 0 or dampening != 0): 26 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 27 | super(SGD_GCC, self).__init__(params, defaults) 28 | 29 | def __setstate__(self, state): 30 | super(SGD_GCC, self).__setstate__(state) 31 | for group in self.param_groups: 32 | group.setdefault('nesterov', False) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | Arguments: 37 | closure (callable, optional): A closure that reevaluates the model 38 | and returns the loss. 39 | """ 40 | loss = None 41 | if closure is not None: 42 | loss = closure() 43 | 44 | for group in self.param_groups: 45 | weight_decay = group['weight_decay'] 46 | momentum = group['momentum'] 47 | dampening = group['dampening'] 48 | nesterov = group['nesterov'] 49 | 50 | for p in group['params']: 51 | if p.grad is None: 52 | continue 53 | d_p = p.grad.data 54 | 55 | if weight_decay != 0: 56 | d_p.add_(weight_decay, p.data) 57 | 58 | # GC operation for Conv layers 59 | if len(list(d_p.size())) > 3: 60 | d_p.add_(-d_p.mean(dim=tuple(range(1, len(list(d_p.size())))), keepdim=True)) 61 | 62 | if momentum != 0: 63 | param_state = self.state[p] 64 | if 'momentum_buffer' not in param_state: 65 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 66 | else: 67 | buf = param_state['momentum_buffer'] 68 | buf.mul_(momentum).add_(1 - dampening, d_p) 69 | if nesterov: 70 | d_p = d_p.add(momentum, buf) 71 | else: 72 | d_p = buf 73 | 74 | p.data.add_(-group['lr'], d_p) 75 | 76 | return loss 77 | 78 | 79 | class SGD_GC(Optimizer): 80 | 81 | def __init__(self, params, lr=required, momentum=0.9, dampening=0, 82 | weight_decay=5e-5, nesterov=False): 83 | if lr is not required and lr < 0.0: 84 | raise ValueError("Invalid learning rate: {}".format(lr)) 85 | if momentum < 0.0: 86 | raise ValueError("Invalid momentum value: {}".format(momentum)) 87 | if weight_decay < 0.0: 88 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 89 | 90 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 91 | weight_decay=weight_decay, nesterov=nesterov) 92 | if nesterov and (momentum <= 0 or dampening != 0): 93 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 94 | super(SGD_GC, self).__init__(params, defaults) 95 | 96 | def __setstate__(self, state): 97 | super(SGD_GC, self).__setstate__(state) 98 | for group in self.param_groups: 99 | group.setdefault('nesterov', False) 100 | 101 | def step(self, closure=None): 102 | """Performs a single optimization step. 103 | Arguments: 104 | closure (callable, optional): A closure that reevaluates the model 105 | and returns the loss. 106 | """ 107 | loss = None 108 | if closure is not None: 109 | loss = closure() 110 | 111 | for group in self.param_groups: 112 | weight_decay = group['weight_decay'] 113 | momentum = group['momentum'] 114 | dampening = group['dampening'] 115 | nesterov = group['nesterov'] 116 | 117 | for p in group['params']: 118 | if p.grad is None: 119 | continue 120 | d_p = p.grad.data 121 | 122 | if weight_decay != 0: 123 | d_p.add_(weight_decay, p.data) 124 | 125 | # GC operation for Conv layers and FC layers 126 | if len(list(d_p.size())) > 1: 127 | d_p.add_(-d_p.mean(dim=tuple(range(1, len(list(d_p.size())))), keepdim=True)) 128 | 129 | if momentum != 0: 130 | param_state = self.state[p] 131 | if 'momentum_buffer' not in param_state: 132 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 133 | else: 134 | buf = param_state['momentum_buffer'] 135 | buf.mul_(momentum).add_(1 - dampening, d_p) 136 | if nesterov: 137 | d_p = d_p.add(momentum, buf) 138 | else: 139 | d_p = buf 140 | 141 | p.data.add_(-group['lr'], d_p) 142 | 143 | return loss 144 | 145 | 146 | class AdamW_GCC2(Optimizer): 147 | """Implements Adam algorithm. 148 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 149 | Arguments: 150 | params (iterable): iterable of parameters to optimize or dicts defining 151 | parameter groups 152 | lr (float, optional): learning rate (default: 1e-3) 153 | betas (Tuple[float, float], optional): coefficients used for computing 154 | running averages of gradient and its square (default: (0.9, 0.999)) 155 | eps (float, optional): term added to the denominator to improve 156 | numerical stability (default: 1e-8) 157 | weight_decay (float, optional): weight decay (L2 penalty) (default: 5e-5) 158 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 159 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 160 | .. _Adam: A Method for Stochastic Optimization: 161 | https://arxiv.org/abs/1412.6980 162 | .. _On the Convergence of Adam and Beyond: 163 | https://openreview.net/forum?id=ryQu7f-RZ 164 | """ 165 | 166 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 167 | weight_decay=5e-5, amsgrad=False): 168 | if not 0.0 <= lr: 169 | raise ValueError("Invalid learning rate: {}".format(lr)) 170 | if not 0.0 <= eps: 171 | raise ValueError("Invalid epsilon value: {}".format(eps)) 172 | if not 0.0 <= betas[0] < 1.0: 173 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 174 | if not 0.0 <= betas[1] < 1.0: 175 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 176 | defaults = dict(lr=lr, betas=betas, eps=eps, 177 | weight_decay=weight_decay, amsgrad=amsgrad) 178 | super(AdamW_GCC2, self).__init__(params, defaults) 179 | 180 | def __setstate__(self, state): 181 | super(AdamW_GCC2, self).__setstate__(state) 182 | for group in self.param_groups: 183 | group.setdefault('amsgrad', False) 184 | 185 | def step(self, closure=None): 186 | """Performs a single optimization step. 187 | Arguments: 188 | closure (callable, optional): A closure that reevaluates the model 189 | and returns the loss. 190 | """ 191 | loss = None 192 | if closure is not None: 193 | loss = closure() 194 | 195 | for group in self.param_groups: 196 | for p in group['params']: 197 | if p.grad is None: 198 | continue 199 | grad = p.grad.data 200 | if grad.is_sparse: 201 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 202 | amsgrad = group['amsgrad'] 203 | 204 | state = self.state[p] 205 | 206 | # State initialization 207 | if len(state) == 0: 208 | state['step'] = 0 209 | # Exponential moving average of gradient values 210 | state['exp_avg'] = torch.zeros_like(p.data) 211 | # Exponential moving average of squared gradient values 212 | state['exp_avg_sq'] = torch.zeros_like(p.data) 213 | if amsgrad: 214 | # Maintains max of all exp. moving avg. of sq. grad. values 215 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 216 | 217 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 218 | if amsgrad: 219 | max_exp_avg_sq = state['max_exp_avg_sq'] 220 | beta1, beta2 = group['betas'] 221 | 222 | # GC operation for Conv layers 223 | if len(list(grad.size())) > 3: 224 | # weight_mean = p.data.mean(dim=tuple(range(1, len(list(grad.size())))), keepdim=True) 225 | grad.add_(-grad.mean(dim=tuple(range(1, len(list(grad.size())))), keepdim=True)) 226 | 227 | state['step'] += 1 228 | 229 | # if group['weight_decay'] != 0: 230 | # grad = grad.add(group['weight_decay'], p.data) 231 | 232 | # Decay the first and second moment running average coefficient 233 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 234 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 235 | if amsgrad: 236 | # Maintains the maximum of all 2nd moment running avg. till now 237 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 238 | # Use the max. for normalizing running avg. of gradient 239 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 240 | else: 241 | denom = exp_avg_sq.sqrt().add_(group['eps']) 242 | 243 | bias_correction1 = 1 - beta1 ** state['step'] 244 | bias_correction2 = 1 - beta2 ** state['step'] 245 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 246 | 247 | # GC operation for Conv layers 248 | if len(list(grad.size())) > 3: 249 | delta = (step_size * torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)).clone() 250 | delta.add_(-delta.mean(dim=tuple(range(1, len(list(grad.size())))), keepdim=True)) 251 | p.data.add_(-delta) 252 | else: 253 | p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)) 254 | 255 | return loss 256 | -------------------------------------------------------------------------------- /torch_utils/optimizer/ranger21/rangerabel.py: -------------------------------------------------------------------------------- 1 | # Ranger21 - @lessw2020 2 | # This is experimental branch of auto lr...not recommended for use atm. 3 | 4 | # core components based on: 5 | 6 | # MADGRAD: https://arxiv.org/abs/2101.11075 7 | 8 | # warmup: https://arxiv.org/abs/1910.04209v3 9 | 10 | # stable weight decay: https://arxiv.org/abs/2011.11152v3 11 | 12 | # Gradient Centralization: https://arxiv.org/abs/2004.01461v2 13 | 14 | 15 | import torch 16 | import torch.optim as TO 17 | import torch.nn.functional as F 18 | 19 | import math 20 | import collections 21 | 22 | import copy 23 | from torch import linalg as LA 24 | 25 | 26 | def centralize_gradient(x, gc_conv_only=False): 27 | """credit - https://github.com/Yonghongwei/Gradient-Centralization """ 28 | 29 | size = len(list(x.size())) 30 | # print(f"size = {size}") 31 | 32 | if gc_conv_only: 33 | if size > 3: 34 | x.add_(-x.mean(dim=tuple(range(1, size)), keepdim=True)) 35 | else: 36 | if size > 1: 37 | x.add_(-x.mean(dim=tuple(range(1, size)), keepdim=True)) 38 | return x 39 | 40 | 41 | class Ranger21abel(TO.Optimizer): 42 | def __init__( 43 | self, 44 | params, 45 | lr, 46 | betas=(0.9, 0.999), # temp for checking tuned warmups 47 | momentum=0.9, 48 | eps=1e-8, 49 | num_batches_per_epoch=None, 50 | num_epochs=None, 51 | use_abel=True, 52 | abel_decay_factor=.3, 53 | use_warmup=True, 54 | num_warmup_iterations=None, 55 | weight_decay=1e-4, 56 | decay_type="stable", 57 | warmup_type="linear", 58 | use_gradient_centralization=True, 59 | gc_conv_only=False, 60 | ): 61 | 62 | # todo - checks on incoming params 63 | defaults = dict( 64 | lr=lr, momentum=momentum, betas=betas, eps=eps, weight_decay=weight_decay 65 | ) 66 | super().__init__(params, defaults) 67 | 68 | self.num_batches = num_batches_per_epoch 69 | self.num_epochs = num_epochs 70 | 71 | self.warmup_type = warmup_type 72 | self.use_gc = (use_gradient_centralization,) 73 | self.gc_conv_only = (gc_conv_only,) 74 | self.starting_lr = lr 75 | self.current_lr = lr 76 | 77 | # abel 78 | self.use_abel = use_abel 79 | self.weight_list = [] 80 | self.batch_count = 0 81 | self.epoch = 0 82 | self.lr_decay_factor = abel_decay_factor 83 | self.abel_decay_end = math.ceil(self.num_epochs * .85) 84 | self.reached_minima = False 85 | self.pweight_accumulator = 0 86 | 87 | # decay 88 | self.decay = weight_decay 89 | self.decay_type = decay_type 90 | self.param_size = 0 91 | 92 | # warmup - we'll use default recommended in Ma/Yarats unless user specifies num iterations 93 | self.use_warmup = use_warmup 94 | if num_warmup_iterations is None: 95 | self.num_warmup_iters = math.ceil( 96 | (2 / (1 - betas[1])) 97 | ) # default untuned linear warmup 98 | else: 99 | self.num_warmup_iters = num_warmup_iterations 100 | 101 | # logging 102 | self.variance_sum_tracking = [] 103 | 104 | # print out initial settings to make usage easier 105 | print(f"Ranger21 optimizer ready with following settings:\n") 106 | print(f"Learning rate of {self.starting_lr}") 107 | if self.use_warmup: 108 | print(f"{self.warmup_type} warmup, over {self.num_warmup_iters} iterations") 109 | 110 | print(f"Stable weight decay of {self.decay}") 111 | if self.use_gc: 112 | print(f"Gradient Centralization = On") 113 | else: 114 | print("Gradient Centralization = Off") 115 | print(f"Num Epochs = {self.num_epochs}") 116 | print(f"Num batches per epoch = {self.num_batches}") 117 | 118 | def __setstate__(self, state): 119 | super().__setstate__(state) 120 | 121 | def warmup_dampening(self, lr, step): 122 | # not usable yet 123 | style = self.warmup_type 124 | warmup = self.num_warmup_iters 125 | 126 | if style is None: 127 | return 1.0 128 | 129 | if style == "linear": 130 | return lr * min(1.0, (step / warmup)) 131 | 132 | elif style == "exponential": 133 | return lr * (1.0 - math.exp(-step / warmup)) 134 | else: 135 | raise ValueError(f"warmup type {style} not implemented.") 136 | 137 | def get_variance(self): 138 | return self.variance_sum_tracking 139 | 140 | def get_state_values(self, group, state): 141 | beta1, beta2 = group["betas"] 142 | mean_avg = state["mean_avg"] 143 | variance_avg = state["variance_avg"] 144 | 145 | return beta1, beta2, mean_avg, variance_avg 146 | 147 | def abel_update(self, step_fn, weight_norm, current_lr): 148 | ''' update lr based on abel''' 149 | 150 | self.pweight_accumulator += weight_norm 151 | 152 | self.batch_count += 1 153 | # print(f"self.batch count = {self.batch_count}") 154 | if self.batch_count == self.num_batches: 155 | self.epoch += 1 156 | self.batch_count = 0 157 | print(f"epoch eval for epoch {self.epoch}") 158 | 159 | # store weights 160 | self.weight_list.append(self.pweight_accumulator) 161 | 162 | print(f"total norm for epoch {self.epoch} = {weight_norm}") 163 | # self.pweight_accumulator = 0 164 | 165 | if self.batch_count != 0: 166 | return None 167 | # self.epoch +=1 168 | new_lr = current_lr 169 | 170 | if len(self.weight_list) < 3: 171 | print(len(self.weight_list)) 172 | return step_fn 173 | 174 | # compute weight norm delta 175 | if (self.weight_list[-1] - self.weight_list[-2]) * (self.weight_list[-2] - self.weight_list[-3]) < 0: 176 | if self.reached_minima: 177 | self.reached_minima = False 178 | new_lr *= self.lr_decay_factor 179 | # step_fn = self.update_train_step(self.learning_rate) 180 | else: 181 | self.reached_minima = True 182 | print(f"\n*****\nABEL mininum detected, new lr = {new_lr}\n***\n") 183 | 184 | if self.epoch == self.abel_decay_end: 185 | new_lr *= self.lr_decay_factor 186 | print(f"abel final decay done, new lr = {new_lr}") 187 | return new_lr 188 | # @staticmethod 189 | 190 | @torch.no_grad() 191 | def step(self, closure=None): 192 | 193 | loss = None 194 | if closure is not None and isinstance(closure, collections.Callable): 195 | with torch.grad(): 196 | loss = closure() 197 | 198 | param_size = 0 199 | variance_ma_sum = 0.0 200 | weight_norm = 0 201 | 202 | # phase 1 - accumulate all of the variance_ma_sum to use in stable weight decay 203 | 204 | for i, group in enumerate(self.param_groups): 205 | for j, p in enumerate(group["params"]): 206 | if p.grad is None: 207 | continue 208 | 209 | if not self.param_size: 210 | param_size += p.numel() 211 | 212 | grad = p.grad 213 | 214 | if grad.is_sparse: 215 | raise RuntimeError("sparse matrix not supported atm") 216 | 217 | state = self.state[p] 218 | 219 | current_weight_norm = LA.norm(p.data) 220 | # print(f"running norm = {current_weight_norm}") 221 | weight_norm += current_weight_norm.item() 222 | 223 | # State initialization 224 | if len(state) == 0: 225 | # print("init state") 226 | state["step"] = 0 227 | # Exponential moving average of gradient values 228 | state["grad_ma"] = torch.zeros_like( 229 | p, memory_format=torch.preserve_format 230 | ) 231 | # Exponential moving average of squared gradient values 232 | state["variance_ma"] = torch.zeros_like( 233 | p, memory_format=torch.preserve_format 234 | ) 235 | 236 | # centralize gradients 237 | if self.use_gc: 238 | grad = centralize_gradient( 239 | grad, 240 | gc_conv_only=self.gc_conv_only, 241 | ) 242 | # else: 243 | # grad = uncentralized_grad 244 | 245 | state["step"] += 1 246 | 247 | beta1, beta2 = group["betas"] 248 | grad_ma = state["grad_ma"] 249 | variance_ma = state["variance_ma"] 250 | 251 | bias_correction2 = 1 - beta2 ** state["step"] 252 | 253 | # update the exp averages 254 | grad_ma.mul_(beta1).add_(grad, alpha=1 - beta1) 255 | 256 | variance_ma.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 257 | 258 | variance_ma_debiased = variance_ma / bias_correction2 259 | 260 | variance_ma_sum += variance_ma_debiased.sum() 261 | 262 | # print(f"variance hat sum = {exp_avg_sq_hat_sum}") 263 | # Calculate the sqrt of the mean of all elements in exp_avg_sq_hat 264 | 265 | # we will run this first epoch only and then memoize 266 | if not self.param_size: 267 | self.param_size = param_size 268 | print(f"params size saved") 269 | print(f"total param groups = {i+1}") 270 | print(f"total params in groups = {j+1}") 271 | 272 | if not self.param_size: 273 | raise ValueError("failed to set param size") 274 | 275 | # debugging 276 | self.variance_sum_tracking.append(variance_ma_sum.item()) 277 | 278 | variance_normalized = math.sqrt(variance_ma_sum / self.param_size) 279 | 280 | # print(f"variance mean sqrt = {variance_normalized}") 281 | 282 | # phase 2 - apply weight decay and step 283 | for group in self.param_groups: 284 | for p in group["params"]: 285 | if p.grad is None: 286 | continue 287 | 288 | state = self.state[p] 289 | 290 | step = state["step"] 291 | 292 | # Perform stable weight decay 293 | decay = group["weight_decay"] 294 | eps = group["eps"] 295 | # lr = group["lr"] 296 | lr = self.current_lr 297 | 298 | if self.use_warmup: 299 | lr = self.warmup_dampening(lr, step) 300 | # if step < 10: 301 | # print(f"warmup dampening at step {step} = {lr} vs {group['lr']}") 302 | 303 | if decay: 304 | p.data.mul_(1 - decay * lr / variance_normalized) 305 | 306 | beta1, beta2 = group["betas"] 307 | grad_exp_avg = state["grad_ma"] 308 | variance_ma = state["variance_ma"] 309 | 310 | bias_correction1 = 1 - beta1 ** step 311 | bias_correction2 = 1 - beta2 ** step 312 | 313 | variance_biased_ma = variance_ma / bias_correction2 314 | 315 | denom = variance_biased_ma.sqrt().add(eps) 316 | 317 | # weight_mod = grad_exp_avg / denom 318 | 319 | step_size = lr / bias_correction1 320 | 321 | # update weights 322 | # p.data.add_(weight_mod, alpha=-step_size) 323 | p.addcdiv_(grad_exp_avg, denom, value=-step_size) 324 | 325 | # abel step 326 | abel_result = self.abel_update(None, weight_norm, self.current_lr) 327 | if abel_result is not None: 328 | self.current_lr = abel_result 329 | 330 | return loss 331 | -------------------------------------------------------------------------------- /torch_utils/dataset/mixup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mixup and Cutmix 3 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 4 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 5 | changed from: https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/mixup.py 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torch.distributions import Beta 11 | from torch.utils.data import Dataset 12 | 13 | 14 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 15 | x = x.long().view(-1, 1) 16 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 17 | 18 | 19 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 20 | off_value = smoothing / num_classes 21 | on_value = 1. - smoothing + off_value 22 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 23 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 24 | return y1 * lam + y2 * (1. - lam) 25 | 26 | 27 | def mixup_target_multi_binary(target, lam=1., smoothing=0.0, device='cuda'): 28 | target = target * (1. - smoothing) + smoothing / 2. 29 | y1 = target.to(device) 30 | y2 = target.flip(0).to(device) 31 | return y1 * lam + y2 * (1. - lam) 32 | 33 | 34 | def rand_bbox(img_shape, lam, margin=0., count=None): 35 | """ Standard CutMix bounding-box 36 | Generates a random square bbox based on lambda value. This impl includes 37 | support for enforcing a border margin as percent of bbox dimensions. 38 | Args: 39 | img_shape (tuple): Image shape as tuple 40 | lam (float): Cutmix lambda value 41 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 42 | count (int): Number of bbox to generate 43 | """ 44 | ratio = np.sqrt(1 - lam) 45 | img_h, img_w = img_shape[-2:] 46 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 47 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 48 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 49 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 50 | yl = np.clip(cy - cut_h // 2, 0, img_h) 51 | yh = np.clip(cy + cut_h // 2, 0, img_h) 52 | xl = np.clip(cx - cut_w // 2, 0, img_w) 53 | xh = np.clip(cx + cut_w // 2, 0, img_w) 54 | return yl, yh, xl, xh 55 | 56 | 57 | def rand_bbox_minmax(img_shape, minmax, count=None): 58 | """ Min-Max CutMix bounding-box 59 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 60 | based on min/max percent values applied to each dimension of the input image. 61 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 62 | Args: 63 | img_shape (tuple): Image shape as tuple 64 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 65 | count (int): Number of bbox to generate 66 | """ 67 | assert len(minmax) == 2 68 | img_h, img_w = img_shape[-2:] 69 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 70 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 71 | yl = np.random.randint(0, img_h - cut_h, size=count) 72 | xl = np.random.randint(0, img_w - cut_w, size=count) 73 | yu = yl + cut_h 74 | xu = xl + cut_w 75 | return yl, yu, xl, xu 76 | 77 | 78 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 79 | """ Generate bbox and apply lambda correction. 80 | """ 81 | if ratio_minmax is not None: 82 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 83 | else: 84 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 85 | if correct_lam or ratio_minmax is not None: 86 | bbox_area = (yu - yl) * (xu - xl) 87 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 88 | return (yl, yu, xl, xu), lam 89 | 90 | 91 | class Mixup: 92 | """ 93 | Mixup/Cutmix that applies different params to each element or whole batch 94 | 95 | Args: 96 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 97 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 98 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 99 | prob (float): probability of applying mixup or cutmix per batch or element 100 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 101 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 102 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 103 | onehot (bool): whether one hot dtype Long label input or float multi-hot or soft label 104 | label_smoothing (float): apply label smoothing to the mixed target tensor 105 | num_classes (int): number of classes for target 106 | 107 | Examples:: 108 | >>> mixup, cutmix = 0.35, 0.15 109 | >>> prob = mixup + cutmix 110 | >>> switch_prob = cutmix / prob 111 | >>> mixup_fn = Mixup(prob=prob, switch_prob=switch_prob, onthot=False, label_smoothing=0.0) 112 | >>> for batch_idx, (input, target) in enumerate(loader): 113 | >>> input, target = input.cuda(), target.cuda() 114 | >>> input, target = mixup_fn(input, target) 115 | """ 116 | 117 | def __init__(self, mixup_alpha=0.2, cutmix_alpha=1.0, cutmix_minmax=None, prob=0.2, switch_prob=0.3, 118 | mode='elem', correct_lam=True, onehot=True, label_smoothing=0.0, num_classes=1000): 119 | self.mixup_alpha = mixup_alpha 120 | self.cutmix_alpha = cutmix_alpha 121 | self.cutmix_minmax = cutmix_minmax 122 | if self.cutmix_minmax is not None: 123 | assert len(self.cutmix_minmax) == 2 124 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 125 | self.cutmix_alpha = 1.0 126 | self.mix_prob = prob 127 | self.switch_prob = switch_prob 128 | self.onehot = onehot 129 | self.label_smoothing = label_smoothing 130 | self.num_classes = num_classes 131 | self.mode = mode 132 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 133 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 134 | 135 | def _params_per_elem(self, batch_size): 136 | lam = np.ones(batch_size, dtype=np.float32) 137 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 138 | if self.mixup_enabled: 139 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 140 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 141 | lam_mix = np.where( 142 | use_cutmix, 143 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 144 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 145 | elif self.mixup_alpha > 0.: 146 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 147 | elif self.cutmix_alpha > 0.: 148 | use_cutmix = np.ones(batch_size, dtype=np.bool) 149 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 150 | else: 151 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 152 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 153 | return lam, use_cutmix 154 | 155 | def _params_per_batch(self): 156 | lam = 1. 157 | use_cutmix = False 158 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 159 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 160 | use_cutmix = np.random.rand() < self.switch_prob 161 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 162 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 163 | elif self.mixup_alpha > 0.: 164 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 165 | elif self.cutmix_alpha > 0.: 166 | use_cutmix = True 167 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 168 | else: 169 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 170 | lam = float(lam_mix) 171 | return lam, use_cutmix 172 | 173 | def _mix_elem(self, x): 174 | batch_size = len(x) 175 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 176 | x_orig = x.clone() # need to keep an unmodified original for mixing source 177 | for i in range(batch_size): 178 | j = batch_size - i - 1 179 | lam = lam_batch[i] 180 | if lam != 1.: 181 | if use_cutmix[i]: 182 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 183 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 184 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 185 | lam_batch[i] = lam 186 | else: 187 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 188 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 189 | 190 | def _mix_pair(self, x): 191 | batch_size = len(x) 192 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 193 | x_orig = x.clone() # need to keep an unmodified original for mixing source 194 | for i in range(batch_size // 2): 195 | j = batch_size - i - 1 196 | lam = lam_batch[i] 197 | if lam != 1.: 198 | if use_cutmix[i]: 199 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 200 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 201 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 202 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 203 | lam_batch[i] = lam 204 | else: 205 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 206 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 207 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 208 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 209 | 210 | def _mix_batch(self, x): 211 | lam, use_cutmix = self._params_per_batch() 212 | if lam == 1.: 213 | return 1. 214 | if use_cutmix: 215 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 216 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 217 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] 218 | else: 219 | x_flipped = x.flip(0).mul_(1. - lam) 220 | x.mul_(lam).add_(x_flipped) 221 | return lam 222 | 223 | def __call__(self, x, target): 224 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 225 | if self.mode == 'elem': 226 | lam = self._mix_elem(x) 227 | elif self.mode == 'pair': 228 | lam = self._mix_pair(x) 229 | else: 230 | lam = self._mix_batch(x) 231 | if self.onehot: 232 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device=x.device) 233 | else: 234 | target = mixup_target_multi_binary(target, lam, self.label_smoothing, device=x.device) 235 | return x, target 236 | 237 | 238 | class MixupDataset(Dataset): 239 | """Mixup for soft label (shape [bs,num_class])""" 240 | 241 | def __init__(self, dataset, alpha=0.2, prob=0.1, mixup_to_cutmix=0.0, raw=False): 242 | self.dataset = dataset 243 | self.prob = prob 244 | self.mixup_to_cutmix = mixup_to_cutmix 245 | self.data_size = len(self) 246 | self.beta = Beta(torch.FloatTensor([alpha]), torch.FloatTensor([alpha])) 247 | self.raw = raw 248 | 249 | def __getitem__(self, idx): 250 | img, label = self.dataset[idx] 251 | label = np.array(label, dtype=np.float32) # assert label like [0,1,0,0] or [0.0, 0.9, 0.05, 0.05] 252 | if torch.rand(1)[0] < self.prob: 253 | lam = self.beta.sample().numpy() 254 | rand_idx = torch.randint(self.data_size, (1,))[0].numpy() 255 | img_aug, label_aug = self.dataset[rand_idx] 256 | label_aug = np.array(label_aug, dtype=np.float32) 257 | if torch.rand(1)[0] > self.mixup_to_cutmix: # mixup 258 | img = img * lam + img_aug * (1 - lam) 259 | else: # cutmix 260 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(img.shape, lam, correct_lam=True) 261 | img[:, yl:yh, xl:xh] = img_aug[:, yl:yh, xl:xh] 262 | if self.raw: 263 | return img, label, label_aug, lam 264 | label = label * lam + label_aug * (1 - lam) 265 | label.astype(np.float32) 266 | if self.raw: 267 | return img, label, label, 1 268 | return img, label 269 | 270 | def __len__(self): 271 | return len(self.dataset) 272 | 273 | 274 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 275 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 276 | -------------------------------------------------------------------------------- /torch_utils/models/layers/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import init 7 | 8 | 9 | def conv3x3(in_channel, out_channel): # not change resolusion 10 | return nn.Conv2d(in_channel, out_channel, 11 | kernel_size=3, stride=1, padding=1, dilation=1, bias=False) 12 | 13 | 14 | def conv1x1(in_channel, out_channel): # not change resolution 15 | return nn.Conv2d(in_channel, out_channel, 16 | kernel_size=1, stride=1, padding=0, dilation=1, bias=False) 17 | 18 | 19 | def init_weight(m): 20 | classname = m.__class__.__name__ 21 | if classname.find('Conv') != -1: 22 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 23 | if m.bias is not None: 24 | m.bias.data.zero_() 25 | elif classname.find('Batch') != -1: 26 | m.weight.data.normal_(1, 0.02) 27 | m.bias.data.zero_() 28 | elif classname.find('Linear') != -1: 29 | nn.init.orthogonal_(m.weight, gain=1) 30 | if m.bias is not None: 31 | m.bias.data.zero_() 32 | elif classname.find('Embedding') != -1: 33 | nn.init.orthogonal_(m.weight, gain=1) 34 | 35 | # Attention 36 | 37 | 38 | class CSE(nn.Module): 39 | def __init__(self, in_ch, r): 40 | super(CSE, self).__init__() 41 | self.linear_1 = nn.Linear(in_ch, in_ch // r) 42 | self.linear_2 = nn.Linear(in_ch // r, in_ch) 43 | 44 | def forward(self, x): 45 | input_x = x 46 | x = x.view(*(x.shape[:-2]), -1).mean(-1) 47 | x = F.relu(self.linear_1(x), inplace=True) 48 | x = self.linear_2(x) 49 | x = x.unsqueeze(-1).unsqueeze(-1) 50 | x = torch.sigmoid(x) 51 | x = input_x * x 52 | return x 53 | 54 | 55 | class SSE(nn.Module): 56 | def __init__(self, in_ch): 57 | super(SSE, self).__init__() 58 | self.conv = nn.Conv2d(in_ch, in_ch, kernel_size=1, stride=1) 59 | 60 | def forward(self, x): 61 | input_x = x 62 | x = self.conv(x) 63 | x = torch.sigmoid(x) 64 | x = input_x * x 65 | return x 66 | 67 | 68 | class SCSE(nn.Module): 69 | def __init__(self, in_ch, r=8): 70 | super(SCSE, self).__init__() 71 | self.cSE = CSE(in_ch, r) 72 | self.sSE = SSE(in_ch) 73 | 74 | def forward(self, x): 75 | cSE = self.cSE(x) 76 | sSE = self.sSE(x) 77 | x = cSE + sSE 78 | return x 79 | 80 | 81 | class SEBlock(nn.Module): 82 | def __init__(self, in_ch, r=8): 83 | super(SEBlock, self).__init__() 84 | 85 | self.linear_1 = nn.Linear(in_ch, in_ch // r) 86 | self.linear_2 = nn.Linear(in_ch // r, in_ch) 87 | 88 | def forward(self, x): 89 | input_x = x 90 | x = F.relu(self.linear_1(x), inplace=True) 91 | x = self.linear_2(x) 92 | x = torch.sigmoid(x) 93 | x = input_x * x 94 | return x 95 | 96 | 97 | class ChannelAttentionModule(nn.Module): 98 | def __init__(self, in_channel, reduction=16): 99 | super().__init__() 100 | self.global_maxpool = nn.AdaptiveMaxPool2d(1) 101 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 102 | self.fc = nn.Sequential( 103 | conv1x1(in_channel, in_channel // reduction).apply(init_weight), 104 | nn.ReLU(True), 105 | conv1x1(in_channel // reduction, in_channel).apply(init_weight) 106 | ) 107 | 108 | def forward(self, inputs): 109 | x1 = self.global_maxpool(inputs) 110 | x2 = self.global_avgpool(inputs) 111 | x1 = self.fc(x1) 112 | x2 = self.fc(x2) 113 | x = torch.sigmoid(x1 + x2) 114 | return x 115 | 116 | 117 | class SpatialAttentionModule(nn.Module): 118 | def __init__(self): 119 | super().__init__() 120 | self.conv3x3 = conv3x3(2, 1).apply(init_weight) 121 | 122 | def forward(self, inputs): 123 | x1, _ = torch.max(inputs, dim=1, keepdim=True) 124 | x2 = torch.mean(inputs, dim=1, keepdim=True) 125 | x = torch.cat([x1, x2], dim=1) 126 | x = self.conv3x3(x) 127 | x = torch.sigmoid(x) 128 | return x 129 | 130 | 131 | class CBAM(nn.Module): 132 | def __init__(self, in_channel, reduction=16): 133 | super().__init__() 134 | self.channel_attention = ChannelAttentionModule(in_channel, reduction) 135 | self.spatial_attention = SpatialAttentionModule() 136 | 137 | def forward(self, inputs): 138 | x = inputs * self.channel_attention(inputs) 139 | x = x * self.spatial_attention(x) 140 | return x 141 | 142 | 143 | class CoordAttention(nn.Module): 144 | '''Coordinate Attention for Efficient Mobile Network Design''' 145 | 146 | def __init__(self, in_channels, out_channels, reduction=16): 147 | super(CoordAttention, self).__init__() 148 | self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1)) 149 | temp_c = max(8, in_channels // reduction) 150 | self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0) 151 | 152 | self.bn1 = nn.BatchNorm2d(temp_c) 153 | self.act1 = nn.SiLU(True) 154 | 155 | self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0) 156 | self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0) 157 | 158 | def forward(self, x): 159 | short = x 160 | n, c, H, W = x.shape 161 | x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2) 162 | x_cat = torch.cat([x_h, x_w], dim=2) 163 | out = self.act1(self.bn1(self.conv1(x_cat))) 164 | x_h, x_w = torch.split(out, [H, W], dim=2) 165 | x_w = x_w.permute(0, 1, 3, 2) 166 | out_h = torch.sigmoid(self.conv2(x_h)) 167 | out_w = torch.sigmoid(self.conv3(x_w)) 168 | return short * out_w * out_h 169 | 170 | 171 | # TODO: 172 | # add GloRe(GCN attention) 173 | # https://github.com/facebookresearch/GloRe/blob/master/network/global_reasoning_unit.py 174 | # add CCAtention 175 | # https://github.com/speedinghzl/CCNet/blob/master/networks/ccnet.py#L99 176 | 177 | 178 | def gem(x, p=1, eps=1e-6): 179 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 180 | 181 | 182 | class GeM(nn.Module): 183 | def __init__(self, p=3, flatten=True, eps=1e-6, requires_grad=True): 184 | super(GeM, self).__init__() 185 | self.p = Parameter(torch.ones(1) * p, requires_grad=requires_grad) 186 | self.eps = eps 187 | if flatten: 188 | self.flatten = nn.Flatten() 189 | else: 190 | self.flatten = False 191 | 192 | def forward(self, x): 193 | x = gem(x, p=self.p, eps=self.eps) 194 | if self.flatten: 195 | return self.flatten(x) 196 | else: 197 | return x 198 | 199 | def __repr__(self): 200 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 201 | 202 | 203 | class GeM_cw(nn.Module): 204 | """ channel-wise GeM Pooling """ 205 | 206 | def __init__(self, num_channel, p=1, flatten=True, eps=1e-6): 207 | super(GeM_cw, self).__init__() 208 | self.p = Parameter(torch.ones(num_channel) * p) 209 | self.eps = eps 210 | if flatten: 211 | self.flatten = nn.Flatten() 212 | else: 213 | self.flatten = False 214 | 215 | def forward(self, x): 216 | p = self.p.unsqueeze(0).unsqueeze(2).unsqueeze(3) 217 | x = gem(x, p=p, eps=self.eps) 218 | if self.flatten: 219 | return self.flatten(x) 220 | else: 221 | return x 222 | 223 | def __repr__(self): 224 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 225 | 226 | 227 | class FastGlobalAvgPool2d(nn.Module): 228 | def __init__(self, flatten=True): # flatten == True : pool + flatten 229 | super(FastGlobalAvgPool2d, self).__init__() 230 | self.flatten = flatten 231 | 232 | def forward(self, x): 233 | if self.flatten: 234 | in_size = x.size() 235 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 236 | else: 237 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) 238 | 239 | 240 | class FastGlobalConcatPool2d(nn.Module): 241 | def __init__(self, flatten=True): # flatten == True : pool + flatten 242 | super(FastGlobalConcatPool2d, self).__init__() 243 | self.flatten = flatten 244 | 245 | def forward(self, x): 246 | if self.flatten: 247 | in_size = x.size() 248 | x = x.view((in_size[0], in_size[1], -1)) 249 | return torch.cat([x.mean(dim=2), x.max(dim=2).values], 1) 250 | else: 251 | x = x.view(x.size(0), x.size(1), -1) 252 | return torch.cat([x.mean(-1), x.max(-1).values], 1).view(x.size(0), 2 * x.size(1), 1, 1) 253 | 254 | 255 | class MultiSampleDropoutFC(nn.Module): 256 | def __init__(self, in_ch, out_ch, num_sample=4, dropout=0.5): 257 | super(MultiSampleDropoutFC, self).__init__() 258 | self.dropouts = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_sample)]) 259 | self.fc = nn.Linear(in_ch, out_ch, bias=True) 260 | 261 | def forward(self, x): 262 | for i, dropout in enumerate(self.dropouts): 263 | if i == 0: 264 | out = self.fc(dropout(x)) 265 | else: 266 | out += self.fc(dropout(x)) 267 | out /= len(self.dropouts) 268 | return out 269 | 270 | ###### activation ###### 271 | 272 | 273 | class Mish(nn.Module): 274 | def __init__(self): 275 | super().__init__() 276 | 277 | def forward(self, x): 278 | return x * (torch.tanh(F.softplus(x))) 279 | 280 | 281 | class Swish(nn.Module): 282 | def __init__(self, inplace=False): 283 | super().__init__() 284 | self.inplace = inplace 285 | 286 | def forward(self, x): 287 | if self.inplace: 288 | x.mul_(torch.sigmoid(x)) 289 | return x 290 | else: 291 | return x * torch.sigmoid(x) 292 | 293 | 294 | class FReLU(nn.Module): 295 | """ 296 | FReLU formulation. The funnel condition has a window size of kxk. (k=3 by default) 297 | """ 298 | 299 | def __init__(self, in_channels): 300 | super().__init__() 301 | self.conv_frelu = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=in_channels) 302 | self.bn_frelu = nn.BatchNorm2d(in_channels) 303 | 304 | def forward(self, x): 305 | y = self.conv_frelu(x) 306 | y = self.bn_frelu(y) 307 | x = torch.max(x, y) 308 | return x 309 | 310 | 311 | ###### example ###### 312 | 313 | def get_simple_fc(in_ch, num_classes, flatten=False): 314 | if flatten: 315 | return nn.Sequential( 316 | nn.Flatten(), 317 | nn.Linear(in_ch, 512), 318 | Swish(inplace=True), 319 | nn.Dropout(), 320 | nn.Linear(512, num_classes), 321 | ) 322 | else: 323 | return nn.Sequential( 324 | nn.Linear(in_ch, 512), 325 | Swish(), 326 | nn.Dropout(), 327 | nn.Linear(512, num_classes), 328 | ) 329 | 330 | 331 | def get_attention_fc(in_ch, num_classes, flatten=False): 332 | if flatten: 333 | return nn.Sequential( 334 | nn.Flatten(), 335 | SEBlock(in_ch), 336 | MultiSampleDropoutFC(in_ch, num_classes), 337 | ) 338 | else: 339 | return nn.Sequential( 340 | SEBlock(in_ch), 341 | MultiSampleDropoutFC(in_ch, num_classes), 342 | ) 343 | 344 | 345 | ######################### 346 | # DepthToSpace == pixel shuffle 347 | # Official: torch.nn.PixelShuffle(upscale_factor) 348 | # SpaceToDepth == inverted pixel shuffle 349 | # Official: torch.nn.PixelUnshuffle(downscale_factor) 350 | 351 | def pixelshuffle(x, factor_hw): 352 | pH = factor_hw[0] 353 | pW = factor_hw[1] 354 | y = x 355 | B, iC, iH, iW = y.shape 356 | oC, oH, oW = iC // (pH * pW), iH * pH, iW * pW 357 | y = y.reshape(B, oC, pH, pW, iH, iW) 358 | y = y.permute(0, 1, 4, 2, 5, 3) # B, oC, iH, pH, iW, pW 359 | y = y.reshape(B, oC, oH, oW) 360 | return y 361 | 362 | 363 | def pixelshuffle_invert(x, factor_hw): 364 | pH = factor_hw[0] 365 | pW = factor_hw[1] 366 | y = x 367 | B, iC, iH, iW = y.shape 368 | oC, oH, oW = iC * (pH * pW), iH // pH, iW // pW 369 | y = y.reshape(B, iC, oH, pH, oW, pW) 370 | y = y.permute(0, 1, 3, 5, 2, 4) # B, iC, pH, pW, oH, oW 371 | y = y.reshape(B, oC, oH, oW) 372 | return y 373 | 374 | 375 | class SpaceToDepth(nn.Module): 376 | def __init__(self, block_size=4): 377 | super().__init__() 378 | assert block_size == 4 379 | self.bs = block_size 380 | 381 | def forward(self, x): 382 | N, C, H, W = x.size() 383 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 384 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 385 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 386 | return x 387 | 388 | 389 | @torch.jit.script 390 | class SpaceToDepthJit(object): 391 | def __call__(self, x): 392 | # assuming hard-coded that block_size==4 for acceleration 393 | N, C, H, W = x.size() 394 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 395 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 396 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 397 | return x 398 | 399 | 400 | class SpaceToDepthModule(nn.Module): 401 | def __init__(self, remove_model_jit=False): 402 | super().__init__() 403 | if not remove_model_jit: 404 | self.op = SpaceToDepthJit() 405 | else: 406 | self.op = SpaceToDepth() 407 | 408 | def forward(self, x): 409 | return self.op(x) 410 | 411 | 412 | class DepthToSpace(nn.Module): 413 | 414 | def __init__(self, block_size): 415 | super().__init__() 416 | self.bs = block_size 417 | 418 | def forward(self, x): 419 | N, C, H, W = x.size() 420 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 421 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 422 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 423 | return x 424 | 425 | 426 | # ASPP 427 | 428 | class _ASPPModule(nn.Module): 429 | def __init__(self, inplanes, planes, kernel_size, padding, dilation): 430 | super(_ASPPModule, self).__init__() 431 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 432 | stride=1, padding=padding, dilation=dilation, bias=False) 433 | self.bn = nn.BatchNorm2d(planes) 434 | self.relu = nn.ReLU() 435 | 436 | def forward(self, x): 437 | x = self.atrous_conv(x) 438 | x = self.bn(x) 439 | 440 | return self.relu(x) 441 | 442 | 443 | class ASPP(nn.Module): 444 | def __init__(self, inplanes=512, mid_c=256, dilations=[1, 6, 12, 18]): 445 | super(ASPP, self).__init__() 446 | self.aspp1 = _ASPPModule(inplanes, mid_c, 1, padding=0, dilation=dilations[0]) 447 | self.aspp2 = _ASPPModule(inplanes, mid_c, 3, padding=dilations[1], dilation=dilations[1]) 448 | self.aspp3 = _ASPPModule(inplanes, mid_c, 3, padding=dilations[2], dilation=dilations[2]) 449 | self.aspp4 = _ASPPModule(inplanes, mid_c, 3, padding=dilations[3], dilation=dilations[3]) 450 | 451 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 452 | nn.Conv2d(inplanes, mid_c, 1, stride=1, bias=False), 453 | nn.BatchNorm2d(mid_c), 454 | nn.ReLU()) 455 | self.conv1 = nn.Conv2d(mid_c * 5, mid_c, 1, bias=False) 456 | 457 | def forward(self, x): 458 | x1 = self.aspp1(x) 459 | x2 = self.aspp2(x) 460 | x3 = self.aspp3(x) 461 | x4 = self.aspp4(x) 462 | x5 = self.global_avg_pool(x) 463 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 464 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 465 | 466 | x = self.conv1(x) 467 | 468 | return x 469 | --------------------------------------------------------------------------------