├── .gitignore
├── .idea
├── .gitignore
├── deployment.xml
├── fixmatch_jian.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── datasets
├── __init__.py
├── cifar.py
├── randaugment.py
├── sampler.py
└── transform.py
├── label_guessor.py
├── lr_scheduler.py
├── models
├── __init__.py
├── ema.py
└── model.py
├── requirements.txt
├── run.sh
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | parts/
18 | sdist/
19 | var/
20 | wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 | *.cover
45 | .hypothesis/
46 | .pytest_cache/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 | db.sqlite3
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # Environments
83 | .env
84 | .venv
85 | env/
86 | venv/
87 | ENV/
88 | env.bak/
89 | venv.bak/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 |
105 | ## Coin:
106 | data
107 | dataset/
108 | data/
109 | play.py
110 | preprocess_data.py
111 | res/
112 | labels.py
113 | adj.md
114 | bayesian_search.py
115 | evaluate.py.old
116 | *png
117 | infer.py
118 |
119 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/fixmatch_jian.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 CoinCheung
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## FixMatch
3 |
4 | The unofficial reimplementation of [fixmatch](https://arxiv.org/abs/2001.07685) with RandomAugment.
5 |
6 | ## Overview
7 |
8 |
9 | |repo|using EMA model to evaluate|using EMA model to train|update parameters|update buffer|
10 | |:---|:---:|:---:|:---:|:---:|
11 | |ours| ✓|-| ✓|-|
12 | |mdiephuis| ✓| ✓| ✓|-|
13 | |kekmodel| ✓|-|-| ✓|
14 |
15 |
16 | 2020-03-30_18:07:08.log : annotation decay and add classifier.bias
17 |
18 | 2020-03-31_09:51:38.log : add interleave and run model once
19 |
20 | ## Dependencies
21 |
22 | - python 3.6
23 | - pytorch 1.3.1
24 | - torchvision 0.2.1
25 |
26 | The other packages and versions are listed in ```requirements.txt```.
27 | You can install them by ```pip install -r requirements.txt```.
28 |
29 |
30 | ## Dataset
31 | download cifar-10 dataset:
32 | ```
33 | $ mkdir -p dataset && cd data
34 | $ wget -c http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
35 | $ tar -xzvf cifar-10-python.tar.gz
36 | ```
37 |
38 | download cifar-100 dataset:
39 | ```
40 | $ mkdir -p dataset && cd data
41 | $ wget -c http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
42 | $ tar -xzvf cifar-100-python.tar.gz
43 | ```
44 |
45 | ## Train the model
46 |
47 | To train the model on CIFAR10 with 40 labeled samples, you can run the script:
48 | ```
49 | $ CUDA_VISIBLE_DEVICES='0' python train.py --dataset CIFAR10 --n-labeled 40
50 | ```
51 | To train the model on CIFAR100 with 400 labeled samples, you can run the script:
52 | ```
53 | $ CUDA_VISIBLE_DEVICES='0' python train.py --dataset CIFAR100 --n-labeled 400
54 | ```
55 |
56 |
57 | ## Results
58 |
59 |
60 | ### CIFAR10
61 | | #Labels | 40 | 250 | 4000 |
62 | |:---|:---:|:---:|:---:|
63 | |Paper (RA) | 86.19 ± 3.37 | 94.93 ± 0.65 | 95.74 ± 0.05 |
64 | |ours| 89.63(85.65) | 93.0832 |94.7154|
65 |
66 | ### CIFAR100
67 |
68 | | #Labels | 400 | 2500 | 10000 |
69 | |:---|:---:|:---:|:---:|
70 | |Paper (RA) | 51.15 ± 1.75 | 71.71 ± 0.11 | 77.40 ± 0.12 |
71 | |ours | 53.74 | 67.3169 | 73.26 |
72 |
73 |
74 | ### References
75 | - https://github.com/CoinCheung/fixmatch
76 | - https://github.com/kekmodel/FixMatch-pytorch
77 | - official implement https://github.com/google-research/fixmatch
78 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valencebond/FixMatch_pytorch/3b7fc36ca0c71754a59c9c78465d681f259ee174/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/cifar.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 | import numpy as np
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 | from datasets import transform as T
8 |
9 | from datasets.randaugment import RandomAugment
10 | from datasets.sampler import RandomSampler, BatchSampler
11 |
12 |
13 | def load_data_train(L=250, dataset='CIFAR10', dspth='./data'):
14 | if dataset == 'CIFAR10':
15 | datalist = [
16 | osp.join(dspth, 'cifar-10-batches-py', 'data_batch_{}'.format(i + 1))
17 | for i in range(5)
18 | ]
19 | n_class = 10
20 | assert L in [40, 250, 4000]
21 | elif dataset == 'CIFAR100':
22 | datalist = [
23 | osp.join(dspth, 'cifar-100-python', 'train')]
24 | n_class = 100
25 | assert L in [400, 2500, 10000]
26 |
27 | data, labels = [], []
28 | for data_batch in datalist:
29 | with open(data_batch, 'rb') as fr:
30 | entry = pickle.load(fr, encoding='latin1')
31 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels']
32 | data.append(entry['data'])
33 | labels.append(lbs)
34 | data = np.concatenate(data, axis=0)
35 | labels = np.concatenate(labels, axis=0)
36 | n_labels = L // n_class
37 | data_x, label_x, data_u, label_u = [], [], [], []
38 | for i in range(n_class):
39 | indices = np.where(labels == i)[0]
40 | np.random.shuffle(indices)
41 | inds_x, inds_u = indices[:n_labels], indices[n_labels:]
42 | data_x += [
43 | data[i].reshape(3, 32, 32).transpose(1, 2, 0)
44 | for i in inds_x
45 | ]
46 | label_x += [labels[i] for i in inds_x]
47 | data_u += [
48 | data[i].reshape(3, 32, 32).transpose(1, 2, 0)
49 | for i in inds_u
50 | ]
51 | label_u += [labels[i] for i in inds_u]
52 | return data_x, label_x, data_u, label_u
53 |
54 |
55 | # def load_data_train(L=250, dspth='./data'):
56 | # datalist = [
57 | # osp.join(dspth, 'cifar-10-batches-py', 'data_batch_{}'.format(i + 1))
58 | # for i in range(5)
59 | # ]
60 | # data, labels = [], []
61 | # for data_batch in datalist:
62 | # with open(data_batch, 'rb') as fr:
63 | # entry = pickle.load(fr, encoding='latin1')
64 | # lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels']
65 | # data.append(entry['data'])
66 | # labels.append(lbs)
67 | # data = np.concatenate(data, axis=0)
68 | # labels = np.concatenate(labels, axis=0)
69 | # n_labels = L // 10
70 | # data_x, label_x, data_u, label_u = [], [], [], []
71 | # for i in range(10):
72 | # indices = np.where(labels == i)[0]
73 | # np.random.shuffle(indices)
74 | # inds_x, inds_u = indices[:n_labels], indices[n_labels:]
75 | # data_x += [
76 | # data[i].reshape(3, 32, 32).transpose(1, 2, 0)
77 | # for i in inds_x
78 | # ]
79 | # label_x += [labels[i] for i in inds_x]
80 | # data_u += [
81 | # data[i].reshape(3, 32, 32).transpose(1, 2, 0)
82 | # for i in inds_u
83 | # ]
84 | # label_u += [labels[i] for i in inds_u]
85 | # return data_x, label_x, data_u, label_u
86 |
87 |
88 | def load_data_val(dataset, dspth='./data'):
89 | if dataset == 'CIFAR10':
90 | datalist = [
91 | osp.join(dspth, 'cifar-10-batches-py', 'test_batch')
92 | ]
93 | elif dataset == 'CIFAR100':
94 | datalist = [
95 | osp.join(dspth, 'cifar-100-python', 'test')
96 | ]
97 |
98 | data, labels = [], []
99 | for data_batch in datalist:
100 | with open(data_batch, 'rb') as fr:
101 | entry = pickle.load(fr, encoding='latin1')
102 | lbs = entry['labels'] if 'labels' in entry.keys() else entry['fine_labels']
103 | data.append(entry['data'])
104 | labels.append(lbs)
105 | data = np.concatenate(data, axis=0)
106 | labels = np.concatenate(labels, axis=0)
107 | data = [
108 | el.reshape(3, 32, 32).transpose(1, 2, 0)
109 | for el in data
110 | ]
111 | return data, labels
112 |
113 |
114 | def compute_mean_var():
115 | data_x, label_x, data_u, label_u = load_data_train()
116 | data = data_x + data_u
117 | data = np.concatenate([el[None, ...] for el in data], axis=0)
118 |
119 | mean, var = [], []
120 | for i in range(3):
121 | channel = (data[:, :, :, i].ravel() / 127.5) - 1
122 | # channel = (data[:, :, :, i].ravel() / 255)
123 | mean.append(np.mean(channel))
124 | var.append(np.std(channel))
125 |
126 | print('mean: ', mean)
127 | print('var: ', var)
128 |
129 |
130 |
131 | class Cifar(Dataset):
132 | def __init__(self, dataset, data, labels, is_train=True):
133 | super(Cifar, self).__init__()
134 | self.data, self.labels = data, labels
135 | self.is_train = is_train
136 | assert len(self.data) == len(self.labels)
137 | if dataset == 'CIFAR10':
138 | mean, std = (0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)
139 | elif dataset == 'CIFAR100':
140 | mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
141 |
142 | if is_train:
143 | self.trans_weak = T.Compose([
144 | T.Resize((32, 32)),
145 | T.PadandRandomCrop(border=4, cropsize=(32, 32)),
146 | T.RandomHorizontalFlip(p=0.5),
147 | T.Normalize(mean, std),
148 | T.ToTensor(),
149 | ])
150 | self.trans_strong = T.Compose([
151 | T.Resize((32, 32)),
152 | T.PadandRandomCrop(border=4, cropsize=(32, 32)),
153 | T.RandomHorizontalFlip(p=0.5),
154 | RandomAugment(2, 10),
155 | T.Normalize(mean, std),
156 | T.ToTensor(),
157 | ])
158 | else:
159 | self.trans = T.Compose([
160 | T.Resize((32, 32)),
161 | T.Normalize(mean, std),
162 | T.ToTensor(),
163 | ])
164 |
165 | def __getitem__(self, idx):
166 | im, lb = self.data[idx], self.labels[idx]
167 | if self.is_train:
168 | return self.trans_weak(im), self.trans_strong(im), lb
169 | else:
170 | return self.trans(im), lb
171 |
172 | def __len__(self):
173 | leng = len(self.data)
174 | return leng
175 |
176 |
177 | def get_train_loader(dataset, batch_size, mu, n_iters_per_epoch, L, root='data'):
178 | data_x, label_x, data_u, label_u = load_data_train(L=L, dataset=dataset, dspth=root)
179 |
180 | ds_x = Cifar(
181 | dataset=dataset,
182 | data=data_x,
183 | labels=label_x,
184 | is_train=True
185 | ) # return an iter of num_samples length (all indices of samples)
186 | sampler_x = RandomSampler(ds_x, replacement=True, num_samples=n_iters_per_epoch * batch_size)
187 | batch_sampler_x = BatchSampler(sampler_x, batch_size, drop_last=True) # yield a batch of samples one time
188 | dl_x = torch.utils.data.DataLoader(
189 | ds_x,
190 | batch_sampler=batch_sampler_x,
191 | num_workers=2,
192 | pin_memory=True
193 | )
194 | ds_u = Cifar(
195 | dataset=dataset,
196 | data=data_u,
197 | labels=label_u,
198 | is_train=True
199 | )
200 | sampler_u = RandomSampler(ds_u, replacement=True, num_samples=mu * n_iters_per_epoch * batch_size)
201 | batch_sampler_u = BatchSampler(sampler_u, batch_size * mu, drop_last=True)
202 | dl_u = torch.utils.data.DataLoader(
203 | ds_u,
204 | batch_sampler=batch_sampler_u,
205 | num_workers=2,
206 | pin_memory=True
207 | )
208 | return dl_x, dl_u
209 |
210 |
211 | def get_val_loader(dataset, batch_size, num_workers, pin_memory=True):
212 | data, labels = load_data_val(dataset)
213 | ds = Cifar(
214 | dataset=dataset,
215 | data=data,
216 | labels=labels,
217 | is_train=False
218 | )
219 | dl = torch.utils.data.DataLoader(
220 | ds,
221 | shuffle=False,
222 | batch_size=batch_size,
223 | drop_last=False,
224 | num_workers=num_workers,
225 | pin_memory=pin_memory
226 | )
227 | return dl
228 |
229 |
230 | class OneHot(object):
231 | def __init__(
232 | self,
233 | n_labels,
234 | lb_ignore=255,
235 | ):
236 | super(OneHot, self).__init__()
237 | self.n_labels = n_labels
238 | self.lb_ignore = lb_ignore
239 |
240 | def __call__(self, label):
241 | N, *S = label.size()
242 | size = [N, self.n_labels] + S
243 | lb_one_hot = torch.zeros(size)
244 | if label.is_cuda:
245 | lb_one_hot = lb_one_hot.cuda()
246 | ignore = label.data.cpu() == self.lb_ignore
247 | label[ignore] = 0
248 | lb_one_hot.scatter_(1, label.unsqueeze(1), 1)
249 | ignore = ignore.nonzero()
250 | _, M = ignore.size()
251 | a, *b = ignore.chunk(M, dim=1)
252 | lb_one_hot[[a, torch.arange(self.n_labels), *b]] = 0
253 |
254 | return lb_one_hot
255 |
256 |
257 | if __name__ == "__main__":
258 | compute_mean_var()
259 | # dl_x, dl_u = get_train_loader(64, 250, 2, 2)
260 | # dl_x2 = iter(dl_x)
261 | # dl_u2 = iter(dl_u)
262 | # ims, lb = next(dl_u2)
263 | # print(type(ims))
264 | # print(len(ims))
265 | # print(ims[0].size())
266 | # print(len(dl_u2))
267 | # for i in range(1024):
268 | # try:
269 | # ims_x, lbs_x = next(dl_x2)
270 | # # ims_u, lbs_u = next(dl_u2)
271 | # print(i, ": ", ims_x[0].size())
272 | # except StopIteration:
273 | # dl_x2 = iter(dl_x)
274 | # dl_u2 = iter(dl_u)
275 | # ims_x, lbs_x = next(dl_x2)
276 | # # ims_u, lbs_u = next(dl_u2)
277 | # print('recreate iterator')
278 | # print(i, ": ", ims_x[0].size())
279 |
--------------------------------------------------------------------------------
/datasets/randaugment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | ## aug functions
6 | def identity_func(img):
7 | return img
8 |
9 |
10 | def autocontrast_func(img, cutoff=0):
11 | '''
12 | same output as PIL.ImageOps.autocontrast
13 | '''
14 | n_bins = 256
15 |
16 | def tune_channel(ch):
17 | n = ch.size
18 | cut = cutoff * n // 100
19 | if cut == 0:
20 | high, low = ch.max(), ch.min()
21 | else:
22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23 | low = np.argwhere(np.cumsum(hist) > cut)
24 | low = 0 if low.shape[0] == 0 else low[0]
25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27 | if high <= low:
28 | table = np.arange(n_bins)
29 | else:
30 | scale = (n_bins - 1) / (high - low)
31 | offset = -low * scale
32 | table = np.arange(n_bins) * scale + offset
33 | table[table < 0] = 0
34 | table[table > n_bins - 1] = n_bins - 1
35 | table = table.clip(0, 255).astype(np.uint8)
36 | return table[ch]
37 |
38 | channels = [tune_channel(ch) for ch in cv2.split(img)]
39 | out = cv2.merge(channels)
40 | return out
41 |
42 |
43 | def equalize_func(img):
44 | '''
45 | same output as PIL.ImageOps.equalize
46 | PIL's implementation is different from cv2.equalize
47 | '''
48 | n_bins = 256
49 |
50 | def tune_channel(ch):
51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52 | non_zero_hist = hist[hist != 0].reshape(-1)
53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54 | if step == 0: return ch
55 | n = np.empty_like(hist)
56 | n[0] = step // 2
57 | n[1:] = hist[:-1]
58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59 | return table[ch]
60 |
61 | channels = [tune_channel(ch) for ch in cv2.split(img)]
62 | out = cv2.merge(channels)
63 | return out
64 |
65 |
66 | def rotate_func(img, degree, fill=(0, 0, 0)):
67 | '''
68 | like PIL, rotate by degree, not radians
69 | '''
70 | H, W = img.shape[0], img.shape[1]
71 | center = W / 2, H / 2
72 | M = cv2.getRotationMatrix2D(center, degree, 1)
73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74 | return out
75 |
76 |
77 | def solarize_func(img, thresh=128):
78 | '''
79 | same output as PIL.ImageOps.posterize
80 | '''
81 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
82 | table = table.clip(0, 255).astype(np.uint8)
83 | out = table[img]
84 | return out
85 |
86 |
87 | def color_func(img, factor):
88 | '''
89 | same output as PIL.ImageEnhance.Color
90 | '''
91 | ## implementation according to PIL definition, quite slow
92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93 | # out = blend(degenerate, img, factor)
94 | # M = (
95 | # np.eye(3) * factor
96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97 | # )[np.newaxis, np.newaxis, :]
98 | M = (
99 | np.float32([
100 | [0.886, -0.114, -0.114],
101 | [-0.587, 0.413, -0.587],
102 | [-0.299, -0.299, 0.701]]) * factor
103 | + np.float32([[0.114], [0.587], [0.299]])
104 | )
105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106 | return out
107 |
108 |
109 | def contrast_func(img, factor):
110 | """
111 | same output as PIL.ImageEnhance.Contrast
112 | """
113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114 | table = np.array([(
115 | el - mean) * factor + mean
116 | for el in range(256)
117 | ]).clip(0, 255).astype(np.uint8)
118 | out = table[img]
119 | return out
120 |
121 |
122 | def brightness_func(img, factor):
123 | '''
124 | same output as PIL.ImageEnhance.Contrast
125 | '''
126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127 | out = table[img]
128 | return out
129 |
130 |
131 | def sharpness_func(img, factor):
132 | '''
133 | The differences the this result and PIL are all on the 4 boundaries, the center
134 | areas are same
135 | '''
136 | kernel = np.ones((3, 3), dtype=np.float32)
137 | kernel[1][1] = 5
138 | kernel /= 13
139 | degenerate = cv2.filter2D(img, -1, kernel)
140 | if factor == 0.0:
141 | out = degenerate
142 | elif factor == 1.0:
143 | out = img
144 | else:
145 | out = img.astype(np.float32)
146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148 | out = out.astype(np.uint8)
149 | return out
150 |
151 |
152 | def shear_x_func(img, factor, fill=(0, 0, 0)):
153 | H, W = img.shape[0], img.shape[1]
154 | M = np.float32([[1, factor, 0], [0, 1, 0]])
155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156 | return out
157 |
158 |
159 | def translate_x_func(img, offset, fill=(0, 0, 0)):
160 | '''
161 | same output as PIL.Image.transform
162 | '''
163 | H, W = img.shape[0], img.shape[1]
164 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166 | return out
167 |
168 |
169 | def translate_y_func(img, offset, fill=(0, 0, 0)):
170 | '''
171 | same output as PIL.Image.transform
172 | '''
173 | H, W = img.shape[0], img.shape[1]
174 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176 | return out
177 |
178 |
179 | def posterize_func(img, bits):
180 | '''
181 | same output as PIL.ImageOps.posterize
182 | '''
183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184 | return out
185 |
186 |
187 | def shear_y_func(img, factor, fill=(0, 0, 0)):
188 | H, W = img.shape[0], img.shape[1]
189 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191 | return out
192 |
193 |
194 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
195 | replace = np.array(replace, dtype=np.uint8)
196 | H, W = img.shape[0], img.shape[1]
197 | rh, rw = np.random.random(2)
198 | pad_size = pad_size // 2
199 | ch, cw = int(rh * H), int(rw * W)
200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202 | out = img.copy()
203 | out[x1:x2, y1:y2, :] = replace
204 | return out
205 |
206 |
207 | ### level to args
208 | def enhance_level_to_args(MAX_LEVEL):
209 | def level_to_args(level):
210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211 | return level_to_args
212 |
213 |
214 | def shear_level_to_args(MAX_LEVEL, replace_value):
215 | def level_to_args(level):
216 | level = (level / MAX_LEVEL) * 0.3
217 | if np.random.random() > 0.5: level = -level
218 | return (level, replace_value)
219 |
220 | return level_to_args
221 |
222 |
223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224 | def level_to_args(level):
225 | level = (level / MAX_LEVEL) * float(translate_const)
226 | if np.random.random() > 0.5: level = -level
227 | return (level, replace_value)
228 |
229 | return level_to_args
230 |
231 |
232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233 | def level_to_args(level):
234 | level = int((level / MAX_LEVEL) * cutout_const)
235 | return (level, replace_value)
236 |
237 | return level_to_args
238 |
239 |
240 | def solarize_level_to_args(MAX_LEVEL):
241 | def level_to_args(level):
242 | level = int((level / MAX_LEVEL) * 256)
243 | return (level, )
244 | return level_to_args
245 |
246 |
247 | def none_level_to_args(level):
248 | return ()
249 |
250 |
251 | def posterize_level_to_args(MAX_LEVEL):
252 | def level_to_args(level):
253 | level = int((level / MAX_LEVEL) * 4)
254 | return (level, )
255 | return level_to_args
256 |
257 |
258 | def rotate_level_to_args(MAX_LEVEL, replace_value):
259 | def level_to_args(level):
260 | level = (level / MAX_LEVEL) * 30
261 | if np.random.random() < 0.5:
262 | level = -level
263 | return (level, replace_value)
264 |
265 | return level_to_args
266 |
267 |
268 | func_dict = {
269 | 'Identity': identity_func,
270 | 'AutoContrast': autocontrast_func,
271 | 'Equalize': equalize_func,
272 | 'Rotate': rotate_func,
273 | 'Solarize': solarize_func,
274 | 'Color': color_func,
275 | 'Contrast': contrast_func,
276 | 'Brightness': brightness_func,
277 | 'Sharpness': sharpness_func,
278 | 'ShearX': shear_x_func,
279 | 'TranslateX': translate_x_func,
280 | 'TranslateY': translate_y_func,
281 | 'Posterize': posterize_func,
282 | 'ShearY': shear_y_func,
283 | }
284 |
285 | translate_const = 10
286 | MAX_LEVEL = 10
287 | replace_value = (128, 128, 128)
288 | arg_dict = {
289 | 'Identity': none_level_to_args,
290 | 'AutoContrast': none_level_to_args,
291 | 'Equalize': none_level_to_args,
292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293 | 'Solarize': solarize_level_to_args(MAX_LEVEL),
294 | 'Color': enhance_level_to_args(MAX_LEVEL),
295 | 'Contrast': enhance_level_to_args(MAX_LEVEL),
296 | 'Brightness': enhance_level_to_args(MAX_LEVEL),
297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299 | 'TranslateX': translate_level_to_args(
300 | translate_const, MAX_LEVEL, replace_value
301 | ),
302 | 'TranslateY': translate_level_to_args(
303 | translate_const, MAX_LEVEL, replace_value
304 | ),
305 | 'Posterize': posterize_level_to_args(MAX_LEVEL),
306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307 | }
308 |
309 |
310 | class RandomAugment(object):
311 |
312 | def __init__(self, N=2, M=10):
313 | self.N = N
314 | self.M = M
315 |
316 | def get_random_ops(self):
317 | sampled_ops = np.random.choice(list(func_dict.keys()), self.N)
318 | return [(op, 0.5, self.M) for op in sampled_ops]
319 |
320 | def __call__(self, img):
321 | ops = self.get_random_ops()
322 | for name, prob, level in ops:
323 | if np.random.random() > prob:
324 | continue
325 | args = arg_dict[name](level)
326 | img = func_dict[name](img, *args)
327 | img = cutout_func(img, 16, replace_value)
328 | return img
329 |
330 |
331 | if __name__ == '__main__':
332 | a = RandomAugment()
333 | img = np.random.randn(32, 32, 3)
334 | a(img)
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch._six import int_classes as _int_classes
3 |
4 |
5 | class Sampler(object):
6 | r"""Base class for all Samplers.
7 |
8 | Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
9 | way to iterate over indices of dataset elements, and a :meth:`__len__` method
10 | that returns the length of the returned iterators.
11 |
12 | .. note:: The :meth:`__len__` method isn't strictly required by
13 | :class:`~torch.utils.data.DataLoader`, but is expected in any
14 | calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
15 | """
16 |
17 | def __init__(self, data_source):
18 | pass
19 |
20 | def __iter__(self):
21 | raise NotImplementedError
22 |
23 | # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
24 | #
25 | # Many times we have an abstract class representing a collection/iterable of
26 | # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
27 | # implementing a `__len__` method. In such cases, we must make sure to not
28 | # provide a default implementation, because both straightforward default
29 | # implementations have their issues:
30 | #
31 | # + `return NotImplemented`:
32 | # Calling `len(subclass_instance)` raises:
33 | # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
34 | #
35 | # + `raise NotImplementedError()`:
36 | # This prevents triggering some fallback behavior. E.g., the built-in
37 | # `list(X)` tries to call `len(X)` first, and executes a different code
38 | # path if the method is not found or `NotImplemented` is returned, while
39 | # raising an `NotImplementedError` will propagate and and make the call
40 | # fail where it could have use `__iter__` to complete the call.
41 | #
42 | # Thus, the only two sensible things to do are
43 | #
44 | # + **not** provide a default `__len__`.
45 | #
46 | # + raise a `TypeError` instead, which is what Python uses when users call
47 | # a method that is not defined on an object.
48 | # (@ssnl verifies that this works on at least Python 3.7.)
49 |
50 |
51 | class SequentialSampler(Sampler):
52 | r"""Samples elements sequentially, always in the same order.
53 |
54 | Arguments:
55 | data_source (Dataset): dataset to sample from
56 | """
57 |
58 | def __init__(self, data_source):
59 | self.data_source = data_source
60 |
61 | def __iter__(self):
62 | return iter(range(len(self.data_source)))
63 |
64 | def __len__(self):
65 | return len(self.data_source)
66 |
67 |
68 | class RandomSampler(Sampler):
69 | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
70 | If with replacement, then user can specify :attr:`num_samples` to draw.
71 |
72 | Arguments:
73 | data_source (Dataset): dataset to sample from
74 | replacement (bool): samples are drawn with replacement if ``True``, default=``False``
75 | num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
76 | is supposed to be specified only when `replacement` is ``True``.
77 | """
78 |
79 | def __init__(self, data_source, replacement=False, num_samples=None):
80 | self.data_source = data_source
81 | self.replacement = replacement
82 | self._num_samples = num_samples
83 |
84 | if not isinstance(self.replacement, bool):
85 | raise ValueError("replacement should be a boolean value, but got "
86 | "replacement={}".format(self.replacement))
87 |
88 | if self._num_samples is not None and not replacement:
89 | raise ValueError("With replacement=False, num_samples should not be specified, "
90 | "since a random permute will be performed.")
91 |
92 | if not isinstance(self.num_samples, int) or self.num_samples <= 0:
93 | raise ValueError("num_samples should be a positive integer "
94 | "value, but got num_samples={}".format(self.num_samples))
95 |
96 | @property
97 | def num_samples(self):
98 | # dataset size might change at runtime
99 | if self._num_samples is None:
100 | return len(self.data_source)
101 | return self._num_samples
102 |
103 | def __iter__(self):
104 | n = len(self.data_source)
105 | if self.replacement:
106 | n_repeats = self.num_samples // n
107 | n_remain = self.num_samples % n
108 | indices = [torch.randperm(n) for _ in range(n_repeats)]
109 | indices.append(torch.randperm(n)[:n_remain])
110 | return iter(torch.cat(indices, dim=0).tolist())
111 | return iter(torch.randperm(n).tolist())
112 |
113 | def __len__(self):
114 | return self.num_samples
115 |
116 |
117 | class SubsetRandomSampler(Sampler):
118 | r"""Samples elements randomly from a given list of indices, without replacement.
119 |
120 | Arguments:
121 | indices (sequence): a sequence of indices
122 | """
123 |
124 | def __init__(self, indices):
125 | self.indices = indices
126 |
127 | def __iter__(self):
128 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
129 |
130 | def __len__(self):
131 | return len(self.indices)
132 |
133 |
134 | class WeightedRandomSampler(Sampler):
135 | r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
136 |
137 | Args:
138 | weights (sequence) : a sequence of weights, not necessary summing up to one
139 | num_samples (int): number of samples to draw
140 | replacement (bool): if ``True``, samples are drawn with replacement.
141 | If not, they are drawn without replacement, which means that when a
142 | sample index is drawn for a row, it cannot be drawn again for that row.
143 |
144 | Example:
145 | >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
146 | [0, 0, 0, 1, 0]
147 | >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
148 | [0, 1, 4, 3, 2]
149 | """
150 |
151 | def __init__(self, weights, num_samples, replacement=True):
152 | if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
153 | num_samples <= 0:
154 | raise ValueError("num_samples should be a positive integer "
155 | "value, but got num_samples={}".format(num_samples))
156 | if not isinstance(replacement, bool):
157 | raise ValueError("replacement should be a boolean value, but got "
158 | "replacement={}".format(replacement))
159 | self.weights = torch.as_tensor(weights, dtype=torch.double)
160 | self.num_samples = num_samples
161 | self.replacement = replacement
162 |
163 | def __iter__(self):
164 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
165 |
166 | def __len__(self):
167 | return self.num_samples
168 |
169 |
170 | class BatchSampler(Sampler):
171 | r"""Wraps another sampler to yield a mini-batch of indices.
172 |
173 | Args:
174 | sampler (Sampler): Base sampler.
175 | batch_size (int): Size of mini-batch.
176 | drop_last (bool): If ``True``, the sampler will drop the last batch if
177 | its size would be less than ``batch_size``
178 |
179 | Example:
180 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
181 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
182 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
183 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
184 | """
185 |
186 | def __init__(self, sampler, batch_size, drop_last):
187 | if not isinstance(sampler, Sampler):
188 | raise ValueError("sampler should be an instance of "
189 | "torch.utils.data.Sampler, but got sampler={}"
190 | .format(sampler))
191 | if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
192 | batch_size <= 0:
193 | raise ValueError("batch_size should be a positive integer value, "
194 | "but got batch_size={}".format(batch_size))
195 | if not isinstance(drop_last, bool):
196 | raise ValueError("drop_last should be a boolean value, but got "
197 | "drop_last={}".format(drop_last))
198 | self.sampler = sampler
199 | self.batch_size = batch_size
200 | self.drop_last = drop_last
201 |
202 | def __iter__(self):
203 | batch = []
204 | for idx in self.sampler:
205 | batch.append(idx)
206 | if len(batch) == self.batch_size:
207 | yield batch
208 | batch = []
209 | if len(batch) > 0 and not self.drop_last:
210 | yield batch
211 |
212 | def __len__(self):
213 | if self.drop_last:
214 | return len(self.sampler) // self.batch_size
215 | else:
216 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
217 |
--------------------------------------------------------------------------------
/datasets/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2
4 |
5 |
6 | class PadandRandomCrop(object):
7 | '''
8 | Input tensor is expected to have shape of (H, W, 3)
9 | '''
10 | def __init__(self, border=4, cropsize=(32, 32)):
11 | self.border = border
12 | self.cropsize = cropsize
13 |
14 | def __call__(self, im):
15 | borders = [(self.border, self.border), (self.border, self.border), (0, 0)] # input is (h, w, c)
16 | convas = np.pad(im, borders, mode='reflect')
17 | H, W, C = convas.shape
18 | h, w = self.cropsize
19 | dh, dw = max(0, H-h), max(0, W-w)
20 | sh, sw = np.random.randint(0, dh), np.random.randint(0, dw)
21 | out = convas[sh:sh+h, sw:sw+w, :]
22 | return out
23 |
24 |
25 | class RandomHorizontalFlip(object):
26 | def __init__(self, p=0.5):
27 | self.p = p
28 |
29 | def __call__(self, im):
30 | if np.random.rand() < self.p:
31 | im = im[:, ::-1, :]
32 | return im
33 |
34 |
35 | class Resize(object):
36 | def __init__(self, size):
37 | self.size = size
38 |
39 | def __call__(self, im):
40 | im = cv2.resize(im, self.size)
41 | return im
42 |
43 |
44 | class Normalize(object):
45 | '''
46 | Inputs are pixel values in range of [0, 255], channel order is 'rgb'
47 | '''
48 | def __init__(self, mean, std):
49 | self.mean = np.array(mean, np.float32).reshape(1, 1, -1)
50 | self.std = np.array(std, np.float32).reshape(1, 1, -1)
51 |
52 | def __call__(self, im):
53 | if len(im.shape) == 4:
54 | mean, std = self.mean[None, ...], self.std[None, ...]
55 | elif len(im.shape) == 3:
56 | mean, std = self.mean, self.std
57 | im = im.astype(np.float32) / 255.
58 | # im = (im.astype(np.float32) / 127.5) - 1
59 | im -= mean
60 | im /= std
61 | return im
62 |
63 |
64 | class ToTensor(object):
65 | def __init__(self):
66 | pass
67 |
68 | def __call__(self, im):
69 | if len(im.shape) == 4:
70 | return torch.from_numpy(im.transpose(0, 3, 1, 2))
71 | elif len(im.shape) == 3:
72 | return torch.from_numpy(im.transpose(2, 0, 1))
73 |
74 |
75 | class Compose(object):
76 | def __init__(self, ops):
77 | self.ops = ops
78 |
79 | def __call__(self, im):
80 | for op in self.ops:
81 | im = op(im)
82 | return im
83 |
--------------------------------------------------------------------------------
/label_guessor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class LabelGuessor(object):
5 |
6 | def __init__(self, thresh):
7 | self.thresh = thresh
8 |
9 | def __call__(self, model, ims):
10 | org_state = {
11 | k: v.clone().detach()
12 | for k, v in model.state_dict().items()
13 | }
14 | is_train = model.training
15 | with torch.no_grad():
16 | model.train()
17 | logits = model(ims)
18 | probs = torch.softmax(logits, dim=1)
19 | scores, lbs = torch.max(probs, dim=1)
20 | mask = scores.ge(self.thresh).float()
21 |
22 | # note it is necessary to keep org_state! especially for bn layer
23 | # for k, v in org_state.items():
24 | # if not all((model.state_dict()[k] == v).reshape(-1)):
25 | # print(f'{k} diff')
26 |
27 | model.load_state_dict(org_state)
28 | if is_train:
29 | model.train()
30 | else:
31 | model.eval()
32 | return mask, lbs
33 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | import math
4 |
5 | import torch
6 | from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
7 | import numpy as np
8 |
9 |
10 | class WarmupExpLrScheduler(_LRScheduler):
11 | def __init__(
12 | self,
13 | optimizer,
14 | power,
15 | step_interval=1,
16 | warmup_iter=500,
17 | warmup_ratio=5e-4,
18 | warmup='exp',
19 | last_epoch=-1,
20 | ):
21 | self.power = power
22 | self.step_interval = step_interval
23 | self.warmup_iter = warmup_iter
24 | self.warmup_ratio = warmup_ratio
25 | self.warmup = warmup
26 | super(WarmupExpLrScheduler, self).__init__(optimizer, last_epoch)
27 |
28 | def get_lr(self):
29 | ratio = self.get_lr_ratio()
30 | lrs = [ratio * lr for lr in self.base_lrs]
31 | return lrs
32 |
33 | def get_lr_ratio(self):
34 | if self.last_epoch < self.warmup_iter:
35 | ratio = self.get_warmup_ratio()
36 | else:
37 | real_iter = self.last_epoch - self.warmup_iter
38 | ratio = self.power ** (real_iter // self.step_interval)
39 | return ratio
40 |
41 | def get_warmup_ratio(self):
42 | assert self.warmup in ('linear', 'exp')
43 | alpha = self.last_epoch / self.warmup_iter
44 | if self.warmup == 'linear':
45 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
46 | elif self.warmup == 'exp':
47 | ratio = self.warmup_ratio ** (1. - alpha)
48 | return ratio
49 |
50 |
51 | class WarmupPolyLrScheduler(_LRScheduler):
52 | def __init__(
53 | self,
54 | optimizer,
55 | power,
56 | max_iter,
57 | warmup_iter,
58 | warmup_ratio=5e-4,
59 | warmup='exp',
60 | last_epoch=-1,
61 | ):
62 | self.power = power
63 | self.max_iter = max_iter
64 | self.warmup_iter = warmup_iter
65 | self.warmup_ratio = warmup_ratio
66 | self.warmup = warmup
67 | super(WarmupPolyLrScheduler, self).__init__(optimizer, last_epoch)
68 |
69 | def get_lr(self):
70 | ratio = self.get_lr_ratio()
71 | lrs = [ratio * lr for lr in self.base_lrs]
72 | return lrs
73 |
74 | def get_lr_ratio(self):
75 | if self.last_epoch < self.warmup_iter:
76 | ratio = self.get_warmup_ratio()
77 | else:
78 | real_iter = self.last_epoch - self.warmup_iter
79 | real_max_iter = self.max_iter - self.warmup_iter
80 | alpha = real_iter / real_max_iter
81 | ratio = (1 - alpha) ** self.power
82 | return ratio
83 |
84 | def get_warmup_ratio(self):
85 | assert self.warmup in ('linear', 'exp')
86 | alpha = self.last_epoch / self.warmup_iter
87 | if self.warmup == 'linear':
88 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
89 | elif self.warmup == 'exp':
90 | ratio = self.warmup_ratio ** (1. - alpha)
91 | return ratio
92 |
93 |
94 | class WarmupCosineLrScheduler(_LRScheduler):
95 | '''
96 | This is different from official definition, this is implemented according to
97 | the paper of fix-match
98 | '''
99 | def __init__(
100 | self,
101 | optimizer,
102 | max_iter,
103 | warmup_iter,
104 | warmup_ratio=5e-4,
105 | warmup='exp',
106 | last_epoch=-1,
107 | ):
108 | self.max_iter = max_iter
109 | self.warmup_iter = warmup_iter
110 | self.warmup_ratio = warmup_ratio
111 | self.warmup = warmup
112 | super(WarmupCosineLrScheduler, self).__init__(optimizer, last_epoch)
113 |
114 | def get_lr(self):
115 | ratio = self.get_lr_ratio()
116 | lrs = [ratio * lr for lr in self.base_lrs]
117 | return lrs
118 |
119 | def get_lr_ratio(self):
120 | if self.last_epoch < self.warmup_iter:
121 | ratio = self.get_warmup_ratio()
122 | else:
123 | real_iter = self.last_epoch - self.warmup_iter
124 | real_max_iter = self.max_iter - self.warmup_iter
125 | ratio = np.cos((7 * np.pi * real_iter) / (16 * real_max_iter))
126 | return ratio
127 |
128 | def get_warmup_ratio(self):
129 | assert self.warmup in ('linear', 'exp')
130 | alpha = self.last_epoch / self.warmup_iter
131 | if self.warmup == 'linear':
132 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha
133 | elif self.warmup == 'exp':
134 | ratio = self.warmup_ratio ** (1. - alpha)
135 | return ratio
136 |
137 |
138 | # from Fixmatch-pytorch
139 | def get_cosine_schedule_with_warmup(optimizer,
140 | num_warmup_steps,
141 | num_training_steps,
142 | num_cycles=7./16.,
143 | last_epoch=-1):
144 | def _lr_lambda(current_step):
145 | if current_step < num_warmup_steps:
146 | return float(current_step) / float(max(1, num_warmup_steps))
147 | no_progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
148 | # return max(0., math.cos(math.pi * num_cycles * no_progress))
149 |
150 | return max(0., (math.cos(math.pi * num_cycles * no_progress) + 1) * 0.5)
151 |
152 | return LambdaLR(optimizer, _lr_lambda, last_epoch)
153 |
154 | if __name__ == "__main__":
155 | model = torch.nn.Conv2d(3, 16, 3, 1, 1)
156 | optim = torch.optim.SGD(model.parameters(), lr=1e-3)
157 |
158 | max_iter = 20000
159 | # lr_scheduler = WarmupCosineLrScheduler(optim, max_iter, 0)
160 | lr_scheduler = get_cosine_schedule_with_warmup(
161 | optim, 0, max_iter)
162 |
163 | lrs = []
164 | for _ in range(max_iter):
165 | lr = lr_scheduler.get_lr()[0]
166 | print(lr)
167 | lrs.append(lr)
168 | lr_scheduler.step()
169 | import matplotlib
170 | import matplotlib.pyplot as plt
171 | import numpy as np
172 | lrs = np.array(lrs)
173 | n_lrs = len(lrs)
174 | plt.plot(np.arange(n_lrs), lrs)
175 | plt.title('3')
176 | plt.grid()
177 | plt.show()
178 |
179 |
180 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/valencebond/FixMatch_pytorch/3b7fc36ca0c71754a59c9c78465d681f259ee174/models/__init__.py
--------------------------------------------------------------------------------
/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 |
4 | """
5 | 为什么 param 和 buffer 要采用不同的的更新策略
6 | param 是 指数移动平均数,buffer 不是
7 | """
8 |
9 |
10 | class EMA(object):
11 | def __init__(self, model, alpha=0.999):
12 | self.step = 0
13 | self.model = model
14 | self.alpha = alpha
15 | self.shadow = self.get_model_state()
16 | self.backup = {}
17 | self.param_keys = [k for k, _ in self.model.named_parameters()]
18 | # num_batches_tracked, running_mean, running_var in bn
19 | self.buffer_keys = [k for k, _ in self.model.named_buffers()]
20 |
21 | def update_params(self):
22 | # decay = min(self.alpha, (self.step + 1) / (self.step + 10)) # ????
23 | decay = self.alpha
24 | state = self.model.state_dict() # current params
25 | for name in self.param_keys:
26 | self.shadow[name].copy_(
27 | decay * self.shadow[name] + (1 - decay) * state[name]
28 | )
29 | # for name in self.buffer_keys:
30 | # self.shadow[name].copy_(
31 | # decay * self.shadow[name]
32 | # + (1 - decay) * state[name]
33 | # )
34 |
35 | self.step += 1
36 |
37 | def update_buffer(self):
38 | # without EMA
39 | state = self.model.state_dict()
40 | for name in self.buffer_keys:
41 | self.shadow[name].copy_(state[name])
42 |
43 | def apply_shadow(self):
44 | self.backup = self.get_model_state()
45 | self.model.load_state_dict(self.shadow)
46 |
47 | def restore(self):
48 | self.model.load_state_dict(self.backup)
49 |
50 | def get_model_state(self):
51 | return {
52 | k: v.clone().detach()
53 | for k, v in self.model.state_dict().items()
54 | }
55 |
56 |
57 | if __name__ == '__main__':
58 | print('=====')
59 | model = torch.nn.BatchNorm1d(5)
60 | ema = EMA(model, 0.9, 0.02, 0.002)
61 | inten = torch.randn(10, 5)
62 | out = model(inten)
63 | ema.update_params()
64 | print(model.state_dict())
65 | ema.update_buffer()
66 | print(model.state_dict())
67 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from torch.nn import BatchNorm2d
9 |
10 | '''
11 | As in the paper, the wide resnet only considers the resnet of the pre-activated version,
12 | and it only considers the basic blocks rather than the bottleneck blocks.
13 | '''
14 |
15 |
16 | class BasicBlockPreAct(nn.Module):
17 | def __init__(self, in_chan, out_chan, drop_rate=0, stride=1, pre_res_act=False):
18 | super(BasicBlockPreAct, self).__init__()
19 | self.bn1 = BatchNorm2d(in_chan, momentum=0.001)
20 | self.relu1 = nn.LeakyReLU(inplace=True, negative_slope=0.1)
21 | self.conv1 = nn.Conv2d(in_chan, out_chan, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn2 = BatchNorm2d(out_chan, momentum=0.001)
23 | self.relu2 = nn.LeakyReLU(inplace=True, negative_slope=0.1)
24 | self.dropout = nn.Dropout(drop_rate) if not drop_rate == 0 else None
25 | self.conv2 = nn.Conv2d(out_chan, out_chan, kernel_size=3, stride=1, padding=1, bias=False)
26 | self.downsample = None
27 | if in_chan != out_chan or stride != 1:
28 | self.downsample = nn.Conv2d(
29 | in_chan, out_chan, kernel_size=1, stride=stride, bias=False
30 | )
31 | self.pre_res_act = pre_res_act
32 | # self.init_weight()
33 |
34 | def forward(self, x):
35 | bn1 = self.bn1(x)
36 | act1 = self.relu1(bn1)
37 | residual = self.conv1(act1)
38 | residual = self.bn2(residual)
39 | residual = self.relu2(residual)
40 | if self.dropout is not None:
41 | residual = self.dropout(residual)
42 | residual = self.conv2(residual)
43 |
44 | shortcut = act1 if self.pre_res_act else x
45 | if self.downsample is not None:
46 | shortcut = self.downsample(shortcut)
47 |
48 | out = shortcut + residual
49 | return out
50 |
51 | def init_weight(self):
52 | # for _, md in self.named_modules():
53 | # if isinstance(md, nn.Conv2d):
54 | # nn.init.kaiming_normal_(
55 | # md.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
56 | # if md.bias is not None:
57 | # nn.init.constant_(md.bias, 0)
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
61 | if m.bias is not None:
62 | nn.init.constant_(m.bias, 0)
63 |
64 |
65 | class WideResnetBackbone(nn.Module):
66 | def __init__(self, k=1, n=28, drop_rate=0):
67 | super(WideResnetBackbone, self).__init__()
68 | self.k, self.n = k, n
69 | assert (self.n - 4) % 6 == 0
70 | n_blocks = (self.n - 4) // 6
71 | n_layers = [16, ] + [self.k * 16 * (2 ** i) for i in range(3)]
72 |
73 | self.conv1 = nn.Conv2d(
74 | 3,
75 | n_layers[0],
76 | kernel_size=3,
77 | stride=1,
78 | padding=1,
79 | bias=False
80 | )
81 | self.layer1 = self.create_layer(
82 | n_layers[0],
83 | n_layers[1],
84 | bnum=n_blocks,
85 | stride=1,
86 | drop_rate=drop_rate,
87 | pre_res_act=True,
88 | )
89 | self.layer2 = self.create_layer(
90 | n_layers[1],
91 | n_layers[2],
92 | bnum=n_blocks,
93 | stride=2,
94 | drop_rate=drop_rate,
95 | pre_res_act=False,
96 | )
97 | self.layer3 = self.create_layer(
98 | n_layers[2],
99 | n_layers[3],
100 | bnum=n_blocks,
101 | stride=2,
102 | drop_rate=drop_rate,
103 | pre_res_act=False,
104 | )
105 | self.bn_last = BatchNorm2d(n_layers[3], momentum=0.001)
106 | self.relu_last = nn.LeakyReLU(inplace=True, negative_slope=0.1)
107 | self.init_weight()
108 |
109 | def create_layer(
110 | self,
111 | in_chan,
112 | out_chan,
113 | bnum,
114 | stride=1,
115 | drop_rate=0,
116 | pre_res_act=False,
117 | ):
118 | layers = [
119 | BasicBlockPreAct(
120 | in_chan,
121 | out_chan,
122 | drop_rate=drop_rate,
123 | stride=stride,
124 | pre_res_act=pre_res_act), ]
125 | for _ in range(bnum - 1):
126 | layers.append(
127 | BasicBlockPreAct(
128 | out_chan,
129 | out_chan,
130 | drop_rate=drop_rate,
131 | stride=1,
132 | pre_res_act=False, ))
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | feat = self.conv1(x)
137 |
138 | feat = self.layer1(feat)
139 | feat2 = self.layer2(feat) # 1/2
140 | feat4 = self.layer3(feat2) # 1/4
141 |
142 | feat4 = self.bn_last(feat4)
143 | feat4 = self.relu_last(feat4)
144 | return feat2, feat4
145 |
146 | def init_weight(self):
147 | # for _, child in self.named_children():
148 | # if isinstance(child, nn.Conv2d):
149 | # n = child.kernel_size[0] * child.kernel_size[0] * child.out_channels
150 | # nn.init.normal_(child.weight, 0, 1. / ((0.5 * n) ** 0.5))
151 | # # nn.init.kaiming_normal_(
152 | # # child.weight, a=0.1, mode='fan_out',
153 | # # nonlinearity='leaky_relu'
154 | # # )
155 | #
156 | # if child.bias is not None:
157 | # nn.init.constant_(child.bias, 0)
158 | for m in self.modules():
159 | if isinstance(m, nn.Conv2d):
160 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
161 | m.weight.data.normal_(0, math.sqrt(2. / n))
162 |
163 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
164 | if m.bias is not None:
165 | nn.init.constant_(m.bias, 0)
166 | elif isinstance(m, nn.BatchNorm2d):
167 | m.weight.data.fill_(1)
168 | m.bias.data.zero_()
169 | elif isinstance(m, nn.Linear):
170 | m.bias.data.zero_()
171 |
172 |
173 | class WideResnet(nn.Module):
174 | '''
175 | for wide-resnet-28-10, the definition should be WideResnet(n_classes, 10, 28)
176 | '''
177 |
178 | def __init__(self, n_classes, k=1, n=28):
179 | super(WideResnet, self).__init__()
180 | self.n_layers, self.k = n, k
181 | self.backbone = WideResnetBackbone(k=k, n=n)
182 | self.classifier = nn.Linear(64 * self.k, n_classes, bias=True)
183 |
184 | def forward(self, x):
185 | feat = self.backbone(x)[-1]
186 | feat = torch.mean(feat, dim=(2, 3))
187 | feat = self.classifier(feat)
188 | return feat
189 |
190 | def init_weight(self):
191 | nn.init.xavier_normal_(self.classifier.weight)
192 | if not self.classifier.bias is None:
193 | nn.init.constant_(self.classifier.bias, 0)
194 |
195 |
196 | if __name__ == "__main__":
197 | x = torch.randn(2, 3, 224, 224)
198 | lb = torch.randint(0, 10, (2,)).long()
199 |
200 | net = WideResnetBackbone()
201 | out = net(x)
202 | print(out[0].size())
203 | del net, out
204 |
205 | net = WideResnet(n_classes=10)
206 | criteria = nn.CrossEntropyLoss()
207 | out = net(x)
208 | loss = criteria(out, lb)
209 | loss.backward()
210 | print(out.size())
211 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv_python==4.1.0.25
2 | torch==1.3.1+cu100
3 | matplotlib==3.1.1
4 | numpy==1.17.2
5 | tensorboardX==2.0
6 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | export OMP_NUM_THREADS=1
2 | export MKL_NUM_THREADS=1
3 |
4 | python train.py
5 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import random
3 |
4 | import time
5 | import argparse
6 | import os
7 | import sys
8 |
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | from models.model import WideResnet
16 | from datasets.cifar import get_train_loader, get_val_loader
17 | from label_guessor import LabelGuessor
18 | from lr_scheduler import WarmupCosineLrScheduler
19 | from models.ema import EMA
20 |
21 | from utils import accuracy, setup_default_logging, interleave, de_interleave
22 |
23 | from utils import AverageMeter
24 |
25 |
26 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
27 |
28 |
29 | def set_model(args):
30 | model = WideResnet(n_classes=10 if args.dataset == 'CIFAR10' else 100,
31 | k=args.wresnet_k, n=args.wresnet_n) # wresnet-28-2
32 |
33 | model.train()
34 | model.cuda()
35 | criteria_x = nn.CrossEntropyLoss().cuda()
36 | criteria_u = nn.CrossEntropyLoss(reduction='none').cuda()
37 | return model, criteria_x, criteria_u
38 |
39 |
40 | def train_one_epoch(epoch,
41 | model,
42 | criteria_x,
43 | criteria_u,
44 | optim,
45 | lr_schdlr,
46 | ema,
47 | dltrain_x,
48 | dltrain_u,
49 | lb_guessor,
50 | lambda_u,
51 | n_iters,
52 | logger,
53 | ):
54 | model.train()
55 | # loss_meter, loss_x_meter, loss_u_meter, loss_u_real_meter = [], [], [], []
56 | loss_meter = AverageMeter()
57 | loss_x_meter = AverageMeter()
58 | loss_u_meter = AverageMeter()
59 | loss_u_real_meter = AverageMeter()
60 | # the number of correctly-predicted and gradient-considered unlabeled data
61 | n_correct_u_lbs_meter = AverageMeter()
62 | # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples
63 | n_strong_aug_meter = AverageMeter()
64 | mask_meter = AverageMeter()
65 |
66 | epoch_start = time.time() # start time
67 | dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
68 | for it in range(n_iters):
69 | ims_x_weak, ims_x_strong, lbs_x = next(dl_x)
70 | ims_u_weak, ims_u_strong, lbs_u_real = next(dl_u)
71 |
72 | lbs_x = lbs_x.cuda()
73 | lbs_u_real = lbs_u_real.cuda()
74 |
75 | # --------------------------------------
76 |
77 | bt = ims_x_weak.size(0)
78 | mu = int(ims_u_weak.size(0) // bt)
79 | imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).cuda()
80 | imgs = interleave(imgs, 2 * mu + 1)
81 | logits = model(imgs)
82 | logits = de_interleave(logits, 2 * mu + 1)
83 |
84 | logits_x = logits[:bt]
85 | logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu)
86 |
87 | loss_x = criteria_x(logits_x, lbs_x)
88 |
89 | with torch.no_grad():
90 | probs = torch.softmax(logits_u_w, dim=1)
91 | scores, lbs_u_guess = torch.max(probs, dim=1)
92 | mask = scores.ge(0.95).float()
93 |
94 | loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean()
95 | loss = loss_x + lambda_u * loss_u
96 | loss_u_real = (F.cross_entropy(logits_u_s, lbs_u_real) * mask).mean()
97 |
98 | # --------------------------------------
99 |
100 |
101 | # mask, lbs_u_guess = lb_guessor(model, ims_u_weak.cuda())
102 | # n_x = ims_x_weak.size(0)
103 | # ims_x_u = torch.cat([ims_x_weak, ims_u_strong]).cuda()
104 | # logits_x_u = model(ims_x_u)
105 | # logits_x, logits_u = logits_x_u[:n_x], logits_x_u[n_x:]
106 | # loss_x = criteria_x(logits_x, lbs_x)
107 | # loss_u = (criteria_u(logits_u, lbs_u_guess) * mask).mean()
108 | # loss = loss_x + lambda_u * loss_u
109 | # loss_u_real = (F.cross_entropy(logits_u, lbs_u_real) * mask).mean()
110 |
111 | optim.zero_grad()
112 | loss.backward()
113 | optim.step()
114 | ema.update_params()
115 | lr_schdlr.step()
116 |
117 | loss_meter.update(loss.item())
118 | loss_x_meter.update(loss_x.item())
119 | loss_u_meter.update(loss_u.item())
120 | loss_u_real_meter.update(loss_u_real.item())
121 | mask_meter.update(mask.mean().item())
122 |
123 | corr_u_lb = (lbs_u_guess == lbs_u_real).float() * mask
124 | n_correct_u_lbs_meter.update(corr_u_lb.sum().item())
125 | n_strong_aug_meter.update(mask.sum().item())
126 |
127 | if (it + 1) % 512 == 0:
128 | t = time.time() - epoch_start
129 |
130 | lr_log = [pg['lr'] for pg in optim.param_groups]
131 | lr_log = sum(lr_log) / len(lr_log)
132 |
133 | logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. loss_u_real: {:.4f}. "
134 | "n_correct_u: {:.2f}/{:.2f}. Mask:{:.4f} . LR: {:.4f}. Time: {:.2f}".format(
135 | epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, loss_u_real_meter.avg,
136 | n_correct_u_lbs_meter.avg, n_strong_aug_meter.avg, mask_meter.avg, lr_log, t))
137 |
138 | epoch_start = time.time()
139 |
140 | ema.update_buffer()
141 | return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, loss_u_real_meter.avg, mask_meter.avg
142 |
143 |
144 | def evaluate(ema, dataloader, criterion):
145 | # using EMA params to evaluate performance
146 | ema.apply_shadow()
147 | ema.model.eval()
148 | ema.model.cuda()
149 |
150 | loss_meter = AverageMeter()
151 | top1_meter = AverageMeter()
152 | top5_meter = AverageMeter()
153 |
154 | # matches = []
155 | with torch.no_grad():
156 | for ims, lbs in dataloader:
157 | ims = ims.cuda()
158 | lbs = lbs.cuda()
159 | logits = ema.model(ims)
160 | loss = criterion(logits, lbs)
161 | scores = torch.softmax(logits, dim=1)
162 | top1, top5 = accuracy(scores, lbs, (1, 5))
163 | loss_meter.update(loss.item())
164 | top1_meter.update(top1.item())
165 | top5_meter.update(top5.item())
166 |
167 | # note roll back model current params to continue training
168 | ema.restore()
169 | return top1_meter.avg, top5_meter.avg, loss_meter.avg
170 |
171 |
172 | def main():
173 | parser = argparse.ArgumentParser(description=' FixMatch Training')
174 | parser.add_argument('--wresnet-k', default=2, type=int,
175 | help='width factor of wide resnet')
176 | parser.add_argument('--wresnet-n', default=28, type=int,
177 | help='depth of wide resnet')
178 | parser.add_argument('--dataset', type=str, default='CIFAR10',
179 | help='number of classes in dataset')
180 | # parser.add_argument('--n-classes', type=int, default=100,
181 | # help='number of classes in dataset')
182 | parser.add_argument('--n-labeled', type=int, default=40,
183 | help='number of labeled samples for training')
184 | parser.add_argument('--n-epoches', type=int, default=1024,
185 | help='number of training epoches')
186 | parser.add_argument('--batchsize', type=int, default=64,
187 | help='train batch size of labeled samples')
188 | parser.add_argument('--mu', type=int, default=7,
189 | help='factor of train batch size of unlabeled samples')
190 | parser.add_argument('--thr', type=float, default=0.95,
191 | help='pseudo label threshold')
192 | parser.add_argument('--n-imgs-per-epoch', type=int, default=64 * 1024,
193 | help='number of training images for each epoch')
194 | parser.add_argument('--lam-u', type=float, default=1.,
195 | help='coefficient of unlabeled loss')
196 | parser.add_argument('--ema-alpha', type=float, default=0.999,
197 | help='decay rate for ema module')
198 | parser.add_argument('--lr', type=float, default=0.03,
199 | help='learning rate for training')
200 | parser.add_argument('--weight-decay', type=float, default=5e-4,
201 | help='weight decay')
202 | parser.add_argument('--momentum', type=float, default=0.9,
203 | help='momentum for optimizer')
204 | parser.add_argument('--seed', type=int, default=-1,
205 | help='seed for random behaviors, no seed if negtive')
206 |
207 | args = parser.parse_args()
208 |
209 | logger, writer = setup_default_logging(args)
210 | logger.info(dict(args._get_kwargs()))
211 |
212 | # global settings
213 | # torch.multiprocessing.set_sharing_strategy('file_system')
214 | if args.seed > 0:
215 | torch.manual_seed(args.seed)
216 | random.seed(args.seed)
217 | np.random.seed(args.seed)
218 | # torch.backends.cudnn.deterministic = True
219 |
220 | n_iters_per_epoch = args.n_imgs_per_epoch // args.batchsize # 1024
221 | n_iters_all = n_iters_per_epoch * args.n_epoches # 1024 * 1024
222 |
223 | logger.info("***** Running training *****")
224 | logger.info(f" Task = {args.dataset}@{args.n_labeled}")
225 | logger.info(f" Num Epochs = {n_iters_per_epoch}")
226 | logger.info(f" Batch size per GPU = {args.batchsize}")
227 | # logger.info(f" Total train batch size = {args.batch_size * args.world_size}")
228 | logger.info(f" Total optimization steps = {n_iters_all}")
229 |
230 | model, criteria_x, criteria_u = set_model(args)
231 | logger.info("Total params: {:.2f}M".format(
232 | sum(p.numel() for p in model.parameters()) / 1e6))
233 |
234 | dltrain_x, dltrain_u = get_train_loader(
235 | args.dataset, args.batchsize, args.mu, n_iters_per_epoch, L=args.n_labeled)
236 | dlval = get_val_loader(dataset=args.dataset, batch_size=64, num_workers=2)
237 |
238 | lb_guessor = LabelGuessor(thresh=args.thr)
239 |
240 | ema = EMA(model, args.ema_alpha)
241 |
242 | wd_params, non_wd_params = [], []
243 | for name, param in model.named_parameters():
244 | # if len(param.size()) == 1:
245 | if 'bn' in name:
246 | non_wd_params.append(param) # bn.weight, bn.bias and classifier.bias
247 | # print(name)
248 | else:
249 | wd_params.append(param)
250 | param_list = [
251 | {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
252 | optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay,
253 | momentum=args.momentum, nesterov=True)
254 | lr_schdlr = WarmupCosineLrScheduler(
255 | optim, max_iter=n_iters_all, warmup_iter=0
256 | )
257 |
258 | train_args = dict(
259 | model=model,
260 | criteria_x=criteria_x,
261 | criteria_u=criteria_u,
262 | optim=optim,
263 | lr_schdlr=lr_schdlr,
264 | ema=ema,
265 | dltrain_x=dltrain_x,
266 | dltrain_u=dltrain_u,
267 | lb_guessor=lb_guessor,
268 | lambda_u=args.lam_u,
269 | n_iters=n_iters_per_epoch,
270 | logger=logger
271 | )
272 | best_acc = -1
273 | best_epoch = 0
274 | logger.info('-----------start training--------------')
275 | for epoch in range(args.n_epoches):
276 | train_loss, loss_x, loss_u, loss_u_real, mask_mean = train_one_epoch(epoch, **train_args)
277 | # torch.cuda.empty_cache()
278 |
279 | top1, top5, valid_loss = evaluate(ema, dlval, criteria_x)
280 |
281 | writer.add_scalars('train/1.loss', {'train': train_loss,
282 | 'test': valid_loss}, epoch)
283 | writer.add_scalar('train/2.train_loss_x', loss_x, epoch)
284 | writer.add_scalar('train/3.train_loss_u', loss_u, epoch)
285 | writer.add_scalar('train/4.train_loss_u_real', loss_u_real, epoch)
286 | writer.add_scalar('train/5.mask_mean', mask_mean, epoch)
287 | writer.add_scalars('test/1.test_acc', {'top1': top1, 'top5': top5}, epoch)
288 | # writer.add_scalar('test/2.test_loss', loss, epoch)
289 |
290 | # best_acc = top1 if best_acc < top1 else best_acc
291 | if best_acc < top1:
292 | best_acc = top1
293 | best_epoch = epoch
294 |
295 | logger.info("Epoch {}. Top1: {:.4f}. Top5: {:.4f}. best_acc: {:.4f} in epoch{}".
296 | format(epoch, top1, top5, best_acc, best_epoch))
297 |
298 | writer.close()
299 |
300 |
301 | if __name__ == '__main__':
302 | main()
303 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import logging
3 | import os
4 | import sys
5 | import torch
6 |
7 | # from torch.utils.tensorboard import SummaryWriter
8 | from tensorboardX import SummaryWriter
9 |
10 |
11 | def interleave(x, bt):
12 | s = list(x.shape)
13 | return torch.reshape(torch.transpose(x.reshape([-1, bt] + s[1:]), 1, 0), [-1] + s[1:])
14 |
15 |
16 | def de_interleave(x, bt):
17 | s = list(x.shape)
18 | return torch.reshape(torch.transpose(x.reshape([bt, -1] + s[1:]), 1, 0), [-1] + s[1:])
19 |
20 |
21 | def setup_default_logging(args, default_level=logging.INFO,
22 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"):
23 | output_dir = os.path.join(args.dataset, f'x{args.n_labeled}')
24 | os.makedirs(output_dir, exist_ok=True)
25 |
26 | writer = SummaryWriter(comment=f'{args.dataset}_{args.n_labeled}')
27 |
28 | logger = logging.getLogger('train')
29 |
30 | logging.basicConfig( # unlike the root logger, a custom logger can’t be configured using basicConfig()
31 | filename=os.path.join(output_dir, f'{time_str()}.log'),
32 | format=format,
33 | datefmt="%m/%d/%Y %H:%M:%S",
34 | level=default_level)
35 |
36 | # print
37 | # file_handler = logging.FileHandler()
38 | console_handler = logging.StreamHandler(sys.stdout)
39 | console_handler.setLevel(default_level)
40 | console_handler.setFormatter(logging.Formatter(format))
41 | logger.addHandler(console_handler)
42 |
43 | return logger, writer
44 |
45 |
46 | def accuracy(output, target, topk=(1,)):
47 | """Computes the precision@k for the specified values of k"""
48 | maxk = max(topk)
49 | batch_size = target.size(0)
50 |
51 | _, pred = output.topk(maxk, 1, largest=True, sorted=True) # return value, indices
52 | pred = pred.t()
53 | correct = pred.eq(target.view(1, -1).expand_as(pred))
54 |
55 | res = []
56 | for k in topk:
57 | correct_k = correct[:k].view(-1).float().sum(0)
58 | res.append(correct_k.mul_(100.0 / batch_size))
59 | return res
60 |
61 |
62 | class AverageMeter(object):
63 | """
64 | Computes and stores the average and current value
65 |
66 | """
67 |
68 | def __init__(self):
69 | self.reset()
70 |
71 | def reset(self):
72 | self.val = 0
73 | self.avg = 0
74 | self.sum = 0
75 | self.count = 0
76 |
77 | def update(self, val, n=1):
78 | self.val = val
79 | self.sum += val * n
80 | self.count += n
81 | # self.avg = self.sum / (self.count + 1e-20)
82 | self.avg = self.sum / self.count
83 |
84 |
85 | def time_str(fmt=None):
86 | if fmt is None:
87 | fmt = '%Y-%m-%d_%H:%M:%S'
88 |
89 | # time.strftime(format[, t])
90 | return datetime.today().strftime(fmt)
91 |
92 |
93 | if __name__ == '__main__':
94 |
95 | a = torch.tensor(range(30))
96 | a_ = interleave(a, 15)
97 | a__ = de_interleave(a_, 15)
98 | print(a, a_, a__)
99 |
--------------------------------------------------------------------------------