├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── .gitignore ├── clr.py ├── color150.npy ├── dataset.py ├── distributed.py ├── model.py ├── sample └── .gitignore ├── scheduler.py ├── train.py ├── transform.py ├── util.py └── vovnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.pth 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ocr-pytorch 2 | 3 | Implementation of Object-Contextual Representations for Semantic Segmentation (https://arxiv.org/abs/1909.11065) in PyTorch 4 | 5 | ## Usage 6 | 7 | > python -m torch.distributed.launch --nproc_per_node=4 --master_port=8890 train.py --batch 4 [ADE20K PATH] 8 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /clr.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, tanh 2 | 3 | 4 | def anneal_linear(start, end, proportion): 5 | return start + proportion * (end - start) 6 | 7 | 8 | def anneal_cos(start, end, proportion): 9 | cos_val = cos(pi * proportion) + 1 10 | 11 | return end + (start - end) / 2 * cos_val 12 | 13 | 14 | def anneal_cospow(start, end, proportion): 15 | power = 5 16 | 17 | cos_val = 0.5 * (cos(pi * proportion) + 1) + 1 18 | cos_val = power ** cos_val - power 19 | cos_val = cos_val / (power ** 2 - power) 20 | 21 | return end + (start - end) * cos_val 22 | 23 | 24 | def anneal_poly(start, end, proportion, power=0.9): 25 | return (start - end) * (1 - proportion) ** power + end 26 | 27 | 28 | def anneal_tanh(start, end, proportion, lower=-6, upper=3): 29 | return end + (start - end) / 2 * (1 - tanh(lower + (upper - lower) * proportion)) 30 | 31 | 32 | class Phase: 33 | def __init__(self, start, end, n_iter, anneal_fn): 34 | self.start, self.end = start, end 35 | self.n_iter = n_iter 36 | self.anneal_fn = anneal_fn 37 | self.n = 0 38 | 39 | def step(self): 40 | self.n += 1 41 | 42 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter) 43 | 44 | def reset(self): 45 | self.n = 0 46 | 47 | @property 48 | def is_done(self): 49 | return self.n >= self.n_iter 50 | 51 | 52 | class CycleScheduler: 53 | def __init__( 54 | self, 55 | optimizer, 56 | lr_max, 57 | n_iter, 58 | divider=25, 59 | warmup_proportion=0.3, 60 | phase=('linear', 'cos'), 61 | ): 62 | self.optimizer = optimizer 63 | 64 | phase1 = int(n_iter * warmup_proportion) 65 | phase2 = n_iter - phase1 66 | lr_min = lr_max / divider 67 | 68 | phase_map = { 69 | 'linear': anneal_linear, 70 | 'cos': anneal_cos, 71 | 'cospow': anneal_cospow, 72 | 'poly': anneal_poly, 73 | 'tanh': anneal_tanh, 74 | } 75 | 76 | self.lr_phase = [ 77 | Phase(lr_min, lr_max, phase1, phase_map[phase[0]]), 78 | Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]), 79 | ] 80 | 81 | self.phase = 0 82 | 83 | def step(self): 84 | lr = self.lr_phase[self.phase].step() 85 | 86 | for group in self.optimizer.param_groups: 87 | group['lr'] = lr 88 | 89 | if self.lr_phase[self.phase].is_done: 90 | self.phase += 1 91 | 92 | if self.phase >= len(self.lr_phase): 93 | for phase in self.lr_phase: 94 | phase.reset() 95 | 96 | self.phase = 0 97 | 98 | return lr 99 | -------------------------------------------------------------------------------- /color150.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/ocr-pytorch/729b7253dc2681d7c2760bb34bcba4839e246417/color150.npy -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class ADE20K(Dataset): 9 | def __init__(self, path, split, transform=None): 10 | split_path = {'train': 'training', 'valid': 'validation'} 11 | self.img_path = os.path.join(path, 'images', split_path[split]) 12 | self.annot_path = os.path.join(path, 'annotations', split_path[split]) 13 | files = os.listdir(self.img_path) 14 | self.ids = [] 15 | 16 | for file in files: 17 | name, ext = os.path.splitext(file) 18 | if ext.lower() == '.jpg': 19 | self.ids.append(name) 20 | 21 | self.transform = transform 22 | 23 | def __len__(self): 24 | return len(self.ids) 25 | 26 | def __getitem__(self, index): 27 | id = self.ids[index] 28 | img = Image.open(os.path.join(self.img_path, id) + '.jpg').convert('RGB') 29 | annot = Image.open(os.path.join(self.annot_path, id) + '.png') 30 | 31 | if self.transform is not None: 32 | img, annot = self.transform(img, annot) 33 | 34 | return img, annot 35 | 36 | 37 | def collate_data(batch): 38 | max_height = max([b[0].shape[1] for b in batch]) 39 | max_width = max([b[0].shape[2] for b in batch]) 40 | batch_size = len(batch) 41 | 42 | img_batch = torch.zeros(batch_size, 3, max_height, max_width, dtype=torch.float32) 43 | annot_batch = torch.zeros(batch_size, max_height, max_width, dtype=torch.int64) 44 | 45 | for i, (img, annot) in enumerate(batch): 46 | _, height, width = img.shape 47 | img_batch[i, :, :height, :width] = img 48 | annot_batch[i, :height, :width] = annot 49 | 50 | return img_batch, annot_batch 51 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def all_gather(data): 45 | world_size = get_world_size() 46 | 47 | if world_size == 1: 48 | return [data] 49 | 50 | buffer = pickle.dumps(data) 51 | storage = torch.ByteStorage.from_buffer(buffer) 52 | tensor = torch.ByteTensor(storage).to('cuda') 53 | 54 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 55 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 56 | dist.all_gather(size_list, local_size) 57 | size_list = [int(size.item()) for size in size_list] 58 | max_size = max(size_list) 59 | 60 | tensor_list = [] 61 | for _ in size_list: 62 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 63 | 64 | if local_size != max_size: 65 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 66 | tensor = torch.cat((tensor, padding), 0) 67 | 68 | dist.all_gather(tensor_list, tensor) 69 | 70 | data_list = [] 71 | 72 | for size, tensor in zip(size_list, tensor_list): 73 | buffer = tensor.cpu().numpy().tobytes()[:size] 74 | data_list.append(pickle.loads(buffer)) 75 | 76 | return data_list 77 | 78 | 79 | def reduce_loss_dict(loss_dict): 80 | world_size = get_world_size() 81 | 82 | if world_size < 2: 83 | return loss_dict 84 | 85 | with torch.no_grad(): 86 | keys = [] 87 | losses = [] 88 | 89 | for k in sorted(loss_dict.keys()): 90 | keys.append(k) 91 | losses.append(loss_dict[k]) 92 | 93 | losses = torch.stack(losses, 0) 94 | dist.reduce(losses, dst=0) 95 | 96 | if dist.get_rank() == 0: 97 | losses /= world_size 98 | 99 | reduced_losses = {k: v for k, v in zip(keys, losses)} 100 | 101 | return reduced_losses 102 | 103 | 104 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 105 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 106 | # FIXME remove this once c10d fixes the bug it has 107 | 108 | 109 | class DistributedSampler(Sampler): 110 | """Sampler that restricts data loading to a subset of the dataset. 111 | It is especially useful in conjunction with 112 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 113 | process can pass a DistributedSampler instance as a DataLoader sampler, 114 | and load a subset of the original dataset that is exclusive to it. 115 | .. note:: 116 | Dataset is assumed to be of constant size. 117 | Arguments: 118 | dataset: Dataset used for sampling. 119 | num_replicas (optional): Number of processes participating in 120 | distributed training. 121 | rank (optional): Rank of the current process within num_replicas. 122 | """ 123 | 124 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 125 | if num_replicas is None: 126 | if not dist.is_available(): 127 | raise RuntimeError("Requires distributed package to be available") 128 | num_replicas = dist.get_world_size() 129 | if rank is None: 130 | if not dist.is_available(): 131 | raise RuntimeError("Requires distributed package to be available") 132 | rank = dist.get_rank() 133 | self.dataset = dataset 134 | self.num_replicas = num_replicas 135 | self.rank = rank 136 | self.epoch = 0 137 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 138 | self.total_size = self.num_samples * self.num_replicas 139 | self.shuffle = shuffle 140 | 141 | def __iter__(self): 142 | if self.shuffle: 143 | # deterministically shuffle based on epoch 144 | g = torch.Generator() 145 | g.manual_seed(self.epoch) 146 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 147 | else: 148 | indices = torch.arange(len(self.dataset)).tolist() 149 | 150 | # add extra samples to make it evenly divisible 151 | indices += indices[: (self.total_size - len(indices))] 152 | assert len(indices) == self.total_size 153 | 154 | # subsample 155 | offset = self.num_samples * self.rank 156 | indices = indices[offset : offset + self.num_samples] 157 | assert len(indices) == self.num_samples 158 | 159 | return iter(indices) 160 | 161 | def __len__(self): 162 | return self.num_samples 163 | 164 | def set_epoch(self, epoch): 165 | self.epoch = epoch 166 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def conv2d(in_channel, out_channel, kernel_size): 7 | layers = [ 8 | nn.Conv2d( 9 | in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False 10 | ), 11 | nn.BatchNorm2d(out_channel), 12 | nn.ReLU(), 13 | ] 14 | 15 | return nn.Sequential(*layers) 16 | 17 | 18 | def conv1d(in_channel, out_channel): 19 | layers = [ 20 | nn.Conv1d(in_channel, out_channel, 1, bias=False), 21 | nn.BatchNorm1d(out_channel), 22 | nn.ReLU(), 23 | ] 24 | 25 | return nn.Sequential(*layers) 26 | 27 | 28 | class OCR(nn.Module): 29 | def __init__(self, n_class, backbone, feat_channels=[768, 1024]): 30 | super().__init__() 31 | 32 | self.backbone = backbone 33 | 34 | ch16, ch32 = feat_channels 35 | 36 | self.L = nn.Conv2d(ch16, n_class, 1) 37 | self.X = conv2d(ch32, 512, 3) 38 | 39 | self.phi = conv1d(512, 256) 40 | self.psi = conv1d(512, 256) 41 | self.delta = conv1d(512, 256) 42 | self.rho = conv1d(256, 512) 43 | self.g = conv2d(512 + 512, 512, 1) 44 | 45 | self.out = nn.Conv2d(512, n_class, 1) 46 | 47 | self.criterion = nn.CrossEntropyLoss(ignore_index=0) 48 | 49 | def forward(self, input, target=None): 50 | input_size = input.shape[2:] 51 | stg16, stg32 = self.backbone(input)[-2:] 52 | 53 | X = self.X(stg32) 54 | L = self.L(stg16) 55 | batch, n_class, height, width = L.shape 56 | l_flat = L.view(batch, n_class, -1) 57 | # M: NKL 58 | M = torch.softmax(l_flat, -1) 59 | channel = X.shape[1] 60 | X_flat = X.view(batch, channel, -1) 61 | # f_k: NCK 62 | f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2) 63 | 64 | # query: NKD 65 | query = self.phi(f_k).transpose(1, 2) 66 | # key: NDL 67 | key = self.psi(X_flat) 68 | logit = query @ key 69 | # attn: NKL 70 | attn = torch.softmax(logit, 1) 71 | 72 | # delta: NDK 73 | delta = self.delta(f_k) 74 | # attn_sum: NDL 75 | attn_sum = delta @ attn 76 | # x_obj = NCHW 77 | X_obj = self.rho(attn_sum).view(batch, -1, height, width) 78 | 79 | concat = torch.cat([X, X_obj], 1) 80 | X_bar = self.g(concat) 81 | out = self.out(X_bar) 82 | out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=False) 83 | 84 | if self.training: 85 | aux_out = F.interpolate( 86 | L, size=input_size, mode='bilinear', align_corners=False 87 | ) 88 | 89 | loss = self.criterion(out, target) 90 | aux_loss = self.criterion(aux_out, target) 91 | 92 | return {'loss': loss, 'aux': aux_loss}, None 93 | 94 | else: 95 | return {}, out 96 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | from math import cos, pi, tanh 2 | 3 | 4 | def anneal_linear(start, end, proportion): 5 | return start + proportion * (end - start) 6 | 7 | 8 | def anneal_cos(start, end, proportion): 9 | cos_val = cos(pi * proportion) + 1 10 | 11 | return end + (start - end) / 2 * cos_val 12 | 13 | 14 | def anneal_cospow(start, end, proportion): 15 | power = 5 16 | 17 | cos_val = 0.5 * (cos(pi * proportion) + 1) + 1 18 | cos_val = power ** cos_val - power 19 | cos_val = cos_val / (power ** 2 - power) 20 | 21 | return end + (start - end) * cos_val 22 | 23 | 24 | def anneal_poly(start, end, proportion, power=0.9): 25 | return (start - end) * (1 - proportion) ** power + end 26 | 27 | 28 | def anneal_tanh(start, end, proportion, lower=-6, upper=3): 29 | return end + (start - end) / 2 * (1 - tanh(lower + (upper - lower) * proportion)) 30 | 31 | 32 | class Phase: 33 | def __init__(self, start, end, n_iter, anneal_fn): 34 | self.start, self.end = start, end 35 | self.n_iter = n_iter 36 | self.anneal_fn = anneal_fn 37 | self.n = 0 38 | 39 | def step(self): 40 | self.n += 1 41 | 42 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter) 43 | 44 | def reset(self): 45 | self.n = 0 46 | 47 | @property 48 | def is_done(self): 49 | return self.n >= self.n_iter 50 | 51 | 52 | class CycleScheduler: 53 | def __init__( 54 | self, 55 | optimizer, 56 | lr_max, 57 | n_iter, 58 | divider=25, 59 | warmup_proportion=0.3, 60 | phase=('linear', 'cos'), 61 | ): 62 | self.optimizer = optimizer 63 | 64 | phase1 = int(n_iter * warmup_proportion) 65 | phase2 = n_iter - phase1 66 | lr_min = lr_max / divider 67 | 68 | phase_map = { 69 | 'linear': anneal_linear, 70 | 'cos': anneal_cos, 71 | 'cospow': anneal_cospow, 72 | 'poly': anneal_poly, 73 | 'tanh': anneal_tanh, 74 | } 75 | 76 | self.lr_phase = [ 77 | Phase(lr_min, lr_max, phase1, phase_map[phase[0]]), 78 | Phase(lr_max, lr_min / 1e4, phase2, phase_map[phase[1]]), 79 | ] 80 | 81 | self.phase = 0 82 | 83 | def step(self): 84 | lr = self.lr_phase[self.phase].step() 85 | 86 | for group in self.optimizer.param_groups: 87 | group['lr'] = lr 88 | 89 | if self.lr_phase[self.phase].is_done: 90 | self.phase += 1 91 | 92 | if self.phase >= len(self.lr_phase): 93 | for phase in self.lr_phase: 94 | phase.reset() 95 | 96 | self.phase = 0 97 | 98 | return lr 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | 5 | import torch 6 | from torch import nn, optim 7 | from torch.utils.data import DataLoader, sampler 8 | from tqdm import tqdm 9 | 10 | from model import OCR 11 | from vovnet import vovnet39, vovnet57 12 | import transform 13 | from dataset import ADE20K, collate_data 14 | from util import get_colormap, show_segmentation, intersection_union 15 | from scheduler import CycleScheduler 16 | from distributed import ( 17 | get_rank, 18 | synchronize, 19 | reduce_loss_dict, 20 | DistributedSampler, 21 | all_gather, 22 | ) 23 | 24 | 25 | def train(args, epoch, loader, model, optimizer, scheduler): 26 | torch.backends.cudnn.benchmark = True 27 | 28 | model.train() 29 | 30 | if get_rank() == 0: 31 | pbar = tqdm(loader, dynamic_ncols=True) 32 | 33 | else: 34 | pbar = loader 35 | 36 | for i, (img, annot) in enumerate(pbar): 37 | img = img.to('cuda') 38 | annot = annot.to('cuda') 39 | 40 | loss, _ = model(img, annot) 41 | loss_sum = loss['loss'] + args.aux_weight * loss['aux'] 42 | model.zero_grad() 43 | loss_sum.backward() 44 | optimizer.step() 45 | scheduler.step() 46 | 47 | loss_dict = reduce_loss_dict(loss) 48 | loss = loss_dict['loss'].mean().item() 49 | aux_loss = loss_dict['aux'].mean().item() 50 | 51 | if get_rank() == 0: 52 | lr = optimizer.param_groups[0]['lr'] 53 | 54 | pbar.set_description( 55 | f'epoch: {epoch + 1}; loss: {loss:.5f}; aux loss: {aux_loss:.5f}; lr: {lr:.5f}' 56 | ) 57 | 58 | 59 | @torch.no_grad() 60 | def valid(args, epoch, loader, model, show): 61 | torch.backends.cudnn.benchmark = False 62 | 63 | model.eval() 64 | 65 | if get_rank() == 0: 66 | pbar = tqdm(loader, dynamic_ncols=True) 67 | 68 | else: 69 | pbar = loader 70 | 71 | intersect_sum = None 72 | union_sum = None 73 | correct_sum = 0 74 | total_sum = 0 75 | 76 | for i, (img, annot) in enumerate(pbar): 77 | img = img.to('cuda') 78 | annot = annot.to('cuda') 79 | _, out = model(img) 80 | _, pred = out.max(1) 81 | 82 | if get_rank() == 0 and i % 10 == 0: 83 | result = show(img[0], annot[0], pred[0]) 84 | result.save(f'sample/{str(epoch + 1).zfill(3)}-{str(i).zfill(4)}.png') 85 | 86 | pred = (annot > 0) * pred 87 | correct = (pred > 0) * (pred == annot) 88 | correct_sum += correct.sum().float().item() 89 | total_sum += (annot > 0).sum().float() 90 | 91 | for g, p, c in zip(annot, pred, correct): 92 | intersect, union = intersection_union(g, p, c, args.n_class) 93 | 94 | if intersect_sum is None: 95 | intersect_sum = intersect 96 | 97 | else: 98 | intersect_sum += intersect 99 | 100 | if union_sum is None: 101 | union_sum = union 102 | 103 | else: 104 | union_sum += union 105 | 106 | all_intersect = sum(all_gather(intersect_sum.to('cpu'))) 107 | all_union = sum(all_gather(union_sum.to('cpu'))) 108 | 109 | if get_rank() == 0: 110 | iou = all_intersect / (all_union + 1e-10) 111 | m_iou = iou.mean().item() 112 | 113 | pbar.set_description( 114 | f'acc: {correct_sum / total_sum:.5f}; mIoU: {m_iou:.5f}' 115 | ) 116 | 117 | 118 | def data_sampler(dataset, shuffle, distributed): 119 | if distributed: 120 | return DistributedSampler(dataset, shuffle=shuffle) 121 | 122 | if shuffle: 123 | return sampler.RandomSampler(dataset) 124 | 125 | else: 126 | return sampler.SequentialSampler(dataset) 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('--epoch', type=int, default=100) 132 | parser.add_argument('--batch', type=int, default=16) 133 | parser.add_argument('--size', type=int, default=520) 134 | parser.add_argument('--arch', type=str, default='vovnet39') 135 | parser.add_argument('--aux_weight', type=float, default=0.4) 136 | parser.add_argument('--n_class', type=int, default=150) 137 | parser.add_argument('--lr', type=float, default=2e-2) 138 | parser.add_argument('--l2', type=float, default=1e-4) 139 | parser.add_argument('--local_rank', type=int, default=0) 140 | parser.add_argument('path', metavar='PATH') 141 | 142 | args = parser.parse_args() 143 | 144 | n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 145 | args.distributed = n_gpu > 1 146 | 147 | if args.distributed: 148 | torch.cuda.set_device(args.local_rank) 149 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 150 | synchronize() 151 | 152 | img_mean = [0.485, 0.456, 0.406] 153 | img_std = [0.229, 0.224, 0.225] 154 | device = 'cuda' 155 | # torch.backends.cudnn.deterministic = True 156 | 157 | train_trans = transform.Compose( 158 | [ 159 | transform.RandomScale(0.5, 2.0), 160 | # transform.Resize(args.size, None), 161 | transform.RandomHorizontalFlip(), 162 | transform.RandomCrop(args.size), 163 | transform.RandomBrightness(0.04), 164 | transform.ToTensor(), 165 | transform.Normalize(img_mean, img_std), 166 | transform.Pad(args.size) 167 | ] 168 | ) 169 | 170 | valid_trans = transform.Compose( 171 | [transform.ToTensor(), transform.Normalize(img_mean, img_std)] 172 | ) 173 | 174 | train_set = ADE20K(args.path, 'train', train_trans) 175 | valid_set = ADE20K(args.path, 'valid', valid_trans) 176 | 177 | arch_map = {'vovnet39': vovnet39, 'vovnet57': vovnet57} 178 | backbone = arch_map[args.arch](pretrained=True) 179 | model = OCR(args.n_class + 1, backbone).to(device) 180 | 181 | if args.distributed: 182 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 183 | 184 | model = nn.parallel.DistributedDataParallel( 185 | model, 186 | device_ids=[args.local_rank], 187 | output_device=args.local_rank, 188 | broadcast_buffers=False, 189 | ) 190 | 191 | optimizer = optim.SGD( 192 | model.parameters(), 193 | lr=args.lr / 25, 194 | momentum=0.9, 195 | weight_decay=args.l2, 196 | nesterov=True, 197 | ) 198 | 199 | max_iter = math.ceil(len(train_set) / (n_gpu * args.batch)) * args.epoch 200 | 201 | scheduler = CycleScheduler( 202 | optimizer, 203 | args.lr, 204 | n_iter=max_iter, 205 | warmup_proportion=0.01, 206 | phase=('linear', 'poly'), 207 | ) 208 | 209 | train_loader = DataLoader( 210 | train_set, 211 | batch_size=args.batch, 212 | num_workers=2, 213 | sampler=data_sampler(train_set, shuffle=True, distributed=args.distributed), 214 | ) 215 | valid_loader = DataLoader( 216 | valid_set, 217 | batch_size=args.batch, 218 | num_workers=2, 219 | sampler=data_sampler(valid_set, shuffle=False, distributed=args.distributed), 220 | collate_fn=collate_data, 221 | ) 222 | 223 | colormap = get_colormap('color150.npy') 224 | 225 | def show_result(img, gt, pred): 226 | return show_segmentation(img, gt, pred, img_mean, img_std, colormap) 227 | 228 | for i in range(args.epoch): 229 | train(args, i, train_loader, model, optimizer, scheduler) 230 | valid(args, i, valid_loader, model, show_result) 231 | 232 | if get_rank() == 0: 233 | torch.save( 234 | {'model': model.module.state_dict(), 'args': args}, 235 | f'checkpoint/epoch-{str(i + 1).zfill(3)}.pt', 236 | ) 237 | 238 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from torch.nn.functional import pad 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | class Compose: 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, img, target): 15 | for t in self.transforms: 16 | img, target = t(img, target) 17 | 18 | return img, target 19 | 20 | def __repr__(self): 21 | format_str = self.__class__.__name__ + '(' 22 | for t in self.transforms: 23 | format_str += '\n' 24 | format_str += f' {t}' 25 | format_str += '\n)' 26 | 27 | return format_str 28 | 29 | 30 | class Resize: 31 | def __init__(self, min_size, max_size): 32 | if not isinstance(min_size, (list, tuple)): 33 | min_size = (min_size,) 34 | 35 | self.min_size = min_size 36 | self.max_size = max_size 37 | 38 | def get_size(self, img_size): 39 | w, h = img_size 40 | size = random.choice(self.min_size) 41 | max_size = self.max_size 42 | 43 | if max_size is not None: 44 | min_orig = float(min((w, h))) 45 | max_orig = float(max((w, h))) 46 | 47 | if max_orig / min_orig * size > max_size: 48 | size = int(round(max_size * min_orig / max_orig)) 49 | 50 | if (w <= h and w == size) or (h <= w and h == size): 51 | return h, w 52 | 53 | if w < h: 54 | ow = size 55 | oh = int(size * h / w) 56 | 57 | else: 58 | oh = size 59 | ow = int(size * w / h) 60 | 61 | return oh, ow 62 | 63 | def __call__(self, img, target): 64 | size = self.get_size(img.size) 65 | img = F.resize(img, size) 66 | target = F.resize(target, size, interpolation=Image.NEAREST) 67 | 68 | return img, target 69 | 70 | 71 | class RandomScale: 72 | def __init__(self, min_scale, max_scale): 73 | self.min_scale = min_scale 74 | self.max_scale = max_scale 75 | 76 | def __call__(self, img, target): 77 | w, h = img.size 78 | scale = random.uniform(self.min_scale, self.max_scale) 79 | h *= scale 80 | w *= scale 81 | size = (round(h), round(w)) 82 | 83 | img = F.resize(img, size) 84 | target = F.resize(target, size, interpolation=Image.NEAREST) 85 | 86 | return img, target 87 | 88 | 89 | class RandomBrightness: 90 | def __init__(self, factor): 91 | self.factor = factor 92 | 93 | def __call__(self, img, target): 94 | factor = random.uniform(-self.factor, self.factor) 95 | img = F.adjust_brightness(img, 1 + factor) 96 | 97 | return img, target 98 | 99 | 100 | class RandomHorizontalFlip: 101 | def __init__(self, p=0.5): 102 | self.p = p 103 | 104 | def __call__(self, img, target): 105 | if random.random() < self.p: 106 | img = F.hflip(img) 107 | target = F.hflip(target) 108 | 109 | return img, target 110 | 111 | 112 | class RandomCrop: 113 | def __init__(self, size): 114 | if not isinstance(size, (list, tuple)): 115 | size = (size, size) 116 | 117 | self.size = size 118 | 119 | def __call__(self, img, target): 120 | w, h = img.size 121 | w_range = w - self.size[0] 122 | h_range = h - self.size[1] 123 | if w_range > 0: 124 | left = random.randint(0, w_range - 1) 125 | 126 | else: 127 | left = 0 128 | 129 | if h_range > 0: 130 | top = random.randint(0, h_range - 1) 131 | 132 | else: 133 | top = 0 134 | 135 | height = min(h - top, self.size[1]) 136 | width = min(w - left, self.size[0]) 137 | 138 | img = F.crop(img, top, left, height, width) 139 | target = F.crop(target, top, left, height, width) 140 | 141 | return img, target 142 | 143 | 144 | class ToTensor: 145 | def __call__(self, img, target): 146 | target = torch.from_numpy(np.array(target, dtype=np.int64, copy=False)) 147 | return F.to_tensor(img), target 148 | 149 | 150 | class Pad: 151 | def __init__(self, size): 152 | self.size = size 153 | 154 | def __call__(self, img, target): 155 | _, h, w = img.shape 156 | 157 | if h == self.size and w == self.size: 158 | return img, target 159 | 160 | h_pad = self.size - h 161 | w_pad = self.size - w 162 | 163 | img = pad(img, [0, w_pad, 0, h_pad]) 164 | target = pad(target, [0, w_pad, 0, h_pad]) 165 | 166 | return img, target 167 | 168 | 169 | class Normalize: 170 | def __init__(self, mean, std): 171 | self.mean = mean 172 | self.std = std 173 | 174 | def __call__(self, img, target): 175 | img = F.normalize(img, mean=self.mean, std=self.std) 176 | 177 | return img, target 178 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | from torchvision.utils import make_grid 6 | 7 | 8 | @torch.no_grad() 9 | def intersection_union(gt, pred, correct, n_class): 10 | intersect = pred * correct 11 | 12 | area_intersect = torch.histc(intersect, bins=n_class, min=1, max=n_class) 13 | area_pred = torch.histc(pred, bins=n_class, min=1, max=n_class) 14 | area_gt = torch.histc(gt, bins=n_class, min=1, max=n_class) 15 | 16 | # intersect = intersect.detach().to('cpu').numpy() 17 | # pred = pred.detach().to('cpu').numpy() 18 | # gt = gt.detach().to('cpu').numpy() 19 | # area_intersect, _ = np.histogram(intersect, bins=n_class, range=(1, n_class)) 20 | # area_pred, _ = np.histogram(pred, bins=n_class, range=(1, n_class)) 21 | # area_gt, _ = np.histogram(gt, bins=n_class, range=(1, n_class)) 22 | 23 | area_union = area_pred + area_gt - area_intersect 24 | 25 | return area_intersect, area_union 26 | 27 | 28 | def get_colormap(filename): 29 | colors = np.load(filename) 30 | colors = np.pad(colors, [(1, 0), (0, 0)], 'constant', constant_values=0) 31 | colors = torch.from_numpy(colors).type(torch.float32) 32 | 33 | return colors 34 | 35 | 36 | @torch.no_grad() 37 | def show_segmentation(img, gt, pred, mean, std, colormap): 38 | colormap = colormap.to(img.device) 39 | gt = F.embedding(gt, colormap).permute(2, 0, 1).div(255) 40 | pred = F.embedding(pred, colormap).permute(2, 0, 1).div(255) 41 | mean = torch.as_tensor(mean, dtype=torch.float32, device=img.device) 42 | std = torch.as_tensor(std, dtype=torch.float32, device=img.device) 43 | img = img * std[:, None, None] + mean[:, None, None] 44 | grid = torch.stack([img, gt, pred], 0) 45 | grid = make_grid(grid, nrow=3) 46 | grid = ( 47 | grid.mul_(255) 48 | .add_(0.5) 49 | .clamp_(0, 255) 50 | .permute(1, 2, 0) 51 | .to('cpu', torch.uint8) 52 | .numpy() 53 | ) 54 | img = Image.fromarray(grid) 55 | 56 | return img 57 | -------------------------------------------------------------------------------- /vovnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | __all__ = ['VoVNet', 'vovnet27_slim', 'vovnet39', 'vovnet57'] 8 | 9 | 10 | model_urls = { 11 | 'vovnet39': './vovnet39_torchvision.pth', 12 | 'vovnet57': './vovnet57_torchvision.pth', 13 | } 14 | 15 | 16 | def conv3x3( 17 | in_channels, 18 | out_channels, 19 | module_name, 20 | postfix, 21 | stride=1, 22 | groups=1, 23 | kernel_size=3, 24 | padding=1, 25 | dilation=1, 26 | ): 27 | """3x3 convolution with padding""" 28 | if dilation != 1: 29 | padding = dilation 30 | 31 | return [ 32 | ( 33 | '{}_{}/conv'.format(module_name, postfix), 34 | nn.Conv2d( 35 | in_channels, 36 | out_channels, 37 | kernel_size=kernel_size, 38 | stride=stride, 39 | padding=padding, 40 | groups=groups, 41 | dilation=dilation, 42 | bias=False, 43 | ), 44 | ), 45 | ('{}_{}/norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)), 46 | ('{}_{}/relu'.format(module_name, postfix), nn.ReLU(inplace=True)), 47 | ] 48 | 49 | 50 | def conv1x1( 51 | in_channels, 52 | out_channels, 53 | module_name, 54 | postfix, 55 | stride=1, 56 | groups=1, 57 | kernel_size=1, 58 | padding=0, 59 | ): 60 | """1x1 convolution""" 61 | return [ 62 | ( 63 | '{}_{}/conv'.format(module_name, postfix), 64 | nn.Conv2d( 65 | in_channels, 66 | out_channels, 67 | kernel_size=kernel_size, 68 | stride=stride, 69 | padding=padding, 70 | groups=groups, 71 | bias=False, 72 | ), 73 | ), 74 | ('{}_{}/norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)), 75 | ('{}_{}/relu'.format(module_name, postfix), nn.ReLU(inplace=True)), 76 | ] 77 | 78 | 79 | class _OSA_module(nn.Module): 80 | def __init__( 81 | self, 82 | in_ch, 83 | stage_ch, 84 | concat_ch, 85 | layer_per_block, 86 | module_name, 87 | identity=False, 88 | dilation=1, 89 | ): 90 | super(_OSA_module, self).__init__() 91 | 92 | self.identity = identity 93 | self.layers = nn.ModuleList() 94 | in_channel = in_ch 95 | for i in range(layer_per_block): 96 | self.layers.append( 97 | nn.Sequential( 98 | OrderedDict( 99 | conv3x3(in_channel, stage_ch, module_name, i, dilation=dilation) 100 | ) 101 | ) 102 | ) 103 | in_channel = stage_ch 104 | 105 | # feature aggregation 106 | in_channel = in_ch + layer_per_block * stage_ch 107 | self.concat = nn.Sequential( 108 | OrderedDict(conv1x1(in_channel, concat_ch, module_name, 'concat')) 109 | ) 110 | 111 | def forward(self, x): 112 | identity_feat = x 113 | output = [] 114 | output.append(x) 115 | for layer in self.layers: 116 | x = layer(x) 117 | output.append(x) 118 | 119 | x = torch.cat(output, dim=1) 120 | xt = self.concat(x) 121 | 122 | if self.identity: 123 | xt = xt + identity_feat 124 | 125 | return xt 126 | 127 | 128 | class _OSA_stage(nn.Sequential): 129 | def __init__( 130 | self, 131 | in_ch, 132 | stage_ch, 133 | concat_ch, 134 | block_per_stage, 135 | layer_per_block, 136 | stage_num, 137 | dilation=1, 138 | ): 139 | super(_OSA_stage, self).__init__() 140 | 141 | if not stage_num == 2 and dilation == 1: 142 | self.add_module( 143 | 'Pooling', nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 144 | ) 145 | 146 | module_name = f'OSA{stage_num}_1' 147 | self.add_module( 148 | module_name, 149 | _OSA_module( 150 | in_ch, 151 | stage_ch, 152 | concat_ch, 153 | layer_per_block, 154 | module_name, 155 | dilation=dilation, 156 | ), 157 | ) 158 | for i in range(block_per_stage - 1): 159 | module_name = f'OSA{stage_num}_{i+2}' 160 | self.add_module( 161 | module_name, 162 | _OSA_module( 163 | concat_ch, 164 | stage_ch, 165 | concat_ch, 166 | layer_per_block, 167 | module_name, 168 | identity=True, 169 | dilation=dilation, 170 | ), 171 | ) 172 | 173 | 174 | class VoVNet(nn.Module): 175 | def __init__( 176 | self, 177 | config_stage_ch, 178 | config_concat_ch, 179 | block_per_stage, 180 | layer_per_block, 181 | num_classes=1000, 182 | ): 183 | super(VoVNet, self).__init__() 184 | 185 | # Stem module 186 | stem = conv3x3(3, 64, 'stem', '1', 2) 187 | stem += conv3x3(64, 64, 'stem', '2', 1) 188 | stem += conv3x3(64, 128, 'stem', '3', 2) 189 | self.add_module('stem', nn.Sequential(OrderedDict(stem))) 190 | 191 | stem_out_ch = [128] 192 | in_ch_list = stem_out_ch + config_concat_ch[:-1] 193 | self.stage_names = [] 194 | for i in range(4): # num_stages 195 | name = 'stage%d' % (i + 2) 196 | self.stage_names.append(name) 197 | 198 | if i == 2: 199 | dilation = 2 200 | 201 | elif i == 3: 202 | dilation = 4 203 | 204 | else: 205 | dilation = 1 206 | 207 | self.add_module( 208 | name, 209 | _OSA_stage( 210 | in_ch_list[i], 211 | config_stage_ch[i], 212 | config_concat_ch[i], 213 | block_per_stage[i], 214 | layer_per_block, 215 | i + 2, 216 | dilation=dilation, 217 | ), 218 | ) 219 | 220 | # self.classifier = nn.Linear(config_concat_ch[-1], num_classes) 221 | 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | nn.init.kaiming_normal_(m.weight) 225 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 226 | nn.init.constant_(m.weight, 1) 227 | nn.init.constant_(m.bias, 0) 228 | elif isinstance(m, nn.Linear): 229 | nn.init.constant_(m.bias, 0) 230 | 231 | def forward(self, x): 232 | features = [] 233 | x = self.stem[:6](x) 234 | features.append(x) 235 | x = self.stem[6:](x) 236 | for name in self.stage_names: 237 | x = getattr(self, name)(x) 238 | features.append(x) 239 | # x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1) 240 | # x = self.classifier(x) 241 | # print([f.shape for f in features]) 242 | 243 | return features 244 | 245 | 246 | def _vovnet( 247 | arch, 248 | config_stage_ch, 249 | config_concat_ch, 250 | block_per_stage, 251 | layer_per_block, 252 | pretrained, 253 | progress, 254 | **kwargs, 255 | ): 256 | model = VoVNet( 257 | config_stage_ch, config_concat_ch, block_per_stage, layer_per_block, **kwargs 258 | ) 259 | if pretrained: 260 | state_dict = torch.load(model_urls[arch]) 261 | new_dict = OrderedDict() 262 | 263 | for k, v in state_dict.items(): 264 | key = k.replace('module.', '') 265 | new_dict[key] = v 266 | 267 | model.load_state_dict(new_dict, strict=False) 268 | return model 269 | 270 | 271 | def vovnet57(pretrained=False, progress=True, **kwargs): 272 | r"""Constructs a VoVNet-57 model as described in 273 | `"An Energy and GPU-Computation Efficient Backbone Networks" 274 | `_. 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained on ImageNet 277 | progress (bool): If True, displays a progress bar of the download to stderr 278 | """ 279 | return _vovnet( 280 | 'vovnet57', 281 | [128, 160, 192, 224], 282 | [256, 512, 768, 1024], 283 | [1, 1, 4, 3], 284 | 5, 285 | pretrained, 286 | progress, 287 | **kwargs, 288 | ) 289 | 290 | 291 | def vovnet39(pretrained=False, progress=True, **kwargs): 292 | r"""Constructs a VoVNet-39 model as described in 293 | `"An Energy and GPU-Computation Efficient Backbone Networks" 294 | `_. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | return _vovnet( 300 | 'vovnet39', 301 | [128, 160, 192, 224], 302 | [256, 512, 768, 1024], 303 | [1, 1, 2, 2], 304 | 5, 305 | pretrained, 306 | progress, 307 | **kwargs, 308 | ) 309 | 310 | 311 | def vovnet27_slim(pretrained=False, progress=True, **kwargs): 312 | r"""Constructs a VoVNet-39 model as described in 313 | `"An Energy and GPU-Computation Efficient Backbone Networks" 314 | `_. 315 | Args: 316 | pretrained (bool): If True, returns a model pre-trained on ImageNet 317 | progress (bool): If True, displays a progress bar of the download to stderr 318 | """ 319 | return _vovnet( 320 | 'vovnet27_slim', 321 | [64, 80, 96, 112], 322 | [128, 256, 384, 512], 323 | [1, 1, 1, 1], 324 | 5, 325 | pretrained, 326 | progress, 327 | **kwargs, 328 | ) 329 | 330 | --------------------------------------------------------------------------------