├── .gitignore ├── README.md ├── logistic_regression.py ├── main.py ├── modules ├── __init__.py ├── byol.py └── transformations │ ├── __init__.py │ └── simclr.py ├── process_features.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | datasets 3 | runs 4 | *.pt 5 | *.p 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BYOL - Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning 2 | PyTorch implementation of "Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning" by J.B. Grill et al. 3 | 4 | [Link to paper](https://arxiv.org/abs/2006.07733) 5 | 6 | This repository includes a practical implementation of BYOL with: 7 | - **Distributed Data Parallel training** 8 | - Benchmarks on vision datasets (CIFAR-10 / STL-10) 9 | - Support for PyTorch **<= 1.5.0** 10 | 11 | Open BYOL in Google Colab Notebook 12 | 13 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1B68Ag_oRB0-rbb9AwC20onmknxyYho4B?usp=sharing) 14 | 15 | ## Results 16 | These are the top-1 accuracy of linear classifiers trained on the (frozen) representations learned by BYOL: 17 | 18 | | Method | Batch size | Image size | ResNet | Projection output dim. | Pre-training epochs | Optimizer | STL-10 | CIFAR-10 19 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | 20 | | BYOL + linear eval. | 192 | 224x224 | ResNet18 | 256 | 100 | Adam | _ | **0.832** | 21 | | Logistic Regression | - | - | - | - | - | - | 0.358 | 0.389 | 22 | 23 | 24 | ## Installation 25 | ``` 26 | git clone https://github.com/spijkervet/byol --recurse-submodules -j8 27 | pip3 install -r requirements.txt 28 | python3 main.py 29 | ``` 30 | 31 | 32 | ## Usage 33 | ### Using a pre-trained model 34 | The following commands will train a logistic regression model on a pre-trained ResNet18, yielding a top-1 accuracy of 83.2% on CIFAR-10. 35 | ``` 36 | curl https://github.com/Spijkervet/BYOL/releases/download/1.0/resnet18-CIFAR10-final.pt -L -O 37 | rm features.p 38 | python3 logistic_regression.py --model_path resnet18-CIFAR10-final.pt 39 | ``` 40 | 41 | ### Pre-training 42 | To run pre-training using BYOL with the default arguments (1 node, 1 GPU), use: 43 | ``` 44 | python3 main.py 45 | ``` 46 | 47 | Which is equivalent to: 48 | ``` 49 | python3 main.py --nodes 1 --gpus 1 50 | ``` 51 | The pre-trained models are saved every *n* epochs in \*.pt files, the final model being `model-final.pt` 52 | 53 | ### Finetuning 54 | Finetuning a model ('linear evaluation') on top of the pre-trained, frozen ResNet model can be done using: 55 | ``` 56 | python3 logistic_regression.py --model_path=./model_final.pt 57 | ``` 58 | 59 | With `model_final.pt` being file containing the pre-trained network from the pre-training stage. 60 | 61 | ## Multi-GPU / Multi-node training 62 | Use `python3 main.py --gpus 2` to train e.g. on 2 GPU's, and `python3 main.py --gpus 2 --nodes 2` to train with 2 GPU's using 2 nodes. 63 | See https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html for an excellent explanation. 64 | 65 | ## Arguments 66 | ``` 67 | --image_size, default=224, "Image size" 68 | --learning_rate, default=3e-4, "Initial learning rate." 69 | --batch_size, default=42, "Batch size for training." 70 | --num_epochs, default=100, "Number of epochs to train for." 71 | --checkpoint_epochs, default=10, "Number of epochs between checkpoints/summaries." 72 | --dataset_dir, default="./datasets", "Directory where dataset is stored.", 73 | --num_workers, default=8, "Number of data loading workers (caution with nodes!)" 74 | --nodes, default=1, "Number of nodes" 75 | --gpus, default=1, "number of gpus per node" 76 | --nr, default=0, "ranking within the nodes" 77 | ``` 78 | -------------------------------------------------------------------------------- /logistic_regression.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | from collections import defaultdict 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import datasets, models, transforms 9 | 10 | from modules.transformations import TransformsSimCLR 11 | from process_features import get_features, create_data_loaders_from_arrays 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model_path", required=True, type=str, help="Path to pre-trained model (e.g. model-10.pt)") 17 | parser.add_argument("--image_size", default=224, type=int, help="Image size") 18 | parser.add_argument( 19 | "--learning_rate", default=3e-3, type=float, help="Initial learning rate." 20 | ) 21 | parser.add_argument( 22 | "--batch_size", default=768, type=int, help="Batch size for training." 23 | ) 24 | parser.add_argument( 25 | "--num_epochs", default=300, type=int, help="Number of epochs to train for." 26 | ) 27 | parser.add_argument( 28 | "--resnet_version", default="resnet18", type=str, help="ResNet version." 29 | ) 30 | parser.add_argument( 31 | "--checkpoint_epochs", 32 | default=10, 33 | type=int, 34 | help="Number of epochs between checkpoints/summaries.", 35 | ) 36 | parser.add_argument( 37 | "--dataset_dir", 38 | default="./datasets", 39 | type=str, 40 | help="Directory where dataset is stored.", 41 | ) 42 | parser.add_argument( 43 | "--num_workers", 44 | default=8, 45 | type=int, 46 | help="Number of data loading workers (caution with nodes!)", 47 | ) 48 | args = parser.parse_args() 49 | 50 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 51 | 52 | # data loaders 53 | train_dataset = datasets.CIFAR10( 54 | args.dataset_dir, 55 | download=True, 56 | transform=TransformsSimCLR(size=args.image_size).test_transform, 57 | ) 58 | 59 | test_dataset = datasets.CIFAR10( 60 | args.dataset_dir, 61 | train=False, 62 | download=True, 63 | transform=TransformsSimCLR(size=args.image_size).test_transform, 64 | ) 65 | 66 | train_loader = torch.utils.data.DataLoader( 67 | train_dataset, 68 | batch_size=args.batch_size, 69 | drop_last=True, 70 | num_workers=args.num_workers, 71 | ) 72 | 73 | test_loader = torch.utils.data.DataLoader( 74 | test_dataset, 75 | batch_size=args.batch_size, 76 | drop_last=True, 77 | num_workers=args.num_workers, 78 | ) 79 | 80 | # pre-trained model 81 | if args.resnet_version == "resnet18": 82 | resnet = models.resnet18(pretrained=False) 83 | elif args.resnet_version == "resnet50": 84 | resnet = models.resnet50(pretrained=False) 85 | else: 86 | raise NotImplementedError("ResNet not implemented") 87 | 88 | 89 | resnet.load_state_dict(torch.load(args.model_path, map_location=device)) 90 | resnet = resnet.to(device) 91 | 92 | num_features = list(resnet.children())[-1].in_features 93 | 94 | # throw away fc layer 95 | resnet = nn.Sequential(*list(resnet.children())[:-1]) 96 | n_classes = 10 # CIFAR-10 has 10 classes 97 | 98 | # fine-tune model 99 | logreg = nn.Sequential(nn.Linear(num_features, n_classes)) 100 | logreg = logreg.to(device) 101 | 102 | # loss / optimizer 103 | criterion = nn.CrossEntropyLoss() 104 | optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate) 105 | 106 | # compute features (only needs to be done once, since it does not backprop during fine-tuning) 107 | if not os.path.exists("features.p"): 108 | print("### Creating features from pre-trained model ###") 109 | (train_X, train_y, test_X, test_y) = get_features( 110 | resnet, train_loader, test_loader, device 111 | ) 112 | pickle.dump( 113 | (train_X, train_y, test_X, test_y), open("features.p", "wb"), protocol=4 114 | ) 115 | else: 116 | print("### Loading features ###") 117 | (train_X, train_y, test_X, test_y) = pickle.load(open("features.p", "rb")) 118 | 119 | 120 | train_loader, test_loader = create_data_loaders_from_arrays( 121 | train_X, train_y, test_X, test_y, 2048 122 | ) 123 | 124 | # Train fine-tuned model 125 | for epoch in range(args.num_epochs): 126 | metrics = defaultdict(list) 127 | for step, (h, y) in enumerate(train_loader): 128 | h = h.to(device) 129 | y = y.to(device) 130 | 131 | outputs = logreg(h) 132 | 133 | loss = criterion(outputs, y) 134 | optimizer.zero_grad() 135 | loss.backward() 136 | optimizer.step() 137 | 138 | # calculate accuracy and save metrics 139 | accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0) 140 | metrics["Loss/train"].append(loss.item()) 141 | metrics["Accuracy/train"].append(accuracy) 142 | 143 | print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()])) 144 | 145 | 146 | # Test fine-tuned model 147 | print("### Calculating final testing performance ###") 148 | metrics = defaultdict(list) 149 | for step, (h, y) in enumerate(test_loader): 150 | h = h.to(device) 151 | y = y.to(device) 152 | 153 | outputs = logreg(h) 154 | 155 | # calculate accuracy and save metrics 156 | accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0) 157 | metrics["Accuracy/test"].append(accuracy) 158 | 159 | print(f"Final test performance: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()])) 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | from torchvision import models, datasets 6 | import numpy as np 7 | from collections import defaultdict 8 | 9 | from modules import BYOL 10 | from modules.transformations import TransformsSimCLR 11 | 12 | # distributed training 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | 17 | 18 | def cleanup(): 19 | dist.destroy_process_group() 20 | 21 | 22 | def main(gpu, args): 23 | rank = args.nr * args.gpus + gpu 24 | dist.init_process_group("nccl", rank=rank, world_size=args.world_size) 25 | 26 | torch.manual_seed(0) 27 | torch.cuda.set_device(gpu) 28 | 29 | # dataset 30 | train_dataset = datasets.CIFAR10( 31 | args.dataset_dir, 32 | download=True, 33 | transform=TransformsSimCLR(size=args.image_size), # paper 224 34 | ) 35 | 36 | train_sampler = torch.utils.data.distributed.DistributedSampler( 37 | train_dataset, num_replicas=args.world_size, rank=rank 38 | ) 39 | 40 | train_loader = torch.utils.data.DataLoader( 41 | train_dataset, 42 | batch_size=args.batch_size, 43 | drop_last=True, 44 | num_workers=args.num_workers, 45 | pin_memory=True, 46 | sampler=train_sampler, 47 | ) 48 | 49 | # model 50 | if args.resnet_version == "resnet18": 51 | resnet = models.resnet18(pretrained=False) 52 | elif args.resnet_version == "resnet50": 53 | resnet = models.resnet50(pretrained=False) 54 | else: 55 | raise NotImplementedError("ResNet not implemented") 56 | 57 | model = BYOL(resnet, image_size=args.image_size, hidden_layer="avgpool") 58 | model = model.cuda(gpu) 59 | 60 | # distributed data parallel 61 | model = DDP(model, device_ids=[gpu], find_unused_parameters=True) 62 | 63 | # optimizer 64 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) 65 | 66 | # TensorBoard writer 67 | 68 | if gpu == 0: 69 | writer = SummaryWriter() 70 | 71 | # solver 72 | global_step = 0 73 | for epoch in range(args.num_epochs): 74 | metrics = defaultdict(list) 75 | for step, ((x_i, x_j), _) in enumerate(train_loader): 76 | x_i = x_i.cuda(non_blocking=True) 77 | x_j = x_j.cuda(non_blocking=True) 78 | 79 | loss = model(x_i, x_j) 80 | optimizer.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | model.module.update_moving_average() # update moving average of target encoder 84 | 85 | if step % 1 == 0 and gpu == 0: 86 | print(f"Step [{step}/{len(train_loader)}]:\tLoss: {loss.item()}") 87 | 88 | if gpu == 0: 89 | writer.add_scalar("Loss/train_step", loss, global_step) 90 | metrics["Loss/train"].append(loss.item()) 91 | global_step += 1 92 | 93 | if gpu == 0: 94 | # write metrics to TensorBoard 95 | for k, v in metrics.items(): 96 | writer.add_scalar(k, np.array(v).mean(), epoch) 97 | 98 | if epoch % args.checkpoint_epochs == 0: 99 | if gpu == 0: 100 | print(f"Saving model at epoch {epoch}") 101 | torch.save(resnet.state_dict(), f"./model-{epoch}.pt") 102 | 103 | # let other workers wait until model is finished 104 | # dist.barrier() 105 | 106 | # save your improved network 107 | if gpu == 0: 108 | torch.save(resnet.state_dict(), "./model-final.pt") 109 | 110 | cleanup() 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--image_size", default=224, type=int, help="Image size") 116 | parser.add_argument( 117 | "--learning_rate", default=3e-4, type=float, help="Initial learning rate." 118 | ) 119 | parser.add_argument( 120 | "--batch_size", default=192, type=int, help="Batch size for training." 121 | ) 122 | parser.add_argument( 123 | "--num_epochs", default=100, type=int, help="Number of epochs to train for." 124 | ) 125 | parser.add_argument( 126 | "--resnet_version", default="resnet18", type=str, help="ResNet version." 127 | ) 128 | parser.add_argument( 129 | "--checkpoint_epochs", 130 | default=5, 131 | type=int, 132 | help="Number of epochs between checkpoints/summaries.", 133 | ) 134 | parser.add_argument( 135 | "--dataset_dir", 136 | default="./datasets", 137 | type=str, 138 | help="Directory where dataset is stored.", 139 | ) 140 | parser.add_argument( 141 | "--num_workers", 142 | default=8, 143 | type=int, 144 | help="Number of data loading workers (caution with nodes!)", 145 | ) 146 | parser.add_argument( 147 | "--nodes", default=1, type=int, help="Number of nodes", 148 | ) 149 | parser.add_argument("--gpus", default=1, type=int, help="number of gpus per node") 150 | parser.add_argument("--nr", default=0, type=int, help="ranking within the nodes") 151 | args = parser.parse_args() 152 | 153 | # Master address for distributed data parallel 154 | os.environ["MASTER_ADDR"] = "127.0.0.1" 155 | os.environ["MASTER_PORT"] = "8010" 156 | args.world_size = args.gpus * args.nodes 157 | 158 | # Initialize the process and join up with the other processes. 159 | # This is “blocking,” meaning that no process will continue until all processes have joined. 160 | mp.spawn(main, args=(args,), nprocs=args.gpus, join=True) 161 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .byol import BYOL -------------------------------------------------------------------------------- /modules/byol.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2020 Phil Wang 5 | https://github.com/lucidrains/byol-pytorch/ 6 | 7 | Adjusted to de-couple for data loading, parallel training 8 | """ 9 | 10 | import copy 11 | import random 12 | from functools import wraps 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | 18 | # helper functions 19 | 20 | 21 | def default(val, def_val): 22 | return def_val if val is None else val 23 | 24 | 25 | def flatten(t): 26 | return t.reshape(t.shape[0], -1) 27 | 28 | 29 | def singleton(cache_key): 30 | def inner_fn(fn): 31 | @wraps(fn) 32 | def wrapper(self, *args, **kwargs): 33 | instance = getattr(self, cache_key) 34 | if instance is not None: 35 | return instance 36 | 37 | instance = fn(self, *args, **kwargs) 38 | setattr(self, cache_key, instance) 39 | return instance 40 | 41 | return wrapper 42 | 43 | return inner_fn 44 | 45 | 46 | # loss fn 47 | 48 | 49 | def loss_fn(x, y): 50 | x = F.normalize(x, dim=-1, p=2) 51 | y = F.normalize(y, dim=-1, p=2) 52 | return 2 - 2 * (x * y).sum(dim=-1) 53 | 54 | 55 | # augmentation utils 56 | 57 | 58 | class RandomApply(nn.Module): 59 | def __init__(self, fn, p): 60 | super().__init__() 61 | self.fn = fn 62 | self.p = p 63 | 64 | def forward(self, x): 65 | if random.random() > self.p: 66 | return x 67 | return self.fn(x) 68 | 69 | 70 | # exponential moving average 71 | 72 | 73 | class EMA: 74 | def __init__(self, beta): 75 | super().__init__() 76 | self.beta = beta 77 | 78 | def update_average(self, old, new): 79 | if old is None: 80 | return new 81 | return old * self.beta + (1 - self.beta) * new 82 | 83 | 84 | def update_moving_average(ema_updater, ma_model, current_model): 85 | for current_params, ma_params in zip( 86 | current_model.parameters(), ma_model.parameters() 87 | ): 88 | old_weight, up_weight = ma_params.data, current_params.data 89 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 90 | 91 | 92 | # MLP class for projector and predictor 93 | 94 | 95 | class MLP(nn.Module): 96 | def __init__(self, dim, projection_size, hidden_size=4096): 97 | super().__init__() 98 | self.net = nn.Sequential( 99 | nn.Linear(dim, hidden_size), 100 | nn.BatchNorm1d(hidden_size), 101 | nn.ReLU(inplace=True), 102 | nn.Linear(hidden_size, projection_size), 103 | ) 104 | 105 | def forward(self, x): 106 | return self.net(x) 107 | 108 | 109 | # a wrapper class for the base neural network 110 | # will manage the interception of the hidden layer output 111 | # and pipe it into the projecter and predictor nets 112 | 113 | 114 | class NetWrapper(nn.Module): 115 | def __init__(self, net, projection_size, projection_hidden_size, layer=-2): 116 | super().__init__() 117 | self.net = net 118 | self.layer = layer 119 | 120 | self.projector = None 121 | self.projection_size = projection_size 122 | self.projection_hidden_size = projection_hidden_size 123 | 124 | self.hidden = None 125 | self.hook_registered = False 126 | 127 | def _find_layer(self): 128 | if type(self.layer) == str: 129 | modules = dict([*self.net.named_modules()]) 130 | return modules.get(self.layer, None) 131 | elif type(self.layer) == int: 132 | children = [*self.net.children()] 133 | return children[self.layer] 134 | return None 135 | 136 | def _hook(self, _, __, output): 137 | self.hidden = flatten(output) 138 | 139 | def _register_hook(self): 140 | layer = self._find_layer() 141 | assert layer is not None, f"hidden layer ({self.layer}) not found" 142 | handle = layer.register_forward_hook(self._hook) 143 | self.hook_registered = True 144 | 145 | @singleton("projector") 146 | def _get_projector(self, hidden): 147 | _, dim = hidden.shape 148 | projector = MLP(dim, self.projection_size, self.projection_hidden_size) 149 | return projector.to(hidden) 150 | 151 | def get_representation(self, x): 152 | if not self.hook_registered: 153 | self._register_hook() 154 | 155 | if self.layer == -1: 156 | return self.net(x) 157 | 158 | _ = self.net(x) 159 | hidden = self.hidden 160 | self.hidden = None 161 | assert hidden is not None, f"hidden layer {self.layer} never emitted an output" 162 | return hidden 163 | 164 | def forward(self, x): 165 | representation = self.get_representation(x) 166 | projector = self._get_projector(representation) 167 | projection = projector(representation) 168 | return projection 169 | 170 | 171 | # main class 172 | 173 | 174 | class BYOL(nn.Module): 175 | def __init__( 176 | self, 177 | net, 178 | image_size, 179 | hidden_layer=-2, 180 | projection_size=256, 181 | projection_hidden_size=4096, 182 | augment_fn=None, 183 | moving_average_decay=0.99, 184 | ): 185 | super().__init__() 186 | 187 | self.online_encoder = NetWrapper( 188 | net, projection_size, projection_hidden_size, layer=hidden_layer 189 | ) 190 | self.target_encoder = None 191 | self.target_ema_updater = EMA(moving_average_decay) 192 | 193 | self.online_predictor = MLP( 194 | projection_size, projection_size, projection_hidden_size 195 | ) 196 | 197 | # send a mock image tensor to instantiate singleton parameters 198 | self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size)) 199 | 200 | @singleton("target_encoder") 201 | def _get_target_encoder(self): 202 | target_encoder = copy.deepcopy(self.online_encoder) 203 | return target_encoder 204 | 205 | def reset_moving_average(self): 206 | del self.target_encoder 207 | self.target_encoder = None 208 | 209 | def update_moving_average(self): 210 | assert ( 211 | self.target_encoder is not None 212 | ), "target encoder has not been created yet" 213 | update_moving_average( 214 | self.target_ema_updater, self.target_encoder, self.online_encoder 215 | ) 216 | 217 | def forward(self, image_one, image_two): 218 | online_proj_one = self.online_encoder(image_one) 219 | online_proj_two = self.online_encoder(image_two) 220 | 221 | online_pred_one = self.online_predictor(online_proj_one) 222 | online_pred_two = self.online_predictor(online_proj_two) 223 | 224 | with torch.no_grad(): 225 | target_encoder = self._get_target_encoder() 226 | target_proj_one = target_encoder(image_one) 227 | target_proj_two = target_encoder(image_two) 228 | 229 | loss_one = loss_fn(online_pred_one, target_proj_two.detach()) 230 | loss_two = loss_fn(online_pred_two, target_proj_one.detach()) 231 | 232 | loss = loss_one + loss_two 233 | return loss.mean() 234 | -------------------------------------------------------------------------------- /modules/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | from .simclr import TransformsSimCLR 2 | -------------------------------------------------------------------------------- /modules/transformations/simclr.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | 4 | class TransformsSimCLR: 5 | """ 6 | A stochastic data augmentation module that transforms any given data example randomly 7 | resulting in two correlated views of the same example, 8 | denoted x ̃i and x ̃j, which we consider as a positive pair. 9 | """ 10 | 11 | def __init__(self, size): 12 | s = 1 13 | color_jitter = torchvision.transforms.ColorJitter( 14 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s 15 | ) 16 | self.train_transform = torchvision.transforms.Compose( 17 | [ 18 | torchvision.transforms.RandomResizedCrop(size=size), 19 | torchvision.transforms.RandomHorizontalFlip(), # with 0.5 probability 20 | torchvision.transforms.RandomApply([color_jitter], p=0.8), 21 | torchvision.transforms.RandomGrayscale(p=0.2), 22 | torchvision.transforms.ToTensor(), 23 | ] 24 | ) 25 | 26 | self.test_transform = torchvision.transforms.Compose( 27 | [ 28 | torchvision.transforms.Resize(size=size), 29 | torchvision.transforms.ToTensor(), 30 | ] 31 | ) 32 | 33 | def __call__(self, x): 34 | return self.train_transform(x), self.train_transform(x) 35 | -------------------------------------------------------------------------------- /process_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def inference(loader, model, device): 5 | feature_vector = [] 6 | labels_vector = [] 7 | for step, (x, y) in enumerate(loader): 8 | x = x.to(device) 9 | 10 | # get encoding 11 | with torch.no_grad(): 12 | h = model(x) 13 | 14 | h = h.squeeze() 15 | h = h.detach() 16 | 17 | feature_vector.extend(h.cpu().detach().numpy()) 18 | labels_vector.extend(y.numpy()) 19 | 20 | if step % 5 == 0: 21 | print(f"Step [{step}/{len(loader)}]\t Computing features...") 22 | 23 | feature_vector = np.array(feature_vector) 24 | labels_vector = np.array(labels_vector) 25 | print("Features shape {}".format(feature_vector.shape)) 26 | return feature_vector, labels_vector 27 | 28 | 29 | def get_features(model, train_loader, test_loader, device): 30 | train_X, train_y = inference(train_loader, model, device) 31 | test_X, test_y = inference(test_loader, model, device) 32 | return train_X, train_y, test_X, test_y 33 | 34 | 35 | def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size): 36 | train = torch.utils.data.TensorDataset( 37 | torch.from_numpy(X_train), torch.from_numpy(y_train) 38 | ) 39 | train_loader = torch.utils.data.DataLoader( 40 | train, batch_size=batch_size, shuffle=False 41 | ) 42 | 43 | test = torch.utils.data.TensorDataset( 44 | torch.from_numpy(X_test), torch.from_numpy(y_test) 45 | ) 46 | test_loader = torch.utils.data.DataLoader( 47 | test, batch_size=batch_size, shuffle=False 48 | ) 49 | return train_loader, test_loader 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision --------------------------------------------------------------------------------