├── README.md ├── 1-guillotine ├── True_train.csv ├── False_test.csv ├── False_train.csv ├── True_test.csv ├── README.md ├── guillotine.py └── reader.ipynb ├── 2-pretext ├── README.md ├── guillotine.py └── augmentations.py └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # SSL-course 2 | Course to learn the basics of self supervised learning 3 | -------------------------------------------------------------------------------- /1-guillotine/True_train.csv: -------------------------------------------------------------------------------- 1 | ,value_orientation,value_scale,value_shape,value_x_position,value_y_position 2 | value_orientation,0.6875,0.2455,0.7739,0.032,0.0329 3 | value_scale,0.9145,0.0395,0.3324,0.017,0.0182 4 | value_shape,0.8819,0.0515,0.2367,0.016,0.0167 5 | value_x_position,0.854,0.1666,0.735,0.0103,0.0131 6 | value_y_position,0.9376,0.164,0.7449,0.0138,0.0104 7 | -------------------------------------------------------------------------------- /1-guillotine/False_test.csv: -------------------------------------------------------------------------------- 1 | ,value_orientation,value_scale,value_shape,value_x_position,value_y_position 2 | value_orientation,0.6638,0.1622,0.7408,0.0139,0.0162 3 | value_scale,0.9508,0.1067,0.5402,0.0154,0.0161 4 | value_shape,0.8724,0.0787,0.1236,0.0135,0.0132 5 | value_x_position,0.8271,0.15,0.738,0.0027,0.0088 6 | value_y_position,0.933,0.1585,0.7654,0.0091,0.0032 7 | -------------------------------------------------------------------------------- /1-guillotine/False_train.csv: -------------------------------------------------------------------------------- 1 | ,value_orientation,value_scale,value_shape,value_x_position,value_y_position 2 | value_orientation,0.6669,0.172,0.7471,0.0133,0.0166 3 | value_scale,0.9343,0.0238,0.2382,0.0149,0.0157 4 | value_shape,0.8741,0.041,0.1005,0.0142,0.0134 5 | value_x_position,0.8326,0.1576,0.7476,0.0029,0.0088 6 | value_y_position,0.9331,0.1664,0.7728,0.0087,0.0031 7 | -------------------------------------------------------------------------------- /1-guillotine/True_test.csv: -------------------------------------------------------------------------------- 1 | ,value_orientation,value_scale,value_shape,value_x_position,value_y_position 2 | value_orientation,1.1126,0.2899,0.7979,0.0381,0.0394 3 | value_scale,0.9103,0.0487,0.3986,0.0158,0.0155 4 | value_shape,0.8817,0.0412,0.2107,0.0142,0.0146 5 | value_x_position,0.8443,0.157,0.7173,0.0104,0.0125 6 | value_y_position,0.9362,0.1532,0.7279,0.014,0.0106 7 | -------------------------------------------------------------------------------- /2-pretext/README.md: -------------------------------------------------------------------------------- 1 | # Equivariant Pretext-Task 2 | 3 | - Implementation of data augmentations that output the applied transformation parameters is provided in the [augmentations.py](./augentations.py) files. 4 | 5 | - To quickly visualize the augmentation and parameters, you can run the [debug.ipynb](./debug.ipynb) notebook 6 | 7 | - use that core codebase to experiment your own pretext-task idea! -------------------------------------------------------------------------------- /1-guillotine/README.md: -------------------------------------------------------------------------------- 1 | # Guillotine regularization 2 | 3 | ## Guillotine Regularization: Why removing layers is needed to improve generalization in Self-Supervised Learning ([link](https://arxiv.org/pdf/2206.13378)) 4 | 5 | One unexpected technique that emerged in recent years consists in training a Deep Network (DN) with a Self-Supervised Learning (SSL) method, and using this network on downstream tasks but with its last few projector layers entirely removed. This trick of throwing away the projector is actually critical for SSL methods to display competitive performances on ImageNet for which more than 30 percentage points can be gained that way. This is a little vexing, as one would hope that the network layer at which invariance is explicitly enforced by the SSL criterion during training (the last projector layer) should be the one to use for best generalization performance downstream. But it seems not to be, and this study sheds some light on why. This trick, which we name Guillotine Regularization (GR), is in fact a generically applicable method that has been used to improve generalization performance in transfer learning scenarios. In this work, we identify the underlying reasons behind its success and show that the optimal layer to use might change significantly depending on the training setup, the data or the downstream task. Lastly, we give some insights on how to reduce the need for a projector in SSL by aligning the pretext SSL task and the downstream task. 6 | 7 | 8 | ## Toy experiment on [dsprites](https://github.com/google-deepmind/dsprites-dataset) dataset 9 | 10 | - run the training pipeline using the [guillotine.py](./guillotine.py) Python file 11 | - This will generate 4 .csv files with the results (assumes you have 1 GPU available) 12 | - read the saved save using the [reader.ipynb](./reader.ipynb) Jupyter notebook 13 | - This will produce the two heatmaps to compare performances per-task with and without guillotine -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | lightning_logs/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | *.gz 7 | cifar-10-batches-py/ 8 | cifar-100-batches-py/ 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /2-pretext/guillotine.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch.nn import MSELoss, CrossEntropyLoss 4 | from torch.optim import AdamW 5 | from torch.utils.data import DataLoader 6 | import torch.nn as nn 7 | import datasets 8 | from torchvision.transforms import v2 9 | import pandas as pd 10 | from pathlib import Path 11 | from pytorch_lightning.loggers import CSVLogger 12 | import augmentations 13 | from torchvision.datasets import CIFAR10 14 | import torchmetrics 15 | 16 | pl.seed_everything(42) 17 | 18 | train_transform = v2.Compose( 19 | [ 20 | augmentations.AddParams(), 21 | augmentations.RandomHorizontalFlip(), 22 | augmentations.ColorJitter(0.4, 0.4, 0.4, 0.2), 23 | augmentations.RandomGrayscale(0.5), 24 | augmentations.RandomRotation(30), 25 | augmentations.ToTensor(), 26 | ] 27 | ) 28 | test_transform = augmentations.ToTensor() 29 | 30 | 31 | def get_CIFAR10_dataset(): 32 | dataset = {} 33 | dataset["train"] = CIFAR10( 34 | root="./", download=True, train=True, transform=train_transform 35 | ) 36 | dataset["test"] = CIFAR10( 37 | root="./", download=True, train=False, transform=test_transform 38 | ) 39 | return dataset 40 | 41 | 42 | class MyModel(pl.LightningModule): 43 | def __init__(self, target: str, guillotine: bool = False, dataset=None): 44 | """ 45 | target: (str) the name of the dataset target to use to train the backbone 46 | guillotine: (bool) whether to add a projector and use guillotine (or not) 47 | """ 48 | super().__init__() 49 | self.target = target 50 | self.dataset = dataset 51 | self.fc = nn.Sequential( 52 | nn.Conv2d(3, 32, kernel_size=5, bias=False), 53 | nn.BatchNorm2d(32), 54 | nn.ReLU(), 55 | nn.Conv2d(32, 64, kernel_size=3, bias=False, stride=2), 56 | nn.BatchNorm2d(64), 57 | nn.ReLU(), 58 | nn.Conv2d(64, 128, kernel_size=3, bias=False, stride=2), 59 | nn.BatchNorm2d(128), 60 | nn.ReLU(), 61 | nn.AdaptiveAvgPool2d((4, 4)), 62 | nn.Flatten(), 63 | ) 64 | if guillotine: 65 | self.projector = nn.Sequential( 66 | nn.Linear(4 * 4 * 128, 4 * 4 * 128, bias=False), 67 | nn.BatchNorm1d(4 * 4 * 128), 68 | nn.ReLU(), 69 | nn.Linear(4 * 4 * 128, 4 * 4 * 128, bias=False), 70 | nn.BatchNorm1d(4 * 4 * 128), 71 | nn.ReLU(), 72 | ) 73 | else: 74 | self.projector = nn.Identity() 75 | self.sup_probe = nn.Sequential(nn.Dropout1d(0.2), nn.Linear(4 * 4 * 128, 10)) 76 | self.probe = nn.Linear(4 * 4 * 128, 10) 77 | self.criterion = CrossEntropyLoss() 78 | self.evaluate = torchmetrics.classification.MulticlassAccuracy( 79 | num_classes=10, average=None 80 | ) 81 | 82 | def forward(self, inputs_id): 83 | outputs = self.fc(inputs_id) 84 | preds = self.probe(outputs.detach()) 85 | return self.sup_probe(self.projector(outputs)), preds 86 | 87 | def get_losses(self, batch): 88 | print(batch) 89 | if self.training: 90 | input_ids = batch[0][0] 91 | else: 92 | input_ids = batch[0] 93 | labels = batch[1] 94 | 95 | outputs, preds = self(input_ids) 96 | probe_losses = self.criterion(preds, labels) 97 | sup_loss = self.criterion(outputs, labels) 98 | return probe_losses, sup_loss, range(10) 99 | 100 | def training_step(self, batch, batch_idx): 101 | probe_losses, sup_loss, label_names = self.get_losses(batch) 102 | loss = probe_losses + sup_loss 103 | log_dict = {} 104 | # log_dict = {f"{name}": p for name, p in zip(label_names, probe_losses.tolist())} 105 | log_dict["target"] = sup_loss.item() 106 | log_dict["epoch"] = self.current_epoch 107 | self.loggers[0].log_metrics(log_dict) 108 | return loss 109 | 110 | def validation_step(self, batch, batch_idx): 111 | _, preds = self(batch[0]) 112 | label_names = range(10) 113 | log_dict = { 114 | f"{name}": p for name, p in zip(label_names, self.evaluate(preds, batch[1])) 115 | } 116 | log_dict["epoch"] = self.current_epoch 117 | self.loggers[1].log_metrics(log_dict) 118 | 119 | def configure_optimizers(self): 120 | optimizer = AdamW(self.parameters(), weight_decay=1e-4, lr=1e-4) 121 | return optimizer 122 | 123 | def optimizer_zero_grad(self, epoch, batch_idx, optimizer): 124 | optimizer.zero_grad(set_to_none=True) 125 | 126 | def train_dataloader(self): 127 | return DataLoader( 128 | self.dataset["train"], 129 | shuffle=True, 130 | drop_last=True, 131 | batch_size=256, 132 | num_workers=10, 133 | persistent_workers=True, 134 | ) 135 | 136 | def val_dataloader(self): 137 | return DataLoader(dataset["test"], batch_size=512, num_workers=10) 138 | 139 | 140 | if __name__ == "__main__": 141 | 142 | for guillotine in [False, True]: 143 | target = "rotation" 144 | dataset = get_CIFAR10_dataset() 145 | train_results = [] 146 | test_results = [] 147 | model = MyModel(target=target, guillotine=guillotine, dataset=dataset) 148 | model = torch.compile(model) 149 | train_logger = CSVLogger(save_dir="lightning_logs", name="train") 150 | val_logger = CSVLogger(save_dir="lightning_logs", name="val") 151 | trainer = pl.Trainer( 152 | max_epochs=5, 153 | accelerator="gpu", 154 | devices=1, 155 | precision="16-mixed", 156 | logger=[train_logger, val_logger], 157 | enable_checkpointing=False, 158 | ) 159 | trainer.fit(model) 160 | train_metrics = pd.read_csv(Path(train_logger.log_dir) / "metrics.csv") 161 | print(train_metrics.shape) 162 | train_metrics = train_metrics.set_index("step").groupby("epoch").mean() 163 | train_results.append(train_metrics.iloc[[-1]]) 164 | val_metrics = pd.read_csv(Path(val_logger.log_dir) / "metrics.csv") 165 | print(val_metrics.shape) 166 | val_metrics = val_metrics.set_index("step").groupby("epoch").mean() 167 | test_results.append(val_metrics.iloc[[-1]]) 168 | 169 | train_results = pd.concat(train_results, axis=0) 170 | train_results = train_results[TARGETS["dsprite"]] 171 | train_results.index = TARGETS["dsprite"] 172 | 173 | test_results = pd.concat(test_results, axis=0) 174 | test_results = test_results[TARGETS["dsprite"]] 175 | test_results.index = TARGETS["dsprite"] 176 | 177 | train_results.round(4).to_csv(f"./{guillotine}_train.csv") 178 | test_results.round(4).to_csv(f"./{guillotine}_test.csv") 179 | -------------------------------------------------------------------------------- /1-guillotine/guillotine.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch.nn import MSELoss 4 | from torch.optim import AdamW 5 | from torch.utils.data import DataLoader 6 | import torch.nn as nn 7 | import datasets 8 | from torchvision.transforms import v2 9 | import pandas as pd 10 | from pathlib import Path 11 | from pytorch_lightning.loggers import CSVLogger 12 | 13 | pl.seed_everything(42) 14 | 15 | TARGETS = { 16 | "dsprite": [ 17 | "value_orientation", 18 | "value_scale", 19 | "value_shape", 20 | "value_x_position", 21 | "value_y_position", 22 | ] 23 | } 24 | 25 | train_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) 26 | test_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) 27 | 28 | 29 | def train_transforms(examples): 30 | examples["image"] = [ 31 | train_transform(img.convert("RGB")) for img in examples["image"] 32 | ] 33 | return examples 34 | 35 | 36 | def test_transforms(examples): 37 | examples["image"] = [ 38 | test_transform(img.convert("RGB")) for img in examples["image"] 39 | ] 40 | return examples 41 | 42 | 43 | def get_dsprites_dataset(): 44 | dataset = datasets.load_dataset("eurecom-ds/dsprites").remove_columns( 45 | [ 46 | "label_orientation", 47 | "label_scale", 48 | "label_shape", 49 | "label_x_position", 50 | "label_y_position", 51 | ] 52 | ) 53 | dataset = dataset["train"].train_test_split(test_size=0.5) 54 | for target in TARGETS["dsprite"]: 55 | mean = dataset["train"].with_format("pandas")[target].mean() 56 | std = dataset["train"].with_format("pandas")[target].std() 57 | dataset["train"] = dataset["train"].map(lambda row:{target: (row[target] - mean)/(1e-5+std)}, batched=True, batch_size=2048) 58 | dataset["test"] = dataset["test"].map(lambda row:{target: (row[target] - mean)/(1e-5+std)}, batched=True, batch_size=2048) 59 | dataset["train"] = dataset["train"].with_transform(train_transforms) 60 | dataset["test"] = dataset["test"].with_transform(test_transforms) 61 | return dataset 62 | 63 | class MyModel(pl.LightningModule): 64 | def __init__(self, target: str, guillotine: bool = False, dataset=None): 65 | """ 66 | target: (str) the name of the dataset target to use to train the backbone 67 | guillotine: (bool) whether to add a projector and use guillotine (or not) 68 | """ 69 | super().__init__() 70 | self.target = target 71 | self.dataset = dataset 72 | self.fc = nn.Sequential( 73 | nn.Conv2d(3, 32, kernel_size=5, bias=False), 74 | nn.BatchNorm2d(32), 75 | nn.ReLU(), 76 | nn.Conv2d(32, 64, kernel_size=3, bias=False, stride=2), 77 | nn.BatchNorm2d(64), 78 | nn.ReLU(), 79 | nn.Conv2d(64, 128, kernel_size=3, bias=False, stride=2), 80 | nn.BatchNorm2d(128), 81 | nn.ReLU(), 82 | nn.AdaptiveAvgPool2d((4,4)), 83 | nn.Flatten(), 84 | ) 85 | if guillotine: 86 | self.projector = nn.Sequential( 87 | nn.Linear(4*4*128, 4*4*128, bias=False), 88 | nn.BatchNorm1d(4*4*128), 89 | nn.ReLU(), 90 | nn.Linear(4*4*128, 4*4*128, bias=False), 91 | nn.BatchNorm1d(4*4*128), 92 | nn.ReLU(), 93 | ) 94 | else: 95 | self.projector = nn.Identity() 96 | self.sup_probe = nn.Sequential(nn.Dropout1d(0.2), nn.Linear(4*4*128, 1)) 97 | self.probe = nn.Linear(4*4*128, 5) 98 | self.criterion = MSELoss(reduction="none") 99 | 100 | def forward(self, inputs_id): 101 | outputs = self.fc(inputs_id) 102 | preds = self.probe(outputs.detach()) 103 | return self.sup_probe(self.projector(outputs)), preds 104 | 105 | 106 | 107 | def get_losses(self, batch): 108 | input_ids = batch["image"] 109 | label_names = [name for name in batch.keys() if name != "image"] 110 | labels = torch.stack([batch[name] for name in label_names], 1) 111 | 112 | outputs, preds = self(input_ids) 113 | probe_losses = self.criterion(preds, labels.float()).mean(0) 114 | sup_loss = self.criterion( 115 | outputs.squeeze(), batch[self.target].squeeze().float() 116 | ).mean() 117 | return probe_losses, sup_loss, label_names 118 | 119 | def training_step(self, batch, batch_idx): 120 | probe_losses, sup_loss, label_names = self.get_losses(batch) 121 | loss = probe_losses.mean() + sup_loss 122 | log_dict = {f"{name}": p for name, p in zip(label_names, probe_losses.tolist())} 123 | log_dict["target"] = sup_loss.item() 124 | log_dict["epoch"] = self.current_epoch 125 | self.loggers[0].log_metrics(log_dict) 126 | return loss 127 | 128 | def validation_step(self, batch, batch_idx): 129 | probe_losses, _, label_names = self.get_losses(batch) 130 | log_dict = {f"{name}": p for name, p in zip(label_names, probe_losses.tolist())} 131 | log_dict["epoch"] = self.current_epoch 132 | self.loggers[1].log_metrics(log_dict) 133 | 134 | def configure_optimizers(self): 135 | optimizer = AdamW(self.parameters(), weight_decay=1e-4, lr=1e-4) 136 | return optimizer 137 | 138 | def optimizer_zero_grad(self, epoch, batch_idx, optimizer): 139 | optimizer.zero_grad(set_to_none=True) 140 | 141 | def train_dataloader(self): 142 | return DataLoader( 143 | self.dataset["train"], 144 | shuffle=True, 145 | drop_last=True, 146 | batch_size=256, 147 | num_workers=10, 148 | persistent_workers=True, 149 | ) 150 | 151 | def val_dataloader(self): 152 | return DataLoader(dataset["test"], batch_size=512, num_workers=10) 153 | 154 | 155 | if __name__ == "__main__": 156 | dataset = get_dsprites_dataset() 157 | 158 | for guillotine in [False, True]: 159 | train_results = [] 160 | test_results = [] 161 | for target in TARGETS["dsprite"]: 162 | model = MyModel(target=target, guillotine=guillotine, dataset=dataset) 163 | model = torch.compile(model) 164 | train_logger = CSVLogger(save_dir="lightning_logs", name="train") 165 | val_logger = CSVLogger(save_dir="lightning_logs", name="val") 166 | trainer = pl.Trainer( 167 | max_epochs=5, 168 | accelerator="gpu", 169 | devices=1, 170 | precision="16-mixed", 171 | logger=[train_logger, val_logger], 172 | enable_checkpointing=False, 173 | ) 174 | trainer.fit(model) 175 | train_metrics = pd.read_csv(Path(train_logger.log_dir) / "metrics.csv") 176 | print(train_metrics.shape) 177 | train_metrics = train_metrics.set_index("step").groupby("epoch").mean() 178 | train_results.append(train_metrics.iloc[[-1]]) 179 | val_metrics = pd.read_csv(Path(val_logger.log_dir) / "metrics.csv") 180 | print(val_metrics.shape) 181 | val_metrics = val_metrics.set_index("step").groupby("epoch").mean() 182 | test_results.append(val_metrics.iloc[[-1]]) 183 | 184 | train_results = pd.concat(train_results, axis=0) 185 | train_results = train_results[TARGETS["dsprite"]] 186 | train_results.index = TARGETS["dsprite"] 187 | 188 | test_results = pd.concat(test_results, axis=0) 189 | test_results = test_results[TARGETS["dsprite"]] 190 | test_results.index = TARGETS["dsprite"] 191 | 192 | train_results.round(4).to_csv(f"./{guillotine}_train.csv") 193 | test_results.round(4).to_csv(f"./{guillotine}_test.csv") 194 | -------------------------------------------------------------------------------- /2-pretext/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numbers 4 | import torchvision 5 | from collections.abc import Sequence 6 | from torchvision.transforms.functional import ( 7 | _interpolation_modes_from_int, 8 | InterpolationMode, 9 | ) 10 | from typing import List, Optional, Tuple, Union 11 | import numpy as np 12 | 13 | class AddParams(torch.nn.Module): 14 | """Randomly convert image to grayscale with a probability of p (default 0.1). 15 | If the image is torch Tensor, it is expected 16 | to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions 17 | 18 | Args: 19 | p (float): probability that image should be converted to grayscale. 20 | 21 | Returns: 22 | PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged 23 | with probability (1-p). 24 | - If input image is 1 channel: grayscale version is 1 channel 25 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 26 | 27 | """ 28 | 29 | def forward(self, img): 30 | """ 31 | Args: 32 | img (PIL Image or Tensor): Image to be converted to grayscale. 33 | 34 | Returns: 35 | PIL Image or Tensor: Randomly grayscaled image. 36 | empty list to stack DA parameters 37 | """ 38 | return img, [] 39 | 40 | def __repr__(self) -> str: 41 | return f"{self.__class__.__name__}({self.p})" 42 | 43 | 44 | class RandomGrayscale(torch.nn.Module): 45 | """Randomly convert image to grayscale with a probability of p (default 0.1). 46 | If the image is torch Tensor, it is expected 47 | to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions 48 | 49 | Args: 50 | p (float): probability that image should be converted to grayscale. 51 | 52 | Returns: 53 | PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged 54 | with probability (1-p). 55 | - If input image is 1 channel: grayscale version is 1 channel 56 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 57 | 58 | """ 59 | 60 | def __init__(self, p=0.1): 61 | super().__init__() 62 | self.p = p 63 | 64 | def forward(self, inputs): 65 | """ 66 | Args: 67 | img (PIL Image or Tensor): Image to be converted to grayscale. 68 | 69 | Returns: 70 | PIL Image or Tensor: Randomly grayscaled image. 71 | """ 72 | img, params = inputs 73 | num_output_channels, _, _ = torchvision.transforms.functional.get_dimensions( 74 | img 75 | ) 76 | if torch.rand(1) < self.p: 77 | return torchvision.transforms.functional.rgb_to_grayscale( 78 | img, num_output_channels=num_output_channels 79 | ), params + [1] 80 | return img, params + [0] 81 | 82 | def __repr__(self) -> str: 83 | return f"{self.__class__.__name__}(p={self.p})" 84 | 85 | 86 | def _check_sequence_input(x, name, req_sizes): 87 | msg = ( 88 | req_sizes[0] if len(req_sizes) < 2 else " or ".join([str(s) for s in req_sizes]) 89 | ) 90 | if not isinstance(x, Sequence): 91 | raise TypeError(f"{name} should be a sequence of length {msg}.") 92 | if len(x) not in req_sizes: 93 | raise ValueError(f"{name} should be a sequence of length {msg}.") 94 | 95 | 96 | def _setup_angle(x, name, req_sizes=(2,)): 97 | if isinstance(x, numbers.Number): 98 | if x < 0: 99 | raise ValueError(f"If {name} is a single number, it must be positive.") 100 | x = [-x, x] 101 | else: 102 | _check_sequence_input(x, name, req_sizes) 103 | 104 | return [float(d) for d in x] 105 | 106 | 107 | class RandomRotation(torch.nn.Module): 108 | """Rotate the image by angle. 109 | If the image is torch Tensor, it is expected 110 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. 111 | 112 | Args: 113 | degrees (sequence or number): Range of degrees to select from. 114 | If degrees is a number instead of sequence like (min, max), the range of degrees 115 | will be (-degrees, +degrees). 116 | interpolation (InterpolationMode): Desired interpolation enum defined by 117 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. 118 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. 119 | The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. 120 | expand (bool, optional): Optional expansion flag. 121 | If true, expands the output to make it large enough to hold the entire rotated image. 122 | If false or omitted, make the output image the same size as the input image. 123 | Note that the expand flag assumes rotation around the center and no translation. 124 | center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. 125 | Default is the center of the image. 126 | fill (sequence or number): Pixel fill value for the area outside the rotated 127 | image. Default is ``0``. If given a number, the value is used for all bands respectively. 128 | 129 | .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters 130 | 131 | """ 132 | 133 | def __init__( 134 | self, 135 | degrees, 136 | interpolation=InterpolationMode.NEAREST, 137 | expand=False, 138 | center=None, 139 | fill=0, 140 | ): 141 | super().__init__() 142 | 143 | if isinstance(interpolation, int): 144 | interpolation = _interpolation_modes_from_int(interpolation) 145 | 146 | self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) 147 | 148 | self.center = center 149 | 150 | self.interpolation = interpolation 151 | self.expand = expand 152 | 153 | if fill is None: 154 | fill = 0 155 | elif not isinstance(fill, (Sequence, numbers.Number)): 156 | raise TypeError("Fill should be either a sequence or a number.") 157 | 158 | self.fill = fill 159 | 160 | @staticmethod 161 | def get_params(degrees: List[float]) -> float: 162 | """Get parameters for ``rotate`` for a random rotation. 163 | 164 | Returns: 165 | float: angle parameter to be passed to ``rotate`` for random rotation. 166 | """ 167 | angle = float( 168 | torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item() 169 | ) 170 | return angle 171 | 172 | def forward(self, inputs): 173 | """ 174 | Args: 175 | img (PIL Image or Tensor): Image to be rotated. 176 | 177 | Returns: 178 | PIL Image or Tensor: Rotated image. 179 | """ 180 | img, params = inputs 181 | fill = self.fill 182 | channels, _, _ = torchvision.transforms.functional.get_dimensions(img) 183 | if isinstance(img, torch.Tensor): 184 | if isinstance(fill, (int, float)): 185 | fill = [float(fill)] * channels 186 | else: 187 | fill = [float(f) for f in fill] 188 | angle = self.get_params(self.degrees) 189 | 190 | return torchvision.transforms.functional.rotate( 191 | img, angle, self.interpolation, self.expand, self.center, fill 192 | ), params + [angle] 193 | 194 | def __repr__(self) -> str: 195 | interpolate_str = self.interpolation.value 196 | format_string = self.__class__.__name__ + f"(degrees={self.degrees}" 197 | format_string += f", interpolation={interpolate_str}" 198 | format_string += f", expand={self.expand}" 199 | if self.center is not None: 200 | format_string += f", center={self.center}" 201 | if self.fill is not None: 202 | format_string += f", fill={self.fill}" 203 | format_string += ")" 204 | return format_string 205 | 206 | 207 | class RandomHorizontalFlip(torch.nn.Module): 208 | """Horizontally flip the given image randomly with a given probability. 209 | If the image is torch Tensor, it is expected 210 | to have [..., H, W] shape, where ... means an arbitrary number of leading 211 | dimensions 212 | 213 | Args: 214 | p (float): probability of the image being flipped. Default value is 0.5 215 | """ 216 | 217 | def __init__(self, p=0.5): 218 | super().__init__() 219 | self.p = p 220 | 221 | def forward(self, inputs): 222 | """ 223 | Args: 224 | img (PIL Image or Tensor): Image to be flipped. 225 | 226 | Returns: 227 | PIL Image or Tensor: Randomly flipped image. 228 | """ 229 | img, params = inputs 230 | if torch.rand(1) < self.p: 231 | return torchvision.transforms.functional.hflip(img), params + [1.0] 232 | return img, params + [0.0] 233 | 234 | def __repr__(self) -> str: 235 | return f"{self.__class__.__name__}(p={self.p})" 236 | 237 | 238 | class ColorJitter(torch.nn.Module): 239 | """Randomly change the brightness, contrast, saturation and hue of an image. 240 | If the image is torch Tensor, it is expected 241 | to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. 242 | If img is PIL Image, mode "1", "I", "F" and modes with transparency (alpha channel) are not supported. 243 | 244 | Args: 245 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 246 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 247 | or the given [min, max]. Should be non negative numbers. 248 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 249 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 250 | or the given [min, max]. Should be non-negative numbers. 251 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 252 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 253 | or the given [min, max]. Should be non negative numbers. 254 | hue (float or tuple of float (min, max)): How much to jitter hue. 255 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 256 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 257 | To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; 258 | thus it does not work if you normalize your image to an interval with negative values, 259 | or use an interpolation that generates negative values before using this function. 260 | """ 261 | 262 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0) -> None: 263 | super().__init__() 264 | self.brightness = self._check_input(brightness, "brightness") 265 | self.contrast = self._check_input(contrast, "contrast") 266 | self.saturation = self._check_input(saturation, "saturation") 267 | self.hue = self._check_input( 268 | hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False 269 | ) 270 | 271 | @torch.jit.unused 272 | def _check_input( 273 | self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True 274 | ): 275 | if isinstance(value, numbers.Number): 276 | if value < 0: 277 | raise ValueError( 278 | f"If {name} is a single number, it must be non negative." 279 | ) 280 | value = [center - float(value), center + float(value)] 281 | if clip_first_on_zero: 282 | value[0] = max(value[0], 0.0) 283 | elif isinstance(value, (tuple, list)) and len(value) == 2: 284 | value = [float(value[0]), float(value[1])] 285 | else: 286 | raise TypeError( 287 | f"{name} should be a single number or a list/tuple with length 2." 288 | ) 289 | 290 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 291 | raise ValueError( 292 | f"{name} values should be between {bound}, but got {value}." 293 | ) 294 | 295 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 296 | # or (0., 0.) for hue, do nothing 297 | if value[0] == value[1] == center: 298 | return None 299 | else: 300 | return tuple(value) 301 | 302 | @staticmethod 303 | def get_params(brightness, contrast, saturation, hue): 304 | """Get the parameters for the randomized transform to be applied on image. 305 | 306 | Args: 307 | brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen 308 | uniformly. Pass None to turn off the transformation. 309 | contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen 310 | uniformly. Pass None to turn off the transformation. 311 | saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen 312 | uniformly. Pass None to turn off the transformation. 313 | hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. 314 | Pass None to turn off the transformation. 315 | 316 | Returns: 317 | tuple: The parameters used to apply the randomized transform 318 | along with their random order. 319 | """ 320 | fn_idx = torch.randperm(4) 321 | 322 | b = ( 323 | None 324 | if brightness is None 325 | else float(torch.empty(1).uniform_(brightness[0], brightness[1])) 326 | ) 327 | c = ( 328 | None 329 | if contrast is None 330 | else float(torch.empty(1).uniform_(contrast[0], contrast[1])) 331 | ) 332 | s = ( 333 | None 334 | if saturation is None 335 | else float(torch.empty(1).uniform_(saturation[0], saturation[1])) 336 | ) 337 | h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) 338 | 339 | return fn_idx, b, c, s, h 340 | 341 | def forward(self, inputs): 342 | """ 343 | Args: 344 | img (PIL Image or Tensor): Input image. 345 | 346 | Returns: 347 | PIL Image or Tensor: Color jittered image. 348 | """ 349 | img, params = inputs 350 | 351 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = ( 352 | self.get_params(self.brightness, self.contrast, self.saturation, self.hue) 353 | ) 354 | 355 | for fn_id in fn_idx: 356 | if fn_id == 0 and brightness_factor is not None: 357 | img = torchvision.transforms.functional.adjust_brightness(img, brightness_factor) 358 | elif fn_id == 1 and contrast_factor is not None: 359 | img = torchvision.transforms.functional.adjust_contrast(img, contrast_factor) 360 | elif fn_id == 2 and saturation_factor is not None: 361 | img = torchvision.transforms.functional.adjust_saturation(img, saturation_factor) 362 | elif fn_id == 3 and hue_factor is not None: 363 | img = torchvision.transforms.functional.adjust_hue(img, hue_factor) 364 | 365 | return img, params + [ 366 | brightness_factor, 367 | contrast_factor, 368 | saturation_factor, 369 | hue_factor, 370 | ] 371 | 372 | def __repr__(self) -> str: 373 | s = ( 374 | f"{self.__class__.__name__}(" 375 | f"brightness={self.brightness}" 376 | f", contrast={self.contrast}" 377 | f", saturation={self.saturation}" 378 | f", hue={self.hue})" 379 | ) 380 | return s 381 | 382 | 383 | class ToTensor: 384 | """Convert a PIL Image or ndarray to tensor and scale the values accordingly. 385 | 386 | This transform does not support torchscript. 387 | 388 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 389 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 390 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 391 | or if the numpy.ndarray has dtype = np.uint8 392 | 393 | In the other cases, tensors are returned without scaling. 394 | 395 | .. note:: 396 | Because the input image is scaled to [0.0, 1.0], this transformation should not be used when 397 | transforming target image masks. See the `references`_ for implementing the transforms for image masks. 398 | 399 | .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation 400 | """ 401 | 402 | def __call__(self, inputs): 403 | """ 404 | Args: 405 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 406 | 407 | Returns: 408 | Tensor: Converted image. 409 | """ 410 | 411 | if not isinstance(inputs, tuple): 412 | return torchvision.transforms.functional.to_tensor(inputs) 413 | return torchvision.transforms.functional.to_tensor(inputs[0]), np.asarray(inputs[1]) 414 | 415 | def __repr__(self) -> str: 416 | return f"{self.__class__.__name__}()" 417 | -------------------------------------------------------------------------------- /1-guillotine/reader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import seaborn\n", 11 | "import matplotlib.pyplot as plt" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "image/png": "", 22 | "text/plain": [ 23 | "
" 24 | ] 25 | }, 26 | "metadata": {}, 27 | "output_type": "display_data" 28 | }, 29 | { 30 | "data": { 31 | "image/png": "", 32 | "text/plain": [ 33 | "
" 34 | ] 35 | }, 36 | "metadata": {}, 37 | "output_type": "display_data" 38 | } 39 | ], 40 | "source": [ 41 | "fig = plt.figure(figsize=(6,4))\n", 42 | "false_test = pd.read_csv(\"False_train.csv\",index_col=\"Unnamed: 0\")\n", 43 | "falsrain = pd.concat([false_test, false_test.mean(1)], axis=1)\n", 44 | "false_test.columns = list(false_test.columns[:-1]) + [\"mean\"]\n", 45 | "seaborn.heatmap(false_test, annot=True, cbar=False, linewidths=1, linecolor='gray')\n", 46 | "plt.xlabel(\"test on\")\n", 47 | "plt.ylabel(\"train on\")\n", 48 | "plt.tight_layout()\n", 49 | "plt.show()\n", 50 | "\n", 51 | "\n", 52 | "fig = plt.figure(figsize=(6,4))\n", 53 | "true_test = pd.read_csv(\"True_train.csv\",index_col=\"Unnamed: 0\")\n", 54 | "true_test = pd.concat([true_test, true_test.mean(1)], axis=1)\n", 55 | "true_test.columns = list(true_test.columns[:-1]) + [\"mean\"]\n", 56 | "seaborn.heatmap(true_test, annot=True, cbar=False, linewidths=1, linecolor='gray')\n", 57 | "plt.xlabel(\"test on\")\n", 58 | "plt.ylabel(\"train on\")\n", 59 | "plt.title(\"With guillotine\")\n", 60 | "plt.tight_layout()\n", 61 | "plt.show()" 62 | ] 63 | } 64 | ], 65 | "metadata": { 66 | "kernelspec": { 67 | "display_name": "pytorch", 68 | "language": "python", 69 | "name": "python3" 70 | }, 71 | "language_info": { 72 | "codemirror_mode": { 73 | "name": "ipython", 74 | "version": 3 75 | }, 76 | "file_extension": ".py", 77 | "mimetype": "text/x-python", 78 | "name": "python", 79 | "nbconvert_exporter": "python", 80 | "pygments_lexer": "ipython3", 81 | "version": "3.10.14" 82 | } 83 | }, 84 | "nbformat": 4, 85 | "nbformat_minor": 2 86 | } 87 | --------------------------------------------------------------------------------