├── README.md ├── datasets.py ├── datasets ├── test.txt └── train.txt ├── demo ├── cub_baseline.yml └── cub_s3n.yml ├── fgvc_datasets.py ├── ft_resnet.py ├── hooks.py ├── illustration.png ├── losses.py ├── meters.py ├── optimizers.py ├── sss_net.py ├── trainer.py ├── utility.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |

Selective Sparse Sampling for Fine-grained Image Recognition

2 | 3 | ![Illustration](illustration.png) 4 | 5 | ## PyTorch Implementation 6 | This repository contains: 7 | 8 | * the **pytorch** implementation of Selective Sparse Sampling. 9 | * the CUB-200-2011 demo (training, test). 10 | 11 | Please follow the instruction below to install it and run the experiment demo. 12 | 13 | ### Prerequisites 14 | * System (tested on Ubuntu 14.04LTS and Win10) 15 | * 2 Tesla P100 + CUDA CuDNN (CPU mode is also supported but significantly slower) 16 | * [Python=3.6.8](https://www.python.org) 17 | * [PyTorch=0.4.1](https://pytorch.org) 18 | * [Jupyter Notebook](https://jupyter.org/install.html) 19 | * [Nest](https://github.com/ZhouYanzhao/Nest.git) 20 | 21 | ### Installation 22 | 23 | 1. Install S3N via Nest's CLI tool: 24 | 25 | ```bash 26 | # note that data will be saved under your current path 27 | $ git clone https://github.com/Yao-DD/S3N.git ./S3N 28 | $ nest module install ./S3N/ s3n 29 | # verify the installation 30 | $ nest module list --filter s3n 31 | ``` 32 | 33 | ### Prepare Data 34 | 35 | 1. Download the CUB-200-2011 dataset: 36 | 37 | ```bash 38 | $ mkdir ./S3N/datasets 39 | $ cd ./S3N/datasets 40 | # download and extract data 41 | $ wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 42 | $ tar xvf CUB_200_2011.tgz 43 | ``` 44 | 45 | 2. Prepare annotation files: 46 | 47 | Move the file ./datasets/train.txt and ./datasets/test.txt into ./datasets/CUB_200_2011. The list of image file names and label is contained in the file ./datasets/CUB_200_2011/train.txt and ./datasets/CUB_200_2011/test.txt, with each line corresponding to one image: 48 | 49 | ``` 50 | 51 | ``` 52 | 53 | ### Run the demo 54 | 55 | 1. run the code as: 56 | 57 | ```bash 58 | $ cd ./S3N 59 | # run baseline 60 | $ PYTHONWARNINGS='ignore' CUDA_VISIBLE_DEVICES=0,1 nest task run ./demo/cub_baseline.yml 61 | # run S3N 62 | $ PYTHONWARNINGS='ignore' CUDA_VISIBLE_DEVICES=0,1 nest task run ./demo/cub_s3n.yml 63 | ``` 64 | 65 | ### Pretrained models 66 | 67 | 1. S3N model for CUB_200_2011 dataset is availavble on Baidu Disk. 68 | 69 | ```bash 70 | The link:https://pan.baidu.com/s/19x9zI_ZNi32sRGRgNwN_Fw 71 | code: r252 72 | ``` 73 | 74 | ## CAUTION 75 | The current code was prepared under the above-mentioned prerequisites. The use of other version can cause problems. 76 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from typing import Tuple, List, Dict, Union, Callable, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import datasets, transforms 9 | from PIL import Image 10 | from nest import register 11 | 12 | 13 | @register 14 | def image_transform( 15 | image_size: Union[int, List[int]], 16 | augmentation: dict, 17 | mean: List[float] = [0.485, 0.456, 0.406], 18 | std: List[float] = [0.229, 0.224, 0.225]) -> Callable: 19 | """Image transforms. 20 | """ 21 | 22 | if isinstance(image_size, int): 23 | image_size = (image_size, image_size) 24 | else: 25 | image_size = tuple(image_size) 26 | 27 | horizontal_flip = augmentation.pop('horizontal_flip', None) 28 | if horizontal_flip is not None: 29 | assert isinstance(horizontal_flip, float) and 0 <= horizontal_flip <= 1 30 | 31 | vertical_flip = augmentation.pop('vertical_flip', None) 32 | if vertical_flip is not None: 33 | assert isinstance(vertical_flip, float) and 0 <= vertical_flip <= 1 34 | 35 | random_crop = augmentation.pop('random_crop', None) 36 | if random_crop is not None: 37 | assert isinstance(random_crop, dict) 38 | 39 | center_crop = augmentation.pop('center_crop', None) 40 | if center_crop is not None: 41 | assert isinstance(center_crop, (int, list)) 42 | 43 | if len(augmentation) > 0: 44 | raise NotImplementedError('Invalid augmentation options: %s.' % ', '.join(augmentation.keys())) 45 | 46 | t = [ 47 | transforms.Resize(image_size) if random_crop is None else transforms.RandomResizedCrop(image_size[0], **random_crop), 48 | transforms.CenterCrop(center_crop) if center_crop is not None else None, 49 | transforms.RandomHorizontalFlip(horizontal_flip) if horizontal_flip is not None else None, 50 | transforms.RandomVerticalFlip(vertical_flip) if vertical_flip is not None else None, 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean, std)] 53 | 54 | return transforms.Compose([v for v in t if v is not None]) 55 | 56 | 57 | @register 58 | def fetch_data( 59 | dataset: Callable[[str], Dataset], 60 | transform: Optional[Callable] = None, 61 | target_transform: Optional[Callable] = None, 62 | num_workers: int = 0, 63 | pin_memory: bool = True, 64 | drop_last: bool = False, 65 | train_splits: List[str] = [], 66 | test_splits: List[str] = [], 67 | train_shuffle: bool = True, 68 | test_shuffle: bool = False, 69 | test_image_size: int = 600, 70 | train_augmentation: dict = {}, 71 | test_augmentation: dict = {}, 72 | batch_size: int = 1, 73 | test_batch_size: Optional[int] = None) -> Tuple[List[Tuple[str, DataLoader]], List[Tuple[str, DataLoader]]]: 74 | """Fetch data. 75 | """ 76 | 77 | train_transform = transform(augmentation=train_augmentation) if transform else None 78 | train_loader_list = [] 79 | for split in train_splits: 80 | train_loader_list.append((split, DataLoader( 81 | dataset = dataset( 82 | split = split, 83 | transform = train_transform, 84 | target_transform = target_transform), 85 | batch_size = batch_size, 86 | num_workers = num_workers, 87 | pin_memory = pin_memory, 88 | drop_last=drop_last, 89 | shuffle = train_shuffle))) 90 | 91 | test_transform = transform(image_size=[test_image_size, test_image_size], augmentation=test_augmentation) if transform else None 92 | test_loader_list = [] 93 | for split in test_splits: 94 | test_loader_list.append((split, DataLoader( 95 | dataset = dataset( 96 | split = split, 97 | transform = test_transform, 98 | target_transform = target_transform), 99 | batch_size = batch_size if test_batch_size is None else test_batch_size, 100 | num_workers = num_workers, 101 | pin_memory = pin_memory, 102 | drop_last=drop_last, 103 | shuffle = test_shuffle))) 104 | 105 | return train_loader_list, test_loader_list -------------------------------------------------------------------------------- /demo/cub_baseline.yml: -------------------------------------------------------------------------------- 1 | _name: network_trainer 2 | data_loaders: 3 | _name: s3n.fetch_data 4 | dataset: 5 | _name: s3n.fgvc_dataset 6 | data_dir: ./datasets/CUB_200_2011 7 | batch_size: 16 8 | num_workers: 4 9 | transform: 10 | _name: image_transform 11 | image_size: [448, 448] 12 | mean: [0.485, 0.456, 0.406] 13 | std: [0.229, 0.224, 0.225] 14 | train_augmentation: 15 | horizontal_flip: 0.5 16 | random_crop: 17 | scale: [0.5, 1] 18 | test_augmentation: 19 | center_crop: 448 20 | train_splits: 21 | - train 22 | test_splits: 23 | - test 24 | log_path: './logs/CUB_baseline.log' 25 | model: 26 | _name: s3n.ft_resnet 27 | mode: 'resnet50' 28 | num_classes: 200 29 | criterion: 30 | _name: s3n.smooth_loss 31 | smooth_ratio: 0.85 32 | optimizer: 33 | _name: s3n.sgd_optimizer 34 | lr: 0.01 35 | momentum: 0.9 36 | weight_decay: 1.0e-4 37 | parameter: 38 | _name: finetune 39 | base_lr: 0.001 40 | groups: 41 | 'classifier': 10.0 42 | meters: 43 | top1: 44 | _name: s3n.topk_meter 45 | k: 1 46 | loss: 47 | _name: loss_meter 48 | max_epoch: 30 49 | device: cuda 50 | hooks: 51 | on_start_epoch: 52 | - 53 | _name: update_lr 54 | epoch_list: [20,] 55 | on_end_epoch: 56 | - 57 | _name: print_state 58 | formats: 59 | - 'epoch: {epoch_idx}' 60 | - 'train_loss: {metrics[train_loss]:.4f}' 61 | - 'test_loss: {metrics[test_loss]:.4f}' 62 | - 'train_top1: {metrics[train_top1]:.2f}%' 63 | - 'test_top1: {metrics[test_top1]:.2f}%' 64 | - 65 | _name: checkpoint 66 | save_dir: './generate/baseline/' 67 | save_step: 1 -------------------------------------------------------------------------------- /demo/cub_s3n.yml: -------------------------------------------------------------------------------- 1 | _name: network_trainer 2 | data_loaders: 3 | _name: s3n.fetch_data 4 | dataset: 5 | _name: s3n.fgvc_dataset 6 | data_dir: ./datasets/CUB_200_2011 7 | batch_size: 16 8 | num_workers: 4 9 | transform: 10 | _name: s3n.image_transform 11 | image_size: [448, 448] 12 | mean: [0.485, 0.456, 0.406] 13 | std: [0.229, 0.224, 0.225] 14 | train_augmentation: 15 | horizontal_flip: 0.5 16 | random_crop: 17 | scale: [0.5, 1] 18 | test_augmentation: 19 | center_crop: 448 20 | test_image_size: 600 21 | train_splits: 22 | - train 23 | test_splits: 24 | - test 25 | log_path: './logs/cub_s3n.log' 26 | model: 27 | _name: s3n.s3n 28 | mode: 'resnet50' 29 | num_classes: 200 30 | radius: 0.09 31 | radius_inv: 0.3 32 | criterion: 33 | _name: s3n.multi_smooth_loss 34 | smooth_ratio: 0.85 35 | optimizer: 36 | _name: s3n.sgd_optimizer 37 | lr: 0.01 38 | momentum: 0.9 39 | weight_decay: 1.0e-4 40 | parameter: 41 | _name: s3n.finetune 42 | base_lr: 0.001 43 | groups: 44 | 'classifier': 10.0 45 | 'radius': 0.0001 46 | 'filter': 0.0001 47 | meters: 48 | top1: 49 | _name: s3n.multi_topk_meter 50 | k: 1 51 | init_num: 0 52 | loss: 53 | _name: s3n.loss_meter 54 | max_epoch: 60 55 | device: cuda 56 | use_data_parallel: yes 57 | hooks: 58 | on_start_epoch: 59 | - 60 | _name: s3n.update_lr 61 | epoch_step: 40 62 | on_start_forward: 63 | - 64 | _name: s3n.three_stage 65 | on_end_epoch: 66 | - 67 | _name: s3n.print_state 68 | formats: 69 | - 'epoch: {epoch_idx}' 70 | - 'train_loss: {metrics[train_loss]:.4f}' 71 | - 'test_loss: {metrics[test_loss]:.4f}' 72 | - 'train_branch1_top1: {metrics[train_top1][branch_0]:.2f}%' 73 | - 'train_branch2_top1: {metrics[train_top1][branch_1]:.2f}%' 74 | - 'train_branch3_top1: {metrics[train_top1][branch_2]:.2f}%' 75 | - 'train_branch4_top1: {metrics[train_top1][branch_3]:.2f}%' 76 | - 'test_branch1_top1: {metrics[test_top1][branch_0]:.2f}%' 77 | - 'test_branch2_top1: {metrics[test_top1][branch_1]:.2f}%' 78 | - 'test_branch3_top1: {metrics[test_top1][branch_2]:.2f}%' 79 | - 'test_branch4_top1: {metrics[test_top1][branch_3]:.2f}%' 80 | - 81 | _name: s3n.checkpoint 82 | save_dir: './generate/s3n/' 83 | save_step: 1 -------------------------------------------------------------------------------- /fgvc_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from typing import List, Dict, Tuple, Callable, Optional, Union 4 | 5 | import torch 6 | import numpy as np 7 | from torch.utils.data import Dataset,DataLoader 8 | from torchvision import datasets, transforms 9 | from PIL import Image 10 | from nest import register 11 | 12 | 13 | class FGVC_Dataset(Dataset): 14 | 15 | def __init__(self, data_dir, split, lable_path=None, transform=None, target_transform=None): 16 | self.data_dir = data_dir 17 | self.split = split 18 | self.lable_path = lable_path 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | 22 | self.image_dir = os.path.join(self.data_dir, 'images') 23 | self.image_lables = self._read_annotation(self.split) 24 | 25 | def _read_annotation(self, split): 26 | class_lables = OrderedDict() 27 | if self.lable_path is None: 28 | lable_path = os.path.join(self.data_dir, split + '.txt') 29 | else: 30 | lable_path = os.path.join(self.data_dir, self.lable_path, split + '.txt') 31 | if os.path.exists(lable_path): 32 | with open(lable_path, 'r') as f: 33 | for line in f: 34 | name, lable = line.split(' ') 35 | class_lables[name] = int(lable) 36 | else: 37 | raise NotImplementedError( 38 | 'Invalid path for dataset') 39 | 40 | return list(class_lables.items()) 41 | 42 | def __getitem__(self, index): 43 | filename, target = self.image_lables[index] 44 | img = Image.open(os.path.join(self.image_dir, filename)).convert('RGB') 45 | 46 | if self.transform: 47 | img = self.transform(img) 48 | if self.target_transform: 49 | target = self.target_transform(target) 50 | 51 | return img, target 52 | 53 | def __len__(self): 54 | return len(self.image_lables) 55 | 56 | 57 | @register 58 | def fgvc_dataset( 59 | split: str, 60 | data_dir: str, 61 | label_path: Optional[str] = None, 62 | transform: Optional[Callable] = None, 63 | target_transform: Optional[Callable] = None) -> Dataset: 64 | '''Fine-grained visual classification datasets. 65 | ''' 66 | return FGVC_Dataset(data_dir, split, label_path, transform, target_transform) -------------------------------------------------------------------------------- /ft_resnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision import models 3 | from nest import register 4 | 5 | 6 | @register 7 | def ft_resnet(mode: str = 'resnet50', fc_or_fcn: str = 'fc', num_classes: int = 10, pretrained: bool = True) -> nn.Module: 8 | """Finetune resnet. 9 | """ 10 | 11 | class FT_Resnet(nn.Module): 12 | def __init__(self, mode='resnet50', fc_or_fcn='fc', num_classes=10, pretrained=True): 13 | super(FT_Resnet, self).__init__() 14 | 15 | if mode=='resnet50': 16 | model = models.resnet50(pretrained=pretrained) 17 | elif mode=='resnet101': 18 | model = models.resnet101(pretrained=pretrained) 19 | elif mode=='resnet152': 20 | model = models.resnet152(pretrained=pretrained) 21 | else: 22 | model = models.resnet18(pretrained=pretrained) 23 | 24 | self.features = nn.Sequential( 25 | model.conv1, 26 | model.bn1, 27 | model.relu, 28 | model.maxpool, 29 | model.layer1, 30 | model.layer2, 31 | model.layer3, 32 | model.layer4 33 | ) 34 | self.num_classes = num_classes 35 | self.num_features = model.layer4[1].conv1.in_channels 36 | self.fc_or_fcn = fc_or_fcn 37 | if self.fc_or_fcn=='fc': 38 | self.classifier = nn.Linear(self.num_features, num_classes) 39 | else: 40 | self.classifier = nn.Conv2d(self.num_features, self.num_classes, 1, 1) 41 | self.avg = nn.AdaptiveAvgPool2d(1) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | if self.fc_or_fcn=='fc': 46 | x = self.avg(x).view(-1, self.num_features) 47 | x = self.classifier(x) 48 | else: 49 | x = self.classifier(x) 50 | x = self.avg(x).view(-1, self.num_classes) 51 | return x 52 | 53 | return FT_Resnet(mode, fc_or_fcn, num_classes, pretrained) 54 | -------------------------------------------------------------------------------- /hooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | from time import localtime, strftime 4 | from typing import List, Callable, Optional 5 | 6 | import torch 7 | import numpy as np 8 | from visdom import Visdom 9 | from nest import register, Context 10 | 11 | 12 | @register 13 | def checkpoint( 14 | train_ctx: Context, 15 | save_dir: str, 16 | save_step: Optional[int] = None, 17 | save_final: bool = False, 18 | save_latest: bool = False, 19 | save_all: bool = False) -> None: 20 | """Checkpoint. 21 | """ 22 | 23 | save_dir = os.path.abspath(save_dir) 24 | try: 25 | os.makedirs(save_dir) 26 | except OSError as exception: 27 | if exception.errno != errno.EEXIST: 28 | raise 29 | 30 | def save_current_train_ctx(save_name): 31 | save_path = os.path.join(save_dir, save_name) 32 | torch.save(dict( 33 | epoch_idx = train_ctx.epoch_idx + 1, 34 | batch_idx = train_ctx.batch_idx + 1, 35 | model = train_ctx.model.state_dict(), 36 | optimizer = train_ctx.optimizer.state_dict()), save_path) 37 | train_ctx.logger.info('checkpoint created at %s' % save_path) 38 | 39 | if save_all: 40 | save_current_train_ctx(strftime("model_%Y_%m_%d_%H.%M.%S.pt", localtime())) 41 | if save_step is not None and (train_ctx.epoch_idx + 1) % save_step == 0: 42 | save_current_train_ctx('model_%d.pt' % train_ctx.epoch_idx) 43 | if save_final and (train_ctx.epoch_idx + 1) == train_ctx.max_epoch: 44 | save_current_train_ctx('model_final.pt') 45 | if save_latest: 46 | save_current_train_ctx('model_latest.pt') 47 | 48 | 49 | @register 50 | def vis_trend(ctx: Context, train_ctx: Context, server: str, env: str, port: int = 80) -> None: 51 | """Track trend with Visdom. 52 | """ 53 | 54 | if not 'vis' in ctx: 55 | ctx.vis = Visdom(server=server, port=port, env=env) 56 | 57 | try: 58 | for k, v in train_ctx.metrics.items(): 59 | if isinstance(v, (int, float)): 60 | if ctx.vis.win_exists(k): 61 | ctx.vis.line( 62 | X = np.array([train_ctx.epoch_idx]), 63 | Y = np.array([v]), 64 | opts = dict(title=k, xlabel='epoch'), 65 | win = k, 66 | update = 'append') 67 | else: 68 | ctx.vis.line( 69 | X = np.array([train_ctx.epoch_idx]), 70 | Y = np.array([v]), 71 | opts = dict(title=k, xlabel='epoch'), 72 | win = k) 73 | ctx.vis.save([env]) 74 | except ConnectionError: 75 | train_ctx.logger.warning('Could not connect to visdom server "%s".' % server) 76 | 77 | 78 | @register 79 | def print_state(train_ctx: Context, formats: List[str], join_str: str = ' | ') -> None: 80 | """Print state. 81 | """ 82 | 83 | def unescape(escapped_str): 84 | return bytes(escapped_str, "utf-8").decode("unicode_escape") 85 | 86 | def safe_format(format_str, **kwargs): 87 | try: 88 | return format_str.format(**kwargs) 89 | except: 90 | return None 91 | 92 | format_list = [safe_format(unescape(format_str), **vars(train_ctx)) for format_str in formats] 93 | output_str = unescape(join_str).join([val for val in format_list if val is not None]) 94 | train_ctx.logger.info(output_str) 95 | 96 | 97 | @register 98 | def interval( 99 | train_ctx: Context, 100 | hook: Callable[[Context], None], 101 | epoch_interval: int = 1, 102 | batch_interval: int = 1) -> None: 103 | """Skip interval. 104 | """ 105 | 106 | if train_ctx.epoch_idx % epoch_interval == 0 and train_ctx.batch_idx % batch_interval == 0: 107 | hook(train_ctx) 108 | 109 | 110 | @register 111 | def update_lr( 112 | train_ctx: Context, 113 | epoch_step: Optional[int] = None, 114 | epoch_list: Optional[List[int]] = None, 115 | factor: float = 0.1) -> None: 116 | """Update learning rate. 117 | """ 118 | 119 | current_epoch = train_ctx.epoch_idx + 1 120 | if ((epoch_step is not None) and (current_epoch % epoch_step == 0)) or \ 121 | ((epoch_list is not None) and (current_epoch in epoch_list)): 122 | for idx, param in enumerate(train_ctx.optimizer.param_groups): 123 | param['lr'] = param['lr'] * factor 124 | print('LR of param group %d is updated to %e' % (idx, param['lr'])) 125 | -------------------------------------------------------------------------------- /illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yao-DD/S3N/c77729b19be2c0d8581f0c522d856a3fc7fdfa21/illustration.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple, Dict, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from nest import register 7 | 8 | 9 | @register 10 | def cross_entropy_loss( 11 | input: Tensor, 12 | target: Tensor, 13 | weight: Optional[Tensor] = None, 14 | size_average: bool = True, 15 | ignore_index: int = -100, 16 | reduce: bool = True) -> Tensor: 17 | """Cross entropy loss. 18 | """ 19 | 20 | return F.cross_entropy(input, target, weight, size_average, ignore_index, reduce) 21 | 22 | 23 | @register 24 | def smooth_loss( 25 | input: Tensor, 26 | target: Tensor, 27 | smooth_ratio: float = 0.9, 28 | weight: Union[None, Tensor] = None, 29 | size_average: bool = True, 30 | ignore_index: int = -100, 31 | reduce: bool = True) -> Tensor: 32 | '''Smooth loss. 33 | ''' 34 | 35 | prob = F.log_softmax(input, dim=1) 36 | ymask = prob.data.new(prob.size()).zero_() 37 | ymask = ymask.scatter_(1, target.view(-1,1), 1) 38 | ymask = smooth_ratio*ymask + (1-smooth_ratio)*(1-ymask)/(len(input[1])-1) 39 | loss = - (prob*ymask).sum(1).mean() 40 | 41 | return loss 42 | 43 | 44 | @register 45 | def multi_smooth_loss( 46 | input: Tuple, 47 | target: Tensor, 48 | smooth_ratio: float = 0.9, 49 | loss_weight: Union[None, Dict]= None, 50 | weight: Union[None, Tensor] = None, 51 | size_average: bool = True, 52 | ignore_index: int = -100, 53 | reduce: bool = True) -> Tensor: 54 | '''Multi smooth loss. 55 | ''' 56 | assert isinstance(input, tuple), 'input is less than 2' 57 | 58 | weight_loss = torch.ones(len(input)).to(input[0].device) 59 | if loss_weight is not None: 60 | for item in loss_weight.items(): 61 | weight_loss[int(item[0])] = item[1] 62 | 63 | loss = 0 64 | for i in range(0, len(input)): 65 | if i in [1, len(input)-1]: 66 | prob = F.log_softmax(input[i], dim=1) 67 | ymask = prob.data.new(prob.size()).zero_() 68 | ymask = ymask.scatter_(1, target.view(-1,1), 1) 69 | ymask = smooth_ratio*ymask + (1-smooth_ratio)*(1-ymask)/(input[i].shape[1]-1) 70 | loss_tmp = - weight_loss[i]*((prob*ymask).sum(1).mean()) 71 | else: 72 | loss_tmp = weight_loss[i]*F.cross_entropy(input[i], target, weight, size_average, ignore_index, reduce) 73 | loss += loss_tmp 74 | 75 | return loss 76 | -------------------------------------------------------------------------------- /meters.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from nest import register, Context 5 | 6 | 7 | class AverageMeter(object): 8 | 9 | def __init__(self): 10 | self.reset() 11 | 12 | def reset(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | 24 | 25 | @register 26 | def loss_meter(ctx: Context, train_ctx: Context) -> float: 27 | """Loss meter. 28 | """ 29 | 30 | if not 'meter' in ctx: 31 | ctx.meter = AverageMeter() 32 | 33 | if train_ctx.batch_idx == 0: 34 | ctx.meter.reset() 35 | ctx.meter.update(train_ctx.loss.item(), train_ctx.target.size(0)) 36 | return ctx.meter.avg 37 | 38 | 39 | @register 40 | def topk_meter(ctx: Context, train_ctx: Context, k: int = 1) -> float: 41 | """Topk meter. 42 | """ 43 | 44 | def accuracy(output, target, k=1): 45 | batch_size = target.size(0) 46 | 47 | _, pred = output.topk(k, 1, True, True) 48 | pred = pred.t() 49 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 50 | 51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 52 | return correct_k.mul_(100.0 / batch_size) 53 | 54 | if not 'meter' in ctx: 55 | ctx.meter = AverageMeter() 56 | 57 | if train_ctx.batch_idx == 0: 58 | ctx.meter.reset() 59 | acc = accuracy(train_ctx.output, train_ctx.target, k) 60 | ctx.meter.update(acc.item()) 61 | return ctx.meter.avg 62 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | from torch import optim 4 | from nest import register 5 | 6 | 7 | @register 8 | def sgd_optimizer( 9 | parameters: Iterable, 10 | lr: float, 11 | momentum: float = 0.0, 12 | dampening: float = 0.0, 13 | weight_decay: float = 0.0, 14 | nesterov: bool = False) -> optim.Optimizer: 15 | """SGD optimizer. 16 | """ 17 | 18 | return optim.SGD(parameters, lr, momentum, dampening, weight_decay, nesterov) 19 | 20 | 21 | @register 22 | def adadelta_optimizer( 23 | parameters: Iterable, 24 | lr: float = 1.0, 25 | rho: float = 0.9, 26 | eps: float = 1e-6, 27 | weight_decay: float = 0.0) -> optim.Optimizer: 28 | """Adadelta optimizer. 29 | """ 30 | 31 | return optim.Adadelta(parameters, lr, rho, eps, weight_decay) 32 | -------------------------------------------------------------------------------- /sss_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | from collections import OrderedDict 5 | from typing import List, Dict, Tuple, Callable, Optional, Union 6 | 7 | import torch 8 | import numpy as np 9 | from PIL import Image 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torchvision import models 13 | from torch.autograd import Function 14 | from nest import register, modules, Context 15 | 16 | 17 | def makeGaussian(size, fwhm = 3, center=None): 18 | 19 | x = np.arange(0, size, 1, float) 20 | y = x[:,np.newaxis] 21 | 22 | if center is None: 23 | x0 = y0 = size // 2 24 | else: 25 | x0 = center[0] 26 | y0 = center[1] 27 | 28 | return np.exp(-4*np.log(2) * ((x-x0)**2 + (y-y0)**2) / fwhm**2) 29 | 30 | 31 | class KernelGenerator(nn.Module): 32 | def __init__(self, size, offset=None): 33 | super(KernelGenerator, self).__init__() 34 | 35 | self.size = self._pair(size) 36 | xx, yy = np.meshgrid(np.arange(0, size), np.arange(0, size)) 37 | if offset is None: 38 | offset_x = offset_y = size // 2 39 | else: 40 | offset_x, offset_y = self._pair(offset) 41 | self.factor = torch.from_numpy(-(np.power(xx - offset_x, 2) + np.power(yy - offset_y, 2)) / 2).float() 42 | 43 | @staticmethod 44 | def _pair(x): 45 | return (x, x) if isinstance(x, int) else x 46 | 47 | def forward(self, theta): 48 | pow2 = torch.pow(theta * self.size[0], 2) 49 | kernel = 1.0 / (2 * np.pi * pow2) * torch.exp(self.factor.to(theta.device) / pow2) 50 | return kernel / kernel.max() 51 | 52 | 53 | def kernel_generate(theta, size, offset=None): 54 | return KernelGenerator(size, offset)(theta) 55 | 56 | 57 | def _mean_filter(input): 58 | batch_size, num_channels, h, w = input.size() 59 | threshold = torch.mean(input.view(batch_size, num_channels, h * w), dim=2) 60 | return threshold.contiguous().view(batch_size, num_channels, 1, 1) 61 | 62 | 63 | class PeakStimulation(Function): 64 | 65 | @staticmethod 66 | def forward(ctx, input, return_aggregation, win_size, peak_filter): 67 | ctx.num_flags = 4 68 | 69 | assert win_size % 2 == 1, 'Window size for peak finding must be odd.' 70 | offset = (win_size - 1) // 2 71 | padding = torch.nn.ConstantPad2d(offset, float('-inf')) 72 | padded_maps = padding(input) 73 | batch_size, num_channels, h, w = padded_maps.size() 74 | element_map = torch.arange(0, h * w).long().view(1, 1, h, w)[:, :, offset: -offset, offset: -offset] 75 | element_map = element_map.to(input.device) 76 | _, indices = F.max_pool2d( 77 | padded_maps, 78 | kernel_size = win_size, 79 | stride = 1, 80 | return_indices = True) 81 | peak_map = (indices == element_map) 82 | 83 | if peak_filter: 84 | mask = input >= peak_filter(input) 85 | peak_map = (peak_map & mask) 86 | peak_list = torch.nonzero(peak_map) 87 | ctx.mark_non_differentiable(peak_list) 88 | 89 | if return_aggregation: 90 | peak_map = peak_map.float() 91 | ctx.save_for_backward(input, peak_map) 92 | return peak_list, (input * peak_map).view(batch_size, num_channels, -1).sum(2) / \ 93 | peak_map.view(batch_size, num_channels, -1).sum(2) 94 | else: 95 | return peak_list 96 | 97 | @staticmethod 98 | def backward(ctx, grad_peak_list, grad_output): 99 | input, peak_map, = ctx.saved_tensors 100 | batch_size, num_channels, _, _ = input.size() 101 | grad_input = peak_map * grad_output.view(batch_size, num_channels, 1, 1)/ \ 102 | (peak_map.view(batch_size, num_channels, -1).sum(2).view(batch_size, num_channels, 1, 1) + 1e-6) 103 | return (grad_input,) + (None,) * ctx.num_flags 104 | 105 | 106 | def peak_stimulation(input, return_aggregation=True, win_size=3, peak_filter=None): 107 | return PeakStimulation.apply(input, return_aggregation, win_size, peak_filter) 108 | 109 | 110 | class ScaleLayer(nn.Module): 111 | 112 | def __init__(self, init_value=1e-3): 113 | super().__init__() 114 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 115 | 116 | def forward(self, input): 117 | return input * self.scale 118 | 119 | 120 | class S3N(nn.Module): 121 | 122 | def __init__(self, base_model, num_classes, task_input_size, base_ratio, radius, radius_inv): 123 | super(S3N, self).__init__() 124 | 125 | self.grid_size = 31 126 | self.padding_size = 30 127 | self.global_size = self.grid_size + 2*self.padding_size 128 | self.input_size_net = task_input_size 129 | gaussian_weights = torch.FloatTensor(makeGaussian(2*self.padding_size+1, fwhm = 13)) 130 | self.base_ratio = base_ratio 131 | self.radius = ScaleLayer(radius) 132 | self.radius_inv = ScaleLayer(radius_inv) 133 | 134 | self.filter = nn.Conv2d(1, 1, kernel_size=(2*self.padding_size+1,2*self.padding_size+1),bias=False) 135 | self.filter.weight[0].data[:,:,:] = gaussian_weights 136 | 137 | self.P_basis = torch.zeros(2,self.grid_size+2*self.padding_size, self.grid_size+2*self.padding_size) 138 | for k in range(2): 139 | for i in range(self.global_size): 140 | for j in range(self.global_size): 141 | self.P_basis[k,i,j] = k*(i-self.padding_size)/(self.grid_size-1.0)+(1.0-k)*(j-self.padding_size)/(self.grid_size-1.0) 142 | 143 | self.features = base_model.features 144 | self.num_features = base_model.num_features 145 | 146 | self.raw_classifier = nn.Linear(2048, num_classes) 147 | self.sampler_buffer = nn.Sequential(nn.Conv2d(2048, 2048, kernel_size=3, stride=2, padding=1, bias=False), 148 | nn.BatchNorm2d(2048), 149 | nn.ReLU(), 150 | ) 151 | self.sampler_classifier = nn.Linear(2048, num_classes) 152 | 153 | self.sampler_buffer1 = nn.Sequential(nn.Conv2d(2048, 2048, kernel_size=3, stride=2, padding=1, bias=False), 154 | nn.BatchNorm2d(2048), 155 | nn.ReLU(), 156 | ) 157 | self.sampler_classifier1 = nn.Linear(2048, num_classes) 158 | 159 | self.con_classifier = nn.Linear(int(self.num_features*3), num_classes) 160 | 161 | self.avg = nn.AdaptiveAvgPool2d(1) 162 | self.max_pool = nn.AdaptiveMaxPool2d(1) 163 | 164 | self.map_origin = nn.Conv2d(2048, num_classes, 1, 1, 0) 165 | 166 | def create_grid(self, x): 167 | P = torch.autograd.Variable(torch.zeros(1,2,self.grid_size+2*self.padding_size, self.grid_size+2*self.padding_size).cuda(),requires_grad=False) 168 | P[0,:,:,:] = self.P_basis 169 | P = P.expand(x.size(0),2,self.grid_size+2*self.padding_size, self.grid_size+2*self.padding_size) 170 | 171 | x_cat = torch.cat((x,x),1) 172 | p_filter = self.filter(x) 173 | x_mul = torch.mul(P,x_cat).view(-1,1,self.global_size,self.global_size) 174 | all_filter = self.filter(x_mul).view(-1,2,self.grid_size,self.grid_size) 175 | 176 | x_filter = all_filter[:,0,:,:].contiguous().view(-1,1,self.grid_size,self.grid_size) 177 | y_filter = all_filter[:,1,:,:].contiguous().view(-1,1,self.grid_size,self.grid_size) 178 | 179 | x_filter = x_filter/p_filter 180 | y_filter = y_filter/p_filter 181 | 182 | xgrids = x_filter*2-1 183 | ygrids = y_filter*2-1 184 | xgrids = torch.clamp(xgrids,min=-1,max=1) 185 | ygrids = torch.clamp(ygrids,min=-1,max=1) 186 | 187 | xgrids = xgrids.view(-1,1,self.grid_size,self.grid_size) 188 | ygrids = ygrids.view(-1,1,self.grid_size,self.grid_size) 189 | 190 | grid = torch.cat((xgrids,ygrids),1) 191 | 192 | grid = F.interpolate(grid, size=(self.input_size_net,self.input_size_net), mode='bilinear', align_corners=True) 193 | 194 | grid = torch.transpose(grid,1,2) 195 | grid = torch.transpose(grid,2,3) 196 | 197 | return grid 198 | 199 | def generate_map(self, input_x, class_response_maps, p): 200 | N, C, H, W = class_response_maps.size() 201 | 202 | score_pred, sort_number = torch.sort(F.softmax(F.adaptive_avg_pool2d(class_response_maps, 1), dim=1), dim=1, descending=True) 203 | gate_score = (score_pred[:, 0:5]*torch.log(score_pred[:, 0:5])).sum(1) 204 | 205 | xs = [] 206 | xs_inv = [] 207 | 208 | for idx_i in range(N): 209 | if gate_score[idx_i] > -0.2: 210 | decide_map = class_response_maps[idx_i, sort_number[idx_i, 0],:,:] 211 | else: 212 | decide_map = class_response_maps[idx_i, sort_number[idx_i, 0:5],:,:].mean(0) 213 | 214 | min_value, max_value = decide_map.min(), decide_map.max() 215 | decide_map = (decide_map-min_value)/(max_value-min_value) 216 | 217 | peak_list, aggregation = peak_stimulation(decide_map, win_size=3, peak_filter=_mean_filter) 218 | 219 | decide_map = decide_map.squeeze(0).squeeze(0) 220 | 221 | score = [decide_map[item[2], item[3]] for item in peak_list] 222 | x = [item[3] for item in peak_list] 223 | y = [item[2] for item in peak_list] 224 | 225 | if score == []: 226 | temp = torch.zeros(1, 1, self.grid_size,self.grid_size).cuda() 227 | temp += self.base_ratio 228 | xs.append(temp) 229 | xs_soft.append(temp) 230 | continue 231 | 232 | peak_num = torch.arange(len(score)) 233 | 234 | temp = self.base_ratio 235 | temp_w = self.base_ratio 236 | 237 | if p == 0: 238 | for i in peak_num: 239 | temp += score[i] * kernel_generate(self.radius(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda() 240 | temp_w += 1/score[i] * \ 241 | kernel_generate(self.radius_inv(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda() 242 | elif p == 1: 243 | for i in peak_num: 244 | rd = random.uniform(0, 1) 245 | if score[i] > rd: 246 | temp += score[i] * kernel_generate(self.radius(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda() 247 | else: 248 | temp_w += 1/score[i] * \ 249 | kernel_generate(self.radius_inv(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda() 250 | elif p == 2: 251 | index = score.index(max(score)) 252 | temp += score[index] * kernel_generate(self.radius(score[index]), H, (x[index].item(), y[index].item())).unsqueeze(0).unsqueeze(0).cuda() 253 | 254 | index = score.index(min(score)) 255 | temp_w += 1/score[index] * \ 256 | kernel_generate(self.radius_inv(torch.sqrt(score[index])), H, (x[index].item(), y[index].item())).unsqueeze(0).unsqueeze(0).cuda() 257 | 258 | if type(temp) == float: 259 | temp += torch.zeros(1, 1, self.grid_size,self.grid_size).cuda() 260 | xs.append(temp) 261 | 262 | if type(temp_w) == float: 263 | temp_w += torch.zeros(1, 1, self.grid_size,self.grid_size).cuda() 264 | xs_inv.append(temp_w) 265 | 266 | xs = torch.cat(xs, 0) 267 | xs_hm = nn.ReplicationPad2d(self.padding_size)(xs) 268 | grid = self.create_grid(xs_hm).to(input_x.device) 269 | x_sampled_zoom = F.grid_sample(input_x, grid) 270 | 271 | xs_inv = torch.cat(xs_inv, 0) 272 | xs_hm_inv = nn.ReplicationPad2d(self.padding_size)(xs_inv) 273 | grid_inv = self.create_grid(xs_hm_inv).to(input_x.device) 274 | x_sampled_inv = F.grid_sample(input_x, grid_inv) 275 | 276 | return x_sampled_zoom, x_sampled_inv 277 | 278 | def forward(self, input_x, p): 279 | 280 | self.map_origin.weight.data.copy_(self.raw_classifier.weight.data.unsqueeze(-1).unsqueeze(-1)) 281 | self.map_origin.bias.data.copy_(self.raw_classifier.bias.data) 282 | 283 | feature_raw = self.features(input_x) 284 | agg_origin = self.raw_classifier(self.avg(feature_raw).view(-1, 2048)) 285 | 286 | with torch.no_grad(): 287 | class_response_maps = F.interpolate(self.map_origin(feature_raw), size=self.grid_size, mode='bilinear', align_corners=True) 288 | x_sampled_zoom, x_sampled_inv = self.generate_map(input_x, class_response_maps, p) 289 | 290 | feature_D = self.sampler_buffer(self.features(x_sampled_zoom)) 291 | agg_sampler = self.sampler_classifier(self.avg(feature_D).view(-1, 2048)) 292 | 293 | feature_C = self.sampler_buffer1(self.features(x_sampled_inv)) 294 | agg_sampler1 = self.sampler_classifier1(self.avg(feature_C).view(-1, 2048)) 295 | 296 | aggregation = self.con_classifier(torch.cat([self.avg(feature_raw).view(-1, 2048), self.avg(feature_D).view(-1, 2048), self.avg(feature_C).view(-1, 2048)], 1)) 297 | 298 | return aggregation, agg_origin, agg_sampler, agg_sampler1 299 | 300 | 301 | @register 302 | def s3n( 303 | mode: str ='resnet50', 304 | num_classes: int = 200, 305 | task_input_size: int = 448, 306 | base_ratio: float = 0.09, 307 | radius: float = 0.08, 308 | radius_inv: float = 0.2) -> nn.Module: 309 | """ Selective sparse sampling. 310 | """ 311 | 312 | classify_network = modules.ft_resnet(mode=mode, fc_or_fcn = 'fc',num_classes=num_classes) 313 | model = S3N(classify_network, num_classes, task_input_size, base_ratio, radius, radius_inv) 314 | 315 | return model 316 | 317 | 318 | @register 319 | def three_stage( 320 | ctx: Context, 321 | train_ctx: Context) -> None: 322 | """Three stage. 323 | """ 324 | 325 | if train_ctx.is_train: 326 | p = 0 if train_ctx.epoch_idx <= 20 else 1 327 | else: 328 | p = 1 if train_ctx.epoch_idx <= 20 else 2 329 | 330 | train_ctx.output = train_ctx.model(train_ctx.input, p) 331 | 332 | raise train_ctx.Skip 333 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import logging 4 | from contextlib import contextmanager 5 | from typing import Any, Iterable, Union, List, Tuple, Dict, Callable, Optional 6 | 7 | import torch 8 | from torch import Tensor, nn, optim 9 | from torch.utils import data 10 | from tqdm import tqdm, tqdm_notebook 11 | from nest import register, Context 12 | 13 | 14 | class TqdmHandler(logging.StreamHandler): 15 | def __init__(self): 16 | logging.StreamHandler.__init__(self) 17 | 18 | def emit(self, record): 19 | msg = self.format(record) 20 | tqdm.write(msg) 21 | 22 | 23 | @register 24 | def network_trainer( 25 | data_loaders: Tuple[List[Tuple[str, data.DataLoader]], List[Tuple[str, data.DataLoader]]], 26 | model: nn.Module, 27 | criterion: object, 28 | optimizer: Callable[[Iterable], optim.Optimizer], 29 | parameter: Optional[Callable] = None, 30 | meters: Optional[Dict[str, Callable[[Context], Any]]] = None, 31 | hooks: Optional[Dict[str, List[Callable[[Context], None]]]] = None, 32 | max_epoch: int = 200, 33 | test_interval: int = 1, 34 | resume: Optional[str] = None, 35 | log_path: Optional[str] = None, 36 | device: str = 'cuda', 37 | use_data_parallel: bool = True, 38 | use_cudnn_benchmark: bool = True, 39 | random_seed: int = 999) -> Context: 40 | """Network trainer. 41 | """ 42 | 43 | torch.manual_seed(random_seed) 44 | 45 | logger = logging.getLogger('nest.network_trainer') 46 | logger.handlers = [] 47 | logger.setLevel(logging.DEBUG) 48 | 49 | screen_handler = TqdmHandler() 50 | screen_handler.setFormatter(logging.Formatter('[%(asctime)s] %(message)s')) 51 | logger.addHandler(screen_handler) 52 | 53 | if not log_path is None: 54 | 55 | try: 56 | os.makedirs(os.path.dirname(log_path)) 57 | except OSError as exception: 58 | if exception.errno != errno.EEXIST: 59 | raise 60 | file_handler = logging.FileHandler(log_path, encoding='utf8') 61 | file_handler.setFormatter(logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s')) 62 | logger.addHandler(file_handler) 63 | 64 | def run_in_notebook(): 65 | try: 66 | return get_ipython().__class__.__name__.startswith('ZMQ') 67 | except NameError: 68 | pass 69 | return False 70 | progress_bar = tqdm_notebook if run_in_notebook() else tqdm 71 | 72 | device = torch.device(device) 73 | if device.type == 'cuda': 74 | assert torch.cuda.is_available(), 'CUDA is not available.' 75 | torch.backends.cudnn.benchmark = use_cudnn_benchmark 76 | 77 | train_loaders, test_loaders = data_loaders 78 | 79 | model = model.to(device) 80 | 81 | if device.type == 'cuda' and use_data_parallel: 82 | model = nn.DataParallel(model) 83 | 84 | params = model.parameters() if parameter is None else parameter(model) 85 | optimizer = optimizer(params) 86 | 87 | start_epoch_idx = 0 88 | start_batch_idx = 0 89 | if not resume is None: 90 | logger.info('loading checkpoint "%s"' % resume) 91 | checkpoint = torch.load(resume) 92 | start_epoch_idx = checkpoint['epoch_idx'] 93 | start_batch_idx = checkpoint['batch_idx'] 94 | model.load_state_dict(checkpoint['model']) 95 | optimizer.load_state_dict(checkpoint['optimizer']) 96 | logger.info('checkpoint loaded (epoch %d)' % start_epoch_idx) 97 | 98 | ctx = Context( 99 | split = 'train', 100 | is_train = True, 101 | model = model, 102 | optimizer = optimizer, 103 | max_epoch = max_epoch, 104 | epoch_idx = start_epoch_idx, 105 | batch_idx = start_batch_idx, 106 | input = Tensor(), 107 | output = Tensor(), 108 | target = Tensor(), 109 | loss = Tensor(), 110 | metrics = dict(), 111 | state_dicts = [], 112 | logger = logger) 113 | 114 | class Skip(Exception): pass 115 | ctx.Skip = Skip 116 | 117 | @contextmanager 118 | def skip(): 119 | try: 120 | yield 121 | except Skip: 122 | pass 123 | 124 | def run_hooks(hook_type): 125 | if isinstance(hooks, dict) and hook_type in hooks: 126 | for hook in hooks.get(hook_type): 127 | hook(ctx) 128 | 129 | @contextmanager 130 | def session(name): 131 | run_hooks('on_start_' + name) 132 | yield 133 | run_hooks('on_end_' + name) 134 | 135 | def process(split, data_loader, is_train): 136 | ctx.max_batch = len(data_loader) 137 | ctx.split = split 138 | ctx.is_train = is_train 139 | 140 | run_hooks('on_start_split') 141 | 142 | if is_train: 143 | model.train() 144 | else: 145 | model.eval() 146 | 147 | for batch_idx, (input, target) in enumerate(progress_bar(data_loader, ascii=True, desc=split, unit='batch', leave=False)): 148 | if batch_idx < ctx.batch_idx: 149 | continue 150 | 151 | ctx.batch_idx = batch_idx 152 | if isinstance(input, (list, tuple)): 153 | ctx.input = [v.to(device) if torch.is_tensor(v) else v for v in input] 154 | elif isinstance(input, dict): 155 | ctx.input = {k: v.to(device) if torch.is_tensor(v) else v for k, v in input.items()} 156 | else: 157 | ctx.input = input.to(device) 158 | ctx.target = target.to(device) 159 | 160 | run_hooks('on_start_batch') 161 | 162 | with skip(), session('batch'): 163 | with torch.set_grad_enabled(ctx.is_train): 164 | with skip(), session('forward'): 165 | ctx.output = ctx.model(ctx.input) 166 | ctx.loss = criterion(ctx.output, ctx.target) 167 | 168 | if not meters is None: 169 | ctx.metrics.update({split + '_' + k: v(ctx) for k, v in meters.items() if v is not None}) 170 | 171 | if is_train: 172 | optimizer.zero_grad() 173 | ctx.loss.backward() 174 | optimizer.step() 175 | 176 | run_hooks('on_end_batch') 177 | ctx.batch_idx = 0 178 | 179 | run_hooks('on_end_split') 180 | 181 | run_hooks('on_start') 182 | 183 | for epoch_idx in progress_bar(range(ctx.epoch_idx, max_epoch), ascii=True, unit='epoch'): 184 | ctx.epoch_idx = epoch_idx 185 | run_hooks('on_start_epoch') 186 | 187 | for split, loader in train_loaders: 188 | process(split, loader, True) 189 | 190 | if epoch_idx % test_interval == 0: 191 | for split, loader in test_loaders: 192 | process(split, loader, False) 193 | 194 | run_hooks('on_end_epoch') 195 | 196 | run_hooks('on_end') 197 | 198 | return ctx 199 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, Tuple, Dict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from nest import register, Context 8 | 9 | 10 | class AverageMeter(object): 11 | 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | @register 29 | def multi_topk_meter( 30 | ctx: Context, 31 | train_ctx: Context, 32 | k: int=1, 33 | init_num: int=1, 34 | end_num: int = 0) -> dict: 35 | """Multi topk meter. 36 | """ 37 | 38 | def accuracy(output, target, k=1): 39 | batch_size = target.size(0) 40 | 41 | _, pred = output.topk(k, 1, True, True) 42 | pred = pred.t() 43 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 44 | 45 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 46 | return correct_k.mul_(100.0 / batch_size) 47 | 48 | for i in range(init_num, len(train_ctx.output) - end_num): 49 | if not "branch_"+str(i) in ctx: 50 | setattr(ctx, "branch_"+str(i), AverageMeter()) 51 | 52 | if train_ctx.batch_idx == 0: 53 | for i in range(init_num, len(train_ctx.output) - end_num): 54 | getattr(ctx, "branch_"+str(i)).reset() 55 | 56 | for i in range(init_num, len(train_ctx.output) - end_num): 57 | acc = accuracy(train_ctx.output[i], train_ctx.target, k) 58 | getattr(ctx, "branch_"+str(i)).update(acc.item()) 59 | 60 | acc_list = {} 61 | 62 | for i in range(init_num, len(train_ctx.output) - end_num): 63 | acc_list["branch_"+str(i)] = getattr(ctx, "branch_"+str(i)).avg 64 | 65 | return acc_list 66 | 67 | 68 | @register 69 | def best_meter(ctx: Context, train_ctx: Context, best_branch: int = 1, k: int = 1) -> float: 70 | """Best meter. 71 | """ 72 | def accuracy(output, target, k=1): 73 | batch_size = target.size(0) 74 | 75 | _, pred = output.topk(k, 1, True, True) 76 | pred = pred.t() 77 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 78 | 79 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 80 | return correct_k.mul_(100.0 / batch_size) 81 | 82 | if not 'meter' in ctx: 83 | ctx.meter = AverageMeter() 84 | 85 | if train_ctx.batch_idx == 0: 86 | ctx.meter.reset() 87 | acc = accuracy(train_ctx.output[best_branch], train_ctx.target, k) 88 | ctx.meter.update(acc.item()) 89 | return ctx.meter.avg -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from fnmatch import fnmatch 3 | from typing import Tuple, List, Union, Dict, Iterable 4 | 5 | import numpy as np 6 | import torch.nn as nn 7 | from nest import register 8 | 9 | 10 | @register 11 | def finetune( 12 | model: nn.Module, 13 | base_lr: float, 14 | groups: Dict[str, float], 15 | ignore_the_rest: bool = False, 16 | raw_query: bool = False) -> List[Dict[str, Union[float, Iterable]]]: 17 | """Fintune. 18 | """ 19 | 20 | parameters = [dict(params=[], names=[], query=query if raw_query else '*'+query+'*', lr=lr*base_lr) for query, lr in groups.items()] 21 | rest_parameters = dict(params=[], names=[], lr=base_lr) 22 | for k, v in model.named_parameters(): 23 | matched = False 24 | for group in parameters: 25 | if fnmatch(k, group['query']): 26 | group['params'].append(v) 27 | group['names'].append(k) 28 | matched = True 29 | break 30 | if not matched: 31 | rest_parameters['params'].append(v) 32 | rest_parameters['names'].append(k) 33 | if not ignore_the_rest: 34 | parameters.append(rest_parameters) 35 | for group in parameters: 36 | group['params'] = iter(group['params']) 37 | return parameters 38 | 39 | 40 | @register(acknowledgement='RLE modules are based on Sam Stainsby\'s implementation (https://www.kaggle.com/stainsby/fast-tested-rle)') 41 | def rle_encode(mask: np.ndarray) -> dict: 42 | """Run-Length Encoding (RLE). 43 | """ 44 | 45 | assert mask.dtype == bool and mask.ndim == 2, 'RLE encoding requires a binary mask (dtype=bool).' 46 | pixels = mask.flatten() 47 | pixels = np.concatenate([[0], pixels, [0]]) 48 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 49 | runs[1::2] -= runs[::2] 50 | return dict(data=base64.b64encode(runs.astype(np.uint32).tobytes()).decode('utf-8'), shape=mask.shape) 51 | 52 | 53 | @register 54 | def rle_decode(rle: dict) -> np.ndarray: 55 | """Run-Length Encoding. 56 | """ 57 | 58 | runs = np.frombuffer(base64.b64decode(rle['data']), np.uint32) 59 | shape = rle['shape'] 60 | starts, lengths = [np.asarray(x, dtype=int) for x in (runs[0:][::2], runs[1:][::2])] 61 | starts -= 1 62 | ends = starts + lengths 63 | img = np.zeros(shape[0]*shape[1], dtype=np.uint8) 64 | for lo, hi in zip(starts, ends): 65 | img[lo:hi] = 1 66 | return img.reshape(shape) 67 | --------------------------------------------------------------------------------