├── .circleci └── config.yml ├── .gitignore ├── Readme.md ├── configs ├── config_mixmatch.yml └── config_mixup.yml ├── imgs ├── flow.drawio ├── flow.pdf └── flow.png ├── lightning_ssl ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── base_data.py │ └── cifar10.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── cnn13.py │ └── wideresnet.py ├── module │ ├── __init__.py │ ├── classifier_module.py │ ├── mixmatch.py │ └── mixmatch_module.py └── utils │ ├── __init__.py │ ├── argparser.py │ ├── torch_utils.py │ └── utils.py ├── main.py ├── read_from_tb.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── dataloader ├── __init__.py └── test_base_data.py ├── models ├── __init__.py ├── test_basemodel.py ├── test_cnn13.py └── test_wideresnet.py ├── module ├── __init__.py ├── assets │ └── tmp_config.yml ├── simple_data.py ├── simple_model.py ├── test_mixmatch.py └── test_supervised.py ├── test_models └── utils ├── __init__.py ├── assets └── tmp_config.yml ├── test_argparser.py ├── test_torch_utils.py └── test_utils.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | python: circleci/python@1.0.0 5 | 6 | jobs: 7 | build-and-test: 8 | resource_class: medium 9 | executor: python/default 10 | steps: 11 | - checkout 12 | # - python/load-cache 13 | # - python/install-deps 14 | # - python/save-cache 15 | - run: 16 | name: Installation 17 | command: pip install -e .[test] 18 | - run: 19 | name: Test 20 | command: | 21 | python -m coverage run --source lightning_ssl -m py.test lightning_ssl tests -v --flake8 22 | python -m coverage report -m 23 | coverage html 24 | - store_artifacts: 25 | path: htmlcov 26 | 27 | workflows: 28 | main: 29 | jobs: 30 | - build-and-test 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | .vscode 3 | **/**/__pycache__/ 4 | lightning_semi_supervised_learning.egg-info/ 5 | **/lightning_logs/ 6 | cifar-10* 7 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Semi Supervised Learning with PyTorch Lightning 2 | 3 | This project aims to use PyTorch Lightning to implement state-of-the-art algorithms in semi-supervised leanring (SSL). 4 | 5 | ## Semi-Supervised Learning 6 | The semi-supervised learning is to leverage abundant unlabeled samples to improve models under the the scenario of scarce data. There are several assumptions which are usually used in semi-supervised learning, 7 | 8 | * Smoothness assumption 9 | * Low-density assumption 10 | * Manifold assumption 11 | 12 | Most of the approaches try to exploit regularization on models to satisfy the assumptions. In this repository, we will first focus on methods using consistency loss. 13 | 14 | 17 | 18 | 19 | ## What is PyTorch Lightning 20 | PyTorch Lightning is a PyTorch Wrapper to standardize the training and testing process of AI projects. The projects using PyTorch Lightning can focus on the implementation of the algorithm, and there is no need to worry about some complicated engineering parts, such as multi-GPU training, 16-bit precision, Tensorboard logging, and TPU training. 21 | 22 | In this project, we leverage PyTorch Lightning as the coding backbone and implement algorithms with minimum changes. The necessary implementation of a new algorithm is put in `module`. 23 | 24 | 25 | 26 | ## Requirements 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Module Flow 32 |

33 | 34 |

35 | 36 | * `configs`: Contained the config files for approaches. 37 | * `models`: Contained all the models. 38 | * `dataloader`: Data loader for every dataset. 39 | * `module`: SSL modules inherited `pytorch_lightning.LightningModule`. 40 | 41 | To implement a new method, one usually need to define new config, data loader and PL module. 42 | 43 | ## Usage 44 | 45 | * Please refer to `argparser.py` for hyperparameters. 46 | * `read_from_tb.py` is used to extract the final accuracies from `tensorboard` logs. 47 | 48 | ### Fully-Supervised Training 49 | 50 | ``` 51 | python main.py -c configs/config_mixup.ini -g [GPU ID] --affix [FILE NAME] 52 | ``` 53 | 54 | ### Train for Mixmatch 55 | 56 | ``` 57 | python main.py -c configs/config_mixmatch.ini -g [GPU ID] --num_labeled [NUMBER OF UNLABELED DATA] --affix [FILE NAME] 58 | ``` 59 | 60 | ## Results 61 | 62 | ### Supervised Training 63 | The result is the average of three runs (seed=1, 2, 3). 64 | 65 | | | Acc | 66 | | :---: | :---: | 67 | | full train with mixup | 4.41±0.03 | 68 | 69 | ### Mixmatch 70 | 71 | The experiments run for five times (seed=1,2,3,4,5) in the paper, but only three times (seed=1,2,3) for this implementation. 72 | 73 | 74 | | | time (hr:min) | 250 | 500 | 1000 | 2000 | 4000 | 75 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 76 | | Paper | |11.08±0.87|9.65±0.94 |7.75±0.32|7.03±0.15|6.24±0.06| 77 | | Reproduce |17:24 |10.93±1.20|9.72±0.63 |8.02±0.74| - | - | 78 | | This repo |17:40 |11.10±1.00|10.05±0.45|8.00±0.42|7.13±0.13|6.22±0.08| 79 | 80 | 81 | 87 | 88 | ## Plans 89 | 90 | 1. Remixmatch 91 | 2. Fixmatch 92 | 3. GAN based method (DSGAN, BadGAN) 93 | 4. Other approach using consistency loss (VAT, mean teacher) 94 | 5. Polish the code for `CustomSemiDataset` in `data_loader/base_data.py` 95 | -------------------------------------------------------------------------------- /configs/config_mixmatch.yml: -------------------------------------------------------------------------------- 1 | cifar10: 2 | learning_rate: 2e-3 3 | learning_scenario: semi 4 | algo: mixmatch 5 | weight_decay: 2e-2 6 | ema: 0.999 7 | batch_size: 64 8 | max_steps: 1048576 9 | max_epochs: 2000 10 | gpus: 0 11 | label_smoothing: 0.000 12 | num_augments: 2 13 | alpha: 0.75 14 | lambda_u: 75 15 | T: 0.5 16 | num_labeled: 250 17 | num_val: 10 18 | affix: "@250_brand_new" 19 | -------------------------------------------------------------------------------- /configs/config_mixup.yml: -------------------------------------------------------------------------------- 1 | cifar10: 2 | learning_rate: 2e-3 3 | learning_scenario: supervised 4 | weight_decay: 2e-2 5 | ema: 0.999 6 | batch_size: 64 7 | max_steps: 1048576 8 | max_epochs: 2000 9 | gpus: 0 10 | label_smoothing: 0.001 11 | alpha: 1 12 | algo: mixup 13 | num_val: 10 14 | affix: fs_mixup 15 | -------------------------------------------------------------------------------- /imgs/flow.drawio: -------------------------------------------------------------------------------- 1 | 3VjBcpswEP0aH9MxCLB9TBy3nU467QydadybgmRQKlhXFjHO11cYKUIxcdLWDnEvtvattEj79FikAZrm1QeBl9lnIJQP/CGpBuhy4PuTIFK/NbBpgDAIGyAVjDSQZ4GY3VMNDjVaMkJXTkcJwCVbumACRUET6WBYCFi73RbA3acucUp3gDjBfBf9zojMGnTsjyz+kbI0M0/2oknjybHprFeyyjCBdQtCswGaCgDZtPJqSnmdO5OXZtz7J7wPExO0kC8ZMP8Roy/ZND7js+z2Pg0+edPZmd9EucO81AvWk5UbkwEBZUFoHWQ4QBfrjEkaL3FSe9eKcoVlMufK8lRzwTifAgexHYsWCxolicJXUsBP2vKQ0eRmWAfUE6BC0urJlXkP+VL7jEJOpdioLnoAGukU6z1mMr62hKFIY1mLLIQ0iPUmSR9C2zyqhk7lH6TVQztppETtK22CkBmkUGA+s+iFm2jb5wpgqdN7S6XcaJHgUoKbfFoxed1qz+tQ70JtXVY68tbYGKNQ6722HWtz3vbZYVvLjGvWVy9qP2kqB1CKhO5LlpY1FimV+/oNu3eBoBxLdudO5PCUhn1SammctzwnT2nQK6NRr4z+lUgdPi29b4dR1CujHdUs4rKuS6AW3KY6+lWCcZyttmSdqw5+sKysU7XS+l99XCxYamKpqTXhGueeguk9XzAPUP4CL3DKX7hb/jzUUf4mR6t+ux8Rb7T6HVAg6CSqGDqSQgiWWDk5YELFqQjlDSil6431+tI5oAyCk5BBcCQZ5M3x9zS2v9f79g//mQZv3EVDHF/V0+UpCCaz/KWEqNRKN+vusbWAgj4642oIc5YWykwUFer9hy5qoliC+bl25IyQrVa7aHb1ewCmo6F7IEYvJNo/FtHRkfT29crEuREG+yYwK2wRso6j6PLxhcc4od0XHjfjMAgPxO/o0YVH/0r2ut6oJ13IzAXks5Us6rOQmVm+irJUdStVrP9GRuH49WSkTHvVu/W17svR7Dc= -------------------------------------------------------------------------------- /imgs/flow.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/imgs/flow.pdf -------------------------------------------------------------------------------- /imgs/flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/imgs/flow.png -------------------------------------------------------------------------------- /lightning_ssl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/lightning_ssl/__init__.py -------------------------------------------------------------------------------- /lightning_ssl/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar10 import ( 2 | SemiCIFAR10Module, 3 | SupervisedCIFAR10Module, 4 | ) 5 | -------------------------------------------------------------------------------- /lightning_ssl/dataloader/base_data.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import math 3 | import torch 4 | import random 5 | import itertools 6 | import numpy as np 7 | 8 | from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset 9 | import pytorch_lightning as pl 10 | 11 | 12 | def get_split_indices(labels, num_labeled, num_val, _n_classes): 13 | """ 14 | Split the train data into the following three set: 15 | (1) labeled data 16 | (2) unlabeled data 17 | (3) val data 18 | 19 | Data distribution of the three sets are same as which of the 20 | original training data. 21 | 22 | Inputs: 23 | labels: (np.int) array of labels 24 | num_labeled: (int) 25 | num_val: (int) 26 | _n_classes: (int) 27 | 28 | 29 | Return: 30 | the three indices for the three sets 31 | """ 32 | # val_per_class = num_val // _n_classes 33 | val_indices = [] 34 | train_indices = [] 35 | 36 | num_total = len(labels) 37 | num_per_class = [] 38 | for c in range(_n_classes): 39 | num_per_class.append((labels == c).sum().astype(int)) 40 | 41 | # obtain val indices, data evenly drawn from each class 42 | for c, num_class in zip(range(_n_classes), num_per_class): 43 | val_this_class = max(int(num_val * (num_class / num_total)), 1) 44 | class_indices = np.where(labels == c)[0] 45 | np.random.shuffle(class_indices) 46 | val_indices.append(class_indices[:val_this_class]) 47 | train_indices.append(class_indices[val_this_class:]) 48 | 49 | # split data into labeled and unlabeled 50 | labeled_indices = [] 51 | unlabeled_indices = [] 52 | 53 | # num_labeled_per_class = num_labeled // _n_classes 54 | 55 | for c, num_class in zip(range(_n_classes), num_per_class): 56 | num_labeled_this_class = max(int(num_labeled * (num_class / num_total)), 1) 57 | labeled_indices.append(train_indices[c][:num_labeled_this_class]) 58 | unlabeled_indices.append(train_indices[c][num_labeled_this_class:]) 59 | 60 | labeled_indices = np.hstack(labeled_indices) 61 | unlabeled_indices = np.hstack(unlabeled_indices) 62 | val_indices = np.hstack(val_indices) 63 | 64 | return labeled_indices, unlabeled_indices, val_indices 65 | 66 | 67 | class Subset(Dataset): 68 | r""" 69 | Subset of a dataset at specified indices. 70 | 71 | Arguments: 72 | dataset (Dataset): The whole Dataset 73 | indices (sequence): Indices in the whole set selected for subset 74 | """ 75 | 76 | def __init__(self, dataset, indices, transform=None): 77 | self.dataset = dataset 78 | self.indices = indices 79 | self.transform = transform 80 | 81 | nums = [0 for _ in range(10)] 82 | for i in range(len(self.indices)): 83 | nums[self.dataset[self.indices[i]][1]] += 1 84 | 85 | print(nums) 86 | print(np.sum(nums)) 87 | 88 | def __getitem__(self, idx): 89 | data, label = self.dataset[self.indices[idx]] 90 | if self.transform is not None: 91 | data = self.transform(data) 92 | return data, label 93 | 94 | def __len__(self): 95 | return len(self.indices) 96 | 97 | 98 | class MultiDataset(Dataset): 99 | """ 100 | MultiDataset is used for training multiple datasets together. The lengths of the datasets 101 | should be the same. 102 | """ 103 | 104 | def __init__(self, datasets): 105 | super(MultiDataset, self).__init__() 106 | assert len(datasets) > 1, "You should use at least two datasets" 107 | 108 | for d in datasets[1:]: 109 | assert len(d) == len( 110 | datasets[0] 111 | ), "The lengths of the datasets should be the same." 112 | 113 | self.datasets = datasets 114 | self.max_length = max([len(d) for d in self.datasets]) 115 | 116 | def __getitem__(self, idx): 117 | return tuple([d[idx] for d in self.datasets]) 118 | 119 | def __len__(self): 120 | return self.max_length 121 | 122 | 123 | class MagicClass(object): 124 | """ 125 | Codes are borrowed from https://github.com/PyTorchLightning/pytorch-lightning/pull/1959 126 | """ 127 | 128 | def __init__(self, data) -> None: 129 | self.d = data 130 | self.l = max([len(d) for d in self.d]) 131 | 132 | def __len__(self) -> int: 133 | return self.l 134 | 135 | def __iter__(self): 136 | if isinstance(self.d, list): 137 | gen = [None for v in self.d] 138 | 139 | # for k,v in self.d.items(): 140 | # # gen[k] = itertools.cycle(v) 141 | # gen[k] = iter(v) 142 | 143 | for i in range(self.l): 144 | rv = [None for v in self.d] 145 | for k, v in enumerate(self.d): 146 | # If reaching the end of the iterator, recreate one 147 | # because shuffle=True in dataloader, the iter will return a different order 148 | if i % len(v) == 0: 149 | gen[k] = iter(v) 150 | rv[k] = next(gen[k]) 151 | 152 | yield rv 153 | 154 | else: 155 | gen = itertools.cycle(self.d) 156 | for i in range(self.l): 157 | batch = next(gen) 158 | yield batch 159 | 160 | 161 | class CustomSemiDataset(Dataset): 162 | def __init__(self, datasets): 163 | self.datasets = datasets 164 | 165 | self.map_indices = [[] for _ in self.datasets] 166 | self.min_length = min(len(d) for d in self.datasets) 167 | self.max_length = max(len(d) for d in self.datasets) 168 | 169 | def __getitem__(self, i): 170 | # return tuple(d[i] for d in self.datasets) 171 | 172 | # self.map_indices will reload when calling self.__len__() 173 | return tuple(d[m[i]] for d, m in zip(self.datasets, self.map_indices)) 174 | 175 | def construct_map_index(self): 176 | """ 177 | Construct the mapping indices for every data. Because the __len__ is larger than the size of some datset, 178 | the map_index is use to map the parameter "index" in __getitem__ to a valid index of each dataset. 179 | Because of the dataset has different length, we should maintain different indices for them. 180 | """ 181 | 182 | def update_indices(original_indices, data_length, max_data_length): 183 | # update the sampling indices for this dataset 184 | 185 | # return: a list, which maps the range(max_data_length) to the val index in the dataset 186 | 187 | original_indices = original_indices[max_data_length:] # remove used indices 188 | fill_num = max_data_length - len(original_indices) 189 | batch = math.ceil(fill_num / data_length) 190 | 191 | additional_indices = list(range(data_length)) * batch 192 | random.shuffle(additional_indices) 193 | 194 | original_indices += additional_indices 195 | 196 | assert ( 197 | len(original_indices) >= max_data_length 198 | ), "the length of matcing indices is too small" 199 | 200 | return original_indices 201 | 202 | # use same mapping index for all unlabeled dataset for data consistency 203 | # the i-th dataset is the labeled data 204 | self.map_indices = [ 205 | update_indices(m, len(d), self.max_length) 206 | for m, d in zip(self.map_indices, self.datasets) 207 | ] 208 | 209 | # use same mapping index for all unlabeled dataset for data consistency 210 | # the i-th dataset is the labeled data 211 | for i in range(1, len(self.map_indices)): 212 | self.map_indices[i] = self.map_indices[1] 213 | 214 | def __len__(self): 215 | # will be called every epoch 216 | return self.max_length 217 | 218 | 219 | class DataModuleBase(pl.LightningDataModule): 220 | labeled_indices: ... 221 | unlabeled_indices: ... 222 | val_indices: ... 223 | 224 | def __init__( 225 | self, data_root, num_workers, batch_size, num_labeled, num_val, n_classes 226 | ): 227 | super().__init__() 228 | self.data_root = data_root 229 | self.batch_size = batch_size 230 | self.num_labeled = num_labeled 231 | self.num_val = num_val 232 | self._n_classes = n_classes 233 | 234 | self.train_transform = None # TODO, need implement this in your custom datasets 235 | self.test_transform = None # TODO, need implement this in your custom datasets 236 | 237 | self.train_set = None 238 | self.val_set = None 239 | self.test_set = None 240 | 241 | self.num_workers = num_workers 242 | 243 | def train_dataloader(self): 244 | # get and process the data first 245 | 246 | return DataLoader( 247 | self.train_set, 248 | batch_size=self.batch_size, 249 | shuffle=True, 250 | num_workers=self.num_workers, 251 | pin_memory=True, 252 | drop_last=True, 253 | ) 254 | 255 | def val_dataloader(self): 256 | # return both val and test loader 257 | 258 | val_loader = DataLoader( 259 | self.val_set, 260 | batch_size=self.batch_size, 261 | shuffle=False, 262 | num_workers=self.num_workers, 263 | pin_memory=True, 264 | ) 265 | 266 | test_loader = DataLoader( 267 | self.test_set, 268 | batch_size=self.batch_size, 269 | num_workers=self.num_workers, 270 | pin_memory=True, 271 | ) 272 | 273 | return [val_loader, test_loader] 274 | 275 | def test_dataloader(self): 276 | 277 | return DataLoader( 278 | self.test_set, 279 | batch_size=self.batch_size, 280 | num_workers=self.num_workers, 281 | pin_memory=True, 282 | ) 283 | 284 | @property 285 | def n_classes(self): 286 | # self._n_class should be defined in _prepare_train_dataset() 287 | return self._n_classes 288 | 289 | @property 290 | def num_labeled_data(self): 291 | assert self.train_set is not None, ( 292 | "Load train data before calling %s" % self.num_labeled_data.__name__ 293 | ) 294 | return len(self.labeled_indices) 295 | 296 | @property 297 | def num_unlabeled_data(self): 298 | assert self.train_set is not None, ( 299 | "Load train data before calling %s" % self.num_unlabeled_data.__name__ 300 | ) 301 | return len(self.unlabeled_indices) 302 | 303 | @property 304 | def num_val_data(self): 305 | assert self.train_set is not None, ( 306 | "Load train data before calling %s" % self.num_val_data.__name__ 307 | ) 308 | return len(self.val_indices) 309 | 310 | @property 311 | def num_test_data(self): 312 | assert self.test_set is not None, ( 313 | "Load test data before calling %s" % self.num_test_data.__name__ 314 | ) 315 | return len(self.test_set) 316 | 317 | 318 | class SemiDataModule(DataModuleBase): 319 | """ 320 | Data module for semi-supervised tasks. self.prepare_data() is not implemented. For custom dataset, 321 | inherit this class and implement self.prepare_data(). 322 | """ 323 | 324 | def __init__( 325 | self, 326 | data_root, 327 | num_workers, 328 | batch_size, 329 | num_labeled, 330 | num_val, 331 | num_augments, 332 | n_classes, 333 | ): 334 | super(SemiDataModule, self).__init__( 335 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes 336 | ) 337 | self.num_augments = num_augments 338 | 339 | def setup(self): 340 | # prepare train and val dataset, and split the train dataset 341 | # into labeled and unlabeled groups. 342 | assert ( 343 | self.train_set is not None 344 | ), "Should create self.train_set in self.setup()" 345 | 346 | indices = np.arange(len(self.train_set)) 347 | ys = np.array([self.train_set[i][1] for i in indices], dtype=np.int64) 348 | # np.random.shuffle(ys) 349 | # get the number of classes 350 | # self._n_classes = len(np.unique(ys)) 351 | 352 | ( 353 | self.labeled_indices, 354 | self.unlabeled_indices, 355 | self.val_indices, 356 | ) = get_split_indices(ys, self.num_labeled, self.num_val, self._n_classes) 357 | 358 | self.val_set = Subset(self.train_set, self.val_indices, self.test_transform) 359 | 360 | # unlabeled_list = [ 361 | # Subset(self.train_set, self.unlabeled_indices, self.train_transform) \ 362 | # for _ in range(self.num_augments) 363 | # ] 364 | 365 | # self.unlabeled_set = MultiDataset(unlabeled_list) 366 | # self.labeled_set = Subset(self.train_set, self.labeled_indices, self.train_transform) 367 | 368 | train_list = [ 369 | Subset(self.train_set, self.unlabeled_indices, self.train_transform) 370 | for _ in range(self.num_augments) 371 | ] 372 | 373 | train_list.insert( 374 | 0, Subset(self.train_set, self.labeled_indices, self.train_transform) 375 | ) 376 | 377 | self.train_set = CustomSemiDataset(train_list) 378 | 379 | # def train_dataloader(self): 380 | # # get and process the data first 381 | # if self.labeled_set is None: 382 | # self._prepare_train_dataset() 383 | 384 | # labeled_loader = DataLoader(self.labeled_set, 385 | # batch_size=self.batch_size, 386 | # shuffle=True, 387 | # num_workers=self.num_workers, 388 | # pin_memory=True, 389 | # drop_last=True) 390 | 391 | # unlabeled_loader = DataLoader(self.unlabeled_set, 392 | # batch_size=self.batch_size, 393 | # shuffle=True, 394 | # num_workers=self.num_workers, 395 | # pin_memory=True, 396 | # drop_last=True) 397 | 398 | # return MagicClass([labeled_loader, unlabeled_loader]) 399 | 400 | def train_dataloader(self): 401 | # get and process the data first 402 | 403 | self.train_set.construct_map_index() 404 | 405 | print("\ncalled\n") 406 | 407 | return DataLoader( 408 | self.train_set, 409 | batch_size=self.batch_size, 410 | shuffle=True, 411 | num_workers=self.num_workers, 412 | pin_memory=True, 413 | drop_last=True, 414 | ) 415 | 416 | 417 | class SupervisedDataModule(DataModuleBase): 418 | """ 419 | Data module for supervised tasks. self.prepare_data() is not implemented. For custom dataset, 420 | inherit this class and implement self.prepare_data(). 421 | """ 422 | 423 | def __init__( 424 | self, data_root, num_workers, batch_size, num_labeled, num_val, n_classes 425 | ): 426 | super(SupervisedDataModule, self).__init__( 427 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes 428 | ) 429 | 430 | def setup(self): 431 | # prepare train and val dataset 432 | assert ( 433 | self.train_set is not None 434 | ), "Should create self.train_set in self.setup()" 435 | 436 | indices = np.arange(len(self.train_set)) 437 | ys = np.array([self.train_set[i][1] for i in indices], dtype=np.int64) 438 | # get the number of classes 439 | # self._n_classes = len(np.unique(ys)) 440 | 441 | ( 442 | self.labeled_indices, 443 | self.unlabeled_indices, 444 | self.val_indices, 445 | ) = get_split_indices(ys, self._n_classes, self.num_val, self._n_classes) 446 | 447 | self.labeled_indices = np.hstack((self.labeled_indices, self.unlabeled_indices)) 448 | self.unlabeled_indices = [] # dummy. only for printing length 449 | 450 | self.val_set = Subset(self.train_set, self.val_indices, self.test_transform) 451 | self.train_set = Subset( 452 | self.train_set, self.labeled_indices, self.train_transform 453 | ) 454 | -------------------------------------------------------------------------------- /lightning_ssl/dataloader/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lightning_ssl.dataloader.base_data import ( 3 | SemiDataModule, 4 | SupervisedDataModule, 5 | ) 6 | 7 | import torchvision as tv 8 | from torch.utils.data import DataLoader, SubsetRandomSampler 9 | from torchvision.datasets import CIFAR10 10 | 11 | 12 | CIFAR_MEAN = (0.4914, 0.4822, 0.4465) 13 | CIFAR_STD = (0.2471, 0.2435, 0.2616) 14 | 15 | 16 | class SemiCIFAR10Module(SemiDataModule): 17 | def __init__( 18 | self, 19 | args, 20 | data_root, 21 | num_workers, 22 | batch_size, 23 | num_labeled, 24 | num_val, 25 | num_augments, 26 | ): 27 | n_classes = 10 28 | super(SemiCIFAR10Module, self).__init__( 29 | data_root, 30 | num_workers, 31 | batch_size, 32 | num_labeled, 33 | num_val, 34 | num_augments, 35 | n_classes, 36 | ) 37 | 38 | self.train_transform = tv.transforms.Compose( 39 | [ 40 | tv.transforms.RandomCrop(32, padding=4, padding_mode="reflect"), 41 | tv.transforms.RandomHorizontalFlip(), 42 | tv.transforms.ToTensor(), 43 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 44 | ] 45 | ) 46 | 47 | self.test_transform = tv.transforms.Compose( 48 | [ 49 | tv.transforms.ToTensor(), 50 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 51 | ] 52 | ) 53 | 54 | def prepare_data(self): 55 | # the transformation for train and validation dataset will be 56 | # done in _prepare_train_dataset() 57 | 58 | self.train_set = CIFAR10( 59 | self.data_root, train=True, download=True, transform=None 60 | ) 61 | 62 | self.test_set = CIFAR10( 63 | self.data_root, train=False, download=True, transform=self.test_transform 64 | ) 65 | 66 | 67 | class SupervisedCIFAR10Module(SupervisedDataModule): 68 | def __init__( 69 | self, 70 | args, 71 | data_root, 72 | num_workers, 73 | batch_size, 74 | num_labeled, 75 | num_val, 76 | num_augments, 77 | ): 78 | n_classes = 10 79 | super(SupervisedCIFAR10Module, self).__init__( 80 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes 81 | ) 82 | 83 | self.train_transform = tv.transforms.Compose( 84 | [ 85 | tv.transforms.RandomCrop(32, padding=4, padding_mode="reflect"), 86 | tv.transforms.RandomHorizontalFlip(), 87 | tv.transforms.ToTensor(), 88 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 89 | ] 90 | ) 91 | 92 | self.test_transform = tv.transforms.Compose( 93 | [ 94 | tv.transforms.ToTensor(), 95 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 96 | ] 97 | ) 98 | 99 | def prepare_data(self): 100 | # the transformation for train and validation dataset will be 101 | # done in _prepare_train_dataset() 102 | self.train_set = CIFAR10( 103 | self.data_root, train=True, download=True, transform=None 104 | ) 105 | 106 | self.test_set = CIFAR10( 107 | self.data_root, train=False, download=True, transform=self.test_transform 108 | ) 109 | -------------------------------------------------------------------------------- /lightning_ssl/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn13 import CNN13 2 | from .wideresnet import WideResNet 3 | -------------------------------------------------------------------------------- /lightning_ssl/models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lightning_ssl.utils.torch_utils import sharpening 6 | 7 | 8 | class CustomModel(nn.Module): 9 | def __init__(self): 10 | super(CustomModel, self).__init__() 11 | self.batchnorms, self.momentums = None, None 12 | 13 | @torch.no_grad() 14 | def psuedo_label(self, unlabeled_xs, temperature, batch_inference=False): 15 | """ 16 | Generate the pseudo labels for given unlabeled_xs. 17 | Args 18 | unlabeled_xs: list of unlabeled data (torch.FloatTensor), 19 | e.g. [unlabeled_data_1, unlabeled_data_2, ..., unlabeled_data_n] 20 | Note that i th element in those unlabeled data should has same 21 | semantic meaning (come from different data augmentations). 22 | temperature: the temperature parameter in softmax 23 | """ 24 | if batch_inference: 25 | batch_size, num_augmentations = unlabeled_xs[0].shape[0], len(unlabeled_xs) 26 | unlabeled_xs = torch.cat(unlabeled_xs, dim=0) 27 | p_labels = F.softmax(self(unlabeled_xs), dim=-1) 28 | p_labels = p_labels.view(num_augmentations, batch_size, -1) 29 | p_label = p_labels.mean(0).detach() 30 | else: 31 | p_labels = [F.softmax(self(x), dim=-1) for x in unlabeled_xs] 32 | p_label = torch.stack(p_labels).mean(0).detach() 33 | 34 | return sharpening(p_label, temperature) 35 | 36 | def extract_norm_n_momentum(self): 37 | # Extract the batchnorms and their momentum 38 | original_momentums = [] 39 | batchnorms = [] 40 | for module in self.modules(): 41 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 42 | original_momentums.append(module.momentum) 43 | batchnorms.append(module) 44 | 45 | return batchnorms, original_momentums 46 | 47 | def extract_running_stats(self): 48 | # Extract the running stats of batchnorms 49 | running_stats = [] 50 | for module in self.modules(): 51 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 52 | running_stats.append([module.running_mean, module.running_var]) 53 | 54 | return running_stats 55 | 56 | def freeze_running_stats(self): 57 | # Set the batchnorms' momentum to 0 to freeze the running stats 58 | # First call 59 | if None in [self.batchnorms, self.momentums]: 60 | self.batchnorms, self.momentums = self.extract_norm_n_momentum() 61 | 62 | for module in self.batchnorms: 63 | module.momentum = 0 64 | 65 | def recover_running_stats(self): 66 | # Recover the batchnorms' momentum to make running stats updatable 67 | if None in [self.batchnorms, self.momentums]: 68 | return 69 | for module, momentum in zip(self.batchnorms, self.momentums): 70 | module.momentum = momentum 71 | -------------------------------------------------------------------------------- /lightning_ssl/models/cnn13.py: -------------------------------------------------------------------------------- 1 | # The code is borrowed from 2 | # https://github.com/benathi/fastswa-semi-sup/blob/master/mean_teacher/architectures.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.utils import weight_norm 7 | 8 | from lightning_ssl.models.base_model import CustomModel 9 | 10 | 11 | class GaussianNoise(nn.Module): 12 | def __init__(self, std): 13 | super(GaussianNoise, self).__init__() 14 | self.std = std 15 | 16 | def forward(self, x): 17 | zeros_ = torch.zeros(x.size()).to(x.device) 18 | n = torch.normal(zeros_, std=self.std).to(x.device) 19 | return x + n 20 | 21 | 22 | class CNN13(CustomModel): 23 | """ 24 | CNN from Mean Teacher paper 25 | """ 26 | 27 | def __init__(self, num_classes=10): 28 | super(CNN13, self).__init__() 29 | 30 | self.gn = GaussianNoise(0.15) 31 | self.activation = nn.LeakyReLU(0.1) 32 | self.conv1a = weight_norm(nn.Conv2d(3, 128, 3, padding=1)) 33 | self.bn1a = nn.BatchNorm2d(128) 34 | self.conv1b = weight_norm(nn.Conv2d(128, 128, 3, padding=1)) 35 | self.bn1b = nn.BatchNorm2d(128) 36 | self.conv1c = weight_norm(nn.Conv2d(128, 128, 3, padding=1)) 37 | self.bn1c = nn.BatchNorm2d(128) 38 | self.mp1 = nn.MaxPool2d(2, stride=2, padding=0) 39 | self.drop1 = nn.Dropout(0.5) 40 | 41 | self.conv2a = weight_norm(nn.Conv2d(128, 256, 3, padding=1)) 42 | self.bn2a = nn.BatchNorm2d(256) 43 | self.conv2b = weight_norm(nn.Conv2d(256, 256, 3, padding=1)) 44 | self.bn2b = nn.BatchNorm2d(256) 45 | self.conv2c = weight_norm(nn.Conv2d(256, 256, 3, padding=1)) 46 | self.bn2c = nn.BatchNorm2d(256) 47 | self.mp2 = nn.MaxPool2d(2, stride=2, padding=0) 48 | self.drop2 = nn.Dropout(0.5) 49 | 50 | self.conv3a = weight_norm(nn.Conv2d(256, 512, 3, padding=0)) 51 | self.bn3a = nn.BatchNorm2d(512) 52 | self.conv3b = weight_norm(nn.Conv2d(512, 256, 1, padding=0)) 53 | self.bn3b = nn.BatchNorm2d(256) 54 | self.conv3c = weight_norm(nn.Conv2d(256, 128, 1, padding=0)) 55 | self.bn3c = nn.BatchNorm2d(128) 56 | self.ap3 = nn.AvgPool2d(6, stride=2, padding=0) 57 | 58 | self.fc1 = weight_norm(nn.Linear(128, num_classes)) 59 | 60 | def forward(self, x, debug=False): 61 | x = self.activation(self.bn1a(self.conv1a(x))) 62 | x = self.activation(self.bn1b(self.conv1b(x))) 63 | x = self.activation(self.bn1c(self.conv1c(x))) 64 | x = self.mp1(x) 65 | x = self.drop1(x) 66 | 67 | x = self.activation(self.bn2a(self.conv2a(x))) 68 | x = self.activation(self.bn2b(self.conv2b(x))) 69 | x = self.activation(self.bn2c(self.conv2c(x))) 70 | x = self.mp2(x) 71 | x = self.drop2(x) 72 | 73 | x = self.activation(self.bn3a(self.conv3a(x))) 74 | x = self.activation(self.bn3b(self.conv3b(x))) 75 | x = self.activation(self.bn3c(self.conv3c(x))) 76 | x = self.ap3(x) 77 | 78 | x = x.view(-1, 128) 79 | 80 | return self.fc1(x) 81 | -------------------------------------------------------------------------------- /lightning_ssl/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | # The code is borrow from 2 | # https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from lightning_ssl.models.base_model import CustomModel 10 | 11 | BIAS = False 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, pre_activate=False): 16 | super(BasicBlock, self).__init__() 17 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 18 | self.relu1 = nn.LeakyReLU(0.1, inplace=True) 19 | self.conv1 = nn.Conv2d( 20 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=BIAS 21 | ) 22 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 23 | self.relu2 = nn.LeakyReLU(0.1, inplace=True) 24 | self.conv2 = nn.Conv2d( 25 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=BIAS 26 | ) 27 | self.droprate = dropRate 28 | self.equalInOut = in_planes == out_planes 29 | self.convShortcut = ( 30 | (not self.equalInOut) 31 | and nn.Conv2d( 32 | in_planes, 33 | out_planes, 34 | kernel_size=1, 35 | stride=stride, 36 | padding=0, 37 | bias=BIAS, 38 | ) 39 | or None 40 | ) 41 | self.pre_activate = pre_activate 42 | 43 | def forward(self, x): 44 | # if not self.equalInOut: 45 | out = self.relu1(self.bn1(x)) 46 | if self.pre_activate: 47 | x = out 48 | # else: 49 | # out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 50 | out = self.relu2(self.bn2(self.conv1(out))) 51 | if self.droprate > 0: 52 | out = F.dropout(out, p=self.droprate, training=self.training) 53 | out = self.conv2(out) 54 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 55 | 56 | 57 | class NetworkBlock(nn.Module): 58 | def __init__( 59 | self, 60 | nb_layers, 61 | in_planes, 62 | out_planes, 63 | block, 64 | stride, 65 | dropRate=0.0, 66 | pre_activate=False, 67 | ): 68 | super(NetworkBlock, self).__init__() 69 | self.layer = self._make_layer( 70 | block, in_planes, out_planes, nb_layers, stride, dropRate, pre_activate 71 | ) 72 | 73 | def _make_layer( 74 | self, block, in_planes, out_planes, nb_layers, stride, dropRate, pre_activate 75 | ): 76 | layers = [] 77 | for i in range(int(nb_layers)): 78 | layers.append( 79 | block( 80 | i == 0 and in_planes or out_planes, 81 | out_planes, 82 | i == 0 and stride or 1, 83 | dropRate, 84 | pre_activate and i == 0, 85 | ) 86 | ) 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | return self.layer(x) 91 | 92 | 93 | class WideResNet(CustomModel): 94 | def __init__(self, depth, num_classes, widen_factor=2, dropRate=0.0): 95 | super(WideResNet, self).__init__() 96 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 97 | assert (depth - 4) % 6 == 0 98 | n = (depth - 4) / 6 99 | block = BasicBlock 100 | # 1st conv before any network block 101 | self.conv1 = nn.Conv2d( 102 | 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=BIAS 103 | ) 104 | # 1st block 105 | self.block1 = NetworkBlock( 106 | n, nChannels[0], nChannels[1], block, 1, dropRate, True 107 | ) 108 | # 2nd block 109 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 110 | # 3rd block 111 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 112 | # global average pooling and classifier 113 | self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) 114 | self.relu = nn.LeakyReLU(0.1, inplace=True) 115 | self.fc = nn.Linear(nChannels[3], num_classes) 116 | self.nChannels = nChannels[3] 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = 0.5 * m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(1.0 / n)) 122 | if BIAS: 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | elif isinstance(m, nn.Linear): 128 | nn.init.xavier_normal_(m.weight.data) 129 | m.bias.data.zero_() 130 | 131 | def forward(self, x): 132 | out = self.conv1(x) 133 | out = self.block1(out) 134 | out = self.block2(out) 135 | out = self.block3(out) 136 | out = self.relu(self.bn1(out)) 137 | out = F.avg_pool2d(out, 8) 138 | out = out.view(-1, self.nChannels) 139 | return self.fc(out) 140 | 141 | 142 | # class BasicBlock(nn.Module): 143 | # def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False): 144 | # super(BasicBlock, self).__init__() 145 | # self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 146 | # self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 147 | # self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 148 | # padding=1, bias=False) 149 | # self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 150 | # self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 151 | # self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 152 | # padding=1, bias=False) 153 | # self.droprate = dropRate 154 | # self.equalInOut = (in_planes == out_planes) 155 | # self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 156 | # padding=0, bias=False) or None 157 | # self.activate_before_residual = activate_before_residual 158 | # def forward(self, x): 159 | # if not self.equalInOut and self.activate_before_residual == True: 160 | # x = self.relu1(self.bn1(x)) 161 | # else: 162 | # out = self.relu1(self.bn1(x)) 163 | # out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 164 | # if self.droprate > 0: 165 | # out = F.dropout(out, p=self.droprate, training=self.training) 166 | # out = self.conv2(out) 167 | # return torch.add(x if self.equalInOut else self.convShortcut(x), out) 168 | 169 | # class NetworkBlock(nn.Module): 170 | # def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False): 171 | # super(NetworkBlock, self).__init__() 172 | # self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual) 173 | # def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual): 174 | # layers = [] 175 | # for i in range(int(nb_layers)): 176 | # layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual)) 177 | # return nn.Sequential(*layers) 178 | # def forward(self, x): 179 | # return self.layer(x) 180 | 181 | # class WideResNet(nn.Module): 182 | # def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0): 183 | # super(WideResNet, self).__init__() 184 | # nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 185 | # assert((depth - 4) % 6 == 0) 186 | # n = (depth - 4) / 6 187 | # block = BasicBlock 188 | # # 1st conv before any network block 189 | # self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 190 | # padding=1, bias=False) 191 | # # 1st block 192 | # self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True) 193 | # # 2nd block 194 | # self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 195 | # # 3rd block 196 | # self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 197 | # # global average pooling and classifier 198 | # self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) 199 | # self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 200 | # self.fc = nn.Linear(nChannels[3], num_classes) 201 | # self.nChannels = nChannels[3] 202 | 203 | # for m in self.modules(): 204 | # if isinstance(m, nn.Conv2d): 205 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 206 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 207 | # elif isinstance(m, nn.BatchNorm2d): 208 | # m.weight.data.fill_(1) 209 | # m.bias.data.zero_() 210 | # elif isinstance(m, nn.Linear): 211 | # nn.init.xavier_normal_(m.weight.data) 212 | # m.bias.data.zero_() 213 | 214 | # def forward(self, x): 215 | # out = self.conv1(x) 216 | # out = self.block1(out) 217 | # out = self.block2(out) 218 | # out = self.block3(out) 219 | # out = self.relu(self.bn1(out)) 220 | # out = F.avg_pool2d(out, 8) 221 | # out = out.view(-1, self.nChannels) 222 | # return self.fc(out) 223 | -------------------------------------------------------------------------------- /lightning_ssl/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier_module import ClassifierModule 2 | from .mixmatch_module import MixmatchModule 3 | -------------------------------------------------------------------------------- /lightning_ssl/module/classifier_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import collections 4 | import numpy as np 5 | from copy import deepcopy 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets import MNIST 9 | from torchvision import transforms 10 | import pytorch_lightning as pl 11 | from lightning_ssl.utils.torch_utils import ( 12 | EMA, 13 | smooth_label, 14 | soft_cross_entropy, 15 | mixup_data, 16 | customized_weight_decay, 17 | WeightDecayModule, 18 | split_weight_decay_weights, 19 | ) 20 | 21 | # Use the style similar to pytorch_lightning (pl) 22 | # Codes will revised to be compatible with pl when pl has all the necessary features. 23 | 24 | # Codes borrowed from 25 | # https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=x-34xKCI40yW 26 | 27 | 28 | class ClassifierModule(pl.LightningModule): 29 | def __init__(self, hparams, classifier, loaders): 30 | super(ClassifierModule, self).__init__() 31 | self.hparams = hparams 32 | self.classifier = classifier 33 | self.ema_classifier = deepcopy(classifier) 34 | self.loaders = loaders 35 | self.best_dict = { 36 | "val_acc": 0, 37 | } 38 | self.train_dict = { 39 | key: collections.deque([], size) 40 | for key, size in zip( 41 | ["loss", "acc", "val_acc", "test_acc"], [500, 500, 1, 20] 42 | ) 43 | } 44 | 45 | # to record the best validation accuracy 46 | self.train_dict["val_acc"].append(0) # avoid empty list 47 | 48 | def on_train_start(self): 49 | # model will put in GPU before this function 50 | # so we initiate EMA and WeightDecayModule here 51 | self.ema = EMA(self.classifier, self.ema_classifier, self.hparams.ema) 52 | # self.wdm = WeightDecayModule(self.classifier, self.hparams.weight_decay, ["bn", "bias"]) 53 | 54 | def on_train_batch_end(self, *args, **kwargs): 55 | # self.ema.update(self.classifier) 56 | # wd = self.hparams.weight_decay * self.hparams.learning_rate 57 | # customized_weight_decay(self.classifier, self.hparams.weight_decay, ["bn", "bias"]) 58 | # self.wdm.decay() 59 | self.ema.step() 60 | 61 | def accuracy(self, y_hat, y): 62 | return 100 * (torch.argmax(y_hat, dim=-1) == y).float().mean() 63 | 64 | def training_step(self, batch, batch_nb): 65 | # REQUIRED 66 | x, y = batch 67 | # return smooth one-hot like label 68 | y_one_hot = smooth_label( 69 | y, self.hparams.n_classes, self.hparams.label_smoothing 70 | ) 71 | # mixup 72 | mixed_x, mixed_y = mixup_data(x, y_one_hot, self.hparams.alpha) 73 | y_hat = self.classifier(mixed_x) 74 | loss = soft_cross_entropy(y_hat, mixed_y) 75 | y = torch.argmax(mixed_y, dim=-1) 76 | acc = self.accuracy(y_hat, y) 77 | num = len(y) 78 | 79 | self.train_dict["loss"].append(loss.item()) 80 | self.train_dict["acc"].append(acc.item()) 81 | 82 | # tensorboard_logs = {"train/loss": np.mean(self.train_dict["loss"]), 83 | # "train/acc": np.mean(self.train_dict["acc"])} 84 | 85 | # progress_bar = {"acc": np.mean(self.train_dict["acc"])} 86 | 87 | # return {"loss": loss, "train_acc": acc, 88 | # "train_num": num, "log": tensorboard_logs, 89 | # "progress_bar": progress_bar} 90 | 91 | self.log( 92 | "train/loss", np.mean(self.train_dict["loss"]), prog_bar=False, logger=True 93 | ) 94 | self.log( 95 | "train/acc", np.mean(self.train_dict["acc"]), prog_bar=False, logger=True 96 | ) 97 | 98 | return {"loss": loss, "train_acc": acc, "train_num": num} 99 | 100 | def validation_step(self, batch, *args): 101 | # OPTIONAL 102 | x, y = batch 103 | y_hat = self.ema_classifier(x) 104 | 105 | acc = self.accuracy(y_hat, y) 106 | num = len(y) 107 | 108 | return {"val_loss": F.cross_entropy(y_hat, y), "val_acc": acc, "val_num": num} 109 | 110 | def validation_epoch_end(self, outputs): 111 | # Monitor both validation set and test set 112 | # record the test accuracies of last 20 checkpoints 113 | 114 | avg_loss_list, avg_acc_list = [], [] 115 | 116 | for output in outputs: 117 | avg_loss = torch.stack( 118 | [x["val_loss"] * x["val_num"] for x in output] 119 | ).sum() / np.sum([x["val_num"] for x in output]) 120 | avg_acc = torch.stack( 121 | [x["val_acc"] * x["val_num"] for x in output] 122 | ).sum() / np.sum([x["val_num"] for x in output]) 123 | 124 | avg_loss_list.append(avg_loss) 125 | avg_acc_list.append(avg_acc) 126 | 127 | # record best results of validation set 128 | self.train_dict["val_acc"][0] = max( 129 | self.train_dict["val_acc"][0], avg_acc_list[0].item() 130 | ) 131 | self.train_dict["test_acc"].append(avg_acc_list[1].item()) 132 | 133 | # tensorboard_logs = {"val/loss": avg_loss_list[0], 134 | # "val/acc": avg_acc_list[0], 135 | # "val/best_acc": self.train_dict["val_acc"][0], 136 | # "test/median_acc": np.median(self.train_dict["test_acc"])} 137 | 138 | # return {"val_loss": avg_loss_list[0], "val_acc": avg_acc_list[0], "log": tensorboard_logs} 139 | self.logger.experiment.add_scalar( 140 | "t", np.median(self.train_dict["test_acc"]), self.global_step 141 | ) 142 | self.log("val/loss", avg_loss_list[0], prog_bar=False, logger=True) 143 | self.log("val/acc", avg_acc_list[0], prog_bar=False, logger=True) 144 | self.log( 145 | "val/best_acc", self.train_dict["val_acc"][0], prog_bar=False, logger=True 146 | ) 147 | self.log( 148 | "test/median_acc", 149 | np.median(self.train_dict["test_acc"]), 150 | prog_bar=False, 151 | logger=True, 152 | ) 153 | 154 | def test_step(self, batch, batch_nb): 155 | # OPTIONAL 156 | x, y = batch 157 | y_hat = self.ema_classifier(x) 158 | acc = self.accuracy(y_hat, y) 159 | num = len(y) 160 | 161 | return { 162 | "test_loss": F.cross_entropy(y_hat, y), 163 | "test_acc": acc, 164 | "test_num": num, 165 | } 166 | 167 | def test_epoch_end(self, outputs): 168 | # OPTIONAL 169 | avg_loss = torch.stack( 170 | [x["test_loss"] * x["test_num"] for x in outputs] 171 | ).sum() / np.sum([x["test_num"] for x in outputs]) 172 | avg_acc = torch.stack( 173 | [x["test_acc"] * x["test_num"] for x in outputs] 174 | ).sum() / np.sum([x["test_num"] for x in outputs]) 175 | # logs = {"test/loss": avg_loss, "test/acc": avg_acc} 176 | # return {"test_loss": avg_loss, "test_acc": avg_acc, "log": logs, "progress_bar": logs} 177 | 178 | self.log("test/loss", avg_loss, prog_bar=False, logger=True) 179 | self.log("test/acc", avg_acc, prog_bar=False, logger=True) 180 | 181 | def configure_optimizers(self): 182 | # REQUIRED 183 | # can return multiple optimizers and learning_rate schedulers 184 | # (LBFGS it is automatically supported, no need for closure function) 185 | 186 | # split the weights into need weight decay and no need weight decay 187 | parameters = split_weight_decay_weights( 188 | self.classifier, self.hparams.weight_decay, ["bn", "bias"] 189 | ) 190 | 191 | opt = torch.optim.AdamW( 192 | parameters, lr=self.hparams.learning_rate, weight_decay=0 193 | ) 194 | # opt = torch.optim.SGD(self.classifier.parameters(), self.hparams.learning_rate, 195 | # momentum=0.9, nesterov=True) 196 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, float(self.hparams.max_epoch)) 197 | return [opt] # , [scheduler] 198 | 199 | # def train_dataloader(self): 200 | # # REQUIRED 201 | # return self.loaders["tr_loader"] 202 | 203 | # def val_dataloader(self): 204 | # # OPTIONAL 205 | # return [self.loaders["va_loader"], self.loaders["te_loader"]] 206 | 207 | # def test_dataloader(self): 208 | # # OPTIONAL 209 | # return self.loaders["te_loader"] 210 | -------------------------------------------------------------------------------- /lightning_ssl/module/mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lightning_ssl.utils.torch_utils import ( 6 | half_mixup_data, 7 | soft_cross_entropy, 8 | l2_distribution_loss, 9 | smooth_label, 10 | customized_weight_decay, 11 | interleave, 12 | ) 13 | 14 | 15 | class Mixmatch: 16 | classifier: nn.Module 17 | hparams: ... 18 | lambda_u: float 19 | 20 | def _loss(self, labeled_x, labeled_y, unlabeled_xs, batch_inference=False): 21 | """ 22 | labeled_x: [B, :] 23 | labeled_y: [B, n_classes] 24 | unlabeled_xs: K unlabeled_x 25 | """ 26 | batch_size = labeled_x.size(0) 27 | num_augmentation = len(unlabeled_xs) # num_augmentation 28 | 29 | # not to update the running mean and variance in BN 30 | self.classifier.freeze_running_stats() 31 | 32 | p_unlabeled_y = self.classifier.psuedo_label( 33 | unlabeled_xs, self.hparams.T, batch_inference 34 | ) 35 | 36 | # # size [(K + 1) * B, :] 37 | # all_inputs = torch.cat([labeled_x] + unlabeled_xs, dim=0) 38 | # all_targets = torch.cat( 39 | # [labeled_y, p_unlabeled_y.repeat(K, 1)], 40 | # dim=0 41 | # ) 42 | 43 | # mixed_input, mixed_target = half_mixup_data(all_inputs, all_targets, self.hparams.alpha) 44 | 45 | # logits = self.forward(mixed_input, self.classifier) 46 | 47 | # l_l = soft_cross_entropy(logits[:batch_size], mixed_target[:batch_size]) 48 | # l_u = l2_distribution_loss(logits[batch_size:], mixed_target[batch_size:]) 49 | 50 | all_inputs = torch.cat([labeled_x] + unlabeled_xs, dim=0) 51 | all_targets = torch.cat( 52 | [labeled_y, p_unlabeled_y.repeat(num_augmentation, 1)], dim=0 53 | ) 54 | 55 | mixed_input, mixed_target = half_mixup_data( 56 | all_inputs, all_targets, self.hparams.alpha 57 | ) 58 | 59 | if batch_inference: 60 | self.classifier.recover_running_stats() 61 | logits = self.classifier(mixed_input) 62 | else: 63 | # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 64 | mixed_input = list(torch.split(mixed_input, batch_size)) 65 | mixed_input = interleave(mixed_input, batch_size) 66 | 67 | logits_other = [self.classifier(x) for x in mixed_input[1:]] 68 | 69 | self.classifier.recover_running_stats() 70 | 71 | # only update the BN stats for the first batch 72 | logits_first = self.classifier(mixed_input[0]) 73 | 74 | logits = [logits_first] + logits_other 75 | 76 | # put interleaved samples back 77 | logits = interleave(logits, batch_size) 78 | 79 | logits = torch.cat(logits, dim=0) 80 | 81 | l_l = soft_cross_entropy(logits[:batch_size], mixed_target[:batch_size]) 82 | l_u = l2_distribution_loss(logits[batch_size:], mixed_target[batch_size:]) 83 | 84 | return l_l + self.lambda_u * l_u, l_l, l_u 85 | -------------------------------------------------------------------------------- /lightning_ssl/module/mixmatch_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import collections 4 | import numpy as np 5 | import torch.nn as nn 6 | from time import time 7 | from copy import deepcopy 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from torchvision.datasets import MNIST 11 | from torchvision import transforms 12 | import pytorch_lightning as pl 13 | from lightning_ssl.utils.torch_utils import ( 14 | half_mixup_data, 15 | soft_cross_entropy, 16 | l2_distribution_loss, 17 | smooth_label, 18 | customized_weight_decay, 19 | interleave, 20 | sharpening, 21 | ) 22 | 23 | from lightning_ssl.module.classifier_module import ClassifierModule 24 | from lightning_ssl.module.mixmatch import Mixmatch 25 | 26 | 27 | class MixmatchModule(ClassifierModule, Mixmatch): 28 | def __init__(self, hparams, classifier, loaders): 29 | super(MixmatchModule, self).__init__(hparams, classifier, loaders) 30 | self.lambda_u = 0 31 | self.rampup_length = 16384 # get from the mixmatch codes 32 | 33 | for key, size in zip(["lab_loss", "unl_loss"], [500, 500]): 34 | self.train_dict[key] = collections.deque([], size) 35 | 36 | def on_train_batch_end(self, *args, **kwargs): 37 | # self.wdm.decay() 38 | self.ema.step() 39 | 40 | # linearly ramp up lambda_u 41 | if self.lambda_u == self.hparams.lambda_u: 42 | return 43 | 44 | step = (self.hparams.lambda_u - 0) / self.rampup_length 45 | self.lambda_u = min(self.lambda_u + step, self.hparams.lambda_u) 46 | 47 | def training_step(self, batch, batch_nb): 48 | # REQUIRED 49 | labeled_batch, unlabeled_batch = batch[0], batch[1:] 50 | # labeled_batch, unlabeled_batch = batch 51 | labeled_x, labeled_y = labeled_batch 52 | 53 | labeled_y = smooth_label( 54 | labeled_y, self.hparams.n_classes, self.hparams.label_smoothing 55 | ) 56 | 57 | unlabeled_xs = [b[0] for b in unlabeled_batch] # only get the images 58 | 59 | loss, l_l, l_u = self._loss( 60 | labeled_x, 61 | labeled_y, 62 | unlabeled_xs, 63 | batch_inference=self.hparams.batch_inference, 64 | ) 65 | 66 | loss = l_l + self.lambda_u * l_u 67 | 68 | # y = torch.argmax(mixed_target, dim=-1) 69 | # acc = self.accuracy(logits[:batch_size], y[:batch_size]) 70 | 71 | self.train_dict["loss"].append(loss.item()) 72 | # self.train_dict["acc"].append(acc.item()) 73 | self.train_dict["lab_loss"].append(l_l.item()) 74 | self.train_dict["unl_loss"].append(l_u.item()) 75 | 76 | # tensorboard_logs = {"train/loss": np.mean(self.train_dict["loss"]), 77 | # # "train/acc": np.mean(self.train_dict["acc"]), 78 | # "train/lab_loss": np.mean(self.train_dict["lab_loss"]), 79 | # "train/unl_loss": np.mean(self.train_dict["unl_loss"]), 80 | # "lambda_u": self.lambda_u} 81 | 82 | # progress_bar = {# "acc": np.mean(self.train_dict["acc"]), 83 | # "lab_loss": np.mean(self.train_dict["lab_loss"]), 84 | # "unl_loss": np.mean(self.train_dict["unl_loss"]), 85 | # "lambda_u": self.lambda_u} 86 | 87 | # return {"loss": loss, "log": tensorboard_logs, "progress_bar": progress_bar} 88 | 89 | self.log( 90 | "train/loss", np.mean(self.train_dict["loss"]), prog_bar=False, logger=True 91 | ) 92 | self.log( 93 | "train/lab_loss", 94 | np.mean(self.train_dict["lab_loss"]), 95 | prog_bar=True, 96 | logger=True, 97 | ) 98 | self.log( 99 | "train/unl_loss", 100 | np.mean(self.train_dict["unl_loss"]), 101 | prog_bar=True, 102 | logger=True, 103 | ) 104 | self.log("train/lambda_u", self.lambda_u, prog_bar=False, logger=True) 105 | 106 | return {"loss": loss} 107 | 108 | # def validation_epoch_end(self, outputs): 109 | # # Monitor both validation set and test set 110 | # # record the test accuracies of last 20 checkpoints 111 | 112 | # avg_loss_list, avg_acc_list = [], [] 113 | 114 | # for output in outputs: 115 | # avg_loss = torch.stack([x["val_loss"] * x["val_num"] for x in output]).sum() / \ 116 | # np.sum([x["val_num"] for x in output]) 117 | # avg_acc = torch.stack([x["val_acc"] * x["val_num"] for x in output]).sum() / \ 118 | # np.sum([x["val_num"] for x in output]) 119 | 120 | # avg_loss_list.append(avg_loss) 121 | # avg_acc_list.append(avg_acc) 122 | 123 | # # record best results of validation set 124 | # self.train_dict["val_acc"][0] = max(self.train_dict["val_acc"][0], avg_acc_list[0].item()) 125 | # self.train_dict["test_acc"].append(avg_acc_list[1].item()) 126 | 127 | # # tensorboard_logs = {"val/loss": avg_loss_list[0], 128 | # # "val/acc": avg_acc_list[0], 129 | # # "val/best_acc": self.train_dict["val_acc"][0], 130 | # # "test/median_acc": np.median(self.train_dict["test_acc"])} 131 | 132 | # # return {"val_loss": avg_loss_list[0], "val_acc": avg_acc_list[0], "log": tensorboard_logs} 133 | 134 | # self.log("val/loss", avg_loss_list[0], on_step=True, on_epoch=False, 135 | # prog_bar=False, logger=True) 136 | # self.log("val/acc", avg_acc_list[0], on_step=True, on_epoch=False, 137 | # prog_bar=False, logger=True) 138 | # self.log("val/best_acc", self.train_dict["val_acc"][0], on_step=True, on_epoch=False, 139 | # prog_bar=False, logger=True) 140 | # self.log("test/median_acc", np.median(self.train_dict["test_acc"]), on_step=True, on_epoch=False, 141 | # prog_bar=False, logger=True) 142 | -------------------------------------------------------------------------------- /lightning_ssl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import makedirs, count_parameters 2 | -------------------------------------------------------------------------------- /lightning_ssl/utils/argparser.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from configparser import ConfigParser 4 | 5 | 6 | def str2bool(v): 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ("yes", "true", "t", "y", "1"): 10 | return True 11 | elif v.lower() in ("no", "false", "f", "n", "0"): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError("Boolean value expected.") 15 | 16 | 17 | def parser(): 18 | conf_parser = argparse.ArgumentParser( 19 | description="parser for config", 20 | # Don't mess with format of description 21 | formatter_class=argparse.RawDescriptionHelpFormatter, 22 | # Turn off help, so we print all options in response to -h 23 | add_help=False, 24 | ) 25 | 26 | conf_parser.add_argument( 27 | "-c", "--conf_file", help="Specify config file", metavar="FILE" 28 | ) 29 | 30 | conf_parser.add_argument( 31 | "-d", "--dataset", default="cifar10", help="use what dataset" 32 | ) 33 | 34 | args, remaining_argv = conf_parser.parse_known_args() 35 | 36 | defaults = {} 37 | if args.conf_file: 38 | with open(args.conf_file, "r") as f: 39 | configs = yaml.load(f, yaml.SafeLoader) 40 | assert args.dataset in configs, f"Don't have the config for {args.dataset}" 41 | defaults.update(configs[args.dataset]) 42 | 43 | parser = argparse.ArgumentParser( 44 | description="Semi-Supervised Learning", parents=[conf_parser] 45 | ) 46 | 47 | parser.add_argument( 48 | "--learning_scenario", 49 | default="semi", 50 | choices=["semi", "supervised"], 51 | help="what learning scenario to use", 52 | ) 53 | 54 | parser.add_argument("--algo", default="none", help="which algorithm to use") 55 | 56 | parser.add_argument( 57 | "--todo", 58 | choices=["train", "test"], 59 | default="train", 60 | help="what behavior want to do: train | test", 61 | ) 62 | 63 | parser.add_argument( 64 | "--data_root", default=".", help="the directory to save the dataset" 65 | ) 66 | 67 | parser.add_argument( 68 | "--log_root", 69 | default=".", 70 | help="the directory to save the logs or other imformations (e.g. images)", 71 | ) 72 | 73 | parser.add_argument( 74 | "--model_root", default="checkpoint", help="the directory to save the models" 75 | ) 76 | 77 | parser.add_argument("--load_checkpoint", default="./model/default/model.pth") 78 | 79 | parser.add_argument( 80 | "--affix", default="default", help="the affix for the save folder" 81 | ) 82 | 83 | parser.add_argument("--seed", type=int, default=1, help="seed") 84 | 85 | parser.add_argument( 86 | "--num_workers", 87 | type=int, 88 | default=16, 89 | help="how many workers used in data loader", 90 | ) 91 | 92 | parser.add_argument("--batch_size", "-b", type=int, default=128, help="batch size") 93 | 94 | parser.add_argument( 95 | "--num_labeled", type=int, default=1000, help="number of labeled data" 96 | ) 97 | 98 | parser.add_argument( 99 | "--num_val", type=int, default=5000, help="the amount of validation data" 100 | ) 101 | 102 | parser.add_argument( 103 | "--max_epochs", 104 | "-m_e", 105 | type=int, 106 | default=200, 107 | help="the maximum numbers of the model see a sample", 108 | ) 109 | 110 | parser.add_argument( 111 | "--learning_rate", "-lr", type=float, default=1e-2, help="learning rate" 112 | ) 113 | 114 | parser.add_argument( 115 | "--weight_decay", 116 | "-w", 117 | type=float, 118 | default=2e-4, 119 | help="the parameter of l2 restriction for weights", 120 | ) 121 | 122 | parser.add_argument("--gpus", "-g", default="0", help="what gpus to use") 123 | 124 | parser.add_argument( 125 | "--n_eval_step", 126 | type=int, 127 | default=100, 128 | help="number of iteration per one evaluation", 129 | ) 130 | 131 | parser.add_argument( 132 | "--n_checkpoint_step", 133 | type=int, 134 | default=4000, 135 | help="number of iteration to save a checkpoint", 136 | ) 137 | 138 | parser.add_argument( 139 | "--n_store_image_step", 140 | type=int, 141 | default=4000, 142 | help="number of iteration to save adversaries", 143 | ) 144 | 145 | parser.add_argument( 146 | "--max_steps", type=int, default=1 << 16, help="maximum iteration for training" 147 | ) 148 | 149 | parser.add_argument( 150 | "--ema", 151 | type=float, 152 | default=0.999, 153 | help="the decay for exponential moving average", 154 | ) 155 | 156 | parser.add_argument( 157 | "--label_smoothing", 158 | type=float, 159 | default=0.0, 160 | help="the paramters for label smoothing", 161 | ) 162 | 163 | parser.add_argument( 164 | "--augment", action="store_true", help="whether to augment the training data" 165 | ) 166 | 167 | parser.add_argument( 168 | "--num_augments", 169 | type=int, 170 | default=2, 171 | help="how many augment samples for unlabeled data in Mixmatch", 172 | ) 173 | 174 | parser.add_argument( 175 | "--alpha", 176 | type=float, 177 | default=-1, 178 | help="the hyperparameter for beta distribution in mixup \ 179 | (0 < alpha. If alpha < 0 means no mix)", 180 | ) 181 | 182 | parser.add_argument( 183 | "--T", 184 | type=float, 185 | default=0.5, 186 | help="temperature for softmax distribution or sharpen parameter in mixmatch", 187 | ) 188 | 189 | parser.add_argument( 190 | "--lambda_u", 191 | type=float, 192 | default=100, 193 | help="the weight of the loss for the unlabeled data", 194 | ) 195 | 196 | parser.add_argument( 197 | "--batch_inference", 198 | type=str2bool, 199 | default=False, 200 | help="whether use batch inference in generating psuedo labels and computing loss", 201 | ) 202 | 203 | parser.add_argument( 204 | "--progress_bar_refresh_rate", 205 | type=int, 206 | default=1, 207 | help="The frequency to refresh the progress bar (0 diables the progress bar)", 208 | ) 209 | 210 | parser.set_defaults(**defaults) 211 | parser.set_defaults(**vars(args)) 212 | 213 | return parser.parse_args(remaining_argv) 214 | 215 | 216 | def print_args(args, logger=None): 217 | for k, v in vars(args).items(): 218 | if logger is not None: 219 | logger.info("{:<16} : {}".format(k, v)) 220 | else: 221 | print("{:<16} : {}".format(k, v)) 222 | -------------------------------------------------------------------------------- /lightning_ssl/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from copy import deepcopy 6 | 7 | 8 | def sharpening(label, T): 9 | label = label.pow(1 / T) 10 | return label / label.sum(-1, keepdim=True) 11 | 12 | 13 | def soft_cross_entropy(input_, target_, dim=-1): 14 | """ 15 | compute the cross entropy between input_ and target_ 16 | Args 17 | input_: logits of model's prediction. Size = [Batch, n_classes] 18 | target_: probability of target distribution. Size = [Batch, n_classes] 19 | Return 20 | the entropy between input_ and target_ 21 | """ 22 | 23 | input_ = input_.log_softmax(dim=dim) 24 | 25 | return torch.mean(torch.sum(-target_ * input_, dim=dim)) 26 | 27 | 28 | def l2_distribution_loss(input_, target_, dim=-1): 29 | input_ = input_.softmax(dim=dim) 30 | 31 | return torch.mean((input_ - target_) ** 2) 32 | 33 | 34 | def mixup_data(x, y, alpha): 35 | """ 36 | Args: 37 | x: data, whose size is [Batch, ...] 38 | y: label, whose size is [Batch, ...] 39 | alpha: the paramters for beta distribution. If alpha <= 0 means no mix 40 | Return 41 | mixed inputs, mixed targets 42 | """ 43 | # code is modified from 44 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py 45 | 46 | batch_size = x.size()[0] 47 | 48 | if alpha > 0: 49 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,)) 50 | else: 51 | lam = torch.ones((batch_size,)) 52 | 53 | lam = lam.to(x.device) 54 | 55 | index = torch.randperm(batch_size).to(x.device) 56 | 57 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))] 58 | x_size[0], y_size[0] = batch_size, batch_size 59 | 60 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index] 61 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index] 62 | 63 | return mixed_x, mixed_y 64 | 65 | 66 | def mixup_data_for_testing(x, y, alpha): 67 | """ 68 | Args: 69 | x: data, whose size is [Batch, ...] 70 | y: label, whose size is [Batch, ...] 71 | alpha: the paramters for beta distribution. If alpha <= 0 means no mix 72 | Return 73 | mixed inputs, mixed targets 74 | """ 75 | # code is modified from 76 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py 77 | 78 | batch_size = x.size()[0] 79 | 80 | if alpha > 0: 81 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,)) 82 | else: 83 | lam = torch.ones((batch_size,)) 84 | 85 | lam = lam.to(x.device) 86 | 87 | index = torch.randperm(batch_size).to(x.device) 88 | 89 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))] 90 | x_size[0], y_size[0] = batch_size, batch_size 91 | 92 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index] 93 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index] 94 | 95 | return mixed_x, mixed_y, x[index], y[index], lam 96 | 97 | 98 | def half_mixup_data(x, y, alpha): 99 | """ 100 | This function is similar to normal mixup except that the mixed_x 101 | and mixed_y are close to x and y. 102 | """ 103 | # code is modified from 104 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py 105 | 106 | batch_size = x.size()[0] 107 | 108 | if alpha > 0: 109 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,)) 110 | else: 111 | lam = torch.ones((batch_size,)) 112 | 113 | lam = torch.max(lam, 1 - lam) 114 | 115 | lam = lam.to(x.device) 116 | 117 | index = torch.randperm(batch_size).to(x.device) 118 | 119 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))] 120 | x_size[0], y_size[0] = batch_size, batch_size 121 | 122 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index] 123 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index] 124 | 125 | return mixed_x, mixed_y 126 | 127 | 128 | def half_mixup_data_for_testing(x, y, alpha): 129 | """ 130 | This function is similar to normal mixup except that the mixed_x 131 | and mixed_y are close to x and y. 132 | """ 133 | # code is modified from 134 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py 135 | 136 | batch_size = x.size()[0] 137 | 138 | if alpha > 0: 139 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,)) 140 | else: 141 | lam = torch.ones((batch_size,)) 142 | 143 | lam = torch.max(lam, 1 - lam) 144 | 145 | lam = lam.to(x.device) 146 | 147 | index = torch.randperm(batch_size).to(x.device) 148 | 149 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))] 150 | x_size[0], y_size[0] = batch_size, batch_size 151 | 152 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index] 153 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index] 154 | 155 | return mixed_x, mixed_y, x[index], y[index], lam 156 | 157 | 158 | def smooth_label(y, n_classes, smoothing=0.0): 159 | """ 160 | Transform the y into one-hot representation and smooth it. 161 | If smoothing is 0, then the return will be one-hot representation of y. 162 | Args 163 | y: label which is LongTensor 164 | n_classes: the total number of classes 165 | smoothing: the paramter of label smoothing 166 | Return 167 | true: the smooth label whose size is [Batch, n_classes] 168 | """ 169 | 170 | confidence = 1.0 - smoothing 171 | true_dist = torch.zeros(*list(y.size()), n_classes).to(y.device) 172 | true_dist.fill_(smoothing / (n_classes - 1)) 173 | true_dist.scatter_(-1, y.data.unsqueeze(1), confidence) 174 | 175 | return true_dist 176 | 177 | 178 | def to_one_hot_vector(y, n_classes): 179 | return smooth_label(y, n_classes, 0) 180 | 181 | 182 | def customized_weight_decay(model, weight_decay, ignore_key=["bn", "bias"]): 183 | for name, p in model.named_parameters(): 184 | if not any(key in name for key in ignore_key): 185 | p.data.mul_(1 - weight_decay) 186 | 187 | 188 | def interleave_offsets(batch, nu): 189 | groups = [batch // (nu + 1)] * (nu + 1) 190 | for x in range(batch - sum(groups)): 191 | groups[-x - 1] += 1 192 | offsets = [0] 193 | for g in groups: 194 | offsets.append(offsets[-1] + g) 195 | assert offsets[-1] == batch 196 | return offsets 197 | 198 | 199 | def interleave(xy, batch): 200 | nu = len(xy) - 1 201 | offsets = interleave_offsets(batch, nu) 202 | xy = [[v[offsets[p] : offsets[p + 1]] for p in range(nu + 1)] for v in xy] 203 | for i in range(1, nu + 1): 204 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 205 | return [torch.cat(v, dim=0) for v in xy] 206 | 207 | 208 | # def my_interleave_1(inputs, batch_size): 209 | # """ 210 | # * Make the data interleave. 211 | # * Swap the data of the first batch (inputs[0]) to other batches. 212 | # * change_indices would be a increasing function. 213 | # * len(change_indices) should be the same as len(inputs[0]), and the elements 214 | # denote which row should the data in first row change with. 215 | # """ 216 | # ret = deepcopy(inputs) 217 | # inputs_size = len(inputs) 218 | 219 | # repeat = batch_size // inputs_size 220 | # residual = batch_size % inputs_size 221 | # change_indices = list(range(inputs_size)) * repeat + list( 222 | # range(inputs_size - residual, inputs_size) 223 | # ) 224 | # change_indices = sorted(change_indices) 225 | # # print(change_indices) 226 | # for i, switch_row in enumerate(change_indices): 227 | # ret[0][i], ret[switch_row % inputs_size][i] = ( 228 | # inputs[switch_row % inputs_size][i], 229 | # inputs[0][i], 230 | # ) 231 | 232 | # return ret 233 | 234 | 235 | # def my_interleave_2(inputs, batch_size): 236 | # """ 237 | # * Make the data interleave. 238 | # * Swap the data of the first batch (inputs[0]) to other batches. 239 | # * change_indices would be a increasing function. 240 | # * len(change_indices) should be the same as len(inputs[0]), and the elements 241 | # denote which row should the data in first row change with. 242 | # """ 243 | # # ret = deepcopy(inputs) 244 | # def swap(A, B): 245 | # return B.clone(), A.clone() 246 | 247 | # ret = inputs 248 | # inputs_size = len(inputs) 249 | 250 | # # equally switch the first row to other rows, so we compute how many repeat for range(inputs_size), 251 | # # which store the rows to change. 252 | # # some of the element cannot evenly spread two rows, so we preferentially use the rows which are farer to 0th row. 253 | # repeat = batch_size // inputs_size 254 | # residual = batch_size % inputs_size 255 | # change_indices = list(range(inputs_size)) * repeat + list( 256 | # range(inputs_size - residual, inputs_size) 257 | # ) 258 | # change_indices = sorted(change_indices) 259 | # # print(change_indices) 260 | 261 | # # start to change elements 262 | # for i, switch_row in enumerate(change_indices): 263 | # ret[0][i], ret[switch_row % inputs_size][i] = swap( 264 | # ret[0][i], ret[switch_row % inputs_size][i] 265 | # ) 266 | 267 | # return ret 268 | 269 | 270 | def my_interleave(inputs, batch_size): 271 | """ 272 | * This function will override inputs 273 | * Make the data interleave. 274 | * Swap the data of the first batch (inputs[0]) to other batches. 275 | * change_indices would be a increasing function. 276 | * len(change_indices) should be the same as len(inputs[0]), and the elements 277 | denote which row should the data in first row change with. 278 | """ 279 | 280 | def swap(A, B): 281 | """ 282 | swap for tensors 283 | """ 284 | return B.clone(), A.clone() 285 | 286 | ret = inputs 287 | inputs_size = len(inputs) 288 | 289 | repeat = batch_size // inputs_size 290 | residual = batch_size % inputs_size 291 | 292 | # equally switch the first row to other rows, so we compute how many repeat for range(inputs_size), 293 | # which store the rows to change. 294 | # some of the element cannot evenly spread two rows, so we preferentially use the rows which are farer to 0th row. 295 | 296 | change_indices = list(range(inputs_size)) * repeat + list( 297 | range(inputs_size - residual, inputs_size) 298 | ) 299 | change_indices = sorted(change_indices) 300 | # print(change_indices) 301 | 302 | # the change_indices is monotone increasing function, so we can group the same elements and swap together 303 | # e.g. change_indices = [0, 1, 1, 2, 2, 2] 304 | # => two_dimension_change_indices = [[0], [1, 1], [2, 2, 2]] 305 | two_dimension_change_indices = [] 306 | 307 | change_indices.insert(0, -1) 308 | change_indices.append(change_indices[-1] + 1) 309 | start = 0 310 | for i in range(1, len(change_indices)): 311 | if change_indices[i] != change_indices[i - 1]: 312 | two_dimension_change_indices.append(change_indices[start:i]) 313 | start = i 314 | 315 | two_dimension_change_indices.pop(0) 316 | 317 | i = 0 318 | for switch_rows in two_dimension_change_indices: 319 | switch_row = switch_rows[0] 320 | num = len(switch_rows) 321 | ret[0][i : i + num], ret[switch_row % inputs_size][i : i + num] = swap( 322 | ret[0][i : i + num], ret[switch_row % inputs_size][i : i + num] 323 | ) 324 | i += num 325 | 326 | return ret 327 | 328 | 329 | def split_weight_decay_weights(model, weight_decay, ignore_key=["bn", "bias"]): 330 | weight_decay_weights = [] 331 | no_weight_decay_weights = [] 332 | for name, p in model.named_parameters(): 333 | if not p.requires_grad: 334 | continue 335 | if any(key in name for key in ignore_key): 336 | no_weight_decay_weights.append(p) 337 | else: 338 | # print(name) 339 | weight_decay_weights.append(p) 340 | 341 | return [ 342 | {"params": no_weight_decay_weights, "weight_decay": 0.0}, 343 | {"params": weight_decay_weights, "weight_decay": weight_decay}, 344 | ] 345 | 346 | 347 | class WeightDecayModule: 348 | def __init__(self, model, weight_decay, ignore_key=["bn", "bias"]): 349 | self.weight_decay = weight_decay 350 | self.available_parameters = [] 351 | for name, p in model.named_parameters(): 352 | if not any(key in name for key in ignore_key): 353 | # print(name) 354 | self.available_parameters.append(p) 355 | 356 | def decay(self): 357 | for p in self.available_parameters: 358 | p.data.mul_(1 - self.weight_decay) 359 | 360 | 361 | class EMA: 362 | """ 363 | The module for exponential moving average 364 | """ 365 | 366 | def __init__(self, model, ema_model, decay=0.999): 367 | self.decay = decay 368 | self.params = list(model.state_dict().values()) 369 | self.ema_params = list(ema_model.state_dict().values()) 370 | 371 | # some of the quantity in batch norm is LongTensor, 372 | # we have to make them become float, or it will cause the 373 | # type error in mul_ or add_ in self.step() 374 | for p in ema_model.parameters(): 375 | p.detach_() 376 | for i in range(len(self.ema_params)): 377 | self.ema_params[i] = self.ema_params[i].float() 378 | 379 | def step(self): 380 | # average all the paramters, including the running mean and 381 | # running std in batchnormalization 382 | for param, ema_param in zip(self.params, self.ema_params): 383 | # if param.dtype == torch.float32: 384 | ema_param.mul_(self.decay) 385 | ema_param.add_(param * (1 - self.decay)) 386 | # if param.dtype == torch.float32: 387 | # param.mul_(1 - 4e-5) 388 | -------------------------------------------------------------------------------- /lightning_ssl/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | 5 | # def create_logger(save_path="", file_type="", level="debug"): 6 | 7 | # if level == "debug": 8 | # _level = logging.DEBUG 9 | # elif level == "info": 10 | # _level = logging.INFO 11 | 12 | # logger = logging.getLogger() 13 | # logger.setLevel(_level) 14 | 15 | # cs = logging.StreamHandler() 16 | # cs.setLevel(_level) 17 | # logger.addHandler(cs) 18 | 19 | # if save_path != "": 20 | # file_name = os.path.join(save_path, file_type + "_log.txt") 21 | # fh = logging.FileHandler(file_name, mode="w") 22 | # fh.setLevel(_level) 23 | 24 | # logger.addHandler(fh) 25 | 26 | # return logger 27 | 28 | 29 | def makedirs(path): 30 | if not os.path.exists(path): 31 | os.makedirs(path) 32 | 33 | 34 | def count_parameters(model): 35 | # copy from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8 36 | # baldassarre.fe's reply 37 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from lightning_ssl.module import ClassifierModule, MixmatchModule 5 | from lightning_ssl.dataloader import SemiCIFAR10Module, SupervisedCIFAR10Module 6 | from lightning_ssl.models import WideResNet 7 | import pytorch_lightning as pl 8 | from lightning_ssl.utils import count_parameters 9 | from lightning_ssl.utils.argparser import parser, print_args 10 | 11 | 12 | if __name__ == "__main__": 13 | args = parser() 14 | 15 | gpus = args.gpus if torch.cuda.is_available() else None 16 | 17 | # set the random seed 18 | # np.random.seed(args.seed) # for sampling labeled/unlabeled/val dataset 19 | pl.seed_everything(args.seed) 20 | 21 | # load the data and classifier 22 | if args.dataset == "cifar10": 23 | if args.learning_scenario == "supervised": 24 | loader_class = SupervisedCIFAR10Module 25 | 26 | elif args.learning_scenario == "semi": 27 | loader_class = SemiCIFAR10Module 28 | 29 | data_loader = loader_class( 30 | args, 31 | args.data_root, 32 | args.num_workers, 33 | args.batch_size, 34 | args.num_labeled, 35 | args.num_val, 36 | args.num_augments, 37 | ) 38 | classifier = WideResNet(depth=28, num_classes=data_loader.n_classes) 39 | 40 | else: 41 | raise NotImplementedError 42 | 43 | if args.learning_scenario == "supervised": 44 | module = ClassifierModule 45 | else: # semi supervised learning algorithm 46 | if args.algo == "mixmatch": 47 | module = MixmatchModule 48 | else: 49 | raise NotImplementedError 50 | 51 | print(f"model paramters: {count_parameters(classifier)} M") 52 | 53 | # set the number of classes in the args 54 | setattr(args, "n_classes", data_loader.n_classes) 55 | 56 | data_loader.prepare_data() 57 | data_loader.setup() 58 | 59 | model = module(args, classifier, loaders=None) 60 | 61 | # trainer = Trainer(tr_loaders, va_loader, te_loader) 62 | 63 | if args.todo == "train": 64 | print( 65 | f"labeled size: {data_loader.num_labeled_data}" 66 | f"unlabeled size: {data_loader.num_unlabeled_data}, " 67 | f"val size: {data_loader.num_val_data}" 68 | ) 69 | 70 | save_folder = ( 71 | f"{args.dataset}_{args.learning_scenario}_{args.algo}_{args.affix}" 72 | ) 73 | # model_folder = os.path.join(args.model_root, save_folder) 74 | # checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(filepath=model_folder) 75 | 76 | # tt_logger = pl.loggers.TestTubeLogger("tt_logs", name=save_folder, create_git_tag=True) 77 | # tt_logger.log_hyperparams(args) 78 | 79 | # if True: 80 | # print_args(args) 81 | 82 | # trainer = pl.trainer.Trainer(gpus=gpus, max_epochs=args.max_epochs) 83 | # else: 84 | 85 | tb_logger = pl.loggers.TensorBoardLogger( 86 | os.path.join(args.log_root, "lightning_logs"), name=save_folder 87 | ) 88 | tb_logger.log_hyperparams(args) 89 | 90 | # set the path of checkpoint 91 | save_dir = getattr(tb_logger, "save_dir", None) or getattr( 92 | tb_logger, "_save_dir", None 93 | ) 94 | 95 | ckpt_path = os.path.join( 96 | save_dir, tb_logger.name, f"version_{tb_logger.version}", "checkpoints" 97 | ) 98 | 99 | ckpt = pl.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_path, "last")) 100 | 101 | setattr(args, "checkpoint_folder", ckpt_path) 102 | 103 | print_args(args) 104 | 105 | trainer = pl.trainer.Trainer( 106 | gpus=gpus, 107 | max_steps=args.max_steps, 108 | logger=tb_logger, 109 | max_epochs=args.max_epochs, 110 | checkpoint_callback=ckpt, 111 | benchmark=True, 112 | profiler=True, 113 | progress_bar_refresh_rate=args.progress_bar_refresh_rate, 114 | reload_dataloaders_every_epoch=True, 115 | ) 116 | 117 | trainer.fit(model, datamodule=data_loader) 118 | trainer.test() 119 | else: 120 | trainer = pl.trainer.Trainer(resume_from_checkpoint=args.load_checkpoint) 121 | trainer.test(model, datamodule=data_loader) 122 | -------------------------------------------------------------------------------- /read_from_tb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import sys 5 | 6 | data_num = int(sys.argv[1]) 7 | seeds = [1, 2, 3] 8 | versions = [0, 0, 0] 9 | 10 | values = [] 11 | iters = [] 12 | 13 | for s, v in zip(seeds, versions): 14 | path_to_events_file = f"lightning_logs/cifar10_semi_mixmatch_@{data_num}.{s}_LB_NoV_FS_W_Batch_2/version_{v}/" 15 | # path_to_events_file = f"lightning_logs/cifar10_supervised_mixup_fs_mixup.{s}/version_{v}" 16 | 17 | files = os.listdir(path_to_events_file) 18 | for file_ in files: 19 | if file_.startswith("event"): 20 | event_file = file_ 21 | 22 | print(event_file) 23 | 24 | path_to_events_file = os.path.join(path_to_events_file, event_file) 25 | 26 | last_iter = 0 27 | for e in tf.train.summary_iterator(path_to_events_file): 28 | for v in e.summary.value: 29 | if v.tag.startswith("test/median_acc"): 30 | value = v.simple_value 31 | last_iter += 1 32 | 33 | values.append(value) 34 | iters.append(last_iter) 35 | 36 | # for i in range(1, len(iters)): 37 | # assert iters[i] == iters[i - 1] 38 | 39 | print(iters) 40 | print(values) 41 | print(100 - np.mean(values)) 42 | print(np.std(values)) 43 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | [flake8] 3 | # TODO: this should be 88 or 100 according PEP8 4 | max-line-length = 120 5 | exclude = .tox,*.egg,build,temp 6 | select = E,W,F 7 | doctests = True 8 | verbose = 2 9 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 10 | format = pylint 11 | ignore = 12 | E741 13 | E731 14 | W504 15 | F401 16 | F841 17 | E203 # E203 - whitespace before ':'. Opposite convention enforced by black 18 | E231 # E231: missing whitespace after ',', ';', or ':'; for black 19 | E501 # E501 - line too long. Handled by black, we have longer lines 20 | W503 # W503 - line break before binary operator, need for black 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import platform 4 | from setuptools import setup, find_packages 5 | 6 | 7 | setup( 8 | name="lightning-semi-supervised-learning", 9 | description="Ready-to-use semi-supervised learning under one common API", 10 | version="master", 11 | packages=find_packages(), 12 | include_package_data=True, 13 | install_requires=["torch==1.7.0", "torchvision", "pytorch-lightning==1.0.2"], 14 | extras_require={ 15 | "test": [ 16 | "coverage", 17 | "pytest", 18 | "flake8", 19 | "pre-commit", 20 | "codecov", 21 | "pytest-cov", 22 | "pytest-flake8", 23 | "flake8-black", 24 | "black", 25 | ] 26 | }, 27 | ) 28 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/__init__.py -------------------------------------------------------------------------------- /tests/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/dataloader/__init__.py -------------------------------------------------------------------------------- /tests/dataloader/test_base_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | from lightning_ssl.dataloader.base_data import ( 5 | get_split_indices, 6 | CustomSemiDataset, 7 | MultiDataset, 8 | MagicClass, 9 | ) 10 | from torch.utils.data import TensorDataset, DataLoader 11 | 12 | 13 | @pytest.fixture( 14 | params=[ 15 | {"max_label": 10, "num_data": 1000}, 16 | {"max_label": 10, "num_data": 1000}, 17 | {"max_label": 10, "num_data": 1000}, 18 | ] 19 | ) 20 | def label_data(request): 21 | param = request.param 22 | 23 | while True: 24 | label_proportion = np.random.uniform( 25 | 0, 1, size=(param["max_label"],) 26 | ) # the proportion of each label 27 | 28 | label_proportion = ( 29 | label_proportion / label_proportion.sum() 30 | ) # normalize to summation of proportion is 1 31 | 32 | # after this normalization 33 | label_proportion = np.round_( 34 | label_proportion, decimals=int(np.log10(param["num_data"]) - 1) 35 | ) 36 | 37 | residual = 1.0 - label_proportion.sum() 38 | # add the residual to the last element 39 | label_proportion[-1] += residual 40 | 41 | if label_proportion.min() > 0: # valid proportion 42 | 43 | label_proportion = (label_proportion * param["num_data"]).astype(int) 44 | 45 | if label_proportion.sum() == param["num_data"]: # valid proportion 46 | label_data = [] 47 | 48 | for i, p in enumerate(label_proportion): 49 | label_data.append(i * np.ones((p,))) 50 | 51 | label_data = np.hstack(label_data) 52 | 53 | np.random.shuffle(label_data) 54 | 55 | return label_data 56 | 57 | 58 | def test_get_split_indices(label_data): 59 | # label_data is specially design 60 | # the proportion of the labeled in label_data is specially designed to that 61 | # (num_labeled * proportion), (num_val * proportion) and (num_unlabeled * proportion) 62 | # are all integer 63 | 64 | num_labeled = int(0.1 * len(label_data)) 65 | num_val = int(0.1 * len(label_data)) 66 | n_classes = int(np.max(label_data).item()) + 1 67 | num_unlabeled = len(label_data) - num_labeled - num_val 68 | labeled_indices, unlabeled_indices, val_indices = get_split_indices( 69 | label_data, num_labeled, num_val, n_classes 70 | ) 71 | 72 | def convert_idx_to_num(indices): 73 | num_list = [] 74 | for c in range(n_classes): 75 | num_list.append((label_data[labeled_indices] == c).sum()) 76 | return num_list 77 | 78 | assert len(val_indices) == num_val 79 | assert len(labeled_indices) == num_labeled 80 | assert len(unlabeled_indices) == num_unlabeled 81 | 82 | # check the proportion of data 83 | assert convert_idx_to_num(labeled_indices) == convert_idx_to_num( 84 | np.arange(len(label_data)) 85 | ) 86 | assert convert_idx_to_num(val_indices) == convert_idx_to_num( 87 | np.arange(len(label_data)) 88 | ) 89 | assert convert_idx_to_num(unlabeled_indices) == convert_idx_to_num( 90 | np.arange(len(label_data)) 91 | ) 92 | 93 | 94 | def test_custum_dataset(): 95 | dataset_1 = TensorDataset(torch.arange(10)) 96 | dataset_2 = TensorDataset(torch.arange(30, 75)) 97 | dataset_3 = TensorDataset(torch.arange(30, 75)) 98 | dataset_4 = TensorDataset(torch.arange(30, 75)) 99 | 100 | dloader_1 = DataLoader(dataset_1, batch_size=3, shuffle=True, num_workers=0) 101 | dloader_2 = DataLoader( 102 | MultiDataset([dataset_2, dataset_3, dataset_4]), 103 | batch_size=3, 104 | shuffle=True, 105 | num_workers=0, 106 | ) 107 | 108 | for _ in range(10): 109 | for batch in MagicClass([dloader_1, dloader_2]): 110 | batch = batch[1] 111 | for b in batch[1:]: 112 | assert torch.all( 113 | batch[0][0] == b[0] 114 | ), "The data in multidataset should be the same" 115 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_basemodel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torch.nn 4 | import pytest 5 | 6 | from copy import deepcopy 7 | 8 | from lightning_ssl.models import WideResNet 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def model_inputs_pair(): 13 | return WideResNet(28, 10), torch.randn((5, 3, 32, 32)) 14 | 15 | 16 | def test_pseudo_label(model_inputs_pair): 17 | model = WideResNet(28, 10) 18 | model.eval() 19 | inputs_list = [torch.randn(5, 3, 32, 32) for _ in range(3)] 20 | 21 | assert model.psuedo_label(inputs_list, 0.5).shape == torch.Size([5, 10]) 22 | 23 | assert torch.all( 24 | torch.isclose( 25 | model.psuedo_label(inputs_list, 0.5), 26 | model.psuedo_label(inputs_list, 0.5, True), 27 | ) 28 | ) 29 | 30 | 31 | def test_freeze_running_stats(model_inputs_pair): 32 | model, inputs = model_inputs_pair 33 | model.freeze_running_stats() 34 | before_stats = deepcopy(model.extract_running_stats()) 35 | # run the network 36 | model(inputs) 37 | # the running_stats should not change 38 | after_stats = deepcopy(model.extract_running_stats()) 39 | 40 | # check the number of batch norm layers 41 | assert len(before_stats) == (28 - 4) // 6 * 2 * 3 + 1 42 | 43 | for f, s in zip(before_stats, after_stats): 44 | assert torch.all(torch.eq(f[0], s[0])) 45 | assert torch.all(torch.eq(f[1], s[1])) 46 | 47 | 48 | def test_recover_running_stats(model_inputs_pair): 49 | model, inputs = model_inputs_pair 50 | # recover running stats 51 | model.recover_running_stats() 52 | before_stats = deepcopy(model.extract_running_stats()) 53 | # run the network 54 | model(inputs) 55 | 56 | # the running_stats should not change 57 | after_stats = deepcopy(model.extract_running_stats()) 58 | 59 | for f, s in zip(before_stats, after_stats): 60 | assert torch.all(~torch.eq(f[0], s[0])) 61 | assert torch.all(~torch.eq(f[1], s[1])) 62 | 63 | 64 | def test_recover_running_stats__without_freeze_first(): 65 | model = WideResNet(28, 10) 66 | before_stats = model.extract_norm_n_momentum() 67 | model.recover_running_stats() 68 | after_stats = model.extract_norm_n_momentum() 69 | 70 | # test batchnorm statistics 71 | for f, s in zip(before_stats[0], after_stats[0]): 72 | assert torch.all(torch.eq(f.running_mean, s.running_mean)) 73 | assert torch.all(torch.eq(f.running_var, s.running_var)) 74 | # test momentum 75 | for f, s in zip(before_stats[1], after_stats[1]): 76 | assert f == s 77 | -------------------------------------------------------------------------------- /tests/models/test_cnn13.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from copy import deepcopy 5 | 6 | from lightning_ssl.models import CNN13 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "shape, num_classes", [((5, 3, 32, 32), 8), ((10, 3, 32, 32), 15)] 11 | ) 12 | def test_wideresnet(shape, num_classes): 13 | inputs = torch.randn(shape) 14 | model = CNN13(num_classes) 15 | batch_size = inputs.shape[0] 16 | assert model(inputs).shape == torch.Size([batch_size, num_classes]) 17 | -------------------------------------------------------------------------------- /tests/models/test_wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from copy import deepcopy 5 | 6 | from lightning_ssl.models import WideResNet 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "shape, num_classes", [((5, 3, 32, 32), 8), ((10, 3, 32, 32), 15)] 11 | ) 12 | def test_wideresnet(shape, num_classes): 13 | inputs = torch.randn(shape) 14 | model = WideResNet(28, num_classes, dropRate=0.2) 15 | batch_size = inputs.shape[0] 16 | assert model(inputs).shape == torch.Size([batch_size, num_classes]) 17 | -------------------------------------------------------------------------------- /tests/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/module/__init__.py -------------------------------------------------------------------------------- /tests/module/assets/tmp_config.yml: -------------------------------------------------------------------------------- 1 | cifar10: 2 | algo: "mixmatch" 3 | learning_scenario: learning_scenario 4 | batch_size: 10 5 | num_workers: 0 6 | log_root: "tests" 7 | num_labeled: 10 8 | num_val: 30 9 | -------------------------------------------------------------------------------- /tests/module/simple_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from lightning_ssl.dataloader.base_data import ( 4 | SemiDataLoader, 5 | SupervisedDataLoader, 6 | SemiDataModule, 7 | SupervisedDataModule, 8 | ) 9 | 10 | import torchvision as tv 11 | from torch.utils.data import DataLoader, SubsetRandomSampler 12 | from torchvision.datasets import CIFAR10 13 | 14 | 15 | CIFAR_MEAN = (0.4914, 0.4822, 0.4465) 16 | CIFAR_STD = (0.2471, 0.2435, 0.2616) 17 | 18 | 19 | def sample_dataset(): 20 | return [ 21 | (np.random.uniform(0, 1, (3, 2, 2)).astype(np.float32), random.randint(0, 9)) 22 | for i in range(50) 23 | ] 24 | 25 | 26 | class SemiSampleModule(SemiDataModule): 27 | def __init__( 28 | self, 29 | args, 30 | data_root, 31 | num_workers, 32 | batch_size, 33 | num_labeled, 34 | num_val, 35 | num_augments, 36 | ): 37 | n_classes = 10 38 | super(SemiSampleModule, self).__init__( 39 | data_root, 40 | num_workers, 41 | batch_size, 42 | num_labeled, 43 | num_val, 44 | num_augments, 45 | n_classes, 46 | ) 47 | 48 | def prepare_data(self): 49 | # the transformation for train and validation dataset will be 50 | # done in _prepare_train_dataset() 51 | 52 | self.train_set = sample_dataset() 53 | self.test_set = sample_dataset() 54 | 55 | 56 | class SupervisedSampleModule(SupervisedDataModule): 57 | def __init__( 58 | self, 59 | args, 60 | data_root, 61 | num_workers, 62 | batch_size, 63 | num_labeled, 64 | num_val, 65 | num_augments, 66 | ): 67 | n_classes = 10 68 | super(SupervisedSampleModule, self).__init__( 69 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes 70 | ) 71 | 72 | def prepare_data(self): 73 | # the transformation for train and validation dataset will be 74 | # done in _prepare_train_dataset() 75 | self.train_set = sample_dataset() 76 | self.test_set = sample_dataset() 77 | -------------------------------------------------------------------------------- /tests/module/simple_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from lightning_ssl.models.base_model import CustomModel 3 | 4 | 5 | class SimpleModel(CustomModel): 6 | def __init__(self, num_classes): 7 | super().__init__() 8 | self.num_classes = num_classes 9 | self.layer_1 = nn.AvgPool2d(kernel_size=2) 10 | self.layer_2 = nn.Linear(3, num_classes) 11 | 12 | def forward(self, inputs): 13 | output = self.layer_1(inputs).reshape(inputs.shape[0], -1) 14 | return self.layer_2(output) 15 | -------------------------------------------------------------------------------- /tests/module/test_mixmatch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import torch 5 | import pytest 6 | import pathlib 7 | import numpy as np 8 | from tests.module.simple_model import SimpleModel 9 | from lightning_ssl.module import ClassifierModule, MixmatchModule 10 | from lightning_ssl.dataloader import ( 11 | SemiCIFAR10Module, 12 | SupervisedCIFAR10Module, 13 | ) 14 | from lightning_ssl.models import WideResNet 15 | import pytorch_lightning as pl 16 | from lightning_ssl.utils import count_parameters 17 | from lightning_ssl.utils.argparser import parser, print_args 18 | 19 | PATH = pathlib.Path(__file__).parent 20 | CONFIG_PATH = os.path.join(PATH, "assets", "tmp_config.yml") 21 | 22 | 23 | @pytest.mark.parametrize("batch_inference", ["True", "False"]) 24 | def test_mixmatch(tmpdir, batch_inference): 25 | original_argv = copy.deepcopy(sys.argv) 26 | sys.argv = [ 27 | "tmp.py", 28 | "-c", 29 | f"{CONFIG_PATH}", 30 | "--log_root", 31 | str(tmpdir), 32 | "--batch_inference", 33 | batch_inference, 34 | ] 35 | print(sys.argv) 36 | args = parser() 37 | sys.argv = original_argv 38 | 39 | gpus = args.gpus if torch.cuda.is_available() else None 40 | 41 | # load the data and classifier 42 | 43 | data_loader = SemiCIFAR10Module( 44 | args, 45 | args.data_root, 46 | args.num_workers, 47 | args.batch_size, 48 | args.num_labeled, 49 | args.num_val, 50 | args.num_augments, 51 | ) 52 | classifier = WideResNet(depth=10, num_classes=data_loader.n_classes) 53 | 54 | print(f"model paramters: {count_parameters(classifier)} M") 55 | 56 | # set the number of classes in the args 57 | setattr(args, "n_classes", data_loader.n_classes) 58 | 59 | data_loader.prepare_data() 60 | data_loader.setup() 61 | 62 | model = MixmatchModule(args, classifier, loaders=None) 63 | 64 | print( 65 | f"labeled size: {data_loader.num_labeled_data}" 66 | f"unlabeled size: {data_loader.num_unlabeled_data}, " 67 | f"val size: {data_loader.num_val_data}, " 68 | f"test size: {data_loader.num_test_data}" 69 | ) 70 | 71 | save_folder = f"{args.dataset}_{args.learning_scenario}_{args.algo}_{args.affix}" 72 | 73 | tb_logger = pl.loggers.TensorBoardLogger( 74 | os.path.join(args.log_root, "lightning_logs"), name=save_folder 75 | ) 76 | tb_logger.log_hyperparams(args) 77 | 78 | # set the path of checkpoint 79 | save_dir = getattr(tb_logger, "save_dir", None) or getattr( 80 | tb_logger, "_save_dir", None 81 | ) 82 | ckpt_path = os.path.join( 83 | save_dir, tb_logger.name, f"version_{tb_logger.version}", "checkpoints" 84 | ) 85 | 86 | ckpt = pl.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_path, "{epoch}")) 87 | 88 | setattr(args, "checkpoint_folder", ckpt_path) 89 | 90 | print_args(args) 91 | 92 | trainer = pl.trainer.Trainer( 93 | gpus=gpus, 94 | logger=tb_logger, 95 | checkpoint_callback=ckpt, 96 | fast_dev_run=True, 97 | reload_dataloaders_every_epoch=True, 98 | progress_bar_refresh_rate=args.progress_bar_refresh_rate, 99 | ) 100 | 101 | trainer.fit(model, datamodule=data_loader) 102 | trainer.test() 103 | -------------------------------------------------------------------------------- /tests/module/test_supervised.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import torch 5 | import pytest 6 | import pathlib 7 | import numpy as np 8 | from lightning_ssl.module import ClassifierModule, MixmatchModule 9 | from lightning_ssl.dataloader import ( 10 | SemiCIFAR10Module, 11 | SupervisedCIFAR10Module, 12 | ) 13 | from lightning_ssl.models import WideResNet 14 | import pytorch_lightning as pl 15 | from lightning_ssl.utils import count_parameters 16 | from lightning_ssl.utils.argparser import parser, print_args 17 | 18 | PATH = pathlib.Path(__file__).parent 19 | CONFIG_PATH = os.path.join(PATH, "assets", "tmp_config.yml") 20 | 21 | 22 | def test_supervised(tmpdir): 23 | original_argv = copy.deepcopy(sys.argv) 24 | sys.argv = ["tmp.py", "-c", f"{CONFIG_PATH}", "--log_root", str(tmpdir)] 25 | print(sys.argv) 26 | args = parser() 27 | sys.argv = original_argv 28 | 29 | gpus = args.gpus if torch.cuda.is_available() else None 30 | 31 | # load the data and classifier 32 | 33 | data_loader = SupervisedCIFAR10Module( 34 | args, 35 | args.data_root, 36 | args.num_workers, 37 | args.batch_size, 38 | args.num_labeled, 39 | args.num_val, 40 | args.num_augments, 41 | ) 42 | classifier = WideResNet(depth=10, num_classes=data_loader.n_classes) 43 | 44 | print(f"model paramters: {count_parameters(classifier)} M") 45 | 46 | # set the number of classes in the args 47 | setattr(args, "n_classes", data_loader.n_classes) 48 | 49 | data_loader.prepare_data() 50 | data_loader.setup() 51 | 52 | model = ClassifierModule(args, classifier, loaders=None) 53 | 54 | print( 55 | f"labeled size: {data_loader.num_labeled_data}" 56 | f"unlabeled size: {data_loader.num_unlabeled_data}, " 57 | f"val size: {data_loader.num_val_data}, " 58 | f"test size: {data_loader.num_test_data}" 59 | ) 60 | 61 | save_folder = f"{args.dataset}_{args.learning_scenario}_{args.algo}_{args.affix}" 62 | 63 | tb_logger = pl.loggers.TensorBoardLogger( 64 | os.path.join(args.log_root, "lightning_logs"), name=save_folder 65 | ) 66 | tb_logger.log_hyperparams(args) 67 | 68 | # set the path of checkpoint 69 | save_dir = getattr(tb_logger, "save_dir", None) or getattr( 70 | tb_logger, "_save_dir", None 71 | ) 72 | ckpt_path = os.path.join( 73 | save_dir, tb_logger.name, f"version_{tb_logger.version}", "checkpoints" 74 | ) 75 | 76 | ckpt = pl.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_path, "{epoch}")) 77 | 78 | setattr(args, "checkpoint_folder", ckpt_path) 79 | 80 | print_args(args) 81 | 82 | trainer = pl.trainer.Trainer( 83 | gpus=gpus, 84 | logger=tb_logger, 85 | checkpoint_callback=ckpt, 86 | fast_dev_run=True, 87 | reload_dataloaders_every_epoch=True, 88 | ) 89 | 90 | trainer.fit(model, datamodule=data_loader) 91 | trainer.test() 92 | -------------------------------------------------------------------------------- /tests/test_models: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/test_models -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/assets/tmp_config.yml: -------------------------------------------------------------------------------- 1 | cifar10: 2 | todo: "test" 3 | affix: "test" 4 | -------------------------------------------------------------------------------- /tests/utils/test_argparser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import pathlib 5 | from lightning_ssl.utils.argparser import parser 6 | 7 | PATH = pathlib.Path(__file__).parent 8 | CONFIG_PATH = os.path.join(PATH, "assets", "tmp_config.yml") 9 | 10 | 11 | def test_parser(): 12 | original_argv = copy.deepcopy(sys.argv) 13 | sys.argv = [ 14 | "temp.py", 15 | "--c", 16 | f"{CONFIG_PATH}", 17 | "--learning_scenario", 18 | "supervised", 19 | "--todo", 20 | "train", 21 | ] 22 | 23 | args = parser() 24 | sys.argv = original_argv 25 | 26 | assert args.learning_scenario == "supervised" 27 | assert args.todo == "train" 28 | assert args.affix == "test" 29 | -------------------------------------------------------------------------------- /tests/utils/test_torch_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pytest 3 | import torch 4 | import numpy as np 5 | import torchvision 6 | from copy import deepcopy 7 | 8 | from lightning_ssl.utils.torch_utils import ( 9 | sharpening, 10 | soft_cross_entropy, 11 | l2_distribution_loss, 12 | mixup_data, 13 | smooth_label, 14 | to_one_hot_vector, 15 | mixup_data_for_testing, 16 | half_mixup_data_for_testing, 17 | customized_weight_decay, 18 | split_weight_decay_weights, 19 | interleave, 20 | my_interleave, 21 | WeightDecayModule, 22 | ) 23 | 24 | from lightning_ssl.models import WideResNet 25 | 26 | 27 | @pytest.fixture 28 | def logits_targets_pair(): 29 | """ 30 | The fixture to test softmax related function. Return the logits and soft targets. 31 | Return 32 | logits, targets 33 | """ 34 | logits = [[-1, 1, 0], [1, 0, -1]] 35 | targets = [[0.3, 0.6, 0.1], [0.5, 0.2, 0.3]] 36 | return torch.FloatTensor(logits), torch.FloatTensor(targets) 37 | 38 | 39 | @pytest.fixture( 40 | params=[ 41 | {"batch_size": 1, "num_classes": 5, "smoothing_factor": 0.3}, 42 | {"batch_size": 3, "num_classes": 5, "smoothing_factor": 0.01}, 43 | {"batch_size": 3, "num_classes": 5, "smoothing_factor": 0.0}, 44 | {"batch_size": 3, "num_classes": 5, "smoothing_factor": 1}, 45 | ] 46 | ) 47 | def targets_smoothing_classes_tuple(request): 48 | param = request.param 49 | return ( 50 | torch.randint(param["num_classes"], size=(param["batch_size"],)), 51 | param["smoothing_factor"], 52 | param["num_classes"], 53 | ) 54 | 55 | 56 | @pytest.fixture( 57 | params=[ 58 | {"batch_size": 20, "num_classes": 10}, 59 | {"batch_size": 30, "num_classes": 2}, 60 | {"batch_size": 20, "num_classes": 100}, 61 | ] 62 | ) 63 | def images_targets_num_classes_tuple(request): 64 | param = request.param 65 | batch_size = param["batch_size"] 66 | num_classes = param["num_classes"] 67 | return ( 68 | torch.randn(batch_size, 3, 32, 32), 69 | torch.randint(num_classes, size=(batch_size,)), 70 | num_classes, 71 | ) 72 | 73 | 74 | @pytest.mark.parametrize("temperature", [0.1, 0.5, 1, 2, 10]) 75 | def test_sharpening(temperature): 76 | def entropy(probs): 77 | return torch.sum(-probs * torch.log(probs), -1).mean() 78 | 79 | for _ in range(10): 80 | logits = torch.randn(4, 10) 81 | probs = torch.softmax(logits, -1) 82 | sharpening_probs = sharpening(probs, temperature) 83 | 84 | if temperature > 1: 85 | assert entropy(sharpening_probs) > entropy(probs) 86 | elif temperature < 1: 87 | assert entropy(sharpening_probs) < entropy(probs) 88 | else: 89 | assert torch.isclose(entropy(sharpening_probs), entropy(probs)) 90 | 91 | 92 | def test_soft_cross_entropy(logits_targets_pair): 93 | logits, targets = logits_targets_pair 94 | 95 | probs = torch.softmax(logits, -1) 96 | 97 | ans = 0 98 | for i in range(logits.shape[0]): 99 | for c in range(logits.shape[1]): 100 | ans += -torch.log(probs[i][c]) * targets[i][c] 101 | 102 | ans /= logits.shape[0] 103 | func_out = soft_cross_entropy(logits, targets, dim=-1) 104 | 105 | assert ans == func_out 106 | 107 | 108 | def test_l2_distribution_loss(logits_targets_pair): 109 | logits, targets = logits_targets_pair 110 | 111 | probs = torch.softmax(logits, -1) 112 | 113 | ans = 0 114 | for i in range(logits.shape[0]): 115 | for c in range(logits.shape[1]): 116 | ans += (probs[i][c] - targets[i][c]) ** 2 117 | 118 | ans /= logits.shape[0] * logits.shape[1] 119 | func_out = l2_distribution_loss(logits, targets, dim=-1) 120 | 121 | assert ans == func_out 122 | 123 | 124 | def test_smooth_label(targets_smoothing_classes_tuple): 125 | targets, smoothing_factor, num_classes = targets_smoothing_classes_tuple 126 | ################################################################## 127 | # test for label smoothing 128 | batch_size = targets.shape[0] 129 | y_labels = smooth_label(targets, num_classes, smoothing_factor) 130 | 131 | # predicted classes 132 | pred = torch.argmax(y_labels, dim=-1) 133 | 134 | # the logits of maximum classes should be (1 - smoothing_factor) 135 | assert torch.all( 136 | y_labels[torch.arange(batch_size), targets] == 1 - smoothing_factor 137 | ) 138 | # all the other logits should be the same 139 | assert torch.sum(y_labels != 1 - smoothing_factor) == batch_size * (num_classes - 1) 140 | # summation should be equal to 1 141 | assert torch.all(torch.isclose(torch.sum(y_labels, -1), torch.ones(pred.shape))) 142 | 143 | 144 | def test_one_hot(targets_smoothing_classes_tuple): 145 | targets, _, num_classes = targets_smoothing_classes_tuple 146 | batch_size = targets.shape[0] 147 | # test for one hot transformation 148 | y_labels = to_one_hot_vector(targets, num_classes) 149 | 150 | # predicted classes 151 | pred = torch.argmax(y_labels, dim=-1) 152 | 153 | # the logits of maximum classes should be 1 154 | assert torch.all(y_labels[torch.arange(batch_size), pred] == 1) 155 | # the maximum classes should be the same as targets 156 | assert torch.all(torch.eq(pred, targets)) 157 | # summation should be equal to 1 158 | assert torch.all(torch.isclose(torch.sum(y_labels, -1), torch.ones(pred.shape))) 159 | 160 | 161 | def test_mixup_minus(images_targets_num_classes_tuple): 162 | inputs, targets, _ = images_targets_num_classes_tuple 163 | mixed_inputs, mixed_targets = mixup_data(inputs, targets, -1) 164 | 165 | assert torch.all(torch.eq(inputs, mixed_inputs)) 166 | assert torch.all(torch.eq(targets, mixed_targets)) 167 | 168 | 169 | def test_mixup(images_targets_num_classes_tuple): 170 | inputs, targets, num_classes = images_targets_num_classes_tuple 171 | 172 | logits = smooth_label(targets, num_classes) 173 | 174 | _, _, _, _, gt_lambda = mixup_data_for_testing(inputs, logits, 0.5) 175 | 176 | # the mixed data should be between it's two ingredients 177 | 178 | # exclude the data which shuffle to the same place, 179 | # this will cause mixed data equal to the origin ingredients, 180 | # but sometimes two same numbers will be considered different due to the limitation float number 181 | 182 | # test lambda is within the range 183 | assert torch.all((gt_lambda >= 0) * (gt_lambda <= 1)) 184 | 185 | # same_pos_x = torch.all(torch.isclose(inputs, p_inputs).reshape(batch_size, -1), dim=-1) 186 | 187 | # inputs, p_inputs, logits, p_logits = \ 188 | # inputs[~same_pos_x], p_inputs[~same_pos_x], logits[~same_pos_x], p_logits[~same_pos_x] 189 | 190 | # mixed_x, mixed_y = mixed_x[~same_pos_x], mixed_y[~same_pos_x] 191 | 192 | # min_x, max_x = torch.min(inputs, p_inputs), torch.max(inputs, p_inputs) 193 | # min_y, max_y = torch.min(logits, p_logits), torch.max(logits, p_logits) 194 | 195 | # pos = ~((min_x <= mixed_x) * (mixed_x <= max_x)) 196 | 197 | # print(inputs[pos]==mixed_x[pos]) 198 | # print(p_inputs[pos]) 199 | 200 | # print(mixed_x[pos]) 201 | 202 | # assert torch.all((min_x <= mixed_x) * (mixed_x <= max_x)) == True 203 | 204 | # def reconstruct(mixed_d, original_d, shuffle_d): 205 | # numerator = mixed_d - shuffle_d 206 | # denominator = original_d - shuffle_d 207 | # mask = (denominator == 0).float() 208 | 209 | # reverted_l = numerator / (denominator + mask * 1e-15) 210 | 211 | # return reverted_l 212 | 213 | # l_1 = reconstruct(mixed_x, inputs, p_inputs) 214 | # l_1 = torch.mode(l_1.reshape(l_1.shape[0], -1), -1)[0] 215 | # l_2 = reconstruct(mixed_y, y_labels, p_y_labels) 216 | # l_2 = torch.max(l_2.reshape(l_2.shape[0], -1), -1)[0] 217 | 218 | # for i in range(l_1.shape[0]): 219 | # # only check when the y_label and permuted y_label are different, 220 | # # or the value will be 0 221 | # if torch.all(y_labels[i] == p_y_labels[i]): 222 | # assert l_2[i] == 0 223 | 224 | # elif torch.all(inputs[i] == p_inputs[i]): 225 | # assert l_1[i] == 0 226 | # else: 227 | # assert torch.isclose(l_1[i], l_2[i], atol=1e-5) == 0 228 | # assert 0.0 <= l_1[i] <= 1 == True 229 | 230 | 231 | def test_half_mixup(images_targets_num_classes_tuple): 232 | inputs, targets, num_classes = images_targets_num_classes_tuple 233 | 234 | logits = smooth_label(targets, num_classes) 235 | 236 | _, _, _, _, gt_lambda = half_mixup_data_for_testing(inputs, logits, 0.5) 237 | 238 | # test lambda is within the range 239 | assert torch.all((gt_lambda >= 0.5) * (gt_lambda <= 1)) 240 | 241 | 242 | def test_customized_weight_decay(): 243 | m = torchvision.models.resnet34() 244 | m_copy = deepcopy(m) 245 | 246 | ignore_key = ["bn", "bias"] 247 | wd = 1e-2 248 | customized_weight_decay(m, wd, ignore_key=ignore_key) 249 | 250 | for (name_m, p_m), (name_copy, p_copy) in zip( 251 | m.named_parameters(), m_copy.named_parameters() 252 | ): 253 | assert name_m == name_copy 254 | # ignore, and don't decay weight 255 | if any(key in name_m for key in ignore_key): 256 | assert torch.all(torch.eq(p_m, p_copy)) 257 | else: 258 | assert torch.all(torch.eq(p_m, p_copy * (1 - wd))) 259 | 260 | 261 | def test_weight_decay_module(): 262 | m = WideResNet(28, 10) 263 | 264 | wd_module = WeightDecayModule(m, 0.1, ["bn", "bias"]) 265 | 266 | num_weights = 29 # weights of conv and fc 267 | 268 | assert len(wd_module.available_parameters) == num_weights 269 | 270 | wd_module = WeightDecayModule(m, 0.1, ["weight"]) 271 | 272 | num_bias = 25 + 1 + 0 # bias of bn + bias of fc + bias of conv 273 | 274 | assert len(wd_module.available_parameters) == num_bias 275 | 276 | original_parameters = deepcopy(wd_module.available_parameters) 277 | wd_module.decay() 278 | 279 | for original_weight, decay_weight in zip( 280 | original_parameters, wd_module.available_parameters 281 | ): 282 | assert torch.all(torch.eq(original_weight * 0.9, decay_weight)) 283 | 284 | 285 | def test_split_weight_decay_weights(): 286 | m = torchvision.models.resnet34() 287 | m_copy = deepcopy(m) 288 | 289 | ignore_key = ["bn", "bias"] 290 | wd = 1e-2 291 | customized_weight_decay(m, wd, ignore_key=ignore_key) 292 | 293 | num_weight_decay = 0 294 | num_no_weight_decay = 0 295 | for (name_m, _), (name_copy, _) in zip( 296 | m.named_parameters(), m_copy.named_parameters() 297 | ): 298 | assert name_m == name_copy 299 | # ignore, and don't decay weight 300 | if any(key in name_m for key in ignore_key): 301 | num_no_weight_decay += 1 302 | else: 303 | num_weight_decay += 1 304 | 305 | params_list = split_weight_decay_weights(m, wd, ignore_key=ignore_key) 306 | 307 | # first item is no weight decay 308 | 309 | assert len(params_list[0]["params"]) == num_no_weight_decay 310 | assert params_list[0]["weight_decay"] == 0 311 | assert len(params_list[1]["params"]) == num_weight_decay 312 | assert params_list[1]["weight_decay"] == wd 313 | 314 | 315 | @pytest.mark.parametrize("shape", [[], [3, 32, 32]]) 316 | @pytest.mark.parametrize( 317 | "batch_size, num_data", [(5, 3), (128, 3), (50, 2), (3, 168), (168, 168)] 318 | ) 319 | def test_interleave(shape, batch_size, num_data): 320 | print(shape, batch_size, num_data) 321 | inputs = [torch.randint(100, [batch_size] + shape) for _ in range(num_data)] 322 | orig_inputs = deepcopy(inputs) 323 | 324 | # test_interleave 325 | 326 | o_1 = interleave(inputs, batch_size) 327 | o_3 = my_interleave(inputs, batch_size) 328 | 329 | for l, r in zip(o_1, o_3): 330 | assert torch.all(torch.eq(l, r)) 331 | 332 | # reverse interleave 333 | o_1 = interleave(o_1, batch_size) 334 | o_3 = my_interleave(o_3, batch_size) 335 | 336 | for a_, l, r in zip(orig_inputs, o_1, o_3): 337 | assert torch.all(torch.eq(a_, r)) 338 | assert torch.all(torch.eq(l, r)) 339 | -------------------------------------------------------------------------------- /tests/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import subprocess 4 | from lightning_ssl.utils.utils import makedirs 5 | 6 | 7 | def test_makedirs(tmpdir): 8 | folder = os.path.join(tmpdir, "a/b") 9 | makedirs(folder) 10 | assert os.path.exists(folder) 11 | --------------------------------------------------------------------------------