├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitkeep ├── docker └── Dockerfile ├── experiments └── exp0.py ├── notebooks └── .gitkeep ├── src ├── data_utils.py ├── debug.py ├── losses.py ├── lr_scheduler.py ├── metrics.py ├── mlsnet.py ├── models.py └── utils.py ├── submit └── .gitkeep └── tests ├── .gitkeep └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __init__.py 2 | *.pyc 3 | .idea/ 4 | 5 | submit/* 6 | data/* 7 | !.gitkeep 8 | 9 | # files 10 | *.csv 11 | *.zip 12 | *.jpg 13 | *.png 14 | *.hdf5 15 | *.h5 16 | *.log 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # dotenv 100 | .env 101 | 102 | # virtualenv 103 | .venv 104 | venv/ 105 | ENV/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 lyakaap 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-template 2 | 3 | This repository is my PyTorch project template (for Kaggle and research). 4 | Currently, this repository supports: 5 | * some loss functions for semantic segmentation (src/losses.py) 6 | * mean IoU cuda implementation (src/metrics.py) 7 | * basic UNet implementation (src/models.py) 8 | * learning rate scheduler (src/lr_scheduler.py) 9 | * useful debugging module (src/debug.py) 10 | 11 | ## Enviroments 12 | This repository supports PyTorch >= 1.0. You can setup the enviroment using docker/Dockerfile (Ubuntu 16.04, cuda 9.2). 13 | 14 | ## Run Experiments 15 | Set working directory to "experiments" and run below commands, 16 | ``` 17 | # run training and evaluation (with saving checkpoints) 18 | python exp0.py job --devices 0,1 19 | 20 | # Run grid-search for better hyperparameter set. 21 | # Parameter space can be set in "exp0.py". 22 | # In below setting, each trial is conducted on a single gpu device 23 | # and thus whole tuning processes are launched on multiple gpu devices in parallel. 24 | python exp0.py tuning --devices 0,1 --n-gpu 1 --mode 'grid' 25 | ``` 26 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyakaap/pytorch-template/eff9f0a4dd50fa49c3b949065247598d5eabc91e/data/.gitkeep -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu16.04 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | build-essential \ 5 | cmake \ 6 | git \ 7 | unzip \ 8 | curl \ 9 | wget \ 10 | vim \ 11 | tmux \ 12 | htop \ 13 | less \ 14 | locate \ 15 | ca-certificates \ 16 | libsm6 \ 17 | libxext6 \ 18 | libxrender1 &&\ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh && \ 22 | chmod +x ~/miniconda.sh && \ 23 | ~/miniconda.sh -b -p /opt/conda && \ 24 | rm ~/miniconda.sh 25 | 26 | ENV PATH=/opt/conda/bin:$PATH \ 27 | LC_ALL=C.UTF-8 \ 28 | LANG=C.UTF-8 29 | 30 | RUN pip install -U \ 31 | tqdm \ 32 | click \ 33 | logzero \ 34 | gensim \ 35 | optuna \ 36 | tensorboardX \ 37 | scikit-image \ 38 | lockfile \ 39 | pytest \ 40 | Cython \ 41 | pyyaml \ 42 | jupyter \ 43 | jupyterthemes \ 44 | kaggle \ 45 | opencv-python \ 46 | joblib \ 47 | seaborn \ 48 | pretrainedmodels \ 49 | plotly \ 50 | albumentations \ 51 | line-profiler \ 52 | tabulate \ 53 | cloudpickle==0.5.6 # to suppress warning 54 | 55 | RUN conda install -y pytorch torchvision cuda92 -c pytorch && \ 56 | conda install -y pandas scikit-learn matplotlib pytables tensorflow-gpu keras && \ 57 | conda install -c conda-forge jupyter_contrib_nbextensions 58 | RUN conda clean --all 59 | 60 | RUN jupyter contrib nbextension install --user 61 | RUN jt -t grade3 -f firacode -nf firacode -altp -fs 100 -tfs 100 -nfs 100 -dfs 100 -ofs 100 -cellw 88% -T -------------------------------------------------------------------------------- /experiments/exp0.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | import click 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from sklearn.model_selection import train_test_split 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms 15 | from tqdm import tqdm 16 | 17 | from src import utils, data_utils, models 18 | 19 | ROOT = '/opt/airbus-ship-detection/' 20 | 21 | params = { 22 | 'ex_name': __file__.replace('.py', ''), 23 | 'seed': 123456789, 24 | 'lr': 1e-4, 25 | 'batch_size': 8, 26 | 'test_batch_size': 8, 27 | 'optimizer': 'momentum', 28 | 'epochs': 10, 29 | 'workers': 8, 30 | 'dropout': 0.3, 31 | 'wd': 1e-5, 32 | } 33 | 34 | 35 | @click.group() 36 | def cli(): 37 | if not Path(ROOT + f'experiments/{params["ex_name"]}/train').exists(): 38 | Path(ROOT + f'experiments/{params["ex_name"]}/train').mkdir(parents=True) 39 | if not Path(ROOT + f'experiments/{params["ex_name"]}/tuning').exists(): 40 | Path(ROOT + f'experiments/{params["ex_name"]}/tuning').mkdir(parents=True) 41 | 42 | np.random.seed(params['seed']) 43 | torch.manual_seed(params['seed']) 44 | torch.cuda.manual_seed_all(params['seed']) 45 | torch.backends.cudnn.benchmark = True 46 | 47 | 48 | @cli.command() 49 | @click.option('--tuning', is_flag=True) 50 | @click.option('--params-path', type=click.Path(), default=None, help='json file path for setting parameters') 51 | @click.option('--devices', '-d', type=str, help='comma delimited gpu device list (e.g. "0,1")') 52 | @click.option('--resume', type=str, default=None, help='checkpoint path') 53 | def job(tuning, params_path, devices, resume): 54 | """ 55 | Example: 56 | python exp0.py job --devices 0,1 57 | """ 58 | 59 | global params 60 | if tuning: 61 | with open(params_path, 'r') as f: 62 | params = json.load(f) 63 | mode_str = 'tuning' 64 | setting = '_'.join(f'{tp}-{params[tp]}' for tp in params['tuning_params']) 65 | else: 66 | mode_str = 'train' 67 | setting = '' 68 | 69 | exp_path = ROOT + f'experiments/{params["ex_name"]}/' 70 | os.environ['CUDA_VISIBLE_DEVICES'] = devices 71 | 72 | logger, writer = utils.get_logger(log_dir=exp_path + f'{mode_str}/log/{setting}', 73 | tensorboard_dir=exp_path + f'{mode_str}/tf_board/{setting}') 74 | 75 | train_df = pd.read_csv(ROOT + 'data/train.csv') 76 | train_df, val_df = train_test_split(train_df, test_size=1024, random_state=params['seed']) 77 | 78 | model = models.UNet(in_channels=3, n_classes=2, depth=4, ch_first=32, padding=True, 79 | batch_norm=False, up_mode='upconv').cuda() 80 | 81 | optimizer = utils.get_optim(params, model) 82 | 83 | if resume is not None: 84 | model, optimizer = utils.load_checkpoint(model, resume, optimizer=optimizer) 85 | 86 | if len(devices.split(',')) > 1: 87 | model = nn.DataParallel(model) 88 | 89 | data_transforms = { 90 | 'train': transforms.Compose([ 91 | transforms.ToPILImage(), 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 95 | ]), 96 | 'val': transforms.Compose([ 97 | transforms.ToPILImage(), 98 | transforms.ToTensor(), 99 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 100 | ]), 101 | } 102 | image_datasets = {'train': data_utils.CSVDataset(train_df, data_transforms['train']), 103 | 'val': data_utils.CSVDataset(val_df, data_transforms['val'])} 104 | data_loaders = {'train': DataLoader(image_datasets['train'], 105 | batch_size=params['batch_size'], pin_memory=True, 106 | shuffle=True, drop_last=True, num_workers=params['workers']), 107 | 'val': DataLoader(image_datasets['val'], batch_size=params['test_batch_size'], 108 | pin_memory=True, shuffle=False, num_workers=params['workers'])} 109 | 110 | criterion = nn.CrossEntropyLoss() 111 | scheduler = optim.lr_scheduler.MultiStepLR( 112 | optimizer, milestones=[int(params['epochs'] * 0.7), int(params['epochs'] * 0.9)], gamma=0.1) 113 | 114 | for epoch in range(params['epochs']): 115 | logger.info(f'Epoch {epoch}/{params["epochs"]} | lr: {optimizer.param_groups[0]["lr"]}') 116 | 117 | # ============================== train ============================== # 118 | model.train(True) 119 | 120 | losses = utils.AverageMeter() 121 | prec1 = utils.AverageMeter() 122 | 123 | for i, (x, y) in tqdm(enumerate(data_loaders['train']), 124 | total=len(data_loaders['train']), miniters=50): 125 | x = x.to('cuda:0') 126 | y = y.to('cuda:0', non_blocking=True) 127 | 128 | outputs = model(x) 129 | loss = criterion(outputs, y) 130 | 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | acc = utils.accuracy(outputs, y) 136 | losses.update(loss.item(), x.size(0)) 137 | prec1.update(acc.item(), x.size(0)) 138 | 139 | train_loss = losses.avg 140 | train_acc = prec1.avg 141 | 142 | # ============================== validation ============================== # 143 | model.train(False) 144 | losses.reset() 145 | prec1.reset() 146 | 147 | for i, (x, y) in tqdm(enumerate(data_loaders['val']), total=len(data_loaders['val'])): 148 | x = x.cuda() 149 | y = y.cuda(non_blocking=True) 150 | 151 | with torch.no_grad(): 152 | outputs = model(x) 153 | loss = criterion(outputs, y) 154 | 155 | acc = utils.accuracy(outputs, y) 156 | losses.update(loss.item(), x.size(0)) 157 | prec1.update(acc.item(), x.size(0)) 158 | 159 | val_loss = losses.avg 160 | val_acc = prec1.avg 161 | 162 | logger.info(f'[Val] Loss: \033[1m{val_loss:.4f}\033[0m | ' 163 | f'Acc: \033[1m{val_acc:.4f}\033[0m\n') 164 | 165 | writer.add_scalars('Loss', {'train': train_loss}, epoch) 166 | writer.add_scalars('Acc', {'train': train_acc}, epoch) 167 | writer.add_scalars('Loss', {'val': val_loss}, epoch) 168 | writer.add_scalars('Acc', {'val': val_acc}, epoch) 169 | writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch) 170 | 171 | scheduler.step() 172 | 173 | if not tuning: 174 | utils.save_checkpoint(model, epoch, exp_path + 'model_optim.pth', optimizer) 175 | 176 | if tuning: 177 | tuning_result = {} 178 | for key in ['train_loss', 'train_acc', 'val_loss', 'val_acc']: 179 | tuning_result[key] = [eval(key)] 180 | utils.write_tuning_result(params, tuning_result, exp_path + 'tuning/results.csv') 181 | 182 | 183 | 184 | @cli.command() 185 | @click.option('--mode', type=str, default='grid', help='Search method (tuning)') 186 | @click.option('--n-iter', type=int, default=10, help='n of iteration for random parameter search (tuning)') 187 | @click.option('--n-gpu', type=int, default=-1, help='n of used gpu at once') 188 | @click.option('--devices', '-d', type=str, help='comma delimited gpu device list (e.g. "0,1")') 189 | def tuning(mode, n_iter, n_gpu, devices): 190 | """ 191 | Example: 192 | python exp0.py tuning --devices 0,1 --n-gpu 1 193 | python exp0.py tuning --devices 0,1 --n-gpu 1 --mode 'random' --n-iter 4 194 | """ 195 | 196 | if n_gpu == -1: 197 | n_gpu = len(devices.split(',')) 198 | space = { 199 | 'lr': [1e-5, 1e-4, 1e-3], 200 | 'batch_size': [16, 8], 201 | } 202 | utils.launch_tuning(mode, n_iter, n_gpu, devices, params, space, ROOT) 203 | 204 | 205 | @cli.command() 206 | @click.option('--model-path', '-m', type=str) 207 | @click.option('--devices', '-d', type=str, help='comma delimited gpu device list (e.g. "0,1")') 208 | @click.option('--compression', '-c', is_flag=True) 209 | @click.option('--message', '-m', default=None, type=str) 210 | def predict(model_path, devices, compression, message): 211 | """WIP""" 212 | 213 | os.environ['CUDA_VISIBLE_DEVICES'] = devices 214 | 215 | test_img_paths = list(map(str, Path(ROOT + 'data/test/').glob('*.jpg'))) 216 | submission = pd.read_csv(ROOT + 'data/sample_submission.csv') 217 | 218 | model = models.UNet(in_channels=3, n_classes=2, depth=4, ch_first=32, padding=True, 219 | batch_norm=False, up_mode='upconv').cuda() 220 | model = utils.load_checkpoint(model, model_path) 221 | 222 | sub_path = ROOT + f'submit/{params["ex_name"]}.csv' 223 | if compression: 224 | sub_path += '.gz' 225 | submission.to_csv(sub_path, index=False, compression='gzip') 226 | else: 227 | submission.to_csv(sub_path, index=False) 228 | 229 | if message is None: 230 | message = params['ex_name'] 231 | 232 | cmd = f'kaggle c submit -c airbus-ship-detection -f {sub_path} -m "{message}"' 233 | subprocess.run(cmd, shell=True) 234 | 235 | 236 | if __name__ == '__main__': 237 | cli() 238 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyakaap/pytorch-template/eff9f0a4dd50fa49c3b949065247598d5eabc91e/notebooks/.gitkeep -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torch.utils.data import sampler 5 | 6 | 7 | class CSVDataset(Dataset): 8 | 9 | def __init__(self, df, transform): 10 | self.df = df 11 | self.transform = transform 12 | 13 | def __getitem__(self, index): 14 | row = self.df.iloc[index] 15 | img = cv2.imread(row['ImageID']) 16 | target = row['class'] 17 | if self.transform is not None: 18 | img = self.transform(img) 19 | return img, target 20 | 21 | def __len__(self): 22 | return len(self.df) 23 | 24 | 25 | class TestDataset(Dataset): 26 | 27 | def __init__(self, img_paths, transform): 28 | self.img_paths = img_paths 29 | self.transform = transform 30 | 31 | def __getitem__(self, index): 32 | img = cv2.imread(str(self.img_paths[index])) 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | return img 36 | 37 | def __len__(self): 38 | return len(self.img_paths) 39 | 40 | 41 | class InfiniteSampler(sampler.Sampler): 42 | 43 | def __init__(self, num_samples): 44 | self.num_samples = num_samples 45 | 46 | def __iter__(self): 47 | while True: 48 | order = np.random.permutation(self.num_samples) 49 | for i in range(self.num_samples): 50 | yield order[i] 51 | 52 | def __len__(self): 53 | return None 54 | -------------------------------------------------------------------------------- /src/debug.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | import pdb 4 | 5 | """ 6 | This module is for debugging without modifying scripts. 7 | 8 | By just adding `import debug` to a script which you want to debug, 9 | automatically pdb debugger starts at the point exception raised. 10 | After launching debugger, `from IPython import embed; embed()` enables us to run IPython. 11 | """ 12 | 13 | 14 | def info(exctype, value, tb): 15 | # we are in interactive mode or we don't have a tty-like 16 | # device, so we call the default hook 17 | if hasattr(sys, 'ps1') or not sys.stderr.isatty(): 18 | sys.__excepthook__(exctype, value, tb) 19 | else: 20 | traceback.print_exception(exctype, value, tb) 21 | pdb.post_mortem(tb) 22 | 23 | 24 | sys.excepthook = info 25 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from itertools import filterfalse 6 | 7 | 8 | class MixLoss(nn.Module): 9 | 10 | def __init__(self, bce_w=1.0, dice_w=0.0, focal_w=0.0, lovasz_w=0.0, 11 | bce_kwargs={}, dice_kwargs={}, focal_kwargs={}, lovasz_kwargs={}): 12 | super(MixLoss, self).__init__() 13 | self.bce_w = bce_w 14 | self.dice_w = dice_w 15 | self.focal_w = focal_w 16 | self.lovasz_w = lovasz_w 17 | 18 | self.bce_loss = nn.BCEWithLogitsLoss(**bce_kwargs) 19 | self.dice_loss = DiceLoss(**dice_kwargs) 20 | self.focal_loss = FocalLoss(**focal_kwargs) 21 | self.lovasz_loss = LovaszHinge(**lovasz_kwargs) 22 | 23 | def forward(self, output, target): 24 | loss = 0.0 25 | 26 | if self.bce_w: 27 | loss += self.bce_w * self.bce_loss(output, target) 28 | if self.dice_w: 29 | loss += self.dice_w * self.dice_loss(output, target) 30 | if self.focal_w: 31 | loss += self.focal_w * self.focal_loss(output, target) 32 | if self.lovasz_w: 33 | loss += self.lovasz_w * self.lovasz_loss(output, target) 34 | 35 | return loss 36 | 37 | 38 | class DiceLoss(nn.Module): 39 | def __init__(self, smooth=1.0, eps=1e-7): 40 | super(DiceLoss, self).__init__() 41 | self.smooth = smooth 42 | self.eps = eps 43 | 44 | def forward(self, output, target): 45 | output = torch.sigmoid(output) 46 | 47 | if torch.sum(target) == 0: 48 | output = 1.0 - output 49 | target = 1.0 - target 50 | 51 | return 1.0 - (2 * torch.sum(output * target) + self.smooth) / ( 52 | torch.sum(output) + torch.sum(target) + self.smooth + self.eps) 53 | 54 | 55 | class SoftIoULoss(nn.Module): 56 | def __init__(self, n_classes=19): 57 | super(SoftIoULoss, self).__init__() 58 | self.n_classes = n_classes 59 | 60 | @staticmethod 61 | def to_one_hot(tensor, n_classes): 62 | n, h, w = tensor.size() 63 | one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) 64 | return one_hot 65 | 66 | def forward(self, logit, target): 67 | # logit => N x Classes x H x W 68 | # target => N x H x W 69 | 70 | N = len(logit) 71 | 72 | pred = F.softmax(logit, dim=1) 73 | target_onehot = self.to_one_hot(target, self.n_classes) 74 | 75 | # Numerator Product 76 | inter = pred * target_onehot 77 | # Sum over all pixels N x C x H x W => N x C 78 | inter = inter.view(N, self.n_classes, -1).sum(2) 79 | 80 | # Denominator 81 | union = pred + target_onehot - (pred * target_onehot) 82 | # Sum over all pixels N x C x H x W => N x C 83 | union = union.view(N, self.n_classes, -1).sum(2) 84 | 85 | loss = inter / (union + 1e-16) 86 | 87 | # Return average loss over classes and batch 88 | return -loss.mean() 89 | 90 | 91 | class FocalLoss(nn.Module): 92 | 93 | def __init__(self, gamma=2, eps=1e-7): 94 | super(FocalLoss, self).__init__() 95 | self.gamma = gamma 96 | self.eps = eps 97 | 98 | def forward(self, logit, target): 99 | prob = torch.sigmoid(logit) 100 | prob = prob.clamp(self.eps, 1. - self.eps) 101 | 102 | loss = -1 * target * torch.log(prob) 103 | loss = loss * (1 - logit) ** self.gamma 104 | 105 | return loss.sum() 106 | 107 | 108 | def lovasz_grad(gt_sorted): 109 | """ 110 | Computes gradient of the Lovasz extension w.r.t sorted errors 111 | See Alg. 1 in paper 112 | """ 113 | p = len(gt_sorted) 114 | gts = gt_sorted.sum() 115 | intersection = gts - gt_sorted.float().cumsum(0) 116 | union = gts + (1 - gt_sorted).float().cumsum(0) 117 | jaccard = 1. - intersection / union 118 | if p > 1: # cover 1-pixel case 119 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 120 | return jaccard 121 | 122 | 123 | def isnan(x): 124 | return x != x 125 | 126 | 127 | def mean(l, ignore_nan=True, empty=0): 128 | """ 129 | nanmean compatible with generators. 130 | """ 131 | l = iter(l) 132 | if ignore_nan: 133 | l = filterfalse(isnan, l) 134 | try: 135 | n = 1 136 | acc = next(l) 137 | except StopIteration: 138 | if empty == 'raise': 139 | raise ValueError('Empty mean') 140 | return empty 141 | for n, v in enumerate(l, 2): 142 | acc += v 143 | if n == 1: 144 | return acc 145 | return acc / n 146 | 147 | 148 | def flatten_binary_scores(scores, labels, ignore=None): 149 | """ 150 | Flattens predictions in the batch (binary case) 151 | Remove labels equal to 'ignore' 152 | """ 153 | scores = scores.view(-1) 154 | labels = labels.view(-1) 155 | if ignore is None: 156 | return scores, labels 157 | valid = (labels != ignore) 158 | vscores = scores[valid] 159 | vlabels = labels[valid] 160 | return vscores, vlabels 161 | 162 | 163 | class LovaszHinge(nn.Module): 164 | 165 | def __init__(self, activation=lambda x: F.elu(x, inplace=True) + 1.0, 166 | per_image=True, ignore=None): 167 | super(LovaszHinge, self).__init__() 168 | self.activation = activation 169 | self.per_image = per_image 170 | self.ignore = ignore 171 | 172 | def lovasz_hinge_flat(self, logits, labels): 173 | """ 174 | Binary Lovasz hinge loss 175 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 176 | labels: [P] Tensor, binary ground truth labels (0 or 1) 177 | ignore: label to ignore 178 | """ 179 | if len(labels) == 0: 180 | # only void pixels, the gradients should be 0 181 | return logits.sum() * 0. 182 | signs = 2. * labels.float() - 1. 183 | errors = (1. - logits * signs) 184 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 185 | perm = perm.data 186 | gt_sorted = labels[perm] 187 | grad = lovasz_grad(gt_sorted) 188 | loss = torch.dot(self.activation(errors_sorted), grad) 189 | return loss 190 | 191 | def forward(self, logits, labels): 192 | if self.per_image: 193 | loss = mean(self.lovasz_hinge_flat( 194 | *flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), self.ignore) 195 | ) for log, lab in zip(logits, labels)) 196 | else: 197 | loss = self.lovasz_hinge_flat( 198 | *flatten_binary_scores(logits, labels, self.ignore)) 199 | return loss 200 | 201 | 202 | def flatten_probas(probas, labels, ignore=None): 203 | """ 204 | Flattens predictions in the batch 205 | """ 206 | B, C, H, W = probas.size() 207 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 208 | labels = labels.view(-1) 209 | if ignore is None: 210 | return probas, labels 211 | valid = (labels != ignore) 212 | vprobas = probas[valid.nonzero().squeeze()] 213 | vlabels = labels[valid] 214 | return vprobas, vlabels 215 | 216 | 217 | def lovasz_softmax_flat(probas, labels, only_present=False): 218 | """ 219 | Multi-class Lovasz-Softmax loss 220 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 221 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 222 | only_present: average only on classes present in ground truth 223 | """ 224 | if len(probas) == 0: 225 | return np.nan 226 | 227 | C = probas.size(1) 228 | losses = [] 229 | for c in range(C): 230 | fg = (labels == c).float() # foreground for class c 231 | if only_present and fg.sum() == 0: 232 | continue 233 | 234 | errors = (fg - probas[:, c]).abs() 235 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 236 | perm = perm.data 237 | fg_sorted = fg[perm] 238 | losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted))) 239 | return mean(losses) 240 | 241 | 242 | class LovaszSoftmax(nn.Module): 243 | """ 244 | Multi-class Lovasz-Softmax loss 245 | logits: [B, C, H, W] class logits at each prediction (between 0 and 1) 246 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 247 | only_present: average only on classes present in ground truth 248 | per_image: compute the loss per image instead of per batch 249 | ignore: void class labels 250 | """ 251 | 252 | def __init__(self, only_present=False, per_image=True, ignore=None): 253 | super(LovaszSoftmax, self).__init__() 254 | self.only_present = only_present 255 | self.per_image = per_image 256 | self.ignore = ignore 257 | 258 | def forward(self, logits, labels): 259 | probas = F.softmax(logits, dim=1) 260 | if self.per_image: 261 | loss = mean(lovasz_softmax_flat(*flatten_probas( 262 | prob.unsqueeze(0), lab.unsqueeze(0), self.ignore), only_present=self.only_present) 263 | for prob, lab in zip(probas, labels)) 264 | else: 265 | loss = lovasz_softmax_flat(*flatten_probas( 266 | probas, labels, self.ignore), only_present=self.only_present) 267 | return loss 268 | 269 | 270 | # Adapted from OCNet Repository (https://github.com/PkuRainBow/OCNet) 271 | class OhemCrossEntropy2d(nn.Module): 272 | def __init__(self, ignore_label=255, thresh=0.6, min_kept=0, use_weight=True): 273 | super(OhemCrossEntropy2d, self).__init__() 274 | self.ignore_label = ignore_label 275 | self.thresh = float(thresh) 276 | self.min_kept = int(min_kept) 277 | if use_weight: 278 | print("w/ class balance") 279 | weight = torch.FloatTensor( 280 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 281 | 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 282 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) 283 | else: 284 | print("w/o class balance") 285 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) 286 | 287 | def forward(self, predict, target, weight=None): 288 | """ 289 | Args: 290 | predict:(n, c, h, w) 291 | target:(n, h, w) 292 | weight (Tensor, optional): a manual rescaling weight given to each class. 293 | If given, has to be a Tensor of size "nclasses" 294 | """ 295 | assert not target.requires_grad 296 | assert predict.dim() == 4 297 | assert target.dim() == 3 298 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 299 | assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) 300 | assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) 301 | 302 | n, c, h, w = predict.size() 303 | input_label = target.data.cpu().numpy().ravel().astype(np.int32) 304 | x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1)) 305 | input_prob = np.exp(x - x.max(axis=0).reshape((1, -1))) 306 | input_prob /= input_prob.sum(axis=0).reshape((1, -1)) 307 | 308 | valid_flag = input_label != self.ignore_label 309 | valid_inds = np.where(valid_flag)[0] 310 | label = input_label[valid_flag] 311 | num_valid = valid_flag.sum() 312 | if self.min_kept >= num_valid: 313 | print('Labels: {}'.format(num_valid)) 314 | elif num_valid > 0: 315 | prob = input_prob[:, valid_flag] 316 | pred = prob[label, np.arange(len(label), dtype=np.int32)] 317 | threshold = self.thresh 318 | if self.min_kept > 0: 319 | index = pred.argsort() 320 | threshold_index = index[min(len(index), self.min_kept) - 1] 321 | if pred[threshold_index] > self.thresh: 322 | threshold = pred[threshold_index] 323 | kept_flag = pred <= threshold 324 | valid_inds = valid_inds[kept_flag] 325 | print('hard ratio: {} = {} / {} '.format(round(len(valid_inds)/num_valid, 4), len(valid_inds), num_valid)) 326 | 327 | label = input_label[valid_inds].copy() 328 | input_label.fill(self.ignore_label) 329 | input_label[valid_inds] = label 330 | print(np.sum(input_label != self.ignore_label)) 331 | target = torch.from_numpy(input_label.reshape(target.size())).long().cuda() 332 | 333 | return self.criterion(predict, target) 334 | 335 | 336 | class CriterionCrossEntropy(nn.Module): 337 | def __init__(self, ignore_index=255, weight='lightnet'): 338 | super(CriterionCrossEntropy, self).__init__() 339 | self.ignore_index = ignore_index 340 | 341 | if weight == 'lightnet': 342 | # https://github.com/ansleliu/LightNet/blob/master/datasets/calculate_class_weight.py 343 | self.weight = torch.FloatTensor( 344 | [0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437, 345 | 5.58540548, 3.56563995, 0.12704978, 1., 0.46783719, 1.34551528, 346 | 5.29974114, 0.28342531, 0.9396095, 0.81551811, 0.42679146, 3.6399074, 347 | 2.78376194]) 348 | else: 349 | self.weight = torch.FloatTensor( 350 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 351 | 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 352 | 353 | self.criterion = torch.nn.CrossEntropyLoss(weight=self.weight, ignore_index=ignore_index) 354 | 355 | def forward(self, preds, target): 356 | h, w = target.size(1), target.size(2) 357 | scale_pred = F.interpolate(input=preds, size=(h, w), mode='bilinear', align_corners=True) 358 | loss = self.criterion(scale_pred, target) 359 | return loss 360 | 361 | 362 | class CriterionDSN(nn.Module): 363 | 364 | def __init__(self, ignore_index=255, use_weight=True, loss_balance_coefs=(0.4, 1.0)): 365 | super(CriterionDSN, self).__init__() 366 | self.ignore_index = ignore_index 367 | self.loss_balance_coefs = loss_balance_coefs 368 | weight = torch.FloatTensor( 369 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 370 | 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 371 | if use_weight: 372 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) 373 | else: 374 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) 375 | 376 | def forward(self, preds, target): 377 | h, w = target.size(1), target.size(2) 378 | 379 | assert len(preds) == len(self.loss_balance_coefs) 380 | 381 | losses = [] 382 | for pred, coef in zip(preds, self.loss_balance_coefs): 383 | scale_pred = F.interpolate(input=pred, size=(h, w), mode='bilinear', align_corners=True) 384 | losses.append(self.criterion(scale_pred, target) * coef) 385 | 386 | return sum(losses) 387 | 388 | 389 | class CriterionOhemDSN(nn.Module): 390 | """ 391 | DSN + OHEM : We need to consider two supervision for the model. 392 | """ 393 | 394 | def __init__(self, ignore_index=255, thres=0.7, min_kept=100000, dsn_weight=0.4, use_weight=True): 395 | super(CriterionOhemDSN, self).__init__() 396 | self.ignore_index = ignore_index 397 | self.dsn_weight = dsn_weight 398 | self.criterion = OhemCrossEntropy2d(ignore_index, thres, min_kept, use_weight=use_weight) 399 | 400 | def forward(self, preds, target): 401 | h, w = target.size(1), target.size(2) 402 | scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) 403 | loss1 = self.criterion(scale_pred, target) 404 | scale_pred = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) 405 | loss2 = self.criterion(scale_pred, target) 406 | return self.dsn_weight * loss1 + loss2 407 | 408 | 409 | class CriterionOhemDSN_single(nn.Module): 410 | """ 411 | DSN + OHEM : we find that use hard-mining for both supervision harms the performance. 412 | Thus we choose the original loss for the shallow supervision 413 | and the hard-mining loss for the deeper supervision 414 | """ 415 | 416 | def __init__(self, ignore_index=255, thres=0.7, min_kept=100000, dsn_weight=0.4): 417 | super(CriterionOhemDSN_single, self).__init__() 418 | self.ignore_index = ignore_index 419 | self.dsn_weight = dsn_weight 420 | weight = torch.FloatTensor( 421 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 422 | 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 423 | self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) 424 | self.criterion_ohem = OhemCrossEntropy2d(ignore_index, thres, min_kept, use_weight=True) 425 | 426 | def forward(self, preds, target): 427 | h, w = target.size(1), target.size(2) 428 | 429 | scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) 430 | loss1 = self.criterion(scale_pred, target) 431 | 432 | scale_pred = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) 433 | loss2 = self.criterion_ohem(scale_pred, target) 434 | return self.dsn_weight * loss1 + loss2 435 | -------------------------------------------------------------------------------- /src/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | import math 3 | 4 | import torch 5 | 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | """multi-step learning rate scheduler with warmup.""" 9 | 10 | def __init__( 11 | self, 12 | optimizer, 13 | milestones, 14 | gamma=0.1, 15 | warmup_factor=1.0 / 3, 16 | warmup_iters=500, 17 | warmup_method="linear", 18 | last_epoch=-1, 19 | ): 20 | if not list(milestones) == sorted(milestones): 21 | raise ValueError( 22 | "Milestones should be main.tex list of" " increasing integers. Got {}", 23 | milestones, 24 | ) 25 | 26 | if warmup_method not in ("constant", "linear"): 27 | raise ValueError( 28 | "Only 'constant' or 'linear' warmup_method accepted" 29 | "got {}".format(warmup_method) 30 | ) 31 | self.milestones = milestones 32 | self.gamma = gamma 33 | self.warmup_factor = warmup_factor 34 | self.warmup_iters = warmup_iters 35 | self.warmup_method = warmup_method 36 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | warmup_factor = 1 40 | if self.last_epoch < self.warmup_iters: 41 | if self.warmup_method == "constant": 42 | warmup_factor = self.warmup_factor 43 | elif self.warmup_method == "linear": 44 | alpha = self.last_epoch / self.warmup_iters 45 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 46 | return [ 47 | base_lr 48 | * warmup_factor 49 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 50 | for base_lr in self.base_lrs 51 | ] 52 | 53 | 54 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 55 | """cosine annealing scheduler with warmup. 56 | 57 | Args: 58 | optimizer (Optimizer): Wrapped optimizer. 59 | T_max (int): Maximum number of iterations. 60 | eta_min (float): Minimum learning rate. Default: 0. 61 | last_epoch (int): The index of last epoch. Default: -1. 62 | 63 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 64 | https://arxiv.org/abs/1608.03983 65 | """ 66 | 67 | def __init__( 68 | self, 69 | optimizer, 70 | T_max, 71 | eta_min, 72 | warmup_factor=1.0 / 3, 73 | warmup_iters=500, 74 | warmup_method="linear", 75 | last_epoch=-1, 76 | ): 77 | if warmup_method not in ("constant", "linear"): 78 | raise ValueError( 79 | "Only 'constant' or 'linear' warmup_method accepted" 80 | "got {}".format(warmup_method) 81 | ) 82 | 83 | self.T_max = T_max 84 | self.eta_min = eta_min 85 | self.warmup_factor = warmup_factor 86 | self.warmup_iters = warmup_iters 87 | self.warmup_method = warmup_method 88 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 89 | 90 | def get_lr(self): 91 | if self.last_epoch < self.warmup_iters: 92 | return self.get_lr_warmup() 93 | else: 94 | return self.get_lr_cos_annealing() 95 | 96 | def get_lr_warmup(self): 97 | if self.warmup_method == "constant": 98 | warmup_factor = self.warmup_factor 99 | elif self.warmup_method == "linear": 100 | alpha = self.last_epoch / self.warmup_iters 101 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 102 | return [ 103 | base_lr * warmup_factor 104 | for base_lr in self.base_lrs 105 | ] 106 | 107 | def get_lr_cos_annealing(self): 108 | last_epoch = self.last_epoch - self.warmup_iters 109 | T_max = self.T_max - self.warmup_iters 110 | return [self.eta_min + (base_lr - self.eta_min) * 111 | (1 + math.cos(math.pi * last_epoch / T_max)) / 2 112 | for base_lr in self.base_lrs] 113 | 114 | 115 | class PiecewiseCyclicalLinearLR(torch.optim.lr_scheduler._LRScheduler): 116 | """Set the learning rate of each parameter group using piecewise 117 | cyclical linear schedule. 118 | 119 | When last_epoch=-1, sets initial lr as lr. 120 | 121 | Args: 122 | c: cycle length 123 | alpha1: lr upper bound of cycle 124 | alpha2: lr lower bound of cycle 125 | 126 | _Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs 127 | https://arxiv.org/pdf/1802.10026 128 | _Exploring loss function topology with cyclical learning rates 129 | https://arxiv.org/abs/1702.04283 130 | """ 131 | 132 | def __init__(self, optimizer, c, alpha1=1e-2, alpha2=5e-4, last_epoch=-1): 133 | 134 | self.c = c 135 | self.alpha1 = alpha1 136 | self.alpha2 = alpha2 137 | super(PiecewiseCyclicalLinearLR, self).__init__(optimizer, last_epoch) 138 | 139 | def get_lr(self): 140 | 141 | lrs = [] 142 | for _ in range(len(self.base_lrs)): 143 | ti = ((self.last_epoch - 1) % self.c + 1) / self.c 144 | if 0 <= ti <= 0.5: 145 | lr = (1 - 2 * ti) * self.alpha1 + 2 * ti * self.alpha2 146 | elif 0.5 < ti <= 1.0: 147 | lr = (2 - 2 * ti) * self.alpha2 + (2 * ti - 1) * self.alpha1 148 | else: 149 | raise ValueError('t(i) is out of range [0,1].') 150 | lrs.append(lr) 151 | 152 | return lrs 153 | 154 | 155 | class PolyLR(torch.optim.lr_scheduler._LRScheduler): 156 | 157 | def __init__(self, optimizer, power=0.9, max_epoch=4e4, last_epoch=-1): 158 | self.power = power 159 | self.max_epoch = max_epoch 160 | self.last_epoch = last_epoch 161 | super(PolyLR, self).__init__(optimizer, last_epoch) 162 | 163 | def get_lr(self): 164 | lrs = [] 165 | for base_lr in self.base_lrs: 166 | lr = base_lr * (1.0 - (self.last_epoch / self.max_epoch)) ** self.power 167 | lrs.append(lr) 168 | 169 | return lrs 170 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def accuracy(outputs: torch.Tensor, labels: torch.Tensor, ignore_index: int=None) -> float: 6 | # Num of class should be less than 255. 7 | 8 | if len(outputs.shape) == 4: 9 | preds = outputs.argmax(dim=1) 10 | elif len(outputs.shape) == 3: 11 | preds = outputs 12 | else: 13 | raise ValueError 14 | 15 | preds = preds.byte().flatten() 16 | labels = labels.byte().flatten() 17 | 18 | if ignore_index is not None: 19 | is_not_ignore = labels != ignore_index 20 | preds = preds[is_not_ignore] 21 | labels = labels[is_not_ignore] 22 | 23 | correct = preds.eq(labels) 24 | 25 | acc = correct.float().mean().item() 26 | 27 | return acc 28 | 29 | 30 | def prec_at_k(output, target, top_k=(1,)): 31 | """Computes the precision@k for the specified values of k""" 32 | max_k = max(top_k) 33 | batch_size = target.size(0) 34 | 35 | _, pred = output.topk(max_k, 1, True, True) 36 | pred = pred.t() 37 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 38 | 39 | res = [] 40 | for k in top_k: 41 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | 44 | if len(res) == 1: 45 | res = res[0] 46 | 47 | return res 48 | 49 | 50 | def intersection_and_union(preds: torch.Tensor, labels: torch.Tensor, 51 | ignore_index=255, n_classes=19): 52 | 53 | assert ignore_index > n_classes, 'ignore_index should be grater than n_classes' 54 | 55 | preds = preds.byte().flatten() 56 | labels = labels.byte().flatten() 57 | 58 | is_not_ignore = labels != ignore_index 59 | preds = preds[is_not_ignore] 60 | labels = labels[is_not_ignore] 61 | 62 | intersection = preds[preds == labels] 63 | area_intersection = intersection.bincount(minlength=n_classes) 64 | 65 | bincount_preds = preds.bincount(minlength=n_classes) 66 | bincount_labels = labels.bincount(minlength=n_classes) 67 | area_union = bincount_preds + bincount_labels - area_intersection 68 | 69 | area_intersection = area_intersection.float().cpu().numpy() 70 | area_union = area_union.float().cpu().numpy() 71 | 72 | return area_intersection, area_union 73 | 74 | 75 | def mean_iou(outputs, labels, n_classes=19): 76 | 77 | preds = outputs.argmax(dim=1) 78 | intersection, union = intersection_and_union(preds, labels, n_classes=n_classes) 79 | 80 | return np.mean(intersection / (union + 1e-16)) 81 | 82 | 83 | def mean_iou_50_to_95(outputs: torch.Tensor, labels: torch.Tensor, 84 | thresh=None, eps=1e-7, reduce=True): 85 | 86 | if thresh is not None: 87 | outputs = outputs > thresh 88 | 89 | outputs = outputs.squeeze(1) 90 | labels = labels.squeeze(1).byte() 91 | 92 | intersection = (outputs & labels).sum(dim=[1, 2]).float() 93 | union = (outputs | labels).sum(dim=[1, 2]).float() 94 | 95 | iou = (intersection + eps) / (union + eps) 96 | thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10 97 | 98 | if reduce: 99 | thresholded = thresholded.mean() 100 | 101 | return thresholded 102 | -------------------------------------------------------------------------------- /src/mlsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | class MLSNet(nn.Module): 8 | 9 | def __init__(self, 10 | n_classes: int = 19, 11 | encoder_depth: int = 18, 12 | pretrained: bool = True, 13 | up_mode: str = 'bilinear', 14 | bn_module=nn.BatchNorm2d, 15 | dilations: tuple = None, 16 | stop_level: int = 3, 17 | multiple_oc_size: int = 4, 18 | n_oc: int = 3): 19 | """ 20 | :param stop_level: level for pruning calc. 21 | 1: first level, 22 | 2: second level, 23 | 3: third level (return all level preds), 24 | """ 25 | super(MLSNet, self).__init__() 26 | 27 | encoder = ResNetEncoder(encoder_depth=encoder_depth, pretrained=pretrained, 28 | bn_module=bn_module, relu_inplace=True) 29 | activation = nn.ReLU(inplace=True) 30 | 31 | self.encoder = encoder.encoder 32 | self.channels_list = encoder.channels_list 33 | self.depth = len(self.channels_list) 34 | self.up_mode = up_mode 35 | self.align_corners = False if self.up_mode == 'bilinear' else None 36 | self.stop_level = stop_level 37 | 38 | if dilations is None: 39 | dilations = [1] * self.depth 40 | 41 | self.decoder = nn.ModuleList([ 42 | nn.ModuleList([None for _ in range(i)]) 43 | for i in reversed(range(1, self.depth)) 44 | ]) 45 | 46 | for i in reversed(range(self.depth - 1)): 47 | for j in range(0, self.depth - i - 1): 48 | in_ch = self.channels_list[i] + self.channels_list[i + 1] 49 | out_ch = self.channels_list[i] 50 | self.decoder[i][j] = ConvLayer(in_ch, out_ch, kernel_size=3, 51 | padding=dilations[j], dilation=dilations[j], 52 | activation=activation, bn_module=bn_module) 53 | 54 | if n_oc > 0: 55 | dilation_rates = [(6, 12, 24), (4, 8, 12), (2, 4, 6)] 56 | for i in range(n_oc): 57 | self.decoder[i][0].add_module(f'oc_{i}', 58 | ASP_OC_Module(self.channels_list[i], self.channels_list[i], 59 | size=2 ** (multiple_oc_size - i), 60 | dilations=dilation_rates[i], bn_module=bn_module)) 61 | 62 | self.cls = nn.ModuleList( 63 | nn.Conv2d(self.channels_list[0], n_classes, kernel_size=1) 64 | for _ in range(self.depth - 1)) 65 | 66 | self._init_weight() 67 | 68 | def _init_weight(self): 69 | for m in self.decoder.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | torch.nn.init.kaiming_normal_(m.weight) 72 | for m in self.cls.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | torch.nn.init.kaiming_normal_(m.weight) 75 | 76 | def forward(self, x, return_all_preds=True): 77 | """ 78 | :param x: input tensor (shape: B, C, H, W) 79 | :param return_all_preds: 80 | True: 81 | All level predictions are returned. It is necessary when training. 82 | False: 83 | Only one prediction are returned according to self.stop_level. 84 | This enables to avoid redundant calculation and thus saves inference time. 85 | :return: list of predictions from each level classifiers 86 | """ 87 | 88 | X = [[None for _ in range(i + 1)] for i in reversed(range(self.depth))] 89 | preds = [] 90 | 91 | for i in range(self.depth): 92 | 93 | # encoder part 94 | if i == 0: 95 | X[0][0] = self.encoder[0](x) 96 | else: 97 | X[i][0] = self.encoder[i](X[i - 1][0]) 98 | 99 | # decoder part 100 | for j in range(i): 101 | 102 | if i - (j + 1) == 0: 103 | cat_feat = torch.cat([ 104 | X[i - (j + 1)][j], 105 | F.interpolate(X[i - j][j], scale_factor=2, mode=self.up_mode, 106 | align_corners=self.align_corners) 107 | ], dim=1) 108 | X[i - (j + 1)][j + 1] = self.decoder[i - (j + 1)][j](cat_feat) 109 | 110 | if i > 0: 111 | if return_all_preds or i == self.stop_level: 112 | preds.append(self.cls[i - 1](X[0][i])) 113 | 114 | if i == self.stop_level: 115 | break # pruning 116 | 117 | return preds 118 | 119 | def adaptive_inference(self, x, threshold=0.93): 120 | """adaptive inference which enables to save computation time 121 | by pruning depending on the hardness of input.""" 122 | 123 | X = [[None for _ in range(i + 1)] for i in reversed(range(self.depth))] 124 | 125 | for i in range(self.depth): 126 | 127 | # encoder part 128 | if i == 0: 129 | X[0][0] = self.encoder[0](x) 130 | else: 131 | X[i][0] = self.encoder[i](X[i - 1][0]) 132 | 133 | # decoder part 134 | for j in range(i): 135 | cat_feat = torch.cat([ 136 | X[i - (j + 1)][j], 137 | F.interpolate(X[i - j][j], scale_factor=2, mode=self.up_mode, 138 | align_corners=self.align_corners) 139 | ], dim=1) 140 | X[i - (j + 1)][j + 1] = self.decoder[i - (j + 1)][j](cat_feat) 141 | 142 | if i > 0: 143 | logits = self.cls[i - 1](X[0][i]) 144 | avg_conf = F.softmax(logits, dim=1).max(dim=1)[0].mean() 145 | if avg_conf > threshold: 146 | break 147 | 148 | return logits, i 149 | 150 | 151 | class NoOperation(nn.Module): 152 | def __init__(self, *args, **kwargs): 153 | super(NoOperation, self).__init__() 154 | 155 | def forward(self, x): 156 | return x 157 | 158 | 159 | class ConvLayer(nn.Sequential): 160 | 161 | def __init__(self, in_channels, out_channels, 162 | kernel_size=3, stride=1, padding=1, dilation=1, 163 | bn_module=nn.BatchNorm2d, activation=nn.ReLU(inplace=True), use_cbam=False): 164 | super(ConvLayer, self).__init__() 165 | self.in_channels = in_channels 166 | self.out_channels = out_channels 167 | 168 | self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size, 169 | stride, padding, dilation)) 170 | 171 | if bn_module is not None: 172 | self.add_module('bn', bn_module(out_channels)) 173 | 174 | if activation is not None: 175 | self.add_module('act', activation) 176 | 177 | if use_cbam: 178 | from . import attention 179 | self.add_module('cbam', attention.CBAM(out_channels, bn_module=bn_module)) 180 | 181 | 182 | class EncoderBase(nn.Module): 183 | 184 | def __init__(self, encoder, channels_list): 185 | super(EncoderBase, self).__init__() 186 | self.encoder = encoder 187 | self.channels_list = channels_list 188 | 189 | def forward(self, x): 190 | 191 | bridges = [] 192 | for down in self.encoder: 193 | x = down(x) 194 | bridges.append(x) 195 | 196 | return bridges 197 | 198 | 199 | class BasicEncoder(EncoderBase): 200 | 201 | def __init__(self, input_channels=1, depth=4, 202 | channels=32, pooling='max', 203 | bn_module=nn.BatchNorm2d, activation=nn.ReLU(inplace=True), se_module=None): 204 | 205 | depth = depth 206 | channels_list = [channels * 2 ** i for i in range(depth)] 207 | 208 | if pooling == 'avg': 209 | pooling = nn.AvgPool2d(2) 210 | elif pooling == 'max': 211 | pooling = nn.MaxPool2d(2) 212 | 213 | down_path = nn.ModuleList() 214 | prev_channels = input_channels 215 | for i in range(depth): 216 | layers = [ 217 | pooling if i != 0 else NoOperation(), 218 | ConvLayer(prev_channels, channels * 2 ** i, 3, 219 | padding=1, bn_module=bn_module, 220 | activation=activation), 221 | ConvLayer(channels * 2 ** i, channels * 2 ** i, 3, 222 | padding=1, bn_module=bn_module, 223 | activation=activation), 224 | ] 225 | if se_module is not None: 226 | layers.append(se_module(channels * 2 ** i)) 227 | 228 | down_path.append(nn.Sequential(*layers)) 229 | prev_channels = channels * 2 ** i 230 | 231 | super(BasicEncoder, self).__init__(down_path, channels_list) 232 | 233 | 234 | class ResNetEncoder(EncoderBase): 235 | 236 | def __init__(self, encoder_depth=18, pretrained=True, 237 | bn_module=nn.BatchNorm2d, relu_inplace=False): 238 | 239 | if encoder_depth == 18: 240 | backbone = resnet_csail.resnet18(pretrained=pretrained, bn_module=bn_module, relu_inplace=relu_inplace) 241 | channels_list = [64, 128, 256, 512] 242 | elif encoder_depth == 34: 243 | backbone = torchvision.models.resnet34(pretrained=pretrained) 244 | channels_list = [64, 128, 256, 512] 245 | elif encoder_depth == 50: 246 | backbone = resnet_csail.resnet50(pretrained=pretrained, bn_module=bn_module, relu_inplace=relu_inplace) 247 | channels_list = [256, 512, 1024, 2048] 248 | elif encoder_depth == 101: # res4b22 corresponds to layer3[-1] 249 | backbone = resnet_csail.resnet101(pretrained=pretrained, bn_module=bn_module, relu_inplace=relu_inplace) 250 | channels_list = [256, 512, 1024, 2048] 251 | elif encoder_depth == 152: 252 | backbone = torchvision.models.resnet152(pretrained=pretrained) 253 | channels_list = [256, 512, 1024, 2048] 254 | 255 | elif encoder_depth == 38: 256 | from . import wider_resnet 257 | backbone = wider_resnet.net_wider_resnet38_a2() 258 | if pretrained: 259 | state_dict = torch.load('/opt/segmentation/weights/wide_resnet38_deeplab_vistas.pth.tar') 260 | backbone.load_state_dict(state_dict['state_dict']['body'], strict=True) 261 | channels_list = [256, 512, 1024, 4096] 262 | 263 | elif encoder_depth == 'X_101': 264 | from experiments import resnet_not_inplace 265 | backbone = resnet_not_inplace.get_pretrained_backbone( 266 | cfg_path='/opt/segmentation/maskrcnn-benchmark/configs/dtc/MX_101_32x8d_FPN.yaml', 267 | bn_module=bn_module 268 | ) 269 | channels_list = [256, 512, 1024, 2048] 270 | 271 | else: 272 | raise ValueError('invalid value (encoder_depth)') 273 | 274 | if encoder_depth == 38: 275 | encoder = nn.ModuleList([ 276 | nn.Sequential(backbone.mod1, backbone.pool2, backbone.mod2, backbone.pool3, backbone.mod3), 277 | backbone.mod4, 278 | backbone.mod5, 279 | nn.Sequential(backbone.mod6, backbone.mod7), 280 | ]) 281 | elif encoder_depth in [18, 50, 101]: 282 | encoder = nn.ModuleList([ 283 | nn.Sequential( 284 | backbone.conv1, backbone.bn1, backbone.relu1, 285 | backbone.conv2, backbone.bn2, backbone.relu2, 286 | backbone.conv3, backbone.bn3, backbone.relu3, 287 | backbone.maxpool, 288 | backbone.layer1, 289 | ), 290 | backbone.layer2, 291 | backbone.layer3, 292 | backbone.layer4, 293 | ]) 294 | elif encoder_depth in [34, 152]: 295 | encoder = nn.ModuleList([ 296 | nn.Sequential( 297 | backbone.conv1, backbone.bn1, backbone.relu, 298 | backbone.maxpool, 299 | backbone.layer1, 300 | ), 301 | backbone.layer2, 302 | backbone.layer3, 303 | backbone.layer4, 304 | ]) 305 | elif encoder_depth == 'X_101': 306 | encoder = nn.ModuleList([ 307 | nn.Sequential(backbone.stem, backbone.layer1), 308 | backbone.layer2, 309 | backbone.layer3, 310 | backbone.layer4, 311 | ]) 312 | 313 | super(ResNetEncoder, self).__init__(encoder, channels_list) 314 | 315 | 316 | class SelfAttentionBlock(nn.Module): 317 | """ 318 | The basic implementation for self-attention block/non-local block 319 | Input: 320 | N X C X H X W 321 | Parameters: 322 | in_channels : the dimension of the input feature map 323 | key_channels : the dimension after the key/query transform 324 | value_channels : the dimension after the value transform 325 | scale : choose the scale to downsample the input feature maps (save memory cost) 326 | Return: 327 | N X C X H X W 328 | position-aware context features.(w/o concate or add with the input) 329 | """ 330 | 331 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1, 332 | bn_module=nn.BatchNorm2d): 333 | super(SelfAttentionBlock, self).__init__() 334 | self.scale = scale 335 | self.in_channels = in_channels 336 | self.out_channels = out_channels 337 | self.key_channels = key_channels 338 | self.value_channels = value_channels 339 | if out_channels is None: 340 | self.out_channels = in_channels 341 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 342 | self.f_key = nn.Sequential( 343 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 344 | kernel_size=1, stride=1, padding=0), 345 | bn_module(self.key_channels), 346 | ) 347 | self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, 348 | kernel_size=1, stride=1, padding=0) 349 | self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels, 350 | kernel_size=1, stride=1, padding=0) 351 | nn.init.constant_(self.W.weight, 0) 352 | nn.init.constant_(self.W.bias, 0) 353 | 354 | def forward(self, x): 355 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 356 | if self.scale > 1: 357 | x = self.pool(x) 358 | 359 | value = self.f_value(x).view(batch_size, self.value_channels, -1) 360 | value = value.permute(0, 2, 1) 361 | key = self.f_key(x).view(batch_size, self.key_channels, -1) 362 | query = key.permute(0, 2, 1) 363 | 364 | sim_map = torch.matmul(query, key) 365 | sim_map = (self.key_channels ** -.5) * sim_map 366 | sim_map = F.softmax(sim_map, dim=-1) 367 | 368 | context = torch.matmul(sim_map, value) 369 | context = context.permute(0, 2, 1).contiguous() 370 | context = context.view(batch_size, self.value_channels, *x.size()[2:]) 371 | context = self.W(context) 372 | if self.scale > 1: 373 | context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) 374 | return context 375 | 376 | 377 | class BaseOC_Context_Module(nn.Module): 378 | """ 379 | Output only the context features. 380 | Parameters: 381 | in_features / out_features: the channels of the input / output feature maps. 382 | size: we find that directly learn the attention weights on even 1/8 feature maps is hard. 383 | Return: 384 | features after "concat" or "add" 385 | """ 386 | 387 | def __init__(self, in_channels, out_channels, key_channels, value_channels, sizes=([1]), bn_module=nn.BatchNorm2d): 388 | super(BaseOC_Context_Module, self).__init__() 389 | self.stages = [] 390 | self.stages = nn.ModuleList([ 391 | self._make_stage(in_channels, out_channels, key_channels, value_channels, size, bn_module) 392 | for size in sizes 393 | ]) 394 | 395 | @staticmethod 396 | def _make_stage(in_channels, output_channels, key_channels, value_channels, size, bn_module): 397 | return SelfAttentionBlock(in_channels, 398 | key_channels, 399 | value_channels, 400 | output_channels, 401 | size, 402 | bn_module) 403 | 404 | def forward(self, feats): 405 | priors = [stage(feats) for stage in self.stages] 406 | context = priors[0] 407 | for i in range(1, len(priors)): 408 | context += priors[i] 409 | return context 410 | 411 | 412 | class ASP_OC_Module(nn.Module): 413 | """ 414 | OC-Module (bit modified version) 415 | ref: https://github.com/PkuRainBow/OCNet/blob/master/LICENSE 416 | """ 417 | def __init__(self, in_features=2048, out_features=2048, dilations=(2, 5, 9), bn_module=nn.BatchNorm2d, size=1): 418 | super(ASP_OC_Module, self).__init__() 419 | internal_features = in_features // 4 420 | self.context = nn.Sequential( 421 | nn.Conv2d(in_features, internal_features, kernel_size=3, padding=1, dilation=1, bias=True), 422 | bn_module(internal_features), 423 | BaseOC_Context_Module(in_channels=internal_features, out_channels=internal_features, 424 | key_channels=internal_features // 2, value_channels=internal_features, 425 | sizes=([size]), bn_module=bn_module)) 426 | self.conv2 = nn.Sequential( 427 | nn.Conv2d(in_features, internal_features, kernel_size=1, padding=0, dilation=1, bias=False), 428 | bn_module(internal_features)) 429 | self.conv3 = nn.Sequential( 430 | nn.Conv2d(in_features, internal_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], 431 | bias=False), 432 | bn_module(internal_features)) 433 | self.conv4 = nn.Sequential( 434 | nn.Conv2d(in_features, internal_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], 435 | bias=False), 436 | bn_module(internal_features)) 437 | self.conv5 = nn.Sequential( 438 | nn.Conv2d(in_features, internal_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], 439 | bias=False), 440 | bn_module(internal_features)) 441 | 442 | self.conv_bn_dropout = nn.Sequential( 443 | nn.Conv2d(internal_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False), 444 | bn_module(out_features), 445 | nn.Dropout2d(0.1) 446 | ) 447 | 448 | @staticmethod 449 | def _cat_each(feat1, feat2, feat3, feat4, feat5): 450 | assert (len(feat1) == len(feat2)) 451 | z = [] 452 | for i in range(len(feat1)): 453 | z.append(torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), 1)) 454 | return z 455 | 456 | def forward(self, x): 457 | if isinstance(x, Variable): 458 | _, _, h, w = x.size() 459 | elif isinstance(x, tuple) or isinstance(x, list): 460 | _, _, h, w = x[0].size() 461 | else: 462 | raise RuntimeError('unknown input type') 463 | 464 | feat1 = self.context(x) 465 | feat2 = self.conv2(x) 466 | feat3 = self.conv3(x) 467 | feat4 = self.conv4(x) 468 | feat5 = self.conv5(x) 469 | 470 | if isinstance(x, Variable): 471 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1) 472 | elif isinstance(x, tuple) or isinstance(x, list): 473 | out = self._cat_each(feat1, feat2, feat3, feat4, feat5) 474 | else: 475 | raise RuntimeError('unknown input type') 476 | 477 | output = self.conv_bn_dropout(out) 478 | return output 479 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, in_channels=1, n_classes=2, depth=5, ch_first=6, padding=False, 8 | batch_norm=False, up_mode='upconv'): 9 | """ 10 | Implementation of 11 | U-Net: Convolutional Networks for Biomedical Image Segmentation 12 | (Ronneberger et al., 2015) 13 | https://arxiv.org/abs/1505.04597 14 | Using the default arguments will yield the exact version used 15 | in the original paper 16 | Args: 17 | in_channels (int): number of input channels 18 | n_classes (int): number of output channels 19 | depth (int): depth of the network 20 | ch_first (int): number of filters in the first layer is 2**wf 21 | padding (bool): if True, apply padding such that the input shape 22 | is the same as the output. 23 | This may introduce artifacts 24 | batch_norm (bool): Use BatchNorm after layers with an 25 | activation function 26 | up_mode (str): one of 'deconv' or 'upconv'. 27 | 'deconv' will use transposed convolutions for 28 | learned upsampling. 29 | 'upconv' will use bilinear upsampling. 30 | """ 31 | super(UNet, self).__init__() 32 | assert up_mode in ('deconv', 'upconv') 33 | self.padding = padding 34 | self.depth = depth 35 | prev_channels = in_channels 36 | self.down_path = nn.ModuleList() 37 | for i in range(depth): 38 | self.down_path.append(UNetConvBlock(prev_channels, 2 ** (ch_first + i), 39 | padding, batch_norm)) 40 | prev_channels = 2**(ch_first + i) 41 | 42 | self.up_path = nn.ModuleList() 43 | for i in reversed(range(depth - 1)): 44 | self.up_path.append(UNetUpBlock(prev_channels, 2 ** (ch_first + i), up_mode, 45 | padding, batch_norm)) 46 | prev_channels = 2**(ch_first + i) 47 | 48 | self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1) 49 | 50 | def forward(self, x): 51 | blocks = [] 52 | for i, down in enumerate(self.down_path): 53 | x = down(x) 54 | if i != len(self.down_path)-1: 55 | blocks.append(x) 56 | x = F.avg_pool2d(x, 2) 57 | 58 | for i, up in enumerate(self.up_path): 59 | x = up(x, blocks[-i-1]) 60 | 61 | return self.last(x) 62 | 63 | 64 | class UNetConvBlock(nn.Module): 65 | def __init__(self, in_size, out_size, padding, batch_norm): 66 | super(UNetConvBlock, self).__init__() 67 | block = [] 68 | 69 | block.append(nn.Conv2d(in_size, out_size, kernel_size=3, 70 | padding=int(padding))) 71 | block.append(nn.ReLU()) 72 | if batch_norm: 73 | block.append(nn.BatchNorm2d(out_size)) 74 | 75 | block.append(nn.Conv2d(out_size, out_size, kernel_size=3, 76 | padding=int(padding))) 77 | block.append(nn.ReLU()) 78 | if batch_norm: 79 | block.append(nn.BatchNorm2d(out_size)) 80 | 81 | self.block = nn.Sequential(*block) 82 | 83 | def forward(self, x): 84 | out = self.block(x) 85 | return out 86 | 87 | 88 | class UNetUpBlock(nn.Module): 89 | def __init__(self, in_size, out_size, up_mode, padding, batch_norm): 90 | super(UNetUpBlock, self).__init__() 91 | if up_mode == 'deconv': 92 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, 93 | stride=2) 94 | elif up_mode == 'upconv': 95 | self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2), 96 | nn.Conv2d(in_size, out_size, kernel_size=1)) 97 | 98 | self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm) 99 | 100 | def center_crop(self, layer, target_size): 101 | _, _, layer_height, layer_width = layer.size() 102 | diff_y = (layer_height - target_size[0]) // 2 103 | diff_x = (layer_width - target_size[1]) // 2 104 | return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])] 105 | 106 | def forward(self, x, bridge): 107 | up = self.up(x) 108 | crop1 = self.center_crop(bridge, up.shape[2:]) 109 | out = torch.cat([up, crop1], 1) 110 | out = self.conv_block(out) 111 | 112 | return out 113 | 114 | 115 | class MultiModalNN(nn.Module): 116 | # https://yashuseth.blog/2018/07/22/pytorch-neural-network-for-tabular-data-with-categorical-embeddings/ 117 | 118 | def __init__(self, emb_dims, n_numeric_feats, n_channels_list=(64, 128), 119 | n_classes=1, emb_dropout=0.2, dropout_list=(0.5, 0.5)): 120 | 121 | """ 122 | Parameters 123 | ---------- 124 | 125 | emb_dims: List of two element tuples 126 | This list will contain a two element tuple for each 127 | categorical feature. The first element of a tuple will 128 | denote the number of unique values of the categorical 129 | feature. The second element will denote the embedding 130 | dimension to be used for that feature. 131 | 132 | n_numeric_feats: Integer 133 | The number of continuous features in the data. 134 | 135 | n_channels_list: List of integers. 136 | The size of each linear layer. The length will be equal 137 | to the total number 138 | of linear layers in the network. 139 | 140 | n_classes: Integer 141 | The size of the final output. 142 | 143 | emb_dropout: Float 144 | The dropout to be used after the embedding layers. 145 | 146 | dropout_list: List of floats 147 | The dropouts to be used after each linear layer. 148 | 149 | Examples 150 | -------- 151 | >>> cat_dims = [int(data[col].nunique()) for col in categorical_features] 152 | >>> cat_dims 153 | [15, 5, 2, 4, 112] 154 | >>> emb_dims = [(x, min(32, (x + 1) // 2)) for x in cat_dims] 155 | >>> emb_dims 156 | [(15, 8), (5, 3), (2, 1), (4, 2), (112, 32)] 157 | >>> model = MultiModalNN(emb_dims, n_numeric_feats=4, lin_layer_sizes=[50, 100], 158 | >>> output_size=1, emb_dropout=0.04, 159 | >>> lin_layer_dropouts=[0.001,0.01]).to(device) 160 | """ 161 | 162 | super(MultiModalNN, self).__init__() 163 | 164 | # Embedding layers 165 | self.emb_layers = nn.ModuleList([nn.Embedding(x, y) 166 | for x, y in emb_dims]) 167 | 168 | no_of_embs = sum([y for x, y in emb_dims]) 169 | self.no_of_embs = no_of_embs 170 | self.n_numeric_feats = n_numeric_feats 171 | 172 | # Linear Layers 173 | first_lin_layer = nn.Linear(self.no_of_embs + self.n_numeric_feats, n_channels_list[0]) 174 | 175 | self.lin_layers = nn.ModuleList([first_lin_layer] + [ 176 | nn.Linear(n_channels_list[i], n_channels_list[i + 1]) for i in range(len(n_channels_list) - 1) 177 | ]) 178 | 179 | for lin_layer in self.lin_layers: 180 | nn.init.kaiming_normal_(lin_layer.weight.data) 181 | 182 | # Output Layer 183 | self.output_layer = nn.Linear(n_channels_list[-1], n_classes) 184 | nn.init.kaiming_normal_(self.output_layer.weight.data) 185 | 186 | # Batch Norm Layers 187 | self.first_bn_layer = nn.BatchNorm1d(self.n_numeric_feats) 188 | self.bn_layers = nn.ModuleList([nn.BatchNorm1d(size) for size in n_channels_list]) 189 | 190 | # Dropout Layers 191 | self.emb_dropout_layer = nn.Dropout(emb_dropout) 192 | self.droput_layers = nn.ModuleList([nn.Dropout(size) for size in dropout_list]) 193 | 194 | def forward(self, numeric_feats, categorical_feats): 195 | 196 | if self.no_of_embs != 0: 197 | x = [emb_layer(categorical_feats[:, i]) for i, emb_layer in enumerate(self.emb_layers)] 198 | x = torch.cat(x, 1) 199 | x = self.emb_dropout_layer(x) 200 | 201 | if self.n_numeric_feats != 0: 202 | normalized_numeric_feats = self.first_bn_layer(numeric_feats) 203 | 204 | if self.no_of_embs != 0: 205 | x = torch.cat([x, normalized_numeric_feats], 1) 206 | else: 207 | x = normalized_numeric_feats 208 | 209 | for lin_layer, dropout_layer, bn_layer in zip(self.lin_layers, self.droput_layers, self.bn_layers): 210 | x = F.relu(lin_layer(x)) 211 | x = bn_layer(x) 212 | x = dropout_layer(x) 213 | 214 | x = self.output_layer(x) 215 | 216 | return x 217 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import subprocess 4 | import sys 5 | import time 6 | from collections import OrderedDict, deque 7 | from pathlib import Path 8 | 9 | import logzero 10 | import pandas as pd 11 | import torch 12 | import torch.optim as optim 13 | from sklearn.model_selection import ParameterGrid, ParameterSampler 14 | from tensorboardX import SummaryWriter 15 | 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.val = 0 25 | self.avg = 0 26 | self.sum = 0 27 | self.count = 0 28 | 29 | def update(self, val, n=1): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | 35 | 36 | def save_checkpoint(model, epoch, filename, optimizer=None, save_arch=False, params=None): 37 | attributes = { 38 | 'epoch': epoch, 39 | 'state_dict': model.state_dict(), 40 | } 41 | 42 | if optimizer is not None: 43 | attributes['optimizer'] = optimizer.state_dict() 44 | 45 | if save_arch: 46 | attributes['arch'] = model 47 | 48 | if params is not None: 49 | attributes['params'] = params 50 | 51 | try: 52 | torch.save(attributes, filename) 53 | except TypeError: 54 | if 'arch' in attributes: 55 | print('Model architecture will be ignored because the architecture includes non-pickable objects.') 56 | del attributes['arch'] 57 | torch.save(attributes, filename) 58 | 59 | 60 | def load_checkpoint(path, model=None, optimizer=None, params=False): 61 | resume = torch.load(path) 62 | 63 | rets = dict() 64 | 65 | if model is not None: 66 | if ('module' in list(resume['state_dict'].keys())[0]) \ 67 | and not (isinstance(model, torch.nn.DataParallel)): 68 | new_state_dict = OrderedDict() 69 | for k, v in resume['state_dict'].items(): 70 | new_state_dict[k.replace('module.', '')] = v # remove DataParallel wrapping 71 | 72 | model.load_state_dict(new_state_dict) 73 | else: 74 | model.load_state_dict(resume['state_dict']) 75 | 76 | rets['model'] = model 77 | 78 | if optimizer is not None: 79 | optimizer.load_state_dict(resume['optimizer']) 80 | rets['optimizer'] = optimizer 81 | if params: 82 | rets['params'] = resume['params'] 83 | 84 | return rets 85 | 86 | 87 | def load_model(path, is_inference=True): 88 | resume = torch.load(path) 89 | model = resume['arch'] 90 | model.load_state_dict(resume['state_dict']) 91 | if is_inference: 92 | model.eval() 93 | return model 94 | 95 | 96 | def get_logger(log_dir, loglevel=logging.INFO, tensorboard_dir=None): 97 | from logzero import logger 98 | 99 | if not Path(log_dir).exists(): 100 | Path(log_dir).mkdir(parents=True) 101 | logzero.loglevel(loglevel) 102 | logzero.logfile(log_dir + '/logfile') 103 | 104 | if tensorboard_dir is not None: 105 | if not Path(tensorboard_dir).exists(): 106 | Path(tensorboard_dir).mkdir(parents=True) 107 | writer = SummaryWriter(tensorboard_dir) 108 | 109 | return logger, writer 110 | 111 | return logger 112 | 113 | 114 | def get_optim(params, target): 115 | 116 | assert isinstance(target, nn.Module) or isinstance(target, dict) 117 | 118 | if isinstance(target, nn.Module): 119 | target = target.parameters() 120 | 121 | if params['optimizer'] == 'sgd': 122 | optimizer = optim.SGD(target, params['lr'], weight_decay=params['wd']) 123 | elif params['optimizer'] == 'momentum': 124 | optimizer = optim.SGD(target, params['lr'], momentum=0.9, weight_decay=params['wd']) 125 | elif params['optimizer'] == 'nesterov': 126 | optimizer = optim.SGD(target, params['lr'], momentum=0.9, 127 | weight_decay=params['wd'], nesterov=True) 128 | elif params['optimizer'] == 'adam': 129 | optimizer = optim.Adam(target, params['lr'], weight_decay=params['wd']) 130 | elif params['optimizer'] == 'amsgrad': 131 | optimizer = optim.Adam(target, params['lr'], weight_decay=params['wd'], amsgrad=True) 132 | elif params['optimizer'] == 'rmsprop': 133 | optimizer = optim.RMSprop(target, params['lr'], weight_decay=params['wd']) 134 | else: 135 | raise ValueError 136 | 137 | return optimizer 138 | 139 | 140 | def write_tuning_result(params: dict, result: dict, df_path: str): 141 | row = pd.DataFrame() 142 | for key in params['tuning_params']: 143 | row[key] = [params[key]] 144 | 145 | for key, val in result.items(): 146 | row[key] = val 147 | 148 | with lockfile.FileLock(df_path): 149 | df_results = pd.read_csv(df_path) 150 | df_results = pd.concat([df_results, row], sort=False).reset_index(drop=True) 151 | df_results.to_csv(df_path, index=None) 152 | 153 | 154 | def check_duplicate(df: pd.DataFrame, p: dict, space): 155 | """check if current params combination has already done""" 156 | 157 | new_key_is_included = not all(map(lambda x: x in df.columns, space.keys())) 158 | if new_key_is_included: 159 | return False 160 | 161 | for i in range(len(df)): # for avoiding unexpected cast due to row-slicing 162 | is_dup = True 163 | for key, val in p.items(): 164 | if df.loc[i, key] != val: 165 | is_dup = False 166 | break 167 | if is_dup: 168 | return True 169 | else: 170 | return False 171 | 172 | 173 | def launch_tuning(mode: str, n_iter: int, n_gpu: int, devices: str, 174 | params: dict, space: dict, root): 175 | """ 176 | Launch paramter search by specific way. 177 | Each trials are launched asynchronously by forking subprocess and all results of trials 178 | are automatically written in csv file. 179 | 180 | :param mode: the way of parameter search, one of 'grid or random'. 181 | :param n_iter: num of iteration for random search. 182 | :param n_gpu: num of gpu used at one trial. 183 | :param devices: gpu devices for tuning. 184 | :param params: training parameters. 185 | the values designated as tuning parameters are overwritten 186 | :param space: paramter search space. 187 | :param root: path of the root directory. 188 | """ 189 | 190 | gpu_list = deque(devices.split(',')) 191 | 192 | if mode == 'grid': 193 | param_list = list(ParameterGrid(space)) 194 | elif mode == 'random': 195 | param_list = list(ParameterSampler(space, n_iter)) 196 | else: 197 | raise ValueError 198 | 199 | params['tuning_params'] = list(param_list[0].keys()) 200 | 201 | df_path = root+f'experiments/{params["ex_name"]}/tuning/results.csv' 202 | if Path(df_path).exists() and Path(df_path).stat().st_size > 5: 203 | df_results = pd.read_csv(df_path) 204 | else: 205 | cols = list(param_list[0].keys()) 206 | df_results = pd.DataFrame(columns=cols) 207 | df_results.to_csv(df_path, index=False) 208 | 209 | procs = [] 210 | for p in param_list: 211 | 212 | if check_duplicate(df_results, p, param_list[0]): 213 | print(f'skip: {p} because this setting is already experimented.') 214 | continue 215 | 216 | # overwrite hyper parameters for search 217 | for key, val in p.items(): 218 | params[key] = val 219 | 220 | while True: 221 | if len(gpu_list) >= n_gpu: 222 | devices = ','.join([gpu_list.pop() for _ in range(n_gpu)]) 223 | params_path = root + f'experiments/{params["ex_name"]}/tuning/params_{devices[0]}.json' 224 | with open(params_path, 'w') as f: 225 | json.dump(params, f) 226 | break 227 | else: 228 | time.sleep(1) 229 | for i, (p, dev) in enumerate(procs): 230 | if p.poll() is not None: 231 | gpu_list += deque(dev.split(',')) 232 | del procs[i] 233 | 234 | cmd = f'{sys.executable} {params["ex_name"]}.py job ' \ 235 | f'--tuning --params-path {params_path} --devices "{devices}"' 236 | procs.append((subprocess.Popen(cmd, shell=True), devices)) 237 | 238 | while True: 239 | time.sleep(1) 240 | if all(p.poll() is not None for i, (p, dev) in enumerate(procs)): 241 | print('All parameter combinations have finished.') 242 | break 243 | 244 | show_tuning_result(params["ex_name"]) 245 | 246 | 247 | def show_tuning_result(ex_name, mode='markdown', sort_by=None, ascending=False): 248 | 249 | table = pd.read_csv(f'../experiments/{ex_name}/tuning/results.csv') 250 | if sort_by is not None: 251 | table = table.sort_values(sort_by, ascending=ascending) 252 | 253 | if mode == 'markdown': 254 | from tabulate import tabulate 255 | print(tabulate(table, headers='keys', tablefmt='pipe', showindex=False)) 256 | elif mode == 'latex': 257 | from tabulate import tabulate 258 | print(tabulate(table, headers='keys', tablefmt='latex', floatfmt='.2f', showindex=False)) 259 | else: 260 | from IPython.core.display import display 261 | display(table) 262 | -------------------------------------------------------------------------------- /submit/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyakaap/pytorch-template/eff9f0a4dd50fa49c3b949065247598d5eabc91e/submit/.gitkeep -------------------------------------------------------------------------------- /tests/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lyakaap/pytorch-template/eff9f0a4dd50fa49c3b949065247598d5eabc91e/tests/.gitkeep -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | 9 | from src import utils 10 | 11 | 12 | class TestSaveAndLoadCheckpoint(unittest.TestCase): 13 | 14 | def test_save_and_load_checkpoint(self): 15 | model = torchvision.models.resnet18(pretrained=False) 16 | utils.save_checkpoint(model, epoch=100, filename='tmp.pth', save_arch=True) 17 | 18 | loaded_model = utils.load_model('tmp.pth') 19 | 20 | torch.testing.assert_allclose(model.conv1.weight, loaded_model.conv1.weight) 21 | 22 | model.conv1.weight = nn.Parameter(torch.zeros_like(model.conv1.weight)) 23 | model = utils.load_checkpoint('tmp.pth', model=model)['model'] 24 | 25 | assert (model.conv1.weight != 0).any() 26 | 27 | def tearDown(self): 28 | Path('tmp.pth').unlink() # rm 29 | 30 | 31 | # class TestAverageMeter(unittest.TestCase): 32 | # 33 | # def test_average_meter(self): 34 | # raise NotImplementedError 35 | # 36 | # 37 | # class TestCheckDuplicate(unittest.TestCase): 38 | # 39 | # def test_check_duplicate(self): 40 | # raise NotImplemented 41 | # 42 | # 43 | # from src.inplace_abn import InPlaceABNSync 44 | # from src.sync_batchnorm import SynchronizedBatchNorm2d 45 | # from src.modeling import resnet_csail 46 | # 47 | # @pytest.mark.parametrize('src_bn, dst_bn', [ 48 | # (nn.BatchNorm2d, InPlaceABNSync), 49 | # (InPlaceABNSync, nn.BatchNorm2d), 50 | # (nn.BatchNorm2d, SynchronizedBatchNorm2d), 51 | # (SynchronizedBatchNorm2d, nn.BatchNorm2d), 52 | # (InPlaceABNSync, SynchronizedBatchNorm2d), 53 | # (SynchronizedBatchNorm2d, InPlaceABNSync), 54 | # ]) 55 | # def test_replace_bn(src_bn, dst_bn): 56 | # model = resnet_csail.resnet18(pretrained=True, bn_module=src_bn) 57 | # w, b = model.bn1.weight, model.bn1.bias 58 | # utils.replace_bn(model, src_bn, dst_bn) 59 | # 60 | # cnt_src_bn, cnt_dst_bn = 0, 0 61 | # 62 | # for name, m in model.named_modules(): 63 | # if name == 'bn1': 64 | # assert (w == m.weight).all() and (b == m.bias).all() 65 | # if isinstance(m, src_bn): 66 | # cnt_src_bn += 1 67 | # if isinstance(m, dst_bn): 68 | # cnt_dst_bn += 1 69 | # 70 | # assert cnt_src_bn == 0 and cnt_dst_bn > 0 71 | --------------------------------------------------------------------------------