├── .gitignore ├── BYOL_gain.png ├── README.md ├── byol.png ├── byol_gain.png ├── config └── defaults.py ├── environment.yml ├── scripts ├── create_env.sh ├── test.sh ├── train.sh └── unsup_train.sh ├── sem_seg ├── __init__.py ├── datasets │ ├── __init__.py │ ├── augmentations.py │ ├── dataloader.py │ ├── dataset.py │ ├── datautils.py │ └── transforms.py ├── models │ ├── __init__.py │ ├── base.py │ ├── eval_metrices.py │ ├── segmentation_model.py │ └── self_supervised_model.py └── networks │ ├── __init__.py │ ├── byol.py │ ├── fcn.py │ ├── network_utils.py │ └── resnet.py └── training ├── __init__.py ├── debugger.py ├── train.py ├── train_byol.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | myvenv 2 | .idea 3 | __pycache__ 4 | .DS_Store -------------------------------------------------------------------------------- /BYOL_gain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/BYOL_gain.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Semantic Segmentation 2 | This repository is related to semantic segmentation models trainined in 3 | self-supervised manner. In a two step training process, the backbone 4 | of the segmentation network(say FCN8s network) is trained first 5 | in self-supervised way and then the whole network is fine-tuned 6 | with few semantic segmentation annotations. 7 | 8 | ![BYOL](https://github.com/nilesh0109/self-supervised-sem-seg/blob/master/byol.png) 9 | 10 | Read more about BYOL architecture in the blog post: https://nilesh0109.medium.com/hands-on-review-byol-bootstrap-your-own-latent-67e4c5744e1b
11 | The baseline segmentation network used is FCN8s and various backbones are: Resnet50, Resnet101, DRN50. 12 | 13 | To reproduce the experiments, follow the below steps. 14 | 15 | # How To Run The Code 16 | 1. Create the virtual environment with all dependencies and activate it 17 | 18 | ``` bash scripts/create_env.sh``` 19 | 20 | 2. Train the Self-supervised Model first using unlabelled data. 21 | 22 | ``` bash scripts\unsup_train.sh ``` 23 | 24 | 3. Train the supervised model e2e with using only labelled 25 | data. 26 | 27 | ``` bash scripts\train.sh``` 28 | 29 | # Results: 30 | Resnet50 based Semantic segmentation FCN8s network, gets 10% improvement in mIOU on cityscapes dataset if it is first pretraited in BYOL manner using 20k unlabelled images and then fine-tuned with 5k labelled cityscapes images. 31 | 32 | X-axis in below graph shows the percetage of labels(out of 5k images) used in fine-tuning step. Top dotted line in the graph is Resnet50 trained from imagenet weights and bottom dotted line is Resnet50 trained from random weights(no pretraining). Middle plot is Resnet50 pretrained using BYOL. Red arrows clearly indicates the success of BYOL for visual representation learning. 33 | 34 | ![BYOL gain](https://github.com/nilesh0109/self-supervised-sem-seg/blob/master/byol_gain.png) 35 | 36 | 37 | # Code Walkthrough 38 | ``` 39 | config 40 | - defaults.py: Configuration file for setting the defaults 41 | scripts 42 | - create_env.sh: script file for creating the virtual environment with dependencies listed in environment.yml 43 | - test.sh: script file for debugging the inputs & outputs to the FCN8s and BYOL network. 44 | - train.sh: script for training the segmentation network 45 | - unsup_train.sh: script for training the unsupervised learning(BYOL) network 46 | sem_seg 47 | datasets 48 | - augmentations.py: Utility file for various augmentations 49 | - dataloader.py: Script for loading cityscapes dataloader 50 | - dataset.py: Custom cityscapes dataset for applying same set of augmentations to images and masks 51 | - datautils.py: Script for cityscapes labels formatting 52 | - transforms.py: Custom transformation scripts 53 | models 54 | - base.py: Base class for all models with train and evaluation pipeline 55 | - eval_metrices.py: Script for different evaluation metrices such as mIOU, accuracy, etc. 56 | - segmentation.py: Class for semantic segmentaion model inherited from base model 57 | - self_supervised_model.py: Self-supervised-model(BYOL) training pipeline 58 | network 59 | - byol.py: BYOL network setup 60 | - fcn.py: Various FCN(Fully Convolutional Network) setup 61 | - network_utils.py: Utility script for vairous network configurations 62 | - resnet.py: Various Residual network(Resnet) setup 63 | training 64 | - debugger.py: Debugger script for FCN and BYOL model 65 | - train.py: Training file for FCN model 66 | - train_byol.py: Training file for BYOL model 67 | - utils.py: Utility file for plotting various inputs and outputs 68 | - environment.yml: List of virtual environment dependencies 69 | ``` 70 | # References 71 | 1. Grill, Jean-Bastien, et al. "Bootstrap your own latent: A new approach to self-supervised learning." arXiv preprint arXiv:2006.07733 (2020). 72 | -------------------------------------------------------------------------------- /byol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/byol.png -------------------------------------------------------------------------------- /byol_gain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/byol_gain.png -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | CITYSCAPES_PATH = './cityscapes_data/' 2 | CITYSCAPES_IGNORED_INDEX = 19 3 | 4 | #Training Config 5 | NUM_EPOCHS = 15 6 | BATCH_SIZE = 8 7 | NUM_WORKERS = 0 #Parallel workers to prepare dataloader from dataset 8 | 9 | #For Matplotlib 10 | MATPLOTLIB_NO_GUI = True -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: myvenv 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - absl-py=0.11.0=pyhd3eb1b0_1 7 | - aiohttp=3.7.3=py37h9ed2024_1 8 | - async-timeout=3.0.1=py37_0 9 | - attrs=20.3.0=pyhd3eb1b0_0 10 | - blas=1.0=mkl 11 | - blinker=1.4=py37_0 12 | - brotlipy=0.7.0=py37h9ed2024_1003 13 | - c-ares=1.17.1=h9ed2024_0 14 | - ca-certificates=2020.12.8=hecd8cb5_0 15 | - cachetools=4.2.0=pyhd3eb1b0_0 16 | - certifi=2020.12.5=py37hecd8cb5_0 17 | - cffi=1.14.4=py37h2125817_0 18 | - chardet=3.0.4=py37hecd8cb5_1003 19 | - click=7.1.2=py_0 20 | - cryptography=2.9.2=py37hbcfaee0_0 21 | - cycler=0.10.0=py37_0 22 | - freetype=2.10.4=ha233b18_0 23 | - google-auth=1.24.0=pyhd3eb1b0_0 24 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2 25 | - grpcio=1.31.0=py37h7580e61_0 26 | - idna=2.10=py_0 27 | - importlib-metadata=2.0.0=py_1 28 | - intel-openmp=2019.4=233 29 | - jpeg=9b=he5867d9_2 30 | - kiwisolver=1.3.0=py37h23ab428_0 31 | - lcms2=2.11=h92f6f08_0 32 | - libcxx=10.0.0=1 33 | - libedit=3.1.20191231=h1de35cc_1 34 | - libffi=3.3=hb1e8313_2 35 | - libpng=1.6.37=ha441bb4_0 36 | - libprotobuf=3.13.0.1=hab81aa3_0 37 | - libtiff=4.1.0=hcb84e12_1 38 | - libuv=1.40.0=haf1e3a3_0 39 | - lz4-c=1.9.2=h79c402e_3 40 | - markdown=3.3.3=py37hecd8cb5_0 41 | - matplotlib=3.3.2=hecd8cb5_0 42 | - matplotlib-base=3.3.2=py37h181983e_0 43 | - mkl=2019.4=233 44 | - mkl-service=2.3.0=py37h9ed2024_0 45 | - mkl_fft=1.2.0=py37hc64f4ea_0 46 | - mkl_random=1.1.1=py37h959d312_0 47 | - multidict=4.7.6=py37haf1e3a3_1 48 | - ncurses=6.2=h0a44026_1 49 | - ninja=1.10.2=py37hf7b0b51_0 50 | - numpy=1.19.2=py37h456fd55_0 51 | - numpy-base=1.19.2=py37hcfb5961_0 52 | - oauthlib=3.1.0=py_0 53 | - olefile=0.46=py37_0 54 | - openssl=1.1.1i=h9ed2024_0 55 | - pillow=8.0.1=py37h5270095_0 56 | - pip=20.3.3=py37hecd8cb5_0 57 | - protobuf=3.13.0.1=py37hb1e8313_1 58 | - pyasn1=0.4.8=py_0 59 | - pyasn1-modules=0.2.8=py_0 60 | - pycparser=2.20=py_2 61 | - pyjwt=2.0.0=py37hecd8cb5_0 62 | - pyopenssl=20.0.1=pyhd3eb1b0_1 63 | - pyparsing=2.4.7=py_0 64 | - pysocks=1.7.1=py37hecd8cb5_0 65 | - python=3.7.9=h26836e1_0 66 | - python-dateutil=2.8.1=py_0 67 | - pytorch=1.7.1=py3.7_0 68 | - readline=8.0=h1de35cc_0 69 | - requests=2.25.1=pyhd3eb1b0_0 70 | - requests-oauthlib=1.3.0=py_0 71 | - rsa=4.6=py_0 72 | - setuptools=51.0.0=py37hecd8cb5_2 73 | - six=1.15.0=py37hecd8cb5_0 74 | - sqlite=3.33.0=hffcf06c_0 75 | - tensorboard=2.3.0=pyh4dce500_0 76 | - tensorboard-plugin-wit=1.6.0=py_0 77 | - tk=8.6.10=hb0a8c7a_0 78 | - torchvision=0.8.2=py37_cpu 79 | - tornado=6.1=py37h9ed2024_0 80 | - tqdm=4.54.1=pyhd3eb1b0_0 81 | - typing-extensions=3.7.4.3=0 82 | - typing_extensions=3.7.4.3=py_0 83 | - urllib3=1.26.2=pyhd3eb1b0_0 84 | - werkzeug=1.0.1=py_0 85 | - wheel=0.36.2=pyhd3eb1b0_0 86 | - xz=5.2.5=h1de35cc_0 87 | - yarl=1.5.1=py37haf1e3a3_0 88 | - zipp=3.4.0=pyhd3eb1b0_0 89 | - zlib=1.2.11=h1de35cc_3 90 | - zstd=1.4.5=h41d2c2f_0 91 | prefix: /Users/nileshvijayrania/anaconda3/envs/myvenv 92 | -------------------------------------------------------------------------------- /scripts/create_env.sh: -------------------------------------------------------------------------------- 1 | conda env create -n bs_venv -f environment.yml 2 | conda activate bs_venv 3 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | python -m training.debugger -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #For supervised Training 2 | config='{"dataset": "cityscapes", "model": "SegmentationModel", "network": "FCN8s", 3 | "network_args":{"backbone":"Resnet50", "pretrained": false, "load_from_byol": false, "freeze_backbone": false}, 4 | "train_args":{"batch_size": 8, "epochs": 400, "log_to_tensorboard": false}, 5 | "experiment_group":{} 6 | }' 7 | 8 | python -m training.train --save "$config" -------------------------------------------------------------------------------- /scripts/unsup_train.sh: -------------------------------------------------------------------------------- 1 | # for unsupervised Training 2 | 3 | config='{"dataset": "cityscapes", "model": "SelfSupervisedModel", "network": "BYOL", "mode": "self-supervised", 4 | "network_args":{"backbone":"Resnet50", "pretrained": false, "target_momentum": 0.996}, 5 | "train_args":{"batch_size": 32, "epochs": 700, "log_to_tensorboard": false}, 6 | "experiment_group":{} 7 | }' 8 | 9 | python -m training.train_byol --save "$config" 10 | 11 | #Use for distributed Training 12 | #python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 -m training.train_byol --save "$config" -------------------------------------------------------------------------------- /sem_seg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/sem_seg/__init__.py -------------------------------------------------------------------------------- /sem_seg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['augmentations', 'datautils', 'dataloader', 'dataset', 'transforms'] 2 | -------------------------------------------------------------------------------- /sem_seg/datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms as tfms 5 | from .datautils import CityscapesUtils 6 | from .transforms import RandomResize 7 | from PIL import Image 8 | 9 | 10 | def get_tfms() -> Tuple: 11 | """ 12 | :return: Image and mask transforms tuple for cityscapes fine annotated images 13 | """ 14 | base_size = 800 15 | min_size, max_size = int(0.5*base_size), int(2.0*base_size) 16 | img_transforms = { 17 | 'train': tfms.Compose([ 18 | RandomResize(min_size, max_size), 19 | tfms.RandomCrop(size=768, pad_if_needed=True, fill=0), 20 | tfms.RandomHorizontalFlip(p=0.5), 21 | tfms.ToTensor() 22 | ]), 23 | 'val': tfms.Compose([ 24 | tfms.ToTensor() 25 | ]) 26 | } 27 | 28 | target_transforms = { 29 | 'train': tfms.Compose([ 30 | RandomResize(min_size, max_size, is_mask=True), 31 | tfms.RandomCrop(size=768, pad_if_needed=True, fill=0), 32 | tfms.RandomHorizontalFlip(p=0.5), 33 | mapId2TrainID 34 | ]), 35 | 'val': tfms.Compose([mapId2TrainID]) 36 | } 37 | return img_transforms, target_transforms 38 | 39 | cityscapes_utils = CityscapesUtils() 40 | 41 | def mapId2TrainID(mask: Image) -> torch.Tensor: 42 | """ 43 | Redcudes the 34 labels present in gt_fine to 19 labels. Ignoring the ignore_in_eval labels. 44 | :param mask: Cityscapes mask(PIL Image type) with 34 classes 45 | :return: mask with 19 cityscapes classes of tensor type 46 | """ 47 | return torch.from_numpy(cityscapes_utils.id2train_id[np.array(mask)]).long() 48 | 49 | 50 | def get_spatial_tfms() -> tfms.Compose: 51 | """ Get the transformation which changes the spatial position of the object in the image""" 52 | return tfms.Compose([ 53 | tfms.RandomResizedCrop(size=512, scale=(0.3, 1)), 54 | tfms.RandomHorizontalFlip() 55 | ]) 56 | 57 | 58 | def get_pixelwise_tfms() -> List[tfms.Compose]: 59 | """Get the transformations which only introduces local perturbation to the image.""" 60 | tfms_list = [ 61 | tfms.RandomApply([ 62 | tfms.ColorJitter(0.4, 0.4, 0.4, 0.1) 63 | ], p=0.8), 64 | tfms.RandomGrayscale(p=0.2) 65 | ] 66 | return tfms_list 67 | 68 | 69 | def get_self_supervised_tfms() -> tfms.Compose: 70 | """ Returns the list of transformations to be used to generating self-supervised training image pairs""" 71 | spatial_tfms = get_spatial_tfms() 72 | pixelwise_tfms = get_pixelwise_tfms() 73 | 74 | self_supervised_tfms = tfms.Compose([ 75 | *pixelwise_tfms, 76 | tfms.ToTensor(), 77 | spatial_tfms, 78 | tfms.Lambda(lambda img: img.squeeze()) 79 | ]) 80 | return self_supervised_tfms 81 | -------------------------------------------------------------------------------- /sem_seg/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import os 3 | import torch 4 | from torchvision import datasets 5 | from torch.utils.data.sampler import SequentialSampler, RandomSampler 6 | from config.defaults import CITYSCAPES_PATH, BATCH_SIZE, NUM_WORKERS 7 | from .dataset import CityscapesDataset 8 | from .augmentations import get_tfms, get_self_supervised_tfms 9 | 10 | 11 | class CityscapesLoader: 12 | """Prepares the cityscapes DataLoader""" 13 | 14 | def __init__(self, label_percent: str = '100%') -> None: 15 | """ 16 | :param label_percent: Percentage of cityscpaes labels to be used. Must be suffixed with % 17 | """ 18 | label_percent = min(float(label_percent[:-1]), 100.0) 19 | self._prepare_dataset(label_percent) 20 | 21 | def _prepare_dataset(self, label_percent: int = 100) -> None: 22 | """ 23 | Constructs all 3 datasets(train, test, val) using CityScapesDataset class 24 | :param label_percent: percentage of cityscapes labels to use 25 | :return: None 26 | """ 27 | 28 | data_fine = {phase: datasets.Cityscapes(CITYSCAPES_PATH, split=phase, mode='fine', 29 | target_type='semantic') for phase in ['train', 'test', 'val']} 30 | if 'gtCoarse' in os.listdir(CITYSCAPES_PATH): 31 | data_coarse = {phase: datasets.Cityscapes(CITYSCAPES_PATH, split=phase, mode='coarse', 32 | target_type='semantic') for phase in ['train_extra', 'val']} 33 | else: 34 | data_coarse = { 35 | 'train_extra': [], 36 | 'val': data_fine['val'] 37 | } 38 | 39 | img_tfms, target_tfms = get_tfms() 40 | self_supervised_tfms = get_self_supervised_tfms() 41 | 42 | self.cityscapes = {phase: CityscapesDataset(data_fine[phase], 43 | label_percent=label_percent if phase == 'train' else 100, 44 | transform=img_tfms[phase], 45 | target_transform=target_tfms[phase]) 46 | for phase in ['train', 'val']} 47 | 48 | #self.cityscapes['train'].imgs_path = self.cityscapes['train'].imgs_path[:30] 49 | #self.cityscapes['val'].imgs_path = self.cityscapes['val'].imgs_path[:30] 50 | 51 | self.cityscapes['self-supervised'] = { 52 | phase: CityscapesDataset([data_fine['train'], data_fine['test']] 53 | if phase == 'train' else data_coarse['val'], 54 | label_percent=label_percent if phase == 'train' else 100, 55 | transform=self_supervised_tfms, 56 | mode='self-supervised') 57 | for phase in ['train', 'val'] 58 | } 59 | 60 | def get_cityscapes_loader(self, batch_size: int = BATCH_SIZE, num_workers: int = NUM_WORKERS, 61 | mode: str = 'supervised') -> Dict[str, torch.utils.data.DataLoader]: 62 | """ 63 | Returns Cityscapes dataloader with correct sampler 64 | :param batch_size: number of batches 65 | :param num_workers: number of parallel workers for dataloading on CPU. 66 | :param mode: 'supervised' or 'self-supervised' type of dataloader 67 | :return: Dict of train and val data loader 68 | """ 69 | 70 | data = self.cityscapes if mode == 'supervised' else self.cityscapes['self-supervised'] 71 | 72 | cityscapes_loader = {x: torch.utils.data.DataLoader(data[x], batch_size=batch_size, 73 | sampler=RandomSampler(data[x]) if x == 'train' 74 | else SequentialSampler(data[x]), 75 | drop_last=bool(mode == 'supervised' and x == 'train'), 76 | num_workers=num_workers, 77 | pin_memory=True) 78 | for x in ['train', 'val']} 79 | return cityscapes_loader 80 | -------------------------------------------------------------------------------- /sem_seg/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union, List, Tuple 2 | import random 3 | import numpy as np 4 | import functools 5 | from PIL import Image 6 | 7 | import torch 8 | import torchvision 9 | from torch.utils.data import Dataset 10 | 11 | from .augmentations import get_pixelwise_tfms 12 | 13 | 14 | class CityscapesDataset(Dataset): 15 | """Custom cityscapes dataset to apply same transformation on image and mask pair""" 16 | 17 | def __init__(self, cityscapes_data: Union[Dataset, List], label_percent: int = 100, transform: Callable=None, 18 | target_transform: Callable = None, is_test: bool = False, mode: str = 'supervised'): 19 | """ 20 | :param cityscapes_data: torch loaded Cityscapes dataset from path defaults.CITYSCAPES_PATH 21 | :param label_percent: % of labels to use 22 | :param transform: List of Transformations to be applied on the input Image 23 | :param target_transform: List of transformations to be applied on the segmentaion mask 24 | :param is_test: is Test/Val dataset 25 | :param mode: type of dataset. 'supervised' | 'self-supervised'. 'self-supervised' mode will generate 26 | pair and 'supervised' mode will generate pair 27 | """ 28 | 29 | inputs = functools.reduce(lambda a, b: a + b.images, cityscapes_data, []) \ 30 | if type(cityscapes_data) == list else cityscapes_data.images 31 | total = len(inputs) 32 | n = int(total * label_percent) // 100 33 | self.imgs_path = inputs[:n] 34 | self.masks_path = None if is_test or mode.lower() == 'self-supervised' else cityscapes_data.targets 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | self.isTestSet = is_test 38 | self.mode = mode.lower() 39 | assert self.mode in ['supervised', 'self-supervised'], "Invalid dataset mode. Only 'supervised', " \ 40 | "'self-supervised' is allowed" 41 | 42 | def __len__(self): 43 | return len(self.imgs_path) 44 | 45 | def get_random_crop(self, image: Image, crop_size: int) -> Image: 46 | """ 47 | :param image: PIL Image 48 | :param crop_size: Size of the crop 49 | :return: PIL Image crop of size crop_size 50 | """ 51 | crop_tfms = torchvision.transforms.RandomCrop(crop_size) 52 | return crop_tfms(image) 53 | 54 | def _apply_transformations(self, image: Image, mask: Image = None) -> Tuple: 55 | """ 56 | 57 | :param image: PIL Image 58 | :param mask: PIL segmentation Mask 59 | :return: Transformed Image and Mask pair after applying exactly the same random transformations on image and 60 | mask 61 | """ 62 | seed = np.random.randint(2147483647) # sample a random seed 63 | random.seed(seed) # set this seed to random and numpy.random function 64 | np.random.seed(seed) 65 | torch.manual_seed(seed) # Needed for torchvision 0.7 66 | if self.transform is not None: 67 | image = self.transform(image) 68 | 69 | if mask is not None and self.target_transform is not None: 70 | random.seed(seed) # set the same seed to random and numpy.random function 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) # Needed for torchvision 0.7 73 | mask = self.target_transform(mask) 74 | return image, mask, seed 75 | 76 | def __getitem__(self, idx): 77 | if torch.is_tensor(idx): 78 | idx = idx.tolist() 79 | image_path = self.imgs_path[idx] 80 | image = Image.open(image_path) 81 | if image.mode == 'P': 82 | image = image.convert('RGB') 83 | 84 | if self.mode == 'self-supervised': 85 | crop = self.get_random_crop(image, 512) 86 | tf_image1, _, seed = self._apply_transformations(crop) 87 | crop_tensor = torchvision.transforms.Compose([ 88 | *get_pixelwise_tfms(), 89 | torchvision.transforms.ToTensor() 90 | ])(crop) 91 | return (crop_tensor, tf_image1, seed), seed 92 | else: 93 | if not self.isTestSet: 94 | mask_path = image_path.replace('leftImg8bit', 'gtFine').replace('.png', '_labelIds.png') 95 | mask = Image.open(mask_path) 96 | image, mask, seed = self._apply_transformations(image, mask) 97 | return image if self.isTestSet else (image, mask) 98 | 99 | -------------------------------------------------------------------------------- /sem_seg/datasets/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import datasets 3 | from config.defaults import CITYSCAPES_PATH 4 | 5 | 6 | class CityscapesUtils: 7 | """CITYSCAPES Utility class that provides the mapping for training labels and their colors for visualization""" 8 | def __init__(self): 9 | cityscapes_data = datasets.Cityscapes(CITYSCAPES_PATH, split='train', mode='fine', target_type='semantic') 10 | self.classes = cityscapes_data.classes 11 | self.num_classes = self._num_classes() 12 | self.train_id2color = self._train_id2color() 13 | self.id2train_id = self._id2train_id() 14 | 15 | def _num_classes(self) -> int: 16 | """ 17 | :return: returns the effective number of classes in cityscapes that are used in validation 18 | """ 19 | train_labels = [label.id for label in self.classes if not label.ignore_in_eval] 20 | return len(train_labels) 21 | 22 | def _id2train_id(self) -> np.array: 23 | """ 24 | :return: returns a list where each index is mapped to its training_id. All ignore_in_eval indexes are mapped to 0 25 | i.e. the unlabelled class. 26 | """ 27 | train_ids = np.array([label.train_id for label in self.classes]) 28 | train_ids[(train_ids == -1) | (train_ids == 255)] = 19 # 19 is Ignore_index(defaults.CITYSCAPES_IGNORE_INDEX) 29 | return train_ids 30 | 31 | def _train_id2color(self) -> np.array: 32 | """ 33 | :return: The mapping of 20 classes (19 training classes + 1 ignore index class) to their standard color used 34 | in cityscapes. 35 | """ 36 | return np.array([label.color for label in self.classes if not label.ignore_in_eval] + [(0, 0, 0)]) 37 | 38 | def label2color(self, mask: np.array) -> np.array: 39 | """ 40 | Given the cityscapes mask with all training id(and optionally 255 for ignored labels) as labels, returns the mask 41 | filled with the label's standard color. 42 | :param mask: np.array mask for which color mapping is required 43 | :return: mask with labels replaced with their standard colors 44 | """ 45 | return self.train_id2color[mask] 46 | -------------------------------------------------------------------------------- /sem_seg/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torchvision.transforms import functional as F 3 | from PIL import Image 4 | 5 | 6 | class RandomResize(object): 7 | """Resize the object randomly between the min and max size""" 8 | def __init__(self, min_size: int, max_size: int, is_mask: bool = False): 9 | """ 10 | :param min_size: min desired size of the image 11 | :param max_size: max desired size of the image 12 | :param is_mask: If the image is the segmentation mask 13 | """ 14 | self.min_size = min_size 15 | self.max_size = max_size 16 | self.is_mask = is_mask 17 | 18 | def __call__(self, img: Image) -> Image: 19 | size = random.randint(self.min_size, self.max_size) 20 | if self.is_mask: 21 | return F.resize(img, size=size, interpolation=Image.NEAREST) 22 | else: 23 | return F.resize(img, size=size) 24 | -------------------------------------------------------------------------------- /sem_seg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation_model import SegmentationModel 2 | from .self_supervised_model import SelfSupervisedModel 3 | 4 | __all__ = ['SegmentationModel', 'base', 'SelfSupervisedModel'] -------------------------------------------------------------------------------- /sem_seg/models/base.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Callable, Dict 3 | from tqdm import tqdm 4 | import time 5 | import copy 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from . import eval_metrices 10 | from training import utils 11 | 12 | DIRNAME = Path(__file__).parents[1].resolve() 13 | WEIGHTSDIR = DIRNAME / "weights" 14 | LOGSDIR = DIRNAME / "tensorboard_logs" 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | class Model: 19 | """ Base Class for training and evaluation of any network """ 20 | 21 | def __init__(self, network: torch.nn.Module, 22 | dataloader: torch.utils.data.DataLoader, 23 | optimizer: torch.optim, 24 | criterion: Callable, 25 | lr_scheduler: torch.optim.lr_scheduler, 26 | additional_identifier: str = ''): 27 | self.network = network 28 | self.dataloader = dataloader 29 | self.optim = optimizer 30 | self.criterion = criterion 31 | self.lr_scheduler = lr_scheduler 32 | self.name = f"{self.__class__.__name__}_{self.network.__class__.__name__}" 33 | if additional_identifier: 34 | self.name += '_' + additional_identifier 35 | self.metrices = ['accuracy'] 36 | self.logToTensorboard = True 37 | self.writer = None 38 | 39 | @property 40 | def weights_file_name(self) -> str: 41 | """ 42 | :return: returns the path where to store/load the weights of the model. 43 | """ 44 | WEIGHTSDIR.mkdir(parents=True, exist_ok=True) 45 | return str(WEIGHTSDIR / f"{self.name}.h5") 46 | 47 | def train(self, num_epochs: int) -> None: 48 | """ 49 | 1. set up the model in appropriate training mode(Use this to setup DDP model for distributed training also if 50 | required) 51 | 2. Train the model for num_epochs and keep track of the best_model_weights based on self.metric[0] values. 52 | 3. Log the epoch results( loss and other stats mentioned in self.metrices) on tensorboard 53 | :param num_epochs: Number of training epochs 54 | :return: None 55 | """ 56 | since = time.time() 57 | losses = {'train': [], 'val': []} 58 | stats = {'train': [], 'val': []} 59 | lr_rates = [] 60 | best_epoch, self.best_stat = None, float('-inf') 61 | best_model_weights = None 62 | self.network.to(device) 63 | 64 | for epoch in tqdm(range(num_epochs)): 65 | for phase in ['train', 'val']: 66 | if phase == 'train': 67 | self.network.train() 68 | else: 69 | self.network.eval() 70 | running_loss, running_stats = 0.0, [0.0] * len(self.metrices) 71 | 72 | for inputs, outputs in self.dataloader[phase]: 73 | inputs, outputs = inputs.to(device), outputs.to(device) 74 | self.optim.zero_grad() 75 | with torch.set_grad_enabled(phase == 'train'): 76 | network_outputs = self.network(inputs) 77 | loss = self.criterion(network_outputs, outputs) 78 | batch_stats = self.evaluate_metrices(network_outputs, outputs) 79 | 80 | if phase == 'train': 81 | loss.backward() 82 | self.optim.step() 83 | 84 | n = inputs.size(0) 85 | running_loss += loss.item() * n 86 | running_stats = self.update_running_stats(batch_stats, running_stats, n) 87 | N = len(self.dataloader[phase].dataset) 88 | epoch_loss = running_loss / N 89 | epoch_stats = self.get_epoch_stats(running_stats, N) 90 | 91 | if phase == 'val' and epoch_stats[0] > self.best_stat: 92 | self.best_stat = epoch_stats[0] 93 | best_epoch = epoch 94 | best_model_weights = copy.deepcopy(self.network.state_dict()) 95 | losses[phase].append(epoch_loss) 96 | stats[phase].append(epoch_stats) 97 | 98 | self.log_stats(epoch, losses, stats, [inputs, network_outputs, outputs], lr_rates) 99 | 100 | if self.lr_scheduler is not None: 101 | self.lr_scheduler.step(epoch_stats[0]) # considering ReduceLrOnPlateau and watch val mIOU 102 | lr_step = self.optim.state_dict()["param_groups"][0]["lr"] 103 | lr_rates.append(lr_step) 104 | min_lr = self.lr_scheduler.min_lrs[0] 105 | if lr_step / min_lr <= 10: 106 | print("Min LR reached") 107 | break 108 | 109 | time_elapsed = time.time() - since 110 | print('Training Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 111 | print('Best {}: {:.6f} and best epoch is {}'.format(self.metrices[0], self.best_stat, best_epoch + 1)) 112 | self.network.load_state_dict(best_model_weights) 113 | 114 | def evaluate_metrices(self, outs: torch.Tensor, targets: torch.Tensor) -> List: 115 | """ 116 | :param outs: output logits from network for a given batch of inputs 117 | :param targets: ground truth targets corresponding to the batch of inputs 118 | :return: calculated metric from the utility function 119 | """ 120 | fn_mapping = { 121 | 'pixel_accuracy': eval_metrices.pixel_accuracy, 122 | 'mIOU': eval_metrices.get_confusion_matrix, 123 | 'accuracy': eval_metrices.accuracy 124 | } 125 | scores = [] 126 | for mat in self.metrices: 127 | scores.append(fn_mapping[mat](outs, targets)) 128 | return scores 129 | 130 | def update_running_stats(self, batch_stats: List, running_stats: List, num_inputs_in_batch: int) -> List: 131 | """ 132 | :param batch_stats: List of batch stats(in the same order as self.metrices) 133 | :param running_stats: List of running stats from the current epoch so far(in the same order as self.metrices) 134 | :param num_inputs_in_batch: Number of inputs in current batch 135 | :return: Aggregated stats for the current epoch. MIOU requires the bins_matrix(CxC) to be summed up. 136 | """ 137 | running_stats = [r_stat + b_stat * num_inputs_in_batch if 'miou' not in key.lower() else r_stat + b_stat 138 | for key, r_stat, b_stat in zip(self.metrices, running_stats, batch_stats)] 139 | return running_stats 140 | 141 | def get_epoch_stats(self, running_stats: List, total_inputs: int) -> List: 142 | epoch_stats = [stat / total_inputs if 'miou' not in key.lower() else eval_metrices.mIOU(stat) 143 | for key, stat in zip(self.metrices, running_stats)] 144 | return epoch_stats 145 | 146 | def log_stats(self, epoch: int, losses: List[Dict], stats: List[Dict] = [], data: List = [], lr_rates: List = []): 147 | """ Format epoch stats and log them to console and to tensorboard if enabled""" 148 | if epoch == 0: 149 | metrices_title = '\t'.join( 150 | [(phase + '_' + mat)[:12] for mat in self.metrices for phase in ['train', 'val']]) 151 | tqdm.write('Epoch\tTrain_Loss\tVal_Loss\t' + metrices_title) 152 | tqdm.write( 153 | '-----------------------------------------------------------------------------------------------------') 154 | metrices_stats = ' \t '.join([format(stats[phase][epoch][i], '.6f') for i in range(len(self.metrices)) 155 | for phase in ['train', 'val']]) 156 | log_stats = ' \t '.join([format(losses[phase][epoch], '.6f') for phase in ['train', 'val']]) 157 | tqdm.write('{:4d} \t'.format(epoch + 1) + log_stats + '\t' + metrices_stats) 158 | if self.logToTensorboard: 159 | self.writeToTensorboard(epoch, losses, stats, data, lr_rates) 160 | 161 | def writeToTensorboard(self, epoch: int, losses: List[Dict], stats: List[Dict] = [], data: List = [], 162 | lr_rates: List = []): 163 | """Logs epoch stats to tensorboard and output predictions to tensorboard after every 25 epoch""" 164 | if self.writer is None: 165 | self.setup_tensorboard() 166 | self.writer.add_scalar('Loss/train', losses['train'][epoch], epoch + 1) 167 | self.writer.add_scalar('Loss/val', losses['val'][epoch], epoch + 1) 168 | num_batches = len(lr_rates) 169 | for i, lr in enumerate(lr_rates): 170 | self.writer.add_scalar('Loss/schedule_lr', lr, epoch * num_batches + 1) 171 | if epoch % 25 != 0: 172 | return 173 | for i, metric in enumerate(self.metrices): 174 | self.writer.add_scalar(metric + '/train', stats['train'][epoch], epoch + 1) 175 | self.writer.add_scalar(metric + '/val', stats['val'][epoch], epoch + 1) 176 | if data: 177 | imgs, network_outputs, targets = data 178 | _, preds = torch.max(network_outputs, dim=1) 179 | val_results = utils.plot_images(imgs[:4], targets[:4], preds[:4], 180 | title='val_mIOU' + str(stats['val'][epoch][0]), 181 | num_cols=3) 182 | self.writer.add_figure('val_epoch' + str(epoch), val_results) 183 | 184 | def add_text_to_tensorboard(self, text: str) -> None: 185 | """ Log text to tensorboard. Used for logging experiment configuration""" 186 | if self.writer is None: 187 | self.setup_tensorboard() 188 | self.writer.add_text("model_details", "Experiment Config " + text) 189 | 190 | def setup_tensorboard(self): 191 | """ setup the tensorboard logging dir""" 192 | logsdir = LOGSDIR / f"{self.name}" 193 | logsdir.mkdir(parents=True, exist_ok=True) 194 | self.writer = SummaryWriter(logsdir) 195 | 196 | def get_logits(self, x: torch.Tensor) -> torch.Tensor: 197 | """ 198 | :param x: torch.Tensor [N, C, H, W] where C is number of channels in inputs 199 | :return: logit Tensor of size [N, K, H, W] where K is number of output classes 200 | """ 201 | self.network.eval() 202 | if not x.is_cuda: 203 | x = x.to(device) 204 | self.network.to(device) 205 | outs = self.network(x) 206 | return outs 207 | 208 | def evaluate(self, x: torch.Tensor) -> torch.Tensor: 209 | """ 210 | :param x: Input Tensor [N, C, H, W] 211 | :return: Predict Tensor [N, H, W] 212 | """ 213 | outs = self.get_logits(x) 214 | _, preds = torch.max(outs, dim=1) 215 | return preds 216 | 217 | def load_weights(self, filename=None) -> None: 218 | """load model weights from passed filename or self.weights_file_name""" 219 | if filename is not None: 220 | filename = self.weights_file_name 221 | self.network.load_state_dict(torch.load(filename)) 222 | 223 | def save_weights(self) -> None: 224 | """ Save current model state dict int self.weights_file_name """ 225 | torch.save(self.network.state_dict(), self.weights_file_name) 226 | -------------------------------------------------------------------------------- /sem_seg/models/eval_metrices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from config import defaults 5 | 6 | 7 | def get_confusion_matrix(output: torch.Tensor, target: torch.Tensor) -> np.ndarray: 8 | """ 9 | :param output: output logits 10 | :param target: targets 11 | :return: returns the confusion matrix of shape CxC between pred and targets 12 | """ 13 | 14 | n_classes = output.shape[1] 15 | softmax_out = F.softmax(output, dim=1) 16 | pred = torch.argmax(softmax_out, dim=1).squeeze(1).cpu().numpy() 17 | target = target.cpu().numpy() 18 | hist = fast_hist(pred.flatten(), target.flatten(), n_classes) 19 | return hist 20 | 21 | 22 | def mIOU(c_matrix: np.ndarray) -> float: 23 | """ 24 | Calculates the mIOU for a given confusion matrix 25 | :param c_matrix: CxC confusion matrix 26 | :return: effection mIOU 27 | """ 28 | if type(c_matrix) != np.ndarray: 29 | return 0 30 | class_iu = per_class_iu(c_matrix) 31 | m_iou = np.nanmean(class_iu) # ignoring Nans 32 | return m_iou 33 | 34 | 35 | def fast_hist(pred: torch.Tensor, label: torch.Tensor, n:int) -> np.ndarray: 36 | """ 37 | :param pred: softmaxed prediction of shape [N, H, W] 38 | :param label: Label of shape [N, H, W] 39 | :param n: num classes 40 | :return: a matrix of shape CXC where row i represents the count of each class in pred when the actual class 41 | of label is i 42 | """ 43 | k = (label >= 0) & (label < n) 44 | return np.bincount(n * label[k].astype(int) + pred[k], minlength=n**2).reshape(n, n) 45 | 46 | 47 | def per_class_iu(hist: np.ndarray) -> np.array: 48 | """ 49 | :param hist: bin count matrix of size CxC where C is the number of output classes. 50 | :return: list of IOU for each class 51 | """ 52 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 53 | 54 | 55 | def pixel_accuracy(output: torch.Tensor, target: torch.Tensor) -> float: 56 | """ 57 | :param output: output logits 58 | :param target: targets 59 | :return: pixelwise accuracy for semantic segmentation 60 | """ 61 | softmax_out = F.softmax(output, dim=1) 62 | pred = torch.argmax(softmax_out, dim=1).squeeze(1) 63 | pred = pred.view(1, -1) 64 | target = target.view(1, -1) 65 | correct = pred.eq(target) 66 | correct = correct[target != defaults.CITYSCAPES_IGNORED_INDEX] 67 | correct = correct.view(-1) 68 | score = correct.float().sum(0) / correct.size(0) 69 | return score.item() 70 | 71 | 72 | def accuracy(outputs: torch.Tensor, targets: torch.Tensor) -> float: 73 | """ 74 | :param outputs: output logits 75 | :param targets: target labels 76 | :return: accuracy score for the batch 77 | """ 78 | softmax_out = F.softmax(outputs, dim=1) 79 | preds = torch.argmax(softmax_out, dim=1) 80 | correct = preds.eq(targets) 81 | score = correct.float().sum(0) / correct.size(0) 82 | return score.item() 83 | -------------------------------------------------------------------------------- /sem_seg/models/segmentation_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from typing import Callable 4 | from config import defaults 5 | import torch 6 | 7 | from sem_seg.models.base import Model 8 | from . import eval_metrices 9 | 10 | DIRNAME = Path(__file__).parents[1].resolve() 11 | OUTDIR = DIRNAME / "outputs" / "class_IOUs" 12 | OUTDIR.mkdir(parents=True, exist_ok=True) 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class SegmentationModel(Model): 18 | """Child class for semantic segmentation model from base Model""" 19 | 20 | def __init__(self, network: torch.nn.Module, 21 | dataloader: torch.utils.data.DataLoader, 22 | optimizer: torch.optim, 23 | criterion: Callable = None, 24 | lr_scheduler: torch.optim.lr_scheduler = None, 25 | additional_identifier: str = ''): 26 | super().__init__(network, dataloader, optimizer, criterion, lr_scheduler, additional_identifier) 27 | 28 | if criterion is None: 29 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=defaults.CITYSCAPES_IGNORED_INDEX) 30 | 31 | self.metrices = ['mIOU', 'pixel_accuracy'] 32 | 33 | def store_per_class_iou(self): 34 | print('Calculating per class mIOU') 35 | self.network.eval() 36 | self.network.to(device) 37 | c_matrix = 0 38 | for inputs, labels in self.dataloader['val']: 39 | inputs, labels = inputs.to(device), labels.to(device) 40 | outs = self.network(inputs) 41 | c_matrix += eval_metrices.get_confusion_matrix(outs, labels) 42 | class_iou = eval_metrices.per_class_iu(c_matrix) 43 | class_iou = np.round(class_iou, decimals=6) 44 | file_path = str(OUTDIR) + '/' + self.name + '.csv' 45 | np.savetxt(file_path, class_iou, delimiter=',', fmt='%.6f') 46 | print(f'per class IOU saved at {file_path}') 47 | -------------------------------------------------------------------------------- /sem_seg/models/self_supervised_model.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from tqdm import tqdm 3 | import numpy as np 4 | import time 5 | import copy 6 | import random 7 | 8 | import torch 9 | 10 | from sem_seg.models.base import Model 11 | 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class SelfSupervisedModel(Model): 16 | def __init__(self, network: torch.nn.Module, 17 | dataloader: torch.utils.data.DataLoader, 18 | optimizer: torch.optim, 19 | criterion: Callable = None, 20 | lr_scheduler: torch.optim.lr_scheduler = None, 21 | additional_identifier: str = ''): 22 | super().__init__(network, dataloader, optimizer, criterion, lr_scheduler, additional_identifier) 23 | self.metrices = [] 24 | if criterion is None: 25 | self.criterion = torch.nn.MSELoss 26 | 27 | def train(self, num_epochs: int): 28 | since = time.time() 29 | losses = {'train': [], 'val': []} 30 | stats = {'train': [], 'val': []} 31 | lr_rates = [] 32 | best_epoch, best_stat = None, float('inf') 33 | best_model_weights = None 34 | self.network.to(device) 35 | self.network.initialize_target_network() 36 | 37 | for epoch in tqdm(range(num_epochs)): 38 | for phase in ['train', 'val']: 39 | if phase == 'train': 40 | self.network.train() 41 | else: 42 | self.network.eval() 43 | running_loss = 0.0 44 | 45 | for (imgs, tf_imgs, seeds), _ in self.dataloader[phase]: 46 | imgs, tf_imgs = imgs.to(device), tf_imgs.to(device) 47 | self.optim.zero_grad() 48 | 49 | with torch.set_grad_enabled(phase == 'train'): 50 | online_feats = self.network.online_network(imgs)['feat5'] 51 | tf_online_feats = self.network.online_network(tf_imgs)['feat5'] 52 | 53 | pred1_feats = self.network.predictor(self.network.online_projector(online_feats)) 54 | pred2_feats = self.network.predictor(self.network.online_projector(tf_online_feats)) 55 | 56 | with torch.no_grad(): 57 | target_feats = self.network.target_network(imgs)['feat5'] 58 | target_feats_tf = self.network.target_network(tf_imgs)['feat5'] 59 | 60 | target_for_pred1_feats = self.network.target_projector(target_feats_tf) 61 | target_for_pred2_feats = self.network.target_projector(target_feats) 62 | 63 | with torch.set_grad_enabled(phase == 'train'): 64 | loss = self.network.regression_loss(pred1_feats, target_for_pred1_feats) 65 | loss += self.network.regression_loss(pred2_feats, target_for_pred2_feats) 66 | 67 | if phase == 'train': 68 | loss.backward() 69 | self.optim.step() 70 | self.network.update_target_network() 71 | 72 | running_loss += loss.item() * imgs.size(0) 73 | epoch_loss = running_loss / len(self.dataloader[phase].dataset) 74 | 75 | if phase == 'val' and epoch_loss <= best_stat: 76 | best_epoch = epoch 77 | best_stat = epoch_loss 78 | best_model_weights = copy.deepcopy(self.network.state_dict()) 79 | losses[phase].append(epoch_loss) 80 | 81 | self.log_stats(epoch, losses) 82 | 83 | if (epoch + 1) % 30 == 0: #checkpoint model every 30 epochs 84 | torch.save(self.network.state_dict(), self.weights_file_name) 85 | torch.save(best_model_weights, self.weights_file_name.split('.')[0]+'_best.h5') 86 | print('model checkpointed to ', self.weights_file_name) 87 | 88 | if self.lr_scheduler is not None: 89 | self.lr_scheduler.step(epoch_loss) # considering ReduceLrOnPlateau and watch val val loss 90 | lr_step = self.optim.state_dict()["param_groups"][0]["lr"] 91 | lr_rates.append(lr_step) 92 | min_lr = self.lr_scheduler.min_lrs[0] 93 | if lr_step / min_lr <= 10: 94 | print("Min LR reached") 95 | break 96 | 97 | time_elapsed = time.time() - since 98 | print('Training Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 99 | print('Best val loss {:.6f} and best epoch is {}'.format(best_stat, best_epoch + 1)) 100 | torch.save(self.network.state_dict(), self.weights_file_name) 101 | print('model saved to ', self.weights_file_name) 102 | self.network.load_state_dict(best_model_weights) 103 | -------------------------------------------------------------------------------- /sem_seg/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .fcn import FCN8s, FCN16s, FCN32s 2 | from .resnet import Resnet18, Resnet50, Resnet101, Resnet152, Resnet50_2, Resnet101_2 3 | from .byol import BYOL -------------------------------------------------------------------------------- /sem_seg/networks/byol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .network_utils import get_num_features 5 | import copy 6 | 7 | 8 | class BYOL_Head(nn.Module): 9 | def __init__(self, backbone: nn.Module, target_momentum=0.996): 10 | super().__init__() 11 | # representation head 12 | self.online_network = backbone 13 | self.target_network = copy.deepcopy(backbone) 14 | 15 | # Projection Head 16 | self.online_projector = ProjectorHead(backbone.name) 17 | self.target_projector = ProjectorHead(backbone.name) 18 | 19 | # Predictor Head 20 | self.predictor = MLPHead(self.online_projector.out_channels, 512, 128) 21 | 22 | self.m = target_momentum 23 | 24 | def initialize_target_network(self): 25 | for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): 26 | param_k.data.copy_(param_q.data) 27 | param_k.requires_grad = False 28 | 29 | for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()): 30 | param_k.data.copy_(param_q.data) 31 | param_k.requires_grad = False 32 | 33 | @torch.no_grad() 34 | def update_target_network(self): 35 | for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()): 36 | param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data 37 | 38 | for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()): 39 | param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data 40 | 41 | @staticmethod 42 | def regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 43 | x_norm = F.normalize(x, dim=1) 44 | y_norm = F.normalize(y, dim=1) 45 | loss = 2 - 2 * (x_norm * y_norm).sum(dim=-1) 46 | return loss.mean() 47 | 48 | 49 | class ProjectorHead(nn.Module): 50 | def __init__(self, backbone_class: str, layer: str = 'feats5'): 51 | super().__init__() 52 | num_feats = get_num_features(backbone_class) 53 | layer_index = -3 if layer == 'feats3' else -2 if layer == 'feats4' else -1 54 | num_features = num_feats[layer_index] 55 | self.projection = MLPHead(num_features, 512, 128) 56 | self.out_channels = 128 57 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 58 | 59 | def forward(self, x): 60 | x_pooled = self.avg_pool(x) 61 | h = x_pooled.view(x_pooled.shape[0], x_pooled.shape[1]) # removing the last dimension 62 | return self.projection(h) 63 | 64 | 65 | class MLPHead(nn.Module): 66 | def __init__(self, in_channels: int, hidden_size: int, out_size: int): 67 | super().__init__() 68 | self.net = nn.Sequential( 69 | nn.Linear(in_channels, hidden_size), 70 | nn.BatchNorm1d(hidden_size), 71 | nn.ReLU(inplace=True), 72 | nn.Linear(hidden_size, out_size) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | 79 | class BYOL(BYOL_Head): 80 | def __init__(self, backbone: nn.Module, target_momentum=0.996): 81 | super().__init__(backbone, target_momentum) 82 | 83 | -------------------------------------------------------------------------------- /sem_seg/networks/fcn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from torch import nn 3 | from .network_utils import get_num_features 4 | 5 | 6 | class FCN(nn.Module): 7 | """ Base Class for all FCN Modules """ 8 | 9 | def __init__(self, backbone: nn.Module, num_classes: int, model_type: str = 'fcn8s'): 10 | super().__init__() 11 | self.backbone = backbone 12 | num_features = get_num_features(backbone.name, model_type) 13 | self.classifier = nn.ModuleList([self.upsample_head(num_feature, num_classes) for num_feature in num_features]) 14 | 15 | def upsample_head(self, in_channels: int, channels: int) -> nn.Module: 16 | """ 17 | :param in_channels: Number of channels in Input 18 | :param channels: Desired Number of channels in Output 19 | :return: torch.nn.Module 20 | """ 21 | inter_channels = in_channels // 8 22 | layers = [ 23 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 24 | nn.BatchNorm2d(inter_channels), 25 | nn.ReLU(), 26 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 27 | nn.BatchNorm2d(inter_channels), 28 | nn.ReLU(), 29 | nn.Conv2d(inter_channels, channels, 1), 30 | ] 31 | return nn.Sequential(*layers) 32 | 33 | def forward(self, x): 34 | """ Abstract method to be implemented by child classes""" 35 | pass 36 | 37 | 38 | class FCN32s(FCN): 39 | """ Child FCN class that generates the output only using feature maps from last layer of the backbone """ 40 | def __init__(self, backbone: nn.Module, num_classes: int): 41 | super().__init__(backbone, num_classes, model_type='fcn32s') 42 | 43 | def forward(self, x): 44 | """ Forward pass through FCN32s""" 45 | h, w = x.shape[-2:] 46 | features = self.backbone(x) 47 | return self.bilinear_upsample(features, h, w) 48 | 49 | def bilinear_upsample(self, features: Dict, h: int, w: int): 50 | """ 51 | :param features: Backbone's output feature map dict 52 | :param h: Desired Output Height 53 | :param w: Desired output Width 54 | :return: Upsample output of size N x C x H x W where C is the number of classes 55 | """ 56 | out32s = self.classifier[-1](features['feat5']) 57 | upsampled_out = nn.functional.interpolate(out32s, size=(h, w), mode='bilinear', align_corners=False) 58 | return upsampled_out 59 | 60 | 61 | class FCN16s(FCN): 62 | """ Child FCN class that generates the output only using feature maps from last two layers of the backbone """ 63 | def __init__(self, backbone: nn.Module, num_classes: int): 64 | super().__init__(backbone, num_classes, model_type='fcn16s') 65 | 66 | def forward(self, x): 67 | """ Forward pass through FCN16s""" 68 | h, w = x.shape[-2:] 69 | features = self.backbone(x) 70 | return self.bilinear_upsample(features, h, w) 71 | 72 | def bilinear_upsample(self, features: Dict, h: int, w: int): 73 | """ 74 | Bilinear upsample after merging the last 2 feature maps 75 | :param features: Backbone's output feature map dict 76 | :param h: Desired Output Height 77 | :param w: Desired output Width 78 | :return: Upsample output of size N x C x H x W where C is the number of classes 79 | """ 80 | out32s = self.classifier[-1](features['feat5']) 81 | out16s = self.classifier[-2](features['feat4']) 82 | upsampled_out32s = nn.functional.interpolate(out32s, size=(h//16, w//16), mode='bilinear', align_corners=False) 83 | out = upsampled_out32s + out16s 84 | upsampled_out = nn.functional.interpolate(out, size=(h, w), mode='bilinear', align_corners=False) 85 | return upsampled_out 86 | 87 | 88 | class FCN8s(FCN): 89 | """ Child FCN class that generates the output only using feature maps from last three layers of the backbone """ 90 | def __init__(self, backbone: nn.Module, num_classes: int): 91 | super().__init__(backbone, num_classes, model_type='fcn8s') 92 | 93 | def forward(self, x): 94 | """ Forward pass through FCN16s""" 95 | h, w = x.shape[-2:] 96 | features = self.backbone(x) 97 | return self.bilinear_upsample(features, h, w) 98 | 99 | def bilinear_upsample(self, features: Dict, h: int, w: int): 100 | """ 101 | Bilinear upsample after merging the last 3 feature maps 102 | :param features: Backbone's output feature map dict 103 | :param h: Desired Output Height 104 | :param w: Desired output Width 105 | :return: Upsample output of size N x C x H x W where C is the number of classes 106 | """ 107 | out32s = self.classifier[-1](features['feat5']) 108 | out16s = self.classifier[-2](features['feat4']) 109 | out8s = self.classifier[-3](features['feat3']) 110 | upsampled_out32s = nn.functional.interpolate(out32s, size=(h//16, w//16), mode='bilinear', align_corners=False) 111 | out = upsampled_out32s + out16s 112 | upsampled_out16s = nn.functional.interpolate(out, size=(h//8, w//8), mode='bilinear', align_corners=False) 113 | out = upsampled_out16s + out8s 114 | upsampled_out = nn.functional.interpolate(out, size=(h, w), mode='bilinear', align_corners=False) 115 | return upsampled_out 116 | 117 | -------------------------------------------------------------------------------- /sem_seg/networks/network_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | from torchvision import models 4 | 5 | def get_num_features(backbone_name: str, model_type: str='') -> List: 6 | """ 7 | Gives a List of features present in the last 3 blocks of the backbone model 8 | :param backbone_name: name of the backbone model e.g. 'resnet18' | 'resnet50' 9 | :param model_type: Type of FCN model(fcn32s | fcn16s | fcn8s) 10 | :return: List of number of features extracted from last 3 blocks of the backbone model 11 | """ 12 | 13 | if 'resnet18' in backbone_name.lower(): 14 | num_features = [64, 128, 256, 512] 15 | else: 16 | num_features = [256, 512, 1024, 2048] 17 | if 'fcn8s' in model_type.lower(): 18 | num_features = num_features[-3:] 19 | elif 'fcn16s' in model_type.lower(): 20 | num_features = num_features[-2:] 21 | elif 'fcn32s' in model_type.lower(): 22 | num_features = num_features[-1:] 23 | return num_features 24 | 25 | 26 | def get_features_dict(base: torch.nn.Module) -> torch.nn.Module: 27 | """ 28 | This function extracts the features from various layers(hardcoded in a dictionary extract_layers) and returns 29 | them as a key-value pair. 30 | For e.g. extract_layer = {'layer2': 'feat3', 'layer3': 'feat4'} will extract 'layer2' and 'layer3' feature maps 31 | from the given network and assigns them to key 'feat3' and 'feat4' respectively of the output. 32 | :param base: backbone model 33 | :return: model with output layer as dictionary which provides extracted features from various layers. 34 | """ 35 | extract_layers = {'layer1': 'feat1', 'layer2': 'feat3', 'layer3': 'feat4', 'layer4': 'feat5'} 36 | return models._utils.IntermediateLayerGetter(base, extract_layers) 37 | 38 | -------------------------------------------------------------------------------- /sem_seg/networks/resnet.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch 3 | from typing import List 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 5 | from pathlib import Path 6 | from .network_utils import get_features_dict 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 15 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 16 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 17 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 18 | } 19 | 20 | 21 | def _build_model(resnet_type: str = 'resnet50', pretrained: bool = False, 22 | replace_stride_with_dilation: List = [False, False, False]) -> torch.nn.Module: 23 | """ 24 | :param resnet_type: 'resnet18' | 'resnet50' | 'resnet101' | 'resnet152' 25 | :param pretrained: True if the network is expected to be initialized with pretrained imagenet weights 26 | :param replace_stride_with_dilation: List of 3 boolean values if the last 3 blocks of resnet should use dilation 27 | instead of stride 28 | :return: torch.nn.Module 29 | """ 30 | if resnet_type == 'resnet18': 31 | base = models.resnet18(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation) 32 | elif resnet_type == 'resnet50': 33 | base = models.resnet50(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation) 34 | elif resnet_type == 'resnet101': 35 | base = models.resnet101(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation) 36 | elif resnet_type == 'resnet152': 37 | base = models.resnet152(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation) 38 | elif resnet_type == 'resnet50_2x': 39 | base = models.wide_resnet50_2(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation) 40 | else: 41 | print(f'resent type {resnet_type} is not currently implemented') 42 | return 43 | 44 | if pretrained: 45 | state_dict = load_state_dict_from_url(model_urls[resnet_type], progress=True) 46 | base.load_state_dict(state_dict) 47 | 48 | network = get_features_dict(base) 49 | network.isDRN = any(replace_stride_with_dilation) # Is Dilated Residual Network 50 | network.name = resnet_type 51 | return network 52 | 53 | 54 | Resnet18 = lambda pretrained: _build_model('resnet18', pretrained) 55 | Resnet50 = lambda pretrained: _build_model('resnet50', pretrained) 56 | Resnet101 = lambda pretrained: _build_model('resnet101', pretrained) 57 | Resnet152 = lambda pretrained: _build_model('resnet152', pretrained) 58 | DRN50 = lambda pretrained: _build_model('resnet50', pretrained, replace_stride_with_dilation = [False, True, True]) 59 | Resnet50_2 = lambda pretrained: _build_model('resnet50_2x', pretrained) 60 | Resnet101_2 = lambda pretrained: _build_model('resnet101_2x', pretrained) 61 | 62 | 63 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/training/__init__.py -------------------------------------------------------------------------------- /training/debugger.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | 6 | 7 | from config import defaults 8 | defaults.MATPLOTLIB_NO_GUI = False # enable GUI for matplotlib 9 | from sem_seg.datasets.dataloader import CityscapesLoader 10 | from sem_seg.datasets.datautils import CityscapesUtils 11 | from sem_seg.networks import FCN8s, FCN32s, FCN16s, Resnet50 12 | from sem_seg.models.segmentation_model import SegmentationModel 13 | from sem_seg.models.self_supervised_model import SelfSupervisedModel 14 | from training import utils, train 15 | 16 | 17 | DIRNAME = Path(__file__).parents[1].resolve() 18 | OUTDIR = DIRNAME / "semantic_seg" / "outputs" 19 | OUTDIR.mkdir(parents=True, exist_ok=True) 20 | 21 | 22 | def get_dataloader(mode='supervised'): 23 | cityscapes = CityscapesLoader(label_percent='100%').get_cityscapes_loader(mode=mode) 24 | return cityscapes 25 | 26 | def get_model(): 27 | cityscapes = get_dataloader() 28 | cityscapes_utils = CityscapesUtils() 29 | num_classes = cityscapes_utils.num_classes + 1 30 | num_features = [256, 512, 512] 31 | base = Resnet50(pretrained=False) 32 | fcn8s = FCN8s(base, num_classes) 33 | optim = torch.optim.Adam(fcn8s.parameters()) 34 | model = SegmentationModel(fcn8s, cityscapes, optim) 35 | return model 36 | 37 | def debug_FCN(): 38 | cityscapes = get_dataloader() 39 | batch_imgs, batch_targets = next(iter(cityscapes['train'])) 40 | utils.plot_images(batch_imgs, batch_targets, title='Predictions') 41 | 42 | def debug_self_supervised_model(): 43 | cityscapes = get_dataloader(mode='self-supervised') 44 | (batch_imgs, tf_imgs, seeds), _ = next(iter(cityscapes['train'])) 45 | preds = None 46 | utils.plot_images(batch_imgs, tf_imgs, preds, title='Predictions') 47 | 48 | if __name__ == '__main__': 49 | debug_FCN() 50 | debug_self_supervised_model() 51 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | from typing import Dict 5 | import importlib 6 | from config import defaults 7 | from sem_seg.datasets.datautils import CityscapesUtils 8 | from sem_seg.datasets.dataloader import CityscapesLoader 9 | from sem_seg import networks, models 10 | 11 | 12 | def run_experiment(experiment_config: Dict, load_weights: bool, save_weights: bool) -> None: 13 | """ 14 | Run a Training Experiment 15 | :param experiment_config: Dict of following form 16 | { 17 | "dataset": "cityscapes", 18 | "model": "SegmentationModel", # Type of the training model(SegmentationModel, SelfSupervisedModel) 19 | "network": "FCN8s", # "FCN8s" | "FCN16s" | "FCN32s" | "deeplabv3" | "deeplabv3plus" 20 | "network_args": {"backbone":"Resnet50", # "Resnet50" | "Resnet18" | "Resnet101" | "DRN50" etc 21 | "pretrained": false, # Whether the backbone model is pretrained 22 | "load_from_byol": true, # Whether to load the backbone weights from self-supervised trained model. 23 | "freeze_backbone": false}, # Whether the backbone weights are fixed while training 24 | "train_args":{"batch_size": 8, 25 | "epochs": 400, 26 | "labels_percent": '10%', #percetange of supervised labels to use for training. Default 100% 27 | "log_to_tensorboard": false}, 28 | "experiment_group":{} 29 | }' 30 | :param load_weights: If true, load weights for the model from last run 31 | :param save_weights: If true, save model weights to sem_seg/weights directory 32 | :return: None 33 | """ 34 | 35 | DEFAULT_TRAIN_ARGS = {'epochs': defaults.NUM_EPOCHS, 'batch_size': defaults.BATCH_SIZE, 36 | 'num_workers': defaults.NUM_WORKERS} 37 | train_args = { 38 | **DEFAULT_TRAIN_ARGS, 39 | **experiment_config.get("train_args", {}) 40 | } 41 | experiment_config["train_args"] = train_args 42 | experiment_config["experiment_group"] = experiment_config.get("experiment_group", None) 43 | print(f'Running experiment with config {experiment_config}') 44 | 45 | labels_percent = train_args.get('labels_percent', '100%') 46 | mode = experiment_config.get('mode', 'supervised') 47 | dataset_name = experiment_config["dataset"].lower() 48 | assert dataset_name in ['cityscapes'], "The dataloader is only implemented for cityscapes dataset" 49 | data_loader = CityscapesLoader(label_percent=labels_percent) \ 50 | .get_cityscapes_loader(batch_size=train_args["batch_size"], 51 | num_workers=train_args["num_workers"], 52 | mode=mode) 53 | models_module = importlib.import_module("sem_seg.models") 54 | model_class_ = getattr(models_module, experiment_config["model"]) 55 | 56 | networks_module = importlib.import_module("sem_seg.networks") 57 | network_args = experiment_config.get("network_args", {}) 58 | 59 | pretrained = network_args["pretrained"] 60 | backbone_class_ = getattr(networks_module, network_args["backbone"]) 61 | base = backbone_class_(pretrained=pretrained) 62 | 63 | num_classes = CityscapesUtils().num_classes 64 | 65 | network_class_ = getattr(networks_module, experiment_config["network"]) 66 | network = network_class_(base, num_classes) 67 | optim = torch.optim.SGD(network.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.00001) 68 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', factor=0.1, 69 | patience=15, min_lr=1e-10, verbose=True) 70 | additional_identifier = get_additional_identifier(network_args["backbone"], pretrained, dataset_name, 71 | labels_percent) 72 | model = model_class_(network, data_loader, optim, lr_scheduler=lr_scheduler, 73 | additional_identifier=additional_identifier) 74 | 75 | if not train_args['log_to_tensorboard']: 76 | model.logToTensorboard = False 77 | model.add_text_to_tensorboard(json.dumps(experiment_config)) 78 | 79 | if load_weights: 80 | model.load_weights() 81 | elif network_args.get('load_from_BYOL', ''): 82 | byol = networks.BYOL(base) 83 | ss_model = models.SelfSupervisedModel(byol, additional_identifier=additional_identifier) 84 | print('loading self-supervised weights from ', ss_model.weights_file_name) 85 | byol_state_dict = torch.load(ss_model.weights_file_name) 86 | backbone_dict = {} 87 | for key in byol_state_dict: 88 | if 'online_network' in key: 89 | new_key = key.replace('online_network.', '') 90 | backbone_dict[new_key] = byol_state_dict[key] 91 | model.network.backbone.load_state_dict(backbone_dict) 92 | 93 | if network_args.get('freeze_backbone', False): 94 | for param in model.network.backbone.parameters(): 95 | param.requires_grad = False 96 | 97 | model.train(num_epochs=train_args["epochs"]) 98 | model.store_per_class_iou() 99 | 100 | if save_weights: 101 | model.save_weights() 102 | 103 | 104 | def get_additional_identifier(backbone: str, pretrained: bool = False, dataset_name: str = '', 105 | labels_percent: str = '100%') -> str: 106 | """ 107 | Returns the additional_identifier added to the model name for efficient tracking of different experiments 108 | :param backbone: name of the backbone 109 | :param pretrained: Whether the backbone is pretrained 110 | :param dataset_name: Name of the training dataset 111 | :param labels_percent: % of labels used for training 112 | :return: additional identifier string 113 | """ 114 | additional_identifier = backbone 115 | additional_identifier += '_pt' if pretrained else '' 116 | additional_identifier += '_' + labels_percent[:-1] if labels_percent and int(labels_percent[:-1]) < 100 else '' 117 | additional_identifier += '_ct' if dataset_name == 'cityscapes' else '_' + dataset_name 118 | return additional_identifier 119 | 120 | 121 | def _parse_args(): 122 | """ parse command line arguments """ 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--save", default=False, action="store_true", 125 | help="If true, final weights will be stored in canonical, version-controlled location") 126 | parser.add_argument("--load", default=False, action="store_true", 127 | help="If true, final weights will be loaded from canonical, version-controlled location") 128 | parser.add_argument("experiment_config", type=str, 129 | help='Experiment JSON (\'{"dataset": "cityscapes", "model": "SegmentationModel",' 130 | ' "network": "fcn8s"}\'' 131 | ) 132 | args = parser.parse_args() 133 | return args 134 | 135 | 136 | def main(): 137 | """Run Experiment""" 138 | args = _parse_args() 139 | experiment_config = json.loads(args.experiment_config) 140 | run_experiment(experiment_config, args.load, args.save) 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /training/train_byol.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | from typing import Dict 5 | import importlib 6 | from config import defaults 7 | from sem_seg.datasets.dataloader import CityscapesLoader 8 | 9 | 10 | def run_experiment(experiment_config: Dict, load_weights: bool, save_weights: bool) -> None: 11 | """ 12 | Run a Training Experiment 13 | :param experiment_config: Dict of following form 14 | { 15 | "dataset": "cityscapes", 16 | "model": "SelfSupervisedModel", # Type of the training model(SegmentationModel, SelfSupervisedModel) 17 | "network": "BYOL", # "BYOL" 18 | "mode": "self-supervised", # Type of training(self-supervised | supervised) 19 | "network_args": {"backbone":"Resnet50", # "Resnet50" | "Resnet18" | "Resnet101" | "DRN50" etc 20 | "pretrained": false, # Whether the backbone model is pretrained 21 | "target_momentum": 0.996 #Target momentum for byol target network 22 | } 23 | "train_args":{"batch_size": 8, 24 | "epochs": 400, 25 | "log_to_tensorboard": false}, 26 | "experiment_group":{} 27 | }' 28 | :param load_weights: If true, load weights for the model from last run 29 | :param save_weights: If true, save model weights to sem_seg/weights directory 30 | :return: None 31 | """ 32 | 33 | DEFAULT_TRAIN_ARGS = {'epochs': defaults.NUM_EPOCHS, 'batch_size': defaults.BATCH_SIZE, 34 | 'num_workers': defaults.NUM_WORKERS} 35 | train_args = { 36 | **DEFAULT_TRAIN_ARGS, 37 | **experiment_config.get("train_args", {}) 38 | } 39 | experiment_config["train_args"] = train_args 40 | experiment_config["experiment_group"] = experiment_config.get("experiment_group", None) 41 | print(f'Running experiment with config {experiment_config}') 42 | 43 | labels_percent = train_args.get('labels_percent', '100%') 44 | mode = experiment_config.get('mode', 'supervised') 45 | dataset_name = experiment_config["dataset"].lower() 46 | assert dataset_name in ['cityscapes'], "The dataloader is only implemented for cityscapes dataset" 47 | data_loader = CityscapesLoader(label_percent=labels_percent) \ 48 | .get_cityscapes_loader(batch_size=train_args["batch_size"], 49 | num_workers=train_args["num_workers"], 50 | mode=mode) 51 | models_module = importlib.import_module("sem_seg.models") 52 | model_class_ = getattr(models_module, experiment_config["model"]) 53 | 54 | networks_module = importlib.import_module("sem_seg.networks") 55 | network_args = experiment_config.get("network_args", {}) 56 | 57 | pretrained = network_args["pretrained"] 58 | backbone_class_ = getattr(networks_module, network_args["backbone"]) 59 | base = backbone_class_(pretrained=pretrained) 60 | 61 | network_class_ = getattr(networks_module, experiment_config["network"]) 62 | target_momentum = train_args.get("target_momentum", 0.996) 63 | network = network_class_(base, target_momentum=target_momentum) 64 | 65 | training_params = [*network.online_network.parameters()] + [*network.online_projector.parameters()] + [ 66 | *network.predictor.parameters()] # explicitly excluding the target network parameters 67 | optim = torch.optim.SGD(training_params, lr=3e-2, momentum=0.9, weight_decay=0.00001) 68 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.2, 69 | patience=20, min_lr=1e-10, verbose=True) 70 | additional_identifier = get_additional_identifier(network_args["backbone"], pretrained, dataset_name, 71 | labels_percent) 72 | model = model_class_(network, data_loader, optim, lr_scheduler=lr_scheduler, 73 | additional_identifier=additional_identifier) 74 | 75 | if not train_args['log_to_tensorboard']: 76 | model.logToTensorboard = False 77 | model.add_text_to_tensorboard(json.dumps(experiment_config)) 78 | 79 | if load_weights: 80 | model.load_weights() 81 | 82 | model.train(num_epochs=train_args["epochs"]) 83 | 84 | if save_weights: 85 | model.save_weights() 86 | 87 | 88 | def get_additional_identifier(backbone: str, pretrained: bool = False, dataset_name: str = '', 89 | labels_percent: str = '100%') -> str: 90 | """ 91 | Returns the additional_identifier added to the model name for efficient tracking of different experiments 92 | :param backbone: name of the backbone 93 | :param pretrained: Whether the backbone is pretrained 94 | :param dataset_name: Name of the training dataset 95 | :param labels_percent: % of labels used for training 96 | :return: additional identifier string 97 | """ 98 | additional_identifier = backbone 99 | additional_identifier += '_pt' if pretrained else '' 100 | additional_identifier += '_' + labels_percent[:-1] if labels_percent and int(labels_percent[:-1]) < 100 else '' 101 | additional_identifier += '_ct' if dataset_name == 'cityscapes' else '_' + dataset_name 102 | return additional_identifier 103 | 104 | 105 | def _parse_args(): 106 | """ parse command line arguments """ 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--save", default=False, action="store_true", 109 | help="If true, final weights will be stored in canonical, version-controlled location") 110 | parser.add_argument("--load", default=False, action="store_true", 111 | help="If true, final weights will be loaded from canonical, version-controlled location") 112 | parser.add_argument("experiment_config", type=str, 113 | help='Experiment JSON (\'{"dataset": "cityscapes", "model": "SegmentationModel",' 114 | ' "network": "fcn8s"}\'' 115 | ) 116 | args = parser.parse_args() 117 | return args 118 | 119 | 120 | def main(): 121 | """Run Experiment""" 122 | args = _parse_args() 123 | experiment_config = json.loads(args.experiment_config) 124 | run_experiment(experiment_config, args.load, args.save) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | import matplotlib 4 | from config import defaults 5 | 6 | # if defaults.MATPLOTLIB_NO_GUI: 7 | # matplotlib.use('Agg') 8 | 9 | import matplotlib.pyplot as plt 10 | import math 11 | import torch 12 | from pathlib import Path 13 | 14 | from sem_seg.datasets.datautils import CityscapesUtils 15 | 16 | DIRNAME = Path(__file__).parents[1].resolve() 17 | OUTDIR = DIRNAME / "sem_seg" / "outputs" 18 | OUTDIR.mkdir(parents=True, exist_ok=True) 19 | 20 | cityscape_utils = CityscapesUtils() 21 | 22 | 23 | def sanitize_imgs(images: List) -> List: 24 | """Converts the list of torch image Tensors in (N, C, H, W) format to list of numpy images in format""" 25 | for i in range(len(images)): 26 | if torch.is_tensor(images[i]): 27 | images[i] = images[i].cpu().numpy().squeeze() 28 | if images[i].ndim > 3 and images[i].shape[-1] != 3: 29 | images[i] = np.moveaxis(images[i], 1, -1) 30 | return images 31 | 32 | 33 | def is_mask(img: np.array) -> bool: 34 | """ 35 | Checks if the image is a segmentation mask 36 | Criteria: If its a 2D image and has values in range of num_classes of the dataset 37 | :param img: Image 38 | :return: True if the image is a segmentation mask, False otherwise 39 | """ 40 | return img.ndim == 2 and 0 <= len(np.unique(img)) < cityscape_utils.num_classes + 1 41 | 42 | 43 | def plot_images(imgs, targets=None, preds=None, title='No title', num_cols=6) -> plt.Figure: 44 | """ 45 | Plot the images, targets and predictions in a grid. Shows 24 images at max in the plot. 46 | :param imgs: List of images to be plotted. 47 | :param targets: Corresponding List of ground truths. 48 | :param preds: Corresponding List of network predictions 49 | :param title: Title of the figure 50 | :param num_cols: Number of columns in the grid. Default is 6 51 | :return: output plot figure object 52 | """ 53 | inputs = [imgs] 54 | if targets is not None: 55 | inputs.append(targets) 56 | if preds is not None: 57 | inputs.append(preds) 58 | inputs = sanitize_imgs(inputs) 59 | num_types = len(inputs) 60 | num_rows, num_cols, plot_width, plot_height = get_plot_sizes(inputs, num_cols) 61 | fig, ax = plt.subplots(num_rows, num_cols, figsize=(plot_width, plot_height), num=title, squeeze=False, 62 | gridspec_kw={'wspace': 0.05, 'hspace': 0.05, 'left': 0, 'top': 0.95}) 63 | [axi.set_axis_off() for axi in ax.ravel()] 64 | 65 | for i in range(len(inputs[0])): 66 | r, c = (num_types * i) // num_cols, (num_types * i) % num_cols 67 | if r >= num_rows: 68 | break 69 | for j in range(num_types): 70 | img = inputs[j][i] 71 | img_to_show = cityscape_utils.label2color(img) if is_mask(img) else img 72 | ax[r][c + j].imshow(img_to_show) 73 | if title: 74 | fig.suptitle(title, fontsize=8) 75 | #if defaults.MATPLOTLIB_NO_GUI: 76 | # plt.savefig(str(OUTDIR) + '/plot2.png', bbox_inches='tight') 77 | #else: 78 | plt.show() 79 | return fig 80 | 81 | 82 | def get_plot_sizes(inputs: List, num_cols: int) -> Tuple: 83 | """ Given the input images, return the dimension of the plot i.e. num_rows, num_cols, Height, Width""" 84 | num_types, num_images = len(inputs), len(inputs[0]) 85 | num_rows = min(4, math.ceil((num_images * num_types) / num_cols)) 86 | if inputs[0].ndim == 4: 87 | N, C, H, W = inputs[0].shape 88 | elif inputs[0].ndim == 3: 89 | C, H, W = inputs[0].shape 90 | else: 91 | H, W = inputs[0].shape 92 | aspect_ratio = W / H 93 | return num_rows, num_cols, num_cols * aspect_ratio * 0.8, num_rows 94 | 95 | 96 | def save_results(data, filename): 97 | filename = str(OUTDIR) + '/' + filename + '.csv' 98 | np.savetxt(filename, data, fmt='%0.6f') 99 | print(f'saved at {filename}') 100 | --------------------------------------------------------------------------------