├── .circleci
└── config.yml
├── .gitignore
├── Readme.md
├── configs
├── config_mixmatch.yml
└── config_mixup.yml
├── imgs
├── flow.drawio
├── flow.pdf
└── flow.png
├── lightning_ssl
├── __init__.py
├── dataloader
│ ├── __init__.py
│ ├── base_data.py
│ └── cifar10.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── cnn13.py
│ └── wideresnet.py
├── module
│ ├── __init__.py
│ ├── classifier_module.py
│ ├── mixmatch.py
│ └── mixmatch_module.py
└── utils
│ ├── __init__.py
│ ├── argparser.py
│ ├── torch_utils.py
│ └── utils.py
├── main.py
├── read_from_tb.py
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── dataloader
├── __init__.py
└── test_base_data.py
├── models
├── __init__.py
├── test_basemodel.py
├── test_cnn13.py
└── test_wideresnet.py
├── module
├── __init__.py
├── assets
│ └── tmp_config.yml
├── simple_data.py
├── simple_model.py
├── test_mixmatch.py
└── test_supervised.py
├── test_models
└── utils
├── __init__.py
├── assets
└── tmp_config.yml
├── test_argparser.py
├── test_torch_utils.py
└── test_utils.py
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | orbs:
4 | python: circleci/python@1.0.0
5 |
6 | jobs:
7 | build-and-test:
8 | resource_class: medium
9 | executor: python/default
10 | steps:
11 | - checkout
12 | # - python/load-cache
13 | # - python/install-deps
14 | # - python/save-cache
15 | - run:
16 | name: Installation
17 | command: pip install -e .[test]
18 | - run:
19 | name: Test
20 | command: |
21 | python -m coverage run --source lightning_ssl -m py.test lightning_ssl tests -v --flake8
22 | python -m coverage report -m
23 | coverage html
24 | - store_artifacts:
25 | path: htmlcov
26 |
27 | workflows:
28 | main:
29 | jobs:
30 | - build-and-test
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/.DS_Store
2 | .vscode
3 | **/**/__pycache__/
4 | lightning_semi_supervised_learning.egg-info/
5 | **/lightning_logs/
6 | cifar-10*
7 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # Semi Supervised Learning with PyTorch Lightning
2 |
3 | This project aims to use PyTorch Lightning to implement state-of-the-art algorithms in semi-supervised leanring (SSL).
4 |
5 | ## Semi-Supervised Learning
6 | The semi-supervised learning is to leverage abundant unlabeled samples to improve models under the the scenario of scarce data. There are several assumptions which are usually used in semi-supervised learning,
7 |
8 | * Smoothness assumption
9 | * Low-density assumption
10 | * Manifold assumption
11 |
12 | Most of the approaches try to exploit regularization on models to satisfy the assumptions. In this repository, we will first focus on methods using consistency loss.
13 |
14 |
17 |
18 |
19 | ## What is PyTorch Lightning
20 | PyTorch Lightning is a PyTorch Wrapper to standardize the training and testing process of AI projects. The projects using PyTorch Lightning can focus on the implementation of the algorithm, and there is no need to worry about some complicated engineering parts, such as multi-GPU training, 16-bit precision, Tensorboard logging, and TPU training.
21 |
22 | In this project, we leverage PyTorch Lightning as the coding backbone and implement algorithms with minimum changes. The necessary implementation of a new algorithm is put in `module`.
23 |
24 |
25 |
26 | ## Requirements
27 | ```
28 | pip install -r requirements.txt
29 | ```
30 |
31 | ## Module Flow
32 |
33 |
34 |
35 |
36 | * `configs`: Contained the config files for approaches.
37 | * `models`: Contained all the models.
38 | * `dataloader`: Data loader for every dataset.
39 | * `module`: SSL modules inherited `pytorch_lightning.LightningModule`.
40 |
41 | To implement a new method, one usually need to define new config, data loader and PL module.
42 |
43 | ## Usage
44 |
45 | * Please refer to `argparser.py` for hyperparameters.
46 | * `read_from_tb.py` is used to extract the final accuracies from `tensorboard` logs.
47 |
48 | ### Fully-Supervised Training
49 |
50 | ```
51 | python main.py -c configs/config_mixup.ini -g [GPU ID] --affix [FILE NAME]
52 | ```
53 |
54 | ### Train for Mixmatch
55 |
56 | ```
57 | python main.py -c configs/config_mixmatch.ini -g [GPU ID] --num_labeled [NUMBER OF UNLABELED DATA] --affix [FILE NAME]
58 | ```
59 |
60 | ## Results
61 |
62 | ### Supervised Training
63 | The result is the average of three runs (seed=1, 2, 3).
64 |
65 | | | Acc |
66 | | :---: | :---: |
67 | | full train with mixup | 4.41±0.03 |
68 |
69 | ### Mixmatch
70 |
71 | The experiments run for five times (seed=1,2,3,4,5) in the paper, but only three times (seed=1,2,3) for this implementation.
72 |
73 |
74 | | | time (hr:min) | 250 | 500 | 1000 | 2000 | 4000 |
75 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
76 | | Paper | |11.08±0.87|9.65±0.94 |7.75±0.32|7.03±0.15|6.24±0.06|
77 | | Reproduce |17:24 |10.93±1.20|9.72±0.63 |8.02±0.74| - | - |
78 | | This repo |17:40 |11.10±1.00|10.05±0.45|8.00±0.42|7.13±0.13|6.22±0.08|
79 |
80 |
81 |
87 |
88 | ## Plans
89 |
90 | 1. Remixmatch
91 | 2. Fixmatch
92 | 3. GAN based method (DSGAN, BadGAN)
93 | 4. Other approach using consistency loss (VAT, mean teacher)
94 | 5. Polish the code for `CustomSemiDataset` in `data_loader/base_data.py`
95 |
--------------------------------------------------------------------------------
/configs/config_mixmatch.yml:
--------------------------------------------------------------------------------
1 | cifar10:
2 | learning_rate: 2e-3
3 | learning_scenario: semi
4 | algo: mixmatch
5 | weight_decay: 2e-2
6 | ema: 0.999
7 | batch_size: 64
8 | max_steps: 1048576
9 | max_epochs: 2000
10 | gpus: 0
11 | label_smoothing: 0.000
12 | num_augments: 2
13 | alpha: 0.75
14 | lambda_u: 75
15 | T: 0.5
16 | num_labeled: 250
17 | num_val: 10
18 | affix: "@250_brand_new"
19 |
--------------------------------------------------------------------------------
/configs/config_mixup.yml:
--------------------------------------------------------------------------------
1 | cifar10:
2 | learning_rate: 2e-3
3 | learning_scenario: supervised
4 | weight_decay: 2e-2
5 | ema: 0.999
6 | batch_size: 64
7 | max_steps: 1048576
8 | max_epochs: 2000
9 | gpus: 0
10 | label_smoothing: 0.001
11 | alpha: 1
12 | algo: mixup
13 | num_val: 10
14 | affix: fs_mixup
15 |
--------------------------------------------------------------------------------
/imgs/flow.drawio:
--------------------------------------------------------------------------------
1 | 3VjBcpswEP0aH9MxCLB9TBy3nU467QydadybgmRQKlhXFjHO11cYKUIxcdLWDnEvtvattEj79FikAZrm1QeBl9lnIJQP/CGpBuhy4PuTIFK/NbBpgDAIGyAVjDSQZ4GY3VMNDjVaMkJXTkcJwCVbumACRUET6WBYCFi73RbA3acucUp3gDjBfBf9zojMGnTsjyz+kbI0M0/2oknjybHprFeyyjCBdQtCswGaCgDZtPJqSnmdO5OXZtz7J7wPExO0kC8ZMP8Roy/ZND7js+z2Pg0+edPZmd9EucO81AvWk5UbkwEBZUFoHWQ4QBfrjEkaL3FSe9eKcoVlMufK8lRzwTifAgexHYsWCxolicJXUsBP2vKQ0eRmWAfUE6BC0urJlXkP+VL7jEJOpdioLnoAGukU6z1mMr62hKFIY1mLLIQ0iPUmSR9C2zyqhk7lH6TVQztppETtK22CkBmkUGA+s+iFm2jb5wpgqdN7S6XcaJHgUoKbfFoxed1qz+tQ70JtXVY68tbYGKNQ6722HWtz3vbZYVvLjGvWVy9qP2kqB1CKhO5LlpY1FimV+/oNu3eBoBxLdudO5PCUhn1SammctzwnT2nQK6NRr4z+lUgdPi29b4dR1CujHdUs4rKuS6AW3KY6+lWCcZyttmSdqw5+sKysU7XS+l99XCxYamKpqTXhGueeguk9XzAPUP4CL3DKX7hb/jzUUf4mR6t+ux8Rb7T6HVAg6CSqGDqSQgiWWDk5YELFqQjlDSil6431+tI5oAyCk5BBcCQZ5M3x9zS2v9f79g//mQZv3EVDHF/V0+UpCCaz/KWEqNRKN+vusbWAgj4642oIc5YWykwUFer9hy5qoliC+bl25IyQrVa7aHb1ewCmo6F7IEYvJNo/FtHRkfT29crEuREG+yYwK2wRso6j6PLxhcc4od0XHjfjMAgPxO/o0YVH/0r2ut6oJ13IzAXks5Us6rOQmVm+irJUdStVrP9GRuH49WSkTHvVu/W17svR7Dc=
--------------------------------------------------------------------------------
/imgs/flow.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/imgs/flow.pdf
--------------------------------------------------------------------------------
/imgs/flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/imgs/flow.png
--------------------------------------------------------------------------------
/lightning_ssl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/lightning_ssl/__init__.py
--------------------------------------------------------------------------------
/lightning_ssl/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar10 import (
2 | SemiCIFAR10Module,
3 | SupervisedCIFAR10Module,
4 | )
5 |
--------------------------------------------------------------------------------
/lightning_ssl/dataloader/base_data.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import math
3 | import torch
4 | import random
5 | import itertools
6 | import numpy as np
7 |
8 | from torch.utils.data import DataLoader, SubsetRandomSampler, Dataset
9 | import pytorch_lightning as pl
10 |
11 |
12 | def get_split_indices(labels, num_labeled, num_val, _n_classes):
13 | """
14 | Split the train data into the following three set:
15 | (1) labeled data
16 | (2) unlabeled data
17 | (3) val data
18 |
19 | Data distribution of the three sets are same as which of the
20 | original training data.
21 |
22 | Inputs:
23 | labels: (np.int) array of labels
24 | num_labeled: (int)
25 | num_val: (int)
26 | _n_classes: (int)
27 |
28 |
29 | Return:
30 | the three indices for the three sets
31 | """
32 | # val_per_class = num_val // _n_classes
33 | val_indices = []
34 | train_indices = []
35 |
36 | num_total = len(labels)
37 | num_per_class = []
38 | for c in range(_n_classes):
39 | num_per_class.append((labels == c).sum().astype(int))
40 |
41 | # obtain val indices, data evenly drawn from each class
42 | for c, num_class in zip(range(_n_classes), num_per_class):
43 | val_this_class = max(int(num_val * (num_class / num_total)), 1)
44 | class_indices = np.where(labels == c)[0]
45 | np.random.shuffle(class_indices)
46 | val_indices.append(class_indices[:val_this_class])
47 | train_indices.append(class_indices[val_this_class:])
48 |
49 | # split data into labeled and unlabeled
50 | labeled_indices = []
51 | unlabeled_indices = []
52 |
53 | # num_labeled_per_class = num_labeled // _n_classes
54 |
55 | for c, num_class in zip(range(_n_classes), num_per_class):
56 | num_labeled_this_class = max(int(num_labeled * (num_class / num_total)), 1)
57 | labeled_indices.append(train_indices[c][:num_labeled_this_class])
58 | unlabeled_indices.append(train_indices[c][num_labeled_this_class:])
59 |
60 | labeled_indices = np.hstack(labeled_indices)
61 | unlabeled_indices = np.hstack(unlabeled_indices)
62 | val_indices = np.hstack(val_indices)
63 |
64 | return labeled_indices, unlabeled_indices, val_indices
65 |
66 |
67 | class Subset(Dataset):
68 | r"""
69 | Subset of a dataset at specified indices.
70 |
71 | Arguments:
72 | dataset (Dataset): The whole Dataset
73 | indices (sequence): Indices in the whole set selected for subset
74 | """
75 |
76 | def __init__(self, dataset, indices, transform=None):
77 | self.dataset = dataset
78 | self.indices = indices
79 | self.transform = transform
80 |
81 | nums = [0 for _ in range(10)]
82 | for i in range(len(self.indices)):
83 | nums[self.dataset[self.indices[i]][1]] += 1
84 |
85 | print(nums)
86 | print(np.sum(nums))
87 |
88 | def __getitem__(self, idx):
89 | data, label = self.dataset[self.indices[idx]]
90 | if self.transform is not None:
91 | data = self.transform(data)
92 | return data, label
93 |
94 | def __len__(self):
95 | return len(self.indices)
96 |
97 |
98 | class MultiDataset(Dataset):
99 | """
100 | MultiDataset is used for training multiple datasets together. The lengths of the datasets
101 | should be the same.
102 | """
103 |
104 | def __init__(self, datasets):
105 | super(MultiDataset, self).__init__()
106 | assert len(datasets) > 1, "You should use at least two datasets"
107 |
108 | for d in datasets[1:]:
109 | assert len(d) == len(
110 | datasets[0]
111 | ), "The lengths of the datasets should be the same."
112 |
113 | self.datasets = datasets
114 | self.max_length = max([len(d) for d in self.datasets])
115 |
116 | def __getitem__(self, idx):
117 | return tuple([d[idx] for d in self.datasets])
118 |
119 | def __len__(self):
120 | return self.max_length
121 |
122 |
123 | class MagicClass(object):
124 | """
125 | Codes are borrowed from https://github.com/PyTorchLightning/pytorch-lightning/pull/1959
126 | """
127 |
128 | def __init__(self, data) -> None:
129 | self.d = data
130 | self.l = max([len(d) for d in self.d])
131 |
132 | def __len__(self) -> int:
133 | return self.l
134 |
135 | def __iter__(self):
136 | if isinstance(self.d, list):
137 | gen = [None for v in self.d]
138 |
139 | # for k,v in self.d.items():
140 | # # gen[k] = itertools.cycle(v)
141 | # gen[k] = iter(v)
142 |
143 | for i in range(self.l):
144 | rv = [None for v in self.d]
145 | for k, v in enumerate(self.d):
146 | # If reaching the end of the iterator, recreate one
147 | # because shuffle=True in dataloader, the iter will return a different order
148 | if i % len(v) == 0:
149 | gen[k] = iter(v)
150 | rv[k] = next(gen[k])
151 |
152 | yield rv
153 |
154 | else:
155 | gen = itertools.cycle(self.d)
156 | for i in range(self.l):
157 | batch = next(gen)
158 | yield batch
159 |
160 |
161 | class CustomSemiDataset(Dataset):
162 | def __init__(self, datasets):
163 | self.datasets = datasets
164 |
165 | self.map_indices = [[] for _ in self.datasets]
166 | self.min_length = min(len(d) for d in self.datasets)
167 | self.max_length = max(len(d) for d in self.datasets)
168 |
169 | def __getitem__(self, i):
170 | # return tuple(d[i] for d in self.datasets)
171 |
172 | # self.map_indices will reload when calling self.__len__()
173 | return tuple(d[m[i]] for d, m in zip(self.datasets, self.map_indices))
174 |
175 | def construct_map_index(self):
176 | """
177 | Construct the mapping indices for every data. Because the __len__ is larger than the size of some datset,
178 | the map_index is use to map the parameter "index" in __getitem__ to a valid index of each dataset.
179 | Because of the dataset has different length, we should maintain different indices for them.
180 | """
181 |
182 | def update_indices(original_indices, data_length, max_data_length):
183 | # update the sampling indices for this dataset
184 |
185 | # return: a list, which maps the range(max_data_length) to the val index in the dataset
186 |
187 | original_indices = original_indices[max_data_length:] # remove used indices
188 | fill_num = max_data_length - len(original_indices)
189 | batch = math.ceil(fill_num / data_length)
190 |
191 | additional_indices = list(range(data_length)) * batch
192 | random.shuffle(additional_indices)
193 |
194 | original_indices += additional_indices
195 |
196 | assert (
197 | len(original_indices) >= max_data_length
198 | ), "the length of matcing indices is too small"
199 |
200 | return original_indices
201 |
202 | # use same mapping index for all unlabeled dataset for data consistency
203 | # the i-th dataset is the labeled data
204 | self.map_indices = [
205 | update_indices(m, len(d), self.max_length)
206 | for m, d in zip(self.map_indices, self.datasets)
207 | ]
208 |
209 | # use same mapping index for all unlabeled dataset for data consistency
210 | # the i-th dataset is the labeled data
211 | for i in range(1, len(self.map_indices)):
212 | self.map_indices[i] = self.map_indices[1]
213 |
214 | def __len__(self):
215 | # will be called every epoch
216 | return self.max_length
217 |
218 |
219 | class DataModuleBase(pl.LightningDataModule):
220 | labeled_indices: ...
221 | unlabeled_indices: ...
222 | val_indices: ...
223 |
224 | def __init__(
225 | self, data_root, num_workers, batch_size, num_labeled, num_val, n_classes
226 | ):
227 | super().__init__()
228 | self.data_root = data_root
229 | self.batch_size = batch_size
230 | self.num_labeled = num_labeled
231 | self.num_val = num_val
232 | self._n_classes = n_classes
233 |
234 | self.train_transform = None # TODO, need implement this in your custom datasets
235 | self.test_transform = None # TODO, need implement this in your custom datasets
236 |
237 | self.train_set = None
238 | self.val_set = None
239 | self.test_set = None
240 |
241 | self.num_workers = num_workers
242 |
243 | def train_dataloader(self):
244 | # get and process the data first
245 |
246 | return DataLoader(
247 | self.train_set,
248 | batch_size=self.batch_size,
249 | shuffle=True,
250 | num_workers=self.num_workers,
251 | pin_memory=True,
252 | drop_last=True,
253 | )
254 |
255 | def val_dataloader(self):
256 | # return both val and test loader
257 |
258 | val_loader = DataLoader(
259 | self.val_set,
260 | batch_size=self.batch_size,
261 | shuffle=False,
262 | num_workers=self.num_workers,
263 | pin_memory=True,
264 | )
265 |
266 | test_loader = DataLoader(
267 | self.test_set,
268 | batch_size=self.batch_size,
269 | num_workers=self.num_workers,
270 | pin_memory=True,
271 | )
272 |
273 | return [val_loader, test_loader]
274 |
275 | def test_dataloader(self):
276 |
277 | return DataLoader(
278 | self.test_set,
279 | batch_size=self.batch_size,
280 | num_workers=self.num_workers,
281 | pin_memory=True,
282 | )
283 |
284 | @property
285 | def n_classes(self):
286 | # self._n_class should be defined in _prepare_train_dataset()
287 | return self._n_classes
288 |
289 | @property
290 | def num_labeled_data(self):
291 | assert self.train_set is not None, (
292 | "Load train data before calling %s" % self.num_labeled_data.__name__
293 | )
294 | return len(self.labeled_indices)
295 |
296 | @property
297 | def num_unlabeled_data(self):
298 | assert self.train_set is not None, (
299 | "Load train data before calling %s" % self.num_unlabeled_data.__name__
300 | )
301 | return len(self.unlabeled_indices)
302 |
303 | @property
304 | def num_val_data(self):
305 | assert self.train_set is not None, (
306 | "Load train data before calling %s" % self.num_val_data.__name__
307 | )
308 | return len(self.val_indices)
309 |
310 | @property
311 | def num_test_data(self):
312 | assert self.test_set is not None, (
313 | "Load test data before calling %s" % self.num_test_data.__name__
314 | )
315 | return len(self.test_set)
316 |
317 |
318 | class SemiDataModule(DataModuleBase):
319 | """
320 | Data module for semi-supervised tasks. self.prepare_data() is not implemented. For custom dataset,
321 | inherit this class and implement self.prepare_data().
322 | """
323 |
324 | def __init__(
325 | self,
326 | data_root,
327 | num_workers,
328 | batch_size,
329 | num_labeled,
330 | num_val,
331 | num_augments,
332 | n_classes,
333 | ):
334 | super(SemiDataModule, self).__init__(
335 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes
336 | )
337 | self.num_augments = num_augments
338 |
339 | def setup(self):
340 | # prepare train and val dataset, and split the train dataset
341 | # into labeled and unlabeled groups.
342 | assert (
343 | self.train_set is not None
344 | ), "Should create self.train_set in self.setup()"
345 |
346 | indices = np.arange(len(self.train_set))
347 | ys = np.array([self.train_set[i][1] for i in indices], dtype=np.int64)
348 | # np.random.shuffle(ys)
349 | # get the number of classes
350 | # self._n_classes = len(np.unique(ys))
351 |
352 | (
353 | self.labeled_indices,
354 | self.unlabeled_indices,
355 | self.val_indices,
356 | ) = get_split_indices(ys, self.num_labeled, self.num_val, self._n_classes)
357 |
358 | self.val_set = Subset(self.train_set, self.val_indices, self.test_transform)
359 |
360 | # unlabeled_list = [
361 | # Subset(self.train_set, self.unlabeled_indices, self.train_transform) \
362 | # for _ in range(self.num_augments)
363 | # ]
364 |
365 | # self.unlabeled_set = MultiDataset(unlabeled_list)
366 | # self.labeled_set = Subset(self.train_set, self.labeled_indices, self.train_transform)
367 |
368 | train_list = [
369 | Subset(self.train_set, self.unlabeled_indices, self.train_transform)
370 | for _ in range(self.num_augments)
371 | ]
372 |
373 | train_list.insert(
374 | 0, Subset(self.train_set, self.labeled_indices, self.train_transform)
375 | )
376 |
377 | self.train_set = CustomSemiDataset(train_list)
378 |
379 | # def train_dataloader(self):
380 | # # get and process the data first
381 | # if self.labeled_set is None:
382 | # self._prepare_train_dataset()
383 |
384 | # labeled_loader = DataLoader(self.labeled_set,
385 | # batch_size=self.batch_size,
386 | # shuffle=True,
387 | # num_workers=self.num_workers,
388 | # pin_memory=True,
389 | # drop_last=True)
390 |
391 | # unlabeled_loader = DataLoader(self.unlabeled_set,
392 | # batch_size=self.batch_size,
393 | # shuffle=True,
394 | # num_workers=self.num_workers,
395 | # pin_memory=True,
396 | # drop_last=True)
397 |
398 | # return MagicClass([labeled_loader, unlabeled_loader])
399 |
400 | def train_dataloader(self):
401 | # get and process the data first
402 |
403 | self.train_set.construct_map_index()
404 |
405 | print("\ncalled\n")
406 |
407 | return DataLoader(
408 | self.train_set,
409 | batch_size=self.batch_size,
410 | shuffle=True,
411 | num_workers=self.num_workers,
412 | pin_memory=True,
413 | drop_last=True,
414 | )
415 |
416 |
417 | class SupervisedDataModule(DataModuleBase):
418 | """
419 | Data module for supervised tasks. self.prepare_data() is not implemented. For custom dataset,
420 | inherit this class and implement self.prepare_data().
421 | """
422 |
423 | def __init__(
424 | self, data_root, num_workers, batch_size, num_labeled, num_val, n_classes
425 | ):
426 | super(SupervisedDataModule, self).__init__(
427 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes
428 | )
429 |
430 | def setup(self):
431 | # prepare train and val dataset
432 | assert (
433 | self.train_set is not None
434 | ), "Should create self.train_set in self.setup()"
435 |
436 | indices = np.arange(len(self.train_set))
437 | ys = np.array([self.train_set[i][1] for i in indices], dtype=np.int64)
438 | # get the number of classes
439 | # self._n_classes = len(np.unique(ys))
440 |
441 | (
442 | self.labeled_indices,
443 | self.unlabeled_indices,
444 | self.val_indices,
445 | ) = get_split_indices(ys, self._n_classes, self.num_val, self._n_classes)
446 |
447 | self.labeled_indices = np.hstack((self.labeled_indices, self.unlabeled_indices))
448 | self.unlabeled_indices = [] # dummy. only for printing length
449 |
450 | self.val_set = Subset(self.train_set, self.val_indices, self.test_transform)
451 | self.train_set = Subset(
452 | self.train_set, self.labeled_indices, self.train_transform
453 | )
454 |
--------------------------------------------------------------------------------
/lightning_ssl/dataloader/cifar10.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lightning_ssl.dataloader.base_data import (
3 | SemiDataModule,
4 | SupervisedDataModule,
5 | )
6 |
7 | import torchvision as tv
8 | from torch.utils.data import DataLoader, SubsetRandomSampler
9 | from torchvision.datasets import CIFAR10
10 |
11 |
12 | CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
13 | CIFAR_STD = (0.2471, 0.2435, 0.2616)
14 |
15 |
16 | class SemiCIFAR10Module(SemiDataModule):
17 | def __init__(
18 | self,
19 | args,
20 | data_root,
21 | num_workers,
22 | batch_size,
23 | num_labeled,
24 | num_val,
25 | num_augments,
26 | ):
27 | n_classes = 10
28 | super(SemiCIFAR10Module, self).__init__(
29 | data_root,
30 | num_workers,
31 | batch_size,
32 | num_labeled,
33 | num_val,
34 | num_augments,
35 | n_classes,
36 | )
37 |
38 | self.train_transform = tv.transforms.Compose(
39 | [
40 | tv.transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
41 | tv.transforms.RandomHorizontalFlip(),
42 | tv.transforms.ToTensor(),
43 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
44 | ]
45 | )
46 |
47 | self.test_transform = tv.transforms.Compose(
48 | [
49 | tv.transforms.ToTensor(),
50 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
51 | ]
52 | )
53 |
54 | def prepare_data(self):
55 | # the transformation for train and validation dataset will be
56 | # done in _prepare_train_dataset()
57 |
58 | self.train_set = CIFAR10(
59 | self.data_root, train=True, download=True, transform=None
60 | )
61 |
62 | self.test_set = CIFAR10(
63 | self.data_root, train=False, download=True, transform=self.test_transform
64 | )
65 |
66 |
67 | class SupervisedCIFAR10Module(SupervisedDataModule):
68 | def __init__(
69 | self,
70 | args,
71 | data_root,
72 | num_workers,
73 | batch_size,
74 | num_labeled,
75 | num_val,
76 | num_augments,
77 | ):
78 | n_classes = 10
79 | super(SupervisedCIFAR10Module, self).__init__(
80 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes
81 | )
82 |
83 | self.train_transform = tv.transforms.Compose(
84 | [
85 | tv.transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
86 | tv.transforms.RandomHorizontalFlip(),
87 | tv.transforms.ToTensor(),
88 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
89 | ]
90 | )
91 |
92 | self.test_transform = tv.transforms.Compose(
93 | [
94 | tv.transforms.ToTensor(),
95 | tv.transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
96 | ]
97 | )
98 |
99 | def prepare_data(self):
100 | # the transformation for train and validation dataset will be
101 | # done in _prepare_train_dataset()
102 | self.train_set = CIFAR10(
103 | self.data_root, train=True, download=True, transform=None
104 | )
105 |
106 | self.test_set = CIFAR10(
107 | self.data_root, train=False, download=True, transform=self.test_transform
108 | )
109 |
--------------------------------------------------------------------------------
/lightning_ssl/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .cnn13 import CNN13
2 | from .wideresnet import WideResNet
3 |
--------------------------------------------------------------------------------
/lightning_ssl/models/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from lightning_ssl.utils.torch_utils import sharpening
6 |
7 |
8 | class CustomModel(nn.Module):
9 | def __init__(self):
10 | super(CustomModel, self).__init__()
11 | self.batchnorms, self.momentums = None, None
12 |
13 | @torch.no_grad()
14 | def psuedo_label(self, unlabeled_xs, temperature, batch_inference=False):
15 | """
16 | Generate the pseudo labels for given unlabeled_xs.
17 | Args
18 | unlabeled_xs: list of unlabeled data (torch.FloatTensor),
19 | e.g. [unlabeled_data_1, unlabeled_data_2, ..., unlabeled_data_n]
20 | Note that i th element in those unlabeled data should has same
21 | semantic meaning (come from different data augmentations).
22 | temperature: the temperature parameter in softmax
23 | """
24 | if batch_inference:
25 | batch_size, num_augmentations = unlabeled_xs[0].shape[0], len(unlabeled_xs)
26 | unlabeled_xs = torch.cat(unlabeled_xs, dim=0)
27 | p_labels = F.softmax(self(unlabeled_xs), dim=-1)
28 | p_labels = p_labels.view(num_augmentations, batch_size, -1)
29 | p_label = p_labels.mean(0).detach()
30 | else:
31 | p_labels = [F.softmax(self(x), dim=-1) for x in unlabeled_xs]
32 | p_label = torch.stack(p_labels).mean(0).detach()
33 |
34 | return sharpening(p_label, temperature)
35 |
36 | def extract_norm_n_momentum(self):
37 | # Extract the batchnorms and their momentum
38 | original_momentums = []
39 | batchnorms = []
40 | for module in self.modules():
41 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
42 | original_momentums.append(module.momentum)
43 | batchnorms.append(module)
44 |
45 | return batchnorms, original_momentums
46 |
47 | def extract_running_stats(self):
48 | # Extract the running stats of batchnorms
49 | running_stats = []
50 | for module in self.modules():
51 | if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
52 | running_stats.append([module.running_mean, module.running_var])
53 |
54 | return running_stats
55 |
56 | def freeze_running_stats(self):
57 | # Set the batchnorms' momentum to 0 to freeze the running stats
58 | # First call
59 | if None in [self.batchnorms, self.momentums]:
60 | self.batchnorms, self.momentums = self.extract_norm_n_momentum()
61 |
62 | for module in self.batchnorms:
63 | module.momentum = 0
64 |
65 | def recover_running_stats(self):
66 | # Recover the batchnorms' momentum to make running stats updatable
67 | if None in [self.batchnorms, self.momentums]:
68 | return
69 | for module, momentum in zip(self.batchnorms, self.momentums):
70 | module.momentum = momentum
71 |
--------------------------------------------------------------------------------
/lightning_ssl/models/cnn13.py:
--------------------------------------------------------------------------------
1 | # The code is borrowed from
2 | # https://github.com/benathi/fastswa-semi-sup/blob/master/mean_teacher/architectures.py
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.utils import weight_norm
7 |
8 | from lightning_ssl.models.base_model import CustomModel
9 |
10 |
11 | class GaussianNoise(nn.Module):
12 | def __init__(self, std):
13 | super(GaussianNoise, self).__init__()
14 | self.std = std
15 |
16 | def forward(self, x):
17 | zeros_ = torch.zeros(x.size()).to(x.device)
18 | n = torch.normal(zeros_, std=self.std).to(x.device)
19 | return x + n
20 |
21 |
22 | class CNN13(CustomModel):
23 | """
24 | CNN from Mean Teacher paper
25 | """
26 |
27 | def __init__(self, num_classes=10):
28 | super(CNN13, self).__init__()
29 |
30 | self.gn = GaussianNoise(0.15)
31 | self.activation = nn.LeakyReLU(0.1)
32 | self.conv1a = weight_norm(nn.Conv2d(3, 128, 3, padding=1))
33 | self.bn1a = nn.BatchNorm2d(128)
34 | self.conv1b = weight_norm(nn.Conv2d(128, 128, 3, padding=1))
35 | self.bn1b = nn.BatchNorm2d(128)
36 | self.conv1c = weight_norm(nn.Conv2d(128, 128, 3, padding=1))
37 | self.bn1c = nn.BatchNorm2d(128)
38 | self.mp1 = nn.MaxPool2d(2, stride=2, padding=0)
39 | self.drop1 = nn.Dropout(0.5)
40 |
41 | self.conv2a = weight_norm(nn.Conv2d(128, 256, 3, padding=1))
42 | self.bn2a = nn.BatchNorm2d(256)
43 | self.conv2b = weight_norm(nn.Conv2d(256, 256, 3, padding=1))
44 | self.bn2b = nn.BatchNorm2d(256)
45 | self.conv2c = weight_norm(nn.Conv2d(256, 256, 3, padding=1))
46 | self.bn2c = nn.BatchNorm2d(256)
47 | self.mp2 = nn.MaxPool2d(2, stride=2, padding=0)
48 | self.drop2 = nn.Dropout(0.5)
49 |
50 | self.conv3a = weight_norm(nn.Conv2d(256, 512, 3, padding=0))
51 | self.bn3a = nn.BatchNorm2d(512)
52 | self.conv3b = weight_norm(nn.Conv2d(512, 256, 1, padding=0))
53 | self.bn3b = nn.BatchNorm2d(256)
54 | self.conv3c = weight_norm(nn.Conv2d(256, 128, 1, padding=0))
55 | self.bn3c = nn.BatchNorm2d(128)
56 | self.ap3 = nn.AvgPool2d(6, stride=2, padding=0)
57 |
58 | self.fc1 = weight_norm(nn.Linear(128, num_classes))
59 |
60 | def forward(self, x, debug=False):
61 | x = self.activation(self.bn1a(self.conv1a(x)))
62 | x = self.activation(self.bn1b(self.conv1b(x)))
63 | x = self.activation(self.bn1c(self.conv1c(x)))
64 | x = self.mp1(x)
65 | x = self.drop1(x)
66 |
67 | x = self.activation(self.bn2a(self.conv2a(x)))
68 | x = self.activation(self.bn2b(self.conv2b(x)))
69 | x = self.activation(self.bn2c(self.conv2c(x)))
70 | x = self.mp2(x)
71 | x = self.drop2(x)
72 |
73 | x = self.activation(self.bn3a(self.conv3a(x)))
74 | x = self.activation(self.bn3b(self.conv3b(x)))
75 | x = self.activation(self.bn3c(self.conv3c(x)))
76 | x = self.ap3(x)
77 |
78 | x = x.view(-1, 128)
79 |
80 | return self.fc1(x)
81 |
--------------------------------------------------------------------------------
/lightning_ssl/models/wideresnet.py:
--------------------------------------------------------------------------------
1 | # The code is borrow from
2 | # https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
3 |
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from lightning_ssl.models.base_model import CustomModel
10 |
11 | BIAS = False
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, pre_activate=False):
16 | super(BasicBlock, self).__init__()
17 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001)
18 | self.relu1 = nn.LeakyReLU(0.1, inplace=True)
19 | self.conv1 = nn.Conv2d(
20 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=BIAS
21 | )
22 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001)
23 | self.relu2 = nn.LeakyReLU(0.1, inplace=True)
24 | self.conv2 = nn.Conv2d(
25 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=BIAS
26 | )
27 | self.droprate = dropRate
28 | self.equalInOut = in_planes == out_planes
29 | self.convShortcut = (
30 | (not self.equalInOut)
31 | and nn.Conv2d(
32 | in_planes,
33 | out_planes,
34 | kernel_size=1,
35 | stride=stride,
36 | padding=0,
37 | bias=BIAS,
38 | )
39 | or None
40 | )
41 | self.pre_activate = pre_activate
42 |
43 | def forward(self, x):
44 | # if not self.equalInOut:
45 | out = self.relu1(self.bn1(x))
46 | if self.pre_activate:
47 | x = out
48 | # else:
49 | # out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
50 | out = self.relu2(self.bn2(self.conv1(out)))
51 | if self.droprate > 0:
52 | out = F.dropout(out, p=self.droprate, training=self.training)
53 | out = self.conv2(out)
54 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
55 |
56 |
57 | class NetworkBlock(nn.Module):
58 | def __init__(
59 | self,
60 | nb_layers,
61 | in_planes,
62 | out_planes,
63 | block,
64 | stride,
65 | dropRate=0.0,
66 | pre_activate=False,
67 | ):
68 | super(NetworkBlock, self).__init__()
69 | self.layer = self._make_layer(
70 | block, in_planes, out_planes, nb_layers, stride, dropRate, pre_activate
71 | )
72 |
73 | def _make_layer(
74 | self, block, in_planes, out_planes, nb_layers, stride, dropRate, pre_activate
75 | ):
76 | layers = []
77 | for i in range(int(nb_layers)):
78 | layers.append(
79 | block(
80 | i == 0 and in_planes or out_planes,
81 | out_planes,
82 | i == 0 and stride or 1,
83 | dropRate,
84 | pre_activate and i == 0,
85 | )
86 | )
87 | return nn.Sequential(*layers)
88 |
89 | def forward(self, x):
90 | return self.layer(x)
91 |
92 |
93 | class WideResNet(CustomModel):
94 | def __init__(self, depth, num_classes, widen_factor=2, dropRate=0.0):
95 | super(WideResNet, self).__init__()
96 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
97 | assert (depth - 4) % 6 == 0
98 | n = (depth - 4) / 6
99 | block = BasicBlock
100 | # 1st conv before any network block
101 | self.conv1 = nn.Conv2d(
102 | 3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=BIAS
103 | )
104 | # 1st block
105 | self.block1 = NetworkBlock(
106 | n, nChannels[0], nChannels[1], block, 1, dropRate, True
107 | )
108 | # 2nd block
109 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
110 | # 3rd block
111 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
112 | # global average pooling and classifier
113 | self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
114 | self.relu = nn.LeakyReLU(0.1, inplace=True)
115 | self.fc = nn.Linear(nChannels[3], num_classes)
116 | self.nChannels = nChannels[3]
117 |
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | n = 0.5 * m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121 | m.weight.data.normal_(0, math.sqrt(1.0 / n))
122 | if BIAS:
123 | m.bias.data.zero_()
124 | elif isinstance(m, nn.BatchNorm2d):
125 | m.weight.data.fill_(1)
126 | m.bias.data.zero_()
127 | elif isinstance(m, nn.Linear):
128 | nn.init.xavier_normal_(m.weight.data)
129 | m.bias.data.zero_()
130 |
131 | def forward(self, x):
132 | out = self.conv1(x)
133 | out = self.block1(out)
134 | out = self.block2(out)
135 | out = self.block3(out)
136 | out = self.relu(self.bn1(out))
137 | out = F.avg_pool2d(out, 8)
138 | out = out.view(-1, self.nChannels)
139 | return self.fc(out)
140 |
141 |
142 | # class BasicBlock(nn.Module):
143 | # def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False):
144 | # super(BasicBlock, self).__init__()
145 | # self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001)
146 | # self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
147 | # self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
148 | # padding=1, bias=False)
149 | # self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001)
150 | # self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
151 | # self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
152 | # padding=1, bias=False)
153 | # self.droprate = dropRate
154 | # self.equalInOut = (in_planes == out_planes)
155 | # self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
156 | # padding=0, bias=False) or None
157 | # self.activate_before_residual = activate_before_residual
158 | # def forward(self, x):
159 | # if not self.equalInOut and self.activate_before_residual == True:
160 | # x = self.relu1(self.bn1(x))
161 | # else:
162 | # out = self.relu1(self.bn1(x))
163 | # out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
164 | # if self.droprate > 0:
165 | # out = F.dropout(out, p=self.droprate, training=self.training)
166 | # out = self.conv2(out)
167 | # return torch.add(x if self.equalInOut else self.convShortcut(x), out)
168 |
169 | # class NetworkBlock(nn.Module):
170 | # def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False):
171 | # super(NetworkBlock, self).__init__()
172 | # self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual)
173 | # def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual):
174 | # layers = []
175 | # for i in range(int(nb_layers)):
176 | # layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual))
177 | # return nn.Sequential(*layers)
178 | # def forward(self, x):
179 | # return self.layer(x)
180 |
181 | # class WideResNet(nn.Module):
182 | # def __init__(self, num_classes, depth=28, widen_factor=2, dropRate=0.0):
183 | # super(WideResNet, self).__init__()
184 | # nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
185 | # assert((depth - 4) % 6 == 0)
186 | # n = (depth - 4) / 6
187 | # block = BasicBlock
188 | # # 1st conv before any network block
189 | # self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
190 | # padding=1, bias=False)
191 | # # 1st block
192 | # self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True)
193 | # # 2nd block
194 | # self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
195 | # # 3rd block
196 | # self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
197 | # # global average pooling and classifier
198 | # self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001)
199 | # self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
200 | # self.fc = nn.Linear(nChannels[3], num_classes)
201 | # self.nChannels = nChannels[3]
202 |
203 | # for m in self.modules():
204 | # if isinstance(m, nn.Conv2d):
205 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
206 | # m.weight.data.normal_(0, math.sqrt(2. / n))
207 | # elif isinstance(m, nn.BatchNorm2d):
208 | # m.weight.data.fill_(1)
209 | # m.bias.data.zero_()
210 | # elif isinstance(m, nn.Linear):
211 | # nn.init.xavier_normal_(m.weight.data)
212 | # m.bias.data.zero_()
213 |
214 | # def forward(self, x):
215 | # out = self.conv1(x)
216 | # out = self.block1(out)
217 | # out = self.block2(out)
218 | # out = self.block3(out)
219 | # out = self.relu(self.bn1(out))
220 | # out = F.avg_pool2d(out, 8)
221 | # out = out.view(-1, self.nChannels)
222 | # return self.fc(out)
223 |
--------------------------------------------------------------------------------
/lightning_ssl/module/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier_module import ClassifierModule
2 | from .mixmatch_module import MixmatchModule
3 |
--------------------------------------------------------------------------------
/lightning_ssl/module/classifier_module.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import collections
4 | import numpy as np
5 | from copy import deepcopy
6 | from torch.nn import functional as F
7 | from torch.utils.data import DataLoader
8 | from torchvision.datasets import MNIST
9 | from torchvision import transforms
10 | import pytorch_lightning as pl
11 | from lightning_ssl.utils.torch_utils import (
12 | EMA,
13 | smooth_label,
14 | soft_cross_entropy,
15 | mixup_data,
16 | customized_weight_decay,
17 | WeightDecayModule,
18 | split_weight_decay_weights,
19 | )
20 |
21 | # Use the style similar to pytorch_lightning (pl)
22 | # Codes will revised to be compatible with pl when pl has all the necessary features.
23 |
24 | # Codes borrowed from
25 | # https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=x-34xKCI40yW
26 |
27 |
28 | class ClassifierModule(pl.LightningModule):
29 | def __init__(self, hparams, classifier, loaders):
30 | super(ClassifierModule, self).__init__()
31 | self.hparams = hparams
32 | self.classifier = classifier
33 | self.ema_classifier = deepcopy(classifier)
34 | self.loaders = loaders
35 | self.best_dict = {
36 | "val_acc": 0,
37 | }
38 | self.train_dict = {
39 | key: collections.deque([], size)
40 | for key, size in zip(
41 | ["loss", "acc", "val_acc", "test_acc"], [500, 500, 1, 20]
42 | )
43 | }
44 |
45 | # to record the best validation accuracy
46 | self.train_dict["val_acc"].append(0) # avoid empty list
47 |
48 | def on_train_start(self):
49 | # model will put in GPU before this function
50 | # so we initiate EMA and WeightDecayModule here
51 | self.ema = EMA(self.classifier, self.ema_classifier, self.hparams.ema)
52 | # self.wdm = WeightDecayModule(self.classifier, self.hparams.weight_decay, ["bn", "bias"])
53 |
54 | def on_train_batch_end(self, *args, **kwargs):
55 | # self.ema.update(self.classifier)
56 | # wd = self.hparams.weight_decay * self.hparams.learning_rate
57 | # customized_weight_decay(self.classifier, self.hparams.weight_decay, ["bn", "bias"])
58 | # self.wdm.decay()
59 | self.ema.step()
60 |
61 | def accuracy(self, y_hat, y):
62 | return 100 * (torch.argmax(y_hat, dim=-1) == y).float().mean()
63 |
64 | def training_step(self, batch, batch_nb):
65 | # REQUIRED
66 | x, y = batch
67 | # return smooth one-hot like label
68 | y_one_hot = smooth_label(
69 | y, self.hparams.n_classes, self.hparams.label_smoothing
70 | )
71 | # mixup
72 | mixed_x, mixed_y = mixup_data(x, y_one_hot, self.hparams.alpha)
73 | y_hat = self.classifier(mixed_x)
74 | loss = soft_cross_entropy(y_hat, mixed_y)
75 | y = torch.argmax(mixed_y, dim=-1)
76 | acc = self.accuracy(y_hat, y)
77 | num = len(y)
78 |
79 | self.train_dict["loss"].append(loss.item())
80 | self.train_dict["acc"].append(acc.item())
81 |
82 | # tensorboard_logs = {"train/loss": np.mean(self.train_dict["loss"]),
83 | # "train/acc": np.mean(self.train_dict["acc"])}
84 |
85 | # progress_bar = {"acc": np.mean(self.train_dict["acc"])}
86 |
87 | # return {"loss": loss, "train_acc": acc,
88 | # "train_num": num, "log": tensorboard_logs,
89 | # "progress_bar": progress_bar}
90 |
91 | self.log(
92 | "train/loss", np.mean(self.train_dict["loss"]), prog_bar=False, logger=True
93 | )
94 | self.log(
95 | "train/acc", np.mean(self.train_dict["acc"]), prog_bar=False, logger=True
96 | )
97 |
98 | return {"loss": loss, "train_acc": acc, "train_num": num}
99 |
100 | def validation_step(self, batch, *args):
101 | # OPTIONAL
102 | x, y = batch
103 | y_hat = self.ema_classifier(x)
104 |
105 | acc = self.accuracy(y_hat, y)
106 | num = len(y)
107 |
108 | return {"val_loss": F.cross_entropy(y_hat, y), "val_acc": acc, "val_num": num}
109 |
110 | def validation_epoch_end(self, outputs):
111 | # Monitor both validation set and test set
112 | # record the test accuracies of last 20 checkpoints
113 |
114 | avg_loss_list, avg_acc_list = [], []
115 |
116 | for output in outputs:
117 | avg_loss = torch.stack(
118 | [x["val_loss"] * x["val_num"] for x in output]
119 | ).sum() / np.sum([x["val_num"] for x in output])
120 | avg_acc = torch.stack(
121 | [x["val_acc"] * x["val_num"] for x in output]
122 | ).sum() / np.sum([x["val_num"] for x in output])
123 |
124 | avg_loss_list.append(avg_loss)
125 | avg_acc_list.append(avg_acc)
126 |
127 | # record best results of validation set
128 | self.train_dict["val_acc"][0] = max(
129 | self.train_dict["val_acc"][0], avg_acc_list[0].item()
130 | )
131 | self.train_dict["test_acc"].append(avg_acc_list[1].item())
132 |
133 | # tensorboard_logs = {"val/loss": avg_loss_list[0],
134 | # "val/acc": avg_acc_list[0],
135 | # "val/best_acc": self.train_dict["val_acc"][0],
136 | # "test/median_acc": np.median(self.train_dict["test_acc"])}
137 |
138 | # return {"val_loss": avg_loss_list[0], "val_acc": avg_acc_list[0], "log": tensorboard_logs}
139 | self.logger.experiment.add_scalar(
140 | "t", np.median(self.train_dict["test_acc"]), self.global_step
141 | )
142 | self.log("val/loss", avg_loss_list[0], prog_bar=False, logger=True)
143 | self.log("val/acc", avg_acc_list[0], prog_bar=False, logger=True)
144 | self.log(
145 | "val/best_acc", self.train_dict["val_acc"][0], prog_bar=False, logger=True
146 | )
147 | self.log(
148 | "test/median_acc",
149 | np.median(self.train_dict["test_acc"]),
150 | prog_bar=False,
151 | logger=True,
152 | )
153 |
154 | def test_step(self, batch, batch_nb):
155 | # OPTIONAL
156 | x, y = batch
157 | y_hat = self.ema_classifier(x)
158 | acc = self.accuracy(y_hat, y)
159 | num = len(y)
160 |
161 | return {
162 | "test_loss": F.cross_entropy(y_hat, y),
163 | "test_acc": acc,
164 | "test_num": num,
165 | }
166 |
167 | def test_epoch_end(self, outputs):
168 | # OPTIONAL
169 | avg_loss = torch.stack(
170 | [x["test_loss"] * x["test_num"] for x in outputs]
171 | ).sum() / np.sum([x["test_num"] for x in outputs])
172 | avg_acc = torch.stack(
173 | [x["test_acc"] * x["test_num"] for x in outputs]
174 | ).sum() / np.sum([x["test_num"] for x in outputs])
175 | # logs = {"test/loss": avg_loss, "test/acc": avg_acc}
176 | # return {"test_loss": avg_loss, "test_acc": avg_acc, "log": logs, "progress_bar": logs}
177 |
178 | self.log("test/loss", avg_loss, prog_bar=False, logger=True)
179 | self.log("test/acc", avg_acc, prog_bar=False, logger=True)
180 |
181 | def configure_optimizers(self):
182 | # REQUIRED
183 | # can return multiple optimizers and learning_rate schedulers
184 | # (LBFGS it is automatically supported, no need for closure function)
185 |
186 | # split the weights into need weight decay and no need weight decay
187 | parameters = split_weight_decay_weights(
188 | self.classifier, self.hparams.weight_decay, ["bn", "bias"]
189 | )
190 |
191 | opt = torch.optim.AdamW(
192 | parameters, lr=self.hparams.learning_rate, weight_decay=0
193 | )
194 | # opt = torch.optim.SGD(self.classifier.parameters(), self.hparams.learning_rate,
195 | # momentum=0.9, nesterov=True)
196 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, float(self.hparams.max_epoch))
197 | return [opt] # , [scheduler]
198 |
199 | # def train_dataloader(self):
200 | # # REQUIRED
201 | # return self.loaders["tr_loader"]
202 |
203 | # def val_dataloader(self):
204 | # # OPTIONAL
205 | # return [self.loaders["va_loader"], self.loaders["te_loader"]]
206 |
207 | # def test_dataloader(self):
208 | # # OPTIONAL
209 | # return self.loaders["te_loader"]
210 |
--------------------------------------------------------------------------------
/lightning_ssl/module/mixmatch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from lightning_ssl.utils.torch_utils import (
6 | half_mixup_data,
7 | soft_cross_entropy,
8 | l2_distribution_loss,
9 | smooth_label,
10 | customized_weight_decay,
11 | interleave,
12 | )
13 |
14 |
15 | class Mixmatch:
16 | classifier: nn.Module
17 | hparams: ...
18 | lambda_u: float
19 |
20 | def _loss(self, labeled_x, labeled_y, unlabeled_xs, batch_inference=False):
21 | """
22 | labeled_x: [B, :]
23 | labeled_y: [B, n_classes]
24 | unlabeled_xs: K unlabeled_x
25 | """
26 | batch_size = labeled_x.size(0)
27 | num_augmentation = len(unlabeled_xs) # num_augmentation
28 |
29 | # not to update the running mean and variance in BN
30 | self.classifier.freeze_running_stats()
31 |
32 | p_unlabeled_y = self.classifier.psuedo_label(
33 | unlabeled_xs, self.hparams.T, batch_inference
34 | )
35 |
36 | # # size [(K + 1) * B, :]
37 | # all_inputs = torch.cat([labeled_x] + unlabeled_xs, dim=0)
38 | # all_targets = torch.cat(
39 | # [labeled_y, p_unlabeled_y.repeat(K, 1)],
40 | # dim=0
41 | # )
42 |
43 | # mixed_input, mixed_target = half_mixup_data(all_inputs, all_targets, self.hparams.alpha)
44 |
45 | # logits = self.forward(mixed_input, self.classifier)
46 |
47 | # l_l = soft_cross_entropy(logits[:batch_size], mixed_target[:batch_size])
48 | # l_u = l2_distribution_loss(logits[batch_size:], mixed_target[batch_size:])
49 |
50 | all_inputs = torch.cat([labeled_x] + unlabeled_xs, dim=0)
51 | all_targets = torch.cat(
52 | [labeled_y, p_unlabeled_y.repeat(num_augmentation, 1)], dim=0
53 | )
54 |
55 | mixed_input, mixed_target = half_mixup_data(
56 | all_inputs, all_targets, self.hparams.alpha
57 | )
58 |
59 | if batch_inference:
60 | self.classifier.recover_running_stats()
61 | logits = self.classifier(mixed_input)
62 | else:
63 | # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
64 | mixed_input = list(torch.split(mixed_input, batch_size))
65 | mixed_input = interleave(mixed_input, batch_size)
66 |
67 | logits_other = [self.classifier(x) for x in mixed_input[1:]]
68 |
69 | self.classifier.recover_running_stats()
70 |
71 | # only update the BN stats for the first batch
72 | logits_first = self.classifier(mixed_input[0])
73 |
74 | logits = [logits_first] + logits_other
75 |
76 | # put interleaved samples back
77 | logits = interleave(logits, batch_size)
78 |
79 | logits = torch.cat(logits, dim=0)
80 |
81 | l_l = soft_cross_entropy(logits[:batch_size], mixed_target[:batch_size])
82 | l_u = l2_distribution_loss(logits[batch_size:], mixed_target[batch_size:])
83 |
84 | return l_l + self.lambda_u * l_u, l_l, l_u
85 |
--------------------------------------------------------------------------------
/lightning_ssl/module/mixmatch_module.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import collections
4 | import numpy as np
5 | import torch.nn as nn
6 | from time import time
7 | from copy import deepcopy
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader
10 | from torchvision.datasets import MNIST
11 | from torchvision import transforms
12 | import pytorch_lightning as pl
13 | from lightning_ssl.utils.torch_utils import (
14 | half_mixup_data,
15 | soft_cross_entropy,
16 | l2_distribution_loss,
17 | smooth_label,
18 | customized_weight_decay,
19 | interleave,
20 | sharpening,
21 | )
22 |
23 | from lightning_ssl.module.classifier_module import ClassifierModule
24 | from lightning_ssl.module.mixmatch import Mixmatch
25 |
26 |
27 | class MixmatchModule(ClassifierModule, Mixmatch):
28 | def __init__(self, hparams, classifier, loaders):
29 | super(MixmatchModule, self).__init__(hparams, classifier, loaders)
30 | self.lambda_u = 0
31 | self.rampup_length = 16384 # get from the mixmatch codes
32 |
33 | for key, size in zip(["lab_loss", "unl_loss"], [500, 500]):
34 | self.train_dict[key] = collections.deque([], size)
35 |
36 | def on_train_batch_end(self, *args, **kwargs):
37 | # self.wdm.decay()
38 | self.ema.step()
39 |
40 | # linearly ramp up lambda_u
41 | if self.lambda_u == self.hparams.lambda_u:
42 | return
43 |
44 | step = (self.hparams.lambda_u - 0) / self.rampup_length
45 | self.lambda_u = min(self.lambda_u + step, self.hparams.lambda_u)
46 |
47 | def training_step(self, batch, batch_nb):
48 | # REQUIRED
49 | labeled_batch, unlabeled_batch = batch[0], batch[1:]
50 | # labeled_batch, unlabeled_batch = batch
51 | labeled_x, labeled_y = labeled_batch
52 |
53 | labeled_y = smooth_label(
54 | labeled_y, self.hparams.n_classes, self.hparams.label_smoothing
55 | )
56 |
57 | unlabeled_xs = [b[0] for b in unlabeled_batch] # only get the images
58 |
59 | loss, l_l, l_u = self._loss(
60 | labeled_x,
61 | labeled_y,
62 | unlabeled_xs,
63 | batch_inference=self.hparams.batch_inference,
64 | )
65 |
66 | loss = l_l + self.lambda_u * l_u
67 |
68 | # y = torch.argmax(mixed_target, dim=-1)
69 | # acc = self.accuracy(logits[:batch_size], y[:batch_size])
70 |
71 | self.train_dict["loss"].append(loss.item())
72 | # self.train_dict["acc"].append(acc.item())
73 | self.train_dict["lab_loss"].append(l_l.item())
74 | self.train_dict["unl_loss"].append(l_u.item())
75 |
76 | # tensorboard_logs = {"train/loss": np.mean(self.train_dict["loss"]),
77 | # # "train/acc": np.mean(self.train_dict["acc"]),
78 | # "train/lab_loss": np.mean(self.train_dict["lab_loss"]),
79 | # "train/unl_loss": np.mean(self.train_dict["unl_loss"]),
80 | # "lambda_u": self.lambda_u}
81 |
82 | # progress_bar = {# "acc": np.mean(self.train_dict["acc"]),
83 | # "lab_loss": np.mean(self.train_dict["lab_loss"]),
84 | # "unl_loss": np.mean(self.train_dict["unl_loss"]),
85 | # "lambda_u": self.lambda_u}
86 |
87 | # return {"loss": loss, "log": tensorboard_logs, "progress_bar": progress_bar}
88 |
89 | self.log(
90 | "train/loss", np.mean(self.train_dict["loss"]), prog_bar=False, logger=True
91 | )
92 | self.log(
93 | "train/lab_loss",
94 | np.mean(self.train_dict["lab_loss"]),
95 | prog_bar=True,
96 | logger=True,
97 | )
98 | self.log(
99 | "train/unl_loss",
100 | np.mean(self.train_dict["unl_loss"]),
101 | prog_bar=True,
102 | logger=True,
103 | )
104 | self.log("train/lambda_u", self.lambda_u, prog_bar=False, logger=True)
105 |
106 | return {"loss": loss}
107 |
108 | # def validation_epoch_end(self, outputs):
109 | # # Monitor both validation set and test set
110 | # # record the test accuracies of last 20 checkpoints
111 |
112 | # avg_loss_list, avg_acc_list = [], []
113 |
114 | # for output in outputs:
115 | # avg_loss = torch.stack([x["val_loss"] * x["val_num"] for x in output]).sum() / \
116 | # np.sum([x["val_num"] for x in output])
117 | # avg_acc = torch.stack([x["val_acc"] * x["val_num"] for x in output]).sum() / \
118 | # np.sum([x["val_num"] for x in output])
119 |
120 | # avg_loss_list.append(avg_loss)
121 | # avg_acc_list.append(avg_acc)
122 |
123 | # # record best results of validation set
124 | # self.train_dict["val_acc"][0] = max(self.train_dict["val_acc"][0], avg_acc_list[0].item())
125 | # self.train_dict["test_acc"].append(avg_acc_list[1].item())
126 |
127 | # # tensorboard_logs = {"val/loss": avg_loss_list[0],
128 | # # "val/acc": avg_acc_list[0],
129 | # # "val/best_acc": self.train_dict["val_acc"][0],
130 | # # "test/median_acc": np.median(self.train_dict["test_acc"])}
131 |
132 | # # return {"val_loss": avg_loss_list[0], "val_acc": avg_acc_list[0], "log": tensorboard_logs}
133 |
134 | # self.log("val/loss", avg_loss_list[0], on_step=True, on_epoch=False,
135 | # prog_bar=False, logger=True)
136 | # self.log("val/acc", avg_acc_list[0], on_step=True, on_epoch=False,
137 | # prog_bar=False, logger=True)
138 | # self.log("val/best_acc", self.train_dict["val_acc"][0], on_step=True, on_epoch=False,
139 | # prog_bar=False, logger=True)
140 | # self.log("test/median_acc", np.median(self.train_dict["test_acc"]), on_step=True, on_epoch=False,
141 | # prog_bar=False, logger=True)
142 |
--------------------------------------------------------------------------------
/lightning_ssl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import makedirs, count_parameters
2 |
--------------------------------------------------------------------------------
/lightning_ssl/utils/argparser.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import argparse
3 | from configparser import ConfigParser
4 |
5 |
6 | def str2bool(v):
7 | if isinstance(v, bool):
8 | return v
9 | if v.lower() in ("yes", "true", "t", "y", "1"):
10 | return True
11 | elif v.lower() in ("no", "false", "f", "n", "0"):
12 | return False
13 | else:
14 | raise argparse.ArgumentTypeError("Boolean value expected.")
15 |
16 |
17 | def parser():
18 | conf_parser = argparse.ArgumentParser(
19 | description="parser for config",
20 | # Don't mess with format of description
21 | formatter_class=argparse.RawDescriptionHelpFormatter,
22 | # Turn off help, so we print all options in response to -h
23 | add_help=False,
24 | )
25 |
26 | conf_parser.add_argument(
27 | "-c", "--conf_file", help="Specify config file", metavar="FILE"
28 | )
29 |
30 | conf_parser.add_argument(
31 | "-d", "--dataset", default="cifar10", help="use what dataset"
32 | )
33 |
34 | args, remaining_argv = conf_parser.parse_known_args()
35 |
36 | defaults = {}
37 | if args.conf_file:
38 | with open(args.conf_file, "r") as f:
39 | configs = yaml.load(f, yaml.SafeLoader)
40 | assert args.dataset in configs, f"Don't have the config for {args.dataset}"
41 | defaults.update(configs[args.dataset])
42 |
43 | parser = argparse.ArgumentParser(
44 | description="Semi-Supervised Learning", parents=[conf_parser]
45 | )
46 |
47 | parser.add_argument(
48 | "--learning_scenario",
49 | default="semi",
50 | choices=["semi", "supervised"],
51 | help="what learning scenario to use",
52 | )
53 |
54 | parser.add_argument("--algo", default="none", help="which algorithm to use")
55 |
56 | parser.add_argument(
57 | "--todo",
58 | choices=["train", "test"],
59 | default="train",
60 | help="what behavior want to do: train | test",
61 | )
62 |
63 | parser.add_argument(
64 | "--data_root", default=".", help="the directory to save the dataset"
65 | )
66 |
67 | parser.add_argument(
68 | "--log_root",
69 | default=".",
70 | help="the directory to save the logs or other imformations (e.g. images)",
71 | )
72 |
73 | parser.add_argument(
74 | "--model_root", default="checkpoint", help="the directory to save the models"
75 | )
76 |
77 | parser.add_argument("--load_checkpoint", default="./model/default/model.pth")
78 |
79 | parser.add_argument(
80 | "--affix", default="default", help="the affix for the save folder"
81 | )
82 |
83 | parser.add_argument("--seed", type=int, default=1, help="seed")
84 |
85 | parser.add_argument(
86 | "--num_workers",
87 | type=int,
88 | default=16,
89 | help="how many workers used in data loader",
90 | )
91 |
92 | parser.add_argument("--batch_size", "-b", type=int, default=128, help="batch size")
93 |
94 | parser.add_argument(
95 | "--num_labeled", type=int, default=1000, help="number of labeled data"
96 | )
97 |
98 | parser.add_argument(
99 | "--num_val", type=int, default=5000, help="the amount of validation data"
100 | )
101 |
102 | parser.add_argument(
103 | "--max_epochs",
104 | "-m_e",
105 | type=int,
106 | default=200,
107 | help="the maximum numbers of the model see a sample",
108 | )
109 |
110 | parser.add_argument(
111 | "--learning_rate", "-lr", type=float, default=1e-2, help="learning rate"
112 | )
113 |
114 | parser.add_argument(
115 | "--weight_decay",
116 | "-w",
117 | type=float,
118 | default=2e-4,
119 | help="the parameter of l2 restriction for weights",
120 | )
121 |
122 | parser.add_argument("--gpus", "-g", default="0", help="what gpus to use")
123 |
124 | parser.add_argument(
125 | "--n_eval_step",
126 | type=int,
127 | default=100,
128 | help="number of iteration per one evaluation",
129 | )
130 |
131 | parser.add_argument(
132 | "--n_checkpoint_step",
133 | type=int,
134 | default=4000,
135 | help="number of iteration to save a checkpoint",
136 | )
137 |
138 | parser.add_argument(
139 | "--n_store_image_step",
140 | type=int,
141 | default=4000,
142 | help="number of iteration to save adversaries",
143 | )
144 |
145 | parser.add_argument(
146 | "--max_steps", type=int, default=1 << 16, help="maximum iteration for training"
147 | )
148 |
149 | parser.add_argument(
150 | "--ema",
151 | type=float,
152 | default=0.999,
153 | help="the decay for exponential moving average",
154 | )
155 |
156 | parser.add_argument(
157 | "--label_smoothing",
158 | type=float,
159 | default=0.0,
160 | help="the paramters for label smoothing",
161 | )
162 |
163 | parser.add_argument(
164 | "--augment", action="store_true", help="whether to augment the training data"
165 | )
166 |
167 | parser.add_argument(
168 | "--num_augments",
169 | type=int,
170 | default=2,
171 | help="how many augment samples for unlabeled data in Mixmatch",
172 | )
173 |
174 | parser.add_argument(
175 | "--alpha",
176 | type=float,
177 | default=-1,
178 | help="the hyperparameter for beta distribution in mixup \
179 | (0 < alpha. If alpha < 0 means no mix)",
180 | )
181 |
182 | parser.add_argument(
183 | "--T",
184 | type=float,
185 | default=0.5,
186 | help="temperature for softmax distribution or sharpen parameter in mixmatch",
187 | )
188 |
189 | parser.add_argument(
190 | "--lambda_u",
191 | type=float,
192 | default=100,
193 | help="the weight of the loss for the unlabeled data",
194 | )
195 |
196 | parser.add_argument(
197 | "--batch_inference",
198 | type=str2bool,
199 | default=False,
200 | help="whether use batch inference in generating psuedo labels and computing loss",
201 | )
202 |
203 | parser.add_argument(
204 | "--progress_bar_refresh_rate",
205 | type=int,
206 | default=1,
207 | help="The frequency to refresh the progress bar (0 diables the progress bar)",
208 | )
209 |
210 | parser.set_defaults(**defaults)
211 | parser.set_defaults(**vars(args))
212 |
213 | return parser.parse_args(remaining_argv)
214 |
215 |
216 | def print_args(args, logger=None):
217 | for k, v in vars(args).items():
218 | if logger is not None:
219 | logger.info("{:<16} : {}".format(k, v))
220 | else:
221 | print("{:<16} : {}".format(k, v))
222 |
--------------------------------------------------------------------------------
/lightning_ssl/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from copy import deepcopy
6 |
7 |
8 | def sharpening(label, T):
9 | label = label.pow(1 / T)
10 | return label / label.sum(-1, keepdim=True)
11 |
12 |
13 | def soft_cross_entropy(input_, target_, dim=-1):
14 | """
15 | compute the cross entropy between input_ and target_
16 | Args
17 | input_: logits of model's prediction. Size = [Batch, n_classes]
18 | target_: probability of target distribution. Size = [Batch, n_classes]
19 | Return
20 | the entropy between input_ and target_
21 | """
22 |
23 | input_ = input_.log_softmax(dim=dim)
24 |
25 | return torch.mean(torch.sum(-target_ * input_, dim=dim))
26 |
27 |
28 | def l2_distribution_loss(input_, target_, dim=-1):
29 | input_ = input_.softmax(dim=dim)
30 |
31 | return torch.mean((input_ - target_) ** 2)
32 |
33 |
34 | def mixup_data(x, y, alpha):
35 | """
36 | Args:
37 | x: data, whose size is [Batch, ...]
38 | y: label, whose size is [Batch, ...]
39 | alpha: the paramters for beta distribution. If alpha <= 0 means no mix
40 | Return
41 | mixed inputs, mixed targets
42 | """
43 | # code is modified from
44 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
45 |
46 | batch_size = x.size()[0]
47 |
48 | if alpha > 0:
49 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,))
50 | else:
51 | lam = torch.ones((batch_size,))
52 |
53 | lam = lam.to(x.device)
54 |
55 | index = torch.randperm(batch_size).to(x.device)
56 |
57 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))]
58 | x_size[0], y_size[0] = batch_size, batch_size
59 |
60 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index]
61 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index]
62 |
63 | return mixed_x, mixed_y
64 |
65 |
66 | def mixup_data_for_testing(x, y, alpha):
67 | """
68 | Args:
69 | x: data, whose size is [Batch, ...]
70 | y: label, whose size is [Batch, ...]
71 | alpha: the paramters for beta distribution. If alpha <= 0 means no mix
72 | Return
73 | mixed inputs, mixed targets
74 | """
75 | # code is modified from
76 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
77 |
78 | batch_size = x.size()[0]
79 |
80 | if alpha > 0:
81 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,))
82 | else:
83 | lam = torch.ones((batch_size,))
84 |
85 | lam = lam.to(x.device)
86 |
87 | index = torch.randperm(batch_size).to(x.device)
88 |
89 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))]
90 | x_size[0], y_size[0] = batch_size, batch_size
91 |
92 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index]
93 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index]
94 |
95 | return mixed_x, mixed_y, x[index], y[index], lam
96 |
97 |
98 | def half_mixup_data(x, y, alpha):
99 | """
100 | This function is similar to normal mixup except that the mixed_x
101 | and mixed_y are close to x and y.
102 | """
103 | # code is modified from
104 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
105 |
106 | batch_size = x.size()[0]
107 |
108 | if alpha > 0:
109 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,))
110 | else:
111 | lam = torch.ones((batch_size,))
112 |
113 | lam = torch.max(lam, 1 - lam)
114 |
115 | lam = lam.to(x.device)
116 |
117 | index = torch.randperm(batch_size).to(x.device)
118 |
119 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))]
120 | x_size[0], y_size[0] = batch_size, batch_size
121 |
122 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index]
123 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index]
124 |
125 | return mixed_x, mixed_y
126 |
127 |
128 | def half_mixup_data_for_testing(x, y, alpha):
129 | """
130 | This function is similar to normal mixup except that the mixed_x
131 | and mixed_y are close to x and y.
132 | """
133 | # code is modified from
134 | # https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py
135 |
136 | batch_size = x.size()[0]
137 |
138 | if alpha > 0:
139 | lam = torch.distributions.beta.Beta(alpha, alpha).rsample((batch_size,))
140 | else:
141 | lam = torch.ones((batch_size,))
142 |
143 | lam = torch.max(lam, 1 - lam)
144 |
145 | lam = lam.to(x.device)
146 |
147 | index = torch.randperm(batch_size).to(x.device)
148 |
149 | x_size, y_size = [1 for _ in range(len(x.shape))], [1 for _ in range(len(y.shape))]
150 | x_size[0], y_size[0] = batch_size, batch_size
151 |
152 | mixed_x = lam.view(x_size) * x + (1 - lam.view(x_size)) * x[index]
153 | mixed_y = lam.view(y_size) * y + (1 - lam.view(y_size)) * y[index]
154 |
155 | return mixed_x, mixed_y, x[index], y[index], lam
156 |
157 |
158 | def smooth_label(y, n_classes, smoothing=0.0):
159 | """
160 | Transform the y into one-hot representation and smooth it.
161 | If smoothing is 0, then the return will be one-hot representation of y.
162 | Args
163 | y: label which is LongTensor
164 | n_classes: the total number of classes
165 | smoothing: the paramter of label smoothing
166 | Return
167 | true: the smooth label whose size is [Batch, n_classes]
168 | """
169 |
170 | confidence = 1.0 - smoothing
171 | true_dist = torch.zeros(*list(y.size()), n_classes).to(y.device)
172 | true_dist.fill_(smoothing / (n_classes - 1))
173 | true_dist.scatter_(-1, y.data.unsqueeze(1), confidence)
174 |
175 | return true_dist
176 |
177 |
178 | def to_one_hot_vector(y, n_classes):
179 | return smooth_label(y, n_classes, 0)
180 |
181 |
182 | def customized_weight_decay(model, weight_decay, ignore_key=["bn", "bias"]):
183 | for name, p in model.named_parameters():
184 | if not any(key in name for key in ignore_key):
185 | p.data.mul_(1 - weight_decay)
186 |
187 |
188 | def interleave_offsets(batch, nu):
189 | groups = [batch // (nu + 1)] * (nu + 1)
190 | for x in range(batch - sum(groups)):
191 | groups[-x - 1] += 1
192 | offsets = [0]
193 | for g in groups:
194 | offsets.append(offsets[-1] + g)
195 | assert offsets[-1] == batch
196 | return offsets
197 |
198 |
199 | def interleave(xy, batch):
200 | nu = len(xy) - 1
201 | offsets = interleave_offsets(batch, nu)
202 | xy = [[v[offsets[p] : offsets[p + 1]] for p in range(nu + 1)] for v in xy]
203 | for i in range(1, nu + 1):
204 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
205 | return [torch.cat(v, dim=0) for v in xy]
206 |
207 |
208 | # def my_interleave_1(inputs, batch_size):
209 | # """
210 | # * Make the data interleave.
211 | # * Swap the data of the first batch (inputs[0]) to other batches.
212 | # * change_indices would be a increasing function.
213 | # * len(change_indices) should be the same as len(inputs[0]), and the elements
214 | # denote which row should the data in first row change with.
215 | # """
216 | # ret = deepcopy(inputs)
217 | # inputs_size = len(inputs)
218 |
219 | # repeat = batch_size // inputs_size
220 | # residual = batch_size % inputs_size
221 | # change_indices = list(range(inputs_size)) * repeat + list(
222 | # range(inputs_size - residual, inputs_size)
223 | # )
224 | # change_indices = sorted(change_indices)
225 | # # print(change_indices)
226 | # for i, switch_row in enumerate(change_indices):
227 | # ret[0][i], ret[switch_row % inputs_size][i] = (
228 | # inputs[switch_row % inputs_size][i],
229 | # inputs[0][i],
230 | # )
231 |
232 | # return ret
233 |
234 |
235 | # def my_interleave_2(inputs, batch_size):
236 | # """
237 | # * Make the data interleave.
238 | # * Swap the data of the first batch (inputs[0]) to other batches.
239 | # * change_indices would be a increasing function.
240 | # * len(change_indices) should be the same as len(inputs[0]), and the elements
241 | # denote which row should the data in first row change with.
242 | # """
243 | # # ret = deepcopy(inputs)
244 | # def swap(A, B):
245 | # return B.clone(), A.clone()
246 |
247 | # ret = inputs
248 | # inputs_size = len(inputs)
249 |
250 | # # equally switch the first row to other rows, so we compute how many repeat for range(inputs_size),
251 | # # which store the rows to change.
252 | # # some of the element cannot evenly spread two rows, so we preferentially use the rows which are farer to 0th row.
253 | # repeat = batch_size // inputs_size
254 | # residual = batch_size % inputs_size
255 | # change_indices = list(range(inputs_size)) * repeat + list(
256 | # range(inputs_size - residual, inputs_size)
257 | # )
258 | # change_indices = sorted(change_indices)
259 | # # print(change_indices)
260 |
261 | # # start to change elements
262 | # for i, switch_row in enumerate(change_indices):
263 | # ret[0][i], ret[switch_row % inputs_size][i] = swap(
264 | # ret[0][i], ret[switch_row % inputs_size][i]
265 | # )
266 |
267 | # return ret
268 |
269 |
270 | def my_interleave(inputs, batch_size):
271 | """
272 | * This function will override inputs
273 | * Make the data interleave.
274 | * Swap the data of the first batch (inputs[0]) to other batches.
275 | * change_indices would be a increasing function.
276 | * len(change_indices) should be the same as len(inputs[0]), and the elements
277 | denote which row should the data in first row change with.
278 | """
279 |
280 | def swap(A, B):
281 | """
282 | swap for tensors
283 | """
284 | return B.clone(), A.clone()
285 |
286 | ret = inputs
287 | inputs_size = len(inputs)
288 |
289 | repeat = batch_size // inputs_size
290 | residual = batch_size % inputs_size
291 |
292 | # equally switch the first row to other rows, so we compute how many repeat for range(inputs_size),
293 | # which store the rows to change.
294 | # some of the element cannot evenly spread two rows, so we preferentially use the rows which are farer to 0th row.
295 |
296 | change_indices = list(range(inputs_size)) * repeat + list(
297 | range(inputs_size - residual, inputs_size)
298 | )
299 | change_indices = sorted(change_indices)
300 | # print(change_indices)
301 |
302 | # the change_indices is monotone increasing function, so we can group the same elements and swap together
303 | # e.g. change_indices = [0, 1, 1, 2, 2, 2]
304 | # => two_dimension_change_indices = [[0], [1, 1], [2, 2, 2]]
305 | two_dimension_change_indices = []
306 |
307 | change_indices.insert(0, -1)
308 | change_indices.append(change_indices[-1] + 1)
309 | start = 0
310 | for i in range(1, len(change_indices)):
311 | if change_indices[i] != change_indices[i - 1]:
312 | two_dimension_change_indices.append(change_indices[start:i])
313 | start = i
314 |
315 | two_dimension_change_indices.pop(0)
316 |
317 | i = 0
318 | for switch_rows in two_dimension_change_indices:
319 | switch_row = switch_rows[0]
320 | num = len(switch_rows)
321 | ret[0][i : i + num], ret[switch_row % inputs_size][i : i + num] = swap(
322 | ret[0][i : i + num], ret[switch_row % inputs_size][i : i + num]
323 | )
324 | i += num
325 |
326 | return ret
327 |
328 |
329 | def split_weight_decay_weights(model, weight_decay, ignore_key=["bn", "bias"]):
330 | weight_decay_weights = []
331 | no_weight_decay_weights = []
332 | for name, p in model.named_parameters():
333 | if not p.requires_grad:
334 | continue
335 | if any(key in name for key in ignore_key):
336 | no_weight_decay_weights.append(p)
337 | else:
338 | # print(name)
339 | weight_decay_weights.append(p)
340 |
341 | return [
342 | {"params": no_weight_decay_weights, "weight_decay": 0.0},
343 | {"params": weight_decay_weights, "weight_decay": weight_decay},
344 | ]
345 |
346 |
347 | class WeightDecayModule:
348 | def __init__(self, model, weight_decay, ignore_key=["bn", "bias"]):
349 | self.weight_decay = weight_decay
350 | self.available_parameters = []
351 | for name, p in model.named_parameters():
352 | if not any(key in name for key in ignore_key):
353 | # print(name)
354 | self.available_parameters.append(p)
355 |
356 | def decay(self):
357 | for p in self.available_parameters:
358 | p.data.mul_(1 - self.weight_decay)
359 |
360 |
361 | class EMA:
362 | """
363 | The module for exponential moving average
364 | """
365 |
366 | def __init__(self, model, ema_model, decay=0.999):
367 | self.decay = decay
368 | self.params = list(model.state_dict().values())
369 | self.ema_params = list(ema_model.state_dict().values())
370 |
371 | # some of the quantity in batch norm is LongTensor,
372 | # we have to make them become float, or it will cause the
373 | # type error in mul_ or add_ in self.step()
374 | for p in ema_model.parameters():
375 | p.detach_()
376 | for i in range(len(self.ema_params)):
377 | self.ema_params[i] = self.ema_params[i].float()
378 |
379 | def step(self):
380 | # average all the paramters, including the running mean and
381 | # running std in batchnormalization
382 | for param, ema_param in zip(self.params, self.ema_params):
383 | # if param.dtype == torch.float32:
384 | ema_param.mul_(self.decay)
385 | ema_param.add_(param * (1 - self.decay))
386 | # if param.dtype == torch.float32:
387 | # param.mul_(1 - 4e-5)
388 |
--------------------------------------------------------------------------------
/lightning_ssl/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 |
4 |
5 | # def create_logger(save_path="", file_type="", level="debug"):
6 |
7 | # if level == "debug":
8 | # _level = logging.DEBUG
9 | # elif level == "info":
10 | # _level = logging.INFO
11 |
12 | # logger = logging.getLogger()
13 | # logger.setLevel(_level)
14 |
15 | # cs = logging.StreamHandler()
16 | # cs.setLevel(_level)
17 | # logger.addHandler(cs)
18 |
19 | # if save_path != "":
20 | # file_name = os.path.join(save_path, file_type + "_log.txt")
21 | # fh = logging.FileHandler(file_name, mode="w")
22 | # fh.setLevel(_level)
23 |
24 | # logger.addHandler(fh)
25 |
26 | # return logger
27 |
28 |
29 | def makedirs(path):
30 | if not os.path.exists(path):
31 | os.makedirs(path)
32 |
33 |
34 | def count_parameters(model):
35 | # copy from https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8
36 | # baldassarre.fe's reply
37 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
38 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from lightning_ssl.module import ClassifierModule, MixmatchModule
5 | from lightning_ssl.dataloader import SemiCIFAR10Module, SupervisedCIFAR10Module
6 | from lightning_ssl.models import WideResNet
7 | import pytorch_lightning as pl
8 | from lightning_ssl.utils import count_parameters
9 | from lightning_ssl.utils.argparser import parser, print_args
10 |
11 |
12 | if __name__ == "__main__":
13 | args = parser()
14 |
15 | gpus = args.gpus if torch.cuda.is_available() else None
16 |
17 | # set the random seed
18 | # np.random.seed(args.seed) # for sampling labeled/unlabeled/val dataset
19 | pl.seed_everything(args.seed)
20 |
21 | # load the data and classifier
22 | if args.dataset == "cifar10":
23 | if args.learning_scenario == "supervised":
24 | loader_class = SupervisedCIFAR10Module
25 |
26 | elif args.learning_scenario == "semi":
27 | loader_class = SemiCIFAR10Module
28 |
29 | data_loader = loader_class(
30 | args,
31 | args.data_root,
32 | args.num_workers,
33 | args.batch_size,
34 | args.num_labeled,
35 | args.num_val,
36 | args.num_augments,
37 | )
38 | classifier = WideResNet(depth=28, num_classes=data_loader.n_classes)
39 |
40 | else:
41 | raise NotImplementedError
42 |
43 | if args.learning_scenario == "supervised":
44 | module = ClassifierModule
45 | else: # semi supervised learning algorithm
46 | if args.algo == "mixmatch":
47 | module = MixmatchModule
48 | else:
49 | raise NotImplementedError
50 |
51 | print(f"model paramters: {count_parameters(classifier)} M")
52 |
53 | # set the number of classes in the args
54 | setattr(args, "n_classes", data_loader.n_classes)
55 |
56 | data_loader.prepare_data()
57 | data_loader.setup()
58 |
59 | model = module(args, classifier, loaders=None)
60 |
61 | # trainer = Trainer(tr_loaders, va_loader, te_loader)
62 |
63 | if args.todo == "train":
64 | print(
65 | f"labeled size: {data_loader.num_labeled_data}"
66 | f"unlabeled size: {data_loader.num_unlabeled_data}, "
67 | f"val size: {data_loader.num_val_data}"
68 | )
69 |
70 | save_folder = (
71 | f"{args.dataset}_{args.learning_scenario}_{args.algo}_{args.affix}"
72 | )
73 | # model_folder = os.path.join(args.model_root, save_folder)
74 | # checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint(filepath=model_folder)
75 |
76 | # tt_logger = pl.loggers.TestTubeLogger("tt_logs", name=save_folder, create_git_tag=True)
77 | # tt_logger.log_hyperparams(args)
78 |
79 | # if True:
80 | # print_args(args)
81 |
82 | # trainer = pl.trainer.Trainer(gpus=gpus, max_epochs=args.max_epochs)
83 | # else:
84 |
85 | tb_logger = pl.loggers.TensorBoardLogger(
86 | os.path.join(args.log_root, "lightning_logs"), name=save_folder
87 | )
88 | tb_logger.log_hyperparams(args)
89 |
90 | # set the path of checkpoint
91 | save_dir = getattr(tb_logger, "save_dir", None) or getattr(
92 | tb_logger, "_save_dir", None
93 | )
94 |
95 | ckpt_path = os.path.join(
96 | save_dir, tb_logger.name, f"version_{tb_logger.version}", "checkpoints"
97 | )
98 |
99 | ckpt = pl.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_path, "last"))
100 |
101 | setattr(args, "checkpoint_folder", ckpt_path)
102 |
103 | print_args(args)
104 |
105 | trainer = pl.trainer.Trainer(
106 | gpus=gpus,
107 | max_steps=args.max_steps,
108 | logger=tb_logger,
109 | max_epochs=args.max_epochs,
110 | checkpoint_callback=ckpt,
111 | benchmark=True,
112 | profiler=True,
113 | progress_bar_refresh_rate=args.progress_bar_refresh_rate,
114 | reload_dataloaders_every_epoch=True,
115 | )
116 |
117 | trainer.fit(model, datamodule=data_loader)
118 | trainer.test()
119 | else:
120 | trainer = pl.trainer.Trainer(resume_from_checkpoint=args.load_checkpoint)
121 | trainer.test(model, datamodule=data_loader)
122 |
--------------------------------------------------------------------------------
/read_from_tb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import sys
5 |
6 | data_num = int(sys.argv[1])
7 | seeds = [1, 2, 3]
8 | versions = [0, 0, 0]
9 |
10 | values = []
11 | iters = []
12 |
13 | for s, v in zip(seeds, versions):
14 | path_to_events_file = f"lightning_logs/cifar10_semi_mixmatch_@{data_num}.{s}_LB_NoV_FS_W_Batch_2/version_{v}/"
15 | # path_to_events_file = f"lightning_logs/cifar10_supervised_mixup_fs_mixup.{s}/version_{v}"
16 |
17 | files = os.listdir(path_to_events_file)
18 | for file_ in files:
19 | if file_.startswith("event"):
20 | event_file = file_
21 |
22 | print(event_file)
23 |
24 | path_to_events_file = os.path.join(path_to_events_file, event_file)
25 |
26 | last_iter = 0
27 | for e in tf.train.summary_iterator(path_to_events_file):
28 | for v in e.summary.value:
29 | if v.tag.startswith("test/median_acc"):
30 | value = v.simple_value
31 | last_iter += 1
32 |
33 | values.append(value)
34 | iters.append(last_iter)
35 |
36 | # for i in range(1, len(iters)):
37 | # assert iters[i] == iters[i - 1]
38 |
39 | print(iters)
40 | print(values)
41 | print(100 - np.mean(values))
42 | print(np.std(values))
43 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 |
2 | [flake8]
3 | # TODO: this should be 88 or 100 according PEP8
4 | max-line-length = 120
5 | exclude = .tox,*.egg,build,temp
6 | select = E,W,F
7 | doctests = True
8 | verbose = 2
9 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
10 | format = pylint
11 | ignore =
12 | E741
13 | E731
14 | W504
15 | F401
16 | F841
17 | E203 # E203 - whitespace before ':'. Opposite convention enforced by black
18 | E231 # E231: missing whitespace after ',', ';', or ':'; for black
19 | E501 # E501 - line too long. Handled by black, we have longer lines
20 | W503 # W503 - line break before binary operator, need for black
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import platform
4 | from setuptools import setup, find_packages
5 |
6 |
7 | setup(
8 | name="lightning-semi-supervised-learning",
9 | description="Ready-to-use semi-supervised learning under one common API",
10 | version="master",
11 | packages=find_packages(),
12 | include_package_data=True,
13 | install_requires=["torch==1.7.0", "torchvision", "pytorch-lightning==1.0.2"],
14 | extras_require={
15 | "test": [
16 | "coverage",
17 | "pytest",
18 | "flake8",
19 | "pre-commit",
20 | "codecov",
21 | "pytest-cov",
22 | "pytest-flake8",
23 | "flake8-black",
24 | "black",
25 | ]
26 | },
27 | )
28 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/__init__.py
--------------------------------------------------------------------------------
/tests/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/dataloader/__init__.py
--------------------------------------------------------------------------------
/tests/dataloader/test_base_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 | import numpy as np
4 | from lightning_ssl.dataloader.base_data import (
5 | get_split_indices,
6 | CustomSemiDataset,
7 | MultiDataset,
8 | MagicClass,
9 | )
10 | from torch.utils.data import TensorDataset, DataLoader
11 |
12 |
13 | @pytest.fixture(
14 | params=[
15 | {"max_label": 10, "num_data": 1000},
16 | {"max_label": 10, "num_data": 1000},
17 | {"max_label": 10, "num_data": 1000},
18 | ]
19 | )
20 | def label_data(request):
21 | param = request.param
22 |
23 | while True:
24 | label_proportion = np.random.uniform(
25 | 0, 1, size=(param["max_label"],)
26 | ) # the proportion of each label
27 |
28 | label_proportion = (
29 | label_proportion / label_proportion.sum()
30 | ) # normalize to summation of proportion is 1
31 |
32 | # after this normalization
33 | label_proportion = np.round_(
34 | label_proportion, decimals=int(np.log10(param["num_data"]) - 1)
35 | )
36 |
37 | residual = 1.0 - label_proportion.sum()
38 | # add the residual to the last element
39 | label_proportion[-1] += residual
40 |
41 | if label_proportion.min() > 0: # valid proportion
42 |
43 | label_proportion = (label_proportion * param["num_data"]).astype(int)
44 |
45 | if label_proportion.sum() == param["num_data"]: # valid proportion
46 | label_data = []
47 |
48 | for i, p in enumerate(label_proportion):
49 | label_data.append(i * np.ones((p,)))
50 |
51 | label_data = np.hstack(label_data)
52 |
53 | np.random.shuffle(label_data)
54 |
55 | return label_data
56 |
57 |
58 | def test_get_split_indices(label_data):
59 | # label_data is specially design
60 | # the proportion of the labeled in label_data is specially designed to that
61 | # (num_labeled * proportion), (num_val * proportion) and (num_unlabeled * proportion)
62 | # are all integer
63 |
64 | num_labeled = int(0.1 * len(label_data))
65 | num_val = int(0.1 * len(label_data))
66 | n_classes = int(np.max(label_data).item()) + 1
67 | num_unlabeled = len(label_data) - num_labeled - num_val
68 | labeled_indices, unlabeled_indices, val_indices = get_split_indices(
69 | label_data, num_labeled, num_val, n_classes
70 | )
71 |
72 | def convert_idx_to_num(indices):
73 | num_list = []
74 | for c in range(n_classes):
75 | num_list.append((label_data[labeled_indices] == c).sum())
76 | return num_list
77 |
78 | assert len(val_indices) == num_val
79 | assert len(labeled_indices) == num_labeled
80 | assert len(unlabeled_indices) == num_unlabeled
81 |
82 | # check the proportion of data
83 | assert convert_idx_to_num(labeled_indices) == convert_idx_to_num(
84 | np.arange(len(label_data))
85 | )
86 | assert convert_idx_to_num(val_indices) == convert_idx_to_num(
87 | np.arange(len(label_data))
88 | )
89 | assert convert_idx_to_num(unlabeled_indices) == convert_idx_to_num(
90 | np.arange(len(label_data))
91 | )
92 |
93 |
94 | def test_custum_dataset():
95 | dataset_1 = TensorDataset(torch.arange(10))
96 | dataset_2 = TensorDataset(torch.arange(30, 75))
97 | dataset_3 = TensorDataset(torch.arange(30, 75))
98 | dataset_4 = TensorDataset(torch.arange(30, 75))
99 |
100 | dloader_1 = DataLoader(dataset_1, batch_size=3, shuffle=True, num_workers=0)
101 | dloader_2 = DataLoader(
102 | MultiDataset([dataset_2, dataset_3, dataset_4]),
103 | batch_size=3,
104 | shuffle=True,
105 | num_workers=0,
106 | )
107 |
108 | for _ in range(10):
109 | for batch in MagicClass([dloader_1, dloader_2]):
110 | batch = batch[1]
111 | for b in batch[1:]:
112 | assert torch.all(
113 | batch[0][0] == b[0]
114 | ), "The data in multidataset should be the same"
115 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/models/__init__.py
--------------------------------------------------------------------------------
/tests/models/test_basemodel.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torch.nn
4 | import pytest
5 |
6 | from copy import deepcopy
7 |
8 | from lightning_ssl.models import WideResNet
9 |
10 |
11 | @pytest.fixture(scope="module")
12 | def model_inputs_pair():
13 | return WideResNet(28, 10), torch.randn((5, 3, 32, 32))
14 |
15 |
16 | def test_pseudo_label(model_inputs_pair):
17 | model = WideResNet(28, 10)
18 | model.eval()
19 | inputs_list = [torch.randn(5, 3, 32, 32) for _ in range(3)]
20 |
21 | assert model.psuedo_label(inputs_list, 0.5).shape == torch.Size([5, 10])
22 |
23 | assert torch.all(
24 | torch.isclose(
25 | model.psuedo_label(inputs_list, 0.5),
26 | model.psuedo_label(inputs_list, 0.5, True),
27 | )
28 | )
29 |
30 |
31 | def test_freeze_running_stats(model_inputs_pair):
32 | model, inputs = model_inputs_pair
33 | model.freeze_running_stats()
34 | before_stats = deepcopy(model.extract_running_stats())
35 | # run the network
36 | model(inputs)
37 | # the running_stats should not change
38 | after_stats = deepcopy(model.extract_running_stats())
39 |
40 | # check the number of batch norm layers
41 | assert len(before_stats) == (28 - 4) // 6 * 2 * 3 + 1
42 |
43 | for f, s in zip(before_stats, after_stats):
44 | assert torch.all(torch.eq(f[0], s[0]))
45 | assert torch.all(torch.eq(f[1], s[1]))
46 |
47 |
48 | def test_recover_running_stats(model_inputs_pair):
49 | model, inputs = model_inputs_pair
50 | # recover running stats
51 | model.recover_running_stats()
52 | before_stats = deepcopy(model.extract_running_stats())
53 | # run the network
54 | model(inputs)
55 |
56 | # the running_stats should not change
57 | after_stats = deepcopy(model.extract_running_stats())
58 |
59 | for f, s in zip(before_stats, after_stats):
60 | assert torch.all(~torch.eq(f[0], s[0]))
61 | assert torch.all(~torch.eq(f[1], s[1]))
62 |
63 |
64 | def test_recover_running_stats__without_freeze_first():
65 | model = WideResNet(28, 10)
66 | before_stats = model.extract_norm_n_momentum()
67 | model.recover_running_stats()
68 | after_stats = model.extract_norm_n_momentum()
69 |
70 | # test batchnorm statistics
71 | for f, s in zip(before_stats[0], after_stats[0]):
72 | assert torch.all(torch.eq(f.running_mean, s.running_mean))
73 | assert torch.all(torch.eq(f.running_var, s.running_var))
74 | # test momentum
75 | for f, s in zip(before_stats[1], after_stats[1]):
76 | assert f == s
77 |
--------------------------------------------------------------------------------
/tests/models/test_cnn13.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | from copy import deepcopy
5 |
6 | from lightning_ssl.models import CNN13
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "shape, num_classes", [((5, 3, 32, 32), 8), ((10, 3, 32, 32), 15)]
11 | )
12 | def test_wideresnet(shape, num_classes):
13 | inputs = torch.randn(shape)
14 | model = CNN13(num_classes)
15 | batch_size = inputs.shape[0]
16 | assert model(inputs).shape == torch.Size([batch_size, num_classes])
17 |
--------------------------------------------------------------------------------
/tests/models/test_wideresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | from copy import deepcopy
5 |
6 | from lightning_ssl.models import WideResNet
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "shape, num_classes", [((5, 3, 32, 32), 8), ((10, 3, 32, 32), 15)]
11 | )
12 | def test_wideresnet(shape, num_classes):
13 | inputs = torch.randn(shape)
14 | model = WideResNet(28, num_classes, dropRate=0.2)
15 | batch_size = inputs.shape[0]
16 | assert model(inputs).shape == torch.Size([batch_size, num_classes])
17 |
--------------------------------------------------------------------------------
/tests/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/module/__init__.py
--------------------------------------------------------------------------------
/tests/module/assets/tmp_config.yml:
--------------------------------------------------------------------------------
1 | cifar10:
2 | algo: "mixmatch"
3 | learning_scenario: learning_scenario
4 | batch_size: 10
5 | num_workers: 0
6 | log_root: "tests"
7 | num_labeled: 10
8 | num_val: 30
9 |
--------------------------------------------------------------------------------
/tests/module/simple_data.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from lightning_ssl.dataloader.base_data import (
4 | SemiDataLoader,
5 | SupervisedDataLoader,
6 | SemiDataModule,
7 | SupervisedDataModule,
8 | )
9 |
10 | import torchvision as tv
11 | from torch.utils.data import DataLoader, SubsetRandomSampler
12 | from torchvision.datasets import CIFAR10
13 |
14 |
15 | CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
16 | CIFAR_STD = (0.2471, 0.2435, 0.2616)
17 |
18 |
19 | def sample_dataset():
20 | return [
21 | (np.random.uniform(0, 1, (3, 2, 2)).astype(np.float32), random.randint(0, 9))
22 | for i in range(50)
23 | ]
24 |
25 |
26 | class SemiSampleModule(SemiDataModule):
27 | def __init__(
28 | self,
29 | args,
30 | data_root,
31 | num_workers,
32 | batch_size,
33 | num_labeled,
34 | num_val,
35 | num_augments,
36 | ):
37 | n_classes = 10
38 | super(SemiSampleModule, self).__init__(
39 | data_root,
40 | num_workers,
41 | batch_size,
42 | num_labeled,
43 | num_val,
44 | num_augments,
45 | n_classes,
46 | )
47 |
48 | def prepare_data(self):
49 | # the transformation for train and validation dataset will be
50 | # done in _prepare_train_dataset()
51 |
52 | self.train_set = sample_dataset()
53 | self.test_set = sample_dataset()
54 |
55 |
56 | class SupervisedSampleModule(SupervisedDataModule):
57 | def __init__(
58 | self,
59 | args,
60 | data_root,
61 | num_workers,
62 | batch_size,
63 | num_labeled,
64 | num_val,
65 | num_augments,
66 | ):
67 | n_classes = 10
68 | super(SupervisedSampleModule, self).__init__(
69 | data_root, num_workers, batch_size, num_labeled, num_val, n_classes
70 | )
71 |
72 | def prepare_data(self):
73 | # the transformation for train and validation dataset will be
74 | # done in _prepare_train_dataset()
75 | self.train_set = sample_dataset()
76 | self.test_set = sample_dataset()
77 |
--------------------------------------------------------------------------------
/tests/module/simple_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from lightning_ssl.models.base_model import CustomModel
3 |
4 |
5 | class SimpleModel(CustomModel):
6 | def __init__(self, num_classes):
7 | super().__init__()
8 | self.num_classes = num_classes
9 | self.layer_1 = nn.AvgPool2d(kernel_size=2)
10 | self.layer_2 = nn.Linear(3, num_classes)
11 |
12 | def forward(self, inputs):
13 | output = self.layer_1(inputs).reshape(inputs.shape[0], -1)
14 | return self.layer_2(output)
15 |
--------------------------------------------------------------------------------
/tests/module/test_mixmatch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import torch
5 | import pytest
6 | import pathlib
7 | import numpy as np
8 | from tests.module.simple_model import SimpleModel
9 | from lightning_ssl.module import ClassifierModule, MixmatchModule
10 | from lightning_ssl.dataloader import (
11 | SemiCIFAR10Module,
12 | SupervisedCIFAR10Module,
13 | )
14 | from lightning_ssl.models import WideResNet
15 | import pytorch_lightning as pl
16 | from lightning_ssl.utils import count_parameters
17 | from lightning_ssl.utils.argparser import parser, print_args
18 |
19 | PATH = pathlib.Path(__file__).parent
20 | CONFIG_PATH = os.path.join(PATH, "assets", "tmp_config.yml")
21 |
22 |
23 | @pytest.mark.parametrize("batch_inference", ["True", "False"])
24 | def test_mixmatch(tmpdir, batch_inference):
25 | original_argv = copy.deepcopy(sys.argv)
26 | sys.argv = [
27 | "tmp.py",
28 | "-c",
29 | f"{CONFIG_PATH}",
30 | "--log_root",
31 | str(tmpdir),
32 | "--batch_inference",
33 | batch_inference,
34 | ]
35 | print(sys.argv)
36 | args = parser()
37 | sys.argv = original_argv
38 |
39 | gpus = args.gpus if torch.cuda.is_available() else None
40 |
41 | # load the data and classifier
42 |
43 | data_loader = SemiCIFAR10Module(
44 | args,
45 | args.data_root,
46 | args.num_workers,
47 | args.batch_size,
48 | args.num_labeled,
49 | args.num_val,
50 | args.num_augments,
51 | )
52 | classifier = WideResNet(depth=10, num_classes=data_loader.n_classes)
53 |
54 | print(f"model paramters: {count_parameters(classifier)} M")
55 |
56 | # set the number of classes in the args
57 | setattr(args, "n_classes", data_loader.n_classes)
58 |
59 | data_loader.prepare_data()
60 | data_loader.setup()
61 |
62 | model = MixmatchModule(args, classifier, loaders=None)
63 |
64 | print(
65 | f"labeled size: {data_loader.num_labeled_data}"
66 | f"unlabeled size: {data_loader.num_unlabeled_data}, "
67 | f"val size: {data_loader.num_val_data}, "
68 | f"test size: {data_loader.num_test_data}"
69 | )
70 |
71 | save_folder = f"{args.dataset}_{args.learning_scenario}_{args.algo}_{args.affix}"
72 |
73 | tb_logger = pl.loggers.TensorBoardLogger(
74 | os.path.join(args.log_root, "lightning_logs"), name=save_folder
75 | )
76 | tb_logger.log_hyperparams(args)
77 |
78 | # set the path of checkpoint
79 | save_dir = getattr(tb_logger, "save_dir", None) or getattr(
80 | tb_logger, "_save_dir", None
81 | )
82 | ckpt_path = os.path.join(
83 | save_dir, tb_logger.name, f"version_{tb_logger.version}", "checkpoints"
84 | )
85 |
86 | ckpt = pl.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_path, "{epoch}"))
87 |
88 | setattr(args, "checkpoint_folder", ckpt_path)
89 |
90 | print_args(args)
91 |
92 | trainer = pl.trainer.Trainer(
93 | gpus=gpus,
94 | logger=tb_logger,
95 | checkpoint_callback=ckpt,
96 | fast_dev_run=True,
97 | reload_dataloaders_every_epoch=True,
98 | progress_bar_refresh_rate=args.progress_bar_refresh_rate,
99 | )
100 |
101 | trainer.fit(model, datamodule=data_loader)
102 | trainer.test()
103 |
--------------------------------------------------------------------------------
/tests/module/test_supervised.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import torch
5 | import pytest
6 | import pathlib
7 | import numpy as np
8 | from lightning_ssl.module import ClassifierModule, MixmatchModule
9 | from lightning_ssl.dataloader import (
10 | SemiCIFAR10Module,
11 | SupervisedCIFAR10Module,
12 | )
13 | from lightning_ssl.models import WideResNet
14 | import pytorch_lightning as pl
15 | from lightning_ssl.utils import count_parameters
16 | from lightning_ssl.utils.argparser import parser, print_args
17 |
18 | PATH = pathlib.Path(__file__).parent
19 | CONFIG_PATH = os.path.join(PATH, "assets", "tmp_config.yml")
20 |
21 |
22 | def test_supervised(tmpdir):
23 | original_argv = copy.deepcopy(sys.argv)
24 | sys.argv = ["tmp.py", "-c", f"{CONFIG_PATH}", "--log_root", str(tmpdir)]
25 | print(sys.argv)
26 | args = parser()
27 | sys.argv = original_argv
28 |
29 | gpus = args.gpus if torch.cuda.is_available() else None
30 |
31 | # load the data and classifier
32 |
33 | data_loader = SupervisedCIFAR10Module(
34 | args,
35 | args.data_root,
36 | args.num_workers,
37 | args.batch_size,
38 | args.num_labeled,
39 | args.num_val,
40 | args.num_augments,
41 | )
42 | classifier = WideResNet(depth=10, num_classes=data_loader.n_classes)
43 |
44 | print(f"model paramters: {count_parameters(classifier)} M")
45 |
46 | # set the number of classes in the args
47 | setattr(args, "n_classes", data_loader.n_classes)
48 |
49 | data_loader.prepare_data()
50 | data_loader.setup()
51 |
52 | model = ClassifierModule(args, classifier, loaders=None)
53 |
54 | print(
55 | f"labeled size: {data_loader.num_labeled_data}"
56 | f"unlabeled size: {data_loader.num_unlabeled_data}, "
57 | f"val size: {data_loader.num_val_data}, "
58 | f"test size: {data_loader.num_test_data}"
59 | )
60 |
61 | save_folder = f"{args.dataset}_{args.learning_scenario}_{args.algo}_{args.affix}"
62 |
63 | tb_logger = pl.loggers.TensorBoardLogger(
64 | os.path.join(args.log_root, "lightning_logs"), name=save_folder
65 | )
66 | tb_logger.log_hyperparams(args)
67 |
68 | # set the path of checkpoint
69 | save_dir = getattr(tb_logger, "save_dir", None) or getattr(
70 | tb_logger, "_save_dir", None
71 | )
72 | ckpt_path = os.path.join(
73 | save_dir, tb_logger.name, f"version_{tb_logger.version}", "checkpoints"
74 | )
75 |
76 | ckpt = pl.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_path, "{epoch}"))
77 |
78 | setattr(args, "checkpoint_folder", ckpt_path)
79 |
80 | print_args(args)
81 |
82 | trainer = pl.trainer.Trainer(
83 | gpus=gpus,
84 | logger=tb_logger,
85 | checkpoint_callback=ckpt,
86 | fast_dev_run=True,
87 | reload_dataloaders_every_epoch=True,
88 | )
89 |
90 | trainer.fit(model, datamodule=data_loader)
91 | trainer.test()
92 |
--------------------------------------------------------------------------------
/tests/test_models:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/test_models
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ylsung/lightning-semi-supervised-learning/23aab968b279975df1cb86763238c0c3ee558a27/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/assets/tmp_config.yml:
--------------------------------------------------------------------------------
1 | cifar10:
2 | todo: "test"
3 | affix: "test"
4 |
--------------------------------------------------------------------------------
/tests/utils/test_argparser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import pathlib
5 | from lightning_ssl.utils.argparser import parser
6 |
7 | PATH = pathlib.Path(__file__).parent
8 | CONFIG_PATH = os.path.join(PATH, "assets", "tmp_config.yml")
9 |
10 |
11 | def test_parser():
12 | original_argv = copy.deepcopy(sys.argv)
13 | sys.argv = [
14 | "temp.py",
15 | "--c",
16 | f"{CONFIG_PATH}",
17 | "--learning_scenario",
18 | "supervised",
19 | "--todo",
20 | "train",
21 | ]
22 |
23 | args = parser()
24 | sys.argv = original_argv
25 |
26 | assert args.learning_scenario == "supervised"
27 | assert args.todo == "train"
28 | assert args.affix == "test"
29 |
--------------------------------------------------------------------------------
/tests/utils/test_torch_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pytest
3 | import torch
4 | import numpy as np
5 | import torchvision
6 | from copy import deepcopy
7 |
8 | from lightning_ssl.utils.torch_utils import (
9 | sharpening,
10 | soft_cross_entropy,
11 | l2_distribution_loss,
12 | mixup_data,
13 | smooth_label,
14 | to_one_hot_vector,
15 | mixup_data_for_testing,
16 | half_mixup_data_for_testing,
17 | customized_weight_decay,
18 | split_weight_decay_weights,
19 | interleave,
20 | my_interleave,
21 | WeightDecayModule,
22 | )
23 |
24 | from lightning_ssl.models import WideResNet
25 |
26 |
27 | @pytest.fixture
28 | def logits_targets_pair():
29 | """
30 | The fixture to test softmax related function. Return the logits and soft targets.
31 | Return
32 | logits, targets
33 | """
34 | logits = [[-1, 1, 0], [1, 0, -1]]
35 | targets = [[0.3, 0.6, 0.1], [0.5, 0.2, 0.3]]
36 | return torch.FloatTensor(logits), torch.FloatTensor(targets)
37 |
38 |
39 | @pytest.fixture(
40 | params=[
41 | {"batch_size": 1, "num_classes": 5, "smoothing_factor": 0.3},
42 | {"batch_size": 3, "num_classes": 5, "smoothing_factor": 0.01},
43 | {"batch_size": 3, "num_classes": 5, "smoothing_factor": 0.0},
44 | {"batch_size": 3, "num_classes": 5, "smoothing_factor": 1},
45 | ]
46 | )
47 | def targets_smoothing_classes_tuple(request):
48 | param = request.param
49 | return (
50 | torch.randint(param["num_classes"], size=(param["batch_size"],)),
51 | param["smoothing_factor"],
52 | param["num_classes"],
53 | )
54 |
55 |
56 | @pytest.fixture(
57 | params=[
58 | {"batch_size": 20, "num_classes": 10},
59 | {"batch_size": 30, "num_classes": 2},
60 | {"batch_size": 20, "num_classes": 100},
61 | ]
62 | )
63 | def images_targets_num_classes_tuple(request):
64 | param = request.param
65 | batch_size = param["batch_size"]
66 | num_classes = param["num_classes"]
67 | return (
68 | torch.randn(batch_size, 3, 32, 32),
69 | torch.randint(num_classes, size=(batch_size,)),
70 | num_classes,
71 | )
72 |
73 |
74 | @pytest.mark.parametrize("temperature", [0.1, 0.5, 1, 2, 10])
75 | def test_sharpening(temperature):
76 | def entropy(probs):
77 | return torch.sum(-probs * torch.log(probs), -1).mean()
78 |
79 | for _ in range(10):
80 | logits = torch.randn(4, 10)
81 | probs = torch.softmax(logits, -1)
82 | sharpening_probs = sharpening(probs, temperature)
83 |
84 | if temperature > 1:
85 | assert entropy(sharpening_probs) > entropy(probs)
86 | elif temperature < 1:
87 | assert entropy(sharpening_probs) < entropy(probs)
88 | else:
89 | assert torch.isclose(entropy(sharpening_probs), entropy(probs))
90 |
91 |
92 | def test_soft_cross_entropy(logits_targets_pair):
93 | logits, targets = logits_targets_pair
94 |
95 | probs = torch.softmax(logits, -1)
96 |
97 | ans = 0
98 | for i in range(logits.shape[0]):
99 | for c in range(logits.shape[1]):
100 | ans += -torch.log(probs[i][c]) * targets[i][c]
101 |
102 | ans /= logits.shape[0]
103 | func_out = soft_cross_entropy(logits, targets, dim=-1)
104 |
105 | assert ans == func_out
106 |
107 |
108 | def test_l2_distribution_loss(logits_targets_pair):
109 | logits, targets = logits_targets_pair
110 |
111 | probs = torch.softmax(logits, -1)
112 |
113 | ans = 0
114 | for i in range(logits.shape[0]):
115 | for c in range(logits.shape[1]):
116 | ans += (probs[i][c] - targets[i][c]) ** 2
117 |
118 | ans /= logits.shape[0] * logits.shape[1]
119 | func_out = l2_distribution_loss(logits, targets, dim=-1)
120 |
121 | assert ans == func_out
122 |
123 |
124 | def test_smooth_label(targets_smoothing_classes_tuple):
125 | targets, smoothing_factor, num_classes = targets_smoothing_classes_tuple
126 | ##################################################################
127 | # test for label smoothing
128 | batch_size = targets.shape[0]
129 | y_labels = smooth_label(targets, num_classes, smoothing_factor)
130 |
131 | # predicted classes
132 | pred = torch.argmax(y_labels, dim=-1)
133 |
134 | # the logits of maximum classes should be (1 - smoothing_factor)
135 | assert torch.all(
136 | y_labels[torch.arange(batch_size), targets] == 1 - smoothing_factor
137 | )
138 | # all the other logits should be the same
139 | assert torch.sum(y_labels != 1 - smoothing_factor) == batch_size * (num_classes - 1)
140 | # summation should be equal to 1
141 | assert torch.all(torch.isclose(torch.sum(y_labels, -1), torch.ones(pred.shape)))
142 |
143 |
144 | def test_one_hot(targets_smoothing_classes_tuple):
145 | targets, _, num_classes = targets_smoothing_classes_tuple
146 | batch_size = targets.shape[0]
147 | # test for one hot transformation
148 | y_labels = to_one_hot_vector(targets, num_classes)
149 |
150 | # predicted classes
151 | pred = torch.argmax(y_labels, dim=-1)
152 |
153 | # the logits of maximum classes should be 1
154 | assert torch.all(y_labels[torch.arange(batch_size), pred] == 1)
155 | # the maximum classes should be the same as targets
156 | assert torch.all(torch.eq(pred, targets))
157 | # summation should be equal to 1
158 | assert torch.all(torch.isclose(torch.sum(y_labels, -1), torch.ones(pred.shape)))
159 |
160 |
161 | def test_mixup_minus(images_targets_num_classes_tuple):
162 | inputs, targets, _ = images_targets_num_classes_tuple
163 | mixed_inputs, mixed_targets = mixup_data(inputs, targets, -1)
164 |
165 | assert torch.all(torch.eq(inputs, mixed_inputs))
166 | assert torch.all(torch.eq(targets, mixed_targets))
167 |
168 |
169 | def test_mixup(images_targets_num_classes_tuple):
170 | inputs, targets, num_classes = images_targets_num_classes_tuple
171 |
172 | logits = smooth_label(targets, num_classes)
173 |
174 | _, _, _, _, gt_lambda = mixup_data_for_testing(inputs, logits, 0.5)
175 |
176 | # the mixed data should be between it's two ingredients
177 |
178 | # exclude the data which shuffle to the same place,
179 | # this will cause mixed data equal to the origin ingredients,
180 | # but sometimes two same numbers will be considered different due to the limitation float number
181 |
182 | # test lambda is within the range
183 | assert torch.all((gt_lambda >= 0) * (gt_lambda <= 1))
184 |
185 | # same_pos_x = torch.all(torch.isclose(inputs, p_inputs).reshape(batch_size, -1), dim=-1)
186 |
187 | # inputs, p_inputs, logits, p_logits = \
188 | # inputs[~same_pos_x], p_inputs[~same_pos_x], logits[~same_pos_x], p_logits[~same_pos_x]
189 |
190 | # mixed_x, mixed_y = mixed_x[~same_pos_x], mixed_y[~same_pos_x]
191 |
192 | # min_x, max_x = torch.min(inputs, p_inputs), torch.max(inputs, p_inputs)
193 | # min_y, max_y = torch.min(logits, p_logits), torch.max(logits, p_logits)
194 |
195 | # pos = ~((min_x <= mixed_x) * (mixed_x <= max_x))
196 |
197 | # print(inputs[pos]==mixed_x[pos])
198 | # print(p_inputs[pos])
199 |
200 | # print(mixed_x[pos])
201 |
202 | # assert torch.all((min_x <= mixed_x) * (mixed_x <= max_x)) == True
203 |
204 | # def reconstruct(mixed_d, original_d, shuffle_d):
205 | # numerator = mixed_d - shuffle_d
206 | # denominator = original_d - shuffle_d
207 | # mask = (denominator == 0).float()
208 |
209 | # reverted_l = numerator / (denominator + mask * 1e-15)
210 |
211 | # return reverted_l
212 |
213 | # l_1 = reconstruct(mixed_x, inputs, p_inputs)
214 | # l_1 = torch.mode(l_1.reshape(l_1.shape[0], -1), -1)[0]
215 | # l_2 = reconstruct(mixed_y, y_labels, p_y_labels)
216 | # l_2 = torch.max(l_2.reshape(l_2.shape[0], -1), -1)[0]
217 |
218 | # for i in range(l_1.shape[0]):
219 | # # only check when the y_label and permuted y_label are different,
220 | # # or the value will be 0
221 | # if torch.all(y_labels[i] == p_y_labels[i]):
222 | # assert l_2[i] == 0
223 |
224 | # elif torch.all(inputs[i] == p_inputs[i]):
225 | # assert l_1[i] == 0
226 | # else:
227 | # assert torch.isclose(l_1[i], l_2[i], atol=1e-5) == 0
228 | # assert 0.0 <= l_1[i] <= 1 == True
229 |
230 |
231 | def test_half_mixup(images_targets_num_classes_tuple):
232 | inputs, targets, num_classes = images_targets_num_classes_tuple
233 |
234 | logits = smooth_label(targets, num_classes)
235 |
236 | _, _, _, _, gt_lambda = half_mixup_data_for_testing(inputs, logits, 0.5)
237 |
238 | # test lambda is within the range
239 | assert torch.all((gt_lambda >= 0.5) * (gt_lambda <= 1))
240 |
241 |
242 | def test_customized_weight_decay():
243 | m = torchvision.models.resnet34()
244 | m_copy = deepcopy(m)
245 |
246 | ignore_key = ["bn", "bias"]
247 | wd = 1e-2
248 | customized_weight_decay(m, wd, ignore_key=ignore_key)
249 |
250 | for (name_m, p_m), (name_copy, p_copy) in zip(
251 | m.named_parameters(), m_copy.named_parameters()
252 | ):
253 | assert name_m == name_copy
254 | # ignore, and don't decay weight
255 | if any(key in name_m for key in ignore_key):
256 | assert torch.all(torch.eq(p_m, p_copy))
257 | else:
258 | assert torch.all(torch.eq(p_m, p_copy * (1 - wd)))
259 |
260 |
261 | def test_weight_decay_module():
262 | m = WideResNet(28, 10)
263 |
264 | wd_module = WeightDecayModule(m, 0.1, ["bn", "bias"])
265 |
266 | num_weights = 29 # weights of conv and fc
267 |
268 | assert len(wd_module.available_parameters) == num_weights
269 |
270 | wd_module = WeightDecayModule(m, 0.1, ["weight"])
271 |
272 | num_bias = 25 + 1 + 0 # bias of bn + bias of fc + bias of conv
273 |
274 | assert len(wd_module.available_parameters) == num_bias
275 |
276 | original_parameters = deepcopy(wd_module.available_parameters)
277 | wd_module.decay()
278 |
279 | for original_weight, decay_weight in zip(
280 | original_parameters, wd_module.available_parameters
281 | ):
282 | assert torch.all(torch.eq(original_weight * 0.9, decay_weight))
283 |
284 |
285 | def test_split_weight_decay_weights():
286 | m = torchvision.models.resnet34()
287 | m_copy = deepcopy(m)
288 |
289 | ignore_key = ["bn", "bias"]
290 | wd = 1e-2
291 | customized_weight_decay(m, wd, ignore_key=ignore_key)
292 |
293 | num_weight_decay = 0
294 | num_no_weight_decay = 0
295 | for (name_m, _), (name_copy, _) in zip(
296 | m.named_parameters(), m_copy.named_parameters()
297 | ):
298 | assert name_m == name_copy
299 | # ignore, and don't decay weight
300 | if any(key in name_m for key in ignore_key):
301 | num_no_weight_decay += 1
302 | else:
303 | num_weight_decay += 1
304 |
305 | params_list = split_weight_decay_weights(m, wd, ignore_key=ignore_key)
306 |
307 | # first item is no weight decay
308 |
309 | assert len(params_list[0]["params"]) == num_no_weight_decay
310 | assert params_list[0]["weight_decay"] == 0
311 | assert len(params_list[1]["params"]) == num_weight_decay
312 | assert params_list[1]["weight_decay"] == wd
313 |
314 |
315 | @pytest.mark.parametrize("shape", [[], [3, 32, 32]])
316 | @pytest.mark.parametrize(
317 | "batch_size, num_data", [(5, 3), (128, 3), (50, 2), (3, 168), (168, 168)]
318 | )
319 | def test_interleave(shape, batch_size, num_data):
320 | print(shape, batch_size, num_data)
321 | inputs = [torch.randint(100, [batch_size] + shape) for _ in range(num_data)]
322 | orig_inputs = deepcopy(inputs)
323 |
324 | # test_interleave
325 |
326 | o_1 = interleave(inputs, batch_size)
327 | o_3 = my_interleave(inputs, batch_size)
328 |
329 | for l, r in zip(o_1, o_3):
330 | assert torch.all(torch.eq(l, r))
331 |
332 | # reverse interleave
333 | o_1 = interleave(o_1, batch_size)
334 | o_3 = my_interleave(o_3, batch_size)
335 |
336 | for a_, l, r in zip(orig_inputs, o_1, o_3):
337 | assert torch.all(torch.eq(a_, r))
338 | assert torch.all(torch.eq(l, r))
339 |
--------------------------------------------------------------------------------
/tests/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import subprocess
4 | from lightning_ssl.utils.utils import makedirs
5 |
6 |
7 | def test_makedirs(tmpdir):
8 | folder = os.path.join(tmpdir, "a/b")
9 | makedirs(folder)
10 | assert os.path.exists(folder)
11 |
--------------------------------------------------------------------------------