├── .gitignore
├── LICENSE
├── README.md
├── gaussian
├── Gaussian.ipynb
├── baselines_mi_gauss_results_128.pdf
├── baselines_mi_values_128.npz
├── baselines_mi_values_128_alphas.npz
├── baselines_mi_values_128_betas.npz
├── baselines_mi_values_128_gammas.npz
├── baselines_mi_values_cubic_128.npz
├── baselines_mi_values_cubic_128_alphas.npz
├── baselines_mi_values_cubic_128_betas.npz
├── baselines_mi_values_cubic_128_gammas.npz
├── baselines_train_objs_values_128.npz
├── baselines_train_objs_values_128_alphas.npz
├── baselines_train_objs_values_128_betas.npz
├── baselines_train_objs_values_128_gammas.npz
├── baselines_train_objs_values_cubic_128.npz
├── baselines_train_objs_values_cubic_128_alphas.npz
├── baselines_train_objs_values_cubic_128_betas.npz
├── baselines_train_objs_values_cubic_128_gammas.npz
└── src
│ ├── __init__.py
│ ├── estimators.py
│ ├── models.py
│ ├── train.py
│ └── utils.py
└── vision
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── data.py
├── data_util.py
├── gpu_pretrain_finetune_cifar.sh
├── lars_optimizer.py
├── model.py
├── model_util.py
├── objective.py
├── requirements.txt
├── resnet.py
├── results_summary.txt
├── run.py
├── tpu_pretrain_finetune_resnet128.sh
└── tpu_pretrain_finetune_resnet50.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Martin Ma
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Relative Predictive Coding
2 | Project page for paper [Self-supervised Representation Learning with Relative Predictive Coding](https://arxiv.org/abs/2103.11275).
3 |
4 | The codebase consists of the following parts:
5 |
6 | ## Vision Task
7 |
8 | Contrastive Self-Supervised Learning on visual tasks.
9 | See `vision` folder for more details.
10 |
11 | ## Gaussian Task
12 |
13 | Mutual Information Estimation on two Gaussians with varied correlation.
14 | See `gaussian/Gaussian.ipynb` for more details.
15 |
16 | ## Cite
17 |
18 | [RPC paper](https://arxiv.org/abs/2103.11275):
19 |
20 | ```
21 | @article{tsai2021self,
22 | title={Self-supervised representation learning with relative predictive coding},
23 | author={Tsai, Yao-Hung Hubert and Ma, Martin Q and Yang, Muqiao and Zhao, Han and Morency, Louis-Philippe and Salakhutdinov, Ruslan},
24 | journal={arXiv preprint arXiv:2103.11275},
25 | year={2021}
26 | }
27 | ```
28 |
29 |
--------------------------------------------------------------------------------
/gaussian/baselines_mi_gauss_results_128.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_gauss_results_128.pdf
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_128.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_128_alphas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128_alphas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_128_betas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128_betas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_128_gammas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128_gammas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_cubic_128.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_cubic_128_alphas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128_alphas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_cubic_128_betas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128_betas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_mi_values_cubic_128_gammas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128_gammas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_128.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_128_alphas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128_alphas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_128_betas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128_betas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_128_gammas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128_gammas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_cubic_128.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_cubic_128_alphas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128_alphas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_cubic_128_betas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128_betas.npz
--------------------------------------------------------------------------------
/gaussian/baselines_train_objs_values_cubic_128_gammas.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128_gammas.npz
--------------------------------------------------------------------------------
/gaussian/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/src/__init__.py
--------------------------------------------------------------------------------
/gaussian/src/estimators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import sys
7 |
8 | from functools import partial
9 |
10 | from src.utils import *
11 |
12 | import random
13 |
14 | def nwj_lower_bound_obj(scores):
15 | return tuba_lower_bound(scores - 1.)
16 |
17 | def mine_lower_bound(f, buffer=None, momentum=0.9):
18 | if buffer is None:
19 | buffer = torch.tensor(1.0).cuda()
20 | first_term = f.diag().mean()
21 |
22 | buffer_update = logmeanexp_nodiag(f).exp()
23 | with torch.no_grad():
24 | second_term = logmeanexp_nodiag(f)
25 | buffer_new = buffer * momentum + buffer_update * (1 - momentum)
26 | buffer_new = torch.clamp(buffer_new, min=1e-4)
27 | third_term_no_grad = buffer_update / buffer_new
28 |
29 | third_term_grad = buffer_update / buffer_new
30 |
31 | return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update
32 |
33 | def chi_lower_bound_obj(f, alpha, beta, gamma):
34 | f_diag = f.diag()
35 | first_term = (f_diag - 0.5 * beta * (f_diag ** 2)).mean()
36 | n = f.size(0)
37 | f_offdiag = f.flatten()[1:].view(n-1, n+1)[:,:-1].reshape(n, n-1)
38 | # f_offdiag = f.masked_fill_(torch.eye(n, n).byte().cuda(), 0)
39 | second_term = (alpha * f_offdiag + 0.5 * gamma * (f_offdiag ** 2)).mean()
40 | return first_term - second_term
41 |
42 | def dv_upper_lower_bound_obj(f):
43 | """DV lower bound, but upper bounded by using log outside."""
44 | first_term = f.diag().mean()
45 | second_term = logmeanexp_nodiag(f)
46 |
47 | return first_term - second_term
48 |
49 | def tuba_lower_bound(scores, log_baseline=None):
50 | if log_baseline is not None:
51 | scores -= log_baseline[:, None]
52 | batch_size = scores.size(0)
53 |
54 | # First term is an expectation over samples from the joint,
55 | # which are the diagonal elmements of the scores matrix.
56 | joint_term = scores.diag().mean()
57 |
58 | # Second term is an expectation over samples from the marginal,
59 | # which are the off-diagonal elements of the scores matrix.
60 | marg_term = logmeanexp_nodiag(scores).exp()
61 | return 1. + joint_term - marg_term
62 |
63 | def infonce_lower_bound_obj(scores):
64 | nll = scores.diag().mean() - scores.logsumexp(dim=1)
65 | # Alternative implementation:
66 | # nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size))
67 | mi = torch.tensor(scores.size(0)).float().log() + nll
68 | mi = mi.mean()
69 | return mi
70 |
71 | def log_density_ratio_mi(f):
72 | return torch.log(torch.clamp(f, min=1e-4)).diag().mean()
73 |
74 | def direct_log_density_ratio_mi(f):
75 | return f.diag().mean()
76 |
77 | def dv_clip_upper_lower_bound(f, alpha=1.0, clip=None):
78 | z = renorm_q(f, alpha, clip)
79 | dv_clip = f.diag().mean() - z
80 |
81 | return dv_clip
82 |
83 | def js_fgan_lower_bound_obj(f):
84 | f_diag = f.diag()
85 | first_term = -F.softplus(-f_diag).mean()
86 | n = f.size(0)
87 | second_term = (torch.sum(F.softplus(f)) -
88 | torch.sum(F.softplus(f_diag))) / (n * (n - 1.))
89 | return first_term - second_term
90 |
91 |
92 | def log_density_ratio_mi_chi(f, alpha, beta, gamma):
93 | f_diag = f.diag().mean()
94 | true_ratio = (f_diag * gamma + alpha) / (1. - f_diag * beta)
95 | return torch.log(torch.clamp(true_ratio, min=1e-4))
96 |
97 | def MI_Estimator(f, train_type='nwj_lower_bound_obj', eval_type='nwj_lower_bound_obj',
98 | **kwargs):
99 | if train_type == 'tuba_lower_bound' or train_type == 'mine_lower_bound'\
100 | or train_type == 'chi_lower_bound_obj':
101 | if train_type != 'chi_lower_bound_obj':
102 | assert train_type == eval_type
103 | train_val = getattr(sys.modules[__name__], train_type)(f, **kwargs)
104 | else:
105 | train_val = getattr(sys.modules[__name__], train_type)(f)
106 | if train_type == eval_type:
107 | return train_val, train_val
108 |
109 | if train_type == 'nwj_lower_bound_obj' and eval_type == 'direct_log_density_ratio_mi':
110 | eval_val = getattr(sys.modules[__name__], eval_type)(f-1.)
111 | elif eval_type == 'tuba_lower_bound' or eval_type == 'dv_clip_upper_lower_bound'\
112 | or eval_type == 'mine_lower_bound' or eval_type == 'log_density_ratio_mi_chi':
113 | eval_val = getattr(sys.modules[__name__], eval_type)(f, **kwargs)
114 | # note that especially when we use JS to train, and use nwj to evaluate
115 | elif eval_type == 'nwj_lower_bound_obj':
116 | eval_val = getattr(sys.modules[__name__], eval_type)(f+1., **kwargs)
117 | else:
118 | eval_val = getattr(sys.modules[__name__], eval_type)(f)
119 |
120 | with torch.no_grad():
121 | eval_train = eval_val - train_val
122 |
123 | return train_val + eval_train, train_val
124 |
125 | def nwj_lower_bound(f):
126 | return MI_Estimator(f, train_type='nwj_lower_bound_obj', eval_type='nwj_lower_bound_obj')
127 |
128 | def infonce_lower_bound(f):
129 | return MI_Estimator(f, train_type='infonce_lower_bound_obj', eval_type='infonce_lower_bound_obj')
130 |
131 | def js_lower_bound(f):
132 | return MI_Estimator(f, train_type='js_fgan_lower_bound_obj', eval_type='nwj_lower_bound_obj')
133 |
134 | def dv_upper_lower_bound(f):
135 | return MI_Estimator(f, train_type='dv_upper_lower_bound_obj', eval_type='dv_upper_lower_bound_obj')
136 |
137 | def smile_lower_bound(f, alpha=1.0, clip=5.0):
138 | return MI_Estimator(f, train_type='js_fgan_lower_bound_obj',
139 | eval_type='dv_clip_upper_lower_bound', alpha=alpha, clip=clip)
140 |
141 | def chi_lower_bound(f, alpha=0.01, beta = 0.005, gamma = 0.995):
142 | return MI_Estimator(f, train_type='chi_lower_bound_obj', eval_type='log_density_ratio_mi_chi', alpha=alpha, beta=beta, gamma=gamma)
--------------------------------------------------------------------------------
/gaussian/src/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | def mlp(dim, hidden_dim, output_dim, layers, activation):
8 | activation = {
9 | 'relu': nn.ReLU,
10 | 'tanh': nn.Tanh,
11 | }[activation]
12 |
13 | seq = [nn.Linear(dim, hidden_dim), activation()]
14 | for _ in range(layers):
15 | seq += [nn.Linear(hidden_dim, hidden_dim), activation()]
16 | seq += [nn.Linear(hidden_dim, output_dim)]
17 |
18 | return nn.Sequential(*seq)
19 |
20 | class SeparableCritic(nn.Module):
21 | def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs):
22 | super(SeparableCritic, self).__init__()
23 | self._g = mlp(dim, hidden_dim, embed_dim, layers, activation)
24 | self._h = mlp(dim, hidden_dim, embed_dim, layers, activation)
25 |
26 | def forward(self, x, y):
27 | scores = torch.matmul(self._h(y), self._g(x).t())
28 | return scores
29 |
30 |
31 | class ConcatCritic(nn.Module):
32 | def __init__(self, dim, hidden_dim, layers, activation, **extra_kwargs):
33 | super(ConcatCritic, self).__init__()
34 | self._f = mlp(dim * 2, hidden_dim, 1, layers, activation)
35 |
36 | def forward(self, x, y):
37 | batch_size = x.size(0)
38 | # Tile all possible combinations of x and y
39 | x_tiled = torch.stack([x] * batch_size, dim=0)
40 | y_tiled = torch.stack([y] * batch_size, dim=1)
41 | # xy is [batch_size * batch_size, x_dim + y_dim]
42 | xy_pairs = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [
43 | batch_size * batch_size, -1])
44 | # Compute scores for each x_i, y_j pair.
45 | scores = self._f(xy_pairs)
46 | return torch.reshape(scores, [batch_size, batch_size]).t()
47 |
--------------------------------------------------------------------------------
/gaussian/src/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | from src.utils import *
8 | from src.estimators import *
9 | from src.models import *
10 |
11 | def estimate_mutual_information(estimator, x, y, critic_fn,
12 | baseline_fn=None, alpha_logit=None, clamping_values=None, **kwargs):
13 | x, y = x.cuda(), y.cuda()
14 | scores = critic_fn(x, y)
15 | if clamping_values is not None:
16 | scores = torch.clamp(scores, min=clamping_values[0], max=clamping_values[1])
17 | if baseline_fn is not None:
18 | # Some baselines' output is (batch_size, 1) which we remove here.
19 | log_baseline = torch.squeeze(baseline_fn(y))
20 | if estimator == 'nwj':
21 | return nwj_lower_bound(scores)
22 | elif estimator == 'infonce':
23 | return infonce_lower_bound(scores)
24 | elif estimator == 'js':
25 | return js_lower_bound(scores)
26 | elif estimator == 'dv':
27 | return dv_upper_lower_bound(scores)
28 | elif estimator == 'smile':
29 | return smile_lower_bound(scores, **kwargs)
30 | elif estimator == 'gen-chi':
31 | return chi_lower_bound(scores, **kwargs)
32 |
33 |
34 | def train_estimator(critic_params, data_params, mi_params, opt_params, **kwargs):
35 |
36 | CRITICS = {
37 | 'separable': SeparableCritic,
38 | 'concat': ConcatCritic,
39 | }
40 |
41 | BASELINES = {
42 | 'constant': lambda: None,
43 | 'unnormalized': lambda: mlp(dim=data_params['dim'], \
44 | hidden_dim=512, output_dim=1, layers=2, activation='relu').cuda(),
45 | }
46 |
47 | critic = CRITICS[mi_params.get('critic', 'separable')](
48 | rho=None, **critic_params).cuda()
49 | baseline = BASELINES[mi_params.get('baseline', 'constant')]()
50 |
51 | opt_crit = optim.Adam(critic.parameters(), lr=opt_params['learning_rate'])
52 | if isinstance(baseline, nn.Module):
53 | opt_base = optim.Adam(baseline.parameters(),
54 | lr=opt_params['learning_rate'])
55 | else:
56 | opt_base = None
57 |
58 | def train_step(rho, data_params, mi_params):
59 | opt_crit.zero_grad()
60 | if isinstance(baseline, nn.Module):
61 | opt_base.zero_grad()
62 |
63 | if mi_params['critic'] == 'conditional':
64 | critic_ = CRITICS['conditional'](rho=rho).cuda()
65 | else:
66 | critic_ = critic
67 |
68 | x, y = sample_correlated_gaussian(
69 | dim=data_params['dim'], rho=rho,\
70 | batch_size=data_params['batch_size'], cubic=data_params['cubic'])
71 | if False:
72 | mi, p_norm = estimate_mutual_information(
73 | mi_params['estimator'], x, y, critic_, baseline,\
74 | mi_params.get('alpha_logit', None), **kwargs)
75 | else:
76 | mi, train_obj = estimate_mutual_information(
77 | mi_params['estimator'], x, y, critic_, baseline,\
78 | mi_params.get('alpha_logit', None), **kwargs)
79 | loss = -mi
80 |
81 | loss.backward()
82 | opt_crit.step()
83 | if isinstance(baseline, nn.Module):
84 | opt_base.step()
85 |
86 | if False:
87 | return mi, p_norm
88 | else:
89 | return mi, train_obj
90 |
91 | mis = mi_schedule(opt_params['iterations'])
92 | rhos = mi_to_rho(data_params['dim'], mis)
93 |
94 | if False:
95 | estimates = []
96 | p_norms = []
97 | for i in range(opt_params['iterations']):
98 | mi, p_norm = train_step(
99 | rhos[i], data_params, mi_params)
100 | mi = mi.detach().cpu().numpy()
101 | p_norm = p_norm.detach().cpu().numpy()
102 | estimates.append(mi)
103 | p_norms.append(p_norm)
104 |
105 | return np.array(estimates), np.array(p_norms)
106 | else:
107 | estimates = []
108 | train_objs = []
109 | for i in range(opt_params['iterations']):
110 | mi, train_obj = train_step(
111 | rhos[i], data_params, mi_params)
112 | mi = mi.detach().cpu().numpy()
113 | train_obj = train_obj.detach().cpu().numpy()
114 | estimates.append(mi)
115 | train_objs.append(train_obj)
116 |
117 | return np.array(estimates), np.array(train_objs)
118 |
--------------------------------------------------------------------------------
/gaussian/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, cubic=None):
8 | """Generate samples from a correlated Gaussian distribution."""
9 | x, eps = torch.chunk(torch.randn(batch_size, 2 * dim), 2, dim=1)
10 | y = rho * x + torch.sqrt(torch.tensor(1. - rho**2).float()) * eps
11 |
12 | if cubic is not None:
13 | y = y ** 3
14 |
15 | return x, y
16 |
17 |
18 | def rho_to_mi(dim, rho):
19 | return -0.5 * np.log(1-rho**2) * dim
20 |
21 |
22 | def mi_to_rho(dim, mi):
23 | return np.sqrt(1-np.exp(-2.0 / dim * mi))
24 |
25 |
26 | def mi_schedule(n_iter):
27 | """Generate schedule for increasing correlation over time."""
28 | mis = np.round(np.linspace(0.5, 5.5-1e-9, n_iter)) * 2.0
29 | return mis.astype(np.float32)
30 |
31 | def logmeanexp_diag(x):
32 | batch_size = x.size(0)
33 |
34 | logsumexp = torch.logsumexp(x.diag(), dim=(0,))
35 | num_elem = batch_size
36 |
37 | return logsumexp - torch.log(torch.tensor(num_elem).float()).cuda()
38 |
39 |
40 | def logmeanexp_nodiag(x, dim=None, device='cuda'):
41 | batch_size = x.size(0)
42 | if dim is None:
43 | dim = (0, 1)
44 |
45 | logsumexp = torch.logsumexp(
46 | x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim)
47 |
48 | try:
49 | if len(dim) == 1:
50 | num_elem = batch_size - 1.
51 | else:
52 | num_elem = batch_size * (batch_size - 1.)
53 | except:
54 | num_elem = batch_size - 1
55 | return logsumexp - torch.log(torch.tensor(num_elem)).to(device)
56 |
57 |
58 | def renorm_q(f, alpha=1.0, clip=None):
59 | if clip is not None:
60 | f = torch.clamp(f * alpha, -clip, clip)
61 | z = logmeanexp_nodiag(f * alpha, dim=(0, 1))
62 | return z
63 |
64 |
65 | def disc_renorm_q(f):
66 | batch_size = f.size(0)
67 | z = torch.zeros(1, requires_grad=True, device='cuda')
68 |
69 | opt = optim.SGD([z], lr=0.001)
70 | for i in range(10):
71 | opt.zero_grad()
72 |
73 | first_term = -F.softplus(z - f).diag().mean()
74 | st = -F.softplus(f - z)
75 | second_term = (st - st.diag().diag()).sum() / \
76 | (batch_size * (batch_size - 1.))
77 | total = first_term + second_term
78 |
79 | total.backward(retain_graph=True)
80 | opt.step()
81 |
82 | if total.item() <= -2 * np.log(2):
83 | break
84 |
85 | return z
86 |
87 |
88 | def renorm_p(f, alpha=1.0):
89 | z = logmeanexp_diag(-f * alpha)
90 | return z
91 |
92 | def estimate_p_norm(f, alpha=1.0):
93 | z = renorm_q(f, alpha)
94 | # f = renorm_p(f, alpha)
95 | # f = renorm_q(f, alpha)
96 | f = f - z
97 | f = -f
98 |
99 | return f.diag().exp().mean()
100 |
--------------------------------------------------------------------------------
/vision/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | SimCLR needs to maintain permanent compatibility with the pre-trained model
4 | files, so we do not plan to make any major changes to this library (other than
5 | what was promised in the README). However, we can accept small patches related
6 | to re-factoring and documentation. To submit contributes, there are just a few
7 | small guidelines you need to follow.
8 |
9 | ## Contributor License Agreement
10 |
11 | Contributions to this project must be accompanied by a Contributor License
12 | Agreement. You (or your employer) retain the copyright to your contribution;
13 | this simply gives us permission to use and redistribute your contributions as
14 | part of the project. Head over to to see
15 | your current agreements on file or to sign a new one.
16 |
17 | You generally only need to submit a CLA once, so if you've already submitted one
18 | (even if it was for a different project), you probably don't need to do it
19 | again.
20 |
21 | ## Code reviews
22 |
23 | All submissions, including submissions by project members, require review. We
24 | use GitHub pull requests for this purpose. Consult
25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
26 | information on using pull requests.
27 |
28 | ## Community Guidelines
29 |
30 | This project follows
31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
32 |
--------------------------------------------------------------------------------
/vision/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/vision/README.md:
--------------------------------------------------------------------------------
1 | # Self-supervised Representation Learning with Relative Predictive Coding
2 |
3 | This is a codebase for the vision experiments.
4 |
5 | ## Enviroment setup
6 |
7 | For CIFAR-10/CIFAR-100 experiments, our code can run on a *single* GPU. It does not support multi-GPUs, for reasons such as global BatchNorm and contrastive loss across cores.
8 |
9 | Our models are also trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining on ImageNet.
10 |
11 | We recommend using conda to avoid compatibility issue:
12 | ```
13 | conda create -n rpc_vision python=3.6
14 | conda activate rpc_vision
15 | pip install -r requirements.txt
16 | conda install cudatoolkit cudnn
17 | ```
18 |
19 | ## Pretraining and Fine-Tuning
20 |
21 | First create a checkpoint directory:
22 |
23 | ```
24 | mkdir checkpoint
25 | ```
26 |
27 | To pretrain and finetune the model on CIFAR-10 with a *single* GPU, try the following command:
28 |
29 | ```
30 | bash gpu_pretrain_finetune_cifar.sh
31 | ```
32 |
33 | For different hyper-parameter specification, please change the corresponding parameter in file `gpu_pretrain_finetune_cifar.sh`.
34 |
35 | To pretrain the model on ImageNet with Cloud TPUs, first check out the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for basic information on how to use Google Cloud TPUs.
36 |
37 | Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for [tensorflow_datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012), please set the following enviroment variables:
38 |
39 | ```
40 | TPU_NAME=
41 | STORAGE_BUCKET=gs://
42 | DATA_DIR=$STORAGE_BUCKET/
43 | MODEL_DIR=$STORAGE_BUCKET/
44 | ```
45 |
46 | in the following files which pretrain and fine-tune a ResNet-50 or a ResNet-152 on ImageNet:
47 |
48 | ```
49 | bash tpu_pretrain_finetune_resnet50.sh
50 | bash tpu_pretrain_finetune_resnet128.sh
51 | ```
52 |
53 | To request checkpoints of the trained models from the commands above, please contact Martin via qianlim@andrew.cmu.edu.
54 |
55 | ## Cite
56 |
57 | [RPC paper](https://arxiv.org/abs/2103.11275):
58 |
59 | ```
60 | @article{tsai2021self,
61 | title={Self-supervised representation learning with relative predictive coding},
62 | author={Tsai, Yao-Hung Hubert and Ma, Martin Q and Yang, Muqiao and Zhao, Han and Morency, Louis-Philippe and Salakhutdinov, Ruslan},
63 | journal={arXiv preprint arXiv:2103.11275},
64 | year={2021}
65 | }
66 | ```
67 |
68 | This code base is adapted from [SimCLR](https://github.com/google-research/simclr). The major change is in the file `objective.py`
69 |
--------------------------------------------------------------------------------
/vision/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data pipeline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from absl import flags
24 |
25 | import data_util as data_util
26 | import tensorflow.compat.v1 as tf
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | def pad_to_batch(dataset, batch_size):
32 | """Pad Tensors to specified batch size.
33 |
34 | Args:
35 | dataset: An instance of tf.data.Dataset.
36 | batch_size: The number of samples per batch of input requested.
37 |
38 | Returns:
39 | An instance of tf.data.Dataset that yields the same Tensors with the same
40 | structure as the original padded to batch_size along the leading
41 | dimension.
42 |
43 | Raises:
44 | ValueError: If the dataset does not comprise any tensors; if a tensor
45 | yielded by the dataset has an unknown number of dimensions or is a
46 | scalar; or if it can be statically determined that tensors comprising
47 | a single dataset element will have different leading dimensions.
48 | """
49 | def _pad_to_batch(*args):
50 | """Given Tensors yielded by a Dataset, pads all to the batch size."""
51 | flat_args = tf.nest.flatten(args)
52 |
53 | for tensor in flat_args:
54 | if tensor.shape.ndims is None:
55 | raise ValueError(
56 | 'Unknown number of dimensions for tensor %s.' % tensor.name)
57 | if tensor.shape.ndims == 0:
58 | raise ValueError('Tensor %s is a scalar.' % tensor.name)
59 |
60 | # This will throw if flat_args is empty. However, as of this writing,
61 | # tf.data.Dataset.map will throw first with an internal error, so we do
62 | # not check this case explicitly.
63 | first_tensor = flat_args[0]
64 | first_tensor_shape = tf.shape(first_tensor)
65 | first_tensor_batch_size = first_tensor_shape[0]
66 | difference = batch_size - first_tensor_batch_size
67 |
68 | for i, tensor in enumerate(flat_args):
69 | control_deps = []
70 | if i != 0:
71 | # Check that leading dimensions of this tensor matches the first,
72 | # either statically or dynamically. (If the first dimensions of both
73 | # tensors are statically known, the we have to check the static
74 | # shapes at graph construction time or else we will never get to the
75 | # dynamic assertion.)
76 | if (first_tensor.shape[:1].is_fully_defined() and
77 | tensor.shape[:1].is_fully_defined()):
78 | if first_tensor.shape[0] != tensor.shape[0]:
79 | raise ValueError(
80 | 'Batch size of dataset tensors does not match. %s '
81 | 'has shape %s, but %s has shape %s' % (
82 | first_tensor.name, first_tensor.shape,
83 | tensor.name, tensor.shape))
84 | else:
85 | curr_shape = tf.shape(tensor)
86 | control_deps = [tf.Assert(
87 | tf.equal(curr_shape[0], first_tensor_batch_size),
88 | ['Batch size of dataset tensors %s and %s do not match. '
89 | 'Shapes are' % (tensor.name, first_tensor.name), curr_shape,
90 | first_tensor_shape])]
91 |
92 | with tf.control_dependencies(control_deps):
93 | # Pad to batch_size along leading dimension.
94 | flat_args[i] = tf.pad(
95 | tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1))
96 | flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:])
97 |
98 | return tf.nest.pack_sequence_as(args, flat_args)
99 |
100 | return dataset.map(_pad_to_batch)
101 |
102 |
103 | def build_input_fn(builder, is_training):
104 | """Build input function.
105 |
106 | Args:
107 | builder: TFDS builder for specified dataset.
108 | is_training: Whether to build in training mode.
109 |
110 | Returns:
111 | A function that accepts a dict of params and returns a tuple of images and
112 | features, to be used as the input_fn in TPUEstimator.
113 | """
114 | def _input_fn(params):
115 | """Inner input function."""
116 | preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True)
117 | preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False)
118 | num_classes = builder.info.features['label'].num_classes
119 |
120 | def map_fn(image, label):
121 | """Produces multiple transformations of the same batch."""
122 | if FLAGS.train_mode == 'pretrain':
123 | xs = []
124 | for _ in range(2): # Two transformations
125 | xs.append(preprocess_fn_pretrain(image))
126 | image = tf.concat(xs, -1)
127 | label = tf.zeros([num_classes])
128 | else:
129 | image = preprocess_fn_finetune(image)
130 | label = tf.one_hot(label, num_classes)
131 | return image, label, 1.0
132 |
133 | dataset = builder.as_dataset(
134 | split=FLAGS.train_split if is_training else FLAGS.eval_split,
135 | shuffle_files=is_training, as_supervised=True)
136 | if FLAGS.cache_dataset:
137 | dataset = dataset.cache()
138 | if is_training:
139 | buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10
140 | dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier)
141 | dataset = dataset.repeat(-1)
142 | dataset = dataset.map(map_fn,
143 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
144 | dataset = dataset.batch(params['batch_size'], drop_remainder=is_training)
145 | dataset = pad_to_batch(dataset, params['batch_size'])
146 | images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next()
147 |
148 | return images, {'labels': labels, 'mask': mask}
149 | return _input_fn
150 |
151 |
152 | def get_preprocess_fn(is_training, is_pretrain):
153 | """Get function that accepts an image and returns a preprocessed image."""
154 | # Disable test cropping for small images (e.g. CIFAR)
155 | if FLAGS.image_size <= 32:
156 | test_crop = False
157 | else:
158 | test_crop = True
159 | return functools.partial(
160 | data_util.preprocess_image,
161 | height=FLAGS.image_size,
162 | width=FLAGS.image_size,
163 | is_training=is_training,
164 | color_distort=is_pretrain,
165 | test_crop=test_crop)
166 |
--------------------------------------------------------------------------------
/vision/data_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Data preprocessing and augmentation."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | from absl import flags
24 |
25 | import tensorflow.compat.v1 as tf
26 |
27 | FLAGS = flags.FLAGS
28 |
29 | CROP_PROPORTION = 0.875 # Standard for ImageNet.
30 |
31 |
32 | def random_apply(func, p, x):
33 | """Randomly apply function func to x with probability p."""
34 | return tf.cond(
35 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32),
36 | tf.cast(p, tf.float32)),
37 | lambda: func(x),
38 | lambda: x)
39 |
40 |
41 | def random_brightness(image, max_delta, impl='simclrv2'):
42 | """A multiplicative vs additive change of brightness."""
43 | if impl == 'simclrv2':
44 | factor = tf.random_uniform(
45 | [], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta)
46 | image = image * factor
47 | elif impl == 'simclrv1':
48 | image = random_brightness(image, max_delta=max_delta)
49 | else:
50 | raise ValueError('Unknown impl {} for random brightness.'.format(impl))
51 | return image
52 |
53 |
54 | def to_grayscale(image, keep_channels=True):
55 | image = tf.image.rgb_to_grayscale(image)
56 | if keep_channels:
57 | image = tf.tile(image, [1, 1, 3])
58 | return image
59 |
60 |
61 | def color_jitter(image,
62 | strength,
63 | random_order=True):
64 | """Distorts the color of the image.
65 |
66 | Args:
67 | image: The input image tensor.
68 | strength: the floating number for the strength of the color augmentation.
69 | random_order: A bool, specifying whether to randomize the jittering order.
70 |
71 | Returns:
72 | The distorted image tensor.
73 | """
74 | brightness = 0.8 * strength
75 | contrast = 0.8 * strength
76 | saturation = 0.8 * strength
77 | hue = 0.2 * strength
78 | if random_order:
79 | return color_jitter_rand(image, brightness, contrast, saturation, hue)
80 | else:
81 | return color_jitter_nonrand(image, brightness, contrast, saturation, hue)
82 |
83 |
84 | def color_jitter_nonrand(image, brightness=0, contrast=0, saturation=0, hue=0):
85 | """Distorts the color of the image (jittering order is fixed).
86 |
87 | Args:
88 | image: The input image tensor.
89 | brightness: A float, specifying the brightness for color jitter.
90 | contrast: A float, specifying the contrast for color jitter.
91 | saturation: A float, specifying the saturation for color jitter.
92 | hue: A float, specifying the hue for color jitter.
93 |
94 | Returns:
95 | The distorted image tensor.
96 | """
97 | with tf.name_scope('distort_color'):
98 | def apply_transform(i, x, brightness, contrast, saturation, hue):
99 | """Apply the i-th transformation."""
100 | if brightness != 0 and i == 0:
101 | x = random_brightness(x, max_delta=brightness)
102 | elif contrast != 0 and i == 1:
103 | x = tf.image.random_contrast(
104 | x, lower=1-contrast, upper=1+contrast)
105 | elif saturation != 0 and i == 2:
106 | x = tf.image.random_saturation(
107 | x, lower=1-saturation, upper=1+saturation)
108 | elif hue != 0:
109 | x = tf.image.random_hue(x, max_delta=hue)
110 | return x
111 |
112 | for i in range(4):
113 | image = apply_transform(i, image, brightness, contrast, saturation, hue)
114 | image = tf.clip_by_value(image, 0., 1.)
115 | return image
116 |
117 |
118 | def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0):
119 | """Distorts the color of the image (jittering order is random).
120 |
121 | Args:
122 | image: The input image tensor.
123 | brightness: A float, specifying the brightness for color jitter.
124 | contrast: A float, specifying the contrast for color jitter.
125 | saturation: A float, specifying the saturation for color jitter.
126 | hue: A float, specifying the hue for color jitter.
127 |
128 | Returns:
129 | The distorted image tensor.
130 | """
131 | with tf.name_scope('distort_color'):
132 | def apply_transform(i, x):
133 | """Apply the i-th transformation."""
134 | def brightness_foo():
135 | if brightness == 0:
136 | return x
137 | else:
138 | return random_brightness(x, max_delta=brightness)
139 | def contrast_foo():
140 | if contrast == 0:
141 | return x
142 | else:
143 | return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast)
144 | def saturation_foo():
145 | if saturation == 0:
146 | return x
147 | else:
148 | return tf.image.random_saturation(
149 | x, lower=1-saturation, upper=1+saturation)
150 | def hue_foo():
151 | if hue == 0:
152 | return x
153 | else:
154 | return tf.image.random_hue(x, max_delta=hue)
155 | x = tf.cond(tf.less(i, 2),
156 | lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
157 | lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
158 | return x
159 |
160 | perm = tf.random_shuffle(tf.range(4))
161 | for i in range(4):
162 | image = apply_transform(perm[i], image)
163 | image = tf.clip_by_value(image, 0., 1.)
164 | return image
165 |
166 |
167 | def _compute_crop_shape(
168 | image_height, image_width, aspect_ratio, crop_proportion):
169 | """Compute aspect ratio-preserving shape for central crop.
170 |
171 | The resulting shape retains `crop_proportion` along one side and a proportion
172 | less than or equal to `crop_proportion` along the other side.
173 |
174 | Args:
175 | image_height: Height of image to be cropped.
176 | image_width: Width of image to be cropped.
177 | aspect_ratio: Desired aspect ratio (width / height) of output.
178 | crop_proportion: Proportion of image to retain along the less-cropped side.
179 |
180 | Returns:
181 | crop_height: Height of image after cropping.
182 | crop_width: Width of image after cropping.
183 | """
184 | image_width_float = tf.cast(image_width, tf.float32)
185 | image_height_float = tf.cast(image_height, tf.float32)
186 |
187 | def _requested_aspect_ratio_wider_than_image():
188 | crop_height = tf.cast(tf.rint(
189 | crop_proportion / aspect_ratio * image_width_float), tf.int32)
190 | crop_width = tf.cast(tf.rint(
191 | crop_proportion * image_width_float), tf.int32)
192 | return crop_height, crop_width
193 |
194 | def _image_wider_than_requested_aspect_ratio():
195 | crop_height = tf.cast(
196 | tf.rint(crop_proportion * image_height_float), tf.int32)
197 | crop_width = tf.cast(tf.rint(
198 | crop_proportion * aspect_ratio *
199 | image_height_float), tf.int32)
200 | return crop_height, crop_width
201 |
202 | return tf.cond(
203 | aspect_ratio > image_width_float / image_height_float,
204 | _requested_aspect_ratio_wider_than_image,
205 | _image_wider_than_requested_aspect_ratio)
206 |
207 |
208 | def center_crop(image, height, width, crop_proportion):
209 | """Crops to center of image and rescales to desired size.
210 |
211 | Args:
212 | image: Image Tensor to crop.
213 | height: Height of image to be cropped.
214 | width: Width of image to be cropped.
215 | crop_proportion: Proportion of image to retain along the less-cropped side.
216 |
217 | Returns:
218 | A `height` x `width` x channels Tensor holding a central crop of `image`.
219 | """
220 | shape = tf.shape(image)
221 | image_height = shape[0]
222 | image_width = shape[1]
223 | crop_height, crop_width = _compute_crop_shape(
224 | image_height, image_width, height / width, crop_proportion)
225 | offset_height = ((image_height - crop_height) + 1) // 2
226 | offset_width = ((image_width - crop_width) + 1) // 2
227 | image = tf.image.crop_to_bounding_box(
228 | image, offset_height, offset_width, crop_height, crop_width)
229 |
230 | image = tf.image.resize_bicubic([image], [height, width])[0]
231 |
232 | return image
233 |
234 |
235 | def distorted_bounding_box_crop(image,
236 | bbox,
237 | min_object_covered=0.1,
238 | aspect_ratio_range=(0.75, 1.33),
239 | area_range=(0.05, 1.0),
240 | max_attempts=100,
241 | scope=None):
242 | """Generates cropped_image using one of the bboxes randomly distorted.
243 |
244 | See `tf.image.sample_distorted_bounding_box` for more documentation.
245 |
246 | Args:
247 | image: `Tensor` of image data.
248 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
249 | where each coordinate is [0, 1) and the coordinates are arranged
250 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
251 | image.
252 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
253 | area of the image must contain at least this fraction of any bounding
254 | box supplied.
255 | aspect_ratio_range: An optional list of `float`s. The cropped area of the
256 | image must have an aspect ratio = width / height within this range.
257 | area_range: An optional list of `float`s. The cropped area of the image
258 | must contain a fraction of the supplied image within in this range.
259 | max_attempts: An optional `int`. Number of attempts at generating a cropped
260 | region of the image of the specified constraints. After `max_attempts`
261 | failures, return the entire image.
262 | scope: Optional `str` for name scope.
263 | Returns:
264 | (cropped image `Tensor`, distorted bbox `Tensor`).
265 | """
266 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
267 | shape = tf.shape(image)
268 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
269 | shape,
270 | bounding_boxes=bbox,
271 | min_object_covered=min_object_covered,
272 | aspect_ratio_range=aspect_ratio_range,
273 | area_range=area_range,
274 | max_attempts=max_attempts,
275 | use_image_if_no_bounding_boxes=True)
276 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
277 |
278 | # Crop the image to the specified bounding box.
279 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
280 | target_height, target_width, _ = tf.unstack(bbox_size)
281 | image = tf.image.crop_to_bounding_box(
282 | image, offset_y, offset_x, target_height, target_width)
283 |
284 | return image
285 |
286 |
287 | def crop_and_resize(image, height, width):
288 | """Make a random crop and resize it to height `height` and width `width`.
289 |
290 | Args:
291 | image: Tensor representing the image.
292 | height: Desired image height.
293 | width: Desired image width.
294 |
295 | Returns:
296 | A `height` x `width` x channels Tensor holding a random crop of `image`.
297 | """
298 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
299 | aspect_ratio = width / height
300 | image = distorted_bounding_box_crop(
301 | image,
302 | bbox,
303 | min_object_covered=0.1,
304 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
305 | area_range=(0.08, 1.0),
306 | max_attempts=100,
307 | scope=None)
308 | return tf.image.resize_bicubic([image], [height, width])[0]
309 |
310 |
311 | def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
312 | """Blurs the given image with separable convolution.
313 |
314 |
315 | Args:
316 | image: Tensor of shape [height, width, channels] and dtype float to blur.
317 | kernel_size: Integer Tensor for the size of the blur kernel. This is should
318 | be an odd number. If it is an even number, the actual kernel size will be
319 | size + 1.
320 | sigma: Sigma value for gaussian operator.
321 | padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
322 |
323 | Returns:
324 | A Tensor representing the blurred image.
325 | """
326 | radius = tf.to_int32(kernel_size / 2)
327 | kernel_size = radius * 2 + 1
328 | x = tf.to_float(tf.range(-radius, radius + 1))
329 | blur_filter = tf.exp(
330 | -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.to_float(sigma), 2.0)))
331 | blur_filter /= tf.reduce_sum(blur_filter)
332 | # One vertical and one horizontal filter.
333 | blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
334 | blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
335 | num_channels = tf.shape(image)[-1]
336 | blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
337 | blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
338 | expand_batch_dim = image.shape.ndims == 3
339 | if expand_batch_dim:
340 | # Tensorflow requires batched input to convolutions, which we can fake with
341 | # an extra dimension.
342 | image = tf.expand_dims(image, axis=0)
343 | blurred = tf.nn.depthwise_conv2d(
344 | image, blur_h, strides=[1, 1, 1, 1], padding=padding)
345 | blurred = tf.nn.depthwise_conv2d(
346 | blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
347 | if expand_batch_dim:
348 | blurred = tf.squeeze(blurred, axis=0)
349 | return blurred
350 |
351 |
352 | def random_crop_with_resize(image, height, width, p=1.0):
353 | """Randomly crop and resize an image.
354 |
355 | Args:
356 | image: `Tensor` representing an image of arbitrary size.
357 | height: Height of output image.
358 | width: Width of output image.
359 | p: Probability of applying this transformation.
360 |
361 | Returns:
362 | A preprocessed image `Tensor`.
363 | """
364 | def _transform(image): # pylint: disable=missing-docstring
365 | image = crop_and_resize(image, height, width)
366 | return image
367 | return random_apply(_transform, p=p, x=image)
368 |
369 |
370 | def random_color_jitter(image, p=1.0):
371 | def _transform(image):
372 | color_jitter_t = functools.partial(
373 | color_jitter, strength=FLAGS.color_jitter_strength)
374 | image = random_apply(color_jitter_t, p=0.8, x=image)
375 | return random_apply(to_grayscale, p=0.2, x=image)
376 | return random_apply(_transform, p=p, x=image)
377 |
378 |
379 | def random_blur(image, height, width, p=1.0):
380 | """Randomly blur an image.
381 |
382 | Args:
383 | image: `Tensor` representing an image of arbitrary size.
384 | height: Height of output image.
385 | width: Width of output image.
386 | p: probability of applying this transformation.
387 |
388 | Returns:
389 | A preprocessed image `Tensor`.
390 | """
391 | del width
392 | def _transform(image):
393 | sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
394 | return gaussian_blur(
395 | image, kernel_size=height//10, sigma=sigma, padding='SAME')
396 | return random_apply(_transform, p=p, x=image)
397 |
398 |
399 | def batch_random_blur(images_list, height, width, blur_probability=0.5):
400 | """Apply efficient batch data transformations.
401 |
402 | Args:
403 | images_list: a list of image tensors.
404 | height: the height of image.
405 | width: the width of image.
406 | blur_probability: the probaility to apply the blur operator.
407 |
408 | Returns:
409 | Preprocessed feature list.
410 | """
411 | def generate_selector(p, bsz):
412 | shape = [bsz, 1, 1, 1]
413 | selector = tf.cast(
414 | tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p),
415 | tf.float32)
416 | return selector
417 |
418 | new_images_list = []
419 | for images in images_list:
420 | images_new = random_blur(images, height, width, p=1.)
421 | selector = generate_selector(blur_probability, tf.shape(images)[0])
422 | images = images_new * selector + images * (1 - selector)
423 | images = tf.clip_by_value(images, 0., 1.)
424 | new_images_list.append(images)
425 |
426 | return new_images_list
427 |
428 |
429 | def preprocess_for_train(image, height, width,
430 | color_distort=True, crop=True, flip=True):
431 | """Preprocesses the given image for training.
432 |
433 | Args:
434 | image: `Tensor` representing an image of arbitrary size.
435 | height: Height of output image.
436 | width: Width of output image.
437 | color_distort: Whether to apply the color distortion.
438 | crop: Whether to crop the image.
439 | flip: Whether or not to flip left and right of an image.
440 |
441 | Returns:
442 | A preprocessed image `Tensor`.
443 | """
444 | if crop:
445 | image = random_crop_with_resize(image, height, width)
446 | if flip:
447 | image = tf.image.random_flip_left_right(image)
448 | if color_distort:
449 | image = random_color_jitter(image)
450 | image = tf.reshape(image, [height, width, 3])
451 | image = tf.clip_by_value(image, 0., 1.)
452 | return image
453 |
454 |
455 | def preprocess_for_eval(image, height, width, crop=True):
456 | """Preprocesses the given image for evaluation.
457 |
458 | Args:
459 | image: `Tensor` representing an image of arbitrary size.
460 | height: Height of output image.
461 | width: Width of output image.
462 | crop: Whether or not to (center) crop the test images.
463 |
464 | Returns:
465 | A preprocessed image `Tensor`.
466 | """
467 | if crop:
468 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)
469 | image = tf.reshape(image, [height, width, 3])
470 | image = tf.clip_by_value(image, 0., 1.)
471 | return image
472 |
473 |
474 | def preprocess_image(image, height, width, is_training=False,
475 | color_distort=True, test_crop=True):
476 | """Preprocesses the given image.
477 |
478 | Args:
479 | image: `Tensor` representing an image of arbitrary size.
480 | height: Height of output image.
481 | width: Width of output image.
482 | is_training: `bool` for whether the preprocessing is for training.
483 | color_distort: whether to apply the color distortion.
484 | test_crop: whether or not to extract a central crop of the images
485 | (as for standard ImageNet evaluation) during the evaluation.
486 |
487 | Returns:
488 | A preprocessed image `Tensor` of range [0, 1].
489 | """
490 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
491 | if is_training:
492 | return preprocess_for_train(image, height, width, color_distort)
493 | else:
494 | return preprocess_for_eval(image, height, width, test_crop)
495 |
--------------------------------------------------------------------------------
/vision/gpu_pretrain_finetune_cifar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # hyperparameters
3 | STORAGE_BUCKET="checkpoint"
4 | BATCH_SIZE=512
5 | EPOCH=100
6 | TEMP=128
7 | LR=1.0
8 | LR_SCALE='linear'
9 | W_DECAY=1e-4
10 | DATASET='cifar10'
11 | IMAGE_SIZE=32
12 | SK_RATIO=0.0
13 | WIDTH_MUL=1
14 | RESNET_DEPTH=18
15 |
16 | # different for different losses
17 | LOSS_TYPE='chi'
18 | HIDDEN_NORM=True
19 | if [ $LOSS_TYPE == 'chi' ]
20 | then
21 | # check hidden norm
22 | HIDDEN_NORM=False
23 | ALPHA=0.0
24 | BETA=0.001
25 | GAMMA=1.0
26 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}_ALPHA_${ALPHA}_BETA_${BETA}_GAMMA_${GAMMA}"
27 | CUDA_VISIBLE_DEVICES=0 python run.py --train_mode=pretrain \
28 | --train_batch_size=$BATCH_SIZE \
29 | --train_epochs=$EPOCH \
30 | --temperature=$TEMP \
31 | --learning_rate=$LR \
32 | --learning_rate_scaling=$LR_SCALE \
33 | --weight_decay=$W_DECAY \
34 | --dataset=$DATASET \
35 | --image_size=$IMAGE_SIZE \
36 | --eval_split=test \
37 | --model_dir=$MODEL_DIR \
38 | --use_tpu=False \
39 | --train_summary_steps=0 \
40 | --sk_ratio $SK_RATIO \
41 | --width_multiplier $WIDTH_MUL \
42 | --resnet_depth $RESNET_DEPTH \
43 | --loss_type $LOSS_TYPE \
44 | --alpha=$ALPHA \
45 | --beta=$BETA \
46 | --gamma=$GAMMA \
47 | --hidden_norm=$HIDDEN_NORM
48 |
49 | # NCE, JS, WPC, etc
50 | else
51 | TEMP=0.5
52 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}"
53 | CUDA_VISIBLE_DEVICES=0 python run.py --train_mode=pretrain \
54 | --train_batch_size=$BATCH_SIZE \
55 | --train_epochs=$EPOCH \
56 | --temperature=$TEMP \
57 | --learning_rate=$LR \
58 | --learning_rate_scaling=$LR_SCALE \
59 | --weight_decay=$W_DECAY \
60 | --dataset=$DATASET \
61 | --image_size=$IMAGE_SIZE \
62 | --eval_split=test \
63 | --model_dir=$MODEL_DIR \
64 | --use_tpu=False \
65 | --train_summary_steps=0 \
66 | --sk_ratio $SK_RATIO \
67 | --width_multiplier $WIDTH_MUL \
68 | --resnet_depth $RESNET_DEPTH \
69 | --loss_type $LOSS_TYPE \
70 | --hidden_norm=$HIDDEN_NORM
71 |
72 | fi
73 |
74 | ##############################################################################################
75 | #####################Fine tune
76 | ##############################################################################################
77 | CHKPT_DIR=$MODEL_DIR
78 | FINETUNE_AFTER_BLOCK=0
79 | LR=0.1
80 | WD=0
81 | EPOCHS=100
82 | WARMUP_EPOCHS=0
83 | MODEL_DIR="${CHKPT_DIR}_ft_BS_${BATCH_SIZE}_FINETUNE_AFTER_BLOCK_${FINETUNE_AFTER_BLOCK}_LR_${LR}_WD_${WD}_EPOCH_${EPOCHS}_WARMUP_EPOCHS_${WARMUP_EPOCHS}"
84 | echo $MODEL_DIR
85 | if [ $LOSS_TYPE == "chi" ]
86 | then
87 | echo "Running chi"
88 | CUDA_VISIBLE_DEVICES=0 python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=$DATASET --image_size=$IMAGE_SIZE --eval_split=test --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=False --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE --alpha $ALPHA --beta $BETA --gamma $GAMMA
89 | else
90 | CUDA_VISIBLE_DEVICES=0 python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=$DATASET --image_size=$IMAGE_SIZE --eval_split=test --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=False --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE
91 | fi
92 |
--------------------------------------------------------------------------------
/vision/lars_optimizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions and classes related to optimization (weight updates)."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import re
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | EETA_DEFAULT = 0.001
27 |
28 |
29 | class LARSOptimizer(tf.train.Optimizer):
30 | """Layer-wise Adaptive Rate Scaling for large batch training.
31 |
32 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
33 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
34 | """
35 |
36 | def __init__(self,
37 | learning_rate,
38 | momentum=0.9,
39 | use_nesterov=False,
40 | weight_decay=0.0,
41 | exclude_from_weight_decay=None,
42 | exclude_from_layer_adaptation=None,
43 | classic_momentum=True,
44 | eeta=EETA_DEFAULT,
45 | name="LARSOptimizer"):
46 | """Constructs a LARSOptimizer.
47 |
48 | Args:
49 | learning_rate: A `float` for learning rate.
50 | momentum: A `float` for momentum.
51 | use_nesterov: A 'Boolean' for whether to use nesterov momentum.
52 | weight_decay: A `float` for weight decay.
53 | exclude_from_weight_decay: A list of `string` for variable screening, if
54 | any of the string appears in a variable's name, the variable will be
55 | excluded for computing weight decay. For example, one could specify
56 | the list like ['batch_normalization', 'bias'] to exclude BN and bias
57 | from weight decay.
58 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
59 | for layer adaptation. If it is None, it will be defaulted the same as
60 | exclude_from_weight_decay.
61 | classic_momentum: A `boolean` for whether to use classic (or popular)
62 | momentum. The learning rate is applied during momeuntum update in
63 | classic momentum, but after momentum for popular momentum.
64 | eeta: A `float` for scaling of learning rate when computing trust ratio.
65 | name: The name for the scope.
66 | """
67 | super(LARSOptimizer, self).__init__(False, name)
68 |
69 | self.learning_rate = learning_rate
70 | self.momentum = momentum
71 | self.weight_decay = weight_decay
72 | self.use_nesterov = use_nesterov
73 | self.classic_momentum = classic_momentum
74 | self.eeta = eeta
75 | self.exclude_from_weight_decay = exclude_from_weight_decay
76 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
77 | # arg is None.
78 | if exclude_from_layer_adaptation:
79 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
80 | else:
81 | self.exclude_from_layer_adaptation = exclude_from_weight_decay
82 |
83 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
84 | if global_step is None:
85 | global_step = tf.train.get_or_create_global_step()
86 | new_global_step = global_step + 1
87 |
88 | assignments = []
89 | for (grad, param) in grads_and_vars:
90 | if grad is None or param is None:
91 | continue
92 |
93 | param_name = param.op.name
94 |
95 | v = tf.get_variable(
96 | name=param_name + "/Momentum",
97 | shape=param.shape.as_list(),
98 | dtype=tf.float32,
99 | trainable=False,
100 | initializer=tf.zeros_initializer())
101 |
102 | if self._use_weight_decay(param_name):
103 | grad += self.weight_decay * param
104 |
105 | if self.classic_momentum:
106 | trust_ratio = 1.0
107 | if self._do_layer_adaptation(param_name):
108 | w_norm = tf.norm(param, ord=2)
109 | g_norm = tf.norm(grad, ord=2)
110 | trust_ratio = tf.where(
111 | tf.greater(w_norm, 0), tf.where(
112 | tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm),
113 | 1.0),
114 | 1.0)
115 | scaled_lr = self.learning_rate * trust_ratio
116 |
117 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
118 | if self.use_nesterov:
119 | update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
120 | else:
121 | update = next_v
122 | next_param = param - update
123 | else:
124 | next_v = tf.multiply(self.momentum, v) + grad
125 | if self.use_nesterov:
126 | update = tf.multiply(self.momentum, next_v) + grad
127 | else:
128 | update = next_v
129 |
130 | trust_ratio = 1.0
131 | if self._do_layer_adaptation(param_name):
132 | w_norm = tf.norm(param, ord=2)
133 | v_norm = tf.norm(update, ord=2)
134 | trust_ratio = tf.where(
135 | tf.greater(w_norm, 0), tf.where(
136 | tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm),
137 | 1.0),
138 | 1.0)
139 | scaled_lr = trust_ratio * self.learning_rate
140 | next_param = param - scaled_lr * update
141 |
142 | assignments.extend(
143 | [param.assign(next_param),
144 | v.assign(next_v),
145 | global_step.assign(new_global_step)])
146 | return tf.group(*assignments, name=name)
147 |
148 | def _use_weight_decay(self, param_name):
149 | """Whether to use L2 weight decay for `param_name`."""
150 | if not self.weight_decay:
151 | return False
152 | if self.exclude_from_weight_decay:
153 | for r in self.exclude_from_weight_decay:
154 | if re.search(r, param_name) is not None:
155 | return False
156 | return True
157 |
158 | def _do_layer_adaptation(self, param_name):
159 | """Whether to do layer-wise learning rate adaptation for `param_name`."""
160 | if self.exclude_from_layer_adaptation:
161 | for r in self.exclude_from_layer_adaptation:
162 | if re.search(r, param_name) is not None:
163 | return False
164 | return True
165 |
--------------------------------------------------------------------------------
/vision/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Model specification for SimCLR."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 |
24 | import data_util as data_util
25 | import model_util as model_util
26 | import objective as obj_lib
27 |
28 | import tensorflow.compat.v1 as tf
29 | import tensorflow.compat.v2 as tf2
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | def build_model_fn(model, num_classes, num_train_examples):
35 | """Build model function."""
36 | def model_fn(features, labels, mode, params=None):
37 | """Build model and optimizer."""
38 | is_training = mode == tf.estimator.ModeKeys.TRAIN
39 |
40 | # Check training mode.
41 | if FLAGS.train_mode == 'pretrain':
42 | num_transforms = 2
43 | if FLAGS.fine_tune_after_block > -1:
44 | raise ValueError('Does not support layer freezing during pretraining,'
45 | 'should set fine_tune_after_block<=-1 for safety.')
46 | elif FLAGS.train_mode == 'finetune':
47 | num_transforms = 1
48 | else:
49 | raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))
50 |
51 | # Split channels, and optionally apply extra batched augmentation.
52 | features_list = tf.split(
53 | features, num_or_size_splits=num_transforms, axis=-1)
54 | if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
55 | features_list = data_util.batch_random_blur(
56 | features_list, FLAGS.image_size, FLAGS.image_size)
57 | features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c)
58 |
59 | # Base network forward pass.
60 | with tf.variable_scope('base_model'):
61 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
62 | # Finetune just supervised (linear) head will not update BN stats.
63 | model_train_mode = False
64 | else:
65 | # Pretrain or finetuen anything else will update BN stats.
66 | model_train_mode = is_training
67 | hiddens = model(features, is_training=model_train_mode)
68 |
69 | # Add head and loss.
70 | if FLAGS.train_mode == 'pretrain':
71 | tpu_context = params['context'] if 'context' in params else None
72 | hiddens_proj = model_util.projection_head(hiddens, is_training)
73 | contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
74 | hiddens_proj,
75 | hidden_norm=FLAGS.hidden_norm,
76 | temperature=FLAGS.temperature,
77 | tpu_context=tpu_context if is_training else None,
78 | loss_type=FLAGS.loss_type,
79 | flags=FLAGS)
80 | logits_sup = tf.zeros([params['batch_size'], num_classes])
81 | gradients_penalty = FLAGS.gradient_penalty_weight * obj_lib.add_gradients_penalty(features, model, model_train_mode)
82 | else:
83 | contrast_loss = tf.zeros([])
84 | logits_con = tf.zeros([params['batch_size'], 10])
85 | labels_con = tf.zeros([params['batch_size'], 10])
86 | hiddens = model_util.projection_head(hiddens, is_training)
87 | logits_sup = model_util.supervised_head(
88 | hiddens, num_classes, is_training)
89 | obj_lib.add_supervised_loss(
90 | labels=labels['labels'],
91 | logits=logits_sup,
92 | weights=labels['mask'])
93 |
94 | # Add weight decay to loss, for non-LARS optimizers.
95 | model_util.add_weight_decay(adjust_per_optimizer=True)
96 | loss = tf.losses.get_total_loss()
97 |
98 | if FLAGS.train_mode == 'pretrain':
99 | variables_to_train = tf.trainable_variables()
100 | else:
101 | collection_prefix = 'trainable_variables_inblock_'
102 | variables_to_train = []
103 | for j in range(FLAGS.fine_tune_after_block + 1, 6):
104 | variables_to_train += tf.get_collection(collection_prefix + str(j))
105 | assert variables_to_train, 'variables_to_train shouldn\'t be empty!'
106 |
107 | tf.logging.info('===============Variables to train (begin)===============')
108 | tf.logging.info(variables_to_train)
109 | tf.logging.info('================Variables to train (end)================')
110 |
111 | learning_rate = model_util.learning_rate_schedule(
112 | FLAGS.learning_rate, num_train_examples)
113 |
114 | if is_training:
115 | if FLAGS.train_summary_steps > 0:
116 | # Compute stats for the summary.
117 | prob_con = tf.nn.softmax(logits_con)
118 | entropy_con = - tf.reduce_mean(
119 | tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))
120 |
121 | summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir)
122 | with tf.control_dependencies([summary_writer.init()]):
123 | with summary_writer.as_default():
124 | should_record = tf.math.equal(
125 | tf.math.floormod(tf.train.get_global_step(),
126 | FLAGS.train_summary_steps), 0)
127 | with tf2.summary.record_if(should_record):
128 | contrast_acc = tf.equal(
129 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
130 | contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
131 | label_acc = tf.equal(
132 | tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
133 | label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
134 | tf2.summary.scalar(
135 | 'train_contrast_loss',
136 | contrast_loss,
137 | step=tf.train.get_global_step())
138 | tf2.summary.scalar(
139 | 'train_contrast_acc',
140 | contrast_acc,
141 | step=tf.train.get_global_step())
142 | tf2.summary.scalar(
143 | 'train_label_accuracy',
144 | label_acc,
145 | step=tf.train.get_global_step())
146 | tf2.summary.scalar(
147 | 'contrast_entropy',
148 | entropy_con,
149 | step=tf.train.get_global_step())
150 | tf2.summary.scalar(
151 | 'learning_rate', learning_rate,
152 | step=tf.train.get_global_step())
153 |
154 | optimizer = model_util.get_optimizer(learning_rate)
155 | control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
156 | if FLAGS.train_summary_steps > 0:
157 | control_deps.extend(tf.summary.all_v2_summary_ops())
158 | with tf.control_dependencies(control_deps):
159 | train_op = optimizer.minimize(
160 | loss, global_step=tf.train.get_or_create_global_step(),
161 | var_list=variables_to_train)
162 |
163 | if FLAGS.checkpoint:
164 | def scaffold_fn():
165 | """Scaffold function to restore non-logits vars from checkpoint."""
166 | tf.train.init_from_checkpoint(
167 | FLAGS.checkpoint,
168 | {v.op.name: v.op.name
169 | for v in tf.global_variables(FLAGS.variable_schema)})
170 |
171 | if FLAGS.zero_init_logits_layer:
172 | # Init op that initializes output layer parameters to zeros.
173 | output_layer_parameters = [
174 | var for var in tf.trainable_variables() if var.name.startswith(
175 | 'head_supervised')]
176 | tf.logging.info('Initializing output layer parameters %s to zero',
177 | [x.op.name for x in output_layer_parameters])
178 | with tf.control_dependencies([tf.global_variables_initializer()]):
179 | init_op = tf.group([
180 | tf.assign(x, tf.zeros_like(x))
181 | for x in output_layer_parameters])
182 | return tf.train.Scaffold(init_op=init_op)
183 | else:
184 | return tf.train.Scaffold()
185 | else:
186 | scaffold_fn = None
187 |
188 | return tf.estimator.tpu.TPUEstimatorSpec(
189 | mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
190 | else:
191 |
192 | def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
193 | **kws):
194 | """Inner metric function."""
195 | metrics = {k: tf.metrics.mean(v, weights=mask)
196 | for k, v in kws.items()}
197 | metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
198 | tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
199 | weights=mask)
200 | metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
201 | tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
202 | metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
203 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1),
204 | weights=mask)
205 | metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
206 | tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
207 | return metrics
208 |
209 | metrics = {
210 | 'logits_sup': logits_sup,
211 | 'labels_sup': labels['labels'],
212 | 'logits_con': logits_con,
213 | 'labels_con': labels_con,
214 | 'mask': labels['mask'],
215 | 'contrast_loss': tf.fill((params['batch_size'],), contrast_loss),
216 | 'regularization_loss': tf.fill((params['batch_size'],),
217 | tf.losses.get_regularization_loss()),
218 | }
219 |
220 | return tf.estimator.tpu.TPUEstimatorSpec(
221 | mode=mode,
222 | loss=loss,
223 | eval_metrics=(metric_fn, metrics),
224 | scaffold_fn=None)
225 |
226 | return model_fn
227 |
--------------------------------------------------------------------------------
/vision/model_util.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Network architectures related functions used in SimCLR."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 | from absl import flags
24 |
25 | import resnet
26 | from lars_optimizer import LARSOptimizer
27 |
28 | import tensorflow.compat.v1 as tf
29 |
30 | FLAGS = flags.FLAGS
31 |
32 |
33 | def add_weight_decay(adjust_per_optimizer=True):
34 | """Compute weight decay from flags."""
35 | if adjust_per_optimizer and 'lars' in FLAGS.optimizer:
36 | # Weight decay are taking care of by optimizer for these cases.
37 | # Except for supervised head, which will be added here.
38 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()
39 | if 'head_supervised' in v.name and 'bias' not in v.name]
40 | if l2_losses:
41 | tf.losses.add_loss(
42 | FLAGS.weight_decay * tf.add_n(l2_losses),
43 | tf.GraphKeys.REGULARIZATION_LOSSES)
44 | return
45 |
46 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables()
47 | if 'batch_normalization' not in v.name]
48 | tf.losses.add_loss(
49 | FLAGS.weight_decay * tf.add_n(l2_losses),
50 | tf.GraphKeys.REGULARIZATION_LOSSES)
51 |
52 |
53 | def get_train_steps(num_examples):
54 | """Determine the number of training steps."""
55 | return FLAGS.train_steps or (
56 | num_examples * FLAGS.train_epochs // FLAGS.train_batch_size + 1)
57 |
58 |
59 | def learning_rate_schedule(base_learning_rate, num_examples):
60 | """Build learning rate schedule."""
61 | global_step = tf.train.get_or_create_global_step()
62 | warmup_steps = int(round(
63 | FLAGS.warmup_epochs * num_examples // FLAGS.train_batch_size))
64 | if FLAGS.learning_rate_scaling == 'linear':
65 | scaled_lr = base_learning_rate * FLAGS.train_batch_size / 256.
66 | elif FLAGS.learning_rate_scaling == 'sqrt':
67 | scaled_lr = base_learning_rate * math.sqrt(FLAGS.train_batch_size)
68 | else:
69 | raise ValueError('Unknown learning rate scaling {}'.format(
70 | FLAGS.learning_rate_scaling))
71 | learning_rate = (tf.to_float(global_step) / int(warmup_steps) * scaled_lr
72 | if warmup_steps else scaled_lr)
73 |
74 | # Cosine decay learning rate schedule
75 | total_steps = get_train_steps(num_examples)
76 | learning_rate = tf.where(
77 | global_step < warmup_steps, learning_rate,
78 | tf.train.cosine_decay(
79 | scaled_lr,
80 | global_step - warmup_steps,
81 | total_steps - warmup_steps))
82 |
83 | return learning_rate
84 |
85 |
86 | def get_optimizer(learning_rate):
87 | """Returns an optimizer."""
88 | if FLAGS.optimizer == 'momentum':
89 | optimizer = tf.train.MomentumOptimizer(
90 | learning_rate, FLAGS.momentum, use_nesterov=True)
91 | elif FLAGS.optimizer == 'adam':
92 | optimizer = tf.train.AdamOptimizer(
93 | learning_rate)
94 | elif FLAGS.optimizer == 'lars':
95 | optimizer = LARSOptimizer(
96 | learning_rate,
97 | momentum=FLAGS.momentum,
98 | weight_decay=FLAGS.weight_decay,
99 | exclude_from_weight_decay=['batch_normalization', 'bias',
100 | 'head_supervised'])
101 | else:
102 | raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer))
103 |
104 | if FLAGS.use_tpu:
105 | optimizer = tf.tpu.CrossShardOptimizer(optimizer)
106 | return optimizer
107 |
108 |
109 | def linear_layer(x,
110 | is_training,
111 | num_classes,
112 | use_bias=True,
113 | use_bn=False,
114 | name='linear_layer'):
115 | """Linear head for linear evaluation.
116 |
117 | Args:
118 | x: hidden state tensor of shape (bsz, dim).
119 | is_training: boolean indicator for training or test.
120 | num_classes: number of classes.
121 | use_bias: whether or not to use bias.
122 | use_bn: whether or not to use BN for output units.
123 | name: the name for variable scope.
124 |
125 | Returns:
126 | logits of shape (bsz, num_classes)
127 | """
128 | assert x.shape.ndims == 2, x.shape
129 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
130 | x = tf.layers.dense(
131 | inputs=x,
132 | units=num_classes,
133 | use_bias=use_bias and not use_bn,
134 | kernel_initializer=tf.random_normal_initializer(stddev=.01))
135 | if use_bn:
136 | x = resnet.batch_norm_relu(x, is_training, relu=False, center=use_bias)
137 | x = tf.identity(x, '%s_out' % name)
138 | return x
139 |
140 |
141 | def projection_head(hiddens, is_training, name='head_contrastive'):
142 | """Head for projecting hiddens fo contrastive loss."""
143 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
144 | mid_dim = hiddens.shape[-1]
145 | out_dim = FLAGS.proj_out_dim
146 | hiddens_list = [hiddens]
147 | if FLAGS.proj_head_mode == 'none':
148 | pass # directly use the output hiddens as hiddens.
149 | elif FLAGS.proj_head_mode == 'linear':
150 | hiddens = linear_layer(
151 | hiddens, is_training, out_dim,
152 | use_bias=False, use_bn=True, name='l_0')
153 | hiddens_list.append(hiddens)
154 | elif FLAGS.proj_head_mode == 'nonlinear':
155 | for j in range(FLAGS.num_proj_layers):
156 | if j != FLAGS.num_proj_layers - 1:
157 | # for the middle layers, use bias and relu for the output.
158 | dim, bias_relu = mid_dim, True
159 | else:
160 | # for the final layer, neither bias nor relu is used.
161 | dim, bias_relu = FLAGS.proj_out_dim, False
162 | hiddens = linear_layer(
163 | hiddens, is_training, dim,
164 | use_bias=bias_relu, use_bn=True, name='nl_%d'%j)
165 | hiddens = tf.nn.relu(hiddens) if bias_relu else hiddens
166 | hiddens_list.append(hiddens)
167 | else:
168 | raise ValueError('Unknown head projection mode {}'.format(
169 | FLAGS.proj_head_mode))
170 | if FLAGS.train_mode == 'pretrain':
171 | # take the projection head output during pre-training.
172 | hiddens = hiddens_list[-1]
173 | else:
174 | # for checkpoint compatibility, whole projection head is built here.
175 | # but you can select part of projection head during fine-tuning.
176 | hiddens = hiddens_list[FLAGS.ft_proj_selector]
177 | return hiddens
178 |
179 |
180 | def supervised_head(hiddens, num_classes, is_training, name='head_supervised'):
181 | """Add supervised head & also add its variables to inblock collection."""
182 | with tf.variable_scope(name):
183 | logits = linear_layer(hiddens, is_training, num_classes)
184 | for var in tf.trainable_variables():
185 | if var.name.startswith(name):
186 | tf.add_to_collection('trainable_variables_inblock_5', var)
187 | return logits
188 |
--------------------------------------------------------------------------------
/vision/objective.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Contrastive loss functions."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from absl import flags
23 |
24 | import tensorflow.compat.v1 as tf
25 |
26 | from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import
27 |
28 | from tensorflow.python.framework import ops
29 | from tensorflow.python.ops import math_ops
30 | from tensorflow.python.ops import array_ops
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | LARGE_NUM = 1e9
35 |
36 | def add_gradients_penalty(x, model, model_train_mode):
37 | """https://colab.research.google.com/github/timsainb/tensorflow2-generative-models/blob/master/3.0-WGAN-GP-fashion-mnist.ipynb#scrollTo=Wyipg-4oSYb1"""
38 | with tf.GradientTape() as t:
39 | t.watch(x)
40 | hidden = model(x, is_training=model_train_mode)
41 | gradients = t.gradient(hidden, x)
42 | dx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2, 3]))
43 | d_regularizer = tf.reduce_mean((dx - 1.0) ** 2)
44 | return d_regularizer
45 |
46 | def add_supervised_loss(labels, logits, weights, **kwargs):
47 | """Compute loss for model and add it to loss collection."""
48 | return tf.losses.softmax_cross_entropy(labels, logits, weights, **kwargs)
49 |
50 |
51 | def add_contrastive_loss(hidden,
52 | hidden_norm=True,
53 | temperature=1.0,
54 | tpu_context=None,
55 | weights=1.0,
56 | loss_type=None,
57 | flags=None):
58 | """Compute loss for model.
59 |
60 | Args:
61 | hidden: hidden vector (`Tensor`) of shape (bsz, dim).
62 | hidden_norm: whether or not to use normalization on the hidden vector.
63 | temperature: a `floating` number for temperature scaling.
64 | tpu_context: context information for tpu.
65 | weights: a weighting number or vector.
66 |
67 | Returns:
68 | A loss scalar.
69 | The logits for contrastive prediction task.
70 | The labels for contrastive prediction task.
71 | """
72 | # Get (normalized) hidden1 and hidden2.
73 | if hidden_norm:
74 | hidden = tf.math.l2_normalize(hidden, -1)
75 | hidden1, hidden2 = tf.split(hidden, 2, 0)
76 | batch_size = tf.shape(hidden1)[0]
77 |
78 | # Gather hidden1/hidden2 across replicas and create local labels.
79 | if tpu_context is not None:
80 | hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
81 | hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
82 | enlarged_batch_size = tf.shape(hidden1_large)[0]
83 | replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
84 | labels_idx = tf.range(batch_size) + replica_id * batch_size
85 | labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
86 | masks = tf.one_hot(labels_idx, enlarged_batch_size)
87 | else:
88 | hidden1_large = hidden1
89 | hidden2_large = hidden2
90 | labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
91 | masks = tf.one_hot(tf.range(batch_size), batch_size)
92 | # check WPC
93 | if loss_type.lower() == "wpc":
94 | assert flags.gradient_penalty_weight != 0.0
95 | else:
96 | assert flags.gradient_penalty_weight == 0.0
97 | logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
98 | logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
99 | # aa and bb diagonals are not positive samples; positive samples are ab abd ba diagnoals
100 | # thus we want to mask aa and bb diagonals out
101 | if loss_type.lower() == "nce" or loss_type.lower() == "dv" or loss_type.lower() == "wpc":
102 | # NCE loss: minus big number to create cloes to 0 values in softmax
103 | print(loss_type)
104 | logits_aa = logits_aa - masks * LARGE_NUM
105 | logits_bb = logits_bb - masks * LARGE_NUM
106 | else: # otherwise just mask out using 0
107 | logits_aa = logits_aa * (1 - masks)
108 | logits_bb = logits_bb * (1 - masks)
109 | logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
110 | logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
111 | #############################################################################
112 | ### Different losses: nce, chi, js and nwj
113 | ### Pos_scores: positive samples, i.e. joint distribution terms
114 | ### neg_scores: negative samples, i.e. marginal distribution terms
115 | #############################################################################
116 | if loss_type.lower() == "nce":
117 | loss_a = tf.losses.softmax_cross_entropy(
118 | labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
119 | loss_b = tf.losses.softmax_cross_entropy(
120 | labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
121 | elif loss_type.lower() == "chi":
122 | ## Chi squared loss in general form
123 | alpha = flags.alpha
124 | beta = flags.beta
125 | gamma = flags.gamma
126 |
127 | joint_a = labels * tf.concat([logits_ab, logits_aa], 1)
128 | joint_b = labels * tf.concat([logits_ba, logits_bb], 1)
129 | # non-correlated views
130 | marg_a = (1.-labels) * tf.concat([logits_ab, logits_aa], 1)
131 | marg_b = (1.-labels) * tf.concat([logits_ba, logits_bb], 1)
132 | batch_size = tf.cast(batch_size, tf.float32)
133 | joint = 0.5*(tf.reduce_sum(joint_a - 0.5 * beta * joint_a**2) / batch_size)\
134 | + 0.5*(tf.reduce_sum(joint_b - 0.5 * beta * joint_b**2) / batch_size)
135 | # non-correlated views
136 | marg = 0.5*(tf.reduce_sum(alpha * marg_a + 0.5 * gamma * marg_a**2) / (2*batch_size*(batch_size-1.)))\
137 | + 0.5*(tf.reduce_sum(alpha * marg_b + 0.5 * gamma * marg_b**2) / (2*batch_size*(batch_size-1.)))
138 | loss = -1. * (joint - marg)
139 | tf.losses.add_loss(loss)
140 | return loss, logits_ab, labels
141 |
142 | elif loss_type.lower() == "js":
143 | # Jensen Shannon
144 | def js(logits_concat, labels, scope=None):
145 | lbls = math_ops.cast(labels, logits_concat.dtype)
146 | """SHOULD I ADD STOP GRADIENT?"""
147 | bs = math_ops.cast(batch_size, logits_concat.dtype)
148 | pos_scores = tf.reduce_sum(lbls * (-tf.math.softplus(-logits_concat))) / bs
149 | neg_scores = tf.reduce_sum((1 - lbls) * tf.math.softplus(logits_concat)) / ((2 * bs - 1) * bs)
150 | return - (pos_scores - neg_scores)
151 |
152 | loss_a = 0.5 * js(tf.concat([logits_ab, logits_aa], 1), labels)
153 | loss_b = 0.5 * js(tf.concat([logits_ba, logits_bb], 1), labels)
154 | tf.losses.add_loss(loss_a)
155 | tf.losses.add_loss(loss_b)
156 |
157 | elif loss_type.lower() == "nwj":
158 | def nwj(logits_concat, labels, scope=None):
159 | lbls = math_ops.cast(labels, logits_concat.dtype)
160 | """SHOULD I ADD STOP GRADIENT?"""
161 | bs = math_ops.cast(batch_size, logits_concat.dtype)
162 | pos_scores = tf.reduce_sum(lbls * logits_concat) / bs
163 | neg_scores = tf.reduce_sum((1 - lbls) * tf.math.exp(logits_concat - 1)) / ((2 * bs - 1) * bs)
164 | return - (pos_scores - neg_scores)
165 |
166 | loss_a = 0.5 * nwj(tf.concat([logits_ab, logits_aa], 1), labels)
167 | loss_b = 0.5 * nwj(tf.concat([logits_ba, logits_bb], 1), labels)
168 | tf.losses.add_loss(loss_a)
169 | tf.losses.add_loss(loss_b)
170 | elif loss_type.lower() == "dv":
171 | # Donsker and Varadhan
172 | def dv(logits_concat, labels, scope=None):
173 | lbls = math_ops.cast(labels, logits_concat.dtype)
174 | """SHOULD I ADD STOP GRADIENT?"""
175 | bs = math_ops.cast(batch_size, logits_concat.dtype)
176 | pos_scores = tf.reduce_sum(lbls * logits_concat) / bs
177 | neg_scores = tf.math.reduce_logsumexp((1 - lbls) * logits_concat) - tf.math.log((2 * bs - 1) * bs)
178 | return - (pos_scores - neg_scores)
179 |
180 | loss_a = 0.5 * dv(tf.concat([logits_ab, logits_aa], 1), labels)
181 | loss_b = 0.5 * dv(tf.concat([logits_ba, logits_bb], 1), labels)
182 | tf.losses.add_loss(loss_a)
183 | tf.losses.add_loss(loss_b)
184 | elif loss_type.lower() == "wpc":
185 | # Wasserstein Dependency Measure (i.e. Wasserstein Predictive Coding)
186 | # Adding soon
187 | pass # operation performed in model.py
188 |
189 | else:
190 | raise ValueError("Loss not implemented yet; only support {nce, chi, js, nwj}")
191 |
192 | loss = loss_a + loss_b
193 | return loss, logits_ab, labels
194 |
195 |
196 | def tpu_cross_replica_concat(tensor, tpu_context=None):
197 | """Reduce a concatenation of the `tensor` across TPU cores.
198 |
199 | Args:
200 | tensor: tensor to concatenate.
201 | tpu_context: A `TPUContext`. If not set, CPU execution is assumed.
202 |
203 | Returns:
204 | Tensor of the same rank as `tensor` with first dimension `num_replicas`
205 | times larger.
206 | """
207 | if tpu_context is None or tpu_context.num_replicas <= 1:
208 | return tensor
209 |
210 | num_replicas = tpu_context.num_replicas
211 |
212 | with tf.name_scope('tpu_cross_replica_concat'):
213 | # This creates a tensor that is like the input tensor but has an added
214 | # replica dimension as the outermost dimension. On each replica it will
215 | # contain the local values and zeros for all other values that need to be
216 | # fetched from other replicas.
217 | ext_tensor = tf.scatter_nd(
218 | indices=[[xla.replica_id()]],
219 | updates=[tensor],
220 | shape=[num_replicas] + tensor.shape.as_list())
221 |
222 | # As every value is only present on one replica and 0 in all others, adding
223 | # them all together will result in the full tensor on all replicas.
224 | ext_tensor = tf.tpu.cross_replica_sum(ext_tensor)
225 |
226 | # Flatten the replica dimension.
227 | # The first dimension size will be: tensor.shape[0] * num_replicas
228 | # Using [-1] trick to support also scalar input.
229 | return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
230 |
--------------------------------------------------------------------------------
/vision/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | tensorflow-gpu==1.15.2
3 | tensorflow-datasets==3.1.0
4 | tensorflow-hub==0.8.0
5 |
--------------------------------------------------------------------------------
/vision/resnet.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Contains definitions for the post-activation form of Residual Networks.
17 |
18 | Residual networks (ResNets) were proposed in:
19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from absl import flags
28 | import tensorflow.compat.v1 as tf
29 |
30 | from tensorflow.python.tpu import tpu_function # pylint: disable=g-direct-tensorflow-import
31 |
32 |
33 | FLAGS = flags.FLAGS
34 | BATCH_NORM_EPSILON = 1e-5
35 |
36 |
37 | class BatchNormalization(tf.layers.BatchNormalization):
38 | """Batch Normalization layer that supports cross replica computation on TPU.
39 |
40 | This class extends the keras.BatchNormalization implementation by supporting
41 | cross replica means and variances. The base class implementation only computes
42 | moments based on mini-batch per replica (TPU core).
43 |
44 | For detailed information of arguments and implementation, refer to:
45 | https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization
46 | """
47 |
48 | def __init__(self, fused=False, **kwargs):
49 | """Builds the batch normalization layer.
50 |
51 | Arguments:
52 | fused: If `False`, use the system recommended implementation. Only support
53 | `False` in the current implementation.
54 | **kwargs: input augments that are forwarded to
55 | tf.layers.BatchNormalization.
56 | """
57 | if fused in (True, None):
58 | raise ValueError('The TPU version of BatchNormalization does not support '
59 | 'fused=True.')
60 | super(BatchNormalization, self).__init__(fused=fused, **kwargs)
61 |
62 | def _cross_replica_average(self, t):
63 | """Calculates the average value of input tensor across TPU replicas."""
64 | num_shards = tpu_function.get_tpu_context().number_of_shards
65 | return tf.tpu.cross_replica_sum(t) / tf.cast(num_shards, t.dtype)
66 |
67 | def _moments(self, inputs, reduction_axes, keep_dims):
68 | """Compute the mean and variance: it overrides the original _moments."""
69 | shard_mean, shard_variance = super(BatchNormalization, self)._moments(
70 | inputs, reduction_axes, keep_dims=keep_dims)
71 |
72 | num_shards = tpu_function.get_tpu_context().number_of_shards
73 | if num_shards and num_shards > 1:
74 | # Each group has multiple replicas: here we compute group mean/variance by
75 | # aggregating per-replica mean/variance.
76 | group_mean = self._cross_replica_average(shard_mean)
77 | group_variance = self._cross_replica_average(shard_variance)
78 |
79 | # Group variance needs to also include the difference between shard_mean
80 | # and group_mean.
81 | mean_distance = tf.square(group_mean - shard_mean)
82 | group_variance += self._cross_replica_average(mean_distance)
83 | return (group_mean, group_variance)
84 | else:
85 | return (shard_mean, shard_variance)
86 |
87 |
88 | def batch_norm_relu(inputs, is_training, relu=True, init_zero=False,
89 | center=True, scale=True, data_format='channels_last'):
90 | """Performs a batch normalization followed by a ReLU.
91 |
92 | Args:
93 | inputs: `Tensor` of shape `[batch, channels, ...]`.
94 | is_training: `bool` for whether the model is training.
95 | relu: `bool` if False, omits the ReLU operation.
96 | init_zero: `bool` if True, initializes scale parameter of batch
97 | normalization with 0 instead of 1 (default).
98 | center: `bool` whether to add learnable bias factor.
99 | scale: `bool` whether to add learnable scaling factor.
100 | data_format: `str` either "channels_first" for `[batch, channels, height,
101 | width]` or "channels_last for `[batch, height, width, channels]`.
102 |
103 | Returns:
104 | A normalized `Tensor` with the same `data_format`.
105 | """
106 | if init_zero:
107 | gamma_initializer = tf.zeros_initializer()
108 | else:
109 | gamma_initializer = tf.ones_initializer()
110 |
111 | if data_format == 'channels_first':
112 | axis = 1
113 | else:
114 | axis = -1
115 |
116 | if FLAGS.global_bn:
117 | bn_foo = BatchNormalization(
118 | axis=axis,
119 | momentum=FLAGS.batch_norm_decay,
120 | epsilon=BATCH_NORM_EPSILON,
121 | center=center,
122 | scale=scale,
123 | fused=False,
124 | gamma_initializer=gamma_initializer)
125 | inputs = bn_foo(inputs, training=is_training)
126 | else:
127 | inputs = tf.layers.batch_normalization(
128 | inputs=inputs,
129 | axis=axis,
130 | momentum=FLAGS.batch_norm_decay,
131 | epsilon=BATCH_NORM_EPSILON,
132 | center=center,
133 | scale=scale,
134 | training=is_training,
135 | fused=True,
136 | gamma_initializer=gamma_initializer)
137 |
138 | if relu:
139 | inputs = tf.nn.relu(inputs)
140 | return inputs
141 |
142 |
143 | def dropblock(net, is_training, keep_prob, dropblock_size,
144 | data_format='channels_last'):
145 | """DropBlock: a regularization method for convolutional neural networks.
146 |
147 | DropBlock is a form of structured dropout, where units in a contiguous
148 | region of a feature map are dropped together. DropBlock works better than
149 | dropout on convolutional layers due to the fact that activation units in
150 | convolutional layers are spatially correlated.
151 | See https://arxiv.org/pdf/1810.12890.pdf for details.
152 |
153 | Args:
154 | net: `Tensor` input tensor.
155 | is_training: `bool` for whether the model is training.
156 | keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. "None"
157 | means no DropBlock.
158 | dropblock_size: `int` size of blocks to be dropped by DropBlock.
159 | data_format: `str` either "channels_first" for `[batch, channels, height,
160 | width]` or "channels_last for `[batch, height, width, channels]`.
161 | Returns:
162 | A version of input tensor with DropBlock applied.
163 | Raises:
164 | if width and height of the input tensor are not equal.
165 | """
166 |
167 | if not is_training or keep_prob is None:
168 | return net
169 |
170 | tf.logging.info('Applying DropBlock: dropblock_size {}, net.shape {}'.format(
171 | dropblock_size, net.shape))
172 |
173 | if data_format == 'channels_last':
174 | _, width, height, _ = net.get_shape().as_list()
175 | else:
176 | _, _, width, height = net.get_shape().as_list()
177 | if width != height:
178 | raise ValueError('Input tensor with width!=height is not supported.')
179 |
180 | dropblock_size = min(dropblock_size, width)
181 | # seed_drop_rate is the gamma parameter of DropBlcok.
182 | seed_drop_rate = (1.0 - keep_prob) * width**2 / dropblock_size**2 / (
183 | width - dropblock_size + 1)**2
184 |
185 | # Forces the block to be inside the feature map.
186 | w_i, h_i = tf.meshgrid(tf.range(width), tf.range(width))
187 | valid_block_center = tf.logical_and(
188 | tf.logical_and(w_i >= int(dropblock_size // 2),
189 | w_i < width - (dropblock_size - 1) // 2),
190 | tf.logical_and(h_i >= int(dropblock_size // 2),
191 | h_i < width - (dropblock_size - 1) // 2))
192 |
193 | valid_block_center = tf.expand_dims(valid_block_center, 0)
194 | valid_block_center = tf.expand_dims(
195 | valid_block_center, -1 if data_format == 'channels_last' else 0)
196 |
197 | randnoise = tf.random_uniform(net.shape, dtype=tf.float32)
198 | block_pattern = (1 - tf.cast(valid_block_center, dtype=tf.float32) + tf.cast(
199 | (1 - seed_drop_rate), dtype=tf.float32) + randnoise) >= 1
200 | block_pattern = tf.cast(block_pattern, dtype=tf.float32)
201 |
202 | if dropblock_size == width:
203 | block_pattern = tf.reduce_min(
204 | block_pattern,
205 | axis=[1, 2] if data_format == 'channels_last' else [2, 3],
206 | keepdims=True)
207 | else:
208 | if data_format == 'channels_last':
209 | ksize = [1, dropblock_size, dropblock_size, 1]
210 | else:
211 | ksize = [1, 1, dropblock_size, dropblock_size]
212 | block_pattern = -tf.nn.max_pool(
213 | -block_pattern, ksize=ksize, strides=[1, 1, 1, 1], padding='SAME',
214 | data_format='NHWC' if data_format == 'channels_last' else 'NCHW')
215 |
216 | percent_ones = tf.cast(tf.reduce_sum((block_pattern)), tf.float32) / tf.cast(
217 | tf.size(block_pattern), tf.float32)
218 |
219 | net = net / tf.cast(percent_ones, net.dtype) * tf.cast(
220 | block_pattern, net.dtype)
221 | return net
222 |
223 |
224 | def fixed_padding(inputs, kernel_size, data_format='channels_last'):
225 | """Pads the input along the spatial dimensions independently of input size.
226 |
227 | Args:
228 | inputs: `Tensor` of size `[batch, channels, height, width]` or
229 | `[batch, height, width, channels]` depending on `data_format`.
230 | kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
231 | operations. Should be a positive integer.
232 | data_format: `str` either "channels_first" for `[batch, channels, height,
233 | width]` or "channels_last for `[batch, height, width, channels]`.
234 |
235 | Returns:
236 | A padded `Tensor` of the same `data_format` with size either intact
237 | (if `kernel_size == 1`) or padded (if `kernel_size > 1`).
238 | """
239 | pad_total = kernel_size - 1
240 | pad_beg = pad_total // 2
241 | pad_end = pad_total - pad_beg
242 | if data_format == 'channels_first':
243 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],
244 | [pad_beg, pad_end], [pad_beg, pad_end]])
245 | else:
246 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
247 | [pad_beg, pad_end], [0, 0]])
248 |
249 | return padded_inputs
250 |
251 |
252 | def conv2d_fixed_padding(inputs, filters, kernel_size, strides,
253 | data_format='channels_last'):
254 | """Strided 2-D convolution with explicit padding.
255 |
256 | The padding is consistent and is based only on `kernel_size`, not on the
257 | dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
258 |
259 | Args:
260 | inputs: `Tensor` of size `[batch, channels, height_in, width_in]`.
261 | filters: `int` number of filters in the convolution.
262 | kernel_size: `int` size of the kernel to be used in the convolution.
263 | strides: `int` strides of the convolution.
264 | data_format: `str` either "channels_first" for `[batch, channels, height,
265 | width]` or "channels_last for `[batch, height, width, channels]`.
266 |
267 | Returns:
268 | A `Tensor` of shape `[batch, filters, height_out, width_out]`.
269 | """
270 | if strides > 1:
271 | inputs = fixed_padding(inputs, kernel_size, data_format=data_format)
272 |
273 | return tf.layers.conv2d(
274 | inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides,
275 | padding=('SAME' if strides == 1 else 'VALID'), use_bias=False,
276 | kernel_initializer=tf.variance_scaling_initializer(),
277 | data_format=data_format)
278 |
279 |
280 | def sk_conv2d(inputs, filters, strides, sk_ratio, min_dim=32,
281 | is_training=True, data_format='channels_last'):
282 | """Selective kernel convolutional layer (https://arxiv.org/abs/1903.06586)."""
283 | channel_axis = 1 if data_format == 'channels_first' else 3
284 | pooling_axes = [2, 3] if data_format == 'channels_first' else [1, 2]
285 |
286 | # Two stream convs (using split and both are 3x3).
287 | inputs = conv2d_fixed_padding(
288 | inputs=inputs, filters=2 * filters, kernel_size=3, strides=strides,
289 | data_format=data_format)
290 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
291 | inputs = tf.stack(tf.split(inputs, num_or_size_splits=2, axis=channel_axis))
292 |
293 | # Mixing weights for two streams.
294 | mid_dim = max(int(filters * sk_ratio), min_dim)
295 | global_features = tf.reduce_mean(
296 | tf.reduce_sum(inputs, axis=0), pooling_axes, keepdims=True)
297 | global_features = tf.layers.conv2d(
298 | inputs=global_features, filters=mid_dim, kernel_size=1, strides=1,
299 | kernel_initializer=tf.variance_scaling_initializer(),
300 | use_bias=False, data_format=data_format)
301 | global_features = batch_norm_relu(
302 | global_features, is_training, data_format=data_format)
303 | mixing = tf.layers.conv2d(
304 | inputs=global_features, filters=2 * filters, kernel_size=1, strides=1,
305 | kernel_initializer=tf.variance_scaling_initializer(),
306 | use_bias=False, data_format=data_format)
307 | mixing = tf.stack(tf.split(mixing, num_or_size_splits=2, axis=channel_axis))
308 | mixing = tf.nn.softmax(mixing, axis=0)
309 |
310 | return tf.reduce_sum(inputs * mixing, axis=0)
311 |
312 |
313 | def se_layer(inputs, filters, se_ratio, data_format='channels_last'):
314 | """Squeeze and Excitation layer (https://arxiv.org/abs/1709.01507)."""
315 | if se_ratio <= 0:
316 | return inputs
317 | se_reduce = tf.layers.Conv2D(
318 | max(1, int(filters * se_ratio)),
319 | kernel_size=[1, 1],
320 | strides=[1, 1],
321 | kernel_initializer=tf.variance_scaling_initializer(),
322 | padding='same',
323 | data_format=data_format,
324 | use_bias=True)
325 | se_expand = tf.layers.Conv2D(
326 | inputs.shape[-1],
327 | kernel_size=[1, 1],
328 | strides=[1, 1],
329 | kernel_initializer=tf.variance_scaling_initializer(),
330 | padding='same',
331 | data_format=data_format,
332 | use_bias=True)
333 |
334 | spatial_dims = [2, 3] if data_format == 'channels_first' else [1, 2]
335 | se_tensor = tf.reduce_mean(
336 | inputs, spatial_dims, keepdims=True)
337 | se_tensor = se_expand(tf.nn.relu(se_reduce(se_tensor)))
338 | return tf.sigmoid(se_tensor) * inputs
339 |
340 |
341 | def residual_block(inputs, filters, is_training, strides,
342 | use_projection=False, data_format='channels_last',
343 | dropblock_keep_prob=None, dropblock_size=None):
344 | """Standard building block for residual networks with BN after convolutions.
345 |
346 | Args:
347 | inputs: `Tensor` of size `[batch, channels, height, width]`.
348 | filters: `int` number of filters for the first two convolutions. Note that
349 | the third and final convolution will use 4 times as many filters.
350 | is_training: `bool` for whether the model is in training.
351 | strides: `int` block stride. If greater than 1, this block will ultimately
352 | downsample the input.
353 | use_projection: `bool` for whether this block should use a projection
354 | shortcut (versus the default identity shortcut). This is usually `True`
355 | for the first block of a block group, which may change the number of
356 | filters and the resolution.
357 | data_format: `str` either "channels_first" for `[batch, channels, height,
358 | width]` or "channels_last for `[batch, height, width, channels]`.
359 | dropblock_keep_prob: unused; needed to give method same signature as other
360 | blocks
361 | dropblock_size: unused; needed to give method same signature as other
362 | blocks
363 | Returns:
364 | The output `Tensor` of the block.
365 | """
366 | del dropblock_keep_prob
367 | del dropblock_size
368 | shortcut = inputs
369 | if use_projection:
370 | # Projection shortcut in first layer to match filters and strides
371 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187)
372 | if strides > 1:
373 | inputs = fixed_padding(inputs, 2, data_format)
374 | inputs = tf.layers.average_pooling2d(
375 | inputs, pool_size=2, strides=strides,
376 | padding='SAME' if strides == 1 else 'VALID', data_format=data_format)
377 | shortcut = conv2d_fixed_padding(
378 | inputs=inputs, filters=filters, kernel_size=1, strides=1,
379 | data_format=data_format)
380 | else:
381 | shortcut = conv2d_fixed_padding(
382 | inputs=inputs, filters=filters, kernel_size=1, strides=strides,
383 | data_format=data_format)
384 | shortcut = batch_norm_relu(shortcut, is_training, relu=False,
385 | data_format=data_format)
386 |
387 | inputs = conv2d_fixed_padding(
388 | inputs=inputs, filters=filters, kernel_size=3, strides=strides,
389 | data_format=data_format)
390 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
391 |
392 | inputs = conv2d_fixed_padding(
393 | inputs=inputs, filters=filters, kernel_size=3, strides=1,
394 | data_format=data_format)
395 | inputs = batch_norm_relu(inputs, is_training, relu=False, init_zero=True,
396 | data_format=data_format)
397 |
398 | if FLAGS.se_ratio > 0:
399 | inputs = se_layer(inputs, filters, FLAGS.se_ratio, data_format=data_format)
400 |
401 | return tf.nn.relu(inputs + shortcut)
402 |
403 |
404 | def bottleneck_block(inputs, filters, is_training, strides,
405 | use_projection=False, data_format='channels_last',
406 | dropblock_keep_prob=None, dropblock_size=None):
407 | """Bottleneck block variant for residual networks with BN after convolutions.
408 |
409 | Args:
410 | inputs: `Tensor` of size `[batch, channels, height, width]`.
411 | filters: `int` number of filters for the first two convolutions. Note that
412 | the third and final convolution will use 4 times as many filters.
413 | is_training: `bool` for whether the model is in training.
414 | strides: `int` block stride. If greater than 1, this block will ultimately
415 | downsample the input.
416 | use_projection: `bool` for whether this block should use a projection
417 | shortcut (versus the default identity shortcut). This is usually `True`
418 | for the first block of a block group, which may change the number of
419 | filters and the resolution.
420 | data_format: `str` either "channels_first" for `[batch, channels, height,
421 | width]` or "channels_last for `[batch, height, width, channels]`.
422 | dropblock_keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock.
423 | "None" means no DropBlock.
424 | dropblock_size: `int` size parameter of DropBlock. Will not be used if
425 | dropblock_keep_prob is "None".
426 |
427 | Returns:
428 | The output `Tensor` of the block.
429 | """
430 | shortcut = inputs
431 | if use_projection:
432 | # Projection shortcut only in first block within a group. Bottleneck blocks
433 | # end with 4 times the number of filters.
434 | filters_out = 4 * filters
435 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187)
436 | if strides > 1:
437 | shortcut = fixed_padding(inputs, 2, data_format)
438 | else:
439 | shortcut = inputs
440 | shortcut = tf.layers.average_pooling2d(
441 | shortcut, pool_size=2, strides=strides,
442 | padding='SAME' if strides == 1 else 'VALID', data_format=data_format)
443 | shortcut = conv2d_fixed_padding(
444 | inputs=shortcut, filters=filters_out, kernel_size=1, strides=1,
445 | data_format=data_format)
446 | else:
447 | shortcut = conv2d_fixed_padding(
448 | inputs=inputs, filters=filters_out, kernel_size=1, strides=strides,
449 | data_format=data_format)
450 | shortcut = batch_norm_relu(shortcut, is_training, relu=False,
451 | data_format=data_format)
452 | shortcut = dropblock(
453 | shortcut, is_training=is_training, data_format=data_format,
454 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size)
455 |
456 | inputs = conv2d_fixed_padding(
457 | inputs=inputs, filters=filters, kernel_size=1, strides=1,
458 | data_format=data_format)
459 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
460 | inputs = dropblock(
461 | inputs, is_training=is_training, data_format=data_format,
462 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size)
463 |
464 | if FLAGS.sk_ratio > 0:
465 | inputs = sk_conv2d(
466 | inputs, filters, strides, FLAGS.sk_ratio,
467 | is_training=is_training, data_format=data_format)
468 | else:
469 | inputs = conv2d_fixed_padding(
470 | inputs=inputs, filters=filters, kernel_size=3, strides=strides,
471 | data_format=data_format)
472 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
473 | inputs = dropblock(
474 | inputs, is_training=is_training, data_format=data_format,
475 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size)
476 |
477 | inputs = conv2d_fixed_padding(
478 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1,
479 | data_format=data_format)
480 | inputs = batch_norm_relu(inputs, is_training, relu=False, init_zero=True,
481 | data_format=data_format)
482 | inputs = dropblock(
483 | inputs, is_training=is_training, data_format=data_format,
484 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size)
485 |
486 | if FLAGS.se_ratio > 0:
487 | inputs = se_layer(inputs, filters, FLAGS.se_ratio, data_format=data_format)
488 |
489 | return tf.nn.relu(inputs + shortcut)
490 |
491 |
492 | def block_group(inputs, filters, block_fn, blocks, strides, is_training, name,
493 | data_format='channels_last', dropblock_keep_prob=None,
494 | dropblock_size=None):
495 | """Creates one group of blocks for the ResNet model.
496 |
497 | Args:
498 | inputs: `Tensor` of size `[batch, channels, height, width]`.
499 | filters: `int` number of filters for the first convolution of the layer.
500 | block_fn: `function` for the block to use within the model
501 | blocks: `int` number of blocks contained in the layer.
502 | strides: `int` stride to use for the first convolution of the layer. If
503 | greater than 1, this layer will downsample the input.
504 | is_training: `bool` for whether the model is training.
505 | name: `str`name for the Tensor output of the block layer.
506 | data_format: `str` either "channels_first" for `[batch, channels, height,
507 | width]` or "channels_last for `[batch, height, width, channels]`.
508 | dropblock_keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock.
509 | "None" means no DropBlock.
510 | dropblock_size: `int` size parameter of DropBlock. Will not be used if
511 | dropblock_keep_prob is "None".
512 |
513 | Returns:
514 | The output `Tensor` of the block layer.
515 | """
516 | # Only the first block per block_group uses projection shortcut and strides.
517 | inputs = block_fn(inputs, filters, is_training, strides,
518 | use_projection=True, data_format=data_format,
519 | dropblock_keep_prob=dropblock_keep_prob,
520 | dropblock_size=dropblock_size)
521 |
522 | for _ in range(1, blocks):
523 | inputs = block_fn(inputs, filters, is_training, 1,
524 | data_format=data_format,
525 | dropblock_keep_prob=dropblock_keep_prob,
526 | dropblock_size=dropblock_size)
527 |
528 | return tf.identity(inputs, name)
529 |
530 |
531 | def resnet_v1_generator(block_fn, layers, width_multiplier,
532 | cifar_stem=False, data_format='channels_last',
533 | dropblock_keep_probs=None, dropblock_size=None):
534 | """Generator for ResNet v1 models.
535 |
536 | Args:
537 | block_fn: `function` for the block to use within the model. Either
538 | `residual_block` or `bottleneck_block`.
539 | layers: list of 4 `int`s denoting the number of blocks to include in each
540 | of the 4 block groups. Each group consists of blocks that take inputs of
541 | the same resolution.
542 | width_multiplier: `int` width multiplier for network.
543 | cifar_stem: `bool` If True, use a 3x3 conv without strides or pooling as
544 | stem.
545 | data_format: `str` either "channels_first" for `[batch, channels, height,
546 | width]` or "channels_last for `[batch, height, width, channels]`.
547 | dropblock_keep_probs: `list` of 4 elements denoting keep_prob of DropBlock
548 | for each block group. None indicates no DropBlock for the corresponding
549 | block group.
550 | dropblock_size: `int`: size parameter of DropBlock.
551 |
552 | Returns:
553 | Model `function` that takes in `inputs` and `is_training` and returns the
554 | output `Tensor` of the ResNet model.
555 |
556 | Raises:
557 | if dropblock_keep_probs is not 'None' or a list with len 4.
558 | """
559 | if dropblock_keep_probs is None:
560 | dropblock_keep_probs = [None] * 4
561 | if not isinstance(dropblock_keep_probs,
562 | list) or len(dropblock_keep_probs) != 4:
563 | raise ValueError('dropblock_keep_probs is not valid:', dropblock_keep_probs)
564 |
565 | def model(inputs, is_training):
566 | """Creation of the model graph."""
567 | if cifar_stem:
568 | inputs = conv2d_fixed_padding(
569 | inputs=inputs, filters=64 * width_multiplier, kernel_size=3,
570 | strides=1, data_format=data_format)
571 | inputs = tf.identity(inputs, 'initial_conv')
572 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
573 | inputs = tf.identity(inputs, 'initial_max_pool')
574 | else:
575 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187)
576 | inputs = conv2d_fixed_padding(
577 | inputs=inputs, filters=64 * width_multiplier // 2, kernel_size=3,
578 | strides=2, data_format=data_format)
579 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
580 | inputs = conv2d_fixed_padding(
581 | inputs=inputs, filters=64 * width_multiplier // 2, kernel_size=3,
582 | strides=1, data_format=data_format)
583 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
584 | inputs = conv2d_fixed_padding(
585 | inputs=inputs, filters=64 * width_multiplier, kernel_size=3,
586 | strides=1, data_format=data_format)
587 | else:
588 | inputs = conv2d_fixed_padding(
589 | inputs=inputs, filters=64 * width_multiplier, kernel_size=7,
590 | strides=2, data_format=data_format)
591 | inputs = tf.identity(inputs, 'initial_conv')
592 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format)
593 |
594 | inputs = tf.layers.max_pooling2d(
595 | inputs=inputs, pool_size=3, strides=2, padding='SAME',
596 | data_format=data_format)
597 | inputs = tf.identity(inputs, 'initial_max_pool')
598 |
599 | def filter_trainable_variables(trainable_variables, after_block):
600 | """Add new trainable variables for the immediate precedent block."""
601 | if after_block == 0:
602 | trainable_variables[after_block] = tf.trainable_variables()
603 | else:
604 | trainable_variables[after_block] = []
605 | for var in tf.trainable_variables():
606 | to_keep = True
607 | for j in range(after_block):
608 | if var in trainable_variables[j]:
609 | to_keep = False
610 | break
611 | if to_keep:
612 | trainable_variables[after_block].append(var)
613 |
614 | def add_to_collection(trainable_variables, prefix):
615 | """Put variables into graph collection."""
616 | for after_block, variables in trainable_variables.items():
617 | collection = prefix + str(after_block)
618 | for var in variables:
619 | tf.add_to_collection(collection, var)
620 |
621 | trainable_variables = {}
622 | filter_trainable_variables(trainable_variables, after_block=0)
623 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 0:
624 | inputs = tf.stop_gradient(inputs)
625 |
626 | inputs = block_group(
627 | inputs=inputs, filters=64 * width_multiplier, block_fn=block_fn,
628 | blocks=layers[0], strides=1, is_training=is_training,
629 | name='block_group1', data_format=data_format,
630 | dropblock_keep_prob=dropblock_keep_probs[0],
631 | dropblock_size=dropblock_size)
632 |
633 | filter_trainable_variables(trainable_variables, after_block=1)
634 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 1:
635 | inputs = tf.stop_gradient(inputs)
636 |
637 | inputs = block_group(
638 | inputs=inputs, filters=128 * width_multiplier, block_fn=block_fn,
639 | blocks=layers[1], strides=2, is_training=is_training,
640 | name='block_group2', data_format=data_format,
641 | dropblock_keep_prob=dropblock_keep_probs[1],
642 | dropblock_size=dropblock_size)
643 |
644 | filter_trainable_variables(trainable_variables, after_block=2)
645 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 2:
646 | inputs = tf.stop_gradient(inputs)
647 |
648 | inputs = block_group(
649 | inputs=inputs, filters=256 * width_multiplier, block_fn=block_fn,
650 | blocks=layers[2], strides=2, is_training=is_training,
651 | name='block_group3', data_format=data_format,
652 | dropblock_keep_prob=dropblock_keep_probs[2],
653 | dropblock_size=dropblock_size)
654 |
655 | filter_trainable_variables(trainable_variables, after_block=3)
656 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 3:
657 | inputs = tf.stop_gradient(inputs)
658 |
659 | inputs = block_group(
660 | inputs=inputs, filters=512 * width_multiplier, block_fn=block_fn,
661 | blocks=layers[3], strides=2, is_training=is_training,
662 | name='block_group4', data_format=data_format,
663 | dropblock_keep_prob=dropblock_keep_probs[3],
664 | dropblock_size=dropblock_size)
665 |
666 | filter_trainable_variables(trainable_variables, after_block=4)
667 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 4:
668 | inputs = tf.stop_gradient(inputs)
669 |
670 | if data_format == 'channels_last':
671 | inputs = tf.reduce_mean(inputs, [1, 2])
672 | else:
673 | inputs = tf.reduce_mean(inputs, [2, 3])
674 | inputs = tf.identity(inputs, 'final_avg_pool')
675 |
676 | # filter_trainable_variables(trainable_variables, after_block=5)
677 | add_to_collection(trainable_variables, 'trainable_variables_inblock_')
678 |
679 | return inputs
680 |
681 | return model
682 |
683 |
684 | def resnet_v1(resnet_depth, width_multiplier,
685 | cifar_stem=False, data_format='channels_last',
686 | dropblock_keep_probs=None, dropblock_size=None):
687 | """Returns the ResNet model for a given size and number of output classes."""
688 | model_params = {
689 | 18: {'block': residual_block, 'layers': [2, 2, 2, 2]},
690 | 34: {'block': residual_block, 'layers': [3, 4, 6, 3]},
691 | 50: {'block': bottleneck_block, 'layers': [3, 4, 6, 3]},
692 | 101: {'block': bottleneck_block, 'layers': [3, 4, 23, 3]},
693 | 152: {'block': bottleneck_block, 'layers': [3, 8, 36, 3]},
694 | 200: {'block': bottleneck_block, 'layers': [3, 24, 36, 3]}
695 | }
696 |
697 | if resnet_depth not in model_params:
698 | raise ValueError('Not a valid resnet_depth:', resnet_depth)
699 |
700 | params = model_params[resnet_depth]
701 | return resnet_v1_generator(
702 | params['block'], params['layers'], width_multiplier,
703 | cifar_stem=cifar_stem,
704 | dropblock_keep_probs=dropblock_keep_probs,
705 | dropblock_size=dropblock_size,
706 | data_format=data_format)
707 |
--------------------------------------------------------------------------------
/vision/results_summary.txt:
--------------------------------------------------------------------------------
1 | TEMP=256
2 |
3 | gs://martin_ma_mql_simclr/nce_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7348, label_top_5_accuracy = 0.91076, loss = 1.4509399, regularization_loss = 0.0
4 |
5 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71788, label_top_5_accuracy = 0.90118, loss = 1.3814622, regularization_loss = 0.0
6 |
7 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7124, label_top_5_accuracy = 0.89796, loss = 1.4208169, regularization_loss = 0.0
8 |
9 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71834, label_top_5_accuracy = 0.9028, loss = 1.3136246, regularization_loss = 0.0
10 |
11 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_512_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.56738, label_top_5_accuracy = 0.8002, loss = 1.8758307, regularization_loss = 0.0
12 |
13 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_1e-4_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71534, label_top_5_accuracy = 0.90094, loss = 1.3881, regularization_loss = 0.0
14 |
15 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_1e-3_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71874, label_top_5_accuracy = 0.90276, loss = 1.48331, regularization_loss = 0.0
16 |
17 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_1e-2_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72452, label_top_5_accuracy = 0.90286, loss = 1.564714, regularization_loss = 0.0
18 |
19 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_1e-1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71364, label_top_5_accuracy = 0.89566, loss = 1.5915436, regularization_loss = 0.0
20 |
21 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_0.5_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71914, label_top_5_accuracy = 0.89728, loss = 1.5467246, regularization_loss = 0.0
22 |
23 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_0.01_GAMMA_2.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7223, label_top_5_accuracy = 0.90068, loss = 1.5355219, regularization_loss = 0.0
24 |
25 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7223, label_top_5_accuracy = 0.90238, loss = 1.5636464, regularization_loss = 0.0
26 |
27 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72394, label_top_5_accuracy = 0.90424, loss = 1.5786924, regularization_loss = 0.0
28 |
29 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73042, label_top_5_accuracy = 0.90804, loss = 1.4838624, regularization_loss = 0.0
30 |
31 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72558, label_top_5_accuracy = 0.90198, loss = 1.5553308, regularization_loss = 0.0
32 |
33 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_5.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7226, label_top_5_accuracy = 0.9008, loss = 1.5155789, regularization_loss = 0.0
34 |
35 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.02_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7227, label_top_5_accuracy = 0.9026, loss = 1.5762701, regularization_loss = 0.0
36 |
37 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.05_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71726, label_top_5_accuracy = 0.89686, loss = 1.6178995, regularization_loss = 0.0
38 |
39 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.005_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71776, label_top_5_accuracy = 0.9007, loss = 1.5661987, regularization_loss = 0.0
40 |
41 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_5.0_BETA_0.05_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71656, label_top_5_accuracy = 0.8976, loss = 1.5743265, regularization_loss = 0.0
42 |
43 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.5_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72322, label_top_5_accuracy = 0.90264, loss = 1.6190277,regularization_loss = 0.0
44 |
45 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73028, label_top_5_accuracy = 0.9073, loss = 1.459571, regularization_loss = 0.0
46 |
47 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.5_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72904, label_top_5_accuracy = 0.90646, loss = 1.5504616,regularization_loss = 0.0
48 |
49 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_5.0_BETA_0.1_GAMMA_3.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71782, label_top_5_accuracy = 0.89956, loss = 1.5116348, regularization_loss = 0.0
50 |
51 |
52 | =======
53 |
54 |
55 | =======
56 |
57 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72056, label_top_5_accuracy = 0.90078, loss = 1.5019414, regularization_loss = 0.0
58 |
59 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_10.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72932, label_top_5_accuracy = 0.90458, loss = 1.570985, regularization_loss = 0.0
60 |
61 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_10.0_BETA_0.02_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7112, label_top_5_accuracy = 0.89654, loss = 1.5867603, regularization_loss = 0.0
62 |
63 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.3_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74106, label_top_5_accuracy = 0.91256, loss = 1.4070679, regularization_loss = 0.0
64 |
65 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74244, label_top_5_accuracy = 0.91488, loss = 1.4087038, regularization_loss = 0.0
66 |
67 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_86_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74068, label_top_5_accuracy = 0.91394, loss = 1.4385886, regularization_loss = 0.0
68 |
69 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7392, label_top_5_accuracy = 0.9137, loss = 1.4215096, regularization_loss = 0.0
70 |
71 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74006, label_top_5_accuracy = 0.91448, loss = 1.398139, regularization_loss = 0.0
72 |
73 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.2_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.738, label_top_5_accuracy = 0.91306, loss = 1.422334, regularization_loss = 0.0
74 |
75 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.02_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73756, label_top_5_accuracy = 0.9119, loss = 1.4739765, regularization_loss = 0.0
76 |
77 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7421, label_top_5_accuracy = 0.91478, loss = 1.434248, regularization_loss = 0.0
78 |
79 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74214, label_top_5_accuracy = 0.9144, loss = 1.3988103, regularization_loss = 0.0
80 |
81 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72656, label_top_5_accuracy = 0.9075, loss = 1.480991, regularization_loss = 0.0
82 |
83 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74076, label_top_5_accuracy = 0.91372, loss = 1.5113194, regularization_loss = 0.0
84 |
85 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_2_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.69452, label_top_5_accuracy = 0.88356, loss = 1.7041928, regularization_loss = 0.0
86 |
87 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7435, label_top_5_accuracy = 0.91578, loss = 1.41444, regularization_loss = 0.0
88 |
89 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7422, label_top_5_accuracy = 0.91474, loss = 1.3790753, regularization_loss = 0.0
90 |
91 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74324, label_top_5_accuracy = 0.917, loss = 1.3407627, regularization_loss = 0.0
92 |
93 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.2_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74182, label_top_5_accuracy = 0.91422, loss = 1.4183415, regularization_loss = 0.0
94 |
95 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.05_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74266, label_top_5_accuracy = 0.91512, loss = 1.3939042, regularization_loss = 0.0
96 |
97 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.02_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73006, label_top_5_accuracy = 0.90818, loss = 1.471376, regularization_loss = 0.0
98 |
99 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.05_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71944, label_top_5_accuracy = 0.9003, loss = 1.4997969, regularization_loss = 0.0
100 |
101 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7416, label_top_5_accuracy = 0.91422, loss = 1.4269485, regularization_loss = 0.0
102 |
103 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.2_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73842, label_top_5_accuracy = 0.91254, loss = 1.464798, regularization_loss = 0.0
104 |
105 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74102, label_top_5_accuracy = 0.91458, loss = 1.5262817, regularization_loss = 0.0
106 |
107 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.25_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74326, label_top_5_accuracy = 0.91406, loss = 1.4183799, regularization_loss = 0.0
108 |
109 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.2_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74182, label_top_5_accuracy = 0.91422, loss = 1.4183415, regularization_loss = 0.0
110 |
111 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.3_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74494, label_top_5_accuracy = 0.91484, loss = 1.3978903, regularization_loss = 0.0
112 | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
113 | gs://martin_ma_mql_simclr/nce_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77824, label_top_5_accuracy = 0.93058, loss = 1.0428659, regularization_loss = 0.0
114 |
115 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7463, label_top_5_accuracy = 0.91458, loss = 1.1594483, regularization_loss = 0.0
116 |
117 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_3.0_BETA_1e-2_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7459, label_top_5_accuracy = 0.91434, loss = 1.1565692, regularization_loss = 0.0
118 |
119 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_1e-3_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74422, label_top_5_accuracy = 0.91424, loss = 1.2344401, regularization_loss = 0.0
120 |
121 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_10.0_BETA_1e-1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74468, label_top_5_accuracy = 0.9085, loss = 1.3773991, regularization_loss = 0.0
122 |
123 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1e-3_BETA_1e-3_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74616, label_top_5_accuracy = 0.91402, loss = 1.3737655, regularization_loss = 0.0
124 |
125 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77374, label_top_5_accuracy = 0.92736, loss = 1.035313, regularization_loss = 0.0
126 |
127 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.02_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.76864, label_top_5_accuracy = 0.92676, loss = 1.1647376, regularization_loss = 0.0
128 |
129 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.1_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7727, label_top_5_accuracy = 0.92882, loss = 1.1740562, regularization_loss = 0.0
130 |
131 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75432, label_top_5_accuracy = 0.91638, loss = 1.1587994, regularization_loss = 0.0
132 |
133 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77156, label_top_5_accuracy = 0.92894, loss = 1.0430934, regularization_loss = 0.0
134 |
135 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_2_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71732, label_top_5_accuracy = 0.89936, loss = 1.3376576, regularization_loss = 0.0
136 |
137 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77336, label_top_5_accuracy = 0.92894, loss = 1.1666732, regularization_loss = 0.0
138 |
139 | s://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77256, label_top_5_accuracy = 0.92906, loss = 0.98218614, regularization_loss = 0.0
140 |
141 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.2_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7731, label_top_5_accuracy = 0.92808, loss = 1.0666859, regularization_loss = 0.0
142 |
143 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_2_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.001, label_top_5_accuracy = 0.005, loss = 6.9076014, regularization_loss = 0.0
144 |
145 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74514, label_top_5_accuracy = 0.91204, loss = 1.3100942, regularization_loss = 0.0
146 |
147 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74134, label_top_5_accuracy = 0.90802, loss = 1.3841316, regularization_loss = 0.0
148 |
149 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74152, label_top_5_accuracy = 0.91008, loss = 1.2542875, regularization_loss = 0.0
150 |
151 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77338, label_top_5_accuracy = 0.92794, loss = 1.0368634, regularization_loss = 0.0
152 |
153 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_2.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77472, label_top_5_accuracy = 0.9304, loss = 1.0788878, regularization_loss = 0.0
154 |
155 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77372, label_top_5_accuracy = 0.929, loss = 1.0592294, regularization_loss = 0.0
156 |
157 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77942, label_top_5_accuracy = 0.9322, loss = 0.9496544, regularization_loss = 0.0
158 |
159 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.1_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77858, label_top_5_accuracy = 0.9315, loss = 1.0732164, regularization_loss = 0.0
160 |
161 | s://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77686, label_top_5_accuracy = 0.93116, loss = 1.1159751, regularization_loss = 0.0
162 |
163 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.2_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77866, label_top_5_accuracy = 0.93184, loss = 1.02233, regularization_loss = 0.0
164 |
165 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.0001_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73946, label_top_5_accuracy = 0.90888, loss = 1.4767729, regularization_loss = 0.0
166 |
167 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.0001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.68206, label_top_5_accuracy = 0.8747, loss = 1.7220646, regularization_loss = 0.0
168 |
169 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.1_BETA_0.0005_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73238, label_top_5_accuracy = 0.90698, loss = 1.3052814, regularization_loss = 0.0
170 |
171 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73986, label_top_5_accuracy = 0.90982, loss = 1.2386664, regularization_loss = 0.0
172 |
173 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.005_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77938, label_top_5_accuracy = 0.93426, loss = 1.0879052, regularization_loss = 0.0
174 |
175 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.0005_GAMMA_0.005_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77476, label_top_5_accuracy = 0.92926, loss = 1.1837187, regularization_loss = 0.0
176 |
177 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.1_GAMMA_0.005_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75216, label_top_5_accuracy = 0.91376, loss = 1.3076787, regularization_loss = 0.0
178 |
179 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74924, label_top_5_accuracy = 0.91502, loss = 1.1803943, regularization_loss = 0.0
180 |
181 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.005_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75048, label_top_5_accuracy = 0.91442, loss = 1.1888653, regularization_loss = 0.0
182 |
183 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75092, label_top_5_accuracy = 0.91302, loss = 1.1644344, regularization_loss = 0.0
184 |
185 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75062, label_top_5_accuracy = 0.91444, loss = 1.2249312, regularization_loss = 0.0
186 |
187 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.0001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.70622, label_top_5_accuracy = 0.88564, loss = 1.4944282, regularization_loss = 0.0
188 |
189 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73806, label_top_5_accuracy = 0.90776, loss = 1.2411014, regularization_loss = 0.0
190 |
191 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.1_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7527, label_top_5_accuracy = 0.91318, loss = 1.2551881, regularization_loss = 0.0
192 |
193 | #---------------------------------------------------------------------------------------------
194 |
195 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.70986, label_top_5_accuracy = 0.89282, loss = 1.4696014, regularization_loss = 0.0
196 |
197 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7138, label_top_5_accuracy = 0.89694, loss = 1.439495, regularization_loss = 0.0
198 |
199 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.01_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7002, label_top_5_accuracy = 0.88744, loss = 1.4174829, regularization_loss = 0.0
200 |
201 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.0_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.68674, label_top_5_accuracy = 0.88098, loss = 1.5044969, regularization_loss = 0.0
202 |
203 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.01_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7002, label_top_5_accuracy = 0.88744, loss = 1.4174829, regularization_loss = 0.0
204 |
205 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.00_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71774, label_top_5_accuracy = 0.8976, loss = 1.3535299, regularization_loss = 0.0
206 |
207 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.1_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71948, label_top_5_accuracy = 0.89782, loss = 1.397318, regularization_loss = 0.0
208 |
209 | #------------------------------------------------------------------------------------------------
210 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_300_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.01_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77076, label_top_5_accuracy = 0.9261, loss = 1.1930635, regularization_loss = 0.0
211 | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
212 | gs://martin_ma_mql_simclr/dv_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72214, label_top_5_accuracy = 0.90264, loss = 1.5218498, regularization_loss = 0.0
213 |
214 | gs://martin_ma_mql_simclr/nwj_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.0052, label_top_5_accuracy = 0.02212, loss = 6.673894, regularization_loss = 0.0
215 |
216 |
217 |
--------------------------------------------------------------------------------
/vision/run.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The SimCLR Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific simclr governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """The main training pipeline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import json
23 | import math
24 | import os
25 | from absl import app
26 | from absl import flags
27 |
28 | import resnet
29 | import data as data_lib
30 | import model as model_lib
31 | import model_util as model_util
32 |
33 | import tensorflow.compat.v1 as tf
34 | import tensorflow_datasets as tfds
35 | import tensorflow_hub as hub
36 |
37 |
38 | FLAGS = flags.FLAGS
39 |
40 |
41 | flags.DEFINE_float(
42 | 'learning_rate', 0.3,
43 | 'Initial learning rate per batch size of 256.')
44 |
45 | flags.DEFINE_enum(
46 | 'learning_rate_scaling', 'linear', ['linear', 'sqrt'],
47 | 'How to scale the learning rate as a function of batch size.')
48 |
49 | flags.DEFINE_float(
50 | 'warmup_epochs', 10,
51 | 'Number of epochs of warmup.')
52 |
53 | flags.DEFINE_float(
54 | 'weight_decay', 1e-4,
55 | 'Amount of weight decay to use.')
56 |
57 | flags.DEFINE_float(
58 | 'batch_norm_decay', 0.9,
59 | 'Batch norm decay parameter.')
60 |
61 | flags.DEFINE_integer(
62 | 'train_batch_size', 512,
63 | 'Batch size for training.')
64 |
65 | flags.DEFINE_string(
66 | 'train_split', 'train',
67 | 'Split for training.')
68 |
69 | flags.DEFINE_integer(
70 | 'train_epochs', 100,
71 | 'Number of epochs to train for.')
72 |
73 | flags.DEFINE_integer(
74 | 'train_steps', 0,
75 | 'Number of steps to train for. If provided, overrides train_epochs.')
76 |
77 | flags.DEFINE_integer(
78 | 'eval_batch_size', 256,
79 | 'Batch size for eval.')
80 |
81 | flags.DEFINE_integer(
82 | 'train_summary_steps', 100,
83 | 'Steps before saving training summaries. If 0, will not save.')
84 |
85 | flags.DEFINE_integer(
86 | 'checkpoint_epochs', 1,
87 | 'Number of epochs between checkpoints/summaries.')
88 |
89 | flags.DEFINE_integer(
90 | 'checkpoint_steps', 0,
91 | 'Number of steps between checkpoints/summaries. If provided, overrides '
92 | 'checkpoint_epochs.')
93 |
94 | flags.DEFINE_string(
95 | 'eval_split', 'validation',
96 | 'Split for evaluation.')
97 |
98 | flags.DEFINE_string(
99 | 'dataset', 'imagenet2012',
100 | 'Name of a dataset.')
101 |
102 | flags.DEFINE_bool(
103 | 'cache_dataset', False,
104 | 'Whether to cache the entire dataset in memory. If the dataset is '
105 | 'ImageNet, this is a very bad idea, but for smaller datasets it can '
106 | 'improve performance.')
107 |
108 | flags.DEFINE_enum(
109 | 'mode', 'train', ['train', 'eval', 'train_then_eval'],
110 | 'Whether to perform training or evaluation.')
111 |
112 | flags.DEFINE_enum(
113 | 'train_mode', 'pretrain', ['pretrain', 'finetune'],
114 | 'The train mode controls different objectives and trainable components.')
115 |
116 | flags.DEFINE_string(
117 | 'checkpoint', None,
118 | 'Loading from the given checkpoint for continued training or fine-tuning.')
119 |
120 | flags.DEFINE_string(
121 | 'variable_schema', '?!global_step',
122 | 'This defines whether some variable from the checkpoint should be loaded.')
123 |
124 | flags.DEFINE_bool(
125 | 'zero_init_logits_layer', False,
126 | 'If True, zero initialize layers after avg_pool for supervised learning.')
127 |
128 | flags.DEFINE_integer(
129 | 'fine_tune_after_block', -1,
130 | 'The layers after which block that we will fine-tune. -1 means fine-tuning '
131 | 'everything. 0 means fine-tuning after stem block. 4 means fine-tuning '
132 | 'just the linera head.')
133 |
134 | flags.DEFINE_string(
135 | 'master', None,
136 | 'Address/name of the TensorFlow master to use. By default, use an '
137 | 'in-process master.')
138 |
139 | flags.DEFINE_string(
140 | 'model_dir', None,
141 | 'Model directory for training.')
142 |
143 | flags.DEFINE_string(
144 | 'data_dir', None,
145 | 'Directory where dataset is stored.')
146 |
147 | flags.DEFINE_bool(
148 | 'use_tpu', True,
149 | 'Whether to run on TPU.')
150 |
151 | tf.flags.DEFINE_string(
152 | 'tpu_name', None,
153 | 'The Cloud TPU to use for training. This should be either the name '
154 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
155 | 'url.')
156 |
157 | tf.flags.DEFINE_string(
158 | 'tpu_zone', None,
159 | '[Optional] GCE zone where the Cloud TPU is located in. If not '
160 | 'specified, we will attempt to automatically detect the GCE project from '
161 | 'metadata.')
162 |
163 | tf.flags.DEFINE_string(
164 | 'gcp_project', None,
165 | '[Optional] Project name for the Cloud TPU-enabled project. If not '
166 | 'specified, we will attempt to automatically detect the GCE project from '
167 | 'metadata.')
168 |
169 | flags.DEFINE_enum(
170 | 'optimizer', 'lars', ['momentum', 'adam', 'lars'],
171 | 'Optimizer to use.')
172 |
173 | flags.DEFINE_float(
174 | 'momentum', 0.9,
175 | 'Momentum parameter.')
176 |
177 | flags.DEFINE_string(
178 | 'eval_name', None,
179 | 'Name for eval.')
180 |
181 | flags.DEFINE_integer(
182 | 'keep_checkpoint_max', 5,
183 | 'Maximum number of checkpoints to keep.')
184 |
185 | flags.DEFINE_integer(
186 | 'keep_hub_module_max', 1,
187 | 'Maximum number of Hub modules to keep.')
188 |
189 | flags.DEFINE_float(
190 | 'temperature', 0.1,
191 | 'Temperature parameter for contrastive loss.')
192 |
193 | flags.DEFINE_boolean(
194 | 'hidden_norm', True,
195 | 'Temperature parameter for contrastive loss.')
196 |
197 | flags.DEFINE_enum(
198 | 'proj_head_mode', 'nonlinear', ['none', 'linear', 'nonlinear'],
199 | 'How the head projection is done.')
200 |
201 | flags.DEFINE_integer(
202 | 'proj_out_dim', 128,
203 | 'Number of head projection dimension.')
204 |
205 | flags.DEFINE_integer(
206 | 'num_proj_layers', 3,
207 | 'Number of non-linear head layers.')
208 |
209 | flags.DEFINE_integer(
210 | 'ft_proj_selector', 0,
211 | 'Which layer of the projection head to use during fine-tuning. '
212 | '0 means throwing away the projection head, and -1 means the final layer.')
213 |
214 | flags.DEFINE_boolean(
215 | 'global_bn', True,
216 | 'Whether to aggregate BN statistics across distributed cores.')
217 |
218 | flags.DEFINE_integer(
219 | 'width_multiplier', 1,
220 | 'Multiplier to change width of network.')
221 |
222 | flags.DEFINE_integer(
223 | 'resnet_depth', 50,
224 | 'Depth of ResNet.')
225 |
226 | flags.DEFINE_float(
227 | 'sk_ratio', 0.,
228 | 'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.')
229 |
230 | flags.DEFINE_float(
231 | 'se_ratio', 0.,
232 | 'If it is bigger than 0, it will enable SE.')
233 |
234 | flags.DEFINE_integer(
235 | 'image_size', 224,
236 | 'Input image size.')
237 |
238 | flags.DEFINE_float(
239 | 'color_jitter_strength', 1.0,
240 | 'The strength of color jittering.')
241 |
242 | flags.DEFINE_boolean(
243 | 'use_blur', True,
244 | 'Whether or not to use Gaussian blur for augmentation during pretraining.')
245 |
246 | #################################
247 | ### Chi loss implementation
248 | #################################
249 |
250 | flags.DEFINE_enum(
251 | 'loss_type', 'chi', ['chi', 'nce', 'dv', 'nwj', 'js', 'wpc'],
252 | 'Types of loss to use.')
253 |
254 | flags.DEFINE_float(
255 | 'alpha', 0.0,
256 | 'Alpha of chi loss.')
257 |
258 | flags.DEFINE_float(
259 | 'beta', 0.0,
260 | 'Beta of chi loss.')
261 |
262 | flags.DEFINE_float(
263 | 'gamma', 1.0,
264 | 'Gamma of chi loss.')
265 |
266 | flags.DEFINE_float(
267 | 'gradient_penalty_weight', 0.0,
268 | 'Parameter for gradient penalty (WPC only).')
269 |
270 |
271 |
272 | def build_hub_module(model, num_classes, global_step, checkpoint_path):
273 | """Create TF-Hub module."""
274 |
275 | tags_and_args = [
276 | # The default graph is built with batch_norm, dropout etc. in inference
277 | # mode. This graph version is good for inference, not training.
278 | ([], {'is_training': False}),
279 | # A separate "train" graph builds batch_norm, dropout etc. in training
280 | # mode.
281 | (['train'], {'is_training': True}),
282 | ]
283 |
284 | def module_fn(is_training):
285 | """Function that builds TF-Hub module."""
286 | endpoints = {}
287 | inputs = tf.placeholder(
288 | tf.float32, [None, None, None, 3])
289 | with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE):
290 | hiddens = model(inputs, is_training)
291 | for v in ['initial_conv', 'initial_max_pool', 'block_group1',
292 | 'block_group2', 'block_group3', 'block_group4',
293 | 'final_avg_pool']:
294 | endpoints[v] = tf.get_default_graph().get_tensor_by_name(
295 | 'base_model/{}:0'.format(v))
296 | if FLAGS.train_mode == 'pretrain':
297 | hiddens_proj = model_util.projection_head(hiddens, is_training)
298 | endpoints['proj_head_input'] = hiddens
299 | endpoints['proj_head_output'] = hiddens_proj
300 | else:
301 | logits_sup = model_util.supervised_head(
302 | hiddens, num_classes, is_training)
303 | endpoints['logits_sup'] = logits_sup
304 | hub.add_signature(inputs=dict(images=inputs),
305 | outputs=dict(endpoints, default=hiddens))
306 |
307 | # Drop the non-supported non-standard graph collection.
308 | drop_collections = ['trainable_variables_inblock_%d'%d for d in range(6)]
309 | spec = hub.create_module_spec(module_fn, tags_and_args, drop_collections)
310 | hub_export_dir = os.path.join(FLAGS.model_dir, 'hub')
311 | checkpoint_export_dir = os.path.join(hub_export_dir, str(global_step))
312 | if tf.io.gfile.exists(checkpoint_export_dir):
313 | # Do not save if checkpoint already saved.
314 | tf.io.gfile.rmtree(checkpoint_export_dir)
315 | spec.export(
316 | checkpoint_export_dir,
317 | checkpoint_path=checkpoint_path,
318 | name_transform_fn=None)
319 |
320 | if FLAGS.keep_hub_module_max > 0:
321 | # Delete old exported Hub modules.
322 | exported_steps = []
323 | for subdir in tf.io.gfile.listdir(hub_export_dir):
324 | if not subdir.isdigit():
325 | continue
326 | exported_steps.append(int(subdir))
327 | exported_steps.sort()
328 | for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]:
329 | tf.io.gfile.rmtree(os.path.join(hub_export_dir, str(step_to_delete)))
330 |
331 |
332 | def perform_evaluation(estimator, input_fn, eval_steps, model, num_classes,
333 | checkpoint_path=None):
334 | """Perform evaluation.
335 |
336 | Args:
337 | estimator: TPUEstimator instance.
338 | input_fn: Input function for estimator.
339 | eval_steps: Number of steps for evaluation.
340 | model: Instance of transfer_learning.models.Model.
341 | num_classes: Number of classes to build model for.
342 | checkpoint_path: Path of checkpoint to evaluate.
343 |
344 | Returns:
345 | result: A Dict of metrics and their values.
346 | """
347 | if not checkpoint_path:
348 | checkpoint_path = estimator.latest_checkpoint()
349 | result = estimator.evaluate(
350 | input_fn, eval_steps, checkpoint_path=checkpoint_path,
351 | name=FLAGS.eval_name)
352 |
353 | # Record results as JSON.
354 | result_json_path = os.path.join(FLAGS.model_dir, 'result.json')
355 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
356 | json.dump({k: float(v) for k, v in result.items()}, f)
357 | result_json_path = os.path.join(
358 | FLAGS.model_dir, 'result_%d.json'%result['global_step'])
359 | with tf.io.gfile.GFile(result_json_path, 'w') as f:
360 | json.dump({k: float(v) for k, v in result.items()}, f)
361 | flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json')
362 | with tf.io.gfile.GFile(flag_json_path, 'w') as f:
363 | json.dump(FLAGS.flag_values_dict(), f)
364 |
365 | # Save Hub module.
366 | build_hub_module(model, num_classes,
367 | global_step=result['global_step'],
368 | checkpoint_path=checkpoint_path)
369 |
370 | return result
371 |
372 |
373 | def main(argv):
374 | if len(argv) > 1:
375 | raise app.UsageError('Too many command-line arguments.')
376 |
377 | # Enable training summary.
378 | if FLAGS.train_summary_steps > 0:
379 | tf.config.set_soft_device_placement(True)
380 |
381 |
382 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
383 | builder.download_and_prepare()
384 | num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
385 | num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
386 | num_classes = builder.info.features['label'].num_classes
387 |
388 | train_steps = model_util.get_train_steps(num_train_examples)
389 | eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
390 | epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))
391 |
392 | resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
393 | model = resnet.resnet_v1(
394 | resnet_depth=FLAGS.resnet_depth,
395 | width_multiplier=FLAGS.width_multiplier,
396 | cifar_stem=FLAGS.image_size <= 32)
397 |
398 | checkpoint_steps = (
399 | FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps))
400 |
401 | cluster = None
402 | if FLAGS.use_tpu and FLAGS.master is None:
403 | if FLAGS.tpu_name:
404 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
405 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
406 | else:
407 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
408 | tf.config.experimental_connect_to_cluster(cluster)
409 | tf.tpu.experimental.initialize_tpu_system(cluster)
410 |
411 | default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
412 | sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
413 | run_config = tf.estimator.tpu.RunConfig(
414 | tpu_config=tf.estimator.tpu.TPUConfig(
415 | iterations_per_loop=checkpoint_steps,
416 | eval_training_input_configuration=sliced_eval_mode
417 | if FLAGS.use_tpu else default_eval_mode),
418 | model_dir=FLAGS.model_dir,
419 | save_summary_steps=checkpoint_steps,
420 | save_checkpoints_steps=checkpoint_steps,
421 | keep_checkpoint_max=FLAGS.keep_checkpoint_max,
422 | master=FLAGS.master,
423 | cluster=cluster)
424 | estimator = tf.estimator.tpu.TPUEstimator(
425 | model_lib.build_model_fn(model, num_classes, num_train_examples),
426 | config=run_config,
427 | train_batch_size=FLAGS.train_batch_size,
428 | eval_batch_size=FLAGS.eval_batch_size,
429 | use_tpu=FLAGS.use_tpu)
430 |
431 | if FLAGS.mode == 'eval':
432 | for ckpt in tf.train.checkpoints_iterator(
433 | run_config.model_dir, min_interval_secs=15):
434 | try:
435 | result = perform_evaluation(
436 | estimator=estimator,
437 | input_fn=data_lib.build_input_fn(builder, False),
438 | eval_steps=eval_steps,
439 | model=model,
440 | num_classes=num_classes,
441 | checkpoint_path=ckpt)
442 | except tf.errors.NotFoundError:
443 | continue
444 | if result['global_step'] >= train_steps:
445 | return
446 | else:
447 | estimator.train(
448 | data_lib.build_input_fn(builder, True), max_steps=train_steps)
449 | if FLAGS.mode == 'train_then_eval':
450 | perform_evaluation(
451 | estimator=estimator,
452 | input_fn=data_lib.build_input_fn(builder, False),
453 | eval_steps=eval_steps,
454 | model=model,
455 | num_classes=num_classes)
456 |
457 |
458 | if __name__ == '__main__':
459 | tf.disable_v2_behavior() # Disable eager mode when running with TF2.
460 | app.run(main)
461 |
--------------------------------------------------------------------------------
/vision/tpu_pretrain_finetune_resnet128.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | TPU_NAME= # enter your tpu name here
3 | STORAGE_BUCKET= # enter your storage bucket here
4 | DATA_DIR=$STORAGE_BUCKET/tensorflow_datasets
5 |
6 | # hyperparameters
7 | BATCH_SIZE=4096
8 | EPOCH=100
9 | TEMP=32
10 | LR=0.1
11 | LR_SCALE='sqrt'
12 | W_DECAY=1e-4
13 | DATASET='imagenet2012'
14 | IMAGE_SIZE=224
15 | SK_RATIO=0.0625
16 | WIDTH_MUL=2
17 | RESNET_DEPTH=152
18 |
19 | # different for different losses
20 | LOSS_TYPE='chi'
21 | HIDDEN_NORM=false
22 | if [ $LOSS_TYPE == 'chi' ]
23 | then
24 | # check hidden norm
25 | HIDDEN_NORM=false
26 | ALPHA=0.3
27 | BETA=0.001
28 | GAMMA=0.01
29 | if [ $TEMP -le 1 ]
30 | then
31 | exit 0
32 | fi
33 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}_ALPHA_${ALPHA}_BETA_${BETA}_GAMMA_${GAMMA}"
34 | python run.py --train_mode=pretrain \
35 | --train_batch_size=$BATCH_SIZE \
36 | --train_epochs=$EPOCH \
37 | --temperature=$TEMP \
38 | --learning_rate=$LR \
39 | --learning_rate_scaling=$LR_SCALE \
40 | --weight_decay=$W_DECAY \
41 | --dataset=$DATASET \
42 | --image_size=$IMAGE_SIZE \
43 | --eval_split=validation \
44 | --data_dir=$DATA_DIR \
45 | --model_dir=$MODEL_DIR \
46 | --use_tpu=True \
47 | --tpu_name=$TPU_NAME \
48 | --train_summary_steps=0 \
49 | --sk_ratio $SK_RATIO \
50 | --width_multiplier $WIDTH_MUL \
51 | --resnet_depth $RESNET_DEPTH \
52 | --loss_type $LOSS_TYPE \
53 | --alpha=$ALPHA \
54 | --beta=$BETA \
55 | --gamma=$GAMMA \
56 | --hidden_norm=$HIDDEN_NORM
57 | fi
58 |
59 | if [ $LOSS_TYPE == 'nce' ] # JS, WPC, etc
60 | then
61 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}"
62 | echo $MODEL_DIR
63 | python run.py --train_mode=pretrain \
64 | --train_batch_size=$BATCH_SIZE \
65 | --train_epochs=$EPOCH \
66 | --temperature=$TEMP \
67 | --learning_rate=$LR \
68 | --learning_rate_scaling=$LR_SCALE \
69 | --weight_decay=$W_DECAY \
70 | --dataset=$DATASET \
71 | --image_size=$IMAGE_SIZE \
72 | --eval_split=validation \
73 | --data_dir=$DATA_DIR \
74 | --model_dir=$MODEL_DIR \
75 | --use_tpu=True \
76 | --tpu_name=$TPU_NAME \
77 | --train_summary_steps=0 \
78 | --sk_ratio $SK_RATIO \
79 | --width_multiplier $WIDTH_MUL \
80 | --resnet_depth $RESNET_DEPTH \
81 | --loss_type $LOSS_TYPE \
82 | --hidden_norm=$HIDDEN_NORM
83 | fi
84 |
85 | ##############################################################################################
86 | #####################Fine tune
87 | ##############################################################################################
88 | CHKPT_DIR=$MODEL_DIR
89 | FINETUNE_AFTER_BLOCK=0
90 | LR=0.16
91 | WD=0
92 | EPOCHS=90
93 | WARMUP_EPOCHS=0
94 | MODEL_DIR="${CHKPT_DIR}_ft_BS_${BATCH_SIZE}_FINETUNE_AFTER_BLOCK_${FINETUNE_AFTER_BLOCK}_LR_${LR}_WD_${WD}_EPOCH_${EPOCHS}_WARMUP_EPOCHS_${WARMUP_EPOCHS}"
95 | echo $MODEL_DIR
96 | if [ $LOSS_TYPE == "chi" ]
97 | then
98 | echo "Running chi"
99 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE --alpha $ALPHA --beta $BETA --gamma $GAMMA
100 | else
101 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE
102 | fi
103 |
--------------------------------------------------------------------------------
/vision/tpu_pretrain_finetune_resnet50.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | TPU_NAME= # enter your tpu name here
3 | STORAGE_BUCKET= # enter your storage bucket here
4 | DATA_DIR=$STORAGE_BUCKET/tensorflow_datasets
5 |
6 | # hyperparameters
7 | BATCH_SIZE=4096
8 | EPOCH=100
9 | TEMP=32
10 | LR=0.1
11 | LR_SCALE='sqrt'
12 | W_DECAY=1e-4
13 | DATASET='imagenet2012'
14 | IMAGE_SIZE=224
15 | SK_RATIO=0.0625
16 | WIDTH_MUL=1
17 | RESNET_DEPTH=50
18 |
19 | # different for different losses
20 | LOSS_TYPE='chi'
21 | HIDDEN_NORM=true
22 | if [ $LOSS_TYPE == 'chi' ]
23 | then
24 | # check hidden norm
25 | HIDDEN_NORM=false
26 | ALPHA=0.3
27 | BETA=0.01
28 | GAMMA=0.1
29 | if [ $TEMP -le 1 ]
30 | then
31 | exit 0
32 | fi
33 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}_ALPHA_${ALPHA}_BETA_${BETA}_GAMMA_${GAMMA}"
34 | echo $MODEL_DIR
35 | python run.py --train_mode=pretrain \
36 | --train_batch_size=$BATCH_SIZE \
37 | --train_epochs=$EPOCH \
38 | --temperature=$TEMP \
39 | --learning_rate=$LR \
40 | --learning_rate_scaling=$LR_SCALE \
41 | --weight_decay=$W_DECAY \
42 | --dataset=$DATASET \
43 | --image_size=$IMAGE_SIZE \
44 | --eval_split=validation \
45 | --data_dir=$DATA_DIR \
46 | --model_dir=$MODEL_DIR \
47 | --use_tpu=True \
48 | --tpu_name=$TPU_NAME \
49 | --train_summary_steps=0 \
50 | --sk_ratio $SK_RATIO \
51 | --width_multiplier $WIDTH_MUL \
52 | --resnet_depth $RESNET_DEPTH \
53 | --loss_type $LOSS_TYPE \
54 | --alpha=$ALPHA \
55 | --beta=$BETA \
56 | --gamma=$GAMMA \
57 | --hidden_norm=$HIDDEN_NORM
58 | fi
59 |
60 | if [ $LOSS_TYPE == 'nce' ] # JS, WPC, etc
61 | then
62 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}"
63 | echo $MODEL_DIR
64 | python run.py --train_mode=pretrain \
65 | --train_batch_size=$BATCH_SIZE \
66 | --train_epochs=$EPOCH \
67 | --temperature=$TEMP \
68 | --learning_rate=$LR \
69 | --learning_rate_scaling=$LR_SCALE \
70 | --weight_decay=$W_DECAY \
71 | --dataset=$DATASET \
72 | --image_size=$IMAGE_SIZE \
73 | --eval_split=validation \
74 | --data_dir=$DATA_DIR \
75 | --model_dir=$MODEL_DIR \
76 | --use_tpu=True \
77 | --tpu_name=$TPU_NAME \
78 | --train_summary_steps=0 \
79 | --sk_ratio $SK_RATIO \
80 | --width_multiplier $WIDTH_MUL \
81 | --resnet_depth $RESNET_DEPTH \
82 | --loss_type $LOSS_TYPE \
83 | --hidden_norm=$HIDDEN_NORM
84 | fi
85 |
86 | ##############################################################################################
87 | #####################Fine tune
88 | ##############################################################################################
89 | CHKPT_DIR=$MODEL_DIR
90 | FINETUNE_AFTER_BLOCK=0
91 | LR=0.16
92 | WD=0
93 | EPOCHS=90
94 | WARMUP_EPOCHS=0
95 | MODEL_DIR="${CHKPT_DIR}_ft_BS_${BATCH_SIZE}_FINETUNE_AFTER_BLOCK_${FINETUNE_AFTER_BLOCK}_LR_${LR}_WD_${WD}_EPOCH_${EPOCHS}_WARMUP_EPOCHS_${WARMUP_EPOCHS}"
96 | echo $MODEL_DIR
97 | if [ $LOSS_TYPE == "chi" ]
98 | then
99 | echo "Running chi"
100 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE --alpha $ALPHA --beta $BETA --gamma $GAMMA
101 | else
102 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE
103 | fi
104 |
--------------------------------------------------------------------------------