├── README.md
├── datasets.py
├── datasets
├── test.txt
└── train.txt
├── demo
├── cub_baseline.yml
└── cub_s3n.yml
├── fgvc_datasets.py
├── ft_resnet.py
├── hooks.py
├── illustration.png
├── losses.py
├── meters.py
├── optimizers.py
├── sss_net.py
├── trainer.py
├── utility.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
Selective Sparse Sampling for Fine-grained Image Recognition
2 |
3 | 
4 |
5 | ## PyTorch Implementation
6 | This repository contains:
7 |
8 | * the **pytorch** implementation of Selective Sparse Sampling.
9 | * the CUB-200-2011 demo (training, test).
10 |
11 | Please follow the instruction below to install it and run the experiment demo.
12 |
13 | ### Prerequisites
14 | * System (tested on Ubuntu 14.04LTS and Win10)
15 | * 2 Tesla P100 + CUDA CuDNN (CPU mode is also supported but significantly slower)
16 | * [Python=3.6.8](https://www.python.org)
17 | * [PyTorch=0.4.1](https://pytorch.org)
18 | * [Jupyter Notebook](https://jupyter.org/install.html)
19 | * [Nest](https://github.com/ZhouYanzhao/Nest.git)
20 |
21 | ### Installation
22 |
23 | 1. Install S3N via Nest's CLI tool:
24 |
25 | ```bash
26 | # note that data will be saved under your current path
27 | $ git clone https://github.com/Yao-DD/S3N.git ./S3N
28 | $ nest module install ./S3N/ s3n
29 | # verify the installation
30 | $ nest module list --filter s3n
31 | ```
32 |
33 | ### Prepare Data
34 |
35 | 1. Download the CUB-200-2011 dataset:
36 |
37 | ```bash
38 | $ mkdir ./S3N/datasets
39 | $ cd ./S3N/datasets
40 | # download and extract data
41 | $ wget http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz
42 | $ tar xvf CUB_200_2011.tgz
43 | ```
44 |
45 | 2. Prepare annotation files:
46 |
47 | Move the file ./datasets/train.txt and ./datasets/test.txt into ./datasets/CUB_200_2011. The list of image file names and label is contained in the file ./datasets/CUB_200_2011/train.txt and ./datasets/CUB_200_2011/test.txt, with each line corresponding to one image:
48 |
49 | ```
50 |
51 | ```
52 |
53 | ### Run the demo
54 |
55 | 1. run the code as:
56 |
57 | ```bash
58 | $ cd ./S3N
59 | # run baseline
60 | $ PYTHONWARNINGS='ignore' CUDA_VISIBLE_DEVICES=0,1 nest task run ./demo/cub_baseline.yml
61 | # run S3N
62 | $ PYTHONWARNINGS='ignore' CUDA_VISIBLE_DEVICES=0,1 nest task run ./demo/cub_s3n.yml
63 | ```
64 |
65 | ### Pretrained models
66 |
67 | 1. S3N model for CUB_200_2011 dataset is availavble on Baidu Disk.
68 |
69 | ```bash
70 | The link:https://pan.baidu.com/s/19x9zI_ZNi32sRGRgNwN_Fw
71 | code: r252
72 | ```
73 |
74 | ## CAUTION
75 | The current code was prepared under the above-mentioned prerequisites. The use of other version can cause problems.
76 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | from typing import Tuple, List, Dict, Union, Callable, Optional
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision import datasets, transforms
9 | from PIL import Image
10 | from nest import register
11 |
12 |
13 | @register
14 | def image_transform(
15 | image_size: Union[int, List[int]],
16 | augmentation: dict,
17 | mean: List[float] = [0.485, 0.456, 0.406],
18 | std: List[float] = [0.229, 0.224, 0.225]) -> Callable:
19 | """Image transforms.
20 | """
21 |
22 | if isinstance(image_size, int):
23 | image_size = (image_size, image_size)
24 | else:
25 | image_size = tuple(image_size)
26 |
27 | horizontal_flip = augmentation.pop('horizontal_flip', None)
28 | if horizontal_flip is not None:
29 | assert isinstance(horizontal_flip, float) and 0 <= horizontal_flip <= 1
30 |
31 | vertical_flip = augmentation.pop('vertical_flip', None)
32 | if vertical_flip is not None:
33 | assert isinstance(vertical_flip, float) and 0 <= vertical_flip <= 1
34 |
35 | random_crop = augmentation.pop('random_crop', None)
36 | if random_crop is not None:
37 | assert isinstance(random_crop, dict)
38 |
39 | center_crop = augmentation.pop('center_crop', None)
40 | if center_crop is not None:
41 | assert isinstance(center_crop, (int, list))
42 |
43 | if len(augmentation) > 0:
44 | raise NotImplementedError('Invalid augmentation options: %s.' % ', '.join(augmentation.keys()))
45 |
46 | t = [
47 | transforms.Resize(image_size) if random_crop is None else transforms.RandomResizedCrop(image_size[0], **random_crop),
48 | transforms.CenterCrop(center_crop) if center_crop is not None else None,
49 | transforms.RandomHorizontalFlip(horizontal_flip) if horizontal_flip is not None else None,
50 | transforms.RandomVerticalFlip(vertical_flip) if vertical_flip is not None else None,
51 | transforms.ToTensor(),
52 | transforms.Normalize(mean, std)]
53 |
54 | return transforms.Compose([v for v in t if v is not None])
55 |
56 |
57 | @register
58 | def fetch_data(
59 | dataset: Callable[[str], Dataset],
60 | transform: Optional[Callable] = None,
61 | target_transform: Optional[Callable] = None,
62 | num_workers: int = 0,
63 | pin_memory: bool = True,
64 | drop_last: bool = False,
65 | train_splits: List[str] = [],
66 | test_splits: List[str] = [],
67 | train_shuffle: bool = True,
68 | test_shuffle: bool = False,
69 | test_image_size: int = 600,
70 | train_augmentation: dict = {},
71 | test_augmentation: dict = {},
72 | batch_size: int = 1,
73 | test_batch_size: Optional[int] = None) -> Tuple[List[Tuple[str, DataLoader]], List[Tuple[str, DataLoader]]]:
74 | """Fetch data.
75 | """
76 |
77 | train_transform = transform(augmentation=train_augmentation) if transform else None
78 | train_loader_list = []
79 | for split in train_splits:
80 | train_loader_list.append((split, DataLoader(
81 | dataset = dataset(
82 | split = split,
83 | transform = train_transform,
84 | target_transform = target_transform),
85 | batch_size = batch_size,
86 | num_workers = num_workers,
87 | pin_memory = pin_memory,
88 | drop_last=drop_last,
89 | shuffle = train_shuffle)))
90 |
91 | test_transform = transform(image_size=[test_image_size, test_image_size], augmentation=test_augmentation) if transform else None
92 | test_loader_list = []
93 | for split in test_splits:
94 | test_loader_list.append((split, DataLoader(
95 | dataset = dataset(
96 | split = split,
97 | transform = test_transform,
98 | target_transform = target_transform),
99 | batch_size = batch_size if test_batch_size is None else test_batch_size,
100 | num_workers = num_workers,
101 | pin_memory = pin_memory,
102 | drop_last=drop_last,
103 | shuffle = test_shuffle)))
104 |
105 | return train_loader_list, test_loader_list
--------------------------------------------------------------------------------
/demo/cub_baseline.yml:
--------------------------------------------------------------------------------
1 | _name: network_trainer
2 | data_loaders:
3 | _name: s3n.fetch_data
4 | dataset:
5 | _name: s3n.fgvc_dataset
6 | data_dir: ./datasets/CUB_200_2011
7 | batch_size: 16
8 | num_workers: 4
9 | transform:
10 | _name: image_transform
11 | image_size: [448, 448]
12 | mean: [0.485, 0.456, 0.406]
13 | std: [0.229, 0.224, 0.225]
14 | train_augmentation:
15 | horizontal_flip: 0.5
16 | random_crop:
17 | scale: [0.5, 1]
18 | test_augmentation:
19 | center_crop: 448
20 | train_splits:
21 | - train
22 | test_splits:
23 | - test
24 | log_path: './logs/CUB_baseline.log'
25 | model:
26 | _name: s3n.ft_resnet
27 | mode: 'resnet50'
28 | num_classes: 200
29 | criterion:
30 | _name: s3n.smooth_loss
31 | smooth_ratio: 0.85
32 | optimizer:
33 | _name: s3n.sgd_optimizer
34 | lr: 0.01
35 | momentum: 0.9
36 | weight_decay: 1.0e-4
37 | parameter:
38 | _name: finetune
39 | base_lr: 0.001
40 | groups:
41 | 'classifier': 10.0
42 | meters:
43 | top1:
44 | _name: s3n.topk_meter
45 | k: 1
46 | loss:
47 | _name: loss_meter
48 | max_epoch: 30
49 | device: cuda
50 | hooks:
51 | on_start_epoch:
52 | -
53 | _name: update_lr
54 | epoch_list: [20,]
55 | on_end_epoch:
56 | -
57 | _name: print_state
58 | formats:
59 | - 'epoch: {epoch_idx}'
60 | - 'train_loss: {metrics[train_loss]:.4f}'
61 | - 'test_loss: {metrics[test_loss]:.4f}'
62 | - 'train_top1: {metrics[train_top1]:.2f}%'
63 | - 'test_top1: {metrics[test_top1]:.2f}%'
64 | -
65 | _name: checkpoint
66 | save_dir: './generate/baseline/'
67 | save_step: 1
--------------------------------------------------------------------------------
/demo/cub_s3n.yml:
--------------------------------------------------------------------------------
1 | _name: network_trainer
2 | data_loaders:
3 | _name: s3n.fetch_data
4 | dataset:
5 | _name: s3n.fgvc_dataset
6 | data_dir: ./datasets/CUB_200_2011
7 | batch_size: 16
8 | num_workers: 4
9 | transform:
10 | _name: s3n.image_transform
11 | image_size: [448, 448]
12 | mean: [0.485, 0.456, 0.406]
13 | std: [0.229, 0.224, 0.225]
14 | train_augmentation:
15 | horizontal_flip: 0.5
16 | random_crop:
17 | scale: [0.5, 1]
18 | test_augmentation:
19 | center_crop: 448
20 | test_image_size: 600
21 | train_splits:
22 | - train
23 | test_splits:
24 | - test
25 | log_path: './logs/cub_s3n.log'
26 | model:
27 | _name: s3n.s3n
28 | mode: 'resnet50'
29 | num_classes: 200
30 | radius: 0.09
31 | radius_inv: 0.3
32 | criterion:
33 | _name: s3n.multi_smooth_loss
34 | smooth_ratio: 0.85
35 | optimizer:
36 | _name: s3n.sgd_optimizer
37 | lr: 0.01
38 | momentum: 0.9
39 | weight_decay: 1.0e-4
40 | parameter:
41 | _name: s3n.finetune
42 | base_lr: 0.001
43 | groups:
44 | 'classifier': 10.0
45 | 'radius': 0.0001
46 | 'filter': 0.0001
47 | meters:
48 | top1:
49 | _name: s3n.multi_topk_meter
50 | k: 1
51 | init_num: 0
52 | loss:
53 | _name: s3n.loss_meter
54 | max_epoch: 60
55 | device: cuda
56 | use_data_parallel: yes
57 | hooks:
58 | on_start_epoch:
59 | -
60 | _name: s3n.update_lr
61 | epoch_step: 40
62 | on_start_forward:
63 | -
64 | _name: s3n.three_stage
65 | on_end_epoch:
66 | -
67 | _name: s3n.print_state
68 | formats:
69 | - 'epoch: {epoch_idx}'
70 | - 'train_loss: {metrics[train_loss]:.4f}'
71 | - 'test_loss: {metrics[test_loss]:.4f}'
72 | - 'train_branch1_top1: {metrics[train_top1][branch_0]:.2f}%'
73 | - 'train_branch2_top1: {metrics[train_top1][branch_1]:.2f}%'
74 | - 'train_branch3_top1: {metrics[train_top1][branch_2]:.2f}%'
75 | - 'train_branch4_top1: {metrics[train_top1][branch_3]:.2f}%'
76 | - 'test_branch1_top1: {metrics[test_top1][branch_0]:.2f}%'
77 | - 'test_branch2_top1: {metrics[test_top1][branch_1]:.2f}%'
78 | - 'test_branch3_top1: {metrics[test_top1][branch_2]:.2f}%'
79 | - 'test_branch4_top1: {metrics[test_top1][branch_3]:.2f}%'
80 | -
81 | _name: s3n.checkpoint
82 | save_dir: './generate/s3n/'
83 | save_step: 1
--------------------------------------------------------------------------------
/fgvc_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | from typing import List, Dict, Tuple, Callable, Optional, Union
4 |
5 | import torch
6 | import numpy as np
7 | from torch.utils.data import Dataset,DataLoader
8 | from torchvision import datasets, transforms
9 | from PIL import Image
10 | from nest import register
11 |
12 |
13 | class FGVC_Dataset(Dataset):
14 |
15 | def __init__(self, data_dir, split, lable_path=None, transform=None, target_transform=None):
16 | self.data_dir = data_dir
17 | self.split = split
18 | self.lable_path = lable_path
19 | self.transform = transform
20 | self.target_transform = target_transform
21 |
22 | self.image_dir = os.path.join(self.data_dir, 'images')
23 | self.image_lables = self._read_annotation(self.split)
24 |
25 | def _read_annotation(self, split):
26 | class_lables = OrderedDict()
27 | if self.lable_path is None:
28 | lable_path = os.path.join(self.data_dir, split + '.txt')
29 | else:
30 | lable_path = os.path.join(self.data_dir, self.lable_path, split + '.txt')
31 | if os.path.exists(lable_path):
32 | with open(lable_path, 'r') as f:
33 | for line in f:
34 | name, lable = line.split(' ')
35 | class_lables[name] = int(lable)
36 | else:
37 | raise NotImplementedError(
38 | 'Invalid path for dataset')
39 |
40 | return list(class_lables.items())
41 |
42 | def __getitem__(self, index):
43 | filename, target = self.image_lables[index]
44 | img = Image.open(os.path.join(self.image_dir, filename)).convert('RGB')
45 |
46 | if self.transform:
47 | img = self.transform(img)
48 | if self.target_transform:
49 | target = self.target_transform(target)
50 |
51 | return img, target
52 |
53 | def __len__(self):
54 | return len(self.image_lables)
55 |
56 |
57 | @register
58 | def fgvc_dataset(
59 | split: str,
60 | data_dir: str,
61 | label_path: Optional[str] = None,
62 | transform: Optional[Callable] = None,
63 | target_transform: Optional[Callable] = None) -> Dataset:
64 | '''Fine-grained visual classification datasets.
65 | '''
66 | return FGVC_Dataset(data_dir, split, label_path, transform, target_transform)
--------------------------------------------------------------------------------
/ft_resnet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision import models
3 | from nest import register
4 |
5 |
6 | @register
7 | def ft_resnet(mode: str = 'resnet50', fc_or_fcn: str = 'fc', num_classes: int = 10, pretrained: bool = True) -> nn.Module:
8 | """Finetune resnet.
9 | """
10 |
11 | class FT_Resnet(nn.Module):
12 | def __init__(self, mode='resnet50', fc_or_fcn='fc', num_classes=10, pretrained=True):
13 | super(FT_Resnet, self).__init__()
14 |
15 | if mode=='resnet50':
16 | model = models.resnet50(pretrained=pretrained)
17 | elif mode=='resnet101':
18 | model = models.resnet101(pretrained=pretrained)
19 | elif mode=='resnet152':
20 | model = models.resnet152(pretrained=pretrained)
21 | else:
22 | model = models.resnet18(pretrained=pretrained)
23 |
24 | self.features = nn.Sequential(
25 | model.conv1,
26 | model.bn1,
27 | model.relu,
28 | model.maxpool,
29 | model.layer1,
30 | model.layer2,
31 | model.layer3,
32 | model.layer4
33 | )
34 | self.num_classes = num_classes
35 | self.num_features = model.layer4[1].conv1.in_channels
36 | self.fc_or_fcn = fc_or_fcn
37 | if self.fc_or_fcn=='fc':
38 | self.classifier = nn.Linear(self.num_features, num_classes)
39 | else:
40 | self.classifier = nn.Conv2d(self.num_features, self.num_classes, 1, 1)
41 | self.avg = nn.AdaptiveAvgPool2d(1)
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | if self.fc_or_fcn=='fc':
46 | x = self.avg(x).view(-1, self.num_features)
47 | x = self.classifier(x)
48 | else:
49 | x = self.classifier(x)
50 | x = self.avg(x).view(-1, self.num_classes)
51 | return x
52 |
53 | return FT_Resnet(mode, fc_or_fcn, num_classes, pretrained)
54 |
--------------------------------------------------------------------------------
/hooks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | from time import localtime, strftime
4 | from typing import List, Callable, Optional
5 |
6 | import torch
7 | import numpy as np
8 | from visdom import Visdom
9 | from nest import register, Context
10 |
11 |
12 | @register
13 | def checkpoint(
14 | train_ctx: Context,
15 | save_dir: str,
16 | save_step: Optional[int] = None,
17 | save_final: bool = False,
18 | save_latest: bool = False,
19 | save_all: bool = False) -> None:
20 | """Checkpoint.
21 | """
22 |
23 | save_dir = os.path.abspath(save_dir)
24 | try:
25 | os.makedirs(save_dir)
26 | except OSError as exception:
27 | if exception.errno != errno.EEXIST:
28 | raise
29 |
30 | def save_current_train_ctx(save_name):
31 | save_path = os.path.join(save_dir, save_name)
32 | torch.save(dict(
33 | epoch_idx = train_ctx.epoch_idx + 1,
34 | batch_idx = train_ctx.batch_idx + 1,
35 | model = train_ctx.model.state_dict(),
36 | optimizer = train_ctx.optimizer.state_dict()), save_path)
37 | train_ctx.logger.info('checkpoint created at %s' % save_path)
38 |
39 | if save_all:
40 | save_current_train_ctx(strftime("model_%Y_%m_%d_%H.%M.%S.pt", localtime()))
41 | if save_step is not None and (train_ctx.epoch_idx + 1) % save_step == 0:
42 | save_current_train_ctx('model_%d.pt' % train_ctx.epoch_idx)
43 | if save_final and (train_ctx.epoch_idx + 1) == train_ctx.max_epoch:
44 | save_current_train_ctx('model_final.pt')
45 | if save_latest:
46 | save_current_train_ctx('model_latest.pt')
47 |
48 |
49 | @register
50 | def vis_trend(ctx: Context, train_ctx: Context, server: str, env: str, port: int = 80) -> None:
51 | """Track trend with Visdom.
52 | """
53 |
54 | if not 'vis' in ctx:
55 | ctx.vis = Visdom(server=server, port=port, env=env)
56 |
57 | try:
58 | for k, v in train_ctx.metrics.items():
59 | if isinstance(v, (int, float)):
60 | if ctx.vis.win_exists(k):
61 | ctx.vis.line(
62 | X = np.array([train_ctx.epoch_idx]),
63 | Y = np.array([v]),
64 | opts = dict(title=k, xlabel='epoch'),
65 | win = k,
66 | update = 'append')
67 | else:
68 | ctx.vis.line(
69 | X = np.array([train_ctx.epoch_idx]),
70 | Y = np.array([v]),
71 | opts = dict(title=k, xlabel='epoch'),
72 | win = k)
73 | ctx.vis.save([env])
74 | except ConnectionError:
75 | train_ctx.logger.warning('Could not connect to visdom server "%s".' % server)
76 |
77 |
78 | @register
79 | def print_state(train_ctx: Context, formats: List[str], join_str: str = ' | ') -> None:
80 | """Print state.
81 | """
82 |
83 | def unescape(escapped_str):
84 | return bytes(escapped_str, "utf-8").decode("unicode_escape")
85 |
86 | def safe_format(format_str, **kwargs):
87 | try:
88 | return format_str.format(**kwargs)
89 | except:
90 | return None
91 |
92 | format_list = [safe_format(unescape(format_str), **vars(train_ctx)) for format_str in formats]
93 | output_str = unescape(join_str).join([val for val in format_list if val is not None])
94 | train_ctx.logger.info(output_str)
95 |
96 |
97 | @register
98 | def interval(
99 | train_ctx: Context,
100 | hook: Callable[[Context], None],
101 | epoch_interval: int = 1,
102 | batch_interval: int = 1) -> None:
103 | """Skip interval.
104 | """
105 |
106 | if train_ctx.epoch_idx % epoch_interval == 0 and train_ctx.batch_idx % batch_interval == 0:
107 | hook(train_ctx)
108 |
109 |
110 | @register
111 | def update_lr(
112 | train_ctx: Context,
113 | epoch_step: Optional[int] = None,
114 | epoch_list: Optional[List[int]] = None,
115 | factor: float = 0.1) -> None:
116 | """Update learning rate.
117 | """
118 |
119 | current_epoch = train_ctx.epoch_idx + 1
120 | if ((epoch_step is not None) and (current_epoch % epoch_step == 0)) or \
121 | ((epoch_list is not None) and (current_epoch in epoch_list)):
122 | for idx, param in enumerate(train_ctx.optimizer.param_groups):
123 | param['lr'] = param['lr'] * factor
124 | print('LR of param group %d is updated to %e' % (idx, param['lr']))
125 |
--------------------------------------------------------------------------------
/illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yao-DD/S3N/c77729b19be2c0d8581f0c522d856a3fc7fdfa21/illustration.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Tuple, Dict, Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 | from nest import register
7 |
8 |
9 | @register
10 | def cross_entropy_loss(
11 | input: Tensor,
12 | target: Tensor,
13 | weight: Optional[Tensor] = None,
14 | size_average: bool = True,
15 | ignore_index: int = -100,
16 | reduce: bool = True) -> Tensor:
17 | """Cross entropy loss.
18 | """
19 |
20 | return F.cross_entropy(input, target, weight, size_average, ignore_index, reduce)
21 |
22 |
23 | @register
24 | def smooth_loss(
25 | input: Tensor,
26 | target: Tensor,
27 | smooth_ratio: float = 0.9,
28 | weight: Union[None, Tensor] = None,
29 | size_average: bool = True,
30 | ignore_index: int = -100,
31 | reduce: bool = True) -> Tensor:
32 | '''Smooth loss.
33 | '''
34 |
35 | prob = F.log_softmax(input, dim=1)
36 | ymask = prob.data.new(prob.size()).zero_()
37 | ymask = ymask.scatter_(1, target.view(-1,1), 1)
38 | ymask = smooth_ratio*ymask + (1-smooth_ratio)*(1-ymask)/(len(input[1])-1)
39 | loss = - (prob*ymask).sum(1).mean()
40 |
41 | return loss
42 |
43 |
44 | @register
45 | def multi_smooth_loss(
46 | input: Tuple,
47 | target: Tensor,
48 | smooth_ratio: float = 0.9,
49 | loss_weight: Union[None, Dict]= None,
50 | weight: Union[None, Tensor] = None,
51 | size_average: bool = True,
52 | ignore_index: int = -100,
53 | reduce: bool = True) -> Tensor:
54 | '''Multi smooth loss.
55 | '''
56 | assert isinstance(input, tuple), 'input is less than 2'
57 |
58 | weight_loss = torch.ones(len(input)).to(input[0].device)
59 | if loss_weight is not None:
60 | for item in loss_weight.items():
61 | weight_loss[int(item[0])] = item[1]
62 |
63 | loss = 0
64 | for i in range(0, len(input)):
65 | if i in [1, len(input)-1]:
66 | prob = F.log_softmax(input[i], dim=1)
67 | ymask = prob.data.new(prob.size()).zero_()
68 | ymask = ymask.scatter_(1, target.view(-1,1), 1)
69 | ymask = smooth_ratio*ymask + (1-smooth_ratio)*(1-ymask)/(input[i].shape[1]-1)
70 | loss_tmp = - weight_loss[i]*((prob*ymask).sum(1).mean())
71 | else:
72 | loss_tmp = weight_loss[i]*F.cross_entropy(input[i], target, weight, size_average, ignore_index, reduce)
73 | loss += loss_tmp
74 |
75 | return loss
76 |
--------------------------------------------------------------------------------
/meters.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from nest import register, Context
5 |
6 |
7 | class AverageMeter(object):
8 |
9 | def __init__(self):
10 | self.reset()
11 |
12 | def reset(self):
13 | self.val = 0
14 | self.avg = 0
15 | self.sum = 0
16 | self.count = 0
17 |
18 | def update(self, val, n=1):
19 | self.val = val
20 | self.sum += val * n
21 | self.count += n
22 | self.avg = self.sum / self.count
23 |
24 |
25 | @register
26 | def loss_meter(ctx: Context, train_ctx: Context) -> float:
27 | """Loss meter.
28 | """
29 |
30 | if not 'meter' in ctx:
31 | ctx.meter = AverageMeter()
32 |
33 | if train_ctx.batch_idx == 0:
34 | ctx.meter.reset()
35 | ctx.meter.update(train_ctx.loss.item(), train_ctx.target.size(0))
36 | return ctx.meter.avg
37 |
38 |
39 | @register
40 | def topk_meter(ctx: Context, train_ctx: Context, k: int = 1) -> float:
41 | """Topk meter.
42 | """
43 |
44 | def accuracy(output, target, k=1):
45 | batch_size = target.size(0)
46 |
47 | _, pred = output.topk(k, 1, True, True)
48 | pred = pred.t()
49 | correct = pred.eq(target.view(1, -1).expand_as(pred))
50 |
51 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
52 | return correct_k.mul_(100.0 / batch_size)
53 |
54 | if not 'meter' in ctx:
55 | ctx.meter = AverageMeter()
56 |
57 | if train_ctx.batch_idx == 0:
58 | ctx.meter.reset()
59 | acc = accuracy(train_ctx.output, train_ctx.target, k)
60 | ctx.meter.update(acc.item())
61 | return ctx.meter.avg
62 |
--------------------------------------------------------------------------------
/optimizers.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable
2 |
3 | from torch import optim
4 | from nest import register
5 |
6 |
7 | @register
8 | def sgd_optimizer(
9 | parameters: Iterable,
10 | lr: float,
11 | momentum: float = 0.0,
12 | dampening: float = 0.0,
13 | weight_decay: float = 0.0,
14 | nesterov: bool = False) -> optim.Optimizer:
15 | """SGD optimizer.
16 | """
17 |
18 | return optim.SGD(parameters, lr, momentum, dampening, weight_decay, nesterov)
19 |
20 |
21 | @register
22 | def adadelta_optimizer(
23 | parameters: Iterable,
24 | lr: float = 1.0,
25 | rho: float = 0.9,
26 | eps: float = 1e-6,
27 | weight_decay: float = 0.0) -> optim.Optimizer:
28 | """Adadelta optimizer.
29 | """
30 |
31 | return optim.Adadelta(parameters, lr, rho, eps, weight_decay)
32 |
--------------------------------------------------------------------------------
/sss_net.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import random
4 | from collections import OrderedDict
5 | from typing import List, Dict, Tuple, Callable, Optional, Union
6 |
7 | import torch
8 | import numpy as np
9 | from PIL import Image
10 | from torch import nn
11 | from torch.nn import functional as F
12 | from torchvision import models
13 | from torch.autograd import Function
14 | from nest import register, modules, Context
15 |
16 |
17 | def makeGaussian(size, fwhm = 3, center=None):
18 |
19 | x = np.arange(0, size, 1, float)
20 | y = x[:,np.newaxis]
21 |
22 | if center is None:
23 | x0 = y0 = size // 2
24 | else:
25 | x0 = center[0]
26 | y0 = center[1]
27 |
28 | return np.exp(-4*np.log(2) * ((x-x0)**2 + (y-y0)**2) / fwhm**2)
29 |
30 |
31 | class KernelGenerator(nn.Module):
32 | def __init__(self, size, offset=None):
33 | super(KernelGenerator, self).__init__()
34 |
35 | self.size = self._pair(size)
36 | xx, yy = np.meshgrid(np.arange(0, size), np.arange(0, size))
37 | if offset is None:
38 | offset_x = offset_y = size // 2
39 | else:
40 | offset_x, offset_y = self._pair(offset)
41 | self.factor = torch.from_numpy(-(np.power(xx - offset_x, 2) + np.power(yy - offset_y, 2)) / 2).float()
42 |
43 | @staticmethod
44 | def _pair(x):
45 | return (x, x) if isinstance(x, int) else x
46 |
47 | def forward(self, theta):
48 | pow2 = torch.pow(theta * self.size[0], 2)
49 | kernel = 1.0 / (2 * np.pi * pow2) * torch.exp(self.factor.to(theta.device) / pow2)
50 | return kernel / kernel.max()
51 |
52 |
53 | def kernel_generate(theta, size, offset=None):
54 | return KernelGenerator(size, offset)(theta)
55 |
56 |
57 | def _mean_filter(input):
58 | batch_size, num_channels, h, w = input.size()
59 | threshold = torch.mean(input.view(batch_size, num_channels, h * w), dim=2)
60 | return threshold.contiguous().view(batch_size, num_channels, 1, 1)
61 |
62 |
63 | class PeakStimulation(Function):
64 |
65 | @staticmethod
66 | def forward(ctx, input, return_aggregation, win_size, peak_filter):
67 | ctx.num_flags = 4
68 |
69 | assert win_size % 2 == 1, 'Window size for peak finding must be odd.'
70 | offset = (win_size - 1) // 2
71 | padding = torch.nn.ConstantPad2d(offset, float('-inf'))
72 | padded_maps = padding(input)
73 | batch_size, num_channels, h, w = padded_maps.size()
74 | element_map = torch.arange(0, h * w).long().view(1, 1, h, w)[:, :, offset: -offset, offset: -offset]
75 | element_map = element_map.to(input.device)
76 | _, indices = F.max_pool2d(
77 | padded_maps,
78 | kernel_size = win_size,
79 | stride = 1,
80 | return_indices = True)
81 | peak_map = (indices == element_map)
82 |
83 | if peak_filter:
84 | mask = input >= peak_filter(input)
85 | peak_map = (peak_map & mask)
86 | peak_list = torch.nonzero(peak_map)
87 | ctx.mark_non_differentiable(peak_list)
88 |
89 | if return_aggregation:
90 | peak_map = peak_map.float()
91 | ctx.save_for_backward(input, peak_map)
92 | return peak_list, (input * peak_map).view(batch_size, num_channels, -1).sum(2) / \
93 | peak_map.view(batch_size, num_channels, -1).sum(2)
94 | else:
95 | return peak_list
96 |
97 | @staticmethod
98 | def backward(ctx, grad_peak_list, grad_output):
99 | input, peak_map, = ctx.saved_tensors
100 | batch_size, num_channels, _, _ = input.size()
101 | grad_input = peak_map * grad_output.view(batch_size, num_channels, 1, 1)/ \
102 | (peak_map.view(batch_size, num_channels, -1).sum(2).view(batch_size, num_channels, 1, 1) + 1e-6)
103 | return (grad_input,) + (None,) * ctx.num_flags
104 |
105 |
106 | def peak_stimulation(input, return_aggregation=True, win_size=3, peak_filter=None):
107 | return PeakStimulation.apply(input, return_aggregation, win_size, peak_filter)
108 |
109 |
110 | class ScaleLayer(nn.Module):
111 |
112 | def __init__(self, init_value=1e-3):
113 | super().__init__()
114 | self.scale = nn.Parameter(torch.FloatTensor([init_value]))
115 |
116 | def forward(self, input):
117 | return input * self.scale
118 |
119 |
120 | class S3N(nn.Module):
121 |
122 | def __init__(self, base_model, num_classes, task_input_size, base_ratio, radius, radius_inv):
123 | super(S3N, self).__init__()
124 |
125 | self.grid_size = 31
126 | self.padding_size = 30
127 | self.global_size = self.grid_size + 2*self.padding_size
128 | self.input_size_net = task_input_size
129 | gaussian_weights = torch.FloatTensor(makeGaussian(2*self.padding_size+1, fwhm = 13))
130 | self.base_ratio = base_ratio
131 | self.radius = ScaleLayer(radius)
132 | self.radius_inv = ScaleLayer(radius_inv)
133 |
134 | self.filter = nn.Conv2d(1, 1, kernel_size=(2*self.padding_size+1,2*self.padding_size+1),bias=False)
135 | self.filter.weight[0].data[:,:,:] = gaussian_weights
136 |
137 | self.P_basis = torch.zeros(2,self.grid_size+2*self.padding_size, self.grid_size+2*self.padding_size)
138 | for k in range(2):
139 | for i in range(self.global_size):
140 | for j in range(self.global_size):
141 | self.P_basis[k,i,j] = k*(i-self.padding_size)/(self.grid_size-1.0)+(1.0-k)*(j-self.padding_size)/(self.grid_size-1.0)
142 |
143 | self.features = base_model.features
144 | self.num_features = base_model.num_features
145 |
146 | self.raw_classifier = nn.Linear(2048, num_classes)
147 | self.sampler_buffer = nn.Sequential(nn.Conv2d(2048, 2048, kernel_size=3, stride=2, padding=1, bias=False),
148 | nn.BatchNorm2d(2048),
149 | nn.ReLU(),
150 | )
151 | self.sampler_classifier = nn.Linear(2048, num_classes)
152 |
153 | self.sampler_buffer1 = nn.Sequential(nn.Conv2d(2048, 2048, kernel_size=3, stride=2, padding=1, bias=False),
154 | nn.BatchNorm2d(2048),
155 | nn.ReLU(),
156 | )
157 | self.sampler_classifier1 = nn.Linear(2048, num_classes)
158 |
159 | self.con_classifier = nn.Linear(int(self.num_features*3), num_classes)
160 |
161 | self.avg = nn.AdaptiveAvgPool2d(1)
162 | self.max_pool = nn.AdaptiveMaxPool2d(1)
163 |
164 | self.map_origin = nn.Conv2d(2048, num_classes, 1, 1, 0)
165 |
166 | def create_grid(self, x):
167 | P = torch.autograd.Variable(torch.zeros(1,2,self.grid_size+2*self.padding_size, self.grid_size+2*self.padding_size).cuda(),requires_grad=False)
168 | P[0,:,:,:] = self.P_basis
169 | P = P.expand(x.size(0),2,self.grid_size+2*self.padding_size, self.grid_size+2*self.padding_size)
170 |
171 | x_cat = torch.cat((x,x),1)
172 | p_filter = self.filter(x)
173 | x_mul = torch.mul(P,x_cat).view(-1,1,self.global_size,self.global_size)
174 | all_filter = self.filter(x_mul).view(-1,2,self.grid_size,self.grid_size)
175 |
176 | x_filter = all_filter[:,0,:,:].contiguous().view(-1,1,self.grid_size,self.grid_size)
177 | y_filter = all_filter[:,1,:,:].contiguous().view(-1,1,self.grid_size,self.grid_size)
178 |
179 | x_filter = x_filter/p_filter
180 | y_filter = y_filter/p_filter
181 |
182 | xgrids = x_filter*2-1
183 | ygrids = y_filter*2-1
184 | xgrids = torch.clamp(xgrids,min=-1,max=1)
185 | ygrids = torch.clamp(ygrids,min=-1,max=1)
186 |
187 | xgrids = xgrids.view(-1,1,self.grid_size,self.grid_size)
188 | ygrids = ygrids.view(-1,1,self.grid_size,self.grid_size)
189 |
190 | grid = torch.cat((xgrids,ygrids),1)
191 |
192 | grid = F.interpolate(grid, size=(self.input_size_net,self.input_size_net), mode='bilinear', align_corners=True)
193 |
194 | grid = torch.transpose(grid,1,2)
195 | grid = torch.transpose(grid,2,3)
196 |
197 | return grid
198 |
199 | def generate_map(self, input_x, class_response_maps, p):
200 | N, C, H, W = class_response_maps.size()
201 |
202 | score_pred, sort_number = torch.sort(F.softmax(F.adaptive_avg_pool2d(class_response_maps, 1), dim=1), dim=1, descending=True)
203 | gate_score = (score_pred[:, 0:5]*torch.log(score_pred[:, 0:5])).sum(1)
204 |
205 | xs = []
206 | xs_inv = []
207 |
208 | for idx_i in range(N):
209 | if gate_score[idx_i] > -0.2:
210 | decide_map = class_response_maps[idx_i, sort_number[idx_i, 0],:,:]
211 | else:
212 | decide_map = class_response_maps[idx_i, sort_number[idx_i, 0:5],:,:].mean(0)
213 |
214 | min_value, max_value = decide_map.min(), decide_map.max()
215 | decide_map = (decide_map-min_value)/(max_value-min_value)
216 |
217 | peak_list, aggregation = peak_stimulation(decide_map, win_size=3, peak_filter=_mean_filter)
218 |
219 | decide_map = decide_map.squeeze(0).squeeze(0)
220 |
221 | score = [decide_map[item[2], item[3]] for item in peak_list]
222 | x = [item[3] for item in peak_list]
223 | y = [item[2] for item in peak_list]
224 |
225 | if score == []:
226 | temp = torch.zeros(1, 1, self.grid_size,self.grid_size).cuda()
227 | temp += self.base_ratio
228 | xs.append(temp)
229 | xs_soft.append(temp)
230 | continue
231 |
232 | peak_num = torch.arange(len(score))
233 |
234 | temp = self.base_ratio
235 | temp_w = self.base_ratio
236 |
237 | if p == 0:
238 | for i in peak_num:
239 | temp += score[i] * kernel_generate(self.radius(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda()
240 | temp_w += 1/score[i] * \
241 | kernel_generate(self.radius_inv(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda()
242 | elif p == 1:
243 | for i in peak_num:
244 | rd = random.uniform(0, 1)
245 | if score[i] > rd:
246 | temp += score[i] * kernel_generate(self.radius(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda()
247 | else:
248 | temp_w += 1/score[i] * \
249 | kernel_generate(self.radius_inv(torch.sqrt(score[i])), H, (x[i].item(), y[i].item())).unsqueeze(0).unsqueeze(0).cuda()
250 | elif p == 2:
251 | index = score.index(max(score))
252 | temp += score[index] * kernel_generate(self.radius(score[index]), H, (x[index].item(), y[index].item())).unsqueeze(0).unsqueeze(0).cuda()
253 |
254 | index = score.index(min(score))
255 | temp_w += 1/score[index] * \
256 | kernel_generate(self.radius_inv(torch.sqrt(score[index])), H, (x[index].item(), y[index].item())).unsqueeze(0).unsqueeze(0).cuda()
257 |
258 | if type(temp) == float:
259 | temp += torch.zeros(1, 1, self.grid_size,self.grid_size).cuda()
260 | xs.append(temp)
261 |
262 | if type(temp_w) == float:
263 | temp_w += torch.zeros(1, 1, self.grid_size,self.grid_size).cuda()
264 | xs_inv.append(temp_w)
265 |
266 | xs = torch.cat(xs, 0)
267 | xs_hm = nn.ReplicationPad2d(self.padding_size)(xs)
268 | grid = self.create_grid(xs_hm).to(input_x.device)
269 | x_sampled_zoom = F.grid_sample(input_x, grid)
270 |
271 | xs_inv = torch.cat(xs_inv, 0)
272 | xs_hm_inv = nn.ReplicationPad2d(self.padding_size)(xs_inv)
273 | grid_inv = self.create_grid(xs_hm_inv).to(input_x.device)
274 | x_sampled_inv = F.grid_sample(input_x, grid_inv)
275 |
276 | return x_sampled_zoom, x_sampled_inv
277 |
278 | def forward(self, input_x, p):
279 |
280 | self.map_origin.weight.data.copy_(self.raw_classifier.weight.data.unsqueeze(-1).unsqueeze(-1))
281 | self.map_origin.bias.data.copy_(self.raw_classifier.bias.data)
282 |
283 | feature_raw = self.features(input_x)
284 | agg_origin = self.raw_classifier(self.avg(feature_raw).view(-1, 2048))
285 |
286 | with torch.no_grad():
287 | class_response_maps = F.interpolate(self.map_origin(feature_raw), size=self.grid_size, mode='bilinear', align_corners=True)
288 | x_sampled_zoom, x_sampled_inv = self.generate_map(input_x, class_response_maps, p)
289 |
290 | feature_D = self.sampler_buffer(self.features(x_sampled_zoom))
291 | agg_sampler = self.sampler_classifier(self.avg(feature_D).view(-1, 2048))
292 |
293 | feature_C = self.sampler_buffer1(self.features(x_sampled_inv))
294 | agg_sampler1 = self.sampler_classifier1(self.avg(feature_C).view(-1, 2048))
295 |
296 | aggregation = self.con_classifier(torch.cat([self.avg(feature_raw).view(-1, 2048), self.avg(feature_D).view(-1, 2048), self.avg(feature_C).view(-1, 2048)], 1))
297 |
298 | return aggregation, agg_origin, agg_sampler, agg_sampler1
299 |
300 |
301 | @register
302 | def s3n(
303 | mode: str ='resnet50',
304 | num_classes: int = 200,
305 | task_input_size: int = 448,
306 | base_ratio: float = 0.09,
307 | radius: float = 0.08,
308 | radius_inv: float = 0.2) -> nn.Module:
309 | """ Selective sparse sampling.
310 | """
311 |
312 | classify_network = modules.ft_resnet(mode=mode, fc_or_fcn = 'fc',num_classes=num_classes)
313 | model = S3N(classify_network, num_classes, task_input_size, base_ratio, radius, radius_inv)
314 |
315 | return model
316 |
317 |
318 | @register
319 | def three_stage(
320 | ctx: Context,
321 | train_ctx: Context) -> None:
322 | """Three stage.
323 | """
324 |
325 | if train_ctx.is_train:
326 | p = 0 if train_ctx.epoch_idx <= 20 else 1
327 | else:
328 | p = 1 if train_ctx.epoch_idx <= 20 else 2
329 |
330 | train_ctx.output = train_ctx.model(train_ctx.input, p)
331 |
332 | raise train_ctx.Skip
333 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 | import logging
4 | from contextlib import contextmanager
5 | from typing import Any, Iterable, Union, List, Tuple, Dict, Callable, Optional
6 |
7 | import torch
8 | from torch import Tensor, nn, optim
9 | from torch.utils import data
10 | from tqdm import tqdm, tqdm_notebook
11 | from nest import register, Context
12 |
13 |
14 | class TqdmHandler(logging.StreamHandler):
15 | def __init__(self):
16 | logging.StreamHandler.__init__(self)
17 |
18 | def emit(self, record):
19 | msg = self.format(record)
20 | tqdm.write(msg)
21 |
22 |
23 | @register
24 | def network_trainer(
25 | data_loaders: Tuple[List[Tuple[str, data.DataLoader]], List[Tuple[str, data.DataLoader]]],
26 | model: nn.Module,
27 | criterion: object,
28 | optimizer: Callable[[Iterable], optim.Optimizer],
29 | parameter: Optional[Callable] = None,
30 | meters: Optional[Dict[str, Callable[[Context], Any]]] = None,
31 | hooks: Optional[Dict[str, List[Callable[[Context], None]]]] = None,
32 | max_epoch: int = 200,
33 | test_interval: int = 1,
34 | resume: Optional[str] = None,
35 | log_path: Optional[str] = None,
36 | device: str = 'cuda',
37 | use_data_parallel: bool = True,
38 | use_cudnn_benchmark: bool = True,
39 | random_seed: int = 999) -> Context:
40 | """Network trainer.
41 | """
42 |
43 | torch.manual_seed(random_seed)
44 |
45 | logger = logging.getLogger('nest.network_trainer')
46 | logger.handlers = []
47 | logger.setLevel(logging.DEBUG)
48 |
49 | screen_handler = TqdmHandler()
50 | screen_handler.setFormatter(logging.Formatter('[%(asctime)s] %(message)s'))
51 | logger.addHandler(screen_handler)
52 |
53 | if not log_path is None:
54 |
55 | try:
56 | os.makedirs(os.path.dirname(log_path))
57 | except OSError as exception:
58 | if exception.errno != errno.EEXIST:
59 | raise
60 | file_handler = logging.FileHandler(log_path, encoding='utf8')
61 | file_handler.setFormatter(logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s'))
62 | logger.addHandler(file_handler)
63 |
64 | def run_in_notebook():
65 | try:
66 | return get_ipython().__class__.__name__.startswith('ZMQ')
67 | except NameError:
68 | pass
69 | return False
70 | progress_bar = tqdm_notebook if run_in_notebook() else tqdm
71 |
72 | device = torch.device(device)
73 | if device.type == 'cuda':
74 | assert torch.cuda.is_available(), 'CUDA is not available.'
75 | torch.backends.cudnn.benchmark = use_cudnn_benchmark
76 |
77 | train_loaders, test_loaders = data_loaders
78 |
79 | model = model.to(device)
80 |
81 | if device.type == 'cuda' and use_data_parallel:
82 | model = nn.DataParallel(model)
83 |
84 | params = model.parameters() if parameter is None else parameter(model)
85 | optimizer = optimizer(params)
86 |
87 | start_epoch_idx = 0
88 | start_batch_idx = 0
89 | if not resume is None:
90 | logger.info('loading checkpoint "%s"' % resume)
91 | checkpoint = torch.load(resume)
92 | start_epoch_idx = checkpoint['epoch_idx']
93 | start_batch_idx = checkpoint['batch_idx']
94 | model.load_state_dict(checkpoint['model'])
95 | optimizer.load_state_dict(checkpoint['optimizer'])
96 | logger.info('checkpoint loaded (epoch %d)' % start_epoch_idx)
97 |
98 | ctx = Context(
99 | split = 'train',
100 | is_train = True,
101 | model = model,
102 | optimizer = optimizer,
103 | max_epoch = max_epoch,
104 | epoch_idx = start_epoch_idx,
105 | batch_idx = start_batch_idx,
106 | input = Tensor(),
107 | output = Tensor(),
108 | target = Tensor(),
109 | loss = Tensor(),
110 | metrics = dict(),
111 | state_dicts = [],
112 | logger = logger)
113 |
114 | class Skip(Exception): pass
115 | ctx.Skip = Skip
116 |
117 | @contextmanager
118 | def skip():
119 | try:
120 | yield
121 | except Skip:
122 | pass
123 |
124 | def run_hooks(hook_type):
125 | if isinstance(hooks, dict) and hook_type in hooks:
126 | for hook in hooks.get(hook_type):
127 | hook(ctx)
128 |
129 | @contextmanager
130 | def session(name):
131 | run_hooks('on_start_' + name)
132 | yield
133 | run_hooks('on_end_' + name)
134 |
135 | def process(split, data_loader, is_train):
136 | ctx.max_batch = len(data_loader)
137 | ctx.split = split
138 | ctx.is_train = is_train
139 |
140 | run_hooks('on_start_split')
141 |
142 | if is_train:
143 | model.train()
144 | else:
145 | model.eval()
146 |
147 | for batch_idx, (input, target) in enumerate(progress_bar(data_loader, ascii=True, desc=split, unit='batch', leave=False)):
148 | if batch_idx < ctx.batch_idx:
149 | continue
150 |
151 | ctx.batch_idx = batch_idx
152 | if isinstance(input, (list, tuple)):
153 | ctx.input = [v.to(device) if torch.is_tensor(v) else v for v in input]
154 | elif isinstance(input, dict):
155 | ctx.input = {k: v.to(device) if torch.is_tensor(v) else v for k, v in input.items()}
156 | else:
157 | ctx.input = input.to(device)
158 | ctx.target = target.to(device)
159 |
160 | run_hooks('on_start_batch')
161 |
162 | with skip(), session('batch'):
163 | with torch.set_grad_enabled(ctx.is_train):
164 | with skip(), session('forward'):
165 | ctx.output = ctx.model(ctx.input)
166 | ctx.loss = criterion(ctx.output, ctx.target)
167 |
168 | if not meters is None:
169 | ctx.metrics.update({split + '_' + k: v(ctx) for k, v in meters.items() if v is not None})
170 |
171 | if is_train:
172 | optimizer.zero_grad()
173 | ctx.loss.backward()
174 | optimizer.step()
175 |
176 | run_hooks('on_end_batch')
177 | ctx.batch_idx = 0
178 |
179 | run_hooks('on_end_split')
180 |
181 | run_hooks('on_start')
182 |
183 | for epoch_idx in progress_bar(range(ctx.epoch_idx, max_epoch), ascii=True, unit='epoch'):
184 | ctx.epoch_idx = epoch_idx
185 | run_hooks('on_start_epoch')
186 |
187 | for split, loader in train_loaders:
188 | process(split, loader, True)
189 |
190 | if epoch_idx % test_interval == 0:
191 | for split, loader in test_loaders:
192 | process(split, loader, False)
193 |
194 | run_hooks('on_end_epoch')
195 |
196 | run_hooks('on_end')
197 |
198 | return ctx
199 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union, Tuple, Dict
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 | from nest import register, Context
8 |
9 |
10 | class AverageMeter(object):
11 |
12 | def __init__(self):
13 | self.reset()
14 |
15 | def reset(self):
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 |
28 | @register
29 | def multi_topk_meter(
30 | ctx: Context,
31 | train_ctx: Context,
32 | k: int=1,
33 | init_num: int=1,
34 | end_num: int = 0) -> dict:
35 | """Multi topk meter.
36 | """
37 |
38 | def accuracy(output, target, k=1):
39 | batch_size = target.size(0)
40 |
41 | _, pred = output.topk(k, 1, True, True)
42 | pred = pred.t()
43 | correct = pred.eq(target.view(1, -1).expand_as(pred))
44 |
45 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
46 | return correct_k.mul_(100.0 / batch_size)
47 |
48 | for i in range(init_num, len(train_ctx.output) - end_num):
49 | if not "branch_"+str(i) in ctx:
50 | setattr(ctx, "branch_"+str(i), AverageMeter())
51 |
52 | if train_ctx.batch_idx == 0:
53 | for i in range(init_num, len(train_ctx.output) - end_num):
54 | getattr(ctx, "branch_"+str(i)).reset()
55 |
56 | for i in range(init_num, len(train_ctx.output) - end_num):
57 | acc = accuracy(train_ctx.output[i], train_ctx.target, k)
58 | getattr(ctx, "branch_"+str(i)).update(acc.item())
59 |
60 | acc_list = {}
61 |
62 | for i in range(init_num, len(train_ctx.output) - end_num):
63 | acc_list["branch_"+str(i)] = getattr(ctx, "branch_"+str(i)).avg
64 |
65 | return acc_list
66 |
67 |
68 | @register
69 | def best_meter(ctx: Context, train_ctx: Context, best_branch: int = 1, k: int = 1) -> float:
70 | """Best meter.
71 | """
72 | def accuracy(output, target, k=1):
73 | batch_size = target.size(0)
74 |
75 | _, pred = output.topk(k, 1, True, True)
76 | pred = pred.t()
77 | correct = pred.eq(target.view(1, -1).expand_as(pred))
78 |
79 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
80 | return correct_k.mul_(100.0 / batch_size)
81 |
82 | if not 'meter' in ctx:
83 | ctx.meter = AverageMeter()
84 |
85 | if train_ctx.batch_idx == 0:
86 | ctx.meter.reset()
87 | acc = accuracy(train_ctx.output[best_branch], train_ctx.target, k)
88 | ctx.meter.update(acc.item())
89 | return ctx.meter.avg
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from fnmatch import fnmatch
3 | from typing import Tuple, List, Union, Dict, Iterable
4 |
5 | import numpy as np
6 | import torch.nn as nn
7 | from nest import register
8 |
9 |
10 | @register
11 | def finetune(
12 | model: nn.Module,
13 | base_lr: float,
14 | groups: Dict[str, float],
15 | ignore_the_rest: bool = False,
16 | raw_query: bool = False) -> List[Dict[str, Union[float, Iterable]]]:
17 | """Fintune.
18 | """
19 |
20 | parameters = [dict(params=[], names=[], query=query if raw_query else '*'+query+'*', lr=lr*base_lr) for query, lr in groups.items()]
21 | rest_parameters = dict(params=[], names=[], lr=base_lr)
22 | for k, v in model.named_parameters():
23 | matched = False
24 | for group in parameters:
25 | if fnmatch(k, group['query']):
26 | group['params'].append(v)
27 | group['names'].append(k)
28 | matched = True
29 | break
30 | if not matched:
31 | rest_parameters['params'].append(v)
32 | rest_parameters['names'].append(k)
33 | if not ignore_the_rest:
34 | parameters.append(rest_parameters)
35 | for group in parameters:
36 | group['params'] = iter(group['params'])
37 | return parameters
38 |
39 |
40 | @register(acknowledgement='RLE modules are based on Sam Stainsby\'s implementation (https://www.kaggle.com/stainsby/fast-tested-rle)')
41 | def rle_encode(mask: np.ndarray) -> dict:
42 | """Run-Length Encoding (RLE).
43 | """
44 |
45 | assert mask.dtype == bool and mask.ndim == 2, 'RLE encoding requires a binary mask (dtype=bool).'
46 | pixels = mask.flatten()
47 | pixels = np.concatenate([[0], pixels, [0]])
48 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
49 | runs[1::2] -= runs[::2]
50 | return dict(data=base64.b64encode(runs.astype(np.uint32).tobytes()).decode('utf-8'), shape=mask.shape)
51 |
52 |
53 | @register
54 | def rle_decode(rle: dict) -> np.ndarray:
55 | """Run-Length Encoding.
56 | """
57 |
58 | runs = np.frombuffer(base64.b64decode(rle['data']), np.uint32)
59 | shape = rle['shape']
60 | starts, lengths = [np.asarray(x, dtype=int) for x in (runs[0:][::2], runs[1:][::2])]
61 | starts -= 1
62 | ends = starts + lengths
63 | img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
64 | for lo, hi in zip(starts, ends):
65 | img[lo:hi] = 1
66 | return img.reshape(shape)
67 |
--------------------------------------------------------------------------------