├── .gitignore
├── LICENSE
├── README.md
├── base
├── __init__.py
├── base_dataloader.py
├── base_dataset.py
├── base_model.py
└── base_trainer.py
├── configs
└── config.json
├── dataloaders
├── __init__.py
├── voc.py
└── voc_splits
│ ├── 1000_train_supervised.txt
│ ├── 1000_train_unsupervised.txt
│ ├── 100_train_supervised.txt
│ ├── 100_train_unsupervised.txt
│ ├── 1464_train_supervised.txt
│ ├── 1464_train_unsupervised.txt
│ ├── 200_train_supervised.txt
│ ├── 200_train_unsupervised.txt
│ ├── 300_train_supervised.txt
│ ├── 300_train_unsupervised.txt
│ ├── 500_train_supervised.txt
│ ├── 500_train_unsupervised.txt
│ ├── 60_train_supervised.txt
│ ├── 60_train_unsupervised.txt
│ ├── 800_train_supervised.txt
│ ├── 800_train_unsupervised.txt
│ ├── boxes.json
│ ├── classes.json
│ └── val.txt
├── inference.py
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ ├── get_pretrained_model.sh
│ ├── module_helper.py
│ ├── resnet_backbone.py
│ └── resnet_models.py
├── decoders.py
├── encoder.py
└── model.py
├── pseudo_labels
├── README.md
├── cam_to_pseudo_labels.py
├── make_cam.py
├── misc
│ ├── imutils.py
│ ├── pyutils.py
│ └── torchutils.py
├── net
│ ├── resnet50.py
│ └── resnet50_cam.py
├── run.py
├── train_cam.py
└── voc12
│ ├── cls_labels.npy
│ ├── dataloader.py
│ ├── make_cls_labels.py
│ ├── test.txt
│ ├── train.txt
│ ├── train_aug.txt
│ └── val.txt
├── requirements.txt
├── train.py
├── trainer.py
└── utils
├── __init__.py
├── helpers.py
├── htmlwriter.py
├── logger.py
├── losses.py
├── lr_scheduler.py
├── metrics.py
├── pallete.py
└── ramps.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 | pretrained/
9 | tb_history/
10 | logs/
11 | cls_runs/
12 | slurm_logs/
13 |
14 | #experiments/
15 | experiments/
16 | experiments_da/
17 | config[0-9]*
18 | *.png
19 | *.pth
20 |
21 | # Distribution / packaging
22 | .Python
23 | env/
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # dotenv
95 | .env
96 |
97 | # virtualenv
98 | .venv
99 | venv/
100 | ENV/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
115 | # input data, saved log, checkpoints
116 | data/
117 | input/
118 | saved/
119 | outputs/
120 | datasets/
121 |
122 | # editor, os cache directory
123 | .vscode/
124 | .idea/
125 | __MACOSX/
126 |
127 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yassine Ouali
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)
4 |
5 | #### [Paper](https://arxiv.org/abs/2003.09005), [Project Page](https://yassouali.github.io/cct_page/)
6 |
7 | This repo contains the official implementation of CVPR 2020 paper: Semi-Supervised Semantic Segmentation with Cross-Consistency Training, which
8 | adapts the traditional consistency training framework of semi-supervised learning for semantic segmentation, with an extension to weak-supervised
9 | learning and learning on multiple domains.
10 |
11 |

12 |
13 | ### Highlights
14 |
15 | **(1) Consistency Training for semantic segmentation.** \
16 | We observe that for semantic segmentation, due to the dense nature of the task,
17 | the cluster assumption is more easily enforced over the hidden representations rather than the inputs.
18 |
19 | **(2) Cross-Consistency Training.** \
20 | We propose CCT (Cross-Consistency Training) for semi-supervised semantic segmentation, where we define
21 | a number of novel perturbations, and show the effectiveness of enforcing consistency over the encoder's outputs
22 | rather than the inputs.
23 |
24 | **(3) Using weak-labels and pixel-level labels from multiple domains.** \
25 | The proposed method is quite simple and flexible, and can easily be extended to use image-level labels and
26 | pixel-level labels from multiple-domains.
27 |
28 |
29 |
30 | ### Requirements
31 |
32 | This repo was tested with Ubuntu 18.04.3 LTS, Python 3.7, PyTorch 1.1.0, and CUDA 10.0. But it should be runnable with recent PyTorch versions >=1.1.0.
33 |
34 | The required packages are `pytorch` and `torchvision`, together with `PIL` and `opencv` for data-preprocessing and `tqdm` for showing the training progress.
35 | With some additional modules like `dominate` to save the results in the form of HTML files. To setup the necessary modules, simply run:
36 |
37 | ```bash
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ### Dataset
42 |
43 | In this repo, we use **Pascal VOC**, to obtain it, first download the [original dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar), after extracting the files we'll end up with `VOCtrainval_11-May-2012/VOCdevkit/VOC2012` containing the image sets, the XML annotation for both object detection and segmentation, and JPEG images.\
44 | The second step is to augment the dataset using the additionnal annotations provided by [Semantic Contours from Inverse Detectors](http://home.bharathh.info/pubs/pdfs/BharathICCV2011.pdf). Download the rest of the annotations [SegmentationClassAug](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0) and add them to the path `VOCtrainval_11-May-2012/VOCdevkit/VOC2012`, now we're set, for training use the path to `VOCtrainval_11-May-2012`.
45 |
46 |
47 | ### Training
48 |
49 | To train a model, first download PASCAL VOC as detailed above, then set `data_dir` to the dataset path in the config file in `configs/config.json` and set the rest of the parameters, like the number of GPUs, cope size, data augmentation ... etc ,you can also change CCT hyperparameters if you wish, more details below. Then simply run:
50 |
51 | ```bash
52 | python train.py --config configs/config.json
53 | ```
54 |
55 | The log files and the `.pth` checkpoints will be saved in `saved\EXP_NAME`, to monitor the training using tensorboard, please run:
56 |
57 | ```bash
58 | tensorboard --logdir saved
59 | ```
60 |
61 | To resume training using a saved `.pth` model:
62 |
63 | ```bash
64 | python train.py --config configs/config.json --resume saved/CCT/checkpoint.pth
65 | ```
66 |
67 | **Results**: The results will be saved in `saved` as an html file, containing the validation results,
68 | and the name it will take is `experim_name` specified in `configs/config.json`.
69 |
70 | ### Pseudo-labels
71 |
72 | If you want to use image level labels to train the auxiliary labels as explained in section 3.3 of the paper. First generate the pseudo-labels
73 | using the code in `pseudo_labels`:
74 |
75 |
76 | ```bash
77 | cd pseudo_labels
78 | python run.py --voc12_root DATA_PATH
79 | ```
80 |
81 | `DATA_PATH` must point to the folder containing `JPEGImages` in Pascal Voc dataset. The results will be
82 | saved in `pseudo_labels/result/pseudo_labels` as PNG files, the flag `use_weak_labels` needs to be set to True in the config file, and
83 | then we can train the model as detailed above.
84 |
85 |
86 | ### Inference
87 |
88 | For inference, we need a pretrained model, the jpg images we'd like to segment and the config used in training (to load the correct model and other parameters),
89 |
90 | ```bash
91 | python inference.py --config config.json --model best_model.pth --images images_folder
92 | ```
93 |
94 | The predictions will be saved as `.png` images in `outputs\` is used, for Pacal VOC the default palette is:
95 |
96 | 
97 |
98 | Here are the flags available for inference:
99 |
100 | ```
101 | --images Folder containing the jpg images to segment.
102 | --model Path to the trained pth model.
103 | --config The config file used for training the model.
104 | ```
105 |
106 | ### Pre-trained models
107 |
108 | Pre-trained models can be downloaded [here](https://github.com/yassouali/CCT/releases).
109 |
110 | ### Citation ✏️ 📄
111 |
112 | If you find this repo useful for your research, please consider citing the paper as follows:
113 |
114 | ```
115 | @InProceedings{Ouali_2020_CVPR,
116 | author = {Ouali, Yassine and Hudelot, Celine and Tami, Myriam},
117 | title = {Semi-Supervised Semantic Segmentation With Cross-Consistency Training},
118 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
119 | month = {June},
120 | year = {2020}
121 | }
122 | ```
123 |
124 | For any questions, please contact Yassine Ouali.
125 |
126 | #### Config file details ⚙️
127 |
128 | Bellow we detail the CCT parameters that can be controlled in the config file `configs/config.json`, the rest of the parameters
129 | are self-explanatory.
130 |
131 | ```javascript
132 | {
133 | "name": "CCT",
134 | "experim_name": "CCT", // The name the results will take (html and the folder in /saved)
135 | "n_gpu": 1, // Number of GPUs
136 | "n_labeled_examples": 1000, // Number of labeled examples (choices are 60, 100, 200,
137 | // 300, 500, 800, 1000, 1464, and the splits are in dataloaders/voc_splits)
138 | "diff_lrs": true,
139 | "ramp_up": 0.1, // The unsupervised loss will be slowly scaled up in the first 10% of Training time
140 | "unsupervised_w": 30, // Weighting of the unsupervised loss
141 | "ignore_index": 255,
142 | "lr_scheduler": "Poly",
143 | "use_weak_labels": false, // If the pseudo-labels were generated, we can use them to train the aux. decoders
144 | "weakly_loss_w": 0.4, // Weighting of the weakly-supervised loss
145 | "pretrained": true,
146 |
147 | "model":{
148 | "supervised": true, // Supervised setting (training only on the labeled examples)
149 | "semi": false, // Semi-supervised setting
150 | "supervised_w": 1, // Weighting of the supervised loss
151 |
152 | "sup_loss": "CE", // supervised loss, choices are CE and ab-CE = ["CE", "ABCE"]
153 | "un_loss": "MSE", // unsupervised loss, choices are CE and KL-divergence = ["MSE", "KL"]
154 |
155 | "softmax_temp": 1,
156 | "aux_constraint": false, // Pair-wise loss (sup. mat.)
157 | "aux_constraint_w": 1,
158 | "confidence_masking": false, // Confidence masking (sup. mat.)
159 | "confidence_th": 0.5,
160 |
161 | "drop": 6, // Number of DropOut decoders
162 | "drop_rate": 0.5, // Dropout probability
163 | "spatial": true,
164 |
165 | "cutout": 6, // Number of G-Cutout decoders
166 | "erase": 0.4, // We drop 40% of the area
167 |
168 | "vat": 2, // Number of I-VAT decoders
169 | "xi": 1e-6, // VAT parameters
170 | "eps": 2.0,
171 |
172 | "context_masking": 2, // Number of Con-Msk decoders
173 | "object_masking": 2, // Number of Obj-Msk decoders
174 | "feature_drop": 6, // Number of F-Drop decoders
175 |
176 | "feature_noise": 6, // Number of F-Noise decoders
177 | "uniform_range": 0.3 // The range of the noise
178 | },
179 | ```
180 |
181 | #### Acknowledgements
182 |
183 | - Pseudo-labels generation is based on Jiwoon Ahn's implementation [irn](https://github.com/jiwoon-ahn/irn).
184 | - Code structure was based on [Pytorch-Template](https://github.com/victoresque/pytorch-template/blob/master/README.m)
185 | - ResNet backbone was downloaded from [torchcv](https://github.com/donnyyou/torchcv)
186 |
--------------------------------------------------------------------------------
/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_dataloader import *
2 | from .base_dataset import *
3 | from .base_model import *
4 | from .base_trainer import *
5 |
6 |
7 |
--------------------------------------------------------------------------------
/base/base_dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from torch.utils.data.sampler import SubsetRandomSampler
6 |
7 | class BaseDataLoader(DataLoader):
8 | def __init__(self, dataset, batch_size, shuffle, num_workers, val_split = 0.0):
9 | self.shuffle = shuffle
10 | self.dataset = dataset
11 | self.nbr_examples = len(dataset)
12 | if val_split:
13 | self.train_sampler, self.val_sampler = self._split_sampler(val_split)
14 | else:
15 | self.train_sampler, self.val_sampler = None, None
16 |
17 | self.init_kwargs = {
18 | 'dataset': self.dataset,
19 | 'batch_size': batch_size,
20 | 'shuffle': self.shuffle,
21 | 'num_workers': num_workers,
22 | 'pin_memory': True
23 | }
24 | super(BaseDataLoader, self).__init__(sampler=self.train_sampler, **self.init_kwargs)
25 |
26 | def _split_sampler(self, split):
27 | if split == 0.0:
28 | return None, None
29 |
30 | self.shuffle = False
31 |
32 | split_indx = int(self.nbr_examples * split)
33 | np.random.seed(0)
34 |
35 | indxs = np.arange(self.nbr_examples)
36 | np.random.shuffle(indxs)
37 | train_indxs = indxs[split_indx:]
38 | val_indxs = indxs[:split_indx]
39 | self.nbr_examples = len(train_indxs)
40 |
41 | train_sampler = SubsetRandomSampler(train_indxs)
42 | val_sampler = SubsetRandomSampler(val_indxs)
43 | return train_sampler, val_sampler
44 |
45 | def get_val_loader(self):
46 | if self.val_sampler is None:
47 | return None
48 | return DataLoader(sampler=self.val_sampler, **self.init_kwargs)
49 |
--------------------------------------------------------------------------------
/base/base_dataset.py:
--------------------------------------------------------------------------------
1 | import random, math
2 | import numpy as np
3 | import cv2
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import Dataset
7 | from PIL import Image
8 | from torchvision import transforms
9 | from scipy import ndimage
10 | from math import ceil
11 |
12 | class BaseDataSet(Dataset):
13 | def __init__(self, data_dir, split, mean, std, ignore_index, base_size=None, augment=True, val=False,
14 | jitter=False, use_weak_lables=False, weak_labels_output=None, crop_size=None, scale=False, flip=False, rotate=False,
15 | blur=False, return_id=False, n_labeled_examples=None):
16 |
17 | self.root = data_dir
18 | self.split = split
19 | self.mean = mean
20 | self.std = std
21 | self.augment = augment
22 | self.crop_size = crop_size
23 | self.jitter = jitter
24 | self.image_padding = (np.array(mean)*255.).tolist()
25 | self.ignore_index = ignore_index
26 | self.return_id = return_id
27 | self.n_labeled_examples = n_labeled_examples
28 | self.val = val
29 |
30 | self.use_weak_lables = use_weak_lables
31 | self.weak_labels_output = weak_labels_output
32 |
33 | if self.augment:
34 | self.base_size = base_size
35 | self.scale = scale
36 | self.flip = flip
37 | self.rotate = rotate
38 | self.blur = blur
39 |
40 | self.jitter_tf = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
41 | self.to_tensor = transforms.ToTensor()
42 | self.normalize = transforms.Normalize(mean, std)
43 |
44 | self.files = []
45 | self._set_files()
46 |
47 | cv2.setNumThreads(0)
48 |
49 | def _set_files(self):
50 | raise NotImplementedError
51 |
52 | def _load_data(self, index):
53 | raise NotImplementedError
54 |
55 | def _rotate(self, image, label):
56 | # Rotate the image with an angle between -10 and 10
57 | h, w, _ = image.shape
58 | angle = random.randint(-10, 10)
59 | center = (w / 2, h / 2)
60 | rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
61 | image = cv2.warpAffine(image, rot_matrix, (w, h), flags=cv2.INTER_CUBIC)#, borderMode=cv2.BORDER_REFLECT)
62 | label = cv2.warpAffine(label, rot_matrix, (w, h), flags=cv2.INTER_NEAREST)#, borderMode=cv2.BORDER_REFLECT)
63 | return image, label
64 |
65 | def _crop(self, image, label):
66 | # Padding to return the correct crop size
67 | if (isinstance(self.crop_size, list) or isinstance(self.crop_size, tuple)) and len(self.crop_size) == 2:
68 | crop_h, crop_w = self.crop_size
69 | elif isinstance(self.crop_size, int):
70 | crop_h, crop_w = self.crop_size, self.crop_size
71 | else:
72 | raise ValueError
73 |
74 | h, w, _ = image.shape
75 | pad_h = max(crop_h - h, 0)
76 | pad_w = max(crop_w - w, 0)
77 | pad_kwargs = {
78 | "top": 0,
79 | "bottom": pad_h,
80 | "left": 0,
81 | "right": pad_w,
82 | "borderType": cv2.BORDER_CONSTANT,}
83 | if pad_h > 0 or pad_w > 0:
84 | image = cv2.copyMakeBorder(image, value=self.image_padding, **pad_kwargs)
85 | label = cv2.copyMakeBorder(label, value=self.ignore_index, **pad_kwargs)
86 |
87 | # Cropping
88 | h, w, _ = image.shape
89 | start_h = random.randint(0, h - crop_h)
90 | start_w = random.randint(0, w - crop_w)
91 | end_h = start_h + crop_h
92 | end_w = start_w + crop_w
93 | image = image[start_h:end_h, start_w:end_w]
94 | label = label[start_h:end_h, start_w:end_w]
95 | return image, label
96 |
97 | def _blur(self, image, label):
98 | # Gaussian Blud (sigma between 0 and 1.5)
99 | sigma = random.random() * 1.5
100 | ksize = int(3.3 * sigma)
101 | ksize = ksize + 1 if ksize % 2 == 0 else ksize
102 | image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT_101)
103 | return image, label
104 |
105 | def _flip(self, image, label):
106 | # Random H flip
107 | if random.random() > 0.5:
108 | image = np.fliplr(image).copy()
109 | label = np.fliplr(label).copy()
110 | return image, label
111 |
112 | def _resize(self, image, label, bigger_side_to_base_size=True):
113 | if isinstance(self.base_size, int):
114 | h, w, _ = image.shape
115 | if self.scale:
116 | longside = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
117 | #longside = random.randint(int(self.base_size*0.5), int(self.base_size*1))
118 | else:
119 | longside = self.base_size
120 |
121 | if bigger_side_to_base_size:
122 | h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (int(1.0 * longside * h / w + 0.5), longside)
123 | else:
124 | h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h < w else (int(1.0 * longside * h / w + 0.5), longside)
125 | image = np.asarray(Image.fromarray(np.uint8(image)).resize((w, h), Image.BICUBIC))
126 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
127 | return image, label
128 |
129 | elif (isinstance(self.base_size, list) or isinstance(self.base_size, tuple)) and len(self.base_size) == 2:
130 | h, w, _ = image.shape
131 | if self.scale:
132 | scale = random.random() * 1.5 + 0.5 # Scaling between [0.5, 2]
133 | h, w = int(self.base_size[0] * scale), int(self.base_size[1] * scale)
134 | else:
135 | h, w = self.base_size
136 | image = np.asarray(Image.fromarray(np.uint8(image)).resize((w, h), Image.BICUBIC))
137 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST)
138 | return image, label
139 |
140 | else:
141 | raise ValueError
142 |
143 | def _val_augmentation(self, image, label):
144 | if self.base_size is not None:
145 | image, label = self._resize(image, label)
146 | image = self.normalize(self.to_tensor(Image.fromarray(np.uint8(image))))
147 | return image, label
148 |
149 | image = self.normalize(self.to_tensor(Image.fromarray(np.uint8(image))))
150 | return image, label
151 |
152 | def _augmentation(self, image, label):
153 | h, w, _ = image.shape
154 |
155 | if self.base_size is not None:
156 | image, label = self._resize(image, label)
157 |
158 | if self.crop_size is not None:
159 | image, label = self._crop(image, label)
160 |
161 | if self.flip:
162 | image, label = self._flip(image, label)
163 |
164 | image = Image.fromarray(np.uint8(image))
165 | image = self.jitter_tf(image) if self.jitter else image
166 |
167 | return self.normalize(self.to_tensor(image)), label
168 |
169 | def __len__(self):
170 | return len(self.files)
171 |
172 | def __getitem__(self, index):
173 | image, label, image_id = self._load_data(index)
174 | if self.val:
175 | image, label = self._val_augmentation(image, label)
176 | elif self.augment:
177 | image, label = self._augmentation(image, label)
178 |
179 | label = torch.from_numpy(np.array(label, dtype=np.int32)).long()
180 | return image, label
181 |
182 | def __repr__(self):
183 | fmt_str = "Dataset: " + self.__class__.__name__ + "\n"
184 | fmt_str += " # data: {}\n".format(self.__len__())
185 | fmt_str += " Split: {}\n".format(self.split)
186 | fmt_str += " Root: {}".format(self.root)
187 | return fmt_str
188 |
189 |
--------------------------------------------------------------------------------
/base/base_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | class BaseModel(nn.Module):
6 | def __init__(self):
7 | super(BaseModel, self).__init__()
8 | self.logger = logging.getLogger(self.__class__.__name__)
9 |
10 | def forward(self):
11 | raise NotImplementedError
12 |
13 | def summary(self):
14 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
15 | nbr_params = sum([np.prod(p.size()) for p in model_parameters])
16 | self.logger.info(f'Nbr of trainable parameters: {nbr_params}')
17 |
18 | def __str__(self):
19 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
20 | nbr_params = int(sum([np.prod(p.size()) for p in model_parameters]))
21 | return f'\nNbr of trainable parameters: {nbr_params}'
22 | #return super(BaseModel, self).__str__() + f'\nNbr of trainable parameters: {nbr_params}'
23 |
--------------------------------------------------------------------------------
/base/base_trainer.py:
--------------------------------------------------------------------------------
1 | import os, json, math, logging, sys, datetime
2 | import torch
3 | from torch.utils import tensorboard
4 | from utils import helpers
5 | from utils import logger
6 | import utils.lr_scheduler
7 | from utils.htmlwriter import HTML
8 |
9 | def get_instance(module, name, config, *args):
10 | return getattr(module, config[name]['type'])(*args, **config[name]['args'])
11 |
12 | class BaseTrainer:
13 | def __init__(self, model, resume, config, iters_per_epoch, train_logger=None):
14 | self.model = model
15 | self.config = config
16 |
17 | self.train_logger = train_logger
18 | self.logger = logging.getLogger(self.__class__.__name__)
19 | self.do_validation = self.config['trainer']['val']
20 | self.start_epoch = 1
21 | self.improved = False
22 |
23 | # SETTING THE DEVICE
24 | self.device, availble_gpus = self._get_available_devices(self.config['n_gpu'])
25 | self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus)
26 | self.model.to(self.device)
27 |
28 | # CONFIGS
29 | cfg_trainer = self.config['trainer']
30 | self.epochs = cfg_trainer['epochs']
31 | self.save_period = cfg_trainer['save_period']
32 |
33 | # OPTIMIZER
34 | trainable_params = [{'params': filter(lambda p:p.requires_grad, self.model.module.get_other_params())},
35 | {'params': filter(lambda p:p.requires_grad, self.model.module.get_backbone_params()),
36 | 'lr': config['optimizer']['args']['lr'] / 10}]
37 |
38 | self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params)
39 | model_params = sum([i.shape.numel() for i in list(model.parameters())])
40 | opt_params = sum([i.shape.numel() for j in self.optimizer.param_groups for i in j['params']])
41 | assert opt_params == model_params, 'some params are missing in the opt'
42 |
43 | self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler'])(optimizer=self.optimizer, num_epochs=self.epochs,
44 | iters_per_epoch=iters_per_epoch)
45 |
46 | # MONITORING
47 | self.monitor = cfg_trainer.get('monitor', 'off')
48 | if self.monitor == 'off':
49 | self.mnt_mode = 'off'
50 | self.mnt_best = 0
51 | else:
52 | self.mnt_mode, self.mnt_metric = self.monitor.split()
53 | assert self.mnt_mode in ['min', 'max']
54 | self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf
55 | self.early_stoping = cfg_trainer.get('early_stop', math.inf)
56 |
57 | # CHECKPOINTS & TENSOBOARD
58 | date_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
59 | run_name = config['experim_name']
60 | self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], run_name)
61 | helpers.dir_exists(self.checkpoint_dir)
62 | config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
63 | with open(config_save_path, 'w') as handle:
64 | json.dump(self.config, handle, indent=4, sort_keys=True)
65 |
66 | writer_dir = os.path.join(cfg_trainer['log_dir'], run_name)
67 | self.writer = tensorboard.SummaryWriter(writer_dir)
68 | self.html_results = HTML(web_dir=config['trainer']['save_dir'], exp_name=config['experim_name'],
69 | save_name=config['experim_name'], config=config, resume=resume)
70 |
71 | if resume: self._resume_checkpoint(resume)
72 |
73 | def _get_available_devices(self, n_gpu):
74 | sys_gpu = torch.cuda.device_count()
75 | if sys_gpu == 0:
76 | self.logger.warning('No GPUs detected, using the CPU')
77 | n_gpu = 0
78 | elif n_gpu > sys_gpu:
79 | self.logger.warning(f'Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available')
80 | n_gpu = sys_gpu
81 |
82 | device = torch.device('cuda:0' if n_gpu > 0 else 'cpu')
83 | self.logger.info(f'Detected GPUs: {sys_gpu} Requested: {n_gpu}')
84 | available_gpus = list(range(n_gpu))
85 | return device, available_gpus
86 |
87 |
88 |
89 | def train(self):
90 | for epoch in range(self.start_epoch, self.epochs+1):
91 | results = self._train_epoch(epoch)
92 | if self.do_validation and epoch % self.config['trainer']['val_per_epochs'] == 0:
93 | results = self._valid_epoch(epoch)
94 | self.logger.info('\n\n')
95 | for k, v in results.items():
96 | self.logger.info(f' {str(k):15s}: {v}')
97 |
98 | if self.train_logger is not None:
99 | log = {'epoch' : epoch, **results}
100 | self.train_logger.add_entry(log)
101 |
102 | # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL)
103 | if self.mnt_mode != 'off' and epoch % self.config['trainer']['val_per_epochs'] == 0:
104 | try:
105 | if self.mnt_mode == 'min': self.improved = (log[self.mnt_metric] < self.mnt_best)
106 | else: self.improved = (log[self.mnt_metric] > self.mnt_best)
107 | except KeyError:
108 | self.logger.warning(f'The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops.')
109 | break
110 |
111 | if self.improved:
112 | self.mnt_best = log[self.mnt_metric]
113 | self.not_improved_count = 0
114 | else:
115 | self.not_improved_count += 1
116 |
117 | if self.not_improved_count > self.early_stoping:
118 | self.logger.info(f'\nPerformance didn\'t improve for {self.early_stoping} epochs')
119 | self.logger.warning('Training Stoped')
120 | break
121 |
122 | # SAVE CHECKPOINT
123 | if epoch % self.save_period == 0:
124 | self._save_checkpoint(epoch, save_best=self.improved)
125 | self.html_results.save()
126 |
127 |
128 | def _save_checkpoint(self, epoch, save_best=False):
129 | state = {
130 | 'arch': type(self.model).__name__,
131 | 'epoch': epoch,
132 | 'state_dict': self.model.state_dict(),
133 | 'monitor_best': self.mnt_best,
134 | 'config': self.config
135 | }
136 |
137 | filename = os.path.join(self.checkpoint_dir, f'checkpoint.pth')
138 | self.logger.info(f'\nSaving a checkpoint: {filename} ...')
139 | torch.save(state, filename)
140 |
141 | if save_best:
142 | filename = os.path.join(self.checkpoint_dir, f'best_model.pth')
143 | torch.save(state, filename)
144 | self.logger.info("Saving current best: best_model.pth")
145 |
146 | def _resume_checkpoint(self, resume_path):
147 | self.logger.info(f'Loading checkpoint : {resume_path}')
148 | checkpoint = torch.load(resume_path)
149 | self.start_epoch = checkpoint['epoch'] + 1
150 | self.mnt_best = checkpoint['monitor_best']
151 | self.not_improved_count = 0
152 |
153 | try:
154 | self.model.load_state_dict(checkpoint['state_dict'])
155 | except Exception as e:
156 | print(f'Error when loading: {e}')
157 | self.model.load_state_dict(checkpoint['state_dict'], strict=False)
158 |
159 | if "logger" in checkpoint.keys():
160 | self.train_logger = checkpoint['logger']
161 | self.logger.info(f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded')
162 |
163 | def _train_epoch(self, epoch):
164 | raise NotImplementedError
165 |
166 | def _valid_epoch(self, epoch):
167 | raise NotImplementedError
168 |
169 | def _eval_metrics(self, output, target):
170 | raise NotImplementedError
171 |
--------------------------------------------------------------------------------
/configs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "CCT",
3 | "experim_name": "CCT",
4 | "n_gpu": 1,
5 | "n_labeled_examples": 1464,
6 | "diff_lrs": true,
7 | "ramp_up": 0.1,
8 | "unsupervised_w": 30,
9 | "ignore_index": 255,
10 | "lr_scheduler": "Poly",
11 | "use_weak_lables":false,
12 | "weakly_loss_w": 0.4,
13 | "pretrained": true,
14 |
15 | "model":{
16 | "supervised": false,
17 | "semi": true,
18 | "supervised_w": 1,
19 |
20 | "sup_loss": "CE",
21 | "un_loss": "MSE",
22 |
23 | "softmax_temp": 1,
24 | "aux_constraint": false,
25 | "aux_constraint_w": 1,
26 | "confidence_masking": false,
27 | "confidence_th": 0.5,
28 |
29 | "drop": 6,
30 | "drop_rate": 0.5,
31 | "spatial": true,
32 |
33 | "cutout": 6,
34 | "erase": 0.4,
35 |
36 | "vat": 2,
37 | "xi": 1e-6,
38 | "eps": 2.0,
39 |
40 | "context_masking": 2,
41 | "object_masking": 2,
42 | "feature_drop": 6,
43 |
44 | "feature_noise": 6,
45 | "uniform_range": 0.3
46 | },
47 |
48 |
49 | "optimizer": {
50 | "type": "SGD",
51 | "args":{
52 | "lr": 1e-2,
53 | "weight_decay": 1e-4,
54 | "momentum": 0.9
55 | }
56 | },
57 |
58 |
59 | "train_supervised": {
60 | "data_dir": "VOCtrainval_11-May-2012",
61 | "batch_size": 10,
62 | "crop_size": 320,
63 | "shuffle": true,
64 | "base_size": 400,
65 | "scale": true,
66 | "augment": true,
67 | "flip": true,
68 | "rotate": false,
69 | "blur": false,
70 | "split": "train_supervised",
71 | "num_workers": 8
72 | },
73 |
74 | "train_unsupervised": {
75 | "data_dir": "VOCtrainval_11-May-2012",
76 | "weak_labels_output": "pseudo_labels/result/pseudo_labels",
77 | "batch_size": 10,
78 | "crop_size": 320,
79 | "shuffle": true,
80 | "base_size": 400,
81 | "scale": true,
82 | "augment": true,
83 | "flip": true,
84 | "rotate": false,
85 | "blur": false,
86 | "split": "train_unsupervised",
87 | "num_workers": 8
88 | },
89 |
90 | "val_loader": {
91 | "data_dir": "VOCtrainval_11-May-2012",
92 | "batch_size": 1,
93 | "val": true,
94 | "split": "val",
95 | "shuffle": false,
96 | "num_workers": 4
97 | },
98 |
99 | "trainer": {
100 | "epochs": 80,
101 | "save_dir": "saved/",
102 | "save_period": 5,
103 |
104 | "monitor": "max Mean_IoU",
105 | "early_stop": 10,
106 |
107 | "tensorboardX": true,
108 | "log_dir": "saved/",
109 | "log_per_iter": 20,
110 |
111 | "val": true,
112 | "val_per_epochs": 5
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .voc import VOC
--------------------------------------------------------------------------------
/dataloaders/voc.py:
--------------------------------------------------------------------------------
1 | from base import BaseDataSet, BaseDataLoader
2 | from utils import pallete
3 | import numpy as np
4 | import os
5 | import scipy
6 | import torch
7 | from PIL import Image
8 | import cv2
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 | import json
12 |
13 | class VOCDataset(BaseDataSet):
14 | def __init__(self, **kwargs):
15 | self.num_classes = 21
16 |
17 | self.palette = pallete.get_voc_pallete(self.num_classes)
18 | super(VOCDataset, self).__init__(**kwargs)
19 |
20 | def _set_files(self):
21 | self.root = os.path.join(self.root, 'VOCdevkit/VOC2012')
22 | if self.split == "val":
23 | file_list = os.path.join("dataloaders/voc_splits", f"{self.split}" + ".txt")
24 | elif self.split in ["train_supervised", "train_unsupervised"]:
25 | file_list = os.path.join("dataloaders/voc_splits", f"{self.n_labeled_examples}_{self.split}" + ".txt")
26 | else:
27 | raise ValueError(f"Invalid split name {self.split}")
28 |
29 | file_list = [line.rstrip().split(' ') for line in tuple(open(file_list, "r"))]
30 | self.files, self.labels = list(zip(*file_list))
31 |
32 | def _load_data(self, index):
33 | image_path = os.path.join(self.root, self.files[index][1:])
34 | image = np.asarray(Image.open(image_path), dtype=np.float32)
35 | image_id = self.files[index].split("/")[-1].split(".")[0]
36 | if self.use_weak_lables:
37 | label_path = os.path.join(self.weak_labels_output, image_id+".png")
38 | else:
39 | label_path = os.path.join(self.root, self.labels[index][1:])
40 | label = np.asarray(Image.open(label_path), dtype=np.int32)
41 | return image, label, image_id
42 |
43 | class VOC(BaseDataLoader):
44 | def __init__(self, kwargs):
45 |
46 | self.MEAN = [0.485, 0.456, 0.406]
47 | self.STD = [0.229, 0.224, 0.225]
48 | self.batch_size = kwargs.pop('batch_size')
49 | kwargs['mean'] = self.MEAN
50 | kwargs['std'] = self.STD
51 | kwargs['ignore_index'] = 255
52 | try:
53 | shuffle = kwargs.pop('shuffle')
54 | except:
55 | shuffle = False
56 | num_workers = kwargs.pop('num_workers')
57 |
58 | self.dataset = VOCDataset(**kwargs)
59 |
60 | super(VOC, self).__init__(self.dataset, self.batch_size, shuffle, num_workers, val_split=None)
61 |
--------------------------------------------------------------------------------
/dataloaders/voc_splits/100_train_supervised.txt:
--------------------------------------------------------------------------------
1 | /JPEGImages/2007_000032.jpg /SegmentationClassAug/2007_000032.png
2 | /JPEGImages/2007_000039.jpg /SegmentationClassAug/2007_000039.png
3 | /JPEGImages/2007_000063.jpg /SegmentationClassAug/2007_000063.png
4 | /JPEGImages/2007_000068.jpg /SegmentationClassAug/2007_000068.png
5 | /JPEGImages/2007_000121.jpg /SegmentationClassAug/2007_000121.png
6 | /JPEGImages/2007_000170.jpg /SegmentationClassAug/2007_000170.png
7 | /JPEGImages/2007_000241.jpg /SegmentationClassAug/2007_000241.png
8 | /JPEGImages/2007_000243.jpg /SegmentationClassAug/2007_000243.png
9 | /JPEGImages/2007_000250.jpg /SegmentationClassAug/2007_000250.png
10 | /JPEGImages/2007_000256.jpg /SegmentationClassAug/2007_000256.png
11 | /JPEGImages/2007_000333.jpg /SegmentationClassAug/2007_000333.png
12 | /JPEGImages/2007_000363.jpg /SegmentationClassAug/2007_000363.png
13 | /JPEGImages/2007_000364.jpg /SegmentationClassAug/2007_000364.png
14 | /JPEGImages/2007_000392.jpg /SegmentationClassAug/2007_000392.png
15 | /JPEGImages/2007_000480.jpg /SegmentationClassAug/2007_000480.png
16 | /JPEGImages/2007_000504.jpg /SegmentationClassAug/2007_000504.png
17 | /JPEGImages/2007_000515.jpg /SegmentationClassAug/2007_000515.png
18 | /JPEGImages/2007_000528.jpg /SegmentationClassAug/2007_000528.png
19 | /JPEGImages/2007_000549.jpg /SegmentationClassAug/2007_000549.png
20 | /JPEGImages/2007_000584.jpg /SegmentationClassAug/2007_000584.png
21 | /JPEGImages/2007_000645.jpg /SegmentationClassAug/2007_000645.png
22 | /JPEGImages/2007_000648.jpg /SegmentationClassAug/2007_000648.png
23 | /JPEGImages/2007_000713.jpg /SegmentationClassAug/2007_000713.png
24 | /JPEGImages/2007_000720.jpg /SegmentationClassAug/2007_000720.png
25 | /JPEGImages/2007_000733.jpg /SegmentationClassAug/2007_000733.png
26 | /JPEGImages/2007_000738.jpg /SegmentationClassAug/2007_000738.png
27 | /JPEGImages/2007_000768.jpg /SegmentationClassAug/2007_000768.png
28 | /JPEGImages/2007_000793.jpg /SegmentationClassAug/2007_000793.png
29 | /JPEGImages/2007_000822.jpg /SegmentationClassAug/2007_000822.png
30 | /JPEGImages/2007_000836.jpg /SegmentationClassAug/2007_000836.png
31 | /JPEGImages/2007_000876.jpg /SegmentationClassAug/2007_000876.png
32 | /JPEGImages/2007_000904.jpg /SegmentationClassAug/2007_000904.png
33 | /JPEGImages/2007_001027.jpg /SegmentationClassAug/2007_001027.png
34 | /JPEGImages/2007_001073.jpg /SegmentationClassAug/2007_001073.png
35 | /JPEGImages/2007_001149.jpg /SegmentationClassAug/2007_001149.png
36 | /JPEGImages/2007_001185.jpg /SegmentationClassAug/2007_001185.png
37 | /JPEGImages/2007_001225.jpg /SegmentationClassAug/2007_001225.png
38 | /JPEGImages/2007_001397.jpg /SegmentationClassAug/2007_001397.png
39 | /JPEGImages/2007_001416.jpg /SegmentationClassAug/2007_001416.png
40 | /JPEGImages/2007_001420.jpg /SegmentationClassAug/2007_001420.png
41 | /JPEGImages/2007_001439.jpg /SegmentationClassAug/2007_001439.png
42 | /JPEGImages/2007_001487.jpg /SegmentationClassAug/2007_001487.png
43 | /JPEGImages/2007_001595.jpg /SegmentationClassAug/2007_001595.png
44 | /JPEGImages/2007_001602.jpg /SegmentationClassAug/2007_001602.png
45 | /JPEGImages/2007_001609.jpg /SegmentationClassAug/2007_001609.png
46 | /JPEGImages/2007_001698.jpg /SegmentationClassAug/2007_001698.png
47 | /JPEGImages/2007_001704.jpg /SegmentationClassAug/2007_001704.png
48 | /JPEGImages/2007_001709.jpg /SegmentationClassAug/2007_001709.png
49 | /JPEGImages/2007_001724.jpg /SegmentationClassAug/2007_001724.png
50 | /JPEGImages/2007_001764.jpg /SegmentationClassAug/2007_001764.png
51 | /JPEGImages/2007_001825.jpg /SegmentationClassAug/2007_001825.png
52 | /JPEGImages/2007_001834.jpg /SegmentationClassAug/2007_001834.png
53 | /JPEGImages/2007_001857.jpg /SegmentationClassAug/2007_001857.png
54 | /JPEGImages/2007_001872.jpg /SegmentationClassAug/2007_001872.png
55 | /JPEGImages/2007_001901.jpg /SegmentationClassAug/2007_001901.png
56 | /JPEGImages/2007_001917.jpg /SegmentationClassAug/2007_001917.png
57 | /JPEGImages/2007_001960.jpg /SegmentationClassAug/2007_001960.png
58 | /JPEGImages/2007_002024.jpg /SegmentationClassAug/2007_002024.png
59 | /JPEGImages/2007_002055.jpg /SegmentationClassAug/2007_002055.png
60 | /JPEGImages/2007_002088.jpg /SegmentationClassAug/2007_002088.png
61 | /JPEGImages/2007_002099.jpg /SegmentationClassAug/2007_002099.png
62 | /JPEGImages/2007_002105.jpg /SegmentationClassAug/2007_002105.png
63 | /JPEGImages/2007_002212.jpg /SegmentationClassAug/2007_002212.png
64 | /JPEGImages/2007_002216.jpg /SegmentationClassAug/2007_002216.png
65 | /JPEGImages/2007_002227.jpg /SegmentationClassAug/2007_002227.png
66 | /JPEGImages/2007_002234.jpg /SegmentationClassAug/2007_002234.png
67 | /JPEGImages/2007_002281.jpg /SegmentationClassAug/2007_002281.png
68 | /JPEGImages/2007_002361.jpg /SegmentationClassAug/2007_002361.png
69 | /JPEGImages/2007_002368.jpg /SegmentationClassAug/2007_002368.png
70 | /JPEGImages/2007_002370.jpg /SegmentationClassAug/2007_002370.png
71 | /JPEGImages/2007_002462.jpg /SegmentationClassAug/2007_002462.png
72 | /JPEGImages/2007_002760.jpg /SegmentationClassAug/2007_002760.png
73 | /JPEGImages/2007_002845.jpg /SegmentationClassAug/2007_002845.png
74 | /JPEGImages/2007_002896.jpg /SegmentationClassAug/2007_002896.png
75 | /JPEGImages/2007_002953.jpg /SegmentationClassAug/2007_002953.png
76 | /JPEGImages/2007_002967.jpg /SegmentationClassAug/2007_002967.png
77 | /JPEGImages/2007_003178.jpg /SegmentationClassAug/2007_003178.png
78 | /JPEGImages/2007_003189.jpg /SegmentationClassAug/2007_003189.png
79 | /JPEGImages/2007_003190.jpg /SegmentationClassAug/2007_003190.png
80 | /JPEGImages/2007_003207.jpg /SegmentationClassAug/2007_003207.png
81 | /JPEGImages/2007_003251.jpg /SegmentationClassAug/2007_003251.png
82 | /JPEGImages/2007_003286.jpg /SegmentationClassAug/2007_003286.png
83 | /JPEGImages/2007_003525.jpg /SegmentationClassAug/2007_003525.png
84 | /JPEGImages/2007_003593.jpg /SegmentationClassAug/2007_003593.png
85 | /JPEGImages/2007_003604.jpg /SegmentationClassAug/2007_003604.png
86 | /JPEGImages/2007_003788.jpg /SegmentationClassAug/2007_003788.png
87 | /JPEGImages/2007_003815.jpg /SegmentationClassAug/2007_003815.png
88 | /JPEGImages/2007_004081.jpg /SegmentationClassAug/2007_004081.png
89 | /JPEGImages/2007_004627.jpg /SegmentationClassAug/2007_004627.png
90 | /JPEGImages/2007_004707.jpg /SegmentationClassAug/2007_004707.png
91 | /JPEGImages/2007_005210.jpg /SegmentationClassAug/2007_005210.png
92 | /JPEGImages/2007_005273.jpg /SegmentationClassAug/2007_005273.png
93 | /JPEGImages/2007_005902.jpg /SegmentationClassAug/2007_005902.png
94 | /JPEGImages/2007_006530.jpg /SegmentationClassAug/2007_006530.png
95 | /JPEGImages/2007_006581.jpg /SegmentationClassAug/2007_006581.png
96 | /JPEGImages/2007_006605.jpg /SegmentationClassAug/2007_006605.png
97 | /JPEGImages/2007_007432.jpg /SegmentationClassAug/2007_007432.png
98 | /JPEGImages/2007_009709.jpg /SegmentationClassAug/2007_009709.png
99 | /JPEGImages/2007_009788.jpg /SegmentationClassAug/2007_009788.png
100 | /JPEGImages/2008_000015.jpg /SegmentationClassAug/2008_000015.png
101 |
--------------------------------------------------------------------------------
/dataloaders/voc_splits/200_train_supervised.txt:
--------------------------------------------------------------------------------
1 | /JPEGImages/2007_000032.jpg /SegmentationClassAug/2007_000032.png
2 | /JPEGImages/2007_000039.jpg /SegmentationClassAug/2007_000039.png
3 | /JPEGImages/2007_000063.jpg /SegmentationClassAug/2007_000063.png
4 | /JPEGImages/2007_000068.jpg /SegmentationClassAug/2007_000068.png
5 | /JPEGImages/2007_000121.jpg /SegmentationClassAug/2007_000121.png
6 | /JPEGImages/2007_000170.jpg /SegmentationClassAug/2007_000170.png
7 | /JPEGImages/2007_000241.jpg /SegmentationClassAug/2007_000241.png
8 | /JPEGImages/2007_000243.jpg /SegmentationClassAug/2007_000243.png
9 | /JPEGImages/2007_000250.jpg /SegmentationClassAug/2007_000250.png
10 | /JPEGImages/2007_000256.jpg /SegmentationClassAug/2007_000256.png
11 | /JPEGImages/2007_000333.jpg /SegmentationClassAug/2007_000333.png
12 | /JPEGImages/2007_000363.jpg /SegmentationClassAug/2007_000363.png
13 | /JPEGImages/2007_000364.jpg /SegmentationClassAug/2007_000364.png
14 | /JPEGImages/2007_000392.jpg /SegmentationClassAug/2007_000392.png
15 | /JPEGImages/2007_000480.jpg /SegmentationClassAug/2007_000480.png
16 | /JPEGImages/2007_000504.jpg /SegmentationClassAug/2007_000504.png
17 | /JPEGImages/2007_000515.jpg /SegmentationClassAug/2007_000515.png
18 | /JPEGImages/2007_000528.jpg /SegmentationClassAug/2007_000528.png
19 | /JPEGImages/2007_000549.jpg /SegmentationClassAug/2007_000549.png
20 | /JPEGImages/2007_000584.jpg /SegmentationClassAug/2007_000584.png
21 | /JPEGImages/2007_000645.jpg /SegmentationClassAug/2007_000645.png
22 | /JPEGImages/2007_000648.jpg /SegmentationClassAug/2007_000648.png
23 | /JPEGImages/2007_000713.jpg /SegmentationClassAug/2007_000713.png
24 | /JPEGImages/2007_000720.jpg /SegmentationClassAug/2007_000720.png
25 | /JPEGImages/2007_000733.jpg /SegmentationClassAug/2007_000733.png
26 | /JPEGImages/2007_000738.jpg /SegmentationClassAug/2007_000738.png
27 | /JPEGImages/2007_000768.jpg /SegmentationClassAug/2007_000768.png
28 | /JPEGImages/2007_000793.jpg /SegmentationClassAug/2007_000793.png
29 | /JPEGImages/2007_000822.jpg /SegmentationClassAug/2007_000822.png
30 | /JPEGImages/2007_000836.jpg /SegmentationClassAug/2007_000836.png
31 | /JPEGImages/2007_000876.jpg /SegmentationClassAug/2007_000876.png
32 | /JPEGImages/2007_000904.jpg /SegmentationClassAug/2007_000904.png
33 | /JPEGImages/2007_001027.jpg /SegmentationClassAug/2007_001027.png
34 | /JPEGImages/2007_001073.jpg /SegmentationClassAug/2007_001073.png
35 | /JPEGImages/2007_001149.jpg /SegmentationClassAug/2007_001149.png
36 | /JPEGImages/2007_001185.jpg /SegmentationClassAug/2007_001185.png
37 | /JPEGImages/2007_001225.jpg /SegmentationClassAug/2007_001225.png
38 | /JPEGImages/2007_001340.jpg /SegmentationClassAug/2007_001340.png
39 | /JPEGImages/2007_001397.jpg /SegmentationClassAug/2007_001397.png
40 | /JPEGImages/2007_001416.jpg /SegmentationClassAug/2007_001416.png
41 | /JPEGImages/2007_001420.jpg /SegmentationClassAug/2007_001420.png
42 | /JPEGImages/2007_001439.jpg /SegmentationClassAug/2007_001439.png
43 | /JPEGImages/2007_001487.jpg /SegmentationClassAug/2007_001487.png
44 | /JPEGImages/2007_001595.jpg /SegmentationClassAug/2007_001595.png
45 | /JPEGImages/2007_001602.jpg /SegmentationClassAug/2007_001602.png
46 | /JPEGImages/2007_001609.jpg /SegmentationClassAug/2007_001609.png
47 | /JPEGImages/2007_001698.jpg /SegmentationClassAug/2007_001698.png
48 | /JPEGImages/2007_001704.jpg /SegmentationClassAug/2007_001704.png
49 | /JPEGImages/2007_001709.jpg /SegmentationClassAug/2007_001709.png
50 | /JPEGImages/2007_001724.jpg /SegmentationClassAug/2007_001724.png
51 | /JPEGImages/2007_001764.jpg /SegmentationClassAug/2007_001764.png
52 | /JPEGImages/2007_001825.jpg /SegmentationClassAug/2007_001825.png
53 | /JPEGImages/2007_001834.jpg /SegmentationClassAug/2007_001834.png
54 | /JPEGImages/2007_001857.jpg /SegmentationClassAug/2007_001857.png
55 | /JPEGImages/2007_001872.jpg /SegmentationClassAug/2007_001872.png
56 | /JPEGImages/2007_001901.jpg /SegmentationClassAug/2007_001901.png
57 | /JPEGImages/2007_001917.jpg /SegmentationClassAug/2007_001917.png
58 | /JPEGImages/2007_001960.jpg /SegmentationClassAug/2007_001960.png
59 | /JPEGImages/2007_002024.jpg /SegmentationClassAug/2007_002024.png
60 | /JPEGImages/2007_002055.jpg /SegmentationClassAug/2007_002055.png
61 | /JPEGImages/2007_002088.jpg /SegmentationClassAug/2007_002088.png
62 | /JPEGImages/2007_002099.jpg /SegmentationClassAug/2007_002099.png
63 | /JPEGImages/2007_002105.jpg /SegmentationClassAug/2007_002105.png
64 | /JPEGImages/2007_002107.jpg /SegmentationClassAug/2007_002107.png
65 | /JPEGImages/2007_002120.jpg /SegmentationClassAug/2007_002120.png
66 | /JPEGImages/2007_002142.jpg /SegmentationClassAug/2007_002142.png
67 | /JPEGImages/2007_002198.jpg /SegmentationClassAug/2007_002198.png
68 | /JPEGImages/2007_002212.jpg /SegmentationClassAug/2007_002212.png
69 | /JPEGImages/2007_002216.jpg /SegmentationClassAug/2007_002216.png
70 | /JPEGImages/2007_002227.jpg /SegmentationClassAug/2007_002227.png
71 | /JPEGImages/2007_002234.jpg /SegmentationClassAug/2007_002234.png
72 | /JPEGImages/2007_002273.jpg /SegmentationClassAug/2007_002273.png
73 | /JPEGImages/2007_002281.jpg /SegmentationClassAug/2007_002281.png
74 | /JPEGImages/2007_002293.jpg /SegmentationClassAug/2007_002293.png
75 | /JPEGImages/2007_002361.jpg /SegmentationClassAug/2007_002361.png
76 | /JPEGImages/2007_002368.jpg /SegmentationClassAug/2007_002368.png
77 | /JPEGImages/2007_002370.jpg /SegmentationClassAug/2007_002370.png
78 | /JPEGImages/2007_002403.jpg /SegmentationClassAug/2007_002403.png
79 | /JPEGImages/2007_002462.jpg /SegmentationClassAug/2007_002462.png
80 | /JPEGImages/2007_002488.jpg /SegmentationClassAug/2007_002488.png
81 | /JPEGImages/2007_002545.jpg /SegmentationClassAug/2007_002545.png
82 | /JPEGImages/2007_002611.jpg /SegmentationClassAug/2007_002611.png
83 | /JPEGImages/2007_002669.jpg /SegmentationClassAug/2007_002669.png
84 | /JPEGImages/2007_002760.jpg /SegmentationClassAug/2007_002760.png
85 | /JPEGImages/2007_002789.jpg /SegmentationClassAug/2007_002789.png
86 | /JPEGImages/2007_002845.jpg /SegmentationClassAug/2007_002845.png
87 | /JPEGImages/2007_002896.jpg /SegmentationClassAug/2007_002896.png
88 | /JPEGImages/2007_002953.jpg /SegmentationClassAug/2007_002953.png
89 | /JPEGImages/2007_002967.jpg /SegmentationClassAug/2007_002967.png
90 | /JPEGImages/2007_003000.jpg /SegmentationClassAug/2007_003000.png
91 | /JPEGImages/2007_003178.jpg /SegmentationClassAug/2007_003178.png
92 | /JPEGImages/2007_003189.jpg /SegmentationClassAug/2007_003189.png
93 | /JPEGImages/2007_003190.jpg /SegmentationClassAug/2007_003190.png
94 | /JPEGImages/2007_003207.jpg /SegmentationClassAug/2007_003207.png
95 | /JPEGImages/2007_003251.jpg /SegmentationClassAug/2007_003251.png
96 | /JPEGImages/2007_003267.jpg /SegmentationClassAug/2007_003267.png
97 | /JPEGImages/2007_003286.jpg /SegmentationClassAug/2007_003286.png
98 | /JPEGImages/2007_003330.jpg /SegmentationClassAug/2007_003330.png
99 | /JPEGImages/2007_003451.jpg /SegmentationClassAug/2007_003451.png
100 | /JPEGImages/2007_003525.jpg /SegmentationClassAug/2007_003525.png
101 | /JPEGImages/2007_003565.jpg /SegmentationClassAug/2007_003565.png
102 | /JPEGImages/2007_003593.jpg /SegmentationClassAug/2007_003593.png
103 | /JPEGImages/2007_003604.jpg /SegmentationClassAug/2007_003604.png
104 | /JPEGImages/2007_003668.jpg /SegmentationClassAug/2007_003668.png
105 | /JPEGImages/2007_003715.jpg /SegmentationClassAug/2007_003715.png
106 | /JPEGImages/2007_003778.jpg /SegmentationClassAug/2007_003778.png
107 | /JPEGImages/2007_003788.jpg /SegmentationClassAug/2007_003788.png
108 | /JPEGImages/2007_003815.jpg /SegmentationClassAug/2007_003815.png
109 | /JPEGImages/2007_003876.jpg /SegmentationClassAug/2007_003876.png
110 | /JPEGImages/2007_003889.jpg /SegmentationClassAug/2007_003889.png
111 | /JPEGImages/2007_003910.jpg /SegmentationClassAug/2007_003910.png
112 | /JPEGImages/2007_004003.jpg /SegmentationClassAug/2007_004003.png
113 | /JPEGImages/2007_004009.jpg /SegmentationClassAug/2007_004009.png
114 | /JPEGImages/2007_004065.jpg /SegmentationClassAug/2007_004065.png
115 | /JPEGImages/2007_004081.jpg /SegmentationClassAug/2007_004081.png
116 | /JPEGImages/2007_004166.jpg /SegmentationClassAug/2007_004166.png
117 | /JPEGImages/2007_004423.jpg /SegmentationClassAug/2007_004423.png
118 | /JPEGImages/2007_004481.jpg /SegmentationClassAug/2007_004481.png
119 | /JPEGImages/2007_004500.jpg /SegmentationClassAug/2007_004500.png
120 | /JPEGImages/2007_004537.jpg /SegmentationClassAug/2007_004537.png
121 | /JPEGImages/2007_004627.jpg /SegmentationClassAug/2007_004627.png
122 | /JPEGImages/2007_004663.jpg /SegmentationClassAug/2007_004663.png
123 | /JPEGImages/2007_004705.jpg /SegmentationClassAug/2007_004705.png
124 | /JPEGImages/2007_004707.jpg /SegmentationClassAug/2007_004707.png
125 | /JPEGImages/2007_004768.jpg /SegmentationClassAug/2007_004768.png
126 | /JPEGImages/2007_004810.jpg /SegmentationClassAug/2007_004810.png
127 | /JPEGImages/2007_004830.jpg /SegmentationClassAug/2007_004830.png
128 | /JPEGImages/2007_004948.jpg /SegmentationClassAug/2007_004948.png
129 | /JPEGImages/2007_004951.jpg /SegmentationClassAug/2007_004951.png
130 | /JPEGImages/2007_004998.jpg /SegmentationClassAug/2007_004998.png
131 | /JPEGImages/2007_005124.jpg /SegmentationClassAug/2007_005124.png
132 | /JPEGImages/2007_005130.jpg /SegmentationClassAug/2007_005130.png
133 | /JPEGImages/2007_005210.jpg /SegmentationClassAug/2007_005210.png
134 | /JPEGImages/2007_005212.jpg /SegmentationClassAug/2007_005212.png
135 | /JPEGImages/2007_005248.jpg /SegmentationClassAug/2007_005248.png
136 | /JPEGImages/2007_005262.jpg /SegmentationClassAug/2007_005262.png
137 | /JPEGImages/2007_005264.jpg /SegmentationClassAug/2007_005264.png
138 | /JPEGImages/2007_005266.jpg /SegmentationClassAug/2007_005266.png
139 | /JPEGImages/2007_005273.jpg /SegmentationClassAug/2007_005273.png
140 | /JPEGImages/2007_005314.jpg /SegmentationClassAug/2007_005314.png
141 | /JPEGImages/2007_005360.jpg /SegmentationClassAug/2007_005360.png
142 | /JPEGImages/2007_005647.jpg /SegmentationClassAug/2007_005647.png
143 | /JPEGImages/2007_005688.jpg /SegmentationClassAug/2007_005688.png
144 | /JPEGImages/2007_005878.jpg /SegmentationClassAug/2007_005878.png
145 | /JPEGImages/2007_005902.jpg /SegmentationClassAug/2007_005902.png
146 | /JPEGImages/2007_005951.jpg /SegmentationClassAug/2007_005951.png
147 | /JPEGImages/2007_006066.jpg /SegmentationClassAug/2007_006066.png
148 | /JPEGImages/2007_006134.jpg /SegmentationClassAug/2007_006134.png
149 | /JPEGImages/2007_006136.jpg /SegmentationClassAug/2007_006136.png
150 | /JPEGImages/2007_006151.jpg /SegmentationClassAug/2007_006151.png
151 | /JPEGImages/2007_006254.jpg /SegmentationClassAug/2007_006254.png
152 | /JPEGImages/2007_006281.jpg /SegmentationClassAug/2007_006281.png
153 | /JPEGImages/2007_006303.jpg /SegmentationClassAug/2007_006303.png
154 | /JPEGImages/2007_006317.jpg /SegmentationClassAug/2007_006317.png
155 | /JPEGImages/2007_006400.jpg /SegmentationClassAug/2007_006400.png
156 | /JPEGImages/2007_006409.jpg /SegmentationClassAug/2007_006409.png
157 | /JPEGImages/2007_006490.jpg /SegmentationClassAug/2007_006490.png
158 | /JPEGImages/2007_006530.jpg /SegmentationClassAug/2007_006530.png
159 | /JPEGImages/2007_006581.jpg /SegmentationClassAug/2007_006581.png
160 | /JPEGImages/2007_006605.jpg /SegmentationClassAug/2007_006605.png
161 | /JPEGImages/2007_006641.jpg /SegmentationClassAug/2007_006641.png
162 | /JPEGImages/2007_006660.jpg /SegmentationClassAug/2007_006660.png
163 | /JPEGImages/2007_006699.jpg /SegmentationClassAug/2007_006699.png
164 | /JPEGImages/2007_006704.jpg /SegmentationClassAug/2007_006704.png
165 | /JPEGImages/2007_006832.jpg /SegmentationClassAug/2007_006832.png
166 | /JPEGImages/2007_006899.jpg /SegmentationClassAug/2007_006899.png
167 | /JPEGImages/2007_006900.jpg /SegmentationClassAug/2007_006900.png
168 | /JPEGImages/2007_007098.jpg /SegmentationClassAug/2007_007098.png
169 | /JPEGImages/2007_007250.jpg /SegmentationClassAug/2007_007250.png
170 | /JPEGImages/2007_007398.jpg /SegmentationClassAug/2007_007398.png
171 | /JPEGImages/2007_007432.jpg /SegmentationClassAug/2007_007432.png
172 | /JPEGImages/2007_007530.jpg /SegmentationClassAug/2007_007530.png
173 | /JPEGImages/2007_007585.jpg /SegmentationClassAug/2007_007585.png
174 | /JPEGImages/2007_007890.jpg /SegmentationClassAug/2007_007890.png
175 | /JPEGImages/2007_007930.jpg /SegmentationClassAug/2007_007930.png
176 | /JPEGImages/2007_008140.jpg /SegmentationClassAug/2007_008140.png
177 | /JPEGImages/2007_008203.jpg /SegmentationClassAug/2007_008203.png
178 | /JPEGImages/2007_008468.jpg /SegmentationClassAug/2007_008468.png
179 | /JPEGImages/2007_008948.jpg /SegmentationClassAug/2007_008948.png
180 | /JPEGImages/2007_009216.jpg /SegmentationClassAug/2007_009216.png
181 | /JPEGImages/2007_009550.jpg /SegmentationClassAug/2007_009550.png
182 | /JPEGImages/2007_009605.jpg /SegmentationClassAug/2007_009605.png
183 | /JPEGImages/2007_009709.jpg /SegmentationClassAug/2007_009709.png
184 | /JPEGImages/2007_009788.jpg /SegmentationClassAug/2007_009788.png
185 | /JPEGImages/2007_009889.jpg /SegmentationClassAug/2007_009889.png
186 | /JPEGImages/2007_009899.jpg /SegmentationClassAug/2007_009899.png
187 | /JPEGImages/2008_000015.jpg /SegmentationClassAug/2008_000015.png
188 | /JPEGImages/2008_000043.jpg /SegmentationClassAug/2008_000043.png
189 | /JPEGImages/2008_000067.jpg /SegmentationClassAug/2008_000067.png
190 | /JPEGImages/2008_000133.jpg /SegmentationClassAug/2008_000133.png
191 | /JPEGImages/2008_000154.jpg /SegmentationClassAug/2008_000154.png
192 | /JPEGImages/2008_000188.jpg /SegmentationClassAug/2008_000188.png
193 | /JPEGImages/2008_000191.jpg /SegmentationClassAug/2008_000191.png
194 | /JPEGImages/2008_000194.jpg /SegmentationClassAug/2008_000194.png
195 | /JPEGImages/2008_000196.jpg /SegmentationClassAug/2008_000196.png
196 | /JPEGImages/2008_000272.jpg /SegmentationClassAug/2008_000272.png
197 | /JPEGImages/2008_000703.jpg /SegmentationClassAug/2008_000703.png
198 | /JPEGImages/2008_001225.jpg /SegmentationClassAug/2008_001225.png
199 | /JPEGImages/2008_001405.jpg /SegmentationClassAug/2008_001405.png
200 | /JPEGImages/2008_001744.jpg /SegmentationClassAug/2008_001744.png
201 |
--------------------------------------------------------------------------------
/dataloaders/voc_splits/300_train_supervised.txt:
--------------------------------------------------------------------------------
1 | /JPEGImages/2007_000032.jpg /SegmentationClassAug/2007_000032.png
2 | /JPEGImages/2007_000039.jpg /SegmentationClassAug/2007_000039.png
3 | /JPEGImages/2007_000063.jpg /SegmentationClassAug/2007_000063.png
4 | /JPEGImages/2007_000068.jpg /SegmentationClassAug/2007_000068.png
5 | /JPEGImages/2007_000121.jpg /SegmentationClassAug/2007_000121.png
6 | /JPEGImages/2007_000170.jpg /SegmentationClassAug/2007_000170.png
7 | /JPEGImages/2007_000241.jpg /SegmentationClassAug/2007_000241.png
8 | /JPEGImages/2007_000243.jpg /SegmentationClassAug/2007_000243.png
9 | /JPEGImages/2007_000250.jpg /SegmentationClassAug/2007_000250.png
10 | /JPEGImages/2007_000256.jpg /SegmentationClassAug/2007_000256.png
11 | /JPEGImages/2007_000333.jpg /SegmentationClassAug/2007_000333.png
12 | /JPEGImages/2007_000363.jpg /SegmentationClassAug/2007_000363.png
13 | /JPEGImages/2007_000364.jpg /SegmentationClassAug/2007_000364.png
14 | /JPEGImages/2007_000392.jpg /SegmentationClassAug/2007_000392.png
15 | /JPEGImages/2007_000480.jpg /SegmentationClassAug/2007_000480.png
16 | /JPEGImages/2007_000504.jpg /SegmentationClassAug/2007_000504.png
17 | /JPEGImages/2007_000515.jpg /SegmentationClassAug/2007_000515.png
18 | /JPEGImages/2007_000528.jpg /SegmentationClassAug/2007_000528.png
19 | /JPEGImages/2007_000549.jpg /SegmentationClassAug/2007_000549.png
20 | /JPEGImages/2007_000584.jpg /SegmentationClassAug/2007_000584.png
21 | /JPEGImages/2007_000645.jpg /SegmentationClassAug/2007_000645.png
22 | /JPEGImages/2007_000648.jpg /SegmentationClassAug/2007_000648.png
23 | /JPEGImages/2007_000713.jpg /SegmentationClassAug/2007_000713.png
24 | /JPEGImages/2007_000720.jpg /SegmentationClassAug/2007_000720.png
25 | /JPEGImages/2007_000733.jpg /SegmentationClassAug/2007_000733.png
26 | /JPEGImages/2007_000738.jpg /SegmentationClassAug/2007_000738.png
27 | /JPEGImages/2007_000768.jpg /SegmentationClassAug/2007_000768.png
28 | /JPEGImages/2007_000793.jpg /SegmentationClassAug/2007_000793.png
29 | /JPEGImages/2007_000822.jpg /SegmentationClassAug/2007_000822.png
30 | /JPEGImages/2007_000836.jpg /SegmentationClassAug/2007_000836.png
31 | /JPEGImages/2007_000876.jpg /SegmentationClassAug/2007_000876.png
32 | /JPEGImages/2007_000904.jpg /SegmentationClassAug/2007_000904.png
33 | /JPEGImages/2007_001027.jpg /SegmentationClassAug/2007_001027.png
34 | /JPEGImages/2007_001073.jpg /SegmentationClassAug/2007_001073.png
35 | /JPEGImages/2007_001149.jpg /SegmentationClassAug/2007_001149.png
36 | /JPEGImages/2007_001185.jpg /SegmentationClassAug/2007_001185.png
37 | /JPEGImages/2007_001225.jpg /SegmentationClassAug/2007_001225.png
38 | /JPEGImages/2007_001340.jpg /SegmentationClassAug/2007_001340.png
39 | /JPEGImages/2007_001397.jpg /SegmentationClassAug/2007_001397.png
40 | /JPEGImages/2007_001416.jpg /SegmentationClassAug/2007_001416.png
41 | /JPEGImages/2007_001420.jpg /SegmentationClassAug/2007_001420.png
42 | /JPEGImages/2007_001439.jpg /SegmentationClassAug/2007_001439.png
43 | /JPEGImages/2007_001487.jpg /SegmentationClassAug/2007_001487.png
44 | /JPEGImages/2007_001595.jpg /SegmentationClassAug/2007_001595.png
45 | /JPEGImages/2007_001602.jpg /SegmentationClassAug/2007_001602.png
46 | /JPEGImages/2007_001609.jpg /SegmentationClassAug/2007_001609.png
47 | /JPEGImages/2007_001698.jpg /SegmentationClassAug/2007_001698.png
48 | /JPEGImages/2007_001704.jpg /SegmentationClassAug/2007_001704.png
49 | /JPEGImages/2007_001709.jpg /SegmentationClassAug/2007_001709.png
50 | /JPEGImages/2007_001724.jpg /SegmentationClassAug/2007_001724.png
51 | /JPEGImages/2007_001764.jpg /SegmentationClassAug/2007_001764.png
52 | /JPEGImages/2007_001825.jpg /SegmentationClassAug/2007_001825.png
53 | /JPEGImages/2007_001834.jpg /SegmentationClassAug/2007_001834.png
54 | /JPEGImages/2007_001857.jpg /SegmentationClassAug/2007_001857.png
55 | /JPEGImages/2007_001872.jpg /SegmentationClassAug/2007_001872.png
56 | /JPEGImages/2007_001901.jpg /SegmentationClassAug/2007_001901.png
57 | /JPEGImages/2007_001917.jpg /SegmentationClassAug/2007_001917.png
58 | /JPEGImages/2007_001960.jpg /SegmentationClassAug/2007_001960.png
59 | /JPEGImages/2007_002024.jpg /SegmentationClassAug/2007_002024.png
60 | /JPEGImages/2007_002055.jpg /SegmentationClassAug/2007_002055.png
61 | /JPEGImages/2007_002088.jpg /SegmentationClassAug/2007_002088.png
62 | /JPEGImages/2007_002099.jpg /SegmentationClassAug/2007_002099.png
63 | /JPEGImages/2007_002105.jpg /SegmentationClassAug/2007_002105.png
64 | /JPEGImages/2007_002107.jpg /SegmentationClassAug/2007_002107.png
65 | /JPEGImages/2007_002120.jpg /SegmentationClassAug/2007_002120.png
66 | /JPEGImages/2007_002142.jpg /SegmentationClassAug/2007_002142.png
67 | /JPEGImages/2007_002198.jpg /SegmentationClassAug/2007_002198.png
68 | /JPEGImages/2007_002212.jpg /SegmentationClassAug/2007_002212.png
69 | /JPEGImages/2007_002216.jpg /SegmentationClassAug/2007_002216.png
70 | /JPEGImages/2007_002227.jpg /SegmentationClassAug/2007_002227.png
71 | /JPEGImages/2007_002234.jpg /SegmentationClassAug/2007_002234.png
72 | /JPEGImages/2007_002273.jpg /SegmentationClassAug/2007_002273.png
73 | /JPEGImages/2007_002281.jpg /SegmentationClassAug/2007_002281.png
74 | /JPEGImages/2007_002293.jpg /SegmentationClassAug/2007_002293.png
75 | /JPEGImages/2007_002361.jpg /SegmentationClassAug/2007_002361.png
76 | /JPEGImages/2007_002368.jpg /SegmentationClassAug/2007_002368.png
77 | /JPEGImages/2007_002370.jpg /SegmentationClassAug/2007_002370.png
78 | /JPEGImages/2007_002403.jpg /SegmentationClassAug/2007_002403.png
79 | /JPEGImages/2007_002462.jpg /SegmentationClassAug/2007_002462.png
80 | /JPEGImages/2007_002488.jpg /SegmentationClassAug/2007_002488.png
81 | /JPEGImages/2007_002545.jpg /SegmentationClassAug/2007_002545.png
82 | /JPEGImages/2007_002611.jpg /SegmentationClassAug/2007_002611.png
83 | /JPEGImages/2007_002639.jpg /SegmentationClassAug/2007_002639.png
84 | /JPEGImages/2007_002668.jpg /SegmentationClassAug/2007_002668.png
85 | /JPEGImages/2007_002669.jpg /SegmentationClassAug/2007_002669.png
86 | /JPEGImages/2007_002760.jpg /SegmentationClassAug/2007_002760.png
87 | /JPEGImages/2007_002789.jpg /SegmentationClassAug/2007_002789.png
88 | /JPEGImages/2007_002845.jpg /SegmentationClassAug/2007_002845.png
89 | /JPEGImages/2007_002895.jpg /SegmentationClassAug/2007_002895.png
90 | /JPEGImages/2007_002896.jpg /SegmentationClassAug/2007_002896.png
91 | /JPEGImages/2007_002914.jpg /SegmentationClassAug/2007_002914.png
92 | /JPEGImages/2007_002953.jpg /SegmentationClassAug/2007_002953.png
93 | /JPEGImages/2007_002954.jpg /SegmentationClassAug/2007_002954.png
94 | /JPEGImages/2007_002967.jpg /SegmentationClassAug/2007_002967.png
95 | /JPEGImages/2007_003000.jpg /SegmentationClassAug/2007_003000.png
96 | /JPEGImages/2007_003178.jpg /SegmentationClassAug/2007_003178.png
97 | /JPEGImages/2007_003189.jpg /SegmentationClassAug/2007_003189.png
98 | /JPEGImages/2007_003190.jpg /SegmentationClassAug/2007_003190.png
99 | /JPEGImages/2007_003207.jpg /SegmentationClassAug/2007_003207.png
100 | /JPEGImages/2007_003251.jpg /SegmentationClassAug/2007_003251.png
101 | /JPEGImages/2007_003267.jpg /SegmentationClassAug/2007_003267.png
102 | /JPEGImages/2007_003286.jpg /SegmentationClassAug/2007_003286.png
103 | /JPEGImages/2007_003330.jpg /SegmentationClassAug/2007_003330.png
104 | /JPEGImages/2007_003451.jpg /SegmentationClassAug/2007_003451.png
105 | /JPEGImages/2007_003525.jpg /SegmentationClassAug/2007_003525.png
106 | /JPEGImages/2007_003565.jpg /SegmentationClassAug/2007_003565.png
107 | /JPEGImages/2007_003593.jpg /SegmentationClassAug/2007_003593.png
108 | /JPEGImages/2007_003604.jpg /SegmentationClassAug/2007_003604.png
109 | /JPEGImages/2007_003668.jpg /SegmentationClassAug/2007_003668.png
110 | /JPEGImages/2007_003715.jpg /SegmentationClassAug/2007_003715.png
111 | /JPEGImages/2007_003778.jpg /SegmentationClassAug/2007_003778.png
112 | /JPEGImages/2007_003788.jpg /SegmentationClassAug/2007_003788.png
113 | /JPEGImages/2007_003815.jpg /SegmentationClassAug/2007_003815.png
114 | /JPEGImages/2007_003876.jpg /SegmentationClassAug/2007_003876.png
115 | /JPEGImages/2007_003889.jpg /SegmentationClassAug/2007_003889.png
116 | /JPEGImages/2007_003910.jpg /SegmentationClassAug/2007_003910.png
117 | /JPEGImages/2007_004003.jpg /SegmentationClassAug/2007_004003.png
118 | /JPEGImages/2007_004009.jpg /SegmentationClassAug/2007_004009.png
119 | /JPEGImages/2007_004065.jpg /SegmentationClassAug/2007_004065.png
120 | /JPEGImages/2007_004081.jpg /SegmentationClassAug/2007_004081.png
121 | /JPEGImages/2007_004166.jpg /SegmentationClassAug/2007_004166.png
122 | /JPEGImages/2007_004423.jpg /SegmentationClassAug/2007_004423.png
123 | /JPEGImages/2007_004459.jpg /SegmentationClassAug/2007_004459.png
124 | /JPEGImages/2007_004481.jpg /SegmentationClassAug/2007_004481.png
125 | /JPEGImages/2007_004500.jpg /SegmentationClassAug/2007_004500.png
126 | /JPEGImages/2007_004537.jpg /SegmentationClassAug/2007_004537.png
127 | /JPEGImages/2007_004627.jpg /SegmentationClassAug/2007_004627.png
128 | /JPEGImages/2007_004663.jpg /SegmentationClassAug/2007_004663.png
129 | /JPEGImages/2007_004705.jpg /SegmentationClassAug/2007_004705.png
130 | /JPEGImages/2007_004707.jpg /SegmentationClassAug/2007_004707.png
131 | /JPEGImages/2007_004768.jpg /SegmentationClassAug/2007_004768.png
132 | /JPEGImages/2007_004810.jpg /SegmentationClassAug/2007_004810.png
133 | /JPEGImages/2007_004830.jpg /SegmentationClassAug/2007_004830.png
134 | /JPEGImages/2007_004841.jpg /SegmentationClassAug/2007_004841.png
135 | /JPEGImages/2007_004948.jpg /SegmentationClassAug/2007_004948.png
136 | /JPEGImages/2007_004951.jpg /SegmentationClassAug/2007_004951.png
137 | /JPEGImages/2007_004988.jpg /SegmentationClassAug/2007_004988.png
138 | /JPEGImages/2007_004998.jpg /SegmentationClassAug/2007_004998.png
139 | /JPEGImages/2007_005043.jpg /SegmentationClassAug/2007_005043.png
140 | /JPEGImages/2007_005124.jpg /SegmentationClassAug/2007_005124.png
141 | /JPEGImages/2007_005130.jpg /SegmentationClassAug/2007_005130.png
142 | /JPEGImages/2007_005210.jpg /SegmentationClassAug/2007_005210.png
143 | /JPEGImages/2007_005212.jpg /SegmentationClassAug/2007_005212.png
144 | /JPEGImages/2007_005248.jpg /SegmentationClassAug/2007_005248.png
145 | /JPEGImages/2007_005262.jpg /SegmentationClassAug/2007_005262.png
146 | /JPEGImages/2007_005264.jpg /SegmentationClassAug/2007_005264.png
147 | /JPEGImages/2007_005266.jpg /SegmentationClassAug/2007_005266.png
148 | /JPEGImages/2007_005273.jpg /SegmentationClassAug/2007_005273.png
149 | /JPEGImages/2007_005314.jpg /SegmentationClassAug/2007_005314.png
150 | /JPEGImages/2007_005360.jpg /SegmentationClassAug/2007_005360.png
151 | /JPEGImages/2007_005647.jpg /SegmentationClassAug/2007_005647.png
152 | /JPEGImages/2007_005688.jpg /SegmentationClassAug/2007_005688.png
153 | /JPEGImages/2007_005797.jpg /SegmentationClassAug/2007_005797.png
154 | /JPEGImages/2007_005878.jpg /SegmentationClassAug/2007_005878.png
155 | /JPEGImages/2007_005902.jpg /SegmentationClassAug/2007_005902.png
156 | /JPEGImages/2007_005951.jpg /SegmentationClassAug/2007_005951.png
157 | /JPEGImages/2007_005989.jpg /SegmentationClassAug/2007_005989.png
158 | /JPEGImages/2007_006066.jpg /SegmentationClassAug/2007_006066.png
159 | /JPEGImages/2007_006134.jpg /SegmentationClassAug/2007_006134.png
160 | /JPEGImages/2007_006136.jpg /SegmentationClassAug/2007_006136.png
161 | /JPEGImages/2007_006151.jpg /SegmentationClassAug/2007_006151.png
162 | /JPEGImages/2007_006212.jpg /SegmentationClassAug/2007_006212.png
163 | /JPEGImages/2007_006254.jpg /SegmentationClassAug/2007_006254.png
164 | /JPEGImages/2007_006281.jpg /SegmentationClassAug/2007_006281.png
165 | /JPEGImages/2007_006303.jpg /SegmentationClassAug/2007_006303.png
166 | /JPEGImages/2007_006317.jpg /SegmentationClassAug/2007_006317.png
167 | /JPEGImages/2007_006400.jpg /SegmentationClassAug/2007_006400.png
168 | /JPEGImages/2007_006409.jpg /SegmentationClassAug/2007_006409.png
169 | /JPEGImages/2007_006445.jpg /SegmentationClassAug/2007_006445.png
170 | /JPEGImages/2007_006490.jpg /SegmentationClassAug/2007_006490.png
171 | /JPEGImages/2007_006530.jpg /SegmentationClassAug/2007_006530.png
172 | /JPEGImages/2007_006581.jpg /SegmentationClassAug/2007_006581.png
173 | /JPEGImages/2007_006585.jpg /SegmentationClassAug/2007_006585.png
174 | /JPEGImages/2007_006605.jpg /SegmentationClassAug/2007_006605.png
175 | /JPEGImages/2007_006641.jpg /SegmentationClassAug/2007_006641.png
176 | /JPEGImages/2007_006660.jpg /SegmentationClassAug/2007_006660.png
177 | /JPEGImages/2007_006673.jpg /SegmentationClassAug/2007_006673.png
178 | /JPEGImages/2007_006699.jpg /SegmentationClassAug/2007_006699.png
179 | /JPEGImages/2007_006704.jpg /SegmentationClassAug/2007_006704.png
180 | /JPEGImages/2007_006803.jpg /SegmentationClassAug/2007_006803.png
181 | /JPEGImages/2007_006832.jpg /SegmentationClassAug/2007_006832.png
182 | /JPEGImages/2007_006865.jpg /SegmentationClassAug/2007_006865.png
183 | /JPEGImages/2007_006899.jpg /SegmentationClassAug/2007_006899.png
184 | /JPEGImages/2007_006900.jpg /SegmentationClassAug/2007_006900.png
185 | /JPEGImages/2007_006944.jpg /SegmentationClassAug/2007_006944.png
186 | /JPEGImages/2007_007003.jpg /SegmentationClassAug/2007_007003.png
187 | /JPEGImages/2007_007021.jpg /SegmentationClassAug/2007_007021.png
188 | /JPEGImages/2007_007048.jpg /SegmentationClassAug/2007_007048.png
189 | /JPEGImages/2007_007098.jpg /SegmentationClassAug/2007_007098.png
190 | /JPEGImages/2007_007230.jpg /SegmentationClassAug/2007_007230.png
191 | /JPEGImages/2007_007250.jpg /SegmentationClassAug/2007_007250.png
192 | /JPEGImages/2007_007355.jpg /SegmentationClassAug/2007_007355.png
193 | /JPEGImages/2007_007387.jpg /SegmentationClassAug/2007_007387.png
194 | /JPEGImages/2007_007398.jpg /SegmentationClassAug/2007_007398.png
195 | /JPEGImages/2007_007415.jpg /SegmentationClassAug/2007_007415.png
196 | /JPEGImages/2007_007432.jpg /SegmentationClassAug/2007_007432.png
197 | /JPEGImages/2007_007480.jpg /SegmentationClassAug/2007_007480.png
198 | /JPEGImages/2007_007481.jpg /SegmentationClassAug/2007_007481.png
199 | /JPEGImages/2007_007523.jpg /SegmentationClassAug/2007_007523.png
200 | /JPEGImages/2007_007530.jpg /SegmentationClassAug/2007_007530.png
201 | /JPEGImages/2007_007585.jpg /SegmentationClassAug/2007_007585.png
202 | /JPEGImages/2007_007591.jpg /SegmentationClassAug/2007_007591.png
203 | /JPEGImages/2007_007621.jpg /SegmentationClassAug/2007_007621.png
204 | /JPEGImages/2007_007726.jpg /SegmentationClassAug/2007_007726.png
205 | /JPEGImages/2007_007772.jpg /SegmentationClassAug/2007_007772.png
206 | /JPEGImages/2007_007773.jpg /SegmentationClassAug/2007_007773.png
207 | /JPEGImages/2007_007783.jpg /SegmentationClassAug/2007_007783.png
208 | /JPEGImages/2007_007878.jpg /SegmentationClassAug/2007_007878.png
209 | /JPEGImages/2007_007890.jpg /SegmentationClassAug/2007_007890.png
210 | /JPEGImages/2007_007902.jpg /SegmentationClassAug/2007_007902.png
211 | /JPEGImages/2007_007908.jpg /SegmentationClassAug/2007_007908.png
212 | /JPEGImages/2007_007930.jpg /SegmentationClassAug/2007_007930.png
213 | /JPEGImages/2007_007947.jpg /SegmentationClassAug/2007_007947.png
214 | /JPEGImages/2007_007948.jpg /SegmentationClassAug/2007_007948.png
215 | /JPEGImages/2007_008085.jpg /SegmentationClassAug/2007_008085.png
216 | /JPEGImages/2007_008140.jpg /SegmentationClassAug/2007_008140.png
217 | /JPEGImages/2007_008142.jpg /SegmentationClassAug/2007_008142.png
218 | /JPEGImages/2007_008203.jpg /SegmentationClassAug/2007_008203.png
219 | /JPEGImages/2007_008219.jpg /SegmentationClassAug/2007_008219.png
220 | /JPEGImages/2007_008307.jpg /SegmentationClassAug/2007_008307.png
221 | /JPEGImages/2007_008403.jpg /SegmentationClassAug/2007_008403.png
222 | /JPEGImages/2007_008468.jpg /SegmentationClassAug/2007_008468.png
223 | /JPEGImages/2007_008526.jpg /SegmentationClassAug/2007_008526.png
224 | /JPEGImages/2007_008571.jpg /SegmentationClassAug/2007_008571.png
225 | /JPEGImages/2007_008575.jpg /SegmentationClassAug/2007_008575.png
226 | /JPEGImages/2007_008764.jpg /SegmentationClassAug/2007_008764.png
227 | /JPEGImages/2007_008821.jpg /SegmentationClassAug/2007_008821.png
228 | /JPEGImages/2007_008927.jpg /SegmentationClassAug/2007_008927.png
229 | /JPEGImages/2007_008945.jpg /SegmentationClassAug/2007_008945.png
230 | /JPEGImages/2007_008948.jpg /SegmentationClassAug/2007_008948.png
231 | /JPEGImages/2007_009052.jpg /SegmentationClassAug/2007_009052.png
232 | /JPEGImages/2007_009082.jpg /SegmentationClassAug/2007_009082.png
233 | /JPEGImages/2007_009216.jpg /SegmentationClassAug/2007_009216.png
234 | /JPEGImages/2007_009295.jpg /SegmentationClassAug/2007_009295.png
235 | /JPEGImages/2007_009322.jpg /SegmentationClassAug/2007_009322.png
236 | /JPEGImages/2007_009435.jpg /SegmentationClassAug/2007_009435.png
237 | /JPEGImages/2007_009436.jpg /SegmentationClassAug/2007_009436.png
238 | /JPEGImages/2007_009464.jpg /SegmentationClassAug/2007_009464.png
239 | /JPEGImages/2007_009527.jpg /SegmentationClassAug/2007_009527.png
240 | /JPEGImages/2007_009550.jpg /SegmentationClassAug/2007_009550.png
241 | /JPEGImages/2007_009594.jpg /SegmentationClassAug/2007_009594.png
242 | /JPEGImages/2007_009605.jpg /SegmentationClassAug/2007_009605.png
243 | /JPEGImages/2007_009630.jpg /SegmentationClassAug/2007_009630.png
244 | /JPEGImages/2007_009665.jpg /SegmentationClassAug/2007_009665.png
245 | /JPEGImages/2007_009709.jpg /SegmentationClassAug/2007_009709.png
246 | /JPEGImages/2007_009779.jpg /SegmentationClassAug/2007_009779.png
247 | /JPEGImages/2007_009788.jpg /SegmentationClassAug/2007_009788.png
248 | /JPEGImages/2007_009832.jpg /SegmentationClassAug/2007_009832.png
249 | /JPEGImages/2007_009889.jpg /SegmentationClassAug/2007_009889.png
250 | /JPEGImages/2007_009899.jpg /SegmentationClassAug/2007_009899.png
251 | /JPEGImages/2008_000002.jpg /SegmentationClassAug/2008_000002.png
252 | /JPEGImages/2008_000015.jpg /SegmentationClassAug/2008_000015.png
253 | /JPEGImages/2008_000019.jpg /SegmentationClassAug/2008_000019.png
254 | /JPEGImages/2008_000023.jpg /SegmentationClassAug/2008_000023.png
255 | /JPEGImages/2008_000043.jpg /SegmentationClassAug/2008_000043.png
256 | /JPEGImages/2008_000053.jpg /SegmentationClassAug/2008_000053.png
257 | /JPEGImages/2008_000059.jpg /SegmentationClassAug/2008_000059.png
258 | /JPEGImages/2008_000066.jpg /SegmentationClassAug/2008_000066.png
259 | /JPEGImages/2008_000067.jpg /SegmentationClassAug/2008_000067.png
260 | /JPEGImages/2008_000070.jpg /SegmentationClassAug/2008_000070.png
261 | /JPEGImages/2008_000078.jpg /SegmentationClassAug/2008_000078.png
262 | /JPEGImages/2008_000084.jpg /SegmentationClassAug/2008_000084.png
263 | /JPEGImages/2008_000089.jpg /SegmentationClassAug/2008_000089.png
264 | /JPEGImages/2008_000093.jpg /SegmentationClassAug/2008_000093.png
265 | /JPEGImages/2008_000115.jpg /SegmentationClassAug/2008_000115.png
266 | /JPEGImages/2008_000128.jpg /SegmentationClassAug/2008_000128.png
267 | /JPEGImages/2008_000133.jpg /SegmentationClassAug/2008_000133.png
268 | /JPEGImages/2008_000145.jpg /SegmentationClassAug/2008_000145.png
269 | /JPEGImages/2008_000154.jpg /SegmentationClassAug/2008_000154.png
270 | /JPEGImages/2008_000188.jpg /SegmentationClassAug/2008_000188.png
271 | /JPEGImages/2008_000191.jpg /SegmentationClassAug/2008_000191.png
272 | /JPEGImages/2008_000194.jpg /SegmentationClassAug/2008_000194.png
273 | /JPEGImages/2008_000196.jpg /SegmentationClassAug/2008_000196.png
274 | /JPEGImages/2008_000227.jpg /SegmentationClassAug/2008_000227.png
275 | /JPEGImages/2008_000272.jpg /SegmentationClassAug/2008_000272.png
276 | /JPEGImages/2008_000273.jpg /SegmentationClassAug/2008_000273.png
277 | /JPEGImages/2008_000274.jpg /SegmentationClassAug/2008_000274.png
278 | /JPEGImages/2008_000287.jpg /SegmentationClassAug/2008_000287.png
279 | /JPEGImages/2008_000305.jpg /SegmentationClassAug/2008_000305.png
280 | /JPEGImages/2008_000321.jpg /SegmentationClassAug/2008_000321.png
281 | /JPEGImages/2008_000335.jpg /SegmentationClassAug/2008_000335.png
282 | /JPEGImages/2008_000397.jpg /SegmentationClassAug/2008_000397.png
283 | /JPEGImages/2008_000491.jpg /SegmentationClassAug/2008_000491.png
284 | /JPEGImages/2008_000564.jpg /SegmentationClassAug/2008_000564.png
285 | /JPEGImages/2008_000703.jpg /SegmentationClassAug/2008_000703.png
286 | /JPEGImages/2008_000790.jpg /SegmentationClassAug/2008_000790.png
287 | /JPEGImages/2008_001077.jpg /SegmentationClassAug/2008_001077.png
288 | /JPEGImages/2008_001225.jpg /SegmentationClassAug/2008_001225.png
289 | /JPEGImages/2008_001336.jpg /SegmentationClassAug/2008_001336.png
290 | /JPEGImages/2008_001405.jpg /SegmentationClassAug/2008_001405.png
291 | /JPEGImages/2008_001626.jpg /SegmentationClassAug/2008_001626.png
292 | /JPEGImages/2008_001744.jpg /SegmentationClassAug/2008_001744.png
293 | /JPEGImages/2008_001813.jpg /SegmentationClassAug/2008_001813.png
294 | /JPEGImages/2008_002005.jpg /SegmentationClassAug/2008_002005.png
295 | /JPEGImages/2008_002153.jpg /SegmentationClassAug/2008_002153.png
296 | /JPEGImages/2008_002204.jpg /SegmentationClassAug/2008_002204.png
297 | /JPEGImages/2008_002292.jpg /SegmentationClassAug/2008_002292.png
298 | /JPEGImages/2008_002372.jpg /SegmentationClassAug/2008_002372.png
299 | /JPEGImages/2008_002418.jpg /SegmentationClassAug/2008_002418.png
300 | /JPEGImages/2008_003579.jpg /SegmentationClassAug/2008_003579.png
301 |
--------------------------------------------------------------------------------
/dataloaders/voc_splits/60_train_supervised.txt:
--------------------------------------------------------------------------------
1 | /JPEGImages/2007_000032.jpg /SegmentationClassAug/2007_000032.png
2 | /JPEGImages/2007_000039.jpg /SegmentationClassAug/2007_000039.png
3 | /JPEGImages/2007_000063.jpg /SegmentationClassAug/2007_000063.png
4 | /JPEGImages/2007_000068.jpg /SegmentationClassAug/2007_000068.png
5 | /JPEGImages/2007_000121.jpg /SegmentationClassAug/2007_000121.png
6 | /JPEGImages/2007_000170.jpg /SegmentationClassAug/2007_000170.png
7 | /JPEGImages/2007_000241.jpg /SegmentationClassAug/2007_000241.png
8 | /JPEGImages/2007_000243.jpg /SegmentationClassAug/2007_000243.png
9 | /JPEGImages/2007_000250.jpg /SegmentationClassAug/2007_000250.png
10 | /JPEGImages/2007_000256.jpg /SegmentationClassAug/2007_000256.png
11 | /JPEGImages/2007_000333.jpg /SegmentationClassAug/2007_000333.png
12 | /JPEGImages/2007_000363.jpg /SegmentationClassAug/2007_000363.png
13 | /JPEGImages/2007_000364.jpg /SegmentationClassAug/2007_000364.png
14 | /JPEGImages/2007_000392.jpg /SegmentationClassAug/2007_000392.png
15 | /JPEGImages/2007_000480.jpg /SegmentationClassAug/2007_000480.png
16 | /JPEGImages/2007_000504.jpg /SegmentationClassAug/2007_000504.png
17 | /JPEGImages/2007_000515.jpg /SegmentationClassAug/2007_000515.png
18 | /JPEGImages/2007_000528.jpg /SegmentationClassAug/2007_000528.png
19 | /JPEGImages/2007_000549.jpg /SegmentationClassAug/2007_000549.png
20 | /JPEGImages/2007_000584.jpg /SegmentationClassAug/2007_000584.png
21 | /JPEGImages/2007_000645.jpg /SegmentationClassAug/2007_000645.png
22 | /JPEGImages/2007_000648.jpg /SegmentationClassAug/2007_000648.png
23 | /JPEGImages/2007_000713.jpg /SegmentationClassAug/2007_000713.png
24 | /JPEGImages/2007_000720.jpg /SegmentationClassAug/2007_000720.png
25 | /JPEGImages/2007_000733.jpg /SegmentationClassAug/2007_000733.png
26 | /JPEGImages/2007_000768.jpg /SegmentationClassAug/2007_000768.png
27 | /JPEGImages/2007_000793.jpg /SegmentationClassAug/2007_000793.png
28 | /JPEGImages/2007_000822.jpg /SegmentationClassAug/2007_000822.png
29 | /JPEGImages/2007_000836.jpg /SegmentationClassAug/2007_000836.png
30 | /JPEGImages/2007_000876.jpg /SegmentationClassAug/2007_000876.png
31 | /JPEGImages/2007_001027.jpg /SegmentationClassAug/2007_001027.png
32 | /JPEGImages/2007_001073.jpg /SegmentationClassAug/2007_001073.png
33 | /JPEGImages/2007_001149.jpg /SegmentationClassAug/2007_001149.png
34 | /JPEGImages/2007_001225.jpg /SegmentationClassAug/2007_001225.png
35 | /JPEGImages/2007_001397.jpg /SegmentationClassAug/2007_001397.png
36 | /JPEGImages/2007_001416.jpg /SegmentationClassAug/2007_001416.png
37 | /JPEGImages/2007_001420.jpg /SegmentationClassAug/2007_001420.png
38 | /JPEGImages/2007_001439.jpg /SegmentationClassAug/2007_001439.png
39 | /JPEGImages/2007_001487.jpg /SegmentationClassAug/2007_001487.png
40 | /JPEGImages/2007_001595.jpg /SegmentationClassAug/2007_001595.png
41 | /JPEGImages/2007_001602.jpg /SegmentationClassAug/2007_001602.png
42 | /JPEGImages/2007_001609.jpg /SegmentationClassAug/2007_001609.png
43 | /JPEGImages/2007_001704.jpg /SegmentationClassAug/2007_001704.png
44 | /JPEGImages/2007_001764.jpg /SegmentationClassAug/2007_001764.png
45 | /JPEGImages/2007_001857.jpg /SegmentationClassAug/2007_001857.png
46 | /JPEGImages/2007_001872.jpg /SegmentationClassAug/2007_001872.png
47 | /JPEGImages/2007_001901.jpg /SegmentationClassAug/2007_001901.png
48 | /JPEGImages/2007_002227.jpg /SegmentationClassAug/2007_002227.png
49 | /JPEGImages/2007_002281.jpg /SegmentationClassAug/2007_002281.png
50 | /JPEGImages/2007_002361.jpg /SegmentationClassAug/2007_002361.png
51 | /JPEGImages/2007_002462.jpg /SegmentationClassAug/2007_002462.png
52 | /JPEGImages/2007_002845.jpg /SegmentationClassAug/2007_002845.png
53 | /JPEGImages/2007_002953.jpg /SegmentationClassAug/2007_002953.png
54 | /JPEGImages/2007_002967.jpg /SegmentationClassAug/2007_002967.png
55 | /JPEGImages/2007_003178.jpg /SegmentationClassAug/2007_003178.png
56 | /JPEGImages/2007_003189.jpg /SegmentationClassAug/2007_003189.png
57 | /JPEGImages/2007_003207.jpg /SegmentationClassAug/2007_003207.png
58 | /JPEGImages/2007_003788.jpg /SegmentationClassAug/2007_003788.png
59 | /JPEGImages/2007_005273.jpg /SegmentationClassAug/2007_005273.png
60 | /JPEGImages/2007_006530.jpg /SegmentationClassAug/2007_006530.png
61 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import scipy, math
3 | from scipy import ndimage
4 | import cv2
5 | import numpy as np
6 | import sys
7 | import json
8 | import models
9 | import dataloaders
10 | from utils.helpers import colorize_mask
11 | from utils.pallete import get_voc_pallete
12 | from utils import metrics
13 | import torch
14 | import torch.nn as nn
15 | from torchvision import transforms
16 | import torch.nn.functional as F
17 | from torch.utils.data import DataLoader, Dataset
18 | import os
19 | from tqdm import tqdm
20 | from math import ceil
21 | from PIL import Image
22 | from pathlib import Path
23 |
24 |
25 | class testDataset(Dataset):
26 | def __init__(self, images):
27 | mean = [0.485, 0.456, 0.406]
28 | std = [0.229, 0.224, 0.225]
29 | images_path = Path(images)
30 | self.filelist = list(images_path.glob("*.jpg"))
31 | self.to_tensor = transforms.ToTensor()
32 | self.normalize = transforms.Normalize(mean, std)
33 |
34 | def __len__(self):
35 | return len(self.filelist)
36 |
37 | def __getitem__(self, index):
38 | image_path = self.filelist[index]
39 | image_id = str(image_path).split("/")[-1].split(".")[0]
40 | image = Image.open(image_path)
41 | image = self.normalize(self.to_tensor(image))
42 | return image, image_id
43 |
44 | def multi_scale_predict(model, image, scales, num_classes, flip=True):
45 | H, W = (image.size(2), image.size(3))
46 | upsize = (ceil(H / 8) * 8, ceil(W / 8) * 8)
47 | upsample = nn.Upsample(size=upsize, mode='bilinear', align_corners=True)
48 | pad_h, pad_w = upsize[0] - H, upsize[1] - W
49 | image = F.pad(image, pad=(0, pad_w, 0, pad_h), mode='reflect')
50 |
51 | total_predictions = np.zeros((num_classes, image.shape[2], image.shape[3]))
52 |
53 | for scale in scales:
54 | scaled_img = F.interpolate(image, scale_factor=scale, mode='bilinear', align_corners=False)
55 | scaled_prediction = upsample(model(scaled_img))
56 |
57 | if flip:
58 | fliped_img = scaled_img.flip(-1)
59 | fliped_predictions = upsample(model(fliped_img))
60 | scaled_prediction = 0.5 * (fliped_predictions.flip(-1) + scaled_prediction)
61 | total_predictions += scaled_prediction.data.cpu().numpy().squeeze(0)
62 |
63 | total_predictions /= len(scales)
64 | return total_predictions[:, :H, :W]
65 |
66 | def main():
67 | args = parse_arguments()
68 |
69 | # CONFIG
70 | assert args.config
71 | config = json.load(open(args.config))
72 | scales = [0.5, 0.75, 1.0, 1.25, 1.5]
73 |
74 | # DATA
75 | testdataset = testDataset(args.images)
76 | loader = DataLoader(testdataset, batch_size=1, shuffle=False, num_workers=1)
77 | num_classes = 21
78 | palette = get_voc_pallete(num_classes)
79 |
80 | # MODEL
81 | config['model']['supervised'] = True; config['model']['semi'] = False
82 | model = models.CCT(num_classes=num_classes,
83 | conf=config['model'], testing=True)
84 | checkpoint = torch.load(args.model)
85 | model = torch.nn.DataParallel(model)
86 | try:
87 | model.load_state_dict(checkpoint['state_dict'], strict=True)
88 | except Exception as e:
89 | print(f'Some modules are missing: {e}')
90 | model.load_state_dict(checkpoint['state_dict'], strict=False)
91 | model.eval()
92 | model.cuda()
93 |
94 | if args.save and not os.path.exists('outputs'):
95 | os.makedirs('outputs')
96 |
97 | # LOOP OVER THE DATA
98 | tbar = tqdm(loader, ncols=100)
99 | total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
100 | labels, predictions = [], []
101 |
102 | for index, data in enumerate(tbar):
103 | image, image_id = data
104 | image = image.cuda()
105 |
106 | # PREDICT
107 | with torch.no_grad():
108 | output = multi_scale_predict(model, image, scales, num_classes)
109 | prediction = np.asarray(np.argmax(output, axis=0), dtype=np.uint8)
110 |
111 | # SAVE RESULTS
112 | prediction_im = colorize_mask(prediction, palette)
113 | prediction_im.save('outputs/'+image_id[0]+'.png')
114 |
115 | def parse_arguments():
116 | parser = argparse.ArgumentParser(description='PyTorch Training')
117 | parser.add_argument('--config', default='configs/config.json',type=str,
118 | help='Path to the config file')
119 | parser.add_argument( '--model', default=None, type=str,
120 | help='Path to the trained .pth model')
121 | parser.add_argument( '--save', action='store_true', help='Save images')
122 | parser.add_argument('--images', default="/home/yassine/Datasets/vision/PascalVoc/VOC/VOCdevkit/VOC2012/test_images", type=str,
123 | help='Test images for Pascal VOC')
124 | args = parser.parse_args()
125 | return args
126 |
127 | if __name__ == '__main__':
128 | main()
129 |
130 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import CCT
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yassouali/CCT/65d4e5bd4501ae3c564493d0ce18924a908639f5/models/backbones/__init__.py
--------------------------------------------------------------------------------
/models/backbones/get_pretrained_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | FILENAME="models/backbones/pretrained/3x3resnet50-imagenet.pth"
4 |
5 | mkdir -p models/backbones/pretrained
6 | wget https://github.com/yassouali/CCT/releases/download/v0.1/3x3resnet50-imagenet.pth -O $FILENAME
7 |
--------------------------------------------------------------------------------
/models/backbones/module_helper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Donny You (youansheng@gmail.com)
4 |
5 |
6 | import os
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | try:
12 | from urllib import urlretrieve
13 | except ImportError:
14 | from urllib.request import urlretrieve
15 |
16 | class FixedBatchNorm(nn.BatchNorm2d):
17 | def forward(self, input):
18 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
19 |
20 | class ModuleHelper(object):
21 |
22 | @staticmethod
23 | def BNReLU(num_features, norm_type=None, **kwargs):
24 | if norm_type == 'batchnorm':
25 | return nn.Sequential(
26 | nn.BatchNorm2d(num_features, **kwargs),
27 | nn.ReLU()
28 | )
29 | elif norm_type == 'encsync_batchnorm':
30 | from encoding.nn import BatchNorm2d
31 | return nn.Sequential(
32 | BatchNorm2d(num_features, **kwargs),
33 | nn.ReLU()
34 | )
35 | elif norm_type == 'instancenorm':
36 | return nn.Sequential(
37 | nn.InstanceNorm2d(num_features, **kwargs),
38 | nn.ReLU()
39 | )
40 | elif norm_type == 'fixed_batchnorm':
41 | return nn.Sequential(
42 | FixedBatchNorm(num_features, **kwargs),
43 | nn.ReLU()
44 | )
45 | else:
46 | raise ValueError('Not support BN type: {}.'.format(norm_type))
47 |
48 | @staticmethod
49 | def BatchNorm3d(norm_type=None, ret_cls=False):
50 | if norm_type == 'batchnorm':
51 | return nn.BatchNorm3d
52 | elif norm_type == 'encsync_batchnorm':
53 | from encoding.nn import BatchNorm3d
54 | return BatchNorm3d
55 | elif norm_type == 'instancenorm':
56 | return nn.InstanceNorm3d
57 | else:
58 | raise ValueError('Not support BN type: {}.'.format(norm_type))
59 |
60 | @staticmethod
61 | def BatchNorm2d(norm_type=None, ret_cls=False):
62 | if norm_type == 'batchnorm':
63 | return nn.BatchNorm2d
64 | elif norm_type == 'encsync_batchnorm':
65 | from encoding.nn import BatchNorm2d
66 | return BatchNorm2d
67 |
68 | elif norm_type == 'instancenorm':
69 | return nn.InstanceNorm2d
70 | else:
71 | raise ValueError('Not support BN type: {}.'.format(norm_type))
72 |
73 | @staticmethod
74 | def BatchNorm1d(norm_type=None, ret_cls=False):
75 | if norm_type == 'batchnorm':
76 | return nn.BatchNorm1d
77 | elif norm_type == 'encsync_batchnorm':
78 | from encoding.nn import BatchNorm1d
79 | return BatchNorm1d
80 | elif norm_type == 'instancenorm':
81 | return nn.InstanceNorm1d
82 | else:
83 | raise ValueError('Not support BN type: {}.'.format(norm_type))
84 |
85 | @staticmethod
86 | def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
87 | if pretrained is None:
88 | return model
89 |
90 | if not os.path.exists(pretrained):
91 | print('{} not exists.'.format(pretrained))
92 | return model
93 |
94 | print('Loading pretrained model:{}'.format(pretrained))
95 | if all_match:
96 | pretrained_dict = torch.load(pretrained, map_location=map_location)
97 | model_dict = model.state_dict()
98 | load_dict = dict()
99 | for k, v in pretrained_dict.items():
100 | if 'prefix.{}'.format(k) in model_dict:
101 | load_dict['prefix.{}'.format(k)] = v
102 | else:
103 | load_dict[k] = v
104 | model.load_state_dict(load_dict)
105 |
106 | else:
107 | pretrained_dict = torch.load(pretrained)
108 | model_dict = model.state_dict()
109 | load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
110 | print('Matched Keys: {}'.format(load_dict.keys()))
111 | model_dict.update(load_dict)
112 | model.load_state_dict(model_dict)
113 |
114 | return model
115 |
116 | @staticmethod
117 | def load_url(url, map_location=None):
118 | model_dir = os.path.join('~', '.TorchCV', 'model')
119 | if not os.path.exists(model_dir):
120 | os.makedirs(model_dir)
121 |
122 | filename = url.split('/')[-1]
123 | cached_file = os.path.join(model_dir, filename)
124 | if not os.path.exists(cached_file):
125 | print('Downloading: "{}" to {}\n'.format(url, cached_file))
126 | urlretrieve(url, cached_file)
127 |
128 | print('Loading pretrained model:{}'.format(cached_file))
129 | return torch.load(cached_file, map_location=map_location)
130 |
131 | @staticmethod
132 | def constant_init(module, val, bias=0):
133 | nn.init.constant_(module.weight, val)
134 | if hasattr(module, 'bias') and module.bias is not None:
135 | nn.init.constant_(module.bias, bias)
136 |
137 | @staticmethod
138 | def xavier_init(module, gain=1, bias=0, distribution='normal'):
139 | assert distribution in ['uniform', 'normal']
140 | if distribution == 'uniform':
141 | nn.init.xavier_uniform_(module.weight, gain=gain)
142 | else:
143 | nn.init.xavier_normal_(module.weight, gain=gain)
144 | if hasattr(module, 'bias') and module.bias is not None:
145 | nn.init.constant_(module.bias, bias)
146 |
147 | @staticmethod
148 | def normal_init(module, mean=0, std=1, bias=0):
149 | nn.init.normal_(module.weight, mean, std)
150 | if hasattr(module, 'bias') and module.bias is not None:
151 | nn.init.constant_(module.bias, bias)
152 |
153 | @staticmethod
154 | def uniform_init(module, a=0, b=1, bias=0):
155 | nn.init.uniform_(module.weight, a, b)
156 | if hasattr(module, 'bias') and module.bias is not None:
157 | nn.init.constant_(module.bias, bias)
158 |
159 | @staticmethod
160 | def kaiming_init(module,
161 | mode='fan_in',
162 | nonlinearity='leaky_relu',
163 | bias=0,
164 | distribution='normal'):
165 | assert distribution in ['uniform', 'normal']
166 | if distribution == 'uniform':
167 | nn.init.kaiming_uniform_(
168 | module.weight, mode=mode, nonlinearity=nonlinearity)
169 | else:
170 | nn.init.kaiming_normal_(
171 | module.weight, mode=mode, nonlinearity=nonlinearity)
172 | if hasattr(module, 'bias') and module.bias is not None:
173 | nn.init.constant_(module.bias, bias)
174 |
175 |
--------------------------------------------------------------------------------
/models/backbones/resnet_backbone.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Donny You(youansheng@gmail.com)
4 |
5 |
6 | import torch.nn as nn
7 | from models.backbones.resnet_models import *
8 |
9 |
10 | class NormalResnetBackbone(nn.Module):
11 | def __init__(self, orig_resnet):
12 | super(NormalResnetBackbone, self).__init__()
13 |
14 | self.num_features = 2048
15 | # take pretrained resnet, except AvgPool and FC
16 | self.prefix = orig_resnet.prefix
17 | self.maxpool = orig_resnet.maxpool
18 | self.layer1 = orig_resnet.layer1
19 | self.layer2 = orig_resnet.layer2
20 | self.layer3 = orig_resnet.layer3
21 | self.layer4 = orig_resnet.layer4
22 |
23 | def get_num_features(self):
24 | return self.num_features
25 |
26 | def forward(self, x):
27 | tuple_features = list()
28 | x = self.prefix(x)
29 | x = self.maxpool(x)
30 | x = self.layer1(x)
31 | tuple_features.append(x)
32 | x = self.layer2(x)
33 | tuple_features.append(x)
34 | x = self.layer3(x)
35 | tuple_features.append(x)
36 | x = self.layer4(x)
37 | tuple_features.append(x)
38 |
39 | return tuple_features
40 |
41 |
42 | class DilatedResnetBackbone(nn.Module):
43 | def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
44 | super(DilatedResnetBackbone, self).__init__()
45 |
46 | self.num_features = 2048
47 | from functools import partial
48 |
49 | if dilate_scale == 8:
50 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
51 | if multi_grid is None:
52 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
53 | else:
54 | for i, r in enumerate(multi_grid):
55 | orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
56 |
57 | elif dilate_scale == 16:
58 | if multi_grid is None:
59 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
60 | else:
61 | for i, r in enumerate(multi_grid):
62 | orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
63 |
64 | # Take pretrained resnet, except AvgPool and FC
65 | self.prefix = orig_resnet.prefix
66 | self.maxpool = orig_resnet.maxpool
67 | self.layer1 = orig_resnet.layer1
68 | self.layer2 = orig_resnet.layer2
69 | self.layer3 = orig_resnet.layer3
70 | self.layer4 = orig_resnet.layer4
71 |
72 | def _nostride_dilate(self, m, dilate):
73 | classname = m.__class__.__name__
74 | if classname.find('Conv') != -1:
75 | # the convolution with stride
76 | if m.stride == (2, 2):
77 | m.stride = (1, 1)
78 | if m.kernel_size == (3, 3):
79 | m.dilation = (dilate // 2, dilate // 2)
80 | m.padding = (dilate // 2, dilate // 2)
81 | # other convoluions
82 | else:
83 | if m.kernel_size == (3, 3):
84 | m.dilation = (dilate, dilate)
85 | m.padding = (dilate, dilate)
86 |
87 | def get_num_features(self):
88 | return self.num_features
89 |
90 | def forward(self, x):
91 | tuple_features = list()
92 | x = self.prefix(x)
93 | x = self.maxpool(x)
94 |
95 | x = self.layer1(x)
96 | tuple_features.append(x)
97 | x = self.layer2(x)
98 | tuple_features.append(x)
99 | x = self.layer3(x)
100 | tuple_features.append(x)
101 | x = self.layer4(x)
102 | tuple_features.append(x)
103 |
104 | return tuple_features
105 |
106 |
107 | def ResNetBackbone(backbone=None, pretrained=None, multi_grid=None, norm_type='batchnorm'):
108 | arch = backbone
109 | if arch == 'resnet34':
110 | orig_resnet = resnet34(pretrained=pretrained)
111 | arch_net = NormalResnetBackbone(orig_resnet)
112 | arch_net.num_features = 512
113 |
114 | elif arch == 'resnet34_dilated8':
115 | orig_resnet = resnet34(pretrained=pretrained)
116 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
117 | arch_net.num_features = 512
118 |
119 | elif arch == 'resnet34_dilated16':
120 | orig_resnet = resnet34(pretrained=pretrained)
121 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
122 | arch_net.num_features = 512
123 |
124 | elif arch == 'resnet50':
125 | orig_resnet = resnet50(pretrained=pretrained)
126 | arch_net = NormalResnetBackbone(orig_resnet)
127 |
128 | elif arch == 'resnet50_dilated8':
129 | orig_resnet = resnet50(pretrained=pretrained)
130 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
131 |
132 | elif arch == 'resnet50_dilated16':
133 | orig_resnet = resnet50(pretrained=pretrained)
134 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
135 |
136 | elif arch == 'deepbase_resnet50':
137 | if pretrained:
138 | pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
139 | orig_resnet = deepbase_resnet50(pretrained=pretrained)
140 | arch_net = NormalResnetBackbone(orig_resnet)
141 |
142 | elif arch == 'deepbase_resnet50_dilated8':
143 | if pretrained:
144 | pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
145 | orig_resnet = deepbase_resnet50(pretrained=pretrained)
146 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
147 |
148 | elif arch == 'deepbase_resnet50_dilated16':
149 | orig_resnet = deepbase_resnet50(pretrained=pretrained)
150 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
151 |
152 | elif arch == 'resnet101':
153 | orig_resnet = resnet101(pretrained=pretrained)
154 | arch_net = NormalResnetBackbone(orig_resnet)
155 |
156 | elif arch == 'resnet101_dilated8':
157 | orig_resnet = resnet101(pretrained=pretrained)
158 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
159 |
160 | elif arch == 'resnet101_dilated16':
161 | orig_resnet = resnet101(pretrained=pretrained)
162 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
163 |
164 | elif arch == 'deepbase_resnet101':
165 | orig_resnet = deepbase_resnet101(pretrained=pretrained)
166 | arch_net = NormalResnetBackbone(orig_resnet)
167 |
168 | elif arch == 'deepbase_resnet101_dilated8':
169 | if pretrained:
170 | pretrained = 'models/backbones/pretrained/3x3resnet101-imagenet.pth'
171 | orig_resnet = deepbase_resnet101(pretrained=pretrained)
172 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
173 |
174 | elif arch == 'deepbase_resnet101_dilated16':
175 | orig_resnet = deepbase_resnet101(pretrained=pretrained)
176 | arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
177 |
178 | else:
179 | raise Exception('Architecture undefined!')
180 |
181 | return arch_net
182 |
--------------------------------------------------------------------------------
/models/backbones/resnet_models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Donny You(youansheng@gmail.com)
4 |
5 |
6 | import math
7 | import torch.nn as nn
8 | from collections import OrderedDict
9 |
10 | from models.backbones.module_helper import ModuleHelper
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1):
23 | "3x3 convolution with padding"
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=1, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
32 | super(BasicBlock, self).__init__()
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes)
37 | self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | residual = x
43 |
44 | out = self.conv1(x)
45 | out = self.bn1(out)
46 | out = self.relu(out)
47 |
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | residual = self.downsample(x)
53 |
54 | out += residual
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
64 | super(Bottleneck, self).__init__()
65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
66 | self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
68 | padding=1, bias=False)
69 | self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
71 | self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | residual = self.downsample(x)
92 |
93 | out += residual
94 | out = self.relu(out)
95 |
96 | return out
97 |
98 |
99 | class ResNet(nn.Module):
100 |
101 | def __init__(self, block, layers, num_classes=1000, deep_base=False, norm_type=None):
102 | super(ResNet, self).__init__()
103 | self.inplanes = 128 if deep_base else 64
104 | if deep_base:
105 | self.prefix = nn.Sequential(OrderedDict([
106 | ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
107 | ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
108 | ('relu1', nn.ReLU(inplace=False)),
109 | ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
110 | ('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
111 | ('relu2', nn.ReLU(inplace=False)),
112 | ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
113 | ('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
114 | ('relu3', nn.ReLU(inplace=False))]
115 | ))
116 | else:
117 | self.prefix = nn.Sequential(OrderedDict([
118 | ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)),
119 | ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
120 | ('relu', nn.ReLU(inplace=False))]
121 | ))
122 |
123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) # change.
124 |
125 | self.layer1 = self._make_layer(block, 64, layers[0], norm_type=norm_type)
126 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_type=norm_type)
127 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_type=norm_type)
128 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_type=norm_type)
129 | self.avgpool = nn.AvgPool2d(7, stride=1)
130 | self.fc = nn.Linear(512 * block.expansion, num_classes)
131 |
132 | for m in self.modules():
133 | if isinstance(m, nn.Conv2d):
134 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
135 | m.weight.data.normal_(0, math.sqrt(2. / n))
136 | elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 |
140 | def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
141 | downsample = None
142 | if stride != 1 or self.inplanes != planes * block.expansion:
143 | downsample = nn.Sequential(
144 | nn.Conv2d(self.inplanes, planes * block.expansion,
145 | kernel_size=1, stride=stride, bias=False),
146 | ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * block.expansion),
147 | )
148 |
149 | layers = []
150 | layers.append(block(self.inplanes, planes, stride, downsample, norm_type=norm_type))
151 | self.inplanes = planes * block.expansion
152 | for i in range(1, blocks):
153 | layers.append(block(self.inplanes, planes, norm_type=norm_type))
154 |
155 | return nn.Sequential(*layers)
156 |
157 | def forward(self, x):
158 | x = self.prefix(x)
159 | x = self.maxpool(x)
160 |
161 | x = self.layer1(x)
162 | x = self.layer2(x)
163 | x = self.layer3(x)
164 | x = self.layer4(x)
165 |
166 | x = self.avgpool(x)
167 | x = x.view(x.size(0), -1)
168 | x = self.fc(x)
169 |
170 | return x
171 |
172 |
173 | def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
174 | """Constructs a ResNet-18 model.
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on Places
177 | norm_type (str): choose norm type
178 | """
179 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type)
180 | model = ModuleHelper.load_model(model, pretrained=pretrained)
181 | return model
182 |
183 | def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
184 | """Constructs a ResNet-18 model.
185 | Args:
186 | pretrained (bool): If True, returns a model pre-trained on Places
187 | """
188 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type)
189 | model = ModuleHelper.load_model(model, pretrained=pretrained)
190 | return model
191 |
192 | def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
193 | """Constructs a ResNet-34 model.
194 | Args:
195 | pretrained (bool): If True, returns a model pre-trained on Places
196 | """
197 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
198 | model = ModuleHelper.load_model(model, pretrained=pretrained)
199 | return model
200 |
201 | def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
202 | """Constructs a ResNet-34 model.
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on Places
205 | """
206 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
207 | model = ModuleHelper.load_model(model, pretrained=pretrained)
208 | return model
209 |
210 | def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
211 | """Constructs a ResNet-50 model.
212 | Args:
213 | pretrained (bool): If True, returns a model pre-trained on Places
214 | """
215 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
216 | model = ModuleHelper.load_model(model, pretrained=pretrained)
217 | return model
218 |
219 | def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
220 | """Constructs a ResNet-50 model.
221 | Args:
222 | pretrained (bool): If True, returns a model pre-trained on Places
223 | """
224 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
225 | model = ModuleHelper.load_model(model, pretrained=pretrained)
226 | return model
227 |
228 | def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
229 | """Constructs a ResNet-101 model.
230 | Args:
231 | pretrained (bool): If True, returns a model pre-trained on Places
232 | """
233 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
234 | model = ModuleHelper.load_model(model, pretrained=pretrained)
235 | return model
236 |
237 | def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
238 | """Constructs a ResNet-101 model.
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on Places
241 | """
242 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
243 | model = ModuleHelper.load_model(model, pretrained=pretrained)
244 | return model
245 |
246 | def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
247 | """Constructs a ResNet-152 model.
248 |
249 | Args:
250 | pretrained (bool): If True, returns a model pre-trained on Places
251 | """
252 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
253 | model = ModuleHelper.load_model(model, pretrained=pretrained)
254 | return model
255 |
256 | def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
257 | """Constructs a ResNet-152 model.
258 |
259 | Args:
260 | pretrained (bool): If True, returns a model pre-trained on Places
261 | """
262 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
263 | model = ModuleHelper.load_model(model, pretrained=pretrained)
264 | return model
265 |
--------------------------------------------------------------------------------
/models/decoders.py:
--------------------------------------------------------------------------------
1 | import math , time
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from utils.helpers import initialize_weights
6 | from itertools import chain
7 | import contextlib
8 | import random
9 | import numpy as np
10 | import cv2
11 | from torch.distributions.uniform import Uniform
12 |
13 |
14 | def icnr(x, scale=2, init=nn.init.kaiming_normal_):
15 | """
16 | Checkerboard artifact free sub-pixel convolution
17 | https://arxiv.org/abs/1707.02937
18 | """
19 | ni,nf,h,w = x.shape
20 | ni2 = int(ni/(scale**2))
21 | k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
22 | k = k.contiguous().view(ni2, nf, -1)
23 | k = k.repeat(1, 1, scale**2)
24 | k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
25 | x.data.copy_(k)
26 |
27 |
28 | class PixelShuffle(nn.Module):
29 | """
30 | Real-Time Single Image and Video Super-Resolution
31 | https://arxiv.org/abs/1609.05158
32 | """
33 | def __init__(self, n_channels, scale):
34 | super(PixelShuffle, self).__init__()
35 | self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
36 | icnr(self.conv.weight)
37 | self.shuf = nn.PixelShuffle(scale)
38 | self.relu = nn.ReLU(inplace=True)
39 |
40 | def forward(self,x):
41 | x = self.shuf(self.relu(self.conv(x)))
42 | return x
43 |
44 |
45 | def upsample(in_channels, out_channels, upscale, kernel_size=3):
46 | # A series of x 2 upsamling until we get to the upscale we want
47 | layers = []
48 | conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
49 | nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu')
50 | layers.append(conv1x1)
51 | for i in range(int(math.log(upscale, 2))):
52 | layers.append(PixelShuffle(out_channels, scale=2))
53 | return nn.Sequential(*layers)
54 |
55 |
56 | class MainDecoder(nn.Module):
57 | def __init__(self, upscale, conv_in_ch, num_classes):
58 | super(MainDecoder, self).__init__()
59 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
60 |
61 | def forward(self, x):
62 | x = self.upsample(x)
63 | return x
64 |
65 |
66 | class DropOutDecoder(nn.Module):
67 | def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True):
68 | super(DropOutDecoder, self).__init__()
69 | self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate)
70 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
71 |
72 | def forward(self, x, _):
73 | x = self.upsample(self.dropout(x))
74 | return x
75 |
76 |
77 | class FeatureDropDecoder(nn.Module):
78 | def __init__(self, upscale, conv_in_ch, num_classes):
79 | super(FeatureDropDecoder, self).__init__()
80 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
81 |
82 | def feature_dropout(self, x):
83 | attention = torch.mean(x, dim=1, keepdim=True)
84 | max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
85 | threshold = max_val * np.random.uniform(0.7, 0.9)
86 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
87 | drop_mask = (attention < threshold).float()
88 | return x.mul(drop_mask)
89 |
90 | def forward(self, x, _):
91 | x = self.feature_dropout(x)
92 | x = self.upsample(x)
93 | return x
94 |
95 |
96 | class FeatureNoiseDecoder(nn.Module):
97 | def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3):
98 | super(FeatureNoiseDecoder, self).__init__()
99 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
100 | self.uni_dist = Uniform(-uniform_range, uniform_range)
101 |
102 | def feature_based_noise(self, x):
103 | noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
104 | x_noise = x.mul(noise_vector) + x
105 | return x_noise
106 |
107 | def forward(self, x, _):
108 | x = self.feature_based_noise(x)
109 | x = self.upsample(x)
110 | return x
111 |
112 |
113 |
114 | def _l2_normalize(d):
115 | # Normalizing per batch axis
116 | d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
117 | d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
118 | return d
119 |
120 |
121 | def get_r_adv(x, decoder, it=1, xi=1e-1, eps=10.0):
122 | """
123 | Virtual Adversarial Training
124 | https://arxiv.org/abs/1704.03976
125 | """
126 | x_detached = x.detach()
127 | with torch.no_grad():
128 | pred = F.softmax(decoder(x_detached), dim=1)
129 |
130 | d = torch.rand(x.shape).sub(0.5).to(x.device)
131 | d = _l2_normalize(d)
132 |
133 | for _ in range(it):
134 | d.requires_grad_()
135 | pred_hat = decoder(x_detached + xi * d)
136 | logp_hat = F.log_softmax(pred_hat, dim=1)
137 | adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean')
138 | adv_distance.backward()
139 | d = _l2_normalize(d.grad)
140 | decoder.zero_grad()
141 |
142 | r_adv = d * eps
143 | return r_adv
144 |
145 |
146 | class VATDecoder(nn.Module):
147 | def __init__(self, upscale, conv_in_ch, num_classes, xi=1e-1, eps=10.0, iterations=1):
148 | super(VATDecoder, self).__init__()
149 | self.xi = xi
150 | self.eps = eps
151 | self.it = iterations
152 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
153 |
154 | def forward(self, x, _):
155 | r_adv = get_r_adv(x, self.upsample, self.it, self.xi, self.eps)
156 | x = self.upsample(x + r_adv)
157 | return x
158 |
159 |
160 |
161 | def guided_cutout(output, upscale, resize, erase=0.4, use_dropout=False):
162 | if len(output.shape) == 3:
163 | masks = (output > 0).float()
164 | else:
165 | masks = (output.argmax(1) > 0).float()
166 |
167 | if use_dropout:
168 | p_drop = random.randint(3, 6)/10
169 | maskdroped = (F.dropout(masks, p_drop) > 0).float()
170 | maskdroped = maskdroped + (1 - masks)
171 | maskdroped.unsqueeze_(0)
172 | maskdroped = F.interpolate(maskdroped, size=resize, mode='nearest')
173 |
174 | masks_np = []
175 | for mask in masks:
176 | mask_np = np.uint8(mask.cpu().numpy())
177 | mask_ones = np.ones_like(mask_np)
178 | try: # Version 3.x
179 | _, contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
180 | except: # Version 4.x
181 | contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
182 |
183 | polys = [c.reshape(c.shape[0], c.shape[-1]) for c in contours if c.shape[0] > 50]
184 | for poly in polys:
185 | min_w, max_w = poly[:, 0].min(), poly[:, 0].max()
186 | min_h, max_h = poly[:, 1].min(), poly[:, 1].max()
187 | bb_w, bb_h = max_w-min_w, max_h-min_h
188 | rnd_start_w = random.randint(0, int(bb_w*(1-erase)))
189 | rnd_start_h = random.randint(0, int(bb_h*(1-erase)))
190 | h_start, h_end = min_h+rnd_start_h, min_h+rnd_start_h+int(bb_h*erase)
191 | w_start, w_end = min_w+rnd_start_w, min_w+rnd_start_w+int(bb_w*erase)
192 | mask_ones[h_start:h_end, w_start:w_end] = 0
193 | masks_np.append(mask_ones)
194 | masks_np = np.stack(masks_np)
195 |
196 | maskcut = torch.from_numpy(masks_np).float().unsqueeze_(1)
197 | maskcut = F.interpolate(maskcut, size=resize, mode='nearest')
198 |
199 | if use_dropout:
200 | return maskcut.to(output.device), maskdroped.to(output.device)
201 | return maskcut.to(output.device)
202 |
203 |
204 | class CutOutDecoder(nn.Module):
205 | def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True, erase=0.4):
206 | super(CutOutDecoder, self).__init__()
207 | self.erase = erase
208 | self.upscale = upscale
209 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
210 |
211 | def forward(self, x, pred=None):
212 | maskcut = guided_cutout(pred, upscale=self.upscale, erase=self.erase, resize=(x.size(2), x.size(3)))
213 | x = x * maskcut
214 | x = self.upsample(x)
215 | return x
216 |
217 |
218 | def guided_masking(x, output, upscale, resize, return_msk_context=True):
219 | if len(output.shape) == 3:
220 | masks_context = (output > 0).float().unsqueeze(1)
221 | else:
222 | masks_context = (output.argmax(1) > 0).float().unsqueeze(1)
223 |
224 | masks_context = F.interpolate(masks_context, size=resize, mode='nearest')
225 |
226 | x_masked_context = masks_context * x
227 | if return_msk_context:
228 | return x_masked_context
229 |
230 | masks_objects = (1 - masks_context)
231 | x_masked_objects = masks_objects * x
232 | return x_masked_objects
233 |
234 |
235 | class ContextMaskingDecoder(nn.Module):
236 | def __init__(self, upscale, conv_in_ch, num_classes):
237 | super(ContextMaskingDecoder, self).__init__()
238 | self.upscale = upscale
239 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
240 |
241 | def forward(self, x, pred=None):
242 | x_masked_context = guided_masking(x, pred, resize=(x.size(2), x.size(3)),
243 | upscale=self.upscale, return_msk_context=True)
244 | x_masked_context = self.upsample(x_masked_context)
245 | return x_masked_context
246 |
247 |
248 | class ObjectMaskingDecoder(nn.Module):
249 | def __init__(self, upscale, conv_in_ch, num_classes):
250 | super(ObjectMaskingDecoder, self).__init__()
251 | self.upscale = upscale
252 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
253 |
254 | def forward(self, x, pred=None):
255 | x_masked_obj = guided_masking(x, pred, resize=(x.size(2), x.size(3)),
256 | upscale=self.upscale, return_msk_context=False)
257 | x_masked_obj = self.upsample(x_masked_obj)
258 |
259 | return x_masked_obj
260 |
261 |
--------------------------------------------------------------------------------
/models/encoder.py:
--------------------------------------------------------------------------------
1 | from models.backbones.resnet_backbone import ResNetBackbone
2 | from utils.helpers import initialize_weights
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import os
7 |
8 | resnet50 = {
9 | "path": "models/backbones/pretrained/3x3resnet50-imagenet.pth",
10 | }
11 |
12 | class _PSPModule(nn.Module):
13 | def __init__(self, in_channels, bin_sizes):
14 | super(_PSPModule, self).__init__()
15 |
16 | out_channels = in_channels // len(bin_sizes)
17 | self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) for b_s in bin_sizes])
18 | self.bottleneck = nn.Sequential(
19 | nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), out_channels,
20 | kernel_size=3, padding=1, bias=False),
21 | nn.BatchNorm2d(out_channels),
22 | nn.ReLU(inplace=True)
23 | )
24 |
25 | def _make_stages(self, in_channels, out_channels, bin_sz):
26 | prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
27 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
28 | bn = nn.BatchNorm2d(out_channels)
29 | relu = nn.ReLU(inplace=True)
30 | return nn.Sequential(prior, conv, bn, relu)
31 |
32 | def forward(self, features):
33 | h, w = features.size()[2], features.size()[3]
34 | pyramids = [features]
35 | pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
36 | align_corners=False) for stage in self.stages])
37 | output = self.bottleneck(torch.cat(pyramids, dim=1))
38 | return output
39 |
40 |
41 | class Encoder(nn.Module):
42 | def __init__(self, pretrained):
43 | super(Encoder, self).__init__()
44 |
45 | if pretrained and not os.path.isfile(resnet50["path"]):
46 | print("Downloading pretrained resnet (source : https://github.com/donnyyou/torchcv)")
47 | os.system('sh models/backbones/get_pretrained_model.sh')
48 |
49 | model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained)
50 | self.base = nn.Sequential(
51 | nn.Sequential(model.prefix, model.maxpool),
52 | model.layer1,
53 | model.layer2,
54 | model.layer3,
55 | model.layer4
56 | )
57 | self.psp = _PSPModule(2048, bin_sizes=[1, 2, 3, 6])
58 |
59 | def forward(self, x):
60 | x = self.base(x)
61 | x = self.psp(x)
62 | return x
63 |
64 | def get_backbone_params(self):
65 | return self.base.parameters()
66 |
67 | def get_module_params(self):
68 | return self.psp.parameters()
69 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import math, time
2 | from itertools import chain
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 | from base import BaseModel
7 | from utils.helpers import set_trainable
8 | from utils.losses import *
9 | from models.decoders import *
10 | from models.encoder import Encoder
11 | from utils.losses import CE_loss
12 |
13 | class CCT(BaseModel):
14 | def __init__(self, num_classes, conf, sup_loss=None, cons_w_unsup=None, ignore_index=None, testing=False,
15 | pretrained=True, use_weak_lables=False, weakly_loss_w=0.4):
16 |
17 | if not testing:
18 | assert (ignore_index is not None) and (sup_loss is not None) and (cons_w_unsup is not None)
19 |
20 | super(CCT, self).__init__()
21 | assert int(conf['supervised']) + int(conf['semi']) == 1, 'one mode only'
22 | if conf['supervised']:
23 | self.mode = 'supervised'
24 | else:
25 | self.mode = 'semi'
26 |
27 | # Supervised and unsupervised losses
28 | self.ignore_index = ignore_index
29 | if conf['un_loss'] == "KL":
30 | self.unsuper_loss = softmax_kl_loss
31 | elif conf['un_loss'] == "MSE":
32 | self.unsuper_loss = softmax_mse_loss
33 | elif conf['un_loss'] == "JS":
34 | self.unsuper_loss = softmax_js_loss
35 | else:
36 | raise ValueError(f"Invalid supervised loss {conf['un_loss']}")
37 |
38 | self.unsup_loss_w = cons_w_unsup
39 | self.sup_loss_w = conf['supervised_w']
40 | self.softmax_temp = conf['softmax_temp']
41 | self.sup_loss = sup_loss
42 | self.sup_type = conf['sup_loss']
43 |
44 | # Use weak labels
45 | self.use_weak_lables = use_weak_lables
46 | self.weakly_loss_w = weakly_loss_w
47 | # pair wise loss (sup mat)
48 | self.aux_constraint = conf['aux_constraint']
49 | self.aux_constraint_w = conf['aux_constraint_w']
50 | # confidence masking (sup mat)
51 | self.confidence_th = conf['confidence_th']
52 | self.confidence_masking = conf['confidence_masking']
53 |
54 | # Create the model
55 | self.encoder = Encoder(pretrained=pretrained)
56 |
57 | # The main encoder
58 | upscale = 8
59 | num_out_ch = 2048
60 | decoder_in_ch = num_out_ch // 4
61 | self.main_decoder = MainDecoder(upscale, decoder_in_ch, num_classes=num_classes)
62 |
63 | # The auxilary decoders
64 | if self.mode == 'semi' or self.mode == 'weakly_semi':
65 | vat_decoder = [VATDecoder(upscale, decoder_in_ch, num_classes, xi=conf['xi'],
66 | eps=conf['eps']) for _ in range(conf['vat'])]
67 | drop_decoder = [DropOutDecoder(upscale, decoder_in_ch, num_classes,
68 | drop_rate=conf['drop_rate'], spatial_dropout=conf['spatial'])
69 | for _ in range(conf['drop'])]
70 | cut_decoder = [CutOutDecoder(upscale, decoder_in_ch, num_classes, erase=conf['erase'])
71 | for _ in range(conf['cutout'])]
72 | context_m_decoder = [ContextMaskingDecoder(upscale, decoder_in_ch, num_classes)
73 | for _ in range(conf['context_masking'])]
74 | object_masking = [ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes)
75 | for _ in range(conf['object_masking'])]
76 | feature_drop = [FeatureDropDecoder(upscale, decoder_in_ch, num_classes)
77 | for _ in range(conf['feature_drop'])]
78 | feature_noise = [FeatureNoiseDecoder(upscale, decoder_in_ch, num_classes,
79 | uniform_range=conf['uniform_range'])
80 | for _ in range(conf['feature_noise'])]
81 |
82 | self.aux_decoders = nn.ModuleList([*vat_decoder, *drop_decoder, *cut_decoder,
83 | *context_m_decoder, *object_masking, *feature_drop, *feature_noise])
84 |
85 | def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None):
86 | if not self.training:
87 | return self.main_decoder(self.encoder(x_l))
88 |
89 | # We compute the losses in the forward pass to avoid problems encountered in muti-gpu
90 |
91 | # Forward pass the labels example
92 | input_size = (x_l.size(2), x_l.size(3))
93 | output_l = self.main_decoder(self.encoder(x_l))
94 | if output_l.shape != x_l.shape:
95 | output_l = F.interpolate(output_l, size=input_size, mode='bilinear', align_corners=True)
96 |
97 | # Supervised loss
98 | if self.sup_type == 'CE':
99 | loss_sup = self.sup_loss(output_l, target_l, ignore_index=self.ignore_index, temperature=self.softmax_temp) * self.sup_loss_w
100 | elif self.sup_type == 'FL':
101 | loss_sup = self.sup_loss(output_l,target_l) * self.sup_loss_w
102 | else:
103 | loss_sup = self.sup_loss(output_l, target_l, curr_iter=curr_iter, epoch=epoch, ignore_index=self.ignore_index) * self.sup_loss_w
104 |
105 | # If supervised mode only, return
106 | if self.mode == 'supervised':
107 | curr_losses = {'loss_sup': loss_sup}
108 | outputs = {'sup_pred': output_l}
109 | total_loss = loss_sup
110 | return total_loss, curr_losses, outputs
111 |
112 | # If semi supervised mode
113 | elif self.mode == 'semi':
114 | # Get main prediction
115 | x_ul = self.encoder(x_ul)
116 | output_ul = self.main_decoder(x_ul)
117 |
118 | # Get auxiliary predictions
119 | outputs_ul = [aux_decoder(x_ul, output_ul.detach()) for aux_decoder in self.aux_decoders]
120 | targets = F.softmax(output_ul.detach(), dim=1)
121 |
122 | # Compute unsupervised loss
123 | loss_unsup = sum([self.unsuper_loss(inputs=u, targets=targets, \
124 | conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False)
125 | for u in outputs_ul])
126 | loss_unsup = (loss_unsup / len(outputs_ul))
127 | curr_losses = {'loss_sup': loss_sup}
128 |
129 | if output_ul.shape != x_l.shape:
130 | output_ul = F.interpolate(output_ul, size=input_size, mode='bilinear', align_corners=True)
131 | outputs = {'sup_pred': output_l, 'unsup_pred': output_ul}
132 |
133 | # Compute the unsupervised loss
134 | weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter)
135 | loss_unsup = loss_unsup * weight_u
136 | curr_losses['loss_unsup'] = loss_unsup
137 | total_loss = loss_unsup + loss_sup
138 |
139 | # If case we're using weak lables, add the weak loss term with a weight (self.weakly_loss_w)
140 | if self.use_weak_lables:
141 | weight_w = (weight_u / self.unsup_loss_w.final_w) * self.weakly_loss_w
142 | loss_weakly = sum([CE_loss(outp, target_ul, ignore_index=self.ignore_index) for outp in outputs_ul]) / len(outputs_ul)
143 | loss_weakly = loss_weakly * weight_w
144 | curr_losses['loss_weakly'] = loss_weakly
145 | total_loss += loss_weakly
146 |
147 | # Pair-wise loss
148 | if self.aux_constraint:
149 | pair_wise = pair_wise_loss(outputs_ul) * self.aux_constraint_w
150 | curr_losses['pair_wise'] = pair_wise
151 | loss_unsup += pair_wise
152 |
153 | return total_loss, curr_losses, outputs
154 |
155 | def get_backbone_params(self):
156 | return self.encoder.get_backbone_params()
157 |
158 | def get_other_params(self):
159 | if self.mode == 'semi':
160 | return chain(self.encoder.get_module_params(), self.main_decoder.parameters(),
161 | self.aux_decoders.parameters())
162 |
163 | return chain(self.encoder.get_module_params(), self.main_decoder.parameters())
164 |
165 |
--------------------------------------------------------------------------------
/pseudo_labels/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Generating Pseudo-Labels
4 |
5 | This is a 3rd party code, which was adapted for our case, we thank the original authors for
6 | providing the implementation for their work, please check it out if you are interested:
7 | * Paper: [Weakly Supervised Learning of Instance Segmentation with Inter-pixel Relations](https://arxiv.org/abs/1904.05044)
8 | * Code: [Jiwoon Ahn's irn](https://github.com/jiwoon-ahn/irn)
9 |
10 | This code is used for generating pseudo pixel-level from class labels. This is done in three steps:
11 |
12 | * `train_cam.py`: first we fine-tune a pretrained resnet50 (on imagenet from torchvision) on Pascal Voc for image classification
13 | with 21 classes. In this case, for fast training, the batch norm layers are frozen, and we only use high learning rate for the last classification
14 | layer after an average pool.
15 | * `make_cam.py`: Using the pretrained resnet on Pascal Voc, we follows the traditional
16 | ([paper](https://arxiv.org/pdf/1512.04150.pdf)) approach to generate localization maps, this is done
17 | by simply weighting the activations of the last block of resnet by the learned weight of the classification weight.
18 | We then only consider the maps of the ground-truth classes.
19 | * `cam_to_pseudo_labels.py`: The last step is a refinement step to only consider the highly confident regions, and the non-confident regions
20 | are ignored. A CRF refinement step is also applied before saving the pseudo-labels.
21 |
22 |
23 |
24 | To generate the pseudo-labels, simply run:
25 |
26 | ```bash
27 | python run.py --voc12_root DATA_PATH
28 | ```
29 |
30 | `DATA_PATH` must point to the folder containing `JPEGImages` in Pascal Voc dataset.
31 |
32 | The results will be saved in `result/pseudo_labels` as PNG files, which will be used to train the auxiliary decoders of CCT
33 | in weakly semi-supervised setting.
34 |
35 | If you find this code useful, please consider citing the original [paper]((https://arxiv.org/abs/1904.05044)).
--------------------------------------------------------------------------------
/pseudo_labels/cam_to_pseudo_labels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 | from torch import multiprocessing
5 | from torch.utils.data import DataLoader
6 | import voc12.dataloader
7 | from misc import torchutils, imutils
8 |
9 |
10 | def _work(process_id, infer_dataset, args):
11 |
12 | databin = infer_dataset[process_id]
13 | infer_data_loader = DataLoader(databin, shuffle=False, num_workers=0, pin_memory=False)
14 |
15 | for iter, pack in enumerate(infer_data_loader):
16 | img_name = voc12.dataloader.decode_int_filename(pack['name'][0])
17 | img = pack['img'][0].numpy()
18 | cam_dict = np.load(os.path.join(args.cam_out_dir, img_name + '.npy'), allow_pickle=True).item()
19 |
20 | cams = cam_dict['high_res']
21 | keys = np.pad(cam_dict['keys'] + 1, (1, 0), mode='constant')
22 |
23 | # 1. find confident fg & bg
24 | fg_conf_cam = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.conf_fg_thres)
25 | fg_conf_cam = np.argmax(fg_conf_cam, axis=0)
26 | pred = imutils.crf_inference_label(img, fg_conf_cam, n_labels=keys.shape[0])
27 | fg_conf = keys[pred]
28 |
29 | bg_conf_cam = np.pad(cams, ((1, 0), (0, 0), (0, 0)), mode='constant', constant_values=args.conf_bg_thres)
30 | bg_conf_cam = np.argmax(bg_conf_cam, axis=0)
31 | pred = imutils.crf_inference_label(img, bg_conf_cam, n_labels=keys.shape[0])
32 | bg_conf = keys[pred]
33 |
34 | # 2. combine confident fg & bg
35 | conf = fg_conf.copy()
36 | conf[fg_conf == 0] = 255
37 | conf[bg_conf + fg_conf == 0] = 0
38 |
39 | imageio.imwrite(os.path.join(args.pseudo_labels_out_dir, img_name + '.png'), conf.astype(np.uint8))
40 |
41 | if process_id == args.num_workers - 1 and iter % (len(databin) // 20) == 0:
42 | print("%d " % ((5 * iter + 1) // (len(databin) // 20)), end='')
43 |
44 | def run(args):
45 | dataset = voc12.dataloader.VOC12ImageDataset(args.train_list, voc12_root=args.voc12_root, img_normal=None, to_torch=False)
46 | dataset = torchutils.split_dataset(dataset, args.num_workers)
47 |
48 | print('[ ', end='')
49 | multiprocessing.spawn(_work, nprocs=args.num_workers, args=(dataset, args), join=True)
50 | print(']')
51 |
--------------------------------------------------------------------------------
/pseudo_labels/make_cam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import multiprocessing, cuda
3 | from torch.utils.data import DataLoader
4 | import torch.nn.functional as F
5 | from torch.backends import cudnn
6 |
7 | import numpy as np
8 | import importlib
9 | import os
10 |
11 | import voc12.dataloader
12 | from misc import torchutils, imutils
13 |
14 | cudnn.enabled = True
15 |
16 | def _work(process_id, model, dataset, args):
17 |
18 | databin = dataset[process_id]
19 | n_gpus = torch.cuda.device_count()
20 | data_loader = DataLoader(databin, shuffle=False, num_workers=args.num_workers // n_gpus, pin_memory=False)
21 |
22 | with torch.no_grad(), cuda.device(process_id):
23 |
24 | model.cuda()
25 |
26 | for iter, pack in enumerate(data_loader):
27 |
28 | img_name = pack['name'][0]
29 | label = pack['label'][0]
30 | size = pack['size']
31 |
32 | strided_size = imutils.get_strided_size(size, 4)
33 | strided_up_size = imutils.get_strided_up_size(size, 16)
34 |
35 | outputs = [model(img[0].cuda(non_blocking=True))
36 | for img in pack['img']]
37 |
38 | strided_cam = torch.sum(torch.stack(
39 | [F.interpolate(torch.unsqueeze(o, 0), strided_size, mode='bilinear', align_corners=False)[0] for o
40 | in outputs]), 0)
41 |
42 | highres_cam = [F.interpolate(torch.unsqueeze(o, 1), strided_up_size,
43 | mode='bilinear', align_corners=False) for o in outputs]
44 | highres_cam = torch.sum(torch.stack(highres_cam, 0), 0)[:, 0, :size[0], :size[1]]
45 |
46 | valid_cat = torch.nonzero(label)[:, 0]
47 |
48 | strided_cam = strided_cam[valid_cat]
49 | strided_cam /= F.adaptive_max_pool2d(strided_cam, (1, 1)) + 1e-5
50 |
51 | highres_cam = highres_cam[valid_cat]
52 | highres_cam /= F.adaptive_max_pool2d(highres_cam, (1, 1)) + 1e-5
53 |
54 | # save cams
55 | np.save(os.path.join(args.cam_out_dir, img_name + '.npy'),
56 | {"keys": valid_cat, "cam": strided_cam.cpu(), "high_res": highres_cam.cpu().numpy()})
57 |
58 | if process_id == n_gpus - 1 and iter % (len(databin) // 20) == 0:
59 | print("%d " % ((5*iter+1)//(len(databin) // 20)), end='')
60 |
61 |
62 | def run(args):
63 | model = getattr(importlib.import_module(args.cam_network), 'CAM')()
64 | model.load_state_dict(torch.load(args.cam_weights_name + '.pth'), strict=True)
65 | model.eval()
66 |
67 | n_gpus = torch.cuda.device_count()
68 |
69 | dataset = voc12.dataloader.VOC12ClassificationDatasetMSF(args.train_list,
70 | voc12_root=args.voc12_root, scales=args.cam_scales)
71 | dataset = torchutils.split_dataset(dataset, n_gpus)
72 |
73 | print('[ ', end='')
74 | multiprocessing.spawn(_work, nprocs=n_gpus, args=(model, dataset, args), join=True)
75 | print(']')
76 |
77 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/pseudo_labels/misc/imutils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 |
4 | import pydensecrf.densecrf as dcrf
5 | from pydensecrf.utils import unary_from_labels
6 | from PIL import Image
7 |
8 | def pil_resize(img, size, order):
9 | if size[0] == img.shape[0] and size[1] == img.shape[1]:
10 | return img
11 |
12 | if order == 3:
13 | resample = Image.BICUBIC
14 | elif order == 0:
15 | resample = Image.NEAREST
16 |
17 | return np.asarray(Image.fromarray(img).resize(size[::-1], resample))
18 |
19 | def pil_rescale(img, scale, order):
20 | height, width = img.shape[:2]
21 | target_size = (int(np.round(height*scale)), int(np.round(width*scale)))
22 | return pil_resize(img, target_size, order)
23 |
24 |
25 | def random_resize_long(img, min_long, max_long):
26 | target_long = random.randint(min_long, max_long)
27 | h, w = img.shape[:2]
28 |
29 | if w < h:
30 | scale = target_long / h
31 | else:
32 | scale = target_long / w
33 |
34 | return pil_rescale(img, scale, 3)
35 |
36 | def random_scale(img, scale_range, order):
37 | target_scale = scale_range[0] + random.random() * (scale_range[1] - scale_range[0])
38 | if isinstance(img, tuple):
39 | return (pil_rescale(img[0], target_scale, order[0]), pil_rescale(img[1], target_scale, order[1]))
40 | else:
41 | return pil_rescale(img[0], target_scale, order)
42 |
43 | def random_lr_flip(img):
44 |
45 | if bool(random.getrandbits(1)):
46 | if isinstance(img, tuple):
47 | return [np.fliplr(m) for m in img]
48 | else:
49 | return np.fliplr(img)
50 | else:
51 | return img
52 |
53 | def get_random_crop_box(imgsize, cropsize):
54 | h, w = imgsize
55 |
56 | ch = min(cropsize, h)
57 | cw = min(cropsize, w)
58 |
59 | w_space = w - cropsize
60 | h_space = h - cropsize
61 |
62 | if w_space > 0:
63 | cont_left = 0
64 | img_left = random.randrange(w_space + 1)
65 | else:
66 | cont_left = random.randrange(-w_space + 1)
67 | img_left = 0
68 |
69 | if h_space > 0:
70 | cont_top = 0
71 | img_top = random.randrange(h_space + 1)
72 | else:
73 | cont_top = random.randrange(-h_space + 1)
74 | img_top = 0
75 |
76 | return cont_top, cont_top+ch, cont_left, cont_left+cw, img_top, img_top+ch, img_left, img_left+cw
77 |
78 | def random_crop(images, cropsize, default_values):
79 |
80 | if isinstance(images, np.ndarray): images = (images,)
81 | if isinstance(default_values, int): default_values = (default_values,)
82 |
83 | imgsize = images[0].shape[:2]
84 | box = get_random_crop_box(imgsize, cropsize)
85 |
86 | new_images = []
87 | for img, f in zip(images, default_values):
88 |
89 | if len(img.shape) == 3:
90 | cont = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*f
91 | else:
92 | cont = np.ones((cropsize, cropsize), img.dtype)*f
93 | cont[box[0]:box[1], box[2]:box[3]] = img[box[4]:box[5], box[6]:box[7]]
94 | new_images.append(cont)
95 |
96 | if len(new_images) == 1:
97 | new_images = new_images[0]
98 |
99 | return new_images
100 |
101 | def top_left_crop(img, cropsize, default_value):
102 |
103 | h, w = img.shape[:2]
104 |
105 | ch = min(cropsize, h)
106 | cw = min(cropsize, w)
107 |
108 | if len(img.shape) == 2:
109 | container = np.ones((cropsize, cropsize), img.dtype)*default_value
110 | else:
111 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value
112 |
113 | container[:ch, :cw] = img[:ch, :cw]
114 |
115 | return container
116 |
117 | def center_crop(img, cropsize, default_value=0):
118 |
119 | h, w = img.shape[:2]
120 |
121 | ch = min(cropsize, h)
122 | cw = min(cropsize, w)
123 |
124 | sh = h - cropsize
125 | sw = w - cropsize
126 |
127 | if sw > 0:
128 | cont_left = 0
129 | img_left = int(round(sw / 2))
130 | else:
131 | cont_left = int(round(-sw / 2))
132 | img_left = 0
133 |
134 | if sh > 0:
135 | cont_top = 0
136 | img_top = int(round(sh / 2))
137 | else:
138 | cont_top = int(round(-sh / 2))
139 | img_top = 0
140 |
141 | if len(img.shape) == 2:
142 | container = np.ones((cropsize, cropsize), img.dtype)*default_value
143 | else:
144 | container = np.ones((cropsize, cropsize, img.shape[2]), img.dtype)*default_value
145 |
146 | container[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
147 | img[img_top:img_top+ch, img_left:img_left+cw]
148 |
149 | return container
150 |
151 | def HWC_to_CHW(img):
152 | return np.transpose(img, (2, 0, 1))
153 |
154 | def crf_inference_label(img, labels, t=10, n_labels=21, gt_prob=0.7):
155 |
156 | h, w = img.shape[:2]
157 |
158 | d = dcrf.DenseCRF2D(w, h, n_labels)
159 |
160 | unary = unary_from_labels(labels, n_labels, gt_prob=gt_prob, zero_unsure=False)
161 |
162 | d.setUnaryEnergy(unary)
163 | d.addPairwiseGaussian(sxy=3, compat=3)
164 | d.addPairwiseBilateral(sxy=50, srgb=5, rgbim=np.ascontiguousarray(np.copy(img)), compat=10)
165 |
166 | q = d.inference(t)
167 |
168 | return np.argmax(np.array(q).reshape((n_labels, h, w)), axis=0)
169 |
170 |
171 | def get_strided_size(orig_size, stride):
172 | return ((orig_size[0]-1)//stride+1, (orig_size[1]-1)//stride+1)
173 |
174 |
175 | def get_strided_up_size(orig_size, stride):
176 | strided_size = get_strided_size(orig_size, stride)
177 | return strided_size[0]*stride, strided_size[1]*stride
178 |
179 |
180 | def compress_range(arr):
181 | uniques = np.unique(arr)
182 | maximum = np.max(uniques)
183 |
184 | d = np.zeros(maximum+1, np.int32)
185 | d[uniques] = np.arange(uniques.shape[0])
186 |
187 | out = d[arr]
188 | return out - np.min(out)
189 |
190 |
191 | def colorize_score(score_map, exclude_zero=False, normalize=True, by_hue=False):
192 | import matplotlib.colors
193 | if by_hue:
194 | aranged = np.arange(score_map.shape[0]) / (score_map.shape[0])
195 | hsv_color = np.stack((aranged, np.ones_like(aranged), np.ones_like(aranged)), axis=-1)
196 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color)
197 |
198 | test = rgb_color[np.argmax(score_map, axis=0)]
199 | test = np.expand_dims(np.max(score_map, axis=0), axis=-1) * test
200 |
201 | if normalize:
202 | return test / (np.max(test) + 1e-5)
203 | else:
204 | return test
205 |
206 | else:
207 | VOC_color = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
208 | (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
209 | (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0),
210 | (0, 192, 0), (128, 192, 0), (0, 64, 128), (255, 255, 255)], np.float32)
211 |
212 | if exclude_zero:
213 | VOC_color = VOC_color[1:]
214 |
215 | test = VOC_color[np.argmax(score_map, axis=0)%22]
216 | test = np.expand_dims(np.max(score_map, axis=0), axis=-1) * test
217 | if normalize:
218 | test /= np.max(test) + 1e-5
219 |
220 | return test
221 |
222 |
223 | def colorize_displacement(disp):
224 |
225 | import matplotlib.colors
226 | import math
227 |
228 | a = (np.arctan2(-disp[0], -disp[1]) / math.pi + 1) / 2
229 |
230 | r = np.sqrt(disp[0] ** 2 + disp[1] ** 2)
231 | s = r / np.max(r)
232 | hsv_color = np.stack((a, s, np.ones_like(a)), axis=-1)
233 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color)
234 |
235 | return rgb_color
236 |
237 |
238 | def colorize_label(label_map, normalize=True, by_hue=True, exclude_zero=False, outline=False):
239 |
240 | label_map = label_map.astype(np.uint8)
241 |
242 | if by_hue:
243 | import matplotlib.colors
244 | sz = np.max(label_map)
245 | aranged = np.arange(sz) / sz
246 | hsv_color = np.stack((aranged, np.ones_like(aranged), np.ones_like(aranged)), axis=-1)
247 | rgb_color = matplotlib.colors.hsv_to_rgb(hsv_color)
248 | rgb_color = np.concatenate([np.zeros((1, 3)), rgb_color], axis=0)
249 |
250 | test = rgb_color[label_map]
251 | else:
252 | VOC_color = np.array([(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
253 | (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0),
254 | (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0),
255 | (0, 192, 0), (128, 192, 0), (0, 64, 128), (255, 255, 255)], np.float32)
256 |
257 | if exclude_zero:
258 | VOC_color = VOC_color[1:]
259 | test = VOC_color[label_map]
260 | if normalize:
261 | test /= np.max(test)
262 |
263 | if outline:
264 | edge = np.greater(np.sum(np.abs(test[:-1, :-1] - test[1:, :-1]), axis=-1) + np.sum(np.abs(test[:-1, :-1] - test[:-1, 1:]), axis=-1), 0)
265 | edge1 = np.pad(edge, ((0, 1), (0, 1)), mode='constant', constant_values=0)
266 | edge2 = np.pad(edge, ((1, 0), (1, 0)), mode='constant', constant_values=0)
267 | edge = np.repeat(np.expand_dims(np.maximum(edge1, edge2), -1), 3, axis=-1)
268 |
269 | test = np.maximum(test, edge)
270 | return test
271 |
--------------------------------------------------------------------------------
/pseudo_labels/misc/pyutils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import time
4 | import sys
5 |
6 | class Logger(object):
7 | def __init__(self, outfile):
8 | self.terminal = sys.stdout
9 | self.log = open(outfile, "w")
10 | sys.stdout = self
11 |
12 | def write(self, message):
13 | self.terminal.write(message)
14 | self.log.write(message)
15 |
16 | def flush(self):
17 | self.terminal.flush()
18 |
19 |
20 | class AverageMeter:
21 | def __init__(self, *keys):
22 | self.__data = dict()
23 | for k in keys:
24 | self.__data[k] = [0.0, 0]
25 |
26 | def add(self, dict):
27 | for k, v in dict.items():
28 | if k not in self.__data:
29 | self.__data[k] = [0.0, 0]
30 | self.__data[k][0] += v
31 | self.__data[k][1] += 1
32 |
33 | def get(self, *keys):
34 | if len(keys) == 1:
35 | return self.__data[keys[0]][0] / self.__data[keys[0]][1]
36 | else:
37 | v_list = [self.__data[k][0] / self.__data[k][1] for k in keys]
38 | return tuple(v_list)
39 |
40 | def pop(self, key=None):
41 | if key is None:
42 | for k in self.__data.keys():
43 | self.__data[k] = [0.0, 0]
44 | else:
45 | v = self.get(key)
46 | self.__data[key] = [0.0, 0]
47 | return v
48 |
49 |
50 | class Timer:
51 | def __init__(self, starting_msg = None):
52 | self.start = time.time()
53 | self.stage_start = self.start
54 |
55 | if starting_msg is not None:
56 | print(starting_msg, time.ctime(time.time()))
57 |
58 | def __enter__(self):
59 | return self
60 |
61 | def __exit__(self, exc_type, exc_val, exc_tb):
62 | return
63 |
64 | def update_progress(self, progress):
65 | self.elapsed = time.time() - self.start
66 | self.est_total = self.elapsed / progress
67 | self.est_remaining = self.est_total - self.elapsed
68 | self.est_finish = int(self.start + self.est_total)
69 |
70 |
71 | def str_estimated_complete(self):
72 | return str(time.ctime(self.est_finish))
73 |
74 | def get_stage_elapsed(self):
75 | return time.time() - self.stage_start
76 |
77 | def reset_stage(self):
78 | self.stage_start = time.time()
79 |
80 | def lapse(self):
81 | out = time.time() - self.stage_start
82 | self.stage_start = time.time()
83 | return out
84 |
85 |
86 | def to_one_hot(sparse_integers, maximum_val=None, dtype=np.bool):
87 |
88 | if maximum_val is None:
89 | maximum_val = np.max(sparse_integers) + 1
90 |
91 | src_shape = sparse_integers.shape
92 |
93 | flat_src = np.reshape(sparse_integers, [-1])
94 | src_size = flat_src.shape[0]
95 |
96 | one_hot = np.zeros((maximum_val, src_size), dtype)
97 | one_hot[flat_src, np.arange(src_size)] = 1
98 |
99 | one_hot = np.reshape(one_hot, [maximum_val] + list(src_shape))
100 |
101 | return one_hot
102 |
--------------------------------------------------------------------------------
/pseudo_labels/misc/torchutils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | from torch.utils.data import Subset
5 | import numpy as np
6 | import math
7 |
8 |
9 | class PolyOptimizer(torch.optim.SGD):
10 |
11 | def __init__(self, params, lr, weight_decay, max_step, momentum=0.9):
12 | super().__init__(params, lr, weight_decay)
13 |
14 | self.global_step = 0
15 | self.max_step = max_step
16 | self.momentum = momentum
17 |
18 | self.__initial_lr = [group['lr'] for group in self.param_groups]
19 |
20 |
21 | def step(self, closure=None):
22 |
23 | if self.global_step < self.max_step:
24 | lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
25 |
26 | for i in range(len(self.param_groups)):
27 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
28 |
29 | super().step(closure)
30 |
31 | self.global_step += 1
32 |
33 | class SGDROptimizer(torch.optim.SGD):
34 |
35 | def __init__(self, params, steps_per_epoch, lr=0, weight_decay=0, epoch_start=1, restart_mult=2):
36 | super().__init__(params, lr, weight_decay)
37 |
38 | self.global_step = 0
39 | self.local_step = 0
40 | self.total_restart = 0
41 |
42 | self.max_step = steps_per_epoch * epoch_start
43 | self.restart_mult = restart_mult
44 |
45 | self.__initial_lr = [group['lr'] for group in self.param_groups]
46 |
47 |
48 | def step(self, closure=None):
49 |
50 | if self.local_step >= self.max_step:
51 | self.local_step = 0
52 | self.max_step *= self.restart_mult
53 | self.total_restart += 1
54 |
55 | lr_mult = (1 + math.cos(math.pi * self.local_step / self.max_step))/2 / (self.total_restart + 1)
56 |
57 | for i in range(len(self.param_groups)):
58 | self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
59 |
60 | super().step(closure)
61 |
62 | self.local_step += 1
63 | self.global_step += 1
64 |
65 |
66 | def split_dataset(dataset, n_splits):
67 |
68 | return [Subset(dataset, np.arange(i, len(dataset), n_splits)) for i in range(n_splits)]
69 |
70 |
71 | def gap2d(x, keepdims=False):
72 | out = torch.mean(x.view(x.size(0), x.size(1), -1), -1)
73 | if keepdims:
74 | out = out.view(out.size(0), out.size(1), 1, 1)
75 |
76 | return out
77 |
--------------------------------------------------------------------------------
/pseudo_labels/net/resnet50.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | model_urls = {
7 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
8 | }
9 |
10 |
11 | class FixedBatchNorm(nn.BatchNorm2d):
12 | def forward(self, input):
13 | return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias,
14 | training=False, eps=self.eps)
15 |
16 |
17 | class Bottleneck(nn.Module):
18 | expansion = 4
19 |
20 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
21 | super(Bottleneck, self).__init__()
22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
23 | self.bn1 = FixedBatchNorm(planes)
24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
25 | padding=dilation, bias=False, dilation=dilation)
26 | self.bn2 = FixedBatchNorm(planes)
27 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
28 | self.bn3 = FixedBatchNorm(planes * 4)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.downsample = downsample
31 | self.stride = stride
32 | self.dilation = dilation
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv3(out)
46 | out = self.bn3(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class ResNet(nn.Module):
58 |
59 | def __init__(self, block, layers, strides=(2, 2, 2, 2), dilations=(1, 1, 1, 1)):
60 | self.inplanes = 64
61 | super(ResNet, self).__init__()
62 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=strides[0], padding=3,
63 | bias=False)
64 | self.bn1 = FixedBatchNorm(64)
65 | self.relu = nn.ReLU(inplace=True)
66 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1, dilation=dilations[0])
68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
70 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3])
71 | self.inplanes = 1024
72 |
73 | #self.avgpool = nn.AvgPool2d(7, stride=1)
74 | #self.fc = nn.Linear(512 * block.expansion, 1000)
75 |
76 |
77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
78 | downsample = None
79 | if stride != 1 or self.inplanes != planes * block.expansion:
80 | downsample = nn.Sequential(
81 | nn.Conv2d(self.inplanes, planes * block.expansion,
82 | kernel_size=1, stride=stride, bias=False),
83 | FixedBatchNorm(planes * block.expansion),
84 | )
85 |
86 | layers = [block(self.inplanes, planes, stride, downsample, dilation=1)]
87 | self.inplanes = planes * block.expansion
88 | for i in range(1, blocks):
89 | layers.append(block(self.inplanes, planes, dilation=dilation))
90 |
91 | return nn.Sequential(*layers)
92 |
93 | def forward(self, x):
94 | x = self.conv1(x)
95 | x = self.bn1(x)
96 | x = self.relu(x)
97 | x = self.maxpool(x)
98 |
99 | x = self.layer1(x)
100 | x = self.layer2(x)
101 | x = self.layer3(x)
102 | x = self.layer4(x)
103 |
104 | x = self.avgpool(x)
105 | x = x.view(x.size(0), -1)
106 | x = self.fc(x)
107 |
108 | return x
109 |
110 |
111 | def resnet50(pretrained=True, **kwargs):
112 |
113 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
114 | if pretrained:
115 | state_dict = model_zoo.load_url(model_urls['resnet50'])
116 | state_dict.pop('fc.weight')
117 | state_dict.pop('fc.bias')
118 | model.load_state_dict(state_dict)
119 | return model
--------------------------------------------------------------------------------
/pseudo_labels/net/resnet50_cam.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from misc import torchutils
4 | from net import resnet50
5 |
6 |
7 | class Net(nn.Module):
8 |
9 | def __init__(self):
10 | super(Net, self).__init__()
11 |
12 | self.resnet50 = resnet50.resnet50(pretrained=True, strides=(2, 2, 2, 1))
13 |
14 | self.stage1 = nn.Sequential(self.resnet50.conv1, self.resnet50.bn1, self.resnet50.relu, self.resnet50.maxpool,
15 | self.resnet50.layer1)
16 | self.stage2 = nn.Sequential(self.resnet50.layer2)
17 | self.stage3 = nn.Sequential(self.resnet50.layer3)
18 | self.stage4 = nn.Sequential(self.resnet50.layer4)
19 |
20 | self.classifier = nn.Conv2d(2048, 20, 1, bias=False)
21 |
22 | self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4])
23 | self.newly_added = nn.ModuleList([self.classifier])
24 |
25 | def forward(self, x):
26 | x = self.stage1(x)
27 | x = self.stage2(x).detach()
28 | x = self.stage3(x)
29 | x = self.stage4(x)
30 | x = torchutils.gap2d(x, keepdims=True)
31 | x = self.classifier(x)
32 | x = x.view(-1, 20)
33 | return x
34 |
35 | def train(self, mode=True):
36 | for p in self.resnet50.conv1.parameters():
37 | p.requires_grad = False
38 | for p in self.resnet50.bn1.parameters():
39 | p.requires_grad = False
40 |
41 | def trainable_parameters(self):
42 | return (list(self.backbone.parameters()), list(self.newly_added.parameters()))
43 |
44 |
45 | class CAM(Net):
46 | def __init__(self):
47 | super(CAM, self).__init__()
48 |
49 | def forward(self, x):
50 | x = self.stage1(x)
51 | x = self.stage2(x)
52 | x = self.stage3(x)
53 | x = self.stage4(x)
54 | x = F.conv2d(x, self.classifier.weight)
55 | x = F.relu(x)
56 | x = x[0] + x[1].flip(-1)
57 | return x
58 |
--------------------------------------------------------------------------------
/pseudo_labels/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from misc import pyutils
4 | import train_cam, make_cam, cam_to_pseudo_labels
5 |
6 | if __name__ == '__main__':
7 |
8 | parser = argparse.ArgumentParser()
9 |
10 | # Environment
11 | parser.add_argument("--num_workers", default=os.cpu_count()//2, type=int)
12 | parser.add_argument("--voc12_root", required=True, type=str,
13 | help="Path to VOC 2012 Devkit, must contain ./JPEGImages as subdirectory.")
14 |
15 | # Dataset
16 | parser.add_argument("--train_list", default="voc12/train_aug.txt", type=str)
17 | parser.add_argument("--val_list", default="voc12/val.txt", type=str)
18 | parser.add_argument("--infer_list", default="voc12/train.txt", type=str,
19 | help="voc12/train_aug.txt to train a fully supervised model, "
20 | "voc12/train.txt or voc12/val.txt to quickly check the quality of the labels.")
21 |
22 | # Class Activation Map
23 | parser.add_argument("--cam_network", default="net.resnet50_cam", type=str)
24 | parser.add_argument("--cam_crop_size", default=512, type=int)
25 | parser.add_argument("--cam_batch_size", default=16, type=int)
26 | parser.add_argument("--cam_num_epoches", default=5, type=int)
27 | parser.add_argument("--cam_learning_rate", default=0.1, type=float)
28 | parser.add_argument("--cam_weight_decay", default=1e-4, type=float)
29 | parser.add_argument("--cam_eval_thres", default=0.15, type=float)
30 | parser.add_argument("--cam_scales", default=(1.0, 0.5, 1.5, 2.0), help="Multi-scale inferences")
31 | parser.add_argument("--conf_fg_thres", default=0.30, type=float)
32 | parser.add_argument("--conf_bg_thres", default=0.05, type=float)
33 |
34 | # Output Path
35 | parser.add_argument("--cam_weights_name", default="saved/res50_cam.pth", type=str)
36 | parser.add_argument("--cam_out_dir", default="result/cam", type=str)
37 | parser.add_argument("--pseudo_labels_out_dir", default="result/pseudo_labels", type=str)
38 |
39 | args = parser.parse_args()
40 | os.makedirs("saved", exist_ok=True)
41 | os.makedirs(args.cam_out_dir, exist_ok=True)
42 | os.makedirs(args.pseudo_labels_out_dir, exist_ok=True)
43 |
44 | print(vars(args))
45 |
46 | # Train resnet on pascal voc for classification
47 | timer = pyutils.Timer('step.train_cam:')
48 | train_cam.run(args)
49 | # Generate class activation maps from pretrained resnet
50 | timer = pyutils.Timer('step.make_cam:')
51 | make_cam.run(args)
52 | # Generate pseudo labels from CAMs
53 | timer = pyutils.Timer('step.cam_to_ir_label:')
54 | cam_to_pseudo_labels.run(args)
55 |
--------------------------------------------------------------------------------
/pseudo_labels/train_cam.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.backends import cudnn
4 | cudnn.enabled = True
5 | from torch.utils.data import DataLoader
6 | import torch.nn.functional as F
7 |
8 | import importlib
9 |
10 | import voc12.dataloader
11 | from misc import pyutils, torchutils
12 |
13 |
14 | def validate(model, data_loader):
15 | print('validating ... ', flush=True, end='')
16 | val_loss_meter = pyutils.AverageMeter('loss1', 'loss2')
17 | model.eval()
18 |
19 | with torch.no_grad():
20 | for pack in data_loader:
21 | img = pack['img']
22 | label = pack['label'].cuda(non_blocking=True)
23 | x = model(img)
24 | loss1 = F.multilabel_soft_margin_loss(x, label)
25 | val_loss_meter.add({'loss1': loss1.item()})
26 |
27 | model.train()
28 | print('loss: %.4f' % (val_loss_meter.pop('loss1')))
29 | return
30 |
31 |
32 | def run(args):
33 | model = getattr(importlib.import_module(args.cam_network), 'Net')()
34 | train_dataset = voc12.dataloader.VOC12ClassificationDataset(args.train_list, voc12_root=args.voc12_root,
35 | resize_long=(320, 640), hor_flip=True,
36 | crop_size=512, crop_method="random")
37 | train_data_loader = DataLoader(train_dataset, batch_size=args.cam_batch_size,
38 | shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True)
39 | max_step = (len(train_dataset) // args.cam_batch_size) * args.cam_num_epoches
40 | val_dataset = voc12.dataloader.VOC12ClassificationDataset(args.val_list, voc12_root=args.voc12_root,
41 | crop_size=512)
42 | val_data_loader = DataLoader(val_dataset, batch_size=args.cam_batch_size,
43 | shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=True)
44 |
45 | param_groups = model.trainable_parameters()
46 | optimizer = torchutils.PolyOptimizer([
47 | {'params': param_groups[0], 'lr': args.cam_learning_rate, 'weight_decay': args.cam_weight_decay},
48 | {'params': param_groups[1], 'lr': 10*args.cam_learning_rate, 'weight_decay': args.cam_weight_decay},
49 | ], lr=args.cam_learning_rate, weight_decay=args.cam_weight_decay, max_step=max_step)
50 |
51 | model = torch.nn.DataParallel(model).cuda()
52 | model.train()
53 | avg_meter = pyutils.AverageMeter()
54 | timer = pyutils.Timer()
55 | for ep in range(args.cam_num_epoches):
56 | print('Epoch %d/%d' % (ep+1, args.cam_num_epoches))
57 | for step, pack in enumerate(train_data_loader):
58 | img = pack['img']
59 | label = pack['label'].cuda(non_blocking=True)
60 | x = model(img)
61 | loss = F.multilabel_soft_margin_loss(x, label)
62 | avg_meter.add({'loss1': loss.item()})
63 | optimizer.zero_grad()
64 | loss.backward()
65 | optimizer.step()
66 |
67 | if (optimizer.global_step-1)%100 == 0:
68 | timer.update_progress(optimizer.global_step / max_step)
69 | print('step:%5d/%5d' % (optimizer.global_step - 1, max_step),
70 | 'loss:%.4f' % (avg_meter.pop('loss1')),
71 | 'imps:%.1f' % ((step + 1) * args.cam_batch_size / timer.get_stage_elapsed()),
72 | 'lr: %.4f' % (optimizer.param_groups[0]['lr']),
73 | 'etc:%s' % (timer.str_estimated_complete()), flush=True)
74 |
75 | else:
76 | validate(model, val_data_loader)
77 | timer.reset_stage()
78 |
79 | torch.save(model.module.state_dict(), args.cam_weights_name + '.pth')
80 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/pseudo_labels/voc12/cls_labels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yassouali/CCT/65d4e5bd4501ae3c564493d0ce18924a908639f5/pseudo_labels/voc12/cls_labels.npy
--------------------------------------------------------------------------------
/pseudo_labels/voc12/dataloader.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import Dataset
5 | import os.path
6 | import imageio
7 | from misc import imutils
8 |
9 | IMG_FOLDER_NAME = "JPEGImages"
10 | ANNOT_FOLDER_NAME = "Annotations"
11 | IGNORE = 255
12 |
13 | CAT_LIST = ['aeroplane', 'bicycle', 'bird', 'boat',
14 | 'bottle', 'bus', 'car', 'cat', 'chair',
15 | 'cow', 'diningtable', 'dog', 'horse',
16 | 'motorbike', 'person', 'pottedplant',
17 | 'sheep', 'sofa', 'train',
18 | 'tvmonitor']
19 |
20 | N_CAT = len(CAT_LIST)
21 |
22 | CAT_NAME_TO_NUM = dict(zip(CAT_LIST,range(len(CAT_LIST))))
23 |
24 | cls_labels_dict = np.load('voc12/cls_labels.npy', allow_pickle=True).item()
25 |
26 | def decode_int_filename(int_filename):
27 | s = str(int(int_filename))
28 | return s[:4] + '_' + s[4:]
29 |
30 | def load_image_label_from_xml(img_name, voc12_root):
31 | from xml.dom import minidom
32 |
33 | elem_list = minidom.parse(os.path.join(voc12_root, ANNOT_FOLDER_NAME, decode_int_filename(img_name) + '.xml')).getElementsByTagName('name')
34 |
35 | multi_cls_lab = np.zeros((N_CAT), np.float32)
36 |
37 | for elem in elem_list:
38 | cat_name = elem.firstChild.data
39 | if cat_name in CAT_LIST:
40 | cat_num = CAT_NAME_TO_NUM[cat_name]
41 | multi_cls_lab[cat_num] = 1.0
42 |
43 | return multi_cls_lab
44 |
45 | def load_image_label_list_from_xml(img_name_list, voc12_root):
46 |
47 | return [load_image_label_from_xml(img_name, voc12_root) for img_name in img_name_list]
48 |
49 | def load_image_label_list_from_npy(img_name_list):
50 |
51 | return np.array([cls_labels_dict[img_name] for img_name in img_name_list])
52 |
53 | def get_img_path(img_name, voc12_root):
54 | if not isinstance(img_name, str):
55 | img_name = decode_int_filename(img_name)
56 | return os.path.join(voc12_root, IMG_FOLDER_NAME, img_name + '.jpg')
57 |
58 | def load_img_name_list(dataset_path):
59 |
60 | img_name_list = np.loadtxt(dataset_path, dtype=np.int32)
61 |
62 | return img_name_list
63 |
64 |
65 | class TorchvisionNormalize():
66 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
67 | self.mean = mean
68 | self.std = std
69 |
70 | def __call__(self, img):
71 | imgarr = np.asarray(img)
72 | proc_img = np.empty_like(imgarr, np.float32)
73 |
74 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0]
75 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1]
76 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2]
77 |
78 | return proc_img
79 |
80 | class GetAffinityLabelFromIndices():
81 |
82 | def __init__(self, indices_from, indices_to):
83 |
84 | self.indices_from = indices_from
85 | self.indices_to = indices_to
86 |
87 | def __call__(self, segm_map):
88 |
89 | segm_map_flat = np.reshape(segm_map, -1)
90 |
91 | segm_label_from = np.expand_dims(segm_map_flat[self.indices_from], axis=0)
92 | segm_label_to = segm_map_flat[self.indices_to]
93 |
94 | valid_label = np.logical_and(np.less(segm_label_from, 21), np.less(segm_label_to, 21))
95 |
96 | equal_label = np.equal(segm_label_from, segm_label_to)
97 |
98 | pos_affinity_label = np.logical_and(equal_label, valid_label)
99 |
100 | bg_pos_affinity_label = np.logical_and(pos_affinity_label, np.equal(segm_label_from, 0)).astype(np.float32)
101 | fg_pos_affinity_label = np.logical_and(pos_affinity_label, np.greater(segm_label_from, 0)).astype(np.float32)
102 |
103 | neg_affinity_label = np.logical_and(np.logical_not(equal_label), valid_label).astype(np.float32)
104 |
105 | return torch.from_numpy(bg_pos_affinity_label), torch.from_numpy(fg_pos_affinity_label), \
106 | torch.from_numpy(neg_affinity_label)
107 |
108 |
109 | class VOC12ImageDataset(Dataset):
110 | def __init__(self, img_name_list_path, voc12_root,
111 | resize_long=None, rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False,
112 | crop_size=None, crop_method=None, to_torch=True):
113 |
114 | self.img_name_list = load_img_name_list(img_name_list_path)
115 | self.voc12_root = voc12_root
116 |
117 | self.resize_long = resize_long
118 | self.rescale = rescale
119 | self.crop_size = crop_size
120 | self.img_normal = img_normal
121 | self.hor_flip = hor_flip
122 | self.crop_method = crop_method
123 | self.to_torch = to_torch
124 |
125 | def __len__(self):
126 | return len(self.img_name_list)
127 |
128 | def __getitem__(self, idx):
129 | name = self.img_name_list[idx]
130 | name_str = decode_int_filename(name)
131 |
132 | img = np.asarray(imageio.imread(get_img_path(name_str, self.voc12_root)))
133 |
134 | if self.resize_long:
135 | img = imutils.random_resize_long(img, self.resize_long[0], self.resize_long[1])
136 |
137 | if self.rescale:
138 | img = imutils.random_scale(img, scale_range=self.rescale, order=3)
139 |
140 | if self.img_normal:
141 | img = self.img_normal(img)
142 |
143 | if self.hor_flip:
144 | img = imutils.random_lr_flip(img)
145 |
146 | if self.crop_size:
147 | if self.crop_method == "random":
148 | img = imutils.random_crop(img, self.crop_size, 0)
149 | else:
150 | img = imutils.top_left_crop(img, self.crop_size, 0)
151 |
152 | if self.to_torch:
153 | img = imutils.HWC_to_CHW(img)
154 |
155 | return {'name': name_str, 'img': img}
156 |
157 | class VOC12ClassificationDataset(VOC12ImageDataset):
158 |
159 | def __init__(self, img_name_list_path, voc12_root,
160 | resize_long=None, rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False,
161 | crop_size=None, crop_method=None):
162 | super().__init__(img_name_list_path, voc12_root,
163 | resize_long, rescale, img_normal, hor_flip,
164 | crop_size, crop_method)
165 | self.label_list = load_image_label_list_from_npy(self.img_name_list)
166 |
167 | def __getitem__(self, idx):
168 | out = super().__getitem__(idx)
169 |
170 | out['label'] = torch.from_numpy(self.label_list[idx])
171 |
172 | return out
173 |
174 | class VOC12ClassificationDatasetMSF(VOC12ClassificationDataset):
175 |
176 | def __init__(self, img_name_list_path, voc12_root,
177 | img_normal=TorchvisionNormalize(),
178 | scales=(1.0,)):
179 | self.scales = scales
180 |
181 | super().__init__(img_name_list_path, voc12_root, img_normal=img_normal)
182 | self.scales = scales
183 |
184 | def __getitem__(self, idx):
185 | name = self.img_name_list[idx]
186 | name_str = decode_int_filename(name)
187 |
188 | img = imageio.imread(get_img_path(name_str, self.voc12_root))
189 |
190 | ms_img_list = []
191 | for s in self.scales:
192 | if s == 1:
193 | s_img = img
194 | else:
195 | s_img = imutils.pil_rescale(img, s, order=3)
196 | s_img = self.img_normal(s_img)
197 | s_img = imutils.HWC_to_CHW(s_img)
198 | ms_img_list.append(np.stack([s_img, np.flip(s_img, -1)], axis=0))
199 | if len(self.scales) == 1:
200 | ms_img_list = ms_img_list[0]
201 |
202 | out = {"name": name_str, "img": ms_img_list, "size": (img.shape[0], img.shape[1]),
203 | "label": torch.from_numpy(self.label_list[idx])}
204 | return out
205 |
206 | class VOC12SegmentationDataset(Dataset):
207 |
208 | def __init__(self, img_name_list_path, label_dir, crop_size, voc12_root,
209 | rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False,
210 | crop_method = 'random'):
211 |
212 | self.img_name_list = load_img_name_list(img_name_list_path)
213 | self.voc12_root = voc12_root
214 |
215 | self.label_dir = label_dir
216 |
217 | self.rescale = rescale
218 | self.crop_size = crop_size
219 | self.img_normal = img_normal
220 | self.hor_flip = hor_flip
221 | self.crop_method = crop_method
222 |
223 | def __len__(self):
224 | return len(self.img_name_list)
225 |
226 | def __getitem__(self, idx):
227 | name = self.img_name_list[idx]
228 | name_str = decode_int_filename(name)
229 |
230 | img = imageio.imread(get_img_path(name_str, self.voc12_root))
231 | label = imageio.imread(os.path.join(self.label_dir, name_str + '.png'))
232 |
233 | img = np.asarray(img)
234 |
235 | if self.rescale:
236 | img, label = imutils.random_scale((img, label), scale_range=self.rescale, order=(3, 0))
237 |
238 | if self.img_normal:
239 | img = self.img_normal(img)
240 |
241 | if self.hor_flip:
242 | img, label = imutils.random_lr_flip((img, label))
243 |
244 | if self.crop_method == "random":
245 | img, label = imutils.random_crop((img, label), self.crop_size, (0, 255))
246 | else:
247 | img = imutils.top_left_crop(img, self.crop_size, 0)
248 | label = imutils.top_left_crop(label, self.crop_size, 255)
249 |
250 | img = imutils.HWC_to_CHW(img)
251 |
252 | return {'name': name, 'img': img, 'label': label}
253 |
254 | class VOC12AffinityDataset(VOC12SegmentationDataset):
255 | def __init__(self, img_name_list_path, label_dir, crop_size, voc12_root,
256 | indices_from, indices_to,
257 | rescale=None, img_normal=TorchvisionNormalize(), hor_flip=False, crop_method=None):
258 | super().__init__(img_name_list_path, label_dir, crop_size, voc12_root, rescale, img_normal, hor_flip, crop_method=crop_method)
259 |
260 | self.extract_aff_lab_func = GetAffinityLabelFromIndices(indices_from, indices_to)
261 |
262 | def __len__(self):
263 | return len(self.img_name_list)
264 |
265 | def __getitem__(self, idx):
266 | out = super().__getitem__(idx)
267 |
268 | reduced_label = imutils.pil_rescale(out['label'], 0.25, 0)
269 |
270 | out['aff_bg_pos_label'], out['aff_fg_pos_label'], out['aff_neg_label'] = self.extract_aff_lab_func(reduced_label)
271 |
272 | return out
273 |
274 |
--------------------------------------------------------------------------------
/pseudo_labels/voc12/make_cls_labels.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import voc12.dataloader
3 | import numpy as np
4 |
5 | if __name__ == '__main__':
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--train_list", default='train_aug.txt', type=str)
9 | parser.add_argument("--val_list", default='val.txt', type=str)
10 | parser.add_argument("--out", default="cls_labels.npy", type=str)
11 | parser.add_argument("--voc12_root", default="../../../Dataset/VOC2012", type=str)
12 | args = parser.parse_args()
13 |
14 | train_name_list = voc12.dataloader.load_img_name_list(args.train_list)
15 | val_name_list = voc12.dataloader.load_img_name_list(args.val_list)
16 |
17 | train_val_name_list = np.concatenate([train_name_list, val_name_list], axis=0)
18 | label_list = voc12.dataloader.load_image_label_list_from_xml(train_val_name_list, args.voc12_root)
19 |
20 | total_label = np.zeros(20)
21 |
22 | d = dict()
23 | for img_name, label in zip(train_val_name_list, label_list):
24 | d[img_name] = label
25 | total_label += label
26 |
27 | print(total_label)
28 | np.save(args.out, d)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.1.0
2 | torchvision
3 | dominate
4 | matplotlib>=3.1.1
5 | opencv-python>=4.1.1.26
6 | tensorboard
7 | tqdm>=4.38.0
8 | numpy>=1.16.3
9 | cython
10 | imageio>=2.5.0
11 | scikit-image>=0.15.0
12 | git+https://github.com/lucasb-eyer/pydensecrf.git
13 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import torch
5 | import dataloaders
6 | import models
7 | import math
8 | from utils import Logger
9 | from trainer import Trainer
10 | import torch.nn.functional as F
11 | from utils.losses import abCE_loss, CE_loss, consistency_weight, FocalLoss, softmax_helper, get_alpha
12 |
13 |
14 | def get_instance(module, name, config, *args):
15 | # GET THE CORRESPONDING CLASS / FCT
16 | return getattr(module, config[name]['type'])(*args, **config[name]['args'])
17 |
18 | def main(config, resume):
19 | torch.manual_seed(42)
20 | train_logger = Logger()
21 |
22 | # DATA LOADERS
23 | config['train_supervised']['n_labeled_examples'] = config['n_labeled_examples']
24 | config['train_unsupervised']['n_labeled_examples'] = config['n_labeled_examples']
25 | config['train_unsupervised']['use_weak_lables'] = config['use_weak_lables']
26 | supervised_loader = dataloaders.VOC(config['train_supervised'])
27 | unsupervised_loader = dataloaders.VOC(config['train_unsupervised'])
28 | val_loader = dataloaders.VOC(config['val_loader'])
29 | iter_per_epoch = len(unsupervised_loader)
30 |
31 | # SUPERVISED LOSS
32 | if config['model']['sup_loss'] == 'CE':
33 | sup_loss = CE_loss
34 | elif config['model']['sup_loss'] == 'FL':
35 | alpha = get_alpha(supervised_loader) # calculare class occurences
36 | sup_loss = FocalLoss(apply_nonlin = softmax_helper, ignore_index = config['ignore_index'], alpha = alpha, gamma = 2, smooth = 1e-5)
37 | else:
38 | sup_loss = abCE_loss(iters_per_epoch=iter_per_epoch, epochs=config['trainer']['epochs'],
39 | num_classes=val_loader.dataset.num_classes)
40 |
41 | # MODEL
42 | rampup_ends = int(config['ramp_up'] * config['trainer']['epochs'])
43 | cons_w_unsup = consistency_weight(final_w=config['unsupervised_w'], iters_per_epoch=len(unsupervised_loader),
44 | rampup_ends=rampup_ends)
45 |
46 | model = models.CCT(num_classes=val_loader.dataset.num_classes, conf=config['model'],
47 | sup_loss=sup_loss, cons_w_unsup=cons_w_unsup,
48 | weakly_loss_w=config['weakly_loss_w'], use_weak_lables=config['use_weak_lables'],
49 | ignore_index=val_loader.dataset.ignore_index)
50 | print(f'\n{model}\n')
51 |
52 | # TRAINING
53 | trainer = Trainer(
54 | model=model,
55 | resume=resume,
56 | config=config,
57 | supervised_loader=supervised_loader,
58 | unsupervised_loader=unsupervised_loader,
59 | val_loader=val_loader,
60 | iter_per_epoch=iter_per_epoch,
61 | train_logger=train_logger)
62 |
63 | trainer.train()
64 |
65 | if __name__=='__main__':
66 | # PARSE THE ARGS
67 | parser = argparse.ArgumentParser(description='PyTorch Training')
68 | parser.add_argument('-c', '--config', default='configs/config.json',type=str,
69 | help='Path to the config file')
70 | parser.add_argument('-r', '--resume', default=None, type=str,
71 | help='Path to the .pth model checkpoint to resume training')
72 | parser.add_argument('-d', '--device', default=None, type=str,
73 | help='indices of GPUs to enable (default: all)')
74 | parser.add_argument('--local', action='store_true', default=False)
75 | args = parser.parse_args()
76 |
77 | config = json.load(open(args.config))
78 | torch.backends.cudnn.benchmark = True
79 | main(config, args.resume)
80 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time, random, cv2, sys
3 | from math import ceil
4 | import numpy as np
5 | from itertools import cycle
6 | import torch.nn.functional as F
7 | from torchvision.utils import make_grid
8 | from torchvision import transforms
9 | from base import BaseTrainer
10 | from utils.helpers import colorize_mask
11 | from utils.metrics import eval_metrics, AverageMeter
12 | from tqdm import tqdm
13 | from PIL import Image
14 | from utils.helpers import DeNormalize
15 |
16 |
17 |
18 | class Trainer(BaseTrainer):
19 | def __init__(self, model, resume, config, supervised_loader, unsupervised_loader, iter_per_epoch,
20 | val_loader=None, train_logger=None):
21 | super(Trainer, self).__init__(model, resume, config, iter_per_epoch, train_logger)
22 |
23 | self.supervised_loader = supervised_loader
24 | self.unsupervised_loader = unsupervised_loader
25 | self.val_loader = val_loader
26 |
27 | self.ignore_index = self.val_loader.dataset.ignore_index
28 | self.wrt_mode, self.wrt_step = 'train_', 0
29 | self.log_step = config['trainer'].get('log_per_iter', int(np.sqrt(self.val_loader.batch_size)))
30 | if config['trainer']['log_per_iter']:
31 | self.log_step = int(self.log_step / self.val_loader.batch_size) + 1
32 |
33 | self.num_classes = self.val_loader.dataset.num_classes
34 | self.mode = self.model.module.mode
35 |
36 | # TRANSORMS FOR VISUALIZATION
37 | self.restore_transform = transforms.Compose([
38 | DeNormalize(self.val_loader.MEAN, self.val_loader.STD),
39 | transforms.ToPILImage()])
40 | self.viz_transform = transforms.Compose([
41 | transforms.Resize((400, 400)),
42 | transforms.ToTensor()])
43 |
44 | self.start_time = time.time()
45 |
46 |
47 |
48 | def _train_epoch(self, epoch):
49 | self.html_results.save()
50 |
51 | self.logger.info('\n')
52 | self.model.train()
53 |
54 | if self.mode == 'supervised':
55 | dataloader = iter(self.supervised_loader)
56 | tbar = tqdm(range(len(self.supervised_loader)), ncols=135)
57 | else:
58 | dataloader = iter(zip(cycle(self.supervised_loader), self.unsupervised_loader))
59 | tbar = tqdm(range(len(self.unsupervised_loader)), ncols=135)
60 |
61 | self._reset_metrics()
62 | for batch_idx in tbar:
63 | if self.mode == 'supervised':
64 | (input_l, target_l), (input_ul, target_ul) = next(dataloader), (None, None)
65 | else:
66 | (input_l, target_l), (input_ul, target_ul) = next(dataloader)
67 | input_ul, target_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True)
68 |
69 | input_l, target_l = input_l.cuda(non_blocking=True), target_l.cuda(non_blocking=True)
70 | self.optimizer.zero_grad()
71 |
72 | total_loss, cur_losses, outputs = self.model(x_l=input_l, target_l=target_l, x_ul=input_ul,
73 | curr_iter=batch_idx, target_ul=target_ul, epoch=epoch-1)
74 | total_loss = total_loss.mean()
75 | total_loss.backward()
76 | self.optimizer.step()
77 |
78 | self._update_losses(cur_losses)
79 | self._compute_metrics(outputs, target_l, target_ul, epoch-1)
80 | logs = self._log_values(cur_losses)
81 |
82 | if batch_idx % self.log_step == 0:
83 | self.wrt_step = (epoch - 1) * len(self.unsupervised_loader) + batch_idx
84 | self._write_scalars_tb(logs)
85 |
86 | if batch_idx % int(len(self.unsupervised_loader)*0.9) == 0:
87 | self._write_img_tb(input_l, target_l, input_ul, target_ul, outputs, epoch)
88 |
89 | del input_l, target_l, input_ul, target_ul
90 | del total_loss, cur_losses, outputs
91 |
92 | tbar.set_description('T ({}) | Ls {:.2f} Lu {:.2f} Lw {:.2f} PW {:.2f} m1 {:.2f} m2 {:.2f}|'.format(
93 | epoch, self.loss_sup.average, self.loss_unsup.average, self.loss_weakly.average,
94 | self.pair_wise.average, self.mIoU_l, self.mIoU_ul))
95 |
96 | self.lr_scheduler.step(epoch=epoch-1)
97 |
98 | return logs
99 |
100 |
101 |
102 | def _valid_epoch(self, epoch):
103 | if self.val_loader is None:
104 | self.logger.warning('Not data loader was passed for the validation step, No validation is performed !')
105 | return {}
106 | self.logger.info('\n###### EVALUATION ######')
107 |
108 | self.model.eval()
109 | self.wrt_mode = 'val'
110 | total_loss_val = AverageMeter()
111 | total_inter, total_union = 0, 0
112 | total_correct, total_label = 0, 0
113 |
114 | tbar = tqdm(self.val_loader, ncols=130)
115 | with torch.no_grad():
116 | val_visual = []
117 | for batch_idx, (data, target) in enumerate(tbar):
118 | target, data = target.cuda(non_blocking=True), data.cuda(non_blocking=True)
119 |
120 | H, W = target.size(1), target.size(2)
121 | up_sizes = (ceil(H / 8) * 8, ceil(W / 8) * 8)
122 | pad_h, pad_w = up_sizes[0] - data.size(2), up_sizes[1] - data.size(3)
123 | data = F.pad(data, pad=(0, pad_w, 0, pad_h), mode='reflect')
124 | output = self.model(data)
125 | output = output[:, :, :H, :W]
126 |
127 | # LOSS
128 | loss = F.cross_entropy(output, target, ignore_index=self.ignore_index)
129 | total_loss_val.update(loss.item())
130 |
131 | correct, labeled, inter, union = eval_metrics(output, target, self.num_classes, self.ignore_index)
132 | total_inter, total_union = total_inter+inter, total_union+union
133 | total_correct, total_label = total_correct+correct, total_label+labeled
134 |
135 | # LIST OF IMAGE TO VIZ (15 images)
136 | if len(val_visual) < 15:
137 | if isinstance(data, list): data = data[0]
138 | target_np = target.data.cpu().numpy()
139 | output_np = output.data.max(1)[1].cpu().numpy()
140 | val_visual.append([data[0].data.cpu(), target_np[0], output_np[0]])
141 |
142 | # PRINT INFO
143 | pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
144 | IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
145 | mIoU = IoU.mean()
146 | seg_metrics = {"Pixel_Accuracy": np.round(pixAcc, 3), "Mean_IoU": np.round(mIoU, 3),
147 | "Class_IoU": dict(zip(range(self.num_classes), np.round(IoU, 3)))}
148 |
149 | tbar.set_description('EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'.format( epoch,
150 | total_loss_val.average, pixAcc, mIoU))
151 |
152 | self._add_img_tb(val_visual, 'val')
153 |
154 | # METRICS TO TENSORBOARD
155 | self.wrt_step = (epoch) * len(self.val_loader)
156 | self.writer.add_scalar(f'{self.wrt_mode}/loss', total_loss_val.average, self.wrt_step)
157 | for k, v in list(seg_metrics.items())[:-1]:
158 | self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step)
159 |
160 | log = {
161 | 'val_loss': total_loss_val.average,
162 | **seg_metrics
163 | }
164 | self.html_results.add_results(epoch=epoch, seg_resuts=log)
165 | self.html_results.save()
166 |
167 | if (time.time() - self.start_time) / 3600 > 22:
168 | self._save_checkpoint(epoch, save_best=self.improved)
169 | return log
170 |
171 |
172 |
173 | def _reset_metrics(self):
174 | self.loss_sup = AverageMeter()
175 | self.loss_unsup = AverageMeter()
176 | self.loss_weakly = AverageMeter()
177 | self.pair_wise = AverageMeter()
178 | self.total_inter_l, self.total_union_l = 0, 0
179 | self.total_correct_l, self.total_label_l = 0, 0
180 | self.total_inter_ul, self.total_union_ul = 0, 0
181 | self.total_correct_ul, self.total_label_ul = 0, 0
182 | self.mIoU_l, self.mIoU_ul = 0, 0
183 | self.pixel_acc_l, self.pixel_acc_ul = 0, 0
184 | self.class_iou_l, self.class_iou_ul = {}, {}
185 |
186 |
187 |
188 | def _update_losses(self, cur_losses):
189 | if "loss_sup" in cur_losses.keys():
190 | self.loss_sup.update(cur_losses['loss_sup'].mean().item())
191 | if "loss_unsup" in cur_losses.keys():
192 | self.loss_unsup.update(cur_losses['loss_unsup'].mean().item())
193 | if "loss_weakly" in cur_losses.keys():
194 | self.loss_weakly.update(cur_losses['loss_weakly'].mean().item())
195 | if "pair_wise" in cur_losses.keys():
196 | self.pair_wise.update(cur_losses['pair_wise'].mean().item())
197 |
198 |
199 |
200 | def _compute_metrics(self, outputs, target_l, target_ul, epoch):
201 | seg_metrics_l = eval_metrics(outputs['sup_pred'], target_l, self.num_classes, self.ignore_index)
202 | self._update_seg_metrics(*seg_metrics_l, True)
203 | seg_metrics_l = self._get_seg_metrics(True)
204 | self.pixel_acc_l, self.mIoU_l, self.class_iou_l = seg_metrics_l.values()
205 |
206 | if self.mode == 'semi':
207 | seg_metrics_ul = eval_metrics(outputs['unsup_pred'], target_ul, self.num_classes, self.ignore_index)
208 | self._update_seg_metrics(*seg_metrics_ul, False)
209 | seg_metrics_ul = self._get_seg_metrics(False)
210 | self.pixel_acc_ul, self.mIoU_ul, self.class_iou_ul = seg_metrics_ul.values()
211 |
212 |
213 |
214 | def _update_seg_metrics(self, correct, labeled, inter, union, supervised=True):
215 | if supervised:
216 | self.total_correct_l += correct
217 | self.total_label_l += labeled
218 | self.total_inter_l += inter
219 | self.total_union_l += union
220 | else:
221 | self.total_correct_ul += correct
222 | self.total_label_ul += labeled
223 | self.total_inter_ul += inter
224 | self.total_union_ul += union
225 |
226 |
227 |
228 | def _get_seg_metrics(self, supervised=True):
229 | if supervised:
230 | pixAcc = 1.0 * self.total_correct_l / (np.spacing(1) + self.total_label_l)
231 | IoU = 1.0 * self.total_inter_l / (np.spacing(1) + self.total_union_l)
232 | else:
233 | pixAcc = 1.0 * self.total_correct_ul / (np.spacing(1) + self.total_label_ul)
234 | IoU = 1.0 * self.total_inter_ul / (np.spacing(1) + self.total_union_ul)
235 | mIoU = IoU.mean()
236 | return {
237 | "Pixel_Accuracy": np.round(pixAcc, 3),
238 | "Mean_IoU": np.round(mIoU, 3),
239 | "Class_IoU": dict(zip(range(self.num_classes), np.round(IoU, 3)))
240 | }
241 |
242 |
243 |
244 | def _log_values(self, cur_losses):
245 | logs = {}
246 | if "loss_sup" in cur_losses.keys():
247 | logs['loss_sup'] = self.loss_sup.average
248 | if "loss_unsup" in cur_losses.keys():
249 | logs['loss_unsup'] = self.loss_unsup.average
250 | if "loss_weakly" in cur_losses.keys():
251 | logs['loss_weakly'] = self.loss_weakly.average
252 | if "pair_wise" in cur_losses.keys():
253 | logs['pair_wise'] = self.pair_wise.average
254 |
255 | logs['mIoU_labeled'] = self.mIoU_l
256 | logs['pixel_acc_labeled'] = self.pixel_acc_l
257 | if self.mode == 'semi':
258 | logs['mIoU_unlabeled'] = self.mIoU_ul
259 | logs['pixel_acc_unlabeled'] = self.pixel_acc_ul
260 | return logs
261 |
262 |
263 | def _write_scalars_tb(self, logs):
264 | for k, v in logs.items():
265 | if 'class_iou' not in k: self.writer.add_scalar(f'train/{k}', v, self.wrt_step)
266 | for i, opt_group in enumerate(self.optimizer.param_groups):
267 | self.writer.add_scalar(f'train/Learning_rate_{i}', opt_group['lr'], self.wrt_step)
268 | current_rampup = self.model.module.unsup_loss_w.current_rampup
269 | self.writer.add_scalar('train/Unsupervised_rampup', current_rampup, self.wrt_step)
270 |
271 |
272 |
273 | def _add_img_tb(self, val_visual, wrt_mode):
274 | val_img = []
275 | palette = self.val_loader.dataset.palette
276 | for imgs in val_visual:
277 | imgs = [self.restore_transform(i) if (isinstance(i, torch.Tensor) and len(i.shape) == 3)
278 | else colorize_mask(i, palette) for i in imgs]
279 | imgs = [i.convert('RGB') for i in imgs]
280 | imgs = [self.viz_transform(i) for i in imgs]
281 | val_img.extend(imgs)
282 | val_img = torch.stack(val_img, 0)
283 | val_img = make_grid(val_img.cpu(), nrow=val_img.size(0)//len(val_visual), padding=5)
284 | self.writer.add_image(f'{wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step)
285 |
286 |
287 |
288 | def _write_img_tb(self, input_l, target_l, input_ul, target_ul, outputs, epoch):
289 | outputs_l_np = outputs['sup_pred'].data.max(1)[1].cpu().numpy()
290 | targets_l_np = target_l.data.cpu().numpy()
291 | imgs = [[i.data.cpu(), j, k] for i, j, k in zip(input_l, outputs_l_np, targets_l_np)]
292 | self._add_img_tb(imgs, 'supervised')
293 |
294 | if self.mode == 'semi':
295 | outputs_ul_np = outputs['unsup_pred'].data.max(1)[1].cpu().numpy()
296 | targets_ul_np = target_ul.data.cpu().numpy()
297 | imgs = [[i.data.cpu(), j, k] for i, j, k in zip(input_ul, outputs_ul_np, targets_ul_np)]
298 | self._add_img_tb(imgs, 'unsupervised')
299 |
300 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import Logger
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import datetime
4 | from torchvision.utils import make_grid
5 | from torchvision import transforms
6 | from torch.utils.tensorboard import SummaryWriter
7 | import torch
8 | import torch.nn as nn
9 | import numpy as np
10 | import math
11 | import PIL
12 | import cv2
13 | from matplotlib import colors
14 | from matplotlib import pyplot as plt
15 | import matplotlib.cm as cmx
16 | from utils import pallete
17 |
18 |
19 | class DeNormalize(object):
20 | def __init__(self, mean, std):
21 | self.mean = mean
22 | self.std = std
23 |
24 | def __call__(self, tensor):
25 | for t, m, s in zip(tensor, self.mean, self.std):
26 | t.mul_(s).add_(m)
27 | return tensor
28 |
29 |
30 | def dir_exists(path):
31 | if not os.path.exists(path):
32 | os.makedirs(path)
33 |
34 |
35 | def initialize_weights(*models):
36 | for model in models:
37 | for m in model.modules():
38 | if isinstance(m, nn.Conv2d):
39 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
40 | if m.bias is not None:
41 | nn.init.constant_(m.bias, 0)
42 | elif isinstance(m, nn.BatchNorm2d):
43 | nn.init.constant_(m.weight, 1)
44 | nn.init.constant_(m.bias, 0)
45 | elif isinstance(m, nn.Linear):
46 | nn.init.normal_(m.weight, 0, 0.01)
47 | nn.init.constant_(m.bias, 0)
48 |
49 |
50 | def colorize_mask(mask, palette):
51 | zero_pad = 256 * 3 - len(palette)
52 | for i in range(zero_pad):
53 | palette.append(0)
54 | palette[-3:] = [255, 255, 255]
55 | new_mask = PIL.Image.fromarray(mask.astype(np.uint8)).convert('P')
56 | new_mask.putpalette(palette)
57 | return new_mask
58 |
59 |
60 | def set_trainable_attr(m,b):
61 | m.trainable = b
62 | for p in m.parameters(): p.requires_grad = b
63 |
64 | def apply_leaf(m, f):
65 | c = m if isinstance(m, (list, tuple)) else list(m.children())
66 | if isinstance(m, nn.Module):
67 | f(m)
68 | if len(c)>0:
69 | for l in c:
70 | apply_leaf(l,f)
71 |
72 | def set_trainable(l, b):
73 | apply_leaf(l, lambda m: set_trainable_attr(m,b))
74 |
75 |
--------------------------------------------------------------------------------
/utils/htmlwriter.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os, json, datetime
4 |
5 | class HTML:
6 | def __init__(self, web_dir, exp_name, config, title='seg results', save_name='index', reflesh=0, resume=None):
7 | self.title = title
8 | self.web_dir = web_dir
9 | self.save_name = save_name+'.html'
10 |
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 |
14 | html_file = os.path.join(self.web_dir, self.save_name)
15 |
16 | if resume is not None and os.path.isfile(html_file):
17 | self.old_content = open(html_file).read()
18 | else :
19 | self.old_content = None
20 |
21 | self.doc = dominate.document(title=title)
22 | if reflesh > 0:
23 | with self.doc.head:
24 | meta(http_equiv="reflesh", content=str(reflesh))
25 |
26 | date_time = datetime.datetime.now().strftime('%m-%d_%H-%M')
27 | header = f'Experiment name: {exp_name}, Date: {date_time}'
28 | self.add_header(header)
29 | self.add_header('Configs')
30 | self.add_config(config)
31 | with self.doc:
32 | hr()
33 | hr()
34 | self.add_table()
35 |
36 | def add_header(self, str):
37 | with self.doc:
38 | h3(str)
39 |
40 | def add_table(self, border=1):
41 | self.t = table(border=border, style="table-layout: fixed;")
42 | self.doc.add(self.t)
43 |
44 | def add_config(self, config):
45 | t = table(border=1, style="table-layout: fixed;")
46 | self.doc.add(t)
47 | conf_model = config['model']
48 | with t:
49 | with tr():
50 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
51 | td(f'Epochs : {config["trainer"]["epochs"]}')
52 | td(f'Lr scheduler : {config["lr_scheduler"]}')
53 | td(f'Lr : {config["optimizer"]["args"]["lr"]}')
54 | if "datasets" in list(config.keys()): td(f'Datasets : {config["datasets"]}')
55 | td(f"""Decoders : Vat {conf_model["vat"]} Dropout {conf_model["drop"]} Cutout {conf_model["cutout"]}
56 | FeatureNoise {conf_model["feature_noise"]} FeatureDrop {conf_model["feature_drop"]}
57 | ContextMsk {conf_model["context_masking"]} ObjMsk {conf_model["object_masking"]}""")
58 | if "datasets" in list(config.keys()):
59 | self.doc.add(p(json.dumps(config[config["datasets"]], indent=4, sort_keys=True)))
60 | else:
61 | self.doc.add(p(json.dumps(config["train_supervised"], indent=4, sort_keys=True)))
62 |
63 | def add_results(self, epoch, seg_resuts, width=400, domain=None):
64 | para = p(__pretty=False)
65 | with self.t:
66 | with tr():
67 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
68 | td(f'Epoch : {epoch}')
69 | if domain is not None:
70 | td(f'Mean_IoU_{domain} : {seg_resuts[f"Mean_IoU_{domain}"]}')
71 | td(f'PixelAcc_{domain} : {seg_resuts[f"Pixel_Accuracy_{domain}"]}')
72 | td(f'Val Loss_{domain} : {seg_resuts[f"val_loss_{domain}"]}')
73 | else:
74 | td(f'Mean_IoU : {seg_resuts["Mean_IoU"]}')
75 | td(f'PixelAcc : {seg_resuts["Pixel_Accuracy"]}')
76 | td(f'Val Loss : {seg_resuts["val_loss"]}')
77 |
78 |
79 | def save(self):
80 | html_file = os.path.join(self.web_dir, self.save_name)
81 | f = open(html_file, 'w')
82 | if self.old_content is not None:
83 | f.write(self.old_content + self.doc.render())
84 | else:
85 | f.write(self.doc.render())
86 | f.close()
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 | logging.basicConfig(level=logging.INFO, format='')
5 |
6 | class Logger:
7 | """
8 | Training process logger
9 |
10 | Note:
11 | Used by BaseTrainer to save training history.
12 | """
13 | def __init__(self):
14 | self.entries = {}
15 |
16 | def add_entry(self, entry):
17 | self.entries[len(self.entries) + 1] = entry
18 |
19 | def __str__(self):
20 | return json.dumps(self.entries, sort_keys=True, indent=4)
21 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | from utils import ramps
6 |
7 |
8 |
9 | class consistency_weight(object):
10 | """
11 | ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup']
12 | """
13 | def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'):
14 | self.final_w = final_w
15 | self.iters_per_epoch = iters_per_epoch
16 | self.rampup_starts = rampup_starts * iters_per_epoch
17 | self.rampup_ends = rampup_ends * iters_per_epoch
18 | self.rampup_length = (self.rampup_ends - self.rampup_starts)
19 | self.rampup_func = getattr(ramps, ramp_type)
20 | self.current_rampup = 0
21 |
22 | def __call__(self, epoch, curr_iter):
23 | cur_total_iter = self.iters_per_epoch * epoch + curr_iter
24 | if cur_total_iter < self.rampup_starts:
25 | return 0
26 | self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length)
27 | return self.final_w * self.current_rampup
28 |
29 |
30 | def CE_loss(input_logits, target_targets, ignore_index, temperature=1):
31 | return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index)
32 |
33 | # for FocalLoss
34 | def softmax_helper(x):
35 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
36 | rpt = [1 for _ in range(len(x.size()))]
37 | rpt[1] = x.size(1)
38 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
39 | e_x = torch.exp(x - x_max)
40 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
41 |
42 | def get_alpha(supervised_loader):
43 | # get number of classes
44 | num_labels = 0
45 | for image_batch, label_batch in supervised_loader:
46 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
47 | l_unique = torch.unique(label_batch.data)
48 | list_unique = [element.item() for element in l_unique.flatten()]
49 | num_labels = max(max(list_unique),num_labels)
50 | num_classes = num_labels + 1
51 | # count class occurrences
52 | alpha = [0 for i in range(num_classes)]
53 | for image_batch, label_batch in supervised_loader:
54 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
55 | l_unique = torch.unique(label_batch.data)
56 | list_unique = [element.item() for element in l_unique.flatten()]
57 | l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480])
58 | list_count = [count.item() for count in l_unique_count.flatten()]
59 | for index in list_unique:
60 | alpha[index] += list_count[list_unique.index(index)]
61 | return alpha
62 |
63 | # for FocalLoss
64 | def softmax_helper(x):
65 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
66 | rpt = [1 for _ in range(len(x.size()))]
67 | rpt[1] = x.size(1)
68 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
69 | e_x = torch.exp(x - x_max)
70 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
71 |
72 |
73 | class FocalLoss(nn.Module):
74 | """
75 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
76 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
77 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
78 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
79 | :param num_class:
80 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
81 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
82 | focus on hard misclassified example
83 | :param smooth: (float,double) smooth value when cross entropy
84 | :param balance_index: (int) balance class index, should be specific when alpha is float
85 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
86 | """
87 |
88 | def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
89 | super(FocalLoss, self).__init__()
90 | self.apply_nonlin = apply_nonlin
91 | self.alpha = alpha
92 | self.gamma = gamma
93 | self.balance_index = balance_index
94 | self.smooth = smooth
95 | self.size_average = size_average
96 |
97 | if self.smooth is not None:
98 | if self.smooth < 0 or self.smooth > 1.0:
99 | raise ValueError('smooth value should be in [0,1]')
100 |
101 | def forward(self, logit, target):
102 | if self.apply_nonlin is not None:
103 | logit = self.apply_nonlin(logit)
104 | num_class = logit.shape[1]
105 |
106 | if logit.dim() > 2:
107 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
108 | logit = logit.view(logit.size(0), logit.size(1), -1)
109 | logit = logit.permute(0, 2, 1).contiguous()
110 | logit = logit.view(-1, logit.size(-1))
111 | target = torch.squeeze(target, 1)
112 | target = target.view(-1, 1)
113 |
114 | valid_mask = None
115 | if self.ignore_index is not None:
116 | valid_mask = target != self.ignore_index
117 | target = target * valid_mask
118 |
119 | alpha = self.alpha
120 |
121 | if alpha is None:
122 | alpha = torch.ones(num_class, 1)
123 | elif isinstance(alpha, (list, np.ndarray)):
124 | assert len(alpha) == num_class
125 | alpha = torch.FloatTensor(alpha).view(num_class, 1)
126 | alpha = alpha / alpha.sum()
127 | alpha = 1/alpha # inverse of class frequency
128 | elif isinstance(alpha, float):
129 | alpha = torch.ones(num_class, 1)
130 | alpha = alpha * (1 - self.alpha)
131 | alpha[self.balance_index] = self.alpha
132 |
133 | else:
134 | raise TypeError('Not support alpha type')
135 |
136 | if alpha.device != logit.device:
137 | alpha = alpha.to(logit.device)
138 |
139 | idx = target.cpu().long()
140 |
141 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
142 |
143 | # to resolve error in idx in scatter_
144 | idx[idx==225]=0
145 |
146 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
147 | if one_hot_key.device != logit.device:
148 | one_hot_key = one_hot_key.to(logit.device)
149 |
150 | if self.smooth:
151 | one_hot_key = torch.clamp(
152 | one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
153 | pt = (one_hot_key * logit).sum(1) + self.smooth
154 | logpt = pt.log()
155 |
156 | gamma = self.gamma
157 |
158 | alpha = alpha[idx]
159 | alpha = torch.squeeze(alpha)
160 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
161 |
162 | if valid_mask is not None:
163 | loss = loss * valid_mask.squeeze()
164 |
165 | if self.size_average:
166 | loss = loss.mean()
167 | else:
168 | loss = loss.sum()
169 | return loss
170 |
171 |
172 | class abCE_loss(nn.Module):
173 | """
174 | Annealed-Bootstrapped cross-entropy loss
175 | """
176 | def __init__(self, iters_per_epoch, epochs, num_classes, weight=None,
177 | reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'):
178 | super(abCE_loss, self).__init__()
179 | self.weight = torch.FloatTensor(weight) if weight is not None else weight
180 | self.reduction = reduction
181 | self.thresh = thresh
182 | self.min_kept = min_kept
183 | self.ramp_type = ramp_type
184 |
185 | if ramp_type is not None:
186 | self.rampup_func = getattr(ramps, ramp_type)
187 | self.iters_per_epoch = iters_per_epoch
188 | self.num_classes = num_classes
189 | self.start = 1/num_classes
190 | self.end = 0.9
191 | self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch
192 |
193 | def threshold(self, curr_iter, epoch):
194 | cur_total_iter = self.iters_per_epoch * epoch + curr_iter
195 | current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters)
196 | return current_rampup * (self.end - self.start) + self.start
197 |
198 | def forward(self, predict, target, ignore_index, curr_iter, epoch):
199 | batch_kept = self.min_kept * target.size(0)
200 | prob_out = F.softmax(predict, dim=1)
201 | tmp_target = target.clone()
202 | tmp_target[tmp_target == ignore_index] = 0
203 | prob = prob_out.gather(1, tmp_target.unsqueeze(1))
204 | mask = target.contiguous().view(-1, ) != ignore_index
205 | sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()
206 |
207 | if self.ramp_type is not None:
208 | thresh = self.threshold(curr_iter=curr_iter, epoch=epoch)
209 | else:
210 | thresh = self.thresh
211 |
212 | min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
213 | threshold = max(min_threshold, thresh)
214 | loss_matrix = F.cross_entropy(predict, target,
215 | weight=self.weight.to(predict.device) if self.weight is not None else None,
216 | ignore_index=ignore_index, reduction='none')
217 | loss_matirx = loss_matrix.contiguous().view(-1, )
218 | sort_loss_matirx = loss_matirx[mask][sort_indices]
219 | select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
220 | if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
221 | return select_loss_matrix.sum()
222 | elif self.reduction == 'mean':
223 | return select_loss_matrix.mean()
224 | else:
225 | raise NotImplementedError('Reduction Error!')
226 |
227 |
228 |
229 | def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
230 | assert inputs.requires_grad == True and targets.requires_grad == False
231 | assert inputs.size() == targets.size() # (batch_size * num_classes * H * W)
232 | inputs = F.softmax(inputs, dim=1)
233 | if use_softmax:
234 | targets = F.softmax(targets, dim=1)
235 |
236 | if conf_mask:
237 | loss_mat = F.mse_loss(inputs, targets, reduction='none')
238 | mask = (targets.max(1)[0] > threshold)
239 | loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
240 | if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
241 | return loss_mat.mean()
242 | else:
243 | return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size
244 |
245 |
246 | def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
247 | assert inputs.requires_grad == True and targets.requires_grad == False
248 | assert inputs.size() == targets.size()
249 | input_log_softmax = F.log_softmax(inputs, dim=1)
250 | if use_softmax:
251 | targets = F.softmax(targets, dim=1)
252 |
253 | if conf_mask:
254 | loss_mat = F.kl_div(input_log_softmax, targets, reduction='none')
255 | mask = (targets.max(1)[0] > threshold)
256 | loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
257 | if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
258 | return loss_mat.sum() / mask.shape.numel()
259 | else:
260 | return F.kl_div(input_log_softmax, targets, reduction='mean')
261 |
262 |
263 | def softmax_js_loss(inputs, targets, **_):
264 | assert inputs.requires_grad == True and targets.requires_grad == False
265 | assert inputs.size() == targets.size()
266 | epsilon = 1e-5
267 |
268 | M = (F.softmax(inputs, dim=1) + targets) * 0.5
269 | kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean')
270 | kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean')
271 | return (kl1 + kl2) * 0.5
272 |
273 |
274 |
275 | def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8):
276 | """
277 | Pair-wise loss in the sup. mat.
278 | """
279 | if isinstance(unsup_outputs, list):
280 | unsup_outputs = torch.stack(unsup_outputs)
281 |
282 | # Only for a subset of the aux outputs to reduce computation and memory
283 | unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))]
284 | unsup_outputs = unsup_outputs[:nbr_of_pairs]
285 |
286 | temp = torch.zeros_like(unsup_outputs) # For grad purposes
287 | for i, u in enumerate(unsup_outputs):
288 | temp[i] = F.softmax(u, dim=1)
289 | mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs
290 | pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance
291 | pw_loss = pw_loss.sum(1) # Sum over classes
292 | if size_average:
293 | return pw_loss.mean()
294 | return pw_loss.sum()
295 |
296 |
--------------------------------------------------------------------------------
/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import _LRScheduler
3 |
4 |
5 | class Step(_LRScheduler):
6 | def __init__(self, optimizer, num_epochs, steps=2, gamma=0.1, last_epoch=-1, **_):
7 | self.step_size = num_epochs // steps
8 | self.num_epochs = num_epochs
9 | self.gamma = gamma
10 | super(Step, self).__init__(optimizer, last_epoch)
11 |
12 | def get_lr(self):
13 | if self.step_size != 0:
14 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
15 | for base_lr in self.base_lrs]
16 | return self.base_lrs
17 |
18 | class Poly(_LRScheduler):
19 | def __init__(self, optimizer, num_epochs, iters_per_epoch, warmup_epochs=0, last_epoch=-1):
20 | self.iters_per_epoch = iters_per_epoch
21 | self.cur_iter = 0
22 | self.N = num_epochs * iters_per_epoch
23 | self.warmup_iters = warmup_epochs * iters_per_epoch
24 | super(Poly, self).__init__(optimizer, last_epoch)
25 |
26 | def get_lr(self):
27 | T = self.last_epoch * self.iters_per_epoch + self.cur_iter
28 | factor = pow((1 - 1.0 * T / self.N), 0.9)
29 | if self.warmup_iters > 0 and T < self.warmup_iters:
30 | factor = 1.0 * T / self.warmup_iters
31 |
32 | self.cur_iter %= self.iters_per_epoch
33 | self.cur_iter += 1
34 | assert factor >= 0, 'error in lr_scheduler'
35 | return [base_lr * factor for base_lr in self.base_lrs]
36 |
37 | class OneCycle(_LRScheduler):
38 | def __init__(self, optimizer, num_epochs, iters_per_epoch=0, last_epoch=-1,
39 | momentums = (0.85, 0.95), div_factor = 25, phase1=0.3):
40 | self.iters_per_epoch = iters_per_epoch
41 | self.cur_iter = 0
42 | self.N = num_epochs * iters_per_epoch
43 | self.phase1_iters = int(self.N * phase1)
44 | self.phase2_iters = (self.N - self.phase1_iters)
45 | self.momentums = momentums
46 | self.mom_diff = momentums[1] - momentums[0]
47 |
48 | self.low_lrs = [opt_grp['lr']/div_factor for opt_grp in optimizer.param_groups]
49 | self.final_lrs = [opt_grp['lr']/(div_factor * 1e4) for opt_grp in optimizer.param_groups]
50 | super(OneCycle, self).__init__(optimizer, last_epoch)
51 |
52 | def get_lr(self):
53 | T = self.last_epoch * self.iters_per_epoch + self.cur_iter
54 | self.cur_iter %= self.iters_per_epoch
55 | self.cur_iter += 1
56 |
57 | # Going from base_lr / 25 -> base_lr
58 | if T <= self.phase1_iters:
59 | cos_anneling = (1 + math.cos(math.pi * T / self.phase1_iters)) / 2
60 | for i in range(len(self.optimizer.param_groups)):
61 | self.optimizer.param_groups[i]['momentum'] = self.momentums[0] + self.mom_diff * cos_anneling
62 |
63 | return [base_lr - (base_lr - low_lr) * cos_anneling
64 | for base_lr, low_lr in zip(self.base_lrs, self.low_lrs)]
65 |
66 | # Going from base_lr -> base_lr / (25e4)
67 | T -= self.phase1_iters
68 | cos_anneling = (1 + math.cos(math.pi * T / self.phase2_iters)) / 2
69 | for i in range(len(self.optimizer.param_groups)):
70 | self.optimizer.param_groups[i]['momentum'] = self.momentums[1] - self.mom_diff * cos_anneling
71 | return [final_lr + (base_lr - final_lr) * cos_anneling
72 | for base_lr, final_lr in zip(self.base_lrs, self.final_lrs)]
73 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 |
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value"""
8 | def __init__(self):
9 | self.initialized = False
10 | self.val = 0
11 | self.avg = 0
12 | self.sum = 0
13 | self.count = 0
14 |
15 | def initialize(self, val, weight):
16 | self.val = val
17 | self.avg = val
18 | self.sum = np.multiply(val, weight)
19 | self.count = weight
20 | self.initialized = True
21 |
22 | def update(self, val, weight=1):
23 | if not self.initialized:
24 | self.initialize(val, weight)
25 | else:
26 | self.add(val, weight)
27 |
28 | def add(self, val, weight):
29 | self.val = val
30 | self.sum = np.add(self.sum, np.multiply(val, weight))
31 | self.count = self.count + weight
32 | self.avg = self.sum / self.count
33 |
34 | @property
35 | def value(self):
36 | return self.val
37 |
38 | @property
39 | def average(self):
40 | return np.round(self.avg, 5)
41 |
42 |
43 | def batch_pix_accuracy(output, target):
44 | _, predict = torch.max(output, 1)
45 |
46 | predict = predict.int() + 1
47 | target = target.int() + 1
48 |
49 | pixel_labeled = (target > 0).sum()
50 | pixel_correct = ((predict == target)*(target > 0)).sum()
51 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
52 | return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy()
53 |
54 |
55 | def batch_intersection_union(output, target, num_class):
56 | _, predict = torch.max(output, 1)
57 | predict = predict + 1
58 | target = target + 1
59 |
60 | predict = predict * (target > 0).long()
61 | intersection = predict * (predict == target).long()
62 |
63 | area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1)
64 | area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1)
65 | area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1)
66 | area_union = area_pred + area_lab - area_inter
67 | assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area"
68 | return area_inter.cpu().numpy(), area_union.cpu().numpy()
69 |
70 |
71 | def eval_metrics(output, target, num_classes, ignore_index):
72 | target = target.clone()
73 | target[target == ignore_index] = -1
74 | correct, labeled = batch_pix_accuracy(output.data, target)
75 | inter, union = batch_intersection_union(output.data, target, num_classes)
76 | return [np.round(correct, 5), np.round(labeled, 5), np.round(inter, 5), np.round(union, 5)]
77 |
78 |
79 | # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
80 | def pixel_accuracy(output, target):
81 | output = np.asarray(output)
82 | target = np.asarray(target)
83 | pixel_labeled = np.sum(target > 0)
84 | pixel_correct = np.sum((output == target) * (target > 0))
85 | return pixel_correct, pixel_labeled
86 |
87 |
88 | def inter_over_union(output, target, num_class):
89 | output = np.asarray(output) + 1
90 | target = np.asarray(target) + 1
91 | output = output * (target > 0)
92 |
93 | intersection = output * (output == target)
94 | area_inter, _ = np.histogram(intersection, bins=num_class, range=(1, num_class))
95 | area_pred, _ = np.histogram(output, bins=num_class, range=(1, num_class))
96 | area_lab, _ = np.histogram(target, bins=num_class, range=(1, num_class))
97 | area_union = area_pred + area_lab - area_inter
98 | return area_inter, area_union
--------------------------------------------------------------------------------
/utils/pallete.py:
--------------------------------------------------------------------------------
1 |
2 | def get_voc_pallete(num_classes):
3 | n = num_classes
4 | pallete = [0]*(n*3)
5 | for j in range(0,n):
6 | lab = j
7 | pallete[j*3+0] = 0
8 | pallete[j*3+1] = 0
9 | pallete[j*3+2] = 0
10 | i = 0
11 | while (lab > 0):
12 | pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i))
13 | pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i))
14 | pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i))
15 | i = i + 1
16 | lab >>= 3
17 | return pallete
18 |
--------------------------------------------------------------------------------
/utils/ramps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def sigmoid_rampup(current, rampup_length):
4 | if rampup_length == 0:
5 | return 1.0
6 | current = np.clip(current, 0.0, rampup_length)
7 | phase = 1.0 - current / rampup_length
8 | return float(np.exp(-5.0 * phase * phase))
9 |
10 | def linear_rampup(current, rampup_length):
11 | assert current >= 0 and rampup_length >= 0
12 | if current >= rampup_length:
13 | return 1.0
14 | return current / rampup_length
15 |
16 | def cosine_rampup(current, rampup_length):
17 | if rampup_length == 0:
18 | return 1.0
19 | current = np.clip(current, 0.0, rampup_length)
20 | return 1 - float(.5 * (np.cos(np.pi * current / rampup_length) + 1))
21 |
22 | def log_rampup(current, rampup_length):
23 | if rampup_length == 0:
24 | return 1.0
25 | current = np.clip(current, 0.0, rampup_length)
26 | return float(1- np.exp(-5.0 * current / rampup_length))
27 |
28 | def exp_rampup(current, rampup_length):
29 | if rampup_length == 0:
30 | return 1.0
31 | current = np.clip(current, 0.0, rampup_length)
32 | return float(np.exp(5.0 * (current / rampup_length - 1)))
33 |
--------------------------------------------------------------------------------