├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── luckmatter ├── README.md ├── check_models.py ├── cluster_utils.py ├── model_gen.py ├── recon_multilayer.py ├── test_multilayer.py ├── theory_utils.py ├── utils_corrs.py └── vis_corrs.py ├── ssl ├── .gitignore ├── README.md ├── common_utils │ ├── __init__.py │ ├── assert_utils.py │ ├── helper.py │ ├── logger.py │ ├── multi_counter.py │ ├── saver.py │ ├── stats.py │ └── stopwatch.py ├── hltm │ ├── README.md │ ├── conf │ │ └── hltm.yaml │ └── simCLR_hltm.py └── real-dataset │ ├── README.md │ ├── bn_gen.py │ ├── bn_gen_utils.py │ ├── byol_trainer.py │ ├── check_sweep.sh │ ├── config │ ├── bn_gen.yaml │ ├── byol_config.yaml │ ├── hydra │ │ └── launcher │ │ │ └── submitit.yaml │ ├── relu_2layer.yaml │ ├── sa.yaml │ ├── sa_linear.yaml │ └── test.yaml │ ├── data │ ├── gaussian_blur.py │ ├── multi_view_data_injector.py │ └── transforms.py │ ├── linear_feature_eval.py │ ├── loss │ └── nt_xent.py │ ├── main.py │ ├── main_checkresult.py │ ├── models │ ├── mlp_head.py │ └── resnet_base_network.py │ ├── paths.txt │ ├── relu_2layer.py │ ├── requirement.txt │ ├── self_attention.py │ ├── self_attention_linear_test.py │ ├── simclr_trainer.py │ ├── test.py │ ├── test2.py │ ├── try_relu.py │ └── utils.py └── student_specialization ├── README.md ├── conf ├── config.yaml ├── config_multilayer.yaml └── hydra │ └── launcher │ ├── fairtask.yaml │ └── submitit.yaml ├── dataset.py ├── model_gen.py ├── recon_multilayer.py ├── recon_two_layer.py ├── stats_operator.py ├── teacher_tune.py ├── theory_utils.py ├── utils.py ├── utils_corrs.py ├── vis_corrs.py └── visualization ├── utils.py ├── visualize.py └── visualize_multi.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.swp 3 | *.pdf 4 | */outputs/* 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to LuckMatters 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 2 spaces for indentation rather than tabs 31 | * 80 character line length 32 | 33 | ## License 34 | By contributing to LuckMatters, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains the code of the following 6 papers: 2 | 3 | ## Analysis of Self-supervised learning (`./ssl`) 4 | **Understanding the Role of Nonlinearity in Training Dynamics of Contrastive Learning** 5 | 6 | Yuandong Tian 7 | 8 | [arXiv](https://arxiv.org/abs/2206.01342) 9 | 10 | **Understanding Deep Contrastive Learning via Coordinate-wise Optimization** 11 | 12 | Yuandong Tian 13 | 14 | [NeurIPS'22](https://arxiv.org/abs/2201.12680) Oral 15 | 16 | **Understanding self-supervised Learning Dynamics without Contrastive Pairs** 17 | 18 | Yuandong Tian, Xinlei Chen, Surya Ganguli 19 | 20 | ICML 2021 [link](https://arxiv.org/abs/2102.06810), *Outstanding Paper Honorable Mention* 21 | 22 | **Understanding Self-supervised Learning with Dual Deep Networks** 23 | 24 | Yuandong Tian, Lantao Yu, Xinlei Chen, Surya Ganguli 25 | 26 | arXiv [link](https://arxiv.org/abs/2010.00578) 27 | 28 | 29 | 30 | 31 | ## Teacher-student setting in supervised learning 32 | **Student Specialization in Deep ReLU Networks With Finite Width and Input Dimension** (`./student_specialization`) 33 | 34 | Yuandong Tian 35 | 36 | ICML 2020 [link](https://arxiv.org/abs/1909.13458) 37 | 38 | **Luck Matters: Luck Matters: Understanding Training Dynamics of Deep ReLU Networks** (`./luckmatter`) 39 | 40 | Yuandong Tian, Tina Jiang, Qucheng Gong, Ari Morcos 41 | 42 | arxiv [link](https://arxiv.org/abs/1905.13405) 43 | -------------------------------------------------------------------------------- /luckmatter/README.md: -------------------------------------------------------------------------------- 1 | # LuckMatters 2 | Code of DL theory for multilayer ReLU networks. 3 | 4 | Relevant paper: "Luck Matters: Understanding Training Dynamics of Deep ReLU Networks", Arxiv [link](https://arxiv.org/abs/1905.13405). 5 | 6 | # Usage 7 | Using existing DL library like PyTorch. Lowest layer gets 0.999 in a few iterations, second lowest can get to ~0.95. 8 | ``` 9 | python recon_multilayer.py --data_std 10.0 --node_multi 10 --lr 0.05 --dataset gaussian --d_output 100 --seed 124 10 | ``` 11 | 12 | Matrix version to check over-parameterization theorem (should be able to see the second layer relevant weights are zero). 13 | ``` 14 | python test_multilayer.py --perturb --node_multi 2 --lr 0.05 --init_std 0.1 --batchsize 64 --seed 232 --verbose 15 | ``` 16 | 17 | Check `W_row_norm` and we can find that: 18 | ``` 19 | [1]: W_row_norm: tensor([1.2050e+00, 1.2196e+00, 1.1427e+00, 1.3761e+00, 1.1161e+00, 1.4610e+00, 20 | 1.1305e+00, 1.0719e+00, 1.1388e+00, 1.2870e+00, 1.2480e+00, 1.1709e+00, 21 | 1.2928e+00, 1.2677e+00, 1.2754e+00, 1.1399e+00, 1.1465e+00, 1.1292e+00, 22 | 1.4311e+00, 1.1534e+00, 1.1562e-04, 1.0990e-04, 9.2137e-05, 8.3408e-05, 23 | 1.2864e-04, 2.3824e-04, 1.0199e-04, 1.1282e-04, 1.1691e-04, 1.4917e-03, 24 | 1.5522e-04, 6.1745e-05, 1.1086e-04, 1.8588e-04, 1.1351e-04, 2.4844e-04, 25 | 1.3347e-04, 6.5837e-05, 1.5340e-03, 9.1208e-05, 4.2515e-05]) 26 | ``` 27 | 28 | Other usage: 29 | ----- 30 | 31 | Matrix version backprapagation: 32 | ``` 33 | python test_multilayer.py --init_std 0.1 --lr 0.2 --seed 433 34 | ``` 35 | 36 | Precise gradient (single sample gradient accumulation, very slow) 37 | ``` 38 | python test_multilayer.py --init_std 0.1 --lr 0.2 --seed 433 --use_accurate_grad 39 | ``` 40 | 41 | Note that 42 | 43 | 1. `data_std` needs to be 10 so that the generated dataset will cover corners (if it is 1 then we won't be able to cover all corners and the correlation is low). 44 | 45 | 2. It looks like `node_multi = 10` is probably good enough. More `node_multi` makes it slower (in terms of steps) to converge. 46 | 47 | 3. More supervision definitely helps. It looks like the larger `d_output` the better. `d_output = 10` also works (also all 0.999 in the lowest layer) but not as good as `d_output = 100`. 48 | 49 | 4. High `lr` seems to make it unstable. 50 | 5. Add `--normalize` makes it a bit worse. More secret in BatchNorm! 51 | 52 | # Visualization code 53 | Will be released soon. 54 | 55 | # License 56 | See LICENSE file. 57 | -------------------------------------------------------------------------------- /luckmatter/check_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from torch.nn import Conv2d, BatchNorm2d 9 | from torchvision.models import vgg, resnet 10 | 11 | def check_bias(model): 12 | for m in model.modules(): 13 | if isinstance(m, BatchNorm2d): 14 | n = m.bias.size(0) 15 | pos = (m.bias > 0).sum().item() 16 | neg = n - pos 17 | print("n: %d, >0: %.2f%% (%d) , <0: %.2f%% (%d)" % (n, pos * 100 / n, pos, neg * 100 / n, neg)) 18 | 19 | for model in ("vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"): 20 | print(model) 21 | m = eval(f"vgg.{model}(pretrained=True)") 22 | check_bias(m) 23 | 24 | for model in ("resnet18", "resnet34", "resnet50", "resnet101"): 25 | print(model) 26 | m = eval(f"resnet.{model}(pretrained=True)") 27 | check_bias(m) 28 | 29 | -------------------------------------------------------------------------------- /luckmatter/cluster_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import sys 9 | 10 | from datetime import datetime 11 | import pickle 12 | 13 | def signature(): 14 | return str(datetime.now()).replace(" ", "_").replace(":", "-").replace(".", "_") 15 | 16 | def sig(): 17 | return datetime.now().strftime("%m%d%y_%H%M%S_%f") 18 | 19 | def print_info(): 20 | print("cuda: " + os.environ.get("CUDA_VISIBLE_DEVICES", "")) 21 | 22 | def add_parser_argument(parser): 23 | parser.add_argument("--save_dir", type=str, default="./") 24 | 25 | def set_args(argv, args): 26 | cmdline = " ".join(argv) 27 | signature = sig() 28 | setattr(args, 'signature', signature) 29 | setattr(args, "cmdline", cmdline) 30 | 31 | def save_data(prefix, args, data): 32 | filename = f"{prefix}-{args.signature}.pickle" 33 | save_dir = os.path.join(args.save_dir, filename) 34 | print(f"Save to {save_dir}") 35 | pickle.dump(dict(data=data, args=args, save_dir=save_dir), open(save_dir, "wb")) 36 | 37 | 38 | -------------------------------------------------------------------------------- /luckmatter/model_gen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | import random 13 | from theory_utils import haar_measure, init_separate_w 14 | 15 | 16 | # Generate random orth matrix. 17 | import numpy as np 18 | import math 19 | 20 | def get_aug_w(w): 21 | # w: [output_d, input_d] 22 | # aug_w: [output_d + 1, input_d + 1] 23 | output_d, input_d = w.weight.size() 24 | aug_w = torch.zeros( (output_d + 1, input_d + 1), dtype = w.weight.dtype, device = w.weight.device) 25 | aug_w[:output_d, :input_d] = w.weight.data 26 | aug_w[:output_d, input_d] = w.bias.data 27 | aug_w[output_d, input_d] = 1 28 | return aug_w 29 | 30 | def set_orth(layer): 31 | w = layer.weight 32 | orth = haar_measure(w.size(1)) 33 | w.data = torch.from_numpy(orth[:w.size(0), :w.size(1)].astype('f4')).cuda() 34 | 35 | def set_add_noise(layer, teacher_layer, perturb): 36 | layer.weight.data[:] = teacher_layer.weight.data[:] + torch.randn(teacher_layer.weight.size()).cuda() * perturb 37 | layer.bias.data[:] = teacher_layer.bias.data[:] + torch.randn(teacher_layer.bias.size()).cuda() * perturb 38 | 39 | def set_same_dir(layer, teacher_layer): 40 | norm = layer.weight.data.norm() 41 | r = norm / teacher_layer.weight.data.norm() 42 | layer.weight.data[:] = teacher_layer.weight.data * r 43 | layer.bias.data[:] = teacher_layer.bias.data * r 44 | 45 | def set_same_sign(layer, teacher_layer): 46 | sel = (teacher_layer.weight.data > 0) * (layer.weight.data < 0) + (teacher_layer.weight.data < 0) * (layer.weight.data > 0) 47 | layer.weight.data[sel] *= -1.0 48 | 49 | sel = (teacher_layer.bias.data > 0) * (layer.bias.data < 0) + (teacher_layer.bias.data < 0) * (layer.bias.data > 0) 50 | layer.bias.data[sel] *= -1.0 51 | 52 | def normalize_layer(layer): 53 | # [output, input] 54 | w = layer.weight.data 55 | for i in range(w.size(0)): 56 | norm = w[i].pow(2).sum().sqrt() + 1e-5 57 | w[i] /= norm 58 | if layer.bias is not None: 59 | layer.bias.data[i] /= norm 60 | 61 | def init_w(layer, use_sep=True): 62 | sz = layer.weight.size() 63 | output_d = sz[0] 64 | input_d = 1 65 | for s in sz[1:]: 66 | input_d *= s 67 | 68 | if use_sep: 69 | choices = [-0.5, -0.25, 0, 0.25, 0.5] 70 | layer.weight.data[:] = torch.from_numpy(init_separate_w(output_d, input_d, choices)).view(*sz).cuda() 71 | if layer.bias is not None: 72 | layer.bias.data.uniform_(-.5, 0.5) 73 | 74 | def init_w2(w, multiplier=5): 75 | w.weight.data *= multiplier 76 | w.bias.data.normal_(0, std=1) 77 | # w.bias.data *= 5 78 | for i, ww in enumerate(w.weight.data): 79 | pos_ratio = (ww > 0.0).sum().item() / w.weight.size(1) - 0.5 80 | w.bias.data[i] -= pos_ratio 81 | 82 | 83 | class Model(nn.Module): 84 | def __init__(self, d, ks, d_output, multi=1, has_bn=True, has_bn_affine=True, has_bias=True, bn_before_relu=False): 85 | super(Model, self).__init__() 86 | self.d = d 87 | self.ks = ks 88 | self.has_bn = has_bn 89 | self.ws_linear = nn.ModuleList() 90 | self.ws_bn = nn.ModuleList() 91 | self.bn_before_relu = bn_before_relu 92 | last_k = d 93 | self.sizes = [d] 94 | 95 | for k in ks: 96 | k *= multi 97 | self.ws_linear.append(nn.Linear(last_k, k, bias=has_bias)) 98 | if has_bn: 99 | self.ws_bn.append(nn.BatchNorm1d(k, affine=has_bn_affine)) 100 | self.sizes.append(k) 101 | last_k = k 102 | 103 | self.final_w = nn.Linear(last_k, d_output, bias=has_bias) 104 | self.relu = nn.ReLU() 105 | 106 | self.sizes.append(d_output) 107 | 108 | def init_orth(self): 109 | for w in self.ws: 110 | set_orth(w) 111 | set_orth(self.final_w) 112 | 113 | def set_teacher(self, teacher, perturb): 114 | for w_s, w_t in zip(self.ws, teacher.ws): 115 | set_add_noise(w_s, w_t, perturb) 116 | set_add_noise(self.final_w, teacher.final_w, perturb) 117 | 118 | def set_teacher_dir(self, teacher): 119 | for w_s, w_t in zip(self.ws, teacher.ws): 120 | set_same_dir(w_s, w_t) 121 | set_same_dir(self.final_w, teacher.final_w) 122 | 123 | def set_teacher_sign(self, teacher): 124 | for w_s, w_t in zip(self.ws, teacher.ws): 125 | set_same_sign(w_s, w_t) 126 | set_same_sign(self.final_w, teacher.final_w) 127 | 128 | def forward(self, x): 129 | hs = [] 130 | pre_bns = [] 131 | #bns = [] 132 | h = x 133 | for i in range(len(self.ws_linear)): 134 | w = self.ws_linear[i] 135 | h = w(h) 136 | if self.bn_before_relu: 137 | pre_bns.append(h) 138 | if len(self.ws_bn) > 0: 139 | bn = self.ws_bn[i] 140 | h = bn(h) 141 | h = self.relu(h) 142 | else: 143 | h = self.relu(h) 144 | pre_bns.append(h) 145 | if len(self.ws_bn) > 0: 146 | bn = self.ws_bn[i] 147 | h = bn(h) 148 | hs.append(h) 149 | #bns.append(h) 150 | y = self.final_w(hs[-1]) 151 | return dict(hs=hs, pre_bns=pre_bns, y=y) 152 | 153 | def init_w(self, use_sep=True): 154 | for w in self.ws_linear: 155 | init_w(w, use_sep=use_sep) 156 | init_w(self.final_w, use_sep=use_sep) 157 | 158 | def reset_parameters(self): 159 | for w in self.ws_linear: 160 | w.reset_parameters() 161 | for w in self.ws_bn: 162 | w.reset_parameters() 163 | self.final_w.reset_parameters() 164 | 165 | def normalize(self): 166 | for w in self.ws_linear: 167 | normalize_layer(w) 168 | normalize_layer(self.final_w) 169 | 170 | def from_bottom_linear(self, j): 171 | if j < len(self.ws_linear): 172 | return self.ws_linear[j].weight.data 173 | elif j == len(self.ws_linear): 174 | return self.final_w.weight.data 175 | else: 176 | raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws))) 177 | 178 | def from_bottom_aug_w(self, j): 179 | if j < len(self.ws_linear): 180 | return get_aug_w(self.ws_linear[j]) 181 | elif j == len(self.ws_linear): 182 | return get_aug_w(self.final_w) 183 | else: 184 | raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws))) 185 | 186 | def num_layers(self): 187 | return len(self.ws_linear) + 1 188 | 189 | def from_bottom_bn(self, j): 190 | assert j < len(self.ws_bn) 191 | return self.ws_bn[j] 192 | 193 | 194 | class ModelConv(nn.Module): 195 | def __init__(self, input_size, ks, d_output, multi=1, has_bn=True, bn_before_relu=False): 196 | super(ModelConv, self).__init__() 197 | self.ks = ks 198 | self.ws_linear = nn.ModuleList() 199 | self.ws_bn = nn.ModuleList() 200 | self.bn_before_relu = bn_before_relu 201 | 202 | init_k, h, w = input_size 203 | last_k = init_k 204 | 205 | for k in ks: 206 | k *= multi 207 | self.ws_linear.append(nn.Conv2d(last_k, k, 3)) 208 | if has_bn: 209 | self.ws_bn.append(nn.BatchNorm2d(k)) 210 | last_k = k 211 | h -= 2 212 | w -= 2 213 | 214 | self.final_w = nn.Linear(last_k * h * w, d_output) 215 | self.relu = nn.ReLU() 216 | 217 | def forward(self, x): 218 | hs = [] 219 | #bns = [] 220 | h = x 221 | for i in range(len(self.ws_linear)): 222 | w = self.ws_linear[i] 223 | h = w(h) 224 | if self.bn_before_relu: 225 | if len(self.ws_bn) > 0: 226 | bn = self.ws_bn[i] 227 | h = bn(h) 228 | h = self.relu(h) 229 | else: 230 | h = self.relu(h) 231 | if len(self.ws_bn) > 0: 232 | bn = self.ws_bn[i] 233 | h = bn(h) 234 | hs.append(h) 235 | #bns.append(h) 236 | h = hs[-1].view(h.size(0), -1) 237 | y = self.final_w(h) 238 | return dict(hs=hs, y=y) 239 | 240 | def init_w(self, use_sep=True): 241 | for w in self.ws_linear: 242 | init_w(w, use_sep=use_sep) 243 | init_w(self.final_w, use_sep=use_sep) 244 | 245 | def normalize(self): 246 | for w in self.ws_linear: 247 | normalize_layer(w) 248 | normalize_layer(self.final_w) 249 | 250 | def normalize_last(self): 251 | normalize_layer(self.final_w) 252 | 253 | def reset_parameters(self): 254 | for w in self.ws_linear: 255 | w.reset_parameters() 256 | for w in self.ws_bn: 257 | w.reset_parameters() 258 | self.final_w.reset_parameters() 259 | 260 | def from_bottom_linear(self, j): 261 | if j < len(self.ws_linear): 262 | return self.ws_linear[j].weight.data 263 | elif j == len(self.ws_linear): 264 | return self.final_w.weight.data 265 | else: 266 | raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws))) 267 | 268 | def num_layers(self): 269 | return len(self.ws_linear) + 1 270 | 271 | def from_bottom_bn(self, j): 272 | assert j < len(self.ws_bn) 273 | return self.ws_bn[j] 274 | 275 | def prune(net, ratios): 276 | # Prune the network and finetune. 277 | n = net.num_layers() 278 | # Compute L1 norm and and prune them globally 279 | masks = [] 280 | inactive_nodes = [] 281 | for i in range(1, n): 282 | W = net.from_bottom_linear(i) 283 | # Prune all input neurons 284 | input_dim = W.size(1) 285 | fc_to_conv = False 286 | 287 | if isinstance(net, ModelConv): 288 | if len(W.size()) == 4: 289 | # W: [output_filter, input_filter, x, y] 290 | w_norms = W.permute(1, 0, 2, 3).contiguous().view(W.size(1), -1).abs().mean(1) 291 | else: 292 | # The final FC layer. 293 | input_dim = net.from_bottom_linear(i - 1).size(0) 294 | W_reshaped = W.view(W.size(0), -1, input_dim) 295 | w_norms = W_reshaped.view(-1, input_dim).abs().mean(0) 296 | fc_to_conv = True 297 | else: 298 | # W: [output_dim, input_dim] 299 | w_norms = W.abs().mean(0) 300 | 301 | sorted_w, sorted_indices = w_norms.sort(0) 302 | n_pruned = int(input_dim * ratios[i - 1]) 303 | inactive_mask = sorted_indices[:n_pruned] 304 | 305 | m = W.clone().fill_(1.0) 306 | if fc_to_conv: 307 | m = m.view(m.size(0), -1, input_dim) 308 | m[:, :, inactive_mask] = 0 309 | m = m.view(W.size(0), W.size(1)) 310 | else: 311 | m[:, inactive_mask] = 0 312 | 313 | # Set the mask for the lower layer to zero. 314 | inactive_nodes.append(inactive_mask.cpu().tolist()) 315 | masks.append(m) 316 | 317 | return inactive_nodes, masks 318 | -------------------------------------------------------------------------------- /luckmatter/theory_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import random 9 | 10 | def haar_measure(n): 11 | '''Generate an n-by-n Random matrix distributed with Haar measure''' 12 | z = np.random.randn(n,n) 13 | q,r = np.linalg.qr(z) 14 | d = np.diag(r) 15 | ph = d/np.absolute(d) 16 | q = np.dot(np.dot(q,np.diag(ph)), q) 17 | return q 18 | 19 | def init_separate_w(output_d, input_d, choices): 20 | existing_encoding = set() 21 | existing_encoding.add(tuple([0] * input_d)) 22 | 23 | w = np.zeros((output_d, input_d)) 24 | 25 | for i in range(output_d): 26 | while True: 27 | encoding = tuple( random.sample(choices, 1)[0] for j in range(input_d) ) 28 | if encoding not in existing_encoding: 29 | break 30 | for j in range(input_d): 31 | w[i, j] = encoding[j] 32 | existing_encoding.add(encoding) 33 | 34 | return w 35 | -------------------------------------------------------------------------------- /luckmatter/utils_corrs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | def corrMat2corrIdx(score): 11 | ''' Given score[N, #candidate], 12 | for each sample sort the score (in descending order) 13 | and output corr_table[N, #dict(s_idx, score)] 14 | ''' 15 | sorted_score, sorted_indices = score.sort(1, descending=True) 16 | N = sorted_score.size(0) 17 | n_candidate = sorted_score.size(1) 18 | # Check the correpsonding weights. 19 | # print("For each teacher node, sorted corr over all student nodes at layer = %d" % i) 20 | corr_table = [] 21 | for k in range(N): 22 | tt = [] 23 | for j in range(n_candidate): 24 | # Compare the upward weights. 25 | s_idx = int(sorted_indices[k][j]) 26 | score = float(sorted_score[k][j]) 27 | tt.append(dict(s_idx=s_idx, score=score)) 28 | corr_table.append(tt) 29 | return corr_table 30 | 31 | ''' 32 | def corrIdx2pickMat(corr_indices): 33 | return [ item[0]["s_idx"] for item in corr_indices ] 34 | 35 | def corrIndices2pickMats(corr_indices_list): 36 | return [ corrIdx2pickMat(corr_indices) for corr_indices in corr_indices_list ] 37 | ''' 38 | 39 | def act2corrMat(src, dst): 40 | ''' src[:, k], with k < K1 41 | dst[:, k'], with k' < K2 42 | output correlation score[K1, K2] 43 | ''' 44 | # K_src by K_dst 45 | if len(src.size()) == 3 and len(dst.size()) == 3: 46 | src = src.permute(0, 2, 1).contiguous().view(src.size(0) * src.size(2), -1) 47 | dst = dst.permute(0, 2, 1).contiguous().view(dst.size(0) * dst.size(2), -1) 48 | 49 | # conv activations. 50 | elif len(src.size()) == 4 and len(dst.size()) == 4: 51 | src = src.permute(0, 2, 3, 1).contiguous().view(src.size(0) * src.size(2) * src.size(3), -1) 52 | dst = dst.permute(0, 2, 3, 1).contiguous().view(dst.size(0) * dst.size(2) * dst.size(3), -1) 53 | 54 | # Substract mean. 55 | src = src - src.mean(0, keepdim=True) 56 | dst = dst - dst.mean(0, keepdim=True) 57 | 58 | inner_prod = torch.mm(src.t(), dst) 59 | src_inv_norm = src.pow(2).sum(0).add_(1e-10).rsqrt().view(-1, 1) 60 | dst_inv_norm = dst.pow(2).sum(0).add_(1e-10).rsqrt().view(1, -1) 61 | 62 | return inner_prod * src_inv_norm * dst_inv_norm 63 | 64 | def acts2corrMats(hidden_t, hidden_s): 65 | # Match response 66 | ''' Output correlation matrix for each layer ''' 67 | corrs = [] 68 | for t, s in zip(hidden_t, hidden_s): 69 | corr = act2corrMat(t.data, s.data) 70 | corrs.append(corr) 71 | return corrs 72 | 73 | def acts2corrIndices(hidden_t, hidden_s): 74 | # Match response 75 | ''' Output correlation indices for each layer ''' 76 | corrs = [] 77 | for t, s in zip(hidden_t, hidden_s): 78 | corr = act2corrMat(t.data, s.data) 79 | corrs.append(corrMat2corrIdx(corr)) 80 | return corrs 81 | 82 | ''' 83 | w_t = getattr(teacher, "w%d" % (i + 1)).weight 84 | w_s = getattr(student, "w%d" % (i + 1)).weight 85 | w_teacher=w_t[:,k], w_student=w_s[:, s_idx] 86 | ''' 87 | 88 | def compareCorrIndices(init_corrs, final_corrs): 89 | res = [] 90 | for k, (init_corr, final_corr) in enumerate(zip(init_corrs, final_corrs)): 91 | # For each layer 92 | # print("Layer %d" % k) 93 | res_per_layer = [] 94 | 95 | for j, (init_node, final_node) in enumerate(zip(init_corr, final_corr)): 96 | # For each node 97 | ranks = dict() 98 | max_init_score = -1000 99 | for node_rank, node_info in enumerate(init_node): 100 | node_id = node_info["s_idx"] 101 | score = node_info["score"] 102 | ranks[node_id] = dict(rank=node_rank, score=score) 103 | max_init_score = max(max_init_score, score) 104 | 105 | s_score = [] 106 | s_idx = [] 107 | for node_info in final_node: 108 | node_id = node_info["s_idx"] 109 | if node_id in ranks: 110 | rank = ranks[node_id]["rank"] 111 | else: 112 | rank = "-" 113 | s_score.append(node_info["score"]) 114 | s_idx.append((node_id, str(rank))) 115 | # "%2d [%s]" % (node_id, str(rank))) 116 | # print("T[%d]: [init_student_max=%.4f] %s | idx: %s | min_rank: %d" % (j, max_init_corr, ",".join(s_val), ", ".join(s_idx), min_rank)) 117 | res_per_layer.append(dict(s_score=s_score, s_idx=s_idx, max_init_score=max_init_score)) 118 | res.append(res_per_layer) 119 | 120 | return res 121 | 122 | -------------------------------------------------------------------------------- /luckmatter/vis_corrs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | 11 | def get_stat(w): 12 | # return f"min: {w.min()}, max: {w.max()}, mean: {w.mean()}, std: {w.std()}" 13 | if isinstance(w, list): 14 | w = np.array(w) 15 | elif isinstance(w, torch.Tensor): 16 | w = w.cpu().numpy() 17 | return f"len: {w.shape}, min/max: {np.min(w):#3f}/{np.max(w):#3f}, mean: {np.mean(w):#3f}" 18 | 19 | def print_corrs(corrs, active_nodes=None, first_n=5, details=False): 20 | summary = "" 21 | for k, corr_per_layer in enumerate(corrs): 22 | score = [] 23 | for kk, corr_per_node in enumerate(corr_per_layer): 24 | if active_nodes is None or kk in active_nodes[k]: 25 | score.append(corr_per_node["s_score"][0]) 26 | 27 | summary += f"L{k}: {get_stat(score)}, " 28 | 29 | print(f"Corrs Summary: {summary}") 30 | 31 | if details: 32 | for k, corr_per_layer in enumerate(corrs): 33 | # For each layer 34 | print("Layer %d" % k) 35 | for j, corr_per_node in enumerate(corr_per_layer): 36 | s_score = corr_per_node["s_score"][:first_n] 37 | s_idx = corr_per_node["s_idx"][:first_n] 38 | 39 | s_score_str = ",".join(["%.4f" % v for v in s_score]) 40 | s_idx_str = ",".join(["%2d [%s]" % (node_id, rank) for node_id, rank in s_idx]) 41 | # import pdb 42 | # pdb.set_trace() 43 | 44 | min_rank = min([ int(rank) for node_id, rank in s_idx ]) 45 | print("T[%d]: [init_best_s=%.4f] %s | idx: %s | min_rank: %d" % (j, corr_per_node["max_init_score"], s_score_str, s_idx_str, min_rank)) 46 | # print("T[%d]: [init_best_s=%.4f] %s | idx: %s " % (j, corr_per_node["max_init_score"], s_score_str, s_idx_str)) 47 | 48 | -------------------------------------------------------------------------------- /ssl/.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | -------------------------------------------------------------------------------- /ssl/README.md: -------------------------------------------------------------------------------- 1 | ## Understanding Self-supervised learning (BYOL / SimCLR) 2 | 3 | This part of the repo covers the following two papers: 4 | 5 | 6 | [1] **Understanding Self-supervised Learning with Dual Deep Networks** 7 | 8 | Yuandong Tian, Lantao Yu, Xinlei Chen, Surya Ganguli 9 | 10 | [arXiv](https://arxiv.org/abs/2010.00578) 11 | 12 | [2] **Understanding Understanding self-supervised Learning Dynamics without Contrastive Pairs** 13 | 14 | Yuandong Tian, Xinlei Chen, Surya Ganguli 15 | 16 | ICML 2021 (*Outstanding Paper Honorable Mention*) [arXiv](https://arxiv.org/abs/2102.06810) 17 | 18 | -------------------------------------------------------------------------------- /ssl/common_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .assert_utils import * 2 | from .helper import * 3 | from .logger import * 4 | from .saver import * 5 | from .multi_counter import MultiCounter 6 | from .stopwatch import Stopwatch 7 | from .stats import corr_summary, StatsCorr 8 | -------------------------------------------------------------------------------- /ssl/common_utils/assert_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for assertions""" 2 | 3 | 4 | def assert_eq(real, expected): 5 | assert real == expected, "%s (true) vs %s (expected)" % (real, expected) 6 | 7 | 8 | def assert_neq(real, expected): 9 | assert real != expected, "%s (true) vs %s (expected)" % (real, expected) 10 | 11 | 12 | def assert_lt(real, expected): 13 | assert real < expected, "%s (true) vs %s (expected)" % (real, expected) 14 | 15 | 16 | def assert_lteq(real, expected): 17 | assert real <= expected, "%s (true) vs %s (expected)" % (real, expected) 18 | 19 | 20 | def assert_tensor_eq(t1, t2, eps=1e-6): 21 | if t1.size() != t2.size(): 22 | print("Warning: size mismatch", t1.size(), "vs", t2.size()) 23 | return False 24 | 25 | t1 = t1.cpu().numpy() 26 | t2 = t2.cpu().numpy() 27 | diff = abs(t1 - t2) 28 | eq = (diff < eps).all() 29 | if not eq: 30 | import pdb 31 | 32 | pdb.set_trace() 33 | assert eq, (diff < eps).max() 34 | 35 | 36 | def assert_zero_grad(params): 37 | for p in params: 38 | if p.grad is not None: 39 | assert p.grad.sum().item() == 0 40 | -------------------------------------------------------------------------------- /ssl/common_utils/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from typing import Dict 8 | 9 | from subprocess import check_output 10 | 11 | def get_git_hash(): 12 | try: 13 | return check_output("git -C ./ log --pretty=format:'%H' -n 1", shell=True).decode('utf-8') 14 | except: 15 | return "" 16 | 17 | def get_git_diffs(): 18 | try: 19 | active_diff = check_output("git diff", shell=True).decode('utf-8') 20 | staged_diff = check_output("git diff --cached", shell=True).decode('utf-8') 21 | return active_diff + "\n" + staged_diff 22 | except: 23 | return "" 24 | 25 | def get_all_files(root, file_extension): 26 | files = [] 27 | for folder, _, fs in os.walk(root): 28 | for f in fs: 29 | if f.endswith(file_extension): 30 | files.append(os.path.join(folder, f)) 31 | return files 32 | 33 | def pretty_print_cmd(args): 34 | s = "" 35 | for i, a in enumerate(args): 36 | if i > 0: 37 | s += " " 38 | s += a + " \\\n" 39 | return s 40 | 41 | def moving_average(data, period): 42 | # padding 43 | left_pad = [data[0] for _ in range(period // 2)] 44 | right_pad = data[-period // 2 + 1 :] 45 | data = left_pad + data + right_pad 46 | weights = np.ones(period) / period 47 | return np.convolve(data, weights, mode="valid") 48 | 49 | 50 | def mem2str(num_bytes): 51 | assert num_bytes >= 0 52 | if num_bytes >= 2 ** 30: # GB 53 | val = float(num_bytes) / (2 ** 30) 54 | result = "%.3f GB" % val 55 | elif num_bytes >= 2 ** 20: # MB 56 | val = float(num_bytes) / (2 ** 20) 57 | result = "%.3f MB" % val 58 | elif num_bytes >= 2 ** 10: # KB 59 | val = float(num_bytes) / (2 ** 10) 60 | result = "%.3f KB" % val 61 | else: 62 | result = "%d bytes" % num_bytes 63 | return result 64 | 65 | 66 | def sec2str(seconds): 67 | seconds = int(seconds) 68 | hour = seconds // 3600 69 | seconds = seconds % (24 * 3600) 70 | seconds %= 3600 71 | minutes = seconds // 60 72 | seconds %= 60 73 | return "%dH %02dM %02dS" % (hour, minutes, seconds) 74 | 75 | 76 | def num2str(n): 77 | if n < 1e3: 78 | s = str(n) 79 | unit = "" 80 | elif n < 1e6: 81 | n /= 1e3 82 | s = "%.3f" % n 83 | unit = "K" 84 | else: 85 | n /= 1e6 86 | s = "%.3f" % n 87 | unit = "M" 88 | 89 | s = s.rstrip("0").rstrip(".") 90 | return s + unit 91 | 92 | 93 | def get_mem_usage(): 94 | import psutil 95 | 96 | mem = psutil.virtual_memory() 97 | result = "" 98 | result += "available: %s, " % (mem2str(mem.available)) 99 | result += "used: %s, " % (mem2str(mem.used)) 100 | result += "free: %s" % (mem2str(mem.free)) 101 | # result += "active: %s\t" % (mem2str(mem.active)) 102 | # result += "inactive: %s\t" % (mem2str(mem.inactive)) 103 | # result += "buffers: %s\t" % (mem2str(mem.buffers)) 104 | # result += "cached: %s\t" % (mem2str(mem.cached)) 105 | # result += "shared: %s\t" % (mem2str(mem.shared)) 106 | # result += "slab: %s\t" % (mem2str(mem.slab)) 107 | return result 108 | 109 | 110 | def flatten_first2dim(batch): 111 | if isinstance(batch, torch.Tensor): 112 | size = batch.size()[2:] 113 | batch = batch.view(-1, *size) 114 | return batch 115 | elif isinstance(batch, dict): 116 | return {key: flatten_first2dim(batch[key]) for key in batch} 117 | else: 118 | assert False, "unsupported type: %s" % type(batch) 119 | 120 | 121 | def _tensor_slice(t, dim, b, e): 122 | if dim == 0: 123 | return t[b:e] 124 | elif dim == 1: 125 | return t[:, b:e] 126 | elif dim == 2: 127 | return t[:, :, b:e] 128 | else: 129 | raise ValueError("unsupported %d in tensor_slice" % dim) 130 | 131 | 132 | def tensor_slice(t, dim, b, e): 133 | if isinstance(t, dict): 134 | return {key: tensor_slice(t[key], dim, b, e) for key in t} 135 | elif isinstance(t, torch.Tensor): 136 | return _tensor_slice(t, dim, b, e).contiguous() 137 | else: 138 | assert False, "Error: unsupported type: %s" % (type(t)) 139 | 140 | 141 | def tensor_index(t, dim, i): 142 | if isinstance(t, dict): 143 | return {key: tensor_index(t[key], dim, i) for key in t} 144 | elif isinstance(t, torch.Tensor): 145 | return _tensor_slice(t, dim, i, i + 1).squeeze(dim).contiguous() 146 | else: 147 | assert False, "Error: unsupported type: %s" % (type(t)) 148 | 149 | 150 | def one_hot(x, n): 151 | assert x.dim() == 2 and x.size(1) == 1 152 | one_hot_x = torch.zeros(x.size(0), n, device=x.device) 153 | one_hot_x.scatter_(1, x, 1) 154 | return one_hot_x 155 | 156 | 157 | def set_all_seeds(rand_seed): 158 | random.seed(rand_seed) 159 | np.random.seed(rand_seed + 1) 160 | torch.manual_seed(rand_seed + 2) 161 | torch.cuda.manual_seed(rand_seed + 3) 162 | 163 | 164 | def weights_init(m): 165 | """custom weights initialization""" 166 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 167 | # nn.init.kaiming_normal(m.weight.data) 168 | nn.init.orthogonal_(m.weight.data) 169 | else: 170 | print("%s is not custom-initialized." % m.__class__) 171 | 172 | 173 | def init_net(net, net_file): 174 | if net_file: 175 | net.load_state_dict(torch.load(net_file)) 176 | else: 177 | net.apply(weights_init) 178 | 179 | 180 | def count_output_size(input_shape, model): 181 | fake_input = torch.FloatTensor(*input_shape) 182 | output_size = model.forward(fake_input).view(-1).size()[0] 183 | return output_size 184 | -------------------------------------------------------------------------------- /ssl/common_utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class Logger: 6 | def __init__(self, path, mode="w"): 7 | assert mode in {"w", "a"}, "unknown mode for logger %s" % mode 8 | self.terminal = sys.stdout 9 | if not os.path.exists(os.path.dirname(path)): 10 | os.makedirs(os.path.dirname(path)) 11 | if mode == "w" or not os.path.exists(path): 12 | self.log = open(path, "w") 13 | else: 14 | self.log = open(path, "a") 15 | 16 | def write(self, message): 17 | self.terminal.write(message) 18 | self.log.write(message) 19 | self.log.flush() 20 | 21 | def flush(self): 22 | # for python 3 compatibility. 23 | pass 24 | -------------------------------------------------------------------------------- /ssl/common_utils/multi_counter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from collections import defaultdict, Counter 4 | from datetime import datetime 5 | from tensorboardX import SummaryWriter 6 | 7 | 8 | class ValueStats: 9 | def __init__(self, name=None): 10 | self.name = name 11 | self.reset() 12 | 13 | def feed(self, v): 14 | self.summation += v 15 | self.sumSqr += v * v 16 | 17 | if v > self.max_value: 18 | self.max_value = v 19 | self.max_idx = self.counter 20 | if v < self.min_value: 21 | self.min_value = v 22 | self.min_idx = self.counter 23 | 24 | self.counter += 1 25 | 26 | def mean(self): 27 | if self.counter == 0: 28 | print("Counter %s is 0" % self.name) 29 | assert False 30 | return self.summation / self.counter 31 | 32 | def summary(self, info=None): 33 | info = "" if info is None else info 34 | name = "" if self.name is None else self.name 35 | if self.counter > 0: 36 | mean = self.summation / self.counter 37 | meanSqr = self.sumSqr / self.counter 38 | std = math.sqrt(meanSqr - mean * mean + 1e-6) 39 | 40 | # try: 41 | return "%s%s[%4d]: avg: %8.4f (± %8.4f), min: %8.4f[%4d], max: %8.4f[%4d]" % ( 42 | info, 43 | name, 44 | self.counter, 45 | mean, 46 | std / math.sqrt(self.counter), # std for the mean 47 | self.min_value, 48 | self.min_idx, 49 | self.max_value, 50 | self.max_idx, 51 | ) 52 | # except BaseException: 53 | # return "%s%s[Err]:" % (info, name) 54 | else: 55 | return "%s%s[0]" % (info, name) 56 | 57 | def reset(self): 58 | self.counter = 0 59 | self.summation = 0.0 60 | self.sumSqr = 0.0 61 | self.max_value = -1e38 62 | self.min_value = 1e38 63 | self.max_idx = None 64 | self.min_idx = None 65 | 66 | 67 | class MultiCounter: 68 | def __init__(self, root, verbose=False): 69 | # TODO: rethink counters 70 | self.last_time = None 71 | self.verbose = verbose 72 | self.counts = Counter() 73 | self.stats = defaultdict(lambda: ValueStats()) 74 | self.total_count = 0 75 | self.max_key_len = 0 76 | if root is not None: 77 | self.tb_writer = SummaryWriter(os.path.join(root, "stat.tb")) 78 | else: 79 | self.tb_writer = None 80 | 81 | def __getitem__(self, key): 82 | if len(key) > self.max_key_len: 83 | self.max_key_len = len(key) 84 | 85 | if key in self.counts: 86 | return self.counts[key] 87 | 88 | return self.stats[key] 89 | 90 | def inc(self, key): 91 | if self.verbose: 92 | print("[MultiCounter]: %s" % key) 93 | self.counts[key] += 1 94 | self.total_count += 1 95 | if self.last_time is None: 96 | self.last_time = datetime.now() 97 | 98 | def reset(self): 99 | for k in self.stats.keys(): 100 | self.stats[k].reset() 101 | 102 | self.counts = Counter() 103 | self.total_count = 0 104 | self.last_time = datetime.now() 105 | 106 | def time_elapsed(self): 107 | return (datetime.now() - self.last_time).total_seconds() 108 | 109 | def summary(self, global_counter): 110 | assert self.last_time is not None 111 | time_elapsed = (datetime.now() - self.last_time).total_seconds() 112 | s = "[%d] Time spent = %.2f s\n" % (global_counter, time_elapsed) 113 | 114 | for key, count in self.counts.items(): 115 | s += "%s: %d/%d\n" % (key, count, self.total_count) 116 | 117 | for k in sorted(self.stats.keys()): 118 | v = self.stats[k] 119 | info = str(global_counter) + ":" + k 120 | s += v.summary(info=info.ljust(self.max_key_len + 4)) + "\n" 121 | 122 | if self.tb_writer is not None: 123 | self.tb_writer.add_scalar(k, v.mean(), global_counter) 124 | 125 | return s 126 | -------------------------------------------------------------------------------- /ssl/common_utils/saver.py: -------------------------------------------------------------------------------- 1 | # model saver that saves top-k performing model 2 | import os 3 | import torch 4 | 5 | 6 | class TopkSaver: 7 | def __init__(self, save_dir, topk): 8 | self.save_dir = save_dir 9 | self.topk = topk 10 | self.worse_perf = -float("inf") 11 | self.worse_perf_idx = 0 12 | self.perfs = [self.worse_perf] 13 | 14 | if not os.path.exists(save_dir): 15 | os.makedirs(save_dir) 16 | 17 | def save(self, model, state_dict, perf): 18 | if perf <= self.worse_perf: 19 | # print('i am sorry') 20 | # [print(i) for i in self.perfs] 21 | return False 22 | 23 | model_name = "model%i.pthm" % self.worse_perf_idx 24 | weight_name = "model%i.pthw" % self.worse_perf_idx 25 | if model is not None: 26 | model.save(os.path.join(self.save_dir, model_name)) 27 | if state_dict is not None: 28 | torch.save(state_dict, os.path.join(self.save_dir, weight_name)) 29 | 30 | if len(self.perfs) < self.topk: 31 | self.perfs.append(perf) 32 | return True 33 | 34 | # neesd to replace 35 | self.perfs[self.worse_perf_idx] = perf 36 | worse_perf = self.perfs[0] 37 | worse_perf_idx = 0 38 | for i, perf in enumerate(self.perfs): 39 | if perf < worse_perf: 40 | worse_perf = perf 41 | worse_perf_idx = i 42 | 43 | self.worse_perf = worse_perf 44 | self.worse_perf_idx = worse_perf_idx 45 | return True 46 | -------------------------------------------------------------------------------- /ssl/common_utils/stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import pprint 5 | from collections import Counter, defaultdict 6 | 7 | def accumulate(a, v): 8 | return a + v if a is not None else v 9 | 10 | 11 | def get_stat(w): 12 | # return f"min: {w.min()}, max: {w.max()}, mean: {w.mean()}, std: {w.std()}" 13 | if isinstance(w, list): 14 | w = np.array(w) 15 | elif isinstance(w, torch.Tensor): 16 | w = w.cpu().numpy() 17 | return f"len: {w.shape}, min/max: {np.min(w):#.6f}/{np.max(w):#.6f}, mean/median: {np.mean(w):#.6f}/{np.median(w):#.6f}" 18 | 19 | 20 | def corr_summary(corr, row_names=None, col_names=None, best_corr_last=None, verbose=False, cnt_thres=None): 21 | # Corrs: [row, col] 22 | sorted_corrs, indices = corr.sort(1, descending=True) 23 | num_row, num_col = corr.size() 24 | 25 | if row_names is None: 26 | row_names = [ f"row{i}" for i in range(num_row) ] 27 | 28 | if col_names is None: 29 | col_names = [ f"col{i}" for i in range(num_col) ] 30 | 31 | assert num_row == len(row_names), f"#row: {num_row} != #row_names {len(row_names)}" 32 | assert num_col == len(col_names), f"#col: {num_col} != #col_names {len(col_names)}" 33 | 34 | best_corr = sorted_corrs[:,0] 35 | if best_corr_last is not None: 36 | best_corr_diff = best_corr - best_corr_last 37 | _, row_orders = best_corr.sort(0, descending=True) 38 | 39 | # Teacher side. 40 | summaries = [] 41 | for i in row_orders: 42 | row_name = row_names[i] 43 | s = sorted_corrs[i] 44 | best = indices[i][0] 45 | best_score = sorted_corrs[i][0] 46 | comment = f"[{i}] {row_name}: best: {col_names[best]} ({best_score:.6f})" 47 | if best_corr_last is not None: 48 | comment += f" delta: {best_corr_diff[i]:.6f}" 49 | summaries.append(comment) 50 | 51 | return dict(summary="\n".join(summaries), best_corr=best_corr) 52 | 53 | 54 | class StatsCorr: 55 | def reset(self): 56 | self.initialized = False 57 | 58 | def add(self, h_t, h_s): 59 | if not self.initialized: 60 | self.inner_prod = None 61 | self.sum_t = None 62 | self.sum_s = None 63 | self.sum_sqr_t = None 64 | self.sum_sqr_s = None 65 | self.counts = 0 66 | self.initialized = True 67 | 68 | # Compute correlation. 69 | # activation: [bs, #nodes] 70 | h_t = h_t.detach() 71 | h_s = h_s.detach() 72 | 73 | if h_t.dim() == 4: 74 | h_t = h_t.permute(0, 2, 3, 1).reshape(-1, h_t.size(1)) 75 | if h_s.dim() == 4: 76 | h_s = h_s.permute(0, 2, 3, 1).reshape(-1, h_s.size(1)) 77 | 78 | self.inner_prod = accumulate(self.inner_prod, h_t.t() @ h_s) 79 | 80 | self.sum_t = accumulate(self.sum_t, h_t.sum(dim=0)) 81 | self.sum_s = accumulate(self.sum_s, h_s.sum(dim=0)) 82 | 83 | self.sum_sqr_t = accumulate(self.sum_sqr_t, h_t.pow(2).sum(dim=0)) 84 | self.sum_sqr_s = accumulate(self.sum_sqr_s, h_s.pow(2).sum(dim=0)) 85 | 86 | self.counts += h_t.size(0) 87 | 88 | def get(self): 89 | assert self.initialized 90 | 91 | n = self.counts 92 | s_avg = self.sum_s / n 93 | t_avg = self.sum_t / n 94 | 95 | ts_centered = self.inner_prod / n - torch.ger(t_avg, s_avg) 96 | 97 | t_var = self.sum_sqr_t / n - t_avg.pow(2) 98 | s_var = self.sum_sqr_s / n - s_avg.pow(2) 99 | 100 | t_var.clamp_(0, None) 101 | s_var.clamp_(0, None) 102 | 103 | t_norm = t_var.sqrt() 104 | s_norm = s_var.sqrt() 105 | 106 | ts_norm = torch.ger(t_norm, s_norm) 107 | corr = ts_centered / ts_norm 108 | # Set 0/0 = 0 109 | # usually that means everything is constant (and super sparse), and we don't know the correlation 110 | zero_entry = (ts_norm < 1e-6) & (ts_centered.abs() < 1e-6) 111 | corr[zero_entry] = 0.0 112 | 113 | corr.clamp_(-1, 1) 114 | 115 | # corr: [#node_t, #node_s] 116 | return dict(corr=corr, s_norm=s_norm, t_norm=t_norm) 117 | -------------------------------------------------------------------------------- /ssl/common_utils/stopwatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from collections import defaultdict 4 | from datetime import datetime 5 | import numpy as np 6 | 7 | 8 | def millis_interval(start, end): 9 | """start and end are datetime instances""" 10 | diff = end - start 11 | millis = diff.days * 24 * 60 * 60 * 1000 12 | millis += diff.seconds * 1000 13 | millis += diff.microseconds / 1000 14 | return millis 15 | 16 | 17 | class Stopwatch: 18 | def __init__(self): 19 | self.last_time = datetime.now() 20 | self.times = defaultdict(list) 21 | self.keys = [] 22 | 23 | def reset(self): 24 | self.last_time = datetime.now() 25 | self.times = defaultdict(list) 26 | self.keys = [] 27 | 28 | def time(self, key): 29 | if key not in self.times: 30 | self.keys.append(key) 31 | self.times[key].append(millis_interval(self.last_time, datetime.now())) 32 | self.last_time = datetime.now() 33 | 34 | def summary(self): 35 | num_elems = -1 36 | total = 0 37 | max_key_len = 0 38 | for k, v in self.times.items(): 39 | if num_elems == -1: 40 | num_elems = len(v) 41 | 42 | assert len(v) == num_elems 43 | total += np.sum(v) 44 | max_key_len = max(max_key_len, len(k)) 45 | 46 | s = "@@@Time" 47 | for k in self.keys: 48 | v = self.times[k] 49 | s += "\t%s: %d MS, %.2f%%\n" % (k.ljust(max_key_len), np.mean(v), 100.0 * np.sum(v) / total) 50 | s += "@@@total time per iter: %.2f ms\n" % (float(total) / num_elems) 51 | self.reset() 52 | 53 | return s 54 | -------------------------------------------------------------------------------- /ssl/hltm/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | The codebase is for runing SimCLR on Hierarchical Latent Tree Model (HLTM). See Section 6 in [1]. 3 | 4 | # Sample Usage 5 | To run training, use the following command: 6 | ``` 7 | python simCLR_hltm.py depth=5 num_children=2 seed=1 delta_lower=0.8 temp=0.1 hid=10 num_epoch=50 8 | ``` 9 | Here 10 | 1. `depth=5` means `L = 5` in the paper. 11 | 2. `num_children=2` means that it is a binary HLTM (i.e., each non-leaf latent variable has two children) 12 | 3. `dalta_lower` determines the sampling range of the polarity, which samples from `Uniform[delta_lower, 1]`. 13 | 4. `temp=0.1` is the temperature `tau` used in NCE loss. We use `H = 1` setting. 14 | 5. `hid=10` is the degree of overparameterization (or `|N^ch_mu|` in the paper). 15 | 16 | # Reference 17 | [1] **Understanding Self-supervised Learning with Dual Deep Networks** 18 | 19 | Yuandong Tian, Lantao Yu, Xinlei Chen, Surya Ganguli 20 | 21 | [arXiv](https://arxiv.org/abs/2010.00578) 22 | -------------------------------------------------------------------------------- /ssl/hltm/conf/hltm.yaml: -------------------------------------------------------------------------------- 1 | depth: 5 2 | num_children: 2 3 | batchsize: 128 4 | N: 64000 5 | hid: 10 6 | num_epoch: 20 7 | save_file: htgm_model.pth 8 | seed: 1 9 | lr: 0.1 10 | 11 | delta_upper: 1 12 | delta_lower: 0.7 13 | loss: nce 14 | 15 | eps: 1e-4 16 | temp: 0.01 17 | 18 | githash: 19 | sweep_filename: -------------------------------------------------------------------------------- /ssl/real-dataset/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | The codebase is built from this [repo](https://github.com/sthalles/PyTorch-BYOL), with proper modifications. It is used in the following two arXiv papers: 3 | 4 | You will need to install [hydra](https://github.com/facebookresearch/hydra) for parameter configuration and supporting of sweeps. 5 | 6 | Note that to avoid downloading dataset every time you run the program, you can change `dataset_path` in `config/byol_config.yaml` (which is actually shared by both BYOL and SimCLR methods) to an absolute path. 7 | 8 | # Prerequisite 9 | 10 | Please install `common_utils` package in https://github.com/yuandong-tian/tools2 before running the code. 11 | 12 | # Sample Usage 13 | 14 | ## First paper [1] 15 | To run verification of Theorem 4 in [1]: 16 | 17 | ``` 18 | python main.py method=simclr use_optimizer=adam optimizer.params.weight_decay=0 seed=1 \ 19 | optimizer.params.lr=1e-3 trainer.nce_loss.exact_cov=true \ 20 | dataset=stl10 trainer.nce_loss.beta=0 trainer.max_epochs=500 21 | ``` 22 | 23 | You can also set `trainer.nce_loss.exact_cov=false` to get performance when using normal NCE loss. Set `trainer.nce_loss.beta` to be nonzero for more general loss functions, when `beta` can be either positive or negative. 24 | 25 | For Hierarchical Latent Tree Model (HLTM) in Section 6, please check the code [here](https://github.com/facebookresearch/luckmatters/tree/master/ssl/hltm). 26 | 27 | ## Second paper [2] 28 | To run DirectPred introduced in [2], here is a sample command (tested in commit cb23d10c3018df6bf275ad537f23675c8a627253) 29 | 30 | ``` 31 | python main.py seed=1 method=byol trainer.max_epochs=100 trainer.predictor_params.has_bias=false \ 32 | trainer.predictor_params.normalization=no_normalization network.predictor_head.mlp_hidden_size=null \ 33 | trainer.predictor_reg=corr trainer.predictor_freq=1 trainer.dyn_lambda=0.3 trainer.dyn_eps=0.01 trainer.balance_type=boost_scale 34 | ``` 35 | Note that 36 | 1. The second line `trainer.predictor_params.normalization=no_normalization` and `network.predictor_head.mlp_hidden_size=null` means that the predictor is linear. 37 | 2. The third line means that we use DirectPredict with update frequency `freq=1`, `dyn_lambda=0.3` (which is `rho` in Eqn. 19 of [2]) and `dyn_eps=0.01` (which is `eps` in Eqn. 18 of [2]). 38 | 39 | ## Third paper [3] 40 | To run alpha-CL (with `p=4` in the paper), here is a sample command 41 | ``` 42 | python main.py method=simclr dataset=cifar100 trainer.nce_loss.loss_type=dual2 trainer.nce_loss.alpha_exponent=2 trainer.nce_loss.alpha_eps=0 trainer.nce_loss.alpha_type=exp use_optimizer=adam optimizer.params.lr=0.01 optimizer.params.weight_decay=0 seed=1 43 | ``` 44 | 45 | ## Fourth paper [4] 46 | To run the experiments in Section 5, try the following. Here `distri.num_tokens_per_pos` is `P`, and `distri.pattern_cnt` is `G` in the paper. 47 | 48 | ``` 49 | python bn_gen.py distri.num_tokens=20 distri.num_tokens_per_pos=5 model.activation=relu beta=5 model.bn_spec.use_bn=true model.bn_spec.backprop_var=false seed=1 model.shared_low_layer=false opt.wd=0.005 distri.pattern_cnt=40 model.output_d=50 opt.lr=0.02 50 | ``` 51 | 52 | Output 53 | ``` 54 | [2022-09-02 15:56:05,981][bn_gen.py][INFO] - distributions: #Tokens: 20, #Loc: 10, Tokens per loc: [5, 5, 5, 5, 5, 5, 5, 5, 5, 5] 55 | patterns: 56 | -F-H--C-MO 57 | -JFL--EN-- 58 | -J-HI---ID 59 | L-H---C-PJ 60 | -JOL-TI--- 61 | -RFC-PE--- 62 | FG-L-K--M- 63 | JNF-M----O 64 | --LC-KL-F- 65 | L-HJK----O 66 | -N-L-E-IM- 67 | TN--HN-I-- 68 | RR----ID-O 69 | --L-HEI--J 70 | L--CIK--I- 71 | --FJMT---T 72 | -R--MNE--T 73 | TRO--PC--- 74 | J-HC----FT 75 | -GO---CQI- 76 | -N--KNL-R- 77 | ----N-IQRD 78 | J--JNN---J 79 | R--H-T-O-J 80 | --H-KP--ID 81 | -FE-H--DI- 82 | ---H-PLQR- 83 | FN-HI----D 84 | JJ--H--I-P 85 | R-HC----RP 86 | FJ--I---RT 87 | T--JK-C-R- 88 | LGHJ----R- 89 | L-L--N--PP 90 | R--CM--D-P 91 | -FE-N---MT 92 | J--HI--QM- 93 | RN--K---RJ 94 | -NF-KE---O 95 | -FOM-KI--- 96 | At loc 0: L=5,F=3,J=5,T=3,R=5 97 | At loc 1: F=4,J=5,R=4,G=3,N=7 98 | At loc 2: F=5,H=6,O=4,L=3,E=2 99 | At loc 3: H=6,L=4,C=6,J=5,M=1 100 | At loc 4: I=5,M=4,K=6,H=4,N=3 101 | At loc 5: T=3,P=4,K=4,E=3,N=5 102 | At loc 6: C=5,E=3,I=5,L=3 103 | At loc 7: N=1,I=3,D=3,Q=4,O=1 104 | At loc 8: M=5,I=5,P=2,F=2,R=8 105 | At loc 9: O=5,D=4,J=5,T=5,P=4 106 | 107 | [2022-09-02 15:56:05,984][/private/home/yuandong/luckmatters/ssl/real-dataset/bn_gen_utils.py][INFO] - mags: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 108 | 1., 1.]) 109 | [2022-09-02 15:56:05,984][bn_gen.py][INFO] - beta overrides multi: multi [25] = tokens_per_loc [5] x beta [5] 110 | [2022-09-02 15:56:06,028][bn_gen.py][INFO] - [0] 2.595567226409912 111 | [2022-09-02 15:56:06,029][bn_gen.py][INFO] - Save to model-0.pth 112 | [2022-09-02 15:56:24,516][bn_gen.py][INFO] - [500] 1.5258710384368896 113 | [2022-09-02 15:56:24,517][bn_gen.py][INFO] - Save to model-500.pth 114 | [2022-09-02 15:56:42,832][bn_gen.py][INFO] - [1000] 1.5300947427749634 115 | [2022-09-02 15:56:42,833][bn_gen.py][INFO] - Save to model-1000.pth 116 | [2022-09-02 15:57:01,072][bn_gen.py][INFO] - [1500] 1.5061111450195312 117 | [2022-09-02 15:57:01,072][bn_gen.py][INFO] - Save to model-1500.pth 118 | [2022-09-02 15:57:19,371][bn_gen.py][INFO] - [2000] 1.4396307468414307 119 | [2022-09-02 15:57:19,371][bn_gen.py][INFO] - Save to model-2000.pth 120 | [2022-09-02 15:57:37,563][bn_gen.py][INFO] - [2500] 1.5044890642166138 121 | [2022-09-02 15:57:37,564][bn_gen.py][INFO] - Save to model-2500.pth 122 | [2022-09-02 15:57:55,655][bn_gen.py][INFO] - [3000] 1.4585031270980835 123 | [2022-09-02 15:57:55,655][bn_gen.py][INFO] - Save to model-3000.pth 124 | [2022-09-02 15:58:16,503][bn_gen.py][INFO] - [3500] 1.472838282585144 125 | [2022-09-02 15:58:16,504][bn_gen.py][INFO] - Save to model-3500.pth 126 | [2022-09-02 15:58:34,750][bn_gen.py][INFO] - [4000] 1.421286940574646 127 | [2022-09-02 15:58:34,751][bn_gen.py][INFO] - Save to model-4000.pth 128 | [2022-09-02 15:58:52,998][bn_gen.py][INFO] - [4500] 1.341599464416504 129 | [2022-09-02 15:58:52,999][bn_gen.py][INFO] - Save to model-4500.pth 130 | [2022-09-02 15:59:11,008][bn_gen.py][INFO] - Final loss = 1.5294005870819092 131 | [2022-09-02 15:59:11,009][bn_gen.py][INFO] - Save to model-final.pth 132 | [2022-09-02 15:59:11,052][bn_gen.py][INFO] - [{'folder': '/private/home/yuandong/luckmatters/ssl/real-dataset/outputs/2022-09-02/15-56-05', 'loc0': 0.9961947202682495, 'loc_other0': 0.005383226554840803, 'loc1': 0.9986963272094727, 'loc_other1': -0.00016310946375597268, 'loc2': 0.9985083341598511, 'loc_other2': -0.0002446844591759145, 'loc3': 0.9983118772506714, 'loc_other3': -0.0002287515817442909, 'loc4': 0.9983332753181458, 'loc_other4': -0.0002713123394642025, 'loc5': 0.9984112977981567, 'loc_other5': -0.00028966396348550916, 'loc6': 0.9983190298080444, 'loc_other6': -0.0002980256685987115, 'loc7': 0.9980360269546509, 'loc_other7': -0.00040157922194339335, 'loc8': 0.9986146092414856, 'loc_other8': -0.00030036718817427754, 'loc9': 0.9987049102783203, 'loc_other9': 0.023116284981369972, 'loc_all': 0.9982131123542786, 'loc_other_all': 0.002630201866850257}] 133 | ``` 134 | 135 | # Reference 136 | [1] **Understanding Self-supervised Learning with Dual Deep Networks** 137 | 138 | Yuandong Tian, Lantao Yu, Xinlei Chen, Surya Ganguli 139 | 140 | [arXiv](https://arxiv.org/abs/2010.00578) 141 | 142 | [2] **Understanding Understanding self-supervised Learning Dynamics without Contrastive Pairs** 143 | 144 | Yuandong Tian, Xinlei Chen, Surya Ganguli 145 | 146 | [ICML'21](https://arxiv.org/abs/2102.06810) *Outstanding paper honorable mention* 147 | 148 | [3] **Understanding Deep Contrastive Learning via Coordinate-wise Optimization** 149 | 150 | Yuandong Tian 151 | 152 | [NeurIPS'22](https://arxiv.org/abs/2201.12680) Oral 153 | 154 | [4] **Understanding the Role of Nonlinearity in Training Dynamics of Contrastive Learning** 155 | 156 | Yuandong Tian 157 | 158 | [arXiv](https://arxiv.org/abs/2206.01342) 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /ssl/real-dataset/bn_gen_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from collections import Counter, defaultdict, deque 4 | 5 | import logging 6 | log = logging.getLogger(__file__) 7 | 8 | class Distribution: 9 | def __init__(self, distri): 10 | if distri.specific is not None: 11 | return [ [ Distribution.letter2idx(t) for t in v] for v in distri.specific.split("-") ] 12 | 13 | # Generate the distribution. 14 | tokens_per_loc = [] 15 | token_indices = list(range(distri.num_tokens)) 16 | for i in range(distri.num_loc): 17 | # For each location, pick tokens. 18 | random.shuffle(token_indices) 19 | tokens_per_loc.append(token_indices[:distri.num_tokens_per_pos]) 20 | 21 | distributions = [] 22 | loc_indices = list(range(distri.num_loc)) 23 | for i in range(distri.pattern_cnt): 24 | # pick locations. 25 | random.shuffle(loc_indices) 26 | 27 | pattern = [-1] * distri.num_loc 28 | # for each loc, pick which token to choose. 29 | for l in loc_indices[:distri.pattern_len]: 30 | pattern[l] = random.choice(tokens_per_loc[l]) 31 | 32 | distributions.append(pattern) 33 | 34 | self.distributions = distributions 35 | self.pattern_len = distri.pattern_len 36 | self.tokens_per_loc = tokens_per_loc 37 | self.num_tokens = distri.num_tokens 38 | self.num_loc = distri.num_loc 39 | 40 | @classmethod 41 | def letter2idx(cls, t): 42 | return ord(t) - ord('A') if t != '*' else -1 43 | 44 | @classmethod 45 | def idx2letter(cls, i): 46 | return '-' if i == -1 else chr(ord('A') + i) 47 | 48 | def save(self, filename): 49 | torch.save(self.__dict__, filename) 50 | 51 | @classmethod 52 | def load(cls, filename): 53 | data = torch.load(filename) 54 | obj = cls.__new__(Distribution) 55 | for k, v in data.items(): 56 | setattr(obj, k, v) 57 | return obj 58 | 59 | def symbol_freq(self): 60 | counts = defaultdict(Counter) 61 | for pattern in self.distributions: 62 | for k, d in enumerate(pattern): 63 | counts[k][d] += 1 64 | 65 | # counts[k][d] is the frequency of symbol d (in terms of index) appears at index k 66 | return counts 67 | 68 | def __repr__(self): 69 | # visualize the distribution 70 | ''' 71 | -1 = wildcard 72 | 73 | distrib = [ 74 | [0, 1, -1, -1, 3], 75 | [-1, -1, 1, 4, 2] 76 | ] 77 | ''' 78 | s = f"#Tokens: {self.num_tokens}, #Loc: {self.num_loc}, Tokens per loc: {[len(a) for a in self.tokens_per_loc]}\n" 79 | s += "patterns: \n" 80 | for pattern in self.distributions: 81 | s += " " + "".join([Distribution.idx2letter(a) for a in pattern]) + "\n" 82 | counts = self.symbol_freq() 83 | for k in range(self.num_loc): 84 | s += f"At loc {k}: " + ",".join([f"{Distribution.idx2letter(idx)}={cnt}" for idx, cnt in counts[k].items() if idx != -1]) + "\n" 85 | s += "\n" 86 | return s 87 | 88 | def sample(self, n): 89 | return random.choices(self.distributions, k=n) 90 | 91 | 92 | class Generator: 93 | def __init__(self, distrib : Distribution, batchsize:int, mag_split = 1, aug_degree = 5, d = None): 94 | self.distrib = distrib 95 | self.K = distrib.num_loc 96 | self.batchsize = batchsize 97 | 98 | assert aug_degree <= self.K - distrib.pattern_len, f"Aug Degree [{aug_degree}] should <= K [{self.K}] - pattern_len [{distrib.pattern_len}]" 99 | 100 | self.num_symbols = distrib.num_tokens 101 | self.aug_degree = aug_degree 102 | 103 | # mags = torch.rand(args.distri.num_tokens)*3 + 1 104 | # 105 | mags = torch.ones(self.num_symbols) 106 | # Pick the first batch, make them low and second one make them higher. 107 | mags[:self.num_symbols//2] /= mag_split 108 | mags[self.num_symbols//2:] *= mag_split 109 | # mags = torch.rand(args.distri.num_tokens) * args.distri.mag_sigma 110 | self.mags = mags 111 | log.info(f"mags: {self.mags}") 112 | 113 | if d is None: 114 | d = self.num_symbols 115 | self.d = d 116 | 117 | if self.d == self.num_symbols: 118 | # i-th column is the embedding for i-th symbol. 119 | self.symbol_embedding = torch.eye(self.d) 120 | else: 121 | # random vector generation. 122 | log.info(f"Generating non-orthogonal embeddings. d = {self.d}, #tokens = {self.num_symbols}") 123 | embeds = torch.randn(self.d, self.num_symbols) 124 | embeds = embeds / embeds.norm(dim=0, keepdim=True) 125 | self.symbol_embedding = embeds 126 | 127 | def _ground_symbol(self, a): 128 | # replace any wildcard in token with any symbols. 129 | return a if a != -1 else random.randint(0, self.num_symbols - 1) 130 | 131 | def _ground_tokens(self, tokens): 132 | return [ [self._ground_symbol(a) for a in token] for token in tokens ] 133 | 134 | def _change_wildcard_tokens(self, tokens_with_wildcard, ground_tokens): 135 | # Pick a subset of wildcard tokens to change. 136 | ground_tokens2 = [] 137 | for token_with_wildcard, ground_token in zip(tokens_with_wildcard, ground_tokens): 138 | wildcard_indices = [ i for i, t in enumerate(token_with_wildcard) if t == -1 ] 139 | random.shuffle(wildcard_indices) 140 | 141 | ground_token2 = list(ground_token) 142 | for idx in wildcard_indices[:self.aug_degree]: 143 | # Replace with another one. 144 | ground_token2[idx] = self._ground_symbol(-1) 145 | 146 | ground_tokens2.append(ground_token2) 147 | 148 | return ground_tokens2 149 | 150 | def _symbol2embedding(self, tokens): 151 | # From symbols to embedding. 152 | x = torch.FloatTensor(len(tokens), self.K, self.d) 153 | # For each sample in the batch 154 | for i, token in enumerate(tokens): 155 | # For each receptive field 156 | for j, a in enumerate(token): 157 | x[i, j, :] = self.symbol_embedding[:, a] * self.mags[a] 158 | return x 159 | 160 | def set_batchsize(self, batchsize): 161 | self.batchsize = batchsize 162 | 163 | def __iter__(self): 164 | while True: 165 | tokens = self.distrib.sample(self.batchsize) 166 | ground_tokens1 = self._ground_tokens(tokens) 167 | # ground_tokens2 = self._ground_tokens(tokens) 168 | ground_tokens2 = self._change_wildcard_tokens(tokens, ground_tokens1) 169 | 170 | x1 = self._symbol2embedding(ground_tokens1) 171 | x2 = self._symbol2embedding(ground_tokens2) 172 | 173 | yield x1, x2, dict(ground_tokens1=ground_tokens1, ground_tokens2=ground_tokens2, tokens=tokens) 174 | 175 | from torchvision import transforms, datasets 176 | from torch.utils.data.dataloader import DataLoader 177 | 178 | def get_mnist_transform(args): 179 | trans = [] 180 | if args.dataset_use_aug: 181 | strength = args.dataset_use_aug_strength 182 | trans.append(transforms.RandomResizedCrop((28,28), scale=(1.0-strength, 1.0), ratio=(1.0-strength, 1.0+strength))) 183 | trans.append(transforms.ToTensor()) 184 | return transforms.Compose(trans) 185 | 186 | class MultiViewDataInjector(object): 187 | def __init__(self, transforms): 188 | self.transforms = transforms 189 | 190 | def __call__(self, sample): 191 | output = [transform(sample) for transform in self.transforms] 192 | return output 193 | 194 | 195 | # MNIST generator 196 | class MNISTGenerator: 197 | def __init__(self, args): 198 | transform = get_mnist_transform(args) 199 | self.train_dataset = datasets.MNIST(args.dataset_path, train=True, download=True, transform=MultiViewDataInjector([transform, transform])) 200 | self.train_loader = DataLoader(self.train_dataset, batch_size=args.batchsize, num_workers=0, drop_last=True, shuffle=True) 201 | self.K_side = 2 202 | self.d_side = 14 203 | self.d = self.d_side * self.d_side 204 | self.K = self.K_side * self.K_side 205 | 206 | def __iter__(self): 207 | while True: 208 | for (x1s, x2s), labels in self.train_loader: 209 | # Flattern x1s and x2s 210 | #x1s = x1s.view(-1, 1, 7, 4, 7, 4).permute(0, 1, 3, 5, 2, 4).reshape(-1, self.K, self.d) 211 | #x2s = x2s.view(-1, 1, 7, 4, 7, 4).permute(0, 1, 3, 5, 2, 4).reshape(-1, self.K, self.d) 212 | x1s = x1s.view(-1, 1, self.K_side, self.d_side, self.K_side, self.d_side).permute(0, 1, 2, 4, 3, 5).reshape(-1, self.K, self.d) 213 | x2s = x2s.view(-1, 1, self.K_side, self.d_side, self.K_side, self.d_side).permute(0, 1, 2, 4, 3, 5).reshape(-1, self.K, self.d) 214 | yield x1s, x2s, dict(labels=labels) 215 | 216 | print("Dataset end, restart (and reshuffle)") 217 | -------------------------------------------------------------------------------- /ssl/real-dataset/check_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | logdir=$1 4 | shift 5 | run_analysis=$1 6 | shift 7 | 8 | if [ "$run_analysis" -eq "1" ]; then 9 | python ~/tools2/analyze.py $logdir --num_process 32 10 | fi 11 | 12 | echo python ~/tools2/stats.py $logdir 13 | python ~/tools2/stats.py $logdir --groups / "$@" 14 | -------------------------------------------------------------------------------- /ssl/real-dataset/config/bn_gen.yaml: -------------------------------------------------------------------------------- 1 | N: 1000 2 | batchsize: 128 3 | T: 0.1 4 | seed: 1 5 | 6 | model: 7 | _target_: "bn_gen.Model" 8 | multi: null # if null, then we determine this with beta, the degree of over-parameterization. 9 | # output dimension of second layer (right before contrastive learning) 10 | output_d: 20 11 | output_nonlinearity: false 12 | 13 | w1_bias: false 14 | activation: relu 15 | per_layer_normalize: false 16 | shared_low_layer: false 17 | 18 | bn_spec: 19 | use_bn: true 20 | backprop_mean: true 21 | backprop_var: true 22 | 23 | # Data distribution: "AA**" means that the pattern is AA at the beginning, followed by other random patterns. 24 | #distributions = [ "AB***", "*BC**", "**CDE", "A****" ] 25 | #distributions = [ "A****", "*B***", "**CDE"] 26 | #distributions = [ "CA***", "*BC**", "C**DA", "C*E**" ] 27 | #distributions = [ "ABC**", "*ABC*", "**ABC" ] 28 | #distributions = [ "ABC**", "*ABC*" ] 29 | distri: 30 | specific: null # "CA***-*BC**-C**DA-C*E**" 31 | # Number of locations. E.g., for CA***, num_loc = 5 32 | num_loc: 10 33 | # Number of tokens at each location, e.g., 8 = ABCDEFGH 34 | # num_tokens / num_symbols 35 | num_tokens: 8 36 | # Number of tokens allowed at each location. 37 | num_tokens_per_pos: 2 38 | 39 | # For pattern [ "CA***", "*BC**", "C**DA", "C*E**" ], we have pattern_cnt = 4, pattern_len = 5 40 | pattern_cnt: 10 41 | pattern_len: 5 42 | 43 | generator: 44 | _target_: "bn_gen.Generator" 45 | # How many wildcard needs to be replaced in the augmentation? 46 | aug_degree: 5 47 | mag_split: 1 48 | # embedding dimension. if null, then it is the same as num_token and all embedding vectors are unit vector. 49 | # otherwise we will create non-orthogonal embedding vectors. 50 | d: null 51 | 52 | dataset_path: "/checkpoint/yuandong/datasets" 53 | dataset_use_aug: false 54 | dataset_use_aug_strength: 0.9 55 | dataset: null 56 | 57 | l2_type: regular 58 | loss_type: infoNCE 59 | aug: true 60 | beta: 1 61 | 62 | similarity: dotprod # or negdist 63 | 64 | niter: 5000 65 | 66 | opt: 67 | lr: 0.01 68 | momentum: 0.9 69 | wd: 5e-3 -------------------------------------------------------------------------------- /ssl/real-dataset/config/byol_config.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | name: resnet18 3 | pretrained_path: None 4 | conv_spec: 5 | variant: regular 6 | freq: 100 7 | resample_ratio: 0.05 8 | layer_involved: conv1-conv2 9 | include_first_conv: false 10 | 11 | bn_spec: 12 | enable_bn1: True 13 | enable_bn2: True 14 | bn_variant: regular # either regular/no_affine/no_proj/no_affine_custom/proj_only_mean/proj_only_var 15 | 16 | projection_head: 17 | mlp_hidden_size: 512 18 | projection_size: 128 19 | 20 | predictor_head: 21 | mlp_hidden_size: 512 22 | projection_size: 128 23 | 24 | aug: 25 | jitter: 1 26 | prob_hflip: 0.5 27 | prob_grayscale: 0.2 28 | blur_sz: 0.1 29 | 30 | # Can be either simclr or byol 31 | method: byol 32 | 33 | trainer: 34 | batch_size: 128 35 | save_per_epoch: 20 36 | m: 0.996 # momentum update 37 | # simply divide 38 | checkpoint_interval: 5000 39 | max_epochs: 100 40 | num_workers: 16 41 | target_noise: 0.005 42 | train_predictor: true 43 | has_predictor: true 44 | 45 | projector_params: 46 | has_bias: true 47 | has_relu: true 48 | custom_nz: null 49 | normalization: bn 50 | has_bn_affine: false 51 | additional_bn_at_input: false 52 | custom_bn: 53 | mean: normal 54 | std: normal 55 | 56 | predictor_params: 57 | has_bias: true 58 | has_relu: true 59 | custom_nz: null 60 | normalization: bn 61 | has_bn_affine: false 62 | additional_bn_at_input: false 63 | custom_bn: 64 | mean: normal 65 | std: normal 66 | 67 | projector_same_as_predictor: false 68 | predictor_init: 69 | low: null 70 | high: null 71 | 72 | # Only use in simclr 73 | nce_loss: 74 | temperature: 0.5 75 | use_cosine_similarity: True 76 | beta: 0 77 | add_one_in_neg: false 78 | loss_type: default # Can be exact_cov, dual, dual2, dual_backprop, dual_lowrank 79 | exact_cov_unaug_sim: false 80 | alpha_type: exp # or poly 81 | alpha_exponent: 1 # 1 = square loss. 82 | alpha_eps: 0.1 83 | inverse_exponent: 1 84 | low_rank: null 85 | 86 | l2_reg_type: regular 87 | 88 | grad_combination_margin: null 89 | 90 | # Can be symmetric, diagonal, onehalfeig, symmetric_norm, solve, minimal_space, corr, directcopy 91 | predictor_reg: "symmetric" 92 | predictor_freq: 0 93 | predictor_rank: 0.5 94 | predictor_eig: 0.5 95 | predictor_eps: 1e-5 96 | 97 | dyn_time: null 98 | dyn_eps: 0.0 99 | dyn_eps_inside: false 100 | dyn_reg: null 101 | dyn_zero_mean: false 102 | # null = average per predictor_freq 103 | dyn_lambda: 0.8 104 | dyn_psd: null 105 | dyn_bn: false 106 | dyn_diagonalize: false 107 | dyn_sym: true 108 | dyn_convert: 2 109 | dyn_noise: null 110 | 111 | corr_collect: false 112 | use_l2_normalization: true 113 | 114 | # Predictor weight decay. 115 | predictor_wd: null 116 | 117 | balance_type: clamp 118 | n_corr: 2 119 | 120 | # When predictor_reg == "solve", how W should be obtained, left or right 121 | solve_direction: "left" 122 | 123 | # call stuff that will be called if rand_pred_n_epoch > 0 before initialization. 124 | init_rand_pred: false 125 | 126 | rand_pred_n_epoch: 0 127 | rand_pred_n_iter: 0 128 | # Can be "all", "top", "bottom" 129 | rand_pred_reg: "all" 130 | 131 | noise_blend: 0.0 132 | 133 | # Without predictor, just use order of variance 134 | use_order_of_variance: False 135 | corr_eigen_decomp: True 136 | 137 | use_optimizer: sgd 138 | 139 | predictor_optimizer_same: true 140 | 141 | optimizer: 142 | params: 143 | lr: 0.03 144 | momentum: 0.9 145 | weight_decay: 0.0004 146 | 147 | predictor_optimizer: 148 | params: 149 | lr: 0.03 150 | momentum: 0.9 151 | weight_decay: 0.0004 152 | 153 | dataset_path: /checkpoint/yuandong/datasets 154 | # Can be stl10 or cifar10 155 | dataset: stl10 156 | 157 | seed: 1 158 | gpu: 0 159 | 160 | eval_after_each_epoch: true 161 | 162 | githash: 163 | sweep_filename: 164 | 165 | test: 166 | exp_name_list: "" 167 | load_epoch_list: [] 168 | batch_size: 512 169 | -------------------------------------------------------------------------------- /ssl/real-dataset/config/hydra/launcher/submitit.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.launcher 2 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher 3 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 4 | timeout_min: 100000 5 | cpus_per_task: 20 6 | gpus_per_node: 1 7 | nodes: 1 8 | num_gpus: 1 9 | mem_gb: 0 10 | time: 4320 11 | name: ${hydra.job.name} 12 | partition: learnfair 13 | comment: null 14 | constraint: null 15 | exclude: null 16 | signal_delay_s: 120 17 | max_num_timeout: 0 18 | 19 | -------------------------------------------------------------------------------- /ssl/real-dataset/config/relu_2layer.yaml: -------------------------------------------------------------------------------- 1 | N: 1000 2 | batchsize: 128 3 | T: 0.1 4 | multi: 10 5 | seed: 1 6 | 7 | # output dimension of second layer (right before contrastive learning) 8 | d_output: 20 9 | d_hidden: 10 10 | 11 | # The first dimension is always the highest mag. 12 | distri: 13 | d: 10 14 | mag_start: 10 15 | mag_end: 1 16 | std_aug: 0.5 17 | 18 | w1_bias: false 19 | l2_type: regular 20 | loss_type: infoNCE 21 | activation: relu 22 | 23 | similarity: dotprod # or negdist 24 | normalization: none # none, perlayer, perfilter 25 | 26 | niter: 5000 27 | 28 | use_bn: true 29 | 30 | opt: 31 | lr: 0.01 32 | momentum: 0.9 33 | wd: 5e-3 -------------------------------------------------------------------------------- /ssl/real-dataset/config/sa.yaml: -------------------------------------------------------------------------------- 1 | niter: 3000 2 | seed: 1 3 | 4 | # number of tokens 5 | M: 10 6 | # Length of the sequence 7 | L: 10 8 | # dimension of embedding 9 | d: 5 10 | # batchsize 11 | batchsize: 128 12 | 13 | opt: 14 | lr: 0.1 15 | momentum: 0.9 16 | wd: 5e-4 -------------------------------------------------------------------------------- /ssl/real-dataset/config/sa_linear.yaml: -------------------------------------------------------------------------------- 1 | niter: 1000 2 | seed: 1 3 | 4 | # Length of seq 5 | L: 10 6 | # dimension of embedding 7 | d: 5 8 | # batchsize 9 | batchsize: 128 10 | 11 | opt: 12 | lr: 0.1 13 | momentum: 0.9 14 | wd: 5e-4 -------------------------------------------------------------------------------- /ssl/real-dataset/config/test.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | -------------------------------------------------------------------------------- /ssl/real-dataset/data/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class GaussianBlur(object): 8 | """blur a single image on CPU""" 9 | 10 | def __init__(self, kernel_size): 11 | radias = kernel_size // 2 12 | kernel_size = radias * 2 + 1 13 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), 14 | stride=1, padding=0, bias=False, groups=3) 15 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), 16 | stride=1, padding=0, bias=False, groups=3) 17 | self.k = kernel_size 18 | self.r = radias 19 | 20 | self.blur = nn.Sequential( 21 | nn.ReflectionPad2d(radias), 22 | self.blur_h, 23 | self.blur_v 24 | ) 25 | 26 | self.pil_to_tensor = transforms.ToTensor() 27 | self.tensor_to_pil = transforms.ToPILImage() 28 | 29 | def __call__(self, img): 30 | img = self.pil_to_tensor(img).unsqueeze(0) 31 | 32 | sigma = np.random.uniform(0.1, 2.0) 33 | x = np.arange(-self.r, self.r + 1) 34 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 35 | x = x / x.sum() 36 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 37 | 38 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 39 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 40 | 41 | with torch.no_grad(): 42 | img = self.blur(img) 43 | img = img.squeeze() 44 | 45 | img = self.tensor_to_pil(img) 46 | 47 | return img 48 | -------------------------------------------------------------------------------- /ssl/real-dataset/data/multi_view_data_injector.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms 2 | 3 | 4 | class MultiViewDataInjector(object): 5 | def __init__(self, *args): 6 | self.transforms = args[0] 7 | self.random_flip = transforms.RandomHorizontalFlip() 8 | 9 | def __call__(self, sample, *with_consistent_flipping): 10 | if with_consistent_flipping: 11 | sample = self.random_flip(sample) 12 | output = [transform(sample) for transform in self.transforms] 13 | return output -------------------------------------------------------------------------------- /ssl/real-dataset/data/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import transforms 2 | from data.gaussian_blur import GaussianBlur 3 | 4 | 5 | def get_simclr_data_transforms_train(dataset_name, args): 6 | s = args["jitter"] 7 | 8 | if dataset_name == "stl10": 9 | input_shape = (96,96,3) 10 | # get a set of data augmentation transformations as described in the SimCLR paper. 11 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 12 | return transforms.Compose( 13 | [ 14 | transforms.RandomResizedCrop(size=input_shape[0]), 15 | transforms.RandomHorizontalFlip(p=args["prob_hflip"]), 16 | transforms.RandomApply([color_jitter], p=0.8), 17 | transforms.RandomGrayscale(p=args["prob_grayscale"]), 18 | GaussianBlur(kernel_size=int(args["blur_sz"] * input_shape[0])), 19 | transforms.ToTensor() 20 | ]) 21 | 22 | elif dataset_name in ["cifar10", "cifar100"]: 23 | # No Gaussian blur since cifar10/100 images is small. 24 | return transforms.Compose( 25 | [ 26 | transforms.RandomResizedCrop(32), 27 | transforms.RandomHorizontalFlip(p=args["prob_hflip"]), 28 | transforms.RandomApply([transforms.ColorJitter(0.4 * s, 0.4 * s, 0.4 * s, 0.1 * s)], p=0.8), 29 | transforms.RandomGrayscale(p=args["prob_grayscale"]), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 32 | ]) 33 | else: 34 | raise RuntimeError(f"unknown dataset: {dataset_name}") 35 | 36 | def get_simclr_data_transforms_test(dataset_name): 37 | if dataset_name in ["cifar10", "cifar100"]: 38 | return transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]) 41 | elif dataset_name == "stl10": 42 | return transforms.Compose([transforms.ToTensor()]) 43 | else: 44 | raise RuntimeError(f"unknown dataset: {dataset_name}") 45 | -------------------------------------------------------------------------------- /ssl/real-dataset/linear_feature_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import yaml 4 | from torchvision import transforms, datasets 5 | import torchvision 6 | import numpy as np 7 | import os 8 | from sklearn import preprocessing 9 | from torch.utils.data.dataloader import DataLoader 10 | from models.resnet_base_network import ResNet18 11 | from data.transforms import get_simclr_data_transforms_test 12 | import hydra 13 | import glob 14 | 15 | class LogisticRegression(torch.nn.Module): 16 | def __init__(self, input_dim, output_dim): 17 | super(LogisticRegression, self).__init__() 18 | self.linear = torch.nn.Linear(input_dim, output_dim) 19 | 20 | def forward(self, x): 21 | return self.linear(x) 22 | 23 | 24 | def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test): 25 | train = torch.utils.data.TensorDataset(X_train, y_train) 26 | train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True) 27 | 28 | test = torch.utils.data.TensorDataset(X_test, y_test) 29 | test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False) 30 | return train_loader, test_loader 31 | 32 | 33 | def get_features_from_encoder(encoder, loader, device): 34 | x_train = [] 35 | y_train = [] 36 | 37 | # get the features from the pre-trained model 38 | with torch.no_grad(): 39 | for i, (x, y) in enumerate(loader): 40 | x = x.to(device) 41 | feature_vector = encoder(x) 42 | x_train.extend(feature_vector) 43 | y_train.extend(y.numpy()) 44 | 45 | x_train = torch.stack(x_train) 46 | y_train = torch.tensor(y_train) 47 | return x_train, y_train 48 | 49 | import logging 50 | log = logging.getLogger(__file__) 51 | 52 | class Evaluator: 53 | def __init__(self, dataset, dataset_path, batch_size): 54 | data_transforms = get_simclr_data_transforms_test(dataset) 55 | if dataset == "stl10": 56 | train_dataset = datasets.STL10(dataset_path, split='train', download=False, 57 | transform=data_transforms) 58 | test_dataset = datasets.STL10(dataset_path, split='test', download=False, 59 | transform=data_transforms) 60 | elif dataset == "cifar10": 61 | train_dataset = datasets.CIFAR10(dataset_path, train=True, download=False, 62 | transform=data_transforms) 63 | test_dataset = datasets.CIFAR10(dataset_path, train=False, download=False, 64 | transform=data_transforms) 65 | elif dataset == "cifar100": 66 | train_dataset = datasets.CIFAR100(dataset_path, train=True, download=False, 67 | transform=data_transforms) 68 | test_dataset = datasets.CIFAR100(dataset_path, train=False, download=False, 69 | transform=data_transforms) 70 | else: 71 | raise RuntimeError(f"Unknown dataset! {dataset}") 72 | 73 | log.info(f"Input shape: {train_dataset[0][0].shape}") 74 | 75 | self.stl_train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, drop_last=False, shuffle=True) 76 | self.stl_test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0, drop_last=False, shuffle=True) 77 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 78 | self.batch_size = batch_size 79 | 80 | def eval_model(self, encoder, save_path=None, num_epoch=50): 81 | remove_projection_head = True 82 | if remove_projection_head: 83 | output_feature_dim = encoder.feature_dim 84 | encoder = torch.nn.Sequential(*list(encoder.children())[:-1]) 85 | else: 86 | output_feature_dim = encoder.projetion.net[-1].out_features 87 | 88 | device = self.device 89 | encoder = encoder.to(device) 90 | 91 | encoder.eval() 92 | x_train, y_train = get_features_from_encoder(encoder, self.stl_train_loader, device) 93 | x_test, y_test = get_features_from_encoder(encoder, self.stl_test_loader, device) 94 | if save_path: 95 | np.savez(save_path, x_train.cpu().numpy(), x_test.cpu().numpy(), y_train.cpu().numpy(), 96 | y_test.cpu().numpy()) 97 | 98 | if len(x_train.shape) > 2: 99 | x_train = torch.mean(x_train, dim=[2, 3]) 100 | x_test = torch.mean(x_test, dim=[2, 3]) 101 | 102 | # log.info("Training data shape:", x_train.shape, y_train.shape) 103 | # log.info("Testing data shape:", x_test.shape, y_test.shape) 104 | 105 | x_train = x_train.cpu().numpy() 106 | x_test = x_test.cpu().numpy() 107 | scaler = preprocessing.StandardScaler() 108 | scaler.fit(x_train) 109 | x_train = scaler.transform(x_train).astype(np.float32) 110 | x_test = scaler.transform(x_test).astype(np.float32) 111 | 112 | train_loader, test_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), y_train, 113 | torch.from_numpy(x_test), y_test) 114 | 115 | logreg = LogisticRegression(output_feature_dim, train_loader.dataset.tensors[1].max() + 1) 116 | logreg = logreg.to(device) 117 | 118 | optimizer = torch.optim.Adam(logreg.parameters(), lr=3e-4, weight_decay=1e-3) 119 | criterion = torch.nn.CrossEntropyLoss() 120 | eval_every_n_epochs = 1 121 | 122 | best_acc = 0. 123 | for epoch in range(num_epoch): 124 | for x, y in train_loader: 125 | x = x.to(device) 126 | y = y.to(device) 127 | 128 | optimizer.zero_grad() 129 | 130 | logits = logreg(x) 131 | predictions = torch.argmax(logits, dim=1) 132 | 133 | loss = criterion(logits, y) 134 | 135 | loss.backward() 136 | optimizer.step() 137 | 138 | if epoch % eval_every_n_epochs == 0: 139 | correct = 0 140 | total = 0 141 | for x, y in test_loader: 142 | x = x.to(device) 143 | y = y.to(device) 144 | 145 | logits = logreg(x) 146 | predictions = torch.argmax(logits, dim=1) 147 | 148 | total += y.size(0) 149 | correct += (predictions == y).sum().item() 150 | 151 | acc = 100 * correct / total 152 | # log.info(f"Epoch {epoch} Testing accuracy: {acc}") 153 | if acc > best_acc: 154 | best_acc = acc 155 | return best_acc 156 | 157 | 158 | def linear_eval(dataset, dataset_path, batch_size, exp_name_list, load_epoch_list, default_network_params=None, default_trainer_params=None): 159 | evaluator = Evaluator(dataset, dataset_path, batch_size) 160 | 161 | result_dict = {} 162 | result_list = [] 163 | 164 | for exp_name in exp_name_list: 165 | arg_file = os.path.join(f"{exp_name}", "args.pt") 166 | if os.path.exists(arg_file): 167 | args = torch.load(arg_file) 168 | network_params = args["network"] 169 | trainer_params = args["trainer"] 170 | else: 171 | network_params = default_network_params 172 | trainer_params = default_trainer_params 173 | 174 | log.info(network_params) 175 | 176 | if len(load_epoch_list) == 0: 177 | # Evaluate all models saved in the folder. 178 | models = [ model for model in glob.glob(os.path.join(exp_name, "checkpoints", "model_*.pth")) ] 179 | else: 180 | models = [ f'{exp_name}/checkpoints/model_{str(epoch).zfill(3)}.pth' for epoch in load_epoch_list ] 181 | 182 | for load_path in models: 183 | save_path = None 184 | 185 | load_params = torch.load( 186 | os.path.join(load_path), 187 | map_location=torch.device(evaluator.device) 188 | ) 189 | encoder = ResNet18(dataset=dataset, options=trainer_params["projector_params"], **network_params) 190 | encoder.load_state_dict(load_params['online_network_state_dict']) 191 | log.info("Load from {}.".format(load_path)) 192 | 193 | best_acc = evaluator.eval_model(encoder, save_path=save_path) 194 | 195 | log.info(f"{load_path}: Best Acc {best_acc}") 196 | result_dict[load_path] = best_acc 197 | result_list.append(best_acc) 198 | 199 | for key in result_dict: 200 | log.info(f"{key}: {result_dict[key]}") 201 | log.info(f"mean acc: {np.mean(result_list)}, std: {np.std(result_list)}") 202 | return result_dict 203 | 204 | 205 | 206 | @hydra.main("config/byol_config.yaml") 207 | def main(args): 208 | # root_dir = '/private/home/lantaoyu1/projects/PyTorch-BYOL/runs_09_03' 209 | # root_dir = '/checkpoint/lantaoyu1/PyTorch-BYOL/runs' 210 | 211 | # exp_name_list = [f"OriginalBN-RandomTargetInit_{i}" for i in range(5)] 212 | # exp_name_list = ['byol-FixBN'] 213 | # exp_name_list = [f"ZeroMean-Std_seed-0_reinit-{i}" for i in [1, 2, 3, 4, 5, 10, 15, 20]] 214 | 215 | default_network_params = dict(args.network) 216 | default_trainer_params = dict(args.trainer) 217 | exp_name_list = args.test.exp_name_list.split(",") 218 | load_epoch_list = args.test.load_epoch_list 219 | linear_eval(args.dataset_path, args.test.batch_size, exp_name_list, load_epoch_list, 220 | default_network_params=default_network_params, 221 | default_trainer_params=default_trainer_params) 222 | 223 | if __name__ == "__main__": 224 | main() 225 | 226 | 227 | ''' 228 | BYOL-AE 229 | 000: 45.86 230 | 010: 58.15 231 | 020: 68.26 232 | 030: 71.3 233 | 040: 72.76 234 | 050: 73.35 235 | 060: 75.23 236 | 070: 75.18 237 | 238 | BYOL 239 | 000: 43.9 240 | 010: 61.61 241 | 020: 66.55 242 | 030: 68.53 243 | 040: 70.08 244 | 050: 71.23 245 | 060: 72.95 246 | 070: 72.15 247 | ''' 248 | 249 | """ 250 | # Epoch 90 251 | byol-ZeroMean-Std: 77.9875 252 | byol-ZeroMean-StdDetach: 70.5125 253 | byol-Std: 69.4625 254 | byol-ZeroMean: 65.7125 255 | byol-ZeroMeanDetach: 24.0875 256 | byol-StdDetach: 50.325 257 | byol-ZeroMeanDetach-Std: 22.75 258 | byol-ZeroMeanDetach-StdDetach: 44.6375 259 | 260 | # Epoch 40 261 | byol-ZeroMean-Std: 73.225 262 | byol-ZeroMean-StdDetach: 72.575 263 | byol-Std: 62.9875 264 | byol-ZeroMean: 59.9125 265 | byol-ZeroMeanDetach: 31.375 266 | byol-StdDetach: 55.0 267 | byol-ZeroMeanDetach-Std: 40.325 268 | byol-ZeroMeanDetach-StdDetach: 35.35 269 | """ 270 | -------------------------------------------------------------------------------- /ssl/real-dataset/loss/nt_xent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def quadratic_assignment(dists, tau): 5 | # each row of dists need to be converged to alpha. 6 | N = dists.size(1) 7 | sorted_dists, sorted_indices = dists.sort(1) 8 | moving_sum = sorted_dists.cumsum(dim=1) 9 | indices = torch.arange(0, N).to(moving_sum.device) 10 | moving_avg = (moving_sum + tau) / (indices[None,:] + 1) 11 | indices_to_keep = sorted_dists < moving_avg 12 | first_indices = torch.logical_xor(indices_to_keep[:, 1:], indices_to_keep[:, :-1]).float() @ indices[1:].float() 13 | mu = moving_avg.gather(1, first_indices.long().view(dists.size(0), -1)).squeeze() 14 | return (mu[:, None] - dists).clamp(min=0) / tau 15 | 16 | def inverse_assignment(dists, tau, inverse_exponent): 17 | return 1 / (dists + tau).pow(inverse_exponent) 18 | 19 | class NTXentLoss(torch.nn.Module): 20 | 21 | def __init__(self, device, batch_size, params): 22 | super(NTXentLoss, self).__init__() 23 | self.batch_size = batch_size 24 | # temperature, use_cosine_similarity, beta, add_one_in_neg, loss_type, exact_cov_unaug_sim 25 | self.params = params 26 | 27 | self.device = device 28 | self.softmax = torch.nn.Softmax(dim=-1) 29 | self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool) 30 | self.mask_samples_small = self._get_correlated_mask_small().type(torch.bool) 31 | self.similarity_function = self._get_similarity_function(params["use_cosine_similarity"]) 32 | self.criterion = torch.nn.CrossEntropyLoss(reduction="sum") 33 | 34 | def need_unaug_data(self): 35 | return self.params["exact_cov_unaug_sim"] 36 | 37 | def _get_similarity_function(self, use_cosine_similarity): 38 | if use_cosine_similarity: 39 | self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) 40 | return self._cosine_simililarity 41 | else: 42 | return self._dot_simililarity 43 | 44 | def _get_correlated_mask(self): 45 | diag = np.eye(2 * self.batch_size) 46 | l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size) 47 | l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size) 48 | mask = torch.from_numpy((diag + l1 + l2)) 49 | mask = (1 - mask).type(torch.bool) 50 | return mask.to(self.device) 51 | 52 | def _get_correlated_mask_small(self): 53 | diag = np.eye(self.batch_size) 54 | mask = torch.from_numpy(diag) 55 | mask = (1 - mask).type(torch.bool) 56 | return mask.to(self.device) 57 | 58 | @staticmethod 59 | def _dot_simililarity(x, y): 60 | v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2) 61 | # x shape: (N, 1, C) 62 | # y shape: (1, C, 2N) 63 | # v shape: (N, 2N) 64 | return v 65 | 66 | def _cosine_simililarity(self, x, y): 67 | # x shape: (2N, 1, C) 68 | # y shape: (1, 2N, C) 69 | # v shape: (2N, 2N) 70 | v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) 71 | return v 72 | 73 | def _process_pairwise_dist(self, dist): 74 | temperature = self.params["temperature"] 75 | alpha_exponent = self.params["alpha_exponent"] 76 | alpha_type = self.params["alpha_type"] 77 | inverse_exponent = self.params["inverse_exponent"] 78 | 79 | w = dist.pow(alpha_exponent) 80 | 81 | if alpha_type == "exp": 82 | w = (-w / temperature).exp() 83 | elif alpha_type == "quadratic": 84 | w = quadratic_assignment(w, temperature) 85 | elif alpha_type == "inverse": 86 | w = inverse_assignment(w, temperature, inverse_exponent) 87 | else: 88 | raise RuntimeError(f"Unknown alpha_type = {alpha_type}") 89 | 90 | return w 91 | 92 | def forward(self, zis, zjs, zs): 93 | # Two towers. For each of the N samples, it has zj and zi. 94 | representations = torch.cat([zjs, zis], dim=0) 95 | similarity_matrix = self.similarity_function(representations, representations) 96 | 97 | # filter out the scores from the positive samples 98 | l_pos = torch.diag(similarity_matrix, self.batch_size) 99 | r_pos = torch.diag(similarity_matrix, -self.batch_size) 100 | # 2N positive pairs 101 | positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1) 102 | 103 | # 2N * (2N - 2) negative samples. 104 | # The i-th row corresponds to 2N - 1 negative samples for i-th sample. 105 | negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) 106 | 107 | temperature = self.params["temperature"] 108 | beta = self.params["beta"] 109 | loss_type = self.params["loss_type"] 110 | alpha_eps = self.params["alpha_eps"] 111 | low_rank = self.params["low_rank"] 112 | 113 | if loss_type == "exact_cov": 114 | # 1 - sim = dist 115 | r_neg = 1 - negatives 116 | r_pos = 1 - positives 117 | 118 | num_negative = negatives.size(1) 119 | 120 | # Similarity matrix for unaugmented data. 121 | if self.exact_cov_unaug_sim and zs is not None: 122 | similarity_matrix2 = self.similarity_function(zs, zs) 123 | negatives_unaug = similarity_matrix2[self.mask_samples_small].view(self.batch_size, -1) 124 | r_neg_unaug = 1 - negatives_unaug 125 | w = (-r_neg_unaug.detach() / temperature).exp() 126 | # Duplicated four times. 127 | w = torch.cat([w, w], dim=0) 128 | w = torch.cat([w, w], dim=1) 129 | else: 130 | w = (-r_neg.detach() / temperature).exp() 131 | 132 | w = w / (1 + w) / temperature / num_negative 133 | # Then we construct the loss function. 134 | w_pos = w.sum(dim=1, keepdim=True) 135 | loss = (w_pos * r_pos - (w * r_neg).sum(dim=1)).mean() 136 | loss_intra = beta * (w_pos * r_pos).mean() 137 | 138 | elif loss_type == "dual": 139 | # 1 - sim = dist 140 | r_neg = 1 - negatives 141 | r_pos = 1 - positives 142 | 143 | # w only depends on r_neg 144 | w = self._process_pairwise_dist(r_neg.detach()) 145 | 146 | # The below is actually mean(w * (r_pos - r_neg)) 147 | w_pos = w.sum(dim=1, keepdim=True) 148 | loss = (w_pos * r_pos - (w * r_neg).sum(dim=1)).mean() 149 | loss_intra = beta * (w_pos * r_pos).mean() 150 | 151 | elif loss_type == "dual2": 152 | # New version of dual 153 | # dist_diff = d_i^2 - d_{ij}^2 154 | dist_sqr = negatives - positives 155 | # w = dist_sqr.detach().pow(alpha_exponent) 156 | 157 | r_neg = 1 - negatives 158 | r_pos = 1 - positives 159 | 160 | # w only depends on r_neg 161 | w = self._process_pairwise_dist(r_neg.detach()) 162 | 163 | # The below is actually mean(w * (r_pos - r_neg)) = mean(w * (negatives - positives)) = mean(w * dist_sqr) 164 | # Get summation. 165 | w_Z = w.sum(dim=1, keepdim=True) 166 | w = w / (w_Z + alpha_eps * w.size(1)) 167 | loss = (w * dist_sqr).sum(dim=1).mean() 168 | loss_intra = 0 169 | 170 | elif loss_type == "dual_per_sample": 171 | # New version of dual 172 | # dist_diff = d_i^2 - d_{ij}^2 173 | dist_sqr = negatives - positives 174 | # w = dist_sqr.detach().pow(alpha_exponent) 175 | 176 | r_neg = 1 - negatives 177 | r_pos = 1 - positives 178 | 179 | # w only depends on r_neg 180 | w = self._process_pairwise_dist(r_neg.detach()) 181 | 182 | w = w.sum(dim=1, keepdim=True) 183 | # Normalize across samples 184 | w = w / w.sum() 185 | # now w is rank one. 186 | loss = (w * dist_sqr).sum(dim=1).mean() 187 | loss_intra = 0 188 | 189 | elif loss_type == "dual_backprop": 190 | # 1 - sim = dist 191 | r_neg = 1 - negatives 192 | r_pos = 1 - positives 193 | 194 | # w only depends on r_neg 195 | w = self._process_pairwise_dist(r_neg.detach()) 196 | 197 | # The below is actually mean(w * (r_pos - r_neg)) 198 | w_pos = w.sum(dim=1, keepdim=True) 199 | loss = (w_pos * r_pos - (w * r_neg).sum(dim=1)).mean() 200 | loss_intra = beta * (w_pos * r_pos).mean() 201 | 202 | elif loss_type == "dual_lowrank": 203 | # 1 - sim = dist 204 | r_neg = 1 - negatives 205 | r_pos = 1 - positives 206 | 207 | w = (-r_neg.detach() / temperature).exp() 208 | # Do an SVD for the weight. 209 | U, D, V = torch.svd(w,compute_uv=True) 210 | D[low_rank:] = 0 211 | w_low_rank = U @ D.diag() @ V.t() 212 | 213 | ''' 214 | # get approximate low rank decomposition. 215 | sample_importance = w.mean(dim=1) 216 | #sample_importance = sample_importance / sample_importance.sum() 217 | similarity_matrix_low_rank = torch.outer(sample_importance, sample_importance) 218 | w_low_rank = similarity_matrix_low_rank[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1) 219 | ''' 220 | 221 | w_pos_low_rank = w_low_rank.sum(dim=1, keepdim=True) 222 | 223 | # The below is actually mean(w_low_rank * (r_pos - r_neg)) 224 | loss = (w_pos_low_rank * r_pos - (w_low_rank * r_neg).sum(dim=1)).mean() 225 | loss_intra = beta * (w_pos_low_rank * r_pos).mean() 226 | 227 | elif loss_type == "default": 228 | if self.params["add_one_in_neg"]: 229 | all_ones = torch.ones(2 * self.batch_size, 1).to(self.device) 230 | logits = torch.cat((positives, negatives, all_ones), dim=1) 231 | else: 232 | logits = torch.cat((positives, negatives), dim=1) 233 | 234 | logits /= temperature 235 | 236 | labels = torch.zeros(2 * self.batch_size).to(self.device).long() 237 | loss = self.criterion(logits, labels) 238 | 239 | # Make positive strong than negative to trigger an additional term. 240 | loss_intra = -positives.sum() * beta / temperature 241 | loss /= (1.0 + beta) * 2 * self.batch_size 242 | loss_intra /= (1.0 + beta) * 2 * self.batch_size 243 | 244 | elif loss_type == "quadratic": 245 | loss_intra = -positives.mean() 246 | loss = negatives.mean() 247 | 248 | else: 249 | raise RuntimeError(f"Unknown loss_type = {loss_type}") 250 | 251 | return loss, loss_intra, negatives.detach() 252 | -------------------------------------------------------------------------------- /ssl/real-dataset/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import yaml 6 | from omegaconf import OmegaConf 7 | from torchvision import datasets 8 | from data.multi_view_data_injector import MultiViewDataInjector 9 | from data.transforms import get_simclr_data_transforms_train, get_simclr_data_transforms_test 10 | from models.mlp_head import MLPHead 11 | from models.resnet_base_network import ResNet18 12 | from byol_trainer import BYOLTrainer 13 | from simclr_trainer import SimCLRTrainer 14 | import argparse 15 | import os 16 | import hydra 17 | from linear_feature_eval import linear_eval, Evaluator 18 | import common_utils 19 | 20 | import logging 21 | log = logging.getLogger(__file__) 22 | 23 | def hydra2dict(args): 24 | if args.__class__.__name__ != 'DictConfig': 25 | return args 26 | 27 | args = dict(args) 28 | for k in args.keys(): 29 | args[k] = hydra2dict(args[k]) 30 | 31 | return args 32 | 33 | @hydra.main(config_path="config", config_name="byol_config.yaml") 34 | def main(args): 35 | log.info(common_utils.print_info(args)) 36 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}" 37 | common_utils.set_all_seeds(args.seed) 38 | log.info(common_utils.pretty_print_args(args)) 39 | 40 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 41 | log.info(f"Training with: {device}") 42 | 43 | data_transform = get_simclr_data_transforms_train(args['dataset'], args["aug"]) 44 | data_transform_identity = get_simclr_data_transforms_test(args['dataset']) 45 | 46 | if args["dataset"] == "stl10": 47 | train_dataset = datasets.STL10(args.dataset_path, split='train+unlabeled', download=True, 48 | transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity])) 49 | elif args["dataset"] == "cifar10": 50 | train_dataset = datasets.CIFAR10(args.dataset_path, train=True, download=True, 51 | transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity])) 52 | elif args["dataset"] == "cifar100": 53 | train_dataset = datasets.CIFAR100(args.dataset_path, train=True, download=True, 54 | transform=MultiViewDataInjector([data_transform, data_transform, data_transform_identity])) 55 | else: 56 | raise RuntimeError(f"Unknown dataset! {args['dataset']}") 57 | 58 | args = hydra2dict(args) 59 | train_params = args["trainer"] 60 | if train_params["projector_same_as_predictor"]: 61 | train_params["projector_params"] = train_params["predictor_params"] 62 | 63 | # online network 64 | online_network = ResNet18(dataset=args["dataset"], options=train_params["projector_params"], **args['network']).to(device) 65 | if torch.cuda.device_count() > 1: 66 | online_network = torch.nn.parallel.DataParallel(online_network) 67 | 68 | pretrained_path = args['network']['pretrained_path'] 69 | if pretrained_path: 70 | try: 71 | load_params = torch.load(pretrained_path, map_location=torch.device(device)) 72 | online_network.load_state_dict(load_params['online_network_state_dict']) 73 | online_network.load_state_dict(load_params) 74 | log.info("Load from {}.".format(pretrained_path)) 75 | except FileNotFoundError: 76 | log.info("Pre-trained weights not found. Training from scratch.") 77 | 78 | # predictor network 79 | if train_params["has_predictor"] and args["method"] == "byol": 80 | predictor = MLPHead(in_channels=args['network']['projection_head']['projection_size'], 81 | **args['network']['predictor_head'], options=train_params["predictor_params"]).to(device) 82 | if torch.cuda.device_count() > 1: 83 | predictor = torch.nn.parallel.DataParallel(predictor) 84 | else: 85 | predictor = None 86 | 87 | # target encoder 88 | target_network = ResNet18(dataset=args["dataset"], options=train_params["projector_params"], **args['network']).to(device) 89 | if torch.cuda.device_count() > 1: 90 | target_network = torch.nn.parallel.DataParallel(target_network) 91 | 92 | params = online_network.parameters() 93 | 94 | # Save network and parameters. 95 | torch.save(args, "args.pt") 96 | 97 | if args["eval_after_each_epoch"]: 98 | evaluator = Evaluator(args["dataset"], args["dataset_path"], args["test"]["batch_size"]) 99 | else: 100 | evaluator = None 101 | 102 | if args["use_optimizer"] == "adam": 103 | optimizer = torch.optim.Adam(params, lr=args['optimizer']['params']["lr"], weight_decay=args["optimizer"]["params"]['weight_decay']) 104 | elif args["use_optimizer"] == "sgd": 105 | optimizer = torch.optim.SGD(params, **args['optimizer']['params']) 106 | else: 107 | raise RuntimeError(f"Unknown optimizer! {args['use_optimizer']}") 108 | 109 | if args["predictor_optimizer_same"]: 110 | args["predictor_optimizer"] = args["optimizer"] 111 | 112 | if predictor and train_params["train_predictor"]: 113 | predictor_optimizer = torch.optim.SGD(predictor.parameters(), **args['predictor_optimizer']['params']) 114 | 115 | ## SimCLR scheduler 116 | if args["method"] == "simclr": 117 | trainer = SimCLRTrainer(log_dir="./", model=online_network, optimizer=optimizer, evaluator=evaluator, device=device, params=args["trainer"]) 118 | elif args["method"] == "byol": 119 | trainer = BYOLTrainer(log_dir="./", 120 | online_network=online_network, 121 | target_network=target_network, 122 | optimizer=optimizer, 123 | predictor_optimizer=predictor_optimizer, 124 | predictor=predictor, 125 | device=device, 126 | evaluator=evaluator, 127 | **args['trainer']) 128 | else: 129 | raise RuntimeError(f'Unknown method {args["method"]}') 130 | 131 | trainer.train(train_dataset) 132 | 133 | if not args["eval_after_each_epoch"]: 134 | result_eval = linear_eval(args["dataset"], args["dataset_path"], args["test"]["batch_size"], ["./"], []) 135 | torch.save(result_eval, "eval.pth") 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /ssl/real-dataset/main_checkresult.py: -------------------------------------------------------------------------------- 1 | import common_utils 2 | import re 3 | import os 4 | import glob 5 | import torch 6 | 7 | # resnet check. 8 | matcher = re.compile(r"encoder.layer([0-9]).([0-9]).conv([0-9]).weight") 9 | def get_entropy(model): 10 | res = dict() 11 | for k, v in model["online_network_state_dict"].items(): 12 | key = None 13 | if k == "encoder.conv1.weight": 14 | key = "conv1" 15 | else: 16 | m = matcher.match(k) 17 | if m: 18 | key = f"{m.group(1)}-{m.group(2)}-{m.group(3)}" 19 | 20 | if key is not None: 21 | # Compute the sparsity of each filter. 22 | v = v.view(v.size(0), -1).abs() 23 | v = v / (v.sum(dim=1, keepdim=True) + 1e-8) 24 | entropies = - (v * (v + 1e-8).log()).sum(dim=1) 25 | res["h_" + key] = entropies.mean().item() 26 | 27 | return res 28 | 29 | def edge_energy(f): 30 | dx = f[1:, :, :] - f[:-1, :, :] 31 | dy = f[:, 1:, :] - f[:, :-1, :] 32 | return (dx.pow(2).mean() + dy.pow(2).mean()) / 2 33 | 34 | def edge_energy2(f): 35 | dx = f[2:, 1:-1, :] - f[:-2, 1:-1, :] 36 | dy = f[1:-1, 2:, :] - f[1:-1, :-2, :] 37 | return (dx.pow(2) + dy.pow(2)).sqrt().mean() 38 | 39 | def check_edge_stats(subfolder): 40 | model_files = glob.glob(os.path.join(subfolder, "checkpoints/model_*.pth")) 41 | # Find the latest. 42 | model_files = [ (os.path.getmtime(f), f) for f in model_files ] 43 | all_model_files = sorted(model_files, key=lambda x: x[0]) 44 | 45 | config = common_utils.MultiRunUtil.load_full_cfg(subfolder) 46 | 47 | if len(all_model_files) == 0: 48 | return None 49 | 50 | last_model_file = all_model_files[-1][1] 51 | model = torch.load(last_model_file, map_location=torch.device('cpu')) 52 | 53 | res = dict() 54 | 55 | for layer_name, ws in model["online_network_state_dict"].items(): 56 | key = None 57 | if layer_name == "encoder.conv1.weight": 58 | key = "c1" 59 | else: 60 | m = matcher.match(layer_name) 61 | if m: 62 | key = f"l{m.group(1)}{m.group(2)}c{m.group(3)}" 63 | 64 | if key is not None: 65 | avg_edge_strength = 0 66 | avg_edge_strength_normalized = 0 67 | for k in range(ws.size(0)): 68 | w = ws[k,:].permute(1, 2, 0) 69 | avg_edge_strength += edge_energy2(w) 70 | 71 | # Normalize. 72 | w = w / (w.pow(2).sum().sqrt() + 1e-6) 73 | avg_edge_strength_normalized += edge_energy2(w) 74 | 75 | avg_edge_strength /= ws.size(0) 76 | avg_edge_strength_normalized /= ws.size(0) 77 | res[key] = avg_edge_strength.item() * 1000 78 | res[key + "_n"] = avg_edge_strength_normalized.item() * 1000 79 | 80 | return res 81 | 82 | _result_matcher = [ 83 | { 84 | "match": re.compile(r"Epoch (\d+): best_acc: ([\d\.]+)"), 85 | "action": [ 86 | [ "acc", "float(m.group(2))" ] 87 | ] 88 | } 89 | ] 90 | 91 | _attr_multirun = { 92 | "result_group" : { 93 | "performance": ("event", _result_matcher), 94 | "entropy": ("func", check_edge_stats), 95 | }, 96 | "default_result_group" : [ "performance" ], 97 | "default_metrics": [ "acc" ], 98 | "specific_options": dict(acc={}), 99 | "common_options" : dict(topk_mean=1, topk=10, descending=True), 100 | } 101 | -------------------------------------------------------------------------------- /ssl/real-dataset/models/mlp_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | import torch 4 | from copy import deepcopy 5 | 6 | class CustomBN(nn.Module): 7 | def __init__(self, options): 8 | super(CustomBN, self).__init__() 9 | # normal, detach, omit 10 | self.mean = options["mean"] 11 | 12 | # normal, detach, omit 13 | self.std = options["std"] 14 | 15 | def forward(self, x): 16 | if self.mean == "detach": 17 | x = x - x.mean(0).detach() 18 | elif self.mean == "normal": 19 | x = x - x.mean(0) 20 | elif self.mean == "omit": 21 | pass 22 | else: 23 | raise NotImplementedError(f"The mean normalization {self.mean} is not implemented!") 24 | 25 | if self.std == "detach": 26 | x = x / (x.var(0).detach() + 1e-5).sqrt() 27 | elif self.std == "normal": 28 | x = x / (x.var(0) + 1e-5).sqrt() 29 | elif self.std == "omit": 30 | pass 31 | else: 32 | raise NotImplementedError(f"The std normalization {self.std} is not implemented!") 33 | 34 | return x 35 | 36 | class MLPHead(nn.Module): 37 | def __init__(self, in_channels, mlp_hidden_size, projection_size, options=None): 38 | super(MLPHead, self).__init__() 39 | if options is None: 40 | options = dict(normalization="bn", has_bias=True, has_bn_affine=False, has_relu=True, additional_bn_at_input=False, custom_nz=None) 41 | 42 | assert options["custom_nz"] == "grad_act_zero" or options["custom_nz"] is None 43 | 44 | bn_size = in_channels if mlp_hidden_size is None else mlp_hidden_size 45 | l = self._create_normalization(bn_size, options) 46 | 47 | if options["additional_bn_at_input"]: 48 | l_before = nn.BatchNorm1d(in_channels, affine=False) 49 | else: 50 | l_before = None 51 | 52 | # assert "OriginalBN" in option 53 | layers = [] 54 | 55 | if l_before is not None: 56 | layers.append(l_before) 57 | 58 | if mlp_hidden_size is not None: 59 | layers.append(nn.Linear(in_channels, mlp_hidden_size, bias=options["has_bias"])) 60 | if l is not None: 61 | layers.append(l) 62 | if options["has_relu"]: 63 | layers.append(nn.ReLU(inplace=True)) 64 | else: 65 | if l is not None: 66 | layers.append(l) 67 | 68 | layers.append(nn.Linear(bn_size, projection_size, bias=options["has_bias"])) 69 | self.layers = nn.ModuleList(layers) 70 | self.gradW = [ None for _ in self.layers ] 71 | self.masks = [ None for _ in self.layers ] 72 | self.prods = [ list() for _ in self.layers ] 73 | self.custom_nz = options["custom_nz"] 74 | self.compute_adj_grad = True 75 | 76 | def _create_normalization(self, size, options): 77 | # nn.BatchNorm1d(mlp_hidden_size), 78 | method = options["normalization"] 79 | if method == "bn": 80 | l = nn.BatchNorm1d(size, affine=options["has_bn_affine"]) 81 | elif method == "custom_bn": 82 | l = CustomBN(options["custom_bn"]) 83 | elif method == "no_normalization": 84 | l = None 85 | else: 86 | raise NotImplementedError(f"The normalization {method} is not implemented yet!") 87 | return l 88 | 89 | def _compute_reg(self, g, f, x): 90 | # g: n_batch x n_output 91 | # f: n_batch x n_output 92 | # x: n_batch x n_input 93 | # return inner_prod(g[i,:], f[i,:]) * outer_prod(g[i,:], x[i,:]) 94 | with torch.no_grad(): 95 | prod = (g * f).sum(dim=1, keepdim=True) 96 | # n_batch x n_output x n_input 97 | return prod, torch.bmm(g.unsqueeze(2), x.unsqueeze(1)) * prod.unsqueeze(2) 98 | 99 | def _grad_hook(self, g, f, x, i): 100 | # extra weight update. 101 | # gradW = [n_output x n_input] 102 | # Generate a random mask. 103 | mask = (torch.rand(g.size(1)) > 0.5).to(device=g.get_device()) 104 | prod, self.gradW[i] = self._compute_reg(g[:,mask], f[:,mask], x) 105 | self.gradW[i] = self.gradW[i].mean(dim=0) 106 | self.masks[i] = mask 107 | self.prods[i].append((g * f).norm().item() / math.sqrt(g.size(0) * g.size(1))) 108 | return None 109 | 110 | def forward(self, x): 111 | for i, l in enumerate(self.layers): 112 | f = l(x) 113 | if self.compute_adj_grad and isinstance(l, nn.Linear) and self.custom_nz == "grad_act_zero": 114 | # Add a backward hook to accumulate gradient for weight normalization. 115 | # We want E[g f] = 0. 116 | # g: n_batch x n_output 117 | # f: n_batch x n_output 118 | # If we want to make it per sample, we would want to achieve g[i,:] . f[i,:] = 0 119 | # or x[i,:] W g[i,:]' = 0 120 | f.register_hook(lambda g, f=f, x=x, i=i: self._grad_hook(g, f, x, i)) 121 | # For the next layer. 122 | x = f 123 | return x 124 | 125 | def set_adj_grad(self, compute_adj_grad): 126 | self.compute_adj_grad = compute_adj_grad 127 | 128 | def adjust_grad(self): 129 | with torch.no_grad(): 130 | for l, mask, gW in zip(self.layers, self.masks, self.gradW): 131 | if gW is not None: 132 | # mask = Output mask. 133 | # we don't want to add an additional weight decay, so the direction should be orthogonal to l.weight. 134 | w = l.weight[mask,:] 135 | coeff = (gW * w).sum() / w.pow(2).sum() 136 | gW -= coeff * w 137 | l.weight.grad[mask,:] += 100 * gW 138 | 139 | self.gradW = [ None for _ in self.layers ] 140 | 141 | def normalize(self): 142 | if self.custom_nz == "grad_act_zero": 143 | # Normalize all linear weight. 144 | with torch.no_grad(): 145 | for l in self.layers: 146 | if isinstance(l, nn.Linear): 147 | l.weight /= l.weight.norm() 148 | 149 | def get_stats(self): 150 | if self.custom_nz == "grad_act_zero": 151 | s = "grad_act_zero: \n" 152 | for i, (p, l) in enumerate(zip(self.prods, self.layers)): 153 | if len(p) > 0: 154 | s += f"[{i}]: norm: {l.weight.norm()}, mean(f*g): start: {p[0]}, end: {p[-1]}\n" 155 | p.clear() 156 | 157 | return s 158 | return None 159 | 160 | # class MLPHead(nn.Module): 161 | # def __init__(self, in_channels, mlp_hidden_size, projection_size, option): 162 | # super(MLPHead, self).__init__() 163 | # self.linear1 = nn.Linear(in_channels, mlp_hidden_size) 164 | # self.relu = nn.ReLU(inplace=True) 165 | # self.linear2 = nn.Linear(mlp_hidden_size, projection_size) 166 | # self.option = option 167 | # 168 | # def forward(self, x): 169 | # x = self.linear1(x) 170 | # 171 | # if "ZeroMeanDetach" in self.option: 172 | # x = x - x.mean(0).detach() 173 | # elif "ZeroMean" in self.option: 174 | # x = x - x.mean(0) 175 | # 176 | # if "StdDetach" in self.option: 177 | # x = x / (x.var(0).detach() + 1e-5).sqrt() 178 | # elif "Std" in self.option: 179 | # x = x / (x.var(0) + 1e-5).sqrt() 180 | # 181 | # x = self.relu(x) 182 | # x = self.linear2(x) 183 | # return x 184 | 185 | # class MLPHead(nn.Module): 186 | # def __init__(self, in_channels, mlp_hidden_size, projection_size, momentum=0.9): 187 | # super(MLPHead, self).__init__() 188 | # self.linear1 = nn.Linear(in_channels, mlp_hidden_size) 189 | # self.relu = nn.ReLU(inplace=True) 190 | # self.linear2 = nn.Linear(mlp_hidden_size, projection_size) 191 | # self.momentum = momentum 192 | # self.running_mean = None 193 | # 194 | # def forward(self, x): 195 | # x = self.linear1(x) 196 | # if self.running_mean: 197 | # self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * torch.mean(x).detach() 198 | # else: 199 | # self.running_mean = torch.mean(x).detach() 200 | # x = x - self.running_mean 201 | # x = self.relu(x) 202 | # x = self.linear2(x) 203 | # return x 204 | -------------------------------------------------------------------------------- /ssl/real-dataset/paths.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/luckmatters/b777959f48731e52ee5b10af4e7c4f93ba3f0180/ssl/real-dataset/paths.txt -------------------------------------------------------------------------------- /ssl/real-dataset/relu_2layer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import sys 4 | import hydra 5 | import os 6 | import torch.nn as nn 7 | from collections import Counter, defaultdict, deque 8 | from torch.distributions.categorical import Categorical 9 | 10 | from copy import deepcopy 11 | 12 | import torch.nn.functional as F 13 | import glob 14 | import common_utils 15 | 16 | import logging 17 | log = logging.getLogger(__file__) 18 | 19 | class Normalizer: 20 | def __init__(self, layer, name=None): 21 | self.layer = layer 22 | self.get_norms() 23 | log.info(f"[{name}] norm = {self.norm}, row_norms = {self.row_norms}") 24 | 25 | def get_norms(self): 26 | with torch.no_grad(): 27 | self.norm = self.layer.weight.norm() 28 | self.row_norms = self.layer.weight.norm(dim=1) 29 | 30 | def normalize_layer(self): 31 | with torch.no_grad(): 32 | norm = self.layer.weight.norm() 33 | self.layer.weight[:] *= self.norm / norm 34 | if self.layer.bias is not None: 35 | self.layer.bias[:] *= self.norm / norm 36 | 37 | def normalize_layer_filter(self): 38 | with torch.no_grad(): 39 | row_norms = self.layer.weight.norm(dim=1) 40 | ratio = self.row_norms / row_norms 41 | self.layer.weight *= ratio[:,None] 42 | if self.layer.bias is not None: 43 | self.layer.bias *= ratio 44 | 45 | 46 | class Model(nn.Module): 47 | def __init__(self, d, d_hidden, d2, activation="relu", w1_bias=False, use_bn=True): 48 | super(Model, self).__init__() 49 | # d = dimension, K = number of filters. 50 | self.w1 = nn.Linear(d, d_hidden, bias=w1_bias) 51 | 52 | if activation == "relu": 53 | self.activation = nn.ReLU() 54 | elif activation == "linear": 55 | self.activation = lambda x : x 56 | else: 57 | raise RuntimeError(f"Unknown activation {activation}") 58 | 59 | self.w2 = nn.Linear(d_hidden, d2, bias=False) 60 | 61 | if use_bn: 62 | self.bn = nn.BatchNorm1d(d_hidden) 63 | else: 64 | self.bn = None 65 | 66 | self.normalizer_w1 = Normalizer(self.w1, name="W1") 67 | self.normalizer_w2 = Normalizer(self.w2, name="W2") 68 | 69 | def forward(self, x): 70 | y = self.w1(x) 71 | y = self.activation(y) 72 | 73 | if self.bn is not None: 74 | y = self.bn(y) 75 | 76 | return self.w2(y) 77 | 78 | def per_layer_normalize(self): 79 | # Normalize using F-norm 80 | self.normalizer_w1.normalize_layer() 81 | self.normalizer_w2.normalize_layer() 82 | 83 | def per_filter_normalize(self): 84 | self.normalizer_w1.normalize_layer() 85 | self.normalizer_w2.normalize_layer_filter() 86 | 87 | def w1_clamp_all_negatives(self): 88 | with torch.no_grad(): 89 | self.w1.weight[self.w1.weight < 0] = 0 90 | 91 | def pairwise_dist(x): 92 | # x: [N, d] 93 | # ret: [N, N] 94 | norms = x.pow(2).sum(dim=1) 95 | return norms[:,None] + norms[None,:] - 2 * (x @ x.t()) 96 | 97 | def check_result(subfolder): 98 | model_files = glob.glob(os.path.join(subfolder, "model-*.pth")) 99 | # Find the latest. 100 | model_files = [ (os.path.getmtime(f), f) for f in model_files ] 101 | all_model_files = sorted(model_files, key=lambda x: x[0]) 102 | 103 | config = common_utils.MultiRunUtil.load_full_cfg(subfolder) 104 | concentration = [] 105 | coverage = [] 106 | 107 | for _, model_file in all_model_files: 108 | model = torch.load(model_file) 109 | w = model[f"w1.weight"].detach() 110 | # Check how much things are scattered around. 111 | w[w<0] = 0 112 | w = w / (w.norm(dim=1,keepdim=True) + 1e-8) 113 | 114 | concentration.append( w.max(dim=1)[0].mean().item()) 115 | coverage.append(w.max(dim=0)[0].mean().item()) 116 | 117 | res = { 118 | "concentration": concentration, 119 | "coverage": coverage 120 | } 121 | 122 | return res 123 | 124 | class Generator: 125 | def __init__(self, distri): 126 | self.sampler = Categorical(torch.ones(distri.d) / distri.d) 127 | self.distri = distri 128 | self.mags = torch.linspace(distri.mag_start, distri.mag_end, steps=distri.d) 129 | 130 | def sample(self, batchsize, seed=None): 131 | if seed is not None: 132 | torch.manual_seed(seed) 133 | 134 | zs = self.sampler.sample((batchsize,)) 135 | mags = torch.ones(batchsize) * self.mags[zs] 136 | x1mags = mags + torch.randn(batchsize) * self.distri.std_aug 137 | x2mags = mags + torch.randn(batchsize) * self.distri.std_aug 138 | 139 | one_hot = torch.nn.functional.one_hot(zs, num_classes=len(self.mags)) 140 | 141 | x1 = one_hot * x1mags[:,None] 142 | x2 = one_hot * x2mags[:,None] 143 | 144 | return x1, x2, zs, zs 145 | 146 | 147 | _attr_multirun = { 148 | "check_result": { 149 | "default": check_result, 150 | } 151 | "common_options" : dict(topk_mean=1, topk=10, descending=True), 152 | "specific_options": dict(concentration={}, coverage={}), 153 | "default_metrics": [ "concentration", "coverage" ] 154 | } 155 | 156 | def get_y(x1, x2): 157 | N = x1.size(0) 158 | y_neg = torch.kron(x1, torch.ones(N, 1)) - torch.cat([x1] * N, dim=0) 159 | y_pos = x1 - x2 160 | return y_neg, y_pos 161 | 162 | def compute_contrastive_covariance(f1, f2, x1, x2, T, norm_type): 163 | N = f1.size(0) 164 | d = x1.size(1) 165 | label = list(range(N)) 166 | eta = 0 167 | 168 | if norm_type in ["regular", "no_proj"]: 169 | norm_f1 = f1.norm(dim=1) + 1e-8 170 | f1 = f1 / norm_f1[:,None] 171 | 172 | norm_f2 = f2.norm(dim=1) + 1e-8 173 | f2 = f2 / norm_f2[:,None] 174 | 175 | # Now we compute the alphas 176 | inner_prod = f1 @ f1.t() 177 | inner_prod_pos = (f1*f2).sum(dim=1) 178 | # Replace the diagnal with the inner product between f1 and f2 179 | inner_prod[label, label] = eta * inner_prod_pos 180 | 181 | # Avoid inf in exp. 182 | inner_prod_shifted = inner_prod - inner_prod.max(dim=1)[0][:,None] 183 | A_no_norm = (inner_prod_shifted/T).exp() 184 | A = A_no_norm / (A_no_norm.sum(dim=1) + 1e-8) 185 | 186 | B = 1 - A.diag() 187 | A[label, label] = 0 188 | 189 | # Then compute the matrix. 190 | if norm_type == "no_l2": 191 | y_neg, y_pos = get_y(x1, x2) 192 | C_inter = y_neg.t() @ (A.view(-1)[:,None] * y_neg) 193 | C_intra = y_pos.t() @ (B[:,None] * y_pos) 194 | 195 | elif norm_type == "no_proj": 196 | x1_normalized = x1 / norm_f1[:,None] 197 | x2_normalized = x2 / norm_f2[:,None] 198 | 199 | y_neg, y_pos = get_y(x1_normalized1, x2_normalized) 200 | C_inter = y_neg.t() @ (A.view(-1)[:,None] * y_neg) 201 | C_intra = y_pos.t() @ (B[:,None] * y_pos) 202 | 203 | elif norm_type == "regular": 204 | C_inter = torch.zeros(d, d) 205 | C_intra = torch.zeros(d, d) 206 | x1_normalized = x1 / norm_f1[:,None] 207 | x2_normalized = x2 / norm_f2[:,None] 208 | 209 | outers_neg = [] 210 | outers_pos = [] 211 | for i in range(N): 212 | outers_neg.append(torch.outer(x1_normalized[i,:], x1_normalized[i,:])) 213 | outers_pos.append(torch.outer(x2_normalized[i,:], x2_normalized[i,:])) 214 | 215 | for i in range(N): 216 | for j in range(i + 1, N): 217 | outer_ij = torch.outer(x1_normalized[i,:], x1_normalized[j,:]) 218 | term = (outers_neg[i] + outers_neg[j]) * inner_prod[i,j] - (outer_ij + outer_ij.t()) 219 | C_inter += (A[i,j] + A[j,i]) * term 220 | # if A[i,j] > 1e-3: 221 | # import pdb 222 | # pdb.set_trace() 223 | 224 | for i in range(N): 225 | outer = torch.outer(x1_normalized[i,:], x2_normalized[i,:]) 226 | term = (outers_neg[i] + outers_pos[i]) * inner_prod_pos[i] - (outer + outer.t()) 227 | C_intra += B[i] * term 228 | else: 229 | raise RuntimeError(f"Unknown norm_type = {norm_type}") 230 | 231 | return C_inter, C_intra, A, B 232 | 233 | 234 | @hydra.main(config_path="config", config_name="relu_2layer.yaml") 235 | def main(args): 236 | log.info(common_utils.print_info(args)) 237 | common_utils.set_all_seeds(args.seed) 238 | 239 | gen = Generator(args.distri) 240 | 241 | model = Model(args.distri.d, args.d_hidden, args.d_output, w1_bias=args.w1_bias, activation=args.activation, use_bn=args.use_bn) 242 | model.w1_clamp_all_negatives() 243 | 244 | if args.loss_type == "infoNCE": 245 | loss_func = nn.CrossEntropyLoss() 246 | elif args.loss_type == "quadratic": 247 | # Quadratic loss 248 | loss_func = lambda x, label: - (1 + 1 / x.size(0)) * x[torch.LongTensor(range(x.size(0))),label].mean() + x.mean() 249 | else: 250 | raise RuntimeError(f"Unknown loss_type = {loss_type}") 251 | 252 | label = torch.LongTensor(range(args.batchsize)) 253 | 254 | optimizer = torch.optim.SGD(model.parameters(), lr=args.opt.lr, momentum=args.opt.momentum, weight_decay=args.opt.wd) 255 | 256 | if args.l2_type == "regular": 257 | l2_reg = lambda x: F.normalize(x, dim=1) 258 | elif args.l2_type == "no_l2": 259 | l2_reg = lambda x: x 260 | else: 261 | raise RuntimeError(f"Unknown l2_type = {args.l2_type}") 262 | 263 | model_q = deque([deepcopy(model)]) 264 | 265 | for t in range(args.niter): 266 | optimizer.zero_grad() 267 | 268 | x1, x2, zs, _ = gen.sample(args.batchsize, seed = t % 500) 269 | # x1 += torch.randn(args.batchsize, x1.size(1)) * 0.001 270 | # x2 += torch.randn(args.batchsize, x2.size(1)) * 0.001 271 | 272 | f1 = model(x1) 273 | f2 = model(x2) 274 | 275 | # #batch x output_dim 276 | # Then we compute the infoNCE. 277 | z1 = l2_reg(f1) 278 | z2 = l2_reg(f2) 279 | 280 | if args.similarity == "dotprod": 281 | # nbatch x nbatch, minus pairwise distance, or inner_prod matrix. 282 | M = z1 @ z1.t() 283 | M[label,label] = (z1 * z2).sum(dim=1) 284 | elif args.similarity == "negdist": 285 | M = -pairwise_dist(z1) 286 | aug_dist = (z1 - z2).pow(2).sum(1) 287 | M[label, label] = -aug_dist 288 | # 1/2 distance matches with innerprod 289 | M = M / 2 290 | else: 291 | raise RuntimeError(f"Unknown similarity = {args.similarity}") 292 | 293 | loss = loss_func(M / args.T, label) 294 | if torch.any(loss.isnan()): 295 | log.info("Encounter NaN!") 296 | model = model_q.popleft() 297 | break 298 | 299 | if t % 500 == 0: 300 | log.info(f"[{t}] {loss.item()}") 301 | model_name = f"model-{t}.pth" 302 | log.info(f"Save to {model_name}") 303 | torch.save(model.state_dict(), model_name) 304 | 305 | with torch.no_grad(): 306 | C_inter, C_intra, A, B = compute_contrastive_covariance(f1, f2, x1, x2, args.T, args.l2_type) 307 | log.info(f"diag of C_inter: {C_inter.diag()}") 308 | log.info(f"diag of C_intra: {C_intra.diag()}") 309 | torch.save(dict(C_inter=C_inter, C_intra=C_intra, A=A, B=B, x1=x1, x2=x2, f1=f1.detach(), f2=f2.detach(), zs=zs), f"data-{t}.pth") 310 | 311 | loss.backward() 312 | 313 | optimizer.step() 314 | 315 | # normalization 316 | #if args.per_layer_normalize: 317 | if args.normalization == "perlayer": 318 | model.per_layer_normalize() 319 | elif args.normalization == "perfilter": 320 | model.per_filter_normalize() 321 | 322 | model.w1_clamp_all_negatives() 323 | 324 | model_q.append(deepcopy(model)) 325 | if len(model_q) >= 3: 326 | model_q.popleft() 327 | 328 | log.info(f"Final loss = {loss.item()}") 329 | log.info(f"Save to model-final.pth") 330 | torch.save(model.state_dict(), "model-final.pth") 331 | 332 | log.info(check_result(os.path.abspath("./"))) 333 | 334 | 335 | if __name__ == '__main__': 336 | main() 337 | -------------------------------------------------------------------------------- /ssl/real-dataset/requirement.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | alabaster==0.7.12 3 | antlr4-python3-runtime==4.8 4 | apipkg==1.5 5 | appdirs==1.4.4 6 | ase==3.21.1 7 | astroid==2.5.6 8 | attrs==20.2.0 9 | autopep8==1.5.7 10 | bayesmark==0.0.8 11 | bertviz==1.0.0 12 | blis==0.7.4 13 | boto3==1.17.22 14 | botocore==1.20.22 15 | cachetools==4.2.2 16 | catalogue==2.0.1 17 | certifi==2020.12.5 18 | cfgv==3.2.0 19 | chardet==3.0.4 20 | click==7.1.2 21 | cloudpickle==1.6.0 22 | cma==3.0.3 23 | colorama==0.4.4 24 | coverage==5.5 25 | cycler==0.10.0 26 | cymem==2.0.5 27 | decorator==4.3.0 28 | dill==0.3.3 29 | dimod==0.9.14 30 | distlib==0.3.1 31 | docutils==0.17.1 32 | entrypoints==0.3 33 | execnet==1.8.0 34 | fasteners==0.15 35 | filelock==3.0.12 36 | future==0.18.2 37 | gitdb==4.0.7 38 | GitPython==3.1.14 39 | GPUtil==1.4.0 40 | grpcio==1.33.2 41 | gym==0.18.0 42 | hydra-core==1.0.6 43 | identify==2.2.4 44 | idna==2.10 45 | imagesize==1.2.0 46 | iniconfig==1.0.1 47 | isodate==0.6.0 48 | isort==4.3.21 49 | jmespath==0.10.0 50 | joblib==1.0.1 51 | kiwisolver==1.3.1 52 | lazy-object-proxy==1.6.0 53 | livereload==2.6.3 54 | matplotlib==3.3.4 55 | mccabe==0.6.1 56 | mistune==0.8.4 57 | mkl-fft==1.2.0 58 | mkl-random==1.1.1 59 | mkl-service==2.3.0 60 | monotonic==1.5 61 | murmurhash==1.0.5 62 | networkx==2.5 63 | nodeenv==1.6.0 64 | numba==0.53.0 65 | numpy==1.19.3 66 | oauthlib==3.1.1 67 | olefile==0.46 68 | omegaconf==2.0.6 69 | or-gym==0.2.0 70 | packaging==20.4 71 | pandas==1.1.4 72 | pathspec==0.8.1 73 | pathvalidate==2.4.1 74 | pathy==0.4.0 75 | Pillow==7.2.0 76 | pluggy==0.13.1 77 | pre-commit==2.12.1 78 | preshed==3.0.5 79 | protobuf==3.13.0 80 | psutil==5.4.5 81 | py==1.9.0 82 | py-cpuinfo==8.0.0 83 | pyasn1==0.4.8 84 | pyasn1-modules==0.2.8 85 | pycodestyle==2.7.0 86 | pydantic==1.8.1 87 | pyglet==1.5.0 88 | pylint==2.8.2 89 | pytz==2018.3 90 | PyYAML==5.4.1 91 | pyzmq==20.0.0 92 | rdflib==5.0.0 93 | regex==2020.11.13 94 | requests==2.24.0 95 | requests-oauthlib==1.3.0 96 | rsa==4.7.2 97 | s3transfer==0.3.4 98 | sacremoses==0.0.43 99 | scikit-learn==0.24.1 100 | scipy==1.6.0 101 | sentencepiece==0.1.95 102 | sip==4.19.13 103 | sklearn==0.0 104 | smart-open==3.0.0 105 | smmap==4.0.0 106 | snowballstemmer==2.1.0 107 | spacy==3.0.3 108 | spacy-legacy==3.0.1 109 | srsly==2.4.0 110 | tabulate==0.8.7 111 | tensorboard==2.5.0 112 | tensorboard-data-server==0.6.1 113 | tensorboard-plugin-wit==1.8.0 114 | tensorboardX==2.2 115 | termcolor==1.1.0 116 | terminado==0.9.2 117 | thinc==8.0.1 118 | threadpoolctl==2.1.0 119 | tokenizers==0.10.1 120 | toml==0.10.2 121 | torch==1.7.1+cu101 122 | torch-cluster==1.5.9 123 | torch-geometric==1.6.3 124 | torch-scatter==2.0.6 125 | torch-sparse==0.6.9 126 | torch-spline-conv==1.2.1 127 | torchaudio==0.7.2 128 | torchtext==0.8.1 129 | torchvision==0.8.2+cu101 130 | tqdm==4.38.0 131 | typed-ast==1.4.3 132 | typer==0.3.2 133 | urllib3==1.25.11 134 | virtualenv==20.4.6 135 | wasabi==0.8.2 136 | webencodings==0.5.1 137 | wrapt==1.12.1 138 | xarray==0.17.0 139 | zipp==3.2.0 140 | -------------------------------------------------------------------------------- /ssl/real-dataset/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import random 6 | import common_utils 7 | import hydra 8 | 9 | import os 10 | 11 | import logging 12 | log = logging.getLogger(__file__) 13 | 14 | class Model(nn.Module): 15 | def __init__(self, M, L, d): 16 | super(Model, self).__init__() 17 | self.M = M 18 | self.embedding = nn.Embedding(M + 1, d, max_norm=1) 19 | self.mask_token = M 20 | 21 | self.positional_embedding = nn.Embedding(L, d, max_norm=1) 22 | self.d = d 23 | self.L = L 24 | 25 | def forward(self, x, mask_idx): 26 | # x is size (bs, L) of type LongTensor, L is the length of the seq 27 | # mask_idx is size (bs) of type LongTensor, for each sample, the curresponding tokens are masked and need to be reconstructed. 28 | x_input = x.clone() 29 | x_input.scatter_(1, mask_idx.unsqueeze(1), self.mask_token) 30 | 31 | locs = torch.arange(self.L).to(x.device) 32 | tokens = torch.arange(self.M).to(x.device) 33 | 34 | # of size [bs, L, d] 35 | content_input = self.embedding(x_input) 36 | 37 | # Do self-attention (bs, L, L) 38 | # No Wk and Wq for now 39 | attentions = torch.bmm(content_input, content_input.permute(0, 2, 1)) 40 | 41 | # [L, d] 42 | pos_input = self.positional_embedding(locs) 43 | attentions = attentions.detach() + (pos_input @ pos_input.t()).unsqueeze(0) 44 | attentions = F.softmax(attentions / math.sqrt(2*self.d), dim=2) 45 | 46 | # output of size (bs, L, d) 47 | output = torch.bmm(attentions, content_input) 48 | 49 | # [bs, d] 50 | sel_output = output.gather(1, mask_idx.unsqueeze(1).unsqueeze(2).expand(-1, -1, self.d)).squeeze() 51 | 52 | # Then we compute the inner product with all embeddings. 53 | # [bs, M] 54 | inner_prod = sel_output @ self.embedding(tokens).t() # / math.sqrt(self.d) 55 | target = x.gather(1, mask_idx.unsqueeze(1)).squeeze() 56 | loss = F.nll_loss(F.log_softmax(inner_prod, dim=1), target) 57 | 58 | ''' 59 | # [Update] we compute the inner product with all embeddings within the sequence. 60 | # [bs, L] 61 | inner_prod = torch.bmm(self.embedding(x), sel_output.unsqueeze(2)).squeeze(2) # / math.sqrt(self.d) 62 | loss = F.nll_loss(F.log_softmax(inner_prod, dim=1), mask_idx) 63 | ''' 64 | 65 | # gt_output = self.embedding(target) 66 | 67 | return loss, sel_output 68 | 69 | class Dataset: 70 | def __init__(self, M, L, seg_len): 71 | # Number of tokens 72 | self.M = M 73 | 74 | # Generate a bunch of random classes 75 | self.nclass = 2 76 | self.classes = [] 77 | seg = M // self.nclass 78 | for i in range(self.nclass): 79 | self.classes.append(list(range(i * seg, (i + 1) * seg))) 80 | 81 | self.seg_len = seg_len 82 | self.L = L 83 | 84 | def generate(self, batchsize): 85 | x = torch.LongTensor(batchsize, self.L) 86 | for i in range(batchsize): 87 | start = 0 88 | while start < self.L: 89 | # sample seg length. 90 | this_seg_len = random.randint(1, min(self.seg_len, self.L - start)) 91 | 92 | # pick a class 93 | class_id = random.randint(0, self.nclass - 1) 94 | # random choose tokens from the class. 95 | x[i, start:start+this_seg_len] = torch.LongTensor(random.choices(self.classes[class_id], k=this_seg_len)) 96 | # j*self.seg_len:(j+1)*self.seg_len] 97 | start += this_seg_len 98 | 99 | return x 100 | 101 | @hydra.main(config_path="config", config_name="sa.yaml") 102 | def main(args): 103 | log.info(common_utils.print_info(args)) 104 | common_utils.set_all_seeds(args.seed) 105 | 106 | dataset = Dataset(args.M, args.L, seg_len=3) 107 | model = Model(args.M, args.L, args.d) 108 | 109 | optimizer = torch.optim.SGD(model.parameters(), lr=args.opt.lr, momentum=args.opt.momentum, weight_decay=args.opt.wd) 110 | 111 | for t in range(args.niter): 112 | optimizer.zero_grad() 113 | 114 | x = dataset.generate(args.batchsize) 115 | # Randomly mask some entry 116 | mask = torch.LongTensor(random.choices(list(range(x.size(1))), k=args.batchsize)) 117 | 118 | loss, _ = model(x, mask) 119 | if t % 100 == 0: 120 | log.info(f"[{t}] loss: {loss.detach().cpu().item()}") 121 | 122 | loss.backward() 123 | optimizer.step() 124 | 125 | #import pdb 126 | #pdb.set_trace() 127 | 128 | log.info("Embedding:") 129 | log.info(model.embedding.weight) 130 | # log.info(model.embedding.weight @ model.embedding.weight.t()) 131 | 132 | log.info("Positional Embedding:") 133 | log.info(model.positional_embedding.weight) 134 | # log.info(model.positional_embedding.weight @ model.positional_embedding.weight.t()) 135 | 136 | torch.save(model.state_dict(), "final.pth") 137 | 138 | log.info(os.getcwd()) 139 | 140 | if __name__ == '__main__': 141 | main() -------------------------------------------------------------------------------- /ssl/real-dataset/self_attention_linear_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import random 6 | import common_utils 7 | import hydra 8 | 9 | import os 10 | 11 | import logging 12 | log = logging.getLogger(__file__) 13 | 14 | 15 | class LinearModel(nn.Module): 16 | def __init__(self, L): 17 | super(LinearModel, self).__init__() 18 | self.model = nn.Linear(L, 1) 19 | self.L = L 20 | 21 | def forward(self, x): 22 | # x is size (bs, L) of type LongTensor, L is the length of the seq 23 | xx = (x != self.L).float() 24 | return self.model(xx).squeeze() 25 | 26 | 27 | class Model(nn.Module): 28 | def __init__(self, M, L, d): 29 | super(Model, self).__init__() 30 | self.M = M 31 | self.embed = nn.Embedding(M, d, max_norm=1) 32 | 33 | self.use_WkWq = False 34 | 35 | if self.use_WkWq: 36 | self.Wk = nn.Linear(d, 2*d, bias=False) 37 | self.Wq = nn.Linear(d, 2*d, bias=False) 38 | 39 | self.V_dim = 1 40 | self.V = nn.Embedding(M, self.V_dim) 41 | self.d = d 42 | self.L = L 43 | 44 | def forward(self, x): 45 | # x is size (bs, L) of type LongTensor, L is the length of the seq 46 | x_input = x.clone() 47 | 48 | # of size [bs, L, d] 49 | embed = self.embed(x_input) 50 | 51 | if self.use_WkWq: 52 | Q_sel = self.Wq(embed) 53 | K_sel = self.Wk(embed) 54 | else: 55 | Q_sel = embed 56 | K_sel = embed 57 | 58 | # of size [bs, L, V_dim] 59 | V_sel = self.V(x_input) 60 | 61 | # Do self-attention (bs, L, L) 62 | # No Wk and Wq for now 63 | attentions = torch.bmm(Q_sel, K_sel.permute(0, 2, 1)) 64 | 65 | # [L, d] 66 | # locs = torch.arange(self.L).to(x.device) 67 | # pos_input = self.positional_embedding(locs) 68 | # attentions = attentions.detach() + (pos_input @ pos_input.t()).unsqueeze(0) 69 | attentions = F.softmax(attentions / math.sqrt(2*self.d), dim=2) 70 | 71 | # attention size = [bs, L, L] 72 | # V_sel size = [bs, L, V_dim] 73 | # output size = [bs, L, V_dim] 74 | output = torch.bmm(attentions, V_sel) 75 | return output 76 | 77 | class Dataset: 78 | def __init__(self, L): 79 | # M = number of tokens 80 | # L = length of the seq 81 | self.L = L 82 | self.M = L + 1 83 | 84 | # Define a binary distribution and correlation between the label and the token 85 | self.probs_pos = torch.ones(self.L) * 0.55 86 | self.probs_pos[0] = 0.95 87 | self.probs_pos[1] = 0.95 88 | self.probs_neg = torch.ones(self.L) * 0.5 89 | 90 | def generate(self, batchsize): 91 | x = torch.LongTensor(batchsize, self.L) 92 | label = torch.LongTensor(batchsize) 93 | for i in range(batchsize): 94 | if random.random() > 0.5: 95 | # Positive 96 | probs = self.probs_pos 97 | label[i] = 1 98 | else: 99 | # Negative. 100 | probs = self.probs_neg 101 | label[i] = 0 102 | 103 | # try each token to see whether they appear. 104 | for j in range(self.L): 105 | if random.random() < probs[j]: 106 | x[i,j] = j 107 | else: 108 | # Last token 109 | x[i,j] = self.M - 1 110 | 111 | return x, label 112 | 113 | @hydra.main(config_path="config", config_name="sa_linear.yaml") 114 | def main(args): 115 | log.info(common_utils.print_info(args)) 116 | common_utils.set_all_seeds(args.seed) 117 | 118 | dataset = Dataset(args.L) 119 | model = Model(dataset.M, dataset.L, args.d) 120 | model_linear = LinearModel(args.L) 121 | 122 | optimizer = torch.optim.SGD(model.parameters(), lr=args.opt.lr, momentum=args.opt.momentum, weight_decay=args.opt.wd) 123 | optimizer_linear = torch.optim.SGD(model_linear.parameters(), lr=args.opt.lr, momentum=args.opt.momentum, weight_decay=args.opt.wd) 124 | 125 | loss_func = torch.nn.BCELoss() 126 | 127 | for t in range(args.niter): 128 | optimizer.zero_grad() 129 | optimizer_linear.zero_grad() 130 | 131 | x, label = dataset.generate(args.batchsize) 132 | output = model(x) 133 | output_linear = model_linear(x) 134 | 135 | single_output = output.squeeze().sum(dim=1) 136 | loss = loss_func(F.sigmoid(single_output), label.float()) 137 | 138 | loss_linear = loss_func(F.sigmoid(output_linear), label.float()) 139 | 140 | if t % 100 == 0: 141 | log.info(f"[{t}] loss: {loss.detach().cpu().item()}") 142 | log.info(f"[{t}] loss_linear: {loss_linear.detach().cpu().item()}") 143 | 144 | loss.backward() 145 | optimizer.step() 146 | 147 | loss_linear.backward() 148 | optimizer_linear.step() 149 | 150 | #import pdb 151 | #pdb.set_trace() 152 | 153 | # log.info("Embedding K:") 154 | # log.info(model.K.weight) 155 | 156 | # log.info("Embedding Q:") 157 | # log.info(model.Q.weight) 158 | # # log.info(model.embedding.weight @ model.embedding.weight.t()) 159 | 160 | # import pdb 161 | # pdb.set_trace() 162 | 163 | torch.save(dict(model=model.state_dict(), model_linear=model_linear.state_dict()), "final.pth") 164 | 165 | log.info(os.getcwd()) 166 | 167 | if __name__ == '__main__': 168 | main() -------------------------------------------------------------------------------- /ssl/real-dataset/simclr_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataloader import DataLoader 3 | from torch.utils.tensorboard import SummaryWriter 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | from loss.nt_xent import NTXentLoss 7 | import re 8 | import os 9 | import shutil 10 | import random 11 | import sys 12 | 13 | import numpy as np 14 | 15 | import logging 16 | log = logging.getLogger(__file__) 17 | 18 | # customized l2 normalization 19 | class SpecializedL2Regularizer(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, input): 22 | assert len(input.size()) == 2 23 | l2_norms = input.pow(2).sum(dim=1, keepdim=True).sqrt().add(1e-8) 24 | ctx.l2_norms = l2_norms 25 | return input / l2_norms 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | grad_input = grad_output / ctx.l2_norms 30 | return grad_input 31 | 32 | class SimCLRTrainer(object): 33 | def __init__(self, log_dir, model, optimizer, evaluator, device, params): 34 | self.model = model 35 | self.params = params 36 | self.device = device 37 | self.optimizer = optimizer 38 | self.evaluator = evaluator 39 | self.writer = SummaryWriter(log_dir) 40 | self.params = params 41 | self.nt_xent_criterion = NTXentLoss(self.device, params['batch_size'], params['nce_loss']) 42 | 43 | l2_reg_type = self.params['l2_reg_type'] 44 | log.info(f"SimCLRTrainer: l2_reg_type: {l2_reg_type}") 45 | 46 | if l2_reg_type == "regular": 47 | self.l2_normalizer = lambda x: F.normalize(x, dim=1) 48 | elif l2_reg_type == "only_mag": 49 | self.l2_normalizer = SpecializedL2Regularizer.apply 50 | elif l2_reg_type == "no_reg": 51 | self.l2_normalizer = lambda x: x 52 | else: 53 | raise RuntimeError(f"Unknown l2_reg_type = {l2_reg_type}") 54 | 55 | def _step(self, model, xis, xjs, xs, n_iter): 56 | 57 | # get the representations and the projections 58 | zis = model(xis) # [N,C] 59 | 60 | # get the representations and the projections 61 | zjs = model(xjs) # [N,C] 62 | 63 | # normalize projection feature vectors 64 | zis = self.l2_normalizer(zis) 65 | zjs = self.l2_normalizer(zjs) 66 | 67 | if xs is not None: 68 | # Unaugmented datapoint. 69 | zs = model(xs) 70 | zs = self.l2_normalizer(zs) 71 | else: 72 | zs = None 73 | 74 | loss, loss_intra, negative_sim = self.nt_xent_criterion(zis, zjs, zs) 75 | return loss, loss_intra, zis, zjs, negative_sim 76 | 77 | def train(self, train_dataset): 78 | train_loader = DataLoader(train_dataset, batch_size=self.params["batch_size"] * (torch.cuda.device_count() if torch.cuda.is_available() else 1), 79 | num_workers=self.params["num_workers"], drop_last=True, shuffle=False) 80 | 81 | model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') 82 | if not os.path.exists(model_checkpoints_folder): 83 | os.mkdir(model_checkpoints_folder) 84 | 85 | self.save_model(os.path.join(model_checkpoints_folder, 'model_000.pth')) 86 | 87 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(train_loader), eta_min=0, 88 | last_epoch=-1) 89 | 90 | margin = self.params["grad_combination_margin"] 91 | if margin is not None: 92 | matcher = re.compile(r"encoder.(\d+)") 93 | layers = dict() 94 | for name, _ in self.model.named_parameters(): 95 | # print(f"{name}: {params.size()}") 96 | m = matcher.match(name) 97 | if m is None: 98 | l = 10 99 | else: 100 | l = int(m.group(1)) 101 | layers[name] = l 102 | unique_entries = sorted(list(set(layers.values()))) 103 | series = np.linspace(margin, 1 - margin, len(unique_entries)) 104 | l2ratio = dict(zip(unique_entries, series)) 105 | layer2ratio = { name : l2ratio[l] for name, l in layers.items() } 106 | 107 | log.info(f"Gradient margin: {margin}") 108 | for name, r in layer2ratio.items(): 109 | log.info(f" {name}: {r}") 110 | else: 111 | log.info("No gradient margin") 112 | 113 | n_iter = 0 114 | alpha = self.params["noise_blend"] 115 | 116 | for epoch_counter in range(self.params['max_epochs']): 117 | loss_record = [] 118 | suffix = str(epoch_counter).zfill(3) 119 | 120 | # Add noise to weight once in a while 121 | if alpha > 0: 122 | for name, p in self.model.named_parameters(): 123 | with torch.no_grad(): 124 | if len(p.size()) < 2: 125 | continue 126 | w = torch.zeros_like(p, device=p.get_device()) 127 | torch.nn.init.xavier_uniform_(w) 128 | p[:] = (1 - alpha) * p[:] + alpha * w 129 | 130 | for (xis, xjs, xs), _ in train_loader: 131 | xis = xis.to(self.device) 132 | xjs = xjs.to(self.device) 133 | 134 | if self.nt_xent_criterion.need_unaug_data(): 135 | xs = xs.to(self.device) 136 | else: 137 | xs = None 138 | 139 | loss, loss_intra, zis, zjs, negative_sim = self._step(self.model, xis, xjs, xs, n_iter) 140 | 141 | # if n_iter % self.params['log_every_n_steps'] == 0: 142 | # self.writer.add_scalar('train_loss', loss, global_step=n_iter) 143 | 144 | all_loss = loss + loss_intra 145 | loss_record.append(all_loss.item()) 146 | 147 | if margin is not None: 148 | # Here we do backward twice for each loss and weight the gradient at different layer differently. 149 | self.optimizer.zero_grad() 150 | loss.backward(retain_graph=True) 151 | 152 | inter_grads = dict() 153 | for name, p in self.model.named_parameters(): 154 | # print(f"{name}: {p.size()}") 155 | inter_grads[name] = p.grad.clone() 156 | 157 | self.optimizer.zero_grad() 158 | loss_intra.backward() 159 | for name, p in self.model.named_parameters(): 160 | r = layer2ratio[name] 161 | # Lower layer -> high ratio of loss_intra 162 | p.grad *= (1 - r) 163 | p.grad += inter_grads[name] * r 164 | else: 165 | self.optimizer.zero_grad() 166 | all_loss.backward() 167 | 168 | self.optimizer.step() 169 | 170 | n_iter += 1 171 | xs = torch.cat([xjs, xis], dim=0) 172 | zs = torch.cat([zjs.detach(), zis.detach()], dim=0) 173 | self.model.special_call(xs, zs, negative_sim, n_iter) 174 | 175 | # warmup for the first 10 epochs 176 | if epoch_counter >= 10: 177 | scheduler.step() 178 | self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter) 179 | 180 | log.info(f"Epoch {epoch_counter}: numIter: {n_iter} Loss: {np.mean(loss_record)}") 181 | if self.evaluator is not None: 182 | best_acc = self.evaluator.eval_model(deepcopy(self.model)) 183 | log.info(f"Epoch {epoch_counter}: best_acc: {best_acc}") 184 | 185 | if epoch_counter % self.params["save_per_epoch"] == 0: 186 | # save checkpoints 187 | self.save_model(os.path.join(model_checkpoints_folder, f'model_{suffix}.pth')) 188 | 189 | def save_model(self, PATH): 190 | torch.save({ 191 | 'online_network_state_dict': self.model.state_dict(), 192 | 'optimizer_state_dict': self.optimizer.state_dict(), 193 | }, PATH) 194 | -------------------------------------------------------------------------------- /ssl/real-dataset/test.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | 4 | @hydra.main(config_path="./config", config_name="test.yaml") 5 | def main(args): 6 | print(f"Seed: {args.seed}, Cuda #device: {torch.cuda.device_count()}") 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /ssl/real-dataset/test2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class SpecializedL2Regularizer(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, input): 7 | assert len(input.size()) == 2 8 | l2_norms = input.pow(2).sum(dim=1, keepdim=True).sqrt().add(1e-8) 9 | ctx.l2_norms = l2_norms 10 | return input / l2_norms 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | grad_input = grad_output / ctx.l2_norms 15 | return grad_input 16 | 17 | a = torch.randn(1, 2, requires_grad=True) 18 | print(a) 19 | z = a.norm(dim=1).detach() 20 | print(z) 21 | 22 | reg = SpecializedL2Regularizer.apply 23 | # reg = lambda x: F.normalize(x, dim=1) 24 | b = reg(a) 25 | b.pow(2).sum().backward() 26 | 27 | print(a.grad) 28 | 29 | print(2 * a / z / z) 30 | -------------------------------------------------------------------------------- /ssl/real-dataset/try_relu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import torch.optim as optim 5 | import numpy as np 6 | 7 | class Model(nn.Module): 8 | def __init__(self, d, num_hidden): 9 | super(Model, self).__init__() 10 | 11 | self.extractor = nn.Linear(d, num_hidden) 12 | self.relu = nn.ReLU() 13 | self.predictor = nn.Linear(num_hidden, num_hidden) 14 | 15 | def forward(self, x1, x2): 16 | output1 = self.relu(self.extractor(x1)) 17 | output2 = self.relu(self.extractor(x2)) 18 | 19 | # Directly send the output to predictor and compute loss. 20 | branch1 = self.predictor(output1) 21 | branch2 = output2.detach() 22 | 23 | branch1 = branch1 / branch1.norm(dim=1, keepdim=True) 24 | branch2 = branch2 / branch2.norm(dim=1, keepdim=True) 25 | 26 | return branch1, branch2 27 | 28 | class Generator: 29 | def __init__(self, d): 30 | # check ReLU case. 31 | self.d = d 32 | 33 | # d choose 2, create an over-complete basis. 34 | self.K = self.d * (self.d - 1) // 2 35 | self.Zs = torch.zeros(self.d, self.K) 36 | cnt = 0 37 | for i in range(d): 38 | for j in range(i + 1, d): 39 | self.Zs[i, cnt] = 1 40 | self.Zs[j, cnt] = 1 41 | cnt += 1 42 | 43 | # Treat the second part as the noisy component. 44 | self.signal_part = self.Zs[:, self.K//2:] 45 | self.noise_part = self.Zs[:, :self.K//2] 46 | 47 | print("Signal part:") 48 | print(self.signal_part) 49 | 50 | print("Noise part:") 51 | print(self.noise_part) 52 | 53 | def generate(self, batchsize): 54 | indices = list(range(self.K // 2)) 55 | 56 | random.shuffle(indices) 57 | x = self.signal_part[:, indices[:batchsize]].t() 58 | 59 | random.shuffle(indices) 60 | x1 = x + 0.5*self.noise_part[:, indices[:batchsize]].t() 61 | 62 | random.shuffle(indices) 63 | x2 = x + 0.5*self.noise_part[:, indices[:batchsize]].t() 64 | 65 | return x1, x2 66 | 67 | 68 | # construct a very simple neural network and train with DirectPred 69 | d = 10 70 | 71 | generator = Generator(d) 72 | num_hidden = 2 * generator.K 73 | model = Model(d, num_hidden) 74 | 75 | with torch.no_grad(): 76 | model.extractor.bias[:] = - 0.05 * torch.rand(num_hidden) 77 | 78 | loss_func = nn.MSELoss() 79 | 80 | optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.0) 81 | 82 | batchsize = 8 83 | 84 | for t in range(10000): 85 | optimizer.zero_grad() 86 | 87 | x1, x2 = generator.generate(batchsize) 88 | branch1, branch2 = model(x1, x2) 89 | 90 | loss = loss_func(branch1, branch2) 91 | if t % 100 == 0: 92 | print(f"{t}: {loss.item()}") 93 | loss.backward() 94 | optimizer.step() 95 | 96 | # Check whether the extractor weight is aligned with the features. 97 | # #type of input x #response of nodes. 98 | response_signal = model.extractor(generator.signal_part.t()) 99 | response_noise = model.extractor(generator.noise_part.t()) 100 | 101 | response_signal = (response_signal >= 0).float() 102 | response_noise = (response_noise >= 0).float() 103 | 104 | signal_noise_diff = response_signal.mean(dim=0) - response_noise.mean(dim=0) 105 | 106 | all_response = torch.cat([response_signal, response_noise], dim=0) 107 | sum_response = all_response.sum(dim=0) 108 | one_to_other_diff = all_response - sum_response[None, :] / (sum_response.size(0) - 1) 109 | 110 | print("signal_noise_diff") 111 | print(signal_noise_diff) 112 | 113 | print("one_to_other_diff") 114 | print(one_to_other_diff) 115 | 116 | print(model.extractor) 117 | import pdb 118 | pdb.set_trace() 119 | -------------------------------------------------------------------------------- /ssl/real-dataset/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | 4 | 5 | def _create_model_training_folder(writer, files_to_same): 6 | model_checkpoints_folder = os.path.join(writer.log_dir, 'checkpoints') 7 | root_dir = os.path.dirname(os.path.abspath(__file__)) 8 | if not os.path.exists(model_checkpoints_folder): 9 | os.makedirs(model_checkpoints_folder) 10 | for file in files_to_same: 11 | src = os.path.join(root_dir, file) 12 | dst = os.path.join(model_checkpoints_folder, os.path.basename(file)) 13 | copyfile(src, dst) 14 | -------------------------------------------------------------------------------- /student_specialization/README.md: -------------------------------------------------------------------------------- 1 | # Alignment of Student and Teacher nodes in deep ReLU network 2 | Code of ArXiv [paper](https://arxiv.org/abs/1909.13458). 3 | 4 | 5 | # Required package 6 | Install [hydra](https://github.com/facebookresearch/hydra) by following its instructions. 7 | 8 | Install Pytorch and other packages (yaml, json, matplotlib). 9 | 10 | 11 | # Usage 12 | 13 | ## Two-layer 14 | 15 | For two-layer results, you can run the following command to sweep 72 jobs and use them to draw figures. 16 | 17 | ``` 18 | python recon_two_layer.py -m multi=1,2,5,10 d=100 m=5,10,20 teacher_strength_decay=0,0.5,1,1.5,2,2.5 lr=0.01 use_sgd=true N_train=10000 num_epoch=100 batchsize=16 num_iter_per_epoch=1000 normalize=false 19 | ``` 20 | 21 | Once it is done, run the following visualization code to replicate the figures shown in the paper: 22 | ``` 23 | python ./visualization/visualize.py [you sweep folder] 24 | ``` 25 | 26 | It will save three figures in the current folder. 27 | 28 | ## Multi-layer 29 | 30 | Use the following command: 31 | ``` 32 | python recon_multilayer.py seed=2351 stats_H=true num_trial=1 num_epoch=100 random_dataset_size=200000 33 | ``` 34 | 35 | Once it is done, run the following visualiztion code: 36 | ``` 37 | python ./visualization/visualize_multi.py [your saved folder] 38 | ``` 39 | 40 | 41 | -------------------------------------------------------------------------------- /student_specialization/conf/config.yaml: -------------------------------------------------------------------------------- 1 | num_iter_per_epoch: 100 2 | num_epoch: 100 3 | multi: 10 4 | batchsize: 8 5 | N_train: 10000 6 | N_eval: 10000 7 | m: 10 8 | lr: 0.01 9 | use_sgd: false 10 | regen_dataset: false 11 | seed: [1, 32] 12 | bias: -0.1 13 | no_bias: false 14 | normalize: false 15 | data_std: 1.5 16 | c: 50 17 | d: 30 18 | teacher_strength_decay: 1 19 | feature_fixed: false 20 | top_layer_fixed: false 21 | teacher_scale: 0.1 22 | student_scale: 0.005 23 | adv_init: none 24 | nonlinear: true 25 | githash: "" 26 | theory_suggest_train: false 27 | lr_reduction: 0 28 | theory_suggest_sigma: 1 29 | theory_suggest_mean: 1 30 | 31 | -------------------------------------------------------------------------------- /student_specialization/conf/config_multilayer.yaml: -------------------------------------------------------------------------------- 1 | node_multi: 10 2 | optim_method: sgd 3 | stats_grad_norm: false 4 | lr: "{0:0.2,20:0.1}" 5 | #lr: "{0:0.001,5:0.002,10:0.005,30:0.01,70:0.001}" 6 | #lr: "{0:0.01}" 7 | weight_choices: [-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0] 8 | data_d: 20 9 | data_std: 10.0 10 | num_epoch: 100 11 | num_trial: 10 12 | batchsize: 64 13 | eval_batchsize: 64 14 | random_dataset_size: 2000000 15 | # If total_bp_iters > 0, then num_epoch = total_bp_iters / random_dataset_size 16 | total_bp_iters: 0 17 | momentum: 0.0 18 | weight_decay: 0 19 | json_output: false 20 | cross_entropy: false 21 | seed: 1 22 | leaky_relu: 23 | perturb: 24 | same_dir: false 25 | same_sign: false 26 | normalize: false 27 | dataset: "gaussian" # [choice between "cifar", "mnist" or "gaussian" 28 | no_bias: false 29 | load_student: 30 | load_teacher: 31 | d_output: 0 32 | ks: [50,75,100,125] 33 | # ks: [10,10,20,20,30,30] 34 | # ks: [50,50,75,75,75,75,100,125] 35 | bn_affine: false 36 | bn: false 37 | no_sep: false 38 | teacher_bn_affine: false 39 | teacher_bn: false 40 | stats_H: false 41 | stats_w: false 42 | use_cnn: false 43 | bn_before_relu: false 44 | regen_dataset_each_epoch: false 45 | stats_teacher: false 46 | stats_student: false 47 | stats_teacher_h: false 48 | stats_student_h: false 49 | teacher_bias_tune: true 50 | teacher_bias_last_layer_tune: true 51 | teacher_strength_decay: 0 52 | student_scale_down: 0.1 53 | data_dir: /checkpoint/yuandong 54 | githash: "" 55 | num_epoch_save_summary: 10 56 | -------------------------------------------------------------------------------- /student_specialization/conf/hydra/launcher/fairtask.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | launcher: 3 | class: hydra_plugins.fairtask.FAIRTaskLauncher 4 | params: 5 | # debug launching issues, set to true to run workers in the same process. 6 | no_workers: false 7 | queue: slurm 8 | queues: 9 | local: 10 | class: fairtask.local.LocalQueueConfig 11 | params: 12 | num_workers: 2 13 | slurm: 14 | class: fairtask_slurm.slurm.SLURMQueueConfig 15 | params: 16 | num_jobs: ${hydra.job.num_jobs} 17 | num_nodes_per_job: 1 18 | num_workers_per_node: 1 19 | name: ${hydra.job.name} 20 | maxtime_mins: 4320 21 | partition: priority 22 | cpus_per_worker: 10 23 | mem_gb_per_worker: 500 24 | gres: 'gpu:1' 25 | log_directory: ${hydra.sweep.dir}/.slurm 26 | output: slurm-%j.out 27 | error: slurm-%j.err 28 | comment: ICLR deadline 29 | 30 | -------------------------------------------------------------------------------- /student_specialization/conf/hydra/launcher/submitit.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | launcher: 3 | class: hydra_plugins.submitit.SubmititLauncher 4 | params: 5 | # one of auto,local,slurm and chronos 6 | queue: slurm 7 | 8 | folder: ${hydra.sweep.dir}/.${hydra.launcher.params.queue} 9 | queue_parameters: 10 | # slrum queue parameters 11 | slurm: 12 | nodes: 1 13 | num_gpus: 1 14 | ntasks_per_node: 1 15 | mem: ${hydra.launcher.mem_limit}GB 16 | cpus_per_task: 1 17 | time: 2880 18 | partition: learnfair 19 | signal_delay_s: 120 20 | comment: ICLRdeadline 21 | # chronos queue parameters 22 | chronos: 23 | # See crun documentation for most parameters 24 | # https://our.internmc.facebook.com/intern/wiki/Chronos-c-binaries/crun/ 25 | hostgroup: fblearner_ash_bigsur_fair 26 | cpu: 10 27 | mem: ${hydra.launcher.mem_limit} 28 | gpu: 1 29 | # local queue parameters 30 | local: 31 | gpus_per_node: 1 32 | tasks_per_node: 1 33 | timeout_min: 2880 34 | 35 | # variables used by queues above 36 | mem_limit: 24 37 | 38 | -------------------------------------------------------------------------------- /student_specialization/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.datasets as datasets 4 | from torchvision import transforms 5 | from torch.utils.data.dataset import Dataset 6 | 7 | class RandomDataset(Dataset): 8 | def __init__(self, N, d, std, noise_type="gaussian"): 9 | super(RandomDataset, self).__init__() 10 | self.d = d 11 | self.std = std 12 | self.N = N 13 | self.noise_type = noise_type 14 | self.regenerate() 15 | 16 | def regenerate(self): 17 | self.x = torch.FloatTensor(self.N, *self.d) 18 | if self.noise_type == "gaussian": 19 | self.x.normal_(0, std=self.std) 20 | elif self.noise_type == "uniform": 21 | self.x.uniform_(-self.std / 2, self.std / 2) 22 | else: 23 | raise NotImplementedError(f"Unknown noise type: {self.noise_type}") 24 | 25 | def __getitem__(self, idx): 26 | return self.x[idx], -1 27 | 28 | def __len__(self): 29 | return self.N 30 | 31 | def init_dataset(args): 32 | transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize( 35 | (0.5,), (0.5,))]) 36 | 37 | transform_cifar10_train = transforms.Compose([ 38 | transforms.RandomCrop(32, padding=4), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 42 | ]) 43 | 44 | transform_cifar10_test = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 47 | ]) 48 | 49 | if args.dataset == "gaussian" or args.dataset == "uniform": 50 | if args.use_cnn: 51 | d = (1, 16, 16) 52 | else: 53 | d = (args.data_d,) 54 | d_output = 100 55 | train_dataset = RandomDataset(args.random_dataset_size, d, args.data_std, noise_type=args.dataset) 56 | eval_dataset = RandomDataset(10240, d, args.data_std, noise_type=args.dataset) 57 | 58 | elif args.dataset == "mnist": 59 | train_dataset = datasets.MNIST( 60 | root=args.data_dir, train=True, download=True, 61 | transform=transform) 62 | 63 | eval_dataset = datasets.MNIST( 64 | root=args.data_dir, train=False, download=True, 65 | transform=transform) 66 | 67 | d = (1, 28, 28) 68 | d_output = 10 69 | 70 | elif args.dataset == "cifar10": 71 | train_dataset = datasets.CIFAR10( 72 | root=args.data_dir, train=True, download=True, 73 | transform=transform_cifar10_train) 74 | 75 | eval_dataset = datasets.CIFAR10( 76 | root=args.data_dir, train=False, download=True, 77 | transform=transform_cifar10_test) 78 | 79 | if not args.use_cnn: 80 | d = (3 * 32 * 32, ) 81 | else: 82 | d = (3, 32, 32) 83 | d_output = 10 84 | 85 | else: 86 | raise NotImplementedError(f"The dataset {args.dataset} is not implemented!") 87 | 88 | return d, d_output, train_dataset, eval_dataset 89 | -------------------------------------------------------------------------------- /student_specialization/model_gen.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import random 7 | from theory_utils import haar_measure, init_separate_w 8 | 9 | 10 | # Generate random orth matrix. 11 | import numpy as np 12 | import math 13 | 14 | def get_aug_w(w): 15 | # w: [output_d, input_d] 16 | # aug_w: [output_d + 1, input_d + 1] 17 | output_d, input_d = w.weight.size() 18 | aug_w = torch.zeros( (output_d + 1, input_d + 1), dtype = w.weight.dtype, device = w.weight.device) 19 | aug_w[:output_d, :input_d] = w.weight.data 20 | aug_w[:output_d, input_d] = w.bias.data 21 | aug_w[output_d, input_d] = 1 22 | return aug_w 23 | 24 | def set_orth(layer): 25 | w = layer.weight 26 | orth = haar_measure(w.size(1)) 27 | w.data = torch.from_numpy(orth[:w.size(0), :w.size(1)].astype('f4')).cuda() 28 | 29 | def set_add_noise(layer, teacher_layer, perturb): 30 | layer.weight.data[:] = teacher_layer.weight.data[:] + torch.randn(teacher_layer.weight.size()).cuda() * perturb 31 | layer.bias.data[:] = teacher_layer.bias.data[:] + torch.randn(teacher_layer.bias.size()).cuda() * perturb 32 | 33 | def set_same_dir(layer, teacher_layer): 34 | norm = layer.weight.data.norm() 35 | r = norm / teacher_layer.weight.data.norm() 36 | layer.weight.data[:] = teacher_layer.weight.data * r 37 | layer.bias.data[:] = teacher_layer.bias.data * r 38 | 39 | def set_same_sign(layer, teacher_layer): 40 | sel = (teacher_layer.weight.data > 0) * (layer.weight.data < 0) + (teacher_layer.weight.data < 0) * (layer.weight.data > 0) 41 | layer.weight.data[sel] *= -1.0 42 | 43 | sel = (teacher_layer.bias.data > 0) * (layer.bias.data < 0) + (teacher_layer.bias.data < 0) * (layer.bias.data > 0) 44 | layer.bias.data[sel] *= -1.0 45 | 46 | def normalize_layer(layer): 47 | # [output, input] 48 | w = layer.weight.data 49 | for i in range(w.size(0)): 50 | norm = w[i].pow(2).sum().sqrt() + 1e-5 51 | w[i] /= norm 52 | if layer.bias is not None: 53 | layer.bias.data[i] /= norm 54 | 55 | def init_w(layer, use_sep=True, weight_choices=[-0.5, -0.25, 0, 0.25, 0.5]): 56 | sz = layer.weight.size() 57 | output_d = sz[0] 58 | input_d = 1 59 | for s in sz[1:]: 60 | input_d *= s 61 | 62 | if use_sep: 63 | layer.weight.data[:] = torch.from_numpy(init_separate_w(output_d, input_d, weight_choices)).view(*sz).cuda() 64 | if layer.bias is not None: 65 | layer.bias.data.uniform_(-.5, 0.5) 66 | 67 | def init_w2(w, multiplier=5): 68 | w.weight.data *= multiplier 69 | w.bias.data.normal_(0, std=1) 70 | # w.bias.data *= 5 71 | for i, ww in enumerate(w.weight.data): 72 | pos_ratio = (ww > 0.0).sum().item() / w.weight.size(1) - 0.5 73 | w.bias.data[i] -= pos_ratio 74 | 75 | 76 | class Model(nn.Module): 77 | def __init__(self, d, ks, d_output, multi=1, has_bn=True, has_bn_affine=True, has_bias=True, bn_before_relu=False, leaky_relu=None): 78 | super(Model, self).__init__() 79 | self.d = d 80 | self.ks = ks 81 | self.has_bn = has_bn 82 | self.ws_linear = nn.ModuleList() 83 | self.ws_bn = nn.ModuleList() 84 | self.bn_before_relu = bn_before_relu 85 | last_k = d 86 | self.sizes = [d] 87 | 88 | for k in ks: 89 | k *= multi 90 | self.ws_linear.append(nn.Linear(last_k, k, bias=has_bias)) 91 | if has_bn: 92 | self.ws_bn.append(nn.BatchNorm1d(k, affine=has_bn_affine)) 93 | self.sizes.append(k) 94 | last_k = k 95 | 96 | self.final_w = nn.Linear(last_k, d_output, bias=has_bias) 97 | self.relu = nn.ReLU() if leaky_relu is None else nn.LeakyReLU(leaky_relu) 98 | 99 | self.sizes.append(d_output) 100 | 101 | self.use_cnn = False 102 | 103 | def init_orth(self): 104 | for w in self.ws: 105 | set_orth(w) 106 | set_orth(self.final_w) 107 | 108 | def set_teacher(self, teacher, perturb): 109 | for w_s, w_t in zip(self.ws, teacher.ws): 110 | set_add_noise(w_s, w_t, perturb) 111 | set_add_noise(self.final_w, teacher.final_w, perturb) 112 | 113 | def set_teacher_dir(self, teacher): 114 | for w_s, w_t in zip(self.ws, teacher.ws): 115 | set_same_dir(w_s, w_t) 116 | set_same_dir(self.final_w, teacher.final_w) 117 | 118 | def set_teacher_sign(self, teacher): 119 | for w_s, w_t in zip(self.ws, teacher.ws): 120 | set_same_sign(w_s, w_t) 121 | set_same_sign(self.final_w, teacher.final_w) 122 | 123 | def prioritize(self, strength_decay): 124 | def _prioritize(w): 125 | # output x input. 126 | for i in range(w.size(1)): 127 | w[:, i] /= pow(1 + i, strength_decay) 128 | 129 | # Prioritize teacher node. 130 | for w in self.ws_linear[1:]: 131 | _prioritize(w.weight.data) 132 | 133 | _prioritize(self.final_w.weight.data) 134 | 135 | def scale(self, r): 136 | def _scale(w): 137 | w.weight.data *= r 138 | w.bias.data *= r 139 | 140 | for w in self.ws_linear: 141 | _scale(w) 142 | 143 | _scale(self.final_w) 144 | 145 | def forward(self, x): 146 | hs = [] 147 | pre_bns = [] 148 | post_lins = [] 149 | #bns = [] 150 | h = x 151 | for i in range(len(self.ws_linear)): 152 | w = self.ws_linear[i] 153 | h = w(h) 154 | post_lins.append(h) 155 | if self.bn_before_relu: 156 | pre_bns.append(h) 157 | if len(self.ws_bn) > 0: 158 | bn = self.ws_bn[i] 159 | h = bn(h) 160 | h = self.relu(h) 161 | else: 162 | h = self.relu(h) 163 | pre_bns.append(h) 164 | if len(self.ws_bn) > 0: 165 | bn = self.ws_bn[i] 166 | h = bn(h) 167 | hs.append(h) 168 | #bns.append(h) 169 | y = self.final_w(hs[-1]) 170 | return dict(hs=hs, post_lins=post_lins, pre_bns=pre_bns, y=y) 171 | 172 | def init_w(self, use_sep=True, weight_choices=None): 173 | for w in self.ws_linear: 174 | init_w(w, use_sep=use_sep, weight_choices=weight_choices) 175 | init_w(self.final_w, use_sep=use_sep, weight_choices=weight_choices) 176 | 177 | def reset_parameters(self): 178 | for w in self.ws_linear: 179 | w.reset_parameters() 180 | for w in self.ws_bn: 181 | w.reset_parameters() 182 | self.final_w.reset_parameters() 183 | 184 | def normalize(self): 185 | for w in self.ws_linear: 186 | normalize_layer(w) 187 | normalize_layer(self.final_w) 188 | 189 | def from_bottom_linear(self, j): 190 | if j < len(self.ws_linear): 191 | return self.ws_linear[j].weight.data 192 | elif j == len(self.ws_linear): 193 | return self.final_w.weight.data 194 | else: 195 | raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws))) 196 | 197 | def from_bottom_aug_w(self, j): 198 | if j < len(self.ws_linear): 199 | return get_aug_w(self.ws_linear[j]) 200 | elif j == len(self.ws_linear): 201 | return get_aug_w(self.final_w) 202 | else: 203 | raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws))) 204 | 205 | def num_hidden_layers(self): 206 | return len(self.ws_linear) 207 | 208 | def num_layers(self): 209 | return len(self.ws_linear) + 1 210 | 211 | def from_bottom_bn(self, j): 212 | assert j < len(self.ws_bn) 213 | return self.ws_bn[j] 214 | 215 | 216 | class ModelConv(nn.Module): 217 | def __init__(self, input_size, ks, d_output, multi=1, has_bn=True, bn_before_relu=False, leaky_relu=None): 218 | super(ModelConv, self).__init__() 219 | self.ks = ks 220 | self.ws_linear = nn.ModuleList() 221 | self.ws_bn = nn.ModuleList() 222 | self.bn_before_relu = bn_before_relu 223 | 224 | init_k, h, w = input_size 225 | last_k = init_k 226 | 227 | for k in ks: 228 | k *= multi 229 | self.ws_linear.append(nn.Conv2d(last_k, k, 3)) 230 | if has_bn: 231 | self.ws_bn.append(nn.BatchNorm2d(k)) 232 | last_k = k 233 | h -= 2 234 | w -= 2 235 | 236 | self.final_w = nn.Linear(last_k * h * w, d_output) 237 | self.relu = nn.ReLU() if leaky_relu is None else nn.LeakyReLU(leaky_relu) 238 | 239 | self.use_cnn = True 240 | 241 | def scale(self, r): 242 | def _scale(w): 243 | w.weight.data *= r 244 | w.bias.data *= r 245 | 246 | for w in self.ws_linear: 247 | _scale(w) 248 | 249 | _scale(self.final_w) 250 | 251 | def forward(self, x): 252 | hs = [] 253 | #bns = [] 254 | h = x 255 | for i in range(len(self.ws_linear)): 256 | w = self.ws_linear[i] 257 | h = w(h) 258 | if self.bn_before_relu: 259 | if len(self.ws_bn) > 0: 260 | bn = self.ws_bn[i] 261 | h = bn(h) 262 | h = self.relu(h) 263 | else: 264 | h = self.relu(h) 265 | if len(self.ws_bn) > 0: 266 | bn = self.ws_bn[i] 267 | h = bn(h) 268 | hs.append(h) 269 | #bns.append(h) 270 | h = hs[-1].view(h.size(0), -1) 271 | y = self.final_w(h) 272 | return dict(hs=hs, y=y) 273 | 274 | def init_w(self, use_sep=True, weight_choices=None): 275 | for w in self.ws_linear: 276 | init_w(w, use_sep=use_sep, weight_choices=weight_choices) 277 | init_w(self.final_w, use_sep=use_sep, weight_choices=weight_choices) 278 | 279 | def normalize(self): 280 | for w in self.ws_linear: 281 | normalize_layer(w) 282 | normalize_layer(self.final_w) 283 | 284 | def normalize_last(self): 285 | normalize_layer(self.final_w) 286 | 287 | def reset_parameters(self): 288 | for w in self.ws_linear: 289 | w.reset_parameters() 290 | for w in self.ws_bn: 291 | w.reset_parameters() 292 | self.final_w.reset_parameters() 293 | 294 | def from_bottom_linear(self, j): 295 | if j < len(self.ws_linear): 296 | return self.ws_linear[j].weight.data 297 | elif j == len(self.ws_linear): 298 | return self.final_w.weight.data 299 | else: 300 | raise RuntimeError("j[%d] is out of bound! should be [0, %d]" % (j, len(self.ws))) 301 | 302 | def num_hidden_layers(self): 303 | return len(self.ws_linear) 304 | 305 | def num_layers(self): 306 | return len(self.ws_linear) + 1 307 | 308 | def from_bottom_bn(self, j): 309 | assert j < len(self.ws_bn) 310 | return self.ws_bn[j] 311 | 312 | def prune(net, ratios): 313 | # Prune the network and finetune. 314 | n = net.num_layers() 315 | # Compute L1 norm and and prune them globally 316 | masks = [] 317 | inactive_nodes = [] 318 | for i in range(1, n): 319 | W = net.from_bottom_linear(i) 320 | # Prune all input neurons 321 | input_dim = W.size(1) 322 | fc_to_conv = False 323 | 324 | if isinstance(net, ModelConv): 325 | if len(W.size()) == 4: 326 | # W: [output_filter, input_filter, x, y] 327 | w_norms = W.permute(1, 0, 2, 3).contiguous().view(W.size(1), -1).abs().mean(1) 328 | else: 329 | # The final FC layer. 330 | input_dim = net.from_bottom_linear(i - 1).size(0) 331 | W_reshaped = W.view(W.size(0), -1, input_dim) 332 | w_norms = W_reshaped.view(-1, input_dim).abs().mean(0) 333 | fc_to_conv = True 334 | else: 335 | # W: [output_dim, input_dim] 336 | w_norms = W.abs().mean(0) 337 | 338 | sorted_w, sorted_indices = w_norms.sort(0) 339 | n_pruned = int(input_dim * ratios[i - 1]) 340 | inactive_mask = sorted_indices[:n_pruned] 341 | 342 | m = W.clone().fill_(1.0) 343 | if fc_to_conv: 344 | m = m.view(m.size(0), -1, input_dim) 345 | m[:, :, inactive_mask] = 0 346 | m = m.view(W.size(0), W.size(1)) 347 | else: 348 | m[:, inactive_mask] = 0 349 | 350 | # Set the mask for the lower layer to zero. 351 | inactive_nodes.append(inactive_mask.cpu().tolist()) 352 | masks.append(m) 353 | 354 | return inactive_nodes, masks 355 | -------------------------------------------------------------------------------- /student_specialization/recon_two_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import math 5 | import argparse 6 | import pickle 7 | import sys 8 | import os 9 | import hydra 10 | import subprocess 11 | import logging 12 | log = logging.getLogger(__file__) 13 | 14 | from theory_utils import init_separate_w, set_all_seeds 15 | 16 | from utils_corrs import * 17 | from vis_corrs import * 18 | 19 | def forward(X, W1, W2, nonlinear: bool): 20 | all_one = torch.zeros(X.size(0), 1, dtype=X.dtype).to(X.device).fill_(1.0) 21 | # Teacher's output. 22 | X = torch.cat([X, all_one], dim=1) 23 | h1 = X @ W1 24 | h1 = torch.cat([h1, all_one], dim=1) 25 | 26 | if nonlinear: 27 | h1_ng = h1 < 0 28 | h1[h1_ng] = 0 29 | else: 30 | h1_ng = h1 < 1e38 31 | 32 | output = h1 @ W2 33 | return X, h1, h1_ng, output 34 | 35 | def backward(X, W1, W2, h1, h1_ng, g2, nonlinear: bool): 36 | deltaW2 = h1.t() @ g2 37 | g1 = (g2 @ W2.t()) 38 | if nonlinear: 39 | g1[h1_ng] = 0 40 | deltaW1 = X.t() @ g1[:, :-1] 41 | 42 | return deltaW1, deltaW2, g1 43 | 44 | def convert(*cfg): 45 | return tuple([ v.double().cuda() for v in cfg ]) 46 | 47 | def get_data(N, d): 48 | X_eval = torch.randn(cfg.N, d).cuda() * cfg.data_std 49 | 50 | def normalize(W): 51 | W[:-1,:] /= W[:-1,:].norm(dim=0) 52 | 53 | def init(cfg): 54 | d = cfg.d 55 | m = cfg.m 56 | n = int(cfg.m * cfg.multi) 57 | c = cfg.c 58 | 59 | log.info(f"d = {d}, m = {m}, n = {n}, c = {c}") 60 | 61 | W1_t = torch.randn(d + 1, m).cuda() * cfg.teacher_scale 62 | W1_t[:-1, :] = torch.from_numpy(init_separate_w(m, d, cfg.choices)).t() 63 | W1_t[-1, :] = cfg.bias 64 | 65 | W2_t = torch.randn(m + 1, c).cuda() * cfg.teacher_scale 66 | W2_t[:-1, :] = torch.from_numpy(init_separate_w(c, m, cfg.choices)).t() 67 | 68 | if cfg.teacher_strength_decay > 0: 69 | for i in range(1, m): 70 | W2_t[i, :] /= pow(i + 1, cfg.teacher_strength_decay) 71 | 72 | W2_t[-1, :] = cfg.bias 73 | 74 | W1_s = torch.randn(d + 1, n).cuda() * cfg.student_scale 75 | # Bias = 0 76 | W1_s[-1, :] = 0 77 | 78 | W2_s = torch.randn(n + 1, c).cuda() * cfg.student_scale 79 | # Bias = 0 80 | W2_s[-1, :] = 0 81 | 82 | # delibrately move the weight away from the last teacher. 83 | if cfg.adv_init == "adv": 84 | for i in range(n): 85 | if (W1_t[:-1, -1] * W1_s[:-1, i]).sum().item() > 0: 86 | W1_s[:-1, i] *= -1 87 | elif cfg.adv_init == "help": 88 | for i in range(n): 89 | if (W1_t[:-1, -1] * W1_s[:-1, i]).sum().item() < 0: 90 | W1_s[:-1, i] *= -1 91 | elif cfg.adv_init != "none": 92 | raise RuntimeError(f"Invalid adv_init: {cfg.adv_init}") 93 | 94 | W1_t, W2_t, W1_s, W2_s = convert(W1_t, W2_t, W1_s, W2_s) 95 | 96 | normalize(W1_t) 97 | 98 | if cfg.normalize: 99 | normalize(W1_s) 100 | 101 | return W1_t, W2_t, W1_s, W2_s 102 | 103 | def compute_boundary_obs(s, t): 104 | s_pos = (s > 0).float() 105 | t_pos = (t > 0).float() 106 | t_neg = 1 - t_pos 107 | 108 | # s should see both the positive and negative part of the teacher (then s sees the boundary) 109 | pos_obs = t_pos.t() @ s_pos 110 | neg_obs = t_neg.t() @ s_pos 111 | 112 | obs = torch.min(pos_obs, neg_obs) 113 | return obs 114 | 115 | def eval_phrase(h1_t, h1_s, h1_eval_t, h1_eval_s): 116 | # with np.log.infooptions(precision=4, suppress=True, linewidth=120): 117 | # More analysis. 118 | # A = h1_s.t() @ h1_s 119 | # B = h1_s.t() @ h1_t 120 | # Solve AX = B, note that this is not stable, so we can remove it. 121 | # C, _ = torch.gesv(B, A) 122 | # log.info("coeffs: \n", C.cpu().numpy()) 123 | # log.info("B*: \n", (C @ Btt).cpu().numpy()) 124 | 125 | h1_s_no_aug = h1_s[:,:-1] 126 | h1_t_no_aug = h1_t[:,:-1] 127 | 128 | ''' 129 | inner_prod = h1_s_no_aug.t() @ h1_t_no_aug 130 | norm_s = h1_s_no_aug.pow(2).sum(0).sqrt() 131 | norm_t = h1_t_no_aug.pow(2).sum(0).sqrt() 132 | correlation = inner_prod / norm_s[:,None] / norm_t[None,:] 133 | ''' 134 | 135 | # import pdb 136 | # pdb.set_trace() 137 | # log.info(torch.cat([correlation, norm_s.view(-1, 1), Bstar], 1).cpu().numpy()) 138 | 139 | N = h1_t.size(0) 140 | ts_prod = h1_t_no_aug.t() @ h1_s_no_aug / N 141 | ss_prod = h1_s_no_aug.t() @ h1_s_no_aug / N 142 | 143 | log.info("Train:") 144 | corr_train = act2corrMat(h1_t_no_aug, h1_s_no_aug) 145 | corr_indices_train = corrMat2corrIdx(corr_train) 146 | log.info(get_corrs([corr_indices_train])) 147 | 148 | counts_train = dict() 149 | for thres in (0.5, 0.6, 0.7, 0.8, 0.9, 0.95): 150 | counts_train[thres] = (corr_train > thres).sum(dim=1).cpu() 151 | log.info(f"Convergence count (Train) (>{thres}): {counts_train[thres]}, covered: { (counts_train[thres] > 0).sum().item() }") 152 | 153 | # Compute correlation between h1_s and h1_t 154 | corr_eval = act2corrMat(h1_eval_t[:,:-1], h1_eval_s[:,:-1]) 155 | corr_indices_eval = corrMat2corrIdx(corr_eval) 156 | log.info("Eval:") 157 | log.info(get_corrs([corr_indices_eval])) 158 | 159 | counts_eval = dict() 160 | for thres in (0.5, 0.6, 0.7, 0.8, 0.9, 0.95): 161 | counts_eval[thres] = (corr_eval > thres).sum(dim=1).cpu() 162 | log.info(f"Convergence count (Eval) (>{thres}): {counts_eval[thres]}, covered: { (counts_eval[thres] > 0).sum().item() }") 163 | 164 | # compute observability matrix. 165 | obs_train = compute_boundary_obs(h1_s_no_aug, h1_t_no_aug) 166 | log.info(f"train: obs: {obs_train.max(dim=1)[0] / h1_t_no_aug.size(0)}") 167 | 168 | obs_eval = compute_boundary_obs(h1_eval_s[:,:-1], h1_eval_t[:,:-1]) 169 | log.info(f"eval: obs: {obs_eval.max(dim=1)[0] / h1_eval_t.size(0)}") 170 | 171 | return dict(corr_train=corr_train.cpu(), corr_eval=corr_eval.cpu(), obs_train=obs_train.cpu(), obs_eval=obs_eval.cpu(), 172 | counts_train=counts_train, counts_eval=counts_eval, ts_prod=ts_prod.cpu(), ss_prod=ss_prod.cpu()) 173 | 174 | def after_epoch_eval(i, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg): 175 | log.info(f"{i}: Epoch evaluation") 176 | _, h1_train_t, _, output_t_train = forward(X_train, W1_t, W2_t, nonlinear=cfg.nonlinear) 177 | # Student's output. 178 | _, h1_train_s, h1_train_ng_s, output_s_train = forward(X_train, W1_s, W2_s, nonlinear=cfg.nonlinear) 179 | 180 | _, h1_eval_t, _, output_t_eval = forward(X_eval, W1_t, W2_t, nonlinear=cfg.nonlinear) 181 | _, h1_eval_s, _, output_s_eval = forward(X_eval, W1_s, W2_s, nonlinear=cfg.nonlinear) 182 | 183 | g2_train = output_t_train - output_s_train 184 | train_loss = g2_train.pow(2).mean() 185 | eval_loss = (output_t_eval - output_s_eval).pow(2).mean() 186 | log.info(f"{i}: train_loss = {train_loss}, eval_loss = {eval_loss}") 187 | 188 | deltaW1_s, deltaW2_s, g1_train = backward(X_train, W1_s, W2_s, h1_train_s, h1_train_ng_s, g2_train, nonlinear=cfg.nonlinear) 189 | deltaW1_s /= X_train.size(0) 190 | deltaW2_s /= X_train.size(0) 191 | log.info(f"|g1_train| = {g1_train.norm() / X_train.size(0)}") 192 | 193 | # Compute g1_train to see whether it is zero for each individual training samples. 194 | stat = dict(iter=i, W1_s=W1_s.cpu(), W2_s=W2_s.cpu(), train_loss=train_loss.cpu(), eval_loss=eval_loss.cpu()) 195 | stat.update(dict(deltaW1_s=deltaW1_s.cpu(), deltaW2_s=deltaW2_s.cpu(), 196 | g1_train=g1_train.pow(2).mean(dim=0).cpu(), g2_train=g2_train.pow(2).mean(0).cpu())) 197 | 198 | stat.update(eval_phrase(h1_train_t, h1_train_s, h1_eval_t, h1_eval_s)) 199 | log.info(f"|gradW1|={deltaW1_s.norm()}, |gradW2|={deltaW2_s.norm()}") 200 | 201 | return stat 202 | 203 | def run(cfg): 204 | W1_t, W2_t, W1_s, W2_s = init(cfg) 205 | 206 | Btt = W2_t @ W2_t.t() 207 | # log.info(Btt) 208 | 209 | X_eval = torch.randn(cfg.N_eval, cfg.d) * cfg.data_std 210 | 211 | if cfg.theory_suggest_train: 212 | X_train = [] 213 | for i in range(cfg.m): 214 | data = torch.randn(math.ceil(cfg.N_train / (cfg.m * 3)), cfg.d + 1).double().cuda() * cfg.data_std 215 | data[:, -1] = 1 216 | # projected to teacher plane. 217 | w = W1_t[:, i] 218 | # In the plane now. 219 | data = data - torch.ger(data @ w, w) / w.pow(2).sum() 220 | data = data[:, :-1] / data[:, -1][:, None] 221 | 222 | alpha = torch.rand(data.size(0)).double().cuda() * cfg.theory_suggest_sigma + cfg.theory_suggest_mean 223 | data_plus = data + torch.ger(alpha, w[:-1]) 224 | data_minus = data - torch.ger(alpha, w[:-1]) 225 | 226 | X_train.extend([data, data_plus, data_minus]) 227 | # X_train.extend([data_plus, data_minus]) 228 | 229 | X_train = torch.cat(X_train, dim=0) 230 | X_train /= X_train.norm(dim=1)[:,None] 231 | X_train *= cfg.data_std * math.sqrt(cfg.d) 232 | print(f"Use dataset from theory: N_train = {X_train.size(0)}") 233 | cfg.N_train = X_train.size(0) 234 | else: 235 | X_train = torch.randn(cfg.N_train, cfg.d) * cfg.data_std 236 | 237 | X_train, X_eval = convert(X_train, X_eval) 238 | 239 | t_norms = W2_t.norm(dim=1) 240 | print(f"teacher norm: {t_norms}") 241 | 242 | init_stat = dict(W1_t=W1_t.cpu(), W2_t=W2_t.cpu(), W1_s=W1_s.cpu(), W2_s=W2_s.cpu()) 243 | init_stat.update(after_epoch_eval(-1, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg)) 244 | 245 | stats = [] 246 | stats.append(init_stat) 247 | 248 | train_set_sel = list(range(cfg.N_train)) 249 | lr = cfg.lr 250 | 251 | for i in range(cfg.num_epoch): 252 | W1_s_old = W1_s.clone() 253 | W2_s_old = W2_s.clone() 254 | 255 | if cfg.lr_reduction > 0 and i > 0 and (i % cfg.lr_reduction == 0): 256 | lr = lr / 2 257 | log.info(f"{i}: reducing learning rate: {lr}") 258 | 259 | for j in range(cfg.num_iter_per_epoch): 260 | if cfg.use_sgd: 261 | sel = random.choices(train_set_sel, k=cfg.batchsize) 262 | # Randomly picking a subset. 263 | X = X_train[sel, :].clone() 264 | else: 265 | # Gradient descent. 266 | X = X_train 267 | 268 | # Teacher's output. 269 | X_aug, h1_t, h1_ng_t, output_t = forward(X, W1_t, W2_t, nonlinear=cfg.nonlinear) 270 | 271 | # Student's output. 272 | X_aug, h1_s, h1_ng_s, output_s = forward(X, W1_s, W2_s, nonlinear=cfg.nonlinear) 273 | 274 | # Backpropagation. 275 | g2 = output_t - output_s 276 | deltaW1_s, deltaW2_s, _ = backward(X_aug, W1_s, W2_s, h1_s, h1_ng_s, g2, nonlinear=cfg.nonlinear) 277 | deltaW1_s /= X.size(0) 278 | deltaW2_s /= X.size(0) 279 | 280 | if not cfg.feature_fixed: 281 | W1_s += lr * deltaW1_s 282 | if cfg.normalize: 283 | normalize(W1_s) 284 | 285 | if not cfg.top_layer_fixed: 286 | W2_s += lr * deltaW2_s 287 | 288 | if cfg.no_bias: 289 | W1_s[-1, :] = 0 290 | W2_s[-1, :] = 0 291 | 292 | 293 | stat = after_epoch_eval(i, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg) 294 | stats.append(stat) 295 | 296 | if cfg.regen_dataset: 297 | X_train = torch.randn(cfg.N_train, cfg.d) * cfg.data_std 298 | X_train = convert(X_train)[0] 299 | 300 | log.info(f"|W1|={W1_s.norm()}, |W2|={W2_s.norm()}") 301 | log.info(f"|deltaW1|={(W1_s - W1_s_old).norm()}, |deltaW2|={(W2_s - W2_s_old).norm()}") 302 | 303 | return stats 304 | 305 | 306 | @hydra.main(config_path='conf/config.yaml', strict=True) 307 | def main(cfg): 308 | cmd_line = " ".join(sys.argv) 309 | log.info(f"{cmd_line}") 310 | log.info(f"Working dir: {os.getcwd()}") 311 | 312 | _, output = subprocess.getstatusoutput("git -C ./ log --pretty=format:'%H' -n 1") 313 | ret, _ = subprocess.getstatusoutput("git -C ./ diff-index --quiet HEAD --") 314 | log.info(f"Githash: {output}, unstaged: {ret}") 315 | log.info("Configuration:\n{}".format(cfg.pretty())) 316 | 317 | # Simulate 2-layer dynamics. 318 | if cfg.no_bias: 319 | cfg.bias = 0.0 320 | 321 | if isinstance(cfg.seed, int): 322 | seeds = [cfg.seed] 323 | else: 324 | seeds = list(range(cfg.seed[0], cfg.seed[1] + 1)) 325 | 326 | all_stats = dict() 327 | for i, seed in enumerate(seeds): 328 | log.info(f"{i} / {len(seeds)}, Seed: {seed}") 329 | set_all_seeds(seed) 330 | all_stats[seed] = run(cfg) 331 | 332 | torch.save(all_stats, "stats.pickle") 333 | log.info(f"Working dir: {os.getcwd()}") 334 | 335 | 336 | if __name__ == "__main__": 337 | main() 338 | 339 | -------------------------------------------------------------------------------- /student_specialization/teacher_tune.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | def tune_teacher(eval_loader, teacher): 5 | # Tune the bias of the teacher so that their activation/inactivation is approximated 0.5/0.5 6 | num_hidden = teacher.num_hidden_layers() 7 | for t in range(num_hidden): 8 | output = utils.concatOutput(eval_loader, [teacher]) 9 | estimated_bias = output[0]["post_lins"][t].median(dim=0)[0] 10 | teacher.ws_linear[t].bias.data[:] -= estimated_bias.cuda() 11 | 12 | # double check 13 | output = utils.concatOutput(eval_loader, [teacher]) 14 | for t in range(num_hidden): 15 | activate_ratio = (output[0]["post_lins"][t] > 0).float().mean(dim=0) 16 | print(f"{t}: {activate_ratio}") 17 | 18 | def tune_teacher_last_layer(eval_loader, teacher): 19 | output = utils.concatOutput(eval_loader, [teacher]) 20 | 21 | # Tune the final linear layer to make output balanced as well. 22 | y = output[0]["y"] 23 | y_mean = y.mean(dim=0).cuda() 24 | y_std = y.std(dim=0).cuda() 25 | 26 | teacher.final_w.weight.data /= y_std[:, None] 27 | teacher.final_w.bias.data -= y_mean 28 | teacher.final_w.bias.data /= y_std 29 | 30 | # double check 31 | output = utils.concatOutput(eval_loader, [teacher]) 32 | y = output[0]["y"] 33 | y_mean = y.mean(dim=0) 34 | y_std = y.std(dim=0) 35 | 36 | print(f"Final layer: y_mean: {y_mean}, y_std: {y_std}") 37 | 38 | -------------------------------------------------------------------------------- /student_specialization/theory_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | 5 | def haar_measure(n): 6 | '''Generate an n-by-n Random matrix distributed with Haar measure''' 7 | z = np.random.randn(n,n) 8 | q,r = np.linalg.qr(z) 9 | d = np.diag(r) 10 | ph = d/np.absolute(d) 11 | q = np.dot(np.dot(q,np.diag(ph)), q) 12 | return q 13 | 14 | def init_separate_w(output_d, input_d, choices): 15 | existing_encoding = set() 16 | existing_encoding.add(tuple([0] * input_d)) 17 | 18 | w = np.zeros((output_d, input_d)) 19 | 20 | for i in range(output_d): 21 | while True: 22 | encoding = tuple( random.sample(choices, 1)[0] for j in range(input_d) ) 23 | if encoding not in existing_encoding: 24 | break 25 | for j in range(input_d): 26 | w[i, j] = encoding[j] 27 | existing_encoding.add(encoding) 28 | 29 | return w 30 | 31 | def set_all_seeds(rand_seed): 32 | random.seed(rand_seed) 33 | np.random.seed(rand_seed) 34 | torch.manual_seed(rand_seed) 35 | torch.cuda.manual_seed(rand_seed) 36 | 37 | -------------------------------------------------------------------------------- /student_specialization/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | 5 | def to_cpu(x): 6 | if isinstance(x, dict): 7 | return { k : to_cpu(v) for k, v in x.items() } 8 | elif isinstance(x, list): 9 | return [ to_cpu(v) for v in x ] 10 | elif isinstance(x, torch.Tensor): 11 | return x.cpu() 12 | else: 13 | return x 14 | 15 | def model2numpy(model): 16 | return { k : v.cpu().numpy() for k, v in model.state_dict().items() } 17 | 18 | def activation2numpy(output): 19 | if isinstance(output, dict): 20 | return { k : activation2numpy(v) for k, v in output.items() } 21 | elif isinstance(output, list): 22 | return [ activation2numpy(v) for v in output ] 23 | elif isinstance(output, Variable): 24 | return output.data.cpu().numpy() 25 | 26 | def count_size(x): 27 | if isinstance(x, dict): 28 | return sum([ count_size(v) for k, v in x.items() ]) 29 | elif isinstance(x, list) or isinstance(x, tuple): 30 | return sum([ count_size(v) for v in x ]) 31 | elif isinstance(x, torch.Tensor): 32 | return x.nelement() * x.element_size() 33 | else: 34 | return sys.getsizeof(x) 35 | 36 | def mem2str(num_bytes): 37 | assert num_bytes >= 0 38 | if num_bytes >= 2 ** 30: # GB 39 | val = float(num_bytes) / (2 ** 30) 40 | result = "%.3f GB" % val 41 | elif num_bytes >= 2 ** 20: # MB 42 | val = float(num_bytes) / (2 ** 20) 43 | result = "%.3f MB" % val 44 | elif num_bytes >= 2 ** 10: # KB 45 | val = float(num_bytes) / (2 ** 10) 46 | result = "%.3f KB" % val 47 | else: 48 | result = "%d bytes" % num_bytes 49 | return result 50 | 51 | def get_mem_usage(): 52 | import psutil 53 | 54 | mem = psutil.virtual_memory() 55 | result = "" 56 | result += "available: %s\t" % (mem2str(mem.available)) 57 | result += "used: %s\t" % (mem2str(mem.used)) 58 | result += "free: %s\t" % (mem2str(mem.free)) 59 | # result += "active: %s\t" % (mem2str(mem.active)) 60 | # result += "inactive: %s\t" % (mem2str(mem.inactive)) 61 | # result += "buffers: %s\t" % (mem2str(mem.buffers)) 62 | # result += "cached: %s\t" % (mem2str(mem.cached)) 63 | # result += "shared: %s\t" % (mem2str(mem.shared)) 64 | # result += "slab: %s\t" % (mem2str(mem.slab)) 65 | return result 66 | 67 | 68 | def accumulate(all_y, y): 69 | if all_y is None: 70 | all_y = dict() 71 | for k, v in y.items(): 72 | if isinstance(v, list): 73 | all_y[k] = [ [vv] for vv in v ] 74 | else: 75 | all_y[k] = [v] 76 | else: 77 | for k, v in all_y.items(): 78 | if isinstance(y[k], list): 79 | for vv, yy in zip(v, y[k]): 80 | vv.append(yy) 81 | else: 82 | v.append(y[k]) 83 | 84 | return all_y 85 | 86 | def combine(all_y): 87 | output = dict() 88 | for k, v in all_y.items(): 89 | if isinstance(v[0], list): 90 | output[k] = [ torch.cat(vv) for vv in v ] 91 | else: 92 | output[k] = torch.cat(v) 93 | 94 | return output 95 | 96 | def concatOutput(loader, nets, condition=None): 97 | outputs = [None] * len(nets) 98 | 99 | use_cnn = nets[0].use_cnn 100 | 101 | with torch.no_grad(): 102 | for i, (x, _) in enumerate(loader): 103 | if not use_cnn: 104 | x = x.view(x.size(0), -1) 105 | x = x.cuda() 106 | 107 | outputs = [ accumulate(output, to_cpu(net(x))) for net, output in zip(nets, outputs) ] 108 | if condition is not None and not condition(i): 109 | break 110 | 111 | return [ combine(output) for output in outputs ] 112 | 113 | 114 | -------------------------------------------------------------------------------- /student_specialization/utils_corrs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def corrMat2corrIdx(score): 4 | ''' Given score[N, #candidate], 5 | for each sample sort the score (in descending order) 6 | and output corr_table[N, #dict(s_idx, score)] 7 | ''' 8 | sorted_score, sorted_indices = score.sort(1, descending=True) 9 | N = sorted_score.size(0) 10 | n_candidate = sorted_score.size(1) 11 | # Check the correpsonding weights. 12 | # print("For each teacher node, sorted corr over all student nodes at layer = %d" % i) 13 | corr_table = [] 14 | for k in range(N): 15 | tt = [] 16 | for j in range(n_candidate): 17 | # Compare the upward weights. 18 | s_idx = int(sorted_indices[k][j]) 19 | score = float(sorted_score[k][j]) 20 | tt.append(dict(s_idx=s_idx, score=score)) 21 | corr_table.append(tt) 22 | return corr_table 23 | 24 | ''' 25 | def corrIdx2pickMat(corr_indices): 26 | return [ item[0]["s_idx"] for item in corr_indices ] 27 | 28 | def corrIndices2pickMats(corr_indices_list): 29 | return [ corrIdx2pickMat(corr_indices) for corr_indices in corr_indices_list ] 30 | ''' 31 | 32 | def act2corrMat(src, dst): 33 | ''' src[:, k], with k < K1 34 | dst[:, k'], with k' < K2 35 | output correlation score[K1, K2] 36 | ''' 37 | # K_src by K_dst 38 | if len(src.size()) == 3 and len(dst.size()) == 3: 39 | src = src.permute(0, 2, 1).contiguous().view(src.size(0) * src.size(2), -1) 40 | dst = dst.permute(0, 2, 1).contiguous().view(dst.size(0) * dst.size(2), -1) 41 | 42 | # conv activations. 43 | elif len(src.size()) == 4 and len(dst.size()) == 4: 44 | src = src.permute(0, 2, 3, 1).contiguous().view(src.size(0) * src.size(2) * src.size(3), -1) 45 | dst = dst.permute(0, 2, 3, 1).contiguous().view(dst.size(0) * dst.size(2) * dst.size(3), -1) 46 | 47 | # Substract mean. 48 | src = src - src.mean(0, keepdim=True) 49 | dst = dst - dst.mean(0, keepdim=True) 50 | 51 | inner_prod = torch.mm(src.t(), dst) 52 | src_inv_norm = src.pow(2).sum(0).add_(1e-10).rsqrt().view(-1, 1) 53 | dst_inv_norm = dst.pow(2).sum(0).add_(1e-10).rsqrt().view(1, -1) 54 | 55 | return inner_prod * src_inv_norm * dst_inv_norm 56 | 57 | def acts2corrMats(hidden_t, hidden_s): 58 | # Match response 59 | ''' Output correlation matrix for each layer ''' 60 | corrs = [] 61 | for t, s in zip(hidden_t, hidden_s): 62 | corr = act2corrMat(t.data, s.data) 63 | corrs.append(corr) 64 | return corrs 65 | 66 | def acts2corrIndices(hidden_t, hidden_s): 67 | # Match response 68 | ''' Output correlation indices for each layer ''' 69 | corrs = [] 70 | for t, s in zip(hidden_t, hidden_s): 71 | corr = act2corrMat(t.data, s.data) 72 | corrs.append(corrMat2corrIdx(corr)) 73 | return corrs 74 | 75 | ''' 76 | w_t = getattr(teacher, "w%d" % (i + 1)).weight 77 | w_s = getattr(student, "w%d" % (i + 1)).weight 78 | w_teacher=w_t[:,k], w_student=w_s[:, s_idx] 79 | ''' 80 | 81 | def compareCorrIndices(init_corrs, final_corrs): 82 | res = [] 83 | for k, (init_corr, final_corr) in enumerate(zip(init_corrs, final_corrs)): 84 | # For each layer 85 | # print("Layer %d" % k) 86 | res_per_layer = [] 87 | 88 | for j, (init_node, final_node) in enumerate(zip(init_corr, final_corr)): 89 | # For each node 90 | ranks = dict() 91 | max_init_score = -1000 92 | for node_rank, node_info in enumerate(init_node): 93 | node_id = node_info["s_idx"] 94 | score = node_info["score"] 95 | ranks[node_id] = dict(rank=node_rank, score=score) 96 | max_init_score = max(max_init_score, score) 97 | 98 | s_score = [] 99 | s_idx = [] 100 | for node_info in final_node: 101 | node_id = node_info["s_idx"] 102 | if node_id in ranks: 103 | rank = ranks[node_id]["rank"] 104 | else: 105 | rank = "-" 106 | s_score.append(node_info["score"]) 107 | s_idx.append((node_id, str(rank))) 108 | # "%2d [%s]" % (node_id, str(rank))) 109 | # print("T[%d]: [init_student_max=%.4f] %s | idx: %s | min_rank: %d" % (j, max_init_corr, ",".join(s_val), ", ".join(s_idx), min_rank)) 110 | res_per_layer.append(dict(s_score=s_score, s_idx=s_idx, max_init_score=max_init_score)) 111 | res.append(res_per_layer) 112 | 113 | return res 114 | 115 | -------------------------------------------------------------------------------- /student_specialization/vis_corrs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def get_stat(w): 5 | # return f"min: {w.min()}, max: {w.max()}, mean: {w.mean()}, std: {w.std()}" 6 | if isinstance(w, list): 7 | w = np.array(w) 8 | elif isinstance(w, torch.Tensor): 9 | w = w.cpu().numpy() 10 | return f"len: {w.shape}, min/max: {np.min(w):#.3f}/{np.max(w):#.3f}, mean: {np.mean(w):#.3f}" 11 | 12 | def get_corrs(corrs, active_nodes=None, first_n=5, cnt_thres=0.9, details=False): 13 | summary = "" 14 | for k, corr_per_layer in enumerate(corrs): 15 | score = [] 16 | cnts = [] 17 | for kk, corr_per_node in enumerate(corr_per_layer): 18 | if active_nodes is not None and kk not in active_nodes[k]: 19 | continue 20 | 21 | # Get best score for each teacher 22 | if isinstance(corr_per_node, list): 23 | s = [ c["score"] for c in corr_per_node ] 24 | else: 25 | s = corr_per_node["s_score"] 26 | score.append(s[0]) 27 | if cnt_thres is not None: 28 | cnt = sum([ ss >= cnt_thres for ss in s ]) 29 | cnts.append(cnt) 30 | 31 | summary += f"L{k}: {get_stat(score)}" 32 | if cnt_thres is not None: 33 | summary += f", MatchCnt[>={cnt_thres}]: {get_stat(cnts)}" 34 | summary += "\n" 35 | 36 | output = "" 37 | output += f"Corrs Summary:\n{summary}" 38 | 39 | if details: 40 | output += "\n" 41 | for k, corr_per_layer in enumerate(corrs): 42 | # For each layer 43 | output += "Layer %d\n" % k 44 | for j, corr_per_node in enumerate(corr_per_layer): 45 | s_score = corr_per_node["s_score"][:first_n] 46 | s_idx = corr_per_node["s_idx"][:first_n] 47 | 48 | s_score_str = ",".join(["%.4f" % v for v in s_score]) 49 | s_idx_str = ",".join(["%2d [%s]" % (node_id, rank) for node_id, rank in s_idx]) 50 | # import pdb 51 | # pdb.set_trace() 52 | 53 | min_rank = min([ int(rank) for node_id, rank in s_idx ]) 54 | output += "T[%d]: [init_best_s=%.4f] %s | idx: %s | min_rank: %d\n" % (j, corr_per_node["max_init_score"], s_score_str, s_idx_str, min_rank) 55 | # print("T[%d]: [init_best_s=%.4f] %s | idx: %s " % (j, corr_per_node["max_init_score"], s_score_str, s_idx_str)) 56 | 57 | return output 58 | 59 | -------------------------------------------------------------------------------- /student_specialization/visualization/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import os 4 | import sys 5 | import glob 6 | import yaml 7 | from copy import deepcopy 8 | 9 | def find_params(data, cond): 10 | for d in data: 11 | found = True 12 | for k, v in cond.items(): 13 | if d["args"][k] != v: 14 | found = False 15 | if found: 16 | return d 17 | return None 18 | 19 | def find_all_params(data, cond): 20 | all_d = [] 21 | for d in data: 22 | found = True 23 | for k, v in cond.items(): 24 | if d["args"][k] != v: 25 | found = False 26 | if found: 27 | all_d.append(d) 28 | return all_d 29 | 30 | def load_stats(folder, stats_filename="stats.pickle"): 31 | print(f"Load stats from {folder}") 32 | filename = os.path.join(folder, stats_filename) 33 | if os.path.exists(filename): 34 | config_filename = os.path.join(folder, "config.yaml") 35 | if not os.path.exists(config_filename): 36 | config_filename = (os.path.join(folder, ".hydra/config.yaml")) 37 | else: 38 | return None 39 | print(f"Config file: {config_filename}, stats file: {stats_filename}") 40 | args = yaml.load(open(config_filename, "r")) 41 | stats = torch.load(filename) 42 | return dict(args=args,stats=stats,path=folder) 43 | else: 44 | print(f"The {filename} doesn't exist") 45 | return None 46 | 47 | 48 | def load_data(root, stats_filename="stats.pickle"): 49 | data = [] 50 | total = 0 51 | folders = sorted(glob.glob(os.path.join(root, "*"))) 52 | last_prefix = None 53 | 54 | for folder in folders: 55 | path, folder_name = os.path.split(folder) 56 | items = folder_name.split("_") 57 | if len(items) > 1: 58 | prefix, job_id = items 59 | if prefix == last_prefix: 60 | continue 61 | else: 62 | job_id = items[0] 63 | prefix = job_id 64 | 65 | stats = load_stats(folder, stats_filename=stats_filename) 66 | if stats is not None: 67 | print(f"{len(data)}: {folder}") 68 | data.append(stats) 69 | last_prefix = prefix 70 | 71 | return data 72 | 73 | def convert_stats_to_summary(folder): 74 | d = load_stats(folder) 75 | if d is None: 76 | print(f"Cannot find stats in {folder}, skip.") 77 | return 78 | 79 | # Only save the last stat. 80 | new_stats = [ d["stats"][0][0], d["stats"][0][-1] ] 81 | torch.save(new_stats, os.path.join(folder, "summary.pth")) 82 | 83 | -------------------------------------------------------------------------------- /student_specialization/visualization/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.image as mpimg 4 | 5 | import os 6 | import sys 7 | import pandas as pd 8 | 9 | import argparse 10 | 11 | import re 12 | import torch 13 | import json 14 | import math 15 | 16 | from utils import find_params, load_data 17 | 18 | def figure_l_shape(data): 19 | # Figure 1. L-shape with 2-layer network. 20 | multis = (1, 2, 5, 10) 21 | decays = (0, 1, 2) 22 | num_teacher = 10 23 | 24 | plt.figure(figsize=(15, 10)) 25 | 26 | counter = 1 27 | for decay in decays: 28 | for multi in multis: 29 | plt.subplot(3, len(multis), counter) 30 | counter += 1 31 | 32 | d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher)) 33 | 34 | # print("multi: ", d["args"]["multi"]) 35 | # print("decay: ", d["args"]["teacher_strength_decay"]) 36 | 37 | losses = [] 38 | 39 | for seed, stats in d["stats"].items(): 40 | s = stats[-1] 41 | 42 | corrs = s["corr_train"] 43 | 44 | norms = s["W2_s"].norm(dim=1) 45 | norms = norms[:-1] 46 | 47 | plt.scatter(corrs.max(dim=0)[0], norms) 48 | losses.append(s["eval_loss"]) 49 | 50 | if decay == 2: 51 | plt.xlabel('Correlation to the best correlated teacher') 52 | 53 | if multi == 1: 54 | plt.ylabel('norm of fan-out weights') 55 | # plt.title(f"{multi}x, loss={sum(losses) / len(losses):#.2f}") 56 | 57 | if decay == 0: 58 | plt.title(f"{multi}x") 59 | plt.axis([0.0, 1.1, -0.1, 2.5]) 60 | 61 | plt.savefig(f"l-shape-m{num_teacher}.pdf") 62 | # plt.show() 63 | 64 | def figure_success_rate(data): 65 | multis = (1, 2, 5, 10) 66 | thres = 0.95 67 | num_teacher = 20 68 | 69 | plt.figure(figsize=(12, 2.5)) 70 | # plt.figure() 71 | 72 | counter = 0 73 | 74 | # fig, ax = plt.subplots(figsize=(6, 5)) 75 | for decay in (0.5, 1, 1.5, 2, 2.5): 76 | ax = plt.subplot(1, 5, counter + 1) 77 | counter += 1 78 | for iter, style in zip((5, -1), (':', '-')): 79 | bars = [] 80 | ind = torch.FloatTensor(list(range(num_teacher))) 81 | # width = 0.15 82 | colors = ['r', 'g','b','c'] 83 | for i, multi in enumerate(multis): 84 | #plt.subplot(1, len(multis), counter) 85 | #counter += 1 86 | 87 | d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher)) 88 | 89 | losses = [] 90 | 91 | counts = None 92 | for seed, stats in d["stats"].items(): 93 | s = stats[iter] 94 | v = (s["counts_eval"][thres] > 0).float() 95 | if counts is None: 96 | counts = v 97 | else: 98 | counts += v 99 | 100 | losses.append(s["eval_loss"]) 101 | 102 | counts /= len(d["stats"]) 103 | plt.plot(ind.numpy(), counts.numpy(), colors[i], label=f"{multi}x" if iter == -1 else None, linestyle=style) 104 | # plt.scatter(ind.numpy(), counts.numpy(), color=colors[i]) 105 | 106 | # plt.title(f"multi={multi}, loss={sum(losses) / len(losses):#.5f}") 107 | # plt.title(f"iter={iter}") 108 | 109 | plt.xlabel('Teacher idx') 110 | plt.title(f"$p={decay}$") 111 | plt.axis([-1, num_teacher, 0, 1.1]) 112 | if counter == 1: 113 | plt.ylabel('Successful Recovery Rate') 114 | plt.legend() 115 | 116 | ticks = ind[::4].numpy() 117 | 118 | ax.set_xticks(ticks) 119 | ax.set_xticklabels([ str(int(i)) for i in ticks ]) 120 | if counter > 1: 121 | ax.set_yticklabels([]) 122 | 123 | # ax.legend(bars, [ f"{multi}x" for multi in multis ]) 124 | 125 | plt.tight_layout() 126 | 127 | plt.savefig(f"rate_drop_m{num_teacher}_thres{thres}.pdf") 128 | # plt.show() 129 | 130 | def figure_loss(data): 131 | multis = (1, 2, 5, 10) 132 | decays = (0, 0.5, 1, 1.5, 2, 2.5) 133 | num_teacher = 20 134 | 135 | plt.figure(figsize=(15, 7)) 136 | # plt.figure() 137 | 138 | counter = 1 139 | 140 | # fig, ax = plt.subplots(figsize=(6, 5)) 141 | for decay in decays: 142 | ax = plt.subplot(2, len(decays) / 2, counter) 143 | counter += 1 144 | for i, multi in enumerate(multis): 145 | d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher)) 146 | losses = None 147 | for j, (seed, stats) in enumerate(d["stats"].items()): 148 | v = torch.DoubleTensor([ math.log(s["eval_loss"]) / math.log(10.0) for s in stats ]) 149 | if losses is None: 150 | losses = torch.DoubleTensor(len(stats), len(d["stats"])) 151 | losses[:, j] = v 152 | 153 | loss = losses.mean(dim=1) 154 | loss_std = losses.std(dim=1) 155 | p = plt.plot(loss.numpy(), label=f"{multi}x") 156 | plt.fill_between(list(range(loss.size(0))), (loss - loss_std).numpy(), (loss + loss_std).numpy(), color=p[0].get_color(), alpha=0.2) 157 | 158 | if counter >= 5: 159 | plt.xlabel('Epoch') 160 | 161 | if counter == 2 or counter == 5: 162 | plt.ylabel('Evaluation log loss') 163 | else: 164 | ax.set_yticklabels([]) 165 | 166 | plt.title(f"$p={decay}$") 167 | plt.axis([0, 100, -8, 0]) 168 | 169 | if counter == 2: 170 | plt.legend() 171 | 172 | plt.savefig(f"convergence_m{num_teacher}.pdf") 173 | # plt.show() 174 | 175 | if __name__ == "__main__": 176 | parser = argparse.ArgumentParser(description='') 177 | parser.add_argument('root', type=str, help="root directory") 178 | 179 | args = parser.parse_args() 180 | 181 | data = load_data(args.root) 182 | plot_max_corr_alpha(stats, teacher_thres=0.2, student_thres=0.6) 183 | -------------------------------------------------------------------------------- /student_specialization/visualization/visualize_multi.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.image as mpimg 5 | 6 | def plot_multilayer_l_shape(stats, epoch_split=5, save_file=None, epoch_till=None, beta_range=None): 7 | s = stats[0][-1] 8 | num_layer = len(s["train_corrs"]) 9 | 10 | total_epoch = len(stats[0]) - 1 11 | if epoch_till is not None and epoch_till < total_epoch: 12 | total_epoch = epoch_till 13 | 14 | epochs = [ int(i * total_epoch / (epoch_split - 1)) for i in range(epoch_split) ] 15 | 16 | plt.figure(figsize=(20, 10)) 17 | count = 0 18 | 19 | for layer in range(num_layer - 1, -1, -1): 20 | print(f"{layer}: student/teacher: {s['train_corrs'][layer].size()}") 21 | 22 | for it in epochs: 23 | count += 1 24 | ax = plt.subplot(num_layer, len(epochs), count) 25 | 26 | s = stats[0][it] 27 | train_corrs = s["train_corrs"][layer] 28 | alphas = s["train_betas_s"][layer][:-1,:-1] 29 | betas = s["train_betas"][layer][:-1, :-1].diag() 30 | 31 | student_usefulness, best_matched_teacher_indices = train_corrs.max(dim=1) 32 | plt.scatter(student_usefulness.numpy(), betas.sqrt().numpy(), alpha=0.2) 33 | 34 | if it == 0: 35 | plt.ylabel("$\\sqrt{\\mathbb{E}_{\\mathbf{x}}\\left[\\beta_{kk}(\\mathbf{x})\\right]}$") 36 | else: 37 | if beta_range is not None: 38 | ax.set_yticklabels([]) 39 | 40 | if layer == 0: 41 | plt.xlabel("Max correlation among teacher") 42 | 43 | plt.axis([-0.05, 1.05, -0.001, beta_range]) 44 | 45 | if layer == 3: 46 | plt.title(f"Epoch {it}") 47 | # plt.legend() 48 | 49 | if save_file is not None: 50 | plt.savefig(save_file) 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser(description='') 54 | parser.add_argument('root', type=str, help="root directory") 55 | parser.add_argument("--save_file", type=str, default="multilayer_l_shape.pdf") 56 | 57 | args = parser.parse_args() 58 | 59 | stats = load_stats(args.root) 60 | plot_multilayer_l_shape(stats, save_file=args.save_file) 61 | 62 | --------------------------------------------------------------------------------