├── .gitignore
├── BYOL_gain.png
├── README.md
├── byol.png
├── byol_gain.png
├── config
└── defaults.py
├── environment.yml
├── scripts
├── create_env.sh
├── test.sh
├── train.sh
└── unsup_train.sh
├── sem_seg
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── augmentations.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── datautils.py
│ └── transforms.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── eval_metrices.py
│ ├── segmentation_model.py
│ └── self_supervised_model.py
└── networks
│ ├── __init__.py
│ ├── byol.py
│ ├── fcn.py
│ ├── network_utils.py
│ └── resnet.py
└── training
├── __init__.py
├── debugger.py
├── train.py
├── train_byol.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | myvenv
2 | .idea
3 | __pycache__
4 | .DS_Store
--------------------------------------------------------------------------------
/BYOL_gain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/BYOL_gain.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-supervised Semantic Segmentation
2 | This repository is related to semantic segmentation models trainined in
3 | self-supervised manner. In a two step training process, the backbone
4 | of the segmentation network(say FCN8s network) is trained first
5 | in self-supervised way and then the whole network is fine-tuned
6 | with few semantic segmentation annotations.
7 |
8 | 
9 |
10 | Read more about BYOL architecture in the blog post: https://nilesh0109.medium.com/hands-on-review-byol-bootstrap-your-own-latent-67e4c5744e1b
11 | The baseline segmentation network used is FCN8s and various backbones are: Resnet50, Resnet101, DRN50.
12 |
13 | To reproduce the experiments, follow the below steps.
14 |
15 | # How To Run The Code
16 | 1. Create the virtual environment with all dependencies and activate it
17 |
18 | ``` bash scripts/create_env.sh```
19 |
20 | 2. Train the Self-supervised Model first using unlabelled data.
21 |
22 | ``` bash scripts\unsup_train.sh ```
23 |
24 | 3. Train the supervised model e2e with using only labelled
25 | data.
26 |
27 | ``` bash scripts\train.sh```
28 |
29 | # Results:
30 | Resnet50 based Semantic segmentation FCN8s network, gets 10% improvement in mIOU on cityscapes dataset if it is first pretraited in BYOL manner using 20k unlabelled images and then fine-tuned with 5k labelled cityscapes images.
31 |
32 | X-axis in below graph shows the percetage of labels(out of 5k images) used in fine-tuning step. Top dotted line in the graph is Resnet50 trained from imagenet weights and bottom dotted line is Resnet50 trained from random weights(no pretraining). Middle plot is Resnet50 pretrained using BYOL. Red arrows clearly indicates the success of BYOL for visual representation learning.
33 |
34 | 
35 |
36 |
37 | # Code Walkthrough
38 | ```
39 | config
40 | - defaults.py: Configuration file for setting the defaults
41 | scripts
42 | - create_env.sh: script file for creating the virtual environment with dependencies listed in environment.yml
43 | - test.sh: script file for debugging the inputs & outputs to the FCN8s and BYOL network.
44 | - train.sh: script for training the segmentation network
45 | - unsup_train.sh: script for training the unsupervised learning(BYOL) network
46 | sem_seg
47 | datasets
48 | - augmentations.py: Utility file for various augmentations
49 | - dataloader.py: Script for loading cityscapes dataloader
50 | - dataset.py: Custom cityscapes dataset for applying same set of augmentations to images and masks
51 | - datautils.py: Script for cityscapes labels formatting
52 | - transforms.py: Custom transformation scripts
53 | models
54 | - base.py: Base class for all models with train and evaluation pipeline
55 | - eval_metrices.py: Script for different evaluation metrices such as mIOU, accuracy, etc.
56 | - segmentation.py: Class for semantic segmentaion model inherited from base model
57 | - self_supervised_model.py: Self-supervised-model(BYOL) training pipeline
58 | network
59 | - byol.py: BYOL network setup
60 | - fcn.py: Various FCN(Fully Convolutional Network) setup
61 | - network_utils.py: Utility script for vairous network configurations
62 | - resnet.py: Various Residual network(Resnet) setup
63 | training
64 | - debugger.py: Debugger script for FCN and BYOL model
65 | - train.py: Training file for FCN model
66 | - train_byol.py: Training file for BYOL model
67 | - utils.py: Utility file for plotting various inputs and outputs
68 | - environment.yml: List of virtual environment dependencies
69 | ```
70 | # References
71 | 1. Grill, Jean-Bastien, et al. "Bootstrap your own latent: A new approach to self-supervised learning." arXiv preprint arXiv:2006.07733 (2020).
72 |
--------------------------------------------------------------------------------
/byol.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/byol.png
--------------------------------------------------------------------------------
/byol_gain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/byol_gain.png
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | CITYSCAPES_PATH = './cityscapes_data/'
2 | CITYSCAPES_IGNORED_INDEX = 19
3 |
4 | #Training Config
5 | NUM_EPOCHS = 15
6 | BATCH_SIZE = 8
7 | NUM_WORKERS = 0 #Parallel workers to prepare dataloader from dataset
8 |
9 | #For Matplotlib
10 | MATPLOTLIB_NO_GUI = True
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: myvenv
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - absl-py=0.11.0=pyhd3eb1b0_1
7 | - aiohttp=3.7.3=py37h9ed2024_1
8 | - async-timeout=3.0.1=py37_0
9 | - attrs=20.3.0=pyhd3eb1b0_0
10 | - blas=1.0=mkl
11 | - blinker=1.4=py37_0
12 | - brotlipy=0.7.0=py37h9ed2024_1003
13 | - c-ares=1.17.1=h9ed2024_0
14 | - ca-certificates=2020.12.8=hecd8cb5_0
15 | - cachetools=4.2.0=pyhd3eb1b0_0
16 | - certifi=2020.12.5=py37hecd8cb5_0
17 | - cffi=1.14.4=py37h2125817_0
18 | - chardet=3.0.4=py37hecd8cb5_1003
19 | - click=7.1.2=py_0
20 | - cryptography=2.9.2=py37hbcfaee0_0
21 | - cycler=0.10.0=py37_0
22 | - freetype=2.10.4=ha233b18_0
23 | - google-auth=1.24.0=pyhd3eb1b0_0
24 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2
25 | - grpcio=1.31.0=py37h7580e61_0
26 | - idna=2.10=py_0
27 | - importlib-metadata=2.0.0=py_1
28 | - intel-openmp=2019.4=233
29 | - jpeg=9b=he5867d9_2
30 | - kiwisolver=1.3.0=py37h23ab428_0
31 | - lcms2=2.11=h92f6f08_0
32 | - libcxx=10.0.0=1
33 | - libedit=3.1.20191231=h1de35cc_1
34 | - libffi=3.3=hb1e8313_2
35 | - libpng=1.6.37=ha441bb4_0
36 | - libprotobuf=3.13.0.1=hab81aa3_0
37 | - libtiff=4.1.0=hcb84e12_1
38 | - libuv=1.40.0=haf1e3a3_0
39 | - lz4-c=1.9.2=h79c402e_3
40 | - markdown=3.3.3=py37hecd8cb5_0
41 | - matplotlib=3.3.2=hecd8cb5_0
42 | - matplotlib-base=3.3.2=py37h181983e_0
43 | - mkl=2019.4=233
44 | - mkl-service=2.3.0=py37h9ed2024_0
45 | - mkl_fft=1.2.0=py37hc64f4ea_0
46 | - mkl_random=1.1.1=py37h959d312_0
47 | - multidict=4.7.6=py37haf1e3a3_1
48 | - ncurses=6.2=h0a44026_1
49 | - ninja=1.10.2=py37hf7b0b51_0
50 | - numpy=1.19.2=py37h456fd55_0
51 | - numpy-base=1.19.2=py37hcfb5961_0
52 | - oauthlib=3.1.0=py_0
53 | - olefile=0.46=py37_0
54 | - openssl=1.1.1i=h9ed2024_0
55 | - pillow=8.0.1=py37h5270095_0
56 | - pip=20.3.3=py37hecd8cb5_0
57 | - protobuf=3.13.0.1=py37hb1e8313_1
58 | - pyasn1=0.4.8=py_0
59 | - pyasn1-modules=0.2.8=py_0
60 | - pycparser=2.20=py_2
61 | - pyjwt=2.0.0=py37hecd8cb5_0
62 | - pyopenssl=20.0.1=pyhd3eb1b0_1
63 | - pyparsing=2.4.7=py_0
64 | - pysocks=1.7.1=py37hecd8cb5_0
65 | - python=3.7.9=h26836e1_0
66 | - python-dateutil=2.8.1=py_0
67 | - pytorch=1.7.1=py3.7_0
68 | - readline=8.0=h1de35cc_0
69 | - requests=2.25.1=pyhd3eb1b0_0
70 | - requests-oauthlib=1.3.0=py_0
71 | - rsa=4.6=py_0
72 | - setuptools=51.0.0=py37hecd8cb5_2
73 | - six=1.15.0=py37hecd8cb5_0
74 | - sqlite=3.33.0=hffcf06c_0
75 | - tensorboard=2.3.0=pyh4dce500_0
76 | - tensorboard-plugin-wit=1.6.0=py_0
77 | - tk=8.6.10=hb0a8c7a_0
78 | - torchvision=0.8.2=py37_cpu
79 | - tornado=6.1=py37h9ed2024_0
80 | - tqdm=4.54.1=pyhd3eb1b0_0
81 | - typing-extensions=3.7.4.3=0
82 | - typing_extensions=3.7.4.3=py_0
83 | - urllib3=1.26.2=pyhd3eb1b0_0
84 | - werkzeug=1.0.1=py_0
85 | - wheel=0.36.2=pyhd3eb1b0_0
86 | - xz=5.2.5=h1de35cc_0
87 | - yarl=1.5.1=py37haf1e3a3_0
88 | - zipp=3.4.0=pyhd3eb1b0_0
89 | - zlib=1.2.11=h1de35cc_3
90 | - zstd=1.4.5=h41d2c2f_0
91 | prefix: /Users/nileshvijayrania/anaconda3/envs/myvenv
92 |
--------------------------------------------------------------------------------
/scripts/create_env.sh:
--------------------------------------------------------------------------------
1 | conda env create -n bs_venv -f environment.yml
2 | conda activate bs_venv
3 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | python -m training.debugger
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #For supervised Training
2 | config='{"dataset": "cityscapes", "model": "SegmentationModel", "network": "FCN8s",
3 | "network_args":{"backbone":"Resnet50", "pretrained": false, "load_from_byol": false, "freeze_backbone": false},
4 | "train_args":{"batch_size": 8, "epochs": 400, "log_to_tensorboard": false},
5 | "experiment_group":{}
6 | }'
7 |
8 | python -m training.train --save "$config"
--------------------------------------------------------------------------------
/scripts/unsup_train.sh:
--------------------------------------------------------------------------------
1 | # for unsupervised Training
2 |
3 | config='{"dataset": "cityscapes", "model": "SelfSupervisedModel", "network": "BYOL", "mode": "self-supervised",
4 | "network_args":{"backbone":"Resnet50", "pretrained": false, "target_momentum": 0.996},
5 | "train_args":{"batch_size": 32, "epochs": 700, "log_to_tensorboard": false},
6 | "experiment_group":{}
7 | }'
8 |
9 | python -m training.train_byol --save "$config"
10 |
11 | #Use for distributed Training
12 | #python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 -m training.train_byol --save "$config"
--------------------------------------------------------------------------------
/sem_seg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/sem_seg/__init__.py
--------------------------------------------------------------------------------
/sem_seg/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['augmentations', 'datautils', 'dataloader', 'dataset', 'transforms']
2 |
--------------------------------------------------------------------------------
/sem_seg/datasets/augmentations.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import numpy as np
3 | import torch
4 | from torchvision import transforms as tfms
5 | from .datautils import CityscapesUtils
6 | from .transforms import RandomResize
7 | from PIL import Image
8 |
9 |
10 | def get_tfms() -> Tuple:
11 | """
12 | :return: Image and mask transforms tuple for cityscapes fine annotated images
13 | """
14 | base_size = 800
15 | min_size, max_size = int(0.5*base_size), int(2.0*base_size)
16 | img_transforms = {
17 | 'train': tfms.Compose([
18 | RandomResize(min_size, max_size),
19 | tfms.RandomCrop(size=768, pad_if_needed=True, fill=0),
20 | tfms.RandomHorizontalFlip(p=0.5),
21 | tfms.ToTensor()
22 | ]),
23 | 'val': tfms.Compose([
24 | tfms.ToTensor()
25 | ])
26 | }
27 |
28 | target_transforms = {
29 | 'train': tfms.Compose([
30 | RandomResize(min_size, max_size, is_mask=True),
31 | tfms.RandomCrop(size=768, pad_if_needed=True, fill=0),
32 | tfms.RandomHorizontalFlip(p=0.5),
33 | mapId2TrainID
34 | ]),
35 | 'val': tfms.Compose([mapId2TrainID])
36 | }
37 | return img_transforms, target_transforms
38 |
39 | cityscapes_utils = CityscapesUtils()
40 |
41 | def mapId2TrainID(mask: Image) -> torch.Tensor:
42 | """
43 | Redcudes the 34 labels present in gt_fine to 19 labels. Ignoring the ignore_in_eval labels.
44 | :param mask: Cityscapes mask(PIL Image type) with 34 classes
45 | :return: mask with 19 cityscapes classes of tensor type
46 | """
47 | return torch.from_numpy(cityscapes_utils.id2train_id[np.array(mask)]).long()
48 |
49 |
50 | def get_spatial_tfms() -> tfms.Compose:
51 | """ Get the transformation which changes the spatial position of the object in the image"""
52 | return tfms.Compose([
53 | tfms.RandomResizedCrop(size=512, scale=(0.3, 1)),
54 | tfms.RandomHorizontalFlip()
55 | ])
56 |
57 |
58 | def get_pixelwise_tfms() -> List[tfms.Compose]:
59 | """Get the transformations which only introduces local perturbation to the image."""
60 | tfms_list = [
61 | tfms.RandomApply([
62 | tfms.ColorJitter(0.4, 0.4, 0.4, 0.1)
63 | ], p=0.8),
64 | tfms.RandomGrayscale(p=0.2)
65 | ]
66 | return tfms_list
67 |
68 |
69 | def get_self_supervised_tfms() -> tfms.Compose:
70 | """ Returns the list of transformations to be used to generating self-supervised training image pairs"""
71 | spatial_tfms = get_spatial_tfms()
72 | pixelwise_tfms = get_pixelwise_tfms()
73 |
74 | self_supervised_tfms = tfms.Compose([
75 | *pixelwise_tfms,
76 | tfms.ToTensor(),
77 | spatial_tfms,
78 | tfms.Lambda(lambda img: img.squeeze())
79 | ])
80 | return self_supervised_tfms
81 |
--------------------------------------------------------------------------------
/sem_seg/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import os
3 | import torch
4 | from torchvision import datasets
5 | from torch.utils.data.sampler import SequentialSampler, RandomSampler
6 | from config.defaults import CITYSCAPES_PATH, BATCH_SIZE, NUM_WORKERS
7 | from .dataset import CityscapesDataset
8 | from .augmentations import get_tfms, get_self_supervised_tfms
9 |
10 |
11 | class CityscapesLoader:
12 | """Prepares the cityscapes DataLoader"""
13 |
14 | def __init__(self, label_percent: str = '100%') -> None:
15 | """
16 | :param label_percent: Percentage of cityscpaes labels to be used. Must be suffixed with %
17 | """
18 | label_percent = min(float(label_percent[:-1]), 100.0)
19 | self._prepare_dataset(label_percent)
20 |
21 | def _prepare_dataset(self, label_percent: int = 100) -> None:
22 | """
23 | Constructs all 3 datasets(train, test, val) using CityScapesDataset class
24 | :param label_percent: percentage of cityscapes labels to use
25 | :return: None
26 | """
27 |
28 | data_fine = {phase: datasets.Cityscapes(CITYSCAPES_PATH, split=phase, mode='fine',
29 | target_type='semantic') for phase in ['train', 'test', 'val']}
30 | if 'gtCoarse' in os.listdir(CITYSCAPES_PATH):
31 | data_coarse = {phase: datasets.Cityscapes(CITYSCAPES_PATH, split=phase, mode='coarse',
32 | target_type='semantic') for phase in ['train_extra', 'val']}
33 | else:
34 | data_coarse = {
35 | 'train_extra': [],
36 | 'val': data_fine['val']
37 | }
38 |
39 | img_tfms, target_tfms = get_tfms()
40 | self_supervised_tfms = get_self_supervised_tfms()
41 |
42 | self.cityscapes = {phase: CityscapesDataset(data_fine[phase],
43 | label_percent=label_percent if phase == 'train' else 100,
44 | transform=img_tfms[phase],
45 | target_transform=target_tfms[phase])
46 | for phase in ['train', 'val']}
47 |
48 | #self.cityscapes['train'].imgs_path = self.cityscapes['train'].imgs_path[:30]
49 | #self.cityscapes['val'].imgs_path = self.cityscapes['val'].imgs_path[:30]
50 |
51 | self.cityscapes['self-supervised'] = {
52 | phase: CityscapesDataset([data_fine['train'], data_fine['test']]
53 | if phase == 'train' else data_coarse['val'],
54 | label_percent=label_percent if phase == 'train' else 100,
55 | transform=self_supervised_tfms,
56 | mode='self-supervised')
57 | for phase in ['train', 'val']
58 | }
59 |
60 | def get_cityscapes_loader(self, batch_size: int = BATCH_SIZE, num_workers: int = NUM_WORKERS,
61 | mode: str = 'supervised') -> Dict[str, torch.utils.data.DataLoader]:
62 | """
63 | Returns Cityscapes dataloader with correct sampler
64 | :param batch_size: number of batches
65 | :param num_workers: number of parallel workers for dataloading on CPU.
66 | :param mode: 'supervised' or 'self-supervised' type of dataloader
67 | :return: Dict of train and val data loader
68 | """
69 |
70 | data = self.cityscapes if mode == 'supervised' else self.cityscapes['self-supervised']
71 |
72 | cityscapes_loader = {x: torch.utils.data.DataLoader(data[x], batch_size=batch_size,
73 | sampler=RandomSampler(data[x]) if x == 'train'
74 | else SequentialSampler(data[x]),
75 | drop_last=bool(mode == 'supervised' and x == 'train'),
76 | num_workers=num_workers,
77 | pin_memory=True)
78 | for x in ['train', 'val']}
79 | return cityscapes_loader
80 |
--------------------------------------------------------------------------------
/sem_seg/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union, List, Tuple
2 | import random
3 | import numpy as np
4 | import functools
5 | from PIL import Image
6 |
7 | import torch
8 | import torchvision
9 | from torch.utils.data import Dataset
10 |
11 | from .augmentations import get_pixelwise_tfms
12 |
13 |
14 | class CityscapesDataset(Dataset):
15 | """Custom cityscapes dataset to apply same transformation on image and mask pair"""
16 |
17 | def __init__(self, cityscapes_data: Union[Dataset, List], label_percent: int = 100, transform: Callable=None,
18 | target_transform: Callable = None, is_test: bool = False, mode: str = 'supervised'):
19 | """
20 | :param cityscapes_data: torch loaded Cityscapes dataset from path defaults.CITYSCAPES_PATH
21 | :param label_percent: % of labels to use
22 | :param transform: List of Transformations to be applied on the input Image
23 | :param target_transform: List of transformations to be applied on the segmentaion mask
24 | :param is_test: is Test/Val dataset
25 | :param mode: type of dataset. 'supervised' | 'self-supervised'. 'self-supervised' mode will generate
26 | pair and 'supervised' mode will generate
pair
27 | """
28 |
29 | inputs = functools.reduce(lambda a, b: a + b.images, cityscapes_data, []) \
30 | if type(cityscapes_data) == list else cityscapes_data.images
31 | total = len(inputs)
32 | n = int(total * label_percent) // 100
33 | self.imgs_path = inputs[:n]
34 | self.masks_path = None if is_test or mode.lower() == 'self-supervised' else cityscapes_data.targets
35 | self.transform = transform
36 | self.target_transform = target_transform
37 | self.isTestSet = is_test
38 | self.mode = mode.lower()
39 | assert self.mode in ['supervised', 'self-supervised'], "Invalid dataset mode. Only 'supervised', " \
40 | "'self-supervised' is allowed"
41 |
42 | def __len__(self):
43 | return len(self.imgs_path)
44 |
45 | def get_random_crop(self, image: Image, crop_size: int) -> Image:
46 | """
47 | :param image: PIL Image
48 | :param crop_size: Size of the crop
49 | :return: PIL Image crop of size crop_size
50 | """
51 | crop_tfms = torchvision.transforms.RandomCrop(crop_size)
52 | return crop_tfms(image)
53 |
54 | def _apply_transformations(self, image: Image, mask: Image = None) -> Tuple:
55 | """
56 |
57 | :param image: PIL Image
58 | :param mask: PIL segmentation Mask
59 | :return: Transformed Image and Mask pair after applying exactly the same random transformations on image and
60 | mask
61 | """
62 | seed = np.random.randint(2147483647) # sample a random seed
63 | random.seed(seed) # set this seed to random and numpy.random function
64 | np.random.seed(seed)
65 | torch.manual_seed(seed) # Needed for torchvision 0.7
66 | if self.transform is not None:
67 | image = self.transform(image)
68 |
69 | if mask is not None and self.target_transform is not None:
70 | random.seed(seed) # set the same seed to random and numpy.random function
71 | np.random.seed(seed)
72 | torch.manual_seed(seed) # Needed for torchvision 0.7
73 | mask = self.target_transform(mask)
74 | return image, mask, seed
75 |
76 | def __getitem__(self, idx):
77 | if torch.is_tensor(idx):
78 | idx = idx.tolist()
79 | image_path = self.imgs_path[idx]
80 | image = Image.open(image_path)
81 | if image.mode == 'P':
82 | image = image.convert('RGB')
83 |
84 | if self.mode == 'self-supervised':
85 | crop = self.get_random_crop(image, 512)
86 | tf_image1, _, seed = self._apply_transformations(crop)
87 | crop_tensor = torchvision.transforms.Compose([
88 | *get_pixelwise_tfms(),
89 | torchvision.transforms.ToTensor()
90 | ])(crop)
91 | return (crop_tensor, tf_image1, seed), seed
92 | else:
93 | if not self.isTestSet:
94 | mask_path = image_path.replace('leftImg8bit', 'gtFine').replace('.png', '_labelIds.png')
95 | mask = Image.open(mask_path)
96 | image, mask, seed = self._apply_transformations(image, mask)
97 | return image if self.isTestSet else (image, mask)
98 |
99 |
--------------------------------------------------------------------------------
/sem_seg/datasets/datautils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import datasets
3 | from config.defaults import CITYSCAPES_PATH
4 |
5 |
6 | class CityscapesUtils:
7 | """CITYSCAPES Utility class that provides the mapping for training labels and their colors for visualization"""
8 | def __init__(self):
9 | cityscapes_data = datasets.Cityscapes(CITYSCAPES_PATH, split='train', mode='fine', target_type='semantic')
10 | self.classes = cityscapes_data.classes
11 | self.num_classes = self._num_classes()
12 | self.train_id2color = self._train_id2color()
13 | self.id2train_id = self._id2train_id()
14 |
15 | def _num_classes(self) -> int:
16 | """
17 | :return: returns the effective number of classes in cityscapes that are used in validation
18 | """
19 | train_labels = [label.id for label in self.classes if not label.ignore_in_eval]
20 | return len(train_labels)
21 |
22 | def _id2train_id(self) -> np.array:
23 | """
24 | :return: returns a list where each index is mapped to its training_id. All ignore_in_eval indexes are mapped to 0
25 | i.e. the unlabelled class.
26 | """
27 | train_ids = np.array([label.train_id for label in self.classes])
28 | train_ids[(train_ids == -1) | (train_ids == 255)] = 19 # 19 is Ignore_index(defaults.CITYSCAPES_IGNORE_INDEX)
29 | return train_ids
30 |
31 | def _train_id2color(self) -> np.array:
32 | """
33 | :return: The mapping of 20 classes (19 training classes + 1 ignore index class) to their standard color used
34 | in cityscapes.
35 | """
36 | return np.array([label.color for label in self.classes if not label.ignore_in_eval] + [(0, 0, 0)])
37 |
38 | def label2color(self, mask: np.array) -> np.array:
39 | """
40 | Given the cityscapes mask with all training id(and optionally 255 for ignored labels) as labels, returns the mask
41 | filled with the label's standard color.
42 | :param mask: np.array mask for which color mapping is required
43 | :return: mask with labels replaced with their standard colors
44 | """
45 | return self.train_id2color[mask]
46 |
--------------------------------------------------------------------------------
/sem_seg/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | from torchvision.transforms import functional as F
3 | from PIL import Image
4 |
5 |
6 | class RandomResize(object):
7 | """Resize the object randomly between the min and max size"""
8 | def __init__(self, min_size: int, max_size: int, is_mask: bool = False):
9 | """
10 | :param min_size: min desired size of the image
11 | :param max_size: max desired size of the image
12 | :param is_mask: If the image is the segmentation mask
13 | """
14 | self.min_size = min_size
15 | self.max_size = max_size
16 | self.is_mask = is_mask
17 |
18 | def __call__(self, img: Image) -> Image:
19 | size = random.randint(self.min_size, self.max_size)
20 | if self.is_mask:
21 | return F.resize(img, size=size, interpolation=Image.NEAREST)
22 | else:
23 | return F.resize(img, size=size)
24 |
--------------------------------------------------------------------------------
/sem_seg/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .segmentation_model import SegmentationModel
2 | from .self_supervised_model import SelfSupervisedModel
3 |
4 | __all__ = ['SegmentationModel', 'base', 'SelfSupervisedModel']
--------------------------------------------------------------------------------
/sem_seg/models/base.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Callable, Dict
3 | from tqdm import tqdm
4 | import time
5 | import copy
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | from . import eval_metrices
10 | from training import utils
11 |
12 | DIRNAME = Path(__file__).parents[1].resolve()
13 | WEIGHTSDIR = DIRNAME / "weights"
14 | LOGSDIR = DIRNAME / "tensorboard_logs"
15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 |
18 | class Model:
19 | """ Base Class for training and evaluation of any network """
20 |
21 | def __init__(self, network: torch.nn.Module,
22 | dataloader: torch.utils.data.DataLoader,
23 | optimizer: torch.optim,
24 | criterion: Callable,
25 | lr_scheduler: torch.optim.lr_scheduler,
26 | additional_identifier: str = ''):
27 | self.network = network
28 | self.dataloader = dataloader
29 | self.optim = optimizer
30 | self.criterion = criterion
31 | self.lr_scheduler = lr_scheduler
32 | self.name = f"{self.__class__.__name__}_{self.network.__class__.__name__}"
33 | if additional_identifier:
34 | self.name += '_' + additional_identifier
35 | self.metrices = ['accuracy']
36 | self.logToTensorboard = True
37 | self.writer = None
38 |
39 | @property
40 | def weights_file_name(self) -> str:
41 | """
42 | :return: returns the path where to store/load the weights of the model.
43 | """
44 | WEIGHTSDIR.mkdir(parents=True, exist_ok=True)
45 | return str(WEIGHTSDIR / f"{self.name}.h5")
46 |
47 | def train(self, num_epochs: int) -> None:
48 | """
49 | 1. set up the model in appropriate training mode(Use this to setup DDP model for distributed training also if
50 | required)
51 | 2. Train the model for num_epochs and keep track of the best_model_weights based on self.metric[0] values.
52 | 3. Log the epoch results( loss and other stats mentioned in self.metrices) on tensorboard
53 | :param num_epochs: Number of training epochs
54 | :return: None
55 | """
56 | since = time.time()
57 | losses = {'train': [], 'val': []}
58 | stats = {'train': [], 'val': []}
59 | lr_rates = []
60 | best_epoch, self.best_stat = None, float('-inf')
61 | best_model_weights = None
62 | self.network.to(device)
63 |
64 | for epoch in tqdm(range(num_epochs)):
65 | for phase in ['train', 'val']:
66 | if phase == 'train':
67 | self.network.train()
68 | else:
69 | self.network.eval()
70 | running_loss, running_stats = 0.0, [0.0] * len(self.metrices)
71 |
72 | for inputs, outputs in self.dataloader[phase]:
73 | inputs, outputs = inputs.to(device), outputs.to(device)
74 | self.optim.zero_grad()
75 | with torch.set_grad_enabled(phase == 'train'):
76 | network_outputs = self.network(inputs)
77 | loss = self.criterion(network_outputs, outputs)
78 | batch_stats = self.evaluate_metrices(network_outputs, outputs)
79 |
80 | if phase == 'train':
81 | loss.backward()
82 | self.optim.step()
83 |
84 | n = inputs.size(0)
85 | running_loss += loss.item() * n
86 | running_stats = self.update_running_stats(batch_stats, running_stats, n)
87 | N = len(self.dataloader[phase].dataset)
88 | epoch_loss = running_loss / N
89 | epoch_stats = self.get_epoch_stats(running_stats, N)
90 |
91 | if phase == 'val' and epoch_stats[0] > self.best_stat:
92 | self.best_stat = epoch_stats[0]
93 | best_epoch = epoch
94 | best_model_weights = copy.deepcopy(self.network.state_dict())
95 | losses[phase].append(epoch_loss)
96 | stats[phase].append(epoch_stats)
97 |
98 | self.log_stats(epoch, losses, stats, [inputs, network_outputs, outputs], lr_rates)
99 |
100 | if self.lr_scheduler is not None:
101 | self.lr_scheduler.step(epoch_stats[0]) # considering ReduceLrOnPlateau and watch val mIOU
102 | lr_step = self.optim.state_dict()["param_groups"][0]["lr"]
103 | lr_rates.append(lr_step)
104 | min_lr = self.lr_scheduler.min_lrs[0]
105 | if lr_step / min_lr <= 10:
106 | print("Min LR reached")
107 | break
108 |
109 | time_elapsed = time.time() - since
110 | print('Training Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
111 | print('Best {}: {:.6f} and best epoch is {}'.format(self.metrices[0], self.best_stat, best_epoch + 1))
112 | self.network.load_state_dict(best_model_weights)
113 |
114 | def evaluate_metrices(self, outs: torch.Tensor, targets: torch.Tensor) -> List:
115 | """
116 | :param outs: output logits from network for a given batch of inputs
117 | :param targets: ground truth targets corresponding to the batch of inputs
118 | :return: calculated metric from the utility function
119 | """
120 | fn_mapping = {
121 | 'pixel_accuracy': eval_metrices.pixel_accuracy,
122 | 'mIOU': eval_metrices.get_confusion_matrix,
123 | 'accuracy': eval_metrices.accuracy
124 | }
125 | scores = []
126 | for mat in self.metrices:
127 | scores.append(fn_mapping[mat](outs, targets))
128 | return scores
129 |
130 | def update_running_stats(self, batch_stats: List, running_stats: List, num_inputs_in_batch: int) -> List:
131 | """
132 | :param batch_stats: List of batch stats(in the same order as self.metrices)
133 | :param running_stats: List of running stats from the current epoch so far(in the same order as self.metrices)
134 | :param num_inputs_in_batch: Number of inputs in current batch
135 | :return: Aggregated stats for the current epoch. MIOU requires the bins_matrix(CxC) to be summed up.
136 | """
137 | running_stats = [r_stat + b_stat * num_inputs_in_batch if 'miou' not in key.lower() else r_stat + b_stat
138 | for key, r_stat, b_stat in zip(self.metrices, running_stats, batch_stats)]
139 | return running_stats
140 |
141 | def get_epoch_stats(self, running_stats: List, total_inputs: int) -> List:
142 | epoch_stats = [stat / total_inputs if 'miou' not in key.lower() else eval_metrices.mIOU(stat)
143 | for key, stat in zip(self.metrices, running_stats)]
144 | return epoch_stats
145 |
146 | def log_stats(self, epoch: int, losses: List[Dict], stats: List[Dict] = [], data: List = [], lr_rates: List = []):
147 | """ Format epoch stats and log them to console and to tensorboard if enabled"""
148 | if epoch == 0:
149 | metrices_title = '\t'.join(
150 | [(phase + '_' + mat)[:12] for mat in self.metrices for phase in ['train', 'val']])
151 | tqdm.write('Epoch\tTrain_Loss\tVal_Loss\t' + metrices_title)
152 | tqdm.write(
153 | '-----------------------------------------------------------------------------------------------------')
154 | metrices_stats = ' \t '.join([format(stats[phase][epoch][i], '.6f') for i in range(len(self.metrices))
155 | for phase in ['train', 'val']])
156 | log_stats = ' \t '.join([format(losses[phase][epoch], '.6f') for phase in ['train', 'val']])
157 | tqdm.write('{:4d} \t'.format(epoch + 1) + log_stats + '\t' + metrices_stats)
158 | if self.logToTensorboard:
159 | self.writeToTensorboard(epoch, losses, stats, data, lr_rates)
160 |
161 | def writeToTensorboard(self, epoch: int, losses: List[Dict], stats: List[Dict] = [], data: List = [],
162 | lr_rates: List = []):
163 | """Logs epoch stats to tensorboard and output predictions to tensorboard after every 25 epoch"""
164 | if self.writer is None:
165 | self.setup_tensorboard()
166 | self.writer.add_scalar('Loss/train', losses['train'][epoch], epoch + 1)
167 | self.writer.add_scalar('Loss/val', losses['val'][epoch], epoch + 1)
168 | num_batches = len(lr_rates)
169 | for i, lr in enumerate(lr_rates):
170 | self.writer.add_scalar('Loss/schedule_lr', lr, epoch * num_batches + 1)
171 | if epoch % 25 != 0:
172 | return
173 | for i, metric in enumerate(self.metrices):
174 | self.writer.add_scalar(metric + '/train', stats['train'][epoch], epoch + 1)
175 | self.writer.add_scalar(metric + '/val', stats['val'][epoch], epoch + 1)
176 | if data:
177 | imgs, network_outputs, targets = data
178 | _, preds = torch.max(network_outputs, dim=1)
179 | val_results = utils.plot_images(imgs[:4], targets[:4], preds[:4],
180 | title='val_mIOU' + str(stats['val'][epoch][0]),
181 | num_cols=3)
182 | self.writer.add_figure('val_epoch' + str(epoch), val_results)
183 |
184 | def add_text_to_tensorboard(self, text: str) -> None:
185 | """ Log text to tensorboard. Used for logging experiment configuration"""
186 | if self.writer is None:
187 | self.setup_tensorboard()
188 | self.writer.add_text("model_details", "Experiment Config " + text)
189 |
190 | def setup_tensorboard(self):
191 | """ setup the tensorboard logging dir"""
192 | logsdir = LOGSDIR / f"{self.name}"
193 | logsdir.mkdir(parents=True, exist_ok=True)
194 | self.writer = SummaryWriter(logsdir)
195 |
196 | def get_logits(self, x: torch.Tensor) -> torch.Tensor:
197 | """
198 | :param x: torch.Tensor [N, C, H, W] where C is number of channels in inputs
199 | :return: logit Tensor of size [N, K, H, W] where K is number of output classes
200 | """
201 | self.network.eval()
202 | if not x.is_cuda:
203 | x = x.to(device)
204 | self.network.to(device)
205 | outs = self.network(x)
206 | return outs
207 |
208 | def evaluate(self, x: torch.Tensor) -> torch.Tensor:
209 | """
210 | :param x: Input Tensor [N, C, H, W]
211 | :return: Predict Tensor [N, H, W]
212 | """
213 | outs = self.get_logits(x)
214 | _, preds = torch.max(outs, dim=1)
215 | return preds
216 |
217 | def load_weights(self, filename=None) -> None:
218 | """load model weights from passed filename or self.weights_file_name"""
219 | if filename is not None:
220 | filename = self.weights_file_name
221 | self.network.load_state_dict(torch.load(filename))
222 |
223 | def save_weights(self) -> None:
224 | """ Save current model state dict int self.weights_file_name """
225 | torch.save(self.network.state_dict(), self.weights_file_name)
226 |
--------------------------------------------------------------------------------
/sem_seg/models/eval_metrices.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from config import defaults
5 |
6 |
7 | def get_confusion_matrix(output: torch.Tensor, target: torch.Tensor) -> np.ndarray:
8 | """
9 | :param output: output logits
10 | :param target: targets
11 | :return: returns the confusion matrix of shape CxC between pred and targets
12 | """
13 |
14 | n_classes = output.shape[1]
15 | softmax_out = F.softmax(output, dim=1)
16 | pred = torch.argmax(softmax_out, dim=1).squeeze(1).cpu().numpy()
17 | target = target.cpu().numpy()
18 | hist = fast_hist(pred.flatten(), target.flatten(), n_classes)
19 | return hist
20 |
21 |
22 | def mIOU(c_matrix: np.ndarray) -> float:
23 | """
24 | Calculates the mIOU for a given confusion matrix
25 | :param c_matrix: CxC confusion matrix
26 | :return: effection mIOU
27 | """
28 | if type(c_matrix) != np.ndarray:
29 | return 0
30 | class_iu = per_class_iu(c_matrix)
31 | m_iou = np.nanmean(class_iu) # ignoring Nans
32 | return m_iou
33 |
34 |
35 | def fast_hist(pred: torch.Tensor, label: torch.Tensor, n:int) -> np.ndarray:
36 | """
37 | :param pred: softmaxed prediction of shape [N, H, W]
38 | :param label: Label of shape [N, H, W]
39 | :param n: num classes
40 | :return: a matrix of shape CXC where row i represents the count of each class in pred when the actual class
41 | of label is i
42 | """
43 | k = (label >= 0) & (label < n)
44 | return np.bincount(n * label[k].astype(int) + pred[k], minlength=n**2).reshape(n, n)
45 |
46 |
47 | def per_class_iu(hist: np.ndarray) -> np.array:
48 | """
49 | :param hist: bin count matrix of size CxC where C is the number of output classes.
50 | :return: list of IOU for each class
51 | """
52 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
53 |
54 |
55 | def pixel_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
56 | """
57 | :param output: output logits
58 | :param target: targets
59 | :return: pixelwise accuracy for semantic segmentation
60 | """
61 | softmax_out = F.softmax(output, dim=1)
62 | pred = torch.argmax(softmax_out, dim=1).squeeze(1)
63 | pred = pred.view(1, -1)
64 | target = target.view(1, -1)
65 | correct = pred.eq(target)
66 | correct = correct[target != defaults.CITYSCAPES_IGNORED_INDEX]
67 | correct = correct.view(-1)
68 | score = correct.float().sum(0) / correct.size(0)
69 | return score.item()
70 |
71 |
72 | def accuracy(outputs: torch.Tensor, targets: torch.Tensor) -> float:
73 | """
74 | :param outputs: output logits
75 | :param targets: target labels
76 | :return: accuracy score for the batch
77 | """
78 | softmax_out = F.softmax(outputs, dim=1)
79 | preds = torch.argmax(softmax_out, dim=1)
80 | correct = preds.eq(targets)
81 | score = correct.float().sum(0) / correct.size(0)
82 | return score.item()
83 |
--------------------------------------------------------------------------------
/sem_seg/models/segmentation_model.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | from typing import Callable
4 | from config import defaults
5 | import torch
6 |
7 | from sem_seg.models.base import Model
8 | from . import eval_metrices
9 |
10 | DIRNAME = Path(__file__).parents[1].resolve()
11 | OUTDIR = DIRNAME / "outputs" / "class_IOUs"
12 | OUTDIR.mkdir(parents=True, exist_ok=True)
13 |
14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 |
17 | class SegmentationModel(Model):
18 | """Child class for semantic segmentation model from base Model"""
19 |
20 | def __init__(self, network: torch.nn.Module,
21 | dataloader: torch.utils.data.DataLoader,
22 | optimizer: torch.optim,
23 | criterion: Callable = None,
24 | lr_scheduler: torch.optim.lr_scheduler = None,
25 | additional_identifier: str = ''):
26 | super().__init__(network, dataloader, optimizer, criterion, lr_scheduler, additional_identifier)
27 |
28 | if criterion is None:
29 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=defaults.CITYSCAPES_IGNORED_INDEX)
30 |
31 | self.metrices = ['mIOU', 'pixel_accuracy']
32 |
33 | def store_per_class_iou(self):
34 | print('Calculating per class mIOU')
35 | self.network.eval()
36 | self.network.to(device)
37 | c_matrix = 0
38 | for inputs, labels in self.dataloader['val']:
39 | inputs, labels = inputs.to(device), labels.to(device)
40 | outs = self.network(inputs)
41 | c_matrix += eval_metrices.get_confusion_matrix(outs, labels)
42 | class_iou = eval_metrices.per_class_iu(c_matrix)
43 | class_iou = np.round(class_iou, decimals=6)
44 | file_path = str(OUTDIR) + '/' + self.name + '.csv'
45 | np.savetxt(file_path, class_iou, delimiter=',', fmt='%.6f')
46 | print(f'per class IOU saved at {file_path}')
47 |
--------------------------------------------------------------------------------
/sem_seg/models/self_supervised_model.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | from tqdm import tqdm
3 | import numpy as np
4 | import time
5 | import copy
6 | import random
7 |
8 | import torch
9 |
10 | from sem_seg.models.base import Model
11 |
12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 |
14 |
15 | class SelfSupervisedModel(Model):
16 | def __init__(self, network: torch.nn.Module,
17 | dataloader: torch.utils.data.DataLoader,
18 | optimizer: torch.optim,
19 | criterion: Callable = None,
20 | lr_scheduler: torch.optim.lr_scheduler = None,
21 | additional_identifier: str = ''):
22 | super().__init__(network, dataloader, optimizer, criterion, lr_scheduler, additional_identifier)
23 | self.metrices = []
24 | if criterion is None:
25 | self.criterion = torch.nn.MSELoss
26 |
27 | def train(self, num_epochs: int):
28 | since = time.time()
29 | losses = {'train': [], 'val': []}
30 | stats = {'train': [], 'val': []}
31 | lr_rates = []
32 | best_epoch, best_stat = None, float('inf')
33 | best_model_weights = None
34 | self.network.to(device)
35 | self.network.initialize_target_network()
36 |
37 | for epoch in tqdm(range(num_epochs)):
38 | for phase in ['train', 'val']:
39 | if phase == 'train':
40 | self.network.train()
41 | else:
42 | self.network.eval()
43 | running_loss = 0.0
44 |
45 | for (imgs, tf_imgs, seeds), _ in self.dataloader[phase]:
46 | imgs, tf_imgs = imgs.to(device), tf_imgs.to(device)
47 | self.optim.zero_grad()
48 |
49 | with torch.set_grad_enabled(phase == 'train'):
50 | online_feats = self.network.online_network(imgs)['feat5']
51 | tf_online_feats = self.network.online_network(tf_imgs)['feat5']
52 |
53 | pred1_feats = self.network.predictor(self.network.online_projector(online_feats))
54 | pred2_feats = self.network.predictor(self.network.online_projector(tf_online_feats))
55 |
56 | with torch.no_grad():
57 | target_feats = self.network.target_network(imgs)['feat5']
58 | target_feats_tf = self.network.target_network(tf_imgs)['feat5']
59 |
60 | target_for_pred1_feats = self.network.target_projector(target_feats_tf)
61 | target_for_pred2_feats = self.network.target_projector(target_feats)
62 |
63 | with torch.set_grad_enabled(phase == 'train'):
64 | loss = self.network.regression_loss(pred1_feats, target_for_pred1_feats)
65 | loss += self.network.regression_loss(pred2_feats, target_for_pred2_feats)
66 |
67 | if phase == 'train':
68 | loss.backward()
69 | self.optim.step()
70 | self.network.update_target_network()
71 |
72 | running_loss += loss.item() * imgs.size(0)
73 | epoch_loss = running_loss / len(self.dataloader[phase].dataset)
74 |
75 | if phase == 'val' and epoch_loss <= best_stat:
76 | best_epoch = epoch
77 | best_stat = epoch_loss
78 | best_model_weights = copy.deepcopy(self.network.state_dict())
79 | losses[phase].append(epoch_loss)
80 |
81 | self.log_stats(epoch, losses)
82 |
83 | if (epoch + 1) % 30 == 0: #checkpoint model every 30 epochs
84 | torch.save(self.network.state_dict(), self.weights_file_name)
85 | torch.save(best_model_weights, self.weights_file_name.split('.')[0]+'_best.h5')
86 | print('model checkpointed to ', self.weights_file_name)
87 |
88 | if self.lr_scheduler is not None:
89 | self.lr_scheduler.step(epoch_loss) # considering ReduceLrOnPlateau and watch val val loss
90 | lr_step = self.optim.state_dict()["param_groups"][0]["lr"]
91 | lr_rates.append(lr_step)
92 | min_lr = self.lr_scheduler.min_lrs[0]
93 | if lr_step / min_lr <= 10:
94 | print("Min LR reached")
95 | break
96 |
97 | time_elapsed = time.time() - since
98 | print('Training Completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
99 | print('Best val loss {:.6f} and best epoch is {}'.format(best_stat, best_epoch + 1))
100 | torch.save(self.network.state_dict(), self.weights_file_name)
101 | print('model saved to ', self.weights_file_name)
102 | self.network.load_state_dict(best_model_weights)
103 |
--------------------------------------------------------------------------------
/sem_seg/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .fcn import FCN8s, FCN16s, FCN32s
2 | from .resnet import Resnet18, Resnet50, Resnet101, Resnet152, Resnet50_2, Resnet101_2
3 | from .byol import BYOL
--------------------------------------------------------------------------------
/sem_seg/networks/byol.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from .network_utils import get_num_features
5 | import copy
6 |
7 |
8 | class BYOL_Head(nn.Module):
9 | def __init__(self, backbone: nn.Module, target_momentum=0.996):
10 | super().__init__()
11 | # representation head
12 | self.online_network = backbone
13 | self.target_network = copy.deepcopy(backbone)
14 |
15 | # Projection Head
16 | self.online_projector = ProjectorHead(backbone.name)
17 | self.target_projector = ProjectorHead(backbone.name)
18 |
19 | # Predictor Head
20 | self.predictor = MLPHead(self.online_projector.out_channels, 512, 128)
21 |
22 | self.m = target_momentum
23 |
24 | def initialize_target_network(self):
25 | for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
26 | param_k.data.copy_(param_q.data)
27 | param_k.requires_grad = False
28 |
29 | for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()):
30 | param_k.data.copy_(param_q.data)
31 | param_k.requires_grad = False
32 |
33 | @torch.no_grad()
34 | def update_target_network(self):
35 | for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
36 | param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data
37 |
38 | for param_q, param_k in zip(self.online_projector.parameters(), self.target_projector.parameters()):
39 | param_k.data = self.m * param_k.data + (1 - self.m) * param_q.data
40 |
41 | @staticmethod
42 | def regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
43 | x_norm = F.normalize(x, dim=1)
44 | y_norm = F.normalize(y, dim=1)
45 | loss = 2 - 2 * (x_norm * y_norm).sum(dim=-1)
46 | return loss.mean()
47 |
48 |
49 | class ProjectorHead(nn.Module):
50 | def __init__(self, backbone_class: str, layer: str = 'feats5'):
51 | super().__init__()
52 | num_feats = get_num_features(backbone_class)
53 | layer_index = -3 if layer == 'feats3' else -2 if layer == 'feats4' else -1
54 | num_features = num_feats[layer_index]
55 | self.projection = MLPHead(num_features, 512, 128)
56 | self.out_channels = 128
57 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
58 |
59 | def forward(self, x):
60 | x_pooled = self.avg_pool(x)
61 | h = x_pooled.view(x_pooled.shape[0], x_pooled.shape[1]) # removing the last dimension
62 | return self.projection(h)
63 |
64 |
65 | class MLPHead(nn.Module):
66 | def __init__(self, in_channels: int, hidden_size: int, out_size: int):
67 | super().__init__()
68 | self.net = nn.Sequential(
69 | nn.Linear(in_channels, hidden_size),
70 | nn.BatchNorm1d(hidden_size),
71 | nn.ReLU(inplace=True),
72 | nn.Linear(hidden_size, out_size)
73 | )
74 |
75 | def forward(self, x):
76 | return self.net(x)
77 |
78 |
79 | class BYOL(BYOL_Head):
80 | def __init__(self, backbone: nn.Module, target_momentum=0.996):
81 | super().__init__(backbone, target_momentum)
82 |
83 |
--------------------------------------------------------------------------------
/sem_seg/networks/fcn.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | from torch import nn
3 | from .network_utils import get_num_features
4 |
5 |
6 | class FCN(nn.Module):
7 | """ Base Class for all FCN Modules """
8 |
9 | def __init__(self, backbone: nn.Module, num_classes: int, model_type: str = 'fcn8s'):
10 | super().__init__()
11 | self.backbone = backbone
12 | num_features = get_num_features(backbone.name, model_type)
13 | self.classifier = nn.ModuleList([self.upsample_head(num_feature, num_classes) for num_feature in num_features])
14 |
15 | def upsample_head(self, in_channels: int, channels: int) -> nn.Module:
16 | """
17 | :param in_channels: Number of channels in Input
18 | :param channels: Desired Number of channels in Output
19 | :return: torch.nn.Module
20 | """
21 | inter_channels = in_channels // 8
22 | layers = [
23 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
24 | nn.BatchNorm2d(inter_channels),
25 | nn.ReLU(),
26 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
27 | nn.BatchNorm2d(inter_channels),
28 | nn.ReLU(),
29 | nn.Conv2d(inter_channels, channels, 1),
30 | ]
31 | return nn.Sequential(*layers)
32 |
33 | def forward(self, x):
34 | """ Abstract method to be implemented by child classes"""
35 | pass
36 |
37 |
38 | class FCN32s(FCN):
39 | """ Child FCN class that generates the output only using feature maps from last layer of the backbone """
40 | def __init__(self, backbone: nn.Module, num_classes: int):
41 | super().__init__(backbone, num_classes, model_type='fcn32s')
42 |
43 | def forward(self, x):
44 | """ Forward pass through FCN32s"""
45 | h, w = x.shape[-2:]
46 | features = self.backbone(x)
47 | return self.bilinear_upsample(features, h, w)
48 |
49 | def bilinear_upsample(self, features: Dict, h: int, w: int):
50 | """
51 | :param features: Backbone's output feature map dict
52 | :param h: Desired Output Height
53 | :param w: Desired output Width
54 | :return: Upsample output of size N x C x H x W where C is the number of classes
55 | """
56 | out32s = self.classifier[-1](features['feat5'])
57 | upsampled_out = nn.functional.interpolate(out32s, size=(h, w), mode='bilinear', align_corners=False)
58 | return upsampled_out
59 |
60 |
61 | class FCN16s(FCN):
62 | """ Child FCN class that generates the output only using feature maps from last two layers of the backbone """
63 | def __init__(self, backbone: nn.Module, num_classes: int):
64 | super().__init__(backbone, num_classes, model_type='fcn16s')
65 |
66 | def forward(self, x):
67 | """ Forward pass through FCN16s"""
68 | h, w = x.shape[-2:]
69 | features = self.backbone(x)
70 | return self.bilinear_upsample(features, h, w)
71 |
72 | def bilinear_upsample(self, features: Dict, h: int, w: int):
73 | """
74 | Bilinear upsample after merging the last 2 feature maps
75 | :param features: Backbone's output feature map dict
76 | :param h: Desired Output Height
77 | :param w: Desired output Width
78 | :return: Upsample output of size N x C x H x W where C is the number of classes
79 | """
80 | out32s = self.classifier[-1](features['feat5'])
81 | out16s = self.classifier[-2](features['feat4'])
82 | upsampled_out32s = nn.functional.interpolate(out32s, size=(h//16, w//16), mode='bilinear', align_corners=False)
83 | out = upsampled_out32s + out16s
84 | upsampled_out = nn.functional.interpolate(out, size=(h, w), mode='bilinear', align_corners=False)
85 | return upsampled_out
86 |
87 |
88 | class FCN8s(FCN):
89 | """ Child FCN class that generates the output only using feature maps from last three layers of the backbone """
90 | def __init__(self, backbone: nn.Module, num_classes: int):
91 | super().__init__(backbone, num_classes, model_type='fcn8s')
92 |
93 | def forward(self, x):
94 | """ Forward pass through FCN16s"""
95 | h, w = x.shape[-2:]
96 | features = self.backbone(x)
97 | return self.bilinear_upsample(features, h, w)
98 |
99 | def bilinear_upsample(self, features: Dict, h: int, w: int):
100 | """
101 | Bilinear upsample after merging the last 3 feature maps
102 | :param features: Backbone's output feature map dict
103 | :param h: Desired Output Height
104 | :param w: Desired output Width
105 | :return: Upsample output of size N x C x H x W where C is the number of classes
106 | """
107 | out32s = self.classifier[-1](features['feat5'])
108 | out16s = self.classifier[-2](features['feat4'])
109 | out8s = self.classifier[-3](features['feat3'])
110 | upsampled_out32s = nn.functional.interpolate(out32s, size=(h//16, w//16), mode='bilinear', align_corners=False)
111 | out = upsampled_out32s + out16s
112 | upsampled_out16s = nn.functional.interpolate(out, size=(h//8, w//8), mode='bilinear', align_corners=False)
113 | out = upsampled_out16s + out8s
114 | upsampled_out = nn.functional.interpolate(out, size=(h, w), mode='bilinear', align_corners=False)
115 | return upsampled_out
116 |
117 |
--------------------------------------------------------------------------------
/sem_seg/networks/network_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | from torchvision import models
4 |
5 | def get_num_features(backbone_name: str, model_type: str='') -> List:
6 | """
7 | Gives a List of features present in the last 3 blocks of the backbone model
8 | :param backbone_name: name of the backbone model e.g. 'resnet18' | 'resnet50'
9 | :param model_type: Type of FCN model(fcn32s | fcn16s | fcn8s)
10 | :return: List of number of features extracted from last 3 blocks of the backbone model
11 | """
12 |
13 | if 'resnet18' in backbone_name.lower():
14 | num_features = [64, 128, 256, 512]
15 | else:
16 | num_features = [256, 512, 1024, 2048]
17 | if 'fcn8s' in model_type.lower():
18 | num_features = num_features[-3:]
19 | elif 'fcn16s' in model_type.lower():
20 | num_features = num_features[-2:]
21 | elif 'fcn32s' in model_type.lower():
22 | num_features = num_features[-1:]
23 | return num_features
24 |
25 |
26 | def get_features_dict(base: torch.nn.Module) -> torch.nn.Module:
27 | """
28 | This function extracts the features from various layers(hardcoded in a dictionary extract_layers) and returns
29 | them as a key-value pair.
30 | For e.g. extract_layer = {'layer2': 'feat3', 'layer3': 'feat4'} will extract 'layer2' and 'layer3' feature maps
31 | from the given network and assigns them to key 'feat3' and 'feat4' respectively of the output.
32 | :param base: backbone model
33 | :return: model with output layer as dictionary which provides extracted features from various layers.
34 | """
35 | extract_layers = {'layer1': 'feat1', 'layer2': 'feat3', 'layer3': 'feat4', 'layer4': 'feat5'}
36 | return models._utils.IntermediateLayerGetter(base, extract_layers)
37 |
38 |
--------------------------------------------------------------------------------
/sem_seg/networks/resnet.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | import torch
3 | from typing import List
4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
5 | from pathlib import Path
6 | from .network_utils import get_features_dict
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
15 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
16 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
17 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
18 | }
19 |
20 |
21 | def _build_model(resnet_type: str = 'resnet50', pretrained: bool = False,
22 | replace_stride_with_dilation: List = [False, False, False]) -> torch.nn.Module:
23 | """
24 | :param resnet_type: 'resnet18' | 'resnet50' | 'resnet101' | 'resnet152'
25 | :param pretrained: True if the network is expected to be initialized with pretrained imagenet weights
26 | :param replace_stride_with_dilation: List of 3 boolean values if the last 3 blocks of resnet should use dilation
27 | instead of stride
28 | :return: torch.nn.Module
29 | """
30 | if resnet_type == 'resnet18':
31 | base = models.resnet18(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation)
32 | elif resnet_type == 'resnet50':
33 | base = models.resnet50(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation)
34 | elif resnet_type == 'resnet101':
35 | base = models.resnet101(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation)
36 | elif resnet_type == 'resnet152':
37 | base = models.resnet152(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation)
38 | elif resnet_type == 'resnet50_2x':
39 | base = models.wide_resnet50_2(pretrained=False, replace_stride_with_dilation=replace_stride_with_dilation)
40 | else:
41 | print(f'resent type {resnet_type} is not currently implemented')
42 | return
43 |
44 | if pretrained:
45 | state_dict = load_state_dict_from_url(model_urls[resnet_type], progress=True)
46 | base.load_state_dict(state_dict)
47 |
48 | network = get_features_dict(base)
49 | network.isDRN = any(replace_stride_with_dilation) # Is Dilated Residual Network
50 | network.name = resnet_type
51 | return network
52 |
53 |
54 | Resnet18 = lambda pretrained: _build_model('resnet18', pretrained)
55 | Resnet50 = lambda pretrained: _build_model('resnet50', pretrained)
56 | Resnet101 = lambda pretrained: _build_model('resnet101', pretrained)
57 | Resnet152 = lambda pretrained: _build_model('resnet152', pretrained)
58 | DRN50 = lambda pretrained: _build_model('resnet50', pretrained, replace_stride_with_dilation = [False, True, True])
59 | Resnet50_2 = lambda pretrained: _build_model('resnet50_2x', pretrained)
60 | Resnet101_2 = lambda pretrained: _build_model('resnet101_2x', pretrained)
61 |
62 |
63 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nilesh0109/self-supervised-sem-seg/fe0e5f2e56028dc881517c72f1900a0cd1c35467/training/__init__.py
--------------------------------------------------------------------------------
/training/debugger.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | import torch
4 | import torchvision
5 |
6 |
7 | from config import defaults
8 | defaults.MATPLOTLIB_NO_GUI = False # enable GUI for matplotlib
9 | from sem_seg.datasets.dataloader import CityscapesLoader
10 | from sem_seg.datasets.datautils import CityscapesUtils
11 | from sem_seg.networks import FCN8s, FCN32s, FCN16s, Resnet50
12 | from sem_seg.models.segmentation_model import SegmentationModel
13 | from sem_seg.models.self_supervised_model import SelfSupervisedModel
14 | from training import utils, train
15 |
16 |
17 | DIRNAME = Path(__file__).parents[1].resolve()
18 | OUTDIR = DIRNAME / "semantic_seg" / "outputs"
19 | OUTDIR.mkdir(parents=True, exist_ok=True)
20 |
21 |
22 | def get_dataloader(mode='supervised'):
23 | cityscapes = CityscapesLoader(label_percent='100%').get_cityscapes_loader(mode=mode)
24 | return cityscapes
25 |
26 | def get_model():
27 | cityscapes = get_dataloader()
28 | cityscapes_utils = CityscapesUtils()
29 | num_classes = cityscapes_utils.num_classes + 1
30 | num_features = [256, 512, 512]
31 | base = Resnet50(pretrained=False)
32 | fcn8s = FCN8s(base, num_classes)
33 | optim = torch.optim.Adam(fcn8s.parameters())
34 | model = SegmentationModel(fcn8s, cityscapes, optim)
35 | return model
36 |
37 | def debug_FCN():
38 | cityscapes = get_dataloader()
39 | batch_imgs, batch_targets = next(iter(cityscapes['train']))
40 | utils.plot_images(batch_imgs, batch_targets, title='Predictions')
41 |
42 | def debug_self_supervised_model():
43 | cityscapes = get_dataloader(mode='self-supervised')
44 | (batch_imgs, tf_imgs, seeds), _ = next(iter(cityscapes['train']))
45 | preds = None
46 | utils.plot_images(batch_imgs, tf_imgs, preds, title='Predictions')
47 |
48 | if __name__ == '__main__':
49 | debug_FCN()
50 | debug_self_supervised_model()
51 |
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import argparse
4 | from typing import Dict
5 | import importlib
6 | from config import defaults
7 | from sem_seg.datasets.datautils import CityscapesUtils
8 | from sem_seg.datasets.dataloader import CityscapesLoader
9 | from sem_seg import networks, models
10 |
11 |
12 | def run_experiment(experiment_config: Dict, load_weights: bool, save_weights: bool) -> None:
13 | """
14 | Run a Training Experiment
15 | :param experiment_config: Dict of following form
16 | {
17 | "dataset": "cityscapes",
18 | "model": "SegmentationModel", # Type of the training model(SegmentationModel, SelfSupervisedModel)
19 | "network": "FCN8s", # "FCN8s" | "FCN16s" | "FCN32s" | "deeplabv3" | "deeplabv3plus"
20 | "network_args": {"backbone":"Resnet50", # "Resnet50" | "Resnet18" | "Resnet101" | "DRN50" etc
21 | "pretrained": false, # Whether the backbone model is pretrained
22 | "load_from_byol": true, # Whether to load the backbone weights from self-supervised trained model.
23 | "freeze_backbone": false}, # Whether the backbone weights are fixed while training
24 | "train_args":{"batch_size": 8,
25 | "epochs": 400,
26 | "labels_percent": '10%', #percetange of supervised labels to use for training. Default 100%
27 | "log_to_tensorboard": false},
28 | "experiment_group":{}
29 | }'
30 | :param load_weights: If true, load weights for the model from last run
31 | :param save_weights: If true, save model weights to sem_seg/weights directory
32 | :return: None
33 | """
34 |
35 | DEFAULT_TRAIN_ARGS = {'epochs': defaults.NUM_EPOCHS, 'batch_size': defaults.BATCH_SIZE,
36 | 'num_workers': defaults.NUM_WORKERS}
37 | train_args = {
38 | **DEFAULT_TRAIN_ARGS,
39 | **experiment_config.get("train_args", {})
40 | }
41 | experiment_config["train_args"] = train_args
42 | experiment_config["experiment_group"] = experiment_config.get("experiment_group", None)
43 | print(f'Running experiment with config {experiment_config}')
44 |
45 | labels_percent = train_args.get('labels_percent', '100%')
46 | mode = experiment_config.get('mode', 'supervised')
47 | dataset_name = experiment_config["dataset"].lower()
48 | assert dataset_name in ['cityscapes'], "The dataloader is only implemented for cityscapes dataset"
49 | data_loader = CityscapesLoader(label_percent=labels_percent) \
50 | .get_cityscapes_loader(batch_size=train_args["batch_size"],
51 | num_workers=train_args["num_workers"],
52 | mode=mode)
53 | models_module = importlib.import_module("sem_seg.models")
54 | model_class_ = getattr(models_module, experiment_config["model"])
55 |
56 | networks_module = importlib.import_module("sem_seg.networks")
57 | network_args = experiment_config.get("network_args", {})
58 |
59 | pretrained = network_args["pretrained"]
60 | backbone_class_ = getattr(networks_module, network_args["backbone"])
61 | base = backbone_class_(pretrained=pretrained)
62 |
63 | num_classes = CityscapesUtils().num_classes
64 |
65 | network_class_ = getattr(networks_module, experiment_config["network"])
66 | network = network_class_(base, num_classes)
67 | optim = torch.optim.SGD(network.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.00001)
68 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', factor=0.1,
69 | patience=15, min_lr=1e-10, verbose=True)
70 | additional_identifier = get_additional_identifier(network_args["backbone"], pretrained, dataset_name,
71 | labels_percent)
72 | model = model_class_(network, data_loader, optim, lr_scheduler=lr_scheduler,
73 | additional_identifier=additional_identifier)
74 |
75 | if not train_args['log_to_tensorboard']:
76 | model.logToTensorboard = False
77 | model.add_text_to_tensorboard(json.dumps(experiment_config))
78 |
79 | if load_weights:
80 | model.load_weights()
81 | elif network_args.get('load_from_BYOL', ''):
82 | byol = networks.BYOL(base)
83 | ss_model = models.SelfSupervisedModel(byol, additional_identifier=additional_identifier)
84 | print('loading self-supervised weights from ', ss_model.weights_file_name)
85 | byol_state_dict = torch.load(ss_model.weights_file_name)
86 | backbone_dict = {}
87 | for key in byol_state_dict:
88 | if 'online_network' in key:
89 | new_key = key.replace('online_network.', '')
90 | backbone_dict[new_key] = byol_state_dict[key]
91 | model.network.backbone.load_state_dict(backbone_dict)
92 |
93 | if network_args.get('freeze_backbone', False):
94 | for param in model.network.backbone.parameters():
95 | param.requires_grad = False
96 |
97 | model.train(num_epochs=train_args["epochs"])
98 | model.store_per_class_iou()
99 |
100 | if save_weights:
101 | model.save_weights()
102 |
103 |
104 | def get_additional_identifier(backbone: str, pretrained: bool = False, dataset_name: str = '',
105 | labels_percent: str = '100%') -> str:
106 | """
107 | Returns the additional_identifier added to the model name for efficient tracking of different experiments
108 | :param backbone: name of the backbone
109 | :param pretrained: Whether the backbone is pretrained
110 | :param dataset_name: Name of the training dataset
111 | :param labels_percent: % of labels used for training
112 | :return: additional identifier string
113 | """
114 | additional_identifier = backbone
115 | additional_identifier += '_pt' if pretrained else ''
116 | additional_identifier += '_' + labels_percent[:-1] if labels_percent and int(labels_percent[:-1]) < 100 else ''
117 | additional_identifier += '_ct' if dataset_name == 'cityscapes' else '_' + dataset_name
118 | return additional_identifier
119 |
120 |
121 | def _parse_args():
122 | """ parse command line arguments """
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument("--save", default=False, action="store_true",
125 | help="If true, final weights will be stored in canonical, version-controlled location")
126 | parser.add_argument("--load", default=False, action="store_true",
127 | help="If true, final weights will be loaded from canonical, version-controlled location")
128 | parser.add_argument("experiment_config", type=str,
129 | help='Experiment JSON (\'{"dataset": "cityscapes", "model": "SegmentationModel",'
130 | ' "network": "fcn8s"}\''
131 | )
132 | args = parser.parse_args()
133 | return args
134 |
135 |
136 | def main():
137 | """Run Experiment"""
138 | args = _parse_args()
139 | experiment_config = json.loads(args.experiment_config)
140 | run_experiment(experiment_config, args.load, args.save)
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/training/train_byol.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import argparse
4 | from typing import Dict
5 | import importlib
6 | from config import defaults
7 | from sem_seg.datasets.dataloader import CityscapesLoader
8 |
9 |
10 | def run_experiment(experiment_config: Dict, load_weights: bool, save_weights: bool) -> None:
11 | """
12 | Run a Training Experiment
13 | :param experiment_config: Dict of following form
14 | {
15 | "dataset": "cityscapes",
16 | "model": "SelfSupervisedModel", # Type of the training model(SegmentationModel, SelfSupervisedModel)
17 | "network": "BYOL", # "BYOL"
18 | "mode": "self-supervised", # Type of training(self-supervised | supervised)
19 | "network_args": {"backbone":"Resnet50", # "Resnet50" | "Resnet18" | "Resnet101" | "DRN50" etc
20 | "pretrained": false, # Whether the backbone model is pretrained
21 | "target_momentum": 0.996 #Target momentum for byol target network
22 | }
23 | "train_args":{"batch_size": 8,
24 | "epochs": 400,
25 | "log_to_tensorboard": false},
26 | "experiment_group":{}
27 | }'
28 | :param load_weights: If true, load weights for the model from last run
29 | :param save_weights: If true, save model weights to sem_seg/weights directory
30 | :return: None
31 | """
32 |
33 | DEFAULT_TRAIN_ARGS = {'epochs': defaults.NUM_EPOCHS, 'batch_size': defaults.BATCH_SIZE,
34 | 'num_workers': defaults.NUM_WORKERS}
35 | train_args = {
36 | **DEFAULT_TRAIN_ARGS,
37 | **experiment_config.get("train_args", {})
38 | }
39 | experiment_config["train_args"] = train_args
40 | experiment_config["experiment_group"] = experiment_config.get("experiment_group", None)
41 | print(f'Running experiment with config {experiment_config}')
42 |
43 | labels_percent = train_args.get('labels_percent', '100%')
44 | mode = experiment_config.get('mode', 'supervised')
45 | dataset_name = experiment_config["dataset"].lower()
46 | assert dataset_name in ['cityscapes'], "The dataloader is only implemented for cityscapes dataset"
47 | data_loader = CityscapesLoader(label_percent=labels_percent) \
48 | .get_cityscapes_loader(batch_size=train_args["batch_size"],
49 | num_workers=train_args["num_workers"],
50 | mode=mode)
51 | models_module = importlib.import_module("sem_seg.models")
52 | model_class_ = getattr(models_module, experiment_config["model"])
53 |
54 | networks_module = importlib.import_module("sem_seg.networks")
55 | network_args = experiment_config.get("network_args", {})
56 |
57 | pretrained = network_args["pretrained"]
58 | backbone_class_ = getattr(networks_module, network_args["backbone"])
59 | base = backbone_class_(pretrained=pretrained)
60 |
61 | network_class_ = getattr(networks_module, experiment_config["network"])
62 | target_momentum = train_args.get("target_momentum", 0.996)
63 | network = network_class_(base, target_momentum=target_momentum)
64 |
65 | training_params = [*network.online_network.parameters()] + [*network.online_projector.parameters()] + [
66 | *network.predictor.parameters()] # explicitly excluding the target network parameters
67 | optim = torch.optim.SGD(training_params, lr=3e-2, momentum=0.9, weight_decay=0.00001)
68 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.2,
69 | patience=20, min_lr=1e-10, verbose=True)
70 | additional_identifier = get_additional_identifier(network_args["backbone"], pretrained, dataset_name,
71 | labels_percent)
72 | model = model_class_(network, data_loader, optim, lr_scheduler=lr_scheduler,
73 | additional_identifier=additional_identifier)
74 |
75 | if not train_args['log_to_tensorboard']:
76 | model.logToTensorboard = False
77 | model.add_text_to_tensorboard(json.dumps(experiment_config))
78 |
79 | if load_weights:
80 | model.load_weights()
81 |
82 | model.train(num_epochs=train_args["epochs"])
83 |
84 | if save_weights:
85 | model.save_weights()
86 |
87 |
88 | def get_additional_identifier(backbone: str, pretrained: bool = False, dataset_name: str = '',
89 | labels_percent: str = '100%') -> str:
90 | """
91 | Returns the additional_identifier added to the model name for efficient tracking of different experiments
92 | :param backbone: name of the backbone
93 | :param pretrained: Whether the backbone is pretrained
94 | :param dataset_name: Name of the training dataset
95 | :param labels_percent: % of labels used for training
96 | :return: additional identifier string
97 | """
98 | additional_identifier = backbone
99 | additional_identifier += '_pt' if pretrained else ''
100 | additional_identifier += '_' + labels_percent[:-1] if labels_percent and int(labels_percent[:-1]) < 100 else ''
101 | additional_identifier += '_ct' if dataset_name == 'cityscapes' else '_' + dataset_name
102 | return additional_identifier
103 |
104 |
105 | def _parse_args():
106 | """ parse command line arguments """
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument("--save", default=False, action="store_true",
109 | help="If true, final weights will be stored in canonical, version-controlled location")
110 | parser.add_argument("--load", default=False, action="store_true",
111 | help="If true, final weights will be loaded from canonical, version-controlled location")
112 | parser.add_argument("experiment_config", type=str,
113 | help='Experiment JSON (\'{"dataset": "cityscapes", "model": "SegmentationModel",'
114 | ' "network": "fcn8s"}\''
115 | )
116 | args = parser.parse_args()
117 | return args
118 |
119 |
120 | def main():
121 | """Run Experiment"""
122 | args = _parse_args()
123 | experiment_config = json.loads(args.experiment_config)
124 | run_experiment(experiment_config, args.load, args.save)
125 |
126 |
127 | if __name__ == '__main__':
128 | main()
129 |
--------------------------------------------------------------------------------
/training/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import numpy as np
3 | import matplotlib
4 | from config import defaults
5 |
6 | # if defaults.MATPLOTLIB_NO_GUI:
7 | # matplotlib.use('Agg')
8 |
9 | import matplotlib.pyplot as plt
10 | import math
11 | import torch
12 | from pathlib import Path
13 |
14 | from sem_seg.datasets.datautils import CityscapesUtils
15 |
16 | DIRNAME = Path(__file__).parents[1].resolve()
17 | OUTDIR = DIRNAME / "sem_seg" / "outputs"
18 | OUTDIR.mkdir(parents=True, exist_ok=True)
19 |
20 | cityscape_utils = CityscapesUtils()
21 |
22 |
23 | def sanitize_imgs(images: List) -> List:
24 | """Converts the list of torch image Tensors in (N, C, H, W) format to list of numpy images in format"""
25 | for i in range(len(images)):
26 | if torch.is_tensor(images[i]):
27 | images[i] = images[i].cpu().numpy().squeeze()
28 | if images[i].ndim > 3 and images[i].shape[-1] != 3:
29 | images[i] = np.moveaxis(images[i], 1, -1)
30 | return images
31 |
32 |
33 | def is_mask(img: np.array) -> bool:
34 | """
35 | Checks if the image is a segmentation mask
36 | Criteria: If its a 2D image and has values in range of num_classes of the dataset
37 | :param img: Image
38 | :return: True if the image is a segmentation mask, False otherwise
39 | """
40 | return img.ndim == 2 and 0 <= len(np.unique(img)) < cityscape_utils.num_classes + 1
41 |
42 |
43 | def plot_images(imgs, targets=None, preds=None, title='No title', num_cols=6) -> plt.Figure:
44 | """
45 | Plot the images, targets and predictions in a grid. Shows 24 images at max in the plot.
46 | :param imgs: List of images to be plotted.
47 | :param targets: Corresponding List of ground truths.
48 | :param preds: Corresponding List of network predictions
49 | :param title: Title of the figure
50 | :param num_cols: Number of columns in the grid. Default is 6
51 | :return: output plot figure object
52 | """
53 | inputs = [imgs]
54 | if targets is not None:
55 | inputs.append(targets)
56 | if preds is not None:
57 | inputs.append(preds)
58 | inputs = sanitize_imgs(inputs)
59 | num_types = len(inputs)
60 | num_rows, num_cols, plot_width, plot_height = get_plot_sizes(inputs, num_cols)
61 | fig, ax = plt.subplots(num_rows, num_cols, figsize=(plot_width, plot_height), num=title, squeeze=False,
62 | gridspec_kw={'wspace': 0.05, 'hspace': 0.05, 'left': 0, 'top': 0.95})
63 | [axi.set_axis_off() for axi in ax.ravel()]
64 |
65 | for i in range(len(inputs[0])):
66 | r, c = (num_types * i) // num_cols, (num_types * i) % num_cols
67 | if r >= num_rows:
68 | break
69 | for j in range(num_types):
70 | img = inputs[j][i]
71 | img_to_show = cityscape_utils.label2color(img) if is_mask(img) else img
72 | ax[r][c + j].imshow(img_to_show)
73 | if title:
74 | fig.suptitle(title, fontsize=8)
75 | #if defaults.MATPLOTLIB_NO_GUI:
76 | # plt.savefig(str(OUTDIR) + '/plot2.png', bbox_inches='tight')
77 | #else:
78 | plt.show()
79 | return fig
80 |
81 |
82 | def get_plot_sizes(inputs: List, num_cols: int) -> Tuple:
83 | """ Given the input images, return the dimension of the plot i.e. num_rows, num_cols, Height, Width"""
84 | num_types, num_images = len(inputs), len(inputs[0])
85 | num_rows = min(4, math.ceil((num_images * num_types) / num_cols))
86 | if inputs[0].ndim == 4:
87 | N, C, H, W = inputs[0].shape
88 | elif inputs[0].ndim == 3:
89 | C, H, W = inputs[0].shape
90 | else:
91 | H, W = inputs[0].shape
92 | aspect_ratio = W / H
93 | return num_rows, num_cols, num_cols * aspect_ratio * 0.8, num_rows
94 |
95 |
96 | def save_results(data, filename):
97 | filename = str(OUTDIR) + '/' + filename + '.csv'
98 | np.savetxt(filename, data, fmt='%0.6f')
99 | print(f'saved at {filename}')
100 |
--------------------------------------------------------------------------------