├── .gitignore ├── LICENSE ├── README.md ├── gaussian ├── Gaussian.ipynb ├── baselines_mi_gauss_results_128.pdf ├── baselines_mi_values_128.npz ├── baselines_mi_values_128_alphas.npz ├── baselines_mi_values_128_betas.npz ├── baselines_mi_values_128_gammas.npz ├── baselines_mi_values_cubic_128.npz ├── baselines_mi_values_cubic_128_alphas.npz ├── baselines_mi_values_cubic_128_betas.npz ├── baselines_mi_values_cubic_128_gammas.npz ├── baselines_train_objs_values_128.npz ├── baselines_train_objs_values_128_alphas.npz ├── baselines_train_objs_values_128_betas.npz ├── baselines_train_objs_values_128_gammas.npz ├── baselines_train_objs_values_cubic_128.npz ├── baselines_train_objs_values_cubic_128_alphas.npz ├── baselines_train_objs_values_cubic_128_betas.npz ├── baselines_train_objs_values_cubic_128_gammas.npz └── src │ ├── __init__.py │ ├── estimators.py │ ├── models.py │ ├── train.py │ └── utils.py └── vision ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data.py ├── data_util.py ├── gpu_pretrain_finetune_cifar.sh ├── lars_optimizer.py ├── model.py ├── model_util.py ├── objective.py ├── requirements.txt ├── resnet.py ├── results_summary.txt ├── run.py ├── tpu_pretrain_finetune_resnet128.sh └── tpu_pretrain_finetune_resnet50.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Martin Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Relative Predictive Coding 2 | Project page for paper [Self-supervised Representation Learning with Relative Predictive Coding](https://arxiv.org/abs/2103.11275). 3 | 4 | The codebase consists of the following parts: 5 | 6 | ## Vision Task 7 | 8 | Contrastive Self-Supervised Learning on visual tasks. 9 | See `vision` folder for more details. 10 | 11 | ## Gaussian Task 12 | 13 | Mutual Information Estimation on two Gaussians with varied correlation. 14 | See `gaussian/Gaussian.ipynb` for more details. 15 | 16 | ## Cite 17 | 18 | [RPC paper](https://arxiv.org/abs/2103.11275): 19 | 20 | ``` 21 | @article{tsai2021self, 22 | title={Self-supervised representation learning with relative predictive coding}, 23 | author={Tsai, Yao-Hung Hubert and Ma, Martin Q and Yang, Muqiao and Zhao, Han and Morency, Louis-Philippe and Salakhutdinov, Ruslan}, 24 | journal={arXiv preprint arXiv:2103.11275}, 25 | year={2021} 26 | } 27 | ``` 28 | 29 | -------------------------------------------------------------------------------- /gaussian/baselines_mi_gauss_results_128.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_gauss_results_128.pdf -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_128.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_128_alphas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128_alphas.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_128_betas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128_betas.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_128_gammas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_128_gammas.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_cubic_128.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_cubic_128_alphas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128_alphas.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_cubic_128_betas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128_betas.npz -------------------------------------------------------------------------------- /gaussian/baselines_mi_values_cubic_128_gammas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_mi_values_cubic_128_gammas.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_128.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_128_alphas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128_alphas.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_128_betas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128_betas.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_128_gammas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_128_gammas.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_cubic_128.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_cubic_128_alphas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128_alphas.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_cubic_128_betas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128_betas.npz -------------------------------------------------------------------------------- /gaussian/baselines_train_objs_values_cubic_128_gammas.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/baselines_train_objs_values_cubic_128_gammas.npz -------------------------------------------------------------------------------- /gaussian/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinmamql/relative_predictive_coding/6c0c66c7e1c35321b25212717a6182547e9ce00d/gaussian/src/__init__.py -------------------------------------------------------------------------------- /gaussian/src/estimators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import sys 7 | 8 | from functools import partial 9 | 10 | from src.utils import * 11 | 12 | import random 13 | 14 | def nwj_lower_bound_obj(scores): 15 | return tuba_lower_bound(scores - 1.) 16 | 17 | def mine_lower_bound(f, buffer=None, momentum=0.9): 18 | if buffer is None: 19 | buffer = torch.tensor(1.0).cuda() 20 | first_term = f.diag().mean() 21 | 22 | buffer_update = logmeanexp_nodiag(f).exp() 23 | with torch.no_grad(): 24 | second_term = logmeanexp_nodiag(f) 25 | buffer_new = buffer * momentum + buffer_update * (1 - momentum) 26 | buffer_new = torch.clamp(buffer_new, min=1e-4) 27 | third_term_no_grad = buffer_update / buffer_new 28 | 29 | third_term_grad = buffer_update / buffer_new 30 | 31 | return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update 32 | 33 | def chi_lower_bound_obj(f, alpha, beta, gamma): 34 | f_diag = f.diag() 35 | first_term = (f_diag - 0.5 * beta * (f_diag ** 2)).mean() 36 | n = f.size(0) 37 | f_offdiag = f.flatten()[1:].view(n-1, n+1)[:,:-1].reshape(n, n-1) 38 | # f_offdiag = f.masked_fill_(torch.eye(n, n).byte().cuda(), 0) 39 | second_term = (alpha * f_offdiag + 0.5 * gamma * (f_offdiag ** 2)).mean() 40 | return first_term - second_term 41 | 42 | def dv_upper_lower_bound_obj(f): 43 | """DV lower bound, but upper bounded by using log outside.""" 44 | first_term = f.diag().mean() 45 | second_term = logmeanexp_nodiag(f) 46 | 47 | return first_term - second_term 48 | 49 | def tuba_lower_bound(scores, log_baseline=None): 50 | if log_baseline is not None: 51 | scores -= log_baseline[:, None] 52 | batch_size = scores.size(0) 53 | 54 | # First term is an expectation over samples from the joint, 55 | # which are the diagonal elmements of the scores matrix. 56 | joint_term = scores.diag().mean() 57 | 58 | # Second term is an expectation over samples from the marginal, 59 | # which are the off-diagonal elements of the scores matrix. 60 | marg_term = logmeanexp_nodiag(scores).exp() 61 | return 1. + joint_term - marg_term 62 | 63 | def infonce_lower_bound_obj(scores): 64 | nll = scores.diag().mean() - scores.logsumexp(dim=1) 65 | # Alternative implementation: 66 | # nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size)) 67 | mi = torch.tensor(scores.size(0)).float().log() + nll 68 | mi = mi.mean() 69 | return mi 70 | 71 | def log_density_ratio_mi(f): 72 | return torch.log(torch.clamp(f, min=1e-4)).diag().mean() 73 | 74 | def direct_log_density_ratio_mi(f): 75 | return f.diag().mean() 76 | 77 | def dv_clip_upper_lower_bound(f, alpha=1.0, clip=None): 78 | z = renorm_q(f, alpha, clip) 79 | dv_clip = f.diag().mean() - z 80 | 81 | return dv_clip 82 | 83 | def js_fgan_lower_bound_obj(f): 84 | f_diag = f.diag() 85 | first_term = -F.softplus(-f_diag).mean() 86 | n = f.size(0) 87 | second_term = (torch.sum(F.softplus(f)) - 88 | torch.sum(F.softplus(f_diag))) / (n * (n - 1.)) 89 | return first_term - second_term 90 | 91 | 92 | def log_density_ratio_mi_chi(f, alpha, beta, gamma): 93 | f_diag = f.diag().mean() 94 | true_ratio = (f_diag * gamma + alpha) / (1. - f_diag * beta) 95 | return torch.log(torch.clamp(true_ratio, min=1e-4)) 96 | 97 | def MI_Estimator(f, train_type='nwj_lower_bound_obj', eval_type='nwj_lower_bound_obj', 98 | **kwargs): 99 | if train_type == 'tuba_lower_bound' or train_type == 'mine_lower_bound'\ 100 | or train_type == 'chi_lower_bound_obj': 101 | if train_type != 'chi_lower_bound_obj': 102 | assert train_type == eval_type 103 | train_val = getattr(sys.modules[__name__], train_type)(f, **kwargs) 104 | else: 105 | train_val = getattr(sys.modules[__name__], train_type)(f) 106 | if train_type == eval_type: 107 | return train_val, train_val 108 | 109 | if train_type == 'nwj_lower_bound_obj' and eval_type == 'direct_log_density_ratio_mi': 110 | eval_val = getattr(sys.modules[__name__], eval_type)(f-1.) 111 | elif eval_type == 'tuba_lower_bound' or eval_type == 'dv_clip_upper_lower_bound'\ 112 | or eval_type == 'mine_lower_bound' or eval_type == 'log_density_ratio_mi_chi': 113 | eval_val = getattr(sys.modules[__name__], eval_type)(f, **kwargs) 114 | # note that especially when we use JS to train, and use nwj to evaluate 115 | elif eval_type == 'nwj_lower_bound_obj': 116 | eval_val = getattr(sys.modules[__name__], eval_type)(f+1., **kwargs) 117 | else: 118 | eval_val = getattr(sys.modules[__name__], eval_type)(f) 119 | 120 | with torch.no_grad(): 121 | eval_train = eval_val - train_val 122 | 123 | return train_val + eval_train, train_val 124 | 125 | def nwj_lower_bound(f): 126 | return MI_Estimator(f, train_type='nwj_lower_bound_obj', eval_type='nwj_lower_bound_obj') 127 | 128 | def infonce_lower_bound(f): 129 | return MI_Estimator(f, train_type='infonce_lower_bound_obj', eval_type='infonce_lower_bound_obj') 130 | 131 | def js_lower_bound(f): 132 | return MI_Estimator(f, train_type='js_fgan_lower_bound_obj', eval_type='nwj_lower_bound_obj') 133 | 134 | def dv_upper_lower_bound(f): 135 | return MI_Estimator(f, train_type='dv_upper_lower_bound_obj', eval_type='dv_upper_lower_bound_obj') 136 | 137 | def smile_lower_bound(f, alpha=1.0, clip=5.0): 138 | return MI_Estimator(f, train_type='js_fgan_lower_bound_obj', 139 | eval_type='dv_clip_upper_lower_bound', alpha=alpha, clip=clip) 140 | 141 | def chi_lower_bound(f, alpha=0.01, beta = 0.005, gamma = 0.995): 142 | return MI_Estimator(f, train_type='chi_lower_bound_obj', eval_type='log_density_ratio_mi_chi', alpha=alpha, beta=beta, gamma=gamma) -------------------------------------------------------------------------------- /gaussian/src/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | def mlp(dim, hidden_dim, output_dim, layers, activation): 8 | activation = { 9 | 'relu': nn.ReLU, 10 | 'tanh': nn.Tanh, 11 | }[activation] 12 | 13 | seq = [nn.Linear(dim, hidden_dim), activation()] 14 | for _ in range(layers): 15 | seq += [nn.Linear(hidden_dim, hidden_dim), activation()] 16 | seq += [nn.Linear(hidden_dim, output_dim)] 17 | 18 | return nn.Sequential(*seq) 19 | 20 | class SeparableCritic(nn.Module): 21 | def __init__(self, dim, hidden_dim, embed_dim, layers, activation, **extra_kwargs): 22 | super(SeparableCritic, self).__init__() 23 | self._g = mlp(dim, hidden_dim, embed_dim, layers, activation) 24 | self._h = mlp(dim, hidden_dim, embed_dim, layers, activation) 25 | 26 | def forward(self, x, y): 27 | scores = torch.matmul(self._h(y), self._g(x).t()) 28 | return scores 29 | 30 | 31 | class ConcatCritic(nn.Module): 32 | def __init__(self, dim, hidden_dim, layers, activation, **extra_kwargs): 33 | super(ConcatCritic, self).__init__() 34 | self._f = mlp(dim * 2, hidden_dim, 1, layers, activation) 35 | 36 | def forward(self, x, y): 37 | batch_size = x.size(0) 38 | # Tile all possible combinations of x and y 39 | x_tiled = torch.stack([x] * batch_size, dim=0) 40 | y_tiled = torch.stack([y] * batch_size, dim=1) 41 | # xy is [batch_size * batch_size, x_dim + y_dim] 42 | xy_pairs = torch.reshape(torch.cat((x_tiled, y_tiled), dim=2), [ 43 | batch_size * batch_size, -1]) 44 | # Compute scores for each x_i, y_j pair. 45 | scores = self._f(xy_pairs) 46 | return torch.reshape(scores, [batch_size, batch_size]).t() 47 | -------------------------------------------------------------------------------- /gaussian/src/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | from src.utils import * 8 | from src.estimators import * 9 | from src.models import * 10 | 11 | def estimate_mutual_information(estimator, x, y, critic_fn, 12 | baseline_fn=None, alpha_logit=None, clamping_values=None, **kwargs): 13 | x, y = x.cuda(), y.cuda() 14 | scores = critic_fn(x, y) 15 | if clamping_values is not None: 16 | scores = torch.clamp(scores, min=clamping_values[0], max=clamping_values[1]) 17 | if baseline_fn is not None: 18 | # Some baselines' output is (batch_size, 1) which we remove here. 19 | log_baseline = torch.squeeze(baseline_fn(y)) 20 | if estimator == 'nwj': 21 | return nwj_lower_bound(scores) 22 | elif estimator == 'infonce': 23 | return infonce_lower_bound(scores) 24 | elif estimator == 'js': 25 | return js_lower_bound(scores) 26 | elif estimator == 'dv': 27 | return dv_upper_lower_bound(scores) 28 | elif estimator == 'smile': 29 | return smile_lower_bound(scores, **kwargs) 30 | elif estimator == 'gen-chi': 31 | return chi_lower_bound(scores, **kwargs) 32 | 33 | 34 | def train_estimator(critic_params, data_params, mi_params, opt_params, **kwargs): 35 | 36 | CRITICS = { 37 | 'separable': SeparableCritic, 38 | 'concat': ConcatCritic, 39 | } 40 | 41 | BASELINES = { 42 | 'constant': lambda: None, 43 | 'unnormalized': lambda: mlp(dim=data_params['dim'], \ 44 | hidden_dim=512, output_dim=1, layers=2, activation='relu').cuda(), 45 | } 46 | 47 | critic = CRITICS[mi_params.get('critic', 'separable')]( 48 | rho=None, **critic_params).cuda() 49 | baseline = BASELINES[mi_params.get('baseline', 'constant')]() 50 | 51 | opt_crit = optim.Adam(critic.parameters(), lr=opt_params['learning_rate']) 52 | if isinstance(baseline, nn.Module): 53 | opt_base = optim.Adam(baseline.parameters(), 54 | lr=opt_params['learning_rate']) 55 | else: 56 | opt_base = None 57 | 58 | def train_step(rho, data_params, mi_params): 59 | opt_crit.zero_grad() 60 | if isinstance(baseline, nn.Module): 61 | opt_base.zero_grad() 62 | 63 | if mi_params['critic'] == 'conditional': 64 | critic_ = CRITICS['conditional'](rho=rho).cuda() 65 | else: 66 | critic_ = critic 67 | 68 | x, y = sample_correlated_gaussian( 69 | dim=data_params['dim'], rho=rho,\ 70 | batch_size=data_params['batch_size'], cubic=data_params['cubic']) 71 | if False: 72 | mi, p_norm = estimate_mutual_information( 73 | mi_params['estimator'], x, y, critic_, baseline,\ 74 | mi_params.get('alpha_logit', None), **kwargs) 75 | else: 76 | mi, train_obj = estimate_mutual_information( 77 | mi_params['estimator'], x, y, critic_, baseline,\ 78 | mi_params.get('alpha_logit', None), **kwargs) 79 | loss = -mi 80 | 81 | loss.backward() 82 | opt_crit.step() 83 | if isinstance(baseline, nn.Module): 84 | opt_base.step() 85 | 86 | if False: 87 | return mi, p_norm 88 | else: 89 | return mi, train_obj 90 | 91 | mis = mi_schedule(opt_params['iterations']) 92 | rhos = mi_to_rho(data_params['dim'], mis) 93 | 94 | if False: 95 | estimates = [] 96 | p_norms = [] 97 | for i in range(opt_params['iterations']): 98 | mi, p_norm = train_step( 99 | rhos[i], data_params, mi_params) 100 | mi = mi.detach().cpu().numpy() 101 | p_norm = p_norm.detach().cpu().numpy() 102 | estimates.append(mi) 103 | p_norms.append(p_norm) 104 | 105 | return np.array(estimates), np.array(p_norms) 106 | else: 107 | estimates = [] 108 | train_objs = [] 109 | for i in range(opt_params['iterations']): 110 | mi, train_obj = train_step( 111 | rhos[i], data_params, mi_params) 112 | mi = mi.detach().cpu().numpy() 113 | train_obj = train_obj.detach().cpu().numpy() 114 | estimates.append(mi) 115 | train_objs.append(train_obj) 116 | 117 | return np.array(estimates), np.array(train_objs) 118 | -------------------------------------------------------------------------------- /gaussian/src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, cubic=None): 8 | """Generate samples from a correlated Gaussian distribution.""" 9 | x, eps = torch.chunk(torch.randn(batch_size, 2 * dim), 2, dim=1) 10 | y = rho * x + torch.sqrt(torch.tensor(1. - rho**2).float()) * eps 11 | 12 | if cubic is not None: 13 | y = y ** 3 14 | 15 | return x, y 16 | 17 | 18 | def rho_to_mi(dim, rho): 19 | return -0.5 * np.log(1-rho**2) * dim 20 | 21 | 22 | def mi_to_rho(dim, mi): 23 | return np.sqrt(1-np.exp(-2.0 / dim * mi)) 24 | 25 | 26 | def mi_schedule(n_iter): 27 | """Generate schedule for increasing correlation over time.""" 28 | mis = np.round(np.linspace(0.5, 5.5-1e-9, n_iter)) * 2.0 29 | return mis.astype(np.float32) 30 | 31 | def logmeanexp_diag(x): 32 | batch_size = x.size(0) 33 | 34 | logsumexp = torch.logsumexp(x.diag(), dim=(0,)) 35 | num_elem = batch_size 36 | 37 | return logsumexp - torch.log(torch.tensor(num_elem).float()).cuda() 38 | 39 | 40 | def logmeanexp_nodiag(x, dim=None, device='cuda'): 41 | batch_size = x.size(0) 42 | if dim is None: 43 | dim = (0, 1) 44 | 45 | logsumexp = torch.logsumexp( 46 | x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim) 47 | 48 | try: 49 | if len(dim) == 1: 50 | num_elem = batch_size - 1. 51 | else: 52 | num_elem = batch_size * (batch_size - 1.) 53 | except: 54 | num_elem = batch_size - 1 55 | return logsumexp - torch.log(torch.tensor(num_elem)).to(device) 56 | 57 | 58 | def renorm_q(f, alpha=1.0, clip=None): 59 | if clip is not None: 60 | f = torch.clamp(f * alpha, -clip, clip) 61 | z = logmeanexp_nodiag(f * alpha, dim=(0, 1)) 62 | return z 63 | 64 | 65 | def disc_renorm_q(f): 66 | batch_size = f.size(0) 67 | z = torch.zeros(1, requires_grad=True, device='cuda') 68 | 69 | opt = optim.SGD([z], lr=0.001) 70 | for i in range(10): 71 | opt.zero_grad() 72 | 73 | first_term = -F.softplus(z - f).diag().mean() 74 | st = -F.softplus(f - z) 75 | second_term = (st - st.diag().diag()).sum() / \ 76 | (batch_size * (batch_size - 1.)) 77 | total = first_term + second_term 78 | 79 | total.backward(retain_graph=True) 80 | opt.step() 81 | 82 | if total.item() <= -2 * np.log(2): 83 | break 84 | 85 | return z 86 | 87 | 88 | def renorm_p(f, alpha=1.0): 89 | z = logmeanexp_diag(-f * alpha) 90 | return z 91 | 92 | def estimate_p_norm(f, alpha=1.0): 93 | z = renorm_q(f, alpha) 94 | # f = renorm_p(f, alpha) 95 | # f = renorm_q(f, alpha) 96 | f = f - z 97 | f = -f 98 | 99 | return f.diag().exp().mean() 100 | -------------------------------------------------------------------------------- /vision/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | SimCLR needs to maintain permanent compatibility with the pre-trained model 4 | files, so we do not plan to make any major changes to this library (other than 5 | what was promised in the README). However, we can accept small patches related 6 | to re-factoring and documentation. To submit contributes, there are just a few 7 | small guidelines you need to follow. 8 | 9 | ## Contributor License Agreement 10 | 11 | Contributions to this project must be accompanied by a Contributor License 12 | Agreement. You (or your employer) retain the copyright to your contribution; 13 | this simply gives us permission to use and redistribute your contributions as 14 | part of the project. Head over to to see 15 | your current agreements on file or to sign a new one. 16 | 17 | You generally only need to submit a CLA once, so if you've already submitted one 18 | (even if it was for a different project), you probably don't need to do it 19 | again. 20 | 21 | ## Code reviews 22 | 23 | All submissions, including submissions by project members, require review. We 24 | use GitHub pull requests for this purpose. Consult 25 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 26 | information on using pull requests. 27 | 28 | ## Community Guidelines 29 | 30 | This project follows 31 | [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). 32 | -------------------------------------------------------------------------------- /vision/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /vision/README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised Representation Learning with Relative Predictive Coding 2 | 3 | This is a codebase for the vision experiments. 4 | 5 | ## Enviroment setup 6 | 7 | For CIFAR-10/CIFAR-100 experiments, our code can run on a *single* GPU. It does not support multi-GPUs, for reasons such as global BatchNorm and contrastive loss across cores. 8 | 9 | Our models are also trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining on ImageNet. 10 | 11 | We recommend using conda to avoid compatibility issue: 12 | ``` 13 | conda create -n rpc_vision python=3.6 14 | conda activate rpc_vision 15 | pip install -r requirements.txt 16 | conda install cudatoolkit cudnn 17 | ``` 18 | 19 | ## Pretraining and Fine-Tuning 20 | 21 | First create a checkpoint directory: 22 | 23 | ``` 24 | mkdir checkpoint 25 | ``` 26 | 27 | To pretrain and finetune the model on CIFAR-10 with a *single* GPU, try the following command: 28 | 29 | ``` 30 | bash gpu_pretrain_finetune_cifar.sh 31 | ``` 32 | 33 | For different hyper-parameter specification, please change the corresponding parameter in file `gpu_pretrain_finetune_cifar.sh`. 34 | 35 | To pretrain the model on ImageNet with Cloud TPUs, first check out the [Google Cloud TPU tutorial](https://cloud.google.com/tpu/docs/tutorials/mnist) for basic information on how to use Google Cloud TPUs. 36 | 37 | Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for [tensorflow_datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012), please set the following enviroment variables: 38 | 39 | ``` 40 | TPU_NAME= 41 | STORAGE_BUCKET=gs:// 42 | DATA_DIR=$STORAGE_BUCKET/ 43 | MODEL_DIR=$STORAGE_BUCKET/ 44 | ``` 45 | 46 | in the following files which pretrain and fine-tune a ResNet-50 or a ResNet-152 on ImageNet: 47 | 48 | ``` 49 | bash tpu_pretrain_finetune_resnet50.sh 50 | bash tpu_pretrain_finetune_resnet128.sh 51 | ``` 52 | 53 | To request checkpoints of the trained models from the commands above, please contact Martin via qianlim@andrew.cmu.edu. 54 | 55 | ## Cite 56 | 57 | [RPC paper](https://arxiv.org/abs/2103.11275): 58 | 59 | ``` 60 | @article{tsai2021self, 61 | title={Self-supervised representation learning with relative predictive coding}, 62 | author={Tsai, Yao-Hung Hubert and Ma, Martin Q and Yang, Muqiao and Zhao, Han and Morency, Louis-Philippe and Salakhutdinov, Ruslan}, 63 | journal={arXiv preprint arXiv:2103.11275}, 64 | year={2021} 65 | } 66 | ``` 67 | 68 | This code base is adapted from [SimCLR](https://github.com/google-research/simclr). The major change is in the file `objective.py` 69 | -------------------------------------------------------------------------------- /vision/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Data pipeline.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import functools 23 | from absl import flags 24 | 25 | import data_util as data_util 26 | import tensorflow.compat.v1 as tf 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | def pad_to_batch(dataset, batch_size): 32 | """Pad Tensors to specified batch size. 33 | 34 | Args: 35 | dataset: An instance of tf.data.Dataset. 36 | batch_size: The number of samples per batch of input requested. 37 | 38 | Returns: 39 | An instance of tf.data.Dataset that yields the same Tensors with the same 40 | structure as the original padded to batch_size along the leading 41 | dimension. 42 | 43 | Raises: 44 | ValueError: If the dataset does not comprise any tensors; if a tensor 45 | yielded by the dataset has an unknown number of dimensions or is a 46 | scalar; or if it can be statically determined that tensors comprising 47 | a single dataset element will have different leading dimensions. 48 | """ 49 | def _pad_to_batch(*args): 50 | """Given Tensors yielded by a Dataset, pads all to the batch size.""" 51 | flat_args = tf.nest.flatten(args) 52 | 53 | for tensor in flat_args: 54 | if tensor.shape.ndims is None: 55 | raise ValueError( 56 | 'Unknown number of dimensions for tensor %s.' % tensor.name) 57 | if tensor.shape.ndims == 0: 58 | raise ValueError('Tensor %s is a scalar.' % tensor.name) 59 | 60 | # This will throw if flat_args is empty. However, as of this writing, 61 | # tf.data.Dataset.map will throw first with an internal error, so we do 62 | # not check this case explicitly. 63 | first_tensor = flat_args[0] 64 | first_tensor_shape = tf.shape(first_tensor) 65 | first_tensor_batch_size = first_tensor_shape[0] 66 | difference = batch_size - first_tensor_batch_size 67 | 68 | for i, tensor in enumerate(flat_args): 69 | control_deps = [] 70 | if i != 0: 71 | # Check that leading dimensions of this tensor matches the first, 72 | # either statically or dynamically. (If the first dimensions of both 73 | # tensors are statically known, the we have to check the static 74 | # shapes at graph construction time or else we will never get to the 75 | # dynamic assertion.) 76 | if (first_tensor.shape[:1].is_fully_defined() and 77 | tensor.shape[:1].is_fully_defined()): 78 | if first_tensor.shape[0] != tensor.shape[0]: 79 | raise ValueError( 80 | 'Batch size of dataset tensors does not match. %s ' 81 | 'has shape %s, but %s has shape %s' % ( 82 | first_tensor.name, first_tensor.shape, 83 | tensor.name, tensor.shape)) 84 | else: 85 | curr_shape = tf.shape(tensor) 86 | control_deps = [tf.Assert( 87 | tf.equal(curr_shape[0], first_tensor_batch_size), 88 | ['Batch size of dataset tensors %s and %s do not match. ' 89 | 'Shapes are' % (tensor.name, first_tensor.name), curr_shape, 90 | first_tensor_shape])] 91 | 92 | with tf.control_dependencies(control_deps): 93 | # Pad to batch_size along leading dimension. 94 | flat_args[i] = tf.pad( 95 | tensor, [[0, difference]] + [[0, 0]] * (tensor.shape.ndims - 1)) 96 | flat_args[i].set_shape([batch_size] + tensor.shape.as_list()[1:]) 97 | 98 | return tf.nest.pack_sequence_as(args, flat_args) 99 | 100 | return dataset.map(_pad_to_batch) 101 | 102 | 103 | def build_input_fn(builder, is_training): 104 | """Build input function. 105 | 106 | Args: 107 | builder: TFDS builder for specified dataset. 108 | is_training: Whether to build in training mode. 109 | 110 | Returns: 111 | A function that accepts a dict of params and returns a tuple of images and 112 | features, to be used as the input_fn in TPUEstimator. 113 | """ 114 | def _input_fn(params): 115 | """Inner input function.""" 116 | preprocess_fn_pretrain = get_preprocess_fn(is_training, is_pretrain=True) 117 | preprocess_fn_finetune = get_preprocess_fn(is_training, is_pretrain=False) 118 | num_classes = builder.info.features['label'].num_classes 119 | 120 | def map_fn(image, label): 121 | """Produces multiple transformations of the same batch.""" 122 | if FLAGS.train_mode == 'pretrain': 123 | xs = [] 124 | for _ in range(2): # Two transformations 125 | xs.append(preprocess_fn_pretrain(image)) 126 | image = tf.concat(xs, -1) 127 | label = tf.zeros([num_classes]) 128 | else: 129 | image = preprocess_fn_finetune(image) 130 | label = tf.one_hot(label, num_classes) 131 | return image, label, 1.0 132 | 133 | dataset = builder.as_dataset( 134 | split=FLAGS.train_split if is_training else FLAGS.eval_split, 135 | shuffle_files=is_training, as_supervised=True) 136 | if FLAGS.cache_dataset: 137 | dataset = dataset.cache() 138 | if is_training: 139 | buffer_multiplier = 50 if FLAGS.image_size <= 32 else 10 140 | dataset = dataset.shuffle(params['batch_size'] * buffer_multiplier) 141 | dataset = dataset.repeat(-1) 142 | dataset = dataset.map(map_fn, 143 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 144 | dataset = dataset.batch(params['batch_size'], drop_remainder=is_training) 145 | dataset = pad_to_batch(dataset, params['batch_size']) 146 | images, labels, mask = tf.data.make_one_shot_iterator(dataset).get_next() 147 | 148 | return images, {'labels': labels, 'mask': mask} 149 | return _input_fn 150 | 151 | 152 | def get_preprocess_fn(is_training, is_pretrain): 153 | """Get function that accepts an image and returns a preprocessed image.""" 154 | # Disable test cropping for small images (e.g. CIFAR) 155 | if FLAGS.image_size <= 32: 156 | test_crop = False 157 | else: 158 | test_crop = True 159 | return functools.partial( 160 | data_util.preprocess_image, 161 | height=FLAGS.image_size, 162 | width=FLAGS.image_size, 163 | is_training=is_training, 164 | color_distort=is_pretrain, 165 | test_crop=test_crop) 166 | -------------------------------------------------------------------------------- /vision/data_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Data preprocessing and augmentation.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import functools 23 | from absl import flags 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | CROP_PROPORTION = 0.875 # Standard for ImageNet. 30 | 31 | 32 | def random_apply(func, p, x): 33 | """Randomly apply function func to x with probability p.""" 34 | return tf.cond( 35 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), 36 | tf.cast(p, tf.float32)), 37 | lambda: func(x), 38 | lambda: x) 39 | 40 | 41 | def random_brightness(image, max_delta, impl='simclrv2'): 42 | """A multiplicative vs additive change of brightness.""" 43 | if impl == 'simclrv2': 44 | factor = tf.random_uniform( 45 | [], tf.maximum(1.0 - max_delta, 0), 1.0 + max_delta) 46 | image = image * factor 47 | elif impl == 'simclrv1': 48 | image = random_brightness(image, max_delta=max_delta) 49 | else: 50 | raise ValueError('Unknown impl {} for random brightness.'.format(impl)) 51 | return image 52 | 53 | 54 | def to_grayscale(image, keep_channels=True): 55 | image = tf.image.rgb_to_grayscale(image) 56 | if keep_channels: 57 | image = tf.tile(image, [1, 1, 3]) 58 | return image 59 | 60 | 61 | def color_jitter(image, 62 | strength, 63 | random_order=True): 64 | """Distorts the color of the image. 65 | 66 | Args: 67 | image: The input image tensor. 68 | strength: the floating number for the strength of the color augmentation. 69 | random_order: A bool, specifying whether to randomize the jittering order. 70 | 71 | Returns: 72 | The distorted image tensor. 73 | """ 74 | brightness = 0.8 * strength 75 | contrast = 0.8 * strength 76 | saturation = 0.8 * strength 77 | hue = 0.2 * strength 78 | if random_order: 79 | return color_jitter_rand(image, brightness, contrast, saturation, hue) 80 | else: 81 | return color_jitter_nonrand(image, brightness, contrast, saturation, hue) 82 | 83 | 84 | def color_jitter_nonrand(image, brightness=0, contrast=0, saturation=0, hue=0): 85 | """Distorts the color of the image (jittering order is fixed). 86 | 87 | Args: 88 | image: The input image tensor. 89 | brightness: A float, specifying the brightness for color jitter. 90 | contrast: A float, specifying the contrast for color jitter. 91 | saturation: A float, specifying the saturation for color jitter. 92 | hue: A float, specifying the hue for color jitter. 93 | 94 | Returns: 95 | The distorted image tensor. 96 | """ 97 | with tf.name_scope('distort_color'): 98 | def apply_transform(i, x, brightness, contrast, saturation, hue): 99 | """Apply the i-th transformation.""" 100 | if brightness != 0 and i == 0: 101 | x = random_brightness(x, max_delta=brightness) 102 | elif contrast != 0 and i == 1: 103 | x = tf.image.random_contrast( 104 | x, lower=1-contrast, upper=1+contrast) 105 | elif saturation != 0 and i == 2: 106 | x = tf.image.random_saturation( 107 | x, lower=1-saturation, upper=1+saturation) 108 | elif hue != 0: 109 | x = tf.image.random_hue(x, max_delta=hue) 110 | return x 111 | 112 | for i in range(4): 113 | image = apply_transform(i, image, brightness, contrast, saturation, hue) 114 | image = tf.clip_by_value(image, 0., 1.) 115 | return image 116 | 117 | 118 | def color_jitter_rand(image, brightness=0, contrast=0, saturation=0, hue=0): 119 | """Distorts the color of the image (jittering order is random). 120 | 121 | Args: 122 | image: The input image tensor. 123 | brightness: A float, specifying the brightness for color jitter. 124 | contrast: A float, specifying the contrast for color jitter. 125 | saturation: A float, specifying the saturation for color jitter. 126 | hue: A float, specifying the hue for color jitter. 127 | 128 | Returns: 129 | The distorted image tensor. 130 | """ 131 | with tf.name_scope('distort_color'): 132 | def apply_transform(i, x): 133 | """Apply the i-th transformation.""" 134 | def brightness_foo(): 135 | if brightness == 0: 136 | return x 137 | else: 138 | return random_brightness(x, max_delta=brightness) 139 | def contrast_foo(): 140 | if contrast == 0: 141 | return x 142 | else: 143 | return tf.image.random_contrast(x, lower=1-contrast, upper=1+contrast) 144 | def saturation_foo(): 145 | if saturation == 0: 146 | return x 147 | else: 148 | return tf.image.random_saturation( 149 | x, lower=1-saturation, upper=1+saturation) 150 | def hue_foo(): 151 | if hue == 0: 152 | return x 153 | else: 154 | return tf.image.random_hue(x, max_delta=hue) 155 | x = tf.cond(tf.less(i, 2), 156 | lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), 157 | lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo)) 158 | return x 159 | 160 | perm = tf.random_shuffle(tf.range(4)) 161 | for i in range(4): 162 | image = apply_transform(perm[i], image) 163 | image = tf.clip_by_value(image, 0., 1.) 164 | return image 165 | 166 | 167 | def _compute_crop_shape( 168 | image_height, image_width, aspect_ratio, crop_proportion): 169 | """Compute aspect ratio-preserving shape for central crop. 170 | 171 | The resulting shape retains `crop_proportion` along one side and a proportion 172 | less than or equal to `crop_proportion` along the other side. 173 | 174 | Args: 175 | image_height: Height of image to be cropped. 176 | image_width: Width of image to be cropped. 177 | aspect_ratio: Desired aspect ratio (width / height) of output. 178 | crop_proportion: Proportion of image to retain along the less-cropped side. 179 | 180 | Returns: 181 | crop_height: Height of image after cropping. 182 | crop_width: Width of image after cropping. 183 | """ 184 | image_width_float = tf.cast(image_width, tf.float32) 185 | image_height_float = tf.cast(image_height, tf.float32) 186 | 187 | def _requested_aspect_ratio_wider_than_image(): 188 | crop_height = tf.cast(tf.rint( 189 | crop_proportion / aspect_ratio * image_width_float), tf.int32) 190 | crop_width = tf.cast(tf.rint( 191 | crop_proportion * image_width_float), tf.int32) 192 | return crop_height, crop_width 193 | 194 | def _image_wider_than_requested_aspect_ratio(): 195 | crop_height = tf.cast( 196 | tf.rint(crop_proportion * image_height_float), tf.int32) 197 | crop_width = tf.cast(tf.rint( 198 | crop_proportion * aspect_ratio * 199 | image_height_float), tf.int32) 200 | return crop_height, crop_width 201 | 202 | return tf.cond( 203 | aspect_ratio > image_width_float / image_height_float, 204 | _requested_aspect_ratio_wider_than_image, 205 | _image_wider_than_requested_aspect_ratio) 206 | 207 | 208 | def center_crop(image, height, width, crop_proportion): 209 | """Crops to center of image and rescales to desired size. 210 | 211 | Args: 212 | image: Image Tensor to crop. 213 | height: Height of image to be cropped. 214 | width: Width of image to be cropped. 215 | crop_proportion: Proportion of image to retain along the less-cropped side. 216 | 217 | Returns: 218 | A `height` x `width` x channels Tensor holding a central crop of `image`. 219 | """ 220 | shape = tf.shape(image) 221 | image_height = shape[0] 222 | image_width = shape[1] 223 | crop_height, crop_width = _compute_crop_shape( 224 | image_height, image_width, height / width, crop_proportion) 225 | offset_height = ((image_height - crop_height) + 1) // 2 226 | offset_width = ((image_width - crop_width) + 1) // 2 227 | image = tf.image.crop_to_bounding_box( 228 | image, offset_height, offset_width, crop_height, crop_width) 229 | 230 | image = tf.image.resize_bicubic([image], [height, width])[0] 231 | 232 | return image 233 | 234 | 235 | def distorted_bounding_box_crop(image, 236 | bbox, 237 | min_object_covered=0.1, 238 | aspect_ratio_range=(0.75, 1.33), 239 | area_range=(0.05, 1.0), 240 | max_attempts=100, 241 | scope=None): 242 | """Generates cropped_image using one of the bboxes randomly distorted. 243 | 244 | See `tf.image.sample_distorted_bounding_box` for more documentation. 245 | 246 | Args: 247 | image: `Tensor` of image data. 248 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 249 | where each coordinate is [0, 1) and the coordinates are arranged 250 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 251 | image. 252 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 253 | area of the image must contain at least this fraction of any bounding 254 | box supplied. 255 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 256 | image must have an aspect ratio = width / height within this range. 257 | area_range: An optional list of `float`s. The cropped area of the image 258 | must contain a fraction of the supplied image within in this range. 259 | max_attempts: An optional `int`. Number of attempts at generating a cropped 260 | region of the image of the specified constraints. After `max_attempts` 261 | failures, return the entire image. 262 | scope: Optional `str` for name scope. 263 | Returns: 264 | (cropped image `Tensor`, distorted bbox `Tensor`). 265 | """ 266 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 267 | shape = tf.shape(image) 268 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 269 | shape, 270 | bounding_boxes=bbox, 271 | min_object_covered=min_object_covered, 272 | aspect_ratio_range=aspect_ratio_range, 273 | area_range=area_range, 274 | max_attempts=max_attempts, 275 | use_image_if_no_bounding_boxes=True) 276 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 277 | 278 | # Crop the image to the specified bounding box. 279 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 280 | target_height, target_width, _ = tf.unstack(bbox_size) 281 | image = tf.image.crop_to_bounding_box( 282 | image, offset_y, offset_x, target_height, target_width) 283 | 284 | return image 285 | 286 | 287 | def crop_and_resize(image, height, width): 288 | """Make a random crop and resize it to height `height` and width `width`. 289 | 290 | Args: 291 | image: Tensor representing the image. 292 | height: Desired image height. 293 | width: Desired image width. 294 | 295 | Returns: 296 | A `height` x `width` x channels Tensor holding a random crop of `image`. 297 | """ 298 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 299 | aspect_ratio = width / height 300 | image = distorted_bounding_box_crop( 301 | image, 302 | bbox, 303 | min_object_covered=0.1, 304 | aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio), 305 | area_range=(0.08, 1.0), 306 | max_attempts=100, 307 | scope=None) 308 | return tf.image.resize_bicubic([image], [height, width])[0] 309 | 310 | 311 | def gaussian_blur(image, kernel_size, sigma, padding='SAME'): 312 | """Blurs the given image with separable convolution. 313 | 314 | 315 | Args: 316 | image: Tensor of shape [height, width, channels] and dtype float to blur. 317 | kernel_size: Integer Tensor for the size of the blur kernel. This is should 318 | be an odd number. If it is an even number, the actual kernel size will be 319 | size + 1. 320 | sigma: Sigma value for gaussian operator. 321 | padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'. 322 | 323 | Returns: 324 | A Tensor representing the blurred image. 325 | """ 326 | radius = tf.to_int32(kernel_size / 2) 327 | kernel_size = radius * 2 + 1 328 | x = tf.to_float(tf.range(-radius, radius + 1)) 329 | blur_filter = tf.exp( 330 | -tf.pow(x, 2.0) / (2.0 * tf.pow(tf.to_float(sigma), 2.0))) 331 | blur_filter /= tf.reduce_sum(blur_filter) 332 | # One vertical and one horizontal filter. 333 | blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1]) 334 | blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1]) 335 | num_channels = tf.shape(image)[-1] 336 | blur_h = tf.tile(blur_h, [1, 1, num_channels, 1]) 337 | blur_v = tf.tile(blur_v, [1, 1, num_channels, 1]) 338 | expand_batch_dim = image.shape.ndims == 3 339 | if expand_batch_dim: 340 | # Tensorflow requires batched input to convolutions, which we can fake with 341 | # an extra dimension. 342 | image = tf.expand_dims(image, axis=0) 343 | blurred = tf.nn.depthwise_conv2d( 344 | image, blur_h, strides=[1, 1, 1, 1], padding=padding) 345 | blurred = tf.nn.depthwise_conv2d( 346 | blurred, blur_v, strides=[1, 1, 1, 1], padding=padding) 347 | if expand_batch_dim: 348 | blurred = tf.squeeze(blurred, axis=0) 349 | return blurred 350 | 351 | 352 | def random_crop_with_resize(image, height, width, p=1.0): 353 | """Randomly crop and resize an image. 354 | 355 | Args: 356 | image: `Tensor` representing an image of arbitrary size. 357 | height: Height of output image. 358 | width: Width of output image. 359 | p: Probability of applying this transformation. 360 | 361 | Returns: 362 | A preprocessed image `Tensor`. 363 | """ 364 | def _transform(image): # pylint: disable=missing-docstring 365 | image = crop_and_resize(image, height, width) 366 | return image 367 | return random_apply(_transform, p=p, x=image) 368 | 369 | 370 | def random_color_jitter(image, p=1.0): 371 | def _transform(image): 372 | color_jitter_t = functools.partial( 373 | color_jitter, strength=FLAGS.color_jitter_strength) 374 | image = random_apply(color_jitter_t, p=0.8, x=image) 375 | return random_apply(to_grayscale, p=0.2, x=image) 376 | return random_apply(_transform, p=p, x=image) 377 | 378 | 379 | def random_blur(image, height, width, p=1.0): 380 | """Randomly blur an image. 381 | 382 | Args: 383 | image: `Tensor` representing an image of arbitrary size. 384 | height: Height of output image. 385 | width: Width of output image. 386 | p: probability of applying this transformation. 387 | 388 | Returns: 389 | A preprocessed image `Tensor`. 390 | """ 391 | del width 392 | def _transform(image): 393 | sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32) 394 | return gaussian_blur( 395 | image, kernel_size=height//10, sigma=sigma, padding='SAME') 396 | return random_apply(_transform, p=p, x=image) 397 | 398 | 399 | def batch_random_blur(images_list, height, width, blur_probability=0.5): 400 | """Apply efficient batch data transformations. 401 | 402 | Args: 403 | images_list: a list of image tensors. 404 | height: the height of image. 405 | width: the width of image. 406 | blur_probability: the probaility to apply the blur operator. 407 | 408 | Returns: 409 | Preprocessed feature list. 410 | """ 411 | def generate_selector(p, bsz): 412 | shape = [bsz, 1, 1, 1] 413 | selector = tf.cast( 414 | tf.less(tf.random_uniform(shape, 0, 1, dtype=tf.float32), p), 415 | tf.float32) 416 | return selector 417 | 418 | new_images_list = [] 419 | for images in images_list: 420 | images_new = random_blur(images, height, width, p=1.) 421 | selector = generate_selector(blur_probability, tf.shape(images)[0]) 422 | images = images_new * selector + images * (1 - selector) 423 | images = tf.clip_by_value(images, 0., 1.) 424 | new_images_list.append(images) 425 | 426 | return new_images_list 427 | 428 | 429 | def preprocess_for_train(image, height, width, 430 | color_distort=True, crop=True, flip=True): 431 | """Preprocesses the given image for training. 432 | 433 | Args: 434 | image: `Tensor` representing an image of arbitrary size. 435 | height: Height of output image. 436 | width: Width of output image. 437 | color_distort: Whether to apply the color distortion. 438 | crop: Whether to crop the image. 439 | flip: Whether or not to flip left and right of an image. 440 | 441 | Returns: 442 | A preprocessed image `Tensor`. 443 | """ 444 | if crop: 445 | image = random_crop_with_resize(image, height, width) 446 | if flip: 447 | image = tf.image.random_flip_left_right(image) 448 | if color_distort: 449 | image = random_color_jitter(image) 450 | image = tf.reshape(image, [height, width, 3]) 451 | image = tf.clip_by_value(image, 0., 1.) 452 | return image 453 | 454 | 455 | def preprocess_for_eval(image, height, width, crop=True): 456 | """Preprocesses the given image for evaluation. 457 | 458 | Args: 459 | image: `Tensor` representing an image of arbitrary size. 460 | height: Height of output image. 461 | width: Width of output image. 462 | crop: Whether or not to (center) crop the test images. 463 | 464 | Returns: 465 | A preprocessed image `Tensor`. 466 | """ 467 | if crop: 468 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION) 469 | image = tf.reshape(image, [height, width, 3]) 470 | image = tf.clip_by_value(image, 0., 1.) 471 | return image 472 | 473 | 474 | def preprocess_image(image, height, width, is_training=False, 475 | color_distort=True, test_crop=True): 476 | """Preprocesses the given image. 477 | 478 | Args: 479 | image: `Tensor` representing an image of arbitrary size. 480 | height: Height of output image. 481 | width: Width of output image. 482 | is_training: `bool` for whether the preprocessing is for training. 483 | color_distort: whether to apply the color distortion. 484 | test_crop: whether or not to extract a central crop of the images 485 | (as for standard ImageNet evaluation) during the evaluation. 486 | 487 | Returns: 488 | A preprocessed image `Tensor` of range [0, 1]. 489 | """ 490 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 491 | if is_training: 492 | return preprocess_for_train(image, height, width, color_distort) 493 | else: 494 | return preprocess_for_eval(image, height, width, test_crop) 495 | -------------------------------------------------------------------------------- /vision/gpu_pretrain_finetune_cifar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # hyperparameters 3 | STORAGE_BUCKET="checkpoint" 4 | BATCH_SIZE=512 5 | EPOCH=100 6 | TEMP=128 7 | LR=1.0 8 | LR_SCALE='linear' 9 | W_DECAY=1e-4 10 | DATASET='cifar10' 11 | IMAGE_SIZE=32 12 | SK_RATIO=0.0 13 | WIDTH_MUL=1 14 | RESNET_DEPTH=18 15 | 16 | # different for different losses 17 | LOSS_TYPE='chi' 18 | HIDDEN_NORM=True 19 | if [ $LOSS_TYPE == 'chi' ] 20 | then 21 | # check hidden norm 22 | HIDDEN_NORM=False 23 | ALPHA=0.0 24 | BETA=0.001 25 | GAMMA=1.0 26 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}_ALPHA_${ALPHA}_BETA_${BETA}_GAMMA_${GAMMA}" 27 | CUDA_VISIBLE_DEVICES=0 python run.py --train_mode=pretrain \ 28 | --train_batch_size=$BATCH_SIZE \ 29 | --train_epochs=$EPOCH \ 30 | --temperature=$TEMP \ 31 | --learning_rate=$LR \ 32 | --learning_rate_scaling=$LR_SCALE \ 33 | --weight_decay=$W_DECAY \ 34 | --dataset=$DATASET \ 35 | --image_size=$IMAGE_SIZE \ 36 | --eval_split=test \ 37 | --model_dir=$MODEL_DIR \ 38 | --use_tpu=False \ 39 | --train_summary_steps=0 \ 40 | --sk_ratio $SK_RATIO \ 41 | --width_multiplier $WIDTH_MUL \ 42 | --resnet_depth $RESNET_DEPTH \ 43 | --loss_type $LOSS_TYPE \ 44 | --alpha=$ALPHA \ 45 | --beta=$BETA \ 46 | --gamma=$GAMMA \ 47 | --hidden_norm=$HIDDEN_NORM 48 | 49 | # NCE, JS, WPC, etc 50 | else 51 | TEMP=0.5 52 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}" 53 | CUDA_VISIBLE_DEVICES=0 python run.py --train_mode=pretrain \ 54 | --train_batch_size=$BATCH_SIZE \ 55 | --train_epochs=$EPOCH \ 56 | --temperature=$TEMP \ 57 | --learning_rate=$LR \ 58 | --learning_rate_scaling=$LR_SCALE \ 59 | --weight_decay=$W_DECAY \ 60 | --dataset=$DATASET \ 61 | --image_size=$IMAGE_SIZE \ 62 | --eval_split=test \ 63 | --model_dir=$MODEL_DIR \ 64 | --use_tpu=False \ 65 | --train_summary_steps=0 \ 66 | --sk_ratio $SK_RATIO \ 67 | --width_multiplier $WIDTH_MUL \ 68 | --resnet_depth $RESNET_DEPTH \ 69 | --loss_type $LOSS_TYPE \ 70 | --hidden_norm=$HIDDEN_NORM 71 | 72 | fi 73 | 74 | ############################################################################################## 75 | #####################Fine tune 76 | ############################################################################################## 77 | CHKPT_DIR=$MODEL_DIR 78 | FINETUNE_AFTER_BLOCK=0 79 | LR=0.1 80 | WD=0 81 | EPOCHS=100 82 | WARMUP_EPOCHS=0 83 | MODEL_DIR="${CHKPT_DIR}_ft_BS_${BATCH_SIZE}_FINETUNE_AFTER_BLOCK_${FINETUNE_AFTER_BLOCK}_LR_${LR}_WD_${WD}_EPOCH_${EPOCHS}_WARMUP_EPOCHS_${WARMUP_EPOCHS}" 84 | echo $MODEL_DIR 85 | if [ $LOSS_TYPE == "chi" ] 86 | then 87 | echo "Running chi" 88 | CUDA_VISIBLE_DEVICES=0 python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=$DATASET --image_size=$IMAGE_SIZE --eval_split=test --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=False --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE --alpha $ALPHA --beta $BETA --gamma $GAMMA 89 | else 90 | CUDA_VISIBLE_DEVICES=0 python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=$DATASET --image_size=$IMAGE_SIZE --eval_split=test --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=False --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE 91 | fi 92 | -------------------------------------------------------------------------------- /vision/lars_optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Functions and classes related to optimization (weight updates).""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import re 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | EETA_DEFAULT = 0.001 27 | 28 | 29 | class LARSOptimizer(tf.train.Optimizer): 30 | """Layer-wise Adaptive Rate Scaling for large batch training. 31 | 32 | Introduced by "Large Batch Training of Convolutional Networks" by Y. You, 33 | I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888) 34 | """ 35 | 36 | def __init__(self, 37 | learning_rate, 38 | momentum=0.9, 39 | use_nesterov=False, 40 | weight_decay=0.0, 41 | exclude_from_weight_decay=None, 42 | exclude_from_layer_adaptation=None, 43 | classic_momentum=True, 44 | eeta=EETA_DEFAULT, 45 | name="LARSOptimizer"): 46 | """Constructs a LARSOptimizer. 47 | 48 | Args: 49 | learning_rate: A `float` for learning rate. 50 | momentum: A `float` for momentum. 51 | use_nesterov: A 'Boolean' for whether to use nesterov momentum. 52 | weight_decay: A `float` for weight decay. 53 | exclude_from_weight_decay: A list of `string` for variable screening, if 54 | any of the string appears in a variable's name, the variable will be 55 | excluded for computing weight decay. For example, one could specify 56 | the list like ['batch_normalization', 'bias'] to exclude BN and bias 57 | from weight decay. 58 | exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but 59 | for layer adaptation. If it is None, it will be defaulted the same as 60 | exclude_from_weight_decay. 61 | classic_momentum: A `boolean` for whether to use classic (or popular) 62 | momentum. The learning rate is applied during momeuntum update in 63 | classic momentum, but after momentum for popular momentum. 64 | eeta: A `float` for scaling of learning rate when computing trust ratio. 65 | name: The name for the scope. 66 | """ 67 | super(LARSOptimizer, self).__init__(False, name) 68 | 69 | self.learning_rate = learning_rate 70 | self.momentum = momentum 71 | self.weight_decay = weight_decay 72 | self.use_nesterov = use_nesterov 73 | self.classic_momentum = classic_momentum 74 | self.eeta = eeta 75 | self.exclude_from_weight_decay = exclude_from_weight_decay 76 | # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the 77 | # arg is None. 78 | if exclude_from_layer_adaptation: 79 | self.exclude_from_layer_adaptation = exclude_from_layer_adaptation 80 | else: 81 | self.exclude_from_layer_adaptation = exclude_from_weight_decay 82 | 83 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 84 | if global_step is None: 85 | global_step = tf.train.get_or_create_global_step() 86 | new_global_step = global_step + 1 87 | 88 | assignments = [] 89 | for (grad, param) in grads_and_vars: 90 | if grad is None or param is None: 91 | continue 92 | 93 | param_name = param.op.name 94 | 95 | v = tf.get_variable( 96 | name=param_name + "/Momentum", 97 | shape=param.shape.as_list(), 98 | dtype=tf.float32, 99 | trainable=False, 100 | initializer=tf.zeros_initializer()) 101 | 102 | if self._use_weight_decay(param_name): 103 | grad += self.weight_decay * param 104 | 105 | if self.classic_momentum: 106 | trust_ratio = 1.0 107 | if self._do_layer_adaptation(param_name): 108 | w_norm = tf.norm(param, ord=2) 109 | g_norm = tf.norm(grad, ord=2) 110 | trust_ratio = tf.where( 111 | tf.greater(w_norm, 0), tf.where( 112 | tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 113 | 1.0), 114 | 1.0) 115 | scaled_lr = self.learning_rate * trust_ratio 116 | 117 | next_v = tf.multiply(self.momentum, v) + scaled_lr * grad 118 | if self.use_nesterov: 119 | update = tf.multiply(self.momentum, next_v) + scaled_lr * grad 120 | else: 121 | update = next_v 122 | next_param = param - update 123 | else: 124 | next_v = tf.multiply(self.momentum, v) + grad 125 | if self.use_nesterov: 126 | update = tf.multiply(self.momentum, next_v) + grad 127 | else: 128 | update = next_v 129 | 130 | trust_ratio = 1.0 131 | if self._do_layer_adaptation(param_name): 132 | w_norm = tf.norm(param, ord=2) 133 | v_norm = tf.norm(update, ord=2) 134 | trust_ratio = tf.where( 135 | tf.greater(w_norm, 0), tf.where( 136 | tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 137 | 1.0), 138 | 1.0) 139 | scaled_lr = trust_ratio * self.learning_rate 140 | next_param = param - scaled_lr * update 141 | 142 | assignments.extend( 143 | [param.assign(next_param), 144 | v.assign(next_v), 145 | global_step.assign(new_global_step)]) 146 | return tf.group(*assignments, name=name) 147 | 148 | def _use_weight_decay(self, param_name): 149 | """Whether to use L2 weight decay for `param_name`.""" 150 | if not self.weight_decay: 151 | return False 152 | if self.exclude_from_weight_decay: 153 | for r in self.exclude_from_weight_decay: 154 | if re.search(r, param_name) is not None: 155 | return False 156 | return True 157 | 158 | def _do_layer_adaptation(self, param_name): 159 | """Whether to do layer-wise learning rate adaptation for `param_name`.""" 160 | if self.exclude_from_layer_adaptation: 161 | for r in self.exclude_from_layer_adaptation: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | -------------------------------------------------------------------------------- /vision/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Model specification for SimCLR.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | 24 | import data_util as data_util 25 | import model_util as model_util 26 | import objective as obj_lib 27 | 28 | import tensorflow.compat.v1 as tf 29 | import tensorflow.compat.v2 as tf2 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | def build_model_fn(model, num_classes, num_train_examples): 35 | """Build model function.""" 36 | def model_fn(features, labels, mode, params=None): 37 | """Build model and optimizer.""" 38 | is_training = mode == tf.estimator.ModeKeys.TRAIN 39 | 40 | # Check training mode. 41 | if FLAGS.train_mode == 'pretrain': 42 | num_transforms = 2 43 | if FLAGS.fine_tune_after_block > -1: 44 | raise ValueError('Does not support layer freezing during pretraining,' 45 | 'should set fine_tune_after_block<=-1 for safety.') 46 | elif FLAGS.train_mode == 'finetune': 47 | num_transforms = 1 48 | else: 49 | raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode)) 50 | 51 | # Split channels, and optionally apply extra batched augmentation. 52 | features_list = tf.split( 53 | features, num_or_size_splits=num_transforms, axis=-1) 54 | if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain': 55 | features_list = data_util.batch_random_blur( 56 | features_list, FLAGS.image_size, FLAGS.image_size) 57 | features = tf.concat(features_list, 0) # (num_transforms * bsz, h, w, c) 58 | 59 | # Base network forward pass. 60 | with tf.variable_scope('base_model'): 61 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4: 62 | # Finetune just supervised (linear) head will not update BN stats. 63 | model_train_mode = False 64 | else: 65 | # Pretrain or finetuen anything else will update BN stats. 66 | model_train_mode = is_training 67 | hiddens = model(features, is_training=model_train_mode) 68 | 69 | # Add head and loss. 70 | if FLAGS.train_mode == 'pretrain': 71 | tpu_context = params['context'] if 'context' in params else None 72 | hiddens_proj = model_util.projection_head(hiddens, is_training) 73 | contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss( 74 | hiddens_proj, 75 | hidden_norm=FLAGS.hidden_norm, 76 | temperature=FLAGS.temperature, 77 | tpu_context=tpu_context if is_training else None, 78 | loss_type=FLAGS.loss_type, 79 | flags=FLAGS) 80 | logits_sup = tf.zeros([params['batch_size'], num_classes]) 81 | gradients_penalty = FLAGS.gradient_penalty_weight * obj_lib.add_gradients_penalty(features, model, model_train_mode) 82 | else: 83 | contrast_loss = tf.zeros([]) 84 | logits_con = tf.zeros([params['batch_size'], 10]) 85 | labels_con = tf.zeros([params['batch_size'], 10]) 86 | hiddens = model_util.projection_head(hiddens, is_training) 87 | logits_sup = model_util.supervised_head( 88 | hiddens, num_classes, is_training) 89 | obj_lib.add_supervised_loss( 90 | labels=labels['labels'], 91 | logits=logits_sup, 92 | weights=labels['mask']) 93 | 94 | # Add weight decay to loss, for non-LARS optimizers. 95 | model_util.add_weight_decay(adjust_per_optimizer=True) 96 | loss = tf.losses.get_total_loss() 97 | 98 | if FLAGS.train_mode == 'pretrain': 99 | variables_to_train = tf.trainable_variables() 100 | else: 101 | collection_prefix = 'trainable_variables_inblock_' 102 | variables_to_train = [] 103 | for j in range(FLAGS.fine_tune_after_block + 1, 6): 104 | variables_to_train += tf.get_collection(collection_prefix + str(j)) 105 | assert variables_to_train, 'variables_to_train shouldn\'t be empty!' 106 | 107 | tf.logging.info('===============Variables to train (begin)===============') 108 | tf.logging.info(variables_to_train) 109 | tf.logging.info('================Variables to train (end)================') 110 | 111 | learning_rate = model_util.learning_rate_schedule( 112 | FLAGS.learning_rate, num_train_examples) 113 | 114 | if is_training: 115 | if FLAGS.train_summary_steps > 0: 116 | # Compute stats for the summary. 117 | prob_con = tf.nn.softmax(logits_con) 118 | entropy_con = - tf.reduce_mean( 119 | tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1)) 120 | 121 | summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir) 122 | with tf.control_dependencies([summary_writer.init()]): 123 | with summary_writer.as_default(): 124 | should_record = tf.math.equal( 125 | tf.math.floormod(tf.train.get_global_step(), 126 | FLAGS.train_summary_steps), 0) 127 | with tf2.summary.record_if(should_record): 128 | contrast_acc = tf.equal( 129 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1)) 130 | contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32)) 131 | label_acc = tf.equal( 132 | tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1)) 133 | label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32)) 134 | tf2.summary.scalar( 135 | 'train_contrast_loss', 136 | contrast_loss, 137 | step=tf.train.get_global_step()) 138 | tf2.summary.scalar( 139 | 'train_contrast_acc', 140 | contrast_acc, 141 | step=tf.train.get_global_step()) 142 | tf2.summary.scalar( 143 | 'train_label_accuracy', 144 | label_acc, 145 | step=tf.train.get_global_step()) 146 | tf2.summary.scalar( 147 | 'contrast_entropy', 148 | entropy_con, 149 | step=tf.train.get_global_step()) 150 | tf2.summary.scalar( 151 | 'learning_rate', learning_rate, 152 | step=tf.train.get_global_step()) 153 | 154 | optimizer = model_util.get_optimizer(learning_rate) 155 | control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 156 | if FLAGS.train_summary_steps > 0: 157 | control_deps.extend(tf.summary.all_v2_summary_ops()) 158 | with tf.control_dependencies(control_deps): 159 | train_op = optimizer.minimize( 160 | loss, global_step=tf.train.get_or_create_global_step(), 161 | var_list=variables_to_train) 162 | 163 | if FLAGS.checkpoint: 164 | def scaffold_fn(): 165 | """Scaffold function to restore non-logits vars from checkpoint.""" 166 | tf.train.init_from_checkpoint( 167 | FLAGS.checkpoint, 168 | {v.op.name: v.op.name 169 | for v in tf.global_variables(FLAGS.variable_schema)}) 170 | 171 | if FLAGS.zero_init_logits_layer: 172 | # Init op that initializes output layer parameters to zeros. 173 | output_layer_parameters = [ 174 | var for var in tf.trainable_variables() if var.name.startswith( 175 | 'head_supervised')] 176 | tf.logging.info('Initializing output layer parameters %s to zero', 177 | [x.op.name for x in output_layer_parameters]) 178 | with tf.control_dependencies([tf.global_variables_initializer()]): 179 | init_op = tf.group([ 180 | tf.assign(x, tf.zeros_like(x)) 181 | for x in output_layer_parameters]) 182 | return tf.train.Scaffold(init_op=init_op) 183 | else: 184 | return tf.train.Scaffold() 185 | else: 186 | scaffold_fn = None 187 | 188 | return tf.estimator.tpu.TPUEstimatorSpec( 189 | mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn) 190 | else: 191 | 192 | def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask, 193 | **kws): 194 | """Inner metric function.""" 195 | metrics = {k: tf.metrics.mean(v, weights=mask) 196 | for k, v in kws.items()} 197 | metrics['label_top_1_accuracy'] = tf.metrics.accuracy( 198 | tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1), 199 | weights=mask) 200 | metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k( 201 | tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask) 202 | metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy( 203 | tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1), 204 | weights=mask) 205 | metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k( 206 | tf.argmax(labels_con, 1), logits_con, k=5, weights=mask) 207 | return metrics 208 | 209 | metrics = { 210 | 'logits_sup': logits_sup, 211 | 'labels_sup': labels['labels'], 212 | 'logits_con': logits_con, 213 | 'labels_con': labels_con, 214 | 'mask': labels['mask'], 215 | 'contrast_loss': tf.fill((params['batch_size'],), contrast_loss), 216 | 'regularization_loss': tf.fill((params['batch_size'],), 217 | tf.losses.get_regularization_loss()), 218 | } 219 | 220 | return tf.estimator.tpu.TPUEstimatorSpec( 221 | mode=mode, 222 | loss=loss, 223 | eval_metrics=(metric_fn, metrics), 224 | scaffold_fn=None) 225 | 226 | return model_fn 227 | -------------------------------------------------------------------------------- /vision/model_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Network architectures related functions used in SimCLR.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | from absl import flags 24 | 25 | import resnet 26 | from lars_optimizer import LARSOptimizer 27 | 28 | import tensorflow.compat.v1 as tf 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def add_weight_decay(adjust_per_optimizer=True): 34 | """Compute weight decay from flags.""" 35 | if adjust_per_optimizer and 'lars' in FLAGS.optimizer: 36 | # Weight decay are taking care of by optimizer for these cases. 37 | # Except for supervised head, which will be added here. 38 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables() 39 | if 'head_supervised' in v.name and 'bias' not in v.name] 40 | if l2_losses: 41 | tf.losses.add_loss( 42 | FLAGS.weight_decay * tf.add_n(l2_losses), 43 | tf.GraphKeys.REGULARIZATION_LOSSES) 44 | return 45 | 46 | l2_losses = [tf.nn.l2_loss(v) for v in tf.trainable_variables() 47 | if 'batch_normalization' not in v.name] 48 | tf.losses.add_loss( 49 | FLAGS.weight_decay * tf.add_n(l2_losses), 50 | tf.GraphKeys.REGULARIZATION_LOSSES) 51 | 52 | 53 | def get_train_steps(num_examples): 54 | """Determine the number of training steps.""" 55 | return FLAGS.train_steps or ( 56 | num_examples * FLAGS.train_epochs // FLAGS.train_batch_size + 1) 57 | 58 | 59 | def learning_rate_schedule(base_learning_rate, num_examples): 60 | """Build learning rate schedule.""" 61 | global_step = tf.train.get_or_create_global_step() 62 | warmup_steps = int(round( 63 | FLAGS.warmup_epochs * num_examples // FLAGS.train_batch_size)) 64 | if FLAGS.learning_rate_scaling == 'linear': 65 | scaled_lr = base_learning_rate * FLAGS.train_batch_size / 256. 66 | elif FLAGS.learning_rate_scaling == 'sqrt': 67 | scaled_lr = base_learning_rate * math.sqrt(FLAGS.train_batch_size) 68 | else: 69 | raise ValueError('Unknown learning rate scaling {}'.format( 70 | FLAGS.learning_rate_scaling)) 71 | learning_rate = (tf.to_float(global_step) / int(warmup_steps) * scaled_lr 72 | if warmup_steps else scaled_lr) 73 | 74 | # Cosine decay learning rate schedule 75 | total_steps = get_train_steps(num_examples) 76 | learning_rate = tf.where( 77 | global_step < warmup_steps, learning_rate, 78 | tf.train.cosine_decay( 79 | scaled_lr, 80 | global_step - warmup_steps, 81 | total_steps - warmup_steps)) 82 | 83 | return learning_rate 84 | 85 | 86 | def get_optimizer(learning_rate): 87 | """Returns an optimizer.""" 88 | if FLAGS.optimizer == 'momentum': 89 | optimizer = tf.train.MomentumOptimizer( 90 | learning_rate, FLAGS.momentum, use_nesterov=True) 91 | elif FLAGS.optimizer == 'adam': 92 | optimizer = tf.train.AdamOptimizer( 93 | learning_rate) 94 | elif FLAGS.optimizer == 'lars': 95 | optimizer = LARSOptimizer( 96 | learning_rate, 97 | momentum=FLAGS.momentum, 98 | weight_decay=FLAGS.weight_decay, 99 | exclude_from_weight_decay=['batch_normalization', 'bias', 100 | 'head_supervised']) 101 | else: 102 | raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer)) 103 | 104 | if FLAGS.use_tpu: 105 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 106 | return optimizer 107 | 108 | 109 | def linear_layer(x, 110 | is_training, 111 | num_classes, 112 | use_bias=True, 113 | use_bn=False, 114 | name='linear_layer'): 115 | """Linear head for linear evaluation. 116 | 117 | Args: 118 | x: hidden state tensor of shape (bsz, dim). 119 | is_training: boolean indicator for training or test. 120 | num_classes: number of classes. 121 | use_bias: whether or not to use bias. 122 | use_bn: whether or not to use BN for output units. 123 | name: the name for variable scope. 124 | 125 | Returns: 126 | logits of shape (bsz, num_classes) 127 | """ 128 | assert x.shape.ndims == 2, x.shape 129 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 130 | x = tf.layers.dense( 131 | inputs=x, 132 | units=num_classes, 133 | use_bias=use_bias and not use_bn, 134 | kernel_initializer=tf.random_normal_initializer(stddev=.01)) 135 | if use_bn: 136 | x = resnet.batch_norm_relu(x, is_training, relu=False, center=use_bias) 137 | x = tf.identity(x, '%s_out' % name) 138 | return x 139 | 140 | 141 | def projection_head(hiddens, is_training, name='head_contrastive'): 142 | """Head for projecting hiddens fo contrastive loss.""" 143 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 144 | mid_dim = hiddens.shape[-1] 145 | out_dim = FLAGS.proj_out_dim 146 | hiddens_list = [hiddens] 147 | if FLAGS.proj_head_mode == 'none': 148 | pass # directly use the output hiddens as hiddens. 149 | elif FLAGS.proj_head_mode == 'linear': 150 | hiddens = linear_layer( 151 | hiddens, is_training, out_dim, 152 | use_bias=False, use_bn=True, name='l_0') 153 | hiddens_list.append(hiddens) 154 | elif FLAGS.proj_head_mode == 'nonlinear': 155 | for j in range(FLAGS.num_proj_layers): 156 | if j != FLAGS.num_proj_layers - 1: 157 | # for the middle layers, use bias and relu for the output. 158 | dim, bias_relu = mid_dim, True 159 | else: 160 | # for the final layer, neither bias nor relu is used. 161 | dim, bias_relu = FLAGS.proj_out_dim, False 162 | hiddens = linear_layer( 163 | hiddens, is_training, dim, 164 | use_bias=bias_relu, use_bn=True, name='nl_%d'%j) 165 | hiddens = tf.nn.relu(hiddens) if bias_relu else hiddens 166 | hiddens_list.append(hiddens) 167 | else: 168 | raise ValueError('Unknown head projection mode {}'.format( 169 | FLAGS.proj_head_mode)) 170 | if FLAGS.train_mode == 'pretrain': 171 | # take the projection head output during pre-training. 172 | hiddens = hiddens_list[-1] 173 | else: 174 | # for checkpoint compatibility, whole projection head is built here. 175 | # but you can select part of projection head during fine-tuning. 176 | hiddens = hiddens_list[FLAGS.ft_proj_selector] 177 | return hiddens 178 | 179 | 180 | def supervised_head(hiddens, num_classes, is_training, name='head_supervised'): 181 | """Add supervised head & also add its variables to inblock collection.""" 182 | with tf.variable_scope(name): 183 | logits = linear_layer(hiddens, is_training, num_classes) 184 | for var in tf.trainable_variables(): 185 | if var.name.startswith(name): 186 | tf.add_to_collection('trainable_variables_inblock_5', var) 187 | return logits 188 | -------------------------------------------------------------------------------- /vision/objective.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Contrastive loss functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from absl import flags 23 | 24 | import tensorflow.compat.v1 as tf 25 | 26 | from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import 27 | 28 | from tensorflow.python.framework import ops 29 | from tensorflow.python.ops import math_ops 30 | from tensorflow.python.ops import array_ops 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | LARGE_NUM = 1e9 35 | 36 | def add_gradients_penalty(x, model, model_train_mode): 37 | """https://colab.research.google.com/github/timsainb/tensorflow2-generative-models/blob/master/3.0-WGAN-GP-fashion-mnist.ipynb#scrollTo=Wyipg-4oSYb1""" 38 | with tf.GradientTape() as t: 39 | t.watch(x) 40 | hidden = model(x, is_training=model_train_mode) 41 | gradients = t.gradient(hidden, x) 42 | dx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2, 3])) 43 | d_regularizer = tf.reduce_mean((dx - 1.0) ** 2) 44 | return d_regularizer 45 | 46 | def add_supervised_loss(labels, logits, weights, **kwargs): 47 | """Compute loss for model and add it to loss collection.""" 48 | return tf.losses.softmax_cross_entropy(labels, logits, weights, **kwargs) 49 | 50 | 51 | def add_contrastive_loss(hidden, 52 | hidden_norm=True, 53 | temperature=1.0, 54 | tpu_context=None, 55 | weights=1.0, 56 | loss_type=None, 57 | flags=None): 58 | """Compute loss for model. 59 | 60 | Args: 61 | hidden: hidden vector (`Tensor`) of shape (bsz, dim). 62 | hidden_norm: whether or not to use normalization on the hidden vector. 63 | temperature: a `floating` number for temperature scaling. 64 | tpu_context: context information for tpu. 65 | weights: a weighting number or vector. 66 | 67 | Returns: 68 | A loss scalar. 69 | The logits for contrastive prediction task. 70 | The labels for contrastive prediction task. 71 | """ 72 | # Get (normalized) hidden1 and hidden2. 73 | if hidden_norm: 74 | hidden = tf.math.l2_normalize(hidden, -1) 75 | hidden1, hidden2 = tf.split(hidden, 2, 0) 76 | batch_size = tf.shape(hidden1)[0] 77 | 78 | # Gather hidden1/hidden2 across replicas and create local labels. 79 | if tpu_context is not None: 80 | hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context) 81 | hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context) 82 | enlarged_batch_size = tf.shape(hidden1_large)[0] 83 | replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) 84 | labels_idx = tf.range(batch_size) + replica_id * batch_size 85 | labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) 86 | masks = tf.one_hot(labels_idx, enlarged_batch_size) 87 | else: 88 | hidden1_large = hidden1 89 | hidden2_large = hidden2 90 | labels = tf.one_hot(tf.range(batch_size), batch_size * 2) 91 | masks = tf.one_hot(tf.range(batch_size), batch_size) 92 | # check WPC 93 | if loss_type.lower() == "wpc": 94 | assert flags.gradient_penalty_weight != 0.0 95 | else: 96 | assert flags.gradient_penalty_weight == 0.0 97 | logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature 98 | logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature 99 | # aa and bb diagonals are not positive samples; positive samples are ab abd ba diagnoals 100 | # thus we want to mask aa and bb diagonals out 101 | if loss_type.lower() == "nce" or loss_type.lower() == "dv" or loss_type.lower() == "wpc": 102 | # NCE loss: minus big number to create cloes to 0 values in softmax 103 | print(loss_type) 104 | logits_aa = logits_aa - masks * LARGE_NUM 105 | logits_bb = logits_bb - masks * LARGE_NUM 106 | else: # otherwise just mask out using 0 107 | logits_aa = logits_aa * (1 - masks) 108 | logits_bb = logits_bb * (1 - masks) 109 | logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature 110 | logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature 111 | ############################################################################# 112 | ### Different losses: nce, chi, js and nwj 113 | ### Pos_scores: positive samples, i.e. joint distribution terms 114 | ### neg_scores: negative samples, i.e. marginal distribution terms 115 | ############################################################################# 116 | if loss_type.lower() == "nce": 117 | loss_a = tf.losses.softmax_cross_entropy( 118 | labels, tf.concat([logits_ab, logits_aa], 1), weights=weights) 119 | loss_b = tf.losses.softmax_cross_entropy( 120 | labels, tf.concat([logits_ba, logits_bb], 1), weights=weights) 121 | elif loss_type.lower() == "chi": 122 | ## Chi squared loss in general form 123 | alpha = flags.alpha 124 | beta = flags.beta 125 | gamma = flags.gamma 126 | 127 | joint_a = labels * tf.concat([logits_ab, logits_aa], 1) 128 | joint_b = labels * tf.concat([logits_ba, logits_bb], 1) 129 | # non-correlated views 130 | marg_a = (1.-labels) * tf.concat([logits_ab, logits_aa], 1) 131 | marg_b = (1.-labels) * tf.concat([logits_ba, logits_bb], 1) 132 | batch_size = tf.cast(batch_size, tf.float32) 133 | joint = 0.5*(tf.reduce_sum(joint_a - 0.5 * beta * joint_a**2) / batch_size)\ 134 | + 0.5*(tf.reduce_sum(joint_b - 0.5 * beta * joint_b**2) / batch_size) 135 | # non-correlated views 136 | marg = 0.5*(tf.reduce_sum(alpha * marg_a + 0.5 * gamma * marg_a**2) / (2*batch_size*(batch_size-1.)))\ 137 | + 0.5*(tf.reduce_sum(alpha * marg_b + 0.5 * gamma * marg_b**2) / (2*batch_size*(batch_size-1.))) 138 | loss = -1. * (joint - marg) 139 | tf.losses.add_loss(loss) 140 | return loss, logits_ab, labels 141 | 142 | elif loss_type.lower() == "js": 143 | # Jensen Shannon 144 | def js(logits_concat, labels, scope=None): 145 | lbls = math_ops.cast(labels, logits_concat.dtype) 146 | """SHOULD I ADD STOP GRADIENT?""" 147 | bs = math_ops.cast(batch_size, logits_concat.dtype) 148 | pos_scores = tf.reduce_sum(lbls * (-tf.math.softplus(-logits_concat))) / bs 149 | neg_scores = tf.reduce_sum((1 - lbls) * tf.math.softplus(logits_concat)) / ((2 * bs - 1) * bs) 150 | return - (pos_scores - neg_scores) 151 | 152 | loss_a = 0.5 * js(tf.concat([logits_ab, logits_aa], 1), labels) 153 | loss_b = 0.5 * js(tf.concat([logits_ba, logits_bb], 1), labels) 154 | tf.losses.add_loss(loss_a) 155 | tf.losses.add_loss(loss_b) 156 | 157 | elif loss_type.lower() == "nwj": 158 | def nwj(logits_concat, labels, scope=None): 159 | lbls = math_ops.cast(labels, logits_concat.dtype) 160 | """SHOULD I ADD STOP GRADIENT?""" 161 | bs = math_ops.cast(batch_size, logits_concat.dtype) 162 | pos_scores = tf.reduce_sum(lbls * logits_concat) / bs 163 | neg_scores = tf.reduce_sum((1 - lbls) * tf.math.exp(logits_concat - 1)) / ((2 * bs - 1) * bs) 164 | return - (pos_scores - neg_scores) 165 | 166 | loss_a = 0.5 * nwj(tf.concat([logits_ab, logits_aa], 1), labels) 167 | loss_b = 0.5 * nwj(tf.concat([logits_ba, logits_bb], 1), labels) 168 | tf.losses.add_loss(loss_a) 169 | tf.losses.add_loss(loss_b) 170 | elif loss_type.lower() == "dv": 171 | # Donsker and Varadhan 172 | def dv(logits_concat, labels, scope=None): 173 | lbls = math_ops.cast(labels, logits_concat.dtype) 174 | """SHOULD I ADD STOP GRADIENT?""" 175 | bs = math_ops.cast(batch_size, logits_concat.dtype) 176 | pos_scores = tf.reduce_sum(lbls * logits_concat) / bs 177 | neg_scores = tf.math.reduce_logsumexp((1 - lbls) * logits_concat) - tf.math.log((2 * bs - 1) * bs) 178 | return - (pos_scores - neg_scores) 179 | 180 | loss_a = 0.5 * dv(tf.concat([logits_ab, logits_aa], 1), labels) 181 | loss_b = 0.5 * dv(tf.concat([logits_ba, logits_bb], 1), labels) 182 | tf.losses.add_loss(loss_a) 183 | tf.losses.add_loss(loss_b) 184 | elif loss_type.lower() == "wpc": 185 | # Wasserstein Dependency Measure (i.e. Wasserstein Predictive Coding) 186 | # Adding soon 187 | pass # operation performed in model.py 188 | 189 | else: 190 | raise ValueError("Loss not implemented yet; only support {nce, chi, js, nwj}") 191 | 192 | loss = loss_a + loss_b 193 | return loss, logits_ab, labels 194 | 195 | 196 | def tpu_cross_replica_concat(tensor, tpu_context=None): 197 | """Reduce a concatenation of the `tensor` across TPU cores. 198 | 199 | Args: 200 | tensor: tensor to concatenate. 201 | tpu_context: A `TPUContext`. If not set, CPU execution is assumed. 202 | 203 | Returns: 204 | Tensor of the same rank as `tensor` with first dimension `num_replicas` 205 | times larger. 206 | """ 207 | if tpu_context is None or tpu_context.num_replicas <= 1: 208 | return tensor 209 | 210 | num_replicas = tpu_context.num_replicas 211 | 212 | with tf.name_scope('tpu_cross_replica_concat'): 213 | # This creates a tensor that is like the input tensor but has an added 214 | # replica dimension as the outermost dimension. On each replica it will 215 | # contain the local values and zeros for all other values that need to be 216 | # fetched from other replicas. 217 | ext_tensor = tf.scatter_nd( 218 | indices=[[xla.replica_id()]], 219 | updates=[tensor], 220 | shape=[num_replicas] + tensor.shape.as_list()) 221 | 222 | # As every value is only present on one replica and 0 in all others, adding 223 | # them all together will result in the full tensor on all replicas. 224 | ext_tensor = tf.tpu.cross_replica_sum(ext_tensor) 225 | 226 | # Flatten the replica dimension. 227 | # The first dimension size will be: tensor.shape[0] * num_replicas 228 | # Using [-1] trick to support also scalar input. 229 | return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:]) 230 | -------------------------------------------------------------------------------- /vision/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | tensorflow-gpu==1.15.2 3 | tensorflow-datasets==3.1.0 4 | tensorflow-hub==0.8.0 5 | -------------------------------------------------------------------------------- /vision/resnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Contains definitions for the post-activation form of Residual Networks. 17 | 18 | Residual networks (ResNets) were proposed in: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from absl import flags 28 | import tensorflow.compat.v1 as tf 29 | 30 | from tensorflow.python.tpu import tpu_function # pylint: disable=g-direct-tensorflow-import 31 | 32 | 33 | FLAGS = flags.FLAGS 34 | BATCH_NORM_EPSILON = 1e-5 35 | 36 | 37 | class BatchNormalization(tf.layers.BatchNormalization): 38 | """Batch Normalization layer that supports cross replica computation on TPU. 39 | 40 | This class extends the keras.BatchNormalization implementation by supporting 41 | cross replica means and variances. The base class implementation only computes 42 | moments based on mini-batch per replica (TPU core). 43 | 44 | For detailed information of arguments and implementation, refer to: 45 | https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization 46 | """ 47 | 48 | def __init__(self, fused=False, **kwargs): 49 | """Builds the batch normalization layer. 50 | 51 | Arguments: 52 | fused: If `False`, use the system recommended implementation. Only support 53 | `False` in the current implementation. 54 | **kwargs: input augments that are forwarded to 55 | tf.layers.BatchNormalization. 56 | """ 57 | if fused in (True, None): 58 | raise ValueError('The TPU version of BatchNormalization does not support ' 59 | 'fused=True.') 60 | super(BatchNormalization, self).__init__(fused=fused, **kwargs) 61 | 62 | def _cross_replica_average(self, t): 63 | """Calculates the average value of input tensor across TPU replicas.""" 64 | num_shards = tpu_function.get_tpu_context().number_of_shards 65 | return tf.tpu.cross_replica_sum(t) / tf.cast(num_shards, t.dtype) 66 | 67 | def _moments(self, inputs, reduction_axes, keep_dims): 68 | """Compute the mean and variance: it overrides the original _moments.""" 69 | shard_mean, shard_variance = super(BatchNormalization, self)._moments( 70 | inputs, reduction_axes, keep_dims=keep_dims) 71 | 72 | num_shards = tpu_function.get_tpu_context().number_of_shards 73 | if num_shards and num_shards > 1: 74 | # Each group has multiple replicas: here we compute group mean/variance by 75 | # aggregating per-replica mean/variance. 76 | group_mean = self._cross_replica_average(shard_mean) 77 | group_variance = self._cross_replica_average(shard_variance) 78 | 79 | # Group variance needs to also include the difference between shard_mean 80 | # and group_mean. 81 | mean_distance = tf.square(group_mean - shard_mean) 82 | group_variance += self._cross_replica_average(mean_distance) 83 | return (group_mean, group_variance) 84 | else: 85 | return (shard_mean, shard_variance) 86 | 87 | 88 | def batch_norm_relu(inputs, is_training, relu=True, init_zero=False, 89 | center=True, scale=True, data_format='channels_last'): 90 | """Performs a batch normalization followed by a ReLU. 91 | 92 | Args: 93 | inputs: `Tensor` of shape `[batch, channels, ...]`. 94 | is_training: `bool` for whether the model is training. 95 | relu: `bool` if False, omits the ReLU operation. 96 | init_zero: `bool` if True, initializes scale parameter of batch 97 | normalization with 0 instead of 1 (default). 98 | center: `bool` whether to add learnable bias factor. 99 | scale: `bool` whether to add learnable scaling factor. 100 | data_format: `str` either "channels_first" for `[batch, channels, height, 101 | width]` or "channels_last for `[batch, height, width, channels]`. 102 | 103 | Returns: 104 | A normalized `Tensor` with the same `data_format`. 105 | """ 106 | if init_zero: 107 | gamma_initializer = tf.zeros_initializer() 108 | else: 109 | gamma_initializer = tf.ones_initializer() 110 | 111 | if data_format == 'channels_first': 112 | axis = 1 113 | else: 114 | axis = -1 115 | 116 | if FLAGS.global_bn: 117 | bn_foo = BatchNormalization( 118 | axis=axis, 119 | momentum=FLAGS.batch_norm_decay, 120 | epsilon=BATCH_NORM_EPSILON, 121 | center=center, 122 | scale=scale, 123 | fused=False, 124 | gamma_initializer=gamma_initializer) 125 | inputs = bn_foo(inputs, training=is_training) 126 | else: 127 | inputs = tf.layers.batch_normalization( 128 | inputs=inputs, 129 | axis=axis, 130 | momentum=FLAGS.batch_norm_decay, 131 | epsilon=BATCH_NORM_EPSILON, 132 | center=center, 133 | scale=scale, 134 | training=is_training, 135 | fused=True, 136 | gamma_initializer=gamma_initializer) 137 | 138 | if relu: 139 | inputs = tf.nn.relu(inputs) 140 | return inputs 141 | 142 | 143 | def dropblock(net, is_training, keep_prob, dropblock_size, 144 | data_format='channels_last'): 145 | """DropBlock: a regularization method for convolutional neural networks. 146 | 147 | DropBlock is a form of structured dropout, where units in a contiguous 148 | region of a feature map are dropped together. DropBlock works better than 149 | dropout on convolutional layers due to the fact that activation units in 150 | convolutional layers are spatially correlated. 151 | See https://arxiv.org/pdf/1810.12890.pdf for details. 152 | 153 | Args: 154 | net: `Tensor` input tensor. 155 | is_training: `bool` for whether the model is training. 156 | keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. "None" 157 | means no DropBlock. 158 | dropblock_size: `int` size of blocks to be dropped by DropBlock. 159 | data_format: `str` either "channels_first" for `[batch, channels, height, 160 | width]` or "channels_last for `[batch, height, width, channels]`. 161 | Returns: 162 | A version of input tensor with DropBlock applied. 163 | Raises: 164 | if width and height of the input tensor are not equal. 165 | """ 166 | 167 | if not is_training or keep_prob is None: 168 | return net 169 | 170 | tf.logging.info('Applying DropBlock: dropblock_size {}, net.shape {}'.format( 171 | dropblock_size, net.shape)) 172 | 173 | if data_format == 'channels_last': 174 | _, width, height, _ = net.get_shape().as_list() 175 | else: 176 | _, _, width, height = net.get_shape().as_list() 177 | if width != height: 178 | raise ValueError('Input tensor with width!=height is not supported.') 179 | 180 | dropblock_size = min(dropblock_size, width) 181 | # seed_drop_rate is the gamma parameter of DropBlcok. 182 | seed_drop_rate = (1.0 - keep_prob) * width**2 / dropblock_size**2 / ( 183 | width - dropblock_size + 1)**2 184 | 185 | # Forces the block to be inside the feature map. 186 | w_i, h_i = tf.meshgrid(tf.range(width), tf.range(width)) 187 | valid_block_center = tf.logical_and( 188 | tf.logical_and(w_i >= int(dropblock_size // 2), 189 | w_i < width - (dropblock_size - 1) // 2), 190 | tf.logical_and(h_i >= int(dropblock_size // 2), 191 | h_i < width - (dropblock_size - 1) // 2)) 192 | 193 | valid_block_center = tf.expand_dims(valid_block_center, 0) 194 | valid_block_center = tf.expand_dims( 195 | valid_block_center, -1 if data_format == 'channels_last' else 0) 196 | 197 | randnoise = tf.random_uniform(net.shape, dtype=tf.float32) 198 | block_pattern = (1 - tf.cast(valid_block_center, dtype=tf.float32) + tf.cast( 199 | (1 - seed_drop_rate), dtype=tf.float32) + randnoise) >= 1 200 | block_pattern = tf.cast(block_pattern, dtype=tf.float32) 201 | 202 | if dropblock_size == width: 203 | block_pattern = tf.reduce_min( 204 | block_pattern, 205 | axis=[1, 2] if data_format == 'channels_last' else [2, 3], 206 | keepdims=True) 207 | else: 208 | if data_format == 'channels_last': 209 | ksize = [1, dropblock_size, dropblock_size, 1] 210 | else: 211 | ksize = [1, 1, dropblock_size, dropblock_size] 212 | block_pattern = -tf.nn.max_pool( 213 | -block_pattern, ksize=ksize, strides=[1, 1, 1, 1], padding='SAME', 214 | data_format='NHWC' if data_format == 'channels_last' else 'NCHW') 215 | 216 | percent_ones = tf.cast(tf.reduce_sum((block_pattern)), tf.float32) / tf.cast( 217 | tf.size(block_pattern), tf.float32) 218 | 219 | net = net / tf.cast(percent_ones, net.dtype) * tf.cast( 220 | block_pattern, net.dtype) 221 | return net 222 | 223 | 224 | def fixed_padding(inputs, kernel_size, data_format='channels_last'): 225 | """Pads the input along the spatial dimensions independently of input size. 226 | 227 | Args: 228 | inputs: `Tensor` of size `[batch, channels, height, width]` or 229 | `[batch, height, width, channels]` depending on `data_format`. 230 | kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d` 231 | operations. Should be a positive integer. 232 | data_format: `str` either "channels_first" for `[batch, channels, height, 233 | width]` or "channels_last for `[batch, height, width, channels]`. 234 | 235 | Returns: 236 | A padded `Tensor` of the same `data_format` with size either intact 237 | (if `kernel_size == 1`) or padded (if `kernel_size > 1`). 238 | """ 239 | pad_total = kernel_size - 1 240 | pad_beg = pad_total // 2 241 | pad_end = pad_total - pad_beg 242 | if data_format == 'channels_first': 243 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], 244 | [pad_beg, pad_end], [pad_beg, pad_end]]) 245 | else: 246 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 247 | [pad_beg, pad_end], [0, 0]]) 248 | 249 | return padded_inputs 250 | 251 | 252 | def conv2d_fixed_padding(inputs, filters, kernel_size, strides, 253 | data_format='channels_last'): 254 | """Strided 2-D convolution with explicit padding. 255 | 256 | The padding is consistent and is based only on `kernel_size`, not on the 257 | dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). 258 | 259 | Args: 260 | inputs: `Tensor` of size `[batch, channels, height_in, width_in]`. 261 | filters: `int` number of filters in the convolution. 262 | kernel_size: `int` size of the kernel to be used in the convolution. 263 | strides: `int` strides of the convolution. 264 | data_format: `str` either "channels_first" for `[batch, channels, height, 265 | width]` or "channels_last for `[batch, height, width, channels]`. 266 | 267 | Returns: 268 | A `Tensor` of shape `[batch, filters, height_out, width_out]`. 269 | """ 270 | if strides > 1: 271 | inputs = fixed_padding(inputs, kernel_size, data_format=data_format) 272 | 273 | return tf.layers.conv2d( 274 | inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, 275 | padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, 276 | kernel_initializer=tf.variance_scaling_initializer(), 277 | data_format=data_format) 278 | 279 | 280 | def sk_conv2d(inputs, filters, strides, sk_ratio, min_dim=32, 281 | is_training=True, data_format='channels_last'): 282 | """Selective kernel convolutional layer (https://arxiv.org/abs/1903.06586).""" 283 | channel_axis = 1 if data_format == 'channels_first' else 3 284 | pooling_axes = [2, 3] if data_format == 'channels_first' else [1, 2] 285 | 286 | # Two stream convs (using split and both are 3x3). 287 | inputs = conv2d_fixed_padding( 288 | inputs=inputs, filters=2 * filters, kernel_size=3, strides=strides, 289 | data_format=data_format) 290 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 291 | inputs = tf.stack(tf.split(inputs, num_or_size_splits=2, axis=channel_axis)) 292 | 293 | # Mixing weights for two streams. 294 | mid_dim = max(int(filters * sk_ratio), min_dim) 295 | global_features = tf.reduce_mean( 296 | tf.reduce_sum(inputs, axis=0), pooling_axes, keepdims=True) 297 | global_features = tf.layers.conv2d( 298 | inputs=global_features, filters=mid_dim, kernel_size=1, strides=1, 299 | kernel_initializer=tf.variance_scaling_initializer(), 300 | use_bias=False, data_format=data_format) 301 | global_features = batch_norm_relu( 302 | global_features, is_training, data_format=data_format) 303 | mixing = tf.layers.conv2d( 304 | inputs=global_features, filters=2 * filters, kernel_size=1, strides=1, 305 | kernel_initializer=tf.variance_scaling_initializer(), 306 | use_bias=False, data_format=data_format) 307 | mixing = tf.stack(tf.split(mixing, num_or_size_splits=2, axis=channel_axis)) 308 | mixing = tf.nn.softmax(mixing, axis=0) 309 | 310 | return tf.reduce_sum(inputs * mixing, axis=0) 311 | 312 | 313 | def se_layer(inputs, filters, se_ratio, data_format='channels_last'): 314 | """Squeeze and Excitation layer (https://arxiv.org/abs/1709.01507).""" 315 | if se_ratio <= 0: 316 | return inputs 317 | se_reduce = tf.layers.Conv2D( 318 | max(1, int(filters * se_ratio)), 319 | kernel_size=[1, 1], 320 | strides=[1, 1], 321 | kernel_initializer=tf.variance_scaling_initializer(), 322 | padding='same', 323 | data_format=data_format, 324 | use_bias=True) 325 | se_expand = tf.layers.Conv2D( 326 | inputs.shape[-1], 327 | kernel_size=[1, 1], 328 | strides=[1, 1], 329 | kernel_initializer=tf.variance_scaling_initializer(), 330 | padding='same', 331 | data_format=data_format, 332 | use_bias=True) 333 | 334 | spatial_dims = [2, 3] if data_format == 'channels_first' else [1, 2] 335 | se_tensor = tf.reduce_mean( 336 | inputs, spatial_dims, keepdims=True) 337 | se_tensor = se_expand(tf.nn.relu(se_reduce(se_tensor))) 338 | return tf.sigmoid(se_tensor) * inputs 339 | 340 | 341 | def residual_block(inputs, filters, is_training, strides, 342 | use_projection=False, data_format='channels_last', 343 | dropblock_keep_prob=None, dropblock_size=None): 344 | """Standard building block for residual networks with BN after convolutions. 345 | 346 | Args: 347 | inputs: `Tensor` of size `[batch, channels, height, width]`. 348 | filters: `int` number of filters for the first two convolutions. Note that 349 | the third and final convolution will use 4 times as many filters. 350 | is_training: `bool` for whether the model is in training. 351 | strides: `int` block stride. If greater than 1, this block will ultimately 352 | downsample the input. 353 | use_projection: `bool` for whether this block should use a projection 354 | shortcut (versus the default identity shortcut). This is usually `True` 355 | for the first block of a block group, which may change the number of 356 | filters and the resolution. 357 | data_format: `str` either "channels_first" for `[batch, channels, height, 358 | width]` or "channels_last for `[batch, height, width, channels]`. 359 | dropblock_keep_prob: unused; needed to give method same signature as other 360 | blocks 361 | dropblock_size: unused; needed to give method same signature as other 362 | blocks 363 | Returns: 364 | The output `Tensor` of the block. 365 | """ 366 | del dropblock_keep_prob 367 | del dropblock_size 368 | shortcut = inputs 369 | if use_projection: 370 | # Projection shortcut in first layer to match filters and strides 371 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187) 372 | if strides > 1: 373 | inputs = fixed_padding(inputs, 2, data_format) 374 | inputs = tf.layers.average_pooling2d( 375 | inputs, pool_size=2, strides=strides, 376 | padding='SAME' if strides == 1 else 'VALID', data_format=data_format) 377 | shortcut = conv2d_fixed_padding( 378 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 379 | data_format=data_format) 380 | else: 381 | shortcut = conv2d_fixed_padding( 382 | inputs=inputs, filters=filters, kernel_size=1, strides=strides, 383 | data_format=data_format) 384 | shortcut = batch_norm_relu(shortcut, is_training, relu=False, 385 | data_format=data_format) 386 | 387 | inputs = conv2d_fixed_padding( 388 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 389 | data_format=data_format) 390 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 391 | 392 | inputs = conv2d_fixed_padding( 393 | inputs=inputs, filters=filters, kernel_size=3, strides=1, 394 | data_format=data_format) 395 | inputs = batch_norm_relu(inputs, is_training, relu=False, init_zero=True, 396 | data_format=data_format) 397 | 398 | if FLAGS.se_ratio > 0: 399 | inputs = se_layer(inputs, filters, FLAGS.se_ratio, data_format=data_format) 400 | 401 | return tf.nn.relu(inputs + shortcut) 402 | 403 | 404 | def bottleneck_block(inputs, filters, is_training, strides, 405 | use_projection=False, data_format='channels_last', 406 | dropblock_keep_prob=None, dropblock_size=None): 407 | """Bottleneck block variant for residual networks with BN after convolutions. 408 | 409 | Args: 410 | inputs: `Tensor` of size `[batch, channels, height, width]`. 411 | filters: `int` number of filters for the first two convolutions. Note that 412 | the third and final convolution will use 4 times as many filters. 413 | is_training: `bool` for whether the model is in training. 414 | strides: `int` block stride. If greater than 1, this block will ultimately 415 | downsample the input. 416 | use_projection: `bool` for whether this block should use a projection 417 | shortcut (versus the default identity shortcut). This is usually `True` 418 | for the first block of a block group, which may change the number of 419 | filters and the resolution. 420 | data_format: `str` either "channels_first" for `[batch, channels, height, 421 | width]` or "channels_last for `[batch, height, width, channels]`. 422 | dropblock_keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. 423 | "None" means no DropBlock. 424 | dropblock_size: `int` size parameter of DropBlock. Will not be used if 425 | dropblock_keep_prob is "None". 426 | 427 | Returns: 428 | The output `Tensor` of the block. 429 | """ 430 | shortcut = inputs 431 | if use_projection: 432 | # Projection shortcut only in first block within a group. Bottleneck blocks 433 | # end with 4 times the number of filters. 434 | filters_out = 4 * filters 435 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187) 436 | if strides > 1: 437 | shortcut = fixed_padding(inputs, 2, data_format) 438 | else: 439 | shortcut = inputs 440 | shortcut = tf.layers.average_pooling2d( 441 | shortcut, pool_size=2, strides=strides, 442 | padding='SAME' if strides == 1 else 'VALID', data_format=data_format) 443 | shortcut = conv2d_fixed_padding( 444 | inputs=shortcut, filters=filters_out, kernel_size=1, strides=1, 445 | data_format=data_format) 446 | else: 447 | shortcut = conv2d_fixed_padding( 448 | inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, 449 | data_format=data_format) 450 | shortcut = batch_norm_relu(shortcut, is_training, relu=False, 451 | data_format=data_format) 452 | shortcut = dropblock( 453 | shortcut, is_training=is_training, data_format=data_format, 454 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 455 | 456 | inputs = conv2d_fixed_padding( 457 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 458 | data_format=data_format) 459 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 460 | inputs = dropblock( 461 | inputs, is_training=is_training, data_format=data_format, 462 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 463 | 464 | if FLAGS.sk_ratio > 0: 465 | inputs = sk_conv2d( 466 | inputs, filters, strides, FLAGS.sk_ratio, 467 | is_training=is_training, data_format=data_format) 468 | else: 469 | inputs = conv2d_fixed_padding( 470 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 471 | data_format=data_format) 472 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 473 | inputs = dropblock( 474 | inputs, is_training=is_training, data_format=data_format, 475 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 476 | 477 | inputs = conv2d_fixed_padding( 478 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, 479 | data_format=data_format) 480 | inputs = batch_norm_relu(inputs, is_training, relu=False, init_zero=True, 481 | data_format=data_format) 482 | inputs = dropblock( 483 | inputs, is_training=is_training, data_format=data_format, 484 | keep_prob=dropblock_keep_prob, dropblock_size=dropblock_size) 485 | 486 | if FLAGS.se_ratio > 0: 487 | inputs = se_layer(inputs, filters, FLAGS.se_ratio, data_format=data_format) 488 | 489 | return tf.nn.relu(inputs + shortcut) 490 | 491 | 492 | def block_group(inputs, filters, block_fn, blocks, strides, is_training, name, 493 | data_format='channels_last', dropblock_keep_prob=None, 494 | dropblock_size=None): 495 | """Creates one group of blocks for the ResNet model. 496 | 497 | Args: 498 | inputs: `Tensor` of size `[batch, channels, height, width]`. 499 | filters: `int` number of filters for the first convolution of the layer. 500 | block_fn: `function` for the block to use within the model 501 | blocks: `int` number of blocks contained in the layer. 502 | strides: `int` stride to use for the first convolution of the layer. If 503 | greater than 1, this layer will downsample the input. 504 | is_training: `bool` for whether the model is training. 505 | name: `str`name for the Tensor output of the block layer. 506 | data_format: `str` either "channels_first" for `[batch, channels, height, 507 | width]` or "channels_last for `[batch, height, width, channels]`. 508 | dropblock_keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. 509 | "None" means no DropBlock. 510 | dropblock_size: `int` size parameter of DropBlock. Will not be used if 511 | dropblock_keep_prob is "None". 512 | 513 | Returns: 514 | The output `Tensor` of the block layer. 515 | """ 516 | # Only the first block per block_group uses projection shortcut and strides. 517 | inputs = block_fn(inputs, filters, is_training, strides, 518 | use_projection=True, data_format=data_format, 519 | dropblock_keep_prob=dropblock_keep_prob, 520 | dropblock_size=dropblock_size) 521 | 522 | for _ in range(1, blocks): 523 | inputs = block_fn(inputs, filters, is_training, 1, 524 | data_format=data_format, 525 | dropblock_keep_prob=dropblock_keep_prob, 526 | dropblock_size=dropblock_size) 527 | 528 | return tf.identity(inputs, name) 529 | 530 | 531 | def resnet_v1_generator(block_fn, layers, width_multiplier, 532 | cifar_stem=False, data_format='channels_last', 533 | dropblock_keep_probs=None, dropblock_size=None): 534 | """Generator for ResNet v1 models. 535 | 536 | Args: 537 | block_fn: `function` for the block to use within the model. Either 538 | `residual_block` or `bottleneck_block`. 539 | layers: list of 4 `int`s denoting the number of blocks to include in each 540 | of the 4 block groups. Each group consists of blocks that take inputs of 541 | the same resolution. 542 | width_multiplier: `int` width multiplier for network. 543 | cifar_stem: `bool` If True, use a 3x3 conv without strides or pooling as 544 | stem. 545 | data_format: `str` either "channels_first" for `[batch, channels, height, 546 | width]` or "channels_last for `[batch, height, width, channels]`. 547 | dropblock_keep_probs: `list` of 4 elements denoting keep_prob of DropBlock 548 | for each block group. None indicates no DropBlock for the corresponding 549 | block group. 550 | dropblock_size: `int`: size parameter of DropBlock. 551 | 552 | Returns: 553 | Model `function` that takes in `inputs` and `is_training` and returns the 554 | output `Tensor` of the ResNet model. 555 | 556 | Raises: 557 | if dropblock_keep_probs is not 'None' or a list with len 4. 558 | """ 559 | if dropblock_keep_probs is None: 560 | dropblock_keep_probs = [None] * 4 561 | if not isinstance(dropblock_keep_probs, 562 | list) or len(dropblock_keep_probs) != 4: 563 | raise ValueError('dropblock_keep_probs is not valid:', dropblock_keep_probs) 564 | 565 | def model(inputs, is_training): 566 | """Creation of the model graph.""" 567 | if cifar_stem: 568 | inputs = conv2d_fixed_padding( 569 | inputs=inputs, filters=64 * width_multiplier, kernel_size=3, 570 | strides=1, data_format=data_format) 571 | inputs = tf.identity(inputs, 'initial_conv') 572 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 573 | inputs = tf.identity(inputs, 'initial_max_pool') 574 | else: 575 | if FLAGS.sk_ratio > 0: # Use ResNet-D (https://arxiv.org/abs/1812.01187) 576 | inputs = conv2d_fixed_padding( 577 | inputs=inputs, filters=64 * width_multiplier // 2, kernel_size=3, 578 | strides=2, data_format=data_format) 579 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 580 | inputs = conv2d_fixed_padding( 581 | inputs=inputs, filters=64 * width_multiplier // 2, kernel_size=3, 582 | strides=1, data_format=data_format) 583 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 584 | inputs = conv2d_fixed_padding( 585 | inputs=inputs, filters=64 * width_multiplier, kernel_size=3, 586 | strides=1, data_format=data_format) 587 | else: 588 | inputs = conv2d_fixed_padding( 589 | inputs=inputs, filters=64 * width_multiplier, kernel_size=7, 590 | strides=2, data_format=data_format) 591 | inputs = tf.identity(inputs, 'initial_conv') 592 | inputs = batch_norm_relu(inputs, is_training, data_format=data_format) 593 | 594 | inputs = tf.layers.max_pooling2d( 595 | inputs=inputs, pool_size=3, strides=2, padding='SAME', 596 | data_format=data_format) 597 | inputs = tf.identity(inputs, 'initial_max_pool') 598 | 599 | def filter_trainable_variables(trainable_variables, after_block): 600 | """Add new trainable variables for the immediate precedent block.""" 601 | if after_block == 0: 602 | trainable_variables[after_block] = tf.trainable_variables() 603 | else: 604 | trainable_variables[after_block] = [] 605 | for var in tf.trainable_variables(): 606 | to_keep = True 607 | for j in range(after_block): 608 | if var in trainable_variables[j]: 609 | to_keep = False 610 | break 611 | if to_keep: 612 | trainable_variables[after_block].append(var) 613 | 614 | def add_to_collection(trainable_variables, prefix): 615 | """Put variables into graph collection.""" 616 | for after_block, variables in trainable_variables.items(): 617 | collection = prefix + str(after_block) 618 | for var in variables: 619 | tf.add_to_collection(collection, var) 620 | 621 | trainable_variables = {} 622 | filter_trainable_variables(trainable_variables, after_block=0) 623 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 0: 624 | inputs = tf.stop_gradient(inputs) 625 | 626 | inputs = block_group( 627 | inputs=inputs, filters=64 * width_multiplier, block_fn=block_fn, 628 | blocks=layers[0], strides=1, is_training=is_training, 629 | name='block_group1', data_format=data_format, 630 | dropblock_keep_prob=dropblock_keep_probs[0], 631 | dropblock_size=dropblock_size) 632 | 633 | filter_trainable_variables(trainable_variables, after_block=1) 634 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 1: 635 | inputs = tf.stop_gradient(inputs) 636 | 637 | inputs = block_group( 638 | inputs=inputs, filters=128 * width_multiplier, block_fn=block_fn, 639 | blocks=layers[1], strides=2, is_training=is_training, 640 | name='block_group2', data_format=data_format, 641 | dropblock_keep_prob=dropblock_keep_probs[1], 642 | dropblock_size=dropblock_size) 643 | 644 | filter_trainable_variables(trainable_variables, after_block=2) 645 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 2: 646 | inputs = tf.stop_gradient(inputs) 647 | 648 | inputs = block_group( 649 | inputs=inputs, filters=256 * width_multiplier, block_fn=block_fn, 650 | blocks=layers[2], strides=2, is_training=is_training, 651 | name='block_group3', data_format=data_format, 652 | dropblock_keep_prob=dropblock_keep_probs[2], 653 | dropblock_size=dropblock_size) 654 | 655 | filter_trainable_variables(trainable_variables, after_block=3) 656 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 3: 657 | inputs = tf.stop_gradient(inputs) 658 | 659 | inputs = block_group( 660 | inputs=inputs, filters=512 * width_multiplier, block_fn=block_fn, 661 | blocks=layers[3], strides=2, is_training=is_training, 662 | name='block_group4', data_format=data_format, 663 | dropblock_keep_prob=dropblock_keep_probs[3], 664 | dropblock_size=dropblock_size) 665 | 666 | filter_trainable_variables(trainable_variables, after_block=4) 667 | if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block == 4: 668 | inputs = tf.stop_gradient(inputs) 669 | 670 | if data_format == 'channels_last': 671 | inputs = tf.reduce_mean(inputs, [1, 2]) 672 | else: 673 | inputs = tf.reduce_mean(inputs, [2, 3]) 674 | inputs = tf.identity(inputs, 'final_avg_pool') 675 | 676 | # filter_trainable_variables(trainable_variables, after_block=5) 677 | add_to_collection(trainable_variables, 'trainable_variables_inblock_') 678 | 679 | return inputs 680 | 681 | return model 682 | 683 | 684 | def resnet_v1(resnet_depth, width_multiplier, 685 | cifar_stem=False, data_format='channels_last', 686 | dropblock_keep_probs=None, dropblock_size=None): 687 | """Returns the ResNet model for a given size and number of output classes.""" 688 | model_params = { 689 | 18: {'block': residual_block, 'layers': [2, 2, 2, 2]}, 690 | 34: {'block': residual_block, 'layers': [3, 4, 6, 3]}, 691 | 50: {'block': bottleneck_block, 'layers': [3, 4, 6, 3]}, 692 | 101: {'block': bottleneck_block, 'layers': [3, 4, 23, 3]}, 693 | 152: {'block': bottleneck_block, 'layers': [3, 8, 36, 3]}, 694 | 200: {'block': bottleneck_block, 'layers': [3, 24, 36, 3]} 695 | } 696 | 697 | if resnet_depth not in model_params: 698 | raise ValueError('Not a valid resnet_depth:', resnet_depth) 699 | 700 | params = model_params[resnet_depth] 701 | return resnet_v1_generator( 702 | params['block'], params['layers'], width_multiplier, 703 | cifar_stem=cifar_stem, 704 | dropblock_keep_probs=dropblock_keep_probs, 705 | dropblock_size=dropblock_size, 706 | data_format=data_format) 707 | -------------------------------------------------------------------------------- /vision/results_summary.txt: -------------------------------------------------------------------------------- 1 | TEMP=256 2 | 3 | gs://martin_ma_mql_simclr/nce_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7348, label_top_5_accuracy = 0.91076, loss = 1.4509399, regularization_loss = 0.0 4 | 5 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71788, label_top_5_accuracy = 0.90118, loss = 1.3814622, regularization_loss = 0.0 6 | 7 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7124, label_top_5_accuracy = 0.89796, loss = 1.4208169, regularization_loss = 0.0 8 | 9 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71834, label_top_5_accuracy = 0.9028, loss = 1.3136246, regularization_loss = 0.0 10 | 11 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_512_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.56738, label_top_5_accuracy = 0.8002, loss = 1.8758307, regularization_loss = 0.0 12 | 13 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_1e-4_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71534, label_top_5_accuracy = 0.90094, loss = 1.3881, regularization_loss = 0.0 14 | 15 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_1e-3_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71874, label_top_5_accuracy = 0.90276, loss = 1.48331, regularization_loss = 0.0 16 | 17 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_1e-2_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72452, label_top_5_accuracy = 0.90286, loss = 1.564714, regularization_loss = 0.0 18 | 19 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_1e-1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71364, label_top_5_accuracy = 0.89566, loss = 1.5915436, regularization_loss = 0.0 20 | 21 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_0.5_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71914, label_top_5_accuracy = 0.89728, loss = 1.5467246, regularization_loss = 0.0 22 | 23 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_0.01_GAMMA_2.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7223, label_top_5_accuracy = 0.90068, loss = 1.5355219, regularization_loss = 0.0 24 | 25 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7223, label_top_5_accuracy = 0.90238, loss = 1.5636464, regularization_loss = 0.0 26 | 27 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72394, label_top_5_accuracy = 0.90424, loss = 1.5786924, regularization_loss = 0.0 28 | 29 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73042, label_top_5_accuracy = 0.90804, loss = 1.4838624, regularization_loss = 0.0 30 | 31 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_2.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72558, label_top_5_accuracy = 0.90198, loss = 1.5553308, regularization_loss = 0.0 32 | 33 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_5.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7226, label_top_5_accuracy = 0.9008, loss = 1.5155789, regularization_loss = 0.0 34 | 35 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.02_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7227, label_top_5_accuracy = 0.9026, loss = 1.5762701, regularization_loss = 0.0 36 | 37 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.05_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71726, label_top_5_accuracy = 0.89686, loss = 1.6178995, regularization_loss = 0.0 38 | 39 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.005_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71776, label_top_5_accuracy = 0.9007, loss = 1.5661987, regularization_loss = 0.0 40 | 41 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_5.0_BETA_0.05_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71656, label_top_5_accuracy = 0.8976, loss = 1.5743265, regularization_loss = 0.0 42 | 43 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.5_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72322, label_top_5_accuracy = 0.90264, loss = 1.6190277,regularization_loss = 0.0 44 | 45 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73028, label_top_5_accuracy = 0.9073, loss = 1.459571, regularization_loss = 0.0 46 | 47 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_256_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.5_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72904, label_top_5_accuracy = 0.90646, loss = 1.5504616,regularization_loss = 0.0 48 | 49 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_5.0_BETA_0.1_GAMMA_3.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71782, label_top_5_accuracy = 0.89956, loss = 1.5116348, regularization_loss = 0.0 50 | 51 | 52 | ======= 53 | 54 | 55 | ======= 56 | 57 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72056, label_top_5_accuracy = 0.90078, loss = 1.5019414, regularization_loss = 0.0 58 | 59 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_10.0_BETA_0.01_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72932, label_top_5_accuracy = 0.90458, loss = 1.570985, regularization_loss = 0.0 60 | 61 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_10.0_BETA_0.02_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7112, label_top_5_accuracy = 0.89654, loss = 1.5867603, regularization_loss = 0.0 62 | 63 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.3_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74106, label_top_5_accuracy = 0.91256, loss = 1.4070679, regularization_loss = 0.0 64 | 65 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74244, label_top_5_accuracy = 0.91488, loss = 1.4087038, regularization_loss = 0.0 66 | 67 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_86_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74068, label_top_5_accuracy = 0.91394, loss = 1.4385886, regularization_loss = 0.0 68 | 69 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7392, label_top_5_accuracy = 0.9137, loss = 1.4215096, regularization_loss = 0.0 70 | 71 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74006, label_top_5_accuracy = 0.91448, loss = 1.398139, regularization_loss = 0.0 72 | 73 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.2_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.738, label_top_5_accuracy = 0.91306, loss = 1.422334, regularization_loss = 0.0 74 | 75 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.02_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73756, label_top_5_accuracy = 0.9119, loss = 1.4739765, regularization_loss = 0.0 76 | 77 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7421, label_top_5_accuracy = 0.91478, loss = 1.434248, regularization_loss = 0.0 78 | 79 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74214, label_top_5_accuracy = 0.9144, loss = 1.3988103, regularization_loss = 0.0 80 | 81 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72656, label_top_5_accuracy = 0.9075, loss = 1.480991, regularization_loss = 0.0 82 | 83 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74076, label_top_5_accuracy = 0.91372, loss = 1.5113194, regularization_loss = 0.0 84 | 85 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_2_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.69452, label_top_5_accuracy = 0.88356, loss = 1.7041928, regularization_loss = 0.0 86 | 87 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7435, label_top_5_accuracy = 0.91578, loss = 1.41444, regularization_loss = 0.0 88 | 89 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7422, label_top_5_accuracy = 0.91474, loss = 1.3790753, regularization_loss = 0.0 90 | 91 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74324, label_top_5_accuracy = 0.917, loss = 1.3407627, regularization_loss = 0.0 92 | 93 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.2_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74182, label_top_5_accuracy = 0.91422, loss = 1.4183415, regularization_loss = 0.0 94 | 95 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.05_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74266, label_top_5_accuracy = 0.91512, loss = 1.3939042, regularization_loss = 0.0 96 | 97 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.02_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73006, label_top_5_accuracy = 0.90818, loss = 1.471376, regularization_loss = 0.0 98 | 99 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.05_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71944, label_top_5_accuracy = 0.9003, loss = 1.4997969, regularization_loss = 0.0 100 | 101 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7416, label_top_5_accuracy = 0.91422, loss = 1.4269485, regularization_loss = 0.0 102 | 103 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.2_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73842, label_top_5_accuracy = 0.91254, loss = 1.464798, regularization_loss = 0.0 104 | 105 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.0_BETA_0.01_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74102, label_top_5_accuracy = 0.91458, loss = 1.5262817, regularization_loss = 0.0 106 | 107 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.25_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74326, label_top_5_accuracy = 0.91406, loss = 1.4183799, regularization_loss = 0.0 108 | 109 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.2_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74182, label_top_5_accuracy = 0.91422, loss = 1.4183415, regularization_loss = 0.0 110 | 111 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_false_ALPHA_0.3_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74494, label_top_5_accuracy = 0.91484, loss = 1.3978903, regularization_loss = 0.0 112 | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 113 | gs://martin_ma_mql_simclr/nce_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77824, label_top_5_accuracy = 0.93058, loss = 1.0428659, regularization_loss = 0.0 114 | 115 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/hub/28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7463, label_top_5_accuracy = 0.91458, loss = 1.1594483, regularization_loss = 0.0 116 | 117 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_3.0_BETA_1e-2_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7459, label_top_5_accuracy = 0.91434, loss = 1.1565692, regularization_loss = 0.0 118 | 119 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_128_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_1e-3_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74422, label_top_5_accuracy = 0.91424, loss = 1.2344401, regularization_loss = 0.0 120 | 121 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_10.0_BETA_1e-1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74468, label_top_5_accuracy = 0.9085, loss = 1.3773991, regularization_loss = 0.0 122 | 123 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1e-3_BETA_1e-3_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74616, label_top_5_accuracy = 0.91402, loss = 1.3737655, regularization_loss = 0.0 124 | 125 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77374, label_top_5_accuracy = 0.92736, loss = 1.035313, regularization_loss = 0.0 126 | 127 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.02_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.76864, label_top_5_accuracy = 0.92676, loss = 1.1647376, regularization_loss = 0.0 128 | 129 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.1_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7727, label_top_5_accuracy = 0.92882, loss = 1.1740562, regularization_loss = 0.0 130 | 131 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_8_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.005_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75432, label_top_5_accuracy = 0.91638, loss = 1.1587994, regularization_loss = 0.0 132 | 133 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77156, label_top_5_accuracy = 0.92894, loss = 1.0430934, regularization_loss = 0.0 134 | 135 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_2_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71732, label_top_5_accuracy = 0.89936, loss = 1.3376576, regularization_loss = 0.0 136 | 137 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77336, label_top_5_accuracy = 0.92894, loss = 1.1666732, regularization_loss = 0.0 138 | 139 | s://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_64_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77256, label_top_5_accuracy = 0.92906, loss = 0.98218614, regularization_loss = 0.0 140 | 141 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_16_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.2_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7731, label_top_5_accuracy = 0.92808, loss = 1.0666859, regularization_loss = 0.0 142 | 143 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_2_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.001, label_top_5_accuracy = 0.005, loss = 6.9076014, regularization_loss = 0.0 144 | 145 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.0_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74514, label_top_5_accuracy = 0.91204, loss = 1.3100942, regularization_loss = 0.0 146 | 147 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.1_GAMMA_1.0_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74134, label_top_5_accuracy = 0.90802, loss = 1.3841316, regularization_loss = 0.0 148 | 149 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74152, label_top_5_accuracy = 0.91008, loss = 1.2542875, regularization_loss = 0.0 150 | 151 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77338, label_top_5_accuracy = 0.92794, loss = 1.0368634, regularization_loss = 0.0 152 | 153 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_2.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77472, label_top_5_accuracy = 0.9304, loss = 1.0788878, regularization_loss = 0.0 154 | 155 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_3.0_BETA_0.01_GAMMA_0.1_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77372, label_top_5_accuracy = 0.929, loss = 1.0592294, regularization_loss = 0.0 156 | 157 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77942, label_top_5_accuracy = 0.9322, loss = 0.9496544, regularization_loss = 0.0 158 | 159 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.1_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77858, label_top_5_accuracy = 0.9315, loss = 1.0732164, regularization_loss = 0.0 160 | 161 | s://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77686, label_top_5_accuracy = 0.93116, loss = 1.1159751, regularization_loss = 0.0 162 | 163 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.2_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77866, label_top_5_accuracy = 0.93184, loss = 1.02233, regularization_loss = 0.0 164 | 165 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.0001_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73946, label_top_5_accuracy = 0.90888, loss = 1.4767729, regularization_loss = 0.0 166 | 167 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.0001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.68206, label_top_5_accuracy = 0.8747, loss = 1.7220646, regularization_loss = 0.0 168 | 169 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.1_BETA_0.0005_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73238, label_top_5_accuracy = 0.90698, loss = 1.3052814, regularization_loss = 0.0 170 | 171 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.05_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73986, label_top_5_accuracy = 0.90982, loss = 1.2386664, regularization_loss = 0.0 172 | 173 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.005_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77938, label_top_5_accuracy = 0.93426, loss = 1.0879052, regularization_loss = 0.0 174 | 175 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.0005_GAMMA_0.005_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77476, label_top_5_accuracy = 0.92926, loss = 1.1837187, regularization_loss = 0.0 176 | 177 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.1_GAMMA_0.005_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75216, label_top_5_accuracy = 0.91376, loss = 1.3076787, regularization_loss = 0.0 178 | 179 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.74924, label_top_5_accuracy = 0.91502, loss = 1.1803943, regularization_loss = 0.0 180 | 181 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.005_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75048, label_top_5_accuracy = 0.91442, loss = 1.1888653, regularization_loss = 0.0 182 | 183 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75092, label_top_5_accuracy = 0.91302, loss = 1.1644344, regularization_loss = 0.0 184 | 185 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.5_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.75062, label_top_5_accuracy = 0.91444, loss = 1.2249312, regularization_loss = 0.0 186 | 187 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.0001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.70622, label_top_5_accuracy = 0.88564, loss = 1.4944282, regularization_loss = 0.0 188 | 189 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_10_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.73806, label_top_5_accuracy = 0.90776, loss = 1.2411014, regularization_loss = 0.0 190 | 191 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_100_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.1_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7527, label_top_5_accuracy = 0.91318, loss = 1.2551881, regularization_loss = 0.0 192 | 193 | #--------------------------------------------------------------------------------------------- 194 | 195 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.001_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.70986, label_top_5_accuracy = 0.89282, loss = 1.4696014, regularization_loss = 0.0 196 | 197 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.0_BETA_0.0_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7138, label_top_5_accuracy = 0.89694, loss = 1.439495, regularization_loss = 0.0 198 | 199 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.01_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7002, label_top_5_accuracy = 0.88744, loss = 1.4174829, regularization_loss = 0.0 200 | 201 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.0_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.68674, label_top_5_accuracy = 0.88098, loss = 1.5044969, regularization_loss = 0.0 202 | 203 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.01_BETA_0.01_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.7002, label_top_5_accuracy = 0.88744, loss = 1.4174829, regularization_loss = 0.0 204 | 205 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.00_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71774, label_top_5_accuracy = 0.8976, loss = 1.3535299, regularization_loss = 0.0 206 | 207 | gs://martin_ma_mql_simclr/SHORT_chi_BS_4096_EPOCH_1_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_1.0_BETA_0.1_GAMMA_0.01_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.71948, label_top_5_accuracy = 0.89782, loss = 1.397318, regularization_loss = 0.0 208 | 209 | #------------------------------------------------------------------------------------------------ 210 | gs://martin_ma_mql_simclr/chi_BS_4096_EPOCH_300_TEMP_32_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_2_RESNETDEP_152_HIDDENNORM_false_ALPHA_0.3_BETA_0.01_GAMMA_0.001_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.77076, label_top_5_accuracy = 0.9261, loss = 1.1930635, regularization_loss = 0.0 211 | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 212 | gs://martin_ma_mql_simclr/dv_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.72214, label_top_5_accuracy = 0.90264, loss = 1.5218498, regularization_loss = 0.0 213 | 214 | gs://martin_ma_mql_simclr/nwj_BS_4096_EPOCH_100_TEMP_0.1_LR_0.1_LRSCALE_sqrt_WDECAY_1e-4_DATASET_imagenet2012_IMAGE_SIZE_224_SKRATIO_0.0625_WIDTHMUL_1_RESNETDEP_50_HIDDENNORM_true_ft_BS_4096_FINETUNE_AFTER_BLOCK_0_LR_0.16_WD_0_EPOCH_90_WARMUP_EPOCHS_0/model.ckpt-28151, contrast_loss = 0.0, contrastive_top_1_accuracy = 1.0, contrastive_top_5_accuracy = 1.0, global_step = 28151, label_top_1_accuracy = 0.0052, label_top_5_accuracy = 0.02212, loss = 6.673894, regularization_loss = 0.0 215 | 216 | 217 | -------------------------------------------------------------------------------- /vision/run.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The SimCLR Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific simclr governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """The main training pipeline.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import math 24 | import os 25 | from absl import app 26 | from absl import flags 27 | 28 | import resnet 29 | import data as data_lib 30 | import model as model_lib 31 | import model_util as model_util 32 | 33 | import tensorflow.compat.v1 as tf 34 | import tensorflow_datasets as tfds 35 | import tensorflow_hub as hub 36 | 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | flags.DEFINE_float( 42 | 'learning_rate', 0.3, 43 | 'Initial learning rate per batch size of 256.') 44 | 45 | flags.DEFINE_enum( 46 | 'learning_rate_scaling', 'linear', ['linear', 'sqrt'], 47 | 'How to scale the learning rate as a function of batch size.') 48 | 49 | flags.DEFINE_float( 50 | 'warmup_epochs', 10, 51 | 'Number of epochs of warmup.') 52 | 53 | flags.DEFINE_float( 54 | 'weight_decay', 1e-4, 55 | 'Amount of weight decay to use.') 56 | 57 | flags.DEFINE_float( 58 | 'batch_norm_decay', 0.9, 59 | 'Batch norm decay parameter.') 60 | 61 | flags.DEFINE_integer( 62 | 'train_batch_size', 512, 63 | 'Batch size for training.') 64 | 65 | flags.DEFINE_string( 66 | 'train_split', 'train', 67 | 'Split for training.') 68 | 69 | flags.DEFINE_integer( 70 | 'train_epochs', 100, 71 | 'Number of epochs to train for.') 72 | 73 | flags.DEFINE_integer( 74 | 'train_steps', 0, 75 | 'Number of steps to train for. If provided, overrides train_epochs.') 76 | 77 | flags.DEFINE_integer( 78 | 'eval_batch_size', 256, 79 | 'Batch size for eval.') 80 | 81 | flags.DEFINE_integer( 82 | 'train_summary_steps', 100, 83 | 'Steps before saving training summaries. If 0, will not save.') 84 | 85 | flags.DEFINE_integer( 86 | 'checkpoint_epochs', 1, 87 | 'Number of epochs between checkpoints/summaries.') 88 | 89 | flags.DEFINE_integer( 90 | 'checkpoint_steps', 0, 91 | 'Number of steps between checkpoints/summaries. If provided, overrides ' 92 | 'checkpoint_epochs.') 93 | 94 | flags.DEFINE_string( 95 | 'eval_split', 'validation', 96 | 'Split for evaluation.') 97 | 98 | flags.DEFINE_string( 99 | 'dataset', 'imagenet2012', 100 | 'Name of a dataset.') 101 | 102 | flags.DEFINE_bool( 103 | 'cache_dataset', False, 104 | 'Whether to cache the entire dataset in memory. If the dataset is ' 105 | 'ImageNet, this is a very bad idea, but for smaller datasets it can ' 106 | 'improve performance.') 107 | 108 | flags.DEFINE_enum( 109 | 'mode', 'train', ['train', 'eval', 'train_then_eval'], 110 | 'Whether to perform training or evaluation.') 111 | 112 | flags.DEFINE_enum( 113 | 'train_mode', 'pretrain', ['pretrain', 'finetune'], 114 | 'The train mode controls different objectives and trainable components.') 115 | 116 | flags.DEFINE_string( 117 | 'checkpoint', None, 118 | 'Loading from the given checkpoint for continued training or fine-tuning.') 119 | 120 | flags.DEFINE_string( 121 | 'variable_schema', '?!global_step', 122 | 'This defines whether some variable from the checkpoint should be loaded.') 123 | 124 | flags.DEFINE_bool( 125 | 'zero_init_logits_layer', False, 126 | 'If True, zero initialize layers after avg_pool for supervised learning.') 127 | 128 | flags.DEFINE_integer( 129 | 'fine_tune_after_block', -1, 130 | 'The layers after which block that we will fine-tune. -1 means fine-tuning ' 131 | 'everything. 0 means fine-tuning after stem block. 4 means fine-tuning ' 132 | 'just the linera head.') 133 | 134 | flags.DEFINE_string( 135 | 'master', None, 136 | 'Address/name of the TensorFlow master to use. By default, use an ' 137 | 'in-process master.') 138 | 139 | flags.DEFINE_string( 140 | 'model_dir', None, 141 | 'Model directory for training.') 142 | 143 | flags.DEFINE_string( 144 | 'data_dir', None, 145 | 'Directory where dataset is stored.') 146 | 147 | flags.DEFINE_bool( 148 | 'use_tpu', True, 149 | 'Whether to run on TPU.') 150 | 151 | tf.flags.DEFINE_string( 152 | 'tpu_name', None, 153 | 'The Cloud TPU to use for training. This should be either the name ' 154 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' 155 | 'url.') 156 | 157 | tf.flags.DEFINE_string( 158 | 'tpu_zone', None, 159 | '[Optional] GCE zone where the Cloud TPU is located in. If not ' 160 | 'specified, we will attempt to automatically detect the GCE project from ' 161 | 'metadata.') 162 | 163 | tf.flags.DEFINE_string( 164 | 'gcp_project', None, 165 | '[Optional] Project name for the Cloud TPU-enabled project. If not ' 166 | 'specified, we will attempt to automatically detect the GCE project from ' 167 | 'metadata.') 168 | 169 | flags.DEFINE_enum( 170 | 'optimizer', 'lars', ['momentum', 'adam', 'lars'], 171 | 'Optimizer to use.') 172 | 173 | flags.DEFINE_float( 174 | 'momentum', 0.9, 175 | 'Momentum parameter.') 176 | 177 | flags.DEFINE_string( 178 | 'eval_name', None, 179 | 'Name for eval.') 180 | 181 | flags.DEFINE_integer( 182 | 'keep_checkpoint_max', 5, 183 | 'Maximum number of checkpoints to keep.') 184 | 185 | flags.DEFINE_integer( 186 | 'keep_hub_module_max', 1, 187 | 'Maximum number of Hub modules to keep.') 188 | 189 | flags.DEFINE_float( 190 | 'temperature', 0.1, 191 | 'Temperature parameter for contrastive loss.') 192 | 193 | flags.DEFINE_boolean( 194 | 'hidden_norm', True, 195 | 'Temperature parameter for contrastive loss.') 196 | 197 | flags.DEFINE_enum( 198 | 'proj_head_mode', 'nonlinear', ['none', 'linear', 'nonlinear'], 199 | 'How the head projection is done.') 200 | 201 | flags.DEFINE_integer( 202 | 'proj_out_dim', 128, 203 | 'Number of head projection dimension.') 204 | 205 | flags.DEFINE_integer( 206 | 'num_proj_layers', 3, 207 | 'Number of non-linear head layers.') 208 | 209 | flags.DEFINE_integer( 210 | 'ft_proj_selector', 0, 211 | 'Which layer of the projection head to use during fine-tuning. ' 212 | '0 means throwing away the projection head, and -1 means the final layer.') 213 | 214 | flags.DEFINE_boolean( 215 | 'global_bn', True, 216 | 'Whether to aggregate BN statistics across distributed cores.') 217 | 218 | flags.DEFINE_integer( 219 | 'width_multiplier', 1, 220 | 'Multiplier to change width of network.') 221 | 222 | flags.DEFINE_integer( 223 | 'resnet_depth', 50, 224 | 'Depth of ResNet.') 225 | 226 | flags.DEFINE_float( 227 | 'sk_ratio', 0., 228 | 'If it is bigger than 0, it will enable SK. Recommendation: 0.0625.') 229 | 230 | flags.DEFINE_float( 231 | 'se_ratio', 0., 232 | 'If it is bigger than 0, it will enable SE.') 233 | 234 | flags.DEFINE_integer( 235 | 'image_size', 224, 236 | 'Input image size.') 237 | 238 | flags.DEFINE_float( 239 | 'color_jitter_strength', 1.0, 240 | 'The strength of color jittering.') 241 | 242 | flags.DEFINE_boolean( 243 | 'use_blur', True, 244 | 'Whether or not to use Gaussian blur for augmentation during pretraining.') 245 | 246 | ################################# 247 | ### Chi loss implementation 248 | ################################# 249 | 250 | flags.DEFINE_enum( 251 | 'loss_type', 'chi', ['chi', 'nce', 'dv', 'nwj', 'js', 'wpc'], 252 | 'Types of loss to use.') 253 | 254 | flags.DEFINE_float( 255 | 'alpha', 0.0, 256 | 'Alpha of chi loss.') 257 | 258 | flags.DEFINE_float( 259 | 'beta', 0.0, 260 | 'Beta of chi loss.') 261 | 262 | flags.DEFINE_float( 263 | 'gamma', 1.0, 264 | 'Gamma of chi loss.') 265 | 266 | flags.DEFINE_float( 267 | 'gradient_penalty_weight', 0.0, 268 | 'Parameter for gradient penalty (WPC only).') 269 | 270 | 271 | 272 | def build_hub_module(model, num_classes, global_step, checkpoint_path): 273 | """Create TF-Hub module.""" 274 | 275 | tags_and_args = [ 276 | # The default graph is built with batch_norm, dropout etc. in inference 277 | # mode. This graph version is good for inference, not training. 278 | ([], {'is_training': False}), 279 | # A separate "train" graph builds batch_norm, dropout etc. in training 280 | # mode. 281 | (['train'], {'is_training': True}), 282 | ] 283 | 284 | def module_fn(is_training): 285 | """Function that builds TF-Hub module.""" 286 | endpoints = {} 287 | inputs = tf.placeholder( 288 | tf.float32, [None, None, None, 3]) 289 | with tf.variable_scope('base_model', reuse=tf.AUTO_REUSE): 290 | hiddens = model(inputs, is_training) 291 | for v in ['initial_conv', 'initial_max_pool', 'block_group1', 292 | 'block_group2', 'block_group3', 'block_group4', 293 | 'final_avg_pool']: 294 | endpoints[v] = tf.get_default_graph().get_tensor_by_name( 295 | 'base_model/{}:0'.format(v)) 296 | if FLAGS.train_mode == 'pretrain': 297 | hiddens_proj = model_util.projection_head(hiddens, is_training) 298 | endpoints['proj_head_input'] = hiddens 299 | endpoints['proj_head_output'] = hiddens_proj 300 | else: 301 | logits_sup = model_util.supervised_head( 302 | hiddens, num_classes, is_training) 303 | endpoints['logits_sup'] = logits_sup 304 | hub.add_signature(inputs=dict(images=inputs), 305 | outputs=dict(endpoints, default=hiddens)) 306 | 307 | # Drop the non-supported non-standard graph collection. 308 | drop_collections = ['trainable_variables_inblock_%d'%d for d in range(6)] 309 | spec = hub.create_module_spec(module_fn, tags_and_args, drop_collections) 310 | hub_export_dir = os.path.join(FLAGS.model_dir, 'hub') 311 | checkpoint_export_dir = os.path.join(hub_export_dir, str(global_step)) 312 | if tf.io.gfile.exists(checkpoint_export_dir): 313 | # Do not save if checkpoint already saved. 314 | tf.io.gfile.rmtree(checkpoint_export_dir) 315 | spec.export( 316 | checkpoint_export_dir, 317 | checkpoint_path=checkpoint_path, 318 | name_transform_fn=None) 319 | 320 | if FLAGS.keep_hub_module_max > 0: 321 | # Delete old exported Hub modules. 322 | exported_steps = [] 323 | for subdir in tf.io.gfile.listdir(hub_export_dir): 324 | if not subdir.isdigit(): 325 | continue 326 | exported_steps.append(int(subdir)) 327 | exported_steps.sort() 328 | for step_to_delete in exported_steps[:-FLAGS.keep_hub_module_max]: 329 | tf.io.gfile.rmtree(os.path.join(hub_export_dir, str(step_to_delete))) 330 | 331 | 332 | def perform_evaluation(estimator, input_fn, eval_steps, model, num_classes, 333 | checkpoint_path=None): 334 | """Perform evaluation. 335 | 336 | Args: 337 | estimator: TPUEstimator instance. 338 | input_fn: Input function for estimator. 339 | eval_steps: Number of steps for evaluation. 340 | model: Instance of transfer_learning.models.Model. 341 | num_classes: Number of classes to build model for. 342 | checkpoint_path: Path of checkpoint to evaluate. 343 | 344 | Returns: 345 | result: A Dict of metrics and their values. 346 | """ 347 | if not checkpoint_path: 348 | checkpoint_path = estimator.latest_checkpoint() 349 | result = estimator.evaluate( 350 | input_fn, eval_steps, checkpoint_path=checkpoint_path, 351 | name=FLAGS.eval_name) 352 | 353 | # Record results as JSON. 354 | result_json_path = os.path.join(FLAGS.model_dir, 'result.json') 355 | with tf.io.gfile.GFile(result_json_path, 'w') as f: 356 | json.dump({k: float(v) for k, v in result.items()}, f) 357 | result_json_path = os.path.join( 358 | FLAGS.model_dir, 'result_%d.json'%result['global_step']) 359 | with tf.io.gfile.GFile(result_json_path, 'w') as f: 360 | json.dump({k: float(v) for k, v in result.items()}, f) 361 | flag_json_path = os.path.join(FLAGS.model_dir, 'flags.json') 362 | with tf.io.gfile.GFile(flag_json_path, 'w') as f: 363 | json.dump(FLAGS.flag_values_dict(), f) 364 | 365 | # Save Hub module. 366 | build_hub_module(model, num_classes, 367 | global_step=result['global_step'], 368 | checkpoint_path=checkpoint_path) 369 | 370 | return result 371 | 372 | 373 | def main(argv): 374 | if len(argv) > 1: 375 | raise app.UsageError('Too many command-line arguments.') 376 | 377 | # Enable training summary. 378 | if FLAGS.train_summary_steps > 0: 379 | tf.config.set_soft_device_placement(True) 380 | 381 | 382 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir) 383 | builder.download_and_prepare() 384 | num_train_examples = builder.info.splits[FLAGS.train_split].num_examples 385 | num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples 386 | num_classes = builder.info.features['label'].num_classes 387 | 388 | train_steps = model_util.get_train_steps(num_train_examples) 389 | eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size)) 390 | epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size)) 391 | 392 | resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay 393 | model = resnet.resnet_v1( 394 | resnet_depth=FLAGS.resnet_depth, 395 | width_multiplier=FLAGS.width_multiplier, 396 | cifar_stem=FLAGS.image_size <= 32) 397 | 398 | checkpoint_steps = ( 399 | FLAGS.checkpoint_steps or (FLAGS.checkpoint_epochs * epoch_steps)) 400 | 401 | cluster = None 402 | if FLAGS.use_tpu and FLAGS.master is None: 403 | if FLAGS.tpu_name: 404 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver( 405 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 406 | else: 407 | cluster = tf.distribute.cluster_resolver.TPUClusterResolver() 408 | tf.config.experimental_connect_to_cluster(cluster) 409 | tf.tpu.experimental.initialize_tpu_system(cluster) 410 | 411 | default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1 412 | sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED 413 | run_config = tf.estimator.tpu.RunConfig( 414 | tpu_config=tf.estimator.tpu.TPUConfig( 415 | iterations_per_loop=checkpoint_steps, 416 | eval_training_input_configuration=sliced_eval_mode 417 | if FLAGS.use_tpu else default_eval_mode), 418 | model_dir=FLAGS.model_dir, 419 | save_summary_steps=checkpoint_steps, 420 | save_checkpoints_steps=checkpoint_steps, 421 | keep_checkpoint_max=FLAGS.keep_checkpoint_max, 422 | master=FLAGS.master, 423 | cluster=cluster) 424 | estimator = tf.estimator.tpu.TPUEstimator( 425 | model_lib.build_model_fn(model, num_classes, num_train_examples), 426 | config=run_config, 427 | train_batch_size=FLAGS.train_batch_size, 428 | eval_batch_size=FLAGS.eval_batch_size, 429 | use_tpu=FLAGS.use_tpu) 430 | 431 | if FLAGS.mode == 'eval': 432 | for ckpt in tf.train.checkpoints_iterator( 433 | run_config.model_dir, min_interval_secs=15): 434 | try: 435 | result = perform_evaluation( 436 | estimator=estimator, 437 | input_fn=data_lib.build_input_fn(builder, False), 438 | eval_steps=eval_steps, 439 | model=model, 440 | num_classes=num_classes, 441 | checkpoint_path=ckpt) 442 | except tf.errors.NotFoundError: 443 | continue 444 | if result['global_step'] >= train_steps: 445 | return 446 | else: 447 | estimator.train( 448 | data_lib.build_input_fn(builder, True), max_steps=train_steps) 449 | if FLAGS.mode == 'train_then_eval': 450 | perform_evaluation( 451 | estimator=estimator, 452 | input_fn=data_lib.build_input_fn(builder, False), 453 | eval_steps=eval_steps, 454 | model=model, 455 | num_classes=num_classes) 456 | 457 | 458 | if __name__ == '__main__': 459 | tf.disable_v2_behavior() # Disable eager mode when running with TF2. 460 | app.run(main) 461 | -------------------------------------------------------------------------------- /vision/tpu_pretrain_finetune_resnet128.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TPU_NAME= # enter your tpu name here 3 | STORAGE_BUCKET= # enter your storage bucket here 4 | DATA_DIR=$STORAGE_BUCKET/tensorflow_datasets 5 | 6 | # hyperparameters 7 | BATCH_SIZE=4096 8 | EPOCH=100 9 | TEMP=32 10 | LR=0.1 11 | LR_SCALE='sqrt' 12 | W_DECAY=1e-4 13 | DATASET='imagenet2012' 14 | IMAGE_SIZE=224 15 | SK_RATIO=0.0625 16 | WIDTH_MUL=2 17 | RESNET_DEPTH=152 18 | 19 | # different for different losses 20 | LOSS_TYPE='chi' 21 | HIDDEN_NORM=false 22 | if [ $LOSS_TYPE == 'chi' ] 23 | then 24 | # check hidden norm 25 | HIDDEN_NORM=false 26 | ALPHA=0.3 27 | BETA=0.001 28 | GAMMA=0.01 29 | if [ $TEMP -le 1 ] 30 | then 31 | exit 0 32 | fi 33 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}_ALPHA_${ALPHA}_BETA_${BETA}_GAMMA_${GAMMA}" 34 | python run.py --train_mode=pretrain \ 35 | --train_batch_size=$BATCH_SIZE \ 36 | --train_epochs=$EPOCH \ 37 | --temperature=$TEMP \ 38 | --learning_rate=$LR \ 39 | --learning_rate_scaling=$LR_SCALE \ 40 | --weight_decay=$W_DECAY \ 41 | --dataset=$DATASET \ 42 | --image_size=$IMAGE_SIZE \ 43 | --eval_split=validation \ 44 | --data_dir=$DATA_DIR \ 45 | --model_dir=$MODEL_DIR \ 46 | --use_tpu=True \ 47 | --tpu_name=$TPU_NAME \ 48 | --train_summary_steps=0 \ 49 | --sk_ratio $SK_RATIO \ 50 | --width_multiplier $WIDTH_MUL \ 51 | --resnet_depth $RESNET_DEPTH \ 52 | --loss_type $LOSS_TYPE \ 53 | --alpha=$ALPHA \ 54 | --beta=$BETA \ 55 | --gamma=$GAMMA \ 56 | --hidden_norm=$HIDDEN_NORM 57 | fi 58 | 59 | if [ $LOSS_TYPE == 'nce' ] # JS, WPC, etc 60 | then 61 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}" 62 | echo $MODEL_DIR 63 | python run.py --train_mode=pretrain \ 64 | --train_batch_size=$BATCH_SIZE \ 65 | --train_epochs=$EPOCH \ 66 | --temperature=$TEMP \ 67 | --learning_rate=$LR \ 68 | --learning_rate_scaling=$LR_SCALE \ 69 | --weight_decay=$W_DECAY \ 70 | --dataset=$DATASET \ 71 | --image_size=$IMAGE_SIZE \ 72 | --eval_split=validation \ 73 | --data_dir=$DATA_DIR \ 74 | --model_dir=$MODEL_DIR \ 75 | --use_tpu=True \ 76 | --tpu_name=$TPU_NAME \ 77 | --train_summary_steps=0 \ 78 | --sk_ratio $SK_RATIO \ 79 | --width_multiplier $WIDTH_MUL \ 80 | --resnet_depth $RESNET_DEPTH \ 81 | --loss_type $LOSS_TYPE \ 82 | --hidden_norm=$HIDDEN_NORM 83 | fi 84 | 85 | ############################################################################################## 86 | #####################Fine tune 87 | ############################################################################################## 88 | CHKPT_DIR=$MODEL_DIR 89 | FINETUNE_AFTER_BLOCK=0 90 | LR=0.16 91 | WD=0 92 | EPOCHS=90 93 | WARMUP_EPOCHS=0 94 | MODEL_DIR="${CHKPT_DIR}_ft_BS_${BATCH_SIZE}_FINETUNE_AFTER_BLOCK_${FINETUNE_AFTER_BLOCK}_LR_${LR}_WD_${WD}_EPOCH_${EPOCHS}_WARMUP_EPOCHS_${WARMUP_EPOCHS}" 95 | echo $MODEL_DIR 96 | if [ $LOSS_TYPE == "chi" ] 97 | then 98 | echo "Running chi" 99 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE --alpha $ALPHA --beta $BETA --gamma $GAMMA 100 | else 101 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE 102 | fi 103 | -------------------------------------------------------------------------------- /vision/tpu_pretrain_finetune_resnet50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | TPU_NAME= # enter your tpu name here 3 | STORAGE_BUCKET= # enter your storage bucket here 4 | DATA_DIR=$STORAGE_BUCKET/tensorflow_datasets 5 | 6 | # hyperparameters 7 | BATCH_SIZE=4096 8 | EPOCH=100 9 | TEMP=32 10 | LR=0.1 11 | LR_SCALE='sqrt' 12 | W_DECAY=1e-4 13 | DATASET='imagenet2012' 14 | IMAGE_SIZE=224 15 | SK_RATIO=0.0625 16 | WIDTH_MUL=1 17 | RESNET_DEPTH=50 18 | 19 | # different for different losses 20 | LOSS_TYPE='chi' 21 | HIDDEN_NORM=true 22 | if [ $LOSS_TYPE == 'chi' ] 23 | then 24 | # check hidden norm 25 | HIDDEN_NORM=false 26 | ALPHA=0.3 27 | BETA=0.01 28 | GAMMA=0.1 29 | if [ $TEMP -le 1 ] 30 | then 31 | exit 0 32 | fi 33 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}_ALPHA_${ALPHA}_BETA_${BETA}_GAMMA_${GAMMA}" 34 | echo $MODEL_DIR 35 | python run.py --train_mode=pretrain \ 36 | --train_batch_size=$BATCH_SIZE \ 37 | --train_epochs=$EPOCH \ 38 | --temperature=$TEMP \ 39 | --learning_rate=$LR \ 40 | --learning_rate_scaling=$LR_SCALE \ 41 | --weight_decay=$W_DECAY \ 42 | --dataset=$DATASET \ 43 | --image_size=$IMAGE_SIZE \ 44 | --eval_split=validation \ 45 | --data_dir=$DATA_DIR \ 46 | --model_dir=$MODEL_DIR \ 47 | --use_tpu=True \ 48 | --tpu_name=$TPU_NAME \ 49 | --train_summary_steps=0 \ 50 | --sk_ratio $SK_RATIO \ 51 | --width_multiplier $WIDTH_MUL \ 52 | --resnet_depth $RESNET_DEPTH \ 53 | --loss_type $LOSS_TYPE \ 54 | --alpha=$ALPHA \ 55 | --beta=$BETA \ 56 | --gamma=$GAMMA \ 57 | --hidden_norm=$HIDDEN_NORM 58 | fi 59 | 60 | if [ $LOSS_TYPE == 'nce' ] # JS, WPC, etc 61 | then 62 | MODEL_DIR=$STORAGE_BUCKET/"${LOSS_TYPE}_BS_${BATCH_SIZE}_EPOCH_${EPOCH}_TEMP_${TEMP}_LR_${LR}_LRSCALE_${LR_SCALE}_WDECAY_${W_DECAY}_DATASET_${DATASET}_IMAGE_SIZE_${IMAGE_SIZE}_SKRATIO_${SK_RATIO}_WIDTHMUL_${WIDTH_MUL}_RESNETDEP_${RESNET_DEPTH}_HIDDENNORM_${HIDDEN_NORM}" 63 | echo $MODEL_DIR 64 | python run.py --train_mode=pretrain \ 65 | --train_batch_size=$BATCH_SIZE \ 66 | --train_epochs=$EPOCH \ 67 | --temperature=$TEMP \ 68 | --learning_rate=$LR \ 69 | --learning_rate_scaling=$LR_SCALE \ 70 | --weight_decay=$W_DECAY \ 71 | --dataset=$DATASET \ 72 | --image_size=$IMAGE_SIZE \ 73 | --eval_split=validation \ 74 | --data_dir=$DATA_DIR \ 75 | --model_dir=$MODEL_DIR \ 76 | --use_tpu=True \ 77 | --tpu_name=$TPU_NAME \ 78 | --train_summary_steps=0 \ 79 | --sk_ratio $SK_RATIO \ 80 | --width_multiplier $WIDTH_MUL \ 81 | --resnet_depth $RESNET_DEPTH \ 82 | --loss_type $LOSS_TYPE \ 83 | --hidden_norm=$HIDDEN_NORM 84 | fi 85 | 86 | ############################################################################################## 87 | #####################Fine tune 88 | ############################################################################################## 89 | CHKPT_DIR=$MODEL_DIR 90 | FINETUNE_AFTER_BLOCK=0 91 | LR=0.16 92 | WD=0 93 | EPOCHS=90 94 | WARMUP_EPOCHS=0 95 | MODEL_DIR="${CHKPT_DIR}_ft_BS_${BATCH_SIZE}_FINETUNE_AFTER_BLOCK_${FINETUNE_AFTER_BLOCK}_LR_${LR}_WD_${WD}_EPOCH_${EPOCHS}_WARMUP_EPOCHS_${WARMUP_EPOCHS}" 96 | echo $MODEL_DIR 97 | if [ $LOSS_TYPE == "chi" ] 98 | then 99 | echo "Running chi" 100 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE --alpha $ALPHA --beta $BETA --gamma $GAMMA 101 | else 102 | python run.py --mode=train_then_eval --train_mode=finetune --fine_tune_after_block=$FINETUNE_AFTER_BLOCK --zero_init_logits_layer=True --variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' --global_bn=True --optimizer=momentum --learning_rate=$LR --weight_decay=$WD --train_epochs=$EPOCHS --train_batch_size=$BATCH_SIZE --warmup_epochs=$WARMUP_EPOCHS --dataset=imagenet2012 --image_size=224 --eval_split=validation --data_dir=gs://martin_ma_mql_simclr/tensorflow_datasets --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR --use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 --width_multiplier $WIDTH_MUL --resnet_depth $RESNET_DEPTH --sk_ratio $SK_RATIO --loss_type $LOSS_TYPE 103 | fi 104 | --------------------------------------------------------------------------------