├── __init__.py ├── teaser.png ├── setup.py ├── zs3 ├── embeddings │ ├── pascal │ │ └── w2c │ │ │ └── norm_embed_arr_300.pkl │ └── context │ │ └── pascalcontext_class_w2c.npy ├── dataloaders │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── combine_dbs.py │ │ ├── pascal.py │ │ ├── sbd.py │ │ └── context.py │ ├── __init__.py │ ├── utils.py │ └── custom_transforms.py ├── data │ ├── pascal_context_unseen_classes.txt │ └── pascal_voc_unseen_classes.txt ├── modeling │ ├── backbone │ │ ├── __init__.py │ │ └── resnet.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── replicate.py │ │ ├── comm.py │ │ └── batchnorm.py │ ├── gmmn.py │ ├── decoder.py │ ├── deeplab.py │ └── aspp.py ├── utils │ ├── loss_GMMN.py │ ├── calculate_weights.py │ ├── summaries.py │ ├── lr_scheduler.py │ ├── saver.py │ ├── loss.py │ └── metrics.py ├── exp_data.py ├── base_trainer.py ├── parsing.py ├── train_context.py ├── train_pascal.py ├── eval_context.py ├── eval_pascal.py └── train_context_GMMN.py ├── Dockerfile ├── .gitignore ├── LICENSE └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valeoai/ZS3/HEAD/teaser.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | setup(name="ZS3", packages=find_packages()) 5 | -------------------------------------------------------------------------------- /zs3/embeddings/pascal/w2c/norm_embed_arr_300.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valeoai/ZS3/HEAD/zs3/embeddings/pascal/w2c/norm_embed_arr_300.pkl -------------------------------------------------------------------------------- /zs3/embeddings/context/pascalcontext_class_w2c.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/valeoai/ZS3/HEAD/zs3/embeddings/context/pascalcontext_class_w2c.npy -------------------------------------------------------------------------------- /zs3/dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .context import CONTEXT_DIR 2 | from .pascal import PASCAL_DIR 3 | from .sbd import SBD_DIR 4 | 5 | 6 | DATASETS_DIRS = {"pascal": PASCAL_DIR, "sbd": SBD_DIR, "context": CONTEXT_DIR} 7 | -------------------------------------------------------------------------------- /zs3/data/pascal_context_unseen_classes.txt: -------------------------------------------------------------------------------- 1 | 2 unseen: cow,motorbike 2 | 4 unseen: cow,motorbike,sofa,cat 3 | 6 unseen: cow,motorbike,sofa,cat,boat,fence 4 | 8 unseen: cow,motorbike,sofa,cat,boat,fence,bird,tvmonitor 5 | 10 unseen: cow,motorbike,sofa,cat,boat,fence,bird,tvmonitor,keyboard,aeroplane -------------------------------------------------------------------------------- /zs3/data/pascal_voc_unseen_classes.txt: -------------------------------------------------------------------------------- 1 | 2 unseen: cow,motorbike 2 | 4 unseen: cow,motorbike,airplane,sofa 3 | 6 unseen: cow,motorbike,airplane,sofa,cat,tv 4 | 8 unseen: cow,motorbike,airplane,sofa,cat,tv,train,bottle 5 | 10 unseen: cow,motorbike,airplane,sofa,cat,tv,train,bottle,chair,potted plant 6 | -------------------------------------------------------------------------------- /zs3/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from zs3.modeling.backbone import resnet 2 | 3 | 4 | def build_backbone( 5 | output_stride, BatchNorm, pretrained=True, imagenet_pretrained_path="" 6 | ): 7 | return resnet.ResNet101( 8 | output_stride, 9 | BatchNorm, 10 | pretrained=pretrained, 11 | imagenet_pretrained_path=imagenet_pretrained_path, 12 | ) 13 | -------------------------------------------------------------------------------- /zs3/modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # File : __init__.py 2 | # Author : Jiayuan Mao 3 | # Email : maojiayuan@gmail.com 4 | # Date : 27/01/2018 5 | # 6 | # This file is part of Synchronized-BatchNorm-PyTorch. 7 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 8 | # Distributed under MIT License. 9 | 10 | from .batchnorm import SynchronizedBatchNorm2d 11 | from .replicate import patch_replication_callback 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda:latest 2 | 3 | RUN conda config --set always_yes yes 4 | RUN conda install python=3.7 5 | 6 | RUN conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 7 | RUN conda install -c menpo opencv 8 | RUN pip install tensorboardX scikit-image tqdm pyyaml easydict future 9 | 10 | COPY ./ /ZS3 11 | RUN pip install -e /ZS3 12 | 13 | WORKDIR /ZS3/zs3 14 | ENV NVIDIA_VISIBLE_DEVICES all 15 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility 16 | -------------------------------------------------------------------------------- /zs3/utils/loss_GMMN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class GMMNLoss: 5 | def __init__(self, sigma=[2, 5, 10, 20, 40, 80], cuda=False): 6 | self.sigma = sigma 7 | self.cuda = cuda 8 | 9 | def build_loss(self): 10 | return self.moment_loss 11 | 12 | def get_scale_matrix(self, M, N): 13 | s1 = torch.ones((N, 1)) * 1.0 / N 14 | s2 = torch.ones((M, 1)) * -1.0 / M 15 | if self.cuda: 16 | s1, s2 = s1.cuda(), s2.cuda() 17 | return torch.cat((s1, s2), 0) 18 | 19 | def moment_loss(self, gen_samples, x): 20 | X = torch.cat((gen_samples, x), 0) 21 | XX = torch.matmul(X, X.t()) 22 | X2 = torch.sum(X * X, 1, keepdim=True) 23 | exp = XX - 0.5 * X2 - 0.5 * X2.t() 24 | M = gen_samples.size()[0] 25 | N = x.size()[0] 26 | s = self.get_scale_matrix(M, N) 27 | S = torch.matmul(s, s.t()) 28 | 29 | loss = 0 30 | for v in self.sigma: 31 | kernel_val = torch.exp(exp / v) 32 | loss += torch.sum(S * kernel_val) 33 | 34 | loss = torch.sqrt(loss) 35 | return loss 36 | -------------------------------------------------------------------------------- /zs3/utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | from zs3.dataloaders.datasets import DATASETS_DIRS 5 | 6 | 7 | def calculate_weigths_labels(dataset, dataloader, num_classes): 8 | # Create an instance from the data loader 9 | z = np.zeros((num_classes,)) 10 | # Initialize tqdm 11 | tqdm_batch = tqdm(dataloader) 12 | print("Calculating classes weights") 13 | for sample in tqdm_batch: 14 | y = sample["label"] 15 | y = y.detach().cpu().numpy() 16 | mask = (y >= 0) & (y < num_classes) 17 | labels = y[mask].astype(np.uint8) 18 | count_l = np.bincount(labels, minlength=num_classes) 19 | z += count_l 20 | tqdm_batch.close() 21 | total_frequency = np.sum(z) 22 | class_weights = [] 23 | for frequency in z: 24 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 25 | class_weights.append(class_weight) 26 | ret = np.array(class_weights) 27 | classes_weights_path = DATASETS_DIRS[dataset] / dataset + "_classes_weights.npy" 28 | np.save(classes_weights_path, ret) 29 | 30 | return ret 31 | -------------------------------------------------------------------------------- /zs3/exp_data.py: -------------------------------------------------------------------------------- 1 | CLASSES_NAMES = [ 2 | "background", # class 0 3 | "aeroplane", # class 1 4 | "bicycle", # class 2 5 | "bird", # class 3 6 | "boat", # class 4 7 | "bottle", # class 5 8 | "bus", # class 6 9 | "car", # class 7 10 | "cat", # class 8 11 | "chair", # class 9 12 | "cow", # class 10 13 | "table", # class 11 14 | "dog", # class 12 15 | "horse", # class 13 16 | "motorbike", # class 14 17 | "person", # class 15 18 | "pottedplant", # class 16 19 | "sheep", # class 17 20 | "sofa", # class 18 21 | "train", # class 19 22 | "tvmonitor", # class 20 23 | "bag", # class 21 24 | "bed", # class 22 25 | "bench", # class 23 26 | "book", # class 24 27 | "building", # class 25 28 | "cabinet", # class 26 29 | "ceiling", # class 27 30 | "cloth", # class 28 31 | "computer", # class 29 32 | "cup", # class 30 33 | "door", # class 31 34 | "fence", # class 32 35 | "floor", # class 33 36 | "flower", # class 34 37 | "food", # class 35 38 | "grass", # class 36 39 | "ground", # class 37 40 | "keyboard", # class 38 41 | "light", # class 39 42 | "mountain", # class 40 43 | "mouse", # class 41 44 | "curtain", # class 42 45 | "platform", # class 43 46 | "sign", # class 44 47 | "plate", # class 45 48 | "road", # class 46 49 | "rock", # class 47 50 | "shelves", # class 48 51 | "sidewalk", # class 49 52 | "sky", # class 50 53 | "snow", # class 51 54 | "bedclothes", # class 52 55 | "track", # class 53 56 | "tree", # class 54 57 | "truck", # class 55 58 | "wall", # class 56 59 | "water", # class 57 60 | "window", # class 58 61 | "wood", # class 59 62 | ] 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | 107 | zs3/checkpoint 108 | run/ 109 | data/VOC2012 110 | zs3/data/VOC2012 111 | -------------------------------------------------------------------------------- /zs3/dataloaders/datasets/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import pickle 3 | import torch 4 | import numpy as np 5 | from pathlib import Path 6 | 7 | 8 | class BaseDataset(data.Dataset): 9 | def __init__( 10 | self, 11 | args, 12 | base_dir, 13 | split, 14 | load_embedding, 15 | w2c_size, 16 | weak_label, 17 | unseen_classes_idx_weak, 18 | transform, 19 | ): 20 | super().__init__() 21 | self.args = args 22 | self._base_dir = Path(base_dir) 23 | self.split = split 24 | self.load_embedding = load_embedding 25 | self.w2c_size = w2c_size 26 | self.embeddings = None 27 | if self.load_embedding: 28 | self.init_embeddings() 29 | self.images = [] 30 | self.weak_label = weak_label 31 | self.unseen_classes_idx_weak = unseen_classes_idx_weak 32 | self.transform = transform 33 | 34 | def __len__(self): 35 | return len(self.images) 36 | 37 | def init_embeddings(self): 38 | raise NotImplementedError 39 | 40 | def make_embeddings(self, embed_arr): 41 | self.embeddings = torch.nn.Embedding(embed_arr.shape[0], embed_arr.shape[1]) 42 | self.embeddings.weight.requires_grad = False 43 | self.embeddings.weight.data.copy_(torch.from_numpy(embed_arr)) 44 | 45 | def get_embeddings(self, sample): 46 | mask = sample["label"] == 255 47 | sample["label"][mask] = 0 48 | lbl_vec = self.embeddings(sample["label"].long()).data 49 | lbl_vec = lbl_vec.permute(2, 0, 1) 50 | sample["label"][mask] = 255 51 | sample["label_emb"] = lbl_vec 52 | 53 | 54 | def load_obj(name): 55 | with open(name + ".pkl", "rb") as f: 56 | return pickle.load(f, encoding="latin-1") 57 | 58 | 59 | def lbl_contains_unseen(lbl, unseen): 60 | unseen_pixel_mask = np.in1d(lbl.ravel(), unseen) 61 | if np.sum(unseen_pixel_mask) > 0: # ignore images with any train_unseen pixels 62 | return True 63 | return False 64 | -------------------------------------------------------------------------------- /zs3/dataloaders/datasets/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | NUM_CLASSES = 21 6 | 7 | def __init__(self, dataloaders, excluded=None): 8 | self.dataloaders = dataloaders 9 | self.excluded = excluded 10 | self.im_ids = [] 11 | 12 | # Combine object lists 13 | for dl in dataloaders: 14 | for elem in dl.im_ids: 15 | if elem not in self.im_ids: 16 | self.im_ids.append(elem) 17 | 18 | # Exclude 19 | if excluded: 20 | for dl in excluded: 21 | for elem in dl.im_ids: 22 | if elem in self.im_ids: 23 | self.im_ids.remove(elem) 24 | 25 | # Get object pointers 26 | self.cat_list = [] 27 | new_im_ids = [] 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | num_images += 1 33 | new_im_ids.append(curr_im_id) 34 | self.cat_list.append({"db_ii": ii, "cat_ii": jj}) 35 | 36 | self.im_ids = new_im_ids 37 | print(f"Combined number of images: {num_images:d}") 38 | 39 | def __getitem__(self, index): 40 | 41 | _db_ii = self.cat_list[index]["db_ii"] 42 | _cat_ii = self.cat_list[index]["cat_ii"] 43 | sample = self.dataloaders[_db_ii].__getitem__(_cat_ii) 44 | 45 | if "meta" in sample.keys(): 46 | sample["meta"]["db"] = str(self.dataloaders[_db_ii]) 47 | 48 | return sample 49 | 50 | def __len__(self): 51 | return len(self.cat_list) 52 | 53 | def __str__(self): 54 | include_db = [str(db) for db in self.dataloaders] 55 | exclude_db = [str(db) for db in self.excluded] 56 | return ( 57 | "Included datasets:" 58 | + str(include_db) 59 | + "\n" 60 | + "Excluded datasets:" 61 | + str(exclude_db) 62 | ) 63 | -------------------------------------------------------------------------------- /zs3/base_trainer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | 4 | class BaseTrainer: 5 | def training(self, epoch): 6 | train_loss = 0.0 7 | self.model.train() 8 | tbar = tqdm(self.train_loader) 9 | num_img_tr = len(self.train_loader) 10 | for i, sample in enumerate(tbar): 11 | if len(sample["image"]) > 1: 12 | image, target = sample["image"], sample["label"] 13 | if self.args.cuda: 14 | image, target = image.cuda(), target.cuda() 15 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 16 | self.optimizer.zero_grad() 17 | output = self.model(image) 18 | loss = self.criterion(output, target) 19 | loss.backward() 20 | self.optimizer.step() 21 | train_loss += loss.item() 22 | tbar.set_description("Train loss: %.3f" % (train_loss / (i + 1))) 23 | self.writer.add_scalar( 24 | "train/total_loss_iter", loss.item(), i + num_img_tr * epoch 25 | ) 26 | 27 | # Show 10 * 3 inference results each epoch 28 | if i % (num_img_tr // 10) == 0: 29 | global_step = i + num_img_tr * epoch 30 | self.summary.visualize_image( 31 | self.writer, 32 | self.args.dataset, 33 | image, 34 | target, 35 | output, 36 | global_step, 37 | ) 38 | 39 | self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch) 40 | print( 41 | "[Epoch: %d, numImages: %5d]" 42 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 43 | ) 44 | print(f"Loss: {train_loss:.3f}") 45 | 46 | if self.args.no_val: 47 | # save checkpoint every epoch 48 | is_best = False 49 | self.saver.save_checkpoint( 50 | { 51 | "epoch": epoch + 1, 52 | "state_dict": self.model.module.state_dict(), 53 | "optimizer": self.optimizer.state_dict(), 54 | "best_pred": self.best_pred, 55 | }, 56 | is_best, 57 | ) 58 | -------------------------------------------------------------------------------- /zs3/modeling/gmmn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pygcn.layers import GraphConvolution 3 | from torch import nn 4 | 5 | 6 | class GMMNnetwork(nn.Module): 7 | def __init__( 8 | self, 9 | noise_dim, 10 | embed_dim, 11 | hidden_size, 12 | feature_dim, 13 | semantic_reconstruction=False, 14 | ): 15 | super().__init__() 16 | 17 | def block(in_feat, out_feat): 18 | layers = [nn.Linear(in_feat, out_feat)] 19 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 20 | layers.append(nn.Dropout(p=0.5)) 21 | return layers 22 | 23 | def init_weights(m): 24 | if type(m) == nn.Linear: 25 | torch.nn.init.xavier_uniform_(m.weight) 26 | m.bias.data.fill_(0.01) 27 | 28 | if hidden_size: 29 | self.model = nn.Sequential( 30 | *block(noise_dim + embed_dim, hidden_size), 31 | nn.Linear(hidden_size, feature_dim), 32 | ) 33 | else: 34 | self.model = nn.Linear(noise_dim + embed_dim, feature_dim) 35 | 36 | self.model.apply(init_weights) 37 | self.semantic_reconstruction = semantic_reconstruction 38 | if self.semantic_reconstruction: 39 | self.semantic_reconstruction_layer = nn.Linear( 40 | feature_dim, noise_dim + embed_dim 41 | ) 42 | 43 | def forward(self, embd, noise): 44 | features = self.model(torch.cat((embd, noise), 1)) 45 | if self.semantic_reconstruction: 46 | semantic = self.semantic_reconstruction_layer(features) 47 | return features, semantic 48 | else: 49 | return features 50 | 51 | 52 | class GMMNnetwork_GCN(nn.Module): 53 | def __init__(self, noise_dim=300, embed_dim=300, hidden_size=256, feature_dim=256): 54 | super().__init__() 55 | self.gcn1 = GraphConvolution(noise_dim + embed_dim, hidden_size) 56 | self.relu = nn.LeakyReLU(0.2) 57 | self.dropout = nn.Dropout(p=0.5) 58 | self.gcn2 = GraphConvolution(hidden_size, feature_dim) 59 | for m in self.modules(): 60 | if isinstance(m, GraphConvolution): 61 | torch.nn.init.xavier_uniform_(m.weight) 62 | m.bias.data.fill_(0.01) 63 | 64 | def forward(self, embd, noise, adj_mat): 65 | x = self.gcn1(torch.cat((embd, noise), 1), adj_mat) 66 | x = self.dropout(self.relu(x)) 67 | return self.gcn2(x, adj_mat) 68 | -------------------------------------------------------------------------------- /zs3/modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # File : replicate.py 2 | # Author : Jiayuan Mao 3 | # Email : maojiayuan@gmail.com 4 | # Date : 27/01/2018 5 | # 6 | # This file is part of Synchronized-BatchNorm-PyTorch. 7 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 8 | # Distributed under MIT License. 9 | 10 | import functools 11 | 12 | from torch.nn.parallel.data_parallel import DataParallel 13 | 14 | __all__ = [ 15 | "CallbackContext", 16 | "execute_replication_callbacks", 17 | "patch_replication_callback", 18 | ] 19 | 20 | 21 | class CallbackContext: 22 | pass 23 | 24 | 25 | def execute_replication_callbacks(modules): 26 | """ 27 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 28 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 29 | Note that, as all modules are isomorphism, we assign each sub-module with a context 30 | (shared among multiple copies of this module on different devices). 31 | Through this context, different copies can share some information. 32 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 33 | of any slave copies. 34 | """ 35 | master_copy = modules[0] 36 | nr_modules = len(list(master_copy.modules())) 37 | ctxs = [CallbackContext() for _ in range(nr_modules)] 38 | 39 | for i, module in enumerate(modules): 40 | for j, m in enumerate(module.modules()): 41 | if hasattr(m, "__data_parallel_replicate__"): 42 | m.__data_parallel_replicate__(ctxs[j], i) 43 | 44 | 45 | def patch_replication_callback(data_parallel): 46 | """ 47 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 48 | Useful when you have customized `DataParallel` implementation. 49 | Examples: 50 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 51 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 52 | > patch_replication_callback(sync_bn) 53 | # this is equivalent to 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | """ 57 | 58 | assert isinstance(data_parallel, DataParallel) 59 | 60 | old_replicate = data_parallel.replicate 61 | 62 | @functools.wraps(old_replicate) 63 | def new_replicate(module, device_ids): 64 | modules = old_replicate(module, device_ids) 65 | execute_replication_callbacks(modules) 66 | return modules 67 | 68 | data_parallel.replicate = new_replicate 69 | -------------------------------------------------------------------------------- /zs3/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from zs3.dataloaders.datasets import combine_dbs, pascal, sbd, context 4 | 5 | 6 | def make_data_loader( 7 | args, 8 | transform=True, 9 | load_embedding=None, 10 | w2c_size=300, 11 | weak_label=False, 12 | unseen_classes_idx_weak=[], 13 | **kwargs, 14 | ): 15 | if args.dataset == "pascal": 16 | train_set = pascal.VOCSegmentation( 17 | args, 18 | transform=transform, 19 | split="train", 20 | load_embedding=load_embedding, 21 | w2c_size=w2c_size, 22 | weak_label=weak_label, 23 | unseen_classes_idx_weak=unseen_classes_idx_weak, 24 | ) 25 | val_set = pascal.VOCSegmentation( 26 | args, split="val", load_embedding=load_embedding, w2c_size=w2c_size 27 | ) 28 | if args.use_sbd: 29 | sbd_train = sbd.SBDSegmentation( 30 | args, 31 | transform=transform, 32 | split=["train_noval"], 33 | load_embedding=load_embedding, 34 | w2c_size=w2c_size, 35 | weak_label=weak_label, 36 | unseen_classes_idx_weak=unseen_classes_idx_weak, 37 | ) 38 | train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) 39 | 40 | num_class = train_set.NUM_CLASSES 41 | train_loader = DataLoader( 42 | train_set, batch_size=args.batch_size, shuffle=True, **kwargs 43 | ) 44 | val_loader = DataLoader( 45 | val_set, batch_size=args.test_batch_size, shuffle=False, **kwargs 46 | ) 47 | test_loader = None 48 | return train_loader, val_loader, test_loader, num_class 49 | 50 | elif args.dataset == "context": 51 | train_set = context.ContextSegmentation( 52 | args, 53 | transform=transform, 54 | split="train", 55 | load_embedding=load_embedding, 56 | w2c_size=w2c_size, 57 | weak_label=weak_label, 58 | unseen_classes_idx_weak=unseen_classes_idx_weak, 59 | ) 60 | val_set = context.ContextSegmentation( 61 | args, split="val", load_embedding=load_embedding, w2c_size=w2c_size 62 | ) 63 | num_class = train_set.NUM_CLASSES 64 | train_loader = DataLoader( 65 | train_set, batch_size=args.batch_size, shuffle=True, **kwargs 66 | ) 67 | val_loader = DataLoader( 68 | val_set, batch_size=args.test_batch_size, shuffle=False, **kwargs 69 | ) 70 | test_loader = None 71 | return train_loader, val_loader, test_loader, num_class 72 | else: 73 | raise NotImplementedError 74 | -------------------------------------------------------------------------------- /zs3/utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | from torchvision.utils import make_grid 6 | 7 | from zs3.dataloaders.utils import decode_seg_map_sequence 8 | 9 | 10 | class TensorboardSummary: 11 | def __init__(self, directory): 12 | self.directory = directory 13 | 14 | def create_summary(self): 15 | writer = SummaryWriter(os.path.join(self.directory)) 16 | return writer 17 | 18 | def visualize_image( 19 | self, 20 | writer, 21 | dataset, 22 | image, 23 | target, 24 | output, 25 | global_step, 26 | name="Train", 27 | nb_image=3, 28 | ): 29 | grid_image = make_grid( 30 | image[:nb_image].clone().cpu().data, nb_image, normalize=True 31 | ) 32 | writer.add_image(name + "_Image", grid_image, global_step) 33 | grid_image = make_grid( 34 | decode_seg_map_sequence( 35 | torch.max(output[:nb_image], 1)[1].detach().cpu().numpy(), 36 | dataset=dataset, 37 | ), 38 | nb_image, 39 | normalize=False, 40 | range=(0, 255), 41 | ) 42 | writer.add_image(name + "_Predicted label", grid_image, global_step) 43 | grid_image = make_grid( 44 | decode_seg_map_sequence( 45 | torch.squeeze(target[:nb_image], 1).detach().cpu().numpy(), 46 | dataset=dataset, 47 | ), 48 | nb_image, 49 | normalize=False, 50 | range=(0, 255), 51 | ) 52 | writer.add_image(name + "_Groundtruth label", grid_image, global_step) 53 | 54 | def visualize_image_validation( 55 | self, 56 | writer, 57 | dataset, 58 | image, 59 | target, 60 | output, 61 | global_step, 62 | name="Train", 63 | nb_image=3, 64 | ): 65 | grid_image = make_grid(image.data, nb_image, normalize=True) 66 | writer.add_image(name + "_Image", grid_image, global_step) 67 | grid_image = make_grid( 68 | decode_seg_map_sequence( 69 | torch.max(output, 1)[1].detach().numpy(), dataset=dataset 70 | ), 71 | nb_image, 72 | normalize=False, 73 | range=(0, 255), 74 | ) 75 | writer.add_image(name + "_Predicted label", grid_image, global_step) 76 | grid_image = make_grid( 77 | decode_seg_map_sequence( 78 | torch.squeeze(target[:nb_image], 1).detach().numpy(), dataset=dataset 79 | ), 80 | nb_image, 81 | normalize=False, 82 | range=(0, 255), 83 | ) 84 | writer.add_image(name + "_Groundtruth label", grid_image, global_step) 85 | -------------------------------------------------------------------------------- /zs3/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | 14 | class LR_Scheduler: 15 | """Learning Rate Scheduler 16 | 17 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 18 | 19 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 20 | 21 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 22 | 23 | Args: 24 | args: 25 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 26 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 27 | :attr:`args.lr_step` 28 | 29 | iters_per_epoch: number of iterations per epoch 30 | """ 31 | 32 | def __init__( 33 | self, mode, base_lr, num_epochs, iters_per_epoch=0, lr_step=0, warmup_epochs=0 34 | ): 35 | self.mode = mode 36 | print(f"Using {self.mode} LR Scheduler!") 37 | self.lr = base_lr 38 | if mode == "step": 39 | assert lr_step 40 | self.lr_step = lr_step 41 | self.iters_per_epoch = iters_per_epoch 42 | self.N = num_epochs * iters_per_epoch 43 | self.epoch = -1 44 | self.warmup_iters = warmup_epochs * iters_per_epoch 45 | 46 | def __call__(self, optimizer, i, epoch, best_pred): 47 | T = epoch * self.iters_per_epoch + i 48 | if self.mode == "cos": 49 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 50 | elif self.mode == "poly": 51 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 52 | elif self.mode == "step": 53 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 54 | else: 55 | raise NotImplemented 56 | # warm up lr schedule 57 | if self.warmup_iters > 0 and T < self.warmup_iters: 58 | lr = lr * 1.0 * T / self.warmup_iters 59 | if epoch > self.epoch: 60 | print( 61 | "\n=>Epoches %i, learning rate = %.4f, \ 62 | previous best = %.4f" 63 | % (epoch, lr, best_pred) 64 | ) 65 | self.epoch = epoch 66 | assert lr >= 0 67 | self._adjust_learning_rate(optimizer, lr) 68 | 69 | def _adjust_learning_rate(self, optimizer, lr): 70 | if len(optimizer.param_groups) == 1: 71 | optimizer.param_groups[0]["lr"] = lr 72 | else: 73 | # enlarge the lr at the head 74 | optimizer.param_groups[0]["lr"] = lr 75 | for i in range(1, len(optimizer.param_groups)): 76 | optimizer.param_groups[i]["lr"] = lr * 10 77 | -------------------------------------------------------------------------------- /zs3/parsing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") 6 | parser.add_argument( 7 | "--workers", type=int, default=6, metavar="N", help="dataloader threads" 8 | ) 9 | parser.add_argument( 10 | "--freeze-bn", 11 | type=bool, 12 | default=False, 13 | help="whether to freeze bn parameters (default: False)", 14 | ) 15 | parser.add_argument( 16 | "--exp_path", type=str, default="run", help="set the checkpoint name" 17 | ) 18 | parser.add_argument( 19 | "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" 20 | ) 21 | parser.add_argument( 22 | "--no-cuda", action="store_true", default=False, help="disables CUDA training" 23 | ) 24 | parser.add_argument( 25 | "--gpu-ids", 26 | type=str, 27 | default="0", 28 | help="use which gpu to train, must be a \ 29 | comma-separated list of integers only (default=0)", 30 | ) 31 | parser.add_argument( 32 | "--start_epoch", 33 | type=int, 34 | default=0, 35 | metavar="N", 36 | help="start epochs (default:0)", 37 | ) 38 | parser.add_argument( 39 | "--test-batch-size", 40 | type=int, 41 | default=1, 42 | metavar="N", 43 | help="input batch size for testing (default: auto)", 44 | ) 45 | # finetuning pre-trained models 46 | parser.add_argument( 47 | "--ft", 48 | action="store_true", 49 | default=False, 50 | help="finetuning on a different dataset", 51 | ) 52 | parser.add_argument( 53 | "--no-val", 54 | action="store_true", 55 | default=False, 56 | help="skip validation during training", 57 | ) 58 | parser.add_argument( 59 | "--use-balanced-weights", 60 | action="store_true", 61 | default=False, 62 | help="whether to use balanced weights (default: False)", 63 | ) 64 | # optimizer params 65 | # PASCAL VOC 66 | parser.add_argument( 67 | "--lr", 68 | type=float, 69 | default=0.007, 70 | metavar="LR", 71 | help="learning rate (default: auto)", 72 | ) 73 | 74 | parser.add_argument( 75 | "--lr-scheduler", 76 | type=str, 77 | default="poly", 78 | choices=["poly", "step", "cos"], 79 | help="lr scheduler mode: (default: poly)", 80 | ) 81 | parser.add_argument( 82 | "--momentum", 83 | type=float, 84 | default=0.9, 85 | metavar="M", 86 | help="momentum (default: 0.9)", 87 | ) 88 | parser.add_argument( 89 | "--weight-decay", 90 | type=float, 91 | default=5e-4, 92 | metavar="M", 93 | help="w-decay (default: 5e-4)", 94 | ) 95 | parser.add_argument( 96 | "--nesterov", 97 | action="store_true", 98 | default=False, 99 | help="whether use nesterov (default: False)", 100 | ) 101 | return parser 102 | -------------------------------------------------------------------------------- /zs3/modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from zs3.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, num_classes, BatchNorm): 10 | super().__init__() 11 | low_level_inplanes = 256 12 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 13 | self.bn1 = BatchNorm(48) 14 | self.relu = nn.ReLU() 15 | self.last_conv = nn.Sequential( 16 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 17 | BatchNorm(256), 18 | nn.ReLU(), 19 | nn.Dropout(0.5), 20 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 21 | BatchNorm(256), 22 | nn.ReLU(), 23 | nn.Dropout(0.1), 24 | ) 25 | 26 | self.pred_conv = nn.Conv2d(256, num_classes, kernel_size=1, stride=1) 27 | self._init_weight() 28 | 29 | def forward(self, x, low_level_feat): 30 | low_level_feat = self.conv1(low_level_feat) 31 | low_level_feat = self.bn1(low_level_feat) 32 | low_level_feat = self.relu(low_level_feat) 33 | 34 | x = F.interpolate( 35 | x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True 36 | ) 37 | x = torch.cat((x, low_level_feat), dim=1) 38 | x = self.last_conv(x) 39 | x = self.pred_conv(x) 40 | return x 41 | 42 | def forward_before_class_prediction(self, x, low_level_feat): 43 | low_level_feat = self.conv1(low_level_feat) 44 | low_level_feat = self.bn1(low_level_feat) 45 | low_level_feat = self.relu(low_level_feat) 46 | 47 | x = F.interpolate( 48 | x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True 49 | ) 50 | x = torch.cat((x, low_level_feat), dim=1) 51 | x = self.last_conv(x) 52 | return x 53 | 54 | def forward_before_last_conv_finetune(self, x, low_level_feat): 55 | low_level_feat = self.conv1(low_level_feat) 56 | low_level_feat = self.bn1(low_level_feat) 57 | low_level_feat = self.relu(low_level_feat) 58 | 59 | x = F.interpolate( 60 | x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True 61 | ) 62 | x = torch.cat((x, low_level_feat), dim=1) 63 | x = self.last_conv[:4](x) 64 | return x 65 | 66 | def forward_class_prediction(self, x): 67 | x = self.pred_conv(x) 68 | return x 69 | 70 | def forward_class_last_conv_finetune(self, x): 71 | x = self.last_conv[4:](x) 72 | return x 73 | 74 | def _init_weight(self): 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | torch.nn.init.kaiming_normal_(m.weight) 78 | elif isinstance(m, SynchronizedBatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.BatchNorm2d): 82 | m.weight.data.fill_(1) 83 | m.bias.data.zero_() 84 | 85 | 86 | def build_decoder(num_classes, BatchNorm): 87 | return Decoder(num_classes, BatchNorm) 88 | -------------------------------------------------------------------------------- /zs3/utils/saver.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | from collections import OrderedDict 5 | 6 | import torch 7 | 8 | 9 | class Saver: 10 | def __init__(self, args): 11 | self.args = args 12 | self.directory = os.path.join(self.args.exp_path, args.dataset, args.checkname) 13 | self.runs = sorted(glob.glob(os.path.join(self.directory, "experiment_*"))) 14 | run_id = int(self.runs[-1].split("_")[-1]) + 1 if self.runs else 0 15 | 16 | self.experiment_dir = os.path.join(self.directory, f"experiment_{str(run_id)}") 17 | print("experiment_dir: ", self.experiment_dir) 18 | if not os.path.exists(self.experiment_dir): 19 | os.makedirs(self.experiment_dir) 20 | 21 | with open(self.experiment_dir + "/args.txt", "w") as text_file: 22 | text_file.write(str(args)) 23 | 24 | def save_checkpoint( 25 | self, state, is_best, filename="checkpoint.pth.tar", generator_state=None 26 | ): 27 | """Saves checkpoint to disk""" 28 | filename_generator = os.path.join(self.experiment_dir, "generator_" + filename) 29 | filename = os.path.join(self.experiment_dir, filename) 30 | torch.save(state, filename) 31 | torch.save(generator_state, filename_generator) 32 | if is_best: 33 | best_pred = state["best_pred"] 34 | with open(os.path.join(self.experiment_dir, "best_pred.txt"), "w") as f: 35 | f.write(str(best_pred)) 36 | if self.runs: 37 | previous_miou = [0.0] 38 | for run in self.runs: 39 | run_id = run.split("_")[-1] 40 | path = os.path.join( 41 | self.directory, f"experiment_{str(run_id)}", "best_pred.txt", 42 | ) 43 | if os.path.exists(path): 44 | with open(path, "r") as f: 45 | miou = float(f.readline()) 46 | previous_miou.append(miou) 47 | else: 48 | continue 49 | shutil.copyfile( 50 | filename, 51 | os.path.join( 52 | self.experiment_dir, str(state["epoch"]) + "_model.pth.tar" 53 | ), 54 | ) 55 | shutil.copyfile( 56 | filename_generator, 57 | os.path.join( 58 | self.experiment_dir, str(state["epoch"]) + "_generator.pth.tar" 59 | ), 60 | ) 61 | else: 62 | shutil.copyfile( 63 | filename, 64 | os.path.join( 65 | self.experiment_dir, str(state["epoch"]) + "_model.pth.tar" 66 | ), 67 | ) 68 | 69 | def save_experiment_config(self): 70 | logfile = os.path.join(self.experiment_dir, "parameters.txt") 71 | log_file = open(logfile, "w") 72 | p = OrderedDict() 73 | p["datset"] = self.args.dataset 74 | p["out_stride"] = self.args.out_stride 75 | p["lr"] = self.args.lr 76 | p["lr_scheduler"] = self.args.lr_scheduler 77 | p["loss_type"] = self.args.loss_type 78 | p["epoch"] = self.args.epochs 79 | p["base_size"] = self.args.base_size 80 | p["crop_size"] = self.args.crop_size 81 | 82 | for key, val in p.items(): 83 | log_file.write(key + ":" + str(val) + "\n") 84 | log_file.close() 85 | -------------------------------------------------------------------------------- /zs3/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def decode_seg_map_sequence(label_masks, dataset="pascal"): 7 | rgb_masks = [] 8 | for label_mask in label_masks: 9 | rgb_mask = decode_segmap(label_mask, dataset) 10 | rgb_masks.append(rgb_mask) 11 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 12 | return rgb_masks 13 | 14 | 15 | def decode_segmap(label_mask, dataset, plot=False): 16 | """Decode segmentation class labels into a color image 17 | Args: 18 | label_mask (np.ndarray): an (M,N) array of integer values denoting 19 | the class label at each spatial location. 20 | plot (bool, optional): whether to show the resulting color image 21 | in a figure. 22 | Returns: 23 | (np.ndarray, optional): the resulting decoded color image. 24 | """ 25 | if dataset == "pascal": 26 | n_classes = 21 27 | label_colours = get_pascal_labels() 28 | elif dataset == "context": 29 | n_classes = 60 30 | label_colours = make_palette(n_classes) 31 | else: 32 | raise NotImplementedError 33 | 34 | r = label_mask.copy() 35 | g = label_mask.copy() 36 | b = label_mask.copy() 37 | for ll in range(0, n_classes): 38 | r[label_mask == ll] = label_colours[ll, 0] 39 | g[label_mask == ll] = label_colours[ll, 1] 40 | b[label_mask == ll] = label_colours[ll, 2] 41 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 42 | rgb[:, :, 0] = r / 255.0 43 | rgb[:, :, 1] = g / 255.0 44 | rgb[:, :, 2] = b / 255.0 45 | if plot: 46 | plt.imshow(rgb) 47 | plt.show() 48 | else: 49 | return rgb 50 | 51 | 52 | def get_pascal_labels(): 53 | """Load the mapping that associates pascal classes with label colors 54 | Returns: 55 | np.ndarray with dimensions (21, 3) 56 | """ 57 | return np.asarray( 58 | [ 59 | [0, 0, 0], 60 | [128, 0, 0], 61 | [0, 128, 0], 62 | [128, 128, 0], 63 | [0, 0, 128], 64 | [128, 0, 128], 65 | [0, 128, 128], 66 | [128, 128, 128], 67 | [64, 0, 0], 68 | [192, 0, 0], 69 | [64, 128, 0], 70 | [192, 128, 0], 71 | [64, 0, 128], 72 | [192, 0, 128], 73 | [64, 128, 128], 74 | [192, 128, 128], 75 | [0, 64, 0], 76 | [128, 64, 0], 77 | [0, 192, 0], 78 | [128, 192, 0], 79 | [0, 64, 128], 80 | ] 81 | ) 82 | 83 | 84 | def make_palette(num_classes): 85 | """ 86 | Maps classes to colors in the style of PASCAL VOC. 87 | Close values are mapped to far colors for segmentation visualization. 88 | See http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 89 | Takes: 90 | num_classes: the number of classes 91 | Gives: 92 | palette: the colormap as a k x 3 array of RGB colors 93 | """ 94 | palette = np.zeros((num_classes, 3), dtype=np.uint8) 95 | for k in range(0, num_classes): 96 | label = k 97 | i = 0 98 | while label: 99 | palette[k, 0] |= ((label >> 0) & 1) << (7 - i) 100 | palette[k, 1] |= ((label >> 1) & 1) << (7 - i) 101 | palette[k, 2] |= ((label >> 2) & 1) << (7 - i) 102 | label >>= 3 103 | i += 1 104 | return palette 105 | -------------------------------------------------------------------------------- /zs3/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SegmentationLosses: 6 | def __init__( 7 | self, 8 | weight=None, 9 | size_average=True, 10 | batch_average=True, 11 | ignore_index=255, 12 | cuda=False, 13 | ): 14 | self.ignore_index = ignore_index 15 | self.weight = weight 16 | self.size_average = size_average 17 | self.batch_average = batch_average 18 | self.cuda = cuda 19 | 20 | def build_loss(self, mode="ce"): 21 | """Choices: ['ce' or 'focal']""" 22 | if mode == "ce": 23 | return self.CrossEntropyLoss 24 | elif mode == "focal": 25 | return self.FocalLoss 26 | elif mode == "ce_finetune": 27 | return self.CrossEntropyLossFinetune 28 | else: 29 | raise NotImplementedError 30 | 31 | def CrossEntropyLoss(self, logit, target): 32 | n, _, h, w = logit.size() 33 | criterion = nn.CrossEntropyLoss( 34 | weight=self.weight, 35 | ignore_index=self.ignore_index, 36 | size_average=self.size_average, 37 | ) 38 | if self.cuda: 39 | criterion = criterion.cuda() 40 | 41 | loss = criterion(logit, target.long()) 42 | 43 | if self.batch_average: 44 | loss /= n 45 | 46 | return loss 47 | 48 | def CrossEntropyLossFinetune(self, logit, target): 49 | criterion = nn.CrossEntropyLoss( 50 | ignore_index=self.ignore_index, size_average=self.size_average 51 | ) 52 | if self.cuda: 53 | criterion = criterion.cuda() 54 | 55 | loss = criterion(logit, target.long()) 56 | 57 | if self.batch_average: 58 | loss /= logit.shape[0] 59 | 60 | return loss 61 | 62 | def FocalLoss(self, logit, target, gamma=2, alpha=0.5): 63 | n, _, h, w = logit.size() 64 | criterion = nn.CrossEntropyLoss( 65 | weight=self.weight, 66 | ignore_index=self.ignore_index, 67 | size_average=self.size_average, 68 | ) 69 | if self.cuda: 70 | criterion = criterion.cuda() 71 | 72 | logpt = -criterion(logit, target.long()) 73 | pt = torch.exp(logpt) 74 | if alpha is not None: 75 | logpt *= alpha 76 | loss = -((1 - pt) ** gamma) * logpt 77 | 78 | if self.batch_average: 79 | loss /= n 80 | 81 | return loss 82 | 83 | 84 | class GMMNLoss: 85 | def __init__(self, sigma=[2, 5, 10, 20, 40, 80], cuda=False): 86 | self.sigma = sigma 87 | self.cuda = cuda 88 | 89 | def build_loss(self): 90 | return self.moment_loss 91 | 92 | def get_scale_matrix(self, M, N): 93 | s1 = torch.ones((N, 1)) * 1.0 / N 94 | s2 = torch.ones((M, 1)) * -1.0 / M 95 | if self.cuda: 96 | s1, s2 = s1.cuda(), s2.cuda() 97 | return torch.cat((s1, s2), 0) 98 | 99 | def moment_loss(self, gen_samples, x): 100 | X = torch.cat((gen_samples, x), 0) 101 | XX = torch.matmul(X, X.t()) 102 | X2 = torch.sum(X * X, 1, keepdim=True) 103 | exp = XX - 0.5 * X2 - 0.5 * X2.t() 104 | M = gen_samples.size()[0] 105 | N = x.size()[0] 106 | s = self.get_scale_matrix(M, N) 107 | S = torch.matmul(s, s.t()) 108 | 109 | loss = 0 110 | for v in self.sigma: 111 | kernel_val = torch.exp(exp / v) 112 | loss += torch.sum(S * kernel_val) 113 | 114 | loss = torch.sqrt(loss) 115 | return loss 116 | -------------------------------------------------------------------------------- /zs3/modeling/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from zs3.modeling.aspp import build_aspp 5 | from zs3.modeling.backbone import build_backbone 6 | from zs3.modeling.decoder import build_decoder 7 | from zs3.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__( 12 | self, 13 | output_stride=16, 14 | num_classes=21, 15 | sync_bn=True, 16 | freeze_bn=False, 17 | pretrained=True, 18 | global_avg_pool_bn=True, 19 | imagenet_pretrained_path="", 20 | ): 21 | super().__init__() 22 | 23 | if sync_bn: 24 | BatchNorm = SynchronizedBatchNorm2d 25 | else: 26 | BatchNorm = nn.BatchNorm2d 27 | 28 | self.backbone = build_backbone( 29 | output_stride, 30 | BatchNorm, 31 | pretrained=pretrained, 32 | imagenet_pretrained_path=imagenet_pretrained_path, 33 | ) 34 | self.aspp = build_aspp(output_stride, BatchNorm, global_avg_pool_bn) 35 | self.decoder = build_decoder(num_classes, BatchNorm) 36 | 37 | if freeze_bn: 38 | self.freeze_bn() 39 | 40 | def forward(self, input): 41 | x, low_level_feat = self.backbone(input) 42 | x = self.aspp(x) 43 | x = self.decoder(x, low_level_feat) 44 | x = F.interpolate(x, size=input.size()[2:], mode="bilinear", align_corners=True) 45 | return x 46 | 47 | def forward_before_class_prediction(self, input): 48 | x, low_level_feat = self.backbone(input) 49 | x = self.aspp(x) 50 | x = self.decoder.forward_before_class_prediction(x, low_level_feat) 51 | return x 52 | 53 | def forward_class_prediction(self, x, input_size): 54 | x = self.decoder.forward_class_prediction(x) 55 | x = F.interpolate(x, size=input_size, mode="bilinear", align_corners=True) 56 | return x 57 | 58 | def forward_before_last_conv_finetune(self, input): 59 | x, low_level_feat = self.backbone(input) 60 | x = self.aspp(x) 61 | x = self.decoder.forward_before_last_conv_finetune(x, low_level_feat) 62 | return x 63 | 64 | def forward_class_last_conv_finetune(self, x): 65 | x = self.decoder.forward_class_last_conv_finetune(x) 66 | return x 67 | 68 | def freeze_bn(self): 69 | for m in self.modules(): 70 | if isinstance(m, SynchronizedBatchNorm2d): 71 | m.eval() 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.eval() 74 | 75 | def get_1x_lr_params(self): 76 | modules = [self.backbone] 77 | for i in range(len(modules)): 78 | for m in modules[i].named_modules(): 79 | if ( 80 | isinstance(m[1], nn.Conv2d) 81 | or isinstance(m[1], SynchronizedBatchNorm2d) 82 | or isinstance(m[1], nn.BatchNorm2d) 83 | ): 84 | for p in m[1].parameters(): 85 | if p.requires_grad: 86 | yield p 87 | 88 | def get_10x_lr_params(self): 89 | modules = [self.aspp, self.decoder] 90 | for i in range(len(modules)): 91 | for m in modules[i].named_modules(): 92 | if ( 93 | isinstance(m[1], nn.Conv2d) 94 | or isinstance(m[1], SynchronizedBatchNorm2d) 95 | or isinstance(m[1], nn.BatchNorm2d) 96 | ): 97 | for p in m[1].parameters(): 98 | if p.requires_grad: 99 | yield p 100 | -------------------------------------------------------------------------------- /zs3/dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | 8 | class Normalize: 9 | """Normalize a tensor image with mean and standard deviation. 10 | Args: 11 | mean (tuple): means for each channel. 12 | std (tuple): standard deviations for each channel. 13 | """ 14 | 15 | def __init__(self, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): 16 | self.mean = mean 17 | self.std = std 18 | 19 | def __call__(self, sample): 20 | img = sample["image"] 21 | mask = sample["label"] 22 | img = np.array(img).astype(np.float32) 23 | mask = np.array(mask).astype(np.float32) 24 | img /= 255.0 25 | img -= self.mean 26 | img /= self.std 27 | 28 | return {"image": img, "label": mask} 29 | 30 | 31 | class ToTensor: 32 | """Convert ndarrays in sample to Tensors.""" 33 | 34 | def __call__(self, sample): 35 | # swap color axis because 36 | # numpy image: H x W x C 37 | # torch image: C X H X W 38 | img = sample["image"] 39 | mask = sample["label"] 40 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 41 | mask = np.array(mask).astype(np.float32) 42 | 43 | img = torch.from_numpy(img).float() 44 | mask = torch.from_numpy(mask).float() 45 | 46 | return {"image": img, "label": mask} 47 | 48 | 49 | class RandomHorizontalFlip: 50 | def __call__(self, sample): 51 | img = sample["image"] 52 | mask = sample["label"] 53 | if random.random() < 0.5: 54 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 56 | 57 | return {"image": img, "label": mask} 58 | 59 | 60 | class RandomGaussianBlur: 61 | def __call__(self, sample): 62 | img = sample["image"] 63 | mask = sample["label"] 64 | if random.random() < 0.5: 65 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 66 | 67 | return {"image": img, "label": mask} 68 | 69 | 70 | class RandomScaleCrop: 71 | def __init__(self, base_size, crop_size, fill=255): 72 | self.base_size = base_size 73 | self.crop_size = crop_size 74 | self.fill = fill 75 | 76 | def __call__(self, sample): 77 | img = sample["image"] 78 | mask = sample["label"] 79 | # random scale (short edge) 80 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 81 | w, h = img.size 82 | if h > w: 83 | ow = short_size 84 | oh = int(1.0 * h * ow / w) 85 | else: 86 | oh = short_size 87 | ow = int(1.0 * w * oh / h) 88 | img = img.resize((ow, oh), Image.BILINEAR) 89 | mask = mask.resize((ow, oh), Image.NEAREST) 90 | # pad crop 91 | if short_size < self.crop_size: 92 | padh = self.crop_size - oh if oh < self.crop_size else 0 93 | padw = self.crop_size - ow if ow < self.crop_size else 0 94 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 95 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 96 | # random crop crop_size 97 | w, h = img.size 98 | x1 = random.randint(0, w - self.crop_size) 99 | y1 = random.randint(0, h - self.crop_size) 100 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 101 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 102 | 103 | return {"image": img, "label": mask} 104 | 105 | 106 | class FixScale: 107 | def __init__(self, crop_size): 108 | self.crop_size = crop_size 109 | 110 | def __call__(self, sample): 111 | img = sample["image"] 112 | mask = sample["label"] 113 | w, h = img.size 114 | if w > h: 115 | oh = self.crop_size 116 | ow = int(1.0 * w * oh / h) 117 | else: 118 | ow = self.crop_size 119 | oh = int(1.0 * h * ow / w) 120 | img = img.resize((ow, oh), Image.BILINEAR) 121 | mask = mask.resize((ow, oh), Image.NEAREST) 122 | 123 | return {"image": img, "label": mask} 124 | -------------------------------------------------------------------------------- /zs3/modeling/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from zs3.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class _ASPPModule(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 10 | super().__init__() 11 | self.atrous_conv = nn.Conv2d( 12 | inplanes, 13 | planes, 14 | kernel_size=kernel_size, 15 | stride=1, 16 | padding=padding, 17 | dilation=dilation, 18 | bias=False, 19 | ) 20 | self.bn = BatchNorm(planes) 21 | self.relu = nn.ReLU() 22 | 23 | self._init_weight() 24 | 25 | def forward(self, x): 26 | x = self.atrous_conv(x) 27 | x = self.bn(x) 28 | 29 | return self.relu(x) 30 | 31 | def _init_weight(self): 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | torch.nn.init.kaiming_normal_(m.weight) 35 | elif isinstance(m, SynchronizedBatchNorm2d): 36 | m.weight.data.fill_(1) 37 | m.bias.data.zero_() 38 | elif isinstance(m, nn.BatchNorm2d): 39 | m.weight.data.fill_(1) 40 | m.bias.data.zero_() 41 | 42 | 43 | class ASPP(nn.Module): 44 | def __init__(self, output_stride, BatchNorm, global_avg_pool_bn=True): 45 | super().__init__() 46 | inplanes = 2048 47 | if output_stride == 16: 48 | dilations = [1, 6, 12, 18] 49 | elif output_stride == 8: 50 | dilations = [1, 12, 24, 36] 51 | else: 52 | raise NotImplementedError 53 | 54 | self.aspp1 = _ASPPModule( 55 | inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm 56 | ) 57 | self.aspp2 = _ASPPModule( 58 | inplanes, 59 | 256, 60 | 3, 61 | padding=dilations[1], 62 | dilation=dilations[1], 63 | BatchNorm=BatchNorm, 64 | ) 65 | self.aspp3 = _ASPPModule( 66 | inplanes, 67 | 256, 68 | 3, 69 | padding=dilations[2], 70 | dilation=dilations[2], 71 | BatchNorm=BatchNorm, 72 | ) 73 | self.aspp4 = _ASPPModule( 74 | inplanes, 75 | 256, 76 | 3, 77 | padding=dilations[3], 78 | dilation=dilations[3], 79 | BatchNorm=BatchNorm, 80 | ) 81 | 82 | ## for batch size == 1 83 | if global_avg_pool_bn: 84 | self.global_avg_pool = nn.Sequential( 85 | nn.AdaptiveAvgPool2d((1, 1)), 86 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 87 | BatchNorm(256), 88 | nn.ReLU(), 89 | ) 90 | else: 91 | self.global_avg_pool = nn.Sequential( 92 | nn.AdaptiveAvgPool2d((1, 1)), 93 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 94 | nn.ReLU(), 95 | ) 96 | 97 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 98 | self.bn1 = BatchNorm(256) 99 | self.relu = nn.ReLU() 100 | self.dropout = nn.Dropout(0.5) 101 | self._init_weight() 102 | 103 | def forward(self, x): 104 | x1 = self.aspp1(x) 105 | x2 = self.aspp2(x) 106 | x3 = self.aspp3(x) 107 | x4 = self.aspp4(x) 108 | x5 = self.global_avg_pool(x) 109 | x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True) 110 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 111 | 112 | x = self.conv1(x) 113 | x = self.bn1(x) 114 | x = self.relu(x) 115 | 116 | return self.dropout(x) 117 | 118 | def _init_weight(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 122 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 123 | torch.nn.init.kaiming_normal_(m.weight) 124 | elif isinstance(m, SynchronizedBatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | 131 | 132 | def build_aspp(output_stride, BatchNorm, global_avg_pool_bn=True): 133 | return ASPP(output_stride, BatchNorm, global_avg_pool_bn) 134 | -------------------------------------------------------------------------------- /zs3/modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # File : comm.py 2 | # Author : Jiayuan Mao 3 | # Email : maojiayuan@gmail.com 4 | # Date : 27/01/2018 5 | # 6 | # This file is part of Synchronized-BatchNorm-PyTorch. 7 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 8 | # Distributed under MIT License. 9 | 10 | import collections 11 | import queue 12 | import threading 13 | 14 | __all__ = ["FutureResult", "SlavePipe", "SyncMaster"] 15 | 16 | 17 | class FutureResult: 18 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 19 | 20 | def __init__(self): 21 | self._result = None 22 | self._lock = threading.Lock() 23 | self._cond = threading.Condition(self._lock) 24 | 25 | def put(self, result): 26 | with self._lock: 27 | assert self._result is None, "Previous result has't been fetched." 28 | self._result = result 29 | self._cond.notify() 30 | 31 | def get(self): 32 | with self._lock: 33 | if self._result is None: 34 | self._cond.wait() 35 | 36 | res = self._result 37 | self._result = None 38 | return res 39 | 40 | 41 | _MasterRegistry = collections.namedtuple("MasterRegistry", ["result"]) 42 | _SlavePipeBase = collections.namedtuple( 43 | "_SlavePipeBase", ["identifier", "queue", "result"] 44 | ) 45 | 46 | 47 | class SlavePipe(_SlavePipeBase): 48 | """Pipe for master-slave communication.""" 49 | 50 | def run_slave(self, msg): 51 | self.queue.put((self.identifier, msg)) 52 | ret = self.result.get() 53 | self.queue.put(True) 54 | return ret 55 | 56 | 57 | class SyncMaster: 58 | """An abstract `SyncMaster` object. 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | Args: 70 | master_callback: a callback to be invoked after having collected messages from slave devices. 71 | """ 72 | self._master_callback = master_callback 73 | self._queue = queue.Queue() 74 | self._registry = collections.OrderedDict() 75 | self._activated = False 76 | 77 | def __getstate__(self): 78 | return {"master_callback": self._master_callback} 79 | 80 | def __setstate__(self, state): 81 | self.__init__(state["master_callback"]) 82 | 83 | def register_slave(self, identifier): 84 | """ 85 | Register an slave device. 86 | Args: 87 | identifier: an identifier, usually is the device id. 88 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 89 | """ 90 | if self._activated: 91 | assert self._queue.empty(), "Queue is not clean before next initialization." 92 | self._activated = False 93 | self._registry.clear() 94 | future = FutureResult() 95 | self._registry[identifier] = _MasterRegistry(future) 96 | return SlavePipe(identifier, self._queue, future) 97 | 98 | def run_master(self, master_msg): 99 | """ 100 | Main entry for the master device in each forward pass. 101 | The messages were first collected from each devices (including the master device), and then 102 | an callback will be invoked to compute the message to be sent back to each devices 103 | (including the master device). 104 | Args: 105 | master_msg: the message that the master want to send to itself. This will be placed as the first 106 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 107 | Returns: the message to be sent back to the master device. 108 | """ 109 | self._activated = True 110 | 111 | intermediates = [(0, master_msg)] 112 | for i in range(self.nr_slaves): 113 | intermediates.append(self._queue.get()) 114 | 115 | results = self._master_callback(intermediates) 116 | assert results[0][0] == 0, "The first result should belongs to the master." 117 | 118 | for i, res in results: 119 | if i == 0: 120 | continue 121 | self._registry[i].result.put(res) 122 | 123 | for i in range(self.nr_slaves): 124 | assert self._queue.get() is True 125 | 126 | return results[0][1] 127 | 128 | @property 129 | def nr_slaves(self): 130 | return len(self._registry) 131 | -------------------------------------------------------------------------------- /zs3/dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torchvision import transforms 6 | 7 | from zs3.dataloaders import custom_transforms as tr 8 | from .base import BaseDataset, load_obj, lbl_contains_unseen 9 | 10 | 11 | PASCAL_DIR = pathlib.Path("./data/VOC2012") 12 | 13 | 14 | class VOCSegmentation(BaseDataset): 15 | """ 16 | PascalVoc dataset 17 | """ 18 | NUM_CLASSES = 21 19 | 20 | def __init__( 21 | self, 22 | args, 23 | base_dir=PASCAL_DIR, 24 | split="train", 25 | load_embedding=None, 26 | w2c_size=300, 27 | weak_label=False, 28 | unseen_classes_idx_weak=[], 29 | transform=True, 30 | ): 31 | """ 32 | :param base_dir: path to VOC dataset directory 33 | :param split: train/val 34 | :param transform: transform to apply 35 | """ 36 | super().__init__( 37 | args, 38 | base_dir, 39 | split, 40 | load_embedding, 41 | w2c_size, 42 | weak_label, 43 | unseen_classes_idx_weak, 44 | transform, 45 | ) 46 | self._image_dir = self._base_dir / "JPEGImages" 47 | self._cat_dir = self._base_dir / "SegmentationClass" 48 | 49 | self.unseen_classes_idx_weak = unseen_classes_idx_weak 50 | 51 | _splits_dir = self._base_dir / "ImageSets" / "Segmentation" 52 | 53 | self.im_ids = [] 54 | self.categories = [] 55 | 56 | lines = (_splits_dir / f"{self.split}.txt").read_text().splitlines() 57 | 58 | for ii, line in enumerate(lines): 59 | _image = self._image_dir / f"{line}.jpg" 60 | _cat = self._cat_dir / f"{line}.png" 61 | assert _image.is_file(), _image 62 | assert _cat.is_file(), _cat 63 | 64 | # if unseen classes and training split 65 | if len(args.unseen_classes_idx) > 0 and self.split == "train": 66 | cat = Image.open(_cat) 67 | cat = np.array(cat, dtype=np.uint8) 68 | if lbl_contains_unseen(cat, args.unseen_classes_idx): 69 | continue 70 | 71 | self.im_ids.append(line) 72 | self.images.append(_image) 73 | self.categories.append(_cat) 74 | 75 | assert len(self.images) == len(self.categories) 76 | 77 | # Display stats 78 | print(f"(pascal) Number of images in {split}: {len(self.images):d}") 79 | 80 | def init_embeddings(self): 81 | embed_arr = load_obj("embeddings/pascal/w2c/norm_embed_arr_" + str(self.w2c_size)) 82 | self.make_embeddings(embed_arr) 83 | 84 | def __getitem__(self, index): 85 | _img, _target = self._make_img_gt_point_pair(index) 86 | 87 | if self.weak_label: 88 | unique_class = np.unique(np.array(_target)) 89 | has_unseen_class = False 90 | for u_class in unique_class: 91 | if u_class in self.unseen_classes_idx_weak: 92 | has_unseen_class = True 93 | if has_unseen_class: 94 | _target = Image.open( 95 | "weak_label_pascal_10_unseen_top_by_image_25.0/pascal/" 96 | + self.categories[index].stem 97 | + ".jpg" 98 | ) 99 | 100 | sample = {"image": _img, "label": _target} 101 | 102 | if self.transform: 103 | if self.split == "train": 104 | sample = self.transform_tr(sample) 105 | elif self.split == "val": 106 | sample = self.transform_val(sample) 107 | else: 108 | sample = self.transform_weak(sample) 109 | 110 | if self.load_embedding: 111 | self.get_embeddings(sample) 112 | sample["image_name"] = str(self.images[index]) 113 | return sample 114 | 115 | def _make_img_gt_point_pair(self, index): 116 | _img = Image.open(self.images[index]).convert("RGB") 117 | _target = Image.open(self.categories[index]) 118 | return _img, _target 119 | 120 | def transform_tr(self, sample): 121 | composed_transforms = transforms.Compose( 122 | [ 123 | tr.RandomHorizontalFlip(), 124 | tr.RandomScaleCrop( 125 | base_size=self.args.base_size, 126 | crop_size=self.args.crop_size, 127 | fill=255, 128 | ), 129 | tr.RandomGaussianBlur(), 130 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 131 | tr.ToTensor(), 132 | ] 133 | ) 134 | return composed_transforms(sample) 135 | 136 | def transform_val(self, sample): 137 | composed_transforms = transforms.Compose( 138 | [ 139 | tr.FixScale(crop_size=self.args.crop_size), 140 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 141 | tr.ToTensor(), 142 | ] 143 | ) 144 | return composed_transforms(sample) 145 | 146 | def transform_weak(self, sample): 147 | 148 | composed_transforms = transforms.Compose( 149 | [ 150 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 151 | tr.ToTensor(), 152 | ] 153 | ) 154 | 155 | return composed_transforms(sample) 156 | 157 | def __str__(self): 158 | return f"VOC2012(split={self.split})" 159 | -------------------------------------------------------------------------------- /zs3/dataloaders/datasets/sbd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import numpy as np 5 | import scipy.io 6 | from PIL import Image 7 | from torchvision import transforms 8 | 9 | from zs3.dataloaders import custom_transforms as tr 10 | from .base import BaseDataset, load_obj, lbl_contains_unseen 11 | 12 | SBD_DIR = pathlib.Path("./data/VOC2012/benchmark_RELEASE") 13 | 14 | 15 | class SBDSegmentation(BaseDataset): 16 | NUM_CLASSES = 21 17 | 18 | def __init__( 19 | self, 20 | args, 21 | base_dir=SBD_DIR, 22 | split="train", 23 | load_embedding=None, 24 | w2c_size=300, 25 | weak_label=False, 26 | unseen_classes_idx_weak=[], 27 | transform=True, 28 | ): 29 | """ 30 | :param base_dir: path to VOC dataset directory 31 | :param split: train/val 32 | :param transform: transform to apply 33 | """ 34 | if isinstance(split, str): 35 | split = [split] 36 | split.sort() 37 | super().__init__( 38 | args, 39 | base_dir, 40 | split, 41 | load_embedding, 42 | w2c_size, 43 | weak_label, 44 | unseen_classes_idx_weak, 45 | transform, 46 | ) 47 | self._dataset_dir = self._base_dir / "dataset" 48 | self._image_dir = self._dataset_dir / "img" 49 | self._cat_dir = self._dataset_dir / "cls" 50 | 51 | # Get list of all images from the split and check that the files exist 52 | self.im_ids = [] 53 | self.categories = [] 54 | for splt in self.split: 55 | lines = (self._dataset_dir / f"{splt}.txt").read_text().splitlines() 56 | 57 | for line in lines: 58 | _image = self._image_dir / f"{line}.jpg" 59 | _categ = self._cat_dir / f"{line}.mat" 60 | assert _image.is_file() 61 | assert _categ.is_file() 62 | 63 | # if unseen classes 64 | if len(args.unseen_classes_idx) > 0: 65 | _target = Image.fromarray( 66 | scipy.io.loadmat(_categ)["GTcls"][0]["Segmentation"][0] 67 | ) 68 | _target = np.array(_target, dtype=np.uint8) 69 | if lbl_contains_unseen(_target, args.unseen_classes_idx): 70 | continue 71 | 72 | self.im_ids.append(line) 73 | self.images.append(_image) 74 | self.categories.append(_categ) 75 | 76 | assert len(self.images) == len(self.categories) 77 | 78 | # Display stats 79 | print(f"(sbd) Number of images: {len(self.images):d}") 80 | 81 | def init_embeddings(self): 82 | if self.load_embedding == "attributes": 83 | embed_arr = np.load("embeddings/pascal/pascalvoc_class_attributes.npy") 84 | elif self.load_embedding == "w2c": 85 | embed_arr = load_obj( 86 | "embeddings/pascal/w2c/norm_embed_arr_" + str(self.w2c_size) 87 | ) 88 | elif self.load_embedding == "w2c_bg": 89 | embed_arr = np.load("embeddings/pascal/pascalvoc_class_w2c_bg.npy") 90 | elif self.load_embedding == "my_w2c": 91 | embed_arr = np.load("embeddings/pascal/pascalvoc_class_w2c.npy") 92 | elif self.load_embedding == "fusion": 93 | attributes = np.load("embeddings/pascal/pascalvoc_class_attributes.npy") 94 | w2c = np.load("embeddings/pascal/pascalvoc_class_w2c.npy") 95 | embed_arr = np.concatenate((attributes, w2c), axis=1) 96 | else: 97 | raise KeyError(self.load_embedding) 98 | self.make_embeddings(embed_arr) 99 | 100 | def __getitem__(self, index): 101 | _img, _target = self._make_img_gt_point_pair(index) 102 | 103 | if self.weak_label: 104 | unique_class = np.unique(np.array(_target)) 105 | has_unseen_class = False 106 | for u_class in unique_class: 107 | if u_class in self.unseen_classes_idx_weak: 108 | has_unseen_class = True 109 | if has_unseen_class: 110 | _target = Image.open( 111 | "weak_label_pascal_10_unseen_top_by_image_25.0/sbd/" 112 | + self.categories[index].stem 113 | + ".jpg" 114 | ) 115 | 116 | sample = {"image": _img, "label": _target} 117 | 118 | if self.transform: 119 | sample = self.transform_s(sample) 120 | else: 121 | sample = self.transform_weak(sample) 122 | 123 | if self.load_embedding: 124 | self.get_embeddings(sample) 125 | sample["image_name"] = str(self.images[index]) 126 | return sample 127 | 128 | def _make_img_gt_point_pair(self, index): 129 | _img = Image.open(self.images[index]).convert("RGB") 130 | _target = Image.fromarray( 131 | scipy.io.loadmat(self.categories[index])["GTcls"][0]["Segmentation"][0] 132 | ) 133 | 134 | return _img, _target 135 | 136 | def transform_s(self, sample): 137 | composed_transforms = transforms.Compose( 138 | [ 139 | tr.RandomHorizontalFlip(), 140 | tr.RandomScaleCrop( 141 | base_size=self.args.base_size, crop_size=self.args.crop_size 142 | ), 143 | tr.RandomGaussianBlur(), 144 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 145 | tr.ToTensor(), 146 | ] 147 | ) 148 | 149 | return composed_transforms(sample) 150 | 151 | def transform_weak(self, sample): 152 | 153 | composed_transforms = transforms.Compose( 154 | [ 155 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 156 | tr.ToTensor(), 157 | ] 158 | ) 159 | 160 | return composed_transforms(sample) 161 | 162 | def __str__(self): 163 | return f"SBDSegmentation(split={self.split})" 164 | -------------------------------------------------------------------------------- /zs3/dataloaders/datasets/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pathlib 4 | 5 | import numpy as np 6 | import scipy 7 | from PIL import Image 8 | from torchvision import transforms 9 | 10 | from zs3.dataloaders import custom_transforms as tr 11 | from .base import BaseDataset, lbl_contains_unseen 12 | 13 | 14 | CONTEXT_DIR = pathlib.Path("./data/context/") 15 | 16 | 17 | class ContextSegmentation(BaseDataset): 18 | """ 19 | PascalVoc dataset 20 | """ 21 | 22 | NUM_CLASSES = 60 23 | 24 | def __init__( 25 | self, 26 | args, 27 | base_dir=CONTEXT_DIR, 28 | split="train", 29 | load_embedding=None, 30 | w2c_size=300, 31 | weak_label=False, 32 | unseen_classes_idx_weak=[], 33 | transform=True, 34 | ): 35 | """ 36 | :param base_dir: path to VOC dataset directory 37 | :param split: train/val 38 | :param transform: transform to apply 39 | """ 40 | super().__init__( 41 | args, 42 | base_dir, 43 | split, 44 | load_embedding, 45 | w2c_size, 46 | weak_label, 47 | unseen_classes_idx_weak, 48 | transform, 49 | ) 50 | 51 | self._image_dir = self._base_dir / "pascal/VOCdevkit/VOC2012/JPEGImages" 52 | self._cat_dir = self._base_dir / "full_annotations/trainval" 53 | 54 | self.unseen_classes_idx_weak = unseen_classes_idx_weak 55 | 56 | 57 | self.im_ids = [] 58 | self.categories = [] 59 | 60 | self.labels_459 = [ 61 | label.decode().replace(" ", "") 62 | for idx, label in np.genfromtxt( 63 | osp.join(self._base_dir, "full_annotations/labels.txt"), 64 | delimiter=":", 65 | dtype=None, 66 | ) 67 | ] 68 | self.labels_59 = [ 69 | label.decode().replace(" ", "") 70 | for idx, label in np.genfromtxt( 71 | osp.join(self._base_dir, "classes-59.txt"), delimiter=":", dtype=None 72 | ) 73 | ] 74 | for main_label, task_label in zip( 75 | ("table", "bedclothes", "cloth"), ("diningtable", "bedcloth", "clothes") 76 | ): 77 | self.labels_59[self.labels_59.index(task_label)] = main_label 78 | 79 | self.idx_59_to_idx_469 = {} 80 | for idx, l in enumerate(self.labels_59): 81 | if idx > 0: 82 | self.idx_59_to_idx_469[idx] = self.labels_459.index(l) + 1 83 | 84 | lines = (self._base_dir / f"{self.split}.txt").read_text().splitlines() 85 | 86 | for ii, line in enumerate(lines): 87 | _image = self._image_dir / f'{line}.jpg' 88 | _cat = self._cat_dir / f"{line}.mat" 89 | assert _image.is_file() 90 | assert _cat.is_file() 91 | 92 | # if unseen classes and training split 93 | if len(args.unseen_classes_idx) > 0: 94 | cat = self.load_label(_cat) 95 | if lbl_contains_unseen(cat, args.unseen_classes_idx): 96 | continue 97 | 98 | self.im_ids.append(line) 99 | self.images.append(_image) 100 | self.categories.append(_cat) 101 | 102 | assert len(self.images) == len(self.categories) 103 | 104 | # Display stats 105 | print( 106 | "(pascal) Number of images in {}: {:d}, {:d} deleted".format( 107 | split, len(self.images), len(lines) - len(self.images) 108 | ) 109 | ) 110 | 111 | def load_label(self, file_path): 112 | """ 113 | Load label image as 1 x height x width integer array of label indices. 114 | The leading singleton dimension is required by the loss. 115 | The full 459 labels are translated to the 59 class task labels. 116 | """ 117 | label_459 = scipy.io.loadmat(file_path)["LabelMap"] 118 | label = np.zeros_like(label_459, dtype=np.uint8) 119 | for idx, l in enumerate(self.labels_59): 120 | if idx > 0: 121 | label[label_459 == self.idx_59_to_idx_469[idx]] = idx 122 | return label 123 | 124 | def init_embeddings(self): 125 | if self.load_embedding == "my_w2c": 126 | embed_arr = np.load("embeddings/context/pascalcontext_class_w2c.npy") 127 | else: 128 | raise KeyError(self.load_embedding) 129 | self.make_embeddings(embed_arr) 130 | 131 | def __getitem__(self, index): 132 | _img, _target = self._make_img_gt_point_pair(index) 133 | 134 | if self.weak_label: 135 | unique_class = np.unique(np.array(_target)) 136 | has_unseen_class = False 137 | for u_class in unique_class: 138 | if u_class in self.unseen_classes_idx_weak: 139 | has_unseen_class = True 140 | if has_unseen_class: 141 | _target = Image.open( 142 | "weak_label_context_10_unseen_top_by_image_75.0/pascal/" 143 | + self.categories[index].stem 144 | + ".jpg" 145 | ) 146 | 147 | sample = {"image": _img, "label": _target} 148 | 149 | if self.transform: 150 | if self.split == "train": 151 | sample = self.transform_tr(sample) 152 | elif self.split == "val": 153 | sample = self.transform_val(sample) 154 | else: 155 | sample = self.transform_weak(sample) 156 | 157 | if self.load_embedding: 158 | self.get_embeddings(sample) 159 | sample["image_name"] = str(self.images[index]) 160 | return sample 161 | 162 | def _make_img_gt_point_pair(self, index): 163 | _img = Image.open(self.images[index]).convert("RGB") 164 | _target = self.load_label(self.categories[index]) 165 | _target = Image.fromarray(_target) 166 | return _img, _target 167 | 168 | def transform_tr(self, sample): 169 | composed_transforms = transforms.Compose( 170 | [ 171 | tr.RandomHorizontalFlip(), 172 | tr.RandomScaleCrop( 173 | base_size=self.args.base_size, 174 | crop_size=self.args.crop_size, 175 | fill=255, 176 | ), 177 | tr.RandomGaussianBlur(), 178 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 179 | tr.ToTensor(), 180 | ] 181 | ) 182 | 183 | return composed_transforms(sample) 184 | 185 | def transform_val(self, sample): 186 | 187 | composed_transforms = transforms.Compose( 188 | [ 189 | tr.FixScale(crop_size=self.args.crop_size), 190 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 191 | tr.ToTensor(), 192 | ] 193 | ) 194 | 195 | return composed_transforms(sample) 196 | 197 | def transform_weak(self, sample): 198 | 199 | composed_transforms = transforms.Compose( 200 | [ 201 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 202 | tr.ToTensor(), 203 | ] 204 | ) 205 | 206 | return composed_transforms(sample) 207 | 208 | def __str__(self): 209 | return f"VOC2012(split={self.split})" 210 | -------------------------------------------------------------------------------- /zs3/modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from zs3.modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | expansion = 4 11 | 12 | def __init__( 13 | self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None 14 | ): 15 | super().__init__() 16 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 17 | self.bn1 = BatchNorm(planes) 18 | self.conv2 = nn.Conv2d( 19 | planes, 20 | planes, 21 | kernel_size=3, 22 | stride=stride, 23 | dilation=dilation, 24 | padding=dilation, 25 | bias=False, 26 | ) 27 | self.bn2 = BatchNorm(planes) 28 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 29 | self.bn3 = BatchNorm(planes * 4) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv3(out) 45 | out = self.bn3(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class ResNet(nn.Module): 57 | def __init__( 58 | self, 59 | block, 60 | layers, 61 | output_stride, 62 | BatchNorm, 63 | pretrained=True, 64 | imagenet_pretrained_path="", 65 | ): 66 | self.inplanes = 64 67 | super().__init__() 68 | blocks = [1, 2, 4] 69 | if output_stride == 16: 70 | strides = [1, 2, 2, 1] 71 | dilations = [1, 1, 1, 2] 72 | elif output_stride == 8: 73 | strides = [1, 2, 1, 1] 74 | dilations = [1, 1, 2, 4] 75 | else: 76 | raise NotImplementedError 77 | 78 | # Modules 79 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 80 | self.bn1 = BatchNorm(64) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 83 | 84 | self.layer1 = self._make_layer( 85 | block, 86 | 64, 87 | layers[0], 88 | stride=strides[0], 89 | dilation=dilations[0], 90 | BatchNorm=BatchNorm, 91 | ) 92 | self.layer2 = self._make_layer( 93 | block, 94 | 128, 95 | layers[1], 96 | stride=strides[1], 97 | dilation=dilations[1], 98 | BatchNorm=BatchNorm, 99 | ) 100 | self.layer3 = self._make_layer( 101 | block, 102 | 256, 103 | layers[2], 104 | stride=strides[2], 105 | dilation=dilations[2], 106 | BatchNorm=BatchNorm, 107 | ) 108 | self.layer4 = self._make_MG_unit( 109 | block, 110 | 512, 111 | blocks=blocks, 112 | stride=strides[3], 113 | dilation=dilations[3], 114 | BatchNorm=BatchNorm, 115 | ) 116 | self._init_weight() 117 | 118 | if pretrained: 119 | self._load_pretrained_model(imagenet_pretrained_path) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d( 126 | self.inplanes, 127 | planes * block.expansion, 128 | kernel_size=1, 129 | stride=stride, 130 | bias=False, 131 | ), 132 | BatchNorm(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append( 137 | block(self.inplanes, planes, stride, dilation, downsample, BatchNorm) 138 | ) 139 | self.inplanes = planes * block.expansion 140 | for i in range(1, blocks): 141 | layers.append( 142 | block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm) 143 | ) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 148 | downsample = None 149 | if stride != 1 or self.inplanes != planes * block.expansion: 150 | downsample = nn.Sequential( 151 | nn.Conv2d( 152 | self.inplanes, 153 | planes * block.expansion, 154 | kernel_size=1, 155 | stride=stride, 156 | bias=False, 157 | ), 158 | BatchNorm(planes * block.expansion), 159 | ) 160 | 161 | layers = [] 162 | layers.append( 163 | block( 164 | self.inplanes, 165 | planes, 166 | stride, 167 | dilation=blocks[0] * dilation, 168 | downsample=downsample, 169 | BatchNorm=BatchNorm, 170 | ) 171 | ) 172 | self.inplanes = planes * block.expansion 173 | for i in range(1, len(blocks)): 174 | layers.append( 175 | block( 176 | self.inplanes, 177 | planes, 178 | stride=1, 179 | dilation=blocks[i] * dilation, 180 | BatchNorm=BatchNorm, 181 | ) 182 | ) 183 | 184 | return nn.Sequential(*layers) 185 | 186 | def forward(self, input): 187 | x = self.conv1(input) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | 192 | x = self.layer1(x) 193 | low_level_feat = x 194 | x = self.layer2(x) 195 | x = self.layer3(x) 196 | x = self.layer4(x) 197 | return x, low_level_feat 198 | 199 | def _init_weight(self): 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv2d): 202 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 203 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 204 | elif isinstance(m, SynchronizedBatchNorm2d): 205 | m.weight.data.fill_(1) 206 | m.bias.data.zero_() 207 | elif isinstance(m, nn.BatchNorm2d): 208 | m.weight.data.fill_(1) 209 | m.bias.data.zero_() 210 | 211 | def _load_pretrained_model(self, imagenet_pretrained_path): 212 | """ 213 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 214 | """ 215 | 216 | pretrain_dict = torch.load(imagenet_pretrained_path)["state_dict"] 217 | model_dict = {} 218 | state_dict = self.state_dict() 219 | for k, v in pretrain_dict.items(): 220 | 221 | k = k[7:] 222 | if k in state_dict: 223 | model_dict[k] = v 224 | state_dict.update(model_dict) 225 | 226 | self.load_state_dict(state_dict) 227 | 228 | 229 | def ResNet101(output_stride, BatchNorm, pretrained=True, imagenet_pretrained_path=""): 230 | """Constructs a ResNet-101 model. 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | """ 234 | model = ResNet( 235 | Bottleneck, 236 | [3, 4, 23, 3], 237 | output_stride, 238 | BatchNorm, 239 | pretrained=pretrained, 240 | imagenet_pretrained_path=imagenet_pretrained_path, 241 | ) 242 | return model 243 | -------------------------------------------------------------------------------- /zs3/modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # File : batchnorm.py 2 | # Author : Jiayuan Mao 3 | # Email : maojiayuan@gmail.com 4 | # Date : 27/01/2018 5 | # 6 | # This file is part of Synchronized-BatchNorm-PyTorch. 7 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 8 | # Distributed under MIT License. 9 | 10 | import collections 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.nn.modules.batchnorm import _BatchNorm 15 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 16 | 17 | from .comm import SyncMaster 18 | 19 | __all__ = ["SynchronizedBatchNorm2d"] 20 | 21 | 22 | def _sum_ft(tensor): 23 | """sum over the first and last dimention""" 24 | return tensor.sum(dim=0).sum(dim=-1) 25 | 26 | 27 | def _unsqueeze_ft(tensor): 28 | """add new dementions at the front and the tail""" 29 | return tensor.unsqueeze(0).unsqueeze(-1) 30 | 31 | 32 | _ChildMessage = collections.namedtuple("_ChildMessage", ["sum", "ssum", "sum_size"]) 33 | _MasterMessage = collections.namedtuple("_MasterMessage", ["sum", "inv_std"]) 34 | 35 | 36 | class _SynchronizedBatchNorm(_BatchNorm): 37 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 38 | super().__init__(num_features, eps=eps, momentum=momentum, affine=affine) 39 | 40 | self._sync_master = SyncMaster(self._data_parallel_master) 41 | 42 | self._is_parallel = False 43 | self._parallel_id = None 44 | self._slave_pipe = None 45 | 46 | def forward(self, input): 47 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 48 | if not (self._is_parallel and self.training): 49 | return F.batch_norm( 50 | input, 51 | self.running_mean, 52 | self.running_var, 53 | self.weight, 54 | self.bias, 55 | self.training, 56 | self.momentum, 57 | self.eps, 58 | ) 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | input = input.view(input.size(0), self.num_features, -1) 63 | 64 | # Compute the sum and square-sum. 65 | sum_size = input.size(0) * input.size(2) 66 | input_sum = _sum_ft(input) 67 | input_ssum = _sum_ft(input ** 2) 68 | 69 | # Reduce-and-broadcast the statistics. 70 | if self._parallel_id == 0: 71 | mean, inv_std = self._sync_master.run_master( 72 | _ChildMessage(input_sum, input_ssum, sum_size) 73 | ) 74 | else: 75 | mean, inv_std = self._slave_pipe.run_slave( 76 | _ChildMessage(input_sum, input_ssum, sum_size) 77 | ) 78 | 79 | # Compute the output. 80 | if self.affine: 81 | # MJY:: Fuse the multiplication for speed. 82 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft( 83 | inv_std * self.weight 84 | ) + _unsqueeze_ft(self.bias) 85 | else: 86 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 87 | 88 | # Reshape it. 89 | return output.view(input_shape) 90 | 91 | def __data_parallel_replicate__(self, ctx, copy_id): 92 | self._is_parallel = True 93 | self._parallel_id = copy_id 94 | 95 | # parallel_id == 0 means master device. 96 | if self._parallel_id == 0: 97 | ctx.sync_master = self._sync_master 98 | else: 99 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 100 | 101 | def _data_parallel_master(self, intermediates): 102 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 103 | 104 | # Always using same "device order" makes the ReduceAdd operation faster. 105 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 106 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 107 | 108 | to_reduce = [i[1][:2] for i in intermediates] 109 | to_reduce = [j for i in to_reduce for j in i] # flatten 110 | target_gpus = [i[1].sum.get_device() for i in intermediates] 111 | 112 | sum_size = sum([i[1].sum_size for i in intermediates]) 113 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 114 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 115 | 116 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 117 | 118 | outputs = [] 119 | for i, rec in enumerate(intermediates): 120 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2 : i * 2 + 2]))) 121 | 122 | return outputs 123 | 124 | def _compute_mean_std(self, sum_, ssum, size): 125 | """Compute the mean and standard-deviation with sum and square-sum. This method 126 | also maintains the moving average on the master device.""" 127 | assert ( 128 | size > 1 129 | ), "BatchNorm computes unbiased standard-deviation, which requires size > 1." 130 | mean = sum_ / size 131 | sumvar = ssum - sum_ * mean 132 | unbias_var = sumvar / (size - 1) 133 | bias_var = sumvar / size 134 | 135 | self.running_mean = ( 136 | 1 - self.momentum 137 | ) * self.running_mean + self.momentum * mean.data 138 | self.running_var = ( 139 | 1 - self.momentum 140 | ) * self.running_var + self.momentum * unbias_var.data 141 | 142 | return mean, bias_var.clamp(self.eps) ** -0.5 143 | 144 | 145 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 146 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 147 | of 3d inputs 148 | .. math:: 149 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 150 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | For example, when one uses `nn.DataParallel` to wrap the network during 153 | training, PyTorch's implementation normalize the tensor on each device using 154 | the statistics only on that device, which accelerated the computation and 155 | is also easy to implement, but the statistics might be inaccurate. 156 | Instead, in this synchronized version, the statistics will be computed 157 | over all training samples distributed on multiple devices. 158 | 159 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 160 | as the built-in PyTorch implementation. 161 | The mean and standard-deviation are calculated per-dimension over 162 | the mini-batches and gamma and beta are learnable parameter vectors 163 | of size C (where C is the input size). 164 | During training, this layer keeps a running estimate of its computed mean 165 | and variance. The running sum is kept with a default momentum of 0.1. 166 | During evaluation, this running mean/variance is used for normalization. 167 | Because the BatchNorm is done over the `C` dimension, computing statistics 168 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 169 | Args: 170 | num_features: num_features from an expected input of 171 | size batch_size x num_features x height x width 172 | eps: a value added to the denominator for numerical stability. 173 | Default: 1e-5 174 | momentum: the value used for the running_mean and running_var 175 | computation. Default: 0.1 176 | affine: a boolean value that when set to ``True``, gives the layer learnable 177 | affine parameters. Default: ``True`` 178 | Shape: 179 | - Input: :math:`(N, C, H, W)` 180 | - Output: :math:`(N, C, H, W)` (same shape as input) 181 | Examples: 182 | >>> # With Learnable Parameters 183 | >>> m = SynchronizedBatchNorm2d(100) 184 | >>> # Without Learnable Parameters 185 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 186 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 187 | >>> output = m(input) 188 | """ 189 | 190 | def _check_input_dim(self, input): 191 | if input.dim() != 4: 192 | raise ValueError(f"expected 4D input (got {input.dim()}D input)") 193 | super()._check_input_dim(input) 194 | -------------------------------------------------------------------------------- /zs3/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator: 5 | def __init__(self, num_class, seen_classes_idx=None, unseen_classes_idx=None): 6 | self.num_class = num_class 7 | self.seen_classes_idx = seen_classes_idx 8 | self.unseen_classes_idx = unseen_classes_idx 9 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 10 | 11 | def Pixel_Accuracy(self): 12 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 13 | if self.seen_classes_idx and self.unseen_classes_idx: 14 | Acc_seen = ( 15 | np.diag(self.confusion_matrix)[self.seen_classes_idx].sum() 16 | / self.confusion_matrix[self.seen_classes_idx, :].sum() 17 | ) 18 | Acc_unseen = ( 19 | np.diag(self.confusion_matrix)[self.unseen_classes_idx].sum() 20 | / self.confusion_matrix[self.unseen_classes_idx, :].sum() 21 | ) 22 | return Acc, Acc_seen, Acc_unseen 23 | else: 24 | return Acc 25 | 26 | def Pixel_Accuracy_Class(self): 27 | Acc_by_class = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 28 | Acc = np.nanmean(np.nan_to_num(Acc_by_class)) 29 | if self.seen_classes_idx and self.unseen_classes_idx: 30 | Acc_seen = np.nanmean(np.nan_to_num(Acc_by_class[self.seen_classes_idx])) 31 | Acc_unseen = np.nanmean(np.nan_to_num(Acc_by_class[self.unseen_classes_idx])) 32 | return Acc, Acc_by_class, Acc_seen, Acc_unseen 33 | else: 34 | return Acc, Acc_by_class 35 | 36 | def Mean_Intersection_over_Union(self): 37 | MIoU_by_class = np.diag(self.confusion_matrix) / ( 38 | np.sum(self.confusion_matrix, axis=1) 39 | + np.sum(self.confusion_matrix, axis=0) 40 | - np.diag(self.confusion_matrix) 41 | ) 42 | MIoU = np.nanmean(np.nan_to_num(MIoU_by_class)) 43 | if self.seen_classes_idx and self.unseen_classes_idx: 44 | MIoU_seen = np.nanmean(np.nan_to_num(MIoU_by_class[self.seen_classes_idx])) 45 | MIoU_unseen = np.nanmean( 46 | np.nan_to_num(MIoU_by_class[self.unseen_classes_idx]) 47 | ) 48 | return MIoU, MIoU_by_class, MIoU_seen, MIoU_unseen 49 | else: 50 | return MIoU, MIoU_by_class 51 | 52 | def Frequency_Weighted_Intersection_over_Union(self): 53 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 54 | iu = np.diag(self.confusion_matrix) / ( 55 | np.sum(self.confusion_matrix, axis=1) 56 | + np.sum(self.confusion_matrix, axis=0) 57 | - np.diag(self.confusion_matrix) 58 | ) 59 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 60 | if self.seen_classes_idx and self.unseen_classes_idx: 61 | FWIoU_seen = ( 62 | freq[self.seen_classes_idx][freq[self.seen_classes_idx] > 0] 63 | * iu[self.seen_classes_idx][freq[self.seen_classes_idx] > 0] 64 | ).sum() 65 | FWIoU_unseen = ( 66 | freq[self.unseen_classes_idx][freq[self.unseen_classes_idx] > 0] 67 | * iu[self.unseen_classes_idx][freq[self.unseen_classes_idx] > 0] 68 | ).sum() 69 | return FWIoU, FWIoU_seen, FWIoU_unseen 70 | else: 71 | return FWIoU 72 | 73 | def _generate_matrix(self, gt_image, pre_image): 74 | mask = (gt_image >= 0) & (gt_image < self.num_class) 75 | label = self.num_class * gt_image[mask].astype("int") + pre_image[mask] 76 | count = np.bincount(label, minlength=self.num_class ** 2) 77 | confusion_matrix = count.reshape(self.num_class, self.num_class) 78 | return confusion_matrix 79 | 80 | def add_batch(self, gt_image, pre_image): 81 | assert gt_image.shape == pre_image.shape 82 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 83 | 84 | def reset(self): 85 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 86 | 87 | 88 | class Evaluator_seen_unseen: 89 | def __init__(self, num_class, unseen_classes_idx): 90 | self.num_class = num_class 91 | self.unseen_classes_idx = unseen_classes_idx 92 | 93 | def _fast_hist(self, label_true, label_pred, n_class, target="all", unseen=None): 94 | mask = (label_true >= 0) & (label_true < n_class) 95 | 96 | if target == "unseen": 97 | mask_unseen = np.in1d(label_true.ravel(), unseen).reshape(label_true.shape) 98 | mask = mask & mask_unseen 99 | 100 | elif target == "seen": 101 | seen = [x for x in range(n_class) if x not in unseen] 102 | mask_seen = np.in1d(label_true.ravel(), seen).reshape(label_true.shape) 103 | mask = mask & mask_seen 104 | 105 | hist = np.bincount( 106 | n_class * label_true[mask].astype(int) + label_pred[mask], 107 | minlength=n_class ** 2, 108 | ).reshape(n_class, n_class) 109 | return hist 110 | 111 | def _fast_hist_specific_class(self, label_true, label_pred, n_class, target_class): 112 | mask = (label_true >= 0) & (label_true < n_class) 113 | mask_class = np.in1d(label_true.ravel(), target_class).reshape(label_true.shape) 114 | mask = mask & mask_class 115 | hist = np.bincount( 116 | n_class * label_true[mask].astype(int) + label_pred[mask], 117 | minlength=n_class ** 2, 118 | ).reshape(n_class, n_class) 119 | return hist 120 | 121 | def _hist_to_metrics(self, hist): 122 | if hist.sum() == 0: 123 | acc = 0.0 124 | else: 125 | acc = np.diag(hist).sum() / hist.sum() 126 | 127 | acc_cls = np.diag(hist) / hist.sum(axis=1) 128 | acc_cls = np.nanmean(acc_cls) 129 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 130 | mean_iu = np.nanmean(iu) 131 | freq = hist.sum(axis=1) / hist.sum() 132 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 133 | return acc, acc_cls, mean_iu, fwavacc 134 | 135 | def label_accuracy_score(self, label_trues, label_preds, by_class=False): 136 | """Returns accuracy score evaluation result. 137 | - overall accuracy 138 | - mean accuracy 139 | - mean IU 140 | - fwavacc 141 | """ 142 | hist = np.zeros((self.num_class, self.num_class)) 143 | 144 | if self.unseen_classes_idx: 145 | unseen_hist, seen_hist = ( 146 | np.zeros((self.num_class, self.num_class)), 147 | np.zeros((self.num_class, self.num_class)), 148 | ) 149 | 150 | if by_class: 151 | class_hist = [] 152 | for class_idx in range(self.num_class): 153 | class_hist.append(np.zeros((self.num_class, self.num_class))) 154 | 155 | for lt, lp in zip(label_trues, label_preds): 156 | hist += self._fast_hist( 157 | lt.flatten(), lp.flatten(), self.num_class, target="all" 158 | ) 159 | if self.unseen_classes_idx: 160 | seen_hist += self._fast_hist( 161 | lt.flatten(), 162 | lp.flatten(), 163 | self.num_class, 164 | target="seen", 165 | unseen=self.unseen_classes_idx, 166 | ) 167 | unseen_hist += self._fast_hist( 168 | lt.flatten(), 169 | lp.flatten(), 170 | self.num_class, 171 | target="unseen", 172 | unseen=self.unseen_classes_idx, 173 | ) 174 | 175 | if by_class: 176 | unique = np.unique(lt.flatten()).astype(np.int32) 177 | for class_idx in unique: 178 | if class_idx != 255: 179 | class_hist[class_idx] += self._fast_hist_specific_class( 180 | lt.flatten(), lp.flatten(), self.num_class, class_idx 181 | ) 182 | 183 | metrics = self._hist_to_metrics(hist) 184 | if self.unseen_classes_idx: 185 | seen_metrics, unseen_metrics = ( 186 | self._hist_to_metrics(seen_hist), 187 | self._hist_to_metrics(unseen_hist), 188 | ) 189 | metrics = metrics, seen_metrics, unseen_metrics 190 | 191 | if by_class: 192 | class_metrics = [] 193 | for class_idx in range(self.num_class): 194 | class_metrics.append(self._hist_to_metrics(class_hist[class_idx])) 195 | 196 | metrics = metrics 197 | if by_class: 198 | return metrics, class_metrics 199 | else: 200 | return metrics 201 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ZS3 2 | 3 | Copyright 2019 Valeo 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | 18 | 19 | Apache License 20 | Version 2.0, January 2004 21 | https://www.apache.org/licenses/ 22 | 23 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 24 | 25 | 1. Definitions. 26 | 27 | "License" shall mean the terms and conditions for use, reproduction, 28 | and distribution as defined by Sections 1 through 9 of this document. 29 | 30 | "Licensor" shall mean the copyright owner or entity authorized by 31 | the copyright owner that is granting the License. 32 | 33 | "Legal Entity" shall mean the union of the acting entity and all 34 | other entities that control, are controlled by, or are under common 35 | control with that entity. For the purposes of this definition, 36 | "control" means (i) the power, direct or indirect, to cause the 37 | direction or management of such entity, whether by contract or 38 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 39 | outstanding shares, or (iii) beneficial ownership of such entity. 40 | 41 | "You" (or "Your") shall mean an individual or Legal Entity 42 | exercising permissions granted by this License. 43 | 44 | "Source" form shall mean the preferred form for making modifications, 45 | including but not limited to software source code, documentation 46 | source, and configuration files. 47 | 48 | "Object" form shall mean any form resulting from mechanical 49 | transformation or translation of a Source form, including but 50 | not limited to compiled object code, generated documentation, 51 | and conversions to other media types. 52 | 53 | "Work" shall mean the work of authorship, whether in Source or 54 | Object form, made available under the License, as indicated by a 55 | copyright notice that is included in or attached to the work 56 | (an example is provided in the Appendix below). 57 | 58 | "Derivative Works" shall mean any work, whether in Source or Object 59 | form, that is based on (or derived from) the Work and for which the 60 | editorial revisions, annotations, elaborations, or other modifications 61 | represent, as a whole, an original work of authorship. For the purposes 62 | of this License, Derivative Works shall not include works that remain 63 | separable from, or merely link (or bind by name) to the interfaces of, 64 | the Work and Derivative Works thereof. 65 | 66 | "Contribution" shall mean any work of authorship, including 67 | the original version of the Work and any modifications or additions 68 | to that Work or Derivative Works thereof, that is intentionally 69 | submitted to Licensor for inclusion in the Work by the copyright owner 70 | or by an individual or Legal Entity authorized to submit on behalf of 71 | the copyright owner. For the purposes of this definition, "submitted" 72 | means any form of electronic, verbal, or written communication sent 73 | to the Licensor or its representatives, including but not limited to 74 | communication on electronic mailing lists, source code control systems, 75 | and issue tracking systems that are managed by, or on behalf of, the 76 | Licensor for the purpose of discussing and improving the Work, but 77 | excluding communication that is conspicuously marked or otherwise 78 | designated in writing by the copyright owner as "Not a Contribution." 79 | 80 | "Contributor" shall mean Licensor and any individual or Legal Entity 81 | on behalf of whom a Contribution has been received by Licensor and 82 | subsequently incorporated within the Work. 83 | 84 | 2. Grant of Copyright License. Subject to the terms and conditions of 85 | this License, each Contributor hereby grants to You a perpetual, 86 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 87 | copyright license to reproduce, prepare Derivative Works of, 88 | publicly display, publicly perform, sublicense, and distribute the 89 | Work and such Derivative Works in Source or Object form. 90 | 91 | 3. Grant of Patent License. Subject to the terms and conditions of 92 | this License, each Contributor hereby grants to You a perpetual, 93 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 94 | (except as stated in this section) patent license to make, have made, 95 | use, offer to sell, sell, import, and otherwise transfer the Work, 96 | where such license applies only to those patent claims licensable 97 | by such Contributor that are necessarily infringed by their 98 | Contribution(s) alone or by combination of their Contribution(s) 99 | with the Work to which such Contribution(s) was submitted. If You 100 | institute patent litigation against any entity (including a 101 | cross-claim or counterclaim in a lawsuit) alleging that the Work 102 | or a Contribution incorporated within the Work constitutes direct 103 | or contributory patent infringement, then any patent licenses 104 | granted to You under this License for that Work shall terminate 105 | as of the date such litigation is filed. 106 | 107 | 4. Redistribution. You may reproduce and distribute copies of the 108 | Work or Derivative Works thereof in any medium, with or without 109 | modifications, and in Source or Object form, provided that You 110 | meet the following conditions: 111 | 112 | (a) You must give any other recipients of the Work or 113 | Derivative Works a copy of this License; and 114 | 115 | (b) You must cause any modified files to carry prominent notices 116 | stating that You changed the files; and 117 | 118 | (c) You must retain, in the Source form of any Derivative Works 119 | that You distribute, all copyright, patent, trademark, and 120 | attribution notices from the Source form of the Work, 121 | excluding those notices that do not pertain to any part of 122 | the Derivative Works; and 123 | 124 | (d) If the Work includes a "NOTICE" text file as part of its 125 | distribution, then any Derivative Works that You distribute must 126 | include a readable copy of the attribution notices contained 127 | within such NOTICE file, excluding those notices that do not 128 | pertain to any part of the Derivative Works, in at least one 129 | of the following places: within a NOTICE text file distributed 130 | as part of the Derivative Works; within the Source form or 131 | documentation, if provided along with the Derivative Works; or, 132 | within a display generated by the Derivative Works, if and 133 | wherever such third-party notices normally appear. The contents 134 | of the NOTICE file are for informational purposes only and 135 | do not modify the License. You may add Your own attribution 136 | notices within Derivative Works that You distribute, alongside 137 | or as an addendum to the NOTICE text from the Work, provided 138 | that such additional attribution notices cannot be construed 139 | as modifying the License. 140 | 141 | You may add Your own copyright statement to Your modifications and 142 | may provide additional or different license terms and conditions 143 | for use, reproduction, or distribution of Your modifications, or 144 | for any such Derivative Works as a whole, provided Your use, 145 | reproduction, and distribution of the Work otherwise complies with 146 | the conditions stated in this License. 147 | 148 | 5. Submission of Contributions. Unless You explicitly state otherwise, 149 | any Contribution intentionally submitted for inclusion in the Work 150 | by You to the Licensor shall be under the terms and conditions of 151 | this License, without any additional terms or conditions. 152 | Notwithstanding the above, nothing herein shall supersede or modify 153 | the terms of any separate license agreement you may have executed 154 | with Licensor regarding such Contributions. 155 | 156 | 6. Trademarks. This License does not grant permission to use the trade 157 | names, trademarks, service marks, or product names of the Licensor, 158 | except as required for reasonable and customary use in describing the 159 | origin of the Work and reproducing the content of the NOTICE file. 160 | 161 | 7. Disclaimer of Warranty. Unless required by applicable law or 162 | agreed to in writing, Licensor provides the Work (and each 163 | Contributor provides its Contributions) on an "AS IS" BASIS, 164 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 165 | implied, including, without limitation, any warranties or conditions 166 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 167 | PARTICULAR PURPOSE. You are solely responsible for determining the 168 | appropriateness of using or redistributing the Work and assume any 169 | risks associated with Your exercise of permissions under this License. 170 | 171 | 8. Limitation of Liability. In no event and under no legal theory, 172 | whether in tort (including negligence), contract, or otherwise, 173 | unless required by applicable law (such as deliberate and grossly 174 | negligent acts) or agreed to in writing, shall any Contributor be 175 | liable to You for damages, including any direct, indirect, special, 176 | incidental, or consequential damages of any character arising as a 177 | result of this License or out of the use or inability to use the 178 | Work (including but not limited to damages for loss of goodwill, 179 | work stoppage, computer failure or malfunction, or any and all 180 | other commercial damages or losses), even if such Contributor 181 | has been advised of the possibility of such damages. 182 | 183 | 9. Accepting Warranty or Additional Liability. While redistributing 184 | the Work or Derivative Works thereof, You may choose to offer, 185 | and charge a fee for, acceptance of support, warranty, indemnity, 186 | or other liability obligations and/or rights consistent with this 187 | License. However, in accepting such obligations, You may act only 188 | on Your own behalf and on Your sole responsibility, not on behalf 189 | of any other Contributor, and only if You agree to indemnify, 190 | defend, and hold each Contributor harmless for any liability 191 | incurred by, or claims asserted against, such Contributor by reason 192 | of your accepting any such warranty or additional liability. 193 | 194 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Shot Semantic Segmentation 2 | 3 | ## Paper 4 | ![](./teaser.png) 5 | 6 | [Zero-Shot Semantic Segmentation](https://arxiv.org/pdf/1906.00817.pdf) 7 | [Maxime Bucher](https://maximebucher.github.io/), [Tuan-Hung Vu](https://tuanhungvu.github.io/) , [Matthieu Cord](http://webia.lip6.fr/~cord/), [Patrick Pérez](https://ptrckprz.github.io/) 8 | valeo.ai, France 9 | Neural Information Processing Systems (NeurIPS) 2019 10 | 11 | If you find this code useful for your research, please cite our [paper](https://arxiv.org/pdf/1906.00817.pdf): 12 | 13 | ``` 14 | @inproceedings{bucher2019zero, 15 | title={Zero-Shot Semantic Segmentation}, 16 | author={Bucher, Maxime and Vu, Tuan-Hung and Cord, Mathieu and P{\'e}rez, Patrick}, 17 | booktitle={NeurIPS}, 18 | year={2019} 19 | } 20 | ``` 21 | 22 | ## Abstract 23 | Semantic segmentation models are limited in their ability to scale to large numbers of object classes. In this paper, we introduce the new task of zero-shot semantic segmentation: learning pixel-wise classifiers for never-seen object categories with zero training examples. To this end, we present a novel architecture, ZS3Net, combining a deep visual segmentation model with an approach to generate visual representations from semantic word embeddings. By this way, ZS3Net addresses pixel classification tasks where both seen and unseen categories are faced at test time (so called "generalized" zero-shot classification). Performance is further improved by a self-training step that relies on automatic pseudo-labeling of pixels from unseen classes. On the two standard segmentation datasets, Pascal-VOC and Pascal-Context, we propose zero-shot benchmarks and set competitive baselines. For complex scenes as ones in the Pascal-Context dataset, we extend our approach by using a graph-context encoding to fully leverage spatial context priors coming from class-wise segmentation maps. 24 | 25 | ## Code 26 | 27 | ### Pre-requisites 28 | * Python 3.6 29 | * Pytorch >= 1.0 or higher 30 | * CUDA 9.0 or higher 31 | 32 | ### Installation 33 | 1. Clone the repo: 34 | ```bash 35 | $ git clone https://github.com/valeoai/ZS3 36 | ``` 37 | 38 | 2. Install this repository and the dependencies using pip: 39 | ```bash 40 | $ pip install -e ZS3 41 | ``` 42 | 43 | With this, you can edit the ZS3 code on the fly and import function and classes of ZS3 in other project as well. 44 | 45 | 3. Optional. To uninstall this package, run: 46 | ```bash 47 | $ pip uninstall ZS3 48 | ``` 49 | 50 | You can take a look at the Dockerfile if you are uncertain about steps to install this project. 51 | 52 | ### Datasets 53 | 54 | #### Pascal-VOC 2012 55 | * **Pascal-VOC 2012**: Please follow the instructions [here](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) to download images and semantic segmentation annotations. 56 | 57 | * **Semantic Boundaries Dataset**: Please follow the instructions [here](http://home.bharathh.info/pubs/codes/SBD/download.html) to download images and semantic segmentation annotations. Use [this](http://home.bharathh.info/pubs/codes/SBD/train_noval.txt) train set, which excludes overlap with Pascal-VOC validation set. 58 | 59 | The Pascal-VOC and SBD datasets directory should have this structure: 60 | ```bash 61 | ZS3/data/VOC2012/ % Pascal VOC and SBD datasets root 62 | ZS3/data/VOC2012/ImageSets/Segmentation/ % Pascal VOC splits 63 | ZS3/data/VOC2012/JPEGImages/ % Pascal VOC images 64 | ZS3/data/VOC2012/SegmentationClass/ % Pascal VOC segmentation maps 65 | ZS3/data/VOC2012/benchmark_RELEASE/dataset/img % SBD images 66 | ZS3/data/VOC2012/benchmark_RELEASE/dataset/cls % SBD segmentation maps 67 | ZS3/data/VOC2012/benchmark_RELEASE/dataset/train_noval.txt % SBD train set 68 | ``` 69 | 70 | 71 | #### Pascal-Context 72 | 73 | * **Pascal-VOC 2010**: Please follow the instructions [here](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/index.html) to download images. 74 | 75 | * **Pascal-Context**: Please follow the instructions [here](https://cs.stanford.edu/~roozbeh/pascal-context/) to download segmentation annotations. 76 | 77 | The Pascal-Context dataset directory should have this structure: 78 | ```bash 79 | ZS3/data/context/ % Pascal context dataset root 80 | ZS3/data/context/train.txt % Pascal context train split 81 | ZS3/data/context/val.txt % Pascal context val split 82 | ZS3/data/context/full_annotations/trainval/ % Pascal context segmentation maps 83 | ZS3/data/context/full_annotations/labels.txt % Pascal context 459 classes 84 | ZS3/data/context/classes-59.txt % Pascal context 59 classes 85 | ZS3/data/context/VOCdevkit/VOC2010/JPEGImages % Pascal VOC images 86 | ``` 87 | 88 | ### Training 89 | 90 | #### Pascal-VOC 91 | Follow steps below to train your model: 92 | 93 | 1. Train deeplabv3+ using Pascal VOC dataset and ResNet as backbone, pretrained on imagenet ([weights here](https://github.com/valeoai/ZS3/releases/download/v0.1/resnet_backbone_pretrained_imagenet_wo_pascalvoc.pth.tar)): 94 | 95 | ```Shell 96 | python train_pascal.py 97 | ``` 98 | * Main options 99 | - `imagenet_pretrained_path`: Path to ImageNet pretrained weights. 100 | - `exp_path`: Path to saved logs and weights folder. 101 | - `checkname`: Name of the saved logs and weights folder. 102 | - `unseen_classes_idx`: List of idx of unseen classes. 103 | 104 | * Trained deeplabv3+ weights 105 | - [2 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pretrained_pascal_voc_02_unseen.pth.tar) 106 | - [4 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pretrained_pascal_voc_04_unseen.pth.tar) 107 | - [6 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pretrained_pascal_voc_06_unseen.pth.tar) 108 | - [8 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pretrained_pascal_voc_08_unseen.pth.tar) 109 | - [10 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pretrained_pascal_voc_10_unseen.pth.tar) 110 | 111 | 112 | 113 | 114 | 2. Train GMMN and finetune the last classification layer of the trained deeplabv3+ model: 115 | 116 | ```Shell 117 | python train_pascal_GMMN.py 118 | ``` 119 | * Main options 120 | - `imagenet_pretrained_path`: Path to ImageNet pretrained weights. 121 | - `resume`: Path to deeplabv3+ weights. 122 | - `exp_path`: Path to saved logs and weights folder. 123 | - `checkname`: Name of the saved logs and weights folder. 124 | - `seen_classes_idx_metric`: List of idx of seen classes. 125 | - `unseen_classes_idx_metric`: List of idx of unseen classes. 126 | 127 | * Final deeplabv3+ and GMMN weights 128 | - [2 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pascal_voc_02_unseen_GMMN_final.pth.tar) 129 | - [4 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pascal_voc_04_unseen_GMMN_final.pth.tar) 130 | - [6 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pascal_voc_06_unseen_GMMN_final.pth.tar) 131 | - [8 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pascal_voc_08_unseen_GMMN_final.pth.tar) 132 | - [10 unseen classes](https://github.com/valeoai/ZS3/releases/download/v0.1/deeplab_pascal_voc_10_unseen_GMMN_final.pth.tar) 133 | 134 | 135 | #### Pascal-Context 136 | Follow steps below to train your model: 137 | 138 | 1. Train deeplabv3+ using Pascal Context dataset and ResNet as backbone, pretrained on imagenet ([weights here](https://github.com/valeoai/ZS3/releases/download/0.2/resnet_backbone_pretrained_imagenet_wo_pascalcontext.pth.tar)): 139 | 140 | ```Shell 141 | python train_context.py 142 | ``` 143 | * Main options 144 | - `imagenet_pretrained_path`: Path to ImageNet pretrained weights. 145 | - `exp_path`: Path to saved logs and weights folder. 146 | - `checkname`: Name of the saved logs and weights folder. 147 | - `unseen_classes_idx`: List of idx of unseen classes. 148 | 149 | * Trained deeplabv3+ weights 150 | - [2 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pretrained_pascal_context_02_unseen.pth.tar) 151 | - [4 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pretrained_pascal_context_04_unseen.pth.tar) 152 | - [6 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pretrained_pascal_context_06_unseen.pth.tar) 153 | - [8 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pretrained_pascal_context_08_unseen.pth.tar) 154 | - [10 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pretrained_pascal_context_10_unseen.pth.tar) 155 | 156 | 157 | 158 | 2. Train GMMN and finetune the last classification layer of the trained deeplabv3+ model: 159 | 160 | ```Shell 161 | python train_context_GMMN.py 162 | ``` 163 | * Main options 164 | - `imagenet_pretrained_path`: Path to ImageNet pretrained weights. 165 | - `resume`: Path to deeplabv3+ weights. 166 | - `exp_path`: Path to saved logs and weights folder. 167 | - `checkname`: Name of the saved logs and weights folder. 168 | - `seen_classes_idx_metric`: List of idx of seen classes. 169 | - `unseen_classes_idx_metric`: List of idx of unseen classes. 170 | 171 | * Final deeplabv3+ and GMMN weights 172 | - [2 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_02_unseen_GMMN_final.pth.tar) 173 | - [4 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_04_unseen_GMMN_final.pth.tar) 174 | - [6 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_06_unseen_GMMN_final.pth.tar) 175 | - [8 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_08_unseen_GMMN_final.pth.tar) 176 | - [10 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_10_unseen_GMMN_final.pth.tar) 177 | 178 | 179 | (2 bis). Train GMMN with graph context and finetune the last classification layer of the trained deeplabv3+ model: 180 | 181 | ```Shell 182 | python train_context_GMMN_GCNcontext.py 183 | ``` 184 | * Main options 185 | - `imagenet_pretrained_path`: Path to ImageNet pretrained weights. 186 | - `resume`: Path to deeplabv3+ weights. 187 | - `exp_path`: Path to saved logs and weights folder. 188 | - `checkname`: Name of the saved logs and weights folder. 189 | - `seen_classes_idx_metric`: List of idx of seen classes. 190 | - `unseen_classes_idx_metric`: List of idx of unseen classes. 191 | 192 | * Final deeplabv3+ and GMMN with graph context weights 193 | - [2 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_02_unseen_GMMN_GC_final.pth.tar) 194 | - [4 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_04_unseen_GMMN_GC_final.pth.tar) 195 | - [6 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_06_unseen_GMMN_GC_final.pth.tar) 196 | - [8 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_08_unseen_GMMN_GC_final.pth.tar) 197 | - [10 unseen classes](https://github.com/valeoai/ZS3/releases/download/0.2/deeplab_pascal_context_10_unseen_GMMN_GC_final.pth.tar) 198 | 199 | ### Testing 200 | 201 | ```Shell 202 | python eval_pascal.py 203 | ``` 204 | ```Shell 205 | python eval_context.py 206 | ``` 207 | * Main options 208 | - `resume`: Path to deeplabv3+ and GMMN weights. 209 | - `seen_classes_idx_metric`: List of idx of seen classes. 210 | - `unseen_classes_idx_metric`: List of idx of unseen classes. 211 | 212 | ## Acknowledgements 213 | * This codebase is heavily borrowed from [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception). 214 | * Special thanks for [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse) for his work in enhancing, cleaning and formatting this repository for release. 215 | 216 | ## License 217 | ZS3Net is released under the [Apache 2.0 license](./LICENSE). 218 | -------------------------------------------------------------------------------- /zs3/train_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from zs3.dataloaders import make_data_loader 8 | from zs3.modeling.deeplab import DeepLab 9 | from zs3.modeling.sync_batchnorm.replicate import patch_replication_callback 10 | from zs3.dataloaders.datasets import DATASETS_DIRS 11 | from zs3.utils.calculate_weights import calculate_weigths_labels 12 | from zs3.utils.loss import SegmentationLosses 13 | from zs3.utils.lr_scheduler import LR_Scheduler 14 | from zs3.utils.metrics import Evaluator 15 | from zs3.utils.saver import Saver 16 | from zs3.utils.summaries import TensorboardSummary 17 | from zs3.parsing import get_parser 18 | from zs3.exp_data import CLASSES_NAMES 19 | from zs3.base_trainer import BaseTrainer 20 | 21 | 22 | class Trainer(BaseTrainer): 23 | def __init__(self, args): 24 | self.args = args 25 | 26 | # Define Saver 27 | self.saver = Saver(args) 28 | self.saver.save_experiment_config() 29 | # Define Tensorboard Summary 30 | self.summary = TensorboardSummary(self.saver.experiment_dir) 31 | self.writer = self.summary.create_summary() 32 | 33 | # Define Dataloader 34 | kwargs = {"num_workers": args.workers, "pin_memory": True} 35 | (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader( 36 | args, **kwargs 37 | ) 38 | 39 | # Define network 40 | model = DeepLab( 41 | num_classes=self.nclass, 42 | output_stride=args.out_stride, 43 | sync_bn=args.sync_bn, 44 | freeze_bn=args.freeze_bn, 45 | pretrained=args.imagenet_pretrained, 46 | imagenet_pretrained_path=args.imagenet_pretrained_path, 47 | ) 48 | 49 | train_params = [ 50 | {"params": model.get_1x_lr_params(), "lr": args.lr}, 51 | {"params": model.get_10x_lr_params(), "lr": args.lr * 10}, 52 | ] 53 | 54 | # Define Optimizer 55 | optimizer = torch.optim.SGD( 56 | train_params, 57 | momentum=args.momentum, 58 | weight_decay=args.weight_decay, 59 | nesterov=args.nesterov, 60 | ) 61 | 62 | # Define Criterion 63 | # whether to use class balanced weights 64 | if args.use_balanced_weights: 65 | classes_weights_path = ( 66 | DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy" 67 | ) 68 | if os.path.isfile(classes_weights_path): 69 | weight = np.load(classes_weights_path) 70 | else: 71 | weight = calculate_weigths_labels( 72 | args.dataset, self.train_loader, self.nclass 73 | ) 74 | weight = torch.from_numpy(weight.astype(np.float32)) 75 | else: 76 | weight = None 77 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss( 78 | mode=args.loss_type 79 | ) 80 | self.model, self.optimizer = model, optimizer 81 | 82 | # Define Evaluator 83 | self.evaluator = Evaluator(self.nclass) 84 | # Define lr scheduler 85 | self.scheduler = LR_Scheduler( 86 | args.lr_scheduler, args.lr, args.epochs, len(self.train_loader) 87 | ) 88 | 89 | # Using cuda 90 | if args.cuda: 91 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 92 | patch_replication_callback(self.model) 93 | self.model = self.model.cuda() 94 | 95 | # Resuming checkpoint 96 | self.best_pred = 0.0 97 | if args.resume is not None: 98 | if not os.path.isfile(args.resume): 99 | raise RuntimeError(f"=> no checkpoint found at '{args.resume}'") 100 | checkpoint = torch.load(args.resume) 101 | args.start_epoch = checkpoint["epoch"] 102 | if args.cuda: 103 | self.model.module.load_state_dict(checkpoint["state_dict"]) 104 | else: 105 | self.model.load_state_dict(checkpoint["state_dict"]) 106 | if not args.ft: 107 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 108 | self.best_pred = checkpoint["best_pred"] 109 | print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 110 | 111 | # Clear start epoch if fine-tuning 112 | if args.ft: 113 | args.start_epoch = 0 114 | 115 | def validation(self, epoch): 116 | self.model.eval() 117 | self.evaluator.reset() 118 | tbar = tqdm(self.val_loader, desc="\r") 119 | test_loss = 0.0 120 | for i, sample in enumerate(tbar): 121 | image, target = sample["image"], sample["label"] 122 | if self.args.cuda: 123 | image, target = image.cuda(), target.cuda() 124 | with torch.no_grad(): 125 | output = self.model(image) 126 | loss = self.criterion(output, target) 127 | test_loss += loss.item() 128 | tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) 129 | pred = output.data.cpu().numpy() 130 | target = target.cpu().numpy() 131 | pred = np.argmax(pred, axis=1) 132 | # Add batch sample into evaluator 133 | self.evaluator.add_batch(target, pred) 134 | 135 | # Fast test during the training 136 | Acc = self.evaluator.Pixel_Accuracy() 137 | Acc_class, Acc_class_by_class = self.evaluator.Pixel_Accuracy_Class() 138 | mIoU, mIoU_by_class = self.evaluator.Mean_Intersection_over_Union() 139 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 140 | self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch) 141 | self.writer.add_scalar("val/mIoU", mIoU, epoch) 142 | self.writer.add_scalar("val/Acc", Acc, epoch) 143 | self.writer.add_scalar("val/Acc_class", Acc_class, epoch) 144 | self.writer.add_scalar("val/fwIoU", FWIoU, epoch) 145 | print("Validation:") 146 | print( 147 | "[Epoch: %d, numImages: %5d]" 148 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 149 | ) 150 | print(f"Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") 151 | print(f"Loss: {test_loss:.3f}") 152 | 153 | for i, (class_name, acc_value, mIoU_value) in enumerate( 154 | zip(CLASSES_NAMES, Acc_class_by_class, mIoU_by_class) 155 | ): 156 | self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) 157 | self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) 158 | print(CLASSES_NAMES[i], "- acc:", acc_value, " mIoU:", mIoU_value) 159 | 160 | new_pred = mIoU 161 | is_best = True 162 | self.best_pred = new_pred 163 | self.saver.save_checkpoint( 164 | { 165 | "epoch": epoch + 1, 166 | "state_dict": self.model.module.state_dict(), 167 | "optimizer": self.optimizer.state_dict(), 168 | "best_pred": self.best_pred, 169 | }, 170 | is_best, 171 | ) 172 | 173 | 174 | def main(): 175 | parser = get_parser() 176 | parser.add_argument( 177 | "--imagenet_pretrained", 178 | type=bool, 179 | default=True, 180 | help="imagenet pretrained backbone", 181 | ) 182 | 183 | parser.add_argument( 184 | "--out-stride", type=int, default=16, help="network output stride (default: 8)" 185 | ) 186 | 187 | # PASCAL VOC 188 | parser.add_argument( 189 | "--dataset", 190 | type=str, 191 | default="context", 192 | choices=["pascal", "coco", "cityscapes", "context"], 193 | help="dataset name (default: pascal)", 194 | ) 195 | 196 | parser.add_argument("--base-size", type=int, default=312, help="base image size") 197 | parser.add_argument("--crop-size", type=int, default=312, help="crop image size") 198 | parser.add_argument( 199 | "--loss-type", 200 | type=str, 201 | default="ce", 202 | choices=["ce", "focal"], 203 | help="loss func type (default: ce)", 204 | ) 205 | # training hyper params 206 | 207 | # PASCAL VOC 208 | parser.add_argument( 209 | "--epochs", 210 | type=int, 211 | default=200, 212 | metavar="N", 213 | help="number of epochs to train (default: auto)", 214 | ) 215 | 216 | # PASCAL VOC 217 | parser.add_argument( 218 | "--batch-size", 219 | type=int, 220 | default=10, 221 | metavar="N", 222 | help="input batch size for training (default: auto)", 223 | ) 224 | # checking point 225 | parser.add_argument( 226 | "--resume", 227 | type=str, 228 | default=None, 229 | help="put the path to resuming file if needed", 230 | ) 231 | parser.add_argument( 232 | "--checkname", 233 | type=str, 234 | default="context_2_unseen", 235 | help="set the checkpoint name", 236 | ) 237 | 238 | parser.add_argument( 239 | "--imagenet_pretrained_path", 240 | type=str, 241 | default="checkpoint/resnet_backbone_pretrained_imagenet_wo_pascalcontext.pth.tar", 242 | help="set the checkpoint name", 243 | ) 244 | 245 | # evaluation option 246 | parser.add_argument( 247 | "--eval-interval", type=int, default=10, help="evaluation interval (default: 1)" 248 | ) 249 | 250 | # 2 unseen 251 | unseen_names = ["cow", "motorbike"] 252 | # 4 unseen 253 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat'] 254 | # 6 unseen 255 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence'] 256 | # 8 unseen 257 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence', 'bird', 'tvmonitor'] 258 | # 10 unseen 259 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence', 'bird', 'tvmonitor', 'aeroplane', 'keyboard'] 260 | 261 | unseen_classes_idx = [] 262 | for name in unseen_names: 263 | unseen_classes_idx.append(CLASSES_NAMES.index(name)) 264 | print(unseen_classes_idx) 265 | # all classes 266 | parser.add_argument("--unseen_classes_idx", type=int, default=unseen_classes_idx) 267 | args = parser.parse_args() 268 | args.cuda = not args.no_cuda and torch.cuda.is_available() 269 | if args.cuda: 270 | try: 271 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(",")] 272 | except ValueError: 273 | raise ValueError( 274 | "Argument --gpu_ids must be a comma-separated list of integers only" 275 | ) 276 | 277 | args.sync_bn = args.cuda and len(args.gpu_ids) > 1 278 | 279 | # default settings for epochs, batch_size and lr 280 | if args.epochs is None: 281 | epoches = { 282 | "coco": 30, 283 | "cityscapes": 200, 284 | "pascal": 50, 285 | "pascal": 150, 286 | } 287 | args.epochs = epoches[args.dataset.lower()] 288 | 289 | if args.batch_size is None: 290 | args.batch_size = 4 * len(args.gpu_ids) 291 | 292 | if args.test_batch_size is None: 293 | args.test_batch_size = args.batch_size 294 | 295 | if args.lr is None: 296 | lrs = { 297 | "coco": 0.1, 298 | "cityscapes": 0.01, 299 | "pascal": 0.007, 300 | } 301 | args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 302 | 303 | if args.checkname is None: 304 | args.checkname = "deeplab-resnet" 305 | print(args) 306 | torch.manual_seed(args.seed) 307 | trainer = Trainer(args) 308 | print("Starting Epoch:", trainer.args.start_epoch) 309 | print("Total Epoches:", trainer.args.epochs) 310 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 311 | trainer.training(epoch) 312 | if not trainer.args.no_val and epoch % args.eval_interval == ( 313 | args.eval_interval - 1 314 | ): 315 | trainer.validation(epoch) 316 | trainer.writer.close() 317 | 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /zs3/train_pascal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from zs3.dataloaders import make_data_loader 8 | from zs3.modeling.deeplab import DeepLab 9 | from zs3.modeling.sync_batchnorm.replicate import patch_replication_callback 10 | from zs3.dataloaders.datasets import DATASETS_DIRS 11 | from zs3.utils.calculate_weights import calculate_weigths_labels 12 | from zs3.utils.loss import SegmentationLosses 13 | from zs3.utils.lr_scheduler import LR_Scheduler 14 | from zs3.utils.metrics import Evaluator 15 | from zs3.utils.saver import Saver 16 | from zs3.utils.summaries import TensorboardSummary 17 | from zs3.parsing import get_parser 18 | from zs3.exp_data import CLASSES_NAMES 19 | from zs3.base_trainer import BaseTrainer 20 | 21 | 22 | class Trainer(BaseTrainer): 23 | def __init__(self, args): 24 | self.args = args 25 | 26 | # Define Saver 27 | self.saver = Saver(args) 28 | self.saver.save_experiment_config() 29 | # Define Tensorboard Summary 30 | self.summary = TensorboardSummary(self.saver.experiment_dir) 31 | self.writer = self.summary.create_summary() 32 | 33 | # Define Dataloader 34 | kwargs = {"num_workers": args.workers, "pin_memory": True} 35 | (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader( 36 | args, **kwargs 37 | ) 38 | 39 | # Define network 40 | model = DeepLab( 41 | num_classes=self.nclass, 42 | output_stride=args.out_stride, 43 | sync_bn=args.sync_bn, 44 | freeze_bn=args.freeze_bn, 45 | pretrained=args.imagenet_pretrained, 46 | imagenet_pretrained_path=args.imagenet_pretrained_path, 47 | ) 48 | 49 | train_params = [ 50 | {"params": model.get_1x_lr_params(), "lr": args.lr}, 51 | {"params": model.get_10x_lr_params(), "lr": args.lr * 10}, 52 | ] 53 | 54 | # Define Optimizer 55 | optimizer = torch.optim.SGD( 56 | train_params, 57 | momentum=args.momentum, 58 | weight_decay=args.weight_decay, 59 | nesterov=args.nesterov, 60 | ) 61 | 62 | # Define Criterion 63 | # whether to use class balanced weights 64 | if args.use_balanced_weights: 65 | classes_weights_path = ( 66 | DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy" 67 | ) 68 | if os.path.isfile(classes_weights_path): 69 | weight = np.load(classes_weights_path) 70 | else: 71 | weight = calculate_weigths_labels( 72 | args.dataset, self.train_loader, self.nclass 73 | ) 74 | weight = torch.from_numpy(weight.astype(np.float32)) 75 | else: 76 | weight = None 77 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss( 78 | mode=args.loss_type 79 | ) 80 | self.model, self.optimizer = model, optimizer 81 | 82 | # Define Evaluator 83 | self.evaluator = Evaluator(self.nclass) 84 | # Define lr scheduler 85 | self.scheduler = LR_Scheduler( 86 | args.lr_scheduler, args.lr, args.epochs, len(self.train_loader) 87 | ) 88 | 89 | # Using cuda 90 | if args.cuda: 91 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 92 | patch_replication_callback(self.model) 93 | self.model = self.model.cuda() 94 | 95 | # Resuming checkpoint 96 | self.best_pred = 0.0 97 | if args.resume is not None: 98 | if not os.path.isfile(args.resume): 99 | raise RuntimeError(f"=> no checkpoint found at '{args.resume}'") 100 | checkpoint = torch.load(args.resume) 101 | args.start_epoch = checkpoint["epoch"] 102 | if args.cuda: 103 | self.model.module.load_state_dict(checkpoint["state_dict"]) 104 | else: 105 | self.model.load_state_dict(checkpoint["state_dict"]) 106 | if not args.ft: 107 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 108 | self.best_pred = checkpoint["best_pred"] 109 | print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 110 | 111 | # Clear start epoch if fine-tuning 112 | if args.ft: 113 | args.start_epoch = 0 114 | 115 | def validation(self, epoch): 116 | class_names = CLASSES_NAMES[:21] 117 | self.model.eval() 118 | self.evaluator.reset() 119 | tbar = tqdm(self.val_loader, desc="\r") 120 | test_loss = 0.0 121 | for i, sample in enumerate(tbar): 122 | image, target = sample["image"], sample["label"] 123 | if self.args.cuda: 124 | image, target = image.cuda(), target.cuda() 125 | with torch.no_grad(): 126 | output = self.model(image) 127 | loss = self.criterion(output, target) 128 | test_loss += loss.item() 129 | tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) 130 | pred = output.data.cpu().numpy() 131 | target = target.cpu().numpy() 132 | pred = np.argmax(pred, axis=1) 133 | # Add batch sample into evaluator 134 | self.evaluator.add_batch(target, pred) 135 | 136 | # Fast test during the training 137 | Acc = self.evaluator.Pixel_Accuracy() 138 | Acc_class, Acc_class_by_class = self.evaluator.Pixel_Accuracy_Class() 139 | mIoU, mIoU_by_class = self.evaluator.Mean_Intersection_over_Union() 140 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 141 | self.writer.add_scalar("val/total_loss_epoch", test_loss, epoch) 142 | self.writer.add_scalar("val/mIoU", mIoU, epoch) 143 | self.writer.add_scalar("val/Acc", Acc, epoch) 144 | self.writer.add_scalar("val/Acc_class", Acc_class, epoch) 145 | self.writer.add_scalar("val/fwIoU", FWIoU, epoch) 146 | print("Validation:") 147 | print( 148 | "[Epoch: %d, numImages: %5d]" 149 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 150 | ) 151 | print(f"Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") 152 | print(f"Loss: {test_loss:.3f}") 153 | 154 | for i, (class_name, acc_value, mIoU_value) in enumerate( 155 | zip(class_names, Acc_class_by_class, mIoU_by_class) 156 | ): 157 | self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) 158 | self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) 159 | print(class_names[i], "- acc:", acc_value, " mIoU:", mIoU_value) 160 | 161 | new_pred = mIoU 162 | is_best = True 163 | self.best_pred = new_pred 164 | self.saver.save_checkpoint( 165 | { 166 | "epoch": epoch + 1, 167 | "state_dict": self.model.module.state_dict(), 168 | "optimizer": self.optimizer.state_dict(), 169 | "best_pred": self.best_pred, 170 | }, 171 | is_best, 172 | ) 173 | 174 | 175 | def main(): 176 | parser = get_parser() 177 | parser.add_argument( 178 | "--imagenet_pretrained", 179 | type=bool, 180 | default=True, 181 | help="imagenet pretrained backbone", 182 | ) 183 | 184 | parser.add_argument( 185 | "--out-stride", type=int, default=16, help="network output stride (default: 8)" 186 | ) 187 | 188 | # PASCAL VOC 189 | parser.add_argument( 190 | "--dataset", 191 | type=str, 192 | default="pascal", 193 | choices=["pascal", "coco", "cityscapes"], 194 | help="dataset name (default: pascal)", 195 | ) 196 | 197 | parser.add_argument( 198 | "--use-sbd", 199 | action="store_true", 200 | default=True, 201 | help="whether to use SBD dataset (default: True)", 202 | ) 203 | parser.add_argument("--base-size", type=int, default=312, help="base image size") 204 | parser.add_argument("--crop-size", type=int, default=312, help="crop image size") 205 | parser.add_argument( 206 | "--loss-type", 207 | type=str, 208 | default="ce", 209 | choices=["ce", "focal"], 210 | help="loss func type (default: ce)", 211 | ) 212 | # training hyper params 213 | 214 | # PASCAL VOC 215 | parser.add_argument( 216 | "--epochs", 217 | type=int, 218 | default=200, 219 | metavar="N", 220 | help="number of epochs to train (default: auto)", 221 | ) 222 | 223 | # PASCAL VOC 224 | parser.add_argument( 225 | "--batch-size", 226 | type=int, 227 | default=16, 228 | metavar="N", 229 | help="input batch size for training (default: auto)", 230 | ) 231 | # checking point 232 | parser.add_argument( 233 | "--resume", 234 | type=str, 235 | default=None, 236 | help="put the path to resuming file if needed", 237 | ) 238 | 239 | parser.add_argument( 240 | "--imagenet_pretrained_path", 241 | type=str, 242 | default="checkpoint/resnet_backbone_pretrained_imagenet_wo_pascalvoc.pth.tar", 243 | help="set the checkpoint name", 244 | ) 245 | 246 | parser.add_argument( 247 | "--checkname", 248 | type=str, 249 | default="pascal_2_unseen", 250 | help="set the checkpoint name", 251 | ) 252 | 253 | # evaluation option 254 | parser.add_argument( 255 | "--eval-interval", type=int, default=10, help="evaluation interval (default: 1)" 256 | ) 257 | # only seen classes 258 | # 10 unseen 259 | # parser.add_argument('--unseen_classes_idx', type=int, default=[10, 14, 1, 18, 8, 20, 19, 5, 9, 16]) 260 | # 8 unseen 261 | # parser.add_argument('--unseen_classes_idx', type=int, default=[10, 14, 1, 18, 8, 20, 19, 5]) 262 | # 6 unseen 263 | # parser.add_argument('--unseen_classes_idx', type=int, default=[10, 14, 1, 18, 8, 20]) 264 | # 4 unseen 265 | # parser.add_argument('--unseen_classes_idx', type=int, default=[10, 14, 1, 18]) 266 | # 2 unseen 267 | parser.add_argument("--unseen_classes_idx", type=int, default=[10, 14]) 268 | 269 | args = parser.parse_args() 270 | args.cuda = not args.no_cuda and torch.cuda.is_available() 271 | if args.cuda: 272 | try: 273 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(",")] 274 | except ValueError: 275 | raise ValueError( 276 | "Argument --gpu_ids must be a comma-separated list of integers only" 277 | ) 278 | 279 | args.sync_bn = args.cuda and len(args.gpu_ids) > 1 280 | 281 | # default settings for epochs, batch_size and lr 282 | if args.epochs is None: 283 | epoches = { 284 | "coco": 30, 285 | "cityscapes": 200, 286 | "pascal": 50, 287 | } 288 | args.epochs = epoches[args.dataset.lower()] 289 | 290 | if args.batch_size is None: 291 | args.batch_size = 4 * len(args.gpu_ids) 292 | 293 | if args.test_batch_size is None: 294 | args.test_batch_size = args.batch_size 295 | 296 | if args.lr is None: 297 | lrs = { 298 | "coco": 0.1, 299 | "cityscapes": 0.01, 300 | "pascal": 0.007, 301 | } 302 | args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 303 | 304 | if args.checkname is None: 305 | args.checkname = "deeplab-resnet" 306 | print(args) 307 | torch.manual_seed(args.seed) 308 | trainer = Trainer(args) 309 | print("Starting Epoch:", trainer.args.start_epoch) 310 | print("Total Epoches:", trainer.args.epochs) 311 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 312 | trainer.training(epoch) 313 | if not trainer.args.no_val and epoch % args.eval_interval == ( 314 | args.eval_interval - 1 315 | ): 316 | trainer.validation(epoch) 317 | trainer.writer.close() 318 | 319 | 320 | if __name__ == "__main__": 321 | main() 322 | -------------------------------------------------------------------------------- /zs3/eval_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from zs3.dataloaders import make_data_loader 8 | from zs3.modeling.deeplab import DeepLab 9 | from zs3.modeling.sync_batchnorm.replicate import patch_replication_callback 10 | from zs3.dataloaders.datasets import DATASETS_DIRS 11 | from zs3.utils.calculate_weights import calculate_weigths_labels 12 | from zs3.utils.loss import SegmentationLosses 13 | from zs3.utils.lr_scheduler import LR_Scheduler 14 | from zs3.utils.metrics import Evaluator 15 | from zs3.utils.saver import Saver 16 | from zs3.utils.summaries import TensorboardSummary 17 | from zs3.parsing import get_parser 18 | from zs3.exp_data import CLASSES_NAMES 19 | 20 | 21 | class Trainer: 22 | def __init__(self, args): 23 | self.args = args 24 | 25 | # Define Saver 26 | self.saver = Saver(args) 27 | self.saver.save_experiment_config() 28 | # Define Tensorboard Summary 29 | self.summary = TensorboardSummary(self.saver.experiment_dir) 30 | self.writer = self.summary.create_summary() 31 | 32 | # Define Dataloader 33 | kwargs = {"num_workers": args.workers, "pin_memory": True} 34 | (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader( 35 | args, **kwargs 36 | ) 37 | 38 | # Define network 39 | model = DeepLab( 40 | num_classes=self.nclass, 41 | output_stride=args.out_stride, 42 | sync_bn=args.sync_bn, 43 | freeze_bn=args.freeze_bn, 44 | imagenet_pretrained_path=args.imagenet_pretrained_path, 45 | ) 46 | train_params = [ 47 | {"params": model.get_1x_lr_params(), "lr": args.lr}, 48 | {"params": model.get_10x_lr_params(), "lr": args.lr * 10}, 49 | ] 50 | 51 | # Define Optimizer 52 | optimizer = torch.optim.SGD( 53 | train_params, 54 | momentum=args.momentum, 55 | weight_decay=args.weight_decay, 56 | nesterov=args.nesterov, 57 | ) 58 | 59 | # Define Criterion 60 | # whether to use class balanced weights 61 | if args.use_balanced_weights: 62 | classes_weights_path = ( 63 | DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy" 64 | ) 65 | if os.path.isfile(classes_weights_path): 66 | weight = np.load(classes_weights_path) 67 | else: 68 | weight = calculate_weigths_labels( 69 | args.dataset, self.train_loader, self.nclass 70 | ) 71 | weight = torch.from_numpy(weight.astype(np.float32)) 72 | else: 73 | weight = None 74 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss( 75 | mode=args.loss_type 76 | ) 77 | self.model, self.optimizer = model, optimizer 78 | 79 | # Define Evaluator 80 | self.evaluator = Evaluator( 81 | self.nclass, args.seen_classes_idx_metric, args.unseen_classes_idx_metric 82 | ) 83 | 84 | # Define lr scheduler 85 | self.scheduler = LR_Scheduler( 86 | args.lr_scheduler, args.lr, args.epochs, len(self.train_loader) 87 | ) 88 | 89 | # Using cuda 90 | if args.cuda: 91 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 92 | patch_replication_callback(self.model) 93 | self.model = self.model.cuda() 94 | 95 | # Resuming checkpoint 96 | self.best_pred = 0.0 97 | if args.resume is not None: 98 | if not os.path.isfile(args.resume): 99 | raise RuntimeError(f"=> no checkpoint found at '{args.resume}'") 100 | checkpoint = torch.load(args.resume) 101 | args.start_epoch = checkpoint["epoch"] 102 | 103 | if args.random_last_layer: 104 | checkpoint["state_dict"]["decoder.pred_conv.weight"] = torch.rand( 105 | ( 106 | self.nclass, 107 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[1], 108 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[2], 109 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[3], 110 | ) 111 | ) 112 | checkpoint["state_dict"]["decoder.pred_conv.bias"] = torch.rand( 113 | self.nclass 114 | ) 115 | 116 | if args.nonlinear_last_layer: 117 | if args.cuda: 118 | self.model.module.deeplab.load_state_dict(checkpoint["state_dict"]) 119 | else: 120 | self.model.deeplab.load_state_dict(checkpoint["state_dict"]) 121 | else: 122 | if args.cuda: 123 | self.model.module.load_state_dict(checkpoint["state_dict"]) 124 | else: 125 | self.model.load_state_dict(checkpoint["state_dict"]) 126 | 127 | if not args.ft: 128 | if not args.nonlinear_last_layer: 129 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 130 | self.best_pred = checkpoint["best_pred"] 131 | print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 132 | 133 | # Clear start epoch if fine-tuning 134 | if args.ft: 135 | args.start_epoch = 0 136 | 137 | def validation(self, epoch, args): 138 | self.model.eval() 139 | self.evaluator.reset() 140 | all_target = [] 141 | all_pred = [] 142 | tbar = tqdm(self.val_loader, desc="\r") 143 | test_loss = 0.0 144 | for i, sample in enumerate(tbar): 145 | image, target = sample["image"], sample["label"] 146 | if self.args.cuda: 147 | image, target = image.cuda(), target.cuda() 148 | with torch.no_grad(): 149 | if args.nonlinear_last_layer: 150 | output = self.model(image, image.size()[2:]) 151 | else: 152 | output = self.model(image) 153 | loss = self.criterion(output, target) 154 | test_loss += loss.item() 155 | tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) 156 | pred = output.data.cpu().numpy() 157 | target = target.cpu().numpy() 158 | pred = np.argmax(pred, axis=1) 159 | 160 | # Add batch sample into evaluator 161 | self.evaluator.add_batch(target, pred) 162 | 163 | all_target.append(target) 164 | all_pred.append(pred) 165 | 166 | # Fast test during the training 167 | Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy() 168 | ( 169 | Acc_class, 170 | Acc_class_by_class, 171 | Acc_class_seen, 172 | Acc_class_unseen, 173 | ) = self.evaluator.Pixel_Accuracy_Class() 174 | ( 175 | mIoU, 176 | mIoU_by_class, 177 | mIoU_seen, 178 | mIoU_unseen, 179 | ) = self.evaluator.Mean_Intersection_over_Union() 180 | ( 181 | FWIoU, 182 | FWIoU_seen, 183 | FWIoU_unseen, 184 | ) = self.evaluator.Frequency_Weighted_Intersection_over_Union() 185 | self.writer.add_scalar("val_overall/total_loss_epoch", test_loss, epoch) 186 | self.writer.add_scalar("val_overall/mIoU", mIoU, epoch) 187 | self.writer.add_scalar("val_overall/Acc", Acc, epoch) 188 | self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch) 189 | self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch) 190 | 191 | self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch) 192 | self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch) 193 | self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch) 194 | self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch) 195 | 196 | self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch) 197 | self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch) 198 | self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch) 199 | self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch) 200 | 201 | print("Validation:") 202 | print( 203 | "[Epoch: %d, numImages: %5d]" 204 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 205 | ) 206 | print(f"Loss: {test_loss:.3f}") 207 | print(f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") 208 | print( 209 | "Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( 210 | Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen 211 | ) 212 | ) 213 | print( 214 | "Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( 215 | Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen 216 | ) 217 | ) 218 | 219 | for class_name, acc_value, mIoU_value in zip( 220 | CLASSES_NAMES, Acc_class_by_class, mIoU_by_class 221 | ): 222 | self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) 223 | self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) 224 | print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value) 225 | 226 | 227 | def main(): 228 | parser = get_parser() 229 | parser.add_argument( 230 | "--out-stride", type=int, default=16, help="network output stride (default: 8)" 231 | ) 232 | 233 | # PASCAL VOC 234 | parser.add_argument( 235 | "--dataset", 236 | type=str, 237 | default="context", 238 | choices=["pascal", "coco", "cityscapes"], 239 | help="dataset name (default: pascal)", 240 | ) 241 | 242 | parser.add_argument( 243 | "--use-sbd", 244 | action="store_true", 245 | default=True, 246 | help="whether to use SBD dataset (default: True)", 247 | ) 248 | parser.add_argument("--base-size", type=int, default=513, help="base image size") 249 | parser.add_argument("--crop-size", type=int, default=513, help="crop image size") 250 | parser.add_argument( 251 | "--loss-type", 252 | type=str, 253 | default="ce", 254 | choices=["ce", "focal"], 255 | help="loss func type (default: ce)", 256 | ) 257 | # training hyper params 258 | 259 | # PASCAL VOC 260 | parser.add_argument( 261 | "--epochs", 262 | type=int, 263 | default=300, 264 | metavar="N", 265 | help="number of epochs to train (default: auto)", 266 | ) 267 | 268 | # PASCAL VOC 269 | parser.add_argument( 270 | "--batch-size", 271 | type=int, 272 | default=8, 273 | metavar="N", 274 | help="input batch size for training (default: auto)", 275 | ) 276 | # cuda, seed and logging 277 | parser.add_argument( 278 | "--imagenet_pretrained_path", 279 | type=str, 280 | default="checkpoint/resnet_backbone_pretrained_imagenet_wo_pascalcontext.pth.tar", 281 | ) 282 | 283 | parser.add_argument( 284 | "--resume", 285 | type=str, 286 | default="checkpoint/deeplab_pascal_context_02_unseen_GMMN_final.pth.tar", 287 | help="put the path to resuming file if needed", 288 | ) 289 | 290 | parser.add_argument("--checkname", type=str, default="context_eval") 291 | 292 | # evaluation option 293 | parser.add_argument( 294 | "--eval-interval", type=int, default=5, help="evaluation interval (default: 1)" 295 | ) 296 | 297 | # keep empty 298 | parser.add_argument("--unseen_classes_idx", type=int, default=[]) 299 | 300 | # 2 unseen 301 | unseen_names = ["cow", "motorbike"] 302 | # 4 unseen 303 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat'] 304 | # 6 unseen 305 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence'] 306 | # 8 unseen 307 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence', 'bird', 'tvmonitor'] 308 | # 10 unseen 309 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence', 'bird', 'tvmonitor', 'aeroplane', 'keyboard'] 310 | 311 | unseen_classes_idx_metric = [] 312 | for name in unseen_names: 313 | unseen_classes_idx_metric.append(CLASSES_NAMES.index(name)) 314 | 315 | ### FOR METRIC COMPUTATION IN ORDER TO GET PERFORMANCES FOR TWO SETS 316 | seen_classes_idx_metric = np.arange(60) 317 | 318 | seen_classes_idx_metric = np.delete( 319 | seen_classes_idx_metric, unseen_classes_idx_metric 320 | ).tolist() 321 | parser.add_argument( 322 | "--seen_classes_idx_metric", type=int, default=seen_classes_idx_metric 323 | ) 324 | parser.add_argument( 325 | "--unseen_classes_idx_metric", type=int, default=unseen_classes_idx_metric 326 | ) 327 | 328 | parser.add_argument( 329 | "--nonlinear_last_layer", type=bool, default=False, help="non linear prediction" 330 | ) 331 | parser.add_argument( 332 | "--random_last_layer", type=bool, default=False, help="randomly init last layer" 333 | ) 334 | 335 | args = parser.parse_args() 336 | args.cuda = not args.no_cuda and torch.cuda.is_available() 337 | if args.cuda: 338 | try: 339 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(",")] 340 | except ValueError: 341 | raise ValueError( 342 | "Argument --gpu_ids must be a comma-separated list of integers only" 343 | ) 344 | 345 | args.sync_bn = args.cuda and len(args.gpu_ids) > 1 346 | 347 | # default settings for epochs, batch_size and lr 348 | if args.epochs is None: 349 | epoches = { 350 | "coco": 30, 351 | "cityscapes": 200, 352 | "pascal": 50, 353 | } 354 | args.epochs = epoches[args.dataset.lower()] 355 | 356 | if args.batch_size is None: 357 | args.batch_size = 4 * len(args.gpu_ids) 358 | 359 | if args.test_batch_size is None: 360 | args.test_batch_size = args.batch_size 361 | 362 | if args.lr is None: 363 | lrs = { 364 | "coco": 0.1, 365 | "cityscapes": 0.01, 366 | "pascal": 0.007, 367 | } 368 | args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 369 | 370 | if args.checkname is None: 371 | args.checkname = "deeplab-resnet" 372 | print(args) 373 | torch.manual_seed(args.seed) 374 | trainer = Trainer(args) 375 | print("Starting Epoch:", trainer.args.start_epoch) 376 | print("Total Epoches:", trainer.args.epochs) 377 | trainer.validation(0, args) 378 | 379 | 380 | if __name__ == "__main__": 381 | main() 382 | -------------------------------------------------------------------------------- /zs3/eval_pascal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from zs3.dataloaders import make_data_loader 8 | from zs3.modeling.deeplab import DeepLab 9 | from zs3.modeling.sync_batchnorm.replicate import patch_replication_callback 10 | from zs3.dataloaders.datasets import DATASETS_DIRS 11 | from zs3.utils.calculate_weights import calculate_weigths_labels 12 | from zs3.utils.loss import SegmentationLosses 13 | from zs3.utils.lr_scheduler import LR_Scheduler 14 | from zs3.utils.metrics import Evaluator, Evaluator_seen_unseen 15 | from zs3.utils.saver import Saver 16 | from zs3.utils.summaries import TensorboardSummary 17 | from zs3.parsing import get_parser 18 | from zs3.exp_data import CLASSES_NAMES 19 | 20 | 21 | class Trainer: 22 | def __init__(self, args): 23 | self.args = args 24 | 25 | # Define Saver 26 | self.saver = Saver(args) 27 | self.saver.save_experiment_config() 28 | # Define Tensorboard Summary 29 | self.summary = TensorboardSummary(self.saver.experiment_dir) 30 | self.writer = self.summary.create_summary() 31 | 32 | # Define Dataloader 33 | kwargs = {"num_workers": args.workers, "pin_memory": True} 34 | (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader( 35 | args, **kwargs 36 | ) 37 | 38 | model = DeepLab( 39 | num_classes=self.nclass, 40 | output_stride=args.out_stride, 41 | sync_bn=args.sync_bn, 42 | freeze_bn=args.freeze_bn, 43 | imagenet_pretrained_path=args.imagenet_pretrained_path, 44 | ) 45 | train_params = [ 46 | {"params": model.get_1x_lr_params(), "lr": args.lr}, 47 | {"params": model.get_10x_lr_params(), "lr": args.lr * 10}, 48 | ] 49 | 50 | # Define Optimizer 51 | optimizer = torch.optim.SGD( 52 | train_params, 53 | momentum=args.momentum, 54 | weight_decay=args.weight_decay, 55 | nesterov=args.nesterov, 56 | ) 57 | 58 | # Define Criterion 59 | # whether to use class balanced weights 60 | if args.use_balanced_weights: 61 | classes_weights_path = ( 62 | DATASETS_DIRS[args.dataset] / args.dataset + "_classes_weights.npy" 63 | ) 64 | 65 | if os.path.isfile(classes_weights_path): 66 | weight = np.load(classes_weights_path) 67 | else: 68 | weight = calculate_weigths_labels( 69 | args.dataset, self.train_loader, self.nclass 70 | ) 71 | weight = torch.from_numpy(weight.astype(np.float32)) 72 | else: 73 | weight = None 74 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss( 75 | mode=args.loss_type 76 | ) 77 | self.model, self.optimizer = model, optimizer 78 | 79 | # Define Evaluator 80 | self.evaluator = Evaluator( 81 | self.nclass, args.seen_classes_idx_metric, args.unseen_classes_idx_metric 82 | ) 83 | self.evaluator_seen_unseen = Evaluator_seen_unseen( 84 | self.nclass, args.unseen_classes_idx_metric 85 | ) 86 | 87 | # Define lr scheduler 88 | self.scheduler = LR_Scheduler( 89 | args.lr_scheduler, args.lr, args.epochs, len(self.train_loader) 90 | ) 91 | 92 | # Using cuda 93 | if args.cuda: 94 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 95 | patch_replication_callback(self.model) 96 | self.model = self.model.cuda() 97 | 98 | # Resuming checkpoint 99 | self.best_pred = 0.0 100 | if args.resume is not None: 101 | if not os.path.isfile(args.resume): 102 | raise RuntimeError(f"=> no checkpoint found at '{args.resume}'") 103 | checkpoint = torch.load(args.resume) 104 | args.start_epoch = checkpoint["epoch"] 105 | 106 | if args.random_last_layer: 107 | checkpoint["state_dict"]["decoder.pred_conv.weight"] = torch.rand( 108 | ( 109 | self.nclass, 110 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[1], 111 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[2], 112 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[3], 113 | ) 114 | ) 115 | checkpoint["state_dict"]["decoder.pred_conv.bias"] = torch.rand( 116 | self.nclass 117 | ) 118 | 119 | if args.cuda: 120 | self.model.module.load_state_dict(checkpoint["state_dict"]) 121 | else: 122 | self.model.load_state_dict(checkpoint["state_dict"]) 123 | 124 | if not args.ft: 125 | if not args.nonlinear_last_layer: 126 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 127 | self.best_pred = checkpoint["best_pred"] 128 | print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 129 | 130 | # Clear start epoch if fine-tuning 131 | if args.ft: 132 | args.start_epoch = 0 133 | 134 | def validation(self, epoch, args): 135 | class_names = CLASSES_NAMES[:21] 136 | self.model.eval() 137 | self.evaluator.reset() 138 | all_target = [] 139 | all_pred = [] 140 | all_pred_unseen = [] 141 | tbar = tqdm(self.val_loader, desc="\r") 142 | test_loss = 0.0 143 | for i, sample in enumerate(tbar): 144 | image, target = sample["image"], sample["label"] 145 | if self.args.cuda: 146 | image, target = image.cuda(), target.cuda() 147 | with torch.no_grad(): 148 | if args.nonlinear_last_layer: 149 | output = self.model(image, image.size()[2:]) 150 | else: 151 | output = self.model(image) 152 | loss = self.criterion(output, target) 153 | test_loss += loss.item() 154 | tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) 155 | pred = output.data.cpu().numpy() 156 | pred_unseen = pred.copy() 157 | target = target.cpu().numpy() 158 | pred = np.argmax(pred, axis=1) 159 | 160 | pred_unseen[:, args.seen_classes_idx_metric] = float("-inf") 161 | pred_unseen = np.argmax(pred_unseen, axis=1) 162 | 163 | # Add batch sample into evaluator 164 | self.evaluator.add_batch(target, pred) 165 | 166 | all_target.append(target) 167 | all_pred.append(pred) 168 | all_pred_unseen.append(pred_unseen) 169 | 170 | # Fast test during the training 171 | Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy() 172 | ( 173 | Acc_class, 174 | Acc_class_by_class, 175 | Acc_class_seen, 176 | Acc_class_unseen, 177 | ) = self.evaluator.Pixel_Accuracy_Class() 178 | ( 179 | mIoU, 180 | mIoU_by_class, 181 | mIoU_seen, 182 | mIoU_unseen, 183 | ) = self.evaluator.Mean_Intersection_over_Union() 184 | ( 185 | FWIoU, 186 | FWIoU_seen, 187 | FWIoU_unseen, 188 | ) = self.evaluator.Frequency_Weighted_Intersection_over_Union() 189 | self.writer.add_scalar("val_overall/total_loss_epoch", test_loss, epoch) 190 | self.writer.add_scalar("val_overall/mIoU", mIoU, epoch) 191 | self.writer.add_scalar("val_overall/Acc", Acc, epoch) 192 | self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch) 193 | self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch) 194 | 195 | self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch) 196 | self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch) 197 | self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch) 198 | self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch) 199 | 200 | self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch) 201 | self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch) 202 | self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch) 203 | self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch) 204 | 205 | print("Validation:") 206 | print( 207 | "[Epoch: %d, numImages: %5d]" 208 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 209 | ) 210 | print(f"Loss: {test_loss:.3f}") 211 | print(f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") 212 | print( 213 | "Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( 214 | Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen 215 | ) 216 | ) 217 | print( 218 | "Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( 219 | Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen 220 | ) 221 | ) 222 | 223 | for class_name, acc_value, mIoU_value in zip( 224 | class_names, Acc_class_by_class, mIoU_by_class 225 | ): 226 | self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) 227 | self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) 228 | print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value) 229 | 230 | 231 | def main(): 232 | parser = get_parser() 233 | parser.add_argument( 234 | "--out-stride", type=int, default=16, help="network output stride (default: 8)" 235 | ) 236 | 237 | # PASCAL VOC 238 | parser.add_argument( 239 | "--dataset", 240 | type=str, 241 | default="pascal", 242 | choices=["pascal", "coco", "cityscapes"], 243 | help="dataset name (default: pascal)", 244 | ) 245 | 246 | parser.add_argument( 247 | "--use-sbd", 248 | action="store_true", 249 | default=True, 250 | help="whether to use SBD dataset (default: True)", 251 | ) 252 | parser.add_argument("--base-size", type=int, default=513, help="base image size") 253 | parser.add_argument("--crop-size", type=int, default=513, help="crop image size") 254 | parser.add_argument( 255 | "--loss-type", 256 | type=str, 257 | default="ce", 258 | choices=["ce", "focal"], 259 | help="loss func type (default: ce)", 260 | ) 261 | # training hyper params 262 | 263 | # PASCAL VOC 264 | parser.add_argument( 265 | "--epochs", 266 | type=int, 267 | default=300, 268 | metavar="N", 269 | help="number of epochs to train (default: auto)", 270 | ) 271 | 272 | # PASCAL VOC 273 | parser.add_argument( 274 | "--batch-size", 275 | type=int, 276 | default=8, 277 | metavar="N", 278 | help="input batch size for training (default: auto)", 279 | ) 280 | # cuda, seed and logging 281 | # checking point 282 | parser.add_argument( 283 | "--imagenet_pretrained_path", 284 | type=str, 285 | default="checkpoint/resnet_backbone_pretrained_imagenet_wo_pascalvoc.pth.tar", 286 | ) 287 | # checking point 288 | parser.add_argument( 289 | "--resume", 290 | type=str, 291 | default="checkpoint/deeplab_pascal_voc_02_unseen_GMMN_final.pth.tar", 292 | help="put the path to resuming file if needed", 293 | ) 294 | 295 | parser.add_argument("--checkname", type=str, default="pascal_eval") 296 | 297 | # evaluation option 298 | parser.add_argument( 299 | "--eval-interval", type=int, default=5, help="evaluation interval (default: 1)" 300 | ) 301 | ### FOR IMAGE SELECTION IN ORDER TO TAKE OFF IMAGE WITH UNSEEN CLASSES FOR TRAINING AND VALIDATION 302 | # keep empty 303 | parser.add_argument("--unseen_classes_idx", type=int, default=[]) 304 | 305 | ### FOR METRIC COMPUTATION IN ORDER TO GET PERFORMANCES FOR TWO SETS 306 | seen_classes_idx_metric = np.arange(21) 307 | 308 | # 2 unseen 309 | unseen_classes_idx_metric = [10, 14] 310 | # 4 unseen 311 | # unseen_classes_idx_metric = [10, 14, 1, 18] 312 | # 6 unseen 313 | # unseen_classes_idx_metric = [10, 14, 1, 18, 8, 20] 314 | # 8 unseen 315 | # unseen_classes_idx_metric = [10, 14, 1, 18, 8, 20, 19, 5] 316 | # 10 unseen 317 | # unseen_classes_idx_metric = [10, 14, 1, 18, 8, 20, 19, 5, 9, 16] 318 | 319 | seen_classes_idx_metric = np.delete( 320 | seen_classes_idx_metric, unseen_classes_idx_metric 321 | ).tolist() 322 | parser.add_argument( 323 | "--seen_classes_idx_metric", type=int, default=seen_classes_idx_metric 324 | ) 325 | parser.add_argument( 326 | "--unseen_classes_idx_metric", type=int, default=unseen_classes_idx_metric 327 | ) 328 | 329 | parser.add_argument( 330 | "--nonlinear_last_layer", type=bool, default=False, help="non linear prediction" 331 | ) 332 | parser.add_argument( 333 | "--random_last_layer", type=bool, default=False, help="randomly init last layer" 334 | ) 335 | 336 | args = parser.parse_args() 337 | args.cuda = not args.no_cuda and torch.cuda.is_available() 338 | if args.cuda: 339 | try: 340 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(",")] 341 | except ValueError: 342 | raise ValueError( 343 | "Argument --gpu_ids must be a comma-separated list of integers only" 344 | ) 345 | 346 | args.sync_bn = args.cuda and len(args.gpu_ids) > 1 347 | 348 | # default settings for epochs, batch_size and lr 349 | if args.epochs is None: 350 | epoches = { 351 | "coco": 30, 352 | "cityscapes": 200, 353 | "pascal": 50, 354 | } 355 | args.epochs = epoches[args.dataset.lower()] 356 | 357 | if args.batch_size is None: 358 | args.batch_size = 4 * len(args.gpu_ids) 359 | 360 | if args.test_batch_size is None: 361 | args.test_batch_size = args.batch_size 362 | 363 | if args.lr is None: 364 | lrs = { 365 | "coco": 0.1, 366 | "cityscapes": 0.01, 367 | "pascal": 0.007, 368 | } 369 | args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 370 | 371 | if args.checkname is None: 372 | args.checkname = "deeplab-resnet" 373 | print(args) 374 | torch.manual_seed(args.seed) 375 | trainer = Trainer(args) 376 | print("Starting Epoch:", trainer.args.start_epoch) 377 | print("Total Epoches:", trainer.args.epochs) 378 | trainer.validation(0, args) 379 | 380 | 381 | if __name__ == "__main__": 382 | main() 383 | -------------------------------------------------------------------------------- /zs3/train_context_GMMN.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from tqdm import tqdm 7 | 8 | from zs3.dataloaders import make_data_loader 9 | from zs3.modeling.deeplab import DeepLab 10 | from zs3.modeling.gmmn import GMMNnetwork 11 | from zs3.modeling.sync_batchnorm.replicate import patch_replication_callback 12 | from zs3.utils.loss import SegmentationLosses, GMMNLoss 13 | from zs3.utils.lr_scheduler import LR_Scheduler 14 | from zs3.utils.metrics import Evaluator 15 | from zs3.utils.saver import Saver 16 | from zs3.utils.summaries import TensorboardSummary 17 | from zs3.parsing import get_parser 18 | from zs3.exp_data import CLASSES_NAMES 19 | 20 | 21 | class Trainer: 22 | def __init__(self, args): 23 | self.args = args 24 | 25 | # Define Saver 26 | self.saver = Saver(args) 27 | self.saver.save_experiment_config() 28 | # Define Tensorboard Summary 29 | self.summary = TensorboardSummary(self.saver.experiment_dir) 30 | self.writer = self.summary.create_summary() 31 | 32 | # Define Dataloader 33 | kwargs = {"num_workers": args.workers, "pin_memory": True} 34 | (self.train_loader, self.val_loader, _, self.nclass,) = make_data_loader( 35 | args, load_embedding=args.load_embedding, w2c_size=args.w2c_size, **kwargs 36 | ) 37 | 38 | model = DeepLab( 39 | num_classes=self.nclass, 40 | output_stride=args.out_stride, 41 | sync_bn=args.sync_bn, 42 | freeze_bn=args.freeze_bn, 43 | global_avg_pool_bn=args.global_avg_pool_bn, 44 | imagenet_pretrained_path=args.imagenet_pretrained_path, 45 | ) 46 | 47 | train_params = [ 48 | {"params": model.get_1x_lr_params(), "lr": args.lr}, 49 | {"params": model.get_10x_lr_params(), "lr": args.lr * 10}, 50 | ] 51 | 52 | # Define Optimizer 53 | optimizer = torch.optim.SGD( 54 | train_params, 55 | momentum=args.momentum, 56 | weight_decay=args.weight_decay, 57 | nesterov=args.nesterov, 58 | ) 59 | 60 | # Define Generator 61 | generator = GMMNnetwork( 62 | args.noise_dim, args.embed_dim, args.hidden_size, args.feature_dim 63 | ) 64 | optimizer_generator = torch.optim.Adam( 65 | generator.parameters(), lr=args.lr_generator 66 | ) 67 | 68 | class_weight = torch.ones(self.nclass) 69 | class_weight[args.unseen_classes_idx_metric] = args.unseen_weight 70 | if args.cuda: 71 | class_weight = class_weight.cuda() 72 | 73 | self.criterion = SegmentationLosses( 74 | weight=class_weight, cuda=args.cuda 75 | ).build_loss(mode=args.loss_type) 76 | self.model, self.optimizer = model, optimizer 77 | 78 | self.criterion_generator = GMMNLoss( 79 | sigma=[2, 5, 10, 20, 40, 80], cuda=args.cuda 80 | ).build_loss() 81 | self.generator, self.optimizer_generator = generator, optimizer_generator 82 | 83 | # Define Evaluator 84 | self.evaluator = Evaluator( 85 | self.nclass, args.seen_classes_idx_metric, args.unseen_classes_idx_metric 86 | ) 87 | 88 | # Define lr scheduler 89 | self.scheduler = LR_Scheduler( 90 | args.lr_scheduler, args.lr, args.epochs, len(self.train_loader) 91 | ) 92 | 93 | # Using cuda 94 | if args.cuda: 95 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 96 | patch_replication_callback(self.model) 97 | self.model = self.model.cuda() 98 | self.generator = self.generator.cuda() 99 | 100 | # Resuming checkpoint 101 | self.best_pred = 0.0 102 | if args.resume is not None: 103 | if not os.path.isfile(args.resume): 104 | raise RuntimeError(f"=> no checkpoint found at '{args.resume}'") 105 | checkpoint = torch.load(args.resume) 106 | # args.start_epoch = checkpoint['epoch'] 107 | 108 | if args.random_last_layer: 109 | checkpoint["state_dict"]["decoder.pred_conv.weight"] = torch.rand( 110 | ( 111 | self.nclass, 112 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[1], 113 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[2], 114 | checkpoint["state_dict"]["decoder.pred_conv.weight"].shape[3], 115 | ) 116 | ) 117 | checkpoint["state_dict"]["decoder.pred_conv.bias"] = torch.rand( 118 | self.nclass 119 | ) 120 | 121 | if args.cuda: 122 | self.model.module.load_state_dict(checkpoint["state_dict"]) 123 | else: 124 | self.model.load_state_dict(checkpoint["state_dict"]) 125 | 126 | # self.best_pred = checkpoint['best_pred'] 127 | print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 128 | 129 | # Clear start epoch if fine-tuning 130 | if args.ft: 131 | args.start_epoch = 0 132 | 133 | def training(self, epoch, args): 134 | train_loss = 0.0 135 | self.model.train() 136 | tbar = tqdm(self.train_loader) 137 | num_img_tr = len(self.train_loader) 138 | for i, sample in enumerate(tbar): 139 | if len(sample["image"]) > 1: 140 | image, target, embedding = ( 141 | sample["image"], 142 | sample["label"], 143 | sample["label_emb"], 144 | ) 145 | if self.args.cuda: 146 | image, target, embedding = ( 147 | image.cuda(), 148 | target.cuda(), 149 | embedding.cuda(), 150 | ) 151 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 152 | # ===================real feature extraction===================== 153 | with torch.no_grad(): 154 | real_features = self.model.module.forward_before_class_prediction( 155 | image 156 | ) 157 | 158 | # ===================fake feature generation===================== 159 | fake_features = torch.zeros(real_features.shape) 160 | if args.cuda: 161 | fake_features = fake_features.cuda() 162 | generator_loss_batch = 0.0 163 | for ( 164 | count_sample_i, 165 | (real_features_i, target_i, embedding_i), 166 | ) in enumerate(zip(real_features, target, embedding)): 167 | generator_loss_sample = 0.0 168 | ## reduce to real feature size 169 | real_features_i = ( 170 | real_features_i.permute(1, 2, 0) 171 | .contiguous() 172 | .view((-1, args.feature_dim)) 173 | ) 174 | target_i = nn.functional.interpolate( 175 | target_i.view(1, 1, target_i.shape[0], target_i.shape[1]), 176 | size=(real_features.shape[2], real_features.shape[3]), 177 | mode="nearest", 178 | ).view(-1) 179 | embedding_i = nn.functional.interpolate( 180 | embedding_i.view( 181 | 1, 182 | embedding_i.shape[0], 183 | embedding_i.shape[1], 184 | embedding_i.shape[2], 185 | ), 186 | size=(real_features.shape[2], real_features.shape[3]), 187 | mode="nearest", 188 | ) 189 | 190 | embedding_i = ( 191 | embedding_i.permute(0, 2, 3, 1) 192 | .contiguous() 193 | .view((-1, args.embed_dim)) 194 | ) 195 | 196 | fake_features_i = torch.zeros(real_features_i.shape) 197 | if args.cuda: 198 | fake_features_i = fake_features_i.cuda() 199 | 200 | unique_class = torch.unique(target_i) 201 | 202 | ## test if image has unseen class pixel, if yes means no training for generator and generated features for the whole image 203 | has_unseen_class = False 204 | for u_class in unique_class: 205 | if u_class in args.unseen_classes_idx_metric: 206 | has_unseen_class = True 207 | 208 | for idx_in in unique_class: 209 | if idx_in != 255: 210 | self.optimizer_generator.zero_grad() 211 | idx_class = target_i == idx_in 212 | real_features_class = real_features_i[idx_class] 213 | embedding_class = embedding_i[idx_class] 214 | 215 | z = torch.rand((embedding_class.shape[0], args.noise_dim)) 216 | if args.cuda: 217 | z = z.cuda() 218 | 219 | fake_features_class = self.generator( 220 | embedding_class, z.float() 221 | ) 222 | 223 | if ( 224 | idx_in in args.seen_classes_idx_metric 225 | and not has_unseen_class 226 | ): 227 | ## in order to avoid CUDA out of memory 228 | random_idx = torch.randint( 229 | low=0, 230 | high=fake_features_class.shape[0], 231 | size=(args.batch_size_generator,), 232 | ) 233 | g_loss = self.criterion_generator( 234 | fake_features_class[random_idx], 235 | real_features_class[random_idx], 236 | ) 237 | generator_loss_sample += g_loss.item() 238 | g_loss.backward() 239 | self.optimizer_generator.step() 240 | 241 | fake_features_i[idx_class] = fake_features_class.clone() 242 | generator_loss_batch += generator_loss_sample / len(unique_class) 243 | if args.real_seen_features and not has_unseen_class: 244 | fake_features[count_sample_i] = real_features_i.view( 245 | ( 246 | fake_features.shape[2], 247 | fake_features.shape[3], 248 | args.feature_dim, 249 | ) 250 | ).permute(2, 0, 1) 251 | else: 252 | fake_features[count_sample_i] = fake_features_i.view( 253 | ( 254 | fake_features.shape[2], 255 | fake_features.shape[3], 256 | args.feature_dim, 257 | ) 258 | ).permute(2, 0, 1) 259 | # ===================classification===================== 260 | self.optimizer.zero_grad() 261 | output = self.model.module.forward_class_prediction( 262 | fake_features.detach(), image.size()[2:] 263 | ) 264 | loss = self.criterion(output, target) 265 | loss.backward() 266 | self.optimizer.step() 267 | train_loss += loss.item() 268 | # ===================log===================== 269 | tbar.set_description( 270 | f" G loss: {generator_loss_batch:.3f}" 271 | + " C loss: %.3f" % (train_loss / (i + 1)) 272 | ) 273 | self.writer.add_scalar( 274 | "train/total_loss_iter", loss.item(), i + num_img_tr * epoch 275 | ) 276 | self.writer.add_scalar( 277 | "train/generator_loss", generator_loss_batch, i + num_img_tr * epoch 278 | ) 279 | 280 | # Show 10 * 3 inference results each epoch 281 | if i % (num_img_tr // 10) == 0: 282 | global_step = i + num_img_tr * epoch 283 | self.summary.visualize_image( 284 | self.writer, 285 | self.args.dataset, 286 | image, 287 | target, 288 | output, 289 | global_step, 290 | ) 291 | 292 | self.writer.add_scalar("train/total_loss_epoch", train_loss, epoch) 293 | print( 294 | "[Epoch: %d, numImages: %5d]" 295 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 296 | ) 297 | print(f"Loss: {train_loss:.3f}") 298 | 299 | if self.args.no_val: 300 | # save checkpoint every epoch 301 | is_best = False 302 | self.saver.save_checkpoint( 303 | { 304 | "epoch": epoch + 1, 305 | "state_dict": self.model.module.state_dict(), 306 | "optimizer": self.optimizer.state_dict(), 307 | "best_pred": self.best_pred, 308 | }, 309 | is_best, 310 | ) 311 | 312 | def validation(self, epoch, args): 313 | self.model.eval() 314 | self.evaluator.reset() 315 | tbar = tqdm(self.val_loader, desc="\r") 316 | test_loss = 0.0 317 | 318 | saved_images = {} 319 | saved_target = {} 320 | saved_prediction = {} 321 | for idx_unseen_class in args.unseen_classes_idx_metric: 322 | saved_images[idx_unseen_class] = [] 323 | saved_target[idx_unseen_class] = [] 324 | saved_prediction[idx_unseen_class] = [] 325 | 326 | for i, sample in enumerate(tbar): 327 | image, target, embedding = ( 328 | sample["image"], 329 | sample["label"], 330 | sample["label_emb"], 331 | ) 332 | if self.args.cuda: 333 | image, target = image.cuda(), target.cuda() 334 | with torch.no_grad(): 335 | output = self.model(image) 336 | loss = self.criterion(output, target) 337 | test_loss += loss.item() 338 | tbar.set_description("Test loss: %.3f" % (test_loss / (i + 1))) 339 | ## save image for tensorboard 340 | for idx_unseen_class in args.unseen_classes_idx_metric: 341 | if len((target.reshape(-1) == idx_unseen_class).nonzero()) > 0: 342 | if len(saved_images[idx_unseen_class]) < args.saved_validation_images: 343 | saved_images[idx_unseen_class].append(image.clone().cpu()) 344 | saved_target[idx_unseen_class].append(target.clone().cpu()) 345 | saved_prediction[idx_unseen_class].append(output.clone().cpu()) 346 | 347 | pred = output.data.cpu().numpy() 348 | target = target.cpu().numpy() 349 | pred = np.argmax(pred, axis=1) 350 | # Add batch sample into evaluator 351 | self.evaluator.add_batch(target, pred) 352 | 353 | # Fast test during the training 354 | Acc, Acc_seen, Acc_unseen = self.evaluator.Pixel_Accuracy() 355 | ( 356 | Acc_class, 357 | Acc_class_by_class, 358 | Acc_class_seen, 359 | Acc_class_unseen, 360 | ) = self.evaluator.Pixel_Accuracy_Class() 361 | ( 362 | mIoU, 363 | mIoU_by_class, 364 | mIoU_seen, 365 | mIoU_unseen, 366 | ) = self.evaluator.Mean_Intersection_over_Union() 367 | ( 368 | FWIoU, 369 | FWIoU_seen, 370 | FWIoU_unseen, 371 | ) = self.evaluator.Frequency_Weighted_Intersection_over_Union() 372 | self.writer.add_scalar("val_overall/total_loss_epoch", test_loss, epoch) 373 | self.writer.add_scalar("val_overall/mIoU", mIoU, epoch) 374 | self.writer.add_scalar("val_overall/Acc", Acc, epoch) 375 | self.writer.add_scalar("val_overall/Acc_class", Acc_class, epoch) 376 | self.writer.add_scalar("val_overall/fwIoU", FWIoU, epoch) 377 | 378 | self.writer.add_scalar("val_seen/mIoU", mIoU_seen, epoch) 379 | self.writer.add_scalar("val_seen/Acc", Acc_seen, epoch) 380 | self.writer.add_scalar("val_seen/Acc_class", Acc_class_seen, epoch) 381 | self.writer.add_scalar("val_seen/fwIoU", FWIoU_seen, epoch) 382 | 383 | self.writer.add_scalar("val_unseen/mIoU", mIoU_unseen, epoch) 384 | self.writer.add_scalar("val_unseen/Acc", Acc_unseen, epoch) 385 | self.writer.add_scalar("val_unseen/Acc_class", Acc_class_unseen, epoch) 386 | self.writer.add_scalar("val_unseen/fwIoU", FWIoU_unseen, epoch) 387 | 388 | print("Validation:") 389 | print( 390 | "[Epoch: %d, numImages: %5d]" 391 | % (epoch, i * self.args.batch_size + image.data.shape[0]) 392 | ) 393 | print(f"Loss: {test_loss:.3f}") 394 | print(f"Overall: Acc:{Acc}, Acc_class:{Acc_class}, mIoU:{mIoU}, fwIoU: {FWIoU}") 395 | print( 396 | "Seen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( 397 | Acc_seen, Acc_class_seen, mIoU_seen, FWIoU_seen 398 | ) 399 | ) 400 | print( 401 | "Unseen: Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( 402 | Acc_unseen, Acc_class_unseen, mIoU_unseen, FWIoU_unseen 403 | ) 404 | ) 405 | 406 | for class_name, acc_value, mIoU_value in zip( 407 | CLASSES_NAMES, Acc_class_by_class, mIoU_by_class 408 | ): 409 | self.writer.add_scalar("Acc_by_class/" + class_name, acc_value, epoch) 410 | self.writer.add_scalar("mIoU_by_class/" + class_name, mIoU_value, epoch) 411 | print(class_name, "- acc:", acc_value, " mIoU:", mIoU_value) 412 | 413 | new_pred = mIoU_unseen 414 | 415 | is_best = True 416 | self.best_pred = new_pred 417 | self.saver.save_checkpoint( 418 | { 419 | "epoch": epoch + 1, 420 | "state_dict": self.model.module.state_dict(), 421 | "optimizer": self.optimizer.state_dict(), 422 | "best_pred": self.best_pred, 423 | }, 424 | is_best, 425 | generator_state={ 426 | "epoch": epoch + 1, 427 | "state_dict": self.generator.state_dict(), 428 | "optimizer": self.optimizer.state_dict(), 429 | "best_pred": self.best_pred, 430 | }, 431 | ) 432 | 433 | global_step = epoch + 1 434 | for idx_unseen_class in args.unseen_classes_idx_metric: 435 | if len(saved_images[idx_unseen_class]) > 0: 436 | nb_image = len(saved_images[idx_unseen_class]) 437 | if nb_image > args.saved_validation_images: 438 | nb_image = args.saved_validation_images 439 | for i in range(nb_image): 440 | self.summary.visualize_image_validation( 441 | self.writer, 442 | self.args.dataset, 443 | saved_images[idx_unseen_class][i], 444 | saved_target[idx_unseen_class][i], 445 | saved_prediction[idx_unseen_class][i], 446 | global_step, 447 | name="validation_" 448 | + CLASSES_NAMES[idx_unseen_class] 449 | + "_" 450 | + str(i), 451 | nb_image=1, 452 | ) 453 | 454 | self.evaluator.reset() 455 | 456 | 457 | def main(): 458 | parser = get_parser() 459 | parser.add_argument( 460 | "--out-stride", type=int, default=16, help="network output stride (default: 8)" 461 | ) 462 | 463 | # PASCAL VOC 464 | parser.add_argument( 465 | "--dataset", 466 | type=str, 467 | default="context", 468 | choices=["pascal", "coco", "cityscapes"], 469 | help="dataset name (default: pascal)", 470 | ) 471 | 472 | parser.add_argument( 473 | "--use-sbd", 474 | action="store_true", 475 | default=True, 476 | help="whether to use SBD dataset (default: True)", 477 | ) 478 | parser.add_argument("--base-size", type=int, default=312, help="base image size") 479 | parser.add_argument("--crop-size", type=int, default=312, help="crop image size") 480 | parser.add_argument( 481 | "--loss-type", 482 | type=str, 483 | default="ce", 484 | choices=["ce", "focal"], 485 | help="loss func type (default: ce)", 486 | ) 487 | # training hyper params 488 | 489 | # PASCAL VOC 490 | parser.add_argument( 491 | "--epochs", 492 | type=int, 493 | default=20, 494 | metavar="N", 495 | help="number of epochs to train (default: auto)", 496 | ) 497 | 498 | # PASCAL VOC 499 | parser.add_argument( 500 | "--batch-size", 501 | type=int, 502 | default=8, 503 | metavar="N", 504 | help="input batch size for training (default: auto)", 505 | ) 506 | # checking point 507 | 508 | parser.add_argument( 509 | "--imagenet_pretrained_path", 510 | type=str, 511 | default="checkpoint/resnet_backbone_pretrained_imagenet_wo_pascalcontext.pth.tar", 512 | ) 513 | 514 | parser.add_argument( 515 | "--resume", 516 | type=str, 517 | default="checkpoint/deeplab_pretrained_pascal_context_02_unseen.pth.tar", 518 | help="put the path to resuming file if needed", 519 | ) 520 | 521 | parser.add_argument( 522 | "--checkname", 523 | type=str, 524 | default="gmmn_context_w2c300_linear_weighted100_hs256_2_unseen", 525 | ) 526 | 527 | # false if embedding resume 528 | parser.add_argument("--global_avg_pool_bn", type=bool, default=True) 529 | 530 | # evaluation option 531 | parser.add_argument( 532 | "--eval-interval", type=int, default=1, help="evaluation interval (default: 1)" 533 | ) 534 | 535 | # keep empty 536 | parser.add_argument("--unseen_classes_idx", type=int, default=[]) 537 | 538 | # 2 unseen 539 | unseen_names = ["cow", "motorbike"] 540 | # 4 unseen 541 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat'] 542 | # 6 unseen 543 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence'] 544 | # 8 unseen 545 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence', 'bird', 'tvmonitor'] 546 | # 10 unseen 547 | # unseen_names = ['cow', 'motorbike', 'sofa', 'cat', 'boat', 'fence', 'bird', 'tvmonitor', 'aeroplane', 'keyboard'] 548 | 549 | unseen_classes_idx_metric = [] 550 | for name in unseen_names: 551 | unseen_classes_idx_metric.append(CLASSES_NAMES.index(name)) 552 | 553 | ### FOR METRIC COMPUTATION IN ORDER TO GET PERFORMANCES FOR TWO SETS 554 | seen_classes_idx_metric = np.arange(60) 555 | 556 | seen_classes_idx_metric = np.delete( 557 | seen_classes_idx_metric, unseen_classes_idx_metric 558 | ).tolist() 559 | parser.add_argument( 560 | "--seen_classes_idx_metric", type=int, default=seen_classes_idx_metric 561 | ) 562 | parser.add_argument( 563 | "--unseen_classes_idx_metric", type=int, default=unseen_classes_idx_metric 564 | ) 565 | 566 | parser.add_argument( 567 | "--unseen_weight", type=int, default=100, help="number of output channels" 568 | ) 569 | 570 | parser.add_argument( 571 | "--random_last_layer", type=bool, default=True, help="randomly init last layer" 572 | ) 573 | 574 | parser.add_argument( 575 | "--real_seen_features", 576 | type=bool, 577 | default=True, 578 | help="real features for seen classes", 579 | ) 580 | parser.add_argument( 581 | "--load_embedding", 582 | type=str, 583 | default="my_w2c", 584 | choices=["attributes", "w2c", "w2c_bg", "my_w2c", "fusion", None], 585 | ) 586 | parser.add_argument("--w2c_size", type=int, default=300) 587 | 588 | ### GENERATOR ARGS 589 | parser.add_argument("--noise_dim", type=int, default=300) 590 | parser.add_argument("--embed_dim", type=int, default=300) 591 | parser.add_argument("--hidden_size", type=int, default=256) 592 | parser.add_argument("--feature_dim", type=int, default=256) 593 | parser.add_argument("--lr_generator", type=float, default=0.0002) 594 | parser.add_argument("--batch_size_generator", type=int, default=128) 595 | parser.add_argument("--saved_validation_images", type=int, default=10) 596 | 597 | args = parser.parse_args() 598 | args.cuda = not args.no_cuda and torch.cuda.is_available() 599 | if args.cuda: 600 | try: 601 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(",")] 602 | except ValueError: 603 | raise ValueError( 604 | "Argument --gpu_ids must be a comma-separated list of integers only" 605 | ) 606 | 607 | args.sync_bn = args.cuda and len(args.gpu_ids) > 1 608 | 609 | # default settings for epochs, batch_size and lr 610 | if args.epochs is None: 611 | epoches = { 612 | "coco": 30, 613 | "cityscapes": 200, 614 | "pascal": 50, 615 | } 616 | args.epochs = epoches[args.dataset.lower()] 617 | 618 | if args.batch_size is None: 619 | args.batch_size = 4 * len(args.gpu_ids) 620 | 621 | if args.test_batch_size is None: 622 | args.test_batch_size = args.batch_size 623 | 624 | if args.lr is None: 625 | lrs = { 626 | "coco": 0.1, 627 | "cityscapes": 0.01, 628 | "pascal": 0.007, 629 | } 630 | args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 631 | 632 | if args.checkname is None: 633 | args.checkname = "deeplab-resnet" 634 | print(args) 635 | torch.manual_seed(args.seed) 636 | trainer = Trainer(args) 637 | print("Starting Epoch:", trainer.args.start_epoch) 638 | print("Total Epoches:", trainer.args.epochs) 639 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 640 | 641 | trainer.training(epoch, args) 642 | if not trainer.args.no_val and epoch % args.eval_interval == ( 643 | args.eval_interval - 1 644 | ): 645 | trainer.validation(epoch, args) 646 | 647 | trainer.writer.close() 648 | 649 | 650 | if __name__ == "__main__": 651 | main() 652 | --------------------------------------------------------------------------------